diff --git a/mediapipe/gpu/BUILD b/mediapipe/gpu/BUILD index aec2445b9..9c2f47469 100644 --- a/mediapipe/gpu/BUILD +++ b/mediapipe/gpu/BUILD @@ -750,19 +750,12 @@ objc_library( ], ) -proto_library( +mediapipe_proto_library( name = "scale_mode_proto", srcs = ["scale_mode.proto"], visibility = ["//visibility:public"], ) -mediapipe_cc_proto_library( - name = "scale_mode_cc_proto", - srcs = ["scale_mode.proto"], - visibility = ["//visibility:public"], - deps = [":scale_mode_proto"], -) - cc_library( name = "gl_quad_renderer", srcs = ["gl_quad_renderer.cc"], diff --git a/mediapipe/model_maker/models/gesture_recognizer/BUILD b/mediapipe/model_maker/models/gesture_recognizer/BUILD new file mode 100644 index 000000000..f8e5cdd21 --- /dev/null +++ b/mediapipe/model_maker/models/gesture_recognizer/BUILD @@ -0,0 +1,51 @@ +# Copyright 2022 The MediaPipe Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +load( + "//mediapipe/framework/tool:mediapipe_files.bzl", + "mediapipe_files", +) + +licenses(["notice"]) + +package( + default_visibility = ["//mediapipe/model_maker/python/vision/gesture_recognizer:__subpackages__"], +) + +mediapipe_files( + srcs = [ + "canned_gesture_classifier.tflite", + "gesture_embedder.tflite", + "gesture_embedder/keras_metadata.pb", + "gesture_embedder/saved_model.pb", + "gesture_embedder/variables/variables.data-00000-of-00001", + "gesture_embedder/variables/variables.index", + "hand_landmark_full.tflite", + "palm_detection_full.tflite", + ], +) + +filegroup( + name = "models", + srcs = [ + "canned_gesture_classifier.tflite", + "gesture_embedder.tflite", + "gesture_embedder/keras_metadata.pb", + "gesture_embedder/saved_model.pb", + "gesture_embedder/variables/variables.data-00000-of-00001", + "gesture_embedder/variables/variables.index", + "hand_landmark_full.tflite", + "palm_detection_full.tflite", + ], +) diff --git a/mediapipe/model_maker/python/vision/gesture_recognizer/BUILD b/mediapipe/model_maker/python/vision/gesture_recognizer/BUILD new file mode 100644 index 000000000..b7d334d9c --- /dev/null +++ b/mediapipe/model_maker/python/vision/gesture_recognizer/BUILD @@ -0,0 +1,165 @@ +# Copyright 2022 The MediaPipe Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +# Placeholder for internal Python strict test compatibility macro. +# Placeholder for internal Python strict library and test compatibility macro. + +licenses(["notice"]) + +package( + default_visibility = ["//mediapipe:__subpackages__"], +) + +# TODO: Remove the unncessary test data once the demo data are moved to an open-sourced +# directory. +filegroup( + name = "test_data", + srcs = glob([ + "test_data/**", + ]), +) + +py_library( + name = "constants", + srcs = ["constants.py"], +) + +# TODO: Change to py_library after migrating the MediaPipe hand solution +# library to MediaPipe hand task library. +py_library( + name = "dataset", + srcs = ["dataset.py"], + deps = [ + ":constants", + "//mediapipe/model_maker/python/core/data:classification_dataset", + "//mediapipe/model_maker/python/core/data:data_util", + "//mediapipe/model_maker/python/core/utils:model_util", + "//mediapipe/python/solutions:hands", + ], +) + +py_test( + name = "dataset_test", + srcs = ["dataset_test.py"], + data = [ + ":test_data", + "//mediapipe/model_maker/models/gesture_recognizer:models", + ], + deps = [ + ":dataset", + "//mediapipe/python/solutions:hands", + "//mediapipe/tasks/python/test:test_utils", + ], +) + +py_library( + name = "hyperparameters", + srcs = ["hyperparameters.py"], + deps = [ + "//mediapipe/model_maker/python/core:hyperparameters", + ], +) + +py_library( + name = "model_options", + srcs = ["model_options.py"], +) + +py_library( + name = "gesture_recognizer_options", + srcs = ["gesture_recognizer_options.py"], + deps = [ + ":hyperparameters", + ":model_options", + ], +) + +py_library( + name = "gesture_recognizer", + srcs = ["gesture_recognizer.py"], + data = ["//mediapipe/model_maker/models/gesture_recognizer:models"], + deps = [ + ":constants", + ":gesture_recognizer_options", + ":hyperparameters", + ":metadata_writer", + ":model_options", + "//mediapipe/model_maker/python/core/data:classification_dataset", + "//mediapipe/model_maker/python/core/tasks:classifier", + "//mediapipe/model_maker/python/core/utils:loss_functions", + "//mediapipe/model_maker/python/core/utils:model_util", + "//mediapipe/tasks/python/metadata/metadata_writers:metadata_writer", + ], +) + +py_library( + name = "gesture_recognizer_import", + srcs = ["__init__.py"], + deps = [ + ":dataset", + ":gesture_recognizer", + ":gesture_recognizer_options", + ":hyperparameters", + ":model_options", + ], +) + +py_library( + name = "metadata_writer", + srcs = ["metadata_writer.py"], + deps = [ + "//mediapipe/tasks/python/metadata/metadata_writers:metadata_writer", + "//mediapipe/tasks/python/metadata/metadata_writers:writer_utils", + ], +) + +py_test( + name = "gesture_recognizer_test", + size = "large", + srcs = ["gesture_recognizer_test.py"], + data = [ + ":test_data", + "//mediapipe/model_maker/models/gesture_recognizer:models", + ], + shard_count = 2, + deps = [ + ":gesture_recognizer_import", + "//mediapipe/model_maker/python/core/utils:test_util", + "//mediapipe/tasks/python/test:test_utils", + ], +) + +py_test( + name = "metadata_writer_test", + srcs = ["metadata_writer_test.py"], + data = [ + ":test_data", + ], + deps = [ + ":metadata_writer", + "//mediapipe/tasks/python/metadata/metadata_writers:metadata_writer", + "//mediapipe/tasks/python/test:test_utils", + ], +) + +py_binary( + name = "gesture_recognizer_demo", + srcs = ["gesture_recognizer_demo.py"], + data = [ + ":test_data", + "//mediapipe/model_maker/models/gesture_recognizer:models", + ], + python_version = "PY3", + deps = [":gesture_recognizer_import"], +) diff --git a/mediapipe/model_maker/python/vision/gesture_recognizer/__init__.py b/mediapipe/model_maker/python/vision/gesture_recognizer/__init__.py new file mode 100644 index 000000000..dc6923fac --- /dev/null +++ b/mediapipe/model_maker/python/vision/gesture_recognizer/__init__.py @@ -0,0 +1,27 @@ +# Copyright 2022 The MediaPipe Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""MediaPipe Model Maker Python Public API For Gesture Recognizer.""" + +from mediapipe.model_maker.python.vision.gesture_recognizer import dataset +from mediapipe.model_maker.python.vision.gesture_recognizer import gesture_recognizer +from mediapipe.model_maker.python.vision.gesture_recognizer import gesture_recognizer_options +from mediapipe.model_maker.python.vision.gesture_recognizer import hyperparameters +from mediapipe.model_maker.python.vision.gesture_recognizer import model_options + +GestureRecognizer = gesture_recognizer.GestureRecognizer +ModelOptions = model_options.GestureRecognizerModelOptions +HParams = hyperparameters.HParams +Dataset = dataset.Dataset +HandDataPreprocessingParams = dataset.HandDataPreprocessingParams +GestureRecognizerOptions = gesture_recognizer_options.GestureRecognizerOptions diff --git a/mediapipe/model_maker/python/vision/gesture_recognizer/constants.py b/mediapipe/model_maker/python/vision/gesture_recognizer/constants.py new file mode 100644 index 000000000..ac9bba12a --- /dev/null +++ b/mediapipe/model_maker/python/vision/gesture_recognizer/constants.py @@ -0,0 +1,20 @@ +# Copyright 2022 The MediaPipe Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""Gesture recognition constants.""" + +GESTURE_EMBEDDER_KERAS_MODEL_PATH = 'mediapipe/model_maker/models/gesture_recognizer/gesture_embedder' +GESTURE_EMBEDDER_TFLITE_MODEL_FILE = 'mediapipe/model_maker/models/gesture_recognizer/gesture_embedder.tflite' +HAND_DETECTOR_TFLITE_MODEL_FILE = 'mediapipe/model_maker/models/gesture_recognizer/palm_detection_full.tflite' +HAND_LANDMARKS_DETECTOR_TFLITE_MODEL_FILE = 'mediapipe/model_maker/models/gesture_recognizer/hand_landmark_full.tflite' +CANNED_GESTURE_CLASSIFIER_TFLITE_MODEL_FILE = 'mediapipe/model_maker/models/gesture_recognizer/canned_gesture_classifier.tflite' diff --git a/mediapipe/model_maker/python/vision/gesture_recognizer/dataset.py b/mediapipe/model_maker/python/vision/gesture_recognizer/dataset.py new file mode 100644 index 000000000..256f26fd6 --- /dev/null +++ b/mediapipe/model_maker/python/vision/gesture_recognizer/dataset.py @@ -0,0 +1,238 @@ +# Copyright 2022 The MediaPipe Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""Gesture recognition dataset library.""" + +import dataclasses +import os +import random +from typing import List, NamedTuple, Optional + +import cv2 +import tensorflow as tf + +from mediapipe.model_maker.python.core.data import classification_dataset +from mediapipe.model_maker.python.core.data import data_util +from mediapipe.model_maker.python.core.utils import model_util +from mediapipe.model_maker.python.vision.gesture_recognizer import constants +from mediapipe.python.solutions import hands as mp_hands + + +@dataclasses.dataclass +class HandDataPreprocessingParams: + """A dataclass wraps the hand data preprocessing hyperparameters. + + Attributes: + shuffle: A boolean controlling if shuffle the dataset. Default to true. + min_detection_confidence: confidence threshold for hand detection. + """ + shuffle: bool = True + min_detection_confidence: float = 0.7 + + +@dataclasses.dataclass +class HandData: + """A dataclass represents hand data for training gesture recognizer model. + + See https://google.github.io/mediapipe/solutions/hands#mediapipe-hands for + more details of the hand gesture data API. + + Attributes: + hand: normalized hand landmarks of shape 21x3 from the screen based + hand-landmark model. + world_hand: hand landmarks of shape 21x3 in world coordinates. + handedness: Collection of handedness confidence of the detected hands (i.e. + is it a left or right hand). + """ + hand: List[List[float]] + world_hand: List[List[float]] + handedness: List[float] + + +def _validate_data_sample(data: NamedTuple) -> bool: + """Validates the input hand data sample. + + Args: + data: input hand data sample. + + Returns: + False if the input data namedtuple does not contain the fields including + 'multi_hand_landmarks' or 'multi_hand_world_landmarks' or 'multi_handedness' + or any of these attributes' values are none. Otherwise, True. + """ + if (not hasattr(data, 'multi_hand_landmarks') or + data.multi_hand_landmarks is None): + return False + if (not hasattr(data, 'multi_hand_world_landmarks') or + data.multi_hand_world_landmarks is None): + return False + if not hasattr(data, 'multi_handedness') or data.multi_handedness is None: + return False + return True + + +def _get_hand_data(all_image_paths: List[str], + min_detection_confidence: float) -> Optional[HandData]: + """Computes hand data (landmarks and handedness) in the input image. + + Args: + all_image_paths: all input image paths. + min_detection_confidence: hand detection confidence threshold + + Returns: + A HandData object. Returns None if no hand is detected. + """ + hand_data_result = [] + with mp_hands.Hands( + static_image_mode=True, + max_num_hands=1, + min_detection_confidence=min_detection_confidence) as hands: + for path in all_image_paths: + tf.compat.v1.logging.info('Loading image %s', path) + image = data_util.load_image(path) + # Flip image around y-axis for correct handedness output + image = cv2.flip(image, 1) + data = hands.process(image) + if not _validate_data_sample(data): + hand_data_result.append(None) + continue + hand_landmarks = [[ + hand_landmark.x, hand_landmark.y, hand_landmark.z + ] for hand_landmark in data.multi_hand_landmarks[0].landmark] + hand_world_landmarks = [[ + hand_landmark.x, hand_landmark.y, hand_landmark.z + ] for hand_landmark in data.multi_hand_world_landmarks[0].landmark] + handedness_scores = [ + handedness.score + for handedness in data.multi_handedness[0].classification + ] + hand_data_result.append( + HandData( + hand=hand_landmarks, + world_hand=hand_world_landmarks, + handedness=handedness_scores)) + return hand_data_result + + +class Dataset(classification_dataset.ClassificationDataset): + """Dataset library for hand gesture recognizer.""" + + @classmethod + def from_folder( + cls, + dirname: str, + hparams: Optional[HandDataPreprocessingParams] = None + ) -> classification_dataset.ClassificationDataset: + """Loads images and labels from the given directory. + + Directory contents are expected to be in the format: + //*.jpg". One of the `gesture_name` must be `none` + (case insensitive). The `none` sub-directory is expected to contain images + of hands that don't belong to other gesture classes in . Assumes + the image data of the same label are in the same subdirectory. + + Args: + dirname: Name of the directory containing the data files. + hparams: Optional hyperparameters for processing input hand gesture + images. + + Returns: + Dataset containing landmarks, labels, and other related info. + + Raises: + ValueError: if the input data directory is empty or the label set does not + contain label 'none' (case insensitive). + """ + data_root = os.path.abspath(dirname) + + # Assumes the image data of the same label are in the same subdirectory, + # gets image path and label names. + all_image_paths = list(tf.io.gfile.glob(data_root + r'/*/*')) + if not all_image_paths: + raise ValueError('Image dataset directory is empty.') + + if not hparams: + hparams = HandDataPreprocessingParams() + + if hparams.shuffle: + # Random shuffle data. + random.shuffle(all_image_paths) + + label_names = sorted( + name for name in os.listdir(data_root) + if os.path.isdir(os.path.join(data_root, name))) + if 'none' not in [v.lower() for v in label_names]: + raise ValueError('Label set does not contain label "None".') + # Move label 'none' to the front of label list. + none_idx = [v.lower() for v in label_names].index('none') + none_value = label_names.pop(none_idx) + label_names.insert(0, none_value) + + index_by_label = dict( + (name, index) for index, name in enumerate(label_names)) + all_gesture_indices = [ + index_by_label[os.path.basename(os.path.dirname(path))] + for path in all_image_paths + ] + + # Compute hand data (including local hand landmark, world hand landmark, and + # handedness) for all the input images. + hand_data = _get_hand_data( + all_image_paths=all_image_paths, + min_detection_confidence=hparams.min_detection_confidence) + + # Get a list of the valid hand landmark sample in the hand data list. + valid_indices = [ + i for i in range(len(hand_data)) if hand_data[i] is not None + ] + # Remove 'None' element from the hand data and label list. + valid_hand_data = [dataclasses.asdict(hand_data[i]) for i in valid_indices] + if not valid_hand_data: + raise ValueError('No valid hand is detected.') + + valid_label = [all_gesture_indices[i] for i in valid_indices] + + # Convert list of dictionaries to dictionary of lists. + hand_data_dict = { + k: [lm[k] for lm in valid_hand_data] for k in valid_hand_data[0] + } + hand_ds = tf.data.Dataset.from_tensor_slices(hand_data_dict) + + embedder_model = model_util.load_keras_model( + constants.GESTURE_EMBEDDER_KERAS_MODEL_PATH) + + hand_ds = hand_ds.batch(batch_size=1) + hand_embedding_ds = hand_ds.map( + map_func=lambda feature: embedder_model(dict(feature)), + num_parallel_calls=tf.data.experimental.AUTOTUNE) + hand_embedding_ds = hand_embedding_ds.unbatch() + + # Create label dataset + label_ds = tf.data.Dataset.from_tensor_slices( + tf.cast(valid_label, tf.int64)) + + label_one_hot_ds = label_ds.map( + map_func=lambda index: tf.one_hot(index, len(label_names)), + num_parallel_calls=tf.data.experimental.AUTOTUNE) + + # Create a dataset with (hand_embedding, one_hot_label) pairs + hand_embedding_label_ds = tf.data.Dataset.zip( + (hand_embedding_ds, label_one_hot_ds)) + + tf.compat.v1.logging.info( + 'Load valid hands with size: {}, num_label: {}, labels: {}.'.format( + len(valid_hand_data), len(label_names), ','.join(label_names))) + return Dataset( + dataset=hand_embedding_label_ds, + size=len(valid_hand_data), + label_names=label_names) diff --git a/mediapipe/model_maker/python/vision/gesture_recognizer/dataset_test.py b/mediapipe/model_maker/python/vision/gesture_recognizer/dataset_test.py new file mode 100644 index 000000000..76e70a58d --- /dev/null +++ b/mediapipe/model_maker/python/vision/gesture_recognizer/dataset_test.py @@ -0,0 +1,161 @@ +# Copyright 2022 The MediaPipe Authors. All Rights Reserved.s +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import collections +import os +import shutil +from typing import NamedTuple +import unittest + +from absl import flags +from absl.testing import parameterized +import tensorflow as tf + +from mediapipe.model_maker.python.vision.gesture_recognizer import dataset +from mediapipe.python.solutions import hands as mp_hands +from mediapipe.tasks.python.test import test_utils + +FLAGS = flags.FLAGS + +_TEST_DATA_DIRNAME = 'raw_data' + + +class DatasetTest(tf.test.TestCase, parameterized.TestCase): + + def test_split(self): + input_data_dir = test_utils.get_test_data_path(_TEST_DATA_DIRNAME) + data = dataset.Dataset.from_folder( + dirname=input_data_dir, hparams=dataset.HandDataPreprocessingParams()) + train_data, test_data = data.split(0.5) + + self.assertLen(train_data, 17) + for _, elem in enumerate(train_data.gen_tf_dataset(is_training=True)): + self.assertEqual(elem[0].shape, (1, 128)) + self.assertEqual(elem[1].shape, ([1, 4])) + self.assertEqual(train_data.num_classes, 4) + self.assertEqual(train_data.label_names, ['none', 'call', 'four', 'rock']) + + self.assertLen(test_data, 18) + for _, elem in enumerate(test_data.gen_tf_dataset(is_training=True)): + self.assertEqual(elem[0].shape, (1, 128)) + self.assertEqual(elem[1].shape, ([1, 4])) + self.assertEqual(test_data.num_classes, 4) + self.assertEqual(test_data.label_names, ['none', 'call', 'four', 'rock']) + + def test_from_folder(self): + input_data_dir = test_utils.get_test_data_path(_TEST_DATA_DIRNAME) + data = dataset.Dataset.from_folder( + dirname=input_data_dir, hparams=dataset.HandDataPreprocessingParams()) + for _, elem in enumerate(data.gen_tf_dataset(is_training=True)): + self.assertEqual(elem[0].shape, (1, 128)) + self.assertEqual(elem[1].shape, ([1, 4])) + self.assertLen(data, 35) + self.assertEqual(data.num_classes, 4) + self.assertEqual(data.label_names, ['none', 'call', 'four', 'rock']) + + def test_create_dataset_from_empty_folder_raise_value_error(self): + with self.assertRaisesRegex(ValueError, 'Image dataset directory is empty'): + dataset.Dataset.from_folder( + dirname=self.get_temp_dir(), + hparams=dataset.HandDataPreprocessingParams()) + + def test_create_dataset_from_folder_without_none_raise_value_error(self): + input_data_dir = test_utils.get_test_data_path(_TEST_DATA_DIRNAME) + tmp_dir = self.create_tempdir() + # Copy input dataset to a temporary directory and skip 'None' directory. + for name in os.listdir(input_data_dir): + if name == 'none': + continue + src_dir = os.path.join(input_data_dir, name) + dst_dir = os.path.join(tmp_dir, name) + shutil.copytree(src_dir, dst_dir) + + with self.assertRaisesRegex(ValueError, + 'Label set does not contain label "None"'): + dataset.Dataset.from_folder( + dirname=tmp_dir, hparams=dataset.HandDataPreprocessingParams()) + + def test_create_dataset_from_folder_with_capital_letter_in_folder_name(self): + input_data_dir = test_utils.get_test_data_path(_TEST_DATA_DIRNAME) + tmp_dir = self.create_tempdir() + # Copy input dataset to a temporary directory and change the base folder + # name to upper case letter, e.g. 'none' -> 'NONE' + for name in os.listdir(input_data_dir): + src_dir = os.path.join(input_data_dir, name) + dst_dir = os.path.join(tmp_dir, name.upper()) + shutil.copytree(src_dir, dst_dir) + + upper_base_folder_name = list(os.listdir(tmp_dir)) + self.assertCountEqual(upper_base_folder_name, + ['CALL', 'FOUR', 'NONE', 'ROCK']) + + data = dataset.Dataset.from_folder( + dirname=tmp_dir, hparams=dataset.HandDataPreprocessingParams()) + for _, elem in enumerate(data.gen_tf_dataset(is_training=True)): + self.assertEqual(elem[0].shape, (1, 128)) + self.assertEqual(elem[1].shape, ([1, 4])) + self.assertLen(data, 35) + self.assertEqual(data.num_classes, 4) + self.assertEqual(data.label_names, ['NONE', 'CALL', 'FOUR', 'ROCK']) + + @parameterized.named_parameters( + dict( + testcase_name='invalid_field_name_multi_hand_landmark', + hand=collections.namedtuple('Hand', [ + 'multi_hand_landmark', 'multi_hand_world_landmarks', + 'multi_handedness' + ])(1, 2, 3)), + dict( + testcase_name='invalid_field_name_multi_hand_world_landmarks', + hand=collections.namedtuple('Hand', [ + 'multi_hand_landmarks', 'multi_hand_world_landmark', + 'multi_handedness' + ])(1, 2, 3)), + dict( + testcase_name='invalid_field_name_multi_handed', + hand=collections.namedtuple('Hand', [ + 'multi_hand_landmarks', 'multi_hand_world_landmarks', + 'multi_handed' + ])(1, 2, 3)), + dict( + testcase_name='multi_hand_landmarks_is_none', + hand=collections.namedtuple('Hand', [ + 'multi_hand_landmarks', 'multi_hand_world_landmarks', + 'multi_handedness' + ])(None, 2, 3)), + dict( + testcase_name='multi_hand_world_landmarks_is_none', + hand=collections.namedtuple('Hand', [ + 'multi_hand_landmarks', 'multi_hand_world_landmarks', + 'multi_handedness' + ])(1, None, 3)), + dict( + testcase_name='multi_handedness_is_none', + hand=collections.namedtuple('Hand', [ + 'multi_hand_landmarks', 'multi_hand_world_landmarks', + 'multi_handedness' + ])(1, 2, None)), + ) + def test_create_dataset_from_invalid_hand_data(self, hand: NamedTuple): + with unittest.mock.patch.object( + mp_hands.Hands, 'process', return_value=hand): + input_data_dir = test_utils.get_test_data_path(_TEST_DATA_DIRNAME) + with self.assertRaisesRegex(ValueError, 'No valid hand is detected'): + dataset.Dataset.from_folder( + dirname=input_data_dir, + hparams=dataset.HandDataPreprocessingParams()) + + +if __name__ == '__main__': + tf.test.main() diff --git a/mediapipe/model_maker/python/vision/gesture_recognizer/gesture_recognizer.py b/mediapipe/model_maker/python/vision/gesture_recognizer/gesture_recognizer.py new file mode 100644 index 000000000..f297d8640 --- /dev/null +++ b/mediapipe/model_maker/python/vision/gesture_recognizer/gesture_recognizer.py @@ -0,0 +1,239 @@ +# Copyright 2022 The MediaPipe Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""APIs to train gesture recognizer model.""" + +import os +from typing import List + +import tensorflow as tf + +from mediapipe.model_maker.python.core.data import classification_dataset as classification_ds +from mediapipe.model_maker.python.core.tasks import classifier +from mediapipe.model_maker.python.core.utils import loss_functions +from mediapipe.model_maker.python.core.utils import model_util +from mediapipe.model_maker.python.vision.gesture_recognizer import constants +from mediapipe.model_maker.python.vision.gesture_recognizer import gesture_recognizer_options +from mediapipe.model_maker.python.vision.gesture_recognizer import hyperparameters as hp +from mediapipe.model_maker.python.vision.gesture_recognizer import metadata_writer +from mediapipe.model_maker.python.vision.gesture_recognizer import model_options as model_opt +from mediapipe.tasks.python.metadata.metadata_writers import metadata_writer as base_metadata_writer + +_EMBEDDING_SIZE = 128 + + +class GestureRecognizer(classifier.Classifier): + """GestureRecognizer for building hand gesture recognizer model. + + Attributes: + embedding_size: Size of the input gesture embedding vector. + """ + + def __init__(self, label_names: List[str], + model_options: model_opt.GestureRecognizerModelOptions, + hparams: hp.HParams): + """Initializes GestureRecognizer class. + + Args: + label_names: A list of label names for the classes. + model_options: options to create gesture recognizer model. + hparams: The hyperparameters for training hand gesture recognizer model. + """ + super().__init__( + model_spec=None, label_names=label_names, shuffle=hparams.shuffle) + self._model_options = model_options + self._hparams = hparams + self._history = None + self.embedding_size = _EMBEDDING_SIZE + + @classmethod + def create( + cls, + train_data: classification_ds.ClassificationDataset, + validation_data: classification_ds.ClassificationDataset, + options: gesture_recognizer_options.GestureRecognizerOptions, + ) -> 'GestureRecognizer': + """Creates and trains a hand gesture recognizer with input datasets. + + If a checkpoint file exists in the {options.hparams.export_dir}/checkpoint/ + directory, the training process will load the weight from the checkpoint + file for continual training. + + Args: + train_data: Training data. + validation_data: Validation data. If None, skips validation process. + options: options for creating and training gesture recognizer model. + + Returns: + An instance of GestureRecognizer. + """ + if options.model_options is None: + options.model_options = model_opt.GestureRecognizerModelOptions() + + if options.hparams is None: + options.hparams = hp.HParams() + + gesture_recognizer = cls( + label_names=train_data.label_names, + model_options=options.model_options, + hparams=options.hparams) + + gesture_recognizer._create_model() + + train_dataset = train_data.gen_tf_dataset( + batch_size=options.hparams.batch_size, + is_training=True, + shuffle=options.hparams.shuffle) + options.hparams.steps_per_epoch = model_util.get_steps_per_epoch( + steps_per_epoch=options.hparams.steps_per_epoch, + batch_size=options.hparams.batch_size, + train_data=train_data) + train_dataset = train_dataset.take(count=options.hparams.steps_per_epoch) + + validation_dataset = validation_data.gen_tf_dataset( + batch_size=options.hparams.batch_size, is_training=False) + + tf.compat.v1.logging.info('Training the gesture recognizer model...') + gesture_recognizer._train( + train_data=train_dataset, validation_data=validation_dataset) + + return gesture_recognizer + + def _train(self, train_data: tf.data.Dataset, + validation_data: tf.data.Dataset): + """Trains the model with input train_data. + + The training results are recorded by a self.History object returned by + tf.keras.Model.fit(). + + Args: + train_data: Training data. + validation_data: Validation data. + """ + hparams = self._hparams + + scheduler = lambda epoch: hparams.learning_rate * (hparams.lr_decay**epoch) + scheduler_callback = tf.keras.callbacks.LearningRateScheduler(scheduler) + + job_dir = hparams.export_dir + checkpoint_path = os.path.join(job_dir, 'epoch_models') + checkpoint_callback = tf.keras.callbacks.ModelCheckpoint( + os.path.join(checkpoint_path, 'model-{epoch:04d}'), + save_weights_only=True) + + best_model_path = os.path.join(job_dir, 'best_model_weights') + best_model_callback = tf.keras.callbacks.ModelCheckpoint( + best_model_path, + monitor='val_loss', + mode='min', + save_best_only=True, + save_weights_only=True) + + tensorboard_callback = tf.keras.callbacks.TensorBoard( + log_dir=os.path.join(job_dir, 'logs')) + + self._model.compile( + optimizer='adam', + loss=loss_functions.FocalLoss(gamma=self._hparams.gamma), + metrics=['categorical_accuracy']) + + latest_checkpoint = tf.train.latest_checkpoint(checkpoint_path) + if latest_checkpoint: + print(f'Resuming from {latest_checkpoint}') + self._model.load_weights(latest_checkpoint) + + self._history = self._model.fit( + x=train_data, + epochs=hparams.epochs, + validation_data=validation_data, + validation_freq=1, + callbacks=[ + checkpoint_callback, best_model_callback, scheduler_callback, + tensorboard_callback + ], + ) + + def _create_model(self): + """Creates the hand gesture recognizer model. + + The gesture embedding model is pretrained and loaded from a tf.saved_model. + """ + inputs = tf.keras.Input( + shape=[self.embedding_size], + batch_size=None, + dtype=tf.float32, + name='hand_embedding') + + x = tf.keras.layers.BatchNormalization()(inputs) + x = tf.keras.layers.ReLU()(x) + dropout_rate = self._model_options.dropout_rate + x = tf.keras.layers.Dropout(rate=dropout_rate, name='dropout')(x) + outputs = tf.keras.layers.Dense( + self._num_classes, + activation='softmax', + name='custom_gesture_recognizer')( + x) + + self._model = tf.keras.Model(inputs=inputs, outputs=outputs) + + print(self._model.summary()) + + def export_model(self, model_name: str = 'gesture_recognizer.task'): + """Converts the model to TFLite and exports as a model bundle file. + + Saves a model bundle file and metadata json file to hparams.export_dir. The + resulting model bundle file will contain necessary models for hand + detection, canned gesture classification, and customized gesture + classification. Only the model bundle file is needed for the downstream + gesture recognition task. The metadata.json file is saved only to + interpret the contents of the model bundle file. + + The customized gesture model is in float without quantization. The model is + lightweight and there is no need to balance performance and efficiency by + quantization. The default score_thresholding is set to 0.5 as it can be + adjusted during inference. + + Args: + model_name: File name to save model bundle file. The full export path is + {export_dir}/{model_name}. + """ + # TODO: Convert keras embedder model instead of using tflite + gesture_embedding_model_buffer = model_util.load_tflite_model_buffer( + constants.GESTURE_EMBEDDER_TFLITE_MODEL_FILE) + hand_detector_model_buffer = model_util.load_tflite_model_buffer( + constants.HAND_DETECTOR_TFLITE_MODEL_FILE) + hand_landmarks_detector_model_buffer = model_util.load_tflite_model_buffer( + constants.HAND_LANDMARKS_DETECTOR_TFLITE_MODEL_FILE) + canned_gesture_model_buffer = model_util.load_tflite_model_buffer( + constants.CANNED_GESTURE_CLASSIFIER_TFLITE_MODEL_FILE) + + if not tf.io.gfile.exists(self._hparams.export_dir): + tf.io.gfile.makedirs(self._hparams.export_dir) + model_bundle_file = os.path.join(self._hparams.export_dir, model_name) + metadata_file = os.path.join(self._hparams.export_dir, 'metadata.json') + + gesture_classifier_options = metadata_writer.GestureClassifierOptions( + model_buffer=model_util.convert_to_tflite(self._model), + labels=base_metadata_writer.Labels().add(list(self._label_names)), + score_thresholding=base_metadata_writer.ScoreThresholding( + global_score_threshold=0.5)) + + writer = metadata_writer.MetadataWriter.create( + hand_detector_model_buffer, hand_landmarks_detector_model_buffer, + gesture_embedding_model_buffer, canned_gesture_model_buffer, + gesture_classifier_options) + model_bundle_content, metadata_json = writer.populate() + with open(model_bundle_file, 'wb') as f: + f.write(model_bundle_content) + with open(metadata_file, 'w') as f: + f.write(metadata_json) diff --git a/mediapipe/model_maker/python/vision/gesture_recognizer/gesture_recognizer_demo.py b/mediapipe/model_maker/python/vision/gesture_recognizer/gesture_recognizer_demo.py new file mode 100644 index 000000000..06075fbc6 --- /dev/null +++ b/mediapipe/model_maker/python/vision/gesture_recognizer/gesture_recognizer_demo.py @@ -0,0 +1,78 @@ +# Copyright 2022 The MediaPipe Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""Demo for making an gesture recognizer model by Mediapipe Model Maker.""" + +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +import os + +# Dependency imports + +from absl import app +from absl import flags +from absl import logging + +from mediapipe.model_maker.python.vision import gesture_recognizer + +FLAGS = flags.FLAGS + +# TODO: Move hand gesture recognizer demo dataset to an +# open-sourced directory. +TEST_DATA_DIR = 'mediapipe/model_maker/python/vision/gesture_recognizer/test_data/raw_data' + + +def define_flags(): + flags.DEFINE_string('export_dir', None, + 'The directory to save exported files.') + flags.DEFINE_string('input_data_dir', None, + 'The directory with input training data.') + flags.mark_flag_as_required('export_dir') + + +def run(data_dir: str, export_dir: str): + """Runs demo.""" + data = gesture_recognizer.Dataset.from_folder(dirname=data_dir) + train_data, rest_data = data.split(0.8) + validation_data, test_data = rest_data.split(0.5) + + model = gesture_recognizer.GestureRecognizer.create( + train_data=train_data, + validation_data=validation_data, + options=gesture_recognizer.GestureRecognizerOptions( + hparams=gesture_recognizer.HParams(export_dir=export_dir))) + + metric = model.evaluate(test_data, batch_size=2) + print('Evaluation metric') + print(metric) + + model.export_model() + + +def main(_): + logging.set_verbosity(logging.INFO) + + if FLAGS.input_data_dir is None: + data_dir = os.path.join(FLAGS.test_srcdir, TEST_DATA_DIR) + else: + data_dir = FLAGS.input_data_dir + + export_dir = os.path.expanduser(FLAGS.export_dir) + run(data_dir=data_dir, export_dir=export_dir) + + +if __name__ == '__main__': + define_flags() + app.run(main) diff --git a/mediapipe/model_maker/python/vision/gesture_recognizer/gesture_recognizer_options.py b/mediapipe/model_maker/python/vision/gesture_recognizer/gesture_recognizer_options.py new file mode 100644 index 000000000..da9e2d647 --- /dev/null +++ b/mediapipe/model_maker/python/vision/gesture_recognizer/gesture_recognizer_options.py @@ -0,0 +1,32 @@ +# Copyright 2022 The MediaPipe Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""Options for building gesture recognizer.""" + +import dataclasses +from typing import Optional + +from mediapipe.model_maker.python.vision.gesture_recognizer import hyperparameters +from mediapipe.model_maker.python.vision.gesture_recognizer import model_options as model_opt + + +@dataclasses.dataclass +class GestureRecognizerOptions: + """Configurable options for building gesture recognizer. + + Attributes: + model_options: A set of options for configuring the selected model. + hparams: A set of hyperparameters used to train the gesture recognizer. + """ + model_options: Optional[model_opt.GestureRecognizerModelOptions] = None + hparams: Optional[hyperparameters.HParams] = None diff --git a/mediapipe/model_maker/python/vision/gesture_recognizer/gesture_recognizer_test.py b/mediapipe/model_maker/python/vision/gesture_recognizer/gesture_recognizer_test.py new file mode 100644 index 000000000..eb2b1d171 --- /dev/null +++ b/mediapipe/model_maker/python/vision/gesture_recognizer/gesture_recognizer_test.py @@ -0,0 +1,132 @@ +# Copyright 2022 The MediaPipe Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the 'License'); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an 'AS IS' BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import io +import os +from unittest import mock as unittest_mock +import zipfile + +import mock +import tensorflow as tf + +from mediapipe.model_maker.python.core.utils import test_util +from mediapipe.model_maker.python.vision import gesture_recognizer +from mediapipe.tasks.python.test import test_utils + +_TEST_DATA_DIR = 'mediapipe/model_maker/python/vision/gesture_recognizer/test_data' + + +class GestureRecognizerTest(tf.test.TestCase): + + def _load_data(self): + input_data_dir = test_utils.get_test_data_path( + os.path.join(_TEST_DATA_DIR, 'raw_data')) + + data = gesture_recognizer.Dataset.from_folder( + dirname=input_data_dir, + hparams=gesture_recognizer.HandDataPreprocessingParams(shuffle=True)) + return data + + def setUp(self): + super().setUp() + self._model_options = gesture_recognizer.ModelOptions() + self._hparams = gesture_recognizer.HParams(epochs=2) + self._gesture_recognizer_options = ( + gesture_recognizer.GestureRecognizerOptions( + model_options=self._model_options, hparams=self._hparams)) + all_data = self._load_data() + # Splits data, 90% data for training, 10% for testing + self._train_data, self._test_data = all_data.split(0.9) + + def test_gesture_recognizer_model(self): + model = gesture_recognizer.GestureRecognizer.create( + train_data=self._train_data, + validation_data=self._test_data, + options=self._gesture_recognizer_options) + + self._test_accuracy(model) + + def test_export_gesture_recognizer_model(self): + model = gesture_recognizer.GestureRecognizer.create( + train_data=self._train_data, + validation_data=self._test_data, + options=self._gesture_recognizer_options) + model.export_model() + model_bundle_file = os.path.join(self._hparams.export_dir, + 'gesture_recognizer.task') + with zipfile.ZipFile(model_bundle_file) as zf: + self.assertEqual( + set(zf.namelist()), + set(['hand_landmarker.task', 'hand_gesture_recognizer.task'])) + zf.extractall(self.get_temp_dir()) + hand_gesture_recognizer_bundle_file = os.path.join( + self.get_temp_dir(), 'hand_gesture_recognizer.task') + with zipfile.ZipFile(hand_gesture_recognizer_bundle_file) as zf: + self.assertEqual( + set(zf.namelist()), + set([ + 'canned_gesture_classifier.tflite', + 'custom_gesture_classifier.tflite', 'gesture_embedder.tflite' + ])) + zf.extractall(self.get_temp_dir()) + gesture_classifier_tflite_file = os.path.join( + self.get_temp_dir(), 'custom_gesture_classifier.tflite') + test_util.test_tflite_file( + keras_model=model._model, + tflite_file=gesture_classifier_tflite_file, + size=[1, model.embedding_size]) + + def _test_accuracy(self, model, threshold=0.5): + _, accuracy = model.evaluate(self._test_data) + tf.compat.v1.logging.info(f'accuracy: {accuracy}') + self.assertGreaterEqual(accuracy, threshold) + + @unittest_mock.patch.object( + gesture_recognizer.hyperparameters, + 'HParams', + autospec=True, + return_value=gesture_recognizer.HParams(epochs=1)) + @unittest_mock.patch.object( + gesture_recognizer.model_options, + 'GestureRecognizerModelOptions', + autospec=True, + return_value=gesture_recognizer.ModelOptions()) + def test_create_hparams_and_model_options_if_none_in_image_classifier_options( + self, mock_hparams, mock_model_options): + options = gesture_recognizer.GestureRecognizerOptions() + gesture_recognizer.GestureRecognizer.create( + train_data=self._train_data, + validation_data=self._test_data, + options=options) + mock_hparams.assert_called_once() + mock_model_options.assert_called_once() + + def test_continual_training_by_loading_checkpoint(self): + mock_stdout = io.StringIO() + with mock.patch('sys.stdout', mock_stdout): + model = gesture_recognizer.GestureRecognizer.create( + train_data=self._train_data, + validation_data=self._test_data, + options=self._gesture_recognizer_options) + model = gesture_recognizer.GestureRecognizer.create( + train_data=self._train_data, + validation_data=self._test_data, + options=self._gesture_recognizer_options) + self._test_accuracy(model) + + self.assertRegex(mock_stdout.getvalue(), 'Resuming from') + + +if __name__ == '__main__': + tf.test.main() diff --git a/mediapipe/model_maker/python/vision/gesture_recognizer/hyperparameters.py b/mediapipe/model_maker/python/vision/gesture_recognizer/hyperparameters.py new file mode 100644 index 000000000..fed62453b --- /dev/null +++ b/mediapipe/model_maker/python/vision/gesture_recognizer/hyperparameters.py @@ -0,0 +1,40 @@ +# Copyright 2022 The MediaPipe Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""Hyperparameters for training customized gesture recognizer models.""" + +import dataclasses + +from mediapipe.model_maker.python.core import hyperparameters as hp + + +@dataclasses.dataclass +class HParams(hp.BaseHParams): + """The hyperparameters for training gesture recognizer. + + Attributes: + learning_rate: Learning rate to use for gradient descent training. + batch_size: Batch size for training. + epochs: Number of training iterations over the dataset. + lr_decay: Learning rate decay to use for gradient descent training. + gamma: Gamma parameter for focal loss. + """ + # Parameters from BaseHParams class. + learning_rate: float = 0.001 + batch_size: int = 2 + epochs: int = 10 + + # Parameters about training configuration + # TODO: Move lr_decay to hp.baseHParams. + lr_decay: float = 0.99 + gamma: int = 2 diff --git a/mediapipe/model_maker/python/vision/gesture_recognizer/metadata_writer.py b/mediapipe/model_maker/python/vision/gesture_recognizer/metadata_writer.py new file mode 100644 index 000000000..58b67e072 --- /dev/null +++ b/mediapipe/model_maker/python/vision/gesture_recognizer/metadata_writer.py @@ -0,0 +1,193 @@ +# Copyright 2022 The MediaPipe Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +"""Writes metadata and creates model asset bundle for gesture recognizer.""" + +import dataclasses +import os +import tempfile +from typing import Union + +import tensorflow as tf +from mediapipe.tasks.python.metadata.metadata_writers import metadata_writer +from mediapipe.tasks.python.metadata.metadata_writers import writer_utils + +_HAND_DETECTOR_TFLITE_NAME = "hand_detector.tflite" +_HAND_LANDMARKS_DETECTOR_TFLITE_NAME = "hand_landmarks_detector.tflite" +_HAND_LANDMARKER_BUNDLE_NAME = "hand_landmarker.task" +_HAND_GESTURE_RECOGNIZER_BUNDLE_NAME = "hand_gesture_recognizer.task" +_GESTURE_EMBEDDER_TFLITE_NAME = "gesture_embedder.tflite" +_CANNED_GESTURE_CLASSIFIER_TFLITE_NAME = "canned_gesture_classifier.tflite" +_CUSTOM_GESTURE_CLASSIFIER_TFLITE_NAME = "custom_gesture_classifier.tflite" + +_MODEL_NAME = "HandGestureRecognition" +_MODEL_DESCRIPTION = "Recognize the hand gesture in the image." + +_INPUT_NAME = "embedding" +_INPUT_DESCRIPTION = "Embedding feature vector from gesture embedder." +_OUTPUT_NAME = "scores" +_OUTPUT_DESCRIPTION = "Hand gesture category scores." + + +@dataclasses.dataclass +class GestureClassifierOptions: + """Options to write metadata for gesture classifier. + + Attributes: + model_buffer: Gesture classifier TFLite model buffer. + labels: Labels for the gesture classifier. + score_thresholding: Parameters to performs thresholding on output tensor + values [1]. + [1]: + https://github.com/google/mediapipe/blob/f8af41b1eb49ff4bdad756ff19d1d36f486be614/mediapipe/tasks/metadata/metadata_schema.fbs#L468 + """ + model_buffer: bytearray + labels: metadata_writer.Labels + score_thresholding: metadata_writer.ScoreThresholding + + +def read_file(file_path: str, mode: str = "rb") -> Union[str, bytes]: + with tf.io.gfile.GFile(file_path, mode) as f: + return f.read() + + +class MetadataWriter: + """MetadataWriter to write the metadata and the model asset bundle.""" + + def __init__( + self, hand_detector_model_buffer: bytearray, + hand_landmarks_detector_model_buffer: bytearray, + gesture_embedder_model_buffer: bytearray, + canned_gesture_classifier_model_buffer: bytearray, + custom_gesture_classifier_metadata_writer: metadata_writer.MetadataWriter + ) -> None: + """Initialize MetadataWriter to write the metadata and model asset bundle. + + Args: + hand_detector_model_buffer: A valid flatbuffer *with* metadata loaded from + the TFLite hand detector model file. + hand_landmarks_detector_model_buffer: A valid flatbuffer *with* metadata + loaded from the TFLite hand landmarks detector model file. + gesture_embedder_model_buffer: A valid flatbuffer *with* metadata loaded + from the TFLite gesture embedder model file. + canned_gesture_classifier_model_buffer: A valid flatbuffer *with* metadata + loaded from the TFLite canned gesture classifier model file. + custom_gesture_classifier_metadata_writer: Metadata writer to write custom + gesture classifier metadata into the TFLite file. + """ + self._hand_detector_model_buffer = hand_detector_model_buffer + self._hand_landmarks_detector_model_buffer = hand_landmarks_detector_model_buffer + self._gesture_embedder_model_buffer = gesture_embedder_model_buffer + self._canned_gesture_classifier_model_buffer = canned_gesture_classifier_model_buffer + self._custom_gesture_classifier_metadata_writer = custom_gesture_classifier_metadata_writer + self._temp_folder = tempfile.TemporaryDirectory() + + def __del__(self): + if os.path.exists(self._temp_folder.name): + self._temp_folder.cleanup() + + @classmethod + def create( + cls, + hand_detector_model_buffer: bytearray, + hand_landmarks_detector_model_buffer: bytearray, + gesture_embedder_model_buffer: bytearray, + canned_gesture_classifier_model_buffer: bytearray, + custom_gesture_classifier_options: GestureClassifierOptions, + ) -> "MetadataWriter": + """Creates MetadataWriter to write the metadata for gesture recognizer. + + Args: + hand_detector_model_buffer: A valid flatbuffer *with* metadata loaded from + the TFLite hand detector model file. + hand_landmarks_detector_model_buffer: A valid flatbuffer *with* metadata + loaded from the TFLite hand landmarks detector model file. + gesture_embedder_model_buffer: A valid flatbuffer *with* metadata loaded + from the TFLite gesture embedder model file. + canned_gesture_classifier_model_buffer: A valid flatbuffer *with* metadata + loaded from the TFLite canned gesture classifier model file. + custom_gesture_classifier_options: Custom gesture classifier options to + write custom gesture classifier metadata into the TFLite file. + + Returns: + An MetadataWrite object. + """ + writer = metadata_writer.MetadataWriter.create( + custom_gesture_classifier_options.model_buffer) + writer.add_general_info(_MODEL_NAME, _MODEL_DESCRIPTION) + writer.add_feature_input(name=_INPUT_NAME, description=_INPUT_DESCRIPTION) + writer.add_classification_output( + labels=custom_gesture_classifier_options.labels, + score_thresholding=custom_gesture_classifier_options.score_thresholding, + name=_OUTPUT_NAME, + description=_OUTPUT_DESCRIPTION) + return cls(hand_detector_model_buffer, hand_landmarks_detector_model_buffer, + gesture_embedder_model_buffer, + canned_gesture_classifier_model_buffer, writer) + + def populate(self): + """Populates the metadata and creates model asset bundle. + + Note that only the output model asset bundle is used for deployment. + The output JSON content is used to interpret the custom gesture classifier + metadata content. + + Returns: + A tuple of (model_asset_bundle_in_bytes, metadata_json_content) + """ + # Creates the model asset bundle for hand landmarker task. + landmark_models = { + _HAND_DETECTOR_TFLITE_NAME: + self._hand_detector_model_buffer, + _HAND_LANDMARKS_DETECTOR_TFLITE_NAME: + self._hand_landmarks_detector_model_buffer + } + output_hand_landmarker_path = os.path.join(self._temp_folder.name, + _HAND_LANDMARKER_BUNDLE_NAME) + writer_utils.create_model_asset_bundle(landmark_models, + output_hand_landmarker_path) + + # Write metadata into custom gesture classifier model. + self._custom_gesture_classifier_model_buffer, custom_gesture_classifier_metadata_json = self._custom_gesture_classifier_metadata_writer.populate( + ) + # Creates the model asset bundle for hand gesture recognizer sub graph. + hand_gesture_recognizer_models = { + _GESTURE_EMBEDDER_TFLITE_NAME: + self._gesture_embedder_model_buffer, + _CANNED_GESTURE_CLASSIFIER_TFLITE_NAME: + self._canned_gesture_classifier_model_buffer, + _CUSTOM_GESTURE_CLASSIFIER_TFLITE_NAME: + self._custom_gesture_classifier_model_buffer + } + output_hand_gesture_recognizer_path = os.path.join( + self._temp_folder.name, _HAND_GESTURE_RECOGNIZER_BUNDLE_NAME) + writer_utils.create_model_asset_bundle(hand_gesture_recognizer_models, + output_hand_gesture_recognizer_path) + + # Creates the model asset bundle for end-to-end hand gesture recognizer + # graph. + gesture_recognizer_models = { + _HAND_LANDMARKER_BUNDLE_NAME: + read_file(output_hand_landmarker_path), + _HAND_GESTURE_RECOGNIZER_BUNDLE_NAME: + read_file(output_hand_gesture_recognizer_path), + } + + output_file_path = os.path.join(self._temp_folder.name, + "gesture_recognizer.task") + writer_utils.create_model_asset_bundle(gesture_recognizer_models, + output_file_path) + with open(output_file_path, "rb") as f: + gesture_recognizer_model_buffer = f.read() + return gesture_recognizer_model_buffer, custom_gesture_classifier_metadata_json diff --git a/mediapipe/model_maker/python/vision/gesture_recognizer/metadata_writer_test.py b/mediapipe/model_maker/python/vision/gesture_recognizer/metadata_writer_test.py new file mode 100644 index 000000000..e1101e066 --- /dev/null +++ b/mediapipe/model_maker/python/vision/gesture_recognizer/metadata_writer_test.py @@ -0,0 +1,90 @@ +# Copyright 2022 The MediaPipe Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +"""Tests for metadata_writer.""" + +import os +import zipfile + +import tensorflow as tf + +from mediapipe.model_maker.python.vision.gesture_recognizer import metadata_writer +from mediapipe.tasks.python.metadata.metadata_writers import metadata_writer as base_metadata_writer +from mediapipe.tasks.python.test import test_utils + +_TEST_DATA_DIR = "mediapipe/model_maker/python/vision/gesture_recognizer/test_data/metadata" + +_EXPECTED_JSON = test_utils.get_test_data_path( + os.path.join(_TEST_DATA_DIR, "custom_gesture_classifier_meta.json")) +_CUSTOM_GESTURE_CLASSIFIER_PATH = test_utils.get_test_data_path( + os.path.join(_TEST_DATA_DIR, "custom_gesture_classifier.tflite")) + + +class MetadataWriterTest(tf.test.TestCase): + + def test_write_metadata_and_create_model_asset_bundle_successful(self): + # Use dummy model buffer for unit test only. + hand_detector_model_buffer = b"\x11\x12" + hand_landmarks_detector_model_buffer = b"\x22" + gesture_embedder_model_buffer = b"\x33" + canned_gesture_classifier_model_buffer = b"\x44" + custom_gesture_classifier_metadata_writer = metadata_writer.GestureClassifierOptions( + model_buffer=metadata_writer.read_file(_CUSTOM_GESTURE_CLASSIFIER_PATH), + labels=base_metadata_writer.Labels().add( + ["None", "Paper", "Rock", "Scissors"]), + score_thresholding=base_metadata_writer.ScoreThresholding( + global_score_threshold=0.5)) + writer = metadata_writer.MetadataWriter.create( + hand_detector_model_buffer, hand_landmarks_detector_model_buffer, + gesture_embedder_model_buffer, canned_gesture_classifier_model_buffer, + custom_gesture_classifier_metadata_writer) + model_bundle_content, metadata_json = writer.populate() + with open(_EXPECTED_JSON, "r") as f: + expected_json = f.read() + self.assertEqual(metadata_json, expected_json) + + # Checks the top-level model bundle can be extracted successfully. + model_bundle_filepath = os.path.join(self.get_temp_dir(), + "gesture_recognition.task") + + with open(model_bundle_filepath, "wb") as f: + f.write(model_bundle_content) + + with zipfile.ZipFile(model_bundle_filepath) as zf: + self.assertEqual( + set(zf.namelist()), + set(["hand_landmarker.task", "hand_gesture_recognizer.task"])) + zf.extractall(self.get_temp_dir()) + + # Checks the model bundles for sub-task can be extracted successfully. + hand_landmarker_bundle_filepath = os.path.join(self.get_temp_dir(), + "hand_landmarker.task") + with zipfile.ZipFile(hand_landmarker_bundle_filepath) as zf: + self.assertEqual( + set(zf.namelist()), + set(["hand_landmarks_detector.tflite", "hand_detector.tflite"])) + + hand_gesture_recognizer_bundle_filepath = os.path.join( + self.get_temp_dir(), "hand_gesture_recognizer.task") + with zipfile.ZipFile(hand_gesture_recognizer_bundle_filepath) as zf: + self.assertEqual( + set(zf.namelist()), + set([ + "canned_gesture_classifier.tflite", + "custom_gesture_classifier.tflite", "gesture_embedder.tflite" + ])) + + +if __name__ == "__main__": + tf.test.main() diff --git a/mediapipe/model_maker/python/vision/gesture_recognizer/model_options.py b/mediapipe/model_maker/python/vision/gesture_recognizer/model_options.py new file mode 100644 index 000000000..79a84c792 --- /dev/null +++ b/mediapipe/model_maker/python/vision/gesture_recognizer/model_options.py @@ -0,0 +1,27 @@ +# Copyright 2022 The MediaPipe Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""Configurable model options for gesture recognizer models.""" + +import dataclasses + + +@dataclasses.dataclass +class GestureRecognizerModelOptions: + """Configurable options for gesture recognizer model. + + Attributes: + dropout_rate: The fraction of the input units to drop, used in dropout + layer. + """ + dropout_rate: float = 0.05 diff --git a/mediapipe/model_maker/python/vision/gesture_recognizer/test_data/metadata/custom_gesture_classifier.tflite b/mediapipe/model_maker/python/vision/gesture_recognizer/test_data/metadata/custom_gesture_classifier.tflite new file mode 100644 index 000000000..553f9b402 Binary files /dev/null and b/mediapipe/model_maker/python/vision/gesture_recognizer/test_data/metadata/custom_gesture_classifier.tflite differ diff --git a/mediapipe/model_maker/python/vision/gesture_recognizer/test_data/metadata/custom_gesture_classifier_meta.json b/mediapipe/model_maker/python/vision/gesture_recognizer/test_data/metadata/custom_gesture_classifier_meta.json new file mode 100644 index 000000000..58739061d --- /dev/null +++ b/mediapipe/model_maker/python/vision/gesture_recognizer/test_data/metadata/custom_gesture_classifier_meta.json @@ -0,0 +1,56 @@ +{ + "name": "HandGestureRecognition", + "description": "Recognize the hand gesture in the image.", + "subgraph_metadata": [ + { + "input_tensor_metadata": [ + { + "name": "embedding", + "description": "Embedding feature vector from gesture embedder.", + "content": { + "content_properties_type": "FeatureProperties", + "content_properties": { + } + }, + "stats": { + } + } + ], + "output_tensor_metadata": [ + { + "name": "scores", + "description": "Hand gesture category scores.", + "content": { + "content_properties_type": "FeatureProperties", + "content_properties": { + } + }, + "process_units": [ + { + "options_type": "ScoreThresholdingOptions", + "options": { + "global_score_threshold": 0.5 + } + } + ], + "stats": { + "max": [ + 1.0 + ], + "min": [ + 0.0 + ] + }, + "associated_files": [ + { + "name": "labels.txt", + "description": "Labels for categories that the model can recognize.", + "type": "TENSOR_AXIS_LABELS" + } + ] + } + ] + } + ], + "min_parser_version": "1.0.0" +} diff --git a/mediapipe/model_maker/python/vision/gesture_recognizer/test_data/raw_data/call/0413d5c5-f5ba-476f-a921-ea5e967692a9.jpg b/mediapipe/model_maker/python/vision/gesture_recognizer/test_data/raw_data/call/0413d5c5-f5ba-476f-a921-ea5e967692a9.jpg new file mode 100644 index 000000000..ad3eafea0 Binary files /dev/null and b/mediapipe/model_maker/python/vision/gesture_recognizer/test_data/raw_data/call/0413d5c5-f5ba-476f-a921-ea5e967692a9.jpg differ diff --git a/mediapipe/model_maker/python/vision/gesture_recognizer/test_data/raw_data/call/044768ad-1709-44ba-b041-c2f8cbe4c166.jpg b/mediapipe/model_maker/python/vision/gesture_recognizer/test_data/raw_data/call/044768ad-1709-44ba-b041-c2f8cbe4c166.jpg new file mode 100644 index 000000000..a8e443080 Binary files /dev/null and b/mediapipe/model_maker/python/vision/gesture_recognizer/test_data/raw_data/call/044768ad-1709-44ba-b041-c2f8cbe4c166.jpg differ diff --git a/mediapipe/model_maker/python/vision/gesture_recognizer/test_data/raw_data/call/0e022ee9-74fd-44fe-adad-60c11835e44f.jpg b/mediapipe/model_maker/python/vision/gesture_recognizer/test_data/raw_data/call/0e022ee9-74fd-44fe-adad-60c11835e44f.jpg new file mode 100644 index 000000000..18fd8128f Binary files /dev/null and b/mediapipe/model_maker/python/vision/gesture_recognizer/test_data/raw_data/call/0e022ee9-74fd-44fe-adad-60c11835e44f.jpg differ diff --git a/mediapipe/model_maker/python/vision/gesture_recognizer/test_data/raw_data/call/143f8b21-1dc3-4383-bf36-0a54244dfbc0.jpg b/mediapipe/model_maker/python/vision/gesture_recognizer/test_data/raw_data/call/143f8b21-1dc3-4383-bf36-0a54244dfbc0.jpg new file mode 100644 index 000000000..d36347e43 Binary files /dev/null and b/mediapipe/model_maker/python/vision/gesture_recognizer/test_data/raw_data/call/143f8b21-1dc3-4383-bf36-0a54244dfbc0.jpg differ diff --git a/mediapipe/model_maker/python/vision/gesture_recognizer/test_data/raw_data/call/172ba7f6-c6ba-4398-89a2-25375dccfefa.jpg b/mediapipe/model_maker/python/vision/gesture_recognizer/test_data/raw_data/call/172ba7f6-c6ba-4398-89a2-25375dccfefa.jpg new file mode 100644 index 000000000..26866e4ad Binary files /dev/null and b/mediapipe/model_maker/python/vision/gesture_recognizer/test_data/raw_data/call/172ba7f6-c6ba-4398-89a2-25375dccfefa.jpg differ diff --git a/mediapipe/model_maker/python/vision/gesture_recognizer/test_data/raw_data/call/17b3aa02-dc4d-448d-8601-e2b67193d436.jpg b/mediapipe/model_maker/python/vision/gesture_recognizer/test_data/raw_data/call/17b3aa02-dc4d-448d-8601-e2b67193d436.jpg new file mode 100644 index 000000000..293f282be Binary files /dev/null and b/mediapipe/model_maker/python/vision/gesture_recognizer/test_data/raw_data/call/17b3aa02-dc4d-448d-8601-e2b67193d436.jpg differ diff --git a/mediapipe/model_maker/python/vision/gesture_recognizer/test_data/raw_data/call/17d804b5-7118-462d-8191-58d764f591b8.jpg b/mediapipe/model_maker/python/vision/gesture_recognizer/test_data/raw_data/call/17d804b5-7118-462d-8191-58d764f591b8.jpg new file mode 100644 index 000000000..12a0d8fb5 Binary files /dev/null and b/mediapipe/model_maker/python/vision/gesture_recognizer/test_data/raw_data/call/17d804b5-7118-462d-8191-58d764f591b8.jpg differ diff --git a/mediapipe/model_maker/python/vision/gesture_recognizer/test_data/raw_data/call/1d65a858-623a-4984-9420-958c7e870c3e.jpg b/mediapipe/model_maker/python/vision/gesture_recognizer/test_data/raw_data/call/1d65a858-623a-4984-9420-958c7e870c3e.jpg new file mode 100644 index 000000000..9917c8690 Binary files /dev/null and b/mediapipe/model_maker/python/vision/gesture_recognizer/test_data/raw_data/call/1d65a858-623a-4984-9420-958c7e870c3e.jpg differ diff --git a/mediapipe/model_maker/python/vision/gesture_recognizer/test_data/raw_data/call/1f5fb137-c7a9-435b-85dd-6d7b63ea233a.jpg b/mediapipe/model_maker/python/vision/gesture_recognizer/test_data/raw_data/call/1f5fb137-c7a9-435b-85dd-6d7b63ea233a.jpg new file mode 100644 index 000000000..62a81e19b Binary files /dev/null and b/mediapipe/model_maker/python/vision/gesture_recognizer/test_data/raw_data/call/1f5fb137-c7a9-435b-85dd-6d7b63ea233a.jpg differ diff --git a/mediapipe/model_maker/python/vision/gesture_recognizer/test_data/raw_data/call/21de0cfe-af9f-42c2-95d4-aa3d852e7dad.jpg b/mediapipe/model_maker/python/vision/gesture_recognizer/test_data/raw_data/call/21de0cfe-af9f-42c2-95d4-aa3d852e7dad.jpg new file mode 100644 index 000000000..7e7af1d7e Binary files /dev/null and b/mediapipe/model_maker/python/vision/gesture_recognizer/test_data/raw_data/call/21de0cfe-af9f-42c2-95d4-aa3d852e7dad.jpg differ diff --git a/mediapipe/model_maker/python/vision/gesture_recognizer/test_data/raw_data/four/06aa70cc-a12a-4b1e-85cf-e54d44c19a3a.jpg b/mediapipe/model_maker/python/vision/gesture_recognizer/test_data/raw_data/four/06aa70cc-a12a-4b1e-85cf-e54d44c19a3a.jpg new file mode 100644 index 000000000..66b41462b Binary files /dev/null and b/mediapipe/model_maker/python/vision/gesture_recognizer/test_data/raw_data/four/06aa70cc-a12a-4b1e-85cf-e54d44c19a3a.jpg differ diff --git a/mediapipe/model_maker/python/vision/gesture_recognizer/test_data/raw_data/four/077fa4bf-a99e-496b-b895-709afc614eec.jpg b/mediapipe/model_maker/python/vision/gesture_recognizer/test_data/raw_data/four/077fa4bf-a99e-496b-b895-709afc614eec.jpg new file mode 100644 index 000000000..4746b32da Binary files /dev/null and b/mediapipe/model_maker/python/vision/gesture_recognizer/test_data/raw_data/four/077fa4bf-a99e-496b-b895-709afc614eec.jpg differ diff --git a/mediapipe/model_maker/python/vision/gesture_recognizer/test_data/raw_data/four/07a5a144-c635-4441-aedb-5c8e9da79aac.jpg b/mediapipe/model_maker/python/vision/gesture_recognizer/test_data/raw_data/four/07a5a144-c635-4441-aedb-5c8e9da79aac.jpg new file mode 100644 index 000000000..2eee14bd0 Binary files /dev/null and b/mediapipe/model_maker/python/vision/gesture_recognizer/test_data/raw_data/four/07a5a144-c635-4441-aedb-5c8e9da79aac.jpg differ diff --git a/mediapipe/model_maker/python/vision/gesture_recognizer/test_data/raw_data/four/07fdea90-1102-4419-a3af-b394cb29531b.jpg b/mediapipe/model_maker/python/vision/gesture_recognizer/test_data/raw_data/four/07fdea90-1102-4419-a3af-b394cb29531b.jpg new file mode 100644 index 000000000..4b8259e4f Binary files /dev/null and b/mediapipe/model_maker/python/vision/gesture_recognizer/test_data/raw_data/four/07fdea90-1102-4419-a3af-b394cb29531b.jpg differ diff --git a/mediapipe/model_maker/python/vision/gesture_recognizer/test_data/raw_data/four/0c960166-75b0-4c1b-a3cc-2ddbd5a21703.jpg b/mediapipe/model_maker/python/vision/gesture_recognizer/test_data/raw_data/four/0c960166-75b0-4c1b-a3cc-2ddbd5a21703.jpg new file mode 100644 index 000000000..46db0a2f6 Binary files /dev/null and b/mediapipe/model_maker/python/vision/gesture_recognizer/test_data/raw_data/four/0c960166-75b0-4c1b-a3cc-2ddbd5a21703.jpg differ diff --git a/mediapipe/model_maker/python/vision/gesture_recognizer/test_data/raw_data/four/105f8f8e-ccd6-45a0-b22a-e314930bc13e.jpg b/mediapipe/model_maker/python/vision/gesture_recognizer/test_data/raw_data/four/105f8f8e-ccd6-45a0-b22a-e314930bc13e.jpg new file mode 100644 index 000000000..4c294e1bf Binary files /dev/null and b/mediapipe/model_maker/python/vision/gesture_recognizer/test_data/raw_data/four/105f8f8e-ccd6-45a0-b22a-e314930bc13e.jpg differ diff --git a/mediapipe/model_maker/python/vision/gesture_recognizer/test_data/raw_data/four/116292ef-5947-4d6c-a479-630ebb8a1050.jpg b/mediapipe/model_maker/python/vision/gesture_recognizer/test_data/raw_data/four/116292ef-5947-4d6c-a479-630ebb8a1050.jpg new file mode 100644 index 000000000..afe6b91e5 Binary files /dev/null and b/mediapipe/model_maker/python/vision/gesture_recognizer/test_data/raw_data/four/116292ef-5947-4d6c-a479-630ebb8a1050.jpg differ diff --git a/mediapipe/model_maker/python/vision/gesture_recognizer/test_data/raw_data/four/15a73593-b13e-4a1b-99bb-51775cfdfc42.jpg b/mediapipe/model_maker/python/vision/gesture_recognizer/test_data/raw_data/four/15a73593-b13e-4a1b-99bb-51775cfdfc42.jpg new file mode 100644 index 000000000..db34a18c5 Binary files /dev/null and b/mediapipe/model_maker/python/vision/gesture_recognizer/test_data/raw_data/four/15a73593-b13e-4a1b-99bb-51775cfdfc42.jpg differ diff --git a/mediapipe/model_maker/python/vision/gesture_recognizer/test_data/raw_data/four/249c5023-6106-447a-84ac-17eb4713731b.jpg b/mediapipe/model_maker/python/vision/gesture_recognizer/test_data/raw_data/four/249c5023-6106-447a-84ac-17eb4713731b.jpg new file mode 100644 index 000000000..16df26ac0 Binary files /dev/null and b/mediapipe/model_maker/python/vision/gesture_recognizer/test_data/raw_data/four/249c5023-6106-447a-84ac-17eb4713731b.jpg differ diff --git a/mediapipe/model_maker/python/vision/gesture_recognizer/test_data/raw_data/four/25bb4c45-e40b-482c-b588-04db60b7e450.jpg b/mediapipe/model_maker/python/vision/gesture_recognizer/test_data/raw_data/four/25bb4c45-e40b-482c-b588-04db60b7e450.jpg new file mode 100644 index 000000000..0d19f69c4 Binary files /dev/null and b/mediapipe/model_maker/python/vision/gesture_recognizer/test_data/raw_data/four/25bb4c45-e40b-482c-b588-04db60b7e450.jpg differ diff --git a/mediapipe/model_maker/python/vision/gesture_recognizer/test_data/raw_data/none/00af1db1-7c86-4e9b-9383-1fbd06c3492d.jpg b/mediapipe/model_maker/python/vision/gesture_recognizer/test_data/raw_data/none/00af1db1-7c86-4e9b-9383-1fbd06c3492d.jpg new file mode 100644 index 000000000..7a33cbf89 Binary files /dev/null and b/mediapipe/model_maker/python/vision/gesture_recognizer/test_data/raw_data/none/00af1db1-7c86-4e9b-9383-1fbd06c3492d.jpg differ diff --git a/mediapipe/model_maker/python/vision/gesture_recognizer/test_data/raw_data/none/00b85ea4-8c5d-4302-b847-0a5de1d7dab2.jpg b/mediapipe/model_maker/python/vision/gesture_recognizer/test_data/raw_data/none/00b85ea4-8c5d-4302-b847-0a5de1d7dab2.jpg new file mode 100644 index 000000000..a35d80cc7 Binary files /dev/null and b/mediapipe/model_maker/python/vision/gesture_recognizer/test_data/raw_data/none/00b85ea4-8c5d-4302-b847-0a5de1d7dab2.jpg differ diff --git a/mediapipe/model_maker/python/vision/gesture_recognizer/test_data/raw_data/none/00c84257-800d-4032-9e64-e47eb97005f5.jpg b/mediapipe/model_maker/python/vision/gesture_recognizer/test_data/raw_data/none/00c84257-800d-4032-9e64-e47eb97005f5.jpg new file mode 100644 index 000000000..e9369aa3c Binary files /dev/null and b/mediapipe/model_maker/python/vision/gesture_recognizer/test_data/raw_data/none/00c84257-800d-4032-9e64-e47eb97005f5.jpg differ diff --git a/mediapipe/model_maker/python/vision/gesture_recognizer/test_data/raw_data/none/0a038096-c14f-46ac-9155-980161ebc440.jpg b/mediapipe/model_maker/python/vision/gesture_recognizer/test_data/raw_data/none/0a038096-c14f-46ac-9155-980161ebc440.jpg new file mode 100644 index 000000000..b1353372b Binary files /dev/null and b/mediapipe/model_maker/python/vision/gesture_recognizer/test_data/raw_data/none/0a038096-c14f-46ac-9155-980161ebc440.jpg differ diff --git a/mediapipe/model_maker/python/vision/gesture_recognizer/test_data/raw_data/none/0a0ef3d2-2560-4a93-904d-437189fffbf2.jpg b/mediapipe/model_maker/python/vision/gesture_recognizer/test_data/raw_data/none/0a0ef3d2-2560-4a93-904d-437189fffbf2.jpg new file mode 100644 index 000000000..f9d12aa64 Binary files /dev/null and b/mediapipe/model_maker/python/vision/gesture_recognizer/test_data/raw_data/none/0a0ef3d2-2560-4a93-904d-437189fffbf2.jpg differ diff --git a/mediapipe/model_maker/python/vision/gesture_recognizer/test_data/raw_data/none/0a272153-56c7-42d5-a17d-cd307a1cd6d4.jpg b/mediapipe/model_maker/python/vision/gesture_recognizer/test_data/raw_data/none/0a272153-56c7-42d5-a17d-cd307a1cd6d4.jpg new file mode 100644 index 000000000..890cde282 Binary files /dev/null and b/mediapipe/model_maker/python/vision/gesture_recognizer/test_data/raw_data/none/0a272153-56c7-42d5-a17d-cd307a1cd6d4.jpg differ diff --git a/mediapipe/model_maker/python/vision/gesture_recognizer/test_data/raw_data/none/0a4a8907-1950-4e43-9a03-1740e78224ef.jpg b/mediapipe/model_maker/python/vision/gesture_recognizer/test_data/raw_data/none/0a4a8907-1950-4e43-9a03-1740e78224ef.jpg new file mode 100644 index 000000000..89caba678 Binary files /dev/null and b/mediapipe/model_maker/python/vision/gesture_recognizer/test_data/raw_data/none/0a4a8907-1950-4e43-9a03-1740e78224ef.jpg differ diff --git a/mediapipe/model_maker/python/vision/gesture_recognizer/test_data/raw_data/none/0a4bc2da-f5b3-48cd-8f0d-c61dbd08ba53.jpg b/mediapipe/model_maker/python/vision/gesture_recognizer/test_data/raw_data/none/0a4bc2da-f5b3-48cd-8f0d-c61dbd08ba53.jpg new file mode 100644 index 000000000..e9a98c988 Binary files /dev/null and b/mediapipe/model_maker/python/vision/gesture_recognizer/test_data/raw_data/none/0a4bc2da-f5b3-48cd-8f0d-c61dbd08ba53.jpg differ diff --git a/mediapipe/model_maker/python/vision/gesture_recognizer/test_data/raw_data/none/0a71a6e8-bb06-4ed0-a60b-c2a602fce261.jpg b/mediapipe/model_maker/python/vision/gesture_recognizer/test_data/raw_data/none/0a71a6e8-bb06-4ed0-a60b-c2a602fce261.jpg new file mode 100644 index 000000000..ae297e6eb Binary files /dev/null and b/mediapipe/model_maker/python/vision/gesture_recognizer/test_data/raw_data/none/0a71a6e8-bb06-4ed0-a60b-c2a602fce261.jpg differ diff --git a/mediapipe/model_maker/python/vision/gesture_recognizer/test_data/raw_data/none/0a787971-9377-4888-803f-aef21863ef7d.jpg b/mediapipe/model_maker/python/vision/gesture_recognizer/test_data/raw_data/none/0a787971-9377-4888-803f-aef21863ef7d.jpg new file mode 100644 index 000000000..240e7315f Binary files /dev/null and b/mediapipe/model_maker/python/vision/gesture_recognizer/test_data/raw_data/none/0a787971-9377-4888-803f-aef21863ef7d.jpg differ diff --git a/mediapipe/model_maker/python/vision/gesture_recognizer/test_data/raw_data/rock/026fd791-8f64-4fae-8cb0-0e01dc4362ce.jpg b/mediapipe/model_maker/python/vision/gesture_recognizer/test_data/raw_data/rock/026fd791-8f64-4fae-8cb0-0e01dc4362ce.jpg new file mode 100644 index 000000000..17e7b9d4c Binary files /dev/null and b/mediapipe/model_maker/python/vision/gesture_recognizer/test_data/raw_data/rock/026fd791-8f64-4fae-8cb0-0e01dc4362ce.jpg differ diff --git a/mediapipe/model_maker/python/vision/gesture_recognizer/test_data/raw_data/rock/055f8be9-f7fd-4c7f-ad3f-7b404b6489c3.jpg b/mediapipe/model_maker/python/vision/gesture_recognizer/test_data/raw_data/rock/055f8be9-f7fd-4c7f-ad3f-7b404b6489c3.jpg new file mode 100644 index 000000000..b2e4133f8 Binary files /dev/null and b/mediapipe/model_maker/python/vision/gesture_recognizer/test_data/raw_data/rock/055f8be9-f7fd-4c7f-ad3f-7b404b6489c3.jpg differ diff --git a/mediapipe/model_maker/python/vision/gesture_recognizer/test_data/raw_data/rock/09a619ab-cdf7-4a66-911f-347113f050f1.jpg b/mediapipe/model_maker/python/vision/gesture_recognizer/test_data/raw_data/rock/09a619ab-cdf7-4a66-911f-347113f050f1.jpg new file mode 100644 index 000000000..70cac9317 Binary files /dev/null and b/mediapipe/model_maker/python/vision/gesture_recognizer/test_data/raw_data/rock/09a619ab-cdf7-4a66-911f-347113f050f1.jpg differ diff --git a/mediapipe/model_maker/python/vision/gesture_recognizer/test_data/raw_data/rock/0c6628ea-4a8c-49c9-b7cf-c30aef18dc3d.jpg b/mediapipe/model_maker/python/vision/gesture_recognizer/test_data/raw_data/rock/0c6628ea-4a8c-49c9-b7cf-c30aef18dc3d.jpg new file mode 100644 index 000000000..8fe840be7 Binary files /dev/null and b/mediapipe/model_maker/python/vision/gesture_recognizer/test_data/raw_data/rock/0c6628ea-4a8c-49c9-b7cf-c30aef18dc3d.jpg differ diff --git a/mediapipe/model_maker/python/vision/gesture_recognizer/test_data/raw_data/rock/0cc7ad09-ae5f-45a8-b264-4216176369b6.jpg b/mediapipe/model_maker/python/vision/gesture_recognizer/test_data/raw_data/rock/0cc7ad09-ae5f-45a8-b264-4216176369b6.jpg new file mode 100644 index 000000000..eb7915ec1 Binary files /dev/null and b/mediapipe/model_maker/python/vision/gesture_recognizer/test_data/raw_data/rock/0cc7ad09-ae5f-45a8-b264-4216176369b6.jpg differ diff --git a/mediapipe/model_maker/python/vision/gesture_recognizer/test_data/raw_data/rock/10eacf4b-8aaf-46d9-be21-7fb8d8353005.jpg b/mediapipe/model_maker/python/vision/gesture_recognizer/test_data/raw_data/rock/10eacf4b-8aaf-46d9-be21-7fb8d8353005.jpg new file mode 100644 index 000000000..55a470cec Binary files /dev/null and b/mediapipe/model_maker/python/vision/gesture_recognizer/test_data/raw_data/rock/10eacf4b-8aaf-46d9-be21-7fb8d8353005.jpg differ diff --git a/mediapipe/model_maker/python/vision/gesture_recognizer/test_data/raw_data/rock/15cb4e8b-ba1d-46f1-8456-247016a599a4.jpg b/mediapipe/model_maker/python/vision/gesture_recognizer/test_data/raw_data/rock/15cb4e8b-ba1d-46f1-8456-247016a599a4.jpg new file mode 100644 index 000000000..89a9a9103 Binary files /dev/null and b/mediapipe/model_maker/python/vision/gesture_recognizer/test_data/raw_data/rock/15cb4e8b-ba1d-46f1-8456-247016a599a4.jpg differ diff --git a/mediapipe/model_maker/python/vision/gesture_recognizer/test_data/raw_data/rock/18e20af8-8fe1-48d4-bd0e-83fa9e2db88e.jpg b/mediapipe/model_maker/python/vision/gesture_recognizer/test_data/raw_data/rock/18e20af8-8fe1-48d4-bd0e-83fa9e2db88e.jpg new file mode 100644 index 000000000..57b8bf3b1 Binary files /dev/null and b/mediapipe/model_maker/python/vision/gesture_recognizer/test_data/raw_data/rock/18e20af8-8fe1-48d4-bd0e-83fa9e2db88e.jpg differ diff --git a/mediapipe/model_maker/python/vision/gesture_recognizer/test_data/raw_data/rock/1bed937b-7ae4-4070-891c-daf69415da41.jpg b/mediapipe/model_maker/python/vision/gesture_recognizer/test_data/raw_data/rock/1bed937b-7ae4-4070-891c-daf69415da41.jpg new file mode 100644 index 000000000..1ebc886f1 Binary files /dev/null and b/mediapipe/model_maker/python/vision/gesture_recognizer/test_data/raw_data/rock/1bed937b-7ae4-4070-891c-daf69415da41.jpg differ diff --git a/mediapipe/model_maker/python/vision/gesture_recognizer/test_data/raw_data/rock/20e2164d-3473-4d42-8755-22cdbd4417ba.jpg b/mediapipe/model_maker/python/vision/gesture_recognizer/test_data/raw_data/rock/20e2164d-3473-4d42-8755-22cdbd4417ba.jpg new file mode 100644 index 000000000..7c56a32cd Binary files /dev/null and b/mediapipe/model_maker/python/vision/gesture_recognizer/test_data/raw_data/rock/20e2164d-3473-4d42-8755-22cdbd4417ba.jpg differ diff --git a/mediapipe/python/BUILD b/mediapipe/python/BUILD index 8548a60d8..debe5404a 100644 --- a/mediapipe/python/BUILD +++ b/mediapipe/python/BUILD @@ -120,6 +120,34 @@ py_library( ], ) +py_library( + name = "solution_base", + srcs = ["solution_base.py"], + srcs_version = "PY3", + visibility = [ + "//mediapipe/python:__subpackages__", + ], + deps = [ + ":_framework_bindings", + ":packet_creator", + ":packet_getter", + "//mediapipe/calculators/core:constant_side_packet_calculator_py_pb2", + "//mediapipe/calculators/image:image_transformation_calculator_py_pb2", + "//mediapipe/calculators/tensor:tensors_to_detections_calculator_py_pb2", + "//mediapipe/calculators/util:landmarks_smoothing_calculator_py_pb2", + "//mediapipe/calculators/util:logic_calculator_py_pb2", + "//mediapipe/calculators/util:thresholding_calculator_py_pb2", + "//mediapipe/framework:calculator_py_pb2", + "//mediapipe/framework/formats:classification_py_pb2", + "//mediapipe/framework/formats:detection_py_pb2", + "//mediapipe/framework/formats:landmark_py_pb2", + "//mediapipe/framework/formats:rect_py_pb2", + "//mediapipe/modules/objectron/calculators:annotation_py_pb2", + "//mediapipe/modules/objectron/calculators:lift_2d_frame_annotation_to_3d_calculator_py_pb2", + "@com_google_protobuf//:protobuf_python", + ], +) + py_test( name = "calculator_graph_test", srcs = ["calculator_graph_test.py"], @@ -175,3 +203,24 @@ py_test( ":_framework_bindings", ], ) + +py_test( + name = "solution_base_test", + srcs = ["solution_base_test.py"], + python_version = "PY3", + srcs_version = "PY3", + deps = [ + ":solution_base", + "//file/google_src", + "//file/localfile", + "//mediapipe/calculators/core:gate_calculator", + "//mediapipe/calculators/core:pass_through_calculator", + "//mediapipe/calculators/core:side_packet_to_stream_calculator", + "//mediapipe/calculators/image:image_transformation_calculator", + "//mediapipe/calculators/util:detection_unique_id_calculator", + "//mediapipe/calculators/util:to_image_calculator", + "//mediapipe/framework:calculator_py_pb2", + "//mediapipe/framework/formats:detection_py_pb2", + "@com_google_protobuf//:protobuf_python", + ], +) diff --git a/mediapipe/python/solution_base.py b/mediapipe/python/solution_base.py index a482e6f6a..020d16371 100644 --- a/mediapipe/python/solution_base.py +++ b/mediapipe/python/solution_base.py @@ -40,7 +40,6 @@ from mediapipe.calculators.util import landmarks_smoothing_calculator_pb2 from mediapipe.calculators.util import logic_calculator_pb2 from mediapipe.calculators.util import thresholding_calculator_pb2 from mediapipe.framework import calculator_pb2 -from mediapipe.framework.formats import body_rig_pb2 from mediapipe.framework.formats import classification_pb2 from mediapipe.framework.formats import detection_pb2 from mediapipe.framework.formats import landmark_pb2 diff --git a/mediapipe/python/solutions/BUILD b/mediapipe/python/solutions/BUILD new file mode 100644 index 000000000..2ca0ba9ff --- /dev/null +++ b/mediapipe/python/solutions/BUILD @@ -0,0 +1,105 @@ +# Copyright 2020 The MediaPipe Authors. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +licenses(["notice"]) + +package(default_visibility = ["//visibility:public"]) + +py_library( + name = "hands", + srcs = [ + "hands.py", + "hands_connections.py", + ], + data = [ + "//mediapipe/modules/hand_landmark:hand_landmark_full.tflite", + "//mediapipe/modules/hand_landmark:hand_landmark_lite.tflite", + "//mediapipe/modules/hand_landmark:hand_landmark_tracking_cpu_graph", + "//mediapipe/modules/hand_landmark:handedness.txt", + "//mediapipe/modules/palm_detection:palm_detection_full.tflite", + "//mediapipe/modules/palm_detection:palm_detection_lite.tflite", + ], + srcs_version = "PY3", + deps = [ + "//mediapipe/calculators/core:constant_side_packet_calculator_py_pb2", + "//mediapipe/calculators/core:gate_calculator_py_pb2", + "//mediapipe/calculators/core:split_vector_calculator_py_pb2", + "//mediapipe/calculators/tensor:image_to_tensor_calculator_py_pb2", + "//mediapipe/calculators/tensor:inference_calculator_py_pb2", + "//mediapipe/calculators/tensor:tensors_to_classification_calculator_py_pb2", + "//mediapipe/calculators/tensor:tensors_to_detections_calculator_py_pb2", + "//mediapipe/calculators/tensor:tensors_to_landmarks_calculator_py_pb2", + "//mediapipe/calculators/tflite:ssd_anchors_calculator_py_pb2", + "//mediapipe/calculators/util:association_calculator_py_pb2", + "//mediapipe/calculators/util:detections_to_rects_calculator_py_pb2", + "//mediapipe/calculators/util:logic_calculator_py_pb2", + "//mediapipe/calculators/util:non_max_suppression_calculator_py_pb2", + "//mediapipe/calculators/util:rect_transformation_calculator_py_pb2", + "//mediapipe/calculators/util:thresholding_calculator_py_pb2", + "//mediapipe/python:solution_base", + ], +) + +py_library( + name = "drawing_styles", + srcs = ["drawing_styles.py"], + srcs_version = "PY3", + deps = [ + "drawing_utils", + "face_mesh", + "hands", + "pose", + ], +) + +py_library( + name = "drawing_utils", + srcs = ["drawing_utils.py"], + srcs_version = "PY3", + deps = [ + "//mediapipe/framework/formats:detection_py_pb2", + "//mediapipe/framework/formats:landmark_py_pb2", + "//mediapipe/framework/formats:location_data_py_pb2", + ], +) + +py_test( + name = "drawing_utils_test", + srcs = ["drawing_utils_test.py"], + python_version = "PY3", + srcs_version = "PY3", + deps = [ + ":drawing_utils", + "//mediapipe/framework/formats:detection_py_pb2", + "//mediapipe/framework/formats:landmark_py_pb2", + "@com_google_protobuf//:protobuf_python", + ], +) + +py_test( + name = "hands_test", + srcs = ["hands_test.py"], + data = [ + ":testdata/asl_hand.25fps.mp4", + ":testdata/asl_hand.full.npz", + ":testdata/hands.jpg", + ], + python_version = "PY3", + srcs_version = "PY3", + deps = [ + ":drawing_styles", + ":drawing_utils", + ":hands", + ], +) diff --git a/third_party/external_files.bzl b/third_party/external_files.bzl index a89a192ee..b4ec3b36c 100644 --- a/third_party/external_files.bzl +++ b/third_party/external_files.bzl @@ -88,6 +88,12 @@ def external_files(): urls = ["https://storage.googleapis.com/mediapipe-assets/burger_rotated.jpg?generation=1665065843774448"], ) + http_file( + name = "com_google_mediapipe_canned_gesture_classifier_tflite", + sha256 = "2fc7e279966a7a9e15fc869223793e390791fc61fdc0062f9bc7d0eef6be98a2", + urls = ["https://storage.googleapis.com/mediapipe-assets/canned_gesture_classifier.tflite?generation=1668124189331326"], + ) + http_file( name = "com_google_mediapipe_cat_jpg", sha256 = "2533197401eebe9410ea4d063f86c43fbd2666f3e8165a38aca155c0d09c21be", @@ -286,6 +292,12 @@ def external_files(): urls = ["https://storage.googleapis.com/mediapipe-assets/general_meta.json?generation=1665422822603848"], ) + http_file( + name = "com_google_mediapipe_gesture_embedder_tflite", + sha256 = "54abe78de1d1cd5e3cdaa0dab01db18e3ec7e09a76e7c3b5fa278572f7a60977", + urls = ["https://storage.googleapis.com/mediapipe-assets/gesture_embedder.tflite?generation=1668124192126494"], + ) + http_file( name = "com_google_mediapipe_gesture_recognizer_task", sha256 = "a966b1d4e774e0423c19c8aa71f070e5a72fe7a03c2663dd2f3cb0b0095ee3e1", @@ -721,7 +733,7 @@ def external_files(): http_file( name = "com_google_mediapipe_README_md", sha256 = "a96d08c9c70cd9717207ed72c926e02e5eada751f00bdc5d3a7e82e3492b72cb", - urls = ["https://storage.googleapis.com/mediapipe-assets/README.md?generation=1661875904887163"], + urls = ["https://storage.googleapis.com/mediapipe-assets/README.md?generation=1668124179156767"], ) http_file( @@ -970,6 +982,18 @@ def external_files(): urls = ["https://storage.googleapis.com/mediapipe-assets/yamnet_audio_classifier_with_metadata.tflite?generation=1661875980774466"], ) + http_file( + name = "com_google_mediapipe_gesture_embedder_keras_metadata_pb", + sha256 = "24268b69429be4e307f9ab099ba20d1de7c40e4191a53f6a92dcbbd97a7047d3", + urls = ["https://storage.googleapis.com/mediapipe-assets/gesture_embedder/keras_metadata.pb?generation=1668124196996131"], + ) + + http_file( + name = "com_google_mediapipe_gesture_embedder_saved_model_pb", + sha256 = "f3a2870ba3ef537a4f6a5889ffc5b7061ad98f9fd96ec431a62116892f100659", + urls = ["https://storage.googleapis.com/mediapipe-assets/gesture_embedder/saved_model.pb?generation=1668124199460071"], + ) + http_file( name = "com_google_mediapipe_object_detection_saved_model_model_ckpt_data-00000-of-00001", sha256 = "ad2f733f271dd5000a8c7f926bfea1083e6408b34d4f3b60679e5a6f96251c97", @@ -1005,3 +1029,15 @@ def external_files(): sha256 = "f29606cf218397d5580c496e50fd28cddf66e2f59b819ab9c761b72270a5adf3", urls = ["https://storage.googleapis.com/mediapipe-assets/object_detection_saved_model/saved_model.pb?generation=1661875999264354"], ) + + http_file( + name = "com_google_mediapipe_gesture_embedder_variables_variables_data-00000-of-00001", + sha256 = "9fdb750c4bac67afb9c0f61916510930b496cc47e7f89449aee2bec6b6ed0af8", + urls = ["https://storage.googleapis.com/mediapipe-assets/gesture_embedder/variables/variables.data-00000-of-00001?generation=1668124201918980"], + ) + + http_file( + name = "com_google_mediapipe_gesture_embedder_variables_variables_index", + sha256 = "3ccbcee9488fec4627d496abd9837997276b32b839a4d0ae434bd806fe380b86", + urls = ["https://storage.googleapis.com/mediapipe-assets/gesture_embedder/variables/variables.index?generation=1668124204353848"], + )