Updated Face Stylizer implementation and tests

This commit is contained in:
kinaryml 2023-03-22 21:15:04 -07:00
parent 3afe4cafc4
commit da70497f35
5 changed files with 10 additions and 9 deletions

View File

@ -36,7 +36,7 @@ _FaceStylizerOptions = face_stylizer.FaceStylizerOptions
_RUNNING_MODE = running_mode_module.VisionTaskRunningMode
_ImageProcessingOptions = image_processing_options_module.ImageProcessingOptions
_MODEL = 'face_stylizer_model_placeholder.tflite'
_MODEL = 'face_stylization_dummy.tflite'
_IMAGE = 'cats_and_dogs.jpg'
_STYLIZED_IMAGE = 'stylized_image_placeholder.jpg'
_TEST_DATA_DIR = 'mediapipe/tasks/testdata/vision'
@ -104,11 +104,9 @@ class FaceStylizerTest(parameterized.TestCase):
stylizer = _FaceStylizer.create_from_options(options)
# Performs face stylization on the input.
stylized_image = stylizer.detect(self.test_image)
stylized_image = stylizer.stylize(self.test_image)
# Comparing results.
self.assertTrue(
np.array_equal(stylized_image.numpy_view(),
self.test_image.numpy_view()))
# TODO:
# Closes the stylizer explicitly when the stylizer is not used in
# a context.
stylizer.close()

View File

@ -162,7 +162,7 @@ py_library(
"//mediapipe/python:_framework_bindings",
"//mediapipe/python:packet_creator",
"//mediapipe/python:packet_getter",
"//mediapipe/tasks/cc/vision/image_segmenter/proto:face_stylizer_graph_options_py_pb2",
"//mediapipe/tasks/cc/vision/face_stylizer/proto:face_stylizer_graph_options_py_pb2",
"//mediapipe/tasks/python/core:base_options",
"//mediapipe/tasks/python/core:optional_dependencies",
"//mediapipe/tasks/python/core:task_info",

View File

@ -124,8 +124,10 @@ class FaceStylizer(base_vision_task_api.BaseVisionTaskApi):
return
image = packet_getter.get_image(output_packets[_IMAGE_OUT_STREAM_NAME])
stylized_image_packet = output_packets[_STYLIZED_IMAGE_NAME]
stylized_image = packet_getter.get_image(stylized_image_packet)
options.result_callback(
stylized_image_packet, image,
stylized_image, image,
stylized_image_packet.timestamp.value // _MICRO_SECONDS_PER_MILLISECOND)
task_info = _TaskInfo(
@ -173,7 +175,7 @@ class FaceStylizer(base_vision_task_api.BaseVisionTaskApi):
_NORM_RECT_STREAM_NAME:
packet_creator.create_proto(normalized_rect.to_pb2())
})
return output_packets[_STYLIZED_IMAGE_NAME]
return packet_getter.get_image(output_packets[_STYLIZED_IMAGE_NAME])
def stylize_for_video(
self,
@ -209,7 +211,7 @@ class FaceStylizer(base_vision_task_api.BaseVisionTaskApi):
packet_creator.create_proto(normalized_rect.to_pb2()).at(
timestamp_ms * _MICRO_SECONDS_PER_MILLISECOND)
})
return output_packets[_STYLIZED_IMAGE_NAME]
return packet_getter.get_image(output_packets[_STYLIZED_IMAGE_NAME])
def stylize_async(
self,

View File

@ -170,6 +170,7 @@ filegroup(
"face_detection_short_range.tflite",
"face_landmark.tflite",
"face_landmark_with_attention.tflite",
"face_stylization_dummy.tflite",
"face_landmarker.task",
"hair_segmentation.tflite",
"hand_landmark_full.tflite",

Binary file not shown.