Removed SegmenterOptions dataclasses to enumerate options within ImageSegmenterOptions instead
This commit is contained in:
parent
91b60da1dc
commit
5231a0ad9f
|
@ -20,9 +20,5 @@ licenses(["notice"])
|
||||||
|
|
||||||
py_library(
|
py_library(
|
||||||
name = "segmenter_options",
|
name = "segmenter_options",
|
||||||
srcs = ["segmenter_options.py"],
|
srcs = ["segmenter_options.py"]
|
||||||
deps = [
|
|
||||||
"//mediapipe/tasks/cc/components/proto:segmenter_options_py_pb2",
|
|
||||||
"//mediapipe/tasks/python/core:optional_dependencies",
|
|
||||||
],
|
|
||||||
)
|
)
|
||||||
|
|
|
@ -13,14 +13,7 @@
|
||||||
# limitations under the License.
|
# limitations under the License.
|
||||||
"""Segmenter options data class."""
|
"""Segmenter options data class."""
|
||||||
|
|
||||||
import dataclasses
|
|
||||||
import enum
|
import enum
|
||||||
from typing import Any, Optional
|
|
||||||
|
|
||||||
from mediapipe.tasks.cc.components.proto import segmenter_options_pb2
|
|
||||||
from mediapipe.tasks.python.core.optional_dependencies import doc_controls
|
|
||||||
|
|
||||||
_SegmenterOptionsProto = segmenter_options_pb2.SegmenterOptions
|
|
||||||
|
|
||||||
|
|
||||||
class OutputType(enum.Enum):
|
class OutputType(enum.Enum):
|
||||||
|
@ -33,46 +26,3 @@ class Activation(enum.Enum):
|
||||||
NONE = 0
|
NONE = 0
|
||||||
SIGMOID = 1
|
SIGMOID = 1
|
||||||
SOFTMAX = 2
|
SOFTMAX = 2
|
||||||
|
|
||||||
|
|
||||||
@dataclasses.dataclass
|
|
||||||
class SegmenterOptions:
|
|
||||||
"""Options for segmentation processor.
|
|
||||||
Attributes:
|
|
||||||
output_type: The output mask type allows specifying the type of
|
|
||||||
post-processing to perform on the raw model results.
|
|
||||||
activation: Activation function to apply to input tensor.
|
|
||||||
"""
|
|
||||||
|
|
||||||
output_type: Optional[OutputType] = OutputType.CATEGORY_MASK
|
|
||||||
activation: Optional[Activation] = Activation.NONE
|
|
||||||
|
|
||||||
@doc_controls.do_not_generate_docs
|
|
||||||
def to_pb2(self) -> _SegmenterOptionsProto:
|
|
||||||
"""Generates a protobuf object to pass to the C++ layer."""
|
|
||||||
return _SegmenterOptionsProto(
|
|
||||||
output_type=self.output_type.value,
|
|
||||||
activation=self.activation.value
|
|
||||||
)
|
|
||||||
|
|
||||||
@classmethod
|
|
||||||
@doc_controls.do_not_generate_docs
|
|
||||||
def create_from_pb2(
|
|
||||||
cls, pb2_obj: _SegmenterOptionsProto) -> "SegmenterOptions":
|
|
||||||
"""Creates a `SegmenterOptions` object from the given protobuf object."""
|
|
||||||
return SegmenterOptions(
|
|
||||||
output_type=OutputType(pb2_obj.output_type),
|
|
||||||
activation=Activation(pb2_obj.output_type)
|
|
||||||
)
|
|
||||||
|
|
||||||
def __eq__(self, other: Any) -> bool:
|
|
||||||
"""Checks if this object is equal to the given object.
|
|
||||||
Args:
|
|
||||||
other: The object to be compared with.
|
|
||||||
Returns:
|
|
||||||
True if the objects are equal.
|
|
||||||
"""
|
|
||||||
if not isinstance(other, SegmenterOptions):
|
|
||||||
return False
|
|
||||||
|
|
||||||
return self.to_pb2().__eq__(other.to_pb2())
|
|
||||||
|
|
|
@ -35,7 +35,6 @@ _Image = image_module.Image
|
||||||
_ImageFormat = image_frame_module.ImageFormat
|
_ImageFormat = image_frame_module.ImageFormat
|
||||||
_OutputType = segmenter_options.OutputType
|
_OutputType = segmenter_options.OutputType
|
||||||
_Activation = segmenter_options.Activation
|
_Activation = segmenter_options.Activation
|
||||||
_SegmenterOptions = segmenter_options.SegmenterOptions
|
|
||||||
_ImageSegmenter = image_segmenter.ImageSegmenter
|
_ImageSegmenter = image_segmenter.ImageSegmenter
|
||||||
_ImageSegmenterOptions = image_segmenter.ImageSegmenterOptions
|
_ImageSegmenterOptions = image_segmenter.ImageSegmenterOptions
|
||||||
_RUNNING_MODE = running_mode_module.VisionTaskRunningMode
|
_RUNNING_MODE = running_mode_module.VisionTaskRunningMode
|
||||||
|
@ -125,9 +124,8 @@ class ImageSegmenterTest(parameterized.TestCase):
|
||||||
# Should never happen
|
# Should never happen
|
||||||
raise ValueError('model_file_type is invalid.')
|
raise ValueError('model_file_type is invalid.')
|
||||||
|
|
||||||
segmenter_options = _SegmenterOptions(output_type=_OutputType.CATEGORY_MASK)
|
|
||||||
options = _ImageSegmenterOptions(base_options=base_options,
|
options = _ImageSegmenterOptions(base_options=base_options,
|
||||||
segmenter_options=segmenter_options)
|
output_type=_OutputType.CATEGORY_MASK)
|
||||||
segmenter = _ImageSegmenter.create_from_options(options)
|
segmenter = _ImageSegmenter.create_from_options(options)
|
||||||
|
|
||||||
# Performs image segmentation on the input.
|
# Performs image segmentation on the input.
|
||||||
|
@ -153,19 +151,16 @@ class ImageSegmenterTest(parameterized.TestCase):
|
||||||
base_options = _BaseOptions(model_asset_path=self.model_path)
|
base_options = _BaseOptions(model_asset_path=self.model_path)
|
||||||
|
|
||||||
# Run segmentation on the model in CATEGORY_MASK mode.
|
# Run segmentation on the model in CATEGORY_MASK mode.
|
||||||
segmenter_options = _SegmenterOptions(output_type=_OutputType.CATEGORY_MASK)
|
|
||||||
options = _ImageSegmenterOptions(base_options=base_options,
|
options = _ImageSegmenterOptions(base_options=base_options,
|
||||||
segmenter_options=segmenter_options)
|
output_type=_OutputType.CATEGORY_MASK)
|
||||||
segmenter = _ImageSegmenter.create_from_options(options)
|
segmenter = _ImageSegmenter.create_from_options(options)
|
||||||
category_masks = segmenter.segment(self.test_image)
|
category_masks = segmenter.segment(self.test_image)
|
||||||
category_mask = category_masks[0].numpy_view()
|
category_mask = category_masks[0].numpy_view()
|
||||||
|
|
||||||
# Run segmentation on the model in CONFIDENCE_MASK mode.
|
# Run segmentation on the model in CONFIDENCE_MASK mode.
|
||||||
segmenter_options = _SegmenterOptions(
|
options = _ImageSegmenterOptions(base_options=base_options,
|
||||||
output_type=_OutputType.CONFIDENCE_MASK,
|
output_type=_OutputType.CONFIDENCE_MASK,
|
||||||
activation=_Activation.SOFTMAX)
|
activation=_Activation.SOFTMAX)
|
||||||
options = _ImageSegmenterOptions(base_options=base_options,
|
|
||||||
segmenter_options=segmenter_options)
|
|
||||||
segmenter = _ImageSegmenter.create_from_options(options)
|
segmenter = _ImageSegmenter.create_from_options(options)
|
||||||
confidence_masks = segmenter.segment(self.test_image)
|
confidence_masks = segmenter.segment(self.test_image)
|
||||||
|
|
||||||
|
@ -204,9 +199,8 @@ class ImageSegmenterTest(parameterized.TestCase):
|
||||||
# Should never happen
|
# Should never happen
|
||||||
raise ValueError('model_file_type is invalid.')
|
raise ValueError('model_file_type is invalid.')
|
||||||
|
|
||||||
segmenter_options = _SegmenterOptions(output_type=_OutputType.CATEGORY_MASK)
|
|
||||||
options = _ImageSegmenterOptions(base_options=base_options,
|
options = _ImageSegmenterOptions(base_options=base_options,
|
||||||
segmenter_options=segmenter_options)
|
output_type=_OutputType.CATEGORY_MASK)
|
||||||
with _ImageSegmenter.create_from_options(options) as segmenter:
|
with _ImageSegmenter.create_from_options(options) as segmenter:
|
||||||
# Performs image segmentation on the input.
|
# Performs image segmentation on the input.
|
||||||
category_masks = segmenter.segment(self.test_image)
|
category_masks = segmenter.segment(self.test_image)
|
||||||
|
@ -284,10 +278,9 @@ class ImageSegmenterTest(parameterized.TestCase):
|
||||||
segmenter.segment_for_video(self.test_image, 0)
|
segmenter.segment_for_video(self.test_image, 0)
|
||||||
|
|
||||||
def test_segment_for_video(self):
|
def test_segment_for_video(self):
|
||||||
segmenter_options = _SegmenterOptions(output_type=_OutputType.CATEGORY_MASK)
|
|
||||||
options = _ImageSegmenterOptions(
|
options = _ImageSegmenterOptions(
|
||||||
base_options=_BaseOptions(model_asset_path=self.model_path),
|
base_options=_BaseOptions(model_asset_path=self.model_path),
|
||||||
segmenter_options=segmenter_options,
|
output_type=_OutputType.CATEGORY_MASK,
|
||||||
running_mode=_RUNNING_MODE.VIDEO)
|
running_mode=_RUNNING_MODE.VIDEO)
|
||||||
with _ImageSegmenter.create_from_options(options) as segmenter:
|
with _ImageSegmenter.create_from_options(options) as segmenter:
|
||||||
for timestamp in range(0, 300, 30):
|
for timestamp in range(0, 300, 30):
|
||||||
|
@ -348,10 +341,9 @@ class ImageSegmenterTest(parameterized.TestCase):
|
||||||
self.assertLess(observed_timestamp_ms, timestamp_ms)
|
self.assertLess(observed_timestamp_ms, timestamp_ms)
|
||||||
self.observed_timestamp_ms = timestamp_ms
|
self.observed_timestamp_ms = timestamp_ms
|
||||||
|
|
||||||
segmenter_options = _SegmenterOptions(output_type=_OutputType.CATEGORY_MASK)
|
|
||||||
options = _ImageSegmenterOptions(
|
options = _ImageSegmenterOptions(
|
||||||
base_options=_BaseOptions(model_asset_path=self.model_path),
|
base_options=_BaseOptions(model_asset_path=self.model_path),
|
||||||
segmenter_options=segmenter_options,
|
output_type=_OutputType.CATEGORY_MASK,
|
||||||
running_mode=_RUNNING_MODE.LIVE_STREAM,
|
running_mode=_RUNNING_MODE.LIVE_STREAM,
|
||||||
result_callback=check_result)
|
result_callback=check_result)
|
||||||
with _ImageSegmenter.create_from_options(options) as segmenter:
|
with _ImageSegmenter.create_from_options(options) as segmenter:
|
||||||
|
|
|
@ -46,6 +46,7 @@ py_library(
|
||||||
"//mediapipe/python:_framework_bindings",
|
"//mediapipe/python:_framework_bindings",
|
||||||
"//mediapipe/python:packet_creator",
|
"//mediapipe/python:packet_creator",
|
||||||
"//mediapipe/python:packet_getter",
|
"//mediapipe/python:packet_getter",
|
||||||
|
"//mediapipe/tasks/cc/components/proto:segmenter_options_py_pb2",
|
||||||
"//mediapipe/tasks/cc/vision/image_segmenter/proto:image_segmenter_options_py_pb2",
|
"//mediapipe/tasks/cc/vision/image_segmenter/proto:image_segmenter_options_py_pb2",
|
||||||
"//mediapipe/tasks/python/components/proto:segmenter_options",
|
"//mediapipe/tasks/python/components/proto:segmenter_options",
|
||||||
"//mediapipe/tasks/python/core:base_options",
|
"//mediapipe/tasks/python/core:base_options",
|
||||||
|
|
|
@ -19,22 +19,25 @@ from typing import Callable, List, Mapping, Optional
|
||||||
from mediapipe.python import packet_creator
|
from mediapipe.python import packet_creator
|
||||||
from mediapipe.python import packet_getter
|
from mediapipe.python import packet_getter
|
||||||
from mediapipe.python._framework_bindings import image as image_module
|
from mediapipe.python._framework_bindings import image as image_module
|
||||||
from mediapipe.python._framework_bindings import packet as packet_module
|
from mediapipe.python._framework_bindings import packet
|
||||||
from mediapipe.python._framework_bindings import task_runner as task_runner_module
|
from mediapipe.python._framework_bindings import task_runner
|
||||||
|
from mediapipe.tasks.cc.components.proto import segmenter_options_pb2
|
||||||
from mediapipe.tasks.cc.vision.image_segmenter.proto import image_segmenter_options_pb2
|
from mediapipe.tasks.cc.vision.image_segmenter.proto import image_segmenter_options_pb2
|
||||||
from mediapipe.tasks.python.components.proto import segmenter_options
|
from mediapipe.tasks.python.components.proto import segmenter_options
|
||||||
from mediapipe.tasks.python.core import base_options as base_options_module
|
from mediapipe.tasks.python.core import base_options as base_options_module
|
||||||
from mediapipe.tasks.python.core import task_info as task_info_module
|
from mediapipe.tasks.python.core import task_info as task_info_module
|
||||||
from mediapipe.tasks.python.core.optional_dependencies import doc_controls
|
from mediapipe.tasks.python.core.optional_dependencies import doc_controls
|
||||||
from mediapipe.tasks.python.vision.core import base_vision_task_api
|
from mediapipe.tasks.python.vision.core import base_vision_task_api
|
||||||
from mediapipe.tasks.python.vision.core import vision_task_running_mode as running_mode_module
|
from mediapipe.tasks.python.vision.core import vision_task_running_mode
|
||||||
|
|
||||||
_BaseOptions = base_options_module.BaseOptions
|
_BaseOptions = base_options_module.BaseOptions
|
||||||
|
_SegmenterOptionsProto = segmenter_options_pb2.SegmenterOptions
|
||||||
_ImageSegmenterOptionsProto = image_segmenter_options_pb2.ImageSegmenterOptions
|
_ImageSegmenterOptionsProto = image_segmenter_options_pb2.ImageSegmenterOptions
|
||||||
_SegmenterOptions = segmenter_options.SegmenterOptions
|
_OutputType = segmenter_options.OutputType
|
||||||
_RunningMode = running_mode_module.VisionTaskRunningMode
|
_Activation = segmenter_options.Activation
|
||||||
|
_RunningMode = vision_task_running_mode.VisionTaskRunningMode
|
||||||
_TaskInfo = task_info_module.TaskInfo
|
_TaskInfo = task_info_module.TaskInfo
|
||||||
_TaskRunner = task_runner_module.TaskRunner
|
_TaskRunner = task_runner.TaskRunner
|
||||||
|
|
||||||
_SEGMENTATION_OUT_STREAM_NAME = 'segmented_mask_out'
|
_SEGMENTATION_OUT_STREAM_NAME = 'segmented_mask_out'
|
||||||
_SEGMENTATION_TAG = 'GROUPED_SEGMENTATION'
|
_SEGMENTATION_TAG = 'GROUPED_SEGMENTATION'
|
||||||
|
@ -57,14 +60,17 @@ class ImageSegmenterOptions:
|
||||||
2) The video mode for segmenting objects on the decoded frames of a video.
|
2) The video mode for segmenting objects on the decoded frames of a video.
|
||||||
3) The live stream mode for segmenting objects on a live stream of input
|
3) The live stream mode for segmenting objects on a live stream of input
|
||||||
data, such as from camera.
|
data, such as from camera.
|
||||||
segmenter_options: Options for the image segmenter task.
|
output_type: The output mask type allows specifying the type of
|
||||||
|
post-processing to perform on the raw model results.
|
||||||
|
activation: Activation function to apply to input tensor.
|
||||||
result_callback: The user-defined result callback for processing live stream
|
result_callback: The user-defined result callback for processing live stream
|
||||||
data. The result callback should only be specified when the running mode
|
data. The result callback should only be specified when the running mode
|
||||||
is set to the live stream mode.
|
is set to the live stream mode.
|
||||||
"""
|
"""
|
||||||
base_options: _BaseOptions
|
base_options: _BaseOptions
|
||||||
running_mode: _RunningMode = _RunningMode.IMAGE
|
running_mode: _RunningMode = _RunningMode.IMAGE
|
||||||
segmenter_options: _SegmenterOptions = _SegmenterOptions()
|
output_type: Optional[_OutputType] = _OutputType.CATEGORY_MASK
|
||||||
|
activation: Optional[_Activation] = _Activation.NONE
|
||||||
result_callback: Optional[
|
result_callback: Optional[
|
||||||
Callable[[List[image_module.Image], image_module.Image, int],
|
Callable[[List[image_module.Image], image_module.Image, int],
|
||||||
None]] = None
|
None]] = None
|
||||||
|
@ -74,8 +80,10 @@ class ImageSegmenterOptions:
|
||||||
"""Generates an ImageSegmenterOptions protobuf object."""
|
"""Generates an ImageSegmenterOptions protobuf object."""
|
||||||
base_options_proto = self.base_options.to_pb2()
|
base_options_proto = self.base_options.to_pb2()
|
||||||
base_options_proto.use_stream_mode = False if self.running_mode == _RunningMode.IMAGE else True
|
base_options_proto.use_stream_mode = False if self.running_mode == _RunningMode.IMAGE else True
|
||||||
segmenter_options_proto = self.segmenter_options.to_pb2()
|
segmenter_options_proto = _SegmenterOptionsProto(
|
||||||
|
output_type=self.output_type.value,
|
||||||
|
activation=self.activation.value
|
||||||
|
)
|
||||||
return _ImageSegmenterOptionsProto(
|
return _ImageSegmenterOptionsProto(
|
||||||
base_options=base_options_proto,
|
base_options=base_options_proto,
|
||||||
segmenter_options=segmenter_options_proto
|
segmenter_options=segmenter_options_proto
|
||||||
|
@ -127,7 +135,7 @@ class ImageSegmenter(base_vision_task_api.BaseVisionTaskApi):
|
||||||
RuntimeError: If other types of error occurred.
|
RuntimeError: If other types of error occurred.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
def packets_callback(output_packets: Mapping[str, packet_module.Packet]):
|
def packets_callback(output_packets: Mapping[str, packet.Packet]):
|
||||||
if output_packets[_IMAGE_OUT_STREAM_NAME].is_empty():
|
if output_packets[_IMAGE_OUT_STREAM_NAME].is_empty():
|
||||||
return
|
return
|
||||||
segmentation_result = packet_getter.get_image_list(
|
segmentation_result = packet_getter.get_image_list(
|
||||||
|
@ -159,8 +167,11 @@ class ImageSegmenter(base_vision_task_api.BaseVisionTaskApi):
|
||||||
image: MediaPipe Image.
|
image: MediaPipe Image.
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
A segmentation result object that contains a list of segmentation masks
|
If the output_type is CATEGORY_MASK, the returned vector of images is
|
||||||
as images.
|
per-category segmented image mask.
|
||||||
|
If the output_type is CONFIDENCE_MASK, the returned vector of images
|
||||||
|
contains only one confidence image mask. A segmentation result object that
|
||||||
|
contains a list of segmentation masks as images.
|
||||||
|
|
||||||
Raises:
|
Raises:
|
||||||
ValueError: If any of the input arguments is invalid.
|
ValueError: If any of the input arguments is invalid.
|
||||||
|
@ -186,8 +197,11 @@ class ImageSegmenter(base_vision_task_api.BaseVisionTaskApi):
|
||||||
timestamp_ms: The timestamp of the input video frame in milliseconds.
|
timestamp_ms: The timestamp of the input video frame in milliseconds.
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
A segmentation result object that contains a list of segmentation masks
|
If the output_type is CATEGORY_MASK, the returned vector of images is
|
||||||
as images.
|
per-category segmented image mask.
|
||||||
|
If the output_type is CONFIDENCE_MASK, the returned vector of images
|
||||||
|
contains only one confidence image mask. A segmentation result object that
|
||||||
|
contains a list of segmentation masks as images.
|
||||||
|
|
||||||
Raises:
|
Raises:
|
||||||
ValueError: If any of the input arguments is invalid.
|
ValueError: If any of the input arguments is invalid.
|
||||||
|
|
Loading…
Reference in New Issue
Block a user