Refactored Rect to use top-left coordinates and appropriately updated the Image Classifier and Gesture Recognizer APIs/tests

This commit is contained in:
kinaryml 2022-11-01 15:37:00 -07:00
parent a913255080
commit c5765ac836
9 changed files with 75 additions and 100 deletions

View File

@ -19,75 +19,44 @@ from typing import Any, Optional
from mediapipe.framework.formats import rect_pb2 from mediapipe.framework.formats import rect_pb2
from mediapipe.tasks.python.core.optional_dependencies import doc_controls from mediapipe.tasks.python.core.optional_dependencies import doc_controls
_RectProto = rect_pb2.Rect
_NormalizedRectProto = rect_pb2.NormalizedRect _NormalizedRectProto = rect_pb2.NormalizedRect
@dataclasses.dataclass @dataclasses.dataclass
class Rect: class Rect:
"""A rectangle with rotation in image coordinates. """A rectangle, used e.g. as part of detection results or as input
region-of-interest.
Attributes: x_center : The X coordinate of the top-left corner, in pixels. The coordinates are normalized wrt the image dimensions, i.e. generally in
y_center : The Y coordinate of the top-left corner, in pixels. [0,1] but they may exceed these bounds if describing a region overlapping the
width: The width of the rectangle, in pixels. image. The origin is on the top-left corner of the image.
height: The height of the rectangle, in pixels.
rotation: Rotation angle is clockwise in radians. Attributes:
rect_id: Optional unique id to help associate different rectangles to each left: The X coordinate of the left side of the rectangle.
other. top: The Y coordinate of the top of the rectangle.
right: The X coordinate of the right side of the rectangle.
bottom: The Y coordinate of the bottom of the rectangle.
""" """
x_center: int left: float
y_center: int top: float
width: int right: float
height: int bottom: float
rotation: Optional[float] = 0.0
rect_id: Optional[int] = None
@doc_controls.do_not_generate_docs
def to_pb2(self) -> _RectProto:
"""Generates a Rect protobuf object."""
return _RectProto(
x_center=self.x_center,
y_center=self.y_center,
width=self.width,
height=self.height,
)
@classmethod
@doc_controls.do_not_generate_docs
def create_from_pb2(cls, pb2_obj: _RectProto) -> 'Rect':
"""Creates a `Rect` object from the given protobuf object."""
return Rect(
x_center=pb2_obj.x_center,
y_center=pb2_obj.y_center,
width=pb2_obj.width,
height=pb2_obj.height)
def __eq__(self, other: Any) -> bool:
"""Checks if this object is equal to the given object.
Args:
other: The object to be compared with.
Returns:
True if the objects are equal.
"""
if not isinstance(other, Rect):
return False
return self.to_pb2().__eq__(other.to_pb2())
@dataclasses.dataclass @dataclasses.dataclass
class NormalizedRect: class NormalizedRect:
"""A rectangle with rotation in normalized coordinates. """A rectangle with rotation in normalized coordinates. Location of the center
of the rectangle in image coordinates. The (0.0, 0.0) point is at the
(top, left) corner.
The values of box The values of box
center location and size are within [0, 1]. center location and size are within [0, 1].
Attributes: x_center : The X normalized coordinate of the top-left corner. Attributes: x_center: The normalized X coordinate of the rectangle, in
y_center : The Y normalized coordinate of the top-left corner. image coordinates.
y_center: The normalized Y coordinate of the rectangle, in image coordinates.
width: The width of the rectangle. width: The width of the rectangle.
height: The height of the rectangle. height: The height of the rectangle.
rotation: Rotation angle is clockwise in radians. rotation: Rotation angle is clockwise in radians.

View File

@ -54,6 +54,7 @@ py_test(
"//mediapipe/tasks/python/test:test_utils", "//mediapipe/tasks/python/test:test_utils",
"//mediapipe/tasks/python/vision:image_classifier", "//mediapipe/tasks/python/vision:image_classifier",
"//mediapipe/tasks/python/vision/core:vision_task_running_mode", "//mediapipe/tasks/python/vision/core:vision_task_running_mode",
"//mediapipe/tasks/python/vision/core:image_processing_options",
], ],
) )

