Refactored Rect to use top-left coordinates and appropriately updated the Image Classifier and Gesture Recognizer APIs/tests

This commit is contained in:
kinaryml 2022-11-01 15:37:00 -07:00
parent a913255080
commit c5765ac836
9 changed files with 75 additions and 100 deletions

View File

@ -19,75 +19,44 @@ from typing import Any, Optional
from mediapipe.framework.formats import rect_pb2
from mediapipe.tasks.python.core.optional_dependencies import doc_controls
_RectProto = rect_pb2.Rect
_NormalizedRectProto = rect_pb2.NormalizedRect
@dataclasses.dataclass
class Rect:
"""A rectangle with rotation in image coordinates.
"""A rectangle, used e.g. as part of detection results or as input
region-of-interest.
Attributes: x_center : The X coordinate of the top-left corner, in pixels.
y_center : The Y coordinate of the top-left corner, in pixels.
width: The width of the rectangle, in pixels.
height: The height of the rectangle, in pixels.
rotation: Rotation angle is clockwise in radians.
rect_id: Optional unique id to help associate different rectangles to each
other.
The coordinates are normalized wrt the image dimensions, i.e. generally in
[0,1] but they may exceed these bounds if describing a region overlapping the
image. The origin is on the top-left corner of the image.
Attributes:
left: The X coordinate of the left side of the rectangle.
top: The Y coordinate of the top of the rectangle.
right: The X coordinate of the right side of the rectangle.
bottom: The Y coordinate of the bottom of the rectangle.
"""
x_center: int
y_center: int
width: int
height: int
rotation: Optional[float] = 0.0
rect_id: Optional[int] = None
@doc_controls.do_not_generate_docs
def to_pb2(self) -> _RectProto:
"""Generates a Rect protobuf object."""
return _RectProto(
x_center=self.x_center,
y_center=self.y_center,
width=self.width,
height=self.height,
)
@classmethod
@doc_controls.do_not_generate_docs
def create_from_pb2(cls, pb2_obj: _RectProto) -> 'Rect':
"""Creates a `Rect` object from the given protobuf object."""
return Rect(
x_center=pb2_obj.x_center,
y_center=pb2_obj.y_center,
width=pb2_obj.width,
height=pb2_obj.height)
def __eq__(self, other: Any) -> bool:
"""Checks if this object is equal to the given object.
Args:
other: The object to be compared with.
Returns:
True if the objects are equal.
"""
if not isinstance(other, Rect):
return False
return self.to_pb2().__eq__(other.to_pb2())
left: float
top: float
right: float
bottom: float
@dataclasses.dataclass
class NormalizedRect:
"""A rectangle with rotation in normalized coordinates.
"""A rectangle with rotation in normalized coordinates. Location of the center
of the rectangle in image coordinates. The (0.0, 0.0) point is at the
(top, left) corner.
The values of box
center location and size are within [0, 1].
Attributes: x_center : The X normalized coordinate of the top-left corner.
y_center : The Y normalized coordinate of the top-left corner.
Attributes: x_center: The normalized X coordinate of the rectangle, in
image coordinates.
y_center: The normalized Y coordinate of the rectangle, in image coordinates.
width: The width of the rectangle.
height: The height of the rectangle.
rotation: Rotation angle is clockwise in radians.

View File

@ -54,6 +54,7 @@ py_test(
"//mediapipe/tasks/python/test:test_utils",
"//mediapipe/tasks/python/vision:image_classifier",
"//mediapipe/tasks/python/vision/core:vision_task_running_mode",
"//mediapipe/tasks/python/vision/core:image_processing_options",
],
)

View File

@ -78,7 +78,7 @@ def _get_expected_gesture_recognition_result(
file_path)
with open(landmarks_detection_result_file_path, "rb") as f:
landmarks_detection_result_proto = _LandmarksDetectionResultProto()
# # Use this if a .pb file is available.
# Use this if a .pb file is available.
# landmarks_detection_result_proto.ParseFromString(f.read())
text_format.Parse(f.read(), landmarks_detection_result_proto)
landmarks_detection_result = _LandmarksDetectionResult.create_from_pb2(

View File

@ -29,8 +29,10 @@ from mediapipe.tasks.python.core import base_options as base_options_module
from mediapipe.tasks.python.test import test_utils
from mediapipe.tasks.python.vision import image_classifier
from mediapipe.tasks.python.vision.core import vision_task_running_mode
from mediapipe.tasks.python.vision.core import image_processing_options as image_processing_options_module
_NormalizedRect = rect.NormalizedRect
_Rect = rect.Rect
_BaseOptions = base_options_module.BaseOptions
_ClassifierOptions = classifier_options.ClassifierOptions
_Category = category.Category
@ -41,6 +43,7 @@ _Image = image.Image
_ImageClassifier = image_classifier.ImageClassifier
_ImageClassifierOptions = image_classifier.ImageClassifierOptions
_RUNNING_MODE = vision_task_running_mode.VisionTaskRunningMode
_ImageProcessingOptions = image_processing_options_module.ImageProcessingOptions
_MODEL_FILE = 'mobilenet_v2_1.0_224.tflite'
_IMAGE_FILE = 'burger.jpg'
@ -226,11 +229,11 @@ class ImageClassifierTest(parameterized.TestCase):
# Load the test image.
test_image = _Image.create_from_file(
test_utils.get_test_data_path('multi_objects.jpg'))
# NormalizedRect around the soccer ball.
roi = _NormalizedRect(
x_center=0.532, y_center=0.521, width=0.164, height=0.427)
# Region-of-interest around the soccer ball.
roi = _Rect(left=0.45, top=0.3075, right=0.614, bottom=0.7345)
image_processing_options = _ImageProcessingOptions(roi)
# Performs image classification on the input.
image_result = classifier.classify(test_image, roi)
image_result = classifier.classify(test_image, image_processing_options)
# Comparing results.
_assert_proto_equals(image_result.to_pb2(),
_generate_soccer_ball_results(0).to_pb2())
@ -414,12 +417,12 @@ class ImageClassifierTest(parameterized.TestCase):
# Load the test image.
test_image = _Image.create_from_file(
test_utils.get_test_data_path('multi_objects.jpg'))
# NormalizedRect around the soccer ball.
roi = _NormalizedRect(
x_center=0.532, y_center=0.521, width=0.164, height=0.427)
# Region-of-interest around the soccer ball.
roi = _Rect(left=0.45, top=0.3075, right=0.614, bottom=0.7345)
image_processing_options = _ImageProcessingOptions(roi)
for timestamp in range(0, 300, 30):
classification_result = classifier.classify_for_video(
test_image, timestamp, roi)
test_image, timestamp, image_processing_options)
self.assertEqual(classification_result,
_generate_soccer_ball_results(timestamp))
@ -486,9 +489,9 @@ class ImageClassifierTest(parameterized.TestCase):
# Load the test image.
test_image = _Image.create_from_file(
test_utils.get_test_data_path('multi_objects.jpg'))
# NormalizedRect around the soccer ball.
roi = _NormalizedRect(
x_center=0.532, y_center=0.521, width=0.164, height=0.427)
# Region-of-interest around the soccer ball.
roi = _Rect(left=0.45, top=0.3075, right=0.614, bottom=0.7345)
image_processing_options = _ImageProcessingOptions(roi)
observed_timestamp_ms = -1
def check_result(result: _ClassificationResult, output_image: _Image,
@ -508,7 +511,8 @@ class ImageClassifierTest(parameterized.TestCase):
result_callback=check_result)
with _ImageClassifier.create_from_options(options) as classifier:
for timestamp in range(0, 300, 30):
classifier.classify_async(test_image, timestamp, roi)
classifier.classify_async(test_image, timestamp,
image_processing_options)
if __name__ == '__main__':

View File

@ -56,6 +56,7 @@ py_library(
"//mediapipe/tasks/python/core:task_info",
"//mediapipe/tasks/python/vision/core:base_vision_task_api",
"//mediapipe/tasks/python/vision/core:vision_task_running_mode",
"//mediapipe/tasks/python/vision/core:image_processing_options",
],
)
@ -96,5 +97,6 @@ py_library(
"//mediapipe/tasks/python/core:task_info",
"//mediapipe/tasks/python/vision/core:base_vision_task_api",
"//mediapipe/tasks/python/vision/core:vision_task_running_mode",
"//mediapipe/tasks/python/vision/core:image_processing_options",
],
)

View File

@ -160,15 +160,15 @@ class BaseVisionTaskApi(object):
if not roi_allowed:
raise ValueError("This task doesn't support region-of-interest.")
roi = options.region_of_interest
if roi.x_center >= roi.width or roi.y_center >= roi.height:
if roi.left >= roi.right or roi.top >= roi.bottom:
raise ValueError(
"Expected Rect with x_center < width and y_center < height.")
if roi.x_center < 0 or roi.y_center < 0 or roi.width > 1 or roi.height > 1:
"Expected Rect with left < right and top < bottom.")
if roi.left < 0 or roi.top < 0 or roi.right > 1 or roi.bottom > 1:
raise ValueError("Expected Rect values to be in [0,1].")
normalized_rect.x_center = roi.x_center + roi.width / 2.0
normalized_rect.y_center = roi.y_center + roi.height / 2.0
normalized_rect.width = roi.width - roi.x_center
normalized_rect.height = roi.height - roi.y_center
normalized_rect.x_center = (roi.left + roi.right) / 2.0
normalized_rect.y_center = (roi.top + roi.bottom) / 2.0
normalized_rect.width = roi.right - roi.left
normalized_rect.height = roi.bottom - roi.top
return normalized_rect
def close(self) -> None:

View File

@ -30,7 +30,7 @@ class ImageProcessingOptions:
Attributes:
region_of_interest: The optional region-of-interest to crop from the image.
If not specified, the full image is used. Coordinates must be in [0,1]
with 'x_center' < 'width' and 'y_center' < height.
with 'left' < 'right' and 'top' < 'bottom'.
rotation_degress: The rotation to apply to the image (or cropped
region-of-interest), in degrees clockwise. The rotation must be a
multiple (positive or negative) of 90°.

View File

@ -63,7 +63,7 @@ class GestureRecognitionResult:
Attributes:
gestures: Recognized hand gestures of detected hands. Note that the index
of the gesture is always 0, because the raw indices from multiple gesture
of the gesture is always -1, because the raw indices from multiple gesture
classifiers cannot consolidate to a meaningful index.
handedness: Classification of handedness.
hand_landmarks: Detected hand landmarks in normalized image coordinates.

View File

@ -31,12 +31,14 @@ from mediapipe.tasks.python.core import task_info as task_info_module
from mediapipe.tasks.python.core.optional_dependencies import doc_controls
from mediapipe.tasks.python.vision.core import base_vision_task_api
from mediapipe.tasks.python.vision.core import vision_task_running_mode
from mediapipe.tasks.python.vision.core import image_processing_options as image_processing_options_module
_NormalizedRect = rect.NormalizedRect
_BaseOptions = base_options_module.BaseOptions
_ImageClassifierGraphOptionsProto = image_classifier_graph_options_pb2.ImageClassifierGraphOptions
_ClassifierOptions = classifier_options.ClassifierOptions
_RunningMode = vision_task_running_mode.VisionTaskRunningMode
_ImageProcessingOptions = image_processing_options_module.ImageProcessingOptions
_TaskInfo = task_info_module.TaskInfo
_CLASSIFICATION_RESULT_OUT_STREAM_NAME = 'classification_result_out'
@ -44,17 +46,12 @@ _CLASSIFICATION_RESULT_TAG = 'CLASSIFICATION_RESULT'
_IMAGE_IN_STREAM_NAME = 'image_in'
_IMAGE_OUT_STREAM_NAME = 'image_out'
_IMAGE_TAG = 'IMAGE'
_NORM_RECT_NAME = 'norm_rect_in'
_NORM_RECT_STREAM_NAME = 'norm_rect_in'
_NORM_RECT_TAG = 'NORM_RECT'
_TASK_GRAPH_NAME = 'mediapipe.tasks.vision.image_classifier.ImageClassifierGraph'
_MICRO_SECONDS_PER_MILLISECOND = 1000
def _build_full_image_norm_rect() -> _NormalizedRect:
# Builds a NormalizedRect covering the entire image.
return _NormalizedRect(x_center=0.5, y_center=0.5, width=1, height=1)
@dataclasses.dataclass
class ImageClassifierOptions:
"""Options for the image classifier task.
@ -156,7 +153,7 @@ class ImageClassifier(base_vision_task_api.BaseVisionTaskApi):
task_graph=_TASK_GRAPH_NAME,
input_streams=[
':'.join([_IMAGE_TAG, _IMAGE_IN_STREAM_NAME]),
':'.join([_NORM_RECT_TAG, _NORM_RECT_NAME]),
':'.join([_NORM_RECT_TAG, _NORM_RECT_STREAM_NAME]),
],
output_streams=[
':'.join([
@ -171,17 +168,16 @@ class ImageClassifier(base_vision_task_api.BaseVisionTaskApi):
_RunningMode.LIVE_STREAM), options.running_mode,
packets_callback if options.result_callback else None)
# TODO: Replace _NormalizedRect with ImageProcessingOption
def classify(
self,
image: image_module.Image,
roi: Optional[_NormalizedRect] = None
image_processing_options: Optional[_ImageProcessingOptions] = None
) -> classifications.ClassificationResult:
"""Performs image classification on the provided MediaPipe Image.
Args:
image: MediaPipe Image.
roi: The region of interest.
image_processing_options: Options for image processing.
Returns:
A classification result object that contains a list of classifications.
@ -190,10 +186,11 @@ class ImageClassifier(base_vision_task_api.BaseVisionTaskApi):
ValueError: If any of the input arguments is invalid.
RuntimeError: If image classification failed to run.
"""
norm_rect = roi if roi is not None else _build_full_image_norm_rect()
normalized_rect = self.convert_to_normalized_rect(image_processing_options)
output_packets = self._process_image_data({
_IMAGE_IN_STREAM_NAME: packet_creator.create_image(image),
_NORM_RECT_NAME: packet_creator.create_proto(norm_rect.to_pb2())
_NORM_RECT_STREAM_NAME: packet_creator.create_proto(
normalized_rect.to_pb2())
})
classification_result_proto = classifications_pb2.ClassificationResult()
@ -210,7 +207,7 @@ class ImageClassifier(base_vision_task_api.BaseVisionTaskApi):
self,
image: image_module.Image,
timestamp_ms: int,
roi: Optional[_NormalizedRect] = None
image_processing_options: Optional[_ImageProcessingOptions] = None
) -> classifications.ClassificationResult:
"""Performs image classification on the provided video frames.
@ -222,7 +219,7 @@ class ImageClassifier(base_vision_task_api.BaseVisionTaskApi):
Args:
image: MediaPipe Image.
timestamp_ms: The timestamp of the input video frame in milliseconds.
roi: The region of interest.
image_processing_options: Options for image processing.
Returns:
A classification result object that contains a list of classifications.
@ -231,13 +228,13 @@ class ImageClassifier(base_vision_task_api.BaseVisionTaskApi):
ValueError: If any of the input arguments is invalid.
RuntimeError: If image classification failed to run.
"""
norm_rect = roi if roi is not None else _build_full_image_norm_rect()
normalized_rect = self.convert_to_normalized_rect(image_processing_options)
output_packets = self._process_video_data({
_IMAGE_IN_STREAM_NAME:
packet_creator.create_image(image).at(
timestamp_ms * _MICRO_SECONDS_PER_MILLISECOND),
_NORM_RECT_NAME:
packet_creator.create_proto(norm_rect.to_pb2()).at(
_NORM_RECT_STREAM_NAME:
packet_creator.create_proto(normalized_rect.to_pb2()).at(
timestamp_ms * _MICRO_SECONDS_PER_MILLISECOND)
})
@ -251,10 +248,12 @@ class ImageClassifier(base_vision_task_api.BaseVisionTaskApi):
for classification in classification_result_proto.classifications
])
def classify_async(self,
def classify_async(
self,
image: image_module.Image,
timestamp_ms: int,
roi: Optional[_NormalizedRect] = None) -> None:
image_processing_options: Optional[_ImageProcessingOptions] = None
) -> None:
"""Sends live image data (an Image with a unique timestamp) to perform image classification.
Only use this method when the ImageClassifier is created with the live
@ -275,18 +274,18 @@ class ImageClassifier(base_vision_task_api.BaseVisionTaskApi):
Args:
image: MediaPipe Image.
timestamp_ms: The timestamp of the input image in milliseconds.
roi: The region of interest.
image_processing_options: Options for image processing.
Raises:
ValueError: If the current input timestamp is smaller than what the image
classifier has already processed.
"""
norm_rect = roi if roi is not None else _build_full_image_norm_rect()
normalized_rect = self.convert_to_normalized_rect(image_processing_options)
self._send_live_stream_data({
_IMAGE_IN_STREAM_NAME:
packet_creator.create_image(image).at(
timestamp_ms * _MICRO_SECONDS_PER_MILLISECOND),
_NORM_RECT_NAME:
packet_creator.create_proto(norm_rect.to_pb2()).at(
_NORM_RECT_STREAM_NAME:
packet_creator.create_proto(normalized_rect.to_pb2()).at(
timestamp_ms * _MICRO_SECONDS_PER_MILLISECOND)
})