Added Image Segmenter implementation and tests

This commit is contained in:
kinaryml 2022-09-25 09:16:13 -07:00
parent e028b24c42
commit d25626ff63
4 changed files with 92 additions and 6 deletions

View File

@ -35,7 +35,7 @@ pybind_extension(
}), }),
module_name = "_framework_bindings", module_name = "_framework_bindings",
deps = [ deps = [
":builtin_calculators", #":builtin_calculators",
":builtin_task_graphs", ":builtin_task_graphs",
"//mediapipe/python/pybind:calculator_graph", "//mediapipe/python/pybind:calculator_graph",
"//mediapipe/python/pybind:image", "//mediapipe/python/pybind:image",
@ -85,6 +85,7 @@ cc_library(
cc_library( cc_library(
name = "builtin_task_graphs", name = "builtin_task_graphs",
deps = [ deps = [
"//mediapipe/calculators/core:flow_limiter_calculator",
"//mediapipe/tasks/cc/vision/object_detector:object_detector_graph", "//mediapipe/tasks/cc/vision/object_detector:object_detector_graph",
"//mediapipe/tasks/cc/vision/image_segmenter:image_segmenter_graph", "//mediapipe/tasks/cc/vision/image_segmenter:image_segmenter_graph",
], ],

View File

@ -47,6 +47,7 @@ py_test(
], ],
deps = [ deps = [
# build rule placeholder: numpy dep, # build rule placeholder: numpy dep,
# build rule placeholder: cv2 dep,
"//mediapipe/tasks/python/core:base_options", "//mediapipe/tasks/python/core:base_options",
"//mediapipe/tasks/python/test:test_util", "//mediapipe/tasks/python/test:test_util",
"//mediapipe/tasks/python/components:segmenter_options", "//mediapipe/tasks/python/components:segmenter_options",

View File

@ -14,11 +14,14 @@
"""Tests for image segmenter.""" """Tests for image segmenter."""
import enum import enum
import numpy as np
import cv2
from absl.testing import absltest from absl.testing import absltest
from absl.testing import parameterized from absl.testing import parameterized
from mediapipe.python._framework_bindings import image as image_module from mediapipe.python._framework_bindings import image as image_module
from mediapipe.python._framework_bindings import image_frame as image_frame_module
from mediapipe.tasks.python.components import segmenter_options from mediapipe.tasks.python.components import segmenter_options
from mediapipe.tasks.python.core import base_options as base_options_module from mediapipe.tasks.python.core import base_options as base_options_module
from mediapipe.tasks.python.test import test_util from mediapipe.tasks.python.test import test_util
@ -27,6 +30,7 @@ from mediapipe.tasks.python.vision.core import vision_task_running_mode as runni
_BaseOptions = base_options_module.BaseOptions _BaseOptions = base_options_module.BaseOptions
_Image = image_module.Image _Image = image_module.Image
_ImageFormat = image_frame_module.ImageFormat
_OutputType = segmenter_options.OutputType _OutputType = segmenter_options.OutputType
_Activation = segmenter_options.Activation _Activation = segmenter_options.Activation
_SegmenterOptions = segmenter_options.SegmenterOptions _SegmenterOptions = segmenter_options.SegmenterOptions
@ -41,6 +45,13 @@ _MASK_MAGNIFICATION_FACTOR = 10
_MATCH_PIXELS_THRESHOLD = 0.01 _MATCH_PIXELS_THRESHOLD = 0.01
def _iou(ground_truth, prediction):
intersection = np.logical_and(ground_truth, prediction)
union = np.logical_or(ground_truth, prediction)
iou = np.sum(intersection) / np.sum(union)
return iou
class ModelFileType(enum.Enum): class ModelFileType(enum.Enum):
FILE_CONTENT = 1 FILE_CONTENT = 1
FILE_NAME = 2 FILE_NAME = 2
@ -52,6 +63,7 @@ class ImageSegmenterTest(parameterized.TestCase):
super().setUp() super().setUp()
self.test_image = test_util.read_test_image( self.test_image = test_util.read_test_image(
test_util.get_test_data_path(_IMAGE_FILE)) test_util.get_test_data_path(_IMAGE_FILE))
self.test_seg_path = test_util.get_test_data_path(_SEGMENTATION_FILE)
self.model_path = test_util.get_test_data_path(_MODEL_FILE) self.model_path = test_util.get_test_data_path(_MODEL_FILE)
def test_create_from_file_succeeds_with_valid_model_path(self): def test_create_from_file_succeeds_with_valid_model_path(self):
@ -85,9 +97,9 @@ class ImageSegmenterTest(parameterized.TestCase):
self.assertIsInstance(segmenter, _ImageSegmenter) self.assertIsInstance(segmenter, _ImageSegmenter)
@parameterized.parameters( @parameterized.parameters(
(ModelFileType.FILE_NAME, 4), (ModelFileType.FILE_NAME,),
(ModelFileType.FILE_CONTENT, 4)) (ModelFileType.FILE_CONTENT,))
def succeeds_with_category_mask(self, model_file_type, max_results): def test_succeeds_with_category_mask(self, model_file_type):
# Creates segmenter. # Creates segmenter.
if model_file_type is ModelFileType.FILE_NAME: if model_file_type is ModelFileType.FILE_NAME:
base_options = _BaseOptions(file_name=self.model_path) base_options = _BaseOptions(file_name=self.model_path)
@ -106,6 +118,78 @@ class ImageSegmenterTest(parameterized.TestCase):
# Performs image segmentation on the input. # Performs image segmentation on the input.
category_masks = segmenter.segment(self.test_image) category_masks = segmenter.segment(self.test_image)
self.assertEqual(len(category_masks), 1)
result_pixels = category_masks[0].numpy_view().flatten()
# Check if data type of `category_masks` is correct.
self.assertEqual(result_pixels.dtype, np.uint8)
# Loads ground truth segmentation file.
image_data = cv2.imread(self.test_seg_path, cv2.IMREAD_GRAYSCALE)
gt_segmentation = _Image(_ImageFormat.GRAY8, image_data)
gt_segmentation_array = gt_segmentation.numpy_view()
gt_segmentation_shape = gt_segmentation_array.shape
num_pixels = gt_segmentation_shape[0] * gt_segmentation_shape[1]
ground_truth_pixels = gt_segmentation_array.flatten()
self.assertEqual(
len(result_pixels), len(ground_truth_pixels),
'Segmentation mask size does not match the ground truth mask size.')
inconsistent_pixels = 0
for index in range(num_pixels):
inconsistent_pixels += (
result_pixels[index] * _MASK_MAGNIFICATION_FACTOR !=
ground_truth_pixels[index])
self.assertLessEqual(
inconsistent_pixels / num_pixels, _MATCH_PIXELS_THRESHOLD,
f'Number of pixels in the candidate mask differing from that of the '
f'ground truth mask exceeds {_MATCH_PIXELS_THRESHOLD}.')
# Closes the segmenter explicitly when the segmenter is not used in
# a context.
segmenter.close()
def test_succeeds_with_confidence_mask(self):
# Creates segmenter.
base_options = _BaseOptions(file_name=self.model_path)
# Run segmentation on the model in CATEGORY_MASK mode.
segmenter_options = _SegmenterOptions(output_type=_OutputType.CATEGORY_MASK)
options = _ImageSegmenterOptions(base_options=base_options,
segmenter_options=segmenter_options)
segmenter = _ImageSegmenter.create_from_options(options)
category_masks = segmenter.segment(self.test_image)
category_mask = category_masks[0].numpy_view()
# Run segmentation on the model in CONFIDENCE_MASK mode.
segmenter_options = _SegmenterOptions(
output_type=_OutputType.CONFIDENCE_MASK,
activation=_Activation.SOFTMAX)
options = _ImageSegmenterOptions(base_options=base_options,
segmenter_options=segmenter_options)
segmenter = _ImageSegmenter.create_from_options(options)
confidence_masks = segmenter.segment(self.test_image)
# Check if confidence mask shape is correct.
self.assertEqual(
len(confidence_masks), 21,
'Number of confidence masks must match with number of categories.')
# Gather the confidence masks in a single array `confidence_mask_array`.
confidence_mask_array = np.array(
[confidence_mask.numpy_view() for confidence_mask in confidence_masks])
# Check if data type of `confidence_masks` are correct.
self.assertEqual(confidence_mask_array.dtype, np.float32)
# Compute the category mask from the created confidence mask.
calculated_category_mask = np.argmax(confidence_mask_array, axis=0)
self.assertListEqual(
calculated_category_mask.tolist(), category_mask.tolist(),
'Confidence mask does not match with the category mask.')
# Closes the segmenter explicitly when the segmenter is not used in # Closes the segmenter explicitly when the segmenter is not used in
# a context. # a context.

View File

@ -128,7 +128,7 @@ class ImageSegmenter(base_vision_task_api.BaseVisionTaskApi):
def packets_callback(output_packets: Mapping[str, packet_module.Packet]): def packets_callback(output_packets: Mapping[str, packet_module.Packet]):
if output_packets[_IMAGE_OUT_STREAM_NAME].is_empty(): if output_packets[_IMAGE_OUT_STREAM_NAME].is_empty():
return return
segmentation_result = packet_getter.get_proto_list( segmentation_result = packet_getter.get_image_list(
output_packets[_SEGMENTATION_OUT_STREAM_NAME]) output_packets[_SEGMENTATION_OUT_STREAM_NAME])
image = packet_getter.get_image(output_packets[_IMAGE_OUT_STREAM_NAME]) image = packet_getter.get_image(output_packets[_IMAGE_OUT_STREAM_NAME])
timestamp = output_packets[_IMAGE_OUT_STREAM_NAME].timestamp timestamp = output_packets[_IMAGE_OUT_STREAM_NAME].timestamp
@ -166,6 +166,6 @@ class ImageSegmenter(base_vision_task_api.BaseVisionTaskApi):
""" """
output_packets = self._process_image_data( output_packets = self._process_image_data(
{_IMAGE_IN_STREAM_NAME: packet_creator.create_image(image)}) {_IMAGE_IN_STREAM_NAME: packet_creator.create_image(image)})
segmentation_result = packet_getter.get_proto_list( segmentation_result = packet_getter.get_image_list(
output_packets[_SEGMENTATION_OUT_STREAM_NAME]) output_packets[_SEGMENTATION_OUT_STREAM_NAME])
return segmentation_result return segmentation_result