Update gesture recognizer to new mediapipe tasks pipeline
PiperOrigin-RevId: 493950564
This commit is contained in:
parent
13f8fa5139
commit
a641ea12e1
|
@ -35,20 +35,21 @@ py_library(
|
|||
srcs = ["constants.py"],
|
||||
)
|
||||
|
||||
# TODO: Change to py_library after migrating the MediaPipe hand solution
|
||||
# library to MediaPipe hand task library.
|
||||
py_library(
|
||||
name = "dataset",
|
||||
srcs = ["dataset.py"],
|
||||
deps = [
|
||||
":constants",
|
||||
":metadata_writer",
|
||||
"//mediapipe/model_maker/python/core/data:classification_dataset",
|
||||
"//mediapipe/model_maker/python/core/data:data_util",
|
||||
"//mediapipe/model_maker/python/core/utils:model_util",
|
||||
"//mediapipe/python/solutions:hands",
|
||||
"//mediapipe/python:_framework_bindings",
|
||||
"//mediapipe/tasks/python/core:base_options",
|
||||
"//mediapipe/tasks/python/vision:hand_landmarker",
|
||||
],
|
||||
)
|
||||
|
||||
# TODO: Remove notsan tag once tasks no longer has race condition issue
|
||||
py_test(
|
||||
name = "dataset_test",
|
||||
srcs = ["dataset_test.py"],
|
||||
|
@ -56,10 +57,11 @@ py_test(
|
|||
":testdata",
|
||||
"//mediapipe/model_maker/models/gesture_recognizer:models",
|
||||
],
|
||||
tags = ["notsan"],
|
||||
deps = [
|
||||
":dataset",
|
||||
"//mediapipe/python/solutions:hands",
|
||||
"//mediapipe/tasks/python/test:test_utils",
|
||||
"//mediapipe/tasks/python/vision:hand_landmarker",
|
||||
],
|
||||
)
|
||||
|
||||
|
@ -131,6 +133,7 @@ py_library(
|
|||
],
|
||||
)
|
||||
|
||||
# TODO: Remove notsan tag once tasks no longer has race condition issue
|
||||
py_test(
|
||||
name = "gesture_recognizer_test",
|
||||
size = "large",
|
||||
|
@ -140,6 +143,7 @@ py_test(
|
|||
"//mediapipe/model_maker/models/gesture_recognizer:models",
|
||||
],
|
||||
shard_count = 2,
|
||||
tags = ["notsan"],
|
||||
deps = [
|
||||
":gesture_recognizer_import",
|
||||
"//mediapipe/model_maker/python/core/utils:test_util",
|
||||
|
|
|
@ -16,16 +16,22 @@
|
|||
import dataclasses
|
||||
import os
|
||||
import random
|
||||
from typing import List, NamedTuple, Optional
|
||||
from typing import List, Optional
|
||||
|
||||
import cv2
|
||||
import tensorflow as tf
|
||||
|
||||
from mediapipe.model_maker.python.core.data import classification_dataset
|
||||
from mediapipe.model_maker.python.core.data import data_util
|
||||
from mediapipe.model_maker.python.core.utils import model_util
|
||||
from mediapipe.model_maker.python.vision.gesture_recognizer import constants
|
||||
from mediapipe.python.solutions import hands as mp_hands
|
||||
from mediapipe.model_maker.python.vision.gesture_recognizer import metadata_writer
|
||||
from mediapipe.python._framework_bindings import image as image_module
|
||||
from mediapipe.tasks.python.core import base_options as base_options_module
|
||||
from mediapipe.tasks.python.vision import hand_landmarker as hand_landmarker_module
|
||||
|
||||
_Image = image_module.Image
|
||||
_HandLandmarker = hand_landmarker_module.HandLandmarker
|
||||
_HandLandmarkerOptions = hand_landmarker_module.HandLandmarkerOptions
|
||||
_HandLandmarkerResult = hand_landmarker_module.HandLandmarkerResult
|
||||
|
||||
|
||||
@dataclasses.dataclass
|
||||
|
@ -59,7 +65,7 @@ class HandData:
|
|||
handedness: List[float]
|
||||
|
||||
|
||||
def _validate_data_sample(data: NamedTuple) -> bool:
|
||||
def _validate_data_sample(data: _HandLandmarkerResult) -> bool:
|
||||
"""Validates the input hand data sample.
|
||||
|
||||
Args:
|
||||
|
@ -70,19 +76,17 @@ def _validate_data_sample(data: NamedTuple) -> bool:
|
|||
'multi_hand_landmarks' or 'multi_hand_world_landmarks' or 'multi_handedness'
|
||||
or any of these attributes' values are none. Otherwise, True.
|
||||
"""
|
||||
if (not hasattr(data, 'multi_hand_landmarks') or
|
||||
data.multi_hand_landmarks is None):
|
||||
if data.hand_landmarks is None or not data.hand_landmarks:
|
||||
return False
|
||||
if (not hasattr(data, 'multi_hand_world_landmarks') or
|
||||
data.multi_hand_world_landmarks is None):
|
||||
if data.hand_world_landmarks is None or not data.hand_world_landmarks:
|
||||
return False
|
||||
if not hasattr(data, 'multi_handedness') or data.multi_handedness is None:
|
||||
if data.handedness is None or not data.handedness:
|
||||
return False
|
||||
return True
|
||||
|
||||
|
||||
def _get_hand_data(all_image_paths: List[str],
|
||||
min_detection_confidence: float) -> Optional[HandData]:
|
||||
min_detection_confidence: float) -> List[Optional[HandData]]:
|
||||
"""Computes hand data (landmarks and handedness) in the input image.
|
||||
|
||||
Args:
|
||||
|
@ -93,28 +97,36 @@ def _get_hand_data(all_image_paths: List[str],
|
|||
A HandData object. Returns None if no hand is detected.
|
||||
"""
|
||||
hand_data_result = []
|
||||
with mp_hands.Hands(
|
||||
static_image_mode=True,
|
||||
max_num_hands=1,
|
||||
min_detection_confidence=min_detection_confidence) as hands:
|
||||
hand_detector_model_buffer = model_util.load_tflite_model_buffer(
|
||||
constants.HAND_DETECTOR_TFLITE_MODEL_FILE)
|
||||
hand_landmarks_detector_model_buffer = model_util.load_tflite_model_buffer(
|
||||
constants.HAND_LANDMARKS_DETECTOR_TFLITE_MODEL_FILE)
|
||||
hand_landmarker_writer = metadata_writer.HandLandmarkerMetadataWriter(
|
||||
hand_detector_model_buffer, hand_landmarks_detector_model_buffer)
|
||||
hand_landmarker_options = _HandLandmarkerOptions(
|
||||
base_options=base_options_module.BaseOptions(
|
||||
model_asset_buffer=hand_landmarker_writer.populate()),
|
||||
num_hands=1,
|
||||
min_hand_detection_confidence=min_detection_confidence,
|
||||
min_hand_presence_confidence=0.5,
|
||||
min_tracking_confidence=1,
|
||||
)
|
||||
with _HandLandmarker.create_from_options(
|
||||
hand_landmarker_options) as hand_landmarker:
|
||||
for path in all_image_paths:
|
||||
tf.compat.v1.logging.info('Loading image %s', path)
|
||||
image = data_util.load_image(path)
|
||||
# Flip image around y-axis for correct handedness output
|
||||
image = cv2.flip(image, 1)
|
||||
data = hands.process(image)
|
||||
image = _Image.create_from_file(path)
|
||||
data = hand_landmarker.detect(image)
|
||||
if not _validate_data_sample(data):
|
||||
hand_data_result.append(None)
|
||||
continue
|
||||
hand_landmarks = [[
|
||||
hand_landmark.x, hand_landmark.y, hand_landmark.z
|
||||
] for hand_landmark in data.multi_hand_landmarks[0].landmark]
|
||||
hand_landmarks = [[hand_landmark.x, hand_landmark.y, hand_landmark.z]
|
||||
for hand_landmark in data.hand_landmarks[0]]
|
||||
hand_world_landmarks = [[
|
||||
hand_landmark.x, hand_landmark.y, hand_landmark.z
|
||||
] for hand_landmark in data.multi_hand_world_landmarks[0].landmark]
|
||||
] for hand_landmark in data.hand_world_landmarks[0]]
|
||||
handedness_scores = [
|
||||
handedness.score
|
||||
for handedness in data.multi_handedness[0].classification
|
||||
handedness.score for handedness in data.handedness[0]
|
||||
]
|
||||
hand_data_result.append(
|
||||
HandData(
|
||||
|
|
|
@ -12,21 +12,17 @@
|
|||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
import collections
|
||||
import os
|
||||
import shutil
|
||||
from typing import NamedTuple
|
||||
import unittest
|
||||
|
||||
from absl import flags
|
||||
from absl.testing import parameterized
|
||||
import tensorflow as tf
|
||||
|
||||
from mediapipe.model_maker.python.vision.gesture_recognizer import dataset
|
||||
from mediapipe.python.solutions import hands as mp_hands
|
||||
from mediapipe.tasks.python.test import test_utils
|
||||
|
||||
FLAGS = flags.FLAGS
|
||||
from mediapipe.tasks.python.vision import hand_landmarker
|
||||
|
||||
_TEST_DATA_DIRNAME = 'raw_data'
|
||||
|
||||
|
@ -39,14 +35,14 @@ class DatasetTest(tf.test.TestCase, parameterized.TestCase):
|
|||
dirname=input_data_dir, hparams=dataset.HandDataPreprocessingParams())
|
||||
train_data, test_data = data.split(0.5)
|
||||
|
||||
self.assertLen(train_data, 17)
|
||||
self.assertLen(train_data, 16)
|
||||
for _, elem in enumerate(train_data.gen_tf_dataset(is_training=True)):
|
||||
self.assertEqual(elem[0].shape, (1, 128))
|
||||
self.assertEqual(elem[1].shape, ([1, 4]))
|
||||
self.assertEqual(train_data.num_classes, 4)
|
||||
self.assertEqual(train_data.label_names, ['none', 'call', 'four', 'rock'])
|
||||
|
||||
self.assertLen(test_data, 18)
|
||||
self.assertLen(test_data, 16)
|
||||
for _, elem in enumerate(test_data.gen_tf_dataset(is_training=True)):
|
||||
self.assertEqual(elem[0].shape, (1, 128))
|
||||
self.assertEqual(elem[1].shape, ([1, 4]))
|
||||
|
@ -60,7 +56,7 @@ class DatasetTest(tf.test.TestCase, parameterized.TestCase):
|
|||
for _, elem in enumerate(data.gen_tf_dataset(is_training=True)):
|
||||
self.assertEqual(elem[0].shape, (1, 128))
|
||||
self.assertEqual(elem[1].shape, ([1, 4]))
|
||||
self.assertLen(data, 35)
|
||||
self.assertLen(data, 32)
|
||||
self.assertEqual(data.num_classes, 4)
|
||||
self.assertEqual(data.label_names, ['none', 'call', 'four', 'rock'])
|
||||
|
||||
|
@ -105,51 +101,42 @@ class DatasetTest(tf.test.TestCase, parameterized.TestCase):
|
|||
for _, elem in enumerate(data.gen_tf_dataset(is_training=True)):
|
||||
self.assertEqual(elem[0].shape, (1, 128))
|
||||
self.assertEqual(elem[1].shape, ([1, 4]))
|
||||
self.assertLen(data, 35)
|
||||
self.assertLen(data, 32)
|
||||
self.assertEqual(data.num_classes, 4)
|
||||
self.assertEqual(data.label_names, ['NONE', 'CALL', 'FOUR', 'ROCK'])
|
||||
|
||||
@parameterized.named_parameters(
|
||||
dict(
|
||||
testcase_name='invalid_field_name_multi_hand_landmark',
|
||||
hand=collections.namedtuple('Hand', [
|
||||
'multi_hand_landmark', 'multi_hand_world_landmarks',
|
||||
'multi_handedness'
|
||||
])(1, 2, 3)),
|
||||
testcase_name='none_handedness',
|
||||
hand=hand_landmarker.HandLandmarkerResult(
|
||||
handedness=None, hand_landmarks=[[2]],
|
||||
hand_world_landmarks=[[3]])),
|
||||
dict(
|
||||
testcase_name='invalid_field_name_multi_hand_world_landmarks',
|
||||
hand=collections.namedtuple('Hand', [
|
||||
'multi_hand_landmarks', 'multi_hand_world_landmark',
|
||||
'multi_handedness'
|
||||
])(1, 2, 3)),
|
||||
testcase_name='none_hand_landmarks',
|
||||
hand=hand_landmarker.HandLandmarkerResult(
|
||||
handedness=[[1]], hand_landmarks=None,
|
||||
hand_world_landmarks=[[3]])),
|
||||
dict(
|
||||
testcase_name='invalid_field_name_multi_handed',
|
||||
hand=collections.namedtuple('Hand', [
|
||||
'multi_hand_landmarks', 'multi_hand_world_landmarks',
|
||||
'multi_handed'
|
||||
])(1, 2, 3)),
|
||||
testcase_name='none_hand_world_landmarks',
|
||||
hand=hand_landmarker.HandLandmarkerResult(
|
||||
handedness=[[1]], hand_landmarks=[[2]],
|
||||
hand_world_landmarks=None)),
|
||||
dict(
|
||||
testcase_name='multi_hand_landmarks_is_none',
|
||||
hand=collections.namedtuple('Hand', [
|
||||
'multi_hand_landmarks', 'multi_hand_world_landmarks',
|
||||
'multi_handedness'
|
||||
])(None, 2, 3)),
|
||||
testcase_name='empty_handedness',
|
||||
hand=hand_landmarker.HandLandmarkerResult(
|
||||
handedness=[], hand_landmarks=[[2]], hand_world_landmarks=[[3]])),
|
||||
dict(
|
||||
testcase_name='multi_hand_world_landmarks_is_none',
|
||||
hand=collections.namedtuple('Hand', [
|
||||
'multi_hand_landmarks', 'multi_hand_world_landmarks',
|
||||
'multi_handedness'
|
||||
])(1, None, 3)),
|
||||
testcase_name='empty_hand_landmarks',
|
||||
hand=hand_landmarker.HandLandmarkerResult(
|
||||
handedness=[[1]], hand_landmarks=[], hand_world_landmarks=[[3]])),
|
||||
dict(
|
||||
testcase_name='multi_handedness_is_none',
|
||||
hand=collections.namedtuple('Hand', [
|
||||
'multi_hand_landmarks', 'multi_hand_world_landmarks',
|
||||
'multi_handedness'
|
||||
])(1, 2, None)),
|
||||
testcase_name='empty_hand_world_landmarks',
|
||||
hand=hand_landmarker.HandLandmarkerResult(
|
||||
handedness=[[1]], hand_landmarks=[[2]], hand_world_landmarks=[])),
|
||||
)
|
||||
def test_create_dataset_from_invalid_hand_data(self, hand: NamedTuple):
|
||||
with unittest.mock.patch.object(
|
||||
mp_hands.Hands, 'process', return_value=hand):
|
||||
hand_landmarker.HandLandmarker, 'detect', return_value=hand):
|
||||
input_data_dir = test_utils.get_test_data_path(_TEST_DATA_DIRNAME)
|
||||
with self.assertRaisesRegex(ValueError, 'No valid hand is detected'):
|
||||
dataset.Dataset.from_folder(
|
||||
|
|
|
@ -62,6 +62,50 @@ def read_file(file_path: str, mode: str = "rb") -> Union[str, bytes]:
|
|||
return f.read()
|
||||
|
||||
|
||||
class HandLandmarkerMetadataWriter:
|
||||
"""MetadataWriter to write the model asset bundle for HandLandmarker."""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
hand_detector_model_buffer: bytearray,
|
||||
hand_landmarks_detector_model_buffer: bytearray,
|
||||
) -> None:
|
||||
"""Initializes HandLandmarkerMetadataWriter to write model asset bundle.
|
||||
|
||||
Args:
|
||||
hand_detector_model_buffer: A valid flatbuffer *with* metadata loaded from
|
||||
the TFLite hand detector model file.
|
||||
hand_landmarks_detector_model_buffer: A valid flatbuffer *with* metadata
|
||||
loaded from the TFLite hand landmarks detector model file.
|
||||
"""
|
||||
self._hand_detector_model_buffer = hand_detector_model_buffer
|
||||
self._hand_landmarks_detector_model_buffer = hand_landmarks_detector_model_buffer
|
||||
self._temp_folder = tempfile.TemporaryDirectory()
|
||||
|
||||
def __del__(self):
|
||||
if os.path.exists(self._temp_folder.name):
|
||||
self._temp_folder.cleanup()
|
||||
|
||||
def populate(self):
|
||||
"""Creates the model asset bundle for hand landmarker task.
|
||||
|
||||
Returns:
|
||||
Model asset bundle in bytes
|
||||
"""
|
||||
landmark_models = {
|
||||
_HAND_DETECTOR_TFLITE_NAME:
|
||||
self._hand_detector_model_buffer,
|
||||
_HAND_LANDMARKS_DETECTOR_TFLITE_NAME:
|
||||
self._hand_landmarks_detector_model_buffer
|
||||
}
|
||||
output_hand_landmarker_path = os.path.join(self._temp_folder.name,
|
||||
_HAND_LANDMARKER_BUNDLE_NAME)
|
||||
writer_utils.create_model_asset_bundle(landmark_models,
|
||||
output_hand_landmarker_path)
|
||||
hand_landmarker_model_buffer = read_file(output_hand_landmarker_path)
|
||||
return hand_landmarker_model_buffer
|
||||
|
||||
|
||||
class MetadataWriter:
|
||||
"""MetadataWriter to write the metadata and the model asset bundle."""
|
||||
|
||||
|
@ -86,8 +130,8 @@ class MetadataWriter:
|
|||
custom_gesture_classifier_metadata_writer: Metadata writer to write custom
|
||||
gesture classifier metadata into the TFLite file.
|
||||
"""
|
||||
self._hand_detector_model_buffer = hand_detector_model_buffer
|
||||
self._hand_landmarks_detector_model_buffer = hand_landmarks_detector_model_buffer
|
||||
self._hand_landmarker_metadata_writer = HandLandmarkerMetadataWriter(
|
||||
hand_detector_model_buffer, hand_landmarks_detector_model_buffer)
|
||||
self._gesture_embedder_model_buffer = gesture_embedder_model_buffer
|
||||
self._canned_gesture_classifier_model_buffer = canned_gesture_classifier_model_buffer
|
||||
self._custom_gesture_classifier_metadata_writer = custom_gesture_classifier_metadata_writer
|
||||
|
@ -147,16 +191,8 @@ class MetadataWriter:
|
|||
A tuple of (model_asset_bundle_in_bytes, metadata_json_content)
|
||||
"""
|
||||
# Creates the model asset bundle for hand landmarker task.
|
||||
landmark_models = {
|
||||
_HAND_DETECTOR_TFLITE_NAME:
|
||||
self._hand_detector_model_buffer,
|
||||
_HAND_LANDMARKS_DETECTOR_TFLITE_NAME:
|
||||
self._hand_landmarks_detector_model_buffer
|
||||
}
|
||||
output_hand_landmarker_path = os.path.join(self._temp_folder.name,
|
||||
_HAND_LANDMARKER_BUNDLE_NAME)
|
||||
writer_utils.create_model_asset_bundle(landmark_models,
|
||||
output_hand_landmarker_path)
|
||||
hand_landmarker_model_buffer = self._hand_landmarker_metadata_writer.populate(
|
||||
)
|
||||
|
||||
# Write metadata into custom gesture classifier model.
|
||||
self._custom_gesture_classifier_model_buffer, custom_gesture_classifier_metadata_json = self._custom_gesture_classifier_metadata_writer.populate(
|
||||
|
@ -179,7 +215,7 @@ class MetadataWriter:
|
|||
# graph.
|
||||
gesture_recognizer_models = {
|
||||
_HAND_LANDMARKER_BUNDLE_NAME:
|
||||
read_file(output_hand_landmarker_path),
|
||||
hand_landmarker_model_buffer,
|
||||
_HAND_GESTURE_RECOGNIZER_BUNDLE_NAME:
|
||||
read_file(output_hand_gesture_recognizer_path),
|
||||
}
|
||||
|
|
|
@ -33,6 +33,23 @@ _CUSTOM_GESTURE_CLASSIFIER_PATH = test_utils.get_test_data_path(
|
|||
|
||||
class MetadataWriterTest(tf.test.TestCase):
|
||||
|
||||
def test_hand_landmarker_metadata_writer(self):
|
||||
# Use dummy model buffer for unit test only.
|
||||
hand_detector_model_buffer = b"\x11\x12"
|
||||
hand_landmarks_detector_model_buffer = b"\x22"
|
||||
writer = metadata_writer.HandLandmarkerMetadataWriter(
|
||||
hand_detector_model_buffer, hand_landmarks_detector_model_buffer)
|
||||
model_bundle_content = writer.populate()
|
||||
model_bundle_filepath = os.path.join(self.get_temp_dir(),
|
||||
"hand_landmarker.task")
|
||||
with open(model_bundle_filepath, "wb") as f:
|
||||
f.write(model_bundle_content)
|
||||
|
||||
with zipfile.ZipFile(model_bundle_filepath) as zf:
|
||||
self.assertEqual(
|
||||
set(zf.namelist()),
|
||||
set(["hand_landmarks_detector.tflite", "hand_detector.tflite"]))
|
||||
|
||||
def test_write_metadata_and_create_model_asset_bundle_successful(self):
|
||||
# Use dummy model buffer for unit test only.
|
||||
hand_detector_model_buffer = b"\x11\x12"
|
||||
|
|
|
@ -23,15 +23,15 @@ py_library(
|
|||
srcs = [
|
||||
"optional_dependencies.py",
|
||||
],
|
||||
deps = [
|
||||
"@org_tensorflow//tensorflow/tools/docs:doc_controls",
|
||||
],
|
||||
)
|
||||
|
||||
py_library(
|
||||
name = "base_options",
|
||||
srcs = ["base_options.py"],
|
||||
visibility = ["//mediapipe/tasks:users"],
|
||||
visibility = [
|
||||
"//mediapipe/model_maker/python/vision/gesture_recognizer:__subpackages__",
|
||||
"//mediapipe/tasks:users",
|
||||
],
|
||||
deps = [
|
||||
":optional_dependencies",
|
||||
"//mediapipe/tasks/cc/core/proto:base_options_py_pb2",
|
||||
|
|
|
@ -131,6 +131,10 @@ py_library(
|
|||
srcs = [
|
||||
"hand_landmarker.py",
|
||||
],
|
||||
visibility = [
|
||||
"//mediapipe/model_maker/python/vision/gesture_recognizer:__subpackages__",
|
||||
"//mediapipe/tasks:internal",
|
||||
],
|
||||
deps = [
|
||||
"//mediapipe/framework/formats:classification_py_pb2",
|
||||
"//mediapipe/framework/formats:landmark_py_pb2",
|
||||
|
|
Loading…
Reference in New Issue
Block a user