View File

@ -78,7 +78,7 @@ def _get_expected_gesture_recognition_result(
file_path) file_path)
with open(landmarks_detection_result_file_path, "rb") as f: with open(landmarks_detection_result_file_path, "rb") as f:
landmarks_detection_result_proto = _LandmarksDetectionResultProto() landmarks_detection_result_proto = _LandmarksDetectionResultProto()
# # Use this if a .pb file is available. # Use this if a .pb file is available.
# landmarks_detection_result_proto.ParseFromString(f.read()) # landmarks_detection_result_proto.ParseFromString(f.read())
text_format.Parse(f.read(), landmarks_detection_result_proto) text_format.Parse(f.read(), landmarks_detection_result_proto)
landmarks_detection_result = _LandmarksDetectionResult.create_from_pb2( landmarks_detection_result = _LandmarksDetectionResult.create_from_pb2(

View File

@ -29,8 +29,10 @@ from mediapipe.tasks.python.core import base_options as base_options_module
from mediapipe.tasks.python.test import test_utils from mediapipe.tasks.python.test import test_utils
from mediapipe.tasks.python.vision import image_classifier from mediapipe.tasks.python.vision import image_classifier
from mediapipe.tasks.python.vision.core import vision_task_running_mode from mediapipe.tasks.python.vision.core import vision_task_running_mode
from mediapipe.tasks.python.vision.core import image_processing_options as image_processing_options_module
_NormalizedRect = rect.NormalizedRect
_Rect = rect.Rect
_BaseOptions = base_options_module.BaseOptions _BaseOptions = base_options_module.BaseOptions
_ClassifierOptions = classifier_options.ClassifierOptions _ClassifierOptions = classifier_options.ClassifierOptions
_Category = category.Category _Category = category.Category
@ -41,6 +43,7 @@ _Image = image.Image
_ImageClassifier = image_classifier.ImageClassifier _ImageClassifier = image_classifier.ImageClassifier
_ImageClassifierOptions = image_classifier.ImageClassifierOptions _ImageClassifierOptions = image_classifier.ImageClassifierOptions
_RUNNING_MODE = vision_task_running_mode.VisionTaskRunningMode _RUNNING_MODE = vision_task_running_mode.VisionTaskRunningMode
_ImageProcessingOptions = image_processing_options_module.ImageProcessingOptions
_MODEL_FILE = 'mobilenet_v2_1.0_224.tflite' _MODEL_FILE = 'mobilenet_v2_1.0_224.tflite'
_IMAGE_FILE = 'burger.jpg' _IMAGE_FILE = 'burger.jpg'
@ -226,11 +229,11 @@ class ImageClassifierTest(parameterized.TestCase):
# Load the test image. # Load the test image.
test_image = _Image.create_from_file( test_image = _Image.create_from_file(
test_utils.get_test_data_path('multi_objects.jpg')) test_utils.get_test_data_path('multi_objects.jpg'))
# NormalizedRect around the soccer ball. # Region-of-interest around the soccer ball.
roi = _NormalizedRect( roi = _Rect(left=0.45, top=0.3075, right=0.614, bottom=0.7345)
x_center=0.532, y_center=0.521, width=0.164, height=0.427) image_processing_options = _ImageProcessingOptions(roi)
# Performs image classification on the input. # Performs image classification on the input.
image_result = classifier.classify(test_image, roi) image_result = classifier.classify(test_image, image_processing_options)
# Comparing results. # Comparing results.
_assert_proto_equals(image_result.to_pb2(), _assert_proto_equals(image_result.to_pb2(),
_generate_soccer_ball_results(0).to_pb2()) _generate_soccer_ball_results(0).to_pb2())
@ -414,12 +417,12 @@ class ImageClassifierTest(parameterized.TestCase):
# Load the test image. # Load the test image.
test_image = _Image.create_from_file( test_image = _Image.create_from_file(
test_utils.get_test_data_path('multi_objects.jpg')) test_utils.get_test_data_path('multi_objects.jpg'))
# NormalizedRect around the soccer ball. # Region-of-interest around the soccer ball.
roi = _NormalizedRect( roi = _Rect(left=0.45, top=0.3075, right=0.614, bottom=0.7345)
x_center=0.532, y_center=0.521, width=0.164, height=0.427) image_processing_options = _ImageProcessingOptions(roi)
for timestamp in range(0, 300, 30): for timestamp in range(0, 300, 30):
classification_result = classifier.classify_for_video( classification_result = classifier.classify_for_video(
test_image, timestamp, roi) test_image, timestamp, image_processing_options)
self.assertEqual(classification_result, self.assertEqual(classification_result,
_generate_soccer_ball_results(timestamp)) _generate_soccer_ball_results(timestamp))
@ -486,9 +489,9 @@ class ImageClassifierTest(parameterized.TestCase):
# Load the test image. # Load the test image.
test_image = _Image.create_from_file( test_image = _Image.create_from_file(
test_utils.get_test_data_path('multi_objects.jpg')) test_utils.get_test_data_path('multi_objects.jpg'))
# NormalizedRect around the soccer ball. # Region-of-interest around the soccer ball.
roi = _NormalizedRect( roi = _Rect(left=0.45, top=0.3075, right=0.614, bottom=0.7345)
x_center=0.532, y_center=0.521, width=0.164, height=0.427) image_processing_options = _ImageProcessingOptions(roi)
observed_timestamp_ms = -1 observed_timestamp_ms = -1
def check_result(result: _ClassificationResult, output_image: _Image, def check_result(result: _ClassificationResult, output_image: _Image,
@ -508,7 +511,8 @@ class ImageClassifierTest(parameterized.TestCase):
result_callback=check_result) result_callback=check_result)
with _ImageClassifier.create_from_options(options) as classifier: with _ImageClassifier.create_from_options(options) as classifier:
for timestamp in range(0, 300, 30): for timestamp in range(0, 300, 30):
classifier.classify_async(test_image, timestamp, roi) classifier.classify_async(test_image, timestamp,
image_processing_options)
if __name__ == '__main__': if __name__ == '__main__':

View File

@ -56,6 +56,7 @@ py_library(
"//mediapipe/tasks/python/core:task_info", "//mediapipe/tasks/python/core:task_info",
"//mediapipe/tasks/python/vision/core:base_vision_task_api", "//mediapipe/tasks/python/vision/core:base_vision_task_api",
"//mediapipe/tasks/python/vision/core:vision_task_running_mode", "//mediapipe/tasks/python/vision/core:vision_task_running_mode",
"//mediapipe/tasks/python/vision/core:image_processing_options",
], ],
) )
@ -96,5 +97,6 @@ py_library(
"//mediapipe/tasks/python/core:task_info", "//mediapipe/tasks/python/core:task_info",
"//mediapipe/tasks/python/vision/core:base_vision_task_api", "//mediapipe/tasks/python/vision/core:base_vision_task_api",
"//mediapipe/tasks/python/vision/core:vision_task_running_mode", "//mediapipe/tasks/python/vision/core:vision_task_running_mode",
"//mediapipe/tasks/python/vision/core:image_processing_options",
], ],
) )

