diff --git a/mediapipe/tasks/python/components/proto/BUILD b/mediapipe/tasks/python/components/proto/BUILD index a58f77e6c..ef37d9270 100644 --- a/mediapipe/tasks/python/components/proto/BUILD +++ b/mediapipe/tasks/python/components/proto/BUILD @@ -20,9 +20,5 @@ licenses(["notice"]) py_library( name = "segmenter_options", - srcs = ["segmenter_options.py"], - deps = [ - "//mediapipe/tasks/cc/components/proto:segmenter_options_py_pb2", - "//mediapipe/tasks/python/core:optional_dependencies", - ], + srcs = ["segmenter_options.py"] ) diff --git a/mediapipe/tasks/python/components/proto/segmenter_options.py b/mediapipe/tasks/python/components/proto/segmenter_options.py index dcf34cc39..5f8e22777 100644 --- a/mediapipe/tasks/python/components/proto/segmenter_options.py +++ b/mediapipe/tasks/python/components/proto/segmenter_options.py @@ -13,14 +13,7 @@ # limitations under the License. """Segmenter options data class.""" -import dataclasses import enum -from typing import Any, Optional - -from mediapipe.tasks.cc.components.proto import segmenter_options_pb2 -from mediapipe.tasks.python.core.optional_dependencies import doc_controls - -_SegmenterOptionsProto = segmenter_options_pb2.SegmenterOptions class OutputType(enum.Enum): @@ -33,46 +26,3 @@ class Activation(enum.Enum): NONE = 0 SIGMOID = 1 SOFTMAX = 2 - - -@dataclasses.dataclass -class SegmenterOptions: - """Options for segmentation processor. - Attributes: - output_type: The output mask type allows specifying the type of - post-processing to perform on the raw model results. - activation: Activation function to apply to input tensor. - """ - - output_type: Optional[OutputType] = OutputType.CATEGORY_MASK - activation: Optional[Activation] = Activation.NONE - - @doc_controls.do_not_generate_docs - def to_pb2(self) -> _SegmenterOptionsProto: - """Generates a protobuf object to pass to the C++ layer.""" - return _SegmenterOptionsProto( - output_type=self.output_type.value, - activation=self.activation.value - ) - - @classmethod - @doc_controls.do_not_generate_docs - def create_from_pb2( - cls, pb2_obj: _SegmenterOptionsProto) -> "SegmenterOptions": - """Creates a `SegmenterOptions` object from the given protobuf object.""" - return SegmenterOptions( - output_type=OutputType(pb2_obj.output_type), - activation=Activation(pb2_obj.output_type) - ) - - def __eq__(self, other: Any) -> bool: - """Checks if this object is equal to the given object. - Args: - other: The object to be compared with. - Returns: - True if the objects are equal. - """ - if not isinstance(other, SegmenterOptions): - return False - - return self.to_pb2().__eq__(other.to_pb2()) diff --git a/mediapipe/tasks/python/test/vision/image_segmenter_test.py b/mediapipe/tasks/python/test/vision/image_segmenter_test.py index b53f301b1..2395eae57 100644 --- a/mediapipe/tasks/python/test/vision/image_segmenter_test.py +++ b/mediapipe/tasks/python/test/vision/image_segmenter_test.py @@ -35,7 +35,6 @@ _Image = image_module.Image _ImageFormat = image_frame_module.ImageFormat _OutputType = segmenter_options.OutputType _Activation = segmenter_options.Activation -_SegmenterOptions = segmenter_options.SegmenterOptions _ImageSegmenter = image_segmenter.ImageSegmenter _ImageSegmenterOptions = image_segmenter.ImageSegmenterOptions _RUNNING_MODE = running_mode_module.VisionTaskRunningMode @@ -125,9 +124,8 @@ class ImageSegmenterTest(parameterized.TestCase): # Should never happen raise ValueError('model_file_type is invalid.') - segmenter_options = _SegmenterOptions(output_type=_OutputType.CATEGORY_MASK) options = _ImageSegmenterOptions(base_options=base_options, - segmenter_options=segmenter_options) + output_type=_OutputType.CATEGORY_MASK) segmenter = _ImageSegmenter.create_from_options(options) # Performs image segmentation on the input. @@ -153,19 +151,16 @@ class ImageSegmenterTest(parameterized.TestCase): base_options = _BaseOptions(model_asset_path=self.model_path) # Run segmentation on the model in CATEGORY_MASK mode. - segmenter_options = _SegmenterOptions(output_type=_OutputType.CATEGORY_MASK) options = _ImageSegmenterOptions(base_options=base_options, - segmenter_options=segmenter_options) + output_type=_OutputType.CATEGORY_MASK) segmenter = _ImageSegmenter.create_from_options(options) category_masks = segmenter.segment(self.test_image) category_mask = category_masks[0].numpy_view() # Run segmentation on the model in CONFIDENCE_MASK mode. - segmenter_options = _SegmenterOptions( - output_type=_OutputType.CONFIDENCE_MASK, - activation=_Activation.SOFTMAX) options = _ImageSegmenterOptions(base_options=base_options, - segmenter_options=segmenter_options) + output_type=_OutputType.CONFIDENCE_MASK, + activation=_Activation.SOFTMAX) segmenter = _ImageSegmenter.create_from_options(options) confidence_masks = segmenter.segment(self.test_image) @@ -204,9 +199,8 @@ class ImageSegmenterTest(parameterized.TestCase): # Should never happen raise ValueError('model_file_type is invalid.') - segmenter_options = _SegmenterOptions(output_type=_OutputType.CATEGORY_MASK) options = _ImageSegmenterOptions(base_options=base_options, - segmenter_options=segmenter_options) + output_type=_OutputType.CATEGORY_MASK) with _ImageSegmenter.create_from_options(options) as segmenter: # Performs image segmentation on the input. category_masks = segmenter.segment(self.test_image) @@ -284,10 +278,9 @@ class ImageSegmenterTest(parameterized.TestCase): segmenter.segment_for_video(self.test_image, 0) def test_segment_for_video(self): - segmenter_options = _SegmenterOptions(output_type=_OutputType.CATEGORY_MASK) options = _ImageSegmenterOptions( base_options=_BaseOptions(model_asset_path=self.model_path), - segmenter_options=segmenter_options, + output_type=_OutputType.CATEGORY_MASK, running_mode=_RUNNING_MODE.VIDEO) with _ImageSegmenter.create_from_options(options) as segmenter: for timestamp in range(0, 300, 30): @@ -348,10 +341,9 @@ class ImageSegmenterTest(parameterized.TestCase): self.assertLess(observed_timestamp_ms, timestamp_ms) self.observed_timestamp_ms = timestamp_ms - segmenter_options = _SegmenterOptions(output_type=_OutputType.CATEGORY_MASK) options = _ImageSegmenterOptions( base_options=_BaseOptions(model_asset_path=self.model_path), - segmenter_options=segmenter_options, + output_type=_OutputType.CATEGORY_MASK, running_mode=_RUNNING_MODE.LIVE_STREAM, result_callback=check_result) with _ImageSegmenter.create_from_options(options) as segmenter: diff --git a/mediapipe/tasks/python/vision/BUILD b/mediapipe/tasks/python/vision/BUILD index 3875ea5de..863312e4c 100644 --- a/mediapipe/tasks/python/vision/BUILD +++ b/mediapipe/tasks/python/vision/BUILD @@ -46,6 +46,7 @@ py_library( "//mediapipe/python:_framework_bindings", "//mediapipe/python:packet_creator", "//mediapipe/python:packet_getter", + "//mediapipe/tasks/cc/components/proto:segmenter_options_py_pb2", "//mediapipe/tasks/cc/vision/image_segmenter/proto:image_segmenter_options_py_pb2", "//mediapipe/tasks/python/components/proto:segmenter_options", "//mediapipe/tasks/python/core:base_options", diff --git a/mediapipe/tasks/python/vision/image_segmenter.py b/mediapipe/tasks/python/vision/image_segmenter.py index 51f802925..e7278eb96 100644 --- a/mediapipe/tasks/python/vision/image_segmenter.py +++ b/mediapipe/tasks/python/vision/image_segmenter.py @@ -19,22 +19,25 @@ from typing import Callable, List, Mapping, Optional from mediapipe.python import packet_creator from mediapipe.python import packet_getter from mediapipe.python._framework_bindings import image as image_module -from mediapipe.python._framework_bindings import packet as packet_module -from mediapipe.python._framework_bindings import task_runner as task_runner_module +from mediapipe.python._framework_bindings import packet +from mediapipe.python._framework_bindings import task_runner +from mediapipe.tasks.cc.components.proto import segmenter_options_pb2 from mediapipe.tasks.cc.vision.image_segmenter.proto import image_segmenter_options_pb2 from mediapipe.tasks.python.components.proto import segmenter_options from mediapipe.tasks.python.core import base_options as base_options_module from mediapipe.tasks.python.core import task_info as task_info_module from mediapipe.tasks.python.core.optional_dependencies import doc_controls from mediapipe.tasks.python.vision.core import base_vision_task_api -from mediapipe.tasks.python.vision.core import vision_task_running_mode as running_mode_module +from mediapipe.tasks.python.vision.core import vision_task_running_mode _BaseOptions = base_options_module.BaseOptions +_SegmenterOptionsProto = segmenter_options_pb2.SegmenterOptions _ImageSegmenterOptionsProto = image_segmenter_options_pb2.ImageSegmenterOptions -_SegmenterOptions = segmenter_options.SegmenterOptions -_RunningMode = running_mode_module.VisionTaskRunningMode +_OutputType = segmenter_options.OutputType +_Activation = segmenter_options.Activation +_RunningMode = vision_task_running_mode.VisionTaskRunningMode _TaskInfo = task_info_module.TaskInfo -_TaskRunner = task_runner_module.TaskRunner +_TaskRunner = task_runner.TaskRunner _SEGMENTATION_OUT_STREAM_NAME = 'segmented_mask_out' _SEGMENTATION_TAG = 'GROUPED_SEGMENTATION' @@ -57,14 +60,17 @@ class ImageSegmenterOptions: 2) The video mode for segmenting objects on the decoded frames of a video. 3) The live stream mode for segmenting objects on a live stream of input data, such as from camera. - segmenter_options: Options for the image segmenter task. + output_type: The output mask type allows specifying the type of + post-processing to perform on the raw model results. + activation: Activation function to apply to input tensor. result_callback: The user-defined result callback for processing live stream data. The result callback should only be specified when the running mode is set to the live stream mode. """ base_options: _BaseOptions running_mode: _RunningMode = _RunningMode.IMAGE - segmenter_options: _SegmenterOptions = _SegmenterOptions() + output_type: Optional[_OutputType] = _OutputType.CATEGORY_MASK + activation: Optional[_Activation] = _Activation.NONE result_callback: Optional[ Callable[[List[image_module.Image], image_module.Image, int], None]] = None @@ -74,8 +80,10 @@ class ImageSegmenterOptions: """Generates an ImageSegmenterOptions protobuf object.""" base_options_proto = self.base_options.to_pb2() base_options_proto.use_stream_mode = False if self.running_mode == _RunningMode.IMAGE else True - segmenter_options_proto = self.segmenter_options.to_pb2() - + segmenter_options_proto = _SegmenterOptionsProto( + output_type=self.output_type.value, + activation=self.activation.value + ) return _ImageSegmenterOptionsProto( base_options=base_options_proto, segmenter_options=segmenter_options_proto @@ -127,7 +135,7 @@ class ImageSegmenter(base_vision_task_api.BaseVisionTaskApi): RuntimeError: If other types of error occurred. """ - def packets_callback(output_packets: Mapping[str, packet_module.Packet]): + def packets_callback(output_packets: Mapping[str, packet.Packet]): if output_packets[_IMAGE_OUT_STREAM_NAME].is_empty(): return segmentation_result = packet_getter.get_image_list( @@ -159,8 +167,11 @@ class ImageSegmenter(base_vision_task_api.BaseVisionTaskApi): image: MediaPipe Image. Returns: - A segmentation result object that contains a list of segmentation masks - as images. + If the output_type is CATEGORY_MASK, the returned vector of images is + per-category segmented image mask. + If the output_type is CONFIDENCE_MASK, the returned vector of images + contains only one confidence image mask. A segmentation result object that + contains a list of segmentation masks as images. Raises: ValueError: If any of the input arguments is invalid. @@ -186,8 +197,11 @@ class ImageSegmenter(base_vision_task_api.BaseVisionTaskApi): timestamp_ms: The timestamp of the input video frame in milliseconds. Returns: - A segmentation result object that contains a list of segmentation masks - as images. + If the output_type is CATEGORY_MASK, the returned vector of images is + per-category segmented image mask. + If the output_type is CONFIDENCE_MASK, the returned vector of images + contains only one confidence image mask. A segmentation result object that + contains a list of segmentation masks as images. Raises: ValueError: If any of the input arguments is invalid.