Merge branch 'google:master' into segmenter-python-add-labels

This commit is contained in:
Kinar R 2023-04-19 09:51:42 +05:30 committed by GitHub
commit d621df8046
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
10 changed files with 459 additions and 229 deletions

View File

@ -64,7 +64,7 @@ std::unique_ptr<GlTextureBuffer> GlTextureBuffer::Create(
int actual_ws = image_frame.WidthStep(); int actual_ws = image_frame.WidthStep();
int alignment = 0; int alignment = 0;
std::unique_ptr<ImageFrame> temp; std::unique_ptr<ImageFrame> temp;
const uint8* data = image_frame.PixelData(); const uint8_t* data = image_frame.PixelData();
// Let's see if the pixel data is tightly aligned to one of the alignments // Let's see if the pixel data is tightly aligned to one of the alignments
// supported by OpenGL, preferring 4 if possible since it's the default. // supported by OpenGL, preferring 4 if possible since it's the default.

View File

@ -175,11 +175,7 @@ py_test(
data = [":testdata"], data = [":testdata"],
tags = ["requires-net:external"], tags = ["requires-net:external"],
deps = [ deps = [
":dataset", ":object_detector_import",
":hyperparameters",
":model_spec",
":object_detector",
":object_detector_options",
"//mediapipe/tasks/python/test:test_utils", "//mediapipe/tasks/python/test:test_utils",
], ],
) )

View File

