Removed MatrixData dataclass and used NumPy to represent Matrix
This commit is contained in:
parent
9aea1be6f9
commit
2753c79fde
|
@ -82,15 +82,6 @@ py_library(
|
|||
],
|
||||
)
|
||||
|
||||
py_library(
|
||||
name = "matrix_data",
|
||||
srcs = ["matrix_data.py"],
|
||||
deps = [
|
||||
"//mediapipe/framework/formats:matrix_data_py_pb2",
|
||||
"//mediapipe/tasks/python/core:optional_dependencies",
|
||||
],
|
||||
)
|
||||
|
||||
py_library(
|
||||
name = "detections",
|
||||
srcs = ["detections.py"],
|
||||
|
|
|
@ -1,81 +0,0 @@
|
|||
# Copyright 2022 The MediaPipe Authors. All Rights Reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
"""Matrix data data class."""
|
||||
|
||||
import dataclasses
|
||||
import enum
|
||||
from typing import Any, Optional
|
||||
|
||||
import numpy as np
|
||||
from mediapipe.framework.formats import matrix_data_pb2
|
||||
from mediapipe.tasks.python.core.optional_dependencies import doc_controls
|
||||
|
||||
_MatrixDataProto = matrix_data_pb2.MatrixData
|
||||
|
||||
|
||||
class Layout(enum.Enum):
|
||||
COLUMN_MAJOR = 0
|
||||
ROW_MAJOR = 1
|
||||
|
||||
|
||||
@dataclasses.dataclass
|
||||
class MatrixData:
|
||||
"""This stores the Matrix data.
|
||||
|
||||
Here the data is stored in column-major order by default.
|
||||
|
||||
Attributes:
|
||||
rows: The number of rows in the matrix.
|
||||
cols: The number of columns in the matrix.
|
||||
data: The data stored in the matrix as a NumPy array.
|
||||
layout: The order in which the data are stored. Defaults to COLUMN_MAJOR.
|
||||
"""
|
||||
|
||||
rows: int = None
|
||||
cols: int = None
|
||||
data: np.ndarray = None
|
||||
layout: Optional[Layout] = Layout.COLUMN_MAJOR
|
||||
|
||||
@doc_controls.do_not_generate_docs
|
||||
def to_pb2(self) -> _MatrixDataProto:
|
||||
"""Generates a MatrixData protobuf object."""
|
||||
return _MatrixDataProto(
|
||||
rows=self.rows,
|
||||
cols=self.cols,
|
||||
packed_data=self.data,
|
||||
layout=self.layout.value)
|
||||
|
||||
@classmethod
|
||||
@doc_controls.do_not_generate_docs
|
||||
def create_from_pb2(cls, pb2_obj: _MatrixDataProto) -> 'MatrixData':
|
||||
"""Creates a `MatrixData` object from the given protobuf object."""
|
||||
return MatrixData(
|
||||
rows=pb2_obj.rows,
|
||||
cols=pb2_obj.cols,
|
||||
data=np.array(pb2_obj.packed_data),
|
||||
layout=Layout(pb2_obj.layout))
|
||||
|
||||
def __eq__(self, other: Any) -> bool:
|
||||
"""Checks if this object is equal to the given object.
|
||||
|
||||
Args:
|
||||
other: The object to be compared with.
|
||||
|
||||
Returns:
|
||||
True if the objects are equal.
|
||||
"""
|
||||
if not isinstance(other, MatrixData):
|
||||
return False
|
||||
|
||||
return self.to_pb2().__eq__(other.to_pb2())
|
|
@ -153,7 +153,6 @@ py_test(
|
|||
"//mediapipe/tasks/python/components/containers:category",
|
||||
"//mediapipe/tasks/python/components/containers:landmark",
|
||||
"//mediapipe/tasks/python/components/containers:rect",
|
||||
"//mediapipe/tasks/python/components/containers:matrix_data",
|
||||
"//mediapipe/tasks/python/core:base_options",
|
||||
"//mediapipe/tasks/python/test:test_utils",
|
||||
"//mediapipe/tasks/python/vision:face_landmarker",
|
||||
|
|
|
@ -26,7 +26,6 @@ from mediapipe.framework.formats import classification_pb2
|
|||
from mediapipe.python._framework_bindings import image as image_module
|
||||
from mediapipe.tasks.python.components.containers import category as category_module
|
||||
from mediapipe.tasks.python.components.containers import landmark as landmark_module
|
||||
from mediapipe.tasks.python.components.containers import matrix_data as matrix_data_module
|
||||
from mediapipe.tasks.python.components.containers import rect as rect_module
|
||||
from mediapipe.tasks.python.core import base_options as base_options_module
|
||||
from mediapipe.tasks.python.test import test_utils
|
||||
|
@ -39,7 +38,6 @@ _BaseOptions = base_options_module.BaseOptions
|
|||
_Category = category_module.Category
|
||||
_Rect = rect_module.Rect
|
||||
_Landmark = landmark_module.Landmark
|
||||
_MatrixData = matrix_data_module.MatrixData
|
||||
_NormalizedLandmark = landmark_module.NormalizedLandmark
|
||||
_Image = image_module.Image
|
||||
_FaceLandmarker = face_landmarker.FaceLandmarker
|
||||
|
@ -90,14 +88,12 @@ def _get_expected_face_blendshapes(file_path: str):
|
|||
|
||||
|
||||
def _make_expected_facial_transformation_matrixes():
|
||||
data = np.array([[0.9995292, -0.005092691, 0.030254554, -0.37340546],
|
||||
matrix = np.array([[0.9995292, -0.005092691, 0.030254554, -0.37340546],
|
||||
[0.0072318087, 0.99744856, -0.07102106, 22.212194],
|
||||
[-0.029815676, 0.07120642, 0.9970159, -64.76358],
|
||||
[0, 0, 0, 1]])
|
||||
rows, cols = len(data), len(data[0])
|
||||
facial_transformation_matrixes_results = []
|
||||
facial_transformation_matrix = _MatrixData(rows, cols, data.flatten())
|
||||
facial_transformation_matrixes_results.append(facial_transformation_matrix)
|
||||
facial_transformation_matrixes_results.append(matrix)
|
||||
return facial_transformation_matrixes_results
|
||||
|
||||
|
||||
|
@ -145,11 +141,13 @@ class FaceLandmarkerTest(parameterized.TestCase):
|
|||
self.assertLen(actual_matrix_list, len(expected_matrix_list))
|
||||
|
||||
for i, rename_me in enumerate(actual_matrix_list):
|
||||
self.assertEqual(rename_me.rows, expected_matrix_list[i].rows)
|
||||
self.assertEqual(rename_me.cols, expected_matrix_list[i].cols)
|
||||
self.assertEqual(rename_me.shape[0],
|
||||
expected_matrix_list[i].shape[0])
|
||||
self.assertEqual(rename_me.shape[1],
|
||||
expected_matrix_list[i].shape[1])
|
||||
self.assertAlmostEqual(
|
||||
rename_me.data.all(),
|
||||
expected_matrix_list[i].data.all(),
|
||||
rename_me.all(),
|
||||
expected_matrix_list[i].all(),
|
||||
delta=_FACIAL_TRANSFORMATION_MATRIX_DIFF_MARGIN)
|
||||
|
||||
def test_create_from_file_succeeds_with_valid_model_path(self):
|
||||
|
|
|
@ -189,7 +189,6 @@ py_library(
|
|||
"//mediapipe/tasks/cc/vision/face_geometry/proto:face_geometry_py_pb2",
|
||||
"//mediapipe/tasks/python/components/containers:category",
|
||||
"//mediapipe/tasks/python/components/containers:landmark",
|
||||
"//mediapipe/tasks/python/components/containers:matrix_data",
|
||||
"//mediapipe/tasks/python/core:base_options",
|
||||
"//mediapipe/tasks/python/core:optional_dependencies",
|
||||
"//mediapipe/tasks/python/core:task_info",
|
||||
|
|
|
@ -17,6 +17,7 @@ import dataclasses
|
|||
import enum
|
||||
from typing import Callable, Mapping, Optional, List
|
||||
|
||||
import numpy as np
|
||||
from mediapipe.framework.formats import classification_pb2
|
||||
from mediapipe.framework.formats import landmark_pb2
|
||||
from mediapipe.framework.formats import matrix_data_pb2
|
||||
|
@ -29,7 +30,6 @@ from mediapipe.tasks.cc.vision.face_landmarker.proto import face_landmarker_grap
|
|||
from mediapipe.tasks.cc.vision.face_geometry.proto import face_geometry_pb2
|
||||
from mediapipe.tasks.python.components.containers import category as category_module
|
||||
from mediapipe.tasks.python.components.containers import landmark as landmark_module
|
||||
from mediapipe.tasks.python.components.containers import matrix_data as matrix_data_module
|
||||
from mediapipe.tasks.python.core import base_options as base_options_module
|
||||
from mediapipe.tasks.python.core import task_info as task_info_module
|
||||
from mediapipe.tasks.python.core.optional_dependencies import doc_controls
|
||||
|
@ -39,6 +39,7 @@ from mediapipe.tasks.python.vision.core import vision_task_running_mode as runni
|
|||
|
||||
_BaseOptions = base_options_module.BaseOptions
|
||||
_FaceLandmarkerGraphOptionsProto = face_landmarker_graph_options_pb2.FaceLandmarkerGraphOptions
|
||||
_LayoutEnum = matrix_data_pb2.MatrixData.Layout
|
||||
_RunningMode = running_mode_module.VisionTaskRunningMode
|
||||
_ImageProcessingOptions = image_processing_options_module.ImageProcessingOptions
|
||||
_TaskInfo = task_info_module.TaskInfo
|
||||
|
@ -126,7 +127,7 @@ class FaceLandmarkerResult:
|
|||
|
||||
face_landmarks: List[List[landmark_module.NormalizedLandmark]]
|
||||
face_blendshapes: List[List[category_module.Category]]
|
||||
facial_transformation_matrixes: List[matrix_data_module.MatrixData]
|
||||
facial_transformation_matrixes: List[np.ndarray]
|
||||
|
||||
|
||||
def _build_landmarker_result(
|
||||
|
@ -170,7 +171,9 @@ def _build_landmarker_result(
|
|||
if proto.pose_transform_matrix:
|
||||
matrix_data = matrix_data_pb2.MatrixData()
|
||||
matrix_data.MergeFrom(proto.pose_transform_matrix)
|
||||
matrix = matrix_data_module.MatrixData.create_from_pb2(matrix_data)
|
||||
order = 'C' if matrix_data.layout == _LayoutEnum.ROW_MAJOR else 'F'
|
||||
data = np.array(matrix_data.packed_data, order=order)
|
||||
matrix = data.reshape((matrix_data.rows, matrix_data.cols))
|
||||
facial_transformation_matrixes_results.append(matrix)
|
||||
|
||||
return FaceLandmarkerResult(face_landmarks_results, face_blendshapes_results,
|
||||
|
|
Loading…
Reference in New Issue
Block a user