Moved ClassifierOptions to mediapipe/tasks/python/components to align with the C++ API

This commit is contained in:
kinaryml 2022-09-21 04:22:33 -07:00
parent 7287e5a0ed
commit d8f7c5a43b
5 changed files with 146 additions and 26 deletions

View File

@ -0,0 +1,28 @@
# Copyright 2022 The MediaPipe Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# Placeholder for internal Python strict library compatibility macro.
package(default_visibility = ["//mediapipe/tasks:internal"])
licenses(["notice"])
py_library(
name = "classifier_options",
srcs = ["classifier_options.py"],
deps = [
"//mediapipe/tasks/cc/components:classifier_options_py_pb2",
"//mediapipe/tasks/python/core:optional_dependencies",
],
)

View File

@ -0,0 +1,92 @@
# Copyright 2022 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Classifier options data class."""
import dataclasses
from typing import Any, List, Optional
from mediapipe.tasks.cc.components import classifier_options_pb2
from mediapipe.tasks.python.core.optional_dependencies import doc_controls
_ClassifierOptionsProto = classifier_options_pb2.ClassifierOptions
@dataclasses.dataclass
class ClassifierOptions:
"""Options for classification processor.
Attributes:
display_names_locale: The locale to use for display names specified through
the TFLite Model Metadata.
max_results: The maximum number of top-scored classification results to
return.
score_threshold: Overrides the ones provided in the model metadata. Results
below this value are rejected.
category_allowlist: Allowlist of category names. If non-empty, detection
results whose category name is not in this set will be filtered out.
Duplicate or unknown category names are ignored. Mutually exclusive with
`category_denylist`.
category_denylist: Denylist of category names. If non-empty, detection
results whose category name is in this set will be filtered out. Duplicate
or unknown category names are ignored. Mutually exclusive with
`category_allowlist`.
"""
display_names_locale: Optional[str] = None
max_results: Optional[int] = None
score_threshold: Optional[float] = None
category_allowlist: Optional[List[str]] = None
category_denylist: Optional[List[str]] = None
@doc_controls.do_not_generate_docs
def to_pb2(self) -> _ClassifierOptionsProto:
"""Generates a ClassifierOptions protobuf object."""
return _ClassifierOptionsProto(
score_threshold=self.score_threshold,
category_allowlist=self.category_allowlist,
category_denylist=self.category_denylist,
display_names_locale=self.display_names_locale,
max_results=self.max_results)
@classmethod
@doc_controls.do_not_generate_docs
def create_from_pb2(
cls,
pb2_obj: _ClassifierOptionsProto
) -> 'ClassifierOptions':
"""Creates a `ClassifierOptions` object from the given protobuf object."""
return ClassifierOptions(
score_threshold=pb2_obj.score_threshold,
category_allowlist=[
str(name) for name in pb2_obj.class_name_allowlist
],
category_denylist=[
str(name) for name in pb2_obj.class_name_denylist
],
display_names_locale=pb2_obj.display_names_locale,
max_results=pb2_obj.max_results)
def __eq__(self, other: Any) -> bool:
"""Checks if this object is equal to the given object.
Args:
other: The object to be compared with.
Returns:
True if the objects are equal.
"""
if not isinstance(other, ClassifierOptions):
return False
return self.to_pb2().__eq__(other.to_pb2())

View File

