Removed MatrixData dataclass and used NumPy to represent Matrix

This commit is contained in:
kinaryml 2023-03-16 11:50:07 -07:00
parent 9aea1be6f9
commit 2753c79fde
6 changed files with 65 additions and 156 deletions

View File

@ -82,15 +82,6 @@ py_library(
],
)
py_library(
name = "matrix_data",
srcs = ["matrix_data.py"],
deps = [
"//mediapipe/framework/formats:matrix_data_py_pb2",
"//mediapipe/tasks/python/core:optional_dependencies",
],
)
py_library(
name = "detections",
srcs = ["detections.py"],

View File

@ -1,81 +0,0 @@
# Copyright 2022 The MediaPipe Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Matrix data data class."""
import dataclasses
import enum
from typing import Any, Optional
import numpy as np
from mediapipe.framework.formats import matrix_data_pb2
from mediapipe.tasks.python.core.optional_dependencies import doc_controls
_MatrixDataProto = matrix_data_pb2.MatrixData
class Layout(enum.Enum):
COLUMN_MAJOR = 0
ROW_MAJOR = 1
@dataclasses.dataclass
class MatrixData:
"""This stores the Matrix data.
Here the data is stored in column-major order by default.
Attributes:
rows: The number of rows in the matrix.
cols: The number of columns in the matrix.
data: The data stored in the matrix as a NumPy array.
layout: The order in which the data are stored. Defaults to COLUMN_MAJOR.
"""
rows: int = None
cols: int = None
data: np.ndarray = None
layout: Optional[Layout] = Layout.COLUMN_MAJOR
@doc_controls.do_not_generate_docs
def to_pb2(self) -> _MatrixDataProto:
"""Generates a MatrixData protobuf object."""
return _MatrixDataProto(
rows=self.rows,
cols=self.cols,
packed_data=self.data,
layout=self.layout.value)
@classmethod
@doc_controls.do_not_generate_docs
def create_from_pb2(cls, pb2_obj: _MatrixDataProto) -> 'MatrixData':
"""Creates a `MatrixData` object from the given protobuf object."""
return MatrixData(
rows=pb2_obj.rows,
cols=pb2_obj.cols,
data=np.array(pb2_obj.packed_data),
layout=Layout(pb2_obj.layout))
def __eq__(self, other: Any) -> bool:
"""Checks if this object is equal to the given object.
Args:
other: The object to be compared with.
Returns:
True if the objects are equal.
"""
if not isinstance(other, MatrixData):
return False
return self.to_pb2().__eq__(other.to_pb2())

View File

@ -153,7 +153,6 @@ py_test(
"//mediapipe/tasks/python/components/containers:category",
"//mediapipe/tasks/python/components/containers:landmark",
"//mediapipe/tasks/python/components/containers:rect",
"//mediapipe/tasks/python/components/containers:matrix_data",
"//mediapipe/tasks/python/core:base_options",
"//mediapipe/tasks/python/test:test_utils",
"//mediapipe/tasks/python/vision:face_landmarker",

View File

@ -26,7 +26,6 @@ from mediapipe.framework.formats import classification_pb2
from mediapipe.python._framework_bindings import image as image_module
from mediapipe.tasks.python.components.containers import category as category_module
from mediapipe.tasks.python.components.containers import landmark as landmark_module
from mediapipe.tasks.python.components.containers import matrix_data as matrix_data_module
from mediapipe.tasks.python.components.containers import rect as rect_module
from mediapipe.tasks.python.core import base_options as base_options_module
from mediapipe.tasks.python.test import test_utils
@ -39,7 +38,6 @@ _BaseOptions = base_options_module.BaseOptions
_Category = category_module.Category
_Rect = rect_module.Rect
_Landmark = landmark_module.Landmark
_MatrixData = matrix_data_module.MatrixData
_NormalizedLandmark = landmark_module.NormalizedLandmark
_Image = image_module.Image
_FaceLandmarker = face_landmarker.FaceLandmarker
@ -90,14 +88,12 @@ def _get_expected_face_blendshapes(file_path: str):
def _make_expected_facial_transformation_matrixes():
data = np.array([[0.9995292, -0.005092691, 0.030254554, -0.37340546],
matrix = np.array([[0.9995292, -0.005092691, 0.030254554, -0.37340546],
[0.0072318087, 0.99744856, -0.07102106, 22.212194],
[-0.029815676, 0.07120642, 0.9970159, -64.76358],
[0, 0, 0, 1]])
rows, cols = len(data), len(data[0])
facial_transformation_matrixes_results = []
facial_transformation_matrix = _MatrixData(rows, cols, data.flatten())
facial_transformation_matrixes_results.append(facial_transformation_matrix)
facial_transformation_matrixes_results.append(matrix)
return facial_transformation_matrixes_results
@ -145,11 +141,13 @@ class FaceLandmarkerTest(parameterized.TestCase):
self.assertLen(actual_matrix_list, len(expected_matrix_list))
for i, rename_me in enumerate(actual_matrix_list):
self.assertEqual(rename_me.rows, expected_matrix_list[i].rows)
self.assertEqual(rename_me.cols, expected_matrix_list[i].cols)
self.assertEqual(rename_me.shape[0],
expected_matrix_list[i].shape[0])
self.assertEqual(rename_me.shape[1],
expected_matrix_list[i].shape[1])
self.assertAlmostEqual(
rename_me.data.all(),
expected_matrix_list[i].data.all(),
rename_me.all(),
expected_matrix_list[i].all(),
delta=_FACIAL_TRANSFORMATION_MATRIX_DIFF_MARGIN)
def test_create_from_file_succeeds_with_valid_model_path(self):

View File

@ -189,7 +189,6 @@ py_library(
"//mediapipe/tasks/cc/vision/face_geometry/proto:face_geometry_py_pb2",
"//mediapipe/tasks/python/components/containers:category",
"//mediapipe/tasks/python/components/containers:landmark",
"//mediapipe/tasks/python/components/containers:matrix_data",
"//mediapipe/tasks/python/core:base_options",
"//mediapipe/tasks/python/core:optional_dependencies",
"//mediapipe/tasks/python/core:task_info",

View File

@ -17,6 +17,7 @@ import dataclasses
import enum
from typing import Callable, Mapping, Optional, List
import numpy as np
from mediapipe.framework.formats import classification_pb2
from mediapipe.framework.formats import landmark_pb2
from mediapipe.framework.formats import matrix_data_pb2
@ -29,7 +30,6 @@ from mediapipe.tasks.cc.vision.face_landmarker.proto import face_landmarker_grap
from mediapipe.tasks.cc.vision.face_geometry.proto import face_geometry_pb2
from mediapipe.tasks.python.components.containers import category as category_module
from mediapipe.tasks.python.components.containers import landmark as landmark_module
from mediapipe.tasks.python.components.containers import matrix_data as matrix_data_module
from mediapipe.tasks.python.core import base_options as base_options_module
from mediapipe.tasks.python.core import task_info as task_info_module
from mediapipe.tasks.python.core.optional_dependencies import doc_controls
@ -39,6 +39,7 @@ from mediapipe.tasks.python.vision.core import vision_task_running_mode as runni
_BaseOptions = base_options_module.BaseOptions
_FaceLandmarkerGraphOptionsProto = face_landmarker_graph_options_pb2.FaceLandmarkerGraphOptions
_LayoutEnum = matrix_data_pb2.MatrixData.Layout
_RunningMode = running_mode_module.VisionTaskRunningMode
_ImageProcessingOptions = image_processing_options_module.ImageProcessingOptions
_TaskInfo = task_info_module.TaskInfo
@ -126,7 +127,7 @@ class FaceLandmarkerResult:
face_landmarks: List[List[landmark_module.NormalizedLandmark]]
face_blendshapes: List[List[category_module.Category]]
facial_transformation_matrixes: List[matrix_data_module.MatrixData]
facial_transformation_matrixes: List[np.ndarray]
def _build_landmarker_result(
@ -170,7 +171,9 @@ def _build_landmarker_result(
if proto.pose_transform_matrix:
matrix_data = matrix_data_pb2.MatrixData()
matrix_data.MergeFrom(proto.pose_transform_matrix)
matrix = matrix_data_module.MatrixData.create_from_pb2(matrix_data)
order = 'C' if matrix_data.layout == _LayoutEnum.ROW_MAJOR else 'F'
data = np.array(matrix_data.packed_data, order=order)
matrix = data.reshape((matrix_data.rows, matrix_data.cols))
facial_transformation_matrixes_results.append(matrix)
return FaceLandmarkerResult(face_landmarks_results, face_blendshapes_results,