View File

@ -160,15 +160,15 @@ class BaseVisionTaskApi(object):
if not roi_allowed: if not roi_allowed:
raise ValueError("This task doesn't support region-of-interest.") raise ValueError("This task doesn't support region-of-interest.")
roi = options.region_of_interest roi = options.region_of_interest
if roi.x_center >= roi.width or roi.y_center >= roi.height: if roi.left >= roi.right or roi.top >= roi.bottom:
raise ValueError( raise ValueError(
"Expected Rect with x_center < width and y_center < height.") "Expected Rect with left < right and top < bottom.")
if roi.x_center < 0 or roi.y_center < 0 or roi.width > 1 or roi.height > 1: if roi.left < 0 or roi.top < 0 or roi.right > 1 or roi.bottom > 1:
raise ValueError("Expected Rect values to be in [0,1].") raise ValueError("Expected Rect values to be in [0,1].")
normalized_rect.x_center = roi.x_center + roi.width / 2.0 normalized_rect.x_center = (roi.left + roi.right) / 2.0
normalized_rect.y_center = roi.y_center + roi.height / 2.0 normalized_rect.y_center = (roi.top + roi.bottom) / 2.0
normalized_rect.width = roi.width - roi.x_center normalized_rect.width = roi.right - roi.left
normalized_rect.height = roi.height - roi.y_center normalized_rect.height = roi.bottom - roi.top
return normalized_rect return normalized_rect
def close(self) -> None: def close(self) -> None:

