Refactored Rect to use top-left coordinates and appropriately updated the Image Classifier and Gesture Recognizer APIs/tests
This commit is contained in:
parent
a913255080
commit
c5765ac836
|
@ -19,75 +19,44 @@ from typing import Any, Optional
|
||||||
from mediapipe.framework.formats import rect_pb2
|
from mediapipe.framework.formats import rect_pb2
|
||||||
from mediapipe.tasks.python.core.optional_dependencies import doc_controls
|
from mediapipe.tasks.python.core.optional_dependencies import doc_controls
|
||||||
|
|
||||||
_RectProto = rect_pb2.Rect
|
|
||||||
_NormalizedRectProto = rect_pb2.NormalizedRect
|
_NormalizedRectProto = rect_pb2.NormalizedRect
|
||||||
|
|
||||||
|
|
||||||
@dataclasses.dataclass
|
@dataclasses.dataclass
|
||||||
class Rect:
|
class Rect:
|
||||||
"""A rectangle with rotation in image coordinates.
|
"""A rectangle, used e.g. as part of detection results or as input
|
||||||
|
region-of-interest.
|
||||||
|
|
||||||
Attributes: x_center : The X coordinate of the top-left corner, in pixels.
|
The coordinates are normalized wrt the image dimensions, i.e. generally in
|
||||||
y_center : The Y coordinate of the top-left corner, in pixels.
|
[0,1] but they may exceed these bounds if describing a region overlapping the
|
||||||
width: The width of the rectangle, in pixels.
|
image. The origin is on the top-left corner of the image.
|
||||||
height: The height of the rectangle, in pixels.
|
|
||||||
rotation: Rotation angle is clockwise in radians.
|
Attributes:
|
||||||
rect_id: Optional unique id to help associate different rectangles to each
|
left: The X coordinate of the left side of the rectangle.
|
||||||
other.
|
top: The Y coordinate of the top of the rectangle.
|
||||||
|
right: The X coordinate of the right side of the rectangle.
|
||||||
|
bottom: The Y coordinate of the bottom of the rectangle.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
x_center: int
|
left: float
|
||||||
y_center: int
|
top: float
|
||||||
width: int
|
right: float
|
||||||
height: int
|
bottom: float
|
||||||
rotation: Optional[float] = 0.0
|
|
||||||
rect_id: Optional[int] = None
|
|
||||||
|
|
||||||
@doc_controls.do_not_generate_docs
|
|
||||||
def to_pb2(self) -> _RectProto:
|
|
||||||
"""Generates a Rect protobuf object."""
|
|
||||||
return _RectProto(
|
|
||||||
x_center=self.x_center,
|
|
||||||
y_center=self.y_center,
|
|
||||||
width=self.width,
|
|
||||||
height=self.height,
|
|
||||||
)
|
|
||||||
|
|
||||||
@classmethod
|
|
||||||
@doc_controls.do_not_generate_docs
|
|
||||||
def create_from_pb2(cls, pb2_obj: _RectProto) -> 'Rect':
|
|
||||||
"""Creates a `Rect` object from the given protobuf object."""
|
|
||||||
return Rect(
|
|
||||||
x_center=pb2_obj.x_center,
|
|
||||||
y_center=pb2_obj.y_center,
|
|
||||||
width=pb2_obj.width,
|
|
||||||
height=pb2_obj.height)
|
|
||||||
|
|
||||||
def __eq__(self, other: Any) -> bool:
|
|
||||||
"""Checks if this object is equal to the given object.
|
|
||||||
|
|
||||||
Args:
|
|
||||||
other: The object to be compared with.
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
True if the objects are equal.
|
|
||||||
"""
|
|
||||||
if not isinstance(other, Rect):
|
|
||||||
return False
|
|
||||||
|
|
||||||
return self.to_pb2().__eq__(other.to_pb2())
|
|
||||||
|
|
||||||
|
|
||||||
@dataclasses.dataclass
|
@dataclasses.dataclass
|
||||||
class NormalizedRect:
|
class NormalizedRect:
|
||||||
"""A rectangle with rotation in normalized coordinates.
|
"""A rectangle with rotation in normalized coordinates. Location of the center
|
||||||
|
of the rectangle in image coordinates. The (0.0, 0.0) point is at the
|
||||||
|
(top, left) corner.
|
||||||
|
|
||||||
The values of box
|
The values of box
|
||||||
|
|
||||||
center location and size are within [0, 1].
|
center location and size are within [0, 1].
|
||||||
|
|
||||||
Attributes: x_center : The X normalized coordinate of the top-left corner.
|
Attributes: x_center: The normalized X coordinate of the rectangle, in
|
||||||
y_center : The Y normalized coordinate of the top-left corner.
|
image coordinates.
|
||||||
|
y_center: The normalized Y coordinate of the rectangle, in image coordinates.
|
||||||
width: The width of the rectangle.
|
width: The width of the rectangle.
|
||||||
height: The height of the rectangle.
|
height: The height of the rectangle.
|
||||||
rotation: Rotation angle is clockwise in radians.
|
rotation: Rotation angle is clockwise in radians.
|
||||||
|
|
|
@ -54,6 +54,7 @@ py_test(
|
||||||
"//mediapipe/tasks/python/test:test_utils",
|
"//mediapipe/tasks/python/test:test_utils",
|
||||||
"//mediapipe/tasks/python/vision:image_classifier",
|
"//mediapipe/tasks/python/vision:image_classifier",
|
||||||
"//mediapipe/tasks/python/vision/core:vision_task_running_mode",
|
"//mediapipe/tasks/python/vision/core:vision_task_running_mode",
|
||||||
|
"//mediapipe/tasks/python/vision/core:image_processing_options",
|
||||||
],
|
],
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
|
@ -78,7 +78,7 @@ def _get_expected_gesture_recognition_result(
|
||||||
file_path)
|
file_path)
|
||||||
with open(landmarks_detection_result_file_path, "rb") as f:
|
with open(landmarks_detection_result_file_path, "rb") as f:
|
||||||
landmarks_detection_result_proto = _LandmarksDetectionResultProto()
|
landmarks_detection_result_proto = _LandmarksDetectionResultProto()
|
||||||
# # Use this if a .pb file is available.
|
# Use this if a .pb file is available.
|
||||||
# landmarks_detection_result_proto.ParseFromString(f.read())
|
# landmarks_detection_result_proto.ParseFromString(f.read())
|
||||||
text_format.Parse(f.read(), landmarks_detection_result_proto)
|
text_format.Parse(f.read(), landmarks_detection_result_proto)
|
||||||
landmarks_detection_result = _LandmarksDetectionResult.create_from_pb2(
|
landmarks_detection_result = _LandmarksDetectionResult.create_from_pb2(
|
||||||
|
|
|
@ -29,8 +29,10 @@ from mediapipe.tasks.python.core import base_options as base_options_module
|
||||||
from mediapipe.tasks.python.test import test_utils
|
from mediapipe.tasks.python.test import test_utils
|
||||||
from mediapipe.tasks.python.vision import image_classifier
|
from mediapipe.tasks.python.vision import image_classifier
|
||||||
from mediapipe.tasks.python.vision.core import vision_task_running_mode
|
from mediapipe.tasks.python.vision.core import vision_task_running_mode
|
||||||
|
from mediapipe.tasks.python.vision.core import image_processing_options as image_processing_options_module
|
||||||
|
|
||||||
_NormalizedRect = rect.NormalizedRect
|
|
||||||
|
_Rect = rect.Rect
|
||||||
_BaseOptions = base_options_module.BaseOptions
|
_BaseOptions = base_options_module.BaseOptions
|
||||||
_ClassifierOptions = classifier_options.ClassifierOptions
|
_ClassifierOptions = classifier_options.ClassifierOptions
|
||||||
_Category = category.Category
|
_Category = category.Category
|
||||||
|
@ -41,6 +43,7 @@ _Image = image.Image
|
||||||
_ImageClassifier = image_classifier.ImageClassifier
|
_ImageClassifier = image_classifier.ImageClassifier
|
||||||
_ImageClassifierOptions = image_classifier.ImageClassifierOptions
|
_ImageClassifierOptions = image_classifier.ImageClassifierOptions
|
||||||
_RUNNING_MODE = vision_task_running_mode.VisionTaskRunningMode
|
_RUNNING_MODE = vision_task_running_mode.VisionTaskRunningMode
|
||||||
|
_ImageProcessingOptions = image_processing_options_module.ImageProcessingOptions
|
||||||
|
|
||||||
_MODEL_FILE = 'mobilenet_v2_1.0_224.tflite'
|
_MODEL_FILE = 'mobilenet_v2_1.0_224.tflite'
|
||||||
_IMAGE_FILE = 'burger.jpg'
|
_IMAGE_FILE = 'burger.jpg'
|
||||||
|
@ -226,11 +229,11 @@ class ImageClassifierTest(parameterized.TestCase):
|
||||||
# Load the test image.
|
# Load the test image.
|
||||||
test_image = _Image.create_from_file(
|
test_image = _Image.create_from_file(
|
||||||
test_utils.get_test_data_path('multi_objects.jpg'))
|
test_utils.get_test_data_path('multi_objects.jpg'))
|
||||||
# NormalizedRect around the soccer ball.
|
# Region-of-interest around the soccer ball.
|
||||||
roi = _NormalizedRect(
|
roi = _Rect(left=0.45, top=0.3075, right=0.614, bottom=0.7345)
|
||||||
x_center=0.532, y_center=0.521, width=0.164, height=0.427)
|
image_processing_options = _ImageProcessingOptions(roi)
|
||||||
# Performs image classification on the input.
|
# Performs image classification on the input.
|
||||||
image_result = classifier.classify(test_image, roi)
|
image_result = classifier.classify(test_image, image_processing_options)
|
||||||
# Comparing results.
|
# Comparing results.
|
||||||
_assert_proto_equals(image_result.to_pb2(),
|
_assert_proto_equals(image_result.to_pb2(),
|
||||||
_generate_soccer_ball_results(0).to_pb2())
|
_generate_soccer_ball_results(0).to_pb2())
|
||||||
|
@ -414,12 +417,12 @@ class ImageClassifierTest(parameterized.TestCase):
|
||||||
# Load the test image.
|
# Load the test image.
|
||||||
test_image = _Image.create_from_file(
|
test_image = _Image.create_from_file(
|
||||||
test_utils.get_test_data_path('multi_objects.jpg'))
|
test_utils.get_test_data_path('multi_objects.jpg'))
|
||||||
# NormalizedRect around the soccer ball.
|
# Region-of-interest around the soccer ball.
|
||||||
roi = _NormalizedRect(
|
roi = _Rect(left=0.45, top=0.3075, right=0.614, bottom=0.7345)
|
||||||
x_center=0.532, y_center=0.521, width=0.164, height=0.427)
|
image_processing_options = _ImageProcessingOptions(roi)
|
||||||
for timestamp in range(0, 300, 30):
|
for timestamp in range(0, 300, 30):
|
||||||
classification_result = classifier.classify_for_video(
|
classification_result = classifier.classify_for_video(
|
||||||
test_image, timestamp, roi)
|
test_image, timestamp, image_processing_options)
|
||||||
self.assertEqual(classification_result,
|
self.assertEqual(classification_result,
|
||||||
_generate_soccer_ball_results(timestamp))
|
_generate_soccer_ball_results(timestamp))
|
||||||
|
|
||||||
|
@ -486,9 +489,9 @@ class ImageClassifierTest(parameterized.TestCase):
|
||||||
# Load the test image.
|
# Load the test image.
|
||||||
test_image = _Image.create_from_file(
|
test_image = _Image.create_from_file(
|
||||||
test_utils.get_test_data_path('multi_objects.jpg'))
|
test_utils.get_test_data_path('multi_objects.jpg'))
|
||||||
# NormalizedRect around the soccer ball.
|
# Region-of-interest around the soccer ball.
|
||||||
roi = _NormalizedRect(
|
roi = _Rect(left=0.45, top=0.3075, right=0.614, bottom=0.7345)
|
||||||
x_center=0.532, y_center=0.521, width=0.164, height=0.427)
|
image_processing_options = _ImageProcessingOptions(roi)
|
||||||
observed_timestamp_ms = -1
|
observed_timestamp_ms = -1
|
||||||
|
|
||||||
def check_result(result: _ClassificationResult, output_image: _Image,
|
def check_result(result: _ClassificationResult, output_image: _Image,
|
||||||
|
@ -508,7 +511,8 @@ class ImageClassifierTest(parameterized.TestCase):
|
||||||
result_callback=check_result)
|
result_callback=check_result)
|
||||||
with _ImageClassifier.create_from_options(options) as classifier:
|
with _ImageClassifier.create_from_options(options) as classifier:
|
||||||
for timestamp in range(0, 300, 30):
|
for timestamp in range(0, 300, 30):
|
||||||
classifier.classify_async(test_image, timestamp, roi)
|
classifier.classify_async(test_image, timestamp,
|
||||||
|
image_processing_options)
|
||||||
|
|
||||||
|
|
||||||
if __name__ == '__main__':
|
if __name__ == '__main__':
|
||||||
|
|
|
@ -56,6 +56,7 @@ py_library(
|
||||||
"//mediapipe/tasks/python/core:task_info",
|
"//mediapipe/tasks/python/core:task_info",
|
||||||
"//mediapipe/tasks/python/vision/core:base_vision_task_api",
|
"//mediapipe/tasks/python/vision/core:base_vision_task_api",
|
||||||
"//mediapipe/tasks/python/vision/core:vision_task_running_mode",
|
"//mediapipe/tasks/python/vision/core:vision_task_running_mode",
|
||||||
|
"//mediapipe/tasks/python/vision/core:image_processing_options",
|
||||||
],
|
],
|
||||||
)
|
)
|
||||||
|
|
||||||
|
@ -96,5 +97,6 @@ py_library(
|
||||||
"//mediapipe/tasks/python/core:task_info",
|
"//mediapipe/tasks/python/core:task_info",
|
||||||
"//mediapipe/tasks/python/vision/core:base_vision_task_api",
|
"//mediapipe/tasks/python/vision/core:base_vision_task_api",
|
||||||
"//mediapipe/tasks/python/vision/core:vision_task_running_mode",
|
"//mediapipe/tasks/python/vision/core:vision_task_running_mode",
|
||||||
|
"//mediapipe/tasks/python/vision/core:image_processing_options",
|
||||||
],
|
],
|
||||||
)
|
)
|
||||||
|
|
|
@ -160,15 +160,15 @@ class BaseVisionTaskApi(object):
|
||||||
if not roi_allowed:
|
if not roi_allowed:
|
||||||
raise ValueError("This task doesn't support region-of-interest.")
|
raise ValueError("This task doesn't support region-of-interest.")
|
||||||
roi = options.region_of_interest
|
roi = options.region_of_interest
|
||||||
if roi.x_center >= roi.width or roi.y_center >= roi.height:
|
if roi.left >= roi.right or roi.top >= roi.bottom:
|
||||||
raise ValueError(
|
raise ValueError(
|
||||||
"Expected Rect with x_center < width and y_center < height.")
|
"Expected Rect with left < right and top < bottom.")
|
||||||
if roi.x_center < 0 or roi.y_center < 0 or roi.width > 1 or roi.height > 1:
|
if roi.left < 0 or roi.top < 0 or roi.right > 1 or roi.bottom > 1:
|
||||||
raise ValueError("Expected Rect values to be in [0,1].")
|
raise ValueError("Expected Rect values to be in [0,1].")
|
||||||
normalized_rect.x_center = roi.x_center + roi.width / 2.0
|
normalized_rect.x_center = (roi.left + roi.right) / 2.0
|
||||||
normalized_rect.y_center = roi.y_center + roi.height / 2.0
|
normalized_rect.y_center = (roi.top + roi.bottom) / 2.0
|
||||||
normalized_rect.width = roi.width - roi.x_center
|
normalized_rect.width = roi.right - roi.left
|
||||||
normalized_rect.height = roi.height - roi.y_center
|
normalized_rect.height = roi.bottom - roi.top
|
||||||
return normalized_rect
|
return normalized_rect
|
||||||
|
|
||||||
def close(self) -> None:
|
def close(self) -> None:
|
||||||
|
|
|
@ -30,7 +30,7 @@ class ImageProcessingOptions:
|
||||||
Attributes:
|
Attributes:
|
||||||
region_of_interest: The optional region-of-interest to crop from the image.
|
region_of_interest: The optional region-of-interest to crop from the image.
|
||||||
If not specified, the full image is used. Coordinates must be in [0,1]
|
If not specified, the full image is used. Coordinates must be in [0,1]
|
||||||
with 'x_center' < 'width' and 'y_center' < height.
|
with 'left' < 'right' and 'top' < 'bottom'.
|
||||||
rotation_degress: The rotation to apply to the image (or cropped
|
rotation_degress: The rotation to apply to the image (or cropped
|
||||||
region-of-interest), in degrees clockwise. The rotation must be a
|
region-of-interest), in degrees clockwise. The rotation must be a
|
||||||
multiple (positive or negative) of 90°.
|
multiple (positive or negative) of 90°.
|
||||||
|
|
|
@ -63,7 +63,7 @@ class GestureRecognitionResult:
|
||||||
|
|
||||||
Attributes:
|
Attributes:
|
||||||
gestures: Recognized hand gestures of detected hands. Note that the index
|
gestures: Recognized hand gestures of detected hands. Note that the index
|
||||||
of the gesture is always 0, because the raw indices from multiple gesture
|
of the gesture is always -1, because the raw indices from multiple gesture
|
||||||
classifiers cannot consolidate to a meaningful index.
|
classifiers cannot consolidate to a meaningful index.
|
||||||
handedness: Classification of handedness.
|
handedness: Classification of handedness.
|
||||||
hand_landmarks: Detected hand landmarks in normalized image coordinates.
|
hand_landmarks: Detected hand landmarks in normalized image coordinates.
|
||||||
|
|
|
@ -31,12 +31,14 @@ from mediapipe.tasks.python.core import task_info as task_info_module
|
||||||
from mediapipe.tasks.python.core.optional_dependencies import doc_controls
|
from mediapipe.tasks.python.core.optional_dependencies import doc_controls
|
||||||
from mediapipe.tasks.python.vision.core import base_vision_task_api
|
from mediapipe.tasks.python.vision.core import base_vision_task_api
|
||||||
from mediapipe.tasks.python.vision.core import vision_task_running_mode
|
from mediapipe.tasks.python.vision.core import vision_task_running_mode
|
||||||
|
from mediapipe.tasks.python.vision.core import image_processing_options as image_processing_options_module
|
||||||
|
|
||||||
_NormalizedRect = rect.NormalizedRect
|
_NormalizedRect = rect.NormalizedRect
|
||||||
_BaseOptions = base_options_module.BaseOptions
|
_BaseOptions = base_options_module.BaseOptions
|
||||||
_ImageClassifierGraphOptionsProto = image_classifier_graph_options_pb2.ImageClassifierGraphOptions
|
_ImageClassifierGraphOptionsProto = image_classifier_graph_options_pb2.ImageClassifierGraphOptions
|
||||||
_ClassifierOptions = classifier_options.ClassifierOptions
|
_ClassifierOptions = classifier_options.ClassifierOptions
|
||||||
_RunningMode = vision_task_running_mode.VisionTaskRunningMode
|
_RunningMode = vision_task_running_mode.VisionTaskRunningMode
|
||||||
|
_ImageProcessingOptions = image_processing_options_module.ImageProcessingOptions
|
||||||
_TaskInfo = task_info_module.TaskInfo
|
_TaskInfo = task_info_module.TaskInfo
|
||||||
|
|
||||||
_CLASSIFICATION_RESULT_OUT_STREAM_NAME = 'classification_result_out'
|
_CLASSIFICATION_RESULT_OUT_STREAM_NAME = 'classification_result_out'
|
||||||
|
@ -44,17 +46,12 @@ _CLASSIFICATION_RESULT_TAG = 'CLASSIFICATION_RESULT'
|
||||||
_IMAGE_IN_STREAM_NAME = 'image_in'
|
_IMAGE_IN_STREAM_NAME = 'image_in'
|
||||||
_IMAGE_OUT_STREAM_NAME = 'image_out'
|
_IMAGE_OUT_STREAM_NAME = 'image_out'
|
||||||
_IMAGE_TAG = 'IMAGE'
|
_IMAGE_TAG = 'IMAGE'
|
||||||
_NORM_RECT_NAME = 'norm_rect_in'
|
_NORM_RECT_STREAM_NAME = 'norm_rect_in'
|
||||||
_NORM_RECT_TAG = 'NORM_RECT'
|
_NORM_RECT_TAG = 'NORM_RECT'
|
||||||
_TASK_GRAPH_NAME = 'mediapipe.tasks.vision.image_classifier.ImageClassifierGraph'
|
_TASK_GRAPH_NAME = 'mediapipe.tasks.vision.image_classifier.ImageClassifierGraph'
|
||||||
_MICRO_SECONDS_PER_MILLISECOND = 1000
|
_MICRO_SECONDS_PER_MILLISECOND = 1000
|
||||||
|
|
||||||
|
|
||||||
def _build_full_image_norm_rect() -> _NormalizedRect:
|
|
||||||
# Builds a NormalizedRect covering the entire image.
|
|
||||||
return _NormalizedRect(x_center=0.5, y_center=0.5, width=1, height=1)
|
|
||||||
|
|
||||||
|
|
||||||
@dataclasses.dataclass
|
@dataclasses.dataclass
|
||||||
class ImageClassifierOptions:
|
class ImageClassifierOptions:
|
||||||
"""Options for the image classifier task.
|
"""Options for the image classifier task.
|
||||||
|
@ -156,7 +153,7 @@ class ImageClassifier(base_vision_task_api.BaseVisionTaskApi):
|
||||||
task_graph=_TASK_GRAPH_NAME,
|
task_graph=_TASK_GRAPH_NAME,
|
||||||
input_streams=[
|
input_streams=[
|
||||||
':'.join([_IMAGE_TAG, _IMAGE_IN_STREAM_NAME]),
|
':'.join([_IMAGE_TAG, _IMAGE_IN_STREAM_NAME]),
|
||||||
':'.join([_NORM_RECT_TAG, _NORM_RECT_NAME]),
|
':'.join([_NORM_RECT_TAG, _NORM_RECT_STREAM_NAME]),
|
||||||
],
|
],
|
||||||
output_streams=[
|
output_streams=[
|
||||||
':'.join([
|
':'.join([
|
||||||
|
@ -171,17 +168,16 @@ class ImageClassifier(base_vision_task_api.BaseVisionTaskApi):
|
||||||
_RunningMode.LIVE_STREAM), options.running_mode,
|
_RunningMode.LIVE_STREAM), options.running_mode,
|
||||||
packets_callback if options.result_callback else None)
|
packets_callback if options.result_callback else None)
|
||||||
|
|
||||||
# TODO: Replace _NormalizedRect with ImageProcessingOption
|
|
||||||
def classify(
|
def classify(
|
||||||
self,
|
self,
|
||||||
image: image_module.Image,
|
image: image_module.Image,
|
||||||
roi: Optional[_NormalizedRect] = None
|
image_processing_options: Optional[_ImageProcessingOptions] = None
|
||||||
) -> classifications.ClassificationResult:
|
) -> classifications.ClassificationResult:
|
||||||
"""Performs image classification on the provided MediaPipe Image.
|
"""Performs image classification on the provided MediaPipe Image.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
image: MediaPipe Image.
|
image: MediaPipe Image.
|
||||||
roi: The region of interest.
|
image_processing_options: Options for image processing.
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
A classification result object that contains a list of classifications.
|
A classification result object that contains a list of classifications.
|
||||||
|
@ -190,10 +186,11 @@ class ImageClassifier(base_vision_task_api.BaseVisionTaskApi):
|
||||||
ValueError: If any of the input arguments is invalid.
|
ValueError: If any of the input arguments is invalid.
|
||||||
RuntimeError: If image classification failed to run.
|
RuntimeError: If image classification failed to run.
|
||||||
"""
|
"""
|
||||||
norm_rect = roi if roi is not None else _build_full_image_norm_rect()
|
normalized_rect = self.convert_to_normalized_rect(image_processing_options)
|
||||||
output_packets = self._process_image_data({
|
output_packets = self._process_image_data({
|
||||||
_IMAGE_IN_STREAM_NAME: packet_creator.create_image(image),
|
_IMAGE_IN_STREAM_NAME: packet_creator.create_image(image),
|
||||||
_NORM_RECT_NAME: packet_creator.create_proto(norm_rect.to_pb2())
|
_NORM_RECT_STREAM_NAME: packet_creator.create_proto(
|
||||||
|
normalized_rect.to_pb2())
|
||||||
})
|
})
|
||||||
|
|
||||||
classification_result_proto = classifications_pb2.ClassificationResult()
|
classification_result_proto = classifications_pb2.ClassificationResult()
|
||||||
|
@ -210,7 +207,7 @@ class ImageClassifier(base_vision_task_api.BaseVisionTaskApi):
|
||||||
self,
|
self,
|
||||||
image: image_module.Image,
|
image: image_module.Image,
|
||||||
timestamp_ms: int,
|
timestamp_ms: int,
|
||||||
roi: Optional[_NormalizedRect] = None
|
image_processing_options: Optional[_ImageProcessingOptions] = None
|
||||||
) -> classifications.ClassificationResult:
|
) -> classifications.ClassificationResult:
|
||||||
"""Performs image classification on the provided video frames.
|
"""Performs image classification on the provided video frames.
|
||||||
|
|
||||||
|
@ -222,7 +219,7 @@ class ImageClassifier(base_vision_task_api.BaseVisionTaskApi):
|
||||||
Args:
|
Args:
|
||||||
image: MediaPipe Image.
|
image: MediaPipe Image.
|
||||||
timestamp_ms: The timestamp of the input video frame in milliseconds.
|
timestamp_ms: The timestamp of the input video frame in milliseconds.
|
||||||
roi: The region of interest.
|
image_processing_options: Options for image processing.
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
A classification result object that contains a list of classifications.
|
A classification result object that contains a list of classifications.
|
||||||
|
@ -231,13 +228,13 @@ class ImageClassifier(base_vision_task_api.BaseVisionTaskApi):
|
||||||
ValueError: If any of the input arguments is invalid.
|
ValueError: If any of the input arguments is invalid.
|
||||||
RuntimeError: If image classification failed to run.
|
RuntimeError: If image classification failed to run.
|
||||||
"""
|
"""
|
||||||
norm_rect = roi if roi is not None else _build_full_image_norm_rect()
|
normalized_rect = self.convert_to_normalized_rect(image_processing_options)
|
||||||
output_packets = self._process_video_data({
|
output_packets = self._process_video_data({
|
||||||
_IMAGE_IN_STREAM_NAME:
|
_IMAGE_IN_STREAM_NAME:
|
||||||
packet_creator.create_image(image).at(
|
packet_creator.create_image(image).at(
|
||||||
timestamp_ms * _MICRO_SECONDS_PER_MILLISECOND),
|
timestamp_ms * _MICRO_SECONDS_PER_MILLISECOND),
|
||||||
_NORM_RECT_NAME:
|
_NORM_RECT_STREAM_NAME:
|
||||||
packet_creator.create_proto(norm_rect.to_pb2()).at(
|
packet_creator.create_proto(normalized_rect.to_pb2()).at(
|
||||||
timestamp_ms * _MICRO_SECONDS_PER_MILLISECOND)
|
timestamp_ms * _MICRO_SECONDS_PER_MILLISECOND)
|
||||||
})
|
})
|
||||||
|
|
||||||
|
@ -251,10 +248,12 @@ class ImageClassifier(base_vision_task_api.BaseVisionTaskApi):
|
||||||
for classification in classification_result_proto.classifications
|
for classification in classification_result_proto.classifications
|
||||||
])
|
])
|
||||||
|
|
||||||
def classify_async(self,
|
def classify_async(
|
||||||
image: image_module.Image,
|
self,
|
||||||
timestamp_ms: int,
|
image: image_module.Image,
|
||||||
roi: Optional[_NormalizedRect] = None) -> None:
|
timestamp_ms: int,
|
||||||
|
image_processing_options: Optional[_ImageProcessingOptions] = None
|
||||||
|
) -> None:
|
||||||
"""Sends live image data (an Image with a unique timestamp) to perform image classification.
|
"""Sends live image data (an Image with a unique timestamp) to perform image classification.
|
||||||
|
|
||||||
Only use this method when the ImageClassifier is created with the live
|
Only use this method when the ImageClassifier is created with the live
|
||||||
|
@ -275,18 +274,18 @@ class ImageClassifier(base_vision_task_api.BaseVisionTaskApi):
|
||||||
Args:
|
Args:
|
||||||
image: MediaPipe Image.
|
image: MediaPipe Image.
|
||||||
timestamp_ms: The timestamp of the input image in milliseconds.
|
timestamp_ms: The timestamp of the input image in milliseconds.
|
||||||
roi: The region of interest.
|
image_processing_options: Options for image processing.
|
||||||
|
|
||||||
Raises:
|
Raises:
|
||||||
ValueError: If the current input timestamp is smaller than what the image
|
ValueError: If the current input timestamp is smaller than what the image
|
||||||
classifier has already processed.
|
classifier has already processed.
|
||||||
"""
|
"""
|
||||||
norm_rect = roi if roi is not None else _build_full_image_norm_rect()
|
normalized_rect = self.convert_to_normalized_rect(image_processing_options)
|
||||||
self._send_live_stream_data({
|
self._send_live_stream_data({
|
||||||
_IMAGE_IN_STREAM_NAME:
|
_IMAGE_IN_STREAM_NAME:
|
||||||
packet_creator.create_image(image).at(
|
packet_creator.create_image(image).at(
|
||||||
timestamp_ms * _MICRO_SECONDS_PER_MILLISECOND),
|
timestamp_ms * _MICRO_SECONDS_PER_MILLISECOND),
|
||||||
_NORM_RECT_NAME:
|
_NORM_RECT_STREAM_NAME:
|
||||||
packet_creator.create_proto(norm_rect.to_pb2()).at(
|
packet_creator.create_proto(normalized_rect.to_pb2()).at(
|
||||||
timestamp_ms * _MICRO_SECONDS_PER_MILLISECOND)
|
timestamp_ms * _MICRO_SECONDS_PER_MILLISECOND)
|
||||||
})
|
})
|
||||||
|
|
Loading…
Reference in New Issue
Block a user