Revised API implementation and added more tests for segment_for_video and segment_async

This commit is contained in:
kinaryml 2022-10-18 04:24:12 -07:00
parent 36ac0689d7
commit f84e0bc1c6
3 changed files with 277 additions and 40 deletions

View File

@ -47,7 +47,7 @@ py_test(
deps = [ deps = [
"//mediapipe/python:_framework_bindings", "//mediapipe/python:_framework_bindings",
"//mediapipe/tasks/python/core:base_options", "//mediapipe/tasks/python/core:base_options",
"//mediapipe/tasks/python/test:test_util", "//mediapipe/tasks/python/test:test_utils",
"//mediapipe/tasks/python/components/proto:segmenter_options", "//mediapipe/tasks/python/components/proto:segmenter_options",
"//mediapipe/tasks/python/vision:image_segmenter", "//mediapipe/tasks/python/vision:image_segmenter",
"//mediapipe/tasks/python/vision/core:vision_task_running_mode", "//mediapipe/tasks/python/vision/core:vision_task_running_mode",

View File

@ -16,6 +16,8 @@
import enum import enum
import numpy as np import numpy as np
import cv2 import cv2
from typing import List
from unittest import mock
from absl.testing import absltest from absl.testing import absltest
from absl.testing import parameterized from absl.testing import parameterized
@ -24,7 +26,7 @@ from mediapipe.python._framework_bindings import image as image_module
from mediapipe.python._framework_bindings import image_frame as image_frame_module from mediapipe.python._framework_bindings import image_frame as image_frame_module
from mediapipe.tasks.python.components.proto import segmenter_options from mediapipe.tasks.python.components.proto import segmenter_options
from mediapipe.tasks.python.core import base_options as base_options_module from mediapipe.tasks.python.core import base_options as base_options_module
from mediapipe.tasks.python.test import test_util from mediapipe.tasks.python.test import test_utils
from mediapipe.tasks.python.vision import image_segmenter from mediapipe.tasks.python.vision import image_segmenter
from mediapipe.tasks.python.vision.core import vision_task_running_mode as running_mode_module from mediapipe.tasks.python.vision.core import vision_task_running_mode as running_mode_module
@ -42,7 +44,22 @@ _MODEL_FILE = 'deeplabv3.tflite'
_IMAGE_FILE = 'segmentation_input_rotation0.jpg' _IMAGE_FILE = 'segmentation_input_rotation0.jpg'
_SEGMENTATION_FILE = 'segmentation_golden_rotation0.png' _SEGMENTATION_FILE = 'segmentation_golden_rotation0.png'
_MASK_MAGNIFICATION_FACTOR = 10 _MASK_MAGNIFICATION_FACTOR = 10
_MATCH_PIXELS_THRESHOLD = 0.01 _MASK_SIMILARITY_THRESHOLD = 0.98
def _similar_to_uint8_mask(actual_mask, expected_mask):
actual_mask_pixels = actual_mask.numpy_view().flatten()
expected_mask_pixels = expected_mask.numpy_view().flatten()
consistent_pixels = 0
num_pixels = len(expected_mask_pixels)
for index in range(num_pixels):
consistent_pixels += (
actual_mask_pixels[index] * _MASK_MAGNIFICATION_FACTOR ==
expected_mask_pixels[index])
return consistent_pixels / num_pixels >= _MASK_SIMILARITY_THRESHOLD
class ModelFileType(enum.Enum): class ModelFileType(enum.Enum):
@ -54,10 +71,14 @@ class ImageSegmenterTest(parameterized.TestCase):
def setUp(self): def setUp(self):
super().setUp() super().setUp()
self.test_image = test_util.read_test_image( # Load the test input image.
test_util.get_test_data_path(_IMAGE_FILE)) self.test_image = _Image.create_from_file(
self.test_seg_path = test_util.get_test_data_path(_SEGMENTATION_FILE) test_utils.get_test_data_path(_IMAGE_FILE))
self.model_path = test_util.get_test_data_path(_MODEL_FILE) # Loads ground truth segmentation file.
gt_segmentation_data = cv2.imread(
test_utils.get_test_data_path(_SEGMENTATION_FILE), cv2.IMREAD_GRAYSCALE)
self.test_seg_image = _Image(_ImageFormat.GRAY8, gt_segmentation_data)
self.model_path = test_utils.get_test_data_path(_MODEL_FILE)
def test_create_from_file_succeeds_with_valid_model_path(self): def test_create_from_file_succeeds_with_valid_model_path(self):
# Creates with default option and valid model file successfully. # Creates with default option and valid model file successfully.
@ -76,7 +97,7 @@ class ImageSegmenterTest(parameterized.TestCase):
with self.assertRaisesRegex( with self.assertRaisesRegex(
ValueError, ValueError,
r"ExternalFile must specify at least one of 'file_content', " r"ExternalFile must specify at least one of 'file_content', "
r"'file_name' or 'file_descriptor_meta'."): r"'file_name', 'file_pointer_meta' or 'file_descriptor_meta'."):
base_options = _BaseOptions(model_asset_path='') base_options = _BaseOptions(model_asset_path='')
options = _ImageSegmenterOptions(base_options=base_options) options = _ImageSegmenterOptions(base_options=base_options)
_ImageSegmenter.create_from_options(options) _ImageSegmenter.create_from_options(options)
@ -112,34 +133,16 @@ class ImageSegmenterTest(parameterized.TestCase):
# Performs image segmentation on the input. # Performs image segmentation on the input.
category_masks = segmenter.segment(self.test_image) category_masks = segmenter.segment(self.test_image)
self.assertEqual(len(category_masks), 1) self.assertEqual(len(category_masks), 1)
result_pixels = category_masks[0].numpy_view().flatten() category_mask = category_masks[0]
result_pixels = category_mask.numpy_view().flatten()
# Check if data type of `category_masks` is correct. # Check if data type of `category_mask` is correct.
self.assertEqual(result_pixels.dtype, np.uint8) self.assertEqual(result_pixels.dtype, np.uint8)
# Loads ground truth segmentation file. self.assertTrue(
image_data = cv2.imread(self.test_seg_path, cv2.IMREAD_GRAYSCALE) _similar_to_uint8_mask(category_masks[0], self.test_seg_image),
gt_segmentation = _Image(_ImageFormat.GRAY8, image_data)
gt_segmentation_array = gt_segmentation.numpy_view()
gt_segmentation_shape = gt_segmentation_array.shape
num_pixels = gt_segmentation_shape[0] * gt_segmentation_shape[1]
ground_truth_pixels = gt_segmentation_array.flatten()
self.assertEqual(
len(result_pixels), len(ground_truth_pixels),
'Segmentation mask size does not match the ground truth mask size.')
inconsistent_pixels = 0
for index in range(num_pixels):
inconsistent_pixels += (
result_pixels[index] * _MASK_MAGNIFICATION_FACTOR !=
ground_truth_pixels[index])
self.assertLessEqual(
inconsistent_pixels / num_pixels, _MATCH_PIXELS_THRESHOLD,
f'Number of pixels in the candidate mask differing from that of the ' f'Number of pixels in the candidate mask differing from that of the '
f'ground truth mask exceeds {_MATCH_PIXELS_THRESHOLD}.') f'ground truth mask exceeds {_MASK_SIMILARITY_THRESHOLD}.')
# Closes the segmenter explicitly when the segmenter is not used in # Closes the segmenter explicitly when the segmenter is not used in
# a context. # a context.
@ -188,6 +191,174 @@ class ImageSegmenterTest(parameterized.TestCase):
# a context. # a context.
segmenter.close() segmenter.close()
@parameterized.parameters(
(ModelFileType.FILE_NAME,),
(ModelFileType.FILE_CONTENT,))
def test_segment_in_context(self, model_file_type):
if model_file_type is ModelFileType.FILE_NAME:
base_options = _BaseOptions(model_asset_path=self.model_path)
elif model_file_type is ModelFileType.FILE_CONTENT:
with open(self.model_path, 'rb') as f:
model_contents = f.read()
base_options = _BaseOptions(model_asset_buffer=model_contents)
else:
# Should never happen
raise ValueError('model_file_type is invalid.')
segmenter_options = _SegmenterOptions(output_type=_OutputType.CATEGORY_MASK)
options = _ImageSegmenterOptions(base_options=base_options,
segmenter_options=segmenter_options)
with _ImageSegmenter.create_from_options(options) as segmenter:
# Performs image segmentation on the input.
category_masks = segmenter.segment(self.test_image)
self.assertEqual(len(category_masks), 1)
self.assertTrue(
_similar_to_uint8_mask(category_masks[0], self.test_seg_image),
f'Number of pixels in the candidate mask differing from that of the '
f'ground truth mask exceeds {_MASK_SIMILARITY_THRESHOLD}.')
def test_missing_result_callback(self):
options = _ImageSegmenterOptions(
base_options=_BaseOptions(model_asset_path=self.model_path),
running_mode=_RUNNING_MODE.LIVE_STREAM)
with self.assertRaisesRegex(ValueError,
r'result callback must be provided'):
with _ImageSegmenter.create_from_options(options) as unused_segmenter:
pass
@parameterized.parameters((_RUNNING_MODE.IMAGE), (_RUNNING_MODE.VIDEO))
def test_illegal_result_callback(self, running_mode):
options = _ImageSegmenterOptions(
base_options=_BaseOptions(model_asset_path=self.model_path),
running_mode=running_mode,
result_callback=mock.MagicMock())
with self.assertRaisesRegex(ValueError,
r'result callback should not be provided'):
with _ImageSegmenter.create_from_options(options) as unused_segmenter:
pass
def test_calling_segment_for_video_in_image_mode(self):
options = _ImageSegmenterOptions(
base_options=_BaseOptions(model_asset_path=self.model_path),
running_mode=_RUNNING_MODE.IMAGE)
with _ImageSegmenter.create_from_options(options) as segmenter:
with self.assertRaisesRegex(ValueError,
r'not initialized with the video mode'):
segmenter.segment_for_video(self.test_image, 0)
def test_calling_segment_async_in_image_mode(self):
options = _ImageSegmenterOptions(
base_options=_BaseOptions(model_asset_path=self.model_path),
running_mode=_RUNNING_MODE.IMAGE)
with _ImageSegmenter.create_from_options(options) as segmenter:
with self.assertRaisesRegex(ValueError,
r'not initialized with the live stream mode'):
segmenter.segment_async(self.test_image, 0)
def test_calling_segment_in_video_mode(self):
options = _ImageSegmenterOptions(
base_options=_BaseOptions(model_asset_path=self.model_path),
running_mode=_RUNNING_MODE.VIDEO)
with _ImageSegmenter.create_from_options(options) as segmenter:
with self.assertRaisesRegex(ValueError,
r'not initialized with the image mode'):
segmenter.segment(self.test_image)
def test_calling_segment_async_in_video_mode(self):
options = _ImageSegmenterOptions(
base_options=_BaseOptions(model_asset_path=self.model_path),
running_mode=_RUNNING_MODE.VIDEO)
with _ImageSegmenter.create_from_options(options) as segmenter:
with self.assertRaisesRegex(ValueError,
r'not initialized with the live stream mode'):
segmenter.segment_async(self.test_image, 0)
def test_detect_for_video_with_out_of_order_timestamp(self):
options = _ImageSegmenterOptions(
base_options=_BaseOptions(model_asset_path=self.model_path),
running_mode=_RUNNING_MODE.VIDEO)
with _ImageSegmenter.create_from_options(options) as segmenter:
unused_result = segmenter.segment_for_video(self.test_image, 1)
with self.assertRaisesRegex(
ValueError, r'Input timestamp must be monotonically increasing'):
segmenter.segment_for_video(self.test_image, 0)
def test_segment_for_video(self):
segmenter_options = _SegmenterOptions(output_type=_OutputType.CATEGORY_MASK)
options = _ImageSegmenterOptions(
base_options=_BaseOptions(model_asset_path=self.model_path),
segmenter_options=segmenter_options,
running_mode=_RUNNING_MODE.VIDEO)
with _ImageSegmenter.create_from_options(options) as segmenter:
for timestamp in range(0, 300, 30):
category_masks = segmenter.segment_for_video(self.test_image, timestamp)
self.assertEqual(len(category_masks), 1)
self.assertTrue(
_similar_to_uint8_mask(category_masks[0], self.test_seg_image),
f'Number of pixels in the candidate mask differing from that of the '
f'ground truth mask exceeds {_MASK_SIMILARITY_THRESHOLD}.')
def test_calling_segment_in_live_stream_mode(self):
options = _ImageSegmenterOptions(
base_options=_BaseOptions(model_asset_path=self.model_path),
running_mode=_RUNNING_MODE.LIVE_STREAM,
result_callback=mock.MagicMock())
with _ImageSegmenter.create_from_options(options) as segmenter:
with self.assertRaisesRegex(ValueError,
r'not initialized with the image mode'):
segmenter.segment(self.test_image)
def test_calling_segment_for_video_in_live_stream_mode(self):
options = _ImageSegmenterOptions(
base_options=_BaseOptions(model_asset_path=self.model_path),
running_mode=_RUNNING_MODE.LIVE_STREAM,
result_callback=mock.MagicMock())
with _ImageSegmenter.create_from_options(options) as segmenter:
with self.assertRaisesRegex(ValueError,
r'not initialized with the video mode'):
segmenter.segment_for_video(self.test_image, 0)
def test_segment_async_calls_with_illegal_timestamp(self):
options = _ImageSegmenterOptions(
base_options=_BaseOptions(model_asset_path=self.model_path),
running_mode=_RUNNING_MODE.LIVE_STREAM,
result_callback=mock.MagicMock())
with _ImageSegmenter.create_from_options(options) as segmenter:
segmenter.segment_async(self.test_image, 100)
with self.assertRaisesRegex(
ValueError, r'Input timestamp must be monotonically increasing'):
segmenter.segment_async(self.test_image, 0)
def test_segment_async_calls(self):
observed_timestamp_ms = -1
def check_result(result: List[image_module.Image],
output_image: _Image,
timestamp_ms: int):
# Get the output category mask.
category_mask = result[0]
self.assertEqual(output_image.width, self.test_image.width)
self.assertEqual(output_image.height, self.test_image.height)
self.assertEqual(output_image.width, self.test_seg_image.width)
self.assertEqual(output_image.height, self.test_seg_image.height)
self.assertTrue(
_similar_to_uint8_mask(category_mask, self.test_seg_image),
f'Number of pixels in the candidate mask differing from that of the '
f'ground truth mask exceeds {_MASK_SIMILARITY_THRESHOLD}.')
self.assertLess(observed_timestamp_ms, timestamp_ms)
self.observed_timestamp_ms = timestamp_ms
segmenter_options = _SegmenterOptions(output_type=_OutputType.CATEGORY_MASK)
options = _ImageSegmenterOptions(
base_options=_BaseOptions(model_asset_path=self.model_path),
segmenter_options=segmenter_options,
running_mode=_RUNNING_MODE.LIVE_STREAM,
result_callback=check_result)
with _ImageSegmenter.create_from_options(options) as segmenter:
for timestamp in range(0, 300, 30):
segmenter.segment_async(self.test_image, timestamp)
if __name__ == '__main__': if __name__ == '__main__':
absltest.main() absltest.main()

View File

@ -42,6 +42,7 @@ _IMAGE_IN_STREAM_NAME = 'image_in'
_IMAGE_OUT_STREAM_NAME = 'image_out' _IMAGE_OUT_STREAM_NAME = 'image_out'
_IMAGE_TAG = 'IMAGE' _IMAGE_TAG = 'IMAGE'
_TASK_GRAPH_NAME = 'mediapipe.tasks.vision.ImageSegmenterGraph' _TASK_GRAPH_NAME = 'mediapipe.tasks.vision.ImageSegmenterGraph'
_MICRO_SECONDS_PER_MILLISECOND = 1000
@dataclasses.dataclass @dataclasses.dataclass
@ -52,9 +53,9 @@ class ImageSegmenterOptions:
base_options: Base options for the image segmenter task. base_options: Base options for the image segmenter task.
running_mode: The running mode of the task. Default to the image mode. running_mode: The running mode of the task. Default to the image mode.
Image segmenter task has three running modes: Image segmenter task has three running modes:
1) The image mode for detecting objects on single image inputs. 1) The image mode for segmenting objects on single image inputs.
2) The video mode for detecting objects on the decoded frames of a video. 2) The video mode for segmenting objects on the decoded frames of a video.
3) The live stream mode for detecting objects on a live stream of input 3) The live stream mode for segmenting objects on a live stream of input
data, such as from camera. data, such as from camera.
segmenter_options: Options for the image segmenter task. segmenter_options: Options for the image segmenter task.
result_callback: The user-defined result callback for processing live stream result_callback: The user-defined result callback for processing live stream
@ -86,7 +87,8 @@ class ImageSegmenter(base_vision_task_api.BaseVisionTaskApi):
@classmethod @classmethod
def create_from_model_path(cls, model_path: str) -> 'ImageSegmenter': def create_from_model_path(cls, model_path: str) -> 'ImageSegmenter':
"""Creates an `ImageSegmenter` object from a TensorFlow Lite model and the default `ImageSegmenterOptions`. """Creates an `ImageSegmenter` object from a TensorFlow Lite model and the
default `ImageSegmenterOptions`.
Note that the created `ImageSegmenter` instance is in image mode, for Note that the created `ImageSegmenter` instance is in image mode, for
performing image segmentation on single image inputs. performing image segmentation on single image inputs.
@ -131,8 +133,9 @@ class ImageSegmenter(base_vision_task_api.BaseVisionTaskApi):
segmentation_result = packet_getter.get_image_list( segmentation_result = packet_getter.get_image_list(
output_packets[_SEGMENTATION_OUT_STREAM_NAME]) output_packets[_SEGMENTATION_OUT_STREAM_NAME])
image = packet_getter.get_image(output_packets[_IMAGE_OUT_STREAM_NAME]) image = packet_getter.get_image(output_packets[_IMAGE_OUT_STREAM_NAME])
timestamp = output_packets[_IMAGE_OUT_STREAM_NAME].timestamp timestamp = output_packets[_SEGMENTATION_OUT_STREAM_NAME].timestamp
options.result_callback(segmentation_result, image, timestamp) options.result_callback(segmentation_result, image,
timestamp.value // _MICRO_SECONDS_PER_MILLISECOND)
task_info = _TaskInfo( task_info = _TaskInfo(
task_graph=_TASK_GRAPH_NAME, task_graph=_TASK_GRAPH_NAME,
@ -148,7 +151,6 @@ class ImageSegmenter(base_vision_task_api.BaseVisionTaskApi):
_RunningMode.LIVE_STREAM), options.running_mode, _RunningMode.LIVE_STREAM), options.running_mode,
packets_callback if options.result_callback else None) packets_callback if options.result_callback else None)
# TODO: Create an Image class for MediaPipe Tasks.
def segment(self, def segment(self,
image: image_module.Image) -> List[image_module.Image]: image: image_module.Image) -> List[image_module.Image]:
"""Performs the actual segmentation task on the provided MediaPipe Image. """Performs the actual segmentation task on the provided MediaPipe Image.
@ -162,10 +164,74 @@ class ImageSegmenter(base_vision_task_api.BaseVisionTaskApi):
Raises: Raises:
ValueError: If any of the input arguments is invalid. ValueError: If any of the input arguments is invalid.
RuntimeError: If object detection failed to run. RuntimeError: If image segmentation failed to run.
""" """
output_packets = self._process_image_data( output_packets = self._process_image_data(
{_IMAGE_IN_STREAM_NAME: packet_creator.create_image(image)}) {_IMAGE_IN_STREAM_NAME: packet_creator.create_image(image)})
segmentation_result = packet_getter.get_image_list( segmentation_result = packet_getter.get_image_list(
output_packets[_SEGMENTATION_OUT_STREAM_NAME]) output_packets[_SEGMENTATION_OUT_STREAM_NAME])
return segmentation_result return segmentation_result
def segment_for_video(self, image: image_module.Image,
timestamp_ms: int) -> List[image_module.Image]:
"""Performs segmentation on the provided video frames.
Only use this method when the ImageSegmenter is created with the video
running mode. It's required to provide the video frame's timestamp (in
milliseconds) along with the video frame. The input timestamps should be
monotonically increasing for adjacent calls of this method.
Args:
image: MediaPipe Image.
timestamp_ms: The timestamp of the input video frame in milliseconds.
Returns:
A segmentation result object that contains a list of segmentation masks
as images.
Raises:
ValueError: If any of the input arguments is invalid.
RuntimeError: If image segmentation failed to run.
"""
output_packets = self._process_video_data({
_IMAGE_IN_STREAM_NAME:
packet_creator.create_image(image).at(
timestamp_ms * _MICRO_SECONDS_PER_MILLISECOND)
})
segmentation_result = packet_getter.get_image_list(
output_packets[_SEGMENTATION_OUT_STREAM_NAME])
return segmentation_result
def segment_async(self, image: image_module.Image, timestamp_ms: int) -> None:
"""Sends live image data (an Image with a unique timestamp) to perform
image segmentation.
Only use this method when the ImageSegmenter is created with the live stream
running mode. The input timestamps should be monotonically increasing for
adjacent calls of this method. This method will return immediately after the
input image is accepted. The results will be available via the
`result_callback` provided in the `ImageSegmenterOptions`. The
`segment_async` method is designed to process live stream data such as
camera input. To lower the overall latency, image segmenter may drop the
input images if needed. In other words, it's not guaranteed to have output
per input image.
The `result_callback` prvoides:
- A segmentation result object that contains a list of segmentation masks
as images.
- The input image that the image segmenter runs on.
- The input timestamp in milliseconds.
Args:
image: MediaPipe Image.
timestamp_ms: The timestamp of the input image in milliseconds.
Raises:
ValueError: If the current input timestamp is smaller than what the image
segmenter has already processed.
"""
self._send_live_stream_data({
_IMAGE_IN_STREAM_NAME:
packet_creator.create_image(image).at(
timestamp_ms * _MICRO_SECONDS_PER_MILLISECOND)
})