Updated Face Stylizer implementation and tests
This commit is contained in:
parent
3afe4cafc4
commit
da70497f35
|
@ -36,7 +36,7 @@ _FaceStylizerOptions = face_stylizer.FaceStylizerOptions
|
||||||
_RUNNING_MODE = running_mode_module.VisionTaskRunningMode
|
_RUNNING_MODE = running_mode_module.VisionTaskRunningMode
|
||||||
_ImageProcessingOptions = image_processing_options_module.ImageProcessingOptions
|
_ImageProcessingOptions = image_processing_options_module.ImageProcessingOptions
|
||||||
|
|
||||||
_MODEL = 'face_stylizer_model_placeholder.tflite'
|
_MODEL = 'face_stylization_dummy.tflite'
|
||||||
_IMAGE = 'cats_and_dogs.jpg'
|
_IMAGE = 'cats_and_dogs.jpg'
|
||||||
_STYLIZED_IMAGE = 'stylized_image_placeholder.jpg'
|
_STYLIZED_IMAGE = 'stylized_image_placeholder.jpg'
|
||||||
_TEST_DATA_DIR = 'mediapipe/tasks/testdata/vision'
|
_TEST_DATA_DIR = 'mediapipe/tasks/testdata/vision'
|
||||||
|
@ -104,11 +104,9 @@ class FaceStylizerTest(parameterized.TestCase):
|
||||||
stylizer = _FaceStylizer.create_from_options(options)
|
stylizer = _FaceStylizer.create_from_options(options)
|
||||||
|
|
||||||
# Performs face stylization on the input.
|
# Performs face stylization on the input.
|
||||||
stylized_image = stylizer.detect(self.test_image)
|
stylized_image = stylizer.stylize(self.test_image)
|
||||||
# Comparing results.
|
# Comparing results.
|
||||||
self.assertTrue(
|
# TODO:
|
||||||
np.array_equal(stylized_image.numpy_view(),
|
|
||||||
self.test_image.numpy_view()))
|
|
||||||
# Closes the stylizer explicitly when the stylizer is not used in
|
# Closes the stylizer explicitly when the stylizer is not used in
|
||||||
# a context.
|
# a context.
|
||||||
stylizer.close()
|
stylizer.close()
|
||||||
|
|
|
@ -162,7 +162,7 @@ py_library(
|
||||||
"//mediapipe/python:_framework_bindings",
|
"//mediapipe/python:_framework_bindings",
|
||||||
"//mediapipe/python:packet_creator",
|
"//mediapipe/python:packet_creator",
|
||||||
"//mediapipe/python:packet_getter",
|
"//mediapipe/python:packet_getter",
|
||||||
"//mediapipe/tasks/cc/vision/image_segmenter/proto:face_stylizer_graph_options_py_pb2",
|
"//mediapipe/tasks/cc/vision/face_stylizer/proto:face_stylizer_graph_options_py_pb2",
|
||||||
"//mediapipe/tasks/python/core:base_options",
|
"//mediapipe/tasks/python/core:base_options",
|
||||||
"//mediapipe/tasks/python/core:optional_dependencies",
|
"//mediapipe/tasks/python/core:optional_dependencies",
|
||||||
"//mediapipe/tasks/python/core:task_info",
|
"//mediapipe/tasks/python/core:task_info",
|
||||||
|
|
|
@ -124,8 +124,10 @@ class FaceStylizer(base_vision_task_api.BaseVisionTaskApi):
|
||||||
return
|
return
|
||||||
image = packet_getter.get_image(output_packets[_IMAGE_OUT_STREAM_NAME])
|
image = packet_getter.get_image(output_packets[_IMAGE_OUT_STREAM_NAME])
|
||||||
stylized_image_packet = output_packets[_STYLIZED_IMAGE_NAME]
|
stylized_image_packet = output_packets[_STYLIZED_IMAGE_NAME]
|
||||||
|
stylized_image = packet_getter.get_image(stylized_image_packet)
|
||||||
|
|
||||||
options.result_callback(
|
options.result_callback(
|
||||||
stylized_image_packet, image,
|
stylized_image, image,
|
||||||
stylized_image_packet.timestamp.value // _MICRO_SECONDS_PER_MILLISECOND)
|
stylized_image_packet.timestamp.value // _MICRO_SECONDS_PER_MILLISECOND)
|
||||||
|
|
||||||
task_info = _TaskInfo(
|
task_info = _TaskInfo(
|
||||||
|
@ -173,7 +175,7 @@ class FaceStylizer(base_vision_task_api.BaseVisionTaskApi):
|
||||||
_NORM_RECT_STREAM_NAME:
|
_NORM_RECT_STREAM_NAME:
|
||||||
packet_creator.create_proto(normalized_rect.to_pb2())
|
packet_creator.create_proto(normalized_rect.to_pb2())
|
||||||
})
|
})
|
||||||
return output_packets[_STYLIZED_IMAGE_NAME]
|
return packet_getter.get_image(output_packets[_STYLIZED_IMAGE_NAME])
|
||||||
|
|
||||||
def stylize_for_video(
|
def stylize_for_video(
|
||||||
self,
|
self,
|
||||||
|
@ -209,7 +211,7 @@ class FaceStylizer(base_vision_task_api.BaseVisionTaskApi):
|
||||||
packet_creator.create_proto(normalized_rect.to_pb2()).at(
|
packet_creator.create_proto(normalized_rect.to_pb2()).at(
|
||||||
timestamp_ms * _MICRO_SECONDS_PER_MILLISECOND)
|
timestamp_ms * _MICRO_SECONDS_PER_MILLISECOND)
|
||||||
})
|
})
|
||||||
return output_packets[_STYLIZED_IMAGE_NAME]
|
return packet_getter.get_image(output_packets[_STYLIZED_IMAGE_NAME])
|
||||||
|
|
||||||
def stylize_async(
|
def stylize_async(
|
||||||
self,
|
self,
|
||||||
|
|
1
mediapipe/tasks/testdata/vision/BUILD
vendored
1
mediapipe/tasks/testdata/vision/BUILD
vendored
|
@ -170,6 +170,7 @@ filegroup(
|
||||||
"face_detection_short_range.tflite",
|
"face_detection_short_range.tflite",
|
||||||
"face_landmark.tflite",
|
"face_landmark.tflite",
|
||||||
"face_landmark_with_attention.tflite",
|
"face_landmark_with_attention.tflite",
|
||||||
|
"face_stylization_dummy.tflite",
|
||||||
"face_landmarker.task",
|
"face_landmarker.task",
|
||||||
"hair_segmentation.tflite",
|
"hair_segmentation.tflite",
|
||||||
"hand_landmark_full.tflite",
|
"hand_landmark_full.tflite",
|
||||||
|
|
BIN
mediapipe/tasks/testdata/vision/face_stylization_dummy.tflite
vendored
Normal file
BIN
mediapipe/tasks/testdata/vision/face_stylization_dummy.tflite
vendored
Normal file
Binary file not shown.
Loading…
Reference in New Issue
Block a user