Added Image Segmenter implementation and tests
This commit is contained in:
parent
e028b24c42
commit
d25626ff63
|
@ -35,7 +35,7 @@ pybind_extension(
|
|||
}),
|
||||
module_name = "_framework_bindings",
|
||||
deps = [
|
||||
":builtin_calculators",
|
||||
#":builtin_calculators",
|
||||
":builtin_task_graphs",
|
||||
"//mediapipe/python/pybind:calculator_graph",
|
||||
"//mediapipe/python/pybind:image",
|
||||
|
@ -85,6 +85,7 @@ cc_library(
|
|||
cc_library(
|
||||
name = "builtin_task_graphs",
|
||||
deps = [
|
||||
"//mediapipe/calculators/core:flow_limiter_calculator",
|
||||
"//mediapipe/tasks/cc/vision/object_detector:object_detector_graph",
|
||||
"//mediapipe/tasks/cc/vision/image_segmenter:image_segmenter_graph",
|
||||
],
|
||||
|
|
|
@ -47,6 +47,7 @@ py_test(
|
|||
],
|
||||
deps = [
|
||||
# build rule placeholder: numpy dep,
|
||||
# build rule placeholder: cv2 dep,
|
||||
"//mediapipe/tasks/python/core:base_options",
|
||||
"//mediapipe/tasks/python/test:test_util",
|
||||
"//mediapipe/tasks/python/components:segmenter_options",
|
||||
|
|
|
@ -14,11 +14,14 @@
|
|||
"""Tests for image segmenter."""
|
||||
|
||||
import enum
|
||||
import numpy as np
|
||||
import cv2
|
||||
|
||||
from absl.testing import absltest
|
||||
from absl.testing import parameterized
|
||||
|
||||
from mediapipe.python._framework_bindings import image as image_module
|
||||
from mediapipe.python._framework_bindings import image_frame as image_frame_module
|
||||
from mediapipe.tasks.python.components import segmenter_options
|
||||
from mediapipe.tasks.python.core import base_options as base_options_module
|
||||
from mediapipe.tasks.python.test import test_util
|
||||
|
@ -27,6 +30,7 @@ from mediapipe.tasks.python.vision.core import vision_task_running_mode as runni
|
|||
|
||||
_BaseOptions = base_options_module.BaseOptions
|
||||
_Image = image_module.Image
|
||||
_ImageFormat = image_frame_module.ImageFormat
|
||||
_OutputType = segmenter_options.OutputType
|
||||
_Activation = segmenter_options.Activation
|
||||
_SegmenterOptions = segmenter_options.SegmenterOptions
|
||||
|
@ -41,6 +45,13 @@ _MASK_MAGNIFICATION_FACTOR = 10
|
|||
_MATCH_PIXELS_THRESHOLD = 0.01
|
||||
|
||||
|
||||
def _iou(ground_truth, prediction):
|
||||
intersection = np.logical_and(ground_truth, prediction)
|
||||
union = np.logical_or(ground_truth, prediction)
|
||||
iou = np.sum(intersection) / np.sum(union)
|
||||
return iou
|
||||
|
||||
|
||||
class ModelFileType(enum.Enum):
|
||||
FILE_CONTENT = 1
|
||||
FILE_NAME = 2
|
||||
|
@ -52,6 +63,7 @@ class ImageSegmenterTest(parameterized.TestCase):
|
|||
super().setUp()
|
||||
self.test_image = test_util.read_test_image(
|
||||
test_util.get_test_data_path(_IMAGE_FILE))
|
||||
self.test_seg_path = test_util.get_test_data_path(_SEGMENTATION_FILE)
|
||||
self.model_path = test_util.get_test_data_path(_MODEL_FILE)
|
||||
|
||||
def test_create_from_file_succeeds_with_valid_model_path(self):
|
||||
|
@ -85,9 +97,9 @@ class ImageSegmenterTest(parameterized.TestCase):
|
|||
self.assertIsInstance(segmenter, _ImageSegmenter)
|
||||
|
||||
@parameterized.parameters(
|
||||
(ModelFileType.FILE_NAME, 4),
|
||||
(ModelFileType.FILE_CONTENT, 4))
|
||||
def succeeds_with_category_mask(self, model_file_type, max_results):
|
||||
(ModelFileType.FILE_NAME,),
|
||||
(ModelFileType.FILE_CONTENT,))
|
||||
def test_succeeds_with_category_mask(self, model_file_type):
|
||||
# Creates segmenter.
|
||||
if model_file_type is ModelFileType.FILE_NAME:
|
||||
base_options = _BaseOptions(file_name=self.model_path)
|
||||
|
@ -106,6 +118,78 @@ class ImageSegmenterTest(parameterized.TestCase):
|
|||
|
||||
# Performs image segmentation on the input.
|
||||
category_masks = segmenter.segment(self.test_image)
|
||||
self.assertEqual(len(category_masks), 1)
|
||||
result_pixels = category_masks[0].numpy_view().flatten()
|
||||
|
||||
# Check if data type of `category_masks` is correct.
|
||||
self.assertEqual(result_pixels.dtype, np.uint8)
|
||||
|
||||
# Loads ground truth segmentation file.
|
||||
image_data = cv2.imread(self.test_seg_path, cv2.IMREAD_GRAYSCALE)
|
||||
gt_segmentation = _Image(_ImageFormat.GRAY8, image_data)
|
||||
gt_segmentation_array = gt_segmentation.numpy_view()
|
||||
gt_segmentation_shape = gt_segmentation_array.shape
|
||||
num_pixels = gt_segmentation_shape[0] * gt_segmentation_shape[1]
|
||||
ground_truth_pixels = gt_segmentation_array.flatten()
|
||||
|
||||
self.assertEqual(
|
||||
len(result_pixels), len(ground_truth_pixels),
|
||||
'Segmentation mask size does not match the ground truth mask size.')
|
||||
|
||||
inconsistent_pixels = 0
|
||||
|
||||
for index in range(num_pixels):
|
||||
inconsistent_pixels += (
|
||||
result_pixels[index] * _MASK_MAGNIFICATION_FACTOR !=
|
||||
ground_truth_pixels[index])
|
||||
|
||||
self.assertLessEqual(
|
||||
inconsistent_pixels / num_pixels, _MATCH_PIXELS_THRESHOLD,
|
||||
f'Number of pixels in the candidate mask differing from that of the '
|
||||
f'ground truth mask exceeds {_MATCH_PIXELS_THRESHOLD}.')
|
||||
|
||||
# Closes the segmenter explicitly when the segmenter is not used in
|
||||
# a context.
|
||||
segmenter.close()
|
||||
|
||||
def test_succeeds_with_confidence_mask(self):
|
||||
# Creates segmenter.
|
||||
base_options = _BaseOptions(file_name=self.model_path)
|
||||
|
||||
# Run segmentation on the model in CATEGORY_MASK mode.
|
||||
segmenter_options = _SegmenterOptions(output_type=_OutputType.CATEGORY_MASK)
|
||||
options = _ImageSegmenterOptions(base_options=base_options,
|
||||
segmenter_options=segmenter_options)
|
||||
segmenter = _ImageSegmenter.create_from_options(options)
|
||||
category_masks = segmenter.segment(self.test_image)
|
||||
category_mask = category_masks[0].numpy_view()
|
||||
|
||||
# Run segmentation on the model in CONFIDENCE_MASK mode.
|
||||
segmenter_options = _SegmenterOptions(
|
||||
output_type=_OutputType.CONFIDENCE_MASK,
|
||||
activation=_Activation.SOFTMAX)
|
||||
options = _ImageSegmenterOptions(base_options=base_options,
|
||||
segmenter_options=segmenter_options)
|
||||
segmenter = _ImageSegmenter.create_from_options(options)
|
||||
confidence_masks = segmenter.segment(self.test_image)
|
||||
|
||||
# Check if confidence mask shape is correct.
|
||||
self.assertEqual(
|
||||
len(confidence_masks), 21,
|
||||
'Number of confidence masks must match with number of categories.')
|
||||
|
||||
# Gather the confidence masks in a single array `confidence_mask_array`.
|
||||
confidence_mask_array = np.array(
|
||||
[confidence_mask.numpy_view() for confidence_mask in confidence_masks])
|
||||
|
||||
# Check if data type of `confidence_masks` are correct.
|
||||
self.assertEqual(confidence_mask_array.dtype, np.float32)
|
||||
|
||||
# Compute the category mask from the created confidence mask.
|
||||
calculated_category_mask = np.argmax(confidence_mask_array, axis=0)
|
||||
self.assertListEqual(
|
||||
calculated_category_mask.tolist(), category_mask.tolist(),
|
||||
'Confidence mask does not match with the category mask.')
|
||||
|
||||
# Closes the segmenter explicitly when the segmenter is not used in
|
||||
# a context.
|
||||
|
|
|
@ -128,7 +128,7 @@ class ImageSegmenter(base_vision_task_api.BaseVisionTaskApi):
|
|||
def packets_callback(output_packets: Mapping[str, packet_module.Packet]):
|
||||
if output_packets[_IMAGE_OUT_STREAM_NAME].is_empty():
|
||||
return
|
||||
segmentation_result = packet_getter.get_proto_list(
|
||||
segmentation_result = packet_getter.get_image_list(
|
||||
output_packets[_SEGMENTATION_OUT_STREAM_NAME])
|
||||
image = packet_getter.get_image(output_packets[_IMAGE_OUT_STREAM_NAME])
|
||||
timestamp = output_packets[_IMAGE_OUT_STREAM_NAME].timestamp
|
||||
|
@ -166,6 +166,6 @@ class ImageSegmenter(base_vision_task_api.BaseVisionTaskApi):
|
|||
"""
|
||||
output_packets = self._process_image_data(
|
||||
{_IMAGE_IN_STREAM_NAME: packet_creator.create_image(image)})
|
||||
segmentation_result = packet_getter.get_proto_list(
|
||||
segmentation_result = packet_getter.get_image_list(
|
||||
output_packets[_SEGMENTATION_OUT_STREAM_NAME])
|
||||
return segmentation_result
|
||||
|
|
Loading…
Reference in New Issue
Block a user