Update gesture recognizer to new mediapipe tasks pipeline
PiperOrigin-RevId: 493950564
This commit is contained in:
parent
13f8fa5139
commit
a641ea12e1
|
@ -35,20 +35,21 @@ py_library(
|
||||||
srcs = ["constants.py"],
|
srcs = ["constants.py"],
|
||||||
)
|
)
|
||||||
|
|
||||||
# TODO: Change to py_library after migrating the MediaPipe hand solution
|
|
||||||
# library to MediaPipe hand task library.
|
|
||||||
py_library(
|
py_library(
|
||||||
name = "dataset",
|
name = "dataset",
|
||||||
srcs = ["dataset.py"],
|
srcs = ["dataset.py"],
|
||||||
deps = [
|
deps = [
|
||||||
":constants",
|
":constants",
|
||||||
|
":metadata_writer",
|
||||||
"//mediapipe/model_maker/python/core/data:classification_dataset",
|
"//mediapipe/model_maker/python/core/data:classification_dataset",
|
||||||
"//mediapipe/model_maker/python/core/data:data_util",
|
|
||||||
"//mediapipe/model_maker/python/core/utils:model_util",
|
"//mediapipe/model_maker/python/core/utils:model_util",
|
||||||
"//mediapipe/python/solutions:hands",
|
"//mediapipe/python:_framework_bindings",
|
||||||
|
"//mediapipe/tasks/python/core:base_options",
|
||||||
|
"//mediapipe/tasks/python/vision:hand_landmarker",
|
||||||
],
|
],
|
||||||
)
|
)
|
||||||
|
|
||||||
|
# TODO: Remove notsan tag once tasks no longer has race condition issue
|
||||||
py_test(
|
py_test(
|
||||||
name = "dataset_test",
|
name = "dataset_test",
|
||||||
srcs = ["dataset_test.py"],
|
srcs = ["dataset_test.py"],
|
||||||
|
@ -56,10 +57,11 @@ py_test(
|
||||||
":testdata",
|
":testdata",
|
||||||
"//mediapipe/model_maker/models/gesture_recognizer:models",
|
"//mediapipe/model_maker/models/gesture_recognizer:models",
|
||||||
],
|
],
|
||||||
|
tags = ["notsan"],
|
||||||
deps = [
|
deps = [
|
||||||
":dataset",
|
":dataset",
|
||||||
"//mediapipe/python/solutions:hands",
|
|
||||||
"//mediapipe/tasks/python/test:test_utils",
|
"//mediapipe/tasks/python/test:test_utils",
|
||||||
|
"//mediapipe/tasks/python/vision:hand_landmarker",
|
||||||
],
|
],
|
||||||
)
|
)
|
||||||
|
|
||||||
|
@ -131,6 +133,7 @@ py_library(
|
||||||
],
|
],
|
||||||
)
|
)
|
||||||
|
|
||||||
|
# TODO: Remove notsan tag once tasks no longer has race condition issue
|
||||||
py_test(
|
py_test(
|
||||||
name = "gesture_recognizer_test",
|
name = "gesture_recognizer_test",
|
||||||
size = "large",
|
size = "large",
|
||||||
|
@ -140,6 +143,7 @@ py_test(
|
||||||
"//mediapipe/model_maker/models/gesture_recognizer:models",
|
"//mediapipe/model_maker/models/gesture_recognizer:models",
|
||||||
],
|
],
|
||||||
shard_count = 2,
|
shard_count = 2,
|
||||||
|
tags = ["notsan"],
|
||||||
deps = [
|
deps = [
|
||||||
":gesture_recognizer_import",
|
":gesture_recognizer_import",
|
||||||
"//mediapipe/model_maker/python/core/utils:test_util",
|
"//mediapipe/model_maker/python/core/utils:test_util",
|
||||||
|
|
|
@ -16,16 +16,22 @@
|
||||||
import dataclasses
|
import dataclasses
|
||||||
import os
|
import os
|
||||||
import random
|
import random
|
||||||
from typing import List, NamedTuple, Optional
|
from typing import List, Optional
|
||||||
|
|
||||||
import cv2
|
|
||||||
import tensorflow as tf
|
import tensorflow as tf
|
||||||
|
|
||||||
from mediapipe.model_maker.python.core.data import classification_dataset
|
from mediapipe.model_maker.python.core.data import classification_dataset
|
||||||
from mediapipe.model_maker.python.core.data import data_util
|
|
||||||
from mediapipe.model_maker.python.core.utils import model_util
|
from mediapipe.model_maker.python.core.utils import model_util
|
||||||
from mediapipe.model_maker.python.vision.gesture_recognizer import constants
|
from mediapipe.model_maker.python.vision.gesture_recognizer import constants
|
||||||
from mediapipe.python.solutions import hands as mp_hands
|
from mediapipe.model_maker.python.vision.gesture_recognizer import metadata_writer
|
||||||
|
from mediapipe.python._framework_bindings import image as image_module
|
||||||
|
from mediapipe.tasks.python.core import base_options as base_options_module
|
||||||
|
from mediapipe.tasks.python.vision import hand_landmarker as hand_landmarker_module
|
||||||
|
|
||||||
|
_Image = image_module.Image
|
||||||
|
_HandLandmarker = hand_landmarker_module.HandLandmarker
|
||||||
|
_HandLandmarkerOptions = hand_landmarker_module.HandLandmarkerOptions
|
||||||
|
_HandLandmarkerResult = hand_landmarker_module.HandLandmarkerResult
|
||||||
|
|
||||||
|
|
||||||
@dataclasses.dataclass
|
@dataclasses.dataclass
|
||||||
|
@ -59,7 +65,7 @@ class HandData:
|
||||||
handedness: List[float]
|
handedness: List[float]
|
||||||
|
|
||||||
|
|
||||||
def _validate_data_sample(data: NamedTuple) -> bool:
|
def _validate_data_sample(data: _HandLandmarkerResult) -> bool:
|
||||||
"""Validates the input hand data sample.
|
"""Validates the input hand data sample.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
|
@ -70,19 +76,17 @@ def _validate_data_sample(data: NamedTuple) -> bool:
|
||||||
'multi_hand_landmarks' or 'multi_hand_world_landmarks' or 'multi_handedness'
|
'multi_hand_landmarks' or 'multi_hand_world_landmarks' or 'multi_handedness'
|
||||||
or any of these attributes' values are none. Otherwise, True.
|
or any of these attributes' values are none. Otherwise, True.
|
||||||
"""
|
"""
|
||||||
if (not hasattr(data, 'multi_hand_landmarks') or
|
if data.hand_landmarks is None or not data.hand_landmarks:
|
||||||
data.multi_hand_landmarks is None):
|
|
||||||
return False
|
return False
|
||||||
if (not hasattr(data, 'multi_hand_world_landmarks') or
|
if data.hand_world_landmarks is None or not data.hand_world_landmarks:
|
||||||
data.multi_hand_world_landmarks is None):
|
|
||||||
return False
|
return False
|
||||||
if not hasattr(data, 'multi_handedness') or data.multi_handedness is None:
|
if data.handedness is None or not data.handedness:
|
||||||
return False
|
return False
|
||||||
return True
|
return True
|
||||||
|
|
||||||
|
|
||||||
def _get_hand_data(all_image_paths: List[str],
|
def _get_hand_data(all_image_paths: List[str],
|
||||||
min_detection_confidence: float) -> Optional[HandData]:
|
min_detection_confidence: float) -> List[Optional[HandData]]:
|
||||||
"""Computes hand data (landmarks and handedness) in the input image.
|
"""Computes hand data (landmarks and handedness) in the input image.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
|
@ -93,28 +97,36 @@ def _get_hand_data(all_image_paths: List[str],
|
||||||
A HandData object. Returns None if no hand is detected.
|
A HandData object. Returns None if no hand is detected.
|
||||||
"""
|
"""
|
||||||
hand_data_result = []
|
hand_data_result = []
|
||||||
with mp_hands.Hands(
|
hand_detector_model_buffer = model_util.load_tflite_model_buffer(
|
||||||
static_image_mode=True,
|
constants.HAND_DETECTOR_TFLITE_MODEL_FILE)
|
||||||
max_num_hands=1,
|
hand_landmarks_detector_model_buffer = model_util.load_tflite_model_buffer(
|
||||||
min_detection_confidence=min_detection_confidence) as hands:
|
constants.HAND_LANDMARKS_DETECTOR_TFLITE_MODEL_FILE)
|
||||||
|
hand_landmarker_writer = metadata_writer.HandLandmarkerMetadataWriter(
|
||||||
|
hand_detector_model_buffer, hand_landmarks_detector_model_buffer)
|
||||||
|
hand_landmarker_options = _HandLandmarkerOptions(
|
||||||
|
base_options=base_options_module.BaseOptions(
|
||||||
|
model_asset_buffer=hand_landmarker_writer.populate()),
|
||||||
|
num_hands=1,
|
||||||
|
min_hand_detection_confidence=min_detection_confidence,
|
||||||
|
min_hand_presence_confidence=0.5,
|
||||||
|
min_tracking_confidence=1,
|
||||||
|
)
|
||||||
|
with _HandLandmarker.create_from_options(
|
||||||
|
hand_landmarker_options) as hand_landmarker:
|
||||||
for path in all_image_paths:
|
for path in all_image_paths:
|
||||||
tf.compat.v1.logging.info('Loading image %s', path)
|
tf.compat.v1.logging.info('Loading image %s', path)
|
||||||
image = data_util.load_image(path)
|
image = _Image.create_from_file(path)
|
||||||
# Flip image around y-axis for correct handedness output
|
data = hand_landmarker.detect(image)
|
||||||
image = cv2.flip(image, 1)
|
|
||||||
data = hands.process(image)
|
|
||||||
if not _validate_data_sample(data):
|
if not _validate_data_sample(data):
|
||||||
hand_data_result.append(None)
|
hand_data_result.append(None)
|
||||||
continue
|
continue
|
||||||
hand_landmarks = [[
|
hand_landmarks = [[hand_landmark.x, hand_landmark.y, hand_landmark.z]
|
||||||
hand_landmark.x, hand_landmark.y, hand_landmark.z
|
for hand_landmark in data.hand_landmarks[0]]
|
||||||
] for hand_landmark in data.multi_hand_landmarks[0].landmark]
|
|
||||||
hand_world_landmarks = [[
|
hand_world_landmarks = [[
|
||||||
hand_landmark.x, hand_landmark.y, hand_landmark.z
|
hand_landmark.x, hand_landmark.y, hand_landmark.z
|
||||||
] for hand_landmark in data.multi_hand_world_landmarks[0].landmark]
|
] for hand_landmark in data.hand_world_landmarks[0]]
|
||||||
handedness_scores = [
|
handedness_scores = [
|
||||||
handedness.score
|
handedness.score for handedness in data.handedness[0]
|
||||||
for handedness in data.multi_handedness[0].classification
|
|
||||||
]
|
]
|
||||||
hand_data_result.append(
|
hand_data_result.append(
|
||||||
HandData(
|
HandData(
|
||||||
|
|
|
@ -12,21 +12,17 @@
|
||||||
# See the License for the specific language governing permissions and
|
# See the License for the specific language governing permissions and
|
||||||
# limitations under the License.
|
# limitations under the License.
|
||||||
|
|
||||||
import collections
|
|
||||||
import os
|
import os
|
||||||
import shutil
|
import shutil
|
||||||
from typing import NamedTuple
|
from typing import NamedTuple
|
||||||
import unittest
|
import unittest
|
||||||
|
|
||||||
from absl import flags
|
|
||||||
from absl.testing import parameterized
|
from absl.testing import parameterized
|
||||||
import tensorflow as tf
|
import tensorflow as tf
|
||||||
|
|
||||||
from mediapipe.model_maker.python.vision.gesture_recognizer import dataset
|
from mediapipe.model_maker.python.vision.gesture_recognizer import dataset
|
||||||
from mediapipe.python.solutions import hands as mp_hands
|
|
||||||
from mediapipe.tasks.python.test import test_utils
|
from mediapipe.tasks.python.test import test_utils
|
||||||
|
from mediapipe.tasks.python.vision import hand_landmarker
|
||||||
FLAGS = flags.FLAGS
|
|
||||||
|
|
||||||
_TEST_DATA_DIRNAME = 'raw_data'
|
_TEST_DATA_DIRNAME = 'raw_data'
|
||||||
|
|
||||||
|
@ -39,14 +35,14 @@ class DatasetTest(tf.test.TestCase, parameterized.TestCase):
|
||||||
dirname=input_data_dir, hparams=dataset.HandDataPreprocessingParams())
|
dirname=input_data_dir, hparams=dataset.HandDataPreprocessingParams())
|
||||||
train_data, test_data = data.split(0.5)
|
train_data, test_data = data.split(0.5)
|
||||||
|
|
||||||
self.assertLen(train_data, 17)
|
self.assertLen(train_data, 16)
|
||||||
for _, elem in enumerate(train_data.gen_tf_dataset(is_training=True)):
|
for _, elem in enumerate(train_data.gen_tf_dataset(is_training=True)):
|
||||||
self.assertEqual(elem[0].shape, (1, 128))
|
self.assertEqual(elem[0].shape, (1, 128))
|
||||||
self.assertEqual(elem[1].shape, ([1, 4]))
|
self.assertEqual(elem[1].shape, ([1, 4]))
|
||||||
self.assertEqual(train_data.num_classes, 4)
|
self.assertEqual(train_data.num_classes, 4)
|
||||||
self.assertEqual(train_data.label_names, ['none', 'call', 'four', 'rock'])
|
self.assertEqual(train_data.label_names, ['none', 'call', 'four', 'rock'])
|
||||||
|
|
||||||
self.assertLen(test_data, 18)
|
self.assertLen(test_data, 16)
|
||||||
for _, elem in enumerate(test_data.gen_tf_dataset(is_training=True)):
|
for _, elem in enumerate(test_data.gen_tf_dataset(is_training=True)):
|
||||||
self.assertEqual(elem[0].shape, (1, 128))
|
self.assertEqual(elem[0].shape, (1, 128))
|
||||||
self.assertEqual(elem[1].shape, ([1, 4]))
|
self.assertEqual(elem[1].shape, ([1, 4]))
|
||||||
|
@ -60,7 +56,7 @@ class DatasetTest(tf.test.TestCase, parameterized.TestCase):
|
||||||
for _, elem in enumerate(data.gen_tf_dataset(is_training=True)):
|
for _, elem in enumerate(data.gen_tf_dataset(is_training=True)):
|
||||||
self.assertEqual(elem[0].shape, (1, 128))
|
self.assertEqual(elem[0].shape, (1, 128))
|
||||||
self.assertEqual(elem[1].shape, ([1, 4]))
|
self.assertEqual(elem[1].shape, ([1, 4]))
|
||||||
self.assertLen(data, 35)
|
self.assertLen(data, 32)
|
||||||
self.assertEqual(data.num_classes, 4)
|
self.assertEqual(data.num_classes, 4)
|
||||||
self.assertEqual(data.label_names, ['none', 'call', 'four', 'rock'])
|
self.assertEqual(data.label_names, ['none', 'call', 'four', 'rock'])
|
||||||
|
|
||||||
|
@ -105,51 +101,42 @@ class DatasetTest(tf.test.TestCase, parameterized.TestCase):
|
||||||
for _, elem in enumerate(data.gen_tf_dataset(is_training=True)):
|
for _, elem in enumerate(data.gen_tf_dataset(is_training=True)):
|
||||||
self.assertEqual(elem[0].shape, (1, 128))
|
self.assertEqual(elem[0].shape, (1, 128))
|
||||||
self.assertEqual(elem[1].shape, ([1, 4]))
|
self.assertEqual(elem[1].shape, ([1, 4]))
|
||||||
self.assertLen(data, 35)
|
self.assertLen(data, 32)
|
||||||
self.assertEqual(data.num_classes, 4)
|
self.assertEqual(data.num_classes, 4)
|
||||||
self.assertEqual(data.label_names, ['NONE', 'CALL', 'FOUR', 'ROCK'])
|
self.assertEqual(data.label_names, ['NONE', 'CALL', 'FOUR', 'ROCK'])
|
||||||
|
|
||||||
@parameterized.named_parameters(
|
@parameterized.named_parameters(
|
||||||
dict(
|
dict(
|
||||||
testcase_name='invalid_field_name_multi_hand_landmark',
|
testcase_name='none_handedness',
|
||||||
hand=collections.namedtuple('Hand', [
|
hand=hand_landmarker.HandLandmarkerResult(
|
||||||
'multi_hand_landmark', 'multi_hand_world_landmarks',
|
handedness=None, hand_landmarks=[[2]],
|
||||||
'multi_handedness'
|
hand_world_landmarks=[[3]])),
|
||||||
])(1, 2, 3)),
|
|
||||||
dict(
|
dict(
|
||||||
testcase_name='invalid_field_name_multi_hand_world_landmarks',
|
testcase_name='none_hand_landmarks',
|
||||||
hand=collections.namedtuple('Hand', [
|
hand=hand_landmarker.HandLandmarkerResult(
|
||||||
'multi_hand_landmarks', 'multi_hand_world_landmark',
|
handedness=[[1]], hand_landmarks=None,
|
||||||
'multi_handedness'
|
hand_world_landmarks=[[3]])),
|
||||||
])(1, 2, 3)),
|
|
||||||
dict(
|
dict(
|
||||||
testcase_name='invalid_field_name_multi_handed',
|
testcase_name='none_hand_world_landmarks',
|
||||||
hand=collections.namedtuple('Hand', [
|
hand=hand_landmarker.HandLandmarkerResult(
|
||||||
'multi_hand_landmarks', 'multi_hand_world_landmarks',
|
handedness=[[1]], hand_landmarks=[[2]],
|
||||||
'multi_handed'
|
hand_world_landmarks=None)),
|
||||||
])(1, 2, 3)),
|
|
||||||
dict(
|
dict(
|
||||||
testcase_name='multi_hand_landmarks_is_none',
|
testcase_name='empty_handedness',
|
||||||
hand=collections.namedtuple('Hand', [
|
hand=hand_landmarker.HandLandmarkerResult(
|
||||||
'multi_hand_landmarks', 'multi_hand_world_landmarks',
|
handedness=[], hand_landmarks=[[2]], hand_world_landmarks=[[3]])),
|
||||||
'multi_handedness'
|
|
||||||
])(None, 2, 3)),
|
|
||||||
dict(
|
dict(
|
||||||
testcase_name='multi_hand_world_landmarks_is_none',
|
testcase_name='empty_hand_landmarks',
|
||||||
hand=collections.namedtuple('Hand', [
|
hand=hand_landmarker.HandLandmarkerResult(
|
||||||
'multi_hand_landmarks', 'multi_hand_world_landmarks',
|
handedness=[[1]], hand_landmarks=[], hand_world_landmarks=[[3]])),
|
||||||
'multi_handedness'
|
|
||||||
])(1, None, 3)),
|
|
||||||
dict(
|
dict(
|
||||||
testcase_name='multi_handedness_is_none',
|
testcase_name='empty_hand_world_landmarks',
|
||||||
hand=collections.namedtuple('Hand', [
|
hand=hand_landmarker.HandLandmarkerResult(
|
||||||
'multi_hand_landmarks', 'multi_hand_world_landmarks',
|
handedness=[[1]], hand_landmarks=[[2]], hand_world_landmarks=[])),
|
||||||
'multi_handedness'
|
|
||||||
])(1, 2, None)),
|
|
||||||
)
|
)
|
||||||
def test_create_dataset_from_invalid_hand_data(self, hand: NamedTuple):
|
def test_create_dataset_from_invalid_hand_data(self, hand: NamedTuple):
|
||||||
with unittest.mock.patch.object(
|
with unittest.mock.patch.object(
|
||||||
mp_hands.Hands, 'process', return_value=hand):
|
hand_landmarker.HandLandmarker, 'detect', return_value=hand):
|
||||||
input_data_dir = test_utils.get_test_data_path(_TEST_DATA_DIRNAME)
|
input_data_dir = test_utils.get_test_data_path(_TEST_DATA_DIRNAME)
|
||||||
with self.assertRaisesRegex(ValueError, 'No valid hand is detected'):
|
with self.assertRaisesRegex(ValueError, 'No valid hand is detected'):
|
||||||
dataset.Dataset.from_folder(
|
dataset.Dataset.from_folder(
|
||||||
|
|
|
@ -62,6 +62,50 @@ def read_file(file_path: str, mode: str = "rb") -> Union[str, bytes]:
|
||||||
return f.read()
|
return f.read()
|
||||||
|
|
||||||
|
|
||||||
|
class HandLandmarkerMetadataWriter:
|
||||||
|
"""MetadataWriter to write the model asset bundle for HandLandmarker."""
|
||||||
|
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
hand_detector_model_buffer: bytearray,
|
||||||
|
hand_landmarks_detector_model_buffer: bytearray,
|
||||||
|
) -> None:
|
||||||
|
"""Initializes HandLandmarkerMetadataWriter to write model asset bundle.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
hand_detector_model_buffer: A valid flatbuffer *with* metadata loaded from
|
||||||
|
the TFLite hand detector model file.
|
||||||
|
hand_landmarks_detector_model_buffer: A valid flatbuffer *with* metadata
|
||||||
|
loaded from the TFLite hand landmarks detector model file.
|
||||||
|
"""
|
||||||
|
self._hand_detector_model_buffer = hand_detector_model_buffer
|
||||||
|
self._hand_landmarks_detector_model_buffer = hand_landmarks_detector_model_buffer
|
||||||
|
self._temp_folder = tempfile.TemporaryDirectory()
|
||||||
|
|
||||||
|
def __del__(self):
|
||||||
|
if os.path.exists(self._temp_folder.name):
|
||||||
|
self._temp_folder.cleanup()
|
||||||
|
|
||||||
|
def populate(self):
|
||||||
|
"""Creates the model asset bundle for hand landmarker task.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Model asset bundle in bytes
|
||||||
|
"""
|
||||||
|
landmark_models = {
|
||||||
|
_HAND_DETECTOR_TFLITE_NAME:
|
||||||
|
self._hand_detector_model_buffer,
|
||||||
|
_HAND_LANDMARKS_DETECTOR_TFLITE_NAME:
|
||||||
|
self._hand_landmarks_detector_model_buffer
|
||||||
|
}
|
||||||
|
output_hand_landmarker_path = os.path.join(self._temp_folder.name,
|
||||||
|
_HAND_LANDMARKER_BUNDLE_NAME)
|
||||||
|
writer_utils.create_model_asset_bundle(landmark_models,
|
||||||
|
output_hand_landmarker_path)
|
||||||
|
hand_landmarker_model_buffer = read_file(output_hand_landmarker_path)
|
||||||
|
return hand_landmarker_model_buffer
|
||||||
|
|
||||||
|
|
||||||
class MetadataWriter:
|
class MetadataWriter:
|
||||||
"""MetadataWriter to write the metadata and the model asset bundle."""
|
"""MetadataWriter to write the metadata and the model asset bundle."""
|
||||||
|
|
||||||
|
@ -86,8 +130,8 @@ class MetadataWriter:
|
||||||
custom_gesture_classifier_metadata_writer: Metadata writer to write custom
|
custom_gesture_classifier_metadata_writer: Metadata writer to write custom
|
||||||
gesture classifier metadata into the TFLite file.
|
gesture classifier metadata into the TFLite file.
|
||||||
"""
|
"""
|
||||||
self._hand_detector_model_buffer = hand_detector_model_buffer
|
self._hand_landmarker_metadata_writer = HandLandmarkerMetadataWriter(
|
||||||
self._hand_landmarks_detector_model_buffer = hand_landmarks_detector_model_buffer
|
hand_detector_model_buffer, hand_landmarks_detector_model_buffer)
|
||||||
self._gesture_embedder_model_buffer = gesture_embedder_model_buffer
|
self._gesture_embedder_model_buffer = gesture_embedder_model_buffer
|
||||||
self._canned_gesture_classifier_model_buffer = canned_gesture_classifier_model_buffer
|
self._canned_gesture_classifier_model_buffer = canned_gesture_classifier_model_buffer
|
||||||
self._custom_gesture_classifier_metadata_writer = custom_gesture_classifier_metadata_writer
|
self._custom_gesture_classifier_metadata_writer = custom_gesture_classifier_metadata_writer
|
||||||
|
@ -147,16 +191,8 @@ class MetadataWriter:
|
||||||
A tuple of (model_asset_bundle_in_bytes, metadata_json_content)
|
A tuple of (model_asset_bundle_in_bytes, metadata_json_content)
|
||||||
"""
|
"""
|
||||||
# Creates the model asset bundle for hand landmarker task.
|
# Creates the model asset bundle for hand landmarker task.
|
||||||
landmark_models = {
|
hand_landmarker_model_buffer = self._hand_landmarker_metadata_writer.populate(
|
||||||
_HAND_DETECTOR_TFLITE_NAME:
|
)
|
||||||
self._hand_detector_model_buffer,
|
|
||||||
_HAND_LANDMARKS_DETECTOR_TFLITE_NAME:
|
|
||||||
self._hand_landmarks_detector_model_buffer
|
|
||||||
}
|
|
||||||
output_hand_landmarker_path = os.path.join(self._temp_folder.name,
|
|
||||||
_HAND_LANDMARKER_BUNDLE_NAME)
|
|
||||||
writer_utils.create_model_asset_bundle(landmark_models,
|
|
||||||
output_hand_landmarker_path)
|
|
||||||
|
|
||||||
# Write metadata into custom gesture classifier model.
|
# Write metadata into custom gesture classifier model.
|
||||||
self._custom_gesture_classifier_model_buffer, custom_gesture_classifier_metadata_json = self._custom_gesture_classifier_metadata_writer.populate(
|
self._custom_gesture_classifier_model_buffer, custom_gesture_classifier_metadata_json = self._custom_gesture_classifier_metadata_writer.populate(
|
||||||
|
@ -179,7 +215,7 @@ class MetadataWriter:
|
||||||
# graph.
|
# graph.
|
||||||
gesture_recognizer_models = {
|
gesture_recognizer_models = {
|
||||||
_HAND_LANDMARKER_BUNDLE_NAME:
|
_HAND_LANDMARKER_BUNDLE_NAME:
|
||||||
read_file(output_hand_landmarker_path),
|
hand_landmarker_model_buffer,
|
||||||
_HAND_GESTURE_RECOGNIZER_BUNDLE_NAME:
|
_HAND_GESTURE_RECOGNIZER_BUNDLE_NAME:
|
||||||
read_file(output_hand_gesture_recognizer_path),
|
read_file(output_hand_gesture_recognizer_path),
|
||||||
}
|
}
|
||||||
|
|
|
@ -33,6 +33,23 @@ _CUSTOM_GESTURE_CLASSIFIER_PATH = test_utils.get_test_data_path(
|
||||||
|
|
||||||
class MetadataWriterTest(tf.test.TestCase):
|
class MetadataWriterTest(tf.test.TestCase):
|
||||||
|
|
||||||
|
def test_hand_landmarker_metadata_writer(self):
|
||||||
|
# Use dummy model buffer for unit test only.
|
||||||
|
hand_detector_model_buffer = b"\x11\x12"
|
||||||
|
hand_landmarks_detector_model_buffer = b"\x22"
|
||||||
|
writer = metadata_writer.HandLandmarkerMetadataWriter(
|
||||||
|
hand_detector_model_buffer, hand_landmarks_detector_model_buffer)
|
||||||
|
model_bundle_content = writer.populate()
|
||||||
|
model_bundle_filepath = os.path.join(self.get_temp_dir(),
|
||||||
|
"hand_landmarker.task")
|
||||||
|
with open(model_bundle_filepath, "wb") as f:
|
||||||
|
f.write(model_bundle_content)
|
||||||
|
|
||||||
|
with zipfile.ZipFile(model_bundle_filepath) as zf:
|
||||||
|
self.assertEqual(
|
||||||
|
set(zf.namelist()),
|
||||||
|
set(["hand_landmarks_detector.tflite", "hand_detector.tflite"]))
|
||||||
|
|
||||||
def test_write_metadata_and_create_model_asset_bundle_successful(self):
|
def test_write_metadata_and_create_model_asset_bundle_successful(self):
|
||||||
# Use dummy model buffer for unit test only.
|
# Use dummy model buffer for unit test only.
|
||||||
hand_detector_model_buffer = b"\x11\x12"
|
hand_detector_model_buffer = b"\x11\x12"
|
||||||
|
|
|
@ -23,15 +23,15 @@ py_library(
|
||||||
srcs = [
|
srcs = [
|
||||||
"optional_dependencies.py",
|
"optional_dependencies.py",
|
||||||
],
|
],
|
||||||
deps = [
|
|
||||||
"@org_tensorflow//tensorflow/tools/docs:doc_controls",
|
|
||||||
],
|
|
||||||
)
|
)
|
||||||
|
|
||||||
py_library(
|
py_library(
|
||||||
name = "base_options",
|
name = "base_options",
|
||||||
srcs = ["base_options.py"],
|
srcs = ["base_options.py"],
|
||||||
visibility = ["//mediapipe/tasks:users"],
|
visibility = [
|
||||||
|
"//mediapipe/model_maker/python/vision/gesture_recognizer:__subpackages__",
|
||||||
|
"//mediapipe/tasks:users",
|
||||||
|
],
|
||||||
deps = [
|
deps = [
|
||||||
":optional_dependencies",
|
":optional_dependencies",
|
||||||
"//mediapipe/tasks/cc/core/proto:base_options_py_pb2",
|
"//mediapipe/tasks/cc/core/proto:base_options_py_pb2",
|
||||||
|
|
|
@ -131,6 +131,10 @@ py_library(
|
||||||
srcs = [
|
srcs = [
|
||||||
"hand_landmarker.py",
|
"hand_landmarker.py",
|
||||||
],
|
],
|
||||||
|
visibility = [
|
||||||
|
"//mediapipe/model_maker/python/vision/gesture_recognizer:__subpackages__",
|
||||||
|
"//mediapipe/tasks:internal",
|
||||||
|
],
|
||||||
deps = [
|
deps = [
|
||||||
"//mediapipe/framework/formats:classification_py_pb2",
|
"//mediapipe/framework/formats:classification_py_pb2",
|
||||||
"//mediapipe/framework/formats:landmark_py_pb2",
|
"//mediapipe/framework/formats:landmark_py_pb2",
|
||||||
|
|
Loading…
Reference in New Issue
Block a user