Revised implementation to align with recent changes
This commit is contained in:
parent
aac7ff946f
commit
f241630b56
|
@ -53,7 +53,7 @@ py_library(
|
||||||
srcs = ["classifications.py"],
|
srcs = ["classifications.py"],
|
||||||
deps = [
|
deps = [
|
||||||
":category",
|
":category",
|
||||||
"//mediapipe/tasks/cc/components/containers:classifications_py_pb2",
|
"//mediapipe/tasks/cc/components/containers/proto:classifications_py_pb2",
|
||||||
"//mediapipe/tasks/python/core:optional_dependencies",
|
"//mediapipe/tasks/python/core:optional_dependencies",
|
||||||
],
|
],
|
||||||
)
|
)
|
|
@ -16,8 +16,8 @@
|
||||||
import dataclasses
|
import dataclasses
|
||||||
from typing import Any, List, Optional
|
from typing import Any, List, Optional
|
||||||
|
|
||||||
from mediapipe.tasks.cc.components.containers import classifications_pb2
|
from mediapipe.tasks.cc.components.containers.proto import classifications_pb2
|
||||||
from mediapipe.tasks.python.components.containers import category as category_module
|
from mediapipe.tasks.python.components.containers.proto import category as category_module
|
||||||
from mediapipe.tasks.python.core.optional_dependencies import doc_controls
|
from mediapipe.tasks.python.core.optional_dependencies import doc_controls
|
||||||
|
|
||||||
_ClassificationEntryProto = classifications_pb2.ClassificationEntry
|
_ClassificationEntryProto = classifications_pb2.ClassificationEntry
|
|
@ -22,8 +22,7 @@ py_library(
|
||||||
name = "classifier_options",
|
name = "classifier_options",
|
||||||
srcs = ["classifier_options.py"],
|
srcs = ["classifier_options.py"],
|
||||||
deps = [
|
deps = [
|
||||||
"//mediapipe/tasks/cc/components/proto:classifier_options_py_pb2",
|
"//mediapipe/tasks/cc/components/processors/proto:classifier_options_py_pb2",
|
||||||
"//mediapipe/tasks/python/core:optional_dependencies",
|
"//mediapipe/tasks/python/core:optional_dependencies",
|
||||||
],
|
],
|
||||||
)
|
)
|
||||||
|
|
|
@ -16,7 +16,7 @@
|
||||||
import dataclasses
|
import dataclasses
|
||||||
from typing import Any, List, Optional
|
from typing import Any, List, Optional
|
||||||
|
|
||||||
from mediapipe.tasks.cc.components.proto import classifier_options_pb2
|
from mediapipe.tasks.cc.components.processors.proto import classifier_options_pb2
|
||||||
from mediapipe.tasks.python.core.optional_dependencies import doc_controls
|
from mediapipe.tasks.python.core.optional_dependencies import doc_controls
|
||||||
|
|
||||||
_ClassifierOptionsProto = classifier_options_pb2.ClassifierOptions
|
_ClassifierOptionsProto = classifier_options_pb2.ClassifierOptions
|
|
@ -46,11 +46,11 @@ py_test(
|
||||||
],
|
],
|
||||||
deps = [
|
deps = [
|
||||||
"//mediapipe/python:_framework_bindings",
|
"//mediapipe/python:_framework_bindings",
|
||||||
"//mediapipe/tasks/python/components/proto:classifier_options",
|
"//mediapipe/tasks/python/components/processors/proto:classifier_options",
|
||||||
"//mediapipe/tasks/python/components/containers:category",
|
"//mediapipe/tasks/python/components/containers/proto:category",
|
||||||
"//mediapipe/tasks/python/components/containers:classifications",
|
"//mediapipe/tasks/python/components/containers/proto:classifications",
|
||||||
"//mediapipe/tasks/python/core:base_options",
|
"//mediapipe/tasks/python/core:base_options",
|
||||||
"//mediapipe/tasks/python/test:test_util",
|
"//mediapipe/tasks/python/test:test_utils",
|
||||||
"//mediapipe/tasks/python/vision:image_classifier",
|
"//mediapipe/tasks/python/vision:image_classifier",
|
||||||
"//mediapipe/tasks/python/vision/core:vision_task_running_mode",
|
"//mediapipe/tasks/python/vision/core:vision_task_running_mode",
|
||||||
],
|
],
|
||||||
|
|
|
@ -19,11 +19,11 @@ from absl.testing import absltest
|
||||||
from absl.testing import parameterized
|
from absl.testing import parameterized
|
||||||
|
|
||||||
from mediapipe.python._framework_bindings import image as image_module
|
from mediapipe.python._framework_bindings import image as image_module
|
||||||
from mediapipe.tasks.python.components.proto import classifier_options
|
from mediapipe.tasks.python.components.processors.proto import classifier_options
|
||||||
from mediapipe.tasks.python.components.containers import category as category_module
|
from mediapipe.tasks.python.components.containers.proto import category as category_module
|
||||||
from mediapipe.tasks.python.components.containers import classifications as classifications_module
|
from mediapipe.tasks.python.components.containers.proto import classifications as classifications_module
|
||||||
from mediapipe.tasks.python.core import base_options as base_options_module
|
from mediapipe.tasks.python.core import base_options as base_options_module
|
||||||
from mediapipe.tasks.python.test import test_util
|
from mediapipe.tasks.python.test import test_utils
|
||||||
from mediapipe.tasks.python.vision import image_classifier
|
from mediapipe.tasks.python.vision import image_classifier
|
||||||
from mediapipe.tasks.python.vision.core import vision_task_running_mode as running_mode_module
|
from mediapipe.tasks.python.vision.core import vision_task_running_mode as running_mode_module
|
||||||
|
|
||||||
|
@ -88,9 +88,9 @@ class ImageClassifierTest(parameterized.TestCase):
|
||||||
|
|
||||||
def setUp(self):
|
def setUp(self):
|
||||||
super().setUp()
|
super().setUp()
|
||||||
self.test_image = test_util.read_test_image(
|
self.test_image = _Image.create_from_file(
|
||||||
test_util.get_test_data_path(_IMAGE_FILE))
|
test_utils.get_test_data_path(_IMAGE_FILE))
|
||||||
self.model_path = test_util.get_test_data_path(_MODEL_FILE)
|
self.model_path = test_utils.get_test_data_path(_MODEL_FILE)
|
||||||
|
|
||||||
@parameterized.parameters(
|
@parameterized.parameters(
|
||||||
(ModelFileType.FILE_NAME, 4, _EXPECTED_CLASSIFICATION_RESULT),
|
(ModelFileType.FILE_NAME, 4, _EXPECTED_CLASSIFICATION_RESULT),
|
||||||
|
|
|
@ -46,9 +46,9 @@ py_library(
|
||||||
"//mediapipe/python:_framework_bindings",
|
"//mediapipe/python:_framework_bindings",
|
||||||
"//mediapipe/python:packet_creator",
|
"//mediapipe/python:packet_creator",
|
||||||
"//mediapipe/python:packet_getter",
|
"//mediapipe/python:packet_getter",
|
||||||
"//mediapipe/tasks/cc/vision/image_classifier/proto:image_classifier_options_py_pb2",
|
"//mediapipe/tasks/cc/vision/image_classifier/proto:image_classifier_graph_options_py_pb2",
|
||||||
"//mediapipe/tasks/python/components/proto:classifier_options",
|
"//mediapipe/tasks/python/components/processors/proto:classifier_options",
|
||||||
"//mediapipe/tasks/python/components/containers:classifications",
|
"//mediapipe/tasks/python/components/containers/proto:classifications",
|
||||||
"//mediapipe/tasks/python/core:base_options",
|
"//mediapipe/tasks/python/core:base_options",
|
||||||
"//mediapipe/tasks/python/core:optional_dependencies",
|
"//mediapipe/tasks/python/core:optional_dependencies",
|
||||||
"//mediapipe/tasks/python/core:task_info",
|
"//mediapipe/tasks/python/core:task_info",
|
||||||
|
|
|
@ -21,9 +21,9 @@ from mediapipe.python import packet_getter
|
||||||
from mediapipe.python._framework_bindings import image as image_module
|
from mediapipe.python._framework_bindings import image as image_module
|
||||||
from mediapipe.python._framework_bindings import packet as packet_module
|
from mediapipe.python._framework_bindings import packet as packet_module
|
||||||
from mediapipe.python._framework_bindings import task_runner as task_runner_module
|
from mediapipe.python._framework_bindings import task_runner as task_runner_module
|
||||||
from mediapipe.tasks.cc.vision.image_classifier.proto import image_classifier_options_pb2
|
from mediapipe.tasks.cc.vision.image_classifier.proto import image_classifier_graph_options_pb2
|
||||||
from mediapipe.tasks.python.components.proto import classifier_options
|
from mediapipe.tasks.python.components.processors.proto import classifier_options
|
||||||
from mediapipe.tasks.python.components.containers import classifications as classifications_module
|
from mediapipe.tasks.python.components.containers.proto import classifications as classifications_module
|
||||||
from mediapipe.tasks.python.core import base_options as base_options_module
|
from mediapipe.tasks.python.core import base_options as base_options_module
|
||||||
from mediapipe.tasks.python.core import task_info as task_info_module
|
from mediapipe.tasks.python.core import task_info as task_info_module
|
||||||
from mediapipe.tasks.python.core.optional_dependencies import doc_controls
|
from mediapipe.tasks.python.core.optional_dependencies import doc_controls
|
||||||
|
@ -31,7 +31,7 @@ from mediapipe.tasks.python.vision.core import base_vision_task_api
|
||||||
from mediapipe.tasks.python.vision.core import vision_task_running_mode as running_mode_module
|
from mediapipe.tasks.python.vision.core import vision_task_running_mode as running_mode_module
|
||||||
|
|
||||||
_BaseOptions = base_options_module.BaseOptions
|
_BaseOptions = base_options_module.BaseOptions
|
||||||
_ImageClassifierOptionsProto = image_classifier_options_pb2.ImageClassifierOptions
|
_ImageClassifierGraphOptionsProto = image_classifier_graph_options_pb2.ImageClassifierGraphOptions
|
||||||
_ClassifierOptions = classifier_options.ClassifierOptions
|
_ClassifierOptions = classifier_options.ClassifierOptions
|
||||||
_RunningMode = running_mode_module.VisionTaskRunningMode
|
_RunningMode = running_mode_module.VisionTaskRunningMode
|
||||||
_TaskInfo = task_info_module.TaskInfo
|
_TaskInfo = task_info_module.TaskInfo
|
||||||
|
@ -41,7 +41,7 @@ _CLASSIFICATION_RESULT_OUT_STREAM_NAME = 'classification_result_out'
|
||||||
_CLASSIFICATION_RESULT_TAG = 'CLASSIFICATION_RESULT'
|
_CLASSIFICATION_RESULT_TAG = 'CLASSIFICATION_RESULT'
|
||||||
_IMAGE_IN_STREAM_NAME = 'image_in'
|
_IMAGE_IN_STREAM_NAME = 'image_in'
|
||||||
_IMAGE_TAG = 'IMAGE'
|
_IMAGE_TAG = 'IMAGE'
|
||||||
_TASK_GRAPH_NAME = 'mediapipe.tasks.vision.ImageClassifierGraph'
|
_TASK_GRAPH_NAME = 'mediapipe.tasks.vision.image_classifier.ImageClassifierGraph'
|
||||||
|
|
||||||
|
|
||||||
@dataclasses.dataclass
|
@dataclasses.dataclass
|
||||||
|
@ -70,13 +70,13 @@ class ImageClassifierOptions:
|
||||||
None]] = None
|
None]] = None
|
||||||
|
|
||||||
@doc_controls.do_not_generate_docs
|
@doc_controls.do_not_generate_docs
|
||||||
def to_pb2(self) -> _ImageClassifierOptionsProto:
|
def to_pb2(self) -> _ImageClassifierGraphOptionsProto:
|
||||||
"""Generates an ImageClassifierOptions protobuf object."""
|
"""Generates an ImageClassifierOptions protobuf object."""
|
||||||
base_options_proto = self.base_options.to_pb2()
|
base_options_proto = self.base_options.to_pb2()
|
||||||
base_options_proto.use_stream_mode = False if self.running_mode == _RunningMode.IMAGE else True
|
base_options_proto.use_stream_mode = False if self.running_mode == _RunningMode.IMAGE else True
|
||||||
classifier_options_proto = self.classifier_options.to_pb2()
|
classifier_options_proto = self.classifier_options.to_pb2()
|
||||||
|
|
||||||
return _ImageClassifierOptionsProto(
|
return _ImageClassifierGraphOptionsProto(
|
||||||
base_options=base_options_proto,
|
base_options=base_options_proto,
|
||||||
classifier_options=classifier_options_proto
|
classifier_options=classifier_options_proto
|
||||||
)
|
)
|
||||||
|
@ -138,7 +138,9 @@ class ImageClassifier(base_vision_task_api.BaseVisionTaskApi):
|
||||||
|
|
||||||
task_info = _TaskInfo(
|
task_info = _TaskInfo(
|
||||||
task_graph=_TASK_GRAPH_NAME,
|
task_graph=_TASK_GRAPH_NAME,
|
||||||
input_streams=[':'.join([_IMAGE_TAG, _IMAGE_IN_STREAM_NAME])],
|
input_streams=[
|
||||||
|
':'.join([_IMAGE_TAG, _IMAGE_IN_STREAM_NAME]),
|
||||||
|
],
|
||||||
output_streams=[
|
output_streams=[
|
||||||
':'.join([_CLASSIFICATION_RESULT_TAG,
|
':'.join([_CLASSIFICATION_RESULT_TAG,
|
||||||
_CLASSIFICATION_RESULT_OUT_STREAM_NAME])
|
_CLASSIFICATION_RESULT_OUT_STREAM_NAME])
|
||||||
|
@ -153,7 +155,7 @@ class ImageClassifier(base_vision_task_api.BaseVisionTaskApi):
|
||||||
# TODO: Create an Image class for MediaPipe Tasks.
|
# TODO: Create an Image class for MediaPipe Tasks.
|
||||||
def classify(
|
def classify(
|
||||||
self,
|
self,
|
||||||
image: image_module.Image
|
image: image_module.Image,
|
||||||
) -> classifications_module.ClassificationResult:
|
) -> classifications_module.ClassificationResult:
|
||||||
"""Performs image classification on the provided MediaPipe Image.
|
"""Performs image classification on the provided MediaPipe Image.
|
||||||
|
|
||||||
|
|
Loading…
Reference in New Issue
Block a user