@ -19,11 +19,7 @@ from unittest import mock as unittest_mock
from absl.testing import parameterized from absl.testing import parameterized
import tensorflow as tf import tensorflow as tf
from mediapipe.model_maker.python.vision.object_detector import dataset from mediapipe.model_maker.python.vision import object_detector
from mediapipe.model_maker.python.vision.object_detector import hyperparameters
from mediapipe.model_maker.python.vision.object_detector import model_spec as ms
from mediapipe.model_maker.python.vision.object_detector import object_detector
from mediapipe.model_maker.python.vision.object_detector import object_detector_options
from mediapipe.tasks.python.test import test_utils as task_test_utils from mediapipe.tasks.python.test import test_utils as task_test_utils
@ -33,7 +29,7 @@ class ObjectDetectorTest(tf.test.TestCase, parameterized.TestCase):
super().setUp() super().setUp()
dataset_folder = task_test_utils.get_test_data_path('coco_data') dataset_folder = task_test_utils.get_test_data_path('coco_data')
cache_dir = self.create_tempdir() cache_dir = self.create_tempdir()
self.data = dataset.Dataset.from_coco_folder( self.data = object_detector.Dataset.from_coco_folder(
dataset_folder, cache_dir=cache_dir dataset_folder, cache_dir=cache_dir
) )
# Mock tempfile.gettempdir() to be unique for each test to avoid race # Mock tempfile.gettempdir() to be unique for each test to avoid race
@ -48,15 +44,16 @@ class ObjectDetectorTest(tf.test.TestCase, parameterized.TestCase):
self.addCleanup(mock_gettempdir.stop) self.addCleanup(mock_gettempdir.stop)
def test_object_detector(self): def test_object_detector(self):
hparams = hyperparameters.HParams( hparams = object_detector.HParams(
epochs=1, epochs=1,
batch_size=2, batch_size=2,
learning_rate=0.9, learning_rate=0.9,
shuffle=False, shuffle=False,
export_dir=self.create_tempdir(), export_dir=self.create_tempdir(),
) )
options = object_detector_options.ObjectDetectorOptions( options = object_detector.ObjectDetectorOptions(
supported_model=ms.SupportedModels.MOBILENET_V2, hparams=hparams supported_model=object_detector.SupportedModels.MOBILENET_V2,
hparams=hparams,
) )
# Test `create`` # Test `create``
model = object_detector.ObjectDetector.create( model = object_detector.ObjectDetector.create(
@ -79,7 +76,7 @@ class ObjectDetectorTest(tf.test.TestCase, parameterized.TestCase):
self.assertGreater(os.path.getsize(output_metadata_file), 0) self.assertGreater(os.path.getsize(output_metadata_file), 0)
# Test `quantization_aware_training` # Test `quantization_aware_training`
qat_hparams = hyperparameters.QATHParams( qat_hparams = object_detector.QATHParams(
learning_rate=0.9, learning_rate=0.9,
batch_size=2, batch_size=2,
epochs=1, epochs=1,

View File

@ -24,8 +24,8 @@ namespace mediapipe {
void FrameAnnotationTracker::AddDetectionResult( void FrameAnnotationTracker::AddDetectionResult(
const FrameAnnotation& frame_annotation) { const FrameAnnotation& frame_annotation) {
const int64 time_us = const int64_t time_us =
static_cast<int64>(std::round(frame_annotation.timestamp())); static_cast<int64_t>(std::round(frame_annotation.timestamp()));
for (const auto& object_annotation : frame_annotation.annotations()) { for (const auto& object_annotation : frame_annotation.annotations()) {
detected_objects_[time_us + object_annotation.object_id()] = detected_objects_[time_us + object_annotation.object_id()] =
object_annotation; object_annotation;
@ -37,7 +37,7 @@ FrameAnnotation FrameAnnotationTracker::ConsolidateTrackingResult(
absl::flat_hash_set<int>* cancel_object_ids) { absl::flat_hash_set<int>* cancel_object_ids) {
CHECK(cancel_object_ids != nullptr); CHECK(cancel_object_ids != nullptr);
FrameAnnotation frame_annotation; FrameAnnotation frame_annotation;
std::vector<int64> keys_to_be_deleted; std::vector<int64_t> keys_to_be_deleted;
for (const auto& detected_obj : detected_objects_) { for (const auto& detected_obj : detected_objects_) {
const int object_id = detected_obj.second.object_id(); const int object_id = detected_obj.second.object_id();
if (cancel_object_ids->contains(object_id)) { if (cancel_object_ids->contains(object_id)) {

View File

@ -32,6 +32,7 @@ _TextEmbedderOptions = text_embedder.TextEmbedderOptions
_BERT_MODEL_FILE = 'mobilebert_embedding_with_metadata.tflite' _BERT_MODEL_FILE = 'mobilebert_embedding_with_metadata.tflite'
_REGEX_MODEL_FILE = 'regex_one_embedding_with_metadata.tflite' _REGEX_MODEL_FILE = 'regex_one_embedding_with_metadata.tflite'
_USE_MODEL_FILE = 'universal_sentence_encoder_qa_with_metadata.tflite'
_TEST_DATA_DIR = 'mediapipe/tasks/testdata/text' _TEST_DATA_DIR = 'mediapipe/tasks/testdata/text'
# Tolerance for embedding vector coordinate values. # Tolerance for embedding vector coordinate values.
_EPSILON = 1e-4 _EPSILON = 1e-4
@ -138,6 +139,24 @@ class TextEmbedderTest(parameterized.TestCase):
16, 16,
(0.549632, 0.552879), (0.549632, 0.552879),
), ),
(
False,
False,
_USE_MODEL_FILE,
ModelFileType.FILE_NAME,
0.851961,
100,
(1.422951, 1.404664),
),
(
True,
False,
_USE_MODEL_FILE,
ModelFileType.FILE_CONTENT,
0.851961,
100,
(0.127049, 0.125416),
),
) )
def test_embed(self, l2_normalize, quantize, model_name, model_file_type, def test_embed(self, l2_normalize, quantize, model_name, model_file_type,
expected_similarity, expected_size, expected_first_values): expected_similarity, expected_size, expected_first_values):
@ -213,6 +232,24 @@ class TextEmbedderTest(parameterized.TestCase):
16, 16,
(0.549632, 0.552879), (0.549632, 0.552879),
), ),
(
False,
False,
_USE_MODEL_FILE,
ModelFileType.FILE_NAME,
0.851961,
100,
(1.422951, 1.404664),
),
(
True,
False,
_USE_MODEL_FILE,
ModelFileType.FILE_CONTENT,
0.851961,
100,
(0.127049, 0.125416),
),
) )
def test_embed_in_context(self, l2_normalize, quantize, model_name, def test_embed_in_context(self, l2_normalize, quantize, model_name,
model_file_type, expected_similarity, expected_size, model_file_type, expected_similarity, expected_size,
@ -251,6 +288,7 @@ class TextEmbedderTest(parameterized.TestCase):
@parameterized.parameters( @parameterized.parameters(
# TODO: The similarity should likely be lower # TODO: The similarity should likely be lower
(_BERT_MODEL_FILE, 0.980880), (_BERT_MODEL_FILE, 0.980880),
(_USE_MODEL_FILE, 0.780334),
) )
def test_embed_with_different_themes(self, model_file, expected_similarity): def test_embed_with_different_themes(self, model_file, expected_similarity):
# Creates embedder. # Creates embedder.

View File

@ -15,7 +15,6 @@
import enum import enum
import os import os
from typing import List
from unittest import mock from unittest import mock
from absl.testing import absltest from absl.testing import absltest
@ -30,11 +29,10 @@ from mediapipe.tasks.python.test import test_utils
from mediapipe.tasks.python.vision import image_segmenter from mediapipe.tasks.python.vision import image_segmenter
from mediapipe.tasks.python.vision.core import vision_task_running_mode from mediapipe.tasks.python.vision.core import vision_task_running_mode
ImageSegmenterResult = image_segmenter.ImageSegmenterResult
_BaseOptions = base_options_module.BaseOptions _BaseOptions = base_options_module.BaseOptions
_Image = image_module.Image _Image = image_module.Image
_ImageFormat = image_frame.ImageFormat _ImageFormat = image_frame.ImageFormat
_OutputType = image_segmenter.ImageSegmenterOptions.OutputType
_Activation = image_segmenter.ImageSegmenterOptions.Activation
_ImageSegmenter = image_segmenter.ImageSegmenter _ImageSegmenter = image_segmenter.ImageSegmenter
_ImageSegmenterOptions = image_segmenter.ImageSegmenterOptions _ImageSegmenterOptions = image_segmenter.ImageSegmenterOptions
_RUNNING_MODE = vision_task_running_mode.VisionTaskRunningMode _RUNNING_MODE = vision_task_running_mode.VisionTaskRunningMode
@ -42,6 +40,8 @@ _RUNNING_MODE = vision_task_running_mode.VisionTaskRunningMode
_MODEL_FILE = 'deeplabv3.tflite' _MODEL_FILE = 'deeplabv3.tflite'
_IMAGE_FILE = 'segmentation_input_rotation0.jpg' _IMAGE_FILE = 'segmentation_input_rotation0.jpg'
_SEGMENTATION_FILE = 'segmentation_golden_rotation0.png' _SEGMENTATION_FILE = 'segmentation_golden_rotation0.png'
_CAT_IMAGE = 'cat.jpg'
_CAT_MASK = 'cat_mask.jpg'
_MASK_MAGNIFICATION_FACTOR = 10 _MASK_MAGNIFICATION_FACTOR = 10
_MASK_SIMILARITY_THRESHOLD = 0.98 _MASK_SIMILARITY_THRESHOLD = 0.98
_TEST_DATA_DIR = 'mediapipe/tasks/testdata/vision' _TEST_DATA_DIR = 'mediapipe/tasks/testdata/vision'
@ -70,6 +70,26 @@ _EXPECTED_LABELS = [
] ]
def _calculate_soft_iou(m1, m2):
intersection_sum = np.sum(m1 * m2)
union_sum = np.sum(m1 * m1) + np.sum(m2 * m2) - intersection_sum
if union_sum > 0:
return intersection_sum / union_sum
else:
return 0
def _similar_to_float_mask(actual_mask, expected_mask, similarity_threshold):
actual_mask = actual_mask.numpy_view()
expected_mask = expected_mask.numpy_view() / 255.0
return (
actual_mask.shape == expected_mask.shape
and _calculate_soft_iou(actual_mask, expected_mask) > similarity_threshold
)
def _similar_to_uint8_mask(actual_mask, expected_mask): def _similar_to_uint8_mask(actual_mask, expected_mask):
actual_mask_pixels = actual_mask.numpy_view().flatten() actual_mask_pixels = actual_mask.numpy_view().flatten()
expected_mask_pixels = expected_mask.numpy_view().flatten() expected_mask_pixels = expected_mask.numpy_view().flatten()
@ -79,8 +99,9 @@ def _similar_to_uint8_mask(actual_mask, expected_mask):
for index in range(num_pixels): for index in range(num_pixels):
consistent_pixels += ( consistent_pixels += (
actual_mask_pixels[index] * actual_mask_pixels[index] * _MASK_MAGNIFICATION_FACTOR
_MASK_MAGNIFICATION_FACTOR == expected_mask_pixels[index]) == expected_mask_pixels[index]
)
return consistent_pixels / num_pixels >= _MASK_SIMILARITY_THRESHOLD return consistent_pixels / num_pixels >= _MASK_SIMILARITY_THRESHOLD
@ -96,16 +117,27 @@ class ImageSegmenterTest(parameterized.TestCase):
super().setUp() super().setUp()
# Load the test input image. # Load the test input image.
self.test_image = _Image.create_from_file( self.test_image = _Image.create_from_file(
test_utils.get_test_data_path( test_utils.get_test_data_path(os.path.join(_TEST_DATA_DIR, _IMAGE_FILE))
os.path.join(_TEST_DATA_DIR, _IMAGE_FILE))) )
# Loads ground truth segmentation file. # Loads ground truth segmentation file.
gt_segmentation_data = cv2.imread( gt_segmentation_data = cv2.imread(
test_utils.get_test_data_path( test_utils.get_test_data_path(
os.path.join(_TEST_DATA_DIR, _SEGMENTATION_FILE)), os.path.join(_TEST_DATA_DIR, _SEGMENTATION_FILE)
cv2.IMREAD_GRAYSCALE) ),
cv2.IMREAD_GRAYSCALE,
)
self.test_seg_image = _Image(_ImageFormat.GRAY8, gt_segmentation_data) self.test_seg_image = _Image(_ImageFormat.GRAY8, gt_segmentation_data)
self.model_path = test_utils.get_test_data_path( self.model_path = test_utils.get_test_data_path(
os.path.join(_TEST_DATA_DIR, _MODEL_FILE)) os.path.join(_TEST_DATA_DIR, _MODEL_FILE)
)
def _load_segmentation_mask(self, file_path: str):
# Loads ground truth segmentation file.
gt_segmentation_data = cv2.imread(
test_utils.get_test_data_path(os.path.join(_TEST_DATA_DIR, file_path)),
cv2.IMREAD_GRAYSCALE,
)
return _Image(_ImageFormat.GRAY8, gt_segmentation_data)
def test_create_from_file_succeeds_with_valid_model_path(self): def test_create_from_file_succeeds_with_valid_model_path(self):
# Creates with default option and valid model file successfully. # Creates with default option and valid model file successfully.
@ -121,9 +153,11 @@ class ImageSegmenterTest(parameterized.TestCase):
def test_create_from_options_fails_with_invalid_model_path(self): def test_create_from_options_fails_with_invalid_model_path(self):
with self.assertRaisesRegex( with self.assertRaisesRegex(
RuntimeError, 'Unable to open file at /path/to/invalid/model.tflite'): RuntimeError, 'Unable to open file at /path/to/invalid/model.tflite'
):
base_options = _BaseOptions( base_options = _BaseOptions(
model_asset_path='/path/to/invalid/model.tflite') model_asset_path='/path/to/invalid/model.tflite'
)
options = _ImageSegmenterOptions(base_options=base_options) options = _ImageSegmenterOptions(base_options=base_options)
_ImageSegmenter.create_from_options(options) _ImageSegmenter.create_from_options(options)
@ -135,8 +169,9 @@ class ImageSegmenterTest(parameterized.TestCase):
segmenter = _ImageSegmenter.create_from_options(options) segmenter = _ImageSegmenter.create_from_options(options)
self.assertIsInstance(segmenter, _ImageSegmenter) self.assertIsInstance(segmenter, _ImageSegmenter)
@parameterized.parameters((ModelFileType.FILE_NAME,), @parameterized.parameters(
(ModelFileType.FILE_CONTENT,)) (ModelFileType.FILE_NAME,), (ModelFileType.FILE_CONTENT,)
)
def test_segment_succeeds_with_category_mask(self, model_file_type): def test_segment_succeeds_with_category_mask(self, model_file_type):
# Creates segmenter. # Creates segmenter.
if model_file_type is ModelFileType.FILE_NAME: if model_file_type is ModelFileType.FILE_NAME:
@ -150,22 +185,27 @@ class ImageSegmenterTest(parameterized.TestCase):
raise ValueError('model_file_type is invalid.') raise ValueError('model_file_type is invalid.')
options = _ImageSegmenterOptions( options = _ImageSegmenterOptions(
base_options=base_options, output_type=_OutputType.CATEGORY_MASK) base_options=base_options,
output_category_mask=True,
output_confidence_masks=False,
)
segmenter = _ImageSegmenter.create_from_options(options) segmenter = _ImageSegmenter.create_from_options(options)
# Performs image segmentation on the input. # Performs image segmentation on the input.
category_masks = segmenter.segment(self.test_image) segmentation_result = segmenter.segment(self.test_image)
self.assertLen(category_masks, 1) category_mask = segmentation_result.category_mask
category_mask = category_masks[0]
result_pixels = category_mask.numpy_view().flatten() result_pixels = category_mask.numpy_view().flatten()
# Check if data type of `category_mask` is correct. # Check if data type of `category_mask` is correct.
self.assertEqual(result_pixels.dtype, np.uint8) self.assertEqual(result_pixels.dtype, np.uint8)
self.assertTrue( self.assertTrue(
_similar_to_uint8_mask(category_masks[0], self.test_seg_image), _similar_to_uint8_mask(category_mask, self.test_seg_image),
f'Number of pixels in the candidate mask differing from that of the ' (
f'ground truth mask exceeds {_MASK_SIMILARITY_THRESHOLD}.') 'Number of pixels in the candidate mask differing from that of the'
f' ground truth mask exceeds {_MASK_SIMILARITY_THRESHOLD}.'
),
)
# Closes the segmenter explicitly when the segmenter is not used in # Closes the segmenter explicitly when the segmenter is not used in
# a context. # a context.
@ -175,67 +215,37 @@ class ImageSegmenterTest(parameterized.TestCase):
# Creates segmenter. # Creates segmenter.
base_options = _BaseOptions(model_asset_path=self.model_path) base_options = _BaseOptions(model_asset_path=self.model_path)
# Run segmentation on the model in CATEGORY_MASK mode. # Load the cat image.
options = _ImageSegmenterOptions( test_image = _Image.create_from_file(
base_options=base_options, output_type=_OutputType.CATEGORY_MASK) test_utils.get_test_data_path(os.path.join(_TEST_DATA_DIR, _CAT_IMAGE))
segmenter = _ImageSegmenter.create_from_options(options) )
category_masks = segmenter.segment(self.test_image)
category_mask = category_masks[0].numpy_view()
# Run segmentation on the model in CONFIDENCE_MASK mode. # Run segmentation on the model in CONFIDENCE_MASK mode.
options = _ImageSegmenterOptions( options = _ImageSegmenterOptions(
base_options=base_options, base_options=base_options,
output_type=_OutputType.CONFIDENCE_MASK, output_category_mask=False,
activation=_Activation.SOFTMAX) output_confidence_masks=True,
segmenter = _ImageSegmenter.create_from_options(options) )
confidence_masks = segmenter.segment(self.test_image)
# Check if confidence mask shape is correct.
self.assertLen(
confidence_masks, 21,
'Number of confidence masks must match with number of categories.')
# Gather the confidence masks in a single array `confidence_mask_array`.
confidence_mask_array = np.array(
[confidence_mask.numpy_view() for confidence_mask in confidence_masks])
# Check if data type of `confidence_masks` are correct.
self.assertEqual(confidence_mask_array.dtype, np.float32)
# Compute the category mask from the created confidence mask.
calculated_category_mask = np.argmax(confidence_mask_array, axis=0)
self.assertListEqual(
calculated_category_mask.tolist(), category_mask.tolist(),
'Confidence mask does not match with the category mask.')
# Closes the segmenter explicitly when the segmenter is not used in
# a context.
segmenter.close()
@parameterized.parameters((ModelFileType.FILE_NAME),
(ModelFileType.FILE_CONTENT))
def test_segment_in_context(self, model_file_type):
if model_file_type is ModelFileType.FILE_NAME:
base_options = _BaseOptions(model_asset_path=self.model_path)
elif model_file_type is ModelFileType.FILE_CONTENT:
with open(self.model_path, 'rb') as f:
model_contents = f.read()
base_options = _BaseOptions(model_asset_buffer=model_contents)
else:
# Should never happen
raise ValueError('model_file_type is invalid.')
options = _ImageSegmenterOptions(
base_options=base_options, output_type=_OutputType.CATEGORY_MASK)
with _ImageSegmenter.create_from_options(options) as segmenter: with _ImageSegmenter.create_from_options(options) as segmenter:
# Performs image segmentation on the input. segmentation_result = segmenter.segment(test_image)
category_masks = segmenter.segment(self.test_image) confidence_masks = segmentation_result.confidence_masks
self.assertLen(category_masks, 1)
# Check if confidence mask shape is correct.
self.assertLen(
confidence_masks,
21,
'Number of confidence masks must match with number of categories.',
)
# Loads ground truth segmentation file.
expected_mask = self._load_segmentation_mask(_CAT_MASK)
self.assertTrue( self.assertTrue(
_similar_to_uint8_mask(category_masks[0], self.test_seg_image), _similar_to_float_mask(
f'Number of pixels in the candidate mask differing from that of the ' confidence_masks[8], expected_mask, _MASK_SIMILARITY_THRESHOLD
f'ground truth mask exceeds {_MASK_SIMILARITY_THRESHOLD}.') )
)
def test_get_labels_succeeds(self): def test_get_labels_succeeds(self):
expected_labels = _EXPECTED_LABELS expected_labels = _EXPECTED_LABELS
@ -250,9 +260,11 @@ class ImageSegmenterTest(parameterized.TestCase):
def test_missing_result_callback(self): def test_missing_result_callback(self):
options = _ImageSegmenterOptions( options = _ImageSegmenterOptions(
base_options=_BaseOptions(model_asset_path=self.model_path), base_options=_BaseOptions(model_asset_path=self.model_path),
running_mode=_RUNNING_MODE.LIVE_STREAM) running_mode=_RUNNING_MODE.LIVE_STREAM,
with self.assertRaisesRegex(ValueError, )
r'result callback must be provided'): with self.assertRaisesRegex(
ValueError, r'result callback must be provided'
):
with _ImageSegmenter.create_from_options(options) as unused_segmenter: with _ImageSegmenter.create_from_options(options) as unused_segmenter:
pass pass
@ -261,130 +273,236 @@ class ImageSegmenterTest(parameterized.TestCase):
options = _ImageSegmenterOptions( options = _ImageSegmenterOptions(
base_options=_BaseOptions(model_asset_path=self.model_path), base_options=_BaseOptions(model_asset_path=self.model_path),
running_mode=running_mode, running_mode=running_mode,
result_callback=mock.MagicMock()) result_callback=mock.MagicMock(),
with self.assertRaisesRegex(ValueError, )
r'result callback should not be provided'): with self.assertRaisesRegex(
ValueError, r'result callback should not be provided'
):
with _ImageSegmenter.create_from_options(options) as unused_segmenter: with _ImageSegmenter.create_from_options(options) as unused_segmenter:
pass pass
def test_calling_segment_for_video_in_image_mode(self): def test_calling_segment_for_video_in_image_mode(self):
options = _ImageSegmenterOptions( options = _ImageSegmenterOptions(
base_options=_BaseOptions(model_asset_path=self.model_path), base_options=_BaseOptions(model_asset_path=self.model_path),
running_mode=_RUNNING_MODE.IMAGE) running_mode=_RUNNING_MODE.IMAGE,
)
with _ImageSegmenter.create_from_options(options) as segmenter: with _ImageSegmenter.create_from_options(options) as segmenter:
with self.assertRaisesRegex(ValueError, with self.assertRaisesRegex(
r'not initialized with the video mode'): ValueError, r'not initialized with the video mode'
):
segmenter.segment_for_video(self.test_image, 0) segmenter.segment_for_video(self.test_image, 0)
def test_calling_segment_async_in_image_mode(self): def test_calling_segment_async_in_image_mode(self):
options = _ImageSegmenterOptions( options = _ImageSegmenterOptions(
base_options=_BaseOptions(model_asset_path=self.model_path), base_options=_BaseOptions(model_asset_path=self.model_path),
running_mode=_RUNNING_MODE.IMAGE) running_mode=_RUNNING_MODE.IMAGE,
)
with _ImageSegmenter.create_from_options(options) as segmenter: with _ImageSegmenter.create_from_options(options) as segmenter:
with self.assertRaisesRegex(ValueError, with self.assertRaisesRegex(
r'not initialized with the live stream mode'): ValueError, r'not initialized with the live stream mode'
):
segmenter.segment_async(self.test_image, 0) segmenter.segment_async(self.test_image, 0)
def test_calling_segment_in_video_mode(self): def test_calling_segment_in_video_mode(self):
options = _ImageSegmenterOptions( options = _ImageSegmenterOptions(
base_options=_BaseOptions(model_asset_path=self.model_path), base_options=_BaseOptions(model_asset_path=self.model_path),
running_mode=_RUNNING_MODE.VIDEO) running_mode=_RUNNING_MODE.VIDEO,
)
with _ImageSegmenter.create_from_options(options) as segmenter: with _ImageSegmenter.create_from_options(options) as segmenter:
with self.assertRaisesRegex(ValueError, with self.assertRaisesRegex(
r'not initialized with the image mode'): ValueError, r'not initialized with the image mode'
):
segmenter.segment(self.test_image) segmenter.segment(self.test_image)
def test_calling_segment_async_in_video_mode(self): def test_calling_segment_async_in_video_mode(self):
options = _ImageSegmenterOptions( options = _ImageSegmenterOptions(
base_options=_BaseOptions(model_asset_path=self.model_path), base_options=_BaseOptions(model_asset_path=self.model_path),
running_mode=_RUNNING_MODE.VIDEO) running_mode=_RUNNING_MODE.VIDEO,
)
with _ImageSegmenter.create_from_options(options) as segmenter: with _ImageSegmenter.create_from_options(options) as segmenter:
with self.assertRaisesRegex(ValueError, with self.assertRaisesRegex(
r'not initialized with the live stream mode'): ValueError, r'not initialized with the live stream mode'
):
segmenter.segment_async(self.test_image, 0) segmenter.segment_async(self.test_image, 0)
def test_segment_for_video_with_out_of_order_timestamp(self): def test_segment_for_video_with_out_of_order_timestamp(self):
options = _ImageSegmenterOptions( options = _ImageSegmenterOptions(
base_options=_BaseOptions(model_asset_path=self.model_path), base_options=_BaseOptions(model_asset_path=self.model_path),
running_mode=_RUNNING_MODE.VIDEO) running_mode=_RUNNING_MODE.VIDEO,
)
with _ImageSegmenter.create_from_options(options) as segmenter: with _ImageSegmenter.create_from_options(options) as segmenter:
unused_result = segmenter.segment_for_video(self.test_image, 1) unused_result = segmenter.segment_for_video(self.test_image, 1)
with self.assertRaisesRegex( with self.assertRaisesRegex(
ValueError, r'Input timestamp must be monotonically increasing'): ValueError, r'Input timestamp must be monotonically increasing'
):
segmenter.segment_for_video(self.test_image, 0) segmenter.segment_for_video(self.test_image, 0)
def test_segment_for_video(self): def test_segment_for_video_in_category_mask_mode(self):
options = _ImageSegmenterOptions( options = _ImageSegmenterOptions(
base_options=_BaseOptions(model_asset_path=self.model_path), base_options=_BaseOptions(model_asset_path=self.model_path),
output_type=_OutputType.CATEGORY_MASK, output_category_mask=True,
running_mode=_RUNNING_MODE.VIDEO) output_confidence_masks=False,
running_mode=_RUNNING_MODE.VIDEO,
)
with _ImageSegmenter.create_from_options(options) as segmenter: with _ImageSegmenter.create_from_options(options) as segmenter:
for timestamp in range(0, 300, 30): for timestamp in range(0, 300, 30):
category_masks = segmenter.segment_for_video(self.test_image, timestamp) segmentation_result = segmenter.segment_for_video(
self.assertLen(category_masks, 1) self.test_image, timestamp
)
category_mask = segmentation_result.category_mask
self.assertTrue( self.assertTrue(
_similar_to_uint8_mask(category_masks[0], self.test_seg_image), _similar_to_uint8_mask(category_mask, self.test_seg_image),
f'Number of pixels in the candidate mask differing from that of the ' (
f'ground truth mask exceeds {_MASK_SIMILARITY_THRESHOLD}.') 'Number of pixels in the candidate mask differing from that of'
f' the ground truth mask exceeds {_MASK_SIMILARITY_THRESHOLD}.'
),
)
def test_segment_for_video_in_confidence_mask_mode(self):
# Load the cat image.
test_image = _Image.create_from_file(
test_utils.get_test_data_path(os.path.join(_TEST_DATA_DIR, _CAT_IMAGE))
)
options = _ImageSegmenterOptions(
base_options=_BaseOptions(model_asset_path=self.model_path),
running_mode=_RUNNING_MODE.VIDEO,
output_category_mask=False,
output_confidence_masks=True,
)
with _ImageSegmenter.create_from_options(options) as segmenter:
for timestamp in range(0, 300, 30):
segmentation_result = segmenter.segment_for_video(test_image, timestamp)
confidence_masks = segmentation_result.confidence_masks
# Check if confidence mask shape is correct.
self.assertLen(
confidence_masks,
21,
'Number of confidence masks must match with number of categories.',
)
# Loads ground truth segmentation file.
expected_mask = self._load_segmentation_mask(_CAT_MASK)
self.assertTrue(
_similar_to_float_mask(
confidence_masks[8], expected_mask, _MASK_SIMILARITY_THRESHOLD
)
)
def test_calling_segment_in_live_stream_mode(self): def test_calling_segment_in_live_stream_mode(self):
options = _ImageSegmenterOptions( options = _ImageSegmenterOptions(
base_options=_BaseOptions(model_asset_path=self.model_path), base_options=_BaseOptions(model_asset_path=self.model_path),
running_mode=_RUNNING_MODE.LIVE_STREAM, running_mode=_RUNNING_MODE.LIVE_STREAM,
result_callback=mock.MagicMock()) result_callback=mock.MagicMock(),
)
with _ImageSegmenter.create_from_options(options) as segmenter: with _ImageSegmenter.create_from_options(options) as segmenter:
with self.assertRaisesRegex(ValueError, with self.assertRaisesRegex(
r'not initialized with the image mode'): ValueError, r'not initialized with the image mode'
):
segmenter.segment(self.test_image) segmenter.segment(self.test_image)
def test_calling_segment_for_video_in_live_stream_mode(self): def test_calling_segment_for_video_in_live_stream_mode(self):
options = _ImageSegmenterOptions( options = _ImageSegmenterOptions(
base_options=_BaseOptions(model_asset_path=self.model_path), base_options=_BaseOptions(model_asset_path=self.model_path),
running_mode=_RUNNING_MODE.LIVE_STREAM, running_mode=_RUNNING_MODE.LIVE_STREAM,
result_callback=mock.MagicMock()) result_callback=mock.MagicMock(),
)
with _ImageSegmenter.create_from_options(options) as segmenter: with _ImageSegmenter.create_from_options(options) as segmenter:
with self.assertRaisesRegex(ValueError, with self.assertRaisesRegex(
r'not initialized with the video mode'): ValueError, r'not initialized with the video mode'
):
segmenter.segment_for_video(self.test_image, 0) segmenter.segment_for_video(self.test_image, 0)
def test_segment_async_calls_with_illegal_timestamp(self): def test_segment_async_calls_with_illegal_timestamp(self):
options = _ImageSegmenterOptions( options = _ImageSegmenterOptions(
base_options=_BaseOptions(model_asset_path=self.model_path), base_options=_BaseOptions(model_asset_path=self.model_path),
running_mode=_RUNNING_MODE.LIVE_STREAM, running_mode=_RUNNING_MODE.LIVE_STREAM,
result_callback=mock.MagicMock()) result_callback=mock.MagicMock(),
)
with _ImageSegmenter.create_from_options(options) as segmenter: with _ImageSegmenter.create_from_options(options) as segmenter:
segmenter.segment_async(self.test_image, 100) segmenter.segment_async(self.test_image, 100)
with self.assertRaisesRegex( with self.assertRaisesRegex(
ValueError, r'Input timestamp must be monotonically increasing'): ValueError, r'Input timestamp must be monotonically increasing'
):
segmenter.segment_async(self.test_image, 0) segmenter.segment_async(self.test_image, 0)
def test_segment_async_calls(self): def test_segment_async_calls_in_category_mask_mode(self):
observed_timestamp_ms = -1 observed_timestamp_ms = -1
def check_result(result: List[image_module.Image], output_image: _Image, def check_result(
timestamp_ms: int): result: ImageSegmenterResult, output_image: _Image, timestamp_ms: int
):
# Get the output category mask. # Get the output category mask.
category_mask = result[0] category_mask = result.category_mask
self.assertEqual(output_image.width, self.test_image.width) self.assertEqual(output_image.width, self.test_image.width)
self.assertEqual(output_image.height, self.test_image.height) self.assertEqual(output_image.height, self.test_image.height)
self.assertEqual(output_image.width, self.test_seg_image.width) self.assertEqual(output_image.width, self.test_seg_image.width)
self.assertEqual(output_image.height, self.test_seg_image.height) self.assertEqual(output_image.height, self.test_seg_image.height)
self.assertTrue( self.assertTrue(
_similar_to_uint8_mask(category_mask, self.test_seg_image), _similar_to_uint8_mask(category_mask, self.test_seg_image),
f'Number of pixels in the candidate mask differing from that of the ' (
f'ground truth mask exceeds {_MASK_SIMILARITY_THRESHOLD}.') 'Number of pixels in the candidate mask differing from that of'
f' the ground truth mask exceeds {_MASK_SIMILARITY_THRESHOLD}.'
),
)
self.assertLess(observed_timestamp_ms, timestamp_ms) self.assertLess(observed_timestamp_ms, timestamp_ms)
self.observed_timestamp_ms = timestamp_ms self.observed_timestamp_ms = timestamp_ms
options = _ImageSegmenterOptions( options = _ImageSegmenterOptions(
base_options=_BaseOptions(model_asset_path=self.model_path), base_options=_BaseOptions(model_asset_path=self.model_path),
output_type=_OutputType.CATEGORY_MASK, output_category_mask=True,
output_confidence_masks=False,
running_mode=_RUNNING_MODE.LIVE_STREAM, running_mode=_RUNNING_MODE.LIVE_STREAM,
result_callback=check_result) result_callback=check_result,
)
with _ImageSegmenter.create_from_options(options) as segmenter: with _ImageSegmenter.create_from_options(options) as segmenter:
for timestamp in range(0, 300, 30): for timestamp in range(0, 300, 30):
segmenter.segment_async(self.test_image, timestamp) segmenter.segment_async(self.test_image, timestamp)
def test_segment_async_calls_in_confidence_mask_mode(self):
# Load the cat image.
test_image = _Image.create_from_file(
test_utils.get_test_data_path(os.path.join(_TEST_DATA_DIR, _CAT_IMAGE))
)
# Loads ground truth segmentation file.
expected_mask = self._load_segmentation_mask(_CAT_MASK)
observed_timestamp_ms = -1
def check_result(
result: ImageSegmenterResult, output_image: _Image, timestamp_ms: int
):
# Get the output category mask.
confidence_masks = result.confidence_masks
# Check if confidence mask shape is correct.
self.assertLen(
confidence_masks,
21,
'Number of confidence masks must match with number of categories.',
)
self.assertEqual(output_image.width, test_image.width)
self.assertEqual(output_image.height, test_image.height)
self.assertTrue(
_similar_to_float_mask(
confidence_masks[8], expected_mask, _MASK_SIMILARITY_THRESHOLD
)
)
self.assertLess(observed_timestamp_ms, timestamp_ms)
self.observed_timestamp_ms = timestamp_ms
options = _ImageSegmenterOptions(
base_options=_BaseOptions(model_asset_path=self.model_path),
running_mode=_RUNNING_MODE.LIVE_STREAM,
output_category_mask=False,
output_confidence_masks=True,
result_callback=check_result,
)
with _ImageSegmenter.create_from_options(options) as segmenter:
for timestamp in range(0, 300, 30):
segmenter.segment_async(test_image, timestamp)
if __name__ == '__main__': if __name__ == '__main__':
absltest.main() absltest.main()

View File

@ -30,12 +30,12 @@ from mediapipe.tasks.python.test import test_utils
from mediapipe.tasks.python.vision import interactive_segmenter from mediapipe.tasks.python.vision import interactive_segmenter
from mediapipe.tasks.python.vision.core import image_processing_options as image_processing_options_module from mediapipe.tasks.python.vision.core import image_processing_options as image_processing_options_module
InteractiveSegmenterResult = interactive_segmenter.InteractiveSegmenterResult
_BaseOptions = base_options_module.BaseOptions _BaseOptions = base_options_module.BaseOptions
_Image = image_module.Image _Image = image_module.Image
_ImageFormat = image_frame.ImageFormat _ImageFormat = image_frame.ImageFormat
_NormalizedKeypoint = keypoint_module.NormalizedKeypoint _NormalizedKeypoint = keypoint_module.NormalizedKeypoint
_Rect = rect.Rect _Rect = rect.Rect
_OutputType = interactive_segmenter.InteractiveSegmenterOptions.OutputType
_InteractiveSegmenter = interactive_segmenter.InteractiveSegmenter _InteractiveSegmenter = interactive_segmenter.InteractiveSegmenter
_InteractiveSegmenterOptions = interactive_segmenter.InteractiveSegmenterOptions _InteractiveSegmenterOptions = interactive_segmenter.InteractiveSegmenterOptions
_RegionOfInterest = interactive_segmenter.RegionOfInterest _RegionOfInterest = interactive_segmenter.RegionOfInterest
@ -200,15 +200,16 @@ class InteractiveSegmenterTest(parameterized.TestCase):
raise ValueError('model_file_type is invalid.') raise ValueError('model_file_type is invalid.')
options = _InteractiveSegmenterOptions( options = _InteractiveSegmenterOptions(
base_options=base_options, output_type=_OutputType.CATEGORY_MASK base_options=base_options,
output_category_mask=True,
output_confidence_masks=False,
) )
segmenter = _InteractiveSegmenter.create_from_options(options) segmenter = _InteractiveSegmenter.create_from_options(options)
# Performs image segmentation on the input. # Performs image segmentation on the input.
roi = _RegionOfInterest(format=roi_format, keypoint=keypoint) roi = _RegionOfInterest(format=roi_format, keypoint=keypoint)
category_masks = segmenter.segment(self.test_image, roi) segmentation_result = segmenter.segment(self.test_image, roi)
self.assertLen(category_masks, 1) category_mask = segmentation_result.category_mask
category_mask = category_masks[0]
result_pixels = category_mask.numpy_view().flatten() result_pixels = category_mask.numpy_view().flatten()
# Check if data type of `category_mask` is correct. # Check if data type of `category_mask` is correct.
@ -219,7 +220,7 @@ class InteractiveSegmenterTest(parameterized.TestCase):
self.assertTrue( self.assertTrue(
_similar_to_uint8_mask( _similar_to_uint8_mask(
category_masks[0], test_seg_image, similarity_threshold category_mask, test_seg_image, similarity_threshold
), ),
( (
'Number of pixels in the candidate mask differing from that of the' 'Number of pixels in the candidate mask differing from that of the'
@ -254,12 +255,15 @@ class InteractiveSegmenterTest(parameterized.TestCase):
# Run segmentation on the model in CONFIDENCE_MASK mode. # Run segmentation on the model in CONFIDENCE_MASK mode.
options = _InteractiveSegmenterOptions( options = _InteractiveSegmenterOptions(
base_options=base_options, output_type=_OutputType.CONFIDENCE_MASK base_options=base_options,
output_category_mask=False,
output_confidence_masks=True,
) )
with _InteractiveSegmenter.create_from_options(options) as segmenter: with _InteractiveSegmenter.create_from_options(options) as segmenter:
# Perform segmentation # Perform segmentation
confidence_masks = segmenter.segment(self.test_image, roi) segmentation_result = segmenter.segment(self.test_image, roi)
confidence_masks = segmentation_result.confidence_masks
# Check if confidence mask shape is correct. # Check if confidence mask shape is correct.
self.assertLen( self.assertLen(
@ -287,15 +291,18 @@ class InteractiveSegmenterTest(parameterized.TestCase):
# Run segmentation on the model in CONFIDENCE_MASK mode. # Run segmentation on the model in CONFIDENCE_MASK mode.
options = _InteractiveSegmenterOptions( options = _InteractiveSegmenterOptions(
base_options=base_options, output_type=_OutputType.CONFIDENCE_MASK base_options=base_options,
output_category_mask=False,
output_confidence_masks=True,
) )
with _InteractiveSegmenter.create_from_options(options) as segmenter: with _InteractiveSegmenter.create_from_options(options) as segmenter:
# Perform segmentation # Perform segmentation
image_processing_options = _ImageProcessingOptions(rotation_degrees=-90) image_processing_options = _ImageProcessingOptions(rotation_degrees=-90)
confidence_masks = segmenter.segment( segmentation_result = segmenter.segment(
self.test_image, roi, image_processing_options self.test_image, roi, image_processing_options
) )
confidence_masks = segmentation_result.confidence_masks
# Check if confidence mask shape is correct. # Check if confidence mask shape is correct.
self.assertLen( self.assertLen(
@ -314,7 +321,9 @@ class InteractiveSegmenterTest(parameterized.TestCase):
# Run segmentation on the model in CONFIDENCE_MASK mode. # Run segmentation on the model in CONFIDENCE_MASK mode.
options = _InteractiveSegmenterOptions( options = _InteractiveSegmenterOptions(
base_options=base_options, output_type=_OutputType.CONFIDENCE_MASK base_options=base_options,
output_category_mask=False,
output_confidence_masks=True,
) )
with self.assertRaisesRegex( with self.assertRaisesRegex(

View File

@ -14,7 +14,6 @@
"""MediaPipe image segmenter task.""" """MediaPipe image segmenter task."""
import dataclasses import dataclasses
import enum
from typing import Callable, List, Mapping, Optional from typing import Callable, List, Mapping, Optional
from mediapipe.python import packet_creator from mediapipe.python import packet_creator
@ -32,7 +31,6 @@ from mediapipe.tasks.python.vision.core import base_vision_task_api
from mediapipe.tasks.python.vision.core import image_processing_options as image_processing_options_module from mediapipe.tasks.python.vision.core import image_processing_options as image_processing_options_module
from mediapipe.tasks.python.vision.core import vision_task_running_mode from mediapipe.tasks.python.vision.core import vision_task_running_mode
ImageSegmenterResult = List[image_module.Image]
_NormalizedRect = rect.NormalizedRect _NormalizedRect = rect.NormalizedRect
_BaseOptions = base_options_module.BaseOptions _BaseOptions = base_options_module.BaseOptions
_SegmenterOptionsProto = segmenter_options_pb2.SegmenterOptions _SegmenterOptionsProto = segmenter_options_pb2.SegmenterOptions
@ -46,8 +44,10 @@ _RunningMode = vision_task_running_mode.VisionTaskRunningMode
_ImageProcessingOptions = image_processing_options_module.ImageProcessingOptions _ImageProcessingOptions = image_processing_options_module.ImageProcessingOptions
_TaskInfo = task_info_module.TaskInfo _TaskInfo = task_info_module.TaskInfo
_SEGMENTATION_OUT_STREAM_NAME = 'segmented_mask_out' _CONFIDENCE_MASKS_STREAM_NAME = 'confidence_masks'
_SEGMENTATION_TAG = 'GROUPED_SEGMENTATION' _CONFIDENCE_MASKS_TAG = 'CONFIDENCE_MASKS'
_CATEGORY_MASK_STREAM_NAME = 'category_mask'
_CATEGORY_MASK_TAG = 'CATEGORY_MASK'
_IMAGE_IN_STREAM_NAME = 'image_in' _IMAGE_IN_STREAM_NAME = 'image_in'
_IMAGE_OUT_STREAM_NAME = 'image_out' _IMAGE_OUT_STREAM_NAME = 'image_out'
_IMAGE_TAG = 'IMAGE' _IMAGE_TAG = 'IMAGE'
@ -58,6 +58,21 @@ _TASK_GRAPH_NAME = 'mediapipe.tasks.vision.image_segmenter.ImageSegmenterGraph'
_MICRO_SECONDS_PER_MILLISECOND = 1000 _MICRO_SECONDS_PER_MILLISECOND = 1000
@dataclasses.dataclass
class ImageSegmenterResult:
"""Output result of ImageSegmenter.
confidence_masks: multiple masks of float image where, for each mask, each
pixel represents the prediction confidence, usually in the [0, 1] range.
category_mask: a category mask of uint8 image where each pixel represents the
class which the pixel in the original image was predicted to belong to.
"""
confidence_masks: Optional[List[image_module.Image]] = None
category_mask: Optional[image_module.Image] = None
@dataclasses.dataclass @dataclasses.dataclass
class ImageSegmenterOptions: class ImageSegmenterOptions:
"""Options for the image segmenter task. """Options for the image segmenter task.
@ -69,28 +84,17 @@ class ImageSegmenterOptions:
objects on single image inputs. 2) The video mode for segmenting objects objects on single image inputs. 2) The video mode for segmenting objects
on the decoded frames of a video. 3) The live stream mode for segmenting on the decoded frames of a video. 3) The live stream mode for segmenting
objects on a live stream of input data, such as from camera. objects on a live stream of input data, such as from camera.
output_type: The output mask type allows specifying the type of output_confidence_masks: Whether to output confidence masks.
post-processing to perform on the raw model results. output_category_mask: Whether to output category mask.
activation: Activation function to apply to input tensor.
result_callback: The user-defined result callback for processing live stream result_callback: The user-defined result callback for processing live stream
data. The result callback should only be specified when the running mode data. The result callback should only be specified when the running mode
is set to the live stream mode. is set to the live stream mode.
""" """
class OutputType(enum.Enum):
UNSPECIFIED = 0
CATEGORY_MASK = 1
CONFIDENCE_MASK = 2
class Activation(enum.Enum):
NONE = 0
SIGMOID = 1
SOFTMAX = 2
base_options: _BaseOptions base_options: _BaseOptions
running_mode: _RunningMode = _RunningMode.IMAGE running_mode: _RunningMode = _RunningMode.IMAGE
output_type: Optional[OutputType] = OutputType.CATEGORY_MASK output_confidence_masks: bool = True
activation: Optional[Activation] = Activation.NONE output_category_mask: bool = False
result_callback: Optional[ result_callback: Optional[
Callable[[ImageSegmenterResult, image_module.Image, int], None] Callable[[ImageSegmenterResult, image_module.Image, int], None]
] = None ] = None
@ -102,9 +106,7 @@ class ImageSegmenterOptions:
base_options_proto.use_stream_mode = ( base_options_proto.use_stream_mode = (
False if self.running_mode == _RunningMode.IMAGE else True False if self.running_mode == _RunningMode.IMAGE else True
) )
segmenter_options_proto = _SegmenterOptionsProto( segmenter_options_proto = _SegmenterOptionsProto()
output_type=self.output_type.value, activation=self.activation.value
)
return _ImageSegmenterGraphOptionsProto( return _ImageSegmenterGraphOptionsProto(
base_options=base_options_proto, base_options=base_options_proto,
segmenter_options=segmenter_options_proto, segmenter_options=segmenter_options_proto,
@ -216,27 +218,48 @@ class ImageSegmenter(base_vision_task_api.BaseVisionTaskApi):
def packets_callback(output_packets: Mapping[str, packet.Packet]): def packets_callback(output_packets: Mapping[str, packet.Packet]):
if output_packets[_IMAGE_OUT_STREAM_NAME].is_empty(): if output_packets[_IMAGE_OUT_STREAM_NAME].is_empty():
return return
segmentation_result = packet_getter.get_image_list(
output_packets[_SEGMENTATION_OUT_STREAM_NAME] segmentation_result = ImageSegmenterResult()
)
if options.output_confidence_masks:
segmentation_result.confidence_masks = packet_getter.get_image_list(
output_packets[_CONFIDENCE_MASKS_STREAM_NAME]
)
if options.output_category_mask:
segmentation_result.category_mask = packet_getter.get_image(
output_packets[_CATEGORY_MASK_STREAM_NAME]
)
image = packet_getter.get_image(output_packets[_IMAGE_OUT_STREAM_NAME]) image = packet_getter.get_image(output_packets[_IMAGE_OUT_STREAM_NAME])
timestamp = output_packets[_SEGMENTATION_OUT_STREAM_NAME].timestamp timestamp = output_packets[_IMAGE_OUT_STREAM_NAME].timestamp
options.result_callback( options.result_callback(
segmentation_result, segmentation_result,
image, image,
timestamp.value // _MICRO_SECONDS_PER_MILLISECOND, timestamp.value // _MICRO_SECONDS_PER_MILLISECOND,
) )
output_streams = [
':'.join([_IMAGE_TAG, _IMAGE_OUT_STREAM_NAME]),
]
if options.output_confidence_masks:
output_streams.append(
':'.join([_CONFIDENCE_MASKS_TAG, _CONFIDENCE_MASKS_STREAM_NAME])
)
if options.output_category_mask:
output_streams.append(
':'.join([_CATEGORY_MASK_TAG, _CATEGORY_MASK_STREAM_NAME])
)
task_info = _TaskInfo( task_info = _TaskInfo(
task_graph=_TASK_GRAPH_NAME, task_graph=_TASK_GRAPH_NAME,
input_streams=[ input_streams=[
':'.join([_IMAGE_TAG, _IMAGE_IN_STREAM_NAME]), ':'.join([_IMAGE_TAG, _IMAGE_IN_STREAM_NAME]),
':'.join([_NORM_RECT_TAG, _NORM_RECT_STREAM_NAME]), ':'.join([_NORM_RECT_TAG, _NORM_RECT_STREAM_NAME]),
], ],
output_streams=[ output_streams=output_streams,
':'.join([_SEGMENTATION_TAG, _SEGMENTATION_OUT_STREAM_NAME]),
':'.join([_IMAGE_TAG, _IMAGE_OUT_STREAM_NAME]),
],
task_options=options, task_options=options,
) )
return cls( return cls(
@ -292,9 +315,18 @@ class ImageSegmenter(base_vision_task_api.BaseVisionTaskApi):
normalized_rect.to_pb2() normalized_rect.to_pb2()
), ),
}) })
segmentation_result = packet_getter.get_image_list( segmentation_result = ImageSegmenterResult()
output_packets[_SEGMENTATION_OUT_STREAM_NAME]
) if _CONFIDENCE_MASKS_STREAM_NAME in output_packets:
segmentation_result.confidence_masks = packet_getter.get_image_list(
output_packets[_CONFIDENCE_MASKS_STREAM_NAME]
)
if _CATEGORY_MASK_STREAM_NAME in output_packets:
segmentation_result.category_mask = packet_getter.get_image(
output_packets[_CATEGORY_MASK_STREAM_NAME]
)
return segmentation_result return segmentation_result
def segment_for_video( def segment_for_video(
@ -337,9 +369,18 @@ class ImageSegmenter(base_vision_task_api.BaseVisionTaskApi):
normalized_rect.to_pb2() normalized_rect.to_pb2()
).at(timestamp_ms * _MICRO_SECONDS_PER_MILLISECOND), ).at(timestamp_ms * _MICRO_SECONDS_PER_MILLISECOND),
}) })
segmentation_result = packet_getter.get_image_list( segmentation_result = ImageSegmenterResult()
output_packets[_SEGMENTATION_OUT_STREAM_NAME]
) if _CONFIDENCE_MASKS_STREAM_NAME in output_packets:
segmentation_result.confidence_masks = packet_getter.get_image_list(
output_packets[_CONFIDENCE_MASKS_STREAM_NAME]
)
if _CATEGORY_MASK_STREAM_NAME in output_packets:
segmentation_result.category_mask = packet_getter.get_image(
output_packets[_CATEGORY_MASK_STREAM_NAME]
)
return segmentation_result return segmentation_result
def segment_async( def segment_async(

View File

@ -41,8 +41,10 @@ _RunningMode = vision_task_running_mode.VisionTaskRunningMode
_ImageProcessingOptions = image_processing_options_module.ImageProcessingOptions _ImageProcessingOptions = image_processing_options_module.ImageProcessingOptions
_TaskInfo = task_info_module.TaskInfo _TaskInfo = task_info_module.TaskInfo
_SEGMENTATION_OUT_STREAM_NAME = 'segmented_mask_out' _CONFIDENCE_MASKS_STREAM_NAME = 'confidence_masks'
_SEGMENTATION_TAG = 'GROUPED_SEGMENTATION' _CONFIDENCE_MASKS_TAG = 'CONFIDENCE_MASKS'
_CATEGORY_MASK_STREAM_NAME = 'category_mask'
_CATEGORY_MASK_TAG = 'CATEGORY_MASK'
_IMAGE_IN_STREAM_NAME = 'image_in' _IMAGE_IN_STREAM_NAME = 'image_in'
_IMAGE_OUT_STREAM_NAME = 'image_out' _IMAGE_OUT_STREAM_NAME = 'image_out'
_ROI_STREAM_NAME = 'roi_in' _ROI_STREAM_NAME = 'roi_in'
@ -55,32 +57,41 @@ _TASK_GRAPH_NAME = (
) )
@dataclasses.dataclass
class InteractiveSegmenterResult:
"""Output result of InteractiveSegmenter.
confidence_masks: multiple masks of float image where, for each mask, each
pixel represents the prediction confidence, usually in the [0, 1] range.
category_mask: a category mask of uint8 image where each pixel represents the
class which the pixel in the original image was predicted to belong to.
"""
confidence_masks: Optional[List[image_module.Image]] = None
category_mask: Optional[image_module.Image] = None
@dataclasses.dataclass @dataclasses.dataclass
class InteractiveSegmenterOptions: class InteractiveSegmenterOptions:
"""Options for the interactive segmenter task. """Options for the interactive segmenter task.
Attributes: Attributes:
base_options: Base options for the interactive segmenter task. base_options: Base options for the interactive segmenter task.
output_type: The output mask type allows specifying the type of output_confidence_masks: Whether to output confidence masks.
post-processing to perform on the raw model results. output_category_mask: Whether to output category mask.
""" """
class OutputType(enum.Enum):
UNSPECIFIED = 0
CATEGORY_MASK = 1
CONFIDENCE_MASK = 2
base_options: _BaseOptions base_options: _BaseOptions
output_type: Optional[OutputType] = OutputType.CATEGORY_MASK output_confidence_masks: bool = True
output_category_mask: bool = False
@doc_controls.do_not_generate_docs @doc_controls.do_not_generate_docs
def to_pb2(self) -> _ImageSegmenterGraphOptionsProto: def to_pb2(self) -> _ImageSegmenterGraphOptionsProto:
"""Generates an InteractiveSegmenterOptions protobuf object.""" """Generates an InteractiveSegmenterOptions protobuf object."""
base_options_proto = self.base_options.to_pb2() base_options_proto = self.base_options.to_pb2()
base_options_proto.use_stream_mode = False base_options_proto.use_stream_mode = False
segmenter_options_proto = _SegmenterOptionsProto( segmenter_options_proto = _SegmenterOptionsProto()
output_type=self.output_type.value
)
return _ImageSegmenterGraphOptionsProto( return _ImageSegmenterGraphOptionsProto(
base_options=base_options_proto, base_options=base_options_proto,
segmenter_options=segmenter_options_proto, segmenter_options=segmenter_options_proto,
@ -192,6 +203,20 @@ class InteractiveSegmenter(base_vision_task_api.BaseVisionTaskApi):
RuntimeError: If other types of error occurred. RuntimeError: If other types of error occurred.
""" """
output_streams = [
':'.join([_IMAGE_TAG, _IMAGE_OUT_STREAM_NAME]),
]
if options.output_confidence_masks:
output_streams.append(
':'.join([_CONFIDENCE_MASKS_TAG, _CONFIDENCE_MASKS_STREAM_NAME])
)
if options.output_category_mask:
output_streams.append(
':'.join([_CATEGORY_MASK_TAG, _CATEGORY_MASK_STREAM_NAME])
)
task_info = _TaskInfo( task_info = _TaskInfo(
task_graph=_TASK_GRAPH_NAME, task_graph=_TASK_GRAPH_NAME,
input_streams=[ input_streams=[
@ -199,10 +224,7 @@ class InteractiveSegmenter(base_vision_task_api.BaseVisionTaskApi):
':'.join([_ROI_TAG, _ROI_STREAM_NAME]), ':'.join([_ROI_TAG, _ROI_STREAM_NAME]),
':'.join([_NORM_RECT_TAG, _NORM_RECT_STREAM_NAME]), ':'.join([_NORM_RECT_TAG, _NORM_RECT_STREAM_NAME]),
], ],
output_streams=[ output_streams=output_streams,
':'.join([_SEGMENTATION_TAG, _SEGMENTATION_OUT_STREAM_NAME]),
':'.join([_IMAGE_TAG, _IMAGE_OUT_STREAM_NAME]),
],
task_options=options, task_options=options,
) )
return cls( return cls(
@ -216,7 +238,7 @@ class InteractiveSegmenter(base_vision_task_api.BaseVisionTaskApi):
image: image_module.Image, image: image_module.Image,
roi: RegionOfInterest, roi: RegionOfInterest,
image_processing_options: Optional[_ImageProcessingOptions] = None, image_processing_options: Optional[_ImageProcessingOptions] = None,
) -> List[image_module.Image]: ) -> InteractiveSegmenterResult:
"""Performs the actual segmentation task on the provided MediaPipe Image. """Performs the actual segmentation task on the provided MediaPipe Image.
The image can be of any size with format RGB. The image can be of any size with format RGB.
@ -248,7 +270,16 @@ class InteractiveSegmenter(base_vision_task_api.BaseVisionTaskApi):
normalized_rect.to_pb2() normalized_rect.to_pb2()
), ),
}) })
segmentation_result = packet_getter.get_image_list( segmentation_result = InteractiveSegmenterResult()
output_packets[_SEGMENTATION_OUT_STREAM_NAME]
) if _CONFIDENCE_MASKS_STREAM_NAME in output_packets:
segmentation_result.confidence_masks = packet_getter.get_image_list(
output_packets[_CONFIDENCE_MASKS_STREAM_NAME]
)
if _CATEGORY_MASK_STREAM_NAME in output_packets:
segmentation_result.category_mask = packet_getter.get_image(
output_packets[_CATEGORY_MASK_STREAM_NAME]
)
return segmentation_result return segmentation_result

View File

@ -12,72 +12,72 @@ def wasm_files():
http_file( http_file(
name = "com_google_mediapipe_wasm_audio_wasm_internal_js", name = "com_google_mediapipe_wasm_audio_wasm_internal_js",
sha256 = "0eca68e2291a548b734bcab5db4c9e6b997e852ea7e19228003b9e2a78c7c646", sha256 = "b810de53d7ccf991b9c70fcdf7e88b5c3f2942ae766436f22be48159b6a7e687",
urls = ["https://storage.googleapis.com/mediapipe-assets/wasm/audio_wasm_internal.js?generation=1681328323089931"], urls = ["https://storage.googleapis.com/mediapipe-assets/wasm/audio_wasm_internal.js?generation=1681849488227617"],
) )
http_file( http_file(
name = "com_google_mediapipe_wasm_audio_wasm_internal_wasm", name = "com_google_mediapipe_wasm_audio_wasm_internal_wasm",
sha256 = "69bc95af5b783b510ec1842d6fb9594254907d8e1334799c5753164878a7dcac", sha256 = "26d91147e5c6c8a92e0a4ebf59599068a3cff6108847b793ef33ac23e98eddb9",
urls = ["https://storage.googleapis.com/mediapipe-assets/wasm/audio_wasm_internal.wasm?generation=1681328325829340"], urls = ["https://storage.googleapis.com/mediapipe-assets/wasm/audio_wasm_internal.wasm?generation=1681849491546937"],
) )
http_file( http_file(
name = "com_google_mediapipe_wasm_audio_wasm_nosimd_internal_js", name = "com_google_mediapipe_wasm_audio_wasm_nosimd_internal_js",
sha256 = "88a0176cc80d6a1eb175a5105df705cf8b8684cf13f6db0a264af0b67b65a22a", sha256 = "b38e37b3024692558eaaba159921fedd3297d1a09bba1c16a06fed327845b0bd",
urls = ["https://storage.googleapis.com/mediapipe-assets/wasm/audio_wasm_nosimd_internal.js?generation=1681328328330829"], urls = ["https://storage.googleapis.com/mediapipe-assets/wasm/audio_wasm_nosimd_internal.js?generation=1681849494099698"],
) )
http_file( http_file(
name = "com_google_mediapipe_wasm_audio_wasm_nosimd_internal_wasm", name = "com_google_mediapipe_wasm_audio_wasm_nosimd_internal_wasm",
sha256 = "1cc0c3db7d252801be4b090d8bbba61f308cc3dd5efe197319581d3af29495c7", sha256 = "6a8e73d2e926565046e16adf1748f0f8ec5135fafe7eb8b9c83892e64c1a449a",
urls = ["https://storage.googleapis.com/mediapipe-assets/wasm/audio_wasm_nosimd_internal.wasm?generation=1681328331085637"], urls = ["https://storage.googleapis.com/mediapipe-assets/wasm/audio_wasm_nosimd_internal.wasm?generation=1681849496451970"],
) )
http_file( http_file(
name = "com_google_mediapipe_wasm_text_wasm_internal_js", name = "com_google_mediapipe_wasm_text_wasm_internal_js",
sha256 = "d9cd100b6d330d36f7749fe5fc64a2cdd0abb947a0376e6140784cfb0361a4e2", sha256 = "785cba67b623b1dc66dc3621e97fd6b30edccbb408184a3094d0aa68ddd5becb",
urls = ["https://storage.googleapis.com/mediapipe-assets/wasm/text_wasm_internal.js?generation=1681328333442454"], urls = ["https://storage.googleapis.com/mediapipe-assets/wasm/text_wasm_internal.js?generation=1681849498746265"],
) )
http_file( http_file(
name = "com_google_mediapipe_wasm_text_wasm_internal_wasm", name = "com_google_mediapipe_wasm_text_wasm_internal_wasm",
sha256 = "30a2fcca630bdad6e99173ea7d0d8c5d7086aedf393d0159fa05bf9d08d4ff65", sha256 = "a858b8a2e8b40e9c936b66566c5aefd396536c4e936459ab9ae7e239621adc14",
urls = ["https://storage.googleapis.com/mediapipe-assets/wasm/text_wasm_internal.wasm?generation=1681328335803336"], urls = ["https://storage.googleapis.com/mediapipe-assets/wasm/text_wasm_internal.wasm?generation=1681849501370461"],
) )
http_file( http_file(
name = "com_google_mediapipe_wasm_text_wasm_nosimd_internal_js", name = "com_google_mediapipe_wasm_text_wasm_nosimd_internal_js",
sha256 = "70ca2bd15c56e0ce7bb10ff2188b4a1f9eafbb657eb9424e4cab8d7b29179871", sha256 = "5292f1442d5e5c037e7cffb78a8c2d71255348ca2c3bd759b314bdbedd5590c2",
urls = ["https://storage.googleapis.com/mediapipe-assets/wasm/text_wasm_nosimd_internal.js?generation=1681328338162884"], urls = ["https://storage.googleapis.com/mediapipe-assets/wasm/text_wasm_nosimd_internal.js?generation=1681849503379116"],
) )
http_file( http_file(
name = "com_google_mediapipe_wasm_text_wasm_nosimd_internal_wasm", name = "com_google_mediapipe_wasm_text_wasm_nosimd_internal_wasm",
sha256 = "8221b385905f36a769d7731a0adbe18b681bcb873561890429ca84278c67c3fd", sha256 = "e44b48ab29ee1d8befec804e9a63445c56266b679d19fb476d556ca621f0e493",
urls = ["https://storage.googleapis.com/mediapipe-assets/wasm/text_wasm_nosimd_internal.wasm?generation=1681328340808115"], urls = ["https://storage.googleapis.com/mediapipe-assets/wasm/text_wasm_nosimd_internal.wasm?generation=1681849505997020"],
) )
http_file( http_file(
name = "com_google_mediapipe_wasm_vision_wasm_internal_js", name = "com_google_mediapipe_wasm_vision_wasm_internal_js",
sha256 = "07692acd8202adafebd35dbcd7e2b8e88a76d4a0e6b9229cb3cad59503eeddc7", sha256 = "205855eba70464a92b9d00e90acac15c51a9f76192f900e697304ac6dea8f714",
urls = ["https://storage.googleapis.com/mediapipe-assets/wasm/vision_wasm_internal.js?generation=1681328343147709"], urls = ["https://storage.googleapis.com/mediapipe-assets/wasm/vision_wasm_internal.js?generation=1681849508414277"],
) )
http_file( http_file(
name = "com_google_mediapipe_wasm_vision_wasm_internal_wasm", name = "com_google_mediapipe_wasm_vision_wasm_internal_wasm",
sha256 = "03bf553fa6a768b0d70103a5e7d835b6b37371ff44e201c3392f22e0879737c3", sha256 = "c0cbd0df3adb2a9cd1331d14f522d2bae9f8adc9f1b35f92cbbc4b782b190cef",
urls = ["https://storage.googleapis.com/mediapipe-assets/wasm/vision_wasm_internal.wasm?generation=1681328345605574"], urls = ["https://storage.googleapis.com/mediapipe-assets/wasm/vision_wasm_internal.wasm?generation=1681849510936608"],
) )
http_file( http_file(
name = "com_google_mediapipe_wasm_vision_wasm_nosimd_internal_js", name = "com_google_mediapipe_wasm_vision_wasm_nosimd_internal_js",
sha256 = "36697be14f921985eac15d1447ec8a260817b05ade1c9bb3ca7e906e0f047ec0", sha256 = "0969812de4d3573198fa2eba4f5b0a7e97e98f97bd4215d876543f4925e57b84",
urls = ["https://storage.googleapis.com/mediapipe-assets/wasm/vision_wasm_nosimd_internal.js?generation=1681328348025082"], urls = ["https://storage.googleapis.com/mediapipe-assets/wasm/vision_wasm_nosimd_internal.js?generation=1681849513292639"],
) )
http_file( http_file(
name = "com_google_mediapipe_wasm_vision_wasm_nosimd_internal_wasm", name = "com_google_mediapipe_wasm_vision_wasm_nosimd_internal_wasm",
sha256 = "103fb145438d61cfecb2e8db3f06b43a5d77a7e3fcea940437fe272227cf2592", sha256 = "f2ab62c3f8dabab0a573dadf5c105ff81a03c29c70f091f8cf273ae030c0a86f",
urls = ["https://storage.googleapis.com/mediapipe-assets/wasm/vision_wasm_nosimd_internal.wasm?generation=1681328350709881"], urls = ["https://storage.googleapis.com/mediapipe-assets/wasm/vision_wasm_nosimd_internal.wasm?generation=1681849515999000"],
) )