Update gesture recognizer to new mediapipe tasks pipeline

PiperOrigin-RevId: 493950564
This commit is contained in:
MediaPipe Team 2022-12-08 11:30:39 -08:00 committed by Copybara-Service
parent 13f8fa5139
commit a641ea12e1
7 changed files with 147 additions and 87 deletions

View File

@ -35,20 +35,21 @@ py_library(
srcs = ["constants.py"], srcs = ["constants.py"],
) )
# TODO: Change to py_library after migrating the MediaPipe hand solution
# library to MediaPipe hand task library.
py_library( py_library(
name = "dataset", name = "dataset",
srcs = ["dataset.py"], srcs = ["dataset.py"],
deps = [ deps = [
":constants", ":constants",
":metadata_writer",
"//mediapipe/model_maker/python/core/data:classification_dataset", "//mediapipe/model_maker/python/core/data:classification_dataset",
"//mediapipe/model_maker/python/core/data:data_util",
"//mediapipe/model_maker/python/core/utils:model_util", "//mediapipe/model_maker/python/core/utils:model_util",
"//mediapipe/python/solutions:hands", "//mediapipe/python:_framework_bindings",
"//mediapipe/tasks/python/core:base_options",
"//mediapipe/tasks/python/vision:hand_landmarker",
], ],
) )
# TODO: Remove notsan tag once tasks no longer has race condition issue
py_test( py_test(
name = "dataset_test", name = "dataset_test",
srcs = ["dataset_test.py"], srcs = ["dataset_test.py"],
@ -56,10 +57,11 @@ py_test(
":testdata", ":testdata",
"//mediapipe/model_maker/models/gesture_recognizer:models", "//mediapipe/model_maker/models/gesture_recognizer:models",
], ],
tags = ["notsan"],
deps = [ deps = [
":dataset", ":dataset",
"//mediapipe/python/solutions:hands",
"//mediapipe/tasks/python/test:test_utils", "//mediapipe/tasks/python/test:test_utils",
"//mediapipe/tasks/python/vision:hand_landmarker",
], ],
) )
@ -131,6 +133,7 @@ py_library(
], ],
) )
# TODO: Remove notsan tag once tasks no longer has race condition issue
py_test( py_test(
name = "gesture_recognizer_test", name = "gesture_recognizer_test",
size = "large", size = "large",
@ -140,6 +143,7 @@ py_test(
"//mediapipe/model_maker/models/gesture_recognizer:models", "//mediapipe/model_maker/models/gesture_recognizer:models",
], ],
shard_count = 2, shard_count = 2,
tags = ["notsan"],
deps = [ deps = [
":gesture_recognizer_import", ":gesture_recognizer_import",
"//mediapipe/model_maker/python/core/utils:test_util", "//mediapipe/model_maker/python/core/utils:test_util",

View File

@ -16,16 +16,22 @@
import dataclasses import dataclasses
import os import os
import random import random
from typing import List, NamedTuple, Optional from typing import List, Optional
import cv2
import tensorflow as tf import tensorflow as tf
from mediapipe.model_maker.python.core.data import classification_dataset from mediapipe.model_maker.python.core.data import classification_dataset
from mediapipe.model_maker.python.core.data import data_util
from mediapipe.model_maker.python.core.utils import model_util from mediapipe.model_maker.python.core.utils import model_util
from mediapipe.model_maker.python.vision.gesture_recognizer import constants from mediapipe.model_maker.python.vision.gesture_recognizer import constants
from mediapipe.python.solutions import hands as mp_hands from mediapipe.model_maker.python.vision.gesture_recognizer import metadata_writer
from mediapipe.python._framework_bindings import image as image_module
from mediapipe.tasks.python.core import base_options as base_options_module
from mediapipe.tasks.python.vision import hand_landmarker as hand_landmarker_module
_Image = image_module.Image
_HandLandmarker = hand_landmarker_module.HandLandmarker
_HandLandmarkerOptions = hand_landmarker_module.HandLandmarkerOptions
_HandLandmarkerResult = hand_landmarker_module.HandLandmarkerResult
@dataclasses.dataclass @dataclasses.dataclass
@ -59,7 +65,7 @@ class HandData:
handedness: List[float] handedness: List[float]
def _validate_data_sample(data: NamedTuple) -> bool: def _validate_data_sample(data: _HandLandmarkerResult) -> bool:
"""Validates the input hand data sample. """Validates the input hand data sample.
Args: Args:
@ -70,19 +76,17 @@ def _validate_data_sample(data: NamedTuple) -> bool:
'multi_hand_landmarks' or 'multi_hand_world_landmarks' or 'multi_handedness' 'multi_hand_landmarks' or 'multi_hand_world_landmarks' or 'multi_handedness'
or any of these attributes' values are none. Otherwise, True. or any of these attributes' values are none. Otherwise, True.
""" """
if (not hasattr(data, 'multi_hand_landmarks') or if data.hand_landmarks is None or not data.hand_landmarks:
data.multi_hand_landmarks is None):
return False return False
if (not hasattr(data, 'multi_hand_world_landmarks') or if data.hand_world_landmarks is None or not data.hand_world_landmarks:
data.multi_hand_world_landmarks is None):
return False return False
if not hasattr(data, 'multi_handedness') or data.multi_handedness is None: if data.handedness is None or not data.handedness:
return False return False
return True return True
def _get_hand_data(all_image_paths: List[str], def _get_hand_data(all_image_paths: List[str],
min_detection_confidence: float) -> Optional[HandData]: min_detection_confidence: float) -> List[Optional[HandData]]:
"""Computes hand data (landmarks and handedness) in the input image. """Computes hand data (landmarks and handedness) in the input image.
Args: Args:
@ -93,28 +97,36 @@ def _get_hand_data(all_image_paths: List[str],
A HandData object. Returns None if no hand is detected. A HandData object. Returns None if no hand is detected.
""" """
hand_data_result = [] hand_data_result = []
with mp_hands.Hands( hand_detector_model_buffer = model_util.load_tflite_model_buffer(
static_image_mode=True, constants.HAND_DETECTOR_TFLITE_MODEL_FILE)
max_num_hands=1, hand_landmarks_detector_model_buffer = model_util.load_tflite_model_buffer(
min_detection_confidence=min_detection_confidence) as hands: constants.HAND_LANDMARKS_DETECTOR_TFLITE_MODEL_FILE)
hand_landmarker_writer = metadata_writer.HandLandmarkerMetadataWriter(
hand_detector_model_buffer, hand_landmarks_detector_model_buffer)
hand_landmarker_options = _HandLandmarkerOptions(
base_options=base_options_module.BaseOptions(
model_asset_buffer=hand_landmarker_writer.populate()),
num_hands=1,
min_hand_detection_confidence=min_detection_confidence,
min_hand_presence_confidence=0.5,
min_tracking_confidence=1,
)
with _HandLandmarker.create_from_options(
hand_landmarker_options) as hand_landmarker:
for path in all_image_paths: for path in all_image_paths:
tf.compat.v1.logging.info('Loading image %s', path) tf.compat.v1.logging.info('Loading image %s', path)
image = data_util.load_image(path) image = _Image.create_from_file(path)
# Flip image around y-axis for correct handedness output data = hand_landmarker.detect(image)
image = cv2.flip(image, 1)
data = hands.process(image)
if not _validate_data_sample(data): if not _validate_data_sample(data):
hand_data_result.append(None) hand_data_result.append(None)
continue continue
hand_landmarks = [[ hand_landmarks = [[hand_landmark.x, hand_landmark.y, hand_landmark.z]
hand_landmark.x, hand_landmark.y, hand_landmark.z for hand_landmark in data.hand_landmarks[0]]
] for hand_landmark in data.multi_hand_landmarks[0].landmark]
hand_world_landmarks = [[ hand_world_landmarks = [[
hand_landmark.x, hand_landmark.y, hand_landmark.z hand_landmark.x, hand_landmark.y, hand_landmark.z
] for hand_landmark in data.multi_hand_world_landmarks[0].landmark] ] for hand_landmark in data.hand_world_landmarks[0]]
handedness_scores = [ handedness_scores = [
handedness.score handedness.score for handedness in data.handedness[0]
for handedness in data.multi_handedness[0].classification
] ]
hand_data_result.append( hand_data_result.append(
HandData( HandData(

View File

@ -12,21 +12,17 @@
# See the License for the specific language governing permissions and # See the License for the specific language governing permissions and
# limitations under the License. # limitations under the License.
import collections
import os import os
import shutil import shutil
from typing import NamedTuple from typing import NamedTuple
import unittest import unittest
from absl import flags
from absl.testing import parameterized from absl.testing import parameterized
import tensorflow as tf import tensorflow as tf
from mediapipe.model_maker.python.vision.gesture_recognizer import dataset from mediapipe.model_maker.python.vision.gesture_recognizer import dataset
from mediapipe.python.solutions import hands as mp_hands
from mediapipe.tasks.python.test import test_utils from mediapipe.tasks.python.test import test_utils
from mediapipe.tasks.python.vision import hand_landmarker
FLAGS = flags.FLAGS
_TEST_DATA_DIRNAME = 'raw_data' _TEST_DATA_DIRNAME = 'raw_data'
@ -39,14 +35,14 @@ class DatasetTest(tf.test.TestCase, parameterized.TestCase):
dirname=input_data_dir, hparams=dataset.HandDataPreprocessingParams()) dirname=input_data_dir, hparams=dataset.HandDataPreprocessingParams())
train_data, test_data = data.split(0.5) train_data, test_data = data.split(0.5)
self.assertLen(train_data, 17) self.assertLen(train_data, 16)
for _, elem in enumerate(train_data.gen_tf_dataset(is_training=True)): for _, elem in enumerate(train_data.gen_tf_dataset(is_training=True)):
self.assertEqual(elem[0].shape, (1, 128)) self.assertEqual(elem[0].shape, (1, 128))
self.assertEqual(elem[1].shape, ([1, 4])) self.assertEqual(elem[1].shape, ([1, 4]))
self.assertEqual(train_data.num_classes, 4) self.assertEqual(train_data.num_classes, 4)
self.assertEqual(train_data.label_names, ['none', 'call', 'four', 'rock']) self.assertEqual(train_data.label_names, ['none', 'call', 'four', 'rock'])
self.assertLen(test_data, 18) self.assertLen(test_data, 16)
for _, elem in enumerate(test_data.gen_tf_dataset(is_training=True)): for _, elem in enumerate(test_data.gen_tf_dataset(is_training=True)):
self.assertEqual(elem[0].shape, (1, 128)) self.assertEqual(elem[0].shape, (1, 128))
self.assertEqual(elem[1].shape, ([1, 4])) self.assertEqual(elem[1].shape, ([1, 4]))
@ -60,7 +56,7 @@ class DatasetTest(tf.test.TestCase, parameterized.TestCase):
for _, elem in enumerate(data.gen_tf_dataset(is_training=True)): for _, elem in enumerate(data.gen_tf_dataset(is_training=True)):
self.assertEqual(elem[0].shape, (1, 128)) self.assertEqual(elem[0].shape, (1, 128))
self.assertEqual(elem[1].shape, ([1, 4])) self.assertEqual(elem[1].shape, ([1, 4]))
self.assertLen(data, 35) self.assertLen(data, 32)
self.assertEqual(data.num_classes, 4) self.assertEqual(data.num_classes, 4)
self.assertEqual(data.label_names, ['none', 'call', 'four', 'rock']) self.assertEqual(data.label_names, ['none', 'call', 'four', 'rock'])
@ -105,51 +101,42 @@ class DatasetTest(tf.test.TestCase, parameterized.TestCase):
for _, elem in enumerate(data.gen_tf_dataset(is_training=True)): for _, elem in enumerate(data.gen_tf_dataset(is_training=True)):
self.assertEqual(elem[0].shape, (1, 128)) self.assertEqual(elem[0].shape, (1, 128))
self.assertEqual(elem[1].shape, ([1, 4])) self.assertEqual(elem[1].shape, ([1, 4]))
self.assertLen(data, 35) self.assertLen(data, 32)
self.assertEqual(data.num_classes, 4) self.assertEqual(data.num_classes, 4)
self.assertEqual(data.label_names, ['NONE', 'CALL', 'FOUR', 'ROCK']) self.assertEqual(data.label_names, ['NONE', 'CALL', 'FOUR', 'ROCK'])
@parameterized.named_parameters( @parameterized.named_parameters(
dict( dict(
testcase_name='invalid_field_name_multi_hand_landmark', testcase_name='none_handedness',
hand=collections.namedtuple('Hand', [ hand=hand_landmarker.HandLandmarkerResult(
'multi_hand_landmark', 'multi_hand_world_landmarks', handedness=None, hand_landmarks=[[2]],
'multi_handedness' hand_world_landmarks=[[3]])),
])(1, 2, 3)),
dict( dict(
testcase_name='invalid_field_name_multi_hand_world_landmarks', testcase_name='none_hand_landmarks',
hand=collections.namedtuple('Hand', [ hand=hand_landmarker.HandLandmarkerResult(
'multi_hand_landmarks', 'multi_hand_world_landmark', handedness=[[1]], hand_landmarks=None,
'multi_handedness' hand_world_landmarks=[[3]])),
])(1, 2, 3)),
dict( dict(
testcase_name='invalid_field_name_multi_handed', testcase_name='none_hand_world_landmarks',
hand=collections.namedtuple('Hand', [ hand=hand_landmarker.HandLandmarkerResult(
'multi_hand_landmarks', 'multi_hand_world_landmarks', handedness=[[1]], hand_landmarks=[[2]],
'multi_handed' hand_world_landmarks=None)),
])(1, 2, 3)),
dict( dict(
testcase_name='multi_hand_landmarks_is_none', testcase_name='empty_handedness',
hand=collections.namedtuple('Hand', [ hand=hand_landmarker.HandLandmarkerResult(
'multi_hand_landmarks', 'multi_hand_world_landmarks', handedness=[], hand_landmarks=[[2]], hand_world_landmarks=[[3]])),
'multi_handedness'
])(None, 2, 3)),
dict( dict(
testcase_name='multi_hand_world_landmarks_is_none', testcase_name='empty_hand_landmarks',
hand=collections.namedtuple('Hand', [ hand=hand_landmarker.HandLandmarkerResult(
'multi_hand_landmarks', 'multi_hand_world_landmarks', handedness=[[1]], hand_landmarks=[], hand_world_landmarks=[[3]])),
'multi_handedness'
])(1, None, 3)),
dict( dict(
testcase_name='multi_handedness_is_none', testcase_name='empty_hand_world_landmarks',
hand=collections.namedtuple('Hand', [ hand=hand_landmarker.HandLandmarkerResult(
'multi_hand_landmarks', 'multi_hand_world_landmarks', handedness=[[1]], hand_landmarks=[[2]], hand_world_landmarks=[])),
'multi_handedness'
])(1, 2, None)),
) )
def test_create_dataset_from_invalid_hand_data(self, hand: NamedTuple): def test_create_dataset_from_invalid_hand_data(self, hand: NamedTuple):
with unittest.mock.patch.object( with unittest.mock.patch.object(
mp_hands.Hands, 'process', return_value=hand): hand_landmarker.HandLandmarker, 'detect', return_value=hand):
input_data_dir = test_utils.get_test_data_path(_TEST_DATA_DIRNAME) input_data_dir = test_utils.get_test_data_path(_TEST_DATA_DIRNAME)
with self.assertRaisesRegex(ValueError, 'No valid hand is detected'): with self.assertRaisesRegex(ValueError, 'No valid hand is detected'):
dataset.Dataset.from_folder( dataset.Dataset.from_folder(

View File

@ -62,6 +62,50 @@ def read_file(file_path: str, mode: str = "rb") -> Union[str, bytes]:
return f.read() return f.read()
class HandLandmarkerMetadataWriter:
"""MetadataWriter to write the model asset bundle for HandLandmarker."""
def __init__(
self,
hand_detector_model_buffer: bytearray,
hand_landmarks_detector_model_buffer: bytearray,
) -> None:
"""Initializes HandLandmarkerMetadataWriter to write model asset bundle.
Args:
hand_detector_model_buffer: A valid flatbuffer *with* metadata loaded from
the TFLite hand detector model file.
hand_landmarks_detector_model_buffer: A valid flatbuffer *with* metadata
loaded from the TFLite hand landmarks detector model file.
"""
self._hand_detector_model_buffer = hand_detector_model_buffer
self._hand_landmarks_detector_model_buffer = hand_landmarks_detector_model_buffer
self._temp_folder = tempfile.TemporaryDirectory()
def __del__(self):
if os.path.exists(self._temp_folder.name):
self._temp_folder.cleanup()
def populate(self):
"""Creates the model asset bundle for hand landmarker task.
Returns:
Model asset bundle in bytes
"""
landmark_models = {
_HAND_DETECTOR_TFLITE_NAME:
self._hand_detector_model_buffer,
_HAND_LANDMARKS_DETECTOR_TFLITE_NAME:
self._hand_landmarks_detector_model_buffer
}
output_hand_landmarker_path = os.path.join(self._temp_folder.name,
_HAND_LANDMARKER_BUNDLE_NAME)
writer_utils.create_model_asset_bundle(landmark_models,
output_hand_landmarker_path)
hand_landmarker_model_buffer = read_file(output_hand_landmarker_path)
return hand_landmarker_model_buffer
class MetadataWriter: class MetadataWriter:
"""MetadataWriter to write the metadata and the model asset bundle.""" """MetadataWriter to write the metadata and the model asset bundle."""
@ -86,8 +130,8 @@ class MetadataWriter:
custom_gesture_classifier_metadata_writer: Metadata writer to write custom custom_gesture_classifier_metadata_writer: Metadata writer to write custom
gesture classifier metadata into the TFLite file. gesture classifier metadata into the TFLite file.
""" """
self._hand_detector_model_buffer = hand_detector_model_buffer self._hand_landmarker_metadata_writer = HandLandmarkerMetadataWriter(
self._hand_landmarks_detector_model_buffer = hand_landmarks_detector_model_buffer hand_detector_model_buffer, hand_landmarks_detector_model_buffer)
self._gesture_embedder_model_buffer = gesture_embedder_model_buffer self._gesture_embedder_model_buffer = gesture_embedder_model_buffer
self._canned_gesture_classifier_model_buffer = canned_gesture_classifier_model_buffer self._canned_gesture_classifier_model_buffer = canned_gesture_classifier_model_buffer
self._custom_gesture_classifier_metadata_writer = custom_gesture_classifier_metadata_writer self._custom_gesture_classifier_metadata_writer = custom_gesture_classifier_metadata_writer
@ -147,16 +191,8 @@ class MetadataWriter:
A tuple of (model_asset_bundle_in_bytes, metadata_json_content) A tuple of (model_asset_bundle_in_bytes, metadata_json_content)
""" """
# Creates the model asset bundle for hand landmarker task. # Creates the model asset bundle for hand landmarker task.
landmark_models = { hand_landmarker_model_buffer = self._hand_landmarker_metadata_writer.populate(
_HAND_DETECTOR_TFLITE_NAME: )
self._hand_detector_model_buffer,
_HAND_LANDMARKS_DETECTOR_TFLITE_NAME:
self._hand_landmarks_detector_model_buffer
}
output_hand_landmarker_path = os.path.join(self._temp_folder.name,
_HAND_LANDMARKER_BUNDLE_NAME)
writer_utils.create_model_asset_bundle(landmark_models,
output_hand_landmarker_path)
# Write metadata into custom gesture classifier model. # Write metadata into custom gesture classifier model.
self._custom_gesture_classifier_model_buffer, custom_gesture_classifier_metadata_json = self._custom_gesture_classifier_metadata_writer.populate( self._custom_gesture_classifier_model_buffer, custom_gesture_classifier_metadata_json = self._custom_gesture_classifier_metadata_writer.populate(
@ -179,7 +215,7 @@ class MetadataWriter:
# graph. # graph.
gesture_recognizer_models = { gesture_recognizer_models = {
_HAND_LANDMARKER_BUNDLE_NAME: _HAND_LANDMARKER_BUNDLE_NAME:
read_file(output_hand_landmarker_path), hand_landmarker_model_buffer,
_HAND_GESTURE_RECOGNIZER_BUNDLE_NAME: _HAND_GESTURE_RECOGNIZER_BUNDLE_NAME:
read_file(output_hand_gesture_recognizer_path), read_file(output_hand_gesture_recognizer_path),
} }

View File

@ -33,6 +33,23 @@ _CUSTOM_GESTURE_CLASSIFIER_PATH = test_utils.get_test_data_path(
class MetadataWriterTest(tf.test.TestCase): class MetadataWriterTest(tf.test.TestCase):
def test_hand_landmarker_metadata_writer(self):
# Use dummy model buffer for unit test only.
hand_detector_model_buffer = b"\x11\x12"
hand_landmarks_detector_model_buffer = b"\x22"
writer = metadata_writer.HandLandmarkerMetadataWriter(
hand_detector_model_buffer, hand_landmarks_detector_model_buffer)
model_bundle_content = writer.populate()
model_bundle_filepath = os.path.join(self.get_temp_dir(),
"hand_landmarker.task")
with open(model_bundle_filepath, "wb") as f:
f.write(model_bundle_content)
with zipfile.ZipFile(model_bundle_filepath) as zf:
self.assertEqual(
set(zf.namelist()),
set(["hand_landmarks_detector.tflite", "hand_detector.tflite"]))
def test_write_metadata_and_create_model_asset_bundle_successful(self): def test_write_metadata_and_create_model_asset_bundle_successful(self):
# Use dummy model buffer for unit test only. # Use dummy model buffer for unit test only.
hand_detector_model_buffer = b"\x11\x12" hand_detector_model_buffer = b"\x11\x12"

View File

@ -23,15 +23,15 @@ py_library(
srcs = [ srcs = [
"optional_dependencies.py", "optional_dependencies.py",
], ],
deps = [
"@org_tensorflow//tensorflow/tools/docs:doc_controls",
],
) )
py_library( py_library(
name = "base_options", name = "base_options",
srcs = ["base_options.py"], srcs = ["base_options.py"],
visibility = ["//mediapipe/tasks:users"], visibility = [
"//mediapipe/model_maker/python/vision/gesture_recognizer:__subpackages__",
"//mediapipe/tasks:users",
],
deps = [ deps = [
":optional_dependencies", ":optional_dependencies",
"//mediapipe/tasks/cc/core/proto:base_options_py_pb2", "//mediapipe/tasks/cc/core/proto:base_options_py_pb2",

View File

@ -131,6 +131,10 @@ py_library(
srcs = [ srcs = [
"hand_landmarker.py", "hand_landmarker.py",
], ],
visibility = [
"//mediapipe/model_maker/python/vision/gesture_recognizer:__subpackages__",
"//mediapipe/tasks:internal",
],
deps = [ deps = [
"//mediapipe/framework/formats:classification_py_pb2", "//mediapipe/framework/formats:classification_py_pb2",
"//mediapipe/framework/formats:landmark_py_pb2", "//mediapipe/framework/formats:landmark_py_pb2",