@ -19,6 +19,7 @@ from absl.testing import absltest
from absl.testing import parameterized
from mediapipe.python._framework_bindings import image as image_module
from mediapipe.tasks.python.components import classifier_options
from mediapipe.tasks.python.components.containers import category as category_module
from mediapipe.tasks.python.components.containers import classifications as classifications_module
from mediapipe.tasks.python.core import base_options as base_options_module
@ -27,6 +28,7 @@ from mediapipe.tasks.python.vision import image_classification
from mediapipe.tasks.python.vision.core import vision_task_running_mode as running_mode_module
_BaseOptions = base_options_module.BaseOptions
_ClassifierOptions = classifier_options.ClassifierOptions
_Category = category_module.Category
_ClassificationEntry = classifications_module.ClassificationEntry
_Classifications = classifications_module.Classifications
@ -136,8 +138,9 @@ class ImageClassifierTest(parameterized.TestCase):
# Should never happen
raise ValueError('model_file_type is invalid.')
classifier_options = _ClassifierOptions(max_results=max_results)
options = _ImageClassifierOptions(
base_options=base_options, max_results=max_results)
base_options=base_options, classifier_options=classifier_options)
classifier = _ImageClassifier.create_from_options(options)
# Performs image classification on the input.
@ -163,8 +166,9 @@ class ImageClassifierTest(parameterized.TestCase):
# Should never happen
raise ValueError('model_file_type is invalid.')
classifier_options = _ClassifierOptions(max_results=max_results)
options = _ImageClassifierOptions(
base_options=base_options, max_results=max_results)
base_options=base_options, classifier_options=classifier_options)
with _ImageClassifier.create_from_options(options) as classifier:
# Performs object detection on the input.
image_result = classifier.classify(self.test_image)
@ -172,9 +176,10 @@ class ImageClassifierTest(parameterized.TestCase):
self.assertEqual(image_result, expected_classification_result)
def test_score_threshold_option(self):
classifier_options = _ClassifierOptions(score_threshold=_SCORE_THRESHOLD)
options = _ImageClassifierOptions(
base_options=_BaseOptions(file_name=self.model_path),
score_threshold=_SCORE_THRESHOLD)
classifier_options=classifier_options)
with _ImageClassifier.create_from_options(options) as classifier:
# Performs image classification on the input.
image_result = classifier.classify(self.test_image)
@ -189,9 +194,10 @@ class ImageClassifierTest(parameterized.TestCase):
f'{classification}')
def test_max_results_option(self):
classifier_options = _ClassifierOptions(score_threshold=_SCORE_THRESHOLD)
options = _ImageClassifierOptions(
base_options=_BaseOptions(file_name=self.model_path),
max_results=_MAX_RESULTS)
classifier_options=classifier_options)
with _ImageClassifier.create_from_options(options) as classifier:
# Performs image classification on the input.
image_result = classifier.classify(self.test_image)
@ -201,9 +207,10 @@ class ImageClassifierTest(parameterized.TestCase):
len(categories), _MAX_RESULTS, 'Too many results returned.')
def test_allow_list_option(self):
classifier_options = _ClassifierOptions(category_allowlist=_ALLOW_LIST)
options = _ImageClassifierOptions(
base_options=_BaseOptions(file_name=self.model_path),
category_allowlist=_ALLOW_LIST)
classifier_options=classifier_options)
with _ImageClassifier.create_from_options(options) as classifier:
# Performs image classification on the input.
image_result = classifier.classify(self.test_image)
@ -216,9 +223,10 @@ class ImageClassifierTest(parameterized.TestCase):
f'Label {label} found but not in label allow list')
def test_deny_list_option(self):
classifier_options = _ClassifierOptions(category_denylist=_DENY_LIST)
options = _ImageClassifierOptions(
base_options=_BaseOptions(file_name=self.model_path),
category_denylist=_DENY_LIST)
base_options=_BaseOptions(file_name=self.model_path),
classifier_options=classifier_options)
with _ImageClassifier.create_from_options(options) as classifier:
# Performs image classification on the input.
image_result = classifier.classify(self.test_image)
@ -236,16 +244,19 @@ class ImageClassifierTest(parameterized.TestCase):
ValueError,
r'`category_allowlist` and `category_denylist` are mutually '
r'exclusive options.'):
classifier_options = _ClassifierOptions(category_allowlist=['foo'],
category_denylist=['bar'])
options = _ImageClassifierOptions(
base_options=_BaseOptions(file_name=self.model_path),
category_allowlist=['foo'],
category_denylist=['bar'])
classifier_options=classifier_options)
with _ImageClassifier.create_from_options(options) as unused_classifier:
pass
def test_empty_classification_outputs(self):
classifier_options = _ClassifierOptions(score_threshold=1)
options = _ImageClassifierOptions(
base_options=_BaseOptions(file_name=self.model_path), score_threshold=1)
base_options=_BaseOptions(file_name=self.model_path),
classifier_options=classifier_options)
with _ImageClassifier.create_from_options(options) as classifier:
# Performs image classification on the input.
image_result = classifier.classify(self.test_image)

View File

@ -46,8 +46,8 @@ py_library(
"//mediapipe/python:_framework_bindings",
"//mediapipe/python:packet_creator",
"//mediapipe/python:packet_getter",
"//mediapipe/tasks/cc/components:classifier_options_py_pb2",
"//mediapipe/tasks/cc/vision/image_classification:image_classifier_options_py_pb2",
"//mediapipe/tasks/python/components:classifier_options",
"//mediapipe/tasks/python/components/containers:classifications",
"//mediapipe/tasks/python/core:base_options",
"//mediapipe/tasks/python/core:optional_dependencies",

View File

@ -21,8 +21,8 @@ from mediapipe.python import packet_getter
from mediapipe.python._framework_bindings import image as image_module
from mediapipe.python._framework_bindings import packet as packet_module
from mediapipe.python._framework_bindings import task_runner as task_runner_module
from mediapipe.tasks.cc.components import classifier_options_pb2
from mediapipe.tasks.cc.vision.image_classification import image_classifier_options_pb2
from mediapipe.tasks.python.components import classifier_options
from mediapipe.tasks.python.components.containers import classifications as classifications_module
from mediapipe.tasks.python.core import base_options as base_options_module
from mediapipe.tasks.python.core import task_info as task_info_module
@ -31,8 +31,8 @@ from mediapipe.tasks.python.vision.core import base_vision_task_api
from mediapipe.tasks.python.vision.core import vision_task_running_mode as running_mode_module
_BaseOptions = base_options_module.BaseOptions
_ClassifierOptionsProto = classifier_options_pb2.ClassifierOptions
_ImageClassifierOptionsProto = image_classifier_options_pb2.ImageClassifierOptions
_ClassifierOptions = classifier_options.ClassifierOptions
_RunningMode = running_mode_module.VisionTaskRunningMode
_TaskInfo = task_info_module.TaskInfo
_TaskRunner = task_runner_module.TaskRunner
@ -77,11 +77,7 @@ class ImageClassifierOptions:
"""
base_options: _BaseOptions
running_mode: _RunningMode = _RunningMode.IMAGE
display_names_locale: Optional[str] = None
max_results: Optional[int] = None
score_threshold: Optional[float] = None
category_allowlist: Optional[List[str]] = None
category_denylist: Optional[List[str]] = None
classifier_options: _ClassifierOptions = _ClassifierOptions()
result_callback: Optional[
Callable[[classifications_module.ClassificationResult],
None]] = None
@ -91,14 +87,7 @@ class ImageClassifierOptions:
"""Generates an ImageClassifierOptions protobuf object."""
base_options_proto = self.base_options.to_pb2()
base_options_proto.use_stream_mode = False if self.running_mode == _RunningMode.IMAGE else True
classifier_options_proto = _ClassifierOptionsProto(
display_names_locale=self.display_names_locale,
max_results=self.max_results,
score_threshold=self.score_threshold,
category_allowlist=self.category_allowlist,
category_denylist=self.category_denylist
)
classifier_options_proto = self.classifier_options.to_pb2()
return _ImageClassifierOptionsProto(
base_options=base_options_proto,