Moved ClassifierOptions to mediapipe/tasks/python/components to align with the C++ API
This commit is contained in:
parent
7287e5a0ed
commit
d8f7c5a43b
28
mediapipe/tasks/python/components/BUILD
Normal file
28
mediapipe/tasks/python/components/BUILD
Normal file
|
@ -0,0 +1,28 @@
|
||||||
|
# Copyright 2022 The MediaPipe Authors. All Rights Reserved.
|
||||||
|
#
|
||||||
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
# you may not use this file except in compliance with the License.
|
||||||
|
# You may obtain a copy of the License at
|
||||||
|
#
|
||||||
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
#
|
||||||
|
# Unless required by applicable law or agreed to in writing, software
|
||||||
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
# See the License for the specific language governing permissions and
|
||||||
|
# limitations under the License.
|
||||||
|
|
||||||
|
# Placeholder for internal Python strict library compatibility macro.
|
||||||
|
|
||||||
|
package(default_visibility = ["//mediapipe/tasks:internal"])
|
||||||
|
|
||||||
|
licenses(["notice"])
|
||||||
|
|
||||||
|
py_library(
|
||||||
|
name = "classifier_options",
|
||||||
|
srcs = ["classifier_options.py"],
|
||||||
|
deps = [
|
||||||
|
"//mediapipe/tasks/cc/components:classifier_options_py_pb2",
|
||||||
|
"//mediapipe/tasks/python/core:optional_dependencies",
|
||||||
|
],
|
||||||
|
)
|
92
mediapipe/tasks/python/components/classifier_options.py
Normal file
92
mediapipe/tasks/python/components/classifier_options.py
Normal file
|
@ -0,0 +1,92 @@
|
||||||
|
# Copyright 2022 The TensorFlow Authors. All Rights Reserved.
|
||||||
|
#
|
||||||
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
# you may not use this file except in compliance with the License.
|
||||||
|
# You may obtain a copy of the License at
|
||||||
|
#
|
||||||
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
#
|
||||||
|
# Unless required by applicable law or agreed to in writing, software
|
||||||
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
# See the License for the specific language governing permissions and
|
||||||
|
# limitations under the License.
|
||||||
|
"""Classifier options data class."""
|
||||||
|
|
||||||
|
import dataclasses
|
||||||
|
from typing import Any, List, Optional
|
||||||
|
|
||||||
|
from mediapipe.tasks.cc.components import classifier_options_pb2
|
||||||
|
from mediapipe.tasks.python.core.optional_dependencies import doc_controls
|
||||||
|
|
||||||
|
_ClassifierOptionsProto = classifier_options_pb2.ClassifierOptions
|
||||||
|
|
||||||
|
|
||||||
|
@dataclasses.dataclass
|
||||||
|
class ClassifierOptions:
|
||||||
|
"""Options for classification processor.
|
||||||
|
|
||||||
|
Attributes:
|
||||||
|
display_names_locale: The locale to use for display names specified through
|
||||||
|
the TFLite Model Metadata.
|
||||||
|
max_results: The maximum number of top-scored classification results to
|
||||||
|
return.
|
||||||
|
score_threshold: Overrides the ones provided in the model metadata. Results
|
||||||
|
below this value are rejected.
|
||||||
|
category_allowlist: Allowlist of category names. If non-empty, detection
|
||||||
|
results whose category name is not in this set will be filtered out.
|
||||||
|
Duplicate or unknown category names are ignored. Mutually exclusive with
|
||||||
|
`category_denylist`.
|
||||||
|
category_denylist: Denylist of category names. If non-empty, detection
|
||||||
|
results whose category name is in this set will be filtered out. Duplicate
|
||||||
|
or unknown category names are ignored. Mutually exclusive with
|
||||||
|
`category_allowlist`.
|
||||||
|
"""
|
||||||
|
|
||||||
|
display_names_locale: Optional[str] = None
|
||||||
|
max_results: Optional[int] = None
|
||||||
|
score_threshold: Optional[float] = None
|
||||||
|
category_allowlist: Optional[List[str]] = None
|
||||||
|
category_denylist: Optional[List[str]] = None
|
||||||
|
|
||||||
|
@doc_controls.do_not_generate_docs
|
||||||
|
def to_pb2(self) -> _ClassifierOptionsProto:
|
||||||
|
"""Generates a ClassifierOptions protobuf object."""
|
||||||
|
return _ClassifierOptionsProto(
|
||||||
|
score_threshold=self.score_threshold,
|
||||||
|
category_allowlist=self.category_allowlist,
|
||||||
|
category_denylist=self.category_denylist,
|
||||||
|
display_names_locale=self.display_names_locale,
|
||||||
|
max_results=self.max_results)
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
@doc_controls.do_not_generate_docs
|
||||||
|
def create_from_pb2(
|
||||||
|
cls,
|
||||||
|
pb2_obj: _ClassifierOptionsProto
|
||||||
|
) -> 'ClassifierOptions':
|
||||||
|
"""Creates a `ClassifierOptions` object from the given protobuf object."""
|
||||||
|
return ClassifierOptions(
|
||||||
|
score_threshold=pb2_obj.score_threshold,
|
||||||
|
category_allowlist=[
|
||||||
|
str(name) for name in pb2_obj.class_name_allowlist
|
||||||
|
],
|
||||||
|
category_denylist=[
|
||||||
|
str(name) for name in pb2_obj.class_name_denylist
|
||||||
|
],
|
||||||
|
display_names_locale=pb2_obj.display_names_locale,
|
||||||
|
max_results=pb2_obj.max_results)
|
||||||
|
|
||||||
|
def __eq__(self, other: Any) -> bool:
|
||||||
|
"""Checks if this object is equal to the given object.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
other: The object to be compared with.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
True if the objects are equal.
|
||||||
|
"""
|
||||||
|
if not isinstance(other, ClassifierOptions):
|
||||||
|
return False
|
||||||
|
|
||||||
|
return self.to_pb2().__eq__(other.to_pb2())
|
|
@ -19,6 +19,7 @@ from absl.testing import absltest
|
||||||
from absl.testing import parameterized
|
from absl.testing import parameterized
|
||||||
|
|
||||||
from mediapipe.python._framework_bindings import image as image_module
|
from mediapipe.python._framework_bindings import image as image_module
|
||||||
|
from mediapipe.tasks.python.components import classifier_options
|
||||||
from mediapipe.tasks.python.components.containers import category as category_module
|
from mediapipe.tasks.python.components.containers import category as category_module
|
||||||
from mediapipe.tasks.python.components.containers import classifications as classifications_module
|
from mediapipe.tasks.python.components.containers import classifications as classifications_module
|
||||||
from mediapipe.tasks.python.core import base_options as base_options_module
|
from mediapipe.tasks.python.core import base_options as base_options_module
|
||||||
|
@ -27,6 +28,7 @@ from mediapipe.tasks.python.vision import image_classification
|
||||||
from mediapipe.tasks.python.vision.core import vision_task_running_mode as running_mode_module
|
from mediapipe.tasks.python.vision.core import vision_task_running_mode as running_mode_module
|
||||||
|
|
||||||
_BaseOptions = base_options_module.BaseOptions
|
_BaseOptions = base_options_module.BaseOptions
|
||||||
|
_ClassifierOptions = classifier_options.ClassifierOptions
|
||||||
_Category = category_module.Category
|
_Category = category_module.Category
|
||||||
_ClassificationEntry = classifications_module.ClassificationEntry
|
_ClassificationEntry = classifications_module.ClassificationEntry
|
||||||
_Classifications = classifications_module.Classifications
|
_Classifications = classifications_module.Classifications
|
||||||
|
@ -136,8 +138,9 @@ class ImageClassifierTest(parameterized.TestCase):
|
||||||
# Should never happen
|
# Should never happen
|
||||||
raise ValueError('model_file_type is invalid.')
|
raise ValueError('model_file_type is invalid.')
|
||||||
|
|
||||||
|
classifier_options = _ClassifierOptions(max_results=max_results)
|
||||||
options = _ImageClassifierOptions(
|
options = _ImageClassifierOptions(
|
||||||
base_options=base_options, max_results=max_results)
|
base_options=base_options, classifier_options=classifier_options)
|
||||||
classifier = _ImageClassifier.create_from_options(options)
|
classifier = _ImageClassifier.create_from_options(options)
|
||||||
|
|
||||||
# Performs image classification on the input.
|
# Performs image classification on the input.
|
||||||
|
@ -163,8 +166,9 @@ class ImageClassifierTest(parameterized.TestCase):
|
||||||
# Should never happen
|
# Should never happen
|
||||||
raise ValueError('model_file_type is invalid.')
|
raise ValueError('model_file_type is invalid.')
|
||||||
|
|
||||||
|
classifier_options = _ClassifierOptions(max_results=max_results)
|
||||||
options = _ImageClassifierOptions(
|
options = _ImageClassifierOptions(
|
||||||
base_options=base_options, max_results=max_results)
|
base_options=base_options, classifier_options=classifier_options)
|
||||||
with _ImageClassifier.create_from_options(options) as classifier:
|
with _ImageClassifier.create_from_options(options) as classifier:
|
||||||
# Performs object detection on the input.
|
# Performs object detection on the input.
|
||||||
image_result = classifier.classify(self.test_image)
|
image_result = classifier.classify(self.test_image)
|
||||||
|
@ -172,9 +176,10 @@ class ImageClassifierTest(parameterized.TestCase):
|
||||||
self.assertEqual(image_result, expected_classification_result)
|
self.assertEqual(image_result, expected_classification_result)
|
||||||
|
|
||||||
def test_score_threshold_option(self):
|
def test_score_threshold_option(self):
|
||||||
|
classifier_options = _ClassifierOptions(score_threshold=_SCORE_THRESHOLD)
|
||||||
options = _ImageClassifierOptions(
|
options = _ImageClassifierOptions(
|
||||||
base_options=_BaseOptions(file_name=self.model_path),
|
base_options=_BaseOptions(file_name=self.model_path),
|
||||||
score_threshold=_SCORE_THRESHOLD)
|
classifier_options=classifier_options)
|
||||||
with _ImageClassifier.create_from_options(options) as classifier:
|
with _ImageClassifier.create_from_options(options) as classifier:
|
||||||
# Performs image classification on the input.
|
# Performs image classification on the input.
|
||||||
image_result = classifier.classify(self.test_image)
|
image_result = classifier.classify(self.test_image)
|
||||||
|
@ -189,9 +194,10 @@ class ImageClassifierTest(parameterized.TestCase):
|
||||||
f'{classification}')
|
f'{classification}')
|
||||||
|
|
||||||
def test_max_results_option(self):
|
def test_max_results_option(self):
|
||||||
|
classifier_options = _ClassifierOptions(score_threshold=_SCORE_THRESHOLD)
|
||||||
options = _ImageClassifierOptions(
|
options = _ImageClassifierOptions(
|
||||||
base_options=_BaseOptions(file_name=self.model_path),
|
base_options=_BaseOptions(file_name=self.model_path),
|
||||||
max_results=_MAX_RESULTS)
|
classifier_options=classifier_options)
|
||||||
with _ImageClassifier.create_from_options(options) as classifier:
|
with _ImageClassifier.create_from_options(options) as classifier:
|
||||||
# Performs image classification on the input.
|
# Performs image classification on the input.
|
||||||
image_result = classifier.classify(self.test_image)
|
image_result = classifier.classify(self.test_image)
|
||||||
|
@ -201,9 +207,10 @@ class ImageClassifierTest(parameterized.TestCase):
|
||||||
len(categories), _MAX_RESULTS, 'Too many results returned.')
|
len(categories), _MAX_RESULTS, 'Too many results returned.')
|
||||||
|
|
||||||
def test_allow_list_option(self):
|
def test_allow_list_option(self):
|
||||||
|
classifier_options = _ClassifierOptions(category_allowlist=_ALLOW_LIST)
|
||||||
options = _ImageClassifierOptions(
|
options = _ImageClassifierOptions(
|
||||||
base_options=_BaseOptions(file_name=self.model_path),
|
base_options=_BaseOptions(file_name=self.model_path),
|
||||||
category_allowlist=_ALLOW_LIST)
|
classifier_options=classifier_options)
|
||||||
with _ImageClassifier.create_from_options(options) as classifier:
|
with _ImageClassifier.create_from_options(options) as classifier:
|
||||||
# Performs image classification on the input.
|
# Performs image classification on the input.
|
||||||
image_result = classifier.classify(self.test_image)
|
image_result = classifier.classify(self.test_image)
|
||||||
|
@ -216,9 +223,10 @@ class ImageClassifierTest(parameterized.TestCase):
|
||||||
f'Label {label} found but not in label allow list')
|
f'Label {label} found but not in label allow list')
|
||||||
|
|
||||||
def test_deny_list_option(self):
|
def test_deny_list_option(self):
|
||||||
|
classifier_options = _ClassifierOptions(category_denylist=_DENY_LIST)
|
||||||
options = _ImageClassifierOptions(
|
options = _ImageClassifierOptions(
|
||||||
base_options=_BaseOptions(file_name=self.model_path),
|
base_options=_BaseOptions(file_name=self.model_path),
|
||||||
category_denylist=_DENY_LIST)
|
classifier_options=classifier_options)
|
||||||
with _ImageClassifier.create_from_options(options) as classifier:
|
with _ImageClassifier.create_from_options(options) as classifier:
|
||||||
# Performs image classification on the input.
|
# Performs image classification on the input.
|
||||||
image_result = classifier.classify(self.test_image)
|
image_result = classifier.classify(self.test_image)
|
||||||
|
@ -236,16 +244,19 @@ class ImageClassifierTest(parameterized.TestCase):
|
||||||
ValueError,
|
ValueError,
|
||||||
r'`category_allowlist` and `category_denylist` are mutually '
|
r'`category_allowlist` and `category_denylist` are mutually '
|
||||||
r'exclusive options.'):
|
r'exclusive options.'):
|
||||||
|
classifier_options = _ClassifierOptions(category_allowlist=['foo'],
|
||||||
|
category_denylist=['bar'])
|
||||||
options = _ImageClassifierOptions(
|
options = _ImageClassifierOptions(
|
||||||
base_options=_BaseOptions(file_name=self.model_path),
|
base_options=_BaseOptions(file_name=self.model_path),
|
||||||
category_allowlist=['foo'],
|
classifier_options=classifier_options)
|
||||||
category_denylist=['bar'])
|
|
||||||
with _ImageClassifier.create_from_options(options) as unused_classifier:
|
with _ImageClassifier.create_from_options(options) as unused_classifier:
|
||||||
pass
|
pass
|
||||||
|
|
||||||
def test_empty_classification_outputs(self):
|
def test_empty_classification_outputs(self):
|
||||||
|
classifier_options = _ClassifierOptions(score_threshold=1)
|
||||||
options = _ImageClassifierOptions(
|
options = _ImageClassifierOptions(
|
||||||
base_options=_BaseOptions(file_name=self.model_path), score_threshold=1)
|
base_options=_BaseOptions(file_name=self.model_path),
|
||||||
|
classifier_options=classifier_options)
|
||||||
with _ImageClassifier.create_from_options(options) as classifier:
|
with _ImageClassifier.create_from_options(options) as classifier:
|
||||||
# Performs image classification on the input.
|
# Performs image classification on the input.
|
||||||
image_result = classifier.classify(self.test_image)
|
image_result = classifier.classify(self.test_image)
|
||||||
|
|
|
@ -46,8 +46,8 @@ py_library(
|
||||||
"//mediapipe/python:_framework_bindings",
|
"//mediapipe/python:_framework_bindings",
|
||||||
"//mediapipe/python:packet_creator",
|
"//mediapipe/python:packet_creator",
|
||||||
"//mediapipe/python:packet_getter",
|
"//mediapipe/python:packet_getter",
|
||||||
"//mediapipe/tasks/cc/components:classifier_options_py_pb2",
|
|
||||||
"//mediapipe/tasks/cc/vision/image_classification:image_classifier_options_py_pb2",
|
"//mediapipe/tasks/cc/vision/image_classification:image_classifier_options_py_pb2",
|
||||||
|
"//mediapipe/tasks/python/components:classifier_options",
|
||||||
"//mediapipe/tasks/python/components/containers:classifications",
|
"//mediapipe/tasks/python/components/containers:classifications",
|
||||||
"//mediapipe/tasks/python/core:base_options",
|
"//mediapipe/tasks/python/core:base_options",
|
||||||
"//mediapipe/tasks/python/core:optional_dependencies",
|
"//mediapipe/tasks/python/core:optional_dependencies",
|
||||||
|
|
|
@ -21,8 +21,8 @@ from mediapipe.python import packet_getter
|
||||||
from mediapipe.python._framework_bindings import image as image_module
|
from mediapipe.python._framework_bindings import image as image_module
|
||||||
from mediapipe.python._framework_bindings import packet as packet_module
|
from mediapipe.python._framework_bindings import packet as packet_module
|
||||||
from mediapipe.python._framework_bindings import task_runner as task_runner_module
|
from mediapipe.python._framework_bindings import task_runner as task_runner_module
|
||||||
from mediapipe.tasks.cc.components import classifier_options_pb2
|
|
||||||
from mediapipe.tasks.cc.vision.image_classification import image_classifier_options_pb2
|
from mediapipe.tasks.cc.vision.image_classification import image_classifier_options_pb2
|
||||||
|
from mediapipe.tasks.python.components import classifier_options
|
||||||
from mediapipe.tasks.python.components.containers import classifications as classifications_module
|
from mediapipe.tasks.python.components.containers import classifications as classifications_module
|
||||||
from mediapipe.tasks.python.core import base_options as base_options_module
|
from mediapipe.tasks.python.core import base_options as base_options_module
|
||||||
from mediapipe.tasks.python.core import task_info as task_info_module
|
from mediapipe.tasks.python.core import task_info as task_info_module
|
||||||
|
@ -31,8 +31,8 @@ from mediapipe.tasks.python.vision.core import base_vision_task_api
|
||||||
from mediapipe.tasks.python.vision.core import vision_task_running_mode as running_mode_module
|
from mediapipe.tasks.python.vision.core import vision_task_running_mode as running_mode_module
|
||||||
|
|
||||||
_BaseOptions = base_options_module.BaseOptions
|
_BaseOptions = base_options_module.BaseOptions
|
||||||
_ClassifierOptionsProto = classifier_options_pb2.ClassifierOptions
|
|
||||||
_ImageClassifierOptionsProto = image_classifier_options_pb2.ImageClassifierOptions
|
_ImageClassifierOptionsProto = image_classifier_options_pb2.ImageClassifierOptions
|
||||||
|
_ClassifierOptions = classifier_options.ClassifierOptions
|
||||||
_RunningMode = running_mode_module.VisionTaskRunningMode
|
_RunningMode = running_mode_module.VisionTaskRunningMode
|
||||||
_TaskInfo = task_info_module.TaskInfo
|
_TaskInfo = task_info_module.TaskInfo
|
||||||
_TaskRunner = task_runner_module.TaskRunner
|
_TaskRunner = task_runner_module.TaskRunner
|
||||||
|
@ -77,11 +77,7 @@ class ImageClassifierOptions:
|
||||||
"""
|
"""
|
||||||
base_options: _BaseOptions
|
base_options: _BaseOptions
|
||||||
running_mode: _RunningMode = _RunningMode.IMAGE
|
running_mode: _RunningMode = _RunningMode.IMAGE
|
||||||
display_names_locale: Optional[str] = None
|
classifier_options: _ClassifierOptions = _ClassifierOptions()
|
||||||
max_results: Optional[int] = None
|
|
||||||
score_threshold: Optional[float] = None
|
|
||||||
category_allowlist: Optional[List[str]] = None
|
|
||||||
category_denylist: Optional[List[str]] = None
|
|
||||||
result_callback: Optional[
|
result_callback: Optional[
|
||||||
Callable[[classifications_module.ClassificationResult],
|
Callable[[classifications_module.ClassificationResult],
|
||||||
None]] = None
|
None]] = None
|
||||||
|
@ -91,14 +87,7 @@ class ImageClassifierOptions:
|
||||||
"""Generates an ImageClassifierOptions protobuf object."""
|
"""Generates an ImageClassifierOptions protobuf object."""
|
||||||
base_options_proto = self.base_options.to_pb2()
|
base_options_proto = self.base_options.to_pb2()
|
||||||
base_options_proto.use_stream_mode = False if self.running_mode == _RunningMode.IMAGE else True
|
base_options_proto.use_stream_mode = False if self.running_mode == _RunningMode.IMAGE else True
|
||||||
|
classifier_options_proto = self.classifier_options.to_pb2()
|
||||||
classifier_options_proto = _ClassifierOptionsProto(
|
|
||||||
display_names_locale=self.display_names_locale,
|
|
||||||
max_results=self.max_results,
|
|
||||||
score_threshold=self.score_threshold,
|
|
||||||
category_allowlist=self.category_allowlist,
|
|
||||||
category_denylist=self.category_denylist
|
|
||||||
)
|
|
||||||
|
|
||||||
return _ImageClassifierOptionsProto(
|
return _ImageClassifierOptionsProto(
|
||||||
base_options=base_options_proto,
|
base_options=base_options_proto,
|
||||||
|
|
Loading…
Reference in New Issue
Block a user