Merge pull request #4302 from kinaryml:segmenter-python-add-labels

PiperOrigin-RevId: 525571089
This commit is contained in:
Copybara-Service 2023-04-19 15:46:20 -07:00
commit 44aa607e06
3 changed files with 96 additions and 2 deletions

View File

@ -45,6 +45,29 @@ _CAT_MASK = 'cat_mask.jpg'
_MASK_MAGNIFICATION_FACTOR = 10
_MASK_SIMILARITY_THRESHOLD = 0.98
_TEST_DATA_DIR = 'mediapipe/tasks/testdata/vision'
_EXPECTED_LABELS = [
'background',
'aeroplane',
'bicycle',
'bird',
'boat',
'bottle',
'bus',
'car',
'cat',
'chair',
'cow',
'dining table',
'dog',
'horse',
'motorbike',
'person',
'potted plant',
'sheep',
'sofa',
'train',
'tv',
]
def _calculate_soft_iou(m1, m2):
@ -224,6 +247,20 @@ class ImageSegmenterTest(parameterized.TestCase):
)
)
@parameterized.parameters((True, False), (False, True))
def test_labels_succeeds(self, output_category_mask, output_confidence_masks):
expected_labels = _EXPECTED_LABELS
base_options = _BaseOptions(model_asset_path=self.model_path)
options = _ImageSegmenterOptions(
base_options=base_options,
output_category_mask=output_category_mask,
output_confidence_masks=output_confidence_masks,
)
with _ImageSegmenter.create_from_options(options) as segmenter:
# Performs image segmentation on the input.
actual_labels = segmenter.labels
self.assertListEqual(actual_labels, expected_labels)
def test_missing_result_callback(self):
options = _ImageSegmenterOptions(
base_options=_BaseOptions(model_asset_path=self.model_path),

View File

@ -71,6 +71,7 @@ py_library(
"//mediapipe/python:_framework_bindings",
"//mediapipe/python:packet_creator",
"//mediapipe/python:packet_getter",
"//mediapipe/tasks/cc/vision/image_segmenter/calculators:tensors_to_segmentation_calculator_py_pb2",
"//mediapipe/tasks/cc/vision/image_segmenter/proto:image_segmenter_graph_options_py_pb2",
"//mediapipe/tasks/cc/vision/image_segmenter/proto:segmenter_options_py_pb2",
"//mediapipe/tasks/python/components/containers:rect",

View File

@ -20,6 +20,7 @@ from mediapipe.python import packet_creator
from mediapipe.python import packet_getter
from mediapipe.python._framework_bindings import image as image_module
from mediapipe.python._framework_bindings import packet
from mediapipe.tasks.cc.vision.image_segmenter.calculators import tensors_to_segmentation_calculator_pb2
from mediapipe.tasks.cc.vision.image_segmenter.proto import image_segmenter_graph_options_pb2
from mediapipe.tasks.cc.vision.image_segmenter.proto import segmenter_options_pb2
from mediapipe.tasks.python.components.containers import rect
@ -36,6 +37,9 @@ _SegmenterOptionsProto = segmenter_options_pb2.SegmenterOptions
_ImageSegmenterGraphOptionsProto = (
image_segmenter_graph_options_pb2.ImageSegmenterGraphOptions
)
TensorsToSegmentationCalculatorOptionsProto = (
tensors_to_segmentation_calculator_pb2.TensorsToSegmentationCalculatorOptions
)
_RunningMode = vision_task_running_mode.VisionTaskRunningMode
_ImageProcessingOptions = image_processing_options_module.ImageProcessingOptions
_TaskInfo = task_info_module.TaskInfo
@ -49,6 +53,9 @@ _IMAGE_OUT_STREAM_NAME = 'image_out'
_IMAGE_TAG = 'IMAGE'
_NORM_RECT_STREAM_NAME = 'norm_rect_in'
_NORM_RECT_TAG = 'NORM_RECT'
_TENSORS_TO_SEGMENTATION_CALCULATOR_NAME = (
'mediapipe.tasks.TensorsToSegmentationCalculator'
)
_TASK_GRAPH_NAME = 'mediapipe.tasks.vision.image_segmenter.ImageSegmenterGraph'
_MICRO_SECONDS_PER_MILLISECOND = 1000
@ -124,8 +131,8 @@ class ImageSegmenter(base_vision_task_api.BaseVisionTaskApi):
Output tensors:
(kTfLiteUInt8/kTfLiteFloat32)
- list of segmented masks.
- if `output_type` is CATEGORY_MASK, uint8 Image, Image vector of size 1.
- if `output_type` is CONFIDENCE_MASK, float32 Image list of size
- if `output_category_mask` is True, uint8 Image, Image vector of size 1.
- if `output_confidence_masks` is True, float32 Image list of size
`channels`.
- batch is always 1
@ -133,6 +140,41 @@ class ImageSegmenter(base_vision_task_api.BaseVisionTaskApi):
https://tfhub.dev/tensorflow/lite-model/deeplabv3/1/metadata/2
"""
def __init__(self, graph_config, running_mode, packet_callback) -> None:
"""Initializes the `ImageSegmenter` object."""
super(ImageSegmenter, self).__init__(
graph_config, running_mode, packet_callback
)
self._populate_labels()
def _populate_labels(self) -> None:
"""Populate the labelmap in TensorsToSegmentationCalculator to labels field.
Raises:
ValueError if there is an error during finding
TensorsToSegmentationCalculator.
"""
self._labels = []
graph_config = self._runner.get_graph_config()
found_tensors_to_segmentation = False
for node in graph_config.node:
if _TENSORS_TO_SEGMENTATION_CALCULATOR_NAME in node.name:
if found_tensors_to_segmentation:
raise ValueError(
'The graph has more than one '
f'{_TENSORS_TO_SEGMENTATION_CALCULATOR_NAME}.'
)
found_tensors_to_segmentation = True
options = node.options.Extensions[
TensorsToSegmentationCalculatorOptionsProto.ext
]
if options.label_items:
for i in range(len(options.label_items)):
if i not in options.label_items:
raise ValueError(f'The labelmap has no expected key: {i}.')
self._labels.append(options.label_items[i].name)
@classmethod
def create_from_model_path(cls, model_path: str) -> 'ImageSegmenter':
"""Creates an `ImageSegmenter` object from a TensorFlow Lite model and the default `ImageSegmenterOptions`.
@ -375,3 +417,17 @@ class ImageSegmenter(base_vision_task_api.BaseVisionTaskApi):
normalized_rect.to_pb2()
).at(timestamp_ms * _MICRO_SECONDS_PER_MILLISECOND),
})
@property
def labels(self) -> List[str]:
"""Get the category label list the ImageSegmenter can recognize.
For CATEGORY_MASK type, the index in the category mask corresponds to the
category in the label list.
For CONFIDENCE_MASK type, the output mask list at index corresponds to the
category in the label list.
If there is no label map provided in the model file, empty label list is
returned.
"""
return self._labels