Registering FaceGeometry proto.

PiperOrigin-RevId: 516597971
This commit is contained in:
Jiuqiang Tang 2023-03-14 12:18:26 -07:00 committed by Copybara-Service
parent 854ab25ee9
commit fef8b9cb58
3 changed files with 18 additions and 0 deletions

View File

@ -57,6 +57,7 @@ pybind_extension(
"//mediapipe/framework/formats:landmark_registration", "//mediapipe/framework/formats:landmark_registration",
"//mediapipe/framework/formats:rect_registration", "//mediapipe/framework/formats:rect_registration",
"//mediapipe/modules/objectron/calculators:annotation_registration", "//mediapipe/modules/objectron/calculators:annotation_registration",
"//mediapipe/tasks/cc/vision/face_geometry/proto:face_geometry_registration",
], ],
) )

View File

@ -28,6 +28,7 @@ from mediapipe.python._framework_bindings import calculator_graph
from mediapipe.python._framework_bindings import image from mediapipe.python._framework_bindings import image
from mediapipe.python._framework_bindings import image_frame from mediapipe.python._framework_bindings import image_frame
from mediapipe.python._framework_bindings import packet from mediapipe.python._framework_bindings import packet
from mediapipe.tasks.cc.vision.face_geometry.proto import face_geometry_pb2
CalculatorGraph = calculator_graph.CalculatorGraph CalculatorGraph = calculator_graph.CalculatorGraph
Image = image.Image Image = image.Image
@ -177,6 +178,11 @@ class PacketTest(absltest.TestCase):
text_format.Parse('score: 0.5', detection) text_format.Parse('score: 0.5', detection)
p = packet_creator.create_proto(detection).at(100) p = packet_creator.create_proto(detection).at(100)
def test_face_geometry_proto_packet(self):
face_geometry_in = face_geometry_pb2.FaceGeometry()
p = packet_creator.create_proto(face_geometry_in).at(100)
face_geometry_out = packet_getter.get_proto(p)
def test_string_packet(self): def test_string_packet(self):
p = packet_creator.create_string('abc').at(100) p = packet_creator.create_string('abc').at(100)
self.assertEqual(packet_getter.get_str(p), 'abc') self.assertEqual(packet_getter.get_str(p), 'abc')

View File

@ -13,6 +13,7 @@
# limitations under the License. # limitations under the License.
load("//mediapipe/framework/port:build_config.bzl", "mediapipe_proto_library") load("//mediapipe/framework/port:build_config.bzl", "mediapipe_proto_library")
load("//mediapipe/framework:mediapipe_register_type.bzl", "mediapipe_register_type")
licenses(["notice"]) licenses(["notice"])
@ -23,6 +24,16 @@ mediapipe_proto_library(
srcs = ["environment.proto"], srcs = ["environment.proto"],
) )
mediapipe_register_type(
base_name = "face_geometry",
include_headers = ["mediapipe/tasks/cc/vision/face_geometry/proto/face_geometry.pb.h"],
types = [
"::mediapipe::tasks::vision::face_geometry::proto::FaceGeometry",
"::std::vector<::mediapipe::tasks::vision::face_geometry::proto::FaceGeometry>",
],
deps = [":face_geometry_cc_proto"],
)
mediapipe_proto_library( mediapipe_proto_library(
name = "face_geometry_proto", name = "face_geometry_proto",
srcs = ["face_geometry.proto"], srcs = ["face_geometry.proto"],