Updated face landmarker implementation and tests

This commit is contained in:
kinaryml 2023-03-13 08:46:41 -07:00
parent 89be4c7b64
commit efae2830f1
3 changed files with 42 additions and 34 deletions

View File

@ -94,6 +94,7 @@ cc_library(
"//mediapipe/tasks/cc/vision/image_embedder:image_embedder_graph", "//mediapipe/tasks/cc/vision/image_embedder:image_embedder_graph",
"//mediapipe/tasks/cc/vision/image_segmenter:image_segmenter_graph", "//mediapipe/tasks/cc/vision/image_segmenter:image_segmenter_graph",
"//mediapipe/tasks/cc/vision/object_detector:object_detector_graph", "//mediapipe/tasks/cc/vision/object_detector:object_detector_graph",
"//mediapipe/tasks/cc/vision/face_landmarker:face_landmarker_graph",
] + select({ ] + select({
# TODO: Build text_classifier_graph and text_embedder_graph on Windows. # TODO: Build text_classifier_graph and text_embedder_graph on Windows.
"//mediapipe:windows": [], "//mediapipe:windows": [],

View File

@ -152,7 +152,7 @@ class HandLandmarkerTest(parameterized.TestCase):
_get_expected_face_landmarks(_PORTRAIT_EXPECTED_FACE_LANDMARKS)), _get_expected_face_landmarks(_PORTRAIT_EXPECTED_FACE_LANDMARKS)),
(ModelFileType.FILE_CONTENT, (ModelFileType.FILE_CONTENT,
_get_expected_face_landmarks(_PORTRAIT_EXPECTED_FACE_LANDMARKS))) _get_expected_face_landmarks(_PORTRAIT_EXPECTED_FACE_LANDMARKS)))
def test_detect(self, model_file_type, expected_result): def test_detect(self, model_file_type, expected_face_landmarks):
# Creates face landmarker. # Creates face landmarker.
if model_file_type is ModelFileType.FILE_NAME: if model_file_type is ModelFileType.FILE_NAME:
base_options = _BaseOptions(model_asset_path=self.model_path) base_options = _BaseOptions(model_asset_path=self.model_path)
@ -164,15 +164,14 @@ class HandLandmarkerTest(parameterized.TestCase):
# Should never happen # Should never happen
raise ValueError('model_file_type is invalid.') raise ValueError('model_file_type is invalid.')
options = _FaceLandmarkerOptions(base_options=base_options, options = _FaceLandmarkerOptions(base_options=base_options)
output_face_blendshapes=True)
landmarker = _FaceLandmarker.create_from_options(options) landmarker = _FaceLandmarker.create_from_options(options)
# Performs face landmarks detection on the input. # Performs face landmarks detection on the input.
detection_result = landmarker.detect(self.test_image) detection_result = landmarker.detect(self.test_image)
# Comparing results. # Comparing results.
self._expect_landmarks_correct(detection_result.face_landmarks, self._expect_landmarks_correct(detection_result.face_landmarks[0],
expected_result.face_landmarks) expected_face_landmarks)
# Closes the face landmarker explicitly when the face landmarker is not used # Closes the face landmarker explicitly when the face landmarker is not used
# in a context. # in a context.
landmarker.close() landmarker.close()

View File

