Fixed some issues in the MatrixData container, revised the implementation and added more tests

This commit is contained in:
kinaryml 2023-03-15 10:41:36 -07:00
parent 06c37c6442
commit 4a6015e65c
4 changed files with 341 additions and 27 deletions

View File

@ -24,6 +24,11 @@ from mediapipe.tasks.python.core.optional_dependencies import doc_controls
_MatrixDataProto = matrix_data_pb2.MatrixData _MatrixDataProto = matrix_data_pb2.MatrixData
class Layout(enum.Enum):
COLUMN_MAJOR = 0
ROW_MAJOR = 1
@dataclasses.dataclass @dataclasses.dataclass
class MatrixData: class MatrixData:
"""This stores the Matrix data. """This stores the Matrix data.
@ -37,10 +42,6 @@ class MatrixData:
layout: The order in which the data are stored. Defaults to COLUMN_MAJOR. layout: The order in which the data are stored. Defaults to COLUMN_MAJOR.
""" """
class Layout(enum.Enum):
COLUMN_MAJOR = 0
ROW_MAJOR = 1
rows: int = None rows: int = None
cols: int = None cols: int = None
data: np.ndarray = None data: np.ndarray = None
@ -52,8 +53,8 @@ class MatrixData:
return _MatrixDataProto( return _MatrixDataProto(
rows=self.rows, rows=self.rows,
cols=self.cols, cols=self.cols,
data=self.data.tolist(), packed_data=self.data,
layout=self.layout) layout=self.layout.value)
@classmethod @classmethod
@doc_controls.do_not_generate_docs @doc_controls.do_not_generate_docs
@ -62,8 +63,8 @@ class MatrixData:
return MatrixData( return MatrixData(
rows=pb2_obj.rows, rows=pb2_obj.rows,
cols=pb2_obj.cols, cols=pb2_obj.cols,
data=np.array(pb2_obj.data), data=np.array(pb2_obj.packed_data),
layout=pb2_obj.layout) layout=Layout(pb2_obj.layout))
def __eq__(self, other: Any) -> bool: def __eq__(self, other: Any) -> bool:
"""Checks if this object is equal to the given object. """Checks if this object is equal to the given object.

View File

