diff --git a/mediapipe/tasks/python/components/containers/BUILD b/mediapipe/tasks/python/components/containers/BUILD index 07c31dc0c..b84ab744d 100644 --- a/mediapipe/tasks/python/components/containers/BUILD +++ b/mediapipe/tasks/python/components/containers/BUILD @@ -82,15 +82,6 @@ py_library( ], ) -py_library( - name = "matrix_data", - srcs = ["matrix_data.py"], - deps = [ - "//mediapipe/framework/formats:matrix_data_py_pb2", - "//mediapipe/tasks/python/core:optional_dependencies", - ], -) - py_library( name = "detections", srcs = ["detections.py"], diff --git a/mediapipe/tasks/python/components/containers/matrix_data.py b/mediapipe/tasks/python/components/containers/matrix_data.py deleted file mode 100644 index ded3a9b4f..000000000 --- a/mediapipe/tasks/python/components/containers/matrix_data.py +++ /dev/null @@ -1,81 +0,0 @@ -# Copyright 2022 The MediaPipe Authors. All Rights Reserved. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -"""Matrix data data class.""" - -import dataclasses -import enum -from typing import Any, Optional - -import numpy as np -from mediapipe.framework.formats import matrix_data_pb2 -from mediapipe.tasks.python.core.optional_dependencies import doc_controls - -_MatrixDataProto = matrix_data_pb2.MatrixData - - -class Layout(enum.Enum): - COLUMN_MAJOR = 0 - ROW_MAJOR = 1 - - -@dataclasses.dataclass -class MatrixData: - """This stores the Matrix data. - - Here the data is stored in column-major order by default. - - Attributes: - rows: The number of rows in the matrix. - cols: The number of columns in the matrix. - data: The data stored in the matrix as a NumPy array. - layout: The order in which the data are stored. Defaults to COLUMN_MAJOR. - """ - - rows: int = None - cols: int = None - data: np.ndarray = None - layout: Optional[Layout] = Layout.COLUMN_MAJOR - - @doc_controls.do_not_generate_docs - def to_pb2(self) -> _MatrixDataProto: - """Generates a MatrixData protobuf object.""" - return _MatrixDataProto( - rows=self.rows, - cols=self.cols, - packed_data=self.data, - layout=self.layout.value) - - @classmethod - @doc_controls.do_not_generate_docs - def create_from_pb2(cls, pb2_obj: _MatrixDataProto) -> 'MatrixData': - """Creates a `MatrixData` object from the given protobuf object.""" - return MatrixData( - rows=pb2_obj.rows, - cols=pb2_obj.cols, - data=np.array(pb2_obj.packed_data), - layout=Layout(pb2_obj.layout)) - - def __eq__(self, other: Any) -> bool: - """Checks if this object is equal to the given object. - - Args: - other: The object to be compared with. - - Returns: - True if the objects are equal. - """ - if not isinstance(other, MatrixData): - return False - - return self.to_pb2().__eq__(other.to_pb2()) diff --git a/mediapipe/tasks/python/test/vision/BUILD b/mediapipe/tasks/python/test/vision/BUILD index fcff54d83..978dc1277 100644 --- a/mediapipe/tasks/python/test/vision/BUILD +++ b/mediapipe/tasks/python/test/vision/BUILD @@ -153,7 +153,6 @@ py_test( "//mediapipe/tasks/python/components/containers:category", "//mediapipe/tasks/python/components/containers:landmark", "//mediapipe/tasks/python/components/containers:rect", - "//mediapipe/tasks/python/components/containers:matrix_data", "//mediapipe/tasks/python/core:base_options", "//mediapipe/tasks/python/test:test_utils", "//mediapipe/tasks/python/vision:face_landmarker", diff --git a/mediapipe/tasks/python/test/vision/face_landmarker_test.py b/mediapipe/tasks/python/test/vision/face_landmarker_test.py index a6b6e02f6..34d1e0b00 100644 --- a/mediapipe/tasks/python/test/vision/face_landmarker_test.py +++ b/mediapipe/tasks/python/test/vision/face_landmarker_test.py @@ -26,7 +26,6 @@ from mediapipe.framework.formats import classification_pb2 from mediapipe.python._framework_bindings import image as image_module from mediapipe.tasks.python.components.containers import category as category_module from mediapipe.tasks.python.components.containers import landmark as landmark_module -from mediapipe.tasks.python.components.containers import matrix_data as matrix_data_module from mediapipe.tasks.python.components.containers import rect as rect_module from mediapipe.tasks.python.core import base_options as base_options_module from mediapipe.tasks.python.test import test_utils @@ -39,7 +38,6 @@ _BaseOptions = base_options_module.BaseOptions _Category = category_module.Category _Rect = rect_module.Rect _Landmark = landmark_module.Landmark -_MatrixData = matrix_data_module.MatrixData _NormalizedLandmark = landmark_module.NormalizedLandmark _Image = image_module.Image _FaceLandmarker = face_landmarker.FaceLandmarker @@ -90,14 +88,12 @@ def _get_expected_face_blendshapes(file_path: str): def _make_expected_facial_transformation_matrixes(): - data = np.array([[0.9995292, -0.005092691, 0.030254554, -0.37340546], + matrix = np.array([[0.9995292, -0.005092691, 0.030254554, -0.37340546], [0.0072318087, 0.99744856, -0.07102106, 22.212194], [-0.029815676, 0.07120642, 0.9970159, -64.76358], [0, 0, 0, 1]]) - rows, cols = len(data), len(data[0]) facial_transformation_matrixes_results = [] - facial_transformation_matrix = _MatrixData(rows, cols, data.flatten()) - facial_transformation_matrixes_results.append(facial_transformation_matrix) + facial_transformation_matrixes_results.append(matrix) return facial_transformation_matrixes_results @@ -111,9 +107,9 @@ class FaceLandmarkerTest(parameterized.TestCase): def setUp(self): super().setUp() self.test_image = _Image.create_from_file( - test_utils.get_test_data_path(_PORTRAIT_IMAGE)) + test_utils.get_test_data_path(_PORTRAIT_IMAGE)) self.model_path = test_utils.get_test_data_path( - _FACE_LANDMARKER_BUNDLE_ASSET_FILE) + _FACE_LANDMARKER_BUNDLE_ASSET_FILE) def _expect_landmarks_correct(self, actual_landmarks, expected_landmarks): # Expects to have the same number of faces detected. @@ -145,11 +141,13 @@ class FaceLandmarkerTest(parameterized.TestCase): self.assertLen(actual_matrix_list, len(expected_matrix_list)) for i, rename_me in enumerate(actual_matrix_list): - self.assertEqual(rename_me.rows, expected_matrix_list[i].rows) - self.assertEqual(rename_me.cols, expected_matrix_list[i].cols) + self.assertEqual(rename_me.shape[0], + expected_matrix_list[i].shape[0]) + self.assertEqual(rename_me.shape[1], + expected_matrix_list[i].shape[1]) self.assertAlmostEqual( - rename_me.data.all(), - expected_matrix_list[i].data.all(), + rename_me.all(), + expected_matrix_list[i].all(), delta=_FACIAL_TRANSFORMATION_MATRIX_DIFF_MARGIN) def test_create_from_file_succeeds_with_valid_model_path(self): @@ -169,7 +167,7 @@ class FaceLandmarkerTest(parameterized.TestCase): with self.assertRaisesRegex( RuntimeError, 'Unable to open file at /path/to/invalid/model.tflite'): base_options = _BaseOptions( - model_asset_path='/path/to/invalid/model.tflite') + model_asset_path='/path/to/invalid/model.tflite') options = _FaceLandmarkerOptions(base_options=base_options) _FaceLandmarker.create_from_options(options) @@ -182,46 +180,46 @@ class FaceLandmarkerTest(parameterized.TestCase): self.assertIsInstance(landmarker, _FaceLandmarker) @parameterized.parameters( - (ModelFileType.FILE_NAME, _FACE_LANDMARKER_BUNDLE_ASSET_FILE, - _get_expected_face_landmarks( - _PORTRAIT_EXPECTED_FACE_LANDMARKS), None, None), - (ModelFileType.FILE_CONTENT, _FACE_LANDMARKER_BUNDLE_ASSET_FILE, - _get_expected_face_landmarks( - _PORTRAIT_EXPECTED_FACE_LANDMARKS), None, None), - (ModelFileType.FILE_NAME, - _FACE_LANDMARKER_WITH_BLENDSHAPES_BUNDLE_ASSET_FILE, - _get_expected_face_landmarks( - _PORTRAIT_EXPECTED_FACE_LANDMARKS_WITH_ATTENTION), None, None), - (ModelFileType.FILE_CONTENT, - _FACE_LANDMARKER_WITH_BLENDSHAPES_BUNDLE_ASSET_FILE, - _get_expected_face_landmarks( - _PORTRAIT_EXPECTED_FACE_LANDMARKS_WITH_ATTENTION), None, None), - (ModelFileType.FILE_NAME, - _FACE_LANDMARKER_WITH_BLENDSHAPES_BUNDLE_ASSET_FILE, - _get_expected_face_landmarks( - _PORTRAIT_EXPECTED_FACE_LANDMARKS_WITH_ATTENTION), - _get_expected_face_blendshapes( - _PORTRAIT_EXPECTED_BLENDSHAPES), None), - (ModelFileType.FILE_CONTENT, - _FACE_LANDMARKER_WITH_BLENDSHAPES_BUNDLE_ASSET_FILE, - _get_expected_face_landmarks( - _PORTRAIT_EXPECTED_FACE_LANDMARKS_WITH_ATTENTION), - _get_expected_face_blendshapes( - _PORTRAIT_EXPECTED_BLENDSHAPES), None), - (ModelFileType.FILE_NAME, - _FACE_LANDMARKER_WITH_BLENDSHAPES_BUNDLE_ASSET_FILE, - _get_expected_face_landmarks( - _PORTRAIT_EXPECTED_FACE_LANDMARKS_WITH_ATTENTION), - _get_expected_face_blendshapes( - _PORTRAIT_EXPECTED_BLENDSHAPES), - _make_expected_facial_transformation_matrixes()), - (ModelFileType.FILE_CONTENT, - _FACE_LANDMARKER_WITH_BLENDSHAPES_BUNDLE_ASSET_FILE, - _get_expected_face_landmarks( - _PORTRAIT_EXPECTED_FACE_LANDMARKS_WITH_ATTENTION), - _get_expected_face_blendshapes( - _PORTRAIT_EXPECTED_BLENDSHAPES), - _make_expected_facial_transformation_matrixes())) + (ModelFileType.FILE_NAME, _FACE_LANDMARKER_BUNDLE_ASSET_FILE, + _get_expected_face_landmarks( + _PORTRAIT_EXPECTED_FACE_LANDMARKS), None, None), + (ModelFileType.FILE_CONTENT, _FACE_LANDMARKER_BUNDLE_ASSET_FILE, + _get_expected_face_landmarks( + _PORTRAIT_EXPECTED_FACE_LANDMARKS), None, None), + (ModelFileType.FILE_NAME, + _FACE_LANDMARKER_WITH_BLENDSHAPES_BUNDLE_ASSET_FILE, + _get_expected_face_landmarks( + _PORTRAIT_EXPECTED_FACE_LANDMARKS_WITH_ATTENTION), None, None), + (ModelFileType.FILE_CONTENT, + _FACE_LANDMARKER_WITH_BLENDSHAPES_BUNDLE_ASSET_FILE, + _get_expected_face_landmarks( + _PORTRAIT_EXPECTED_FACE_LANDMARKS_WITH_ATTENTION), None, None), + (ModelFileType.FILE_NAME, + _FACE_LANDMARKER_WITH_BLENDSHAPES_BUNDLE_ASSET_FILE, + _get_expected_face_landmarks( + _PORTRAIT_EXPECTED_FACE_LANDMARKS_WITH_ATTENTION), + _get_expected_face_blendshapes( + _PORTRAIT_EXPECTED_BLENDSHAPES), None), + (ModelFileType.FILE_CONTENT, + _FACE_LANDMARKER_WITH_BLENDSHAPES_BUNDLE_ASSET_FILE, + _get_expected_face_landmarks( + _PORTRAIT_EXPECTED_FACE_LANDMARKS_WITH_ATTENTION), + _get_expected_face_blendshapes( + _PORTRAIT_EXPECTED_BLENDSHAPES), None), + (ModelFileType.FILE_NAME, + _FACE_LANDMARKER_WITH_BLENDSHAPES_BUNDLE_ASSET_FILE, + _get_expected_face_landmarks( + _PORTRAIT_EXPECTED_FACE_LANDMARKS_WITH_ATTENTION), + _get_expected_face_blendshapes( + _PORTRAIT_EXPECTED_BLENDSHAPES), + _make_expected_facial_transformation_matrixes()), + (ModelFileType.FILE_CONTENT, + _FACE_LANDMARKER_WITH_BLENDSHAPES_BUNDLE_ASSET_FILE, + _get_expected_face_landmarks( + _PORTRAIT_EXPECTED_FACE_LANDMARKS_WITH_ATTENTION), + _get_expected_face_blendshapes( + _PORTRAIT_EXPECTED_BLENDSHAPES), + _make_expected_facial_transformation_matrixes())) def test_detect( self, model_file_type, model_name, expected_face_landmarks, expected_face_blendshapes, expected_facial_transformation_matrixes): @@ -238,10 +236,10 @@ class FaceLandmarkerTest(parameterized.TestCase): raise ValueError('model_file_type is invalid.') options = _FaceLandmarkerOptions( - base_options=base_options, - output_face_blendshapes=True if expected_face_blendshapes else False, - output_facial_transformation_matrixes=True - if expected_facial_transformation_matrixes else False) + base_options=base_options, + output_face_blendshapes=True if expected_face_blendshapes else False, + output_facial_transformation_matrixes=True + if expected_facial_transformation_matrixes else False) landmarker = _FaceLandmarker.create_from_options(options) # Performs face landmarks detection on the input. @@ -255,8 +253,8 @@ class FaceLandmarkerTest(parameterized.TestCase): expected_face_blendshapes) if expected_facial_transformation_matrixes is not None: self._expect_facial_transformation_matrix_correct( - detection_result.facial_transformation_matrixes, - expected_facial_transformation_matrixes) + detection_result.facial_transformation_matrixes, + expected_facial_transformation_matrixes) # Closes the face landmarker explicitly when the face landmarker is not used # in a context. @@ -342,7 +340,7 @@ class FaceLandmarkerTest(parameterized.TestCase): def test_detect_succeeds_with_num_faces(self): # Creates face landmarker. model_path = test_utils.get_test_data_path( - _FACE_LANDMARKER_WITH_BLENDSHAPES_BUNDLE_ASSET_FILE) + _FACE_LANDMARKER_WITH_BLENDSHAPES_BUNDLE_ASSET_FILE) base_options = _BaseOptions(model_asset_path=model_path) options = _FaceLandmarkerOptions(base_options=base_options, num_faces=1, output_face_blendshapes=True) @@ -436,7 +434,7 @@ class FaceLandmarkerTest(parameterized.TestCase): @parameterized.parameters( (_FACE_LANDMARKER_BUNDLE_ASSET_FILE, _get_expected_face_landmarks( - _PORTRAIT_EXPECTED_FACE_LANDMARKS), None, None), + _PORTRAIT_EXPECTED_FACE_LANDMARKS), None, None), (_FACE_LANDMARKER_WITH_BLENDSHAPES_BUNDLE_ASSET_FILE, _get_expected_face_landmarks( _PORTRAIT_EXPECTED_FACE_LANDMARKS_WITH_ATTENTION), None, None), diff --git a/mediapipe/tasks/python/vision/BUILD b/mediapipe/tasks/python/vision/BUILD index 83763c1ae..ae02e2775 100644 --- a/mediapipe/tasks/python/vision/BUILD +++ b/mediapipe/tasks/python/vision/BUILD @@ -189,7 +189,6 @@ py_library( "//mediapipe/tasks/cc/vision/face_geometry/proto:face_geometry_py_pb2", "//mediapipe/tasks/python/components/containers:category", "//mediapipe/tasks/python/components/containers:landmark", - "//mediapipe/tasks/python/components/containers:matrix_data", "//mediapipe/tasks/python/core:base_options", "//mediapipe/tasks/python/core:optional_dependencies", "//mediapipe/tasks/python/core:task_info", diff --git a/mediapipe/tasks/python/vision/face_landmarker.py b/mediapipe/tasks/python/vision/face_landmarker.py index 6862818ce..7d53b8208 100644 --- a/mediapipe/tasks/python/vision/face_landmarker.py +++ b/mediapipe/tasks/python/vision/face_landmarker.py @@ -17,6 +17,7 @@ import dataclasses import enum from typing import Callable, Mapping, Optional, List +import numpy as np from mediapipe.framework.formats import classification_pb2 from mediapipe.framework.formats import landmark_pb2 from mediapipe.framework.formats import matrix_data_pb2 @@ -29,7 +30,6 @@ from mediapipe.tasks.cc.vision.face_landmarker.proto import face_landmarker_grap from mediapipe.tasks.cc.vision.face_geometry.proto import face_geometry_pb2 from mediapipe.tasks.python.components.containers import category as category_module from mediapipe.tasks.python.components.containers import landmark as landmark_module -from mediapipe.tasks.python.components.containers import matrix_data as matrix_data_module from mediapipe.tasks.python.core import base_options as base_options_module from mediapipe.tasks.python.core import task_info as task_info_module from mediapipe.tasks.python.core.optional_dependencies import doc_controls @@ -39,6 +39,7 @@ from mediapipe.tasks.python.vision.core import vision_task_running_mode as runni _BaseOptions = base_options_module.BaseOptions _FaceLandmarkerGraphOptionsProto = face_landmarker_graph_options_pb2.FaceLandmarkerGraphOptions +_LayoutEnum = matrix_data_pb2.MatrixData.Layout _RunningMode = running_mode_module.VisionTaskRunningMode _ImageProcessingOptions = image_processing_options_module.ImageProcessingOptions _TaskInfo = task_info_module.TaskInfo @@ -126,7 +127,7 @@ class FaceLandmarkerResult: face_landmarks: List[List[landmark_module.NormalizedLandmark]] face_blendshapes: List[List[category_module.Category]] - facial_transformation_matrixes: List[matrix_data_module.MatrixData] + facial_transformation_matrixes: List[np.ndarray] def _build_landmarker_result( @@ -170,7 +171,9 @@ def _build_landmarker_result( if proto.pose_transform_matrix: matrix_data = matrix_data_pb2.MatrixData() matrix_data.MergeFrom(proto.pose_transform_matrix) - matrix = matrix_data_module.MatrixData.create_from_pb2(matrix_data) + order = 'C' if matrix_data.layout == _LayoutEnum.ROW_MAJOR else 'F' + data = np.array(matrix_data.packed_data, order=order) + matrix = data.reshape((matrix_data.rows, matrix_data.cols)) facial_transformation_matrixes_results.append(matrix) return FaceLandmarkerResult(face_landmarks_results, face_blendshapes_results,