Revised implementation to align with recent changes
This commit is contained in:
		
							parent
							
								
									aac7ff946f
								
							
						
					
					
						commit
						f241630b56
					
				|  | @ -53,7 +53,7 @@ py_library( | |||
|     srcs = ["classifications.py"], | ||||
|     deps = [ | ||||
|         ":category", | ||||
|         "//mediapipe/tasks/cc/components/containers:classifications_py_pb2", | ||||
|         "//mediapipe/tasks/cc/components/containers/proto:classifications_py_pb2", | ||||
|         "//mediapipe/tasks/python/core:optional_dependencies", | ||||
|     ], | ||||
| ) | ||||
|  | @ -16,8 +16,8 @@ | |||
| import dataclasses | ||||
| from typing import Any, List, Optional | ||||
| 
 | ||||
| from mediapipe.tasks.cc.components.containers import classifications_pb2 | ||||
| from mediapipe.tasks.python.components.containers import category as category_module | ||||
| from mediapipe.tasks.cc.components.containers.proto import classifications_pb2 | ||||
| from mediapipe.tasks.python.components.containers.proto import category as category_module | ||||
| from mediapipe.tasks.python.core.optional_dependencies import doc_controls | ||||
| 
 | ||||
| _ClassificationEntryProto = classifications_pb2.ClassificationEntry | ||||
|  | @ -22,8 +22,7 @@ py_library( | |||
|     name = "classifier_options", | ||||
|     srcs = ["classifier_options.py"], | ||||
|     deps = [ | ||||
|         "//mediapipe/tasks/cc/components/proto:classifier_options_py_pb2", | ||||
|         "//mediapipe/tasks/cc/components/processors/proto:classifier_options_py_pb2", | ||||
|         "//mediapipe/tasks/python/core:optional_dependencies", | ||||
|     ], | ||||
| ) | ||||
| 
 | ||||
|  | @ -16,7 +16,7 @@ | |||
| import dataclasses | ||||
| from typing import Any, List, Optional | ||||
| 
 | ||||
| from mediapipe.tasks.cc.components.proto import classifier_options_pb2 | ||||
| from mediapipe.tasks.cc.components.processors.proto import classifier_options_pb2 | ||||
| from mediapipe.tasks.python.core.optional_dependencies import doc_controls | ||||
| 
 | ||||
| _ClassifierOptionsProto = classifier_options_pb2.ClassifierOptions | ||||
|  | @ -46,11 +46,11 @@ py_test( | |||
|     ], | ||||
|     deps = [ | ||||
|         "//mediapipe/python:_framework_bindings", | ||||
|         "//mediapipe/tasks/python/components/proto:classifier_options", | ||||
|         "//mediapipe/tasks/python/components/containers:category", | ||||
|         "//mediapipe/tasks/python/components/containers:classifications", | ||||
|         "//mediapipe/tasks/python/components/processors/proto:classifier_options", | ||||
|         "//mediapipe/tasks/python/components/containers/proto:category", | ||||
|         "//mediapipe/tasks/python/components/containers/proto:classifications", | ||||
|         "//mediapipe/tasks/python/core:base_options", | ||||
|         "//mediapipe/tasks/python/test:test_util", | ||||
|         "//mediapipe/tasks/python/test:test_utils", | ||||
|         "//mediapipe/tasks/python/vision:image_classifier", | ||||
|         "//mediapipe/tasks/python/vision/core:vision_task_running_mode", | ||||
|     ], | ||||
|  |  | |||
|  | @ -19,11 +19,11 @@ from absl.testing import absltest | |||
| from absl.testing import parameterized | ||||
| 
 | ||||
| from mediapipe.python._framework_bindings import image as image_module | ||||
| from mediapipe.tasks.python.components.proto import classifier_options | ||||
| from mediapipe.tasks.python.components.containers import category as category_module | ||||
| from mediapipe.tasks.python.components.containers import classifications as classifications_module | ||||
| from mediapipe.tasks.python.components.processors.proto import classifier_options | ||||
| from mediapipe.tasks.python.components.containers.proto import category as category_module | ||||
| from mediapipe.tasks.python.components.containers.proto import classifications as classifications_module | ||||
| from mediapipe.tasks.python.core import base_options as base_options_module | ||||
| from mediapipe.tasks.python.test import test_util | ||||
| from mediapipe.tasks.python.test import test_utils | ||||
| from mediapipe.tasks.python.vision import image_classifier | ||||
| from mediapipe.tasks.python.vision.core import vision_task_running_mode as running_mode_module | ||||
| 
 | ||||
|  | @ -88,9 +88,9 @@ class ImageClassifierTest(parameterized.TestCase): | |||
| 
 | ||||
|   def setUp(self): | ||||
|     super().setUp() | ||||
|     self.test_image = test_util.read_test_image( | ||||
|         test_util.get_test_data_path(_IMAGE_FILE)) | ||||
|     self.model_path = test_util.get_test_data_path(_MODEL_FILE) | ||||
|     self.test_image = _Image.create_from_file( | ||||
|         test_utils.get_test_data_path(_IMAGE_FILE)) | ||||
|     self.model_path = test_utils.get_test_data_path(_MODEL_FILE) | ||||
| 
 | ||||
