Added Image Segmenter implementation and tests

This commit is contained in:
kinaryml 2022-09-25 09:16:13 -07:00
parent e028b24c42
commit d25626ff63
4 changed files with 92 additions and 6 deletions

View File

@ -35,7 +35,7 @@ pybind_extension(
}),
module_name = "_framework_bindings",
deps = [
":builtin_calculators",
#":builtin_calculators",
":builtin_task_graphs",
"//mediapipe/python/pybind:calculator_graph",
"//mediapipe/python/pybind:image",
@ -85,6 +85,7 @@ cc_library(
cc_library(
name = "builtin_task_graphs",
deps = [
"//mediapipe/calculators/core:flow_limiter_calculator",
"//mediapipe/tasks/cc/vision/object_detector:object_detector_graph",
"//mediapipe/tasks/cc/vision/image_segmenter:image_segmenter_graph",
],

View File

@ -47,6 +47,7 @@ py_test(
],
deps = [
# build rule placeholder: numpy dep,
# build rule placeholder: cv2 dep,
"//mediapipe/tasks/python/core:base_options",
"//mediapipe/tasks/python/test:test_util",
"//mediapipe/tasks/python/components:segmenter_options",

View File

@ -14,11 +14,14 @@
"""Tests for image segmenter."""
import enum
import numpy as np
import cv2
from absl.testing import absltest
from absl.testing import parameterized
from mediapipe.python._framework_bindings import image as image_module
from mediapipe.python._framework_bindings import image_frame as image_frame_module
from mediapipe.tasks.python.components import segmenter_options
from mediapipe.tasks.python.core import base_options as base_options_module
from mediapipe.tasks.python.test import test_util
@ -27,6 +30,7 @@ from mediapipe.tasks.python.vision.core import vision_task_running_mode as runni
_BaseOptions = base_options_module.BaseOptions
_Image = image_module.Image
_ImageFormat = image_frame_module.ImageFormat
_OutputType = segmenter_options.OutputType
_Activation = segmenter_options.Activation
_SegmenterOptions = segmenter_options.SegmenterOptions
@ -41,6 +45,13 @@ _MASK_MAGNIFICATION_FACTOR = 10
_MATCH_PIXELS_THRESHOLD = 0.01
def _iou(ground_truth, prediction):
intersection = np.logical_and(ground_truth, prediction)
union = np.logical_or(ground_truth, prediction)
iou = np.sum(intersection) / np.sum(union)
return iou
class ModelFileType(enum.Enum):
FILE_CONTENT = 1
FILE_NAME = 2
@ -52,6 +63,7 @@ class ImageSegmenterTest(parameterized.TestCase):
super().setUp()
self.test_image = test_util.read_test_image(
test_util.get_test_data_path(_IMAGE_FILE))
self.test_seg_path = test_util.get_test_data_path(_SEGMENTATION_FILE)
self.model_path = test_util.get_test_data_path(_MODEL_FILE)
def test_create_from_file_succeeds_with_valid_model_path(self):
@ -85,9 +97,9 @@ class ImageSegmenterTest(parameterized.TestCase):
self.assertIsInstance(segmenter, _ImageSegmenter)
@parameterized.parameters(
(ModelFileType.FILE_NAME, 4),
(ModelFileType.FILE_CONTENT, 4))
def succeeds_with_category_mask(self, model_file_type, max_results):
(ModelFileType.FILE_NAME,),
(ModelFileType.FILE_CONTENT,))
def test_succeeds_with_category_mask(self, model_file_type):
# Creates segmenter.
if model_file_type is ModelFileType.FILE_NAME:
base_options = _BaseOptions(file_name=self.model_path)
@ -106,6 +118,78 @@ class ImageSegmenterTest(parameterized.TestCase):
# Performs image segmentation on the input.
category_masks = segmenter.segment(self.test_image)
self.assertEqual(len(category_masks), 1)
result_pixels = category_masks[0].numpy_view().flatten()
# Check if data type of `category_masks` is correct.
self.assertEqual(result_pixels.dtype, np.uint8)
# Loads ground truth segmentation file.
image_data = cv2.imread(self.test_seg_path, cv2.IMREAD_GRAYSCALE)
gt_segmentation = _Image(_ImageFormat.GRAY8, image_data)
gt_segmentation_array = gt_segmentation.numpy_view()
gt_segmentation_shape = gt_segmentation_array.shape
num_pixels = gt_segmentation_shape[0] * gt_segmentation_shape[1]
ground_truth_pixels = gt_segmentation_array.flatten()
self.assertEqual(
len(result_pixels), len(ground_truth_pixels),
'Segmentation mask size does not match the ground truth mask size.')
inconsistent_pixels = 0
for index in range(num_pixels):
inconsistent_pixels += (
result_pixels[index] * _MASK_MAGNIFICATION_FACTOR !=
ground_truth_pixels[index])
self.assertLessEqual(
inconsistent_pixels / num_pixels, _MATCH_PIXELS_THRESHOLD,
f'Number of pixels in the candidate mask differing from that of the '
f'ground truth mask exceeds {_MATCH_PIXELS_THRESHOLD}.')
# Closes the segmenter explicitly when the segmenter is not used in
# a context.
segmenter.close()
def test_succeeds_with_confidence_mask(self):
# Creates segmenter.
base_options = _BaseOptions(file_name=self.model_path)
# Run segmentation on the model in CATEGORY_MASK mode.
segmenter_options = _SegmenterOptions(output_type=_OutputType.CATEGORY_MASK)
options = _ImageSegmenterOptions(base_options=base_options,
segmenter_options=segmenter_options)
segmenter = _ImageSegmenter.create_from_options(options)
category_masks = segmenter.segment(self.test_image)
category_mask = category_masks[0].numpy_view()
# Run segmentation on the model in CONFIDENCE_MASK mode.
segmenter_options = _SegmenterOptions(
output_type=_OutputType.CONFIDENCE_MASK,
activation=_Activation.SOFTMAX)
options = _ImageSegmenterOptions(base_options=base_options,
segmenter_options=segmenter_options)
segmenter = _ImageSegmenter.create_from_options(options)
confidence_masks = segmenter.segment(self.test_image)
# Check if confidence mask shape is correct.
self.assertEqual(
len(confidence_masks), 21,
'Number of confidence masks must match with number of categories.')
# Gather the confidence masks in a single array `confidence_mask_array`.
confidence_mask_array = np.array(
[confidence_mask.numpy_view() for confidence_mask in confidence_masks])
# Check if data type of `confidence_masks` are correct.
self.assertEqual(confidence_mask_array.dtype, np.float32)
# Compute the category mask from the created confidence mask.
calculated_category_mask = np.argmax(confidence_mask_array, axis=0)
self.assertListEqual(
calculated_category_mask.tolist(), category_mask.tolist(),
'Confidence mask does not match with the category mask.')
# Closes the segmenter explicitly when the segmenter is not used in
# a context.

View File

@ -128,7 +128,7 @@ class ImageSegmenter(base_vision_task_api.BaseVisionTaskApi):
def packets_callback(output_packets: Mapping[str, packet_module.Packet]):
if output_packets[_IMAGE_OUT_STREAM_NAME].is_empty():
return
segmentation_result = packet_getter.get_proto_list(
segmentation_result = packet_getter.get_image_list(
output_packets[_SEGMENTATION_OUT_STREAM_NAME])
image = packet_getter.get_image(output_packets[_IMAGE_OUT_STREAM_NAME])
timestamp = output_packets[_IMAGE_OUT_STREAM_NAME].timestamp
@ -166,6 +166,6 @@ class ImageSegmenter(base_vision_task_api.BaseVisionTaskApi):
"""
output_packets = self._process_image_data(
{_IMAGE_IN_STREAM_NAME: packet_creator.create_image(image)})
segmentation_result = packet_getter.get_proto_list(
segmentation_result = packet_getter.get_image_list(
output_packets[_SEGMENTATION_OUT_STREAM_NAME])
return segmentation_result