diff --git a/mediapipe/python/BUILD b/mediapipe/python/BUILD index f56e5b3d4..e6c5723f3 100644 --- a/mediapipe/python/BUILD +++ b/mediapipe/python/BUILD @@ -57,6 +57,7 @@ pybind_extension( "//mediapipe/framework/formats:landmark_registration", "//mediapipe/framework/formats:rect_registration", "//mediapipe/modules/objectron/calculators:annotation_registration", + "//mediapipe/tasks/cc/vision/face_geometry/proto:face_geometry_registration", ], ) diff --git a/mediapipe/python/packet_test.py b/mediapipe/python/packet_test.py index 16fc37c87..93c8601bb 100644 --- a/mediapipe/python/packet_test.py +++ b/mediapipe/python/packet_test.py @@ -28,6 +28,7 @@ from mediapipe.python._framework_bindings import calculator_graph from mediapipe.python._framework_bindings import image from mediapipe.python._framework_bindings import image_frame from mediapipe.python._framework_bindings import packet +from mediapipe.tasks.cc.vision.face_geometry.proto import face_geometry_pb2 CalculatorGraph = calculator_graph.CalculatorGraph Image = image.Image @@ -177,6 +178,11 @@ class PacketTest(absltest.TestCase): text_format.Parse('score: 0.5', detection) p = packet_creator.create_proto(detection).at(100) + def test_face_geometry_proto_packet(self): + face_geometry_in = face_geometry_pb2.FaceGeometry() + p = packet_creator.create_proto(face_geometry_in).at(100) + face_geometry_out = packet_getter.get_proto(p) + def test_string_packet(self): p = packet_creator.create_string('abc').at(100) self.assertEqual(packet_getter.get_str(p), 'abc') diff --git a/mediapipe/tasks/cc/vision/face_geometry/proto/BUILD b/mediapipe/tasks/cc/vision/face_geometry/proto/BUILD index c9dd15845..e337a3452 100644 --- a/mediapipe/tasks/cc/vision/face_geometry/proto/BUILD +++ b/mediapipe/tasks/cc/vision/face_geometry/proto/BUILD @@ -13,6 +13,7 @@ # limitations under the License. load("//mediapipe/framework/port:build_config.bzl", "mediapipe_proto_library") +load("//mediapipe/framework:mediapipe_register_type.bzl", "mediapipe_register_type") licenses(["notice"]) @@ -23,6 +24,16 @@ mediapipe_proto_library( srcs = ["environment.proto"], ) +mediapipe_register_type( + base_name = "face_geometry", + include_headers = ["mediapipe/tasks/cc/vision/face_geometry/proto/face_geometry.pb.h"], + types = [ + "::mediapipe::tasks::vision::face_geometry::proto::FaceGeometry", + "::std::vector<::mediapipe::tasks::vision::face_geometry::proto::FaceGeometry>", + ], + deps = [":face_geometry_cc_proto"], +) + mediapipe_proto_library( name = "face_geometry_proto", srcs = ["face_geometry.proto"],