Update gesture recognizer to new mediapipe tasks pipeline

PiperOrigin-RevId: 493950564
This commit is contained in:
MediaPipe Team 2022-12-08 11:30:39 -08:00 committed by Copybara-Service
parent 13f8fa5139
commit a641ea12e1
7 changed files with 147 additions and 87 deletions

View File

@ -35,20 +35,21 @@ py_library(
srcs = ["constants.py"],
)
# TODO: Change to py_library after migrating the MediaPipe hand solution
# library to MediaPipe hand task library.
py_library(
name = "dataset",
srcs = ["dataset.py"],
deps = [
":constants",
":metadata_writer",
"//mediapipe/model_maker/python/core/data:classification_dataset",
"//mediapipe/model_maker/python/core/data:data_util",
"//mediapipe/model_maker/python/core/utils:model_util",
"//mediapipe/python/solutions:hands",
"//mediapipe/python:_framework_bindings",
"//mediapipe/tasks/python/core:base_options",
"//mediapipe/tasks/python/vision:hand_landmarker",
],
)
# TODO: Remove notsan tag once tasks no longer has race condition issue
py_test(
name = "dataset_test",
srcs = ["dataset_test.py"],
@ -56,10 +57,11 @@ py_test(
":testdata",
"//mediapipe/model_maker/models/gesture_recognizer:models",
],
tags = ["notsan"],
deps = [
":dataset",
"//mediapipe/python/solutions:hands",
"//mediapipe/tasks/python/test:test_utils",
"//mediapipe/tasks/python/vision:hand_landmarker",
],
)
@ -131,6 +133,7 @@ py_library(
],
)
# TODO: Remove notsan tag once tasks no longer has race condition issue
py_test(
name = "gesture_recognizer_test",
size = "large",
@ -140,6 +143,7 @@ py_test(
"//mediapipe/model_maker/models/gesture_recognizer:models",
],
shard_count = 2,
tags = ["notsan"],
deps = [
":gesture_recognizer_import",
"//mediapipe/model_maker/python/core/utils:test_util",

View File

@ -16,16 +16,22 @@
import dataclasses
import os
import random
from typing import List, NamedTuple, Optional
from typing import List, Optional
import cv2
import tensorflow as tf
from mediapipe.model_maker.python.core.data import classification_dataset
from mediapipe.model_maker.python.core.data import data_util
from mediapipe.model_maker.python.core.utils import model_util
from mediapipe.model_maker.python.vision.gesture_recognizer import constants
from mediapipe.python.solutions import hands as mp_hands
from mediapipe.model_maker.python.vision.gesture_recognizer import metadata_writer
from mediapipe.python._framework_bindings import image as image_module
from mediapipe.tasks.python.core import base_options as base_options_module
from mediapipe.tasks.python.vision import hand_landmarker as hand_landmarker_module
_Image = image_module.Image
_HandLandmarker = hand_landmarker_module.HandLandmarker
_HandLandmarkerOptions = hand_landmarker_module.HandLandmarkerOptions
_HandLandmarkerResult = hand_landmarker_module.HandLandmarkerResult
@dataclasses.dataclass
@ -59,7 +65,7 @@ class HandData:
handedness: List[float]
def _validate_data_sample(data: NamedTuple) -> bool:
def _validate_data_sample(data: _HandLandmarkerResult) -> bool:
"""Validates the input hand data sample.
Args:
@ -70,19 +76,17 @@ def _validate_data_sample(data: NamedTuple) -> bool:
'multi_hand_landmarks' or 'multi_hand_world_landmarks' or 'multi_handedness'
or any of these attributes' values are none. Otherwise, True.
"""
if (not hasattr(data, 'multi_hand_landmarks') or
data.multi_hand_landmarks is None):
if data.hand_landmarks is None or not data.hand_landmarks:
return False
if (not hasattr(data, 'multi_hand_world_landmarks') or
data.multi_hand_world_landmarks is None):
if data.hand_world_landmarks is None or not data.hand_world_landmarks:
return False
if not hasattr(data, 'multi_handedness') or data.multi_handedness is None:
if data.handedness is None or not data.handedness:
return False
return True
def _get_hand_data(all_image_paths: List[str],
min_detection_confidence: float) -> Optional[HandData]:
min_detection_confidence: float) -> List[Optional[HandData]]:
"""Computes hand data (landmarks and handedness) in the input image.
Args:
@ -93,28 +97,36 @@ def _get_hand_data(all_image_paths: List[str],
A HandData object. Returns None if no hand is detected.
"""
hand_data_result = []
with mp_hands.Hands(
static_image_mode=True,
max_num_hands=1,
min_detection_confidence=min_detection_confidence) as hands:
hand_detector_model_buffer = model_util.load_tflite_model_buffer(
constants.HAND_DETECTOR_TFLITE_MODEL_FILE)
hand_landmarks_detector_model_buffer = model_util.load_tflite_model_buffer(
constants.HAND_LANDMARKS_DETECTOR_TFLITE_MODEL_FILE)
hand_landmarker_writer = metadata_writer.HandLandmarkerMetadataWriter(
hand_detector_model_buffer, hand_landmarks_detector_model_buffer)
hand_landmarker_options = _HandLandmarkerOptions(
base_options=base_options_module.BaseOptions(
model_asset_buffer=hand_landmarker_writer.populate()),
num_hands=1,
min_hand_detection_confidence=min_detection_confidence,
min_hand_presence_confidence=0.5,
min_tracking_confidence=1,
)
with _HandLandmarker.create_from_options(
hand_landmarker_options) as hand_landmarker:
for path in all_image_paths:
tf.compat.v1.logging.info('Loading image %s', path)
image = data_util.load_image(path)
# Flip image around y-axis for correct handedness output
image = cv2.flip(image, 1)
data = hands.process(image)
image = _Image.create_from_file(path)
data = hand_landmarker.detect(image)
if not _validate_data_sample(data):
hand_data_result.append(None)
continue
hand_landmarks = [[
hand_landmark.x, hand_landmark.y, hand_landmark.z
] for hand_landmark in data.multi_hand_landmarks[0].landmark]
hand_landmarks = [[hand_landmark.x, hand_landmark.y, hand_landmark.z]
for hand_landmark in data.hand_landmarks[0]]
hand_world_landmarks = [[
hand_landmark.x, hand_landmark.y, hand_landmark.z
] for hand_landmark in data.multi_hand_world_landmarks[0].landmark]
] for hand_landmark in data.hand_world_landmarks[0]]
handedness_scores = [
handedness.score
for handedness in data.multi_handedness[0].classification
handedness.score for handedness in data.handedness[0]
]
hand_data_result.append(
HandData(

View File

@ -12,21 +12,17 @@
# See the License for the specific language governing permissions and
# limitations under the License.
import collections
import os
import shutil
from typing import NamedTuple
import unittest
from absl import flags
from absl.testing import parameterized
import tensorflow as tf
from mediapipe.model_maker.python.vision.gesture_recognizer import dataset
from mediapipe.python.solutions import hands as mp_hands
from mediapipe.tasks.python.test import test_utils
FLAGS = flags.FLAGS
from mediapipe.tasks.python.vision import hand_landmarker
_TEST_DATA_DIRNAME = 'raw_data'
@ -39,14 +35,14 @@ class DatasetTest(tf.test.TestCase, parameterized.TestCase):
dirname=input_data_dir, hparams=dataset.HandDataPreprocessingParams())
train_data, test_data = data.split(0.5)
self.assertLen(train_data, 17)
self.assertLen(train_data, 16)
for _, elem in enumerate(train_data.gen_tf_dataset(is_training=True)):
self.assertEqual(elem[0].shape, (1, 128))
self.assertEqual(elem[1].shape, ([1, 4]))
self.assertEqual(train_data.num_classes, 4)
self.assertEqual(train_data.label_names, ['none', 'call', 'four', 'rock'])
self.assertLen(test_data, 18)
self.assertLen(test_data, 16)
for _, elem in enumerate(test_data.gen_tf_dataset(is_training=True)):
self.assertEqual(elem[0].shape, (1, 128))
self.assertEqual(elem[1].shape, ([1, 4]))
@ -60,7 +56,7 @@ class DatasetTest(tf.test.TestCase, parameterized.TestCase):
for _, elem in enumerate(data.gen_tf_dataset(is_training=True)):
self.assertEqual(elem[0].shape, (1, 128))
self.assertEqual(elem[1].shape, ([1, 4]))
self.assertLen(data, 35)
self.assertLen(data, 32)
self.assertEqual(data.num_classes, 4)
self.assertEqual(data.label_names, ['none', 'call', 'four', 'rock'])
@ -105,51 +101,42 @@ class DatasetTest(tf.test.TestCase, parameterized.TestCase):
for _, elem in enumerate(data.gen_tf_dataset(is_training=True)):
self.assertEqual(elem[0].shape, (1, 128))
self.assertEqual(elem[1].shape, ([1, 4]))
self.assertLen(data, 35)
self.assertLen(data, 32)
self.assertEqual(data.num_classes, 4)
self.assertEqual(data.label_names, ['NONE', 'CALL', 'FOUR', 'ROCK'])
@parameterized.named_parameters(
dict(
testcase_name='invalid_field_name_multi_hand_landmark',
hand=collections.namedtuple('Hand', [
'multi_hand_landmark', 'multi_hand_world_landmarks',
'multi_handedness'
])(1, 2, 3)),
testcase_name='none_handedness',
hand=hand_landmarker.HandLandmarkerResult(
handedness=None, hand_landmarks=[[2]],
hand_world_landmarks=[[3]])),
dict(
testcase_name='invalid_field_name_multi_hand_world_landmarks',
hand=collections.namedtuple('Hand', [
'multi_hand_landmarks', 'multi_hand_world_landmark',
'multi_handedness'
])(1, 2, 3)),
testcase_name='none_hand_landmarks',
hand=hand_landmarker.HandLandmarkerResult(
handedness=[[1]], hand_landmarks=None,
hand_world_landmarks=[[3]])),
dict(
testcase_name='invalid_field_name_multi_handed',
hand=collections.namedtuple('Hand', [
'multi_hand_landmarks', 'multi_hand_world_landmarks',
'multi_handed'
])(1, 2, 3)),
testcase_name='none_hand_world_landmarks',
hand=hand_landmarker.HandLandmarkerResult(
handedness=[[1]], hand_landmarks=[[2]],
hand_world_landmarks=None)),
dict(
testcase_name='multi_hand_landmarks_is_none',
hand=collections.namedtuple('Hand', [
'multi_hand_landmarks', 'multi_hand_world_landmarks',
'multi_handedness'
])(None, 2, 3)),
testcase_name='empty_handedness',
hand=hand_landmarker.HandLandmarkerResult(
handedness=[], hand_landmarks=[[2]], hand_world_landmarks=[[3]])),
dict(
testcase_name='multi_hand_world_landmarks_is_none',
hand=collections.namedtuple('Hand', [
'multi_hand_landmarks', 'multi_hand_world_landmarks',
'multi_handedness'
])(1, None, 3)),
testcase_name='empty_hand_landmarks',
hand=hand_landmarker.HandLandmarkerResult(
handedness=[[1]], hand_landmarks=[], hand_world_landmarks=[[3]])),
dict(
testcase_name='multi_handedness_is_none',
hand=collections.namedtuple('Hand', [
'multi_hand_landmarks', 'multi_hand_world_landmarks',
'multi_handedness'
])(1, 2, None)),
testcase_name='empty_hand_world_landmarks',
hand=hand_landmarker.HandLandmarkerResult(
handedness=[[1]], hand_landmarks=[[2]], hand_world_landmarks=[])),
)
def test_create_dataset_from_invalid_hand_data(self, hand: NamedTuple):
with unittest.mock.patch.object(
mp_hands.Hands, 'process', return_value=hand):
hand_landmarker.HandLandmarker, 'detect', return_value=hand):
input_data_dir = test_utils.get_test_data_path(_TEST_DATA_DIRNAME)
with self.assertRaisesRegex(ValueError, 'No valid hand is detected'):
dataset.Dataset.from_folder(

View File

@ -62,6 +62,50 @@ def read_file(file_path: str, mode: str = "rb") -> Union[str, bytes]:
return f.read()
class HandLandmarkerMetadataWriter:
"""MetadataWriter to write the model asset bundle for HandLandmarker."""
def __init__(
self,
hand_detector_model_buffer: bytearray,
hand_landmarks_detector_model_buffer: bytearray,
) -> None:
"""Initializes HandLandmarkerMetadataWriter to write model asset bundle.
Args:
hand_detector_model_buffer: A valid flatbuffer *with* metadata loaded from
the TFLite hand detector model file.
hand_landmarks_detector_model_buffer: A valid flatbuffer *with* metadata
loaded from the TFLite hand landmarks detector model file.
"""
self._hand_detector_model_buffer = hand_detector_model_buffer
self._hand_landmarks_detector_model_buffer = hand_landmarks_detector_model_buffer
self._temp_folder = tempfile.TemporaryDirectory()
def __del__(self):
if os.path.exists(self._temp_folder.name):
self._temp_folder.cleanup()
def populate(self):
"""Creates the model asset bundle for hand landmarker task.
Returns:
Model asset bundle in bytes
"""
landmark_models = {
_HAND_DETECTOR_TFLITE_NAME:
self._hand_detector_model_buffer,
_HAND_LANDMARKS_DETECTOR_TFLITE_NAME:
self._hand_landmarks_detector_model_buffer
}
output_hand_landmarker_path = os.path.join(self._temp_folder.name,
_HAND_LANDMARKER_BUNDLE_NAME)
writer_utils.create_model_asset_bundle(landmark_models,
output_hand_landmarker_path)
hand_landmarker_model_buffer = read_file(output_hand_landmarker_path)
return hand_landmarker_model_buffer
class MetadataWriter:
"""MetadataWriter to write the metadata and the model asset bundle."""
@ -86,8 +130,8 @@ class MetadataWriter:
custom_gesture_classifier_metadata_writer: Metadata writer to write custom
gesture classifier metadata into the TFLite file.
"""
self._hand_detector_model_buffer = hand_detector_model_buffer
self._hand_landmarks_detector_model_buffer = hand_landmarks_detector_model_buffer
self._hand_landmarker_metadata_writer = HandLandmarkerMetadataWriter(
hand_detector_model_buffer, hand_landmarks_detector_model_buffer)
self._gesture_embedder_model_buffer = gesture_embedder_model_buffer
self._canned_gesture_classifier_model_buffer = canned_gesture_classifier_model_buffer
self._custom_gesture_classifier_metadata_writer = custom_gesture_classifier_metadata_writer
@ -147,16 +191,8 @@ class MetadataWriter:
A tuple of (model_asset_bundle_in_bytes, metadata_json_content)
"""
# Creates the model asset bundle for hand landmarker task.
landmark_models = {
_HAND_DETECTOR_TFLITE_NAME:
self._hand_detector_model_buffer,
_HAND_LANDMARKS_DETECTOR_TFLITE_NAME:
self._hand_landmarks_detector_model_buffer
}
output_hand_landmarker_path = os.path.join(self._temp_folder.name,
_HAND_LANDMARKER_BUNDLE_NAME)
writer_utils.create_model_asset_bundle(landmark_models,
output_hand_landmarker_path)
hand_landmarker_model_buffer = self._hand_landmarker_metadata_writer.populate(
)
# Write metadata into custom gesture classifier model.
self._custom_gesture_classifier_model_buffer, custom_gesture_classifier_metadata_json = self._custom_gesture_classifier_metadata_writer.populate(
@ -179,7 +215,7 @@ class MetadataWriter:
# graph.
gesture_recognizer_models = {
_HAND_LANDMARKER_BUNDLE_NAME:
read_file(output_hand_landmarker_path),
hand_landmarker_model_buffer,
_HAND_GESTURE_RECOGNIZER_BUNDLE_NAME:
read_file(output_hand_gesture_recognizer_path),
}

View File

@ -33,6 +33,23 @@ _CUSTOM_GESTURE_CLASSIFIER_PATH = test_utils.get_test_data_path(
class MetadataWriterTest(tf.test.TestCase):
def test_hand_landmarker_metadata_writer(self):
# Use dummy model buffer for unit test only.
hand_detector_model_buffer = b"\x11\x12"
hand_landmarks_detector_model_buffer = b"\x22"
writer = metadata_writer.HandLandmarkerMetadataWriter(
hand_detector_model_buffer, hand_landmarks_detector_model_buffer)
model_bundle_content = writer.populate()
model_bundle_filepath = os.path.join(self.get_temp_dir(),
"hand_landmarker.task")
with open(model_bundle_filepath, "wb") as f:
f.write(model_bundle_content)
with zipfile.ZipFile(model_bundle_filepath) as zf:
self.assertEqual(
set(zf.namelist()),
set(["hand_landmarks_detector.tflite", "hand_detector.tflite"]))
def test_write_metadata_and_create_model_asset_bundle_successful(self):
# Use dummy model buffer for unit test only.
hand_detector_model_buffer = b"\x11\x12"

View File

@ -23,15 +23,15 @@ py_library(
srcs = [
"optional_dependencies.py",
],
deps = [
"@org_tensorflow//tensorflow/tools/docs:doc_controls",
],
)
py_library(
name = "base_options",
srcs = ["base_options.py"],
visibility = ["//mediapipe/tasks:users"],
visibility = [
"//mediapipe/model_maker/python/vision/gesture_recognizer:__subpackages__",
"//mediapipe/tasks:users",
],
deps = [
":optional_dependencies",
"//mediapipe/tasks/cc/core/proto:base_options_py_pb2",

View File

@ -131,6 +131,10 @@ py_library(
srcs = [
"hand_landmarker.py",
],
visibility = [
"//mediapipe/model_maker/python/vision/gesture_recognizer:__subpackages__",
"//mediapipe/tasks:internal",
],
deps = [
"//mediapipe/framework/formats:classification_py_pb2",
"//mediapipe/framework/formats:landmark_py_pb2",