View File

@ -30,7 +30,7 @@ class ImageProcessingOptions:
Attributes: Attributes:
region_of_interest: The optional region-of-interest to crop from the image. region_of_interest: The optional region-of-interest to crop from the image.
If not specified, the full image is used. Coordinates must be in [0,1] If not specified, the full image is used. Coordinates must be in [0,1]
with 'x_center' < 'width' and 'y_center' < height. with 'left' < 'right' and 'top' < 'bottom'.
rotation_degress: The rotation to apply to the image (or cropped rotation_degress: The rotation to apply to the image (or cropped
region-of-interest), in degrees clockwise. The rotation must be a region-of-interest), in degrees clockwise. The rotation must be a
multiple (positive or negative) of 90°. multiple (positive or negative) of 90°.

View File

@ -63,7 +63,7 @@ class GestureRecognitionResult:
Attributes: Attributes:
gestures: Recognized hand gestures of detected hands. Note that the index gestures: Recognized hand gestures of detected hands. Note that the index
of the gesture is always 0, because the raw indices from multiple gesture of the gesture is always -1, because the raw indices from multiple gesture
classifiers cannot consolidate to a meaningful index. classifiers cannot consolidate to a meaningful index.
handedness: Classification of handedness. handedness: Classification of handedness.
hand_landmarks: Detected hand landmarks in normalized image coordinates. hand_landmarks: Detected hand landmarks in normalized image coordinates.

View File

