Set the default running model to Image for face stylizer.

PiperOrigin-RevId: 564511316
This commit is contained in:
MediaPipe Team 2023-09-11 15:00:26 -07:00 committed by Copybara-Service
parent 81481df304
commit 7a04d60134

View File

@ -25,6 +25,7 @@ from mediapipe.tasks.python.core import task_info as task_info_module
from mediapipe.tasks.python.core.optional_dependencies import doc_controls from mediapipe.tasks.python.core.optional_dependencies import doc_controls
from mediapipe.tasks.python.vision.core import base_vision_task_api from mediapipe.tasks.python.vision.core import base_vision_task_api
from mediapipe.tasks.python.vision.core import image_processing_options as image_processing_options_module from mediapipe.tasks.python.vision.core import image_processing_options as image_processing_options_module
from mediapipe.tasks.python.vision.core import vision_task_running_mode as running_mode_module
_BaseOptions = base_options_module.BaseOptions _BaseOptions = base_options_module.BaseOptions
_FaceStylizerGraphOptionsProto = ( _FaceStylizerGraphOptionsProto = (
@ -115,7 +116,10 @@ class FaceStylizer(base_vision_task_api.BaseVisionTaskApi):
], ],
task_options=options, task_options=options,
) )
return cls(task_info.generate_graph_config()) return cls(
task_info.generate_graph_config(),
running_mode=running_mode_module.VisionTaskRunningMode.IMAGE,
)
def stylize( def stylize(
self, self,
@ -141,7 +145,8 @@ class FaceStylizer(base_vision_task_api.BaseVisionTaskApi):
RuntimeError: If face stylization failed to run. RuntimeError: If face stylization failed to run.
""" """
normalized_rect = self.convert_to_normalized_rect( normalized_rect = self.convert_to_normalized_rect(
image_processing_options, image) image_processing_options, image
)
output_packets = self._process_image_data({ output_packets = self._process_image_data({
_IMAGE_IN_STREAM_NAME: packet_creator.create_image(image), _IMAGE_IN_STREAM_NAME: packet_creator.create_image(image),
_NORM_RECT_STREAM_NAME: packet_creator.create_proto( _NORM_RECT_STREAM_NAME: packet_creator.create_proto(