@ -132,10 +132,6 @@ def _build_landmarker_result(
"""Constructs a `FaceLandmarkerResult` from output packets.""" """Constructs a `FaceLandmarkerResult` from output packets."""
face_landmarks_proto_list = packet_getter.get_proto_list( face_landmarks_proto_list = packet_getter.get_proto_list(
output_packets[_NORM_LANDMARKS_STREAM_NAME]) output_packets[_NORM_LANDMARKS_STREAM_NAME])
face_blendshapes_proto_list = packet_getter.get_proto_list(
output_packets[_BLENDSHAPES_STREAM_NAME])
facial_transformation_matrixes_proto_list = packet_getter.get_proto_list(
output_packets[_FACE_GEOMETRY_STREAM_NAME])
face_landmarks_results = [] face_landmarks_results = []
for proto in face_landmarks_proto_list: for proto in face_landmarks_proto_list:
@ -143,11 +139,14 @@ def _build_landmarker_result(
face_landmarks.MergeFrom(proto) face_landmarks.MergeFrom(proto)
face_landmarks_list = [] face_landmarks_list = []
for face_landmark in face_landmarks.landmark: for face_landmark in face_landmarks.landmark:
face_landmarks.append( face_landmarks_list.append(
landmark_module.NormalizedLandmark.create_from_pb2(face_landmark)) landmark_module.NormalizedLandmark.create_from_pb2(face_landmark))
face_landmarks_results.append(face_landmarks_list) face_landmarks_results.append(face_landmarks_list)
face_blendshapes_results = [] face_blendshapes_results = []
if _BLENDSHAPES_STREAM_NAME in output_packets:
face_blendshapes_proto_list = packet_getter.get_proto_list(
output_packets[_BLENDSHAPES_STREAM_NAME])
for proto in face_blendshapes_proto_list: for proto in face_blendshapes_proto_list:
face_blendshapes_categories = [] face_blendshapes_categories = []
face_blendshapes_classifications = classification_pb2.ClassificationList() face_blendshapes_classifications = classification_pb2.ClassificationList()
@ -162,6 +161,9 @@ def _build_landmarker_result(
face_blendshapes_results.append(face_blendshapes_categories) face_blendshapes_results.append(face_blendshapes_categories)
facial_transformation_matrixes_results = [] facial_transformation_matrixes_results = []
if _FACE_GEOMETRY_STREAM_NAME in output_packets:
facial_transformation_matrixes_proto_list = packet_getter.get_proto_list(
output_packets[_FACE_GEOMETRY_STREAM_NAME])
for proto in facial_transformation_matrixes_proto_list: for proto in facial_transformation_matrixes_proto_list:
matrix_data = matrix_data_pb2.MatrixData() matrix_data = matrix_data_pb2.MatrixData()
matrix_data.MergeFrom(proto) matrix_data.MergeFrom(proto)
@ -298,19 +300,25 @@ class FaceLandmarker(base_vision_task_api.BaseVisionTaskApi):
options.result_callback(face_landmarks_result, image, options.result_callback(face_landmarks_result, image,
timestamp.value // _MICRO_SECONDS_PER_MILLISECOND) timestamp.value // _MICRO_SECONDS_PER_MILLISECOND)
output_streams = [
':'.join([_NORM_LANDMARKS_TAG, _NORM_LANDMARKS_STREAM_NAME]),
':'.join([_IMAGE_TAG, _IMAGE_OUT_STREAM_NAME])
]
if options.output_face_blendshapes:
output_streams.append(
':'.join([_BLENDSHAPES_TAG, _BLENDSHAPES_STREAM_NAME]))
if options.output_facial_transformation_matrixes:
output_streams.append(
':'.join([_FACE_GEOMETRY_TAG, _FACE_GEOMETRY_STREAM_NAME]))
task_info = _TaskInfo( task_info = _TaskInfo(
task_graph=_TASK_GRAPH_NAME, task_graph=_TASK_GRAPH_NAME,
input_streams=[ input_streams=[
':'.join([_IMAGE_TAG, _IMAGE_IN_STREAM_NAME]), ':'.join([_IMAGE_TAG, _IMAGE_IN_STREAM_NAME]),
':'.join([_NORM_RECT_TAG, _NORM_RECT_STREAM_NAME]), ':'.join([_NORM_RECT_TAG, _NORM_RECT_STREAM_NAME]),
], ],
output_streams=[ output_streams=output_streams,
':'.join([_NORM_LANDMARKS_TAG, _NORM_LANDMARKS_STREAM_NAME]),
':'.join([_BLENDSHAPES_TAG, _BLENDSHAPES_STREAM_NAME]),
':'.join([
_FACE_GEOMETRY_TAG, _FACE_GEOMETRY_STREAM_NAME
]), ':'.join([_IMAGE_TAG, _IMAGE_OUT_STREAM_NAME])
],
task_options=options) task_options=options)
return cls( return cls(
task_info.generate_graph_config( task_info.generate_graph_config(