Updated classifier python tasks to use the new classification_result dataclass

This commit is contained in:
kinaryml 2022-11-08 10:28:24 -08:00
parent 26066787b3
commit 0fd78c2ec6
11 changed files with 191 additions and 380 deletions

View File

@ -85,21 +85,12 @@ py_library(
],
)
py_library(
name = "classifications",
srcs = ["classifications.py"],
deps = [
":category",
"//mediapipe/tasks/cc/components/containers/proto:classifications_py_pb2",
"//mediapipe/tasks/python/core:optional_dependencies",
],
)
py_library(
name = "classification_result",
srcs = ["classification_result.py"],
deps = [
":category",
"//mediapipe/framework/formats:classification_py_pb2",
"//mediapipe/tasks/cc/components/containers/proto:classifications_py_pb2",
"//mediapipe/tasks/python/core:optional_dependencies",
],

View File

@ -16,10 +16,13 @@
import dataclasses
from typing import List, Optional
from mediapipe.framework.formats import classification_pb2
from mediapipe.tasks.cc.components.containers.proto import classifications_pb2
from mediapipe.tasks.python.components.containers import category as category_module
from mediapipe.tasks.python.core.optional_dependencies import doc_controls
_ClassificationProto = classification_pb2.Classification
_ClassificationListProto = classification_pb2.ClassificationList
_ClassificationsProto = classifications_pb2.Classifications
_ClassificationResultProto = classifications_pb2.ClassificationResult
@ -41,6 +44,24 @@ class Classifications:
head_index: int
head_name: Optional[str] = None
@doc_controls.do_not_generate_docs
def to_pb2(self) -> _ClassificationsProto:
"""Generates a Classifications protobuf object."""
classification_list_proto = _ClassificationListProto()
classification_list_proto.classification.extend([
_ClassificationProto(
index=category.index,
score=category.score,
label=category.category_name,
display_name=category.display_name
)
for category in self.categories
])
return _ClassificationsProto(
classification_list=classification_list_proto,
head_index=self.head_index,
head_name=self.head_name)
@classmethod
@doc_controls.do_not_generate_docs
def create_from_pb2(cls, pb2_obj: _ClassificationsProto) -> 'Classifications':
@ -78,6 +99,16 @@ class ClassificationResult:
classifications: List[Classifications]
timestamp_ms: Optional[int] = None
@doc_controls.do_not_generate_docs
def to_pb2(self) -> _ClassificationResultProto:
"""Generates a ClassificationResult protobuf object."""
return _ClassificationResultProto(
classifications=[
classification.to_pb2()
for classification in self.classifications
],
timestamp_ms=self.timestamp_ms)
@classmethod
@doc_controls.do_not_generate_docs
def create_from_pb2(

View File

@ -1,168 +0,0 @@
# Copyright 2022 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Classifications data class."""
import dataclasses
from typing import Any, List, Optional
from mediapipe.tasks.cc.components.containers.proto import classifications_pb2
from mediapipe.tasks.python.components.containers import category as category_module
from mediapipe.tasks.python.core.optional_dependencies import doc_controls
_ClassificationEntryProto = classifications_pb2.ClassificationEntry
_ClassificationsProto = classifications_pb2.Classifications
_ClassificationResultProto = classifications_pb2.ClassificationResult
@dataclasses.dataclass
class ClassificationEntry:
"""List of predicted classes (aka labels) for a given classifier head.
Attributes:
categories: The array of predicted categories, usually sorted by descending
scores (e.g. from high to low probability).
timestamp_ms: The optional timestamp (in milliseconds) associated to the
classification entry. This is useful for time series use cases, e.g.,
audio classification.
"""
categories: List[category_module.Category]
timestamp_ms: Optional[int] = None
@doc_controls.do_not_generate_docs
def to_pb2(self) -> _ClassificationEntryProto:
"""Generates a ClassificationEntry protobuf object."""
return _ClassificationEntryProto(
categories=[category.to_pb2() for category in self.categories],
timestamp_ms=self.timestamp_ms)
@classmethod
@doc_controls.do_not_generate_docs
def create_from_pb2(
cls, pb2_obj: _ClassificationEntryProto) -> 'ClassificationEntry':
"""Creates a `ClassificationEntry` object from the given protobuf object."""
return ClassificationEntry(
categories=[
category_module.Category.create_from_pb2(category)
for category in pb2_obj.categories
],
timestamp_ms=pb2_obj.timestamp_ms)
def __eq__(self, other: Any) -> bool:
"""Checks if this object is equal to the given object.
Args:
other: The object to be compared with.
Returns:
True if the objects are equal.
"""
if not isinstance(other, ClassificationEntry):
return False
return self.to_pb2().__eq__(other.to_pb2())
@dataclasses.dataclass
class Classifications:
"""Represents the classifications for a given classifier head.
Attributes:
entries: A list of `ClassificationEntry` objects.
head_index: The index of the classifier head these categories refer to. This
is useful for multi-head models.
head_name: The name of the classifier head, which is the corresponding
tensor metadata name.
"""
entries: List[ClassificationEntry]
head_index: int
head_name: str
@doc_controls.do_not_generate_docs
def to_pb2(self) -> _ClassificationsProto:
"""Generates a Classifications protobuf object."""
return _ClassificationsProto(
entries=[entry.to_pb2() for entry in self.entries],
head_index=self.head_index,
head_name=self.head_name)
@classmethod
@doc_controls.do_not_generate_docs
def create_from_pb2(cls, pb2_obj: _ClassificationsProto) -> 'Classifications':
"""Creates a `Classifications` object from the given protobuf object."""
return Classifications(
entries=[
ClassificationEntry.create_from_pb2(entry)
for entry in pb2_obj.entries
],
head_index=pb2_obj.head_index,
head_name=pb2_obj.head_name)
def __eq__(self, other: Any) -> bool:
"""Checks if this object is equal to the given object.
Args:
other: The object to be compared with.
Returns:
True if the objects are equal.
"""
if not isinstance(other, Classifications):
return False
return self.to_pb2().__eq__(other.to_pb2())
@dataclasses.dataclass
class ClassificationResult:
"""Contains one set of results per classifier head.
Attributes:
classifications: A list of `Classifications` objects.
"""
classifications: List[Classifications]
@doc_controls.do_not_generate_docs
def to_pb2(self) -> _ClassificationResultProto:
"""Generates a ClassificationResult protobuf object."""
return _ClassificationResultProto(classifications=[
classification.to_pb2() for classification in self.classifications
])
@classmethod
@doc_controls.do_not_generate_docs
def create_from_pb2(
cls, pb2_obj: _ClassificationResultProto) -> 'ClassificationResult':
"""Creates a `ClassificationResult` object from the given protobuf object.
"""
return ClassificationResult(classifications=[
Classifications.create_from_pb2(classification)
for classification in pb2_obj.classifications
])
def __eq__(self, other: Any) -> bool:
"""Checks if this object is equal to the given object.
Args:
other: The object to be compared with.
Returns:
True if the objects are equal.
"""
if not isinstance(other, ClassificationResult):
return False
return self.to_pb2().__eq__(other.to_pb2())

View File

@ -27,7 +27,7 @@ py_test(
],
deps = [
"//mediapipe/tasks/python/components/containers:category",
"//mediapipe/tasks/python/components/containers:classifications",
"//mediapipe/tasks/python/components/containers:classification_result",
"//mediapipe/tasks/python/components/processors:classifier_options",
"//mediapipe/tasks/python/core:base_options",
"//mediapipe/tasks/python/test:test_utils",

View File

@ -20,18 +20,17 @@ from absl.testing import absltest
from absl.testing import parameterized
from mediapipe.tasks.python.components.containers import category
from mediapipe.tasks.python.components.containers import classifications as classifications_module
from mediapipe.tasks.python.components.containers import classification_result as classification_result_module
from mediapipe.tasks.python.components.processors import classifier_options
from mediapipe.tasks.python.core import base_options as base_options_module
from mediapipe.tasks.python.test import test_utils
from mediapipe.tasks.python.text import text_classifier
TextClassifierResult = classification_result_module.ClassificationResult
_BaseOptions = base_options_module.BaseOptions
_ClassifierOptions = classifier_options.ClassifierOptions
_Category = category.Category
_ClassificationEntry = classifications_module.ClassificationEntry
_Classifications = classifications_module.Classifications
_TextClassifierResult = classifications_module.ClassificationResult
_Classifications = classification_result_module.Classifications
_TextClassifier = text_classifier.TextClassifier
_TextClassifierOptions = text_classifier.TextClassifierOptions
@ -43,90 +42,75 @@ _NEGATIVE_TEXT = 'What a waste of my time.'
_POSITIVE_TEXT = ('This is the best movie Ive seen in recent years.'
'Strongly recommend it!')
_BERT_NEGATIVE_RESULTS = _TextClassifierResult(classifications=[
_Classifications(
entries=[
_ClassificationEntry(
categories=[
_Category(
index=0,
score=0.999479,
display_name='',
category_name='negative'),
_Category(
index=1,
score=0.00052154,
display_name='',
category_name='positive')
],
timestamp_ms=0)
_BERT_NEGATIVE_RESULTS = TextClassifierResult(classifications=[
_Classifications(categories=[
_Category(
index=0,
score=0.999479,
display_name='',
category_name='negative'),
_Category(
index=1,
score=0.00052154,
display_name='',
category_name='positive')
],
head_index=0,
head_name='probability')
])
_BERT_POSITIVE_RESULTS = _TextClassifierResult(classifications=[
_Classifications(
entries=[
_ClassificationEntry(
categories=[
_Category(
index=1,
score=0.999466,
display_name='',
category_name='positive'),
_Category(
index=0,
score=0.000533596,
display_name='',
category_name='negative')
],
timestamp_ms=0)
],
timestamp_ms=0)
_BERT_POSITIVE_RESULTS = TextClassifierResult(classifications=[
_Classifications(categories=[
_Category(
index=1,
score=0.999466,
display_name='',
category_name='positive'),
_Category(
index=0,
score=0.000533596,
display_name='',
category_name='negative')
],
head_index=0,
head_name='probability')
])
_REGEX_NEGATIVE_RESULTS = _TextClassifierResult(classifications=[
_Classifications(
entries=[
_ClassificationEntry(
categories=[
_Category(
index=0,
score=0.81313,
display_name='',
category_name='Negative'),
_Category(
index=1,
score=0.1868704,
display_name='',
category_name='Positive')
],
timestamp_ms=0)
],
timestamp_ms=0)
_REGEX_NEGATIVE_RESULTS = TextClassifierResult(classifications=[
_Classifications(categories=[
_Category(
index=0,
score=0.81313,
display_name='',
category_name='Negative'),
_Category(
index=1,
score=0.1868704,
display_name='',
category_name='Positive')
],
head_index=0,
head_name='probability')
])
_REGEX_POSITIVE_RESULTS = _TextClassifierResult(classifications=[
_Classifications(
entries=[
_ClassificationEntry(
categories=[
_Category(
index=1,
score=0.5134273,
display_name='',
category_name='Positive'),
_Category(
index=0,
score=0.486573,
display_name='',
category_name='Negative')
],
timestamp_ms=0)
],
timestamp_ms=0)
_REGEX_POSITIVE_RESULTS = TextClassifierResult(classifications=[
_Classifications(categories=[
_Category(
index=1,
score=0.5134273,
display_name='',
category_name='Positive'),
_Category(
index=0,
score=0.486573,
display_name='',
category_name='Negative')
],
head_index=0,
head_name='probability')
])
],
timestamp_ms=0)
class ModelFileType(enum.Enum):

View File

@ -47,7 +47,7 @@ py_test(
deps = [
"//mediapipe/python:_framework_bindings",
"//mediapipe/tasks/python/components/containers:category",
"//mediapipe/tasks/python/components/containers:classifications",
"//mediapipe/tasks/python/components/containers:classification_result",
"//mediapipe/tasks/python/components/containers:rect",
"//mediapipe/tasks/python/components/processors:classifier_options",
"//mediapipe/tasks/python/core:base_options",

View File

@ -23,8 +23,8 @@ from absl.testing import parameterized
import numpy as np
from mediapipe.python._framework_bindings import image
from mediapipe.tasks.python.components.containers import category
from mediapipe.tasks.python.components.containers import classifications as classifications_module
from mediapipe.tasks.python.components.containers import category as category_module
from mediapipe.tasks.python.components.containers import classification_result as classification_result_module
from mediapipe.tasks.python.components.containers import rect
from mediapipe.tasks.python.components.processors import classifier_options
from mediapipe.tasks.python.core import base_options as base_options_module
@ -33,13 +33,12 @@ from mediapipe.tasks.python.vision import image_classifier
from mediapipe.tasks.python.vision.core import image_processing_options as image_processing_options_module
from mediapipe.tasks.python.vision.core import vision_task_running_mode
ImageClassifierResult = classification_result_module.ClassificationResult
_Rect = rect.Rect
_BaseOptions = base_options_module.BaseOptions
_ClassifierOptions = classifier_options.ClassifierOptions
_Category = category.Category
_ClassificationEntry = classifications_module.ClassificationEntry
_Classifications = classifications_module.Classifications
_ClassificationResult = classifications_module.ClassificationResult
_Category = category_module.Category
_Classifications = classification_result_module.Classifications
_Image = image.Image
_ImageClassifier = image_classifier.ImageClassifier
_ImageClassifierOptions = image_classifier.ImageClassifierOptions
@ -55,68 +54,56 @@ _MAX_RESULTS = 3
_TEST_DATA_DIR = 'mediapipe/tasks/testdata/vision'
def _generate_empty_results(timestamp_ms: int) -> _ClassificationResult:
return _ClassificationResult(classifications=[
_Classifications(
entries=[
_ClassificationEntry(categories=[], timestamp_ms=timestamp_ms)
def _generate_empty_results() -> ImageClassifierResult:
return ImageClassifierResult(classifications=[
_Classifications(categories=[], head_index=0, head_name='probability')
],
timestamp_ms=0)
def _generate_burger_results() -> ImageClassifierResult:
return ImageClassifierResult(classifications=[
_Classifications(categories=[
_Category(
index=934,
score=0.793959,
display_name='',
category_name='cheeseburger'),
_Category(
index=932,
score=0.0273929,
display_name='',
category_name='bagel'),
_Category(
index=925,
score=0.0193408,
display_name='',
category_name='guacamole'),
_Category(
index=963,
score=0.00632786,
display_name='',
category_name='meat loaf')
],
head_index=0,
head_name='probability')
])
],
timestamp_ms=0)
def _generate_burger_results(timestamp_ms: int) -> _ClassificationResult:
return _ClassificationResult(classifications=[
_Classifications(
entries=[
_ClassificationEntry(
categories=[
_Category(
index=934,
score=0.793959,
display_name='',
category_name='cheeseburger'),
_Category(
index=932,
score=0.0273929,
display_name='',
category_name='bagel'),
_Category(
index=925,
score=0.0193408,
display_name='',
category_name='guacamole'),
_Category(
index=963,
score=0.00632786,
display_name='',
category_name='meat loaf')
],
timestamp_ms=timestamp_ms)
def _generate_soccer_ball_results() -> ImageClassifierResult:
return ImageClassifierResult(classifications=[
_Classifications(categories=[
_Category(
index=806,
score=0.996527,
display_name='',
category_name='soccer ball')
],
head_index=0,
head_name='probability')
])
def _generate_soccer_ball_results(timestamp_ms: int) -> _ClassificationResult:
return _ClassificationResult(classifications=[
_Classifications(
entries=[
_ClassificationEntry(
categories=[
_Category(
index=806,
score=0.996527,
display_name='',
category_name='soccer ball')
],
timestamp_ms=timestamp_ms)
],
head_index=0,
head_name='probability')
])
],
timestamp_ms=0)
class ModelFileType(enum.Enum):
@ -165,8 +152,8 @@ class ImageClassifierTest(parameterized.TestCase):
self.assertIsInstance(classifier, _ImageClassifier)
@parameterized.parameters(
(ModelFileType.FILE_NAME, 4, _generate_burger_results(0)),
(ModelFileType.FILE_CONTENT, 4, _generate_burger_results(0)))
(ModelFileType.FILE_NAME, 4, _generate_burger_results()),
(ModelFileType.FILE_CONTENT, 4, _generate_burger_results()))
def test_classify(self, model_file_type, max_results,
expected_classification_result):
# Creates classifier.
@ -195,8 +182,8 @@ class ImageClassifierTest(parameterized.TestCase):
classifier.close()
@parameterized.parameters(
(ModelFileType.FILE_NAME, 4, _generate_burger_results(0)),
(ModelFileType.FILE_CONTENT, 4, _generate_burger_results(0)))
(ModelFileType.FILE_NAME, 4, _generate_burger_results()),
(ModelFileType.FILE_CONTENT, 4, _generate_burger_results()))
def test_classify_in_context(self, model_file_type, max_results,
expected_classification_result):
if model_file_type is ModelFileType.FILE_NAME:
@ -236,7 +223,7 @@ class ImageClassifierTest(parameterized.TestCase):
image_result = classifier.classify(test_image, image_processing_options)
# Comparing results.
test_utils.assert_proto_equals(self, image_result.to_pb2(),
_generate_soccer_ball_results(0).to_pb2())
_generate_soccer_ball_results().to_pb2())
def test_score_threshold_option(self):
custom_classifier_options = _ClassifierOptions(
@ -250,8 +237,8 @@ class ImageClassifierTest(parameterized.TestCase):
classifications = image_result.classifications
for classification in classifications:
for entry in classification.entries:
score = entry.categories[0].score
for category in classification.categories:
score = category.score
self.assertGreaterEqual(
score, _SCORE_THRESHOLD,
f'Classification with score lower than threshold found. '
@ -266,7 +253,7 @@ class ImageClassifierTest(parameterized.TestCase):
with _ImageClassifier.create_from_options(options) as classifier:
# Performs image classification on the input.
image_result = classifier.classify(self.test_image)
categories = image_result.classifications[0].entries[0].categories
categories = image_result.classifications[0].categories
self.assertLessEqual(
len(categories), _MAX_RESULTS, 'Too many results returned.')
@ -283,8 +270,8 @@ class ImageClassifierTest(parameterized.TestCase):
classifications = image_result.classifications
for classification in classifications:
for entry in classification.entries:
label = entry.categories[0].category_name
for category in classification.categories:
label = category.category_name
self.assertIn(label, _ALLOW_LIST,
f'Label {label} found but not in label allow list')
@ -299,8 +286,8 @@ class ImageClassifierTest(parameterized.TestCase):
classifications = image_result.classifications
for classification in classifications:
for entry in classification.entries:
label = entry.categories[0].category_name
for category in classification.categories:
label = category.category_name
self.assertNotIn(label, _DENY_LIST,
f'Label {label} found but in deny list.')
@ -326,7 +313,7 @@ class ImageClassifierTest(parameterized.TestCase):
with _ImageClassifier.create_from_options(options) as classifier:
# Performs image classification on the input.
image_result = classifier.classify(self.test_image)
self.assertEmpty(image_result.classifications[0].entries[0].categories)
self.assertEmpty(image_result.classifications[0].categories)
def test_missing_result_callback(self):
options = _ImageClassifierOptions(
@ -406,7 +393,7 @@ class ImageClassifierTest(parameterized.TestCase):
self.test_image, timestamp)
test_utils.assert_proto_equals(
self, classification_result.to_pb2(),
_generate_burger_results(timestamp).to_pb2())
_generate_burger_results().to_pb2())
def test_classify_for_video_succeeds_with_region_of_interest(self):
custom_classifier_options = _ClassifierOptions(max_results=1)
@ -427,7 +414,7 @@ class ImageClassifierTest(parameterized.TestCase):
test_image, timestamp, image_processing_options)
test_utils.assert_proto_equals(
self, classification_result.to_pb2(),
_generate_soccer_ball_results(timestamp).to_pb2())
_generate_soccer_ball_results().to_pb2())
def test_calling_classify_in_live_stream_mode(self):
options = _ImageClassifierOptions(
@ -462,15 +449,15 @@ class ImageClassifierTest(parameterized.TestCase):
ValueError, r'Input timestamp must be monotonically increasing'):
classifier.classify_async(self.test_image, 0)
@parameterized.parameters((0, _generate_burger_results),
(1, _generate_empty_results))
def test_classify_async_calls(self, threshold, expected_result_fn):
@parameterized.parameters((0, _generate_burger_results()),
(1, _generate_empty_results()))
def test_classify_async_calls(self, threshold, expected_result):
observed_timestamp_ms = -1
def check_result(result: _ClassificationResult, output_image: _Image,
def check_result(result: ImageClassifierResult, output_image: _Image,
timestamp_ms: int):
test_utils.assert_proto_equals(self, result.to_pb2(),
expected_result_fn(timestamp_ms).to_pb2())
expected_result.to_pb2())
self.assertTrue(
np.array_equal(output_image.numpy_view(),
self.test_image.numpy_view()))
@ -498,11 +485,11 @@ class ImageClassifierTest(parameterized.TestCase):
image_processing_options = _ImageProcessingOptions(roi)
observed_timestamp_ms = -1
def check_result(result: _ClassificationResult, output_image: _Image,
def check_result(result: ImageClassifierResult, output_image: _Image,
timestamp_ms: int):
test_utils.assert_proto_equals(
self, result.to_pb2(),
_generate_soccer_ball_results(timestamp_ms).to_pb2())
_generate_soccer_ball_results().to_pb2())
self.assertEqual(output_image.width, test_image.width)
self.assertEqual(output_image.height, test_image.height)
self.assertLess(observed_timestamp_ms, timestamp_ms)

View File

@ -28,7 +28,7 @@ py_library(
"//mediapipe/python:packet_getter",
"//mediapipe/tasks/cc/components/containers/proto:classifications_py_pb2",
"//mediapipe/tasks/cc/text/text_classifier/proto:text_classifier_graph_options_py_pb2",
"//mediapipe/tasks/python/components/containers:classifications",
"//mediapipe/tasks/python/components/containers:classification_result",
"//mediapipe/tasks/python/components/processors:classifier_options",
"//mediapipe/tasks/python/core:base_options",
"//mediapipe/tasks/python/core:optional_dependencies",

View File

@ -19,21 +19,21 @@ from mediapipe.python import packet_creator
from mediapipe.python import packet_getter
from mediapipe.tasks.cc.components.containers.proto import classifications_pb2
from mediapipe.tasks.cc.text.text_classifier.proto import text_classifier_graph_options_pb2
from mediapipe.tasks.python.components.containers import classifications
from mediapipe.tasks.python.components.containers import classification_result as classification_result_module
from mediapipe.tasks.python.components.processors import classifier_options
from mediapipe.tasks.python.core import base_options as base_options_module
from mediapipe.tasks.python.core import task_info as task_info_module
from mediapipe.tasks.python.core.optional_dependencies import doc_controls
from mediapipe.tasks.python.text.core import base_text_task_api
TextClassifierResult = classifications.ClassificationResult
TextClassifierResult = classification_result_module.ClassificationResult
_BaseOptions = base_options_module.BaseOptions
_TextClassifierGraphOptionsProto = text_classifier_graph_options_pb2.TextClassifierGraphOptions
_ClassifierOptions = classifier_options.ClassifierOptions
_TaskInfo = task_info_module.TaskInfo
_CLASSIFICATION_RESULT_OUT_STREAM_NAME = 'classification_result_out'
_CLASSIFICATION_RESULT_TAG = 'CLASSIFICATION_RESULT'
_CLASSIFICATIONS_STREAM_NAME = 'classifications_out'
_CLASSIFICATIONS_TAG = 'CLASSIFICATIONS'
_TEXT_IN_STREAM_NAME = 'text_in'
_TEXT_TAG = 'TEXT'
_TASK_GRAPH_NAME = 'mediapipe.tasks.text.text_classifier.TextClassifierGraph'
@ -105,8 +105,8 @@ class TextClassifier(base_text_task_api.BaseTextTaskApi):
input_streams=[':'.join([_TEXT_TAG, _TEXT_IN_STREAM_NAME])],
output_streams=[
':'.join([
_CLASSIFICATION_RESULT_TAG,
_CLASSIFICATION_RESULT_OUT_STREAM_NAME
_CLASSIFICATIONS_TAG,
_CLASSIFICATIONS_STREAM_NAME
])
],
task_options=options)
@ -132,9 +132,6 @@ class TextClassifier(base_text_task_api.BaseTextTaskApi):
classification_result_proto = classifications_pb2.ClassificationResult()
classification_result_proto.CopyFrom(
packet_getter.get_proto(
output_packets[_CLASSIFICATION_RESULT_OUT_STREAM_NAME]))
output_packets[_CLASSIFICATIONS_STREAM_NAME]))
return TextClassifierResult([
classifications.Classifications.create_from_pb2(classification)
for classification in classification_result_proto.classifications
])
return TextClassifierResult.create_from_pb2(classification_result_proto)

View File

@ -48,7 +48,7 @@ py_library(
"//mediapipe/python:packet_getter",
"//mediapipe/tasks/cc/components/containers/proto:classifications_py_pb2",
"//mediapipe/tasks/cc/vision/image_classifier/proto:image_classifier_graph_options_py_pb2",
"//mediapipe/tasks/python/components/containers:classifications",
"//mediapipe/tasks/python/components/containers:classification_result",
"//mediapipe/tasks/python/components/containers:rect",
"//mediapipe/tasks/python/components/processors:classifier_options",
"//mediapipe/tasks/python/core:base_options",

View File

@ -23,7 +23,7 @@ from mediapipe.python._framework_bindings import image as image_module
from mediapipe.python._framework_bindings import packet
from mediapipe.tasks.cc.components.containers.proto import classifications_pb2
from mediapipe.tasks.cc.vision.image_classifier.proto import image_classifier_graph_options_pb2
from mediapipe.tasks.python.components.containers import classifications
from mediapipe.tasks.python.components.containers import classification_result as classification_result_module
from mediapipe.tasks.python.components.containers import rect
from mediapipe.tasks.python.components.processors import classifier_options
from mediapipe.tasks.python.core import base_options as base_options_module
@ -33,6 +33,7 @@ from mediapipe.tasks.python.vision.core import base_vision_task_api
from mediapipe.tasks.python.vision.core import image_processing_options as image_processing_options_module
from mediapipe.tasks.python.vision.core import vision_task_running_mode
ImageClassifierResult = classification_result_module.ClassificationResult
_NormalizedRect = rect.NormalizedRect
_BaseOptions = base_options_module.BaseOptions
_ImageClassifierGraphOptionsProto = image_classifier_graph_options_pb2.ImageClassifierGraphOptions
@ -41,8 +42,8 @@ _RunningMode = vision_task_running_mode.VisionTaskRunningMode
_ImageProcessingOptions = image_processing_options_module.ImageProcessingOptions
_TaskInfo = task_info_module.TaskInfo
_CLASSIFICATION_RESULT_OUT_STREAM_NAME = 'classification_result_out'
_CLASSIFICATION_RESULT_TAG = 'CLASSIFICATION_RESULT'
_CLASSIFICATIONS_STREAM_NAME = 'classifications_out'
_CLASSIFICATIONS_TAG = 'CLASSIFICATIONS'
_IMAGE_IN_STREAM_NAME = 'image_in'
_IMAGE_OUT_STREAM_NAME = 'image_out'
_IMAGE_TAG = 'IMAGE'
@ -72,7 +73,7 @@ class ImageClassifierOptions:
running_mode: _RunningMode = _RunningMode.IMAGE
classifier_options: _ClassifierOptions = _ClassifierOptions()
result_callback: Optional[
Callable[[classifications.ClassificationResult, image_module.Image, int],
Callable[[ImageClassifierResult, image_module.Image, int],
None]] = None
@doc_controls.do_not_generate_docs
@ -137,17 +138,13 @@ class ImageClassifier(base_vision_task_api.BaseVisionTaskApi):
classification_result_proto = classifications_pb2.ClassificationResult()
classification_result_proto.CopyFrom(
packet_getter.get_proto(
output_packets[_CLASSIFICATION_RESULT_OUT_STREAM_NAME]))
classification_result = classifications.ClassificationResult([
classifications.Classifications.create_from_pb2(classification)
for classification in classification_result_proto.classifications
])
packet_getter.get_proto(output_packets[_CLASSIFICATIONS_STREAM_NAME]))
image = packet_getter.get_image(output_packets[_IMAGE_OUT_STREAM_NAME])
timestamp = output_packets[_IMAGE_OUT_STREAM_NAME].timestamp
options.result_callback(classification_result, image,
timestamp.value // _MICRO_SECONDS_PER_MILLISECOND)
options.result_callback(
ImageClassifierResult.create_from_pb2(classification_result_proto),
image,
timestamp.value // _MICRO_SECONDS_PER_MILLISECOND)
task_info = _TaskInfo(
task_graph=_TASK_GRAPH_NAME,
@ -157,8 +154,8 @@ class ImageClassifier(base_vision_task_api.BaseVisionTaskApi):
],
output_streams=[
':'.join([
_CLASSIFICATION_RESULT_TAG,
_CLASSIFICATION_RESULT_OUT_STREAM_NAME
_CLASSIFICATIONS_TAG,
_CLASSIFICATIONS_STREAM_NAME
]), ':'.join([_IMAGE_TAG, _IMAGE_OUT_STREAM_NAME])
],
task_options=options)
@ -172,7 +169,7 @@ class ImageClassifier(base_vision_task_api.BaseVisionTaskApi):
self,
image: image_module.Image,
image_processing_options: Optional[_ImageProcessingOptions] = None
) -> classifications.ClassificationResult:
) -> ImageClassifierResult:
"""Performs image classification on the provided MediaPipe Image.
Args:
@ -196,20 +193,16 @@ class ImageClassifier(base_vision_task_api.BaseVisionTaskApi):
classification_result_proto = classifications_pb2.ClassificationResult()
classification_result_proto.CopyFrom(
packet_getter.get_proto(
output_packets[_CLASSIFICATION_RESULT_OUT_STREAM_NAME]))
packet_getter.get_proto(output_packets[_CLASSIFICATIONS_STREAM_NAME]))
return classifications.ClassificationResult([
classifications.Classifications.create_from_pb2(classification)
for classification in classification_result_proto.classifications
])
return ImageClassifierResult.create_from_pb2(classification_result_proto)
def classify_for_video(
self,
image: image_module.Image,
timestamp_ms: int,
image_processing_options: Optional[_ImageProcessingOptions] = None
) -> classifications.ClassificationResult:
) -> ImageClassifierResult:
"""Performs image classification on the provided video frames.
Only use this method when the ImageClassifier is created with the video
@ -241,13 +234,9 @@ class ImageClassifier(base_vision_task_api.BaseVisionTaskApi):
classification_result_proto = classifications_pb2.ClassificationResult()
classification_result_proto.CopyFrom(
packet_getter.get_proto(
output_packets[_CLASSIFICATION_RESULT_OUT_STREAM_NAME]))
packet_getter.get_proto(output_packets[_CLASSIFICATIONS_STREAM_NAME]))
return classifications.ClassificationResult([
classifications.Classifications.create_from_pb2(classification)
for classification in classification_result_proto.classifications
])
return ImageClassifierResult.create_from_pb2(classification_result_proto)
def classify_async(
self,