Removed SegmenterOptions dataclasses to enumerate options within ImageSegmenterOptions instead

This commit is contained in:
kinaryml 2022-10-21 13:34:30 -07:00
parent 91b60da1dc
commit 5231a0ad9f
5 changed files with 38 additions and 85 deletions

View File

@ -20,9 +20,5 @@ licenses(["notice"])
py_library( py_library(
name = "segmenter_options", name = "segmenter_options",
srcs = ["segmenter_options.py"], srcs = ["segmenter_options.py"]
deps = [
"//mediapipe/tasks/cc/components/proto:segmenter_options_py_pb2",
"//mediapipe/tasks/python/core:optional_dependencies",
],
) )

View File

@ -13,14 +13,7 @@
# limitations under the License. # limitations under the License.
"""Segmenter options data class.""" """Segmenter options data class."""
import dataclasses
import enum import enum
from typing import Any, Optional
from mediapipe.tasks.cc.components.proto import segmenter_options_pb2
from mediapipe.tasks.python.core.optional_dependencies import doc_controls
_SegmenterOptionsProto = segmenter_options_pb2.SegmenterOptions
class OutputType(enum.Enum): class OutputType(enum.Enum):
@ -33,46 +26,3 @@ class Activation(enum.Enum):
NONE = 0 NONE = 0
SIGMOID = 1 SIGMOID = 1
SOFTMAX = 2 SOFTMAX = 2
@dataclasses.dataclass
class SegmenterOptions:
"""Options for segmentation processor.
Attributes:
output_type: The output mask type allows specifying the type of
post-processing to perform on the raw model results.
activation: Activation function to apply to input tensor.
"""
output_type: Optional[OutputType] = OutputType.CATEGORY_MASK
activation: Optional[Activation] = Activation.NONE
@doc_controls.do_not_generate_docs
def to_pb2(self) -> _SegmenterOptionsProto:
"""Generates a protobuf object to pass to the C++ layer."""
return _SegmenterOptionsProto(
output_type=self.output_type.value,
activation=self.activation.value
)
@classmethod
@doc_controls.do_not_generate_docs
def create_from_pb2(
cls, pb2_obj: _SegmenterOptionsProto) -> "SegmenterOptions":
"""Creates a `SegmenterOptions` object from the given protobuf object."""
return SegmenterOptions(
output_type=OutputType(pb2_obj.output_type),
activation=Activation(pb2_obj.output_type)
)
def __eq__(self, other: Any) -> bool:
"""Checks if this object is equal to the given object.
Args:
other: The object to be compared with.
Returns:
True if the objects are equal.
"""
if not isinstance(other, SegmenterOptions):
return False
return self.to_pb2().__eq__(other.to_pb2())

View File

