Updated image classifier to use a region of interest parameter
This commit is contained in:
		
							parent
							
								
									cb806071ba
								
							
						
					
					
						commit
						44e6f8e1a1
					
				|  | @ -27,6 +27,15 @@ py_library( | ||||||
|     ], |     ], | ||||||
| ) | ) | ||||||
| 
 | 
 | ||||||
|  | py_library( | ||||||
|  |     name = "rect", | ||||||
|  |     srcs = ["rect.py"], | ||||||
|  |     deps = [ | ||||||
|  |         "//mediapipe/framework/formats:rect_py_pb2", | ||||||
|  |         "//mediapipe/tasks/python/core:optional_dependencies", | ||||||
|  |     ], | ||||||
|  | ) | ||||||
|  | 
 | ||||||
| py_library( | py_library( | ||||||
|     name = "category", |     name = "category", | ||||||
|     srcs = ["category.py"], |     srcs = ["category.py"], | ||||||
|  |  | ||||||
							
								
								
									
										136
									
								
								mediapipe/tasks/python/components/containers/rect.py
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										136
									
								
								mediapipe/tasks/python/components/containers/rect.py
									
									
									
									
									
										Normal file
									
								
							|  | @ -0,0 +1,136 @@ | ||||||
|  | # Copyright 2022 The MediaPipe Authors. All Rights Reserved. | ||||||
|  | # | ||||||
|  | # Licensed under the Apache License, Version 2.0 (the "License"); | ||||||
|  | # you may not use this file except in compliance with the License. | ||||||
|  | # You may obtain a copy of the License at | ||||||
|  | # | ||||||
|  | #     http://www.apache.org/licenses/LICENSE-2.0 | ||||||
|  | # | ||||||
|  | # Unless required by applicable law or agreed to in writing, software | ||||||
|  | # distributed under the License is distributed on an "AS IS" BASIS, | ||||||
|  | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||||||
|  | # See the License for the specific language governing permissions and | ||||||
|  | # limitations under the License. | ||||||
|  | """Rect data class.""" | ||||||
|  | 
 | ||||||
|  | import dataclasses | ||||||
|  | from typing import Any, Optional | ||||||
|  | 
 | ||||||
|  | from mediapipe.framework.formats import rect_pb2 | ||||||
|  | from mediapipe.tasks.python.core.optional_dependencies import doc_controls | ||||||
|  | 
 | ||||||
|  | _RectProto = rect_pb2.Rect | ||||||
|  | _NormalizedRectProto = rect_pb2.NormalizedRect | ||||||
|  | 
 | ||||||
|  | 
 | ||||||
|  | @dataclasses.dataclass | ||||||
|  | class Rect: | ||||||
|  |   """A rectangle with rotation in image coordinates. | ||||||
|  | 
 | ||||||
|  |   Attributes: | ||||||
|  |     x_center : The X coordinate of the top-left corner, in pixels. | ||||||
|  |     y_center : The Y coordinate of the top-left corner, in pixels. | ||||||
|  |     width: The width of the rectangle, in pixels. | ||||||
|  |     height: The height of the rectangle, in pixels. | ||||||
|  |     rotation: Rotation angle is clockwise in radians. | ||||||
|  |     rect_id:  Optional unique id to help associate different rectangles to each | ||||||
|  |       other. | ||||||
|  |   """ | ||||||
|  | 
 | ||||||
|  |   x_center: int | ||||||
|  |   y_center: int | ||||||
|  |   width: int | ||||||
|  |   height: int | ||||||
|  |   rotation: Optional[float] = 0.0 | ||||||
|  |   rect_id: Optional[int] = None | ||||||
|  | 
 | ||||||
|  |   @doc_controls.do_not_generate_docs | ||||||
|  |   def to_pb2(self) -> _RectProto: | ||||||
|  |     """Generates a Rect protobuf object.""" | ||||||
|  |     return _RectProto( | ||||||
|  |         x_center=self.x_center, | ||||||
|  |         y_center=self.y_center, | ||||||
|  |         width=self.width, | ||||||
|  |         height=self.height, | ||||||
|  |     ) | ||||||
|  | 
 | ||||||
|  |   @classmethod | ||||||
|  |   @doc_controls.do_not_generate_docs | ||||||
|  |   def create_from_pb2(cls, pb2_obj: _RectProto) -> 'Rect': | ||||||
|  |     """Creates a `Rect` object from the given protobuf object.""" | ||||||
|  |     return Rect( | ||||||
|  |         x_center=pb2_obj.x_center, | ||||||
|  |         y_center=pb2_obj.y_center, | ||||||
|  |         width=pb2_obj.width, | ||||||
|  |         height=pb2_obj.height) | ||||||
|  | 
 | ||||||
|  |   def __eq__(self, other: Any) -> bool: | ||||||
|  |     """Checks if this object is equal to the given object. | ||||||
|  | 
 | ||||||
|  |     Args: | ||||||
|  |       other: The object to be compared with. | ||||||
|  | 
 | ||||||
|  |     Returns: | ||||||
|  |       True if the objects are equal. | ||||||
|  |     """ | ||||||
|  |     if not isinstance(other, Rect): | ||||||
|  |       return False | ||||||
|  | 
 | ||||||
|  |     return self.to_pb2().__eq__(other.to_pb2()) | ||||||
|  | 
 | ||||||
|  | 
 | ||||||
|  | @dataclasses.dataclass | ||||||
|  | class NormalizedRect: | ||||||
|  |   """A rectangle with rotation in normalized coordinates. The values of box | ||||||
|  |     center location and size are within [0, 1]. | ||||||
|  | 
 | ||||||
|  |   Attributes: | ||||||
|  |     x_center : The X normalized coordinate of the top-left corner. | ||||||
|  |     y_center : The Y normalized coordinate of the top-left corner. | ||||||
|  |     width: The width of the rectangle. | ||||||
|  |     height: The height of the rectangle. | ||||||
|  |     rotation: Rotation angle is clockwise in radians. | ||||||
|  |     rect_id:  Optional unique id to help associate different rectangles to each | ||||||
|  |       other. | ||||||
|  |   """ | ||||||
|  | 
 | ||||||
|  |   x_center: float | ||||||
|  |   y_center: float | ||||||
|  |   width: float | ||||||
|  |   height: float | ||||||
|  |   rotation: Optional[float] = 0.0 | ||||||
|  |   rect_id: Optional[int] = None | ||||||
|  | 
 | ||||||
|  |   @doc_controls.do_not_generate_docs | ||||||
|  |   def to_pb2(self) -> _NormalizedRectProto: | ||||||
|  |     """Generates a NormalizedRect protobuf object.""" | ||||||
|  |     return _NormalizedRectProto( | ||||||
|  |         x_center=self.x_center, | ||||||
|  |         y_center=self.y_center, | ||||||
|  |         width=self.width, | ||||||
|  |         height=self.height, | ||||||
|  |     ) | ||||||
|  | 
 | ||||||
|  |   @classmethod | ||||||
|  |   @doc_controls.do_not_generate_docs | ||||||
|  |   def create_from_pb2(cls, pb2_obj: _NormalizedRectProto) -> 'NormalizedRect': | ||||||
|  |     """Creates a `NormalizedRect` object from the given protobuf object.""" | ||||||
|  |     return NormalizedRect( | ||||||
|  |         x_center=pb2_obj.x_center, | ||||||
|  |         y_center=pb2_obj.y_center, | ||||||
|  |         width=pb2_obj.width, | ||||||
|  |         height=pb2_obj.height) | ||||||
|  | 
 | ||||||
|  |   def __eq__(self, other: Any) -> bool: | ||||||
|  |     """Checks if this object is equal to the given object. | ||||||
|  | 
 | ||||||
|  |     Args: | ||||||
|  |       other: The object to be compared with. | ||||||
|  | 
 | ||||||
|  |     Returns: | ||||||
|  |       True if the objects are equal. | ||||||
|  |     """ | ||||||
|  |     if not isinstance(other, NormalizedRect): | ||||||
|  |       return False | ||||||
|  | 
 | ||||||
|  |     return self.to_pb2().__eq__(other.to_pb2()) | ||||||
|  | @ -49,6 +49,7 @@ py_test( | ||||||
|         "//mediapipe/tasks/python/components/processors:classifier_options", |         "//mediapipe/tasks/python/components/processors:classifier_options", | ||||||
|         "//mediapipe/tasks/python/components/containers:category", |         "//mediapipe/tasks/python/components/containers:category", | ||||||
|         "//mediapipe/tasks/python/components/containers:classifications", |         "//mediapipe/tasks/python/components/containers:classifications", | ||||||
|  |         "//mediapipe/tasks/python/components/containers:rect", | ||||||
|         "//mediapipe/tasks/python/core:base_options", |         "//mediapipe/tasks/python/core:base_options", | ||||||
|         "//mediapipe/tasks/python/test:test_utils", |         "//mediapipe/tasks/python/test:test_utils", | ||||||
|         "//mediapipe/tasks/python/vision:image_classifier", |         "//mediapipe/tasks/python/vision:image_classifier", | ||||||
|  |  | ||||||
|  | @ -24,11 +24,13 @@ from mediapipe.python._framework_bindings import image as image_module | ||||||
| from mediapipe.tasks.python.components.processors import classifier_options | from mediapipe.tasks.python.components.processors import classifier_options | ||||||
| from mediapipe.tasks.python.components.containers import category as category_module | from mediapipe.tasks.python.components.containers import category as category_module | ||||||
| from mediapipe.tasks.python.components.containers import classifications as classifications_module | from mediapipe.tasks.python.components.containers import classifications as classifications_module | ||||||
|  | from mediapipe.tasks.python.components.containers import rect as rect_module | ||||||
| from mediapipe.tasks.python.core import base_options as base_options_module | from mediapipe.tasks.python.core import base_options as base_options_module | ||||||
| from mediapipe.tasks.python.test import test_utils | from mediapipe.tasks.python.test import test_utils | ||||||
| from mediapipe.tasks.python.vision import image_classifier | from mediapipe.tasks.python.vision import image_classifier | ||||||
| from mediapipe.tasks.python.vision.core import vision_task_running_mode as running_mode_module | from mediapipe.tasks.python.vision.core import vision_task_running_mode as running_mode_module | ||||||
| 
 | 
 | ||||||
|  | _NormalizedRect = rect_module.NormalizedRect | ||||||
| _BaseOptions = base_options_module.BaseOptions | _BaseOptions = base_options_module.BaseOptions | ||||||
| _ClassifierOptions = classifier_options.ClassifierOptions | _ClassifierOptions = classifier_options.ClassifierOptions | ||||||
| _Category = category_module.Category | _Category = category_module.Category | ||||||
|  | @ -42,40 +44,6 @@ _RUNNING_MODE = running_mode_module.VisionTaskRunningMode | ||||||
| 
 | 
 | ||||||
| _MODEL_FILE = 'mobilenet_v2_1.0_224.tflite' | _MODEL_FILE = 'mobilenet_v2_1.0_224.tflite' | ||||||
| _IMAGE_FILE = 'burger.jpg' | _IMAGE_FILE = 'burger.jpg' | ||||||
| _EXPECTED_CATEGORIES = [ |  | ||||||
|     _Category( |  | ||||||
|       index=934, |  | ||||||
|       score=0.7939587831497192, |  | ||||||
|       display_name='', |  | ||||||
|       category_name='cheeseburger'), |  | ||||||
|     _Category( |  | ||||||
|       index=932, |  | ||||||
|       score=0.02739289402961731, |  | ||||||
|       display_name='', |  | ||||||
|       category_name='bagel'), |  | ||||||
|     _Category( |  | ||||||
|       index=925, |  | ||||||
|       score=0.01934075355529785, |  | ||||||
|       display_name='', |  | ||||||
|       category_name='guacamole'), |  | ||||||
|     _Category( |  | ||||||
|       index=963, |  | ||||||
|       score=0.006327860057353973, |  | ||||||
|       display_name='', |  | ||||||
|       category_name='meat loaf') |  | ||||||
| ] |  | ||||||
| _EXPECTED_CLASSIFICATION_RESULT = _ClassificationResult( |  | ||||||
|   classifications=[ |  | ||||||
|     _Classifications( |  | ||||||
|       entries=[ |  | ||||||
|         _ClassificationEntry( |  | ||||||
|           categories=_EXPECTED_CATEGORIES, |  | ||||||
|           timestamp_ms=0 |  | ||||||
|         ) |  | ||||||
|       ], |  | ||||||
|       head_index=0, |  | ||||||
|       head_name='probability') |  | ||||||
|   ]) |  | ||||||
| _EMPTY_CLASSIFICATION_RESULT = _ClassificationResult( | _EMPTY_CLASSIFICATION_RESULT = _ClassificationResult( | ||||||
|   classifications=[ |   classifications=[ | ||||||
|     _Classifications( |     _Classifications( | ||||||
|  | @ -94,6 +62,60 @@ _SCORE_THRESHOLD = 0.5 | ||||||
| _MAX_RESULTS = 3 | _MAX_RESULTS = 3 | ||||||
| 
 | 
 | ||||||
| 
 | 
 | ||||||
|  | def _generate_burger_results(timestamp_ms: int) -> _ClassificationResult: | ||||||
|  |   return _ClassificationResult( | ||||||
|  |     classifications=[ | ||||||
|  |       _Classifications( | ||||||
|  |         entries=[ | ||||||
|  |           _ClassificationEntry( | ||||||
|  |             categories=[ | ||||||
|  |               _Category( | ||||||
|  |                 index=934, | ||||||
|  |                 score=0.7939587831497192, | ||||||
|  |                 display_name='', | ||||||
|  |                 category_name='cheeseburger'), | ||||||
|  |               _Category( | ||||||
|  |                 index=932, | ||||||
|  |                 score=0.02739289402961731, | ||||||
|  |                 display_name='', | ||||||
|  |                 category_name='bagel'), | ||||||
|  |               _Category( | ||||||
|  |                 index=925, | ||||||
|  |                 score=0.01934075355529785, | ||||||
|  |                 display_name='', | ||||||
|  |                 category_name='guacamole'), | ||||||
|  |               _Category( | ||||||
|  |                 index=963, | ||||||
|  |                 score=0.006327860057353973, | ||||||
|  |                 display_name='', | ||||||
|  |                 category_name='meat loaf') | ||||||
|  |             ], | ||||||
|  |             timestamp_ms=timestamp_ms | ||||||
|  |           ) | ||||||
|  |         ], | ||||||
|  |         head_index=0, | ||||||
|  |         head_name='probability') | ||||||
|  |     ]) | ||||||
|  | 
 | ||||||
|  | 
 | ||||||
|  | def _generate_soccer_ball_results(timestamp_ms: int) -> _ClassificationResult: | ||||||
|  |   return _ClassificationResult( | ||||||
|  |     classifications=[ | ||||||
|  |       _Classifications( | ||||||
|  |         entries=[ | ||||||
|  |           _ClassificationEntry( | ||||||
|  |             categories=[ | ||||||
|  |               _Category(index=806, score=0.9965274930000305, display_name='', | ||||||
|  |                         category_name='soccer ball') | ||||||
|  |             ], | ||||||
|  |             timestamp_ms=timestamp_ms | ||||||
|  |           ) | ||||||
|  |         ], | ||||||
|  |         head_index=0, | ||||||
|  |         head_name='probability') | ||||||
|  |     ]) | ||||||
|  | 
 | ||||||
|  | 
 | ||||||
| class ModelFileType(enum.Enum): | class ModelFileType(enum.Enum): | ||||||
|   FILE_CONTENT = 1 |   FILE_CONTENT = 1 | ||||||
|   FILE_NAME = 2 |   FILE_NAME = 2 | ||||||
|  | @ -138,8 +160,8 @@ class ImageClassifierTest(parameterized.TestCase): | ||||||
|       self.assertIsInstance(classifier, _ImageClassifier) |       self.assertIsInstance(classifier, _ImageClassifier) | ||||||
| 
 | 
 | ||||||
|   @parameterized.parameters( |   @parameterized.parameters( | ||||||
|       (ModelFileType.FILE_NAME, 4, _EXPECTED_CLASSIFICATION_RESULT), |       (ModelFileType.FILE_NAME, 4, _generate_burger_results(0)), | ||||||
|       (ModelFileType.FILE_CONTENT, 4, _EXPECTED_CLASSIFICATION_RESULT)) |       (ModelFileType.FILE_CONTENT, 4, _generate_burger_results(0))) | ||||||
|   def test_classify(self, model_file_type, max_results, |   def test_classify(self, model_file_type, max_results, | ||||||
|                     expected_classification_result): |                     expected_classification_result): | ||||||
|     # Creates classifier. |     # Creates classifier. | ||||||
|  | @ -167,8 +189,8 @@ class ImageClassifierTest(parameterized.TestCase): | ||||||
|     classifier.close() |     classifier.close() | ||||||
| 
 | 
 | ||||||
|   @parameterized.parameters( |   @parameterized.parameters( | ||||||
|     (ModelFileType.FILE_NAME, 4, _EXPECTED_CLASSIFICATION_RESULT), |     (ModelFileType.FILE_NAME, 4, _generate_burger_results(0)), | ||||||
|     (ModelFileType.FILE_CONTENT, 4, _EXPECTED_CLASSIFICATION_RESULT)) |     (ModelFileType.FILE_CONTENT, 4, _generate_burger_results(0))) | ||||||
|   def test_classify_in_context(self, model_file_type, max_results, |   def test_classify_in_context(self, model_file_type, max_results, | ||||||
|                                expected_classification_result): |                                expected_classification_result): | ||||||
|     if model_file_type is ModelFileType.FILE_NAME: |     if model_file_type is ModelFileType.FILE_NAME: | ||||||
|  | @ -190,6 +212,23 @@ class ImageClassifierTest(parameterized.TestCase): | ||||||
|       # Comparing results. |       # Comparing results. | ||||||
|       self.assertEqual(image_result, expected_classification_result) |       self.assertEqual(image_result, expected_classification_result) | ||||||
| 
 | 
 | ||||||
|  |   def test_classify_succeeds_with_region_of_interest(self): | ||||||
|  |     base_options = _BaseOptions(model_asset_path=self.model_path) | ||||||
|  |     classifier_options = _ClassifierOptions(max_results=1) | ||||||
|  |     options = _ImageClassifierOptions( | ||||||
|  |       base_options=base_options, classifier_options=classifier_options) | ||||||
|  |     with _ImageClassifier.create_from_options(options) as classifier: | ||||||
|  |       # Load the test image. | ||||||
|  |       test_image = _Image.create_from_file( | ||||||
|  |           test_utils.get_test_data_path('multi_objects.jpg')) | ||||||
|  |       # NormalizedRect around the soccer ball. | ||||||
|  |       roi = _NormalizedRect(x_center=0.532, y_center=0.521, width=0.164, | ||||||
|  |                             height=0.427) | ||||||
|  |       # Performs image classification on the input. | ||||||
|  |       image_result = classifier.classify(test_image, roi) | ||||||
|  |       # Comparing results. | ||||||
|  |       self.assertEqual(image_result, _generate_soccer_ball_results(0)) | ||||||
|  | 
 | ||||||
|   def test_score_threshold_option(self): |   def test_score_threshold_option(self): | ||||||
|     classifier_options = _ClassifierOptions(score_threshold=_SCORE_THRESHOLD) |     classifier_options = _ClassifierOptions(score_threshold=_SCORE_THRESHOLD) | ||||||
|     options = _ImageClassifierOptions( |     options = _ImageClassifierOptions( | ||||||
|  | @ -353,16 +392,27 @@ class ImageClassifierTest(parameterized.TestCase): | ||||||
|       for timestamp in range(0, 300, 30): |       for timestamp in range(0, 300, 30): | ||||||
|         classification_result = classifier.classify_for_video( |         classification_result = classifier.classify_for_video( | ||||||
|             self.test_image, timestamp) |             self.test_image, timestamp) | ||||||
|         expected_classification_result = _ClassificationResult( |         self.assertEqual(classification_result, | ||||||
|           classifications=[ |                          _generate_burger_results(timestamp)) | ||||||
|             _Classifications( | 
 | ||||||
|               entries=[ |   def test_classify_for_video_succeeds_with_region_of_interest(self): | ||||||
|                 _ClassificationEntry( |     classifier_options = _ClassifierOptions(max_results=1) | ||||||
|                   categories=_EXPECTED_CATEGORIES, timestamp_ms=timestamp) |     options = _ImageClassifierOptions( | ||||||
|               ], |       base_options=_BaseOptions(model_asset_path=self.model_path), | ||||||
|               head_index=0, head_name='probability') |       running_mode=_RUNNING_MODE.VIDEO, | ||||||
|           ]) |       classifier_options=classifier_options) | ||||||
|         self.assertEqual(classification_result, expected_classification_result) |     with _ImageClassifier.create_from_options(options) as classifier: | ||||||
|  |       # Load the test image. | ||||||
|  |       test_image = _Image.create_from_file( | ||||||
|  |         test_utils.get_test_data_path('multi_objects.jpg')) | ||||||
|  |       # NormalizedRect around the soccer ball. | ||||||
|  |       roi = _NormalizedRect(x_center=0.532, y_center=0.521, width=0.164, | ||||||
|  |                             height=0.427) | ||||||
|  |       for timestamp in range(0, 300, 30): | ||||||
|  |         classification_result = classifier.classify_for_video( | ||||||
|  |           test_image, timestamp, roi) | ||||||
|  |         self.assertEqual(classification_result, | ||||||
|  |                          _generate_soccer_ball_results(timestamp)) | ||||||
| 
 | 
 | ||||||
|   def test_calling_classify_in_live_stream_mode(self): |   def test_calling_classify_in_live_stream_mode(self): | ||||||
|     options = _ImageClassifierOptions( |     options = _ImageClassifierOptions( | ||||||
|  |  | ||||||
|  | @ -49,6 +49,7 @@ py_library( | ||||||
|         "//mediapipe/tasks/cc/vision/image_classifier/proto:image_classifier_graph_options_py_pb2", |         "//mediapipe/tasks/cc/vision/image_classifier/proto:image_classifier_graph_options_py_pb2", | ||||||
|         "//mediapipe/tasks/python/components/processors:classifier_options", |         "//mediapipe/tasks/python/components/processors:classifier_options", | ||||||
|         "//mediapipe/tasks/python/components/containers:classifications", |         "//mediapipe/tasks/python/components/containers:classifications", | ||||||
|  |         "//mediapipe/tasks/python/components/containers:rect", | ||||||
|         "//mediapipe/tasks/python/core:base_options", |         "//mediapipe/tasks/python/core:base_options", | ||||||
|         "//mediapipe/tasks/python/core:optional_dependencies", |         "//mediapipe/tasks/python/core:optional_dependencies", | ||||||
|         "//mediapipe/tasks/python/core:task_info", |         "//mediapipe/tasks/python/core:task_info", | ||||||
|  |  | ||||||
|  | @ -24,12 +24,14 @@ from mediapipe.python._framework_bindings import task_runner as task_runner_modu | ||||||
| from mediapipe.tasks.cc.vision.image_classifier.proto import image_classifier_graph_options_pb2 | from mediapipe.tasks.cc.vision.image_classifier.proto import image_classifier_graph_options_pb2 | ||||||
| from mediapipe.tasks.python.components.processors import classifier_options | from mediapipe.tasks.python.components.processors import classifier_options | ||||||
| from mediapipe.tasks.python.components.containers import classifications as classifications_module | from mediapipe.tasks.python.components.containers import classifications as classifications_module | ||||||
|  | from mediapipe.tasks.python.components.containers import rect as rect_module | ||||||
| from mediapipe.tasks.python.core import base_options as base_options_module | from mediapipe.tasks.python.core import base_options as base_options_module | ||||||
| from mediapipe.tasks.python.core import task_info as task_info_module | from mediapipe.tasks.python.core import task_info as task_info_module | ||||||
| from mediapipe.tasks.python.core.optional_dependencies import doc_controls | from mediapipe.tasks.python.core.optional_dependencies import doc_controls | ||||||
| from mediapipe.tasks.python.vision.core import base_vision_task_api | from mediapipe.tasks.python.vision.core import base_vision_task_api | ||||||
| from mediapipe.tasks.python.vision.core import vision_task_running_mode as running_mode_module | from mediapipe.tasks.python.vision.core import vision_task_running_mode as running_mode_module | ||||||
| 
 | 
 | ||||||
|  | _NormalizedRect = rect_module.NormalizedRect | ||||||
| _BaseOptions = base_options_module.BaseOptions | _BaseOptions = base_options_module.BaseOptions | ||||||
| _ImageClassifierGraphOptionsProto = image_classifier_graph_options_pb2.ImageClassifierGraphOptions | _ImageClassifierGraphOptionsProto = image_classifier_graph_options_pb2.ImageClassifierGraphOptions | ||||||
| _ClassifierOptions = classifier_options.ClassifierOptions | _ClassifierOptions = classifier_options.ClassifierOptions | ||||||
|  | @ -42,10 +44,17 @@ _CLASSIFICATION_RESULT_TAG = 'CLASSIFICATION_RESULT' | ||||||
| _IMAGE_IN_STREAM_NAME = 'image_in' | _IMAGE_IN_STREAM_NAME = 'image_in' | ||||||
| _IMAGE_OUT_STREAM_NAME = 'image_out' | _IMAGE_OUT_STREAM_NAME = 'image_out' | ||||||
| _IMAGE_TAG = 'IMAGE' | _IMAGE_TAG = 'IMAGE' | ||||||
|  | _NORM_RECT_NAME = 'norm_rect_in' | ||||||
|  | _NORM_RECT_TAG = 'NORM_RECT' | ||||||
| _TASK_GRAPH_NAME = 'mediapipe.tasks.vision.image_classifier.ImageClassifierGraph' | _TASK_GRAPH_NAME = 'mediapipe.tasks.vision.image_classifier.ImageClassifierGraph' | ||||||
| _MICRO_SECONDS_PER_MILLISECOND = 1000 | _MICRO_SECONDS_PER_MILLISECOND = 1000 | ||||||
| 
 | 
 | ||||||
| 
 | 
 | ||||||
|  | def _build_full_image_norm_rect() -> _NormalizedRect: | ||||||
|  |   # Builds a NormalizedRect covering the entire image. | ||||||
|  |   return _NormalizedRect(x_center=0.5, y_center=0.5, width=1, height=1) | ||||||
|  | 
 | ||||||
|  | 
 | ||||||
| @dataclasses.dataclass | @dataclasses.dataclass | ||||||
| class ImageClassifierOptions: | class ImageClassifierOptions: | ||||||
|   """Options for the image classifier task. |   """Options for the image classifier task. | ||||||
|  | @ -145,6 +154,7 @@ class ImageClassifier(base_vision_task_api.BaseVisionTaskApi): | ||||||
|         task_graph=_TASK_GRAPH_NAME, |         task_graph=_TASK_GRAPH_NAME, | ||||||
|         input_streams=[ |         input_streams=[ | ||||||
|             ':'.join([_IMAGE_TAG, _IMAGE_IN_STREAM_NAME]), |             ':'.join([_IMAGE_TAG, _IMAGE_IN_STREAM_NAME]), | ||||||
|  |             ':'.join([_NORM_RECT_TAG, _NORM_RECT_NAME]), | ||||||
|         ], |         ], | ||||||
|         output_streams=[ |         output_streams=[ | ||||||
|             ':'.join([_CLASSIFICATION_RESULT_TAG, |             ':'.join([_CLASSIFICATION_RESULT_TAG, | ||||||
|  | @ -161,11 +171,13 @@ class ImageClassifier(base_vision_task_api.BaseVisionTaskApi): | ||||||
|   def classify( |   def classify( | ||||||
|       self, |       self, | ||||||
|       image: image_module.Image, |       image: image_module.Image, | ||||||
|  |       roi: Optional[_NormalizedRect] = None | ||||||
|   ) -> classifications_module.ClassificationResult: |   ) -> classifications_module.ClassificationResult: | ||||||
|     """Performs image classification on the provided MediaPipe Image. |     """Performs image classification on the provided MediaPipe Image. | ||||||
| 
 | 
 | ||||||
|     Args: |     Args: | ||||||
|       image: MediaPipe Image. |       image: MediaPipe Image. | ||||||
|  |       roi: The region of interest. | ||||||
| 
 | 
 | ||||||
|     Returns: |     Returns: | ||||||
|       A classification result object that contains a list of classifications. |       A classification result object that contains a list of classifications. | ||||||
|  | @ -174,8 +186,10 @@ class ImageClassifier(base_vision_task_api.BaseVisionTaskApi): | ||||||
|       ValueError: If any of the input arguments is invalid. |       ValueError: If any of the input arguments is invalid. | ||||||
|       RuntimeError: If image classification failed to run. |       RuntimeError: If image classification failed to run. | ||||||
|     """ |     """ | ||||||
|     output_packets = self._process_image_data( |     norm_rect = roi if roi is not None else _build_full_image_norm_rect() | ||||||
|         {_IMAGE_IN_STREAM_NAME: packet_creator.create_image(image)}) |     output_packets = self._process_image_data({ | ||||||
|  |         _IMAGE_IN_STREAM_NAME: packet_creator.create_image(image), | ||||||
|  |         _NORM_RECT_NAME: packet_creator.create_proto(norm_rect.to_pb2())}) | ||||||
|     classification_result_proto = packet_getter.get_proto( |     classification_result_proto = packet_getter.get_proto( | ||||||
|         output_packets[_CLASSIFICATION_RESULT_OUT_STREAM_NAME]) |         output_packets[_CLASSIFICATION_RESULT_OUT_STREAM_NAME]) | ||||||
| 
 | 
 | ||||||
|  | @ -186,7 +200,8 @@ class ImageClassifier(base_vision_task_api.BaseVisionTaskApi): | ||||||
| 
 | 
 | ||||||
|   def classify_for_video( |   def classify_for_video( | ||||||
|       self, image: image_module.Image, |       self, image: image_module.Image, | ||||||
|       timestamp_ms: int |       timestamp_ms: int, | ||||||
|  |       roi: Optional[_NormalizedRect] = None | ||||||
|   ) -> classifications_module.ClassificationResult: |   ) -> classifications_module.ClassificationResult: | ||||||
|     """Performs image classification on the provided video frames. |     """Performs image classification on the provided video frames. | ||||||
| 
 | 
 | ||||||
|  | @ -198,6 +213,7 @@ class ImageClassifier(base_vision_task_api.BaseVisionTaskApi): | ||||||
|     Args: |     Args: | ||||||
|       image: MediaPipe Image. |       image: MediaPipe Image. | ||||||
|       timestamp_ms: The timestamp of the input video frame in milliseconds. |       timestamp_ms: The timestamp of the input video frame in milliseconds. | ||||||
|  |       roi: The region of interest. | ||||||
| 
 | 
 | ||||||
|     Returns: |     Returns: | ||||||
|       A classification result object that contains a list of classifications. |       A classification result object that contains a list of classifications. | ||||||
|  | @ -206,10 +222,12 @@ class ImageClassifier(base_vision_task_api.BaseVisionTaskApi): | ||||||
|       ValueError: If any of the input arguments is invalid. |       ValueError: If any of the input arguments is invalid. | ||||||
|       RuntimeError: If image classification failed to run. |       RuntimeError: If image classification failed to run. | ||||||
|     """ |     """ | ||||||
|  |     norm_rect = roi if roi is not None else _build_full_image_norm_rect() | ||||||
|     output_packets = self._process_video_data({ |     output_packets = self._process_video_data({ | ||||||
|         _IMAGE_IN_STREAM_NAME: |         _IMAGE_IN_STREAM_NAME: packet_creator.create_image(image).at( | ||||||
|             packet_creator.create_image(image).at( |             timestamp_ms * _MICRO_SECONDS_PER_MILLISECOND), | ||||||
|                 timestamp_ms * _MICRO_SECONDS_PER_MILLISECOND) |         _NORM_RECT_NAME: packet_creator.create_proto(norm_rect.to_pb2()).at( | ||||||
|  |             timestamp_ms * _MICRO_SECONDS_PER_MILLISECOND) | ||||||
|     }) |     }) | ||||||
|     classification_result_proto = packet_getter.get_proto( |     classification_result_proto = packet_getter.get_proto( | ||||||
|       output_packets[_CLASSIFICATION_RESULT_OUT_STREAM_NAME]) |       output_packets[_CLASSIFICATION_RESULT_OUT_STREAM_NAME]) | ||||||
|  | @ -219,7 +237,12 @@ class ImageClassifier(base_vision_task_api.BaseVisionTaskApi): | ||||||
|         for classification in classification_result_proto.classifications |         for classification in classification_result_proto.classifications | ||||||
|     ]) |     ]) | ||||||
| 
 | 
 | ||||||
|   def classify_async(self, image: image_module.Image, timestamp_ms: int) -> None: |   def classify_async( | ||||||
|  |       self, | ||||||
|  |       image: image_module.Image, | ||||||
|  |       timestamp_ms: int, | ||||||
|  |       roi: Optional[_NormalizedRect] = None | ||||||
|  |   ) -> None: | ||||||
|     """Sends live image data (an Image with a unique timestamp) to perform |     """Sends live image data (an Image with a unique timestamp) to perform | ||||||
|     image classification. |     image classification. | ||||||
| 
 | 
 | ||||||
|  | @ -241,13 +264,16 @@ class ImageClassifier(base_vision_task_api.BaseVisionTaskApi): | ||||||
|     Args: |     Args: | ||||||
|       image: MediaPipe Image. |       image: MediaPipe Image. | ||||||
|       timestamp_ms: The timestamp of the input image in milliseconds. |       timestamp_ms: The timestamp of the input image in milliseconds. | ||||||
|  |       roi: The region of interest. | ||||||
| 
 | 
 | ||||||
|     Raises: |     Raises: | ||||||
|       ValueError: If the current input timestamp is smaller than what the image |       ValueError: If the current input timestamp is smaller than what the image | ||||||
|         classifier has already processed. |         classifier has already processed. | ||||||
|     """ |     """ | ||||||
|  |     norm_rect = roi if roi is not None else _build_full_image_norm_rect() | ||||||
|     self._send_live_stream_data({ |     self._send_live_stream_data({ | ||||||
|         _IMAGE_IN_STREAM_NAME: |         _IMAGE_IN_STREAM_NAME: packet_creator.create_image(image).at( | ||||||
|             packet_creator.create_image(image).at( |             timestamp_ms * _MICRO_SECONDS_PER_MILLISECOND), | ||||||
|                 timestamp_ms * _MICRO_SECONDS_PER_MILLISECOND) |         _NORM_RECT_NAME: packet_creator.create_proto(norm_rect.to_pb2()).at( | ||||||
|  |             timestamp_ms * _MICRO_SECONDS_PER_MILLISECOND) | ||||||
|     }) |     }) | ||||||
|  |  | ||||||
		Loading…
	
		Reference in New Issue
	
	Block a user