Deprecated output_type for the ImageSegmenter and InteractiveSegmenter APIs

This commit is contained in:
kinaryml 2023-04-12 14:37:16 -07:00
parent c7aecb42ff
commit 3f68f90238
6 changed files with 272 additions and 132 deletions

View File

@ -93,6 +93,25 @@ py_test(
], ],
) )
py_test(
name = "interactive_segmenter_test",
srcs = ["interactive_segmenter_test.py"],
data = [
"//mediapipe/tasks/testdata/vision:test_images",
"//mediapipe/tasks/testdata/vision:test_models",
],
deps = [
"//mediapipe/python:_framework_bindings",
"//mediapipe/tasks/python/components/containers:keypoint",
"//mediapipe/tasks/python/components/containers:rect",
"//mediapipe/tasks/python/core:base_options",
"//mediapipe/tasks/python/test:test_utils",
"//mediapipe/tasks/python/vision:interactive_segmenter",
"//mediapipe/tasks/python/vision/core:image_processing_options",
"//mediapipe/tasks/python/vision/core:vision_task_running_mode",
],
)
py_test( py_test(
name = "face_detector_test", name = "face_detector_test",
srcs = ["face_detector_test.py"], srcs = ["face_detector_test.py"],

View File

@ -30,10 +30,10 @@ from mediapipe.tasks.python.test import test_utils
from mediapipe.tasks.python.vision import image_segmenter from mediapipe.tasks.python.vision import image_segmenter
from mediapipe.tasks.python.vision.core import vision_task_running_mode from mediapipe.tasks.python.vision.core import vision_task_running_mode
ImageSegmenterResult = image_segmenter.ImageSegmenterResult
_BaseOptions = base_options_module.BaseOptions _BaseOptions = base_options_module.BaseOptions
_Image = image_module.Image _Image = image_module.Image
_ImageFormat = image_frame.ImageFormat _ImageFormat = image_frame.ImageFormat
_OutputType = image_segmenter.ImageSegmenterOptions.OutputType
_Activation = image_segmenter.ImageSegmenterOptions.Activation _Activation = image_segmenter.ImageSegmenterOptions.Activation
_ImageSegmenter = image_segmenter.ImageSegmenter _ImageSegmenter = image_segmenter.ImageSegmenter
_ImageSegmenterOptions = image_segmenter.ImageSegmenterOptions _ImageSegmenterOptions = image_segmenter.ImageSegmenterOptions
@ -42,11 +42,33 @@ _RUNNING_MODE = vision_task_running_mode.VisionTaskRunningMode
_MODEL_FILE = 'deeplabv3.tflite' _MODEL_FILE = 'deeplabv3.tflite'
_IMAGE_FILE = 'segmentation_input_rotation0.jpg' _IMAGE_FILE = 'segmentation_input_rotation0.jpg'
_SEGMENTATION_FILE = 'segmentation_golden_rotation0.png' _SEGMENTATION_FILE = 'segmentation_golden_rotation0.png'
_CAT_IMAGE = 'cat.jpg'
_CAT_MASK = 'cat_mask.jpg'
_MASK_MAGNIFICATION_FACTOR = 10 _MASK_MAGNIFICATION_FACTOR = 10
_MASK_SIMILARITY_THRESHOLD = 0.98 _MASK_SIMILARITY_THRESHOLD = 0.98
_TEST_DATA_DIR = 'mediapipe/tasks/testdata/vision' _TEST_DATA_DIR = 'mediapipe/tasks/testdata/vision'
def _calculate_soft_iou(m1, m2):
intersection_sum = np.sum(m1 * m2)
union_sum = np.sum(m1 * m1) + np.sum(m2 * m2) - intersection_sum
if union_sum > 0:
return intersection_sum / union_sum
else:
return 0
def _similar_to_float_mask(actual_mask, expected_mask, similarity_threshold):
actual_mask = actual_mask.numpy_view()
expected_mask = expected_mask.numpy_view() / 255.0
return (
actual_mask.shape == expected_mask.shape
and _calculate_soft_iou(actual_mask, expected_mask) > similarity_threshold
)
def _similar_to_uint8_mask(actual_mask, expected_mask): def _similar_to_uint8_mask(actual_mask, expected_mask):
actual_mask_pixels = actual_mask.numpy_view().flatten() actual_mask_pixels = actual_mask.numpy_view().flatten()
expected_mask_pixels = expected_mask.numpy_view().flatten() expected_mask_pixels = expected_mask.numpy_view().flatten()
@ -84,6 +106,14 @@ class ImageSegmenterTest(parameterized.TestCase):
self.model_path = test_utils.get_test_data_path( self.model_path = test_utils.get_test_data_path(
os.path.join(_TEST_DATA_DIR, _MODEL_FILE)) os.path.join(_TEST_DATA_DIR, _MODEL_FILE))
def _load_segmentation_mask(self, file_path: str):
# Loads ground truth segmentation file.
gt_segmentation_data = cv2.imread(
test_utils.get_test_data_path(os.path.join(_TEST_DATA_DIR, file_path)),
cv2.IMREAD_GRAYSCALE,
)
return _Image(_ImageFormat.GRAY8, gt_segmentation_data)
def test_create_from_file_succeeds_with_valid_model_path(self): def test_create_from_file_succeeds_with_valid_model_path(self):
# Creates with default option and valid model file successfully. # Creates with default option and valid model file successfully.
with _ImageSegmenter.create_from_model_path(self.model_path) as segmenter: with _ImageSegmenter.create_from_model_path(self.model_path) as segmenter:
@ -127,20 +157,19 @@ class ImageSegmenterTest(parameterized.TestCase):
raise ValueError('model_file_type is invalid.') raise ValueError('model_file_type is invalid.')
options = _ImageSegmenterOptions( options = _ImageSegmenterOptions(
base_options=base_options, output_type=_OutputType.CATEGORY_MASK) base_options=base_options, output_category_mask=True)
segmenter = _ImageSegmenter.create_from_options(options) segmenter = _ImageSegmenter.create_from_options(options)
# Performs image segmentation on the input. # Performs image segmentation on the input.
category_masks = segmenter.segment(self.test_image) segmentation_result = segmenter.segment(self.test_image)
self.assertLen(category_masks, 1) category_mask = segmentation_result.category_mask
category_mask = category_masks[0]
result_pixels = category_mask.numpy_view().flatten() result_pixels = category_mask.numpy_view().flatten()
# Check if data type of `category_mask` is correct. # Check if data type of `category_mask` is correct.
self.assertEqual(result_pixels.dtype, np.uint8) self.assertEqual(result_pixels.dtype, np.uint8)
self.assertTrue( self.assertTrue(
_similar_to_uint8_mask(category_masks[0], self.test_seg_image), _similar_to_uint8_mask(category_mask, self.test_seg_image),
f'Number of pixels in the candidate mask differing from that of the ' f'Number of pixels in the candidate mask differing from that of the '
f'ground truth mask exceeds {_MASK_SIMILARITY_THRESHOLD}.') f'ground truth mask exceeds {_MASK_SIMILARITY_THRESHOLD}.')
@ -152,67 +181,33 @@ class ImageSegmenterTest(parameterized.TestCase):
# Creates segmenter. # Creates segmenter.
base_options = _BaseOptions(model_asset_path=self.model_path) base_options = _BaseOptions(model_asset_path=self.model_path)
# Run segmentation on the model in CATEGORY_MASK mode. # Load the cat image.
options = _ImageSegmenterOptions( test_image = _Image.create_from_file(
base_options=base_options, output_type=_OutputType.CATEGORY_MASK) test_utils.get_test_data_path(
segmenter = _ImageSegmenter.create_from_options(options) os.path.join(_TEST_DATA_DIR, _CAT_IMAGE)))
category_masks = segmenter.segment(self.test_image)
category_mask = category_masks[0].numpy_view()
# Run segmentation on the model in CONFIDENCE_MASK mode. # Run segmentation on the model in CONFIDENCE_MASK mode.
options = _ImageSegmenterOptions( options = _ImageSegmenterOptions(
base_options=base_options, base_options=base_options,
output_type=_OutputType.CONFIDENCE_MASK,
activation=_Activation.SOFTMAX) activation=_Activation.SOFTMAX)
segmenter = _ImageSegmenter.create_from_options(options)
confidence_masks = segmenter.segment(self.test_image)
# Check if confidence mask shape is correct.
self.assertLen(
confidence_masks, 21,
'Number of confidence masks must match with number of categories.')
# Gather the confidence masks in a single array `confidence_mask_array`.
confidence_mask_array = np.array(
[confidence_mask.numpy_view() for confidence_mask in confidence_masks])
# Check if data type of `confidence_masks` are correct.
self.assertEqual(confidence_mask_array.dtype, np.float32)
# Compute the category mask from the created confidence mask.
calculated_category_mask = np.argmax(confidence_mask_array, axis=0)
self.assertListEqual(
calculated_category_mask.tolist(), category_mask.tolist(),
'Confidence mask does not match with the category mask.')
# Closes the segmenter explicitly when the segmenter is not used in
# a context.
segmenter.close()
@parameterized.parameters((ModelFileType.FILE_NAME),
(ModelFileType.FILE_CONTENT))
def test_segment_in_context(self, model_file_type):
if model_file_type is ModelFileType.FILE_NAME:
base_options = _BaseOptions(model_asset_path=self.model_path)
elif model_file_type is ModelFileType.FILE_CONTENT:
with open(self.model_path, 'rb') as f:
model_contents = f.read()
base_options = _BaseOptions(model_asset_buffer=model_contents)
else:
# Should never happen
raise ValueError('model_file_type is invalid.')
options = _ImageSegmenterOptions(
base_options=base_options, output_type=_OutputType.CATEGORY_MASK)
with _ImageSegmenter.create_from_options(options) as segmenter: with _ImageSegmenter.create_from_options(options) as segmenter:
# Performs image segmentation on the input. segmentation_result = segmenter.segment(test_image)
category_masks = segmenter.segment(self.test_image) confidence_masks = segmentation_result.confidence_masks
self.assertLen(category_masks, 1)
# Check if confidence mask shape is correct.
self.assertLen(
confidence_masks, 21,
'Number of confidence masks must match with number of categories.')
# Loads ground truth segmentation file.
expected_mask = self._load_segmentation_mask(_CAT_MASK)
self.assertTrue( self.assertTrue(
_similar_to_uint8_mask(category_masks[0], self.test_seg_image), _similar_to_float_mask(
f'Number of pixels in the candidate mask differing from that of the ' confidence_masks[8], expected_mask, _MASK_SIMILARITY_THRESHOLD
f'ground truth mask exceeds {_MASK_SIMILARITY_THRESHOLD}.') )
)
def test_missing_result_callback(self): def test_missing_result_callback(self):
options = _ImageSegmenterOptions( options = _ImageSegmenterOptions(
@ -280,20 +275,49 @@ class ImageSegmenterTest(parameterized.TestCase):
ValueError, r'Input timestamp must be monotonically increasing'): ValueError, r'Input timestamp must be monotonically increasing'):
segmenter.segment_for_video(self.test_image, 0) segmenter.segment_for_video(self.test_image, 0)
def test_segment_for_video(self): def test_segment_for_video_in_category_mask_mode(self):
options = _ImageSegmenterOptions( options = _ImageSegmenterOptions(
base_options=_BaseOptions(model_asset_path=self.model_path), base_options=_BaseOptions(model_asset_path=self.model_path),
output_type=_OutputType.CATEGORY_MASK, output_category_mask=True,
running_mode=_RUNNING_MODE.VIDEO) running_mode=_RUNNING_MODE.VIDEO)
with _ImageSegmenter.create_from_options(options) as segmenter: with _ImageSegmenter.create_from_options(options) as segmenter:
for timestamp in range(0, 300, 30): for timestamp in range(0, 300, 30):
category_masks = segmenter.segment_for_video(self.test_image, timestamp) segmentation_result = segmenter.segment_for_video(
self.assertLen(category_masks, 1) self.test_image, timestamp)
category_mask = segmentation_result.category_mask
self.assertTrue( self.assertTrue(
_similar_to_uint8_mask(category_masks[0], self.test_seg_image), _similar_to_uint8_mask(category_mask, self.test_seg_image),
f'Number of pixels in the candidate mask differing from that of the ' f'Number of pixels in the candidate mask differing from that of the '
f'ground truth mask exceeds {_MASK_SIMILARITY_THRESHOLD}.') f'ground truth mask exceeds {_MASK_SIMILARITY_THRESHOLD}.')
def test_segment_for_video_in_confidence_mask_mode(self):
# Load the cat image.
test_image = _Image.create_from_file(
test_utils.get_test_data_path(
os.path.join(_TEST_DATA_DIR, _CAT_IMAGE)))
options = _ImageSegmenterOptions(
base_options=_BaseOptions(model_asset_path=self.model_path),
running_mode=_RUNNING_MODE.VIDEO)
with _ImageSegmenter.create_from_options(options) as segmenter:
for timestamp in range(0, 300, 30):
segmentation_result = segmenter.segment_for_video(
test_image, timestamp)
confidence_masks = segmentation_result.confidence_masks
# Check if confidence mask shape is correct.
self.assertLen(
confidence_masks, 21,
'Number of confidence masks must match with number of categories.')
# Loads ground truth segmentation file.
expected_mask = self._load_segmentation_mask(_CAT_MASK)
self.assertTrue(
_similar_to_float_mask(
confidence_masks[8], expected_mask, _MASK_SIMILARITY_THRESHOLD
)
)
def test_calling_segment_in_live_stream_mode(self): def test_calling_segment_in_live_stream_mode(self):
options = _ImageSegmenterOptions( options = _ImageSegmenterOptions(
base_options=_BaseOptions(model_asset_path=self.model_path), base_options=_BaseOptions(model_asset_path=self.model_path),
@ -325,13 +349,13 @@ class ImageSegmenterTest(parameterized.TestCase):
ValueError, r'Input timestamp must be monotonically increasing'): ValueError, r'Input timestamp must be monotonically increasing'):
segmenter.segment_async(self.test_image, 0) segmenter.segment_async(self.test_image, 0)
def test_segment_async_calls(self): def test_segment_async_calls_in_category_mask_mode(self):
observed_timestamp_ms = -1 observed_timestamp_ms = -1
def check_result(result: List[image_module.Image], output_image: _Image, def check_result(result: ImageSegmenterResult, output_image: _Image,
timestamp_ms: int): timestamp_ms: int):
# Get the output category mask. # Get the output category mask.
category_mask = result[0] category_mask = result.category_mask
self.assertEqual(output_image.width, self.test_image.width) self.assertEqual(output_image.width, self.test_image.width)
self.assertEqual(output_image.height, self.test_image.height) self.assertEqual(output_image.height, self.test_image.height)
self.assertEqual(output_image.width, self.test_seg_image.width) self.assertEqual(output_image.width, self.test_seg_image.width)
@ -345,13 +369,49 @@ class ImageSegmenterTest(parameterized.TestCase):
options = _ImageSegmenterOptions( options = _ImageSegmenterOptions(
base_options=_BaseOptions(model_asset_path=self.model_path), base_options=_BaseOptions(model_asset_path=self.model_path),
output_type=_OutputType.CATEGORY_MASK, output_category_mask=True,
running_mode=_RUNNING_MODE.LIVE_STREAM, running_mode=_RUNNING_MODE.LIVE_STREAM,
result_callback=check_result) result_callback=check_result)
with _ImageSegmenter.create_from_options(options) as segmenter: with _ImageSegmenter.create_from_options(options) as segmenter:
for timestamp in range(0, 300, 30): for timestamp in range(0, 300, 30):
segmenter.segment_async(self.test_image, timestamp) segmenter.segment_async(self.test_image, timestamp)
def test_segment_async_calls_in_confidence_mask_mode(self):
# Load the cat image.
test_image = _Image.create_from_file(
test_utils.get_test_data_path(
os.path.join(_TEST_DATA_DIR, _CAT_IMAGE)))
# Loads ground truth segmentation file.
expected_mask = self._load_segmentation_mask(_CAT_MASK)
observed_timestamp_ms = -1
def check_result(result: ImageSegmenterResult, output_image: _Image,
timestamp_ms: int):
# Get the output category mask.
confidence_masks = result.confidence_masks
# Check if confidence mask shape is correct.
self.assertLen(
confidence_masks, 21,
'Number of confidence masks must match with number of categories.')
self.assertTrue(
_similar_to_float_mask(
confidence_masks[8], expected_mask, _MASK_SIMILARITY_THRESHOLD
)
)
self.assertLess(observed_timestamp_ms, timestamp_ms)
self.observed_timestamp_ms = timestamp_ms
options = _ImageSegmenterOptions(
base_options=_BaseOptions(model_asset_path=self.model_path),
running_mode=_RUNNING_MODE.LIVE_STREAM,
result_callback=check_result)
with _ImageSegmenter.create_from_options(options) as segmenter:
for timestamp in range(0, 300, 30):
segmenter.segment_async(test_image, timestamp)
if __name__ == '__main__': if __name__ == '__main__':
absltest.main() absltest.main()

View File

@ -30,12 +30,12 @@ from mediapipe.tasks.python.test import test_utils
from mediapipe.tasks.python.vision import interactive_segmenter from mediapipe.tasks.python.vision import interactive_segmenter
from mediapipe.tasks.python.vision.core import image_processing_options as image_processing_options_module from mediapipe.tasks.python.vision.core import image_processing_options as image_processing_options_module
InteractiveSegmenterResult = interactive_segmenter.InteractiveSegmenterResult
_BaseOptions = base_options_module.BaseOptions _BaseOptions = base_options_module.BaseOptions
_Image = image_module.Image _Image = image_module.Image
_ImageFormat = image_frame.ImageFormat _ImageFormat = image_frame.ImageFormat
_NormalizedKeypoint = keypoint_module.NormalizedKeypoint _NormalizedKeypoint = keypoint_module.NormalizedKeypoint
_Rect = rect.Rect _Rect = rect.Rect
_OutputType = interactive_segmenter.InteractiveSegmenterOptions.OutputType
_InteractiveSegmenter = interactive_segmenter.InteractiveSegmenter _InteractiveSegmenter = interactive_segmenter.InteractiveSegmenter
_InteractiveSegmenterOptions = interactive_segmenter.InteractiveSegmenterOptions _InteractiveSegmenterOptions = interactive_segmenter.InteractiveSegmenterOptions
_RegionOfInterest = interactive_segmenter.RegionOfInterest _RegionOfInterest = interactive_segmenter.RegionOfInterest
@ -200,15 +200,14 @@ class InteractiveSegmenterTest(parameterized.TestCase):
raise ValueError('model_file_type is invalid.') raise ValueError('model_file_type is invalid.')
options = _InteractiveSegmenterOptions( options = _InteractiveSegmenterOptions(
base_options=base_options, output_type=_OutputType.CATEGORY_MASK base_options=base_options, output_category_mask=True
) )
segmenter = _InteractiveSegmenter.create_from_options(options) segmenter = _InteractiveSegmenter.create_from_options(options)
# Performs image segmentation on the input. # Performs image segmentation on the input.
roi = _RegionOfInterest(format=roi_format, keypoint=keypoint) roi = _RegionOfInterest(format=roi_format, keypoint=keypoint)
category_masks = segmenter.segment(self.test_image, roi) segmentation_result = segmenter.segment(self.test_image, roi)
self.assertLen(category_masks, 1) category_mask = segmentation_result.category_mask
category_mask = category_masks[0]
result_pixels = category_mask.numpy_view().flatten() result_pixels = category_mask.numpy_view().flatten()
# Check if data type of `category_mask` is correct. # Check if data type of `category_mask` is correct.
@ -219,7 +218,7 @@ class InteractiveSegmenterTest(parameterized.TestCase):
self.assertTrue( self.assertTrue(
_similar_to_uint8_mask( _similar_to_uint8_mask(
category_masks[0], test_seg_image, similarity_threshold category_mask, test_seg_image, similarity_threshold
), ),
( (
'Number of pixels in the candidate mask differing from that of the' 'Number of pixels in the candidate mask differing from that of the'
@ -253,13 +252,12 @@ class InteractiveSegmenterTest(parameterized.TestCase):
roi = _RegionOfInterest(format=roi_format, keypoint=keypoint) roi = _RegionOfInterest(format=roi_format, keypoint=keypoint)
# Run segmentation on the model in CONFIDENCE_MASK mode. # Run segmentation on the model in CONFIDENCE_MASK mode.
options = _InteractiveSegmenterOptions( options = _InteractiveSegmenterOptions(base_options=base_options)
base_options=base_options, output_type=_OutputType.CONFIDENCE_MASK
)
with _InteractiveSegmenter.create_from_options(options) as segmenter: with _InteractiveSegmenter.create_from_options(options) as segmenter:
# Perform segmentation # Perform segmentation
confidence_masks = segmenter.segment(self.test_image, roi) segmentation_result = segmenter.segment(self.test_image, roi)
confidence_masks = segmentation_result.confidence_masks
# Check if confidence mask shape is correct. # Check if confidence mask shape is correct.
self.assertLen( self.assertLen(
@ -286,16 +284,15 @@ class InteractiveSegmenterTest(parameterized.TestCase):
) )
# Run segmentation on the model in CONFIDENCE_MASK mode. # Run segmentation on the model in CONFIDENCE_MASK mode.
options = _InteractiveSegmenterOptions( options = _InteractiveSegmenterOptions(base_options=base_options)
base_options=base_options, output_type=_OutputType.CONFIDENCE_MASK
)
with _InteractiveSegmenter.create_from_options(options) as segmenter: with _InteractiveSegmenter.create_from_options(options) as segmenter:
# Perform segmentation # Perform segmentation
image_processing_options = _ImageProcessingOptions(rotation_degrees=-90) image_processing_options = _ImageProcessingOptions(rotation_degrees=-90)
confidence_masks = segmenter.segment( segmentation_result = segmenter.segment(
self.test_image, roi, image_processing_options self.test_image, roi, image_processing_options
) )
confidence_masks = segmentation_result.confidence_masks
# Check if confidence mask shape is correct. # Check if confidence mask shape is correct.
self.assertLen( self.assertLen(
@ -313,9 +310,7 @@ class InteractiveSegmenterTest(parameterized.TestCase):
) )
# Run segmentation on the model in CONFIDENCE_MASK mode. # Run segmentation on the model in CONFIDENCE_MASK mode.
options = _InteractiveSegmenterOptions( options = _InteractiveSegmenterOptions(base_options=base_options)
base_options=base_options, output_type=_OutputType.CONFIDENCE_MASK
)
with self.assertRaisesRegex( with self.assertRaisesRegex(
ValueError, "This task doesn't support region-of-interest." ValueError, "This task doesn't support region-of-interest."

View File

@ -31,7 +31,6 @@ from mediapipe.tasks.python.vision.core import base_vision_task_api
from mediapipe.tasks.python.vision.core import image_processing_options as image_processing_options_module from mediapipe.tasks.python.vision.core import image_processing_options as image_processing_options_module
from mediapipe.tasks.python.vision.core import vision_task_running_mode from mediapipe.tasks.python.vision.core import vision_task_running_mode
ImageSegmenterResult = List[image_module.Image]
_NormalizedRect = rect.NormalizedRect _NormalizedRect = rect.NormalizedRect
_BaseOptions = base_options_module.BaseOptions _BaseOptions = base_options_module.BaseOptions
_SegmenterOptionsProto = segmenter_options_pb2.SegmenterOptions _SegmenterOptionsProto = segmenter_options_pb2.SegmenterOptions
@ -42,8 +41,10 @@ _RunningMode = vision_task_running_mode.VisionTaskRunningMode
_ImageProcessingOptions = image_processing_options_module.ImageProcessingOptions _ImageProcessingOptions = image_processing_options_module.ImageProcessingOptions
_TaskInfo = task_info_module.TaskInfo _TaskInfo = task_info_module.TaskInfo
_SEGMENTATION_OUT_STREAM_NAME = 'segmented_mask_out' _CONFIDENCE_MASKS_STREAM_NAME = 'confidence_masks'
_SEGMENTATION_TAG = 'GROUPED_SEGMENTATION' _CONFIDENCE_MASKS_TAG = 'CONFIDENCE_MASKS'
_CATEGORY_MASK_STREAM_NAME = 'category_mask'
_CATEGORY_MASK_TAG = 'CATEGORY_MASK'
_IMAGE_IN_STREAM_NAME = 'image_in' _IMAGE_IN_STREAM_NAME = 'image_in'
_IMAGE_OUT_STREAM_NAME = 'image_out' _IMAGE_OUT_STREAM_NAME = 'image_out'
_IMAGE_TAG = 'IMAGE' _IMAGE_TAG = 'IMAGE'
@ -53,6 +54,12 @@ _TASK_GRAPH_NAME = 'mediapipe.tasks.vision.image_segmenter.ImageSegmenterGraph'
_MICRO_SECONDS_PER_MILLISECOND = 1000 _MICRO_SECONDS_PER_MILLISECOND = 1000
@dataclasses.dataclass
class ImageSegmenterResult:
confidence_masks: Optional[List[image_module.Image]] = None
category_mask: Optional[image_module.Image] = None
@dataclasses.dataclass @dataclasses.dataclass
class ImageSegmenterOptions: class ImageSegmenterOptions:
"""Options for the image segmenter task. """Options for the image segmenter task.
@ -64,19 +71,13 @@ class ImageSegmenterOptions:
objects on single image inputs. 2) The video mode for segmenting objects objects on single image inputs. 2) The video mode for segmenting objects
on the decoded frames of a video. 3) The live stream mode for segmenting on the decoded frames of a video. 3) The live stream mode for segmenting
objects on a live stream of input data, such as from camera. objects on a live stream of input data, such as from camera.
output_type: The output mask type allows specifying the type of output_confidence_masks: Whether to output confidence masks.
post-processing to perform on the raw model results. output_category_mask: Whether to output category mask.
activation: Activation function to apply to input tensor.
result_callback: The user-defined result callback for processing live stream result_callback: The user-defined result callback for processing live stream
data. The result callback should only be specified when the running mode data. The result callback should only be specified when the running mode
is set to the live stream mode. is set to the live stream mode.
""" """
class OutputType(enum.Enum):
UNSPECIFIED = 0
CATEGORY_MASK = 1
CONFIDENCE_MASK = 2
class Activation(enum.Enum): class Activation(enum.Enum):
NONE = 0 NONE = 0
SIGMOID = 1 SIGMOID = 1
@ -84,7 +85,8 @@ class ImageSegmenterOptions:
base_options: _BaseOptions base_options: _BaseOptions
running_mode: _RunningMode = _RunningMode.IMAGE running_mode: _RunningMode = _RunningMode.IMAGE
output_type: Optional[OutputType] = OutputType.CATEGORY_MASK output_confidence_masks: bool = True
output_category_mask: bool = False
activation: Optional[Activation] = Activation.NONE activation: Optional[Activation] = Activation.NONE
result_callback: Optional[ result_callback: Optional[
Callable[[ImageSegmenterResult, image_module.Image, int], None] Callable[[ImageSegmenterResult, image_module.Image, int], None]
@ -98,7 +100,7 @@ class ImageSegmenterOptions:
False if self.running_mode == _RunningMode.IMAGE else True False if self.running_mode == _RunningMode.IMAGE else True
) )
segmenter_options_proto = _SegmenterOptionsProto( segmenter_options_proto = _SegmenterOptionsProto(
output_type=self.output_type.value, activation=self.activation.value activation=self.activation.value
) )
return _ImageSegmenterGraphOptionsProto( return _ImageSegmenterGraphOptionsProto(
base_options=base_options_proto, base_options=base_options_proto,
@ -177,27 +179,48 @@ class ImageSegmenter(base_vision_task_api.BaseVisionTaskApi):
def packets_callback(output_packets: Mapping[str, packet.Packet]): def packets_callback(output_packets: Mapping[str, packet.Packet]):
if output_packets[_IMAGE_OUT_STREAM_NAME].is_empty(): if output_packets[_IMAGE_OUT_STREAM_NAME].is_empty():
return return
segmentation_result = packet_getter.get_image_list(
output_packets[_SEGMENTATION_OUT_STREAM_NAME] segmentation_result = ImageSegmenterResult()
)
if options.output_confidence_masks:
segmentation_result.confidence_masks = packet_getter.get_image_list(
output_packets[_CONFIDENCE_MASKS_STREAM_NAME]
)
if options.output_category_mask:
segmentation_result.category_mask = packet_getter.get_image(
output_packets[_CATEGORY_MASK_STREAM_NAME]
)
image = packet_getter.get_image(output_packets[_IMAGE_OUT_STREAM_NAME]) image = packet_getter.get_image(output_packets[_IMAGE_OUT_STREAM_NAME])
timestamp = output_packets[_SEGMENTATION_OUT_STREAM_NAME].timestamp timestamp = output_packets[_IMAGE_OUT_STREAM_NAME].timestamp
options.result_callback( options.result_callback(
segmentation_result, segmentation_result,
image, image,
timestamp.value // _MICRO_SECONDS_PER_MILLISECOND, timestamp.value // _MICRO_SECONDS_PER_MILLISECOND,
) )
output_streams = [
':'.join([_IMAGE_TAG, _IMAGE_OUT_STREAM_NAME]),
]
if options.output_confidence_masks:
output_streams.append(
':'.join([_CONFIDENCE_MASKS_TAG, _CONFIDENCE_MASKS_STREAM_NAME])
)
if options.output_category_mask:
output_streams.append(
':'.join([_CATEGORY_MASK_TAG, _CATEGORY_MASK_STREAM_NAME])
)
task_info = _TaskInfo( task_info = _TaskInfo(
task_graph=_TASK_GRAPH_NAME, task_graph=_TASK_GRAPH_NAME,
input_streams=[ input_streams=[
':'.join([_IMAGE_TAG, _IMAGE_IN_STREAM_NAME]), ':'.join([_IMAGE_TAG, _IMAGE_IN_STREAM_NAME]),
':'.join([_NORM_RECT_TAG, _NORM_RECT_STREAM_NAME]), ':'.join([_NORM_RECT_TAG, _NORM_RECT_STREAM_NAME]),
], ],
output_streams=[ output_streams=output_streams,
':'.join([_SEGMENTATION_TAG, _SEGMENTATION_OUT_STREAM_NAME]),
':'.join([_IMAGE_TAG, _IMAGE_OUT_STREAM_NAME]),
],
task_options=options, task_options=options,
) )
return cls( return cls(
@ -240,9 +263,18 @@ class ImageSegmenter(base_vision_task_api.BaseVisionTaskApi):
normalized_rect.to_pb2() normalized_rect.to_pb2()
), ),
}) })
segmentation_result = packet_getter.get_image_list( segmentation_result = ImageSegmenterResult()
output_packets[_SEGMENTATION_OUT_STREAM_NAME]
) if _CONFIDENCE_MASKS_STREAM_NAME in output_packets:
segmentation_result.confidence_masks = packet_getter.get_image_list(
output_packets[_CONFIDENCE_MASKS_STREAM_NAME]
)
if _CATEGORY_MASK_STREAM_NAME in output_packets:
segmentation_result.category_mask = packet_getter.get_image(
output_packets[_CATEGORY_MASK_STREAM_NAME]
)
return segmentation_result return segmentation_result
def segment_for_video( def segment_for_video(
@ -285,9 +317,19 @@ class ImageSegmenter(base_vision_task_api.BaseVisionTaskApi):
normalized_rect.to_pb2() normalized_rect.to_pb2()
).at(timestamp_ms * _MICRO_SECONDS_PER_MILLISECOND), ).at(timestamp_ms * _MICRO_SECONDS_PER_MILLISECOND),
}) })
segmentation_result = packet_getter.get_image_list( segmentation_result = ImageSegmenterResult()
output_packets[_SEGMENTATION_OUT_STREAM_NAME]
) if _CONFIDENCE_MASKS_STREAM_NAME in output_packets:
segmentation_result.confidence_masks = packet_getter.get_image_list(
output_packets[_CONFIDENCE_MASKS_STREAM_NAME]
)
if _CATEGORY_MASK_STREAM_NAME in output_packets:
segmentation_result.category_mask = packet_getter.get_image(
output_packets[_CATEGORY_MASK_STREAM_NAME]
)
return segmentation_result return segmentation_result
def segment_async( def segment_async(

View File

@ -41,8 +41,10 @@ _RunningMode = vision_task_running_mode.VisionTaskRunningMode
_ImageProcessingOptions = image_processing_options_module.ImageProcessingOptions _ImageProcessingOptions = image_processing_options_module.ImageProcessingOptions
_TaskInfo = task_info_module.TaskInfo _TaskInfo = task_info_module.TaskInfo
_SEGMENTATION_OUT_STREAM_NAME = 'segmented_mask_out' _CONFIDENCE_MASKS_STREAM_NAME = 'confidence_masks'
_SEGMENTATION_TAG = 'GROUPED_SEGMENTATION' _CONFIDENCE_MASKS_TAG = 'CONFIDENCE_MASKS'
_CATEGORY_MASK_STREAM_NAME = 'category_mask'
_CATEGORY_MASK_TAG = 'CATEGORY_MASK'
_IMAGE_IN_STREAM_NAME = 'image_in' _IMAGE_IN_STREAM_NAME = 'image_in'
_IMAGE_OUT_STREAM_NAME = 'image_out' _IMAGE_OUT_STREAM_NAME = 'image_out'
_ROI_STREAM_NAME = 'roi_in' _ROI_STREAM_NAME = 'roi_in'
@ -55,32 +57,32 @@ _TASK_GRAPH_NAME = (
) )
@dataclasses.dataclass
class InteractiveSegmenterResult:
confidence_masks: Optional[List[image_module.Image]] = None
category_mask: Optional[image_module.Image] = None
@dataclasses.dataclass @dataclasses.dataclass
class InteractiveSegmenterOptions: class InteractiveSegmenterOptions:
"""Options for the interactive segmenter task. """Options for the interactive segmenter task.
Attributes: Attributes:
base_options: Base options for the interactive segmenter task. base_options: Base options for the interactive segmenter task.
output_type: The output mask type allows specifying the type of output_confidence_masks: Whether to output confidence masks.
post-processing to perform on the raw model results. output_category_mask: Whether to output category mask.
""" """
class OutputType(enum.Enum):
UNSPECIFIED = 0
CATEGORY_MASK = 1
CONFIDENCE_MASK = 2
base_options: _BaseOptions base_options: _BaseOptions
output_type: Optional[OutputType] = OutputType.CATEGORY_MASK output_confidence_masks: bool = True
output_category_mask: bool = False
@doc_controls.do_not_generate_docs @doc_controls.do_not_generate_docs
def to_pb2(self) -> _ImageSegmenterGraphOptionsProto: def to_pb2(self) -> _ImageSegmenterGraphOptionsProto:
"""Generates an InteractiveSegmenterOptions protobuf object.""" """Generates an InteractiveSegmenterOptions protobuf object."""
base_options_proto = self.base_options.to_pb2() base_options_proto = self.base_options.to_pb2()
base_options_proto.use_stream_mode = False base_options_proto.use_stream_mode = False
segmenter_options_proto = _SegmenterOptionsProto( segmenter_options_proto = _SegmenterOptionsProto()
output_type=self.output_type.value
)
return _ImageSegmenterGraphOptionsProto( return _ImageSegmenterGraphOptionsProto(
base_options=base_options_proto, base_options=base_options_proto,
segmenter_options=segmenter_options_proto, segmenter_options=segmenter_options_proto,
@ -192,6 +194,20 @@ class InteractiveSegmenter(base_vision_task_api.BaseVisionTaskApi):
RuntimeError: If other types of error occurred. RuntimeError: If other types of error occurred.
""" """
output_streams = [
':'.join([_IMAGE_TAG, _IMAGE_OUT_STREAM_NAME]),
]
if options.output_confidence_masks:
output_streams.append(
':'.join([_CONFIDENCE_MASKS_TAG, _CONFIDENCE_MASKS_STREAM_NAME])
)
if options.output_category_mask:
output_streams.append(
':'.join([_CATEGORY_MASK_TAG, _CATEGORY_MASK_STREAM_NAME])
)
task_info = _TaskInfo( task_info = _TaskInfo(
task_graph=_TASK_GRAPH_NAME, task_graph=_TASK_GRAPH_NAME,
input_streams=[ input_streams=[
@ -199,10 +215,7 @@ class InteractiveSegmenter(base_vision_task_api.BaseVisionTaskApi):
':'.join([_ROI_TAG, _ROI_STREAM_NAME]), ':'.join([_ROI_TAG, _ROI_STREAM_NAME]),
':'.join([_NORM_RECT_TAG, _NORM_RECT_STREAM_NAME]), ':'.join([_NORM_RECT_TAG, _NORM_RECT_STREAM_NAME]),
], ],
output_streams=[ output_streams=output_streams,
':'.join([_SEGMENTATION_TAG, _SEGMENTATION_OUT_STREAM_NAME]),
':'.join([_IMAGE_TAG, _IMAGE_OUT_STREAM_NAME]),
],
task_options=options, task_options=options,
) )
return cls( return cls(
@ -216,7 +229,7 @@ class InteractiveSegmenter(base_vision_task_api.BaseVisionTaskApi):
image: image_module.Image, image: image_module.Image,
roi: RegionOfInterest, roi: RegionOfInterest,
image_processing_options: Optional[_ImageProcessingOptions] = None, image_processing_options: Optional[_ImageProcessingOptions] = None,
) -> List[image_module.Image]: ) -> InteractiveSegmenterResult:
"""Performs the actual segmentation task on the provided MediaPipe Image. """Performs the actual segmentation task on the provided MediaPipe Image.
The image can be of any size with format RGB. The image can be of any size with format RGB.
@ -248,7 +261,16 @@ class InteractiveSegmenter(base_vision_task_api.BaseVisionTaskApi):
normalized_rect.to_pb2() normalized_rect.to_pb2()
), ),
}) })
segmentation_result = packet_getter.get_image_list( segmentation_result = InteractiveSegmenterResult()
output_packets[_SEGMENTATION_OUT_STREAM_NAME]
) if _CONFIDENCE_MASKS_STREAM_NAME in output_packets:
segmentation_result.confidence_masks = packet_getter.get_image_list(
output_packets[_CONFIDENCE_MASKS_STREAM_NAME]
)
if _CATEGORY_MASK_STREAM_NAME in output_packets:
segmentation_result.category_mask = packet_getter.get_image(
output_packets[_CATEGORY_MASK_STREAM_NAME]
)
return segmentation_result return segmentation_result

View File

@ -77,6 +77,7 @@ mediapipe_files(srcs = [
"portrait_selfie_segmentation_landscape_expected_category_mask.jpg", "portrait_selfie_segmentation_landscape_expected_category_mask.jpg",
"pose.jpg", "pose.jpg",
"pose_detection.tflite", "pose_detection.tflite",
"ptm_512_hdt_ptm_woid.tflite",
"pose_landmark_lite.tflite", "pose_landmark_lite.tflite",
"pose_landmarker.task", "pose_landmarker.task",
"right_hands.jpg", "right_hands.jpg",
@ -187,6 +188,7 @@ filegroup(
"mobilenet_v3_small_100_224_embedder.tflite", "mobilenet_v3_small_100_224_embedder.tflite",
"palm_detection_full.tflite", "palm_detection_full.tflite",
"pose_detection.tflite", "pose_detection.tflite",
"ptm_512_hdt_ptm_woid.tflite",
"pose_landmark_lite.tflite", "pose_landmark_lite.tflite",
"pose_landmarker.task", "pose_landmarker.task",
"selfie_segm_128_128_3.tflite", "selfie_segm_128_128_3.tflite",