Removed MatrixData dataclass and used NumPy to represent Matrix
This commit is contained in:
parent
9aea1be6f9
commit
2753c79fde
|
@ -82,15 +82,6 @@ py_library(
|
||||||
],
|
],
|
||||||
)
|
)
|
||||||
|
|
||||||
py_library(
|
|
||||||
name = "matrix_data",
|
|
||||||
srcs = ["matrix_data.py"],
|
|
||||||
deps = [
|
|
||||||
"//mediapipe/framework/formats:matrix_data_py_pb2",
|
|
||||||
"//mediapipe/tasks/python/core:optional_dependencies",
|
|
||||||
],
|
|
||||||
)
|
|
||||||
|
|
||||||
py_library(
|
py_library(
|
||||||
name = "detections",
|
name = "detections",
|
||||||
srcs = ["detections.py"],
|
srcs = ["detections.py"],
|
||||||
|
|
|
@ -1,81 +0,0 @@
|
||||||
# Copyright 2022 The MediaPipe Authors. All Rights Reserved.
|
|
||||||
#
|
|
||||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
||||||
# you may not use this file except in compliance with the License.
|
|
||||||
# You may obtain a copy of the License at
|
|
||||||
#
|
|
||||||
# http://www.apache.org/licenses/LICENSE-2.0
|
|
||||||
#
|
|
||||||
# Unless required by applicable law or agreed to in writing, software
|
|
||||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
||||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
||||||
# See the License for the specific language governing permissions and
|
|
||||||
# limitations under the License.
|
|
||||||
"""Matrix data data class."""
|
|
||||||
|
|
||||||
import dataclasses
|
|
||||||
import enum
|
|
||||||
from typing import Any, Optional
|
|
||||||
|
|
||||||
import numpy as np
|
|
||||||
from mediapipe.framework.formats import matrix_data_pb2
|
|
||||||
from mediapipe.tasks.python.core.optional_dependencies import doc_controls
|
|
||||||
|
|
||||||
_MatrixDataProto = matrix_data_pb2.MatrixData
|
|
||||||
|
|
||||||
|
|
||||||
class Layout(enum.Enum):
|
|
||||||
COLUMN_MAJOR = 0
|
|
||||||
ROW_MAJOR = 1
|
|
||||||
|
|
||||||
|
|
||||||
@dataclasses.dataclass
|
|
||||||
class MatrixData:
|
|
||||||
"""This stores the Matrix data.
|
|
||||||
|
|
||||||
Here the data is stored in column-major order by default.
|
|
||||||
|
|
||||||
Attributes:
|
|
||||||
rows: The number of rows in the matrix.
|
|
||||||
cols: The number of columns in the matrix.
|
|
||||||
data: The data stored in the matrix as a NumPy array.
|
|
||||||
layout: The order in which the data are stored. Defaults to COLUMN_MAJOR.
|
|
||||||
"""
|
|
||||||
|
|
||||||
rows: int = None
|
|
||||||
cols: int = None
|
|
||||||
data: np.ndarray = None
|
|
||||||
layout: Optional[Layout] = Layout.COLUMN_MAJOR
|
|
||||||
|
|
||||||
@doc_controls.do_not_generate_docs
|
|
||||||
def to_pb2(self) -> _MatrixDataProto:
|
|
||||||
"""Generates a MatrixData protobuf object."""
|
|
||||||
return _MatrixDataProto(
|
|
||||||
rows=self.rows,
|
|
||||||
cols=self.cols,
|
|
||||||
packed_data=self.data,
|
|
||||||
layout=self.layout.value)
|
|
||||||
|
|
||||||
@classmethod
|
|
||||||
@doc_controls.do_not_generate_docs
|
|
||||||
def create_from_pb2(cls, pb2_obj: _MatrixDataProto) -> 'MatrixData':
|
|
||||||
"""Creates a `MatrixData` object from the given protobuf object."""
|
|
||||||
return MatrixData(
|
|
||||||
rows=pb2_obj.rows,
|
|
||||||
cols=pb2_obj.cols,
|
|
||||||
data=np.array(pb2_obj.packed_data),
|
|
||||||
layout=Layout(pb2_obj.layout))
|
|
||||||
|
|
||||||
def __eq__(self, other: Any) -> bool:
|
|
||||||
"""Checks if this object is equal to the given object.
|
|
||||||
|
|
||||||
Args:
|
|
||||||
other: The object to be compared with.
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
True if the objects are equal.
|
|
||||||
"""
|
|
||||||
if not isinstance(other, MatrixData):
|
|
||||||
return False
|
|
||||||
|
|
||||||
return self.to_pb2().__eq__(other.to_pb2())
|
|
|
@ -153,7 +153,6 @@ py_test(
|
||||||
"//mediapipe/tasks/python/components/containers:category",
|
"//mediapipe/tasks/python/components/containers:category",
|
||||||
"//mediapipe/tasks/python/components/containers:landmark",
|
"//mediapipe/tasks/python/components/containers:landmark",
|
||||||
"//mediapipe/tasks/python/components/containers:rect",
|
"//mediapipe/tasks/python/components/containers:rect",
|
||||||
"//mediapipe/tasks/python/components/containers:matrix_data",
|
|
||||||
"//mediapipe/tasks/python/core:base_options",
|
"//mediapipe/tasks/python/core:base_options",
|
||||||
"//mediapipe/tasks/python/test:test_utils",
|
"//mediapipe/tasks/python/test:test_utils",
|
||||||
"//mediapipe/tasks/python/vision:face_landmarker",
|
"//mediapipe/tasks/python/vision:face_landmarker",
|
||||||
|
|
|
@ -26,7 +26,6 @@ from mediapipe.framework.formats import classification_pb2
|
||||||
from mediapipe.python._framework_bindings import image as image_module
|
from mediapipe.python._framework_bindings import image as image_module
|
||||||
from mediapipe.tasks.python.components.containers import category as category_module
|
from mediapipe.tasks.python.components.containers import category as category_module
|
||||||
from mediapipe.tasks.python.components.containers import landmark as landmark_module
|
from mediapipe.tasks.python.components.containers import landmark as landmark_module
|
||||||
from mediapipe.tasks.python.components.containers import matrix_data as matrix_data_module
|
|
||||||
from mediapipe.tasks.python.components.containers import rect as rect_module
|
from mediapipe.tasks.python.components.containers import rect as rect_module
|
||||||
from mediapipe.tasks.python.core import base_options as base_options_module
|
from mediapipe.tasks.python.core import base_options as base_options_module
|
||||||
from mediapipe.tasks.python.test import test_utils
|
from mediapipe.tasks.python.test import test_utils
|
||||||
|
@ -39,7 +38,6 @@ _BaseOptions = base_options_module.BaseOptions
|
||||||
_Category = category_module.Category
|
_Category = category_module.Category
|
||||||
_Rect = rect_module.Rect
|
_Rect = rect_module.Rect
|
||||||
_Landmark = landmark_module.Landmark
|
_Landmark = landmark_module.Landmark
|
||||||
_MatrixData = matrix_data_module.MatrixData
|
|
||||||
_NormalizedLandmark = landmark_module.NormalizedLandmark
|
_NormalizedLandmark = landmark_module.NormalizedLandmark
|
||||||
_Image = image_module.Image
|
_Image = image_module.Image
|
||||||
_FaceLandmarker = face_landmarker.FaceLandmarker
|
_FaceLandmarker = face_landmarker.FaceLandmarker
|
||||||
|
@ -90,14 +88,12 @@ def _get_expected_face_blendshapes(file_path: str):
|
||||||
|
|
||||||
|
|
||||||
def _make_expected_facial_transformation_matrixes():
|
def _make_expected_facial_transformation_matrixes():
|
||||||
data = np.array([[0.9995292, -0.005092691, 0.030254554, -0.37340546],
|
matrix = np.array([[0.9995292, -0.005092691, 0.030254554, -0.37340546],
|
||||||
[0.0072318087, 0.99744856, -0.07102106, 22.212194],
|
[0.0072318087, 0.99744856, -0.07102106, 22.212194],
|
||||||
[-0.029815676, 0.07120642, 0.9970159, -64.76358],
|
[-0.029815676, 0.07120642, 0.9970159, -64.76358],
|
||||||
[0, 0, 0, 1]])
|
[0, 0, 0, 1]])
|
||||||
rows, cols = len(data), len(data[0])
|
|
||||||
facial_transformation_matrixes_results = []
|
facial_transformation_matrixes_results = []
|
||||||
facial_transformation_matrix = _MatrixData(rows, cols, data.flatten())
|
facial_transformation_matrixes_results.append(matrix)
|
||||||
facial_transformation_matrixes_results.append(facial_transformation_matrix)
|
|
||||||
return facial_transformation_matrixes_results
|
return facial_transformation_matrixes_results
|
||||||
|
|
||||||
|
|
||||||
|
@ -145,11 +141,13 @@ class FaceLandmarkerTest(parameterized.TestCase):
|
||||||
self.assertLen(actual_matrix_list, len(expected_matrix_list))
|
self.assertLen(actual_matrix_list, len(expected_matrix_list))
|
||||||
|
|
||||||
for i, rename_me in enumerate(actual_matrix_list):
|
for i, rename_me in enumerate(actual_matrix_list):
|
||||||
self.assertEqual(rename_me.rows, expected_matrix_list[i].rows)
|
self.assertEqual(rename_me.shape[0],
|
||||||
self.assertEqual(rename_me.cols, expected_matrix_list[i].cols)
|
expected_matrix_list[i].shape[0])
|
||||||
|
self.assertEqual(rename_me.shape[1],
|
||||||
|
expected_matrix_list[i].shape[1])
|
||||||
self.assertAlmostEqual(
|
self.assertAlmostEqual(
|
||||||
rename_me.data.all(),
|
rename_me.all(),
|
||||||
expected_matrix_list[i].data.all(),
|
expected_matrix_list[i].all(),
|
||||||
delta=_FACIAL_TRANSFORMATION_MATRIX_DIFF_MARGIN)
|
delta=_FACIAL_TRANSFORMATION_MATRIX_DIFF_MARGIN)
|
||||||
|
|
||||||
def test_create_from_file_succeeds_with_valid_model_path(self):
|
def test_create_from_file_succeeds_with_valid_model_path(self):
|
||||||
|
|
|
@ -189,7 +189,6 @@ py_library(
|
||||||
"//mediapipe/tasks/cc/vision/face_geometry/proto:face_geometry_py_pb2",
|
"//mediapipe/tasks/cc/vision/face_geometry/proto:face_geometry_py_pb2",
|
||||||
"//mediapipe/tasks/python/components/containers:category",
|
"//mediapipe/tasks/python/components/containers:category",
|
||||||
"//mediapipe/tasks/python/components/containers:landmark",
|
"//mediapipe/tasks/python/components/containers:landmark",
|
||||||
"//mediapipe/tasks/python/components/containers:matrix_data",
|
|
||||||
"//mediapipe/tasks/python/core:base_options",
|
"//mediapipe/tasks/python/core:base_options",
|
||||||
"//mediapipe/tasks/python/core:optional_dependencies",
|
"//mediapipe/tasks/python/core:optional_dependencies",
|
||||||
"//mediapipe/tasks/python/core:task_info",
|
"//mediapipe/tasks/python/core:task_info",
|
||||||
|
|
|
@ -17,6 +17,7 @@ import dataclasses
|
||||||
import enum
|
import enum
|
||||||
from typing import Callable, Mapping, Optional, List
|
from typing import Callable, Mapping, Optional, List
|
||||||
|
|
||||||
|
import numpy as np
|
||||||
from mediapipe.framework.formats import classification_pb2
|
from mediapipe.framework.formats import classification_pb2
|
||||||
from mediapipe.framework.formats import landmark_pb2
|
from mediapipe.framework.formats import landmark_pb2
|
||||||
from mediapipe.framework.formats import matrix_data_pb2
|
from mediapipe.framework.formats import matrix_data_pb2
|
||||||
|
@ -29,7 +30,6 @@ from mediapipe.tasks.cc.vision.face_landmarker.proto import face_landmarker_grap
|
||||||
from mediapipe.tasks.cc.vision.face_geometry.proto import face_geometry_pb2
|
from mediapipe.tasks.cc.vision.face_geometry.proto import face_geometry_pb2
|
||||||
from mediapipe.tasks.python.components.containers import category as category_module
|
from mediapipe.tasks.python.components.containers import category as category_module
|
||||||
from mediapipe.tasks.python.components.containers import landmark as landmark_module
|
from mediapipe.tasks.python.components.containers import landmark as landmark_module
|
||||||
from mediapipe.tasks.python.components.containers import matrix_data as matrix_data_module
|
|
||||||
from mediapipe.tasks.python.core import base_options as base_options_module
|
from mediapipe.tasks.python.core import base_options as base_options_module
|
||||||
from mediapipe.tasks.python.core import task_info as task_info_module
|
from mediapipe.tasks.python.core import task_info as task_info_module
|
||||||
from mediapipe.tasks.python.core.optional_dependencies import doc_controls
|
from mediapipe.tasks.python.core.optional_dependencies import doc_controls
|
||||||
|
@ -39,6 +39,7 @@ from mediapipe.tasks.python.vision.core import vision_task_running_mode as runni
|
||||||
|
|
||||||
_BaseOptions = base_options_module.BaseOptions
|
_BaseOptions = base_options_module.BaseOptions
|
||||||
_FaceLandmarkerGraphOptionsProto = face_landmarker_graph_options_pb2.FaceLandmarkerGraphOptions
|
_FaceLandmarkerGraphOptionsProto = face_landmarker_graph_options_pb2.FaceLandmarkerGraphOptions
|
||||||
|
_LayoutEnum = matrix_data_pb2.MatrixData.Layout
|
||||||
_RunningMode = running_mode_module.VisionTaskRunningMode
|
_RunningMode = running_mode_module.VisionTaskRunningMode
|
||||||
_ImageProcessingOptions = image_processing_options_module.ImageProcessingOptions
|
_ImageProcessingOptions = image_processing_options_module.ImageProcessingOptions
|
||||||
_TaskInfo = task_info_module.TaskInfo
|
_TaskInfo = task_info_module.TaskInfo
|
||||||
|
@ -126,7 +127,7 @@ class FaceLandmarkerResult:
|
||||||
|
|
||||||
face_landmarks: List[List[landmark_module.NormalizedLandmark]]
|
face_landmarks: List[List[landmark_module.NormalizedLandmark]]
|
||||||
face_blendshapes: List[List[category_module.Category]]
|
face_blendshapes: List[List[category_module.Category]]
|
||||||
facial_transformation_matrixes: List[matrix_data_module.MatrixData]
|
facial_transformation_matrixes: List[np.ndarray]
|
||||||
|
|
||||||
|
|
||||||
def _build_landmarker_result(
|
def _build_landmarker_result(
|
||||||
|
@ -170,7 +171,9 @@ def _build_landmarker_result(
|
||||||
if proto.pose_transform_matrix:
|
if proto.pose_transform_matrix:
|
||||||
matrix_data = matrix_data_pb2.MatrixData()
|
matrix_data = matrix_data_pb2.MatrixData()
|
||||||
matrix_data.MergeFrom(proto.pose_transform_matrix)
|
matrix_data.MergeFrom(proto.pose_transform_matrix)
|
||||||
matrix = matrix_data_module.MatrixData.create_from_pb2(matrix_data)
|
order = 'C' if matrix_data.layout == _LayoutEnum.ROW_MAJOR else 'F'
|
||||||
|
data = np.array(matrix_data.packed_data, order=order)
|
||||||
|
matrix = data.reshape((matrix_data.rows, matrix_data.cols))
|
||||||
facial_transformation_matrixes_results.append(matrix)
|
facial_transformation_matrixes_results.append(matrix)
|
||||||
|
|
||||||
return FaceLandmarkerResult(face_landmarks_results, face_blendshapes_results,
|
return FaceLandmarkerResult(face_landmarks_results, face_blendshapes_results,
|
||||||
|
|
Loading…
Reference in New Issue
Block a user