@ -31,12 +31,14 @@ from mediapipe.tasks.python.core import task_info as task_info_module
from mediapipe.tasks.python.core.optional_dependencies import doc_controls from mediapipe.tasks.python.core.optional_dependencies import doc_controls
from mediapipe.tasks.python.vision.core import base_vision_task_api from mediapipe.tasks.python.vision.core import base_vision_task_api
from mediapipe.tasks.python.vision.core import vision_task_running_mode from mediapipe.tasks.python.vision.core import vision_task_running_mode
from mediapipe.tasks.python.vision.core import image_processing_options as image_processing_options_module
_NormalizedRect = rect.NormalizedRect _NormalizedRect = rect.NormalizedRect
_BaseOptions = base_options_module.BaseOptions _BaseOptions = base_options_module.BaseOptions
_ImageClassifierGraphOptionsProto = image_classifier_graph_options_pb2.ImageClassifierGraphOptions _ImageClassifierGraphOptionsProto = image_classifier_graph_options_pb2.ImageClassifierGraphOptions
_ClassifierOptions = classifier_options.ClassifierOptions _ClassifierOptions = classifier_options.ClassifierOptions
_RunningMode = vision_task_running_mode.VisionTaskRunningMode _RunningMode = vision_task_running_mode.VisionTaskRunningMode
_ImageProcessingOptions = image_processing_options_module.ImageProcessingOptions
_TaskInfo = task_info_module.TaskInfo _TaskInfo = task_info_module.TaskInfo
_CLASSIFICATION_RESULT_OUT_STREAM_NAME = 'classification_result_out' _CLASSIFICATION_RESULT_OUT_STREAM_NAME = 'classification_result_out'
@ -44,17 +46,12 @@ _CLASSIFICATION_RESULT_TAG = 'CLASSIFICATION_RESULT'
_IMAGE_IN_STREAM_NAME = 'image_in' _IMAGE_IN_STREAM_NAME = 'image_in'
_IMAGE_OUT_STREAM_NAME = 'image_out' _IMAGE_OUT_STREAM_NAME = 'image_out'
_IMAGE_TAG = 'IMAGE' _IMAGE_TAG = 'IMAGE'
_NORM_RECT_NAME = 'norm_rect_in' _NORM_RECT_STREAM_NAME = 'norm_rect_in'
_NORM_RECT_TAG = 'NORM_RECT' _NORM_RECT_TAG = 'NORM_RECT'
_TASK_GRAPH_NAME = 'mediapipe.tasks.vision.image_classifier.ImageClassifierGraph' _TASK_GRAPH_NAME = 'mediapipe.tasks.vision.image_classifier.ImageClassifierGraph'
_MICRO_SECONDS_PER_MILLISECOND = 1000 _MICRO_SECONDS_PER_MILLISECOND = 1000
def _build_full_image_norm_rect() -> _NormalizedRect:
# Builds a NormalizedRect covering the entire image.
return _NormalizedRect(x_center=0.5, y_center=0.5, width=1, height=1)
@dataclasses.dataclass @dataclasses.dataclass
class ImageClassifierOptions: class ImageClassifierOptions:
"""Options for the image classifier task. """Options for the image classifier task.
@ -156,7 +153,7 @@ class ImageClassifier(base_vision_task_api.BaseVisionTaskApi):
task_graph=_TASK_GRAPH_NAME, task_graph=_TASK_GRAPH_NAME,
input_streams=[ input_streams=[
':'.join([_IMAGE_TAG, _IMAGE_IN_STREAM_NAME]), ':'.join([_IMAGE_TAG, _IMAGE_IN_STREAM_NAME]),
':'.join([_NORM_RECT_TAG, _NORM_RECT_NAME]), ':'.join([_NORM_RECT_TAG, _NORM_RECT_STREAM_NAME]),
], ],
output_streams=[ output_streams=[
':'.join([ ':'.join([
@ -171,17 +168,16 @@ class ImageClassifier(base_vision_task_api.BaseVisionTaskApi):
_RunningMode.LIVE_STREAM), options.running_mode, _RunningMode.LIVE_STREAM), options.running_mode,
packets_callback if options.result_callback else None) packets_callback if options.result_callback else None)
# TODO: Replace _NormalizedRect with ImageProcessingOption
def classify( def classify(
self, self,
image: image_module.Image, image: image_module.Image,
roi: Optional[_NormalizedRect] = None image_processing_options: Optional[_ImageProcessingOptions] = None
) -> classifications.ClassificationResult: ) -> classifications.ClassificationResult:
"""Performs image classification on the provided MediaPipe Image. """Performs image classification on the provided MediaPipe Image.
Args: Args:
image: MediaPipe Image. image: MediaPipe Image.
roi: The region of interest. image_processing_options: Options for image processing.
Returns: Returns:
A classification result object that contains a list of classifications. A classification result object that contains a list of classifications.
@ -190,10 +186,11 @@ class ImageClassifier(base_vision_task_api.BaseVisionTaskApi):
ValueError: If any of the input arguments is invalid. ValueError: If any of the input arguments is invalid.
RuntimeError: If image classification failed to run. RuntimeError: If image classification failed to run.
""" """
norm_rect = roi if roi is not None else _build_full_image_norm_rect() normalized_rect = self.convert_to_normalized_rect(image_processing_options)
output_packets = self._process_image_data({ output_packets = self._process_image_data({
_IMAGE_IN_STREAM_NAME: packet_creator.create_image(image), _IMAGE_IN_STREAM_NAME: packet_creator.create_image(image),
_NORM_RECT_NAME: packet_creator.create_proto(norm_rect.to_pb2()) _NORM_RECT_STREAM_NAME: packet_creator.create_proto(
normalized_rect.to_pb2())
}) })
classification_result_proto = classifications_pb2.ClassificationResult() classification_result_proto = classifications_pb2.ClassificationResult()
@ -210,7 +207,7 @@ class ImageClassifier(base_vision_task_api.BaseVisionTaskApi):
self, self,
image: image_module.Image, image: image_module.Image,
timestamp_ms: int, timestamp_ms: int,
roi: Optional[_NormalizedRect] = None image_processing_options: Optional[_ImageProcessingOptions] = None
) -> classifications.ClassificationResult: ) -> classifications.ClassificationResult:
"""Performs image classification on the provided video frames. """Performs image classification on the provided video frames.
@ -222,7 +219,7 @@ class ImageClassifier(base_vision_task_api.BaseVisionTaskApi):
Args: Args:
image: MediaPipe Image. image: MediaPipe Image.
timestamp_ms: The timestamp of the input video frame in milliseconds. timestamp_ms: The timestamp of the input video frame in milliseconds.
roi: The region of interest. image_processing_options: Options for image processing.
Returns: Returns:
A classification result object that contains a list of classifications. A classification result object that contains a list of classifications.
@ -231,13 +228,13 @@ class ImageClassifier(base_vision_task_api.BaseVisionTaskApi):
ValueError: If any of the input arguments is invalid. ValueError: If any of the input arguments is invalid.
RuntimeError: If image classification failed to run. RuntimeError: If image classification failed to run.
""" """
norm_rect = roi if roi is not None else _build_full_image_norm_rect() normalized_rect = self.convert_to_normalized_rect(image_processing_options)
output_packets = self._process_video_data({ output_packets = self._process_video_data({
_IMAGE_IN_STREAM_NAME: _IMAGE_IN_STREAM_NAME:
packet_creator.create_image(image).at( packet_creator.create_image(image).at(
timestamp_ms * _MICRO_SECONDS_PER_MILLISECOND), timestamp_ms * _MICRO_SECONDS_PER_MILLISECOND),
_NORM_RECT_NAME: _NORM_RECT_STREAM_NAME:
packet_creator.create_proto(norm_rect.to_pb2()).at( packet_creator.create_proto(normalized_rect.to_pb2()).at(
timestamp_ms * _MICRO_SECONDS_PER_MILLISECOND) timestamp_ms * _MICRO_SECONDS_PER_MILLISECOND)
}) })
@ -251,10 +248,12 @@ class ImageClassifier(base_vision_task_api.BaseVisionTaskApi):
for classification in classification_result_proto.classifications for classification in classification_result_proto.classifications
]) ])
def classify_async(self, def classify_async(
image: image_module.Image, self,
timestamp_ms: int, image: image_module.Image,
roi: Optional[_NormalizedRect] = None) -> None: timestamp_ms: int,
image_processing_options: Optional[_ImageProcessingOptions] = None
) -> None:
"""Sends live image data (an Image with a unique timestamp) to perform image classification. """Sends live image data (an Image with a unique timestamp) to perform image classification.
Only use this method when the ImageClassifier is created with the live Only use this method when the ImageClassifier is created with the live
@ -275,18 +274,18 @@ class ImageClassifier(base_vision_task_api.BaseVisionTaskApi):
Args: Args:
image: MediaPipe Image. image: MediaPipe Image.
timestamp_ms: The timestamp of the input image in milliseconds. timestamp_ms: The timestamp of the input image in milliseconds.
roi: The region of interest. image_processing_options: Options for image processing.
Raises: Raises:
ValueError: If the current input timestamp is smaller than what the image ValueError: If the current input timestamp is smaller than what the image
classifier has already processed. classifier has already processed.
""" """
norm_rect = roi if roi is not None else _build_full_image_norm_rect() normalized_rect = self.convert_to_normalized_rect(image_processing_options)
self._send_live_stream_data({ self._send_live_stream_data({
_IMAGE_IN_STREAM_NAME: _IMAGE_IN_STREAM_NAME:
packet_creator.create_image(image).at( packet_creator.create_image(image).at(
timestamp_ms * _MICRO_SECONDS_PER_MILLISECOND), timestamp_ms * _MICRO_SECONDS_PER_MILLISECOND),
_NORM_RECT_NAME: _NORM_RECT_STREAM_NAME:
packet_creator.create_proto(norm_rect.to_pb2()).at( packet_creator.create_proto(normalized_rect.to_pb2()).at(
timestamp_ms * _MICRO_SECONDS_PER_MILLISECOND) timestamp_ms * _MICRO_SECONDS_PER_MILLISECOND)
}) })