|   @parameterized.parameters( | ||||
|       (ModelFileType.FILE_NAME, 4, _EXPECTED_CLASSIFICATION_RESULT), | ||||
|  |  | |||
|  | @ -46,9 +46,9 @@ py_library( | |||
|         "//mediapipe/python:_framework_bindings", | ||||
|         "//mediapipe/python:packet_creator", | ||||
|         "//mediapipe/python:packet_getter", | ||||
|         "//mediapipe/tasks/cc/vision/image_classifier/proto:image_classifier_options_py_pb2", | ||||
|         "//mediapipe/tasks/python/components/proto:classifier_options", | ||||
|         "//mediapipe/tasks/python/components/containers:classifications", | ||||
|         "//mediapipe/tasks/cc/vision/image_classifier/proto:image_classifier_graph_options_py_pb2", | ||||
|         "//mediapipe/tasks/python/components/processors/proto:classifier_options", | ||||
|         "//mediapipe/tasks/python/components/containers/proto:classifications", | ||||
|         "//mediapipe/tasks/python/core:base_options", | ||||
|         "//mediapipe/tasks/python/core:optional_dependencies", | ||||
|         "//mediapipe/tasks/python/core:task_info", | ||||
|  |  | |||
|  | @ -21,9 +21,9 @@ from mediapipe.python import packet_getter | |||
| from mediapipe.python._framework_bindings import image as image_module | ||||
| from mediapipe.python._framework_bindings import packet as packet_module | ||||
| from mediapipe.python._framework_bindings import task_runner as task_runner_module | ||||
| from mediapipe.tasks.cc.vision.image_classifier.proto import image_classifier_options_pb2 | ||||
| from mediapipe.tasks.python.components.proto import classifier_options | ||||
| from mediapipe.tasks.python.components.containers import classifications as classifications_module | ||||
| from mediapipe.tasks.cc.vision.image_classifier.proto import image_classifier_graph_options_pb2 | ||||
| from mediapipe.tasks.python.components.processors.proto import classifier_options | ||||
| from mediapipe.tasks.python.components.containers.proto import classifications as classifications_module | ||||
| from mediapipe.tasks.python.core import base_options as base_options_module | ||||
| from mediapipe.tasks.python.core import task_info as task_info_module | ||||
| from mediapipe.tasks.python.core.optional_dependencies import doc_controls | ||||
|  | @ -31,7 +31,7 @@ from mediapipe.tasks.python.vision.core import base_vision_task_api | |||
| from mediapipe.tasks.python.vision.core import vision_task_running_mode as running_mode_module | ||||
| 
 | ||||
| _BaseOptions = base_options_module.BaseOptions | ||||
| _ImageClassifierOptionsProto = image_classifier_options_pb2.ImageClassifierOptions | ||||
| _ImageClassifierGraphOptionsProto = image_classifier_graph_options_pb2.ImageClassifierGraphOptions | ||||
| _ClassifierOptions = classifier_options.ClassifierOptions | ||||
| _RunningMode = running_mode_module.VisionTaskRunningMode | ||||
| _TaskInfo = task_info_module.TaskInfo | ||||
|  | @ -41,7 +41,7 @@ _CLASSIFICATION_RESULT_OUT_STREAM_NAME = 'classification_result_out' | |||
| _CLASSIFICATION_RESULT_TAG = 'CLASSIFICATION_RESULT' | ||||
| _IMAGE_IN_STREAM_NAME = 'image_in' | ||||
| _IMAGE_TAG = 'IMAGE' | ||||
| _TASK_GRAPH_NAME = 'mediapipe.tasks.vision.ImageClassifierGraph' | ||||
| _TASK_GRAPH_NAME = 'mediapipe.tasks.vision.image_classifier.ImageClassifierGraph' | ||||
| 
 | ||||
| 
 | ||||
| @dataclasses.dataclass | ||||
|  | @ -70,13 +70,13 @@ class ImageClassifierOptions: | |||
|                None]] = None | ||||
| 
 | ||||
|   @doc_controls.do_not_generate_docs | ||||
|   def to_pb2(self) -> _ImageClassifierOptionsProto: | ||||
|   def to_pb2(self) -> _ImageClassifierGraphOptionsProto: | ||||
|     """Generates an ImageClassifierOptions protobuf object.""" | ||||
|     base_options_proto = self.base_options.to_pb2() | ||||
|     base_options_proto.use_stream_mode = False if self.running_mode == _RunningMode.IMAGE else True | ||||
|     classifier_options_proto = self.classifier_options.to_pb2() | ||||
| 
 | ||||
|     return _ImageClassifierOptionsProto( | ||||
|     return _ImageClassifierGraphOptionsProto( | ||||
|         base_options=base_options_proto, | ||||
|         classifier_options=classifier_options_proto | ||||
|     ) | ||||
|  | @ -138,7 +138,9 @@ class ImageClassifier(base_vision_task_api.BaseVisionTaskApi): | |||
| 
 | ||||
|     task_info = _TaskInfo( | ||||
|         task_graph=_TASK_GRAPH_NAME, | ||||
|         input_streams=[':'.join([_IMAGE_TAG, _IMAGE_IN_STREAM_NAME])], | ||||
|         input_streams=[ | ||||
|             ':'.join([_IMAGE_TAG, _IMAGE_IN_STREAM_NAME]), | ||||
|         ], | ||||
|         output_streams=[ | ||||
|             ':'.join([_CLASSIFICATION_RESULT_TAG, | ||||
|                       _CLASSIFICATION_RESULT_OUT_STREAM_NAME]) | ||||
|  | @ -153,7 +155,7 @@ class ImageClassifier(base_vision_task_api.BaseVisionTaskApi): | |||
|   # TODO: Create an Image class for MediaPipe Tasks. | ||||
|   def classify( | ||||
|       self, | ||||
|       image: image_module.Image | ||||
|       image: image_module.Image, | ||||
|   ) -> classifications_module.ClassificationResult: | ||||
|     """Performs image classification on the provided MediaPipe Image. | ||||
| 
 | ||||
|  |  | |||
		Loading…
	
		Reference in New Issue
	
	Block a user