Added Image Segmenter implementation and tests
This commit is contained in:
parent
e028b24c42
commit
d25626ff63
|
@ -35,7 +35,7 @@ pybind_extension(
|
||||||
}),
|
}),
|
||||||
module_name = "_framework_bindings",
|
module_name = "_framework_bindings",
|
||||||
deps = [
|
deps = [
|
||||||
":builtin_calculators",
|
#":builtin_calculators",
|
||||||
":builtin_task_graphs",
|
":builtin_task_graphs",
|
||||||
"//mediapipe/python/pybind:calculator_graph",
|
"//mediapipe/python/pybind:calculator_graph",
|
||||||
"//mediapipe/python/pybind:image",
|
"//mediapipe/python/pybind:image",
|
||||||
|
@ -85,6 +85,7 @@ cc_library(
|
||||||
cc_library(
|
cc_library(
|
||||||
name = "builtin_task_graphs",
|
name = "builtin_task_graphs",
|
||||||
deps = [
|
deps = [
|
||||||
|
"//mediapipe/calculators/core:flow_limiter_calculator",
|
||||||
"//mediapipe/tasks/cc/vision/object_detector:object_detector_graph",
|
"//mediapipe/tasks/cc/vision/object_detector:object_detector_graph",
|
||||||
"//mediapipe/tasks/cc/vision/image_segmenter:image_segmenter_graph",
|
"//mediapipe/tasks/cc/vision/image_segmenter:image_segmenter_graph",
|
||||||
],
|
],
|
||||||
|
|
|
@ -47,6 +47,7 @@ py_test(
|
||||||
],
|
],
|
||||||
deps = [
|
deps = [
|
||||||
# build rule placeholder: numpy dep,
|
# build rule placeholder: numpy dep,
|
||||||
|
# build rule placeholder: cv2 dep,
|
||||||
"//mediapipe/tasks/python/core:base_options",
|
"//mediapipe/tasks/python/core:base_options",
|
||||||
"//mediapipe/tasks/python/test:test_util",
|
"//mediapipe/tasks/python/test:test_util",
|
||||||
"//mediapipe/tasks/python/components:segmenter_options",
|
"//mediapipe/tasks/python/components:segmenter_options",
|
||||||
|
|
|
@ -14,11 +14,14 @@
|
||||||
"""Tests for image segmenter."""
|
"""Tests for image segmenter."""
|
||||||
|
|
||||||
import enum
|
import enum
|
||||||
|
import numpy as np
|
||||||
|
import cv2
|
||||||
|
|
||||||
from absl.testing import absltest
|
from absl.testing import absltest
|
||||||
from absl.testing import parameterized
|
from absl.testing import parameterized
|
||||||
|
|
||||||
from mediapipe.python._framework_bindings import image as image_module
|
from mediapipe.python._framework_bindings import image as image_module
|
||||||
|
from mediapipe.python._framework_bindings import image_frame as image_frame_module
|
||||||
from mediapipe.tasks.python.components import segmenter_options
|
from mediapipe.tasks.python.components import segmenter_options
|
||||||
from mediapipe.tasks.python.core import base_options as base_options_module
|
from mediapipe.tasks.python.core import base_options as base_options_module
|
||||||
from mediapipe.tasks.python.test import test_util
|
from mediapipe.tasks.python.test import test_util
|
||||||
|
@ -27,6 +30,7 @@ from mediapipe.tasks.python.vision.core import vision_task_running_mode as runni
|
||||||
|
|
||||||
_BaseOptions = base_options_module.BaseOptions
|
_BaseOptions = base_options_module.BaseOptions
|
||||||
_Image = image_module.Image
|
_Image = image_module.Image
|
||||||
|
_ImageFormat = image_frame_module.ImageFormat
|
||||||
_OutputType = segmenter_options.OutputType
|
_OutputType = segmenter_options.OutputType
|
||||||
_Activation = segmenter_options.Activation
|
_Activation = segmenter_options.Activation
|
||||||
_SegmenterOptions = segmenter_options.SegmenterOptions
|
_SegmenterOptions = segmenter_options.SegmenterOptions
|
||||||
|
@ -41,6 +45,13 @@ _MASK_MAGNIFICATION_FACTOR = 10
|
||||||
_MATCH_PIXELS_THRESHOLD = 0.01
|
_MATCH_PIXELS_THRESHOLD = 0.01
|
||||||
|
|
||||||
|
|
||||||
|
def _iou(ground_truth, prediction):
|
||||||
|
intersection = np.logical_and(ground_truth, prediction)
|
||||||
|
union = np.logical_or(ground_truth, prediction)
|
||||||
|
iou = np.sum(intersection) / np.sum(union)
|
||||||
|
return iou
|
||||||
|
|
||||||
|
|
||||||
class ModelFileType(enum.Enum):
|
class ModelFileType(enum.Enum):
|
||||||
FILE_CONTENT = 1
|
FILE_CONTENT = 1
|
||||||
FILE_NAME = 2
|
FILE_NAME = 2
|
||||||
|
@ -52,6 +63,7 @@ class ImageSegmenterTest(parameterized.TestCase):
|
||||||
super().setUp()
|
super().setUp()
|
||||||
self.test_image = test_util.read_test_image(
|
self.test_image = test_util.read_test_image(
|
||||||
test_util.get_test_data_path(_IMAGE_FILE))
|
test_util.get_test_data_path(_IMAGE_FILE))
|
||||||
|
self.test_seg_path = test_util.get_test_data_path(_SEGMENTATION_FILE)
|
||||||
self.model_path = test_util.get_test_data_path(_MODEL_FILE)
|
self.model_path = test_util.get_test_data_path(_MODEL_FILE)
|
||||||
|
|
||||||
def test_create_from_file_succeeds_with_valid_model_path(self):
|
def test_create_from_file_succeeds_with_valid_model_path(self):
|
||||||
|
@ -85,9 +97,9 @@ class ImageSegmenterTest(parameterized.TestCase):
|
||||||
self.assertIsInstance(segmenter, _ImageSegmenter)
|
self.assertIsInstance(segmenter, _ImageSegmenter)
|
||||||
|
|
||||||
@parameterized.parameters(
|
@parameterized.parameters(
|
||||||
(ModelFileType.FILE_NAME, 4),
|
(ModelFileType.FILE_NAME,),
|
||||||
(ModelFileType.FILE_CONTENT, 4))
|
(ModelFileType.FILE_CONTENT,))
|
||||||
def succeeds_with_category_mask(self, model_file_type, max_results):
|
def test_succeeds_with_category_mask(self, model_file_type):
|
||||||
# Creates segmenter.
|
# Creates segmenter.
|
||||||
if model_file_type is ModelFileType.FILE_NAME:
|
if model_file_type is ModelFileType.FILE_NAME:
|
||||||
base_options = _BaseOptions(file_name=self.model_path)
|
base_options = _BaseOptions(file_name=self.model_path)
|
||||||
|
@ -106,6 +118,78 @@ class ImageSegmenterTest(parameterized.TestCase):
|
||||||
|
|
||||||
# Performs image segmentation on the input.
|
# Performs image segmentation on the input.
|
||||||
category_masks = segmenter.segment(self.test_image)
|
category_masks = segmenter.segment(self.test_image)
|
||||||
|
self.assertEqual(len(category_masks), 1)
|
||||||
|
result_pixels = category_masks[0].numpy_view().flatten()
|
||||||
|
|
||||||
|
# Check if data type of `category_masks` is correct.
|
||||||
|
self.assertEqual(result_pixels.dtype, np.uint8)
|
||||||
|
|
||||||
|
# Loads ground truth segmentation file.
|
||||||
|
image_data = cv2.imread(self.test_seg_path, cv2.IMREAD_GRAYSCALE)
|
||||||
|
gt_segmentation = _Image(_ImageFormat.GRAY8, image_data)
|
||||||
|
gt_segmentation_array = gt_segmentation.numpy_view()
|
||||||
|
gt_segmentation_shape = gt_segmentation_array.shape
|
||||||
|
num_pixels = gt_segmentation_shape[0] * gt_segmentation_shape[1]
|
||||||
|
ground_truth_pixels = gt_segmentation_array.flatten()
|
||||||
|
|
||||||
|
self.assertEqual(
|
||||||
|
len(result_pixels), len(ground_truth_pixels),
|
||||||
|
'Segmentation mask size does not match the ground truth mask size.')
|
||||||
|
|
||||||
|
inconsistent_pixels = 0
|
||||||
|
|
||||||
|
for index in range(num_pixels):
|
||||||
|
inconsistent_pixels += (
|
||||||
|
result_pixels[index] * _MASK_MAGNIFICATION_FACTOR !=
|
||||||
|
ground_truth_pixels[index])
|
||||||
|
|
||||||
|
self.assertLessEqual(
|
||||||
|
inconsistent_pixels / num_pixels, _MATCH_PIXELS_THRESHOLD,
|
||||||
|
f'Number of pixels in the candidate mask differing from that of the '
|
||||||
|
f'ground truth mask exceeds {_MATCH_PIXELS_THRESHOLD}.')
|
||||||
|
|
||||||
|
# Closes the segmenter explicitly when the segmenter is not used in
|
||||||
|
# a context.
|
||||||
|
segmenter.close()
|
||||||
|
|
||||||
|
def test_succeeds_with_confidence_mask(self):
|
||||||
|
# Creates segmenter.
|
||||||
|
base_options = _BaseOptions(file_name=self.model_path)
|
||||||
|
|
||||||
|
# Run segmentation on the model in CATEGORY_MASK mode.
|
||||||
|
segmenter_options = _SegmenterOptions(output_type=_OutputType.CATEGORY_MASK)
|
||||||
|
options = _ImageSegmenterOptions(base_options=base_options,
|
||||||
|
segmenter_options=segmenter_options)
|
||||||
|
segmenter = _ImageSegmenter.create_from_options(options)
|
||||||
|
category_masks = segmenter.segment(self.test_image)
|
||||||
|
category_mask = category_masks[0].numpy_view()
|
||||||
|
|
||||||
|
# Run segmentation on the model in CONFIDENCE_MASK mode.
|
||||||
|
segmenter_options = _SegmenterOptions(
|
||||||
|
output_type=_OutputType.CONFIDENCE_MASK,
|
||||||
|
activation=_Activation.SOFTMAX)
|
||||||
|
options = _ImageSegmenterOptions(base_options=base_options,
|
||||||
|
segmenter_options=segmenter_options)
|
||||||
|
segmenter = _ImageSegmenter.create_from_options(options)
|
||||||
|
confidence_masks = segmenter.segment(self.test_image)
|
||||||
|
|
||||||
|
# Check if confidence mask shape is correct.
|
||||||
|
self.assertEqual(
|
||||||
|
len(confidence_masks), 21,
|
||||||
|
'Number of confidence masks must match with number of categories.')
|
||||||
|
|
||||||
|
# Gather the confidence masks in a single array `confidence_mask_array`.
|
||||||
|
confidence_mask_array = np.array(
|
||||||
|
[confidence_mask.numpy_view() for confidence_mask in confidence_masks])
|
||||||
|
|
||||||
|
# Check if data type of `confidence_masks` are correct.
|
||||||
|
self.assertEqual(confidence_mask_array.dtype, np.float32)
|
||||||
|
|
||||||
|
# Compute the category mask from the created confidence mask.
|
||||||
|
calculated_category_mask = np.argmax(confidence_mask_array, axis=0)
|
||||||
|
self.assertListEqual(
|
||||||
|
calculated_category_mask.tolist(), category_mask.tolist(),
|
||||||
|
'Confidence mask does not match with the category mask.')
|
||||||
|
|
||||||
# Closes the segmenter explicitly when the segmenter is not used in
|
# Closes the segmenter explicitly when the segmenter is not used in
|
||||||
# a context.
|
# a context.
|
||||||
|
|
|
@ -128,7 +128,7 @@ class ImageSegmenter(base_vision_task_api.BaseVisionTaskApi):
|
||||||
def packets_callback(output_packets: Mapping[str, packet_module.Packet]):
|
def packets_callback(output_packets: Mapping[str, packet_module.Packet]):
|
||||||
if output_packets[_IMAGE_OUT_STREAM_NAME].is_empty():
|
if output_packets[_IMAGE_OUT_STREAM_NAME].is_empty():
|
||||||
return
|
return
|
||||||
segmentation_result = packet_getter.get_proto_list(
|
segmentation_result = packet_getter.get_image_list(
|
||||||
output_packets[_SEGMENTATION_OUT_STREAM_NAME])
|
output_packets[_SEGMENTATION_OUT_STREAM_NAME])
|
||||||
image = packet_getter.get_image(output_packets[_IMAGE_OUT_STREAM_NAME])
|
image = packet_getter.get_image(output_packets[_IMAGE_OUT_STREAM_NAME])
|
||||||
timestamp = output_packets[_IMAGE_OUT_STREAM_NAME].timestamp
|
timestamp = output_packets[_IMAGE_OUT_STREAM_NAME].timestamp
|
||||||
|
@ -166,6 +166,6 @@ class ImageSegmenter(base_vision_task_api.BaseVisionTaskApi):
|
||||||
"""
|
"""
|
||||||
output_packets = self._process_image_data(
|
output_packets = self._process_image_data(
|
||||||
{_IMAGE_IN_STREAM_NAME: packet_creator.create_image(image)})
|
{_IMAGE_IN_STREAM_NAME: packet_creator.create_image(image)})
|
||||||
segmentation_result = packet_getter.get_proto_list(
|
segmentation_result = packet_getter.get_image_list(
|
||||||
output_packets[_SEGMENTATION_OUT_STREAM_NAME])
|
output_packets[_SEGMENTATION_OUT_STREAM_NAME])
|
||||||
return segmentation_result
|
return segmentation_result
|
||||||
|
|
Loading…
Reference in New Issue
Block a user