@ -35,7 +35,6 @@ _Image = image_module.Image
_ImageFormat = image_frame_module.ImageFormat _ImageFormat = image_frame_module.ImageFormat
_OutputType = segmenter_options.OutputType _OutputType = segmenter_options.OutputType
_Activation = segmenter_options.Activation _Activation = segmenter_options.Activation
_SegmenterOptions = segmenter_options.SegmenterOptions
_ImageSegmenter = image_segmenter.ImageSegmenter _ImageSegmenter = image_segmenter.ImageSegmenter
_ImageSegmenterOptions = image_segmenter.ImageSegmenterOptions _ImageSegmenterOptions = image_segmenter.ImageSegmenterOptions
_RUNNING_MODE = running_mode_module.VisionTaskRunningMode _RUNNING_MODE = running_mode_module.VisionTaskRunningMode
@ -125,9 +124,8 @@ class ImageSegmenterTest(parameterized.TestCase):
# Should never happen # Should never happen
raise ValueError('model_file_type is invalid.') raise ValueError('model_file_type is invalid.')
segmenter_options = _SegmenterOptions(output_type=_OutputType.CATEGORY_MASK)
options = _ImageSegmenterOptions(base_options=base_options, options = _ImageSegmenterOptions(base_options=base_options,
segmenter_options=segmenter_options) output_type=_OutputType.CATEGORY_MASK)
segmenter = _ImageSegmenter.create_from_options(options) segmenter = _ImageSegmenter.create_from_options(options)
# Performs image segmentation on the input. # Performs image segmentation on the input.
@ -153,19 +151,16 @@ class ImageSegmenterTest(parameterized.TestCase):
base_options = _BaseOptions(model_asset_path=self.model_path) base_options = _BaseOptions(model_asset_path=self.model_path)
# Run segmentation on the model in CATEGORY_MASK mode. # Run segmentation on the model in CATEGORY_MASK mode.
segmenter_options = _SegmenterOptions(output_type=_OutputType.CATEGORY_MASK)
options = _ImageSegmenterOptions(base_options=base_options, options = _ImageSegmenterOptions(base_options=base_options,
segmenter_options=segmenter_options) output_type=_OutputType.CATEGORY_MASK)
segmenter = _ImageSegmenter.create_from_options(options) segmenter = _ImageSegmenter.create_from_options(options)
category_masks = segmenter.segment(self.test_image) category_masks = segmenter.segment(self.test_image)
category_mask = category_masks[0].numpy_view() category_mask = category_masks[0].numpy_view()
# Run segmentation on the model in CONFIDENCE_MASK mode. # Run segmentation on the model in CONFIDENCE_MASK mode.
segmenter_options = _SegmenterOptions( options = _ImageSegmenterOptions(base_options=base_options,
output_type=_OutputType.CONFIDENCE_MASK, output_type=_OutputType.CONFIDENCE_MASK,
activation=_Activation.SOFTMAX) activation=_Activation.SOFTMAX)
options = _ImageSegmenterOptions(base_options=base_options,
segmenter_options=segmenter_options)
segmenter = _ImageSegmenter.create_from_options(options) segmenter = _ImageSegmenter.create_from_options(options)
confidence_masks = segmenter.segment(self.test_image) confidence_masks = segmenter.segment(self.test_image)
@ -204,9 +199,8 @@ class ImageSegmenterTest(parameterized.TestCase):
# Should never happen # Should never happen
raise ValueError('model_file_type is invalid.') raise ValueError('model_file_type is invalid.')
segmenter_options = _SegmenterOptions(output_type=_OutputType.CATEGORY_MASK)
options = _ImageSegmenterOptions(base_options=base_options, options = _ImageSegmenterOptions(base_options=base_options,
segmenter_options=segmenter_options) output_type=_OutputType.CATEGORY_MASK)
with _ImageSegmenter.create_from_options(options) as segmenter: with _ImageSegmenter.create_from_options(options) as segmenter:
# Performs image segmentation on the input. # Performs image segmentation on the input.
category_masks = segmenter.segment(self.test_image) category_masks = segmenter.segment(self.test_image)
@ -284,10 +278,9 @@ class ImageSegmenterTest(parameterized.TestCase):
segmenter.segment_for_video(self.test_image, 0) segmenter.segment_for_video(self.test_image, 0)
def test_segment_for_video(self): def test_segment_for_video(self):
segmenter_options = _SegmenterOptions(output_type=_OutputType.CATEGORY_MASK)
options = _ImageSegmenterOptions( options = _ImageSegmenterOptions(
base_options=_BaseOptions(model_asset_path=self.model_path), base_options=_BaseOptions(model_asset_path=self.model_path),
segmenter_options=segmenter_options, output_type=_OutputType.CATEGORY_MASK,
running_mode=_RUNNING_MODE.VIDEO) running_mode=_RUNNING_MODE.VIDEO)
with _ImageSegmenter.create_from_options(options) as segmenter: with _ImageSegmenter.create_from_options(options) as segmenter:
for timestamp in range(0, 300, 30): for timestamp in range(0, 300, 30):
@ -348,10 +341,9 @@ class ImageSegmenterTest(parameterized.TestCase):
self.assertLess(observed_timestamp_ms, timestamp_ms) self.assertLess(observed_timestamp_ms, timestamp_ms)
self.observed_timestamp_ms = timestamp_ms self.observed_timestamp_ms = timestamp_ms
segmenter_options = _SegmenterOptions(output_type=_OutputType.CATEGORY_MASK)
options = _ImageSegmenterOptions( options = _ImageSegmenterOptions(
base_options=_BaseOptions(model_asset_path=self.model_path), base_options=_BaseOptions(model_asset_path=self.model_path),
segmenter_options=segmenter_options, output_type=_OutputType.CATEGORY_MASK,
running_mode=_RUNNING_MODE.LIVE_STREAM, running_mode=_RUNNING_MODE.LIVE_STREAM,
result_callback=check_result) result_callback=check_result)
with _ImageSegmenter.create_from_options(options) as segmenter: with _ImageSegmenter.create_from_options(options) as segmenter:

