Code cleanup

This commit is contained in:
kinaryml 2022-09-21 04:06:12 -07:00
parent 500ad5a7f0
commit 660a88b7ea
2 changed files with 7 additions and 43 deletions

View File

@ -17,7 +17,6 @@ import enum
from absl.testing import absltest from absl.testing import absltest
from absl.testing import parameterized from absl.testing import parameterized
import numpy as np
from mediapipe.python._framework_bindings import image as image_module from mediapipe.python._framework_bindings import image as image_module
from mediapipe.tasks.python.components import segmenter_options from mediapipe.tasks.python.components import segmenter_options
@ -30,6 +29,7 @@ _BaseOptions = base_options_module.BaseOptions
_Image = image_module.Image _Image = image_module.Image
_OutputType = segmenter_options.OutputType _OutputType = segmenter_options.OutputType
_Activation = segmenter_options.Activation _Activation = segmenter_options.Activation
_SegmenterOptions = segmenter_options.SegmenterOptions
_ImageSegmenter = image_segmenter.ImageSegmenter _ImageSegmenter = image_segmenter.ImageSegmenter
_ImageSegmenterOptions = image_segmenter.ImageSegmenterOptions _ImageSegmenterOptions = image_segmenter.ImageSegmenterOptions
_RUNNING_MODE = running_mode_module.VisionTaskRunningMode _RUNNING_MODE = running_mode_module.VisionTaskRunningMode
@ -99,17 +99,14 @@ class ImageSegmenterTest(parameterized.TestCase):
# Should never happen # Should never happen
raise ValueError('model_file_type is invalid.') raise ValueError('model_file_type is invalid.')
segmenter_options = _SegmenterOptions(output_type=_OutputType.CATEGORY_MASK)
options = _ImageSegmenterOptions(base_options=base_options, options = _ImageSegmenterOptions(base_options=base_options,
output_type=_OutputType.CATEGORY_MASK) segmenter_options=segmenter_options)
segmenter = _ImageSegmenter.create_from_options(options) segmenter = _ImageSegmenter.create_from_options(options)
# Performs image segmentation on the input. # Performs image segmentation on the input.
category_masks = segmenter.segment(self.test_image) category_masks = segmenter.segment(self.test_image)
# Comparing results.
print(len(category_masks))
s
# Closes the segmenter explicitly when the segmenter is not used in # Closes the segmenter explicitly when the segmenter is not used in
# a context. # a context.
segmenter.close() segmenter.close()

View File

@ -57,16 +57,14 @@ class ImageSegmenterOptions:
2) The video mode for detecting objects on the decoded frames of a video. 2) The video mode for detecting objects on the decoded frames of a video.
3) The live stream mode for detecting objects on a live stream of input 3) The live stream mode for detecting objects on a live stream of input
data, such as from camera. data, such as from camera.
output_type: Optional output mask type. segmenter_options: Options for the image segmenter task.
activation: Activation function to apply to input tensor.
result_callback: The user-defined result callback for processing live stream result_callback: The user-defined result callback for processing live stream
data. The result callback should only be specified when the running mode data. The result callback should only be specified when the running mode
is set to the live stream mode. is set to the live stream mode.
""" """
base_options: _BaseOptions base_options: _BaseOptions
running_mode: _RunningMode = _RunningMode.IMAGE running_mode: _RunningMode = _RunningMode.IMAGE
output_type: Optional[segmenter_options.OutputType] = segmenter_options.OutputType.CATEGORY_MASK segmenter_options: _SegmenterOptions = _SegmenterOptions()
activation: Optional[segmenter_options.Activation] = segmenter_options.Activation.NONE
result_callback: Optional[ result_callback: Optional[
Callable[[List[image_module.Image], image_module.Image, int], Callable[[List[image_module.Image], image_module.Image, int],
None]] = None None]] = None
@ -76,15 +74,11 @@ class ImageSegmenterOptions:
"""Generates an ImageSegmenterOptions protobuf object.""" """Generates an ImageSegmenterOptions protobuf object."""
base_options_proto = self.base_options.to_pb2() base_options_proto = self.base_options.to_pb2()
base_options_proto.use_stream_mode = False if self.running_mode == _RunningMode.IMAGE else True base_options_proto.use_stream_mode = False if self.running_mode == _RunningMode.IMAGE else True
segmenter_options_proto = self.segmenter_options.to_pb2()
segmenter_options = _SegmenterOptions(
output_type=self.output_type,
activation=self.activation
)
return _ImageSegmenterOptionsProto( return _ImageSegmenterOptionsProto(
base_options=base_options_proto, base_options=base_options_proto,
segmenter_options=segmenter_options.to_pb2() segmenter_options=segmenter_options_proto
) )
@ -176,30 +170,3 @@ class ImageSegmenter(base_vision_task_api.BaseVisionTaskApi):
segmentation_result = packet_getter.get_proto_list( segmentation_result = packet_getter.get_proto_list(
output_packets[_SEGMENTATION_OUT_STREAM_NAME]) output_packets[_SEGMENTATION_OUT_STREAM_NAME])
return segmentation_result return segmentation_result
# def segment_async(self, image: image_module.Image, timestamp_ms: int) -> None:
# """Sends live image data (an Image with a unique timestamp) to perform image segmentation.
#
# This method will return immediately after the input image is accepted. The
# results will be available via the `result_callback` provided in the
# `ImageSegmenterOptions`. The `segment_async` method is designed to process
# live stream data such as camera input. To lower the overall latency, image
# segmenter may drop the input images if needed. In other words, it's not
# guaranteed to have output per input image. The `result_callback` provides:
# - A segmentation result object that contains a list of segmentation masks
# as images.
# - The input image that the image segmenter runs on.
# - The input timestamp in milliseconds.
#
# Args:
# image: MediaPipe Image.
# timestamp_ms: The timestamp of the input image in milliseconds.
#
# Raises:
# ValueError: If the current input timestamp is smaller than what the object
# detector has already processed.
# """
# self._send_live_stream_data({
# _IMAGE_IN_STREAM_NAME:
# packet_creator.create_image(image).at(timestamp_ms)
# })