Revised face landmarker implementation and tests
This commit is contained in:
		
							parent
							
								
									4a7489cd3a
								
							
						
					
					
						commit
						23681cde0d
					
				| 
						 | 
				
			
			@ -17,6 +17,7 @@ import dataclasses
 | 
			
		|||
import enum
 | 
			
		||||
from typing import Any, Optional
 | 
			
		||||
 | 
			
		||||
import numpy as np
 | 
			
		||||
from mediapipe.framework.formats import matrix_data_pb2
 | 
			
		||||
from mediapipe.tasks.python.core.optional_dependencies import doc_controls
 | 
			
		||||
 | 
			
		||||
| 
						 | 
				
			
			@ -32,7 +33,7 @@ class MatrixData:
 | 
			
		|||
  Attributes:
 | 
			
		||||
    rows: The number of rows in the matrix.
 | 
			
		||||
    cols: The number of columns in the matrix.
 | 
			
		||||
    data: The data stored in the matrix.
 | 
			
		||||
    data: The data stored in the matrix as a NumPy array.
 | 
			
		||||
    layout: The order in which the data are stored. Defaults to COLUMN_MAJOR.
 | 
			
		||||
  """
 | 
			
		||||
 | 
			
		||||
| 
						 | 
				
			
			@ -40,10 +41,10 @@ class MatrixData:
 | 
			
		|||
    COLUMN_MAJOR = 0
 | 
			
		||||
    ROW_MAJOR = 1
 | 
			
		||||
 | 
			
		||||
  rows: Optional[int] = None
 | 
			
		||||
  cols: Optional[int] = None
 | 
			
		||||
  data: Optional[float] = None
 | 
			
		||||
  layout: Optional[Layout] = None
 | 
			
		||||
  rows: int = None
 | 
			
		||||
  cols: int = None
 | 
			
		||||
  data: np.ndarray = None
 | 
			
		||||
  layout: Optional[Layout] = Layout.COLUMN_MAJOR
 | 
			
		||||
 | 
			
		||||
  @doc_controls.do_not_generate_docs
 | 
			
		||||
  def to_pb2(self) -> _MatrixDataProto:
 | 
			
		||||
| 
						 | 
				
			
			@ -51,7 +52,7 @@ class MatrixData:
 | 
			
		|||
    return _MatrixDataProto(
 | 
			
		||||
        rows=self.rows,
 | 
			
		||||
        cols=self.cols,
 | 
			
		||||
        data=self.data,
 | 
			
		||||
        data=self.data.tolist(),
 | 
			
		||||
        layout=self.layout)
 | 
			
		||||
 | 
			
		||||
  @classmethod
 | 
			
		||||
| 
						 | 
				
			
			@ -61,7 +62,7 @@ class MatrixData:
 | 
			
		|||
    return MatrixData(
 | 
			
		||||
        rows=pb2_obj.rows,
 | 
			
		||||
        cols=pb2_obj.cols,
 | 
			
		||||
        data=pb2_obj.data,
 | 
			
		||||
        data=np.array(pb2_obj.data),
 | 
			
		||||
        layout=pb2_obj.layout)
 | 
			
		||||
 | 
			
		||||
  def __eq__(self, other: Any) -> bool:
 | 
			
		||||
| 
						 | 
				
			
			
 | 
			
		|||
| 
						 | 
				
			
			@ -126,10 +126,10 @@ py_test(
 | 
			
		|||
    deps = [
 | 
			
		||||
        "//mediapipe/python:_framework_bindings",
 | 
			
		||||
        "//mediapipe/framework/formats:landmark_py_pb2",
 | 
			
		||||
        "//mediapipe/framework/formats:classification_py_pb2",
 | 
			
		||||
        "//mediapipe/tasks/python/components/containers:category",
 | 
			
		||||
        "//mediapipe/tasks/python/components/containers:landmark",
 | 
			
		||||
        "//mediapipe/tasks/python/components/containers:rect",
 | 
			
		||||
        "//mediapipe/tasks/python/components/containers:classification_result",
 | 
			
		||||
        "//mediapipe/tasks/python/components/containers:matrix_data",
 | 
			
		||||
        "//mediapipe/tasks/python/core:base_options",
 | 
			
		||||
        "//mediapipe/tasks/python/test:test_utils",
 | 
			
		||||
| 
						 | 
				
			
			
 | 
			
		|||
| 
						 | 
				
			
			@ -22,11 +22,12 @@ import numpy as np
 | 
			
		|||
 | 
			
		||||
from google.protobuf import text_format
 | 
			
		||||
from mediapipe.framework.formats import landmark_pb2
 | 
			
		||||
from mediapipe.framework.formats import classification_pb2
 | 
			
		||||
from mediapipe.python._framework_bindings import image as image_module
 | 
			
		||||
from mediapipe.tasks.python.components.containers import category as category_module
 | 
			
		||||
from mediapipe.tasks.python.components.containers import landmark as landmark_module
 | 
			
		||||
from mediapipe.tasks.python.components.containers import matrix_data as matrix_data_module
 | 
			
		||||
from mediapipe.tasks.python.components.containers import rect as rect_module
 | 
			
		||||
from mediapipe.tasks.python.components.containers import classification_result as classification_result_module
 | 
			
		||||
from mediapipe.tasks.python.core import base_options as base_options_module
 | 
			
		||||
from mediapipe.tasks.python.test import test_utils
 | 
			
		||||
from mediapipe.tasks.python.vision import face_landmarker
 | 
			
		||||
| 
						 | 
				
			
			@ -38,6 +39,7 @@ _BaseOptions = base_options_module.BaseOptions
 | 
			
		|||
_Category = category_module.Category
 | 
			
		||||
_Rect = rect_module.Rect
 | 
			
		||||
_Landmark = landmark_module.Landmark
 | 
			
		||||
_MatrixData = matrix_data_module.MatrixData
 | 
			
		||||
_NormalizedLandmark = landmark_module.NormalizedLandmark
 | 
			
		||||
_Image = image_module.Image
 | 
			
		||||
_FaceLandmarker = face_landmarker.FaceLandmarker
 | 
			
		||||
| 
						 | 
				
			
			@ -51,6 +53,7 @@ _PORTRAIT_IMAGE = 'portrait.jpg'
 | 
			
		|||
_PORTRAIT_EXPECTED_FACE_LANDMARKS = 'portrait_expected_face_landmarks.pbtxt'
 | 
			
		||||
_PORTRAIT_EXPECTED_FACE_LANDMARKS_WITH_ATTENTION = 'portrait_expected_face_landmarks_with_attention.pbtxt'
 | 
			
		||||
_PORTRAIT_EXPECTED_BLENDSHAPES = 'portrait_expected_blendshapes_with_attention.pbtxt'
 | 
			
		||||
_PORTRAIT_EXPECTED_FACE_GEOMETRY = 'portrait_expected_face_geometry_with_attention.pbtxt'
 | 
			
		||||
_LANDMARKS_DIFF_MARGIN = 0.03
 | 
			
		||||
_BLENDSHAPES_DIFF_MARGIN = 0.1
 | 
			
		||||
_FACIAL_TRANSFORMATION_MATRIX_DIFF_MARGIN = 0.02
 | 
			
		||||
| 
						 | 
				
			
			@ -61,10 +64,40 @@ def _get_expected_face_landmarks(file_path: str):
 | 
			
		|||
  with open(proto_file_path, 'rb') as f:
 | 
			
		||||
    proto = landmark_pb2.NormalizedLandmarkList()
 | 
			
		||||
    text_format.Parse(f.read(), proto)
 | 
			
		||||
    landmarks = []
 | 
			
		||||
    face_landmarks = []
 | 
			
		||||
    for landmark in proto.landmark:
 | 
			
		||||
      landmarks.append(_NormalizedLandmark.create_from_pb2(landmark))
 | 
			
		||||
  return landmarks
 | 
			
		||||
      face_landmarks.append(_NormalizedLandmark.create_from_pb2(landmark))
 | 
			
		||||
  return face_landmarks
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
def _get_expected_face_blendshapes(file_path: str):
 | 
			
		||||
  proto_file_path = test_utils.get_test_data_path(file_path)
 | 
			
		||||
  with open(proto_file_path, 'rb') as f:
 | 
			
		||||
    proto = classification_pb2.ClassificationList()
 | 
			
		||||
    text_format.Parse(f.read(), proto)
 | 
			
		||||
    face_blendshapes_categories = []
 | 
			
		||||
    face_blendshapes_classifications = classification_pb2.ClassificationList()
 | 
			
		||||
    face_blendshapes_classifications.MergeFrom(proto)
 | 
			
		||||
    for face_blendshapes in face_blendshapes_classifications.classification:
 | 
			
		||||
      face_blendshapes_categories.append(
 | 
			
		||||
        category_module.Category(
 | 
			
		||||
          index=face_blendshapes.index,
 | 
			
		||||
          score=face_blendshapes.score,
 | 
			
		||||
          display_name=face_blendshapes.display_name,
 | 
			
		||||
          category_name=face_blendshapes.label))
 | 
			
		||||
  return face_blendshapes_categories
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
def _make_expected_facial_transformation_matrixes():
 | 
			
		||||
  data = np.array([[0.9995292, -0.005092691, 0.030254554, -0.37340546],
 | 
			
		||||
          [0.0072318087, 0.99744856, -0.07102106, 22.212194],
 | 
			
		||||
          [-0.029815676, 0.07120642, 0.9970159, -64.76358],
 | 
			
		||||
          [0, 0, 0, 1]])
 | 
			
		||||
  rows, cols = len(data), len(data[0])
 | 
			
		||||
  facial_transformation_matrixes_results = []
 | 
			
		||||
  facial_transformation_matrix = _MatrixData(rows, cols, data)
 | 
			
		||||
  facial_transformation_matrixes_results.append(facial_transformation_matrix)
 | 
			
		||||
  return facial_transformation_matrixes_results
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
class ModelFileType(enum.Enum):
 | 
			
		||||
| 
						 | 
				
			
			@ -148,30 +181,82 @@ class HandLandmarkerTest(parameterized.TestCase):
 | 
			
		|||
      self.assertIsInstance(landmarker, _FaceLandmarker)
 | 
			
		||||
 | 
			
		||||
  @parameterized.parameters(
 | 
			
		||||
      (ModelFileType.FILE_NAME, _FACE_LANDMARKER_BUNDLE_ASSET_FILE,
 | 
			
		||||
      _get_expected_face_landmarks(
 | 
			
		||||
        _PORTRAIT_EXPECTED_FACE_LANDMARKS), None, None),
 | 
			
		||||
      (ModelFileType.FILE_CONTENT, _FACE_LANDMARKER_BUNDLE_ASSET_FILE,
 | 
			
		||||
       _get_expected_face_landmarks(
 | 
			
		||||
        _PORTRAIT_EXPECTED_FACE_LANDMARKS), None, None),
 | 
			
		||||
      (ModelFileType.FILE_NAME,
 | 
			
		||||
       _get_expected_face_landmarks(_PORTRAIT_EXPECTED_FACE_LANDMARKS)),
 | 
			
		||||
       _FACE_LANDMARKER_WITH_BLENDSHAPES_BUNDLE_ASSET_FILE,
 | 
			
		||||
       _get_expected_face_landmarks(
 | 
			
		||||
         _PORTRAIT_EXPECTED_FACE_LANDMARKS_WITH_ATTENTION), None, None),
 | 
			
		||||
      (ModelFileType.FILE_CONTENT,
 | 
			
		||||
       _get_expected_face_landmarks(_PORTRAIT_EXPECTED_FACE_LANDMARKS)))
 | 
			
		||||
  def test_detect(self, model_file_type, expected_face_landmarks):
 | 
			
		||||
       _FACE_LANDMARKER_WITH_BLENDSHAPES_BUNDLE_ASSET_FILE,
 | 
			
		||||
       _get_expected_face_landmarks(
 | 
			
		||||
         _PORTRAIT_EXPECTED_FACE_LANDMARKS_WITH_ATTENTION), None, None),
 | 
			
		||||
      (ModelFileType.FILE_NAME,
 | 
			
		||||
       _FACE_LANDMARKER_WITH_BLENDSHAPES_BUNDLE_ASSET_FILE,
 | 
			
		||||
       _get_expected_face_landmarks(
 | 
			
		||||
         _PORTRAIT_EXPECTED_FACE_LANDMARKS_WITH_ATTENTION),
 | 
			
		||||
       _get_expected_face_blendshapes(
 | 
			
		||||
         _PORTRAIT_EXPECTED_BLENDSHAPES), None),
 | 
			
		||||
      (ModelFileType.FILE_CONTENT,
 | 
			
		||||
       _FACE_LANDMARKER_WITH_BLENDSHAPES_BUNDLE_ASSET_FILE,
 | 
			
		||||
       _get_expected_face_landmarks(
 | 
			
		||||
         _PORTRAIT_EXPECTED_FACE_LANDMARKS_WITH_ATTENTION),
 | 
			
		||||
       _get_expected_face_blendshapes(
 | 
			
		||||
         _PORTRAIT_EXPECTED_BLENDSHAPES), None),
 | 
			
		||||
      # (ModelFileType.FILE_NAME,
 | 
			
		||||
      #  _FACE_LANDMARKER_WITH_BLENDSHAPES_BUNDLE_ASSET_FILE,
 | 
			
		||||
      #  _get_expected_face_landmarks(
 | 
			
		||||
      #    _PORTRAIT_EXPECTED_FACE_LANDMARKS_WITH_ATTENTION),
 | 
			
		||||
      #  _get_expected_face_blendshapes(
 | 
			
		||||
      #    _PORTRAIT_EXPECTED_BLENDSHAPES),
 | 
			
		||||
      #  _make_expected_facial_transformation_matrixes()),
 | 
			
		||||
      # (ModelFileType.FILE_CONTENT,
 | 
			
		||||
      #  _FACE_LANDMARKER_WITH_BLENDSHAPES_BUNDLE_ASSET_FILE,
 | 
			
		||||
      #  _get_expected_face_landmarks(
 | 
			
		||||
      #    _PORTRAIT_EXPECTED_FACE_LANDMARKS_WITH_ATTENTION),
 | 
			
		||||
      #  _get_expected_face_blendshapes(
 | 
			
		||||
      #    _PORTRAIT_EXPECTED_BLENDSHAPES),
 | 
			
		||||
      #  _make_expected_facial_transformation_matrixes())
 | 
			
		||||
  )
 | 
			
		||||
  def test_detect(self, model_file_type, model_name, expected_face_landmarks,
 | 
			
		||||
                  expected_face_blendshapes, expected_facial_transformation_matrix):
 | 
			
		||||
    # Creates face landmarker.
 | 
			
		||||
    model_path = test_utils.get_test_data_path(model_name)
 | 
			
		||||
    if model_file_type is ModelFileType.FILE_NAME:
 | 
			
		||||
      base_options = _BaseOptions(model_asset_path=self.model_path)
 | 
			
		||||
      base_options = _BaseOptions(model_asset_path=model_path)
 | 
			
		||||
    elif model_file_type is ModelFileType.FILE_CONTENT:
 | 
			
		||||
      with open(self.model_path, 'rb') as f:
 | 
			
		||||
      with open(model_path, 'rb') as f:
 | 
			
		||||
        model_content = f.read()
 | 
			
		||||
      base_options = _BaseOptions(model_asset_buffer=model_content)
 | 
			
		||||
    else:
 | 
			
		||||
      # Should never happen
 | 
			
		||||
      raise ValueError('model_file_type is invalid.')
 | 
			
		||||
 | 
			
		||||
    options = _FaceLandmarkerOptions(base_options=base_options)
 | 
			
		||||
    options = _FaceLandmarkerOptions(
 | 
			
		||||
        base_options=base_options,
 | 
			
		||||
        output_face_blendshapes=True if expected_face_blendshapes else False,
 | 
			
		||||
        output_facial_transformation_matrixes=True
 | 
			
		||||
        if expected_facial_transformation_matrix else False)
 | 
			
		||||
    landmarker = _FaceLandmarker.create_from_options(options)
 | 
			
		||||
 | 
			
		||||
    # Performs face landmarks detection on the input.
 | 
			
		||||
    detection_result = landmarker.detect(self.test_image)
 | 
			
		||||
    # Comparing results.
 | 
			
		||||
    self._expect_landmarks_correct(detection_result.face_landmarks[0],
 | 
			
		||||
                                   expected_face_landmarks)
 | 
			
		||||
    if expected_face_landmarks is not None:
 | 
			
		||||
      self._expect_landmarks_correct(detection_result.face_landmarks[0],
 | 
			
		||||
                                     expected_face_landmarks)
 | 
			
		||||
    if expected_face_blendshapes is not None:
 | 
			
		||||
      self._expect_blendshapes_correct(detection_result.face_blendshapes[0],
 | 
			
		||||
                                       expected_face_blendshapes)
 | 
			
		||||
    if expected_facial_transformation_matrix is not None:
 | 
			
		||||
      self._expect_facial_transformation_matrix_correct(
 | 
			
		||||
          detection_result.facial_transformation_matrixes[0],
 | 
			
		||||
          expected_facial_transformation_matrix)
 | 
			
		||||
 | 
			
		||||
    # Closes the face landmarker explicitly when the face landmarker is not used
 | 
			
		||||
    # in a context.
 | 
			
		||||
    landmarker.close()
 | 
			
		||||
| 
						 | 
				
			
			
 | 
			
		|||
| 
						 | 
				
			
			@ -162,6 +162,7 @@ def _build_landmarker_result(
 | 
			
		|||
 | 
			
		||||
  facial_transformation_matrixes_results = []
 | 
			
		||||
  if _FACE_GEOMETRY_STREAM_NAME in output_packets:
 | 
			
		||||
    print(output_packets[_FACE_GEOMETRY_STREAM_NAME])
 | 
			
		||||
    facial_transformation_matrixes_proto_list = packet_getter.get_proto_list(
 | 
			
		||||
      output_packets[_FACE_GEOMETRY_STREAM_NAME])
 | 
			
		||||
    for proto in facial_transformation_matrixes_proto_list:
 | 
			
		||||
| 
						 | 
				
			
			
 | 
			
		|||
							
								
								
									
										2
									
								
								mediapipe/tasks/testdata/vision/BUILD
									
									
									
									
										vendored
									
									
								
							
							
						
						
									
										2
									
								
								mediapipe/tasks/testdata/vision/BUILD
									
									
									
									
										vendored
									
									
								
							| 
						 | 
				
			
			@ -156,6 +156,7 @@ filegroup(
 | 
			
		|||
        "face_landmark.tflite",
 | 
			
		||||
        "face_landmark_with_attention.tflite",
 | 
			
		||||
        "face_landmarker.task",
 | 
			
		||||
        "face_landmarker_with_blendshapes.task",
 | 
			
		||||
        "hair_segmentation.tflite",
 | 
			
		||||
        "hand_landmark_full.tflite",
 | 
			
		||||
        "hand_landmark_lite.tflite",
 | 
			
		||||
| 
						 | 
				
			
			@ -191,6 +192,7 @@ filegroup(
 | 
			
		|||
        "pointing_up_landmarks.pbtxt",
 | 
			
		||||
        "pointing_up_rotated_landmarks.pbtxt",
 | 
			
		||||
        "portrait_expected_detection.pbtxt",
 | 
			
		||||
        "portrait_expected_blendshapes_with_attention.pbtxt",
 | 
			
		||||
        "portrait_expected_face_geometry_with_attention.pbtxt",
 | 
			
		||||
        "portrait_expected_face_landmarks.pbtxt",
 | 
			
		||||
        "portrait_expected_face_landmarks_with_attention.pbtxt",
 | 
			
		||||
| 
						 | 
				
			
			
 | 
			
		|||
		Loading…
	
		Reference in New Issue
	
	Block a user