View File

@ -46,6 +46,7 @@ py_library(
"//mediapipe/python:_framework_bindings", "//mediapipe/python:_framework_bindings",
"//mediapipe/python:packet_creator", "//mediapipe/python:packet_creator",
"//mediapipe/python:packet_getter", "//mediapipe/python:packet_getter",
"//mediapipe/tasks/cc/components/proto:segmenter_options_py_pb2",
"//mediapipe/tasks/cc/vision/image_segmenter/proto:image_segmenter_options_py_pb2", "//mediapipe/tasks/cc/vision/image_segmenter/proto:image_segmenter_options_py_pb2",
"//mediapipe/tasks/python/components/proto:segmenter_options", "//mediapipe/tasks/python/components/proto:segmenter_options",
"//mediapipe/tasks/python/core:base_options", "//mediapipe/tasks/python/core:base_options",

View File

@ -19,22 +19,25 @@ from typing import Callable, List, Mapping, Optional
from mediapipe.python import packet_creator from mediapipe.python import packet_creator
from mediapipe.python import packet_getter from mediapipe.python import packet_getter
from mediapipe.python._framework_bindings import image as image_module from mediapipe.python._framework_bindings import image as image_module
from mediapipe.python._framework_bindings import packet as packet_module from mediapipe.python._framework_bindings import packet
from mediapipe.python._framework_bindings import task_runner as task_runner_module from mediapipe.python._framework_bindings import task_runner
from mediapipe.tasks.cc.components.proto import segmenter_options_pb2
from mediapipe.tasks.cc.vision.image_segmenter.proto import image_segmenter_options_pb2 from mediapipe.tasks.cc.vision.image_segmenter.proto import image_segmenter_options_pb2
from mediapipe.tasks.python.components.proto import segmenter_options from mediapipe.tasks.python.components.proto import segmenter_options
from mediapipe.tasks.python.core import base_options as base_options_module from mediapipe.tasks.python.core import base_options as base_options_module
from mediapipe.tasks.python.core import task_info as task_info_module from mediapipe.tasks.python.core import task_info as task_info_module
from mediapipe.tasks.python.core.optional_dependencies import doc_controls from mediapipe.tasks.python.core.optional_dependencies import doc_controls
from mediapipe.tasks.python.vision.core import base_vision_task_api from mediapipe.tasks.python.vision.core import base_vision_task_api
from mediapipe.tasks.python.vision.core import vision_task_running_mode as running_mode_module from mediapipe.tasks.python.vision.core import vision_task_running_mode
_BaseOptions = base_options_module.BaseOptions _BaseOptions = base_options_module.BaseOptions
_SegmenterOptionsProto = segmenter_options_pb2.SegmenterOptions
_ImageSegmenterOptionsProto = image_segmenter_options_pb2.ImageSegmenterOptions _ImageSegmenterOptionsProto = image_segmenter_options_pb2.ImageSegmenterOptions
_SegmenterOptions = segmenter_options.SegmenterOptions _OutputType = segmenter_options.OutputType
_RunningMode = running_mode_module.VisionTaskRunningMode _Activation = segmenter_options.Activation
_RunningMode = vision_task_running_mode.VisionTaskRunningMode
_TaskInfo = task_info_module.TaskInfo _TaskInfo = task_info_module.TaskInfo
_TaskRunner = task_runner_module.TaskRunner _TaskRunner = task_runner.TaskRunner
_SEGMENTATION_OUT_STREAM_NAME = 'segmented_mask_out' _SEGMENTATION_OUT_STREAM_NAME = 'segmented_mask_out'
_SEGMENTATION_TAG = 'GROUPED_SEGMENTATION' _SEGMENTATION_TAG = 'GROUPED_SEGMENTATION'
@ -57,14 +60,17 @@ class ImageSegmenterOptions:
2) The video mode for segmenting objects on the decoded frames of a video. 2) The video mode for segmenting objects on the decoded frames of a video.
3) The live stream mode for segmenting objects on a live stream of input 3) The live stream mode for segmenting objects on a live stream of input
data, such as from camera. data, such as from camera.
segmenter_options: Options for the image segmenter task. output_type: The output mask type allows specifying the type of
post-processing to perform on the raw model results.
activation: Activation function to apply to input tensor.
result_callback: The user-defined result callback for processing live stream result_callback: The user-defined result callback for processing live stream
data. The result callback should only be specified when the running mode data. The result callback should only be specified when the running mode
is set to the live stream mode. is set to the live stream mode.
""" """
base_options: _BaseOptions base_options: _BaseOptions
running_mode: _RunningMode = _RunningMode.IMAGE running_mode: _RunningMode = _RunningMode.IMAGE
segmenter_options: _SegmenterOptions = _SegmenterOptions() output_type: Optional[_OutputType] = _OutputType.CATEGORY_MASK
activation: Optional[_Activation] = _Activation.NONE
result_callback: Optional[ result_callback: Optional[
Callable[[List[image_module.Image], image_module.Image, int], Callable[[List[image_module.Image], image_module.Image, int],
None]] = None None]] = None
@ -74,8 +80,10 @@ class ImageSegmenterOptions:
"""Generates an ImageSegmenterOptions protobuf object.""" """Generates an ImageSegmenterOptions protobuf object."""
base_options_proto = self.base_options.to_pb2() base_options_proto = self.base_options.to_pb2()
base_options_proto.use_stream_mode = False if self.running_mode == _RunningMode.IMAGE else True base_options_proto.use_stream_mode = False if self.running_mode == _RunningMode.IMAGE else True
segmenter_options_proto = self.segmenter_options.to_pb2() segmenter_options_proto = _SegmenterOptionsProto(
output_type=self.output_type.value,
activation=self.activation.value
)
return _ImageSegmenterOptionsProto( return _ImageSegmenterOptionsProto(
base_options=base_options_proto, base_options=base_options_proto,
segmenter_options=segmenter_options_proto segmenter_options=segmenter_options_proto
@ -127,7 +135,7 @@ class ImageSegmenter(base_vision_task_api.BaseVisionTaskApi):
RuntimeError: If other types of error occurred. RuntimeError: If other types of error occurred.
""" """
def packets_callback(output_packets: Mapping[str, packet_module.Packet]): def packets_callback(output_packets: Mapping[str, packet.Packet]):
if output_packets[_IMAGE_OUT_STREAM_NAME].is_empty(): if output_packets[_IMAGE_OUT_STREAM_NAME].is_empty():
return return
segmentation_result = packet_getter.get_image_list( segmentation_result = packet_getter.get_image_list(
@ -159,8 +167,11 @@ class ImageSegmenter(base_vision_task_api.BaseVisionTaskApi):
image: MediaPipe Image. image: MediaPipe Image.
Returns: Returns:
A segmentation result object that contains a list of segmentation masks If the output_type is CATEGORY_MASK, the returned vector of images is
as images. per-category segmented image mask.
If the output_type is CONFIDENCE_MASK, the returned vector of images
contains only one confidence image mask. A segmentation result object that
contains a list of segmentation masks as images.
Raises: Raises:
ValueError: If any of the input arguments is invalid. ValueError: If any of the input arguments is invalid.
@ -186,8 +197,11 @@ class ImageSegmenter(base_vision_task_api.BaseVisionTaskApi):
timestamp_ms: The timestamp of the input video frame in milliseconds. timestamp_ms: The timestamp of the input video frame in milliseconds.
Returns: Returns:
A segmentation result object that contains a list of segmentation masks If the output_type is CATEGORY_MASK, the returned vector of images is
as images. per-category segmented image mask.
If the output_type is CONFIDENCE_MASK, the returned vector of images
contains only one confidence image mask. A segmentation result object that
contains a list of segmentation masks as images.
Raises: Raises:
ValueError: If any of the input arguments is invalid. ValueError: If any of the input arguments is invalid.