@ -50,12 +50,13 @@ _ImageProcessingOptions = image_processing_options_module.ImageProcessingOptions
_FACE_LANDMARKER_BUNDLE_ASSET_FILE = 'face_landmarker.task' _FACE_LANDMARKER_BUNDLE_ASSET_FILE = 'face_landmarker.task'
_FACE_LANDMARKER_WITH_BLENDSHAPES_BUNDLE_ASSET_FILE = 'face_landmarker_with_blendshapes.task' _FACE_LANDMARKER_WITH_BLENDSHAPES_BUNDLE_ASSET_FILE = 'face_landmarker_with_blendshapes.task'
_PORTRAIT_IMAGE = 'portrait.jpg' _PORTRAIT_IMAGE = 'portrait.jpg'
_CAT_IMAGE = 'cat.jpg'
_PORTRAIT_EXPECTED_FACE_LANDMARKS = 'portrait_expected_face_landmarks.pbtxt' _PORTRAIT_EXPECTED_FACE_LANDMARKS = 'portrait_expected_face_landmarks.pbtxt'
_PORTRAIT_EXPECTED_FACE_LANDMARKS_WITH_ATTENTION = 'portrait_expected_face_landmarks_with_attention.pbtxt' _PORTRAIT_EXPECTED_FACE_LANDMARKS_WITH_ATTENTION = 'portrait_expected_face_landmarks_with_attention.pbtxt'
_PORTRAIT_EXPECTED_BLENDSHAPES = 'portrait_expected_blendshapes_with_attention.pbtxt' _PORTRAIT_EXPECTED_BLENDSHAPES = 'portrait_expected_blendshapes_with_attention.pbtxt'
_PORTRAIT_EXPECTED_FACE_GEOMETRY = 'portrait_expected_face_geometry_with_attention.pbtxt' _PORTRAIT_EXPECTED_FACE_GEOMETRY = 'portrait_expected_face_geometry_with_attention.pbtxt'
_LANDMARKS_DIFF_MARGIN = 0.03 _LANDMARKS_DIFF_MARGIN = 0.03
_BLENDSHAPES_DIFF_MARGIN = 0.1 _BLENDSHAPES_DIFF_MARGIN = 0.12
_FACIAL_TRANSFORMATION_MATRIX_DIFF_MARGIN = 0.02 _FACIAL_TRANSFORMATION_MATRIX_DIFF_MARGIN = 0.02
@ -90,12 +91,12 @@ def _get_expected_face_blendshapes(file_path: str):
def _make_expected_facial_transformation_matrixes(): def _make_expected_facial_transformation_matrixes():
data = np.array([[0.9995292, -0.005092691, 0.030254554, -0.37340546], data = np.array([[0.9995292, -0.005092691, 0.030254554, -0.37340546],
[0.0072318087, 0.99744856, -0.07102106, 22.212194], [0.0072318087, 0.99744856, -0.07102106, 22.212194],
[-0.029815676, 0.07120642, 0.9970159, -64.76358], [-0.029815676, 0.07120642, 0.9970159, -64.76358],
[0, 0, 0, 1]]) [0, 0, 0, 1]])
rows, cols = len(data), len(data[0]) rows, cols = len(data), len(data[0])
facial_transformation_matrixes_results = [] facial_transformation_matrixes_results = []
facial_transformation_matrix = _MatrixData(rows, cols, data) facial_transformation_matrix = _MatrixData(rows, cols, data.flatten())
facial_transformation_matrixes_results.append(facial_transformation_matrix) facial_transformation_matrixes_results.append(facial_transformation_matrix)
return facial_transformation_matrixes_results return facial_transformation_matrixes_results
@ -147,8 +148,8 @@ class FaceLandmarkerTest(parameterized.TestCase):
self.assertEqual(rename_me.rows, expected_matrix_list[i].rows) self.assertEqual(rename_me.rows, expected_matrix_list[i].rows)
self.assertEqual(rename_me.cols, expected_matrix_list[i].cols) self.assertEqual(rename_me.cols, expected_matrix_list[i].cols)
self.assertAlmostEqual( self.assertAlmostEqual(
rename_me.data, rename_me.data.all(),
expected_matrix_list[i].data, expected_matrix_list[i].data.all(),
delta=_FACIAL_TRANSFORMATION_MATRIX_DIFF_MARGIN) delta=_FACIAL_TRANSFORMATION_MATRIX_DIFF_MARGIN)
def test_create_from_file_succeeds_with_valid_model_path(self): def test_create_from_file_succeeds_with_valid_model_path(self):
@ -220,10 +221,10 @@ class FaceLandmarkerTest(parameterized.TestCase):
_PORTRAIT_EXPECTED_FACE_LANDMARKS_WITH_ATTENTION), _PORTRAIT_EXPECTED_FACE_LANDMARKS_WITH_ATTENTION),
_get_expected_face_blendshapes( _get_expected_face_blendshapes(
_PORTRAIT_EXPECTED_BLENDSHAPES), _PORTRAIT_EXPECTED_BLENDSHAPES),
_make_expected_facial_transformation_matrixes()) _make_expected_facial_transformation_matrixes()))
) def test_detect(
def test_detect(self, model_file_type, model_name, expected_face_landmarks, self, model_file_type, model_name, expected_face_landmarks,
expected_face_blendshapes, expected_facial_transformation_matrix): expected_face_blendshapes, expected_facial_transformation_matrixes):
# Creates face landmarker. # Creates face landmarker.
model_path = test_utils.get_test_data_path(model_name) model_path = test_utils.get_test_data_path(model_name)
if model_file_type is ModelFileType.FILE_NAME: if model_file_type is ModelFileType.FILE_NAME:
@ -240,7 +241,7 @@ class FaceLandmarkerTest(parameterized.TestCase):
base_options=base_options, base_options=base_options,
output_face_blendshapes=True if expected_face_blendshapes else False, output_face_blendshapes=True if expected_face_blendshapes else False,
output_facial_transformation_matrixes=True output_facial_transformation_matrixes=True
if expected_facial_transformation_matrix else False) if expected_facial_transformation_matrixes else False)
landmarker = _FaceLandmarker.create_from_options(options) landmarker = _FaceLandmarker.create_from_options(options)
# Performs face landmarks detection on the input. # Performs face landmarks detection on the input.
@ -252,15 +253,317 @@ class FaceLandmarkerTest(parameterized.TestCase):
if expected_face_blendshapes is not None: if expected_face_blendshapes is not None:
self._expect_blendshapes_correct(detection_result.face_blendshapes[0], self._expect_blendshapes_correct(detection_result.face_blendshapes[0],
expected_face_blendshapes) expected_face_blendshapes)
if expected_facial_transformation_matrix is not None: if expected_facial_transformation_matrixes is not None:
self._expect_facial_transformation_matrix_correct( self._expect_facial_transformation_matrix_correct(
detection_result.facial_transformation_matrixes[0], detection_result.facial_transformation_matrixes,
expected_facial_transformation_matrix) expected_facial_transformation_matrixes)
# Closes the face landmarker explicitly when the face landmarker is not used # Closes the face landmarker explicitly when the face landmarker is not used
# in a context. # in a context.
landmarker.close() landmarker.close()
@parameterized.parameters(
(ModelFileType.FILE_NAME, _FACE_LANDMARKER_BUNDLE_ASSET_FILE,
_get_expected_face_landmarks(
_PORTRAIT_EXPECTED_FACE_LANDMARKS), None, None),
(ModelFileType.FILE_CONTENT, _FACE_LANDMARKER_BUNDLE_ASSET_FILE,
_get_expected_face_landmarks(
_PORTRAIT_EXPECTED_FACE_LANDMARKS), None, None),
(ModelFileType.FILE_NAME,
_FACE_LANDMARKER_WITH_BLENDSHAPES_BUNDLE_ASSET_FILE,
_get_expected_face_landmarks(
_PORTRAIT_EXPECTED_FACE_LANDMARKS_WITH_ATTENTION), None, None),
(ModelFileType.FILE_CONTENT,
_FACE_LANDMARKER_WITH_BLENDSHAPES_BUNDLE_ASSET_FILE,
_get_expected_face_landmarks(
_PORTRAIT_EXPECTED_FACE_LANDMARKS_WITH_ATTENTION), None, None),
(ModelFileType.FILE_NAME,
_FACE_LANDMARKER_WITH_BLENDSHAPES_BUNDLE_ASSET_FILE,
_get_expected_face_landmarks(
_PORTRAIT_EXPECTED_FACE_LANDMARKS_WITH_ATTENTION),
_get_expected_face_blendshapes(
_PORTRAIT_EXPECTED_BLENDSHAPES), None),
(ModelFileType.FILE_CONTENT,
_FACE_LANDMARKER_WITH_BLENDSHAPES_BUNDLE_ASSET_FILE,
_get_expected_face_landmarks(
_PORTRAIT_EXPECTED_FACE_LANDMARKS_WITH_ATTENTION),
_get_expected_face_blendshapes(
_PORTRAIT_EXPECTED_BLENDSHAPES), None),
(ModelFileType.FILE_NAME,
_FACE_LANDMARKER_WITH_BLENDSHAPES_BUNDLE_ASSET_FILE,
_get_expected_face_landmarks(
_PORTRAIT_EXPECTED_FACE_LANDMARKS_WITH_ATTENTION),
_get_expected_face_blendshapes(
_PORTRAIT_EXPECTED_BLENDSHAPES),
_make_expected_facial_transformation_matrixes()),
(ModelFileType.FILE_CONTENT,
_FACE_LANDMARKER_WITH_BLENDSHAPES_BUNDLE_ASSET_FILE,
_get_expected_face_landmarks(
_PORTRAIT_EXPECTED_FACE_LANDMARKS_WITH_ATTENTION),
_get_expected_face_blendshapes(
_PORTRAIT_EXPECTED_BLENDSHAPES),
_make_expected_facial_transformation_matrixes()))
def test_detect_in_context(
self, model_file_type, model_name, expected_face_landmarks,
expected_face_blendshapes, expected_facial_transformation_matrixes):
# Creates face landmarker.
model_path = test_utils.get_test_data_path(model_name)
if model_file_type is ModelFileType.FILE_NAME:
base_options = _BaseOptions(model_asset_path=model_path)
elif model_file_type is ModelFileType.FILE_CONTENT:
with open(model_path, 'rb') as f:
model_content = f.read()
base_options = _BaseOptions(model_asset_buffer=model_content)
else:
# Should never happen
raise ValueError('model_file_type is invalid.')
options = _FaceLandmarkerOptions(
base_options=base_options,
output_face_blendshapes=True if expected_face_blendshapes else False,
output_facial_transformation_matrixes=True
if expected_facial_transformation_matrixes else False)
with _FaceLandmarker.create_from_options(options) as landmarker:
# Performs face landmarks detection on the input.
detection_result = landmarker.detect(self.test_image)
# Comparing results.
if expected_face_landmarks is not None:
self._expect_landmarks_correct(detection_result.face_landmarks[0],
expected_face_landmarks)
if expected_face_blendshapes is not None:
self._expect_blendshapes_correct(detection_result.face_blendshapes[0],
expected_face_blendshapes)
if expected_facial_transformation_matrixes is not None:
self._expect_facial_transformation_matrix_correct(
detection_result.facial_transformation_matrixes,
expected_facial_transformation_matrixes)
def test_detect_succeeds_with_num_faces(self):
# Creates face landmarker.
model_path = test_utils.get_test_data_path(
_FACE_LANDMARKER_WITH_BLENDSHAPES_BUNDLE_ASSET_FILE)
base_options = _BaseOptions(model_asset_path=model_path)
options = _FaceLandmarkerOptions(base_options=base_options, num_faces=1,
output_face_blendshapes=True)
with _FaceLandmarker.create_from_options(options) as landmarker:
# Load the portrait image.
test_image = _Image.create_from_file(
test_utils.get_test_data_path(_PORTRAIT_IMAGE))
# Performs face landmarks detection on the input.
detection_result = landmarker.detect(test_image)
# Comparing results.
self.assertLen(detection_result.face_blendshapes, 1)
def test_empty_detection_outputs(self):
options = _FaceLandmarkerOptions(
base_options=_BaseOptions(model_asset_path=self.model_path))
with _FaceLandmarker.create_from_options(options) as landmarker:
# Load the image with no faces.
no_faces_test_image = _Image.create_from_file(
test_utils.get_test_data_path(_CAT_IMAGE))
# Performs face landmarks detection on the input.
detection_result = landmarker.detect(no_faces_test_image)
self.assertEmpty(detection_result.face_landmarks)
self.assertEmpty(detection_result.face_blendshapes)
self.assertEmpty(detection_result.facial_transformation_matrixes)
def test_missing_result_callback(self):
options = _FaceLandmarkerOptions(
base_options=_BaseOptions(model_asset_path=self.model_path),
running_mode=_RUNNING_MODE.LIVE_STREAM)
with self.assertRaisesRegex(ValueError,
r'result callback must be provided'):
with _FaceLandmarker.create_from_options(options) as unused_landmarker:
pass
@parameterized.parameters((_RUNNING_MODE.IMAGE), (_RUNNING_MODE.VIDEO))
def test_illegal_result_callback(self, running_mode):
options = _FaceLandmarkerOptions(
base_options=_BaseOptions(model_asset_path=self.model_path),
running_mode=running_mode,
result_callback=mock.MagicMock())
with self.assertRaisesRegex(ValueError,
r'result callback should not be provided'):
with _FaceLandmarker.create_from_options(options) as unused_landmarker:
pass
def test_calling_detect_for_video_in_image_mode(self):
options = _FaceLandmarkerOptions(
base_options=_BaseOptions(model_asset_path=self.model_path),
running_mode=_RUNNING_MODE.IMAGE)
with _FaceLandmarker.create_from_options(options) as landmarker:
with self.assertRaisesRegex(ValueError,
r'not initialized with the video mode'):
landmarker.detect_for_video(self.test_image, 0)
def test_calling_detect_async_in_image_mode(self):
options = _FaceLandmarkerOptions(
base_options=_BaseOptions(model_asset_path=self.model_path),
running_mode=_RUNNING_MODE.IMAGE)
with _FaceLandmarker.create_from_options(options) as landmarker:
with self.assertRaisesRegex(ValueError,
r'not initialized with the live stream mode'):
landmarker.detect_async(self.test_image, 0)
def test_calling_detect_in_video_mode(self):
options = _FaceLandmarkerOptions(
base_options=_BaseOptions(model_asset_path=self.model_path),
running_mode=_RUNNING_MODE.VIDEO)
with _FaceLandmarker.create_from_options(options) as landmarker:
with self.assertRaisesRegex(ValueError,
r'not initialized with the image mode'):
landmarker.detect(self.test_image)
def test_calling_detect_async_in_video_mode(self):
options = _FaceLandmarkerOptions(
base_options=_BaseOptions(model_asset_path=self.model_path),
running_mode=_RUNNING_MODE.VIDEO)
with _FaceLandmarker.create_from_options(options) as landmarker:
with self.assertRaisesRegex(ValueError,
r'not initialized with the live stream mode'):
landmarker.detect_async(self.test_image, 0)
def test_detect_for_video_with_out_of_order_timestamp(self):
options = _FaceLandmarkerOptions(
base_options=_BaseOptions(model_asset_path=self.model_path),
running_mode=_RUNNING_MODE.VIDEO)
with _FaceLandmarker.create_from_options(options) as landmarker:
unused_result = landmarker.detect_for_video(self.test_image, 1)
with self.assertRaisesRegex(
ValueError, r'Input timestamp must be monotonically increasing'):
landmarker.detect_for_video(self.test_image, 0)
@parameterized.parameters(
(_FACE_LANDMARKER_BUNDLE_ASSET_FILE, _get_expected_face_landmarks(
_PORTRAIT_EXPECTED_FACE_LANDMARKS), None, None),
(_FACE_LANDMARKER_WITH_BLENDSHAPES_BUNDLE_ASSET_FILE,
_get_expected_face_landmarks(
_PORTRAIT_EXPECTED_FACE_LANDMARKS_WITH_ATTENTION), None, None),
(_FACE_LANDMARKER_WITH_BLENDSHAPES_BUNDLE_ASSET_FILE,
_get_expected_face_landmarks(
_PORTRAIT_EXPECTED_FACE_LANDMARKS_WITH_ATTENTION),
_get_expected_face_blendshapes(_PORTRAIT_EXPECTED_BLENDSHAPES), None),
(_FACE_LANDMARKER_WITH_BLENDSHAPES_BUNDLE_ASSET_FILE,
_get_expected_face_landmarks(
_PORTRAIT_EXPECTED_FACE_LANDMARKS_WITH_ATTENTION),
_get_expected_face_blendshapes(_PORTRAIT_EXPECTED_BLENDSHAPES),
_make_expected_facial_transformation_matrixes()))
def test_detect_for_video(
self, model_name, expected_face_landmarks, expected_face_blendshapes,
expected_facial_transformation_matrixes):
# Creates face landmarker.
model_path = test_utils.get_test_data_path(model_name)
base_options = _BaseOptions(model_asset_path=model_path)
options = _FaceLandmarkerOptions(
base_options=base_options,
running_mode=_RUNNING_MODE.VIDEO,
output_face_blendshapes=True if expected_face_blendshapes else False,
output_facial_transformation_matrixes=True
if expected_facial_transformation_matrixes else False)
with _FaceLandmarker.create_from_options(options) as landmarker:
for timestamp in range(0, 300, 30):
# Performs face landmarks detection on the input.
detection_result = landmarker.detect_for_video(self.test_image,
timestamp)
# Comparing results.
if expected_face_landmarks is not None:
self._expect_landmarks_correct(detection_result.face_landmarks[0],
expected_face_landmarks)
if expected_face_blendshapes is not None:
self._expect_blendshapes_correct(detection_result.face_blendshapes[0],
expected_face_blendshapes)
if expected_facial_transformation_matrixes is not None:
self._expect_facial_transformation_matrix_correct(
detection_result.facial_transformation_matrixes,
expected_facial_transformation_matrixes)
def test_calling_detect_in_live_stream_mode(self):
options = _FaceLandmarkerOptions(
base_options=_BaseOptions(model_asset_path=self.model_path),
running_mode=_RUNNING_MODE.LIVE_STREAM,
result_callback=mock.MagicMock())
with _FaceLandmarker.create_from_options(options) as landmarker:
with self.assertRaisesRegex(ValueError,
r'not initialized with the image mode'):
landmarker.detect(self.test_image)
def test_calling_detect_for_video_in_live_stream_mode(self):
options = _FaceLandmarkerOptions(
base_options=_BaseOptions(model_asset_path=self.model_path),
running_mode=_RUNNING_MODE.LIVE_STREAM,
result_callback=mock.MagicMock())
with _FaceLandmarker.create_from_options(options) as landmarker:
with self.assertRaisesRegex(ValueError,
r'not initialized with the video mode'):
landmarker.detect_for_video(self.test_image, 0)
def test_detect_async_calls_with_illegal_timestamp(self):
options = _FaceLandmarkerOptions(
base_options=_BaseOptions(model_asset_path=self.model_path),
running_mode=_RUNNING_MODE.LIVE_STREAM,
result_callback=mock.MagicMock())
with _FaceLandmarker.create_from_options(options) as landmarker:
landmarker.detect_async(self.test_image, 100)
with self.assertRaisesRegex(
ValueError, r'Input timestamp must be monotonically increasing'):
landmarker.detect_async(self.test_image, 0)
@parameterized.parameters(
(_PORTRAIT_IMAGE, _FACE_LANDMARKER_BUNDLE_ASSET_FILE,
_get_expected_face_landmarks(
_PORTRAIT_EXPECTED_FACE_LANDMARKS), None, None),
(_PORTRAIT_IMAGE, _FACE_LANDMARKER_WITH_BLENDSHAPES_BUNDLE_ASSET_FILE,
_get_expected_face_landmarks(
_PORTRAIT_EXPECTED_FACE_LANDMARKS_WITH_ATTENTION), None, None),
(_PORTRAIT_IMAGE, _FACE_LANDMARKER_WITH_BLENDSHAPES_BUNDLE_ASSET_FILE,
_get_expected_face_landmarks(
_PORTRAIT_EXPECTED_FACE_LANDMARKS_WITH_ATTENTION),
_get_expected_face_blendshapes(_PORTRAIT_EXPECTED_BLENDSHAPES), None),
(_PORTRAIT_IMAGE, _FACE_LANDMARKER_WITH_BLENDSHAPES_BUNDLE_ASSET_FILE,
_get_expected_face_landmarks(
_PORTRAIT_EXPECTED_FACE_LANDMARKS_WITH_ATTENTION),
_get_expected_face_blendshapes(_PORTRAIT_EXPECTED_BLENDSHAPES),
_make_expected_facial_transformation_matrixes()))
def test_detect_async_calls(
self, image_path, model_name, expected_face_landmarks,
expected_face_blendshapes, expected_facial_transformation_matrixes):
test_image = _Image.create_from_file(
test_utils.get_test_data_path(image_path))
observed_timestamp_ms = -1
def check_result(result: FaceLandmarkerResult, output_image: _Image,
timestamp_ms: int):
# Comparing results.
if expected_face_landmarks is not None:
self._expect_landmarks_correct(result.face_landmarks[0],
expected_face_landmarks)
if expected_face_blendshapes is not None:
self._expect_blendshapes_correct(result.face_blendshapes[0],
expected_face_blendshapes)
if expected_facial_transformation_matrixes is not None:
self._expect_facial_transformation_matrix_correct(
result.facial_transformation_matrixes,
expected_facial_transformation_matrixes)
self.assertTrue(
np.array_equal(output_image.numpy_view(), test_image.numpy_view()))
self.assertLess(observed_timestamp_ms, timestamp_ms)
self.observed_timestamp_ms = timestamp_ms
model_path = test_utils.get_test_data_path(model_name)
options = _FaceLandmarkerOptions(
base_options=_BaseOptions(model_asset_path=model_path),
running_mode=_RUNNING_MODE.LIVE_STREAM,
output_face_blendshapes=True if expected_face_blendshapes else False,
output_facial_transformation_matrixes=True
if expected_facial_transformation_matrixes else False,
result_callback=check_result)
with _FaceLandmarker.create_from_options(options) as landmarker:
for timestamp in range(0, 300, 30):
landmarker.detect_async(test_image, timestamp)
if __name__ == '__main__': if __name__ == '__main__':
absltest.main() absltest.main()

