Populate labels using model metadata for the ImageSegmenter Python API

This commit is contained in:
kinaryml 2023-04-18 02:49:13 -07:00
parent 63cd09951d
commit 723cb2a919
3 changed files with 97 additions and 0 deletions

View File

@ -45,6 +45,29 @@ _SEGMENTATION_FILE = 'segmentation_golden_rotation0.png'
_MASK_MAGNIFICATION_FACTOR = 10 _MASK_MAGNIFICATION_FACTOR = 10
_MASK_SIMILARITY_THRESHOLD = 0.98 _MASK_SIMILARITY_THRESHOLD = 0.98
_TEST_DATA_DIR = 'mediapipe/tasks/testdata/vision' _TEST_DATA_DIR = 'mediapipe/tasks/testdata/vision'
_EXPECTED_LABELS = [
"background",
"aeroplane",
"bicycle",
"bird",
"boat",
"bottle",
"bus",
"car",
"cat",
"chair",
"cow",
"dining table",
"dog",
"horse",
"motorbike",
"person",
"potted plant",
"sheep",
"sofa",
"train",
"tv"
]
def _similar_to_uint8_mask(actual_mask, expected_mask): def _similar_to_uint8_mask(actual_mask, expected_mask):
@ -214,6 +237,16 @@ class ImageSegmenterTest(parameterized.TestCase):
f'Number of pixels in the candidate mask differing from that of the ' f'Number of pixels in the candidate mask differing from that of the '
f'ground truth mask exceeds {_MASK_SIMILARITY_THRESHOLD}.') f'ground truth mask exceeds {_MASK_SIMILARITY_THRESHOLD}.')
def test_get_labels_succeeds(self):
expected_labels = _EXPECTED_LABELS
base_options = _BaseOptions(model_asset_path=self.model_path)
options = _ImageSegmenterOptions(
base_options=base_options, output_type=_OutputType.CATEGORY_MASK)
with _ImageSegmenter.create_from_options(options) as segmenter:
# Performs image segmentation on the input.
actual_labels = segmenter.get_labels()
self.assertListEqual(actual_labels, expected_labels)
def test_missing_result_callback(self): def test_missing_result_callback(self):
options = _ImageSegmenterOptions( options = _ImageSegmenterOptions(
base_options=_BaseOptions(model_asset_path=self.model_path), base_options=_BaseOptions(model_asset_path=self.model_path),

View File

@ -71,6 +71,7 @@ py_library(
"//mediapipe/python:_framework_bindings", "//mediapipe/python:_framework_bindings",
"//mediapipe/python:packet_creator", "//mediapipe/python:packet_creator",
"//mediapipe/python:packet_getter", "//mediapipe/python:packet_getter",
"//mediapipe/tasks/cc/vision/image_segmenter/calculators:tensors_to_segmentation_calculator_py_pb2",
"//mediapipe/tasks/cc/vision/image_segmenter/proto:image_segmenter_graph_options_py_pb2", "//mediapipe/tasks/cc/vision/image_segmenter/proto:image_segmenter_graph_options_py_pb2",
"//mediapipe/tasks/cc/vision/image_segmenter/proto:segmenter_options_py_pb2", "//mediapipe/tasks/cc/vision/image_segmenter/proto:segmenter_options_py_pb2",
"//mediapipe/tasks/python/components/containers:rect", "//mediapipe/tasks/python/components/containers:rect",

View File

@ -21,6 +21,7 @@ from mediapipe.python import packet_creator
from mediapipe.python import packet_getter from mediapipe.python import packet_getter
from mediapipe.python._framework_bindings import image as image_module from mediapipe.python._framework_bindings import image as image_module
from mediapipe.python._framework_bindings import packet from mediapipe.python._framework_bindings import packet
from mediapipe.tasks.cc.vision.image_segmenter.calculators import tensors_to_segmentation_calculator_pb2
from mediapipe.tasks.cc.vision.image_segmenter.proto import image_segmenter_graph_options_pb2 from mediapipe.tasks.cc.vision.image_segmenter.proto import image_segmenter_graph_options_pb2
from mediapipe.tasks.cc.vision.image_segmenter.proto import segmenter_options_pb2 from mediapipe.tasks.cc.vision.image_segmenter.proto import segmenter_options_pb2
from mediapipe.tasks.python.components.containers import rect from mediapipe.tasks.python.components.containers import rect
@ -38,6 +39,9 @@ _SegmenterOptionsProto = segmenter_options_pb2.SegmenterOptions
_ImageSegmenterGraphOptionsProto = ( _ImageSegmenterGraphOptionsProto = (
image_segmenter_graph_options_pb2.ImageSegmenterGraphOptions image_segmenter_graph_options_pb2.ImageSegmenterGraphOptions
) )
TensorsToSegmentationCalculatorOptionsProto = (
tensors_to_segmentation_calculator_pb2.TensorsToSegmentationCalculatorOptions
)
_RunningMode = vision_task_running_mode.VisionTaskRunningMode _RunningMode = vision_task_running_mode.VisionTaskRunningMode
_ImageProcessingOptions = image_processing_options_module.ImageProcessingOptions _ImageProcessingOptions = image_processing_options_module.ImageProcessingOptions
_TaskInfo = task_info_module.TaskInfo _TaskInfo = task_info_module.TaskInfo
@ -49,6 +53,7 @@ _IMAGE_OUT_STREAM_NAME = 'image_out'
_IMAGE_TAG = 'IMAGE' _IMAGE_TAG = 'IMAGE'
_NORM_RECT_STREAM_NAME = 'norm_rect_in' _NORM_RECT_STREAM_NAME = 'norm_rect_in'
_NORM_RECT_TAG = 'NORM_RECT' _NORM_RECT_TAG = 'NORM_RECT'
_TENSORS_TO_SEGMENTATION_CALCULATOR_NAME = 'mediapipe.tasks.TensorsToSegmentationCalculator'
_TASK_GRAPH_NAME = 'mediapipe.tasks.vision.image_segmenter.ImageSegmenterGraph' _TASK_GRAPH_NAME = 'mediapipe.tasks.vision.image_segmenter.ImageSegmenterGraph'
_MICRO_SECONDS_PER_MILLISECOND = 1000 _MICRO_SECONDS_PER_MILLISECOND = 1000
@ -130,6 +135,40 @@ class ImageSegmenter(base_vision_task_api.BaseVisionTaskApi):
An example of such model can be found at: An example of such model can be found at:
https://tfhub.dev/tensorflow/lite-model/deeplabv3/1/metadata/2 https://tfhub.dev/tensorflow/lite-model/deeplabv3/1/metadata/2
""" """
def __init__(self, graph_config, running_mode, packet_callback):
super(ImageSegmenter, self).__init__(
graph_config, running_mode, packet_callback
)
self._populate_labels()
def _populate_labels(self):
"""
Populate the labelmap in TensorsToSegmentationCalculator to labels field.
Returns:
Exception if there is an error during finding TensorsToSegmentationCalculator.
:return:
"""
self.labels = []
graph_config = self._runner.get_graph_config()
found_tensors_to_segmentation = False
for node in graph_config.node:
if _TENSORS_TO_SEGMENTATION_CALCULATOR_NAME in node.name:
if found_tensors_to_segmentation:
raise Exception(
f"The graph has more than one "
f"{_TENSORS_TO_SEGMENTATION_CALCULATOR_NAME}."
)
found_tensors_to_segmentation = True
options = node.options.Extensions[
TensorsToSegmentationCalculatorOptionsProto.ext
]
if options.label_items:
for i in range(len(options.label_items)):
if i not in options.label_items:
raise Exception(f"The labelmap has no expected key: {i}.")
self.labels.append(options.label_items[i].name)
@classmethod @classmethod
def create_from_model_path(cls, model_path: str) -> 'ImageSegmenter': def create_from_model_path(cls, model_path: str) -> 'ImageSegmenter':
@ -209,6 +248,30 @@ class ImageSegmenter(base_vision_task_api.BaseVisionTaskApi):
packets_callback if options.result_callback else None, packets_callback if options.result_callback else None,
) )
def get_labels(self):
""" Get the category label list of the ImageSegmenter can recognize.
For CATEGORY_MASK type, the index in the category mask corresponds to the
category in the label list.
For CONFIDENCE_MASK type, the output mask list at index corresponds to the
category in the label list.
If there is no label map provided in the model file, empty label list is
returned.
Returns:
If the output_type is CATEGORY_MASK, the returned vector of images is
per-category segmented image mask.
If the output_type is CONFIDENCE_MASK, the returned vector of images
contains only one confidence image mask. A segmentation result object that
contains a list of segmentation masks as images.
Raises:
ValueError: If any of the input arguments is invalid.
RuntimeError: If image segmentation failed to run.
"""
return self.labels
def segment( def segment(
self, self,
image: image_module.Image, image: image_module.Image,