Refactored Rect to use top-left coordinates and appropriately updated the Image Classifier and Gesture Recognizer APIs/tests
This commit is contained in:
parent
a913255080
commit
c5765ac836
|
@ -19,75 +19,44 @@ from typing import Any, Optional
|
|||
from mediapipe.framework.formats import rect_pb2
|
||||
from mediapipe.tasks.python.core.optional_dependencies import doc_controls
|
||||
|
||||
_RectProto = rect_pb2.Rect
|
||||
_NormalizedRectProto = rect_pb2.NormalizedRect
|
||||
|
||||
|
||||
@dataclasses.dataclass
|
||||
class Rect:
|
||||
"""A rectangle with rotation in image coordinates.
|
||||
"""A rectangle, used e.g. as part of detection results or as input
|
||||
region-of-interest.
|
||||
|
||||
Attributes: x_center : The X coordinate of the top-left corner, in pixels.
|
||||
y_center : The Y coordinate of the top-left corner, in pixels.
|
||||
width: The width of the rectangle, in pixels.
|
||||
height: The height of the rectangle, in pixels.
|
||||
rotation: Rotation angle is clockwise in radians.
|
||||
rect_id: Optional unique id to help associate different rectangles to each
|
||||
other.
|
||||
The coordinates are normalized wrt the image dimensions, i.e. generally in
|
||||
[0,1] but they may exceed these bounds if describing a region overlapping the
|
||||
image. The origin is on the top-left corner of the image.
|
||||
|
||||
Attributes:
|
||||
left: The X coordinate of the left side of the rectangle.
|
||||
top: The Y coordinate of the top of the rectangle.
|
||||
right: The X coordinate of the right side of the rectangle.
|
||||
bottom: The Y coordinate of the bottom of the rectangle.
|
||||
"""
|
||||
|
||||
x_center: int
|
||||
y_center: int
|
||||
width: int
|
||||
height: int
|
||||
rotation: Optional[float] = 0.0
|
||||
rect_id: Optional[int] = None
|
||||
|
||||
@doc_controls.do_not_generate_docs
|
||||
def to_pb2(self) -> _RectProto:
|
||||
"""Generates a Rect protobuf object."""
|
||||
return _RectProto(
|
||||
x_center=self.x_center,
|
||||
y_center=self.y_center,
|
||||
width=self.width,
|
||||
height=self.height,
|
||||
)
|
||||
|
||||
@classmethod
|
||||
@doc_controls.do_not_generate_docs
|
||||
def create_from_pb2(cls, pb2_obj: _RectProto) -> 'Rect':
|
||||
"""Creates a `Rect` object from the given protobuf object."""
|
||||
return Rect(
|
||||
x_center=pb2_obj.x_center,
|
||||
y_center=pb2_obj.y_center,
|
||||
width=pb2_obj.width,
|
||||
height=pb2_obj.height)
|
||||
|
||||
def __eq__(self, other: Any) -> bool:
|
||||
"""Checks if this object is equal to the given object.
|
||||
|
||||
Args:
|
||||
other: The object to be compared with.
|
||||
|
||||
Returns:
|
||||
True if the objects are equal.
|
||||
"""
|
||||
if not isinstance(other, Rect):
|
||||
return False
|
||||
|
||||
return self.to_pb2().__eq__(other.to_pb2())
|
||||
left: float
|
||||
top: float
|
||||
right: float
|
||||
bottom: float
|
||||
|
||||
|
||||
@dataclasses.dataclass
|
||||
class NormalizedRect:
|
||||
"""A rectangle with rotation in normalized coordinates.
|
||||
"""A rectangle with rotation in normalized coordinates. Location of the center
|
||||
of the rectangle in image coordinates. The (0.0, 0.0) point is at the
|
||||
(top, left) corner.
|
||||
|
||||
The values of box
|
||||
|
||||
center location and size are within [0, 1].
|
||||
|
||||
Attributes: x_center : The X normalized coordinate of the top-left corner.
|
||||
y_center : The Y normalized coordinate of the top-left corner.
|
||||
Attributes: x_center: The normalized X coordinate of the rectangle, in
|
||||
image coordinates.
|
||||
y_center: The normalized Y coordinate of the rectangle, in image coordinates.
|
||||
width: The width of the rectangle.
|
||||
height: The height of the rectangle.
|
||||
rotation: Rotation angle is clockwise in radians.
|
||||
|
|
|
@ -54,6 +54,7 @@ py_test(
|
|||
"//mediapipe/tasks/python/test:test_utils",
|
||||
"//mediapipe/tasks/python/vision:image_classifier",
|
||||
"//mediapipe/tasks/python/vision/core:vision_task_running_mode",
|
||||
"//mediapipe/tasks/python/vision/core:image_processing_options",
|
||||
],
|
||||
)
|
||||
|
||||
|
|
|
@ -78,7 +78,7 @@ def _get_expected_gesture_recognition_result(
|
|||
file_path)
|
||||
with open(landmarks_detection_result_file_path, "rb") as f:
|
||||
landmarks_detection_result_proto = _LandmarksDetectionResultProto()
|
||||
# # Use this if a .pb file is available.
|
||||
# Use this if a .pb file is available.
|
||||
# landmarks_detection_result_proto.ParseFromString(f.read())
|
||||
text_format.Parse(f.read(), landmarks_detection_result_proto)
|
||||
landmarks_detection_result = _LandmarksDetectionResult.create_from_pb2(
|
||||
|
|
|
@ -29,8 +29,10 @@ from mediapipe.tasks.python.core import base_options as base_options_module
|
|||
from mediapipe.tasks.python.test import test_utils
|
||||
from mediapipe.tasks.python.vision import image_classifier
|
||||
from mediapipe.tasks.python.vision.core import vision_task_running_mode
|
||||
from mediapipe.tasks.python.vision.core import image_processing_options as image_processing_options_module
|
||||
|
||||
_NormalizedRect = rect.NormalizedRect
|
||||
|
||||
_Rect = rect.Rect
|
||||
_BaseOptions = base_options_module.BaseOptions
|
||||
_ClassifierOptions = classifier_options.ClassifierOptions
|
||||
_Category = category.Category
|
||||
|
@ -41,6 +43,7 @@ _Image = image.Image
|
|||
_ImageClassifier = image_classifier.ImageClassifier
|
||||
_ImageClassifierOptions = image_classifier.ImageClassifierOptions
|
||||
_RUNNING_MODE = vision_task_running_mode.VisionTaskRunningMode
|
||||
_ImageProcessingOptions = image_processing_options_module.ImageProcessingOptions
|
||||
|
||||
_MODEL_FILE = 'mobilenet_v2_1.0_224.tflite'
|
||||
_IMAGE_FILE = 'burger.jpg'
|
||||
|
@ -226,11 +229,11 @@ class ImageClassifierTest(parameterized.TestCase):
|
|||
# Load the test image.
|
||||
test_image = _Image.create_from_file(
|
||||
test_utils.get_test_data_path('multi_objects.jpg'))
|
||||
# NormalizedRect around the soccer ball.
|
||||
roi = _NormalizedRect(
|
||||
x_center=0.532, y_center=0.521, width=0.164, height=0.427)
|
||||
# Region-of-interest around the soccer ball.
|
||||
roi = _Rect(left=0.45, top=0.3075, right=0.614, bottom=0.7345)
|
||||
image_processing_options = _ImageProcessingOptions(roi)
|
||||
# Performs image classification on the input.
|
||||
image_result = classifier.classify(test_image, roi)
|
||||
image_result = classifier.classify(test_image, image_processing_options)
|
||||
# Comparing results.
|
||||
_assert_proto_equals(image_result.to_pb2(),
|
||||
_generate_soccer_ball_results(0).to_pb2())
|
||||
|
@ -414,12 +417,12 @@ class ImageClassifierTest(parameterized.TestCase):
|
|||
# Load the test image.
|
||||
test_image = _Image.create_from_file(
|
||||
test_utils.get_test_data_path('multi_objects.jpg'))
|
||||
# NormalizedRect around the soccer ball.
|
||||
roi = _NormalizedRect(
|
||||
x_center=0.532, y_center=0.521, width=0.164, height=0.427)
|
||||
# Region-of-interest around the soccer ball.
|
||||
roi = _Rect(left=0.45, top=0.3075, right=0.614, bottom=0.7345)
|
||||
image_processing_options = _ImageProcessingOptions(roi)
|
||||
for timestamp in range(0, 300, 30):
|
||||
classification_result = classifier.classify_for_video(
|
||||
test_image, timestamp, roi)
|
||||
test_image, timestamp, image_processing_options)
|
||||
self.assertEqual(classification_result,
|
||||
_generate_soccer_ball_results(timestamp))
|
||||
|
||||
|
@ -486,9 +489,9 @@ class ImageClassifierTest(parameterized.TestCase):
|
|||
# Load the test image.
|
||||
test_image = _Image.create_from_file(
|
||||
test_utils.get_test_data_path('multi_objects.jpg'))
|
||||
# NormalizedRect around the soccer ball.
|
||||
roi = _NormalizedRect(
|
||||
x_center=0.532, y_center=0.521, width=0.164, height=0.427)
|
||||
# Region-of-interest around the soccer ball.
|
||||
roi = _Rect(left=0.45, top=0.3075, right=0.614, bottom=0.7345)
|
||||
image_processing_options = _ImageProcessingOptions(roi)
|
||||
observed_timestamp_ms = -1
|
||||
|
||||
def check_result(result: _ClassificationResult, output_image: _Image,
|
||||
|
@ -508,7 +511,8 @@ class ImageClassifierTest(parameterized.TestCase):
|
|||
result_callback=check_result)
|
||||
with _ImageClassifier.create_from_options(options) as classifier:
|
||||
for timestamp in range(0, 300, 30):
|
||||
classifier.classify_async(test_image, timestamp, roi)
|
||||
classifier.classify_async(test_image, timestamp,
|
||||
image_processing_options)
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
|
|
|
@ -56,6 +56,7 @@ py_library(
|
|||
"//mediapipe/tasks/python/core:task_info",
|
||||
"//mediapipe/tasks/python/vision/core:base_vision_task_api",
|
||||
"//mediapipe/tasks/python/vision/core:vision_task_running_mode",
|
||||
"//mediapipe/tasks/python/vision/core:image_processing_options",
|
||||
],
|
||||
)
|
||||
|
||||
|
@ -96,5 +97,6 @@ py_library(
|
|||
"//mediapipe/tasks/python/core:task_info",
|
||||
"//mediapipe/tasks/python/vision/core:base_vision_task_api",
|
||||
"//mediapipe/tasks/python/vision/core:vision_task_running_mode",
|
||||
"//mediapipe/tasks/python/vision/core:image_processing_options",
|
||||
],
|
||||
)
|
||||
|
|
|
@ -160,15 +160,15 @@ class BaseVisionTaskApi(object):
|
|||
if not roi_allowed:
|
||||
raise ValueError("This task doesn't support region-of-interest.")
|
||||
roi = options.region_of_interest
|
||||
if roi.x_center >= roi.width or roi.y_center >= roi.height:
|
||||
if roi.left >= roi.right or roi.top >= roi.bottom:
|
||||
raise ValueError(
|
||||
"Expected Rect with x_center < width and y_center < height.")
|
||||
if roi.x_center < 0 or roi.y_center < 0 or roi.width > 1 or roi.height > 1:
|
||||
"Expected Rect with left < right and top < bottom.")
|
||||
if roi.left < 0 or roi.top < 0 or roi.right > 1 or roi.bottom > 1:
|
||||
raise ValueError("Expected Rect values to be in [0,1].")
|
||||
normalized_rect.x_center = roi.x_center + roi.width / 2.0
|
||||
normalized_rect.y_center = roi.y_center + roi.height / 2.0
|
||||
normalized_rect.width = roi.width - roi.x_center
|
||||
normalized_rect.height = roi.height - roi.y_center
|
||||
normalized_rect.x_center = (roi.left + roi.right) / 2.0
|
||||
normalized_rect.y_center = (roi.top + roi.bottom) / 2.0
|
||||
normalized_rect.width = roi.right - roi.left
|
||||
normalized_rect.height = roi.bottom - roi.top
|
||||
return normalized_rect
|
||||
|
||||
def close(self) -> None:
|
||||
|
|
|
@ -30,7 +30,7 @@ class ImageProcessingOptions:
|
|||
Attributes:
|
||||
region_of_interest: The optional region-of-interest to crop from the image.
|
||||
If not specified, the full image is used. Coordinates must be in [0,1]
|
||||
with 'x_center' < 'width' and 'y_center' < height.
|
||||
with 'left' < 'right' and 'top' < 'bottom'.
|
||||
rotation_degress: The rotation to apply to the image (or cropped
|
||||
region-of-interest), in degrees clockwise. The rotation must be a
|
||||
multiple (positive or negative) of 90°.
|
||||
|
|
|
@ -63,7 +63,7 @@ class GestureRecognitionResult:
|
|||
|
||||
Attributes:
|
||||
gestures: Recognized hand gestures of detected hands. Note that the index
|
||||
of the gesture is always 0, because the raw indices from multiple gesture
|
||||
of the gesture is always -1, because the raw indices from multiple gesture
|
||||
classifiers cannot consolidate to a meaningful index.
|
||||
handedness: Classification of handedness.
|
||||
hand_landmarks: Detected hand landmarks in normalized image coordinates.
|
||||
|
|
|
@ -31,12 +31,14 @@ from mediapipe.tasks.python.core import task_info as task_info_module
|
|||
from mediapipe.tasks.python.core.optional_dependencies import doc_controls
|
||||
from mediapipe.tasks.python.vision.core import base_vision_task_api
|
||||
from mediapipe.tasks.python.vision.core import vision_task_running_mode
|
||||
from mediapipe.tasks.python.vision.core import image_processing_options as image_processing_options_module
|
||||
|
||||
_NormalizedRect = rect.NormalizedRect
|
||||
_BaseOptions = base_options_module.BaseOptions
|
||||
_ImageClassifierGraphOptionsProto = image_classifier_graph_options_pb2.ImageClassifierGraphOptions
|
||||
_ClassifierOptions = classifier_options.ClassifierOptions
|
||||
_RunningMode = vision_task_running_mode.VisionTaskRunningMode
|
||||
_ImageProcessingOptions = image_processing_options_module.ImageProcessingOptions
|
||||
_TaskInfo = task_info_module.TaskInfo
|
||||
|
||||
_CLASSIFICATION_RESULT_OUT_STREAM_NAME = 'classification_result_out'
|
||||
|
@ -44,17 +46,12 @@ _CLASSIFICATION_RESULT_TAG = 'CLASSIFICATION_RESULT'
|
|||
_IMAGE_IN_STREAM_NAME = 'image_in'
|
||||
_IMAGE_OUT_STREAM_NAME = 'image_out'
|
||||
_IMAGE_TAG = 'IMAGE'
|
||||
_NORM_RECT_NAME = 'norm_rect_in'
|
||||
_NORM_RECT_STREAM_NAME = 'norm_rect_in'
|
||||
_NORM_RECT_TAG = 'NORM_RECT'
|
||||
_TASK_GRAPH_NAME = 'mediapipe.tasks.vision.image_classifier.ImageClassifierGraph'
|
||||
_MICRO_SECONDS_PER_MILLISECOND = 1000
|
||||
|
||||
|
||||
def _build_full_image_norm_rect() -> _NormalizedRect:
|
||||
# Builds a NormalizedRect covering the entire image.
|
||||
return _NormalizedRect(x_center=0.5, y_center=0.5, width=1, height=1)
|
||||
|
||||
|
||||
@dataclasses.dataclass
|
||||
class ImageClassifierOptions:
|
||||
"""Options for the image classifier task.
|
||||
|
@ -156,7 +153,7 @@ class ImageClassifier(base_vision_task_api.BaseVisionTaskApi):
|
|||
task_graph=_TASK_GRAPH_NAME,
|
||||
input_streams=[
|
||||
':'.join([_IMAGE_TAG, _IMAGE_IN_STREAM_NAME]),
|
||||
':'.join([_NORM_RECT_TAG, _NORM_RECT_NAME]),
|
||||
':'.join([_NORM_RECT_TAG, _NORM_RECT_STREAM_NAME]),
|
||||
],
|
||||
output_streams=[
|
||||
':'.join([
|
||||
|
@ -171,17 +168,16 @@ class ImageClassifier(base_vision_task_api.BaseVisionTaskApi):
|
|||
_RunningMode.LIVE_STREAM), options.running_mode,
|
||||
packets_callback if options.result_callback else None)
|
||||
|
||||
# TODO: Replace _NormalizedRect with ImageProcessingOption
|
||||
def classify(
|
||||
self,
|
||||
image: image_module.Image,
|
||||
roi: Optional[_NormalizedRect] = None
|
||||
image_processing_options: Optional[_ImageProcessingOptions] = None
|
||||
) -> classifications.ClassificationResult:
|
||||
"""Performs image classification on the provided MediaPipe Image.
|
||||
|
||||
Args:
|
||||
image: MediaPipe Image.
|
||||
roi: The region of interest.
|
||||
image_processing_options: Options for image processing.
|
||||
|
||||
Returns:
|
||||
A classification result object that contains a list of classifications.
|
||||
|
@ -190,10 +186,11 @@ class ImageClassifier(base_vision_task_api.BaseVisionTaskApi):
|
|||
ValueError: If any of the input arguments is invalid.
|
||||
RuntimeError: If image classification failed to run.
|
||||
"""
|
||||
norm_rect = roi if roi is not None else _build_full_image_norm_rect()
|
||||
normalized_rect = self.convert_to_normalized_rect(image_processing_options)
|
||||
output_packets = self._process_image_data({
|
||||
_IMAGE_IN_STREAM_NAME: packet_creator.create_image(image),
|
||||
_NORM_RECT_NAME: packet_creator.create_proto(norm_rect.to_pb2())
|
||||
_NORM_RECT_STREAM_NAME: packet_creator.create_proto(
|
||||
normalized_rect.to_pb2())
|
||||
})
|
||||
|
||||
classification_result_proto = classifications_pb2.ClassificationResult()
|
||||
|
@ -210,7 +207,7 @@ class ImageClassifier(base_vision_task_api.BaseVisionTaskApi):
|
|||
self,
|
||||
image: image_module.Image,
|
||||
timestamp_ms: int,
|
||||
roi: Optional[_NormalizedRect] = None
|
||||
image_processing_options: Optional[_ImageProcessingOptions] = None
|
||||
) -> classifications.ClassificationResult:
|
||||
"""Performs image classification on the provided video frames.
|
||||
|
||||
|
@ -222,7 +219,7 @@ class ImageClassifier(base_vision_task_api.BaseVisionTaskApi):
|
|||
Args:
|
||||
image: MediaPipe Image.
|
||||
timestamp_ms: The timestamp of the input video frame in milliseconds.
|
||||
roi: The region of interest.
|
||||
image_processing_options: Options for image processing.
|
||||
|
||||
Returns:
|
||||
A classification result object that contains a list of classifications.
|
||||
|
@ -231,13 +228,13 @@ class ImageClassifier(base_vision_task_api.BaseVisionTaskApi):
|
|||
ValueError: If any of the input arguments is invalid.
|
||||
RuntimeError: If image classification failed to run.
|
||||
"""
|
||||
norm_rect = roi if roi is not None else _build_full_image_norm_rect()
|
||||
normalized_rect = self.convert_to_normalized_rect(image_processing_options)
|
||||
output_packets = self._process_video_data({
|
||||
_IMAGE_IN_STREAM_NAME:
|
||||
packet_creator.create_image(image).at(
|
||||
timestamp_ms * _MICRO_SECONDS_PER_MILLISECOND),
|
||||
_NORM_RECT_NAME:
|
||||
packet_creator.create_proto(norm_rect.to_pb2()).at(
|
||||
_NORM_RECT_STREAM_NAME:
|
||||
packet_creator.create_proto(normalized_rect.to_pb2()).at(
|
||||
timestamp_ms * _MICRO_SECONDS_PER_MILLISECOND)
|
||||
})
|
||||
|
||||
|
@ -251,10 +248,12 @@ class ImageClassifier(base_vision_task_api.BaseVisionTaskApi):
|
|||
for classification in classification_result_proto.classifications
|
||||
])
|
||||
|
||||
def classify_async(self,
|
||||
def classify_async(
|
||||
self,
|
||||
image: image_module.Image,
|
||||
timestamp_ms: int,
|
||||
roi: Optional[_NormalizedRect] = None) -> None:
|
||||
image_processing_options: Optional[_ImageProcessingOptions] = None
|
||||
) -> None:
|
||||
"""Sends live image data (an Image with a unique timestamp) to perform image classification.
|
||||
|
||||
Only use this method when the ImageClassifier is created with the live
|
||||
|
@ -275,18 +274,18 @@ class ImageClassifier(base_vision_task_api.BaseVisionTaskApi):
|
|||
Args:
|
||||
image: MediaPipe Image.
|
||||
timestamp_ms: The timestamp of the input image in milliseconds.
|
||||
roi: The region of interest.
|
||||
image_processing_options: Options for image processing.
|
||||
|
||||
Raises:
|
||||
ValueError: If the current input timestamp is smaller than what the image
|
||||
classifier has already processed.
|
||||
"""
|
||||
norm_rect = roi if roi is not None else _build_full_image_norm_rect()
|
||||
normalized_rect = self.convert_to_normalized_rect(image_processing_options)
|
||||
self._send_live_stream_data({
|
||||
_IMAGE_IN_STREAM_NAME:
|
||||
packet_creator.create_image(image).at(
|
||||
timestamp_ms * _MICRO_SECONDS_PER_MILLISECOND),
|
||||
_NORM_RECT_NAME:
|
||||
packet_creator.create_proto(norm_rect.to_pb2()).at(
|
||||
_NORM_RECT_STREAM_NAME:
|
||||
packet_creator.create_proto(normalized_rect.to_pb2()).at(
|
||||
timestamp_ms * _MICRO_SECONDS_PER_MILLISECOND)
|
||||
})
|
||||
|
|
Loading…
Reference in New Issue
Block a user