View File

@ -166,6 +166,7 @@ py_library(
"//mediapipe/python:packet_creator", "//mediapipe/python:packet_creator",
"//mediapipe/python:packet_getter", "//mediapipe/python:packet_getter",
"//mediapipe/tasks/cc/vision/face_landmarker/proto:face_landmarker_graph_options_py_pb2", "//mediapipe/tasks/cc/vision/face_landmarker/proto:face_landmarker_graph_options_py_pb2",
"//mediapipe/tasks/cc/vision/face_geometry/proto:face_geometry_py_pb2",
"//mediapipe/tasks/python/components/containers:category", "//mediapipe/tasks/python/components/containers:category",
"//mediapipe/tasks/python/components/containers:landmark", "//mediapipe/tasks/python/components/containers:landmark",
"//mediapipe/tasks/python/components/containers:matrix_data", "//mediapipe/tasks/python/components/containers:matrix_data",

View File

@ -25,6 +25,8 @@ from mediapipe.python import packet_getter
from mediapipe.python._framework_bindings import image as image_module from mediapipe.python._framework_bindings import image as image_module
from mediapipe.python._framework_bindings import packet as packet_module from mediapipe.python._framework_bindings import packet as packet_module
from mediapipe.tasks.cc.vision.face_landmarker.proto import face_landmarker_graph_options_pb2 from mediapipe.tasks.cc.vision.face_landmarker.proto import face_landmarker_graph_options_pb2
# TODO: Remove later.
from mediapipe.tasks.cc.vision.face_geometry.proto import face_geometry_pb2
from mediapipe.tasks.python.components.containers import category as category_module from mediapipe.tasks.python.components.containers import category as category_module
from mediapipe.tasks.python.components.containers import landmark as landmark_module from mediapipe.tasks.python.components.containers import landmark as landmark_module
from mediapipe.tasks.python.components.containers import matrix_data as matrix_data_module from mediapipe.tasks.python.components.containers import matrix_data as matrix_data_module
@ -160,15 +162,22 @@ def _build_landmarker_result(
category_name=face_blendshapes.label)) category_name=face_blendshapes.label))
face_blendshapes_results.append(face_blendshapes_categories) face_blendshapes_results.append(face_blendshapes_categories)
# Creates a dummy FaceGeometry packet to initialize the symbol database.
# TODO: Remove later.
face_geometry_in = face_geometry_pb2.FaceGeometry()
p = packet_creator.create_proto(face_geometry_in).at(100)
face_geometry_out = packet_getter.get_proto(p)
facial_transformation_matrixes_results = [] facial_transformation_matrixes_results = []
if _FACE_GEOMETRY_STREAM_NAME in output_packets: if _FACE_GEOMETRY_STREAM_NAME in output_packets:
facial_transformation_matrixes_proto_list = packet_getter.get_proto_list( facial_transformation_matrixes_proto_list = packet_getter.get_proto_list(
output_packets[_FACE_GEOMETRY_STREAM_NAME]) output_packets[_FACE_GEOMETRY_STREAM_NAME])
for proto in facial_transformation_matrixes_proto_list: for proto in facial_transformation_matrixes_proto_list:
matrix_data = matrix_data_pb2.MatrixData() if proto.pose_transform_matrix:
matrix_data.MergeFrom(proto) matrix_data = matrix_data_pb2.MatrixData()
matrix = matrix_data_module.MatrixData.create_from_pb2(matrix_data) matrix_data.MergeFrom(proto.pose_transform_matrix)
facial_transformation_matrixes_results.append(matrix) matrix = matrix_data_module.MatrixData.create_from_pb2(matrix_data)
facial_transformation_matrixes_results.append(matrix)
return FaceLandmarkerResult(face_landmarks_results, face_blendshapes_results, return FaceLandmarkerResult(face_landmarks_results, face_blendshapes_results,
facial_transformation_matrixes_results) facial_transformation_matrixes_results)