Updated face landmarker implementation and tests
This commit is contained in:
parent
89be4c7b64
commit
efae2830f1
|
@ -94,6 +94,7 @@ cc_library(
|
||||||
"//mediapipe/tasks/cc/vision/image_embedder:image_embedder_graph",
|
"//mediapipe/tasks/cc/vision/image_embedder:image_embedder_graph",
|
||||||
"//mediapipe/tasks/cc/vision/image_segmenter:image_segmenter_graph",
|
"//mediapipe/tasks/cc/vision/image_segmenter:image_segmenter_graph",
|
||||||
"//mediapipe/tasks/cc/vision/object_detector:object_detector_graph",
|
"//mediapipe/tasks/cc/vision/object_detector:object_detector_graph",
|
||||||
|
"//mediapipe/tasks/cc/vision/face_landmarker:face_landmarker_graph",
|
||||||
] + select({
|
] + select({
|
||||||
# TODO: Build text_classifier_graph and text_embedder_graph on Windows.
|
# TODO: Build text_classifier_graph and text_embedder_graph on Windows.
|
||||||
"//mediapipe:windows": [],
|
"//mediapipe:windows": [],
|
||||||
|
|
|
@ -152,7 +152,7 @@ class HandLandmarkerTest(parameterized.TestCase):
|
||||||
_get_expected_face_landmarks(_PORTRAIT_EXPECTED_FACE_LANDMARKS)),
|
_get_expected_face_landmarks(_PORTRAIT_EXPECTED_FACE_LANDMARKS)),
|
||||||
(ModelFileType.FILE_CONTENT,
|
(ModelFileType.FILE_CONTENT,
|
||||||
_get_expected_face_landmarks(_PORTRAIT_EXPECTED_FACE_LANDMARKS)))
|
_get_expected_face_landmarks(_PORTRAIT_EXPECTED_FACE_LANDMARKS)))
|
||||||
def test_detect(self, model_file_type, expected_result):
|
def test_detect(self, model_file_type, expected_face_landmarks):
|
||||||
# Creates face landmarker.
|
# Creates face landmarker.
|
||||||
if model_file_type is ModelFileType.FILE_NAME:
|
if model_file_type is ModelFileType.FILE_NAME:
|
||||||
base_options = _BaseOptions(model_asset_path=self.model_path)
|
base_options = _BaseOptions(model_asset_path=self.model_path)
|
||||||
|
@ -164,15 +164,14 @@ class HandLandmarkerTest(parameterized.TestCase):
|
||||||
# Should never happen
|
# Should never happen
|
||||||
raise ValueError('model_file_type is invalid.')
|
raise ValueError('model_file_type is invalid.')
|
||||||
|
|
||||||
options = _FaceLandmarkerOptions(base_options=base_options,
|
options = _FaceLandmarkerOptions(base_options=base_options)
|
||||||
output_face_blendshapes=True)
|
|
||||||
landmarker = _FaceLandmarker.create_from_options(options)
|
landmarker = _FaceLandmarker.create_from_options(options)
|
||||||
|
|
||||||
# Performs face landmarks detection on the input.
|
# Performs face landmarks detection on the input.
|
||||||
detection_result = landmarker.detect(self.test_image)
|
detection_result = landmarker.detect(self.test_image)
|
||||||
# Comparing results.
|
# Comparing results.
|
||||||
self._expect_landmarks_correct(detection_result.face_landmarks,
|
self._expect_landmarks_correct(detection_result.face_landmarks[0],
|
||||||
expected_result.face_landmarks)
|
expected_face_landmarks)
|
||||||
# Closes the face landmarker explicitly when the face landmarker is not used
|
# Closes the face landmarker explicitly when the face landmarker is not used
|
||||||
# in a context.
|
# in a context.
|
||||||
landmarker.close()
|
landmarker.close()
|
||||||
|
|
|
@ -132,10 +132,6 @@ def _build_landmarker_result(
|
||||||
"""Constructs a `FaceLandmarkerResult` from output packets."""
|
"""Constructs a `FaceLandmarkerResult` from output packets."""
|
||||||
face_landmarks_proto_list = packet_getter.get_proto_list(
|
face_landmarks_proto_list = packet_getter.get_proto_list(
|
||||||
output_packets[_NORM_LANDMARKS_STREAM_NAME])
|
output_packets[_NORM_LANDMARKS_STREAM_NAME])
|
||||||
face_blendshapes_proto_list = packet_getter.get_proto_list(
|
|
||||||
output_packets[_BLENDSHAPES_STREAM_NAME])
|
|
||||||
facial_transformation_matrixes_proto_list = packet_getter.get_proto_list(
|
|
||||||
output_packets[_FACE_GEOMETRY_STREAM_NAME])
|
|
||||||
|
|
||||||
face_landmarks_results = []
|
face_landmarks_results = []
|
||||||
for proto in face_landmarks_proto_list:
|
for proto in face_landmarks_proto_list:
|
||||||
|
@ -143,30 +139,36 @@ def _build_landmarker_result(
|
||||||
face_landmarks.MergeFrom(proto)
|
face_landmarks.MergeFrom(proto)
|
||||||
face_landmarks_list = []
|
face_landmarks_list = []
|
||||||
for face_landmark in face_landmarks.landmark:
|
for face_landmark in face_landmarks.landmark:
|
||||||
face_landmarks.append(
|
face_landmarks_list.append(
|
||||||
landmark_module.NormalizedLandmark.create_from_pb2(face_landmark))
|
landmark_module.NormalizedLandmark.create_from_pb2(face_landmark))
|
||||||
face_landmarks_results.append(face_landmarks_list)
|
face_landmarks_results.append(face_landmarks_list)
|
||||||
|
|
||||||
face_blendshapes_results = []
|
face_blendshapes_results = []
|
||||||
for proto in face_blendshapes_proto_list:
|
if _BLENDSHAPES_STREAM_NAME in output_packets:
|
||||||
face_blendshapes_categories = []
|
face_blendshapes_proto_list = packet_getter.get_proto_list(
|
||||||
face_blendshapes_classifications = classification_pb2.ClassificationList()
|
output_packets[_BLENDSHAPES_STREAM_NAME])
|
||||||
face_blendshapes_classifications.MergeFrom(proto)
|
for proto in face_blendshapes_proto_list:
|
||||||
for face_blendshapes in face_blendshapes_classifications.classification:
|
face_blendshapes_categories = []
|
||||||
face_blendshapes_categories.append(
|
face_blendshapes_classifications = classification_pb2.ClassificationList()
|
||||||
category_module.Category(
|
face_blendshapes_classifications.MergeFrom(proto)
|
||||||
index=face_blendshapes.index,
|
for face_blendshapes in face_blendshapes_classifications.classification:
|
||||||
score=face_blendshapes.score,
|
face_blendshapes_categories.append(
|
||||||
display_name=face_blendshapes.display_name,
|
category_module.Category(
|
||||||
category_name=face_blendshapes.label))
|
index=face_blendshapes.index,
|
||||||
face_blendshapes_results.append(face_blendshapes_categories)
|
score=face_blendshapes.score,
|
||||||
|
display_name=face_blendshapes.display_name,
|
||||||
|
category_name=face_blendshapes.label))
|
||||||
|
face_blendshapes_results.append(face_blendshapes_categories)
|
||||||
|
|
||||||
facial_transformation_matrixes_results = []
|
facial_transformation_matrixes_results = []
|
||||||
for proto in facial_transformation_matrixes_proto_list:
|
if _FACE_GEOMETRY_STREAM_NAME in output_packets:
|
||||||
matrix_data = matrix_data_pb2.MatrixData()
|
facial_transformation_matrixes_proto_list = packet_getter.get_proto_list(
|
||||||
matrix_data.MergeFrom(proto)
|
output_packets[_FACE_GEOMETRY_STREAM_NAME])
|
||||||
matrix = matrix_data_module.MatrixData.create_from_pb2(matrix_data)
|
for proto in facial_transformation_matrixes_proto_list:
|
||||||
facial_transformation_matrixes_results.append(matrix)
|
matrix_data = matrix_data_pb2.MatrixData()
|
||||||
|
matrix_data.MergeFrom(proto)
|
||||||
|
matrix = matrix_data_module.MatrixData.create_from_pb2(matrix_data)
|
||||||
|
facial_transformation_matrixes_results.append(matrix)
|
||||||
|
|
||||||
return FaceLandmarkerResult(face_landmarks_results, face_blendshapes_results,
|
return FaceLandmarkerResult(face_landmarks_results, face_blendshapes_results,
|
||||||
facial_transformation_matrixes_results)
|
facial_transformation_matrixes_results)
|
||||||
|
@ -298,19 +300,25 @@ class FaceLandmarker(base_vision_task_api.BaseVisionTaskApi):
|
||||||
options.result_callback(face_landmarks_result, image,
|
options.result_callback(face_landmarks_result, image,
|
||||||
timestamp.value // _MICRO_SECONDS_PER_MILLISECOND)
|
timestamp.value // _MICRO_SECONDS_PER_MILLISECOND)
|
||||||
|
|
||||||
|
output_streams = [
|
||||||
|
':'.join([_NORM_LANDMARKS_TAG, _NORM_LANDMARKS_STREAM_NAME]),
|
||||||
|
':'.join([_IMAGE_TAG, _IMAGE_OUT_STREAM_NAME])
|
||||||
|
]
|
||||||
|
|
||||||
|
if options.output_face_blendshapes:
|
||||||
|
output_streams.append(
|
||||||
|
':'.join([_BLENDSHAPES_TAG, _BLENDSHAPES_STREAM_NAME]))
|
||||||
|
if options.output_facial_transformation_matrixes:
|
||||||
|
output_streams.append(
|
||||||
|
':'.join([_FACE_GEOMETRY_TAG, _FACE_GEOMETRY_STREAM_NAME]))
|
||||||
|
|
||||||
task_info = _TaskInfo(
|
task_info = _TaskInfo(
|
||||||
task_graph=_TASK_GRAPH_NAME,
|
task_graph=_TASK_GRAPH_NAME,
|
||||||
input_streams=[
|
input_streams=[
|
||||||
':'.join([_IMAGE_TAG, _IMAGE_IN_STREAM_NAME]),
|
':'.join([_IMAGE_TAG, _IMAGE_IN_STREAM_NAME]),
|
||||||
':'.join([_NORM_RECT_TAG, _NORM_RECT_STREAM_NAME]),
|
':'.join([_NORM_RECT_TAG, _NORM_RECT_STREAM_NAME]),
|
||||||
],
|
],
|
||||||
output_streams=[
|
output_streams=output_streams,
|
||||||
':'.join([_NORM_LANDMARKS_TAG, _NORM_LANDMARKS_STREAM_NAME]),
|
|
||||||
':'.join([_BLENDSHAPES_TAG, _BLENDSHAPES_STREAM_NAME]),
|
|
||||||
':'.join([
|
|
||||||
_FACE_GEOMETRY_TAG, _FACE_GEOMETRY_STREAM_NAME
|
|
||||||
]), ':'.join([_IMAGE_TAG, _IMAGE_OUT_STREAM_NAME])
|
|
||||||
],
|
|
||||||
task_options=options)
|
task_options=options)
|
||||||
return cls(
|
return cls(
|
||||||
task_info.generate_graph_config(
|
task_info.generate_graph_config(
|
||||||
|
|
Loading…
Reference in New Issue
Block a user