Merge pull request #3832 from kinaryml:update-classification-tasks-output-stream
PiperOrigin-RevId: 487638077
This commit is contained in:
commit
bbf4ff0300
|
@ -85,21 +85,12 @@ py_library(
|
|||
],
|
||||
)
|
||||
|
||||
py_library(
|
||||
name = "classifications",
|
||||
srcs = ["classifications.py"],
|
||||
deps = [
|
||||
":category",
|
||||
"//mediapipe/tasks/cc/components/containers/proto:classifications_py_pb2",
|
||||
"//mediapipe/tasks/python/core:optional_dependencies",
|
||||
],
|
||||
)
|
||||
|
||||
py_library(
|
||||
name = "classification_result",
|
||||
srcs = ["classification_result.py"],
|
||||
deps = [
|
||||
":category",
|
||||
"//mediapipe/framework/formats:classification_py_pb2",
|
||||
"//mediapipe/tasks/cc/components/containers/proto:classifications_py_pb2",
|
||||
"//mediapipe/tasks/python/core:optional_dependencies",
|
||||
],
|
||||
|
|
|
@ -16,10 +16,13 @@
|
|||
import dataclasses
|
||||
from typing import List, Optional
|
||||
|
||||
from mediapipe.framework.formats import classification_pb2
|
||||
from mediapipe.tasks.cc.components.containers.proto import classifications_pb2
|
||||
from mediapipe.tasks.python.components.containers import category as category_module
|
||||
from mediapipe.tasks.python.core.optional_dependencies import doc_controls
|
||||
|
||||
_ClassificationProto = classification_pb2.Classification
|
||||
_ClassificationListProto = classification_pb2.ClassificationList
|
||||
_ClassificationsProto = classifications_pb2.Classifications
|
||||
_ClassificationResultProto = classifications_pb2.ClassificationResult
|
||||
|
||||
|
@ -41,6 +44,22 @@ class Classifications:
|
|||
head_index: int
|
||||
head_name: Optional[str] = None
|
||||
|
||||
@doc_controls.do_not_generate_docs
|
||||
def to_pb2(self) -> _ClassificationsProto:
|
||||
"""Generates a Classifications protobuf object."""
|
||||
classification_list_proto = _ClassificationListProto()
|
||||
for category in self.categories:
|
||||
classification_proto = _ClassificationProto(
|
||||
index=category.index,
|
||||
score=category.score,
|
||||
label=category.category_name,
|
||||
display_name=category.display_name)
|
||||
classification_list_proto.classification.append(classification_proto)
|
||||
return _ClassificationsProto(
|
||||
classification_list=classification_list_proto,
|
||||
head_index=self.head_index,
|
||||
head_name=self.head_name)
|
||||
|
||||
@classmethod
|
||||
@doc_controls.do_not_generate_docs
|
||||
def create_from_pb2(cls, pb2_obj: _ClassificationsProto) -> 'Classifications':
|
||||
|
@ -78,6 +97,15 @@ class ClassificationResult:
|
|||
classifications: List[Classifications]
|
||||
timestamp_ms: Optional[int] = None
|
||||
|
||||
@doc_controls.do_not_generate_docs
|
||||
def to_pb2(self) -> _ClassificationResultProto:
|
||||
"""Generates a ClassificationResult protobuf object."""
|
||||
return _ClassificationResultProto(
|
||||
classifications=[
|
||||
classification.to_pb2() for classification in self.classifications
|
||||
],
|
||||
timestamp_ms=self.timestamp_ms)
|
||||
|
||||
@classmethod
|
||||
@doc_controls.do_not_generate_docs
|
||||
def create_from_pb2(
|
||||
|
|
|
@ -1,168 +0,0 @@
|
|||
# Copyright 2022 The TensorFlow Authors. All Rights Reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
"""Classifications data class."""
|
||||
|
||||
import dataclasses
|
||||
from typing import Any, List, Optional
|
||||
|
||||
from mediapipe.tasks.cc.components.containers.proto import classifications_pb2
|
||||
from mediapipe.tasks.python.components.containers import category as category_module
|
||||
from mediapipe.tasks.python.core.optional_dependencies import doc_controls
|
||||
|
||||
_ClassificationEntryProto = classifications_pb2.ClassificationEntry
|
||||
_ClassificationsProto = classifications_pb2.Classifications
|
||||
_ClassificationResultProto = classifications_pb2.ClassificationResult
|
||||
|
||||
|
||||
@dataclasses.dataclass
|
||||
class ClassificationEntry:
|
||||
"""List of predicted classes (aka labels) for a given classifier head.
|
||||
|
||||
Attributes:
|
||||
categories: The array of predicted categories, usually sorted by descending
|
||||
scores (e.g. from high to low probability).
|
||||
timestamp_ms: The optional timestamp (in milliseconds) associated to the
|
||||
classification entry. This is useful for time series use cases, e.g.,
|
||||
audio classification.
|
||||
"""
|
||||
|
||||
categories: List[category_module.Category]
|
||||
timestamp_ms: Optional[int] = None
|
||||
|
||||
@doc_controls.do_not_generate_docs
|
||||
def to_pb2(self) -> _ClassificationEntryProto:
|
||||
"""Generates a ClassificationEntry protobuf object."""
|
||||
return _ClassificationEntryProto(
|
||||
categories=[category.to_pb2() for category in self.categories],
|
||||
timestamp_ms=self.timestamp_ms)
|
||||
|
||||
@classmethod
|
||||
@doc_controls.do_not_generate_docs
|
||||
def create_from_pb2(
|
||||
cls, pb2_obj: _ClassificationEntryProto) -> 'ClassificationEntry':
|
||||
"""Creates a `ClassificationEntry` object from the given protobuf object."""
|
||||
return ClassificationEntry(
|
||||
categories=[
|
||||
category_module.Category.create_from_pb2(category)
|
||||
for category in pb2_obj.categories
|
||||
],
|
||||
timestamp_ms=pb2_obj.timestamp_ms)
|
||||
|
||||
def __eq__(self, other: Any) -> bool:
|
||||
"""Checks if this object is equal to the given object.
|
||||
|
||||
Args:
|
||||
other: The object to be compared with.
|
||||
|
||||
Returns:
|
||||
True if the objects are equal.
|
||||
"""
|
||||
if not isinstance(other, ClassificationEntry):
|
||||
return False
|
||||
|
||||
return self.to_pb2().__eq__(other.to_pb2())
|
||||
|
||||
|
||||
@dataclasses.dataclass
|
||||
class Classifications:
|
||||
"""Represents the classifications for a given classifier head.
|
||||
|
||||
Attributes:
|
||||
entries: A list of `ClassificationEntry` objects.
|
||||
head_index: The index of the classifier head these categories refer to. This
|
||||
is useful for multi-head models.
|
||||
head_name: The name of the classifier head, which is the corresponding
|
||||
tensor metadata name.
|
||||
"""
|
||||
|
||||
entries: List[ClassificationEntry]
|
||||
head_index: int
|
||||
head_name: str
|
||||
|
||||
@doc_controls.do_not_generate_docs
|
||||
def to_pb2(self) -> _ClassificationsProto:
|
||||
"""Generates a Classifications protobuf object."""
|
||||
return _ClassificationsProto(
|
||||
entries=[entry.to_pb2() for entry in self.entries],
|
||||
head_index=self.head_index,
|
||||
head_name=self.head_name)
|
||||
|
||||
@classmethod
|
||||
@doc_controls.do_not_generate_docs
|
||||
def create_from_pb2(cls, pb2_obj: _ClassificationsProto) -> 'Classifications':
|
||||
"""Creates a `Classifications` object from the given protobuf object."""
|
||||
return Classifications(
|
||||
entries=[
|
||||
ClassificationEntry.create_from_pb2(entry)
|
||||
for entry in pb2_obj.entries
|
||||
],
|
||||
head_index=pb2_obj.head_index,
|
||||
head_name=pb2_obj.head_name)
|
||||
|
||||
def __eq__(self, other: Any) -> bool:
|
||||
"""Checks if this object is equal to the given object.
|
||||
|
||||
Args:
|
||||
other: The object to be compared with.
|
||||
|
||||
Returns:
|
||||
True if the objects are equal.
|
||||
"""
|
||||
if not isinstance(other, Classifications):
|
||||
return False
|
||||
|
||||
return self.to_pb2().__eq__(other.to_pb2())
|
||||
|
||||
|
||||
@dataclasses.dataclass
|
||||
class ClassificationResult:
|
||||
"""Contains one set of results per classifier head.
|
||||
|
||||
Attributes:
|
||||
classifications: A list of `Classifications` objects.
|
||||
"""
|
||||
|
||||
classifications: List[Classifications]
|
||||
|
||||
@doc_controls.do_not_generate_docs
|
||||
def to_pb2(self) -> _ClassificationResultProto:
|
||||
"""Generates a ClassificationResult protobuf object."""
|
||||
return _ClassificationResultProto(classifications=[
|
||||
classification.to_pb2() for classification in self.classifications
|
||||
])
|
||||
|
||||
@classmethod
|
||||
@doc_controls.do_not_generate_docs
|
||||
def create_from_pb2(
|
||||
cls, pb2_obj: _ClassificationResultProto) -> 'ClassificationResult':
|
||||
"""Creates a `ClassificationResult` object from the given protobuf object.
|
||||
"""
|
||||
return ClassificationResult(classifications=[
|
||||
Classifications.create_from_pb2(classification)
|
||||
for classification in pb2_obj.classifications
|
||||
])
|
||||
|
||||
def __eq__(self, other: Any) -> bool:
|
||||
"""Checks if this object is equal to the given object.
|
||||
|
||||
Args:
|
||||
other: The object to be compared with.
|
||||
|
||||
Returns:
|
||||
True if the objects are equal.
|
||||
"""
|
||||
if not isinstance(other, ClassificationResult):
|
||||
return False
|
||||
|
||||
return self.to_pb2().__eq__(other.to_pb2())
|
|
@ -27,7 +27,7 @@ py_test(
|
|||
],
|
||||
deps = [
|
||||
"//mediapipe/tasks/python/components/containers:category",
|
||||
"//mediapipe/tasks/python/components/containers:classifications",
|
||||
"//mediapipe/tasks/python/components/containers:classification_result",
|
||||
"//mediapipe/tasks/python/components/processors:classifier_options",
|
||||
"//mediapipe/tasks/python/core:base_options",
|
||||
"//mediapipe/tasks/python/test:test_utils",
|
||||
|
|
|
@ -20,18 +20,17 @@ from absl.testing import absltest
|
|||
from absl.testing import parameterized
|
||||
|
||||
from mediapipe.tasks.python.components.containers import category
|
||||
from mediapipe.tasks.python.components.containers import classifications as classifications_module
|
||||
from mediapipe.tasks.python.components.containers import classification_result as classification_result_module
|
||||
from mediapipe.tasks.python.components.processors import classifier_options
|
||||
from mediapipe.tasks.python.core import base_options as base_options_module
|
||||
from mediapipe.tasks.python.test import test_utils
|
||||
from mediapipe.tasks.python.text import text_classifier
|
||||
|
||||
TextClassifierResult = classification_result_module.ClassificationResult
|
||||
_BaseOptions = base_options_module.BaseOptions
|
||||
_ClassifierOptions = classifier_options.ClassifierOptions
|
||||
_Category = category.Category
|
||||
_ClassificationEntry = classifications_module.ClassificationEntry
|
||||
_Classifications = classifications_module.Classifications
|
||||
_TextClassifierResult = classifications_module.ClassificationResult
|
||||
_Classifications = classification_result_module.Classifications
|
||||
_TextClassifier = text_classifier.TextClassifier
|
||||
_TextClassifierOptions = text_classifier.TextClassifierOptions
|
||||
|
||||
|
@ -43,10 +42,9 @@ _NEGATIVE_TEXT = 'What a waste of my time.'
|
|||
_POSITIVE_TEXT = ('This is the best movie I’ve seen in recent years.'
|
||||
'Strongly recommend it!')
|
||||
|
||||
_BERT_NEGATIVE_RESULTS = _TextClassifierResult(classifications=[
|
||||
_BERT_NEGATIVE_RESULTS = TextClassifierResult(
|
||||
classifications=[
|
||||
_Classifications(
|
||||
entries=[
|
||||
_ClassificationEntry(
|
||||
categories=[
|
||||
_Category(
|
||||
index=0,
|
||||
|
@ -59,15 +57,13 @@ _BERT_NEGATIVE_RESULTS = _TextClassifierResult(classifications=[
|
|||
display_name='',
|
||||
category_name='positive')
|
||||
],
|
||||
timestamp_ms=0)
|
||||
],
|
||||
head_index=0,
|
||||
head_name='probability')
|
||||
])
|
||||
_BERT_POSITIVE_RESULTS = _TextClassifierResult(classifications=[
|
||||
],
|
||||
timestamp_ms=0)
|
||||
_BERT_POSITIVE_RESULTS = TextClassifierResult(
|
||||
classifications=[
|
||||
_Classifications(
|
||||
entries=[
|
||||
_ClassificationEntry(
|
||||
categories=[
|
||||
_Category(
|
||||
index=1,
|
||||
|
@ -80,15 +76,13 @@ _BERT_POSITIVE_RESULTS = _TextClassifierResult(classifications=[
|
|||
display_name='',
|
||||
category_name='negative')
|
||||
],
|
||||
timestamp_ms=0)
|
||||
],
|
||||
head_index=0,
|
||||
head_name='probability')
|
||||
])
|
||||
_REGEX_NEGATIVE_RESULTS = _TextClassifierResult(classifications=[
|
||||
],
|
||||
timestamp_ms=0)
|
||||
_REGEX_NEGATIVE_RESULTS = TextClassifierResult(
|
||||
classifications=[
|
||||
_Classifications(
|
||||
entries=[
|
||||
_ClassificationEntry(
|
||||
categories=[
|
||||
_Category(
|
||||
index=0,
|
||||
|
@ -101,15 +95,13 @@ _REGEX_NEGATIVE_RESULTS = _TextClassifierResult(classifications=[
|
|||
display_name='',
|
||||
category_name='Positive')
|
||||
],
|
||||
timestamp_ms=0)
|
||||
],
|
||||
head_index=0,
|
||||
head_name='probability')
|
||||
])
|
||||
_REGEX_POSITIVE_RESULTS = _TextClassifierResult(classifications=[
|
||||
],
|
||||
timestamp_ms=0)
|
||||
_REGEX_POSITIVE_RESULTS = TextClassifierResult(
|
||||
classifications=[
|
||||
_Classifications(
|
||||
entries=[
|
||||
_ClassificationEntry(
|
||||
categories=[
|
||||
_Category(
|
||||
index=1,
|
||||
|
@ -122,11 +114,10 @@ _REGEX_POSITIVE_RESULTS = _TextClassifierResult(classifications=[
|
|||
display_name='',
|
||||
category_name='Negative')
|
||||
],
|
||||
timestamp_ms=0)
|
||||
],
|
||||
head_index=0,
|
||||
head_name='probability')
|
||||
])
|
||||
],
|
||||
timestamp_ms=0)
|
||||
|
||||
|
||||
class ModelFileType(enum.Enum):
|
||||
|
|
|
@ -47,7 +47,7 @@ py_test(
|
|||
deps = [
|
||||
"//mediapipe/python:_framework_bindings",
|
||||
"//mediapipe/tasks/python/components/containers:category",
|
||||
"//mediapipe/tasks/python/components/containers:classifications",
|
||||
"//mediapipe/tasks/python/components/containers:classification_result",
|
||||
"//mediapipe/tasks/python/components/containers:rect",
|
||||
"//mediapipe/tasks/python/components/processors:classifier_options",
|
||||
"//mediapipe/tasks/python/core:base_options",
|
||||
|
|
|
@ -23,8 +23,8 @@ from absl.testing import parameterized
|
|||
import numpy as np
|
||||
|
||||
from mediapipe.python._framework_bindings import image
|
||||
from mediapipe.tasks.python.components.containers import category
|
||||
from mediapipe.tasks.python.components.containers import classifications as classifications_module
|
||||
from mediapipe.tasks.python.components.containers import category as category_module
|
||||
from mediapipe.tasks.python.components.containers import classification_result as classification_result_module
|
||||
from mediapipe.tasks.python.components.containers import rect
|
||||
from mediapipe.tasks.python.components.processors import classifier_options
|
||||
from mediapipe.tasks.python.core import base_options as base_options_module
|
||||
|
@ -33,13 +33,12 @@ from mediapipe.tasks.python.vision import image_classifier
|
|||
from mediapipe.tasks.python.vision.core import image_processing_options as image_processing_options_module
|
||||
from mediapipe.tasks.python.vision.core import vision_task_running_mode
|
||||
|
||||
ImageClassifierResult = classification_result_module.ClassificationResult
|
||||
_Rect = rect.Rect
|
||||
_BaseOptions = base_options_module.BaseOptions
|
||||
_ClassifierOptions = classifier_options.ClassifierOptions
|
||||
_Category = category.Category
|
||||
_ClassificationEntry = classifications_module.ClassificationEntry
|
||||
_Classifications = classifications_module.Classifications
|
||||
_ClassificationResult = classifications_module.ClassificationResult
|
||||
_Category = category_module.Category
|
||||
_Classifications = classification_result_module.Classifications
|
||||
_Image = image.Image
|
||||
_ImageClassifier = image_classifier.ImageClassifier
|
||||
_ImageClassifierOptions = image_classifier.ImageClassifierOptions
|
||||
|
@ -55,22 +54,19 @@ _MAX_RESULTS = 3
|
|||
_TEST_DATA_DIR = 'mediapipe/tasks/testdata/vision'
|
||||
|
||||
|
||||
def _generate_empty_results(timestamp_ms: int) -> _ClassificationResult:
|
||||
return _ClassificationResult(classifications=[
|
||||
def _generate_empty_results() -> ImageClassifierResult:
|
||||
return ImageClassifierResult(
|
||||
classifications=[
|
||||
_Classifications(
|
||||
entries=[
|
||||
_ClassificationEntry(categories=[], timestamp_ms=timestamp_ms)
|
||||
categories=[], head_index=0, head_name='probability')
|
||||
],
|
||||
head_index=0,
|
||||
head_name='probability')
|
||||
])
|
||||
timestamp_ms=0)
|
||||
|
||||
|
||||
def _generate_burger_results(timestamp_ms: int) -> _ClassificationResult:
|
||||
return _ClassificationResult(classifications=[
|
||||
def _generate_burger_results() -> ImageClassifierResult:
|
||||
return ImageClassifierResult(
|
||||
classifications=[
|
||||
_Classifications(
|
||||
entries=[
|
||||
_ClassificationEntry(
|
||||
categories=[
|
||||
_Category(
|
||||
index=934,
|
||||
|
@ -93,18 +89,16 @@ def _generate_burger_results(timestamp_ms: int) -> _ClassificationResult:
|
|||
display_name='',
|
||||
category_name='meat loaf')
|
||||
],
|
||||
timestamp_ms=timestamp_ms)
|
||||
],
|
||||
head_index=0,
|
||||
head_name='probability')
|
||||
])
|
||||
],
|
||||
timestamp_ms=0)
|
||||
|
||||
|
||||
def _generate_soccer_ball_results(timestamp_ms: int) -> _ClassificationResult:
|
||||
return _ClassificationResult(classifications=[
|
||||
def _generate_soccer_ball_results() -> ImageClassifierResult:
|
||||
return ImageClassifierResult(
|
||||
classifications=[
|
||||
_Classifications(
|
||||
entries=[
|
||||
_ClassificationEntry(
|
||||
categories=[
|
||||
_Category(
|
||||
index=806,
|
||||
|
@ -112,11 +106,10 @@ def _generate_soccer_ball_results(timestamp_ms: int) -> _ClassificationResult:
|
|||
display_name='',
|
||||
category_name='soccer ball')
|
||||
],
|
||||
timestamp_ms=timestamp_ms)
|
||||
],
|
||||
head_index=0,
|
||||
head_name='probability')
|
||||
])
|
||||
],
|
||||
timestamp_ms=0)
|
||||
|
||||
|
||||
class ModelFileType(enum.Enum):
|
||||
|
@ -163,8 +156,8 @@ class ImageClassifierTest(parameterized.TestCase):
|
|||
self.assertIsInstance(classifier, _ImageClassifier)
|
||||
|
||||
@parameterized.parameters(
|
||||
(ModelFileType.FILE_NAME, 4, _generate_burger_results(0)),
|
||||
(ModelFileType.FILE_CONTENT, 4, _generate_burger_results(0)))
|
||||
(ModelFileType.FILE_NAME, 4, _generate_burger_results()),
|
||||
(ModelFileType.FILE_CONTENT, 4, _generate_burger_results()))
|
||||
def test_classify(self, model_file_type, max_results,
|
||||
expected_classification_result):
|
||||
# Creates classifier.
|
||||
|
@ -193,8 +186,8 @@ class ImageClassifierTest(parameterized.TestCase):
|
|||
classifier.close()
|
||||
|
||||
@parameterized.parameters(
|
||||
(ModelFileType.FILE_NAME, 4, _generate_burger_results(0)),
|
||||
(ModelFileType.FILE_CONTENT, 4, _generate_burger_results(0)))
|
||||
(ModelFileType.FILE_NAME, 4, _generate_burger_results()),
|
||||
(ModelFileType.FILE_CONTENT, 4, _generate_burger_results()))
|
||||
def test_classify_in_context(self, model_file_type, max_results,
|
||||
expected_classification_result):
|
||||
if model_file_type is ModelFileType.FILE_NAME:
|
||||
|
@ -234,7 +227,7 @@ class ImageClassifierTest(parameterized.TestCase):
|
|||
image_result = classifier.classify(test_image, image_processing_options)
|
||||
# Comparing results.
|
||||
test_utils.assert_proto_equals(self, image_result.to_pb2(),
|
||||
_generate_soccer_ball_results(0).to_pb2())
|
||||
_generate_soccer_ball_results().to_pb2())
|
||||
|
||||
def test_score_threshold_option(self):
|
||||
custom_classifier_options = _ClassifierOptions(
|
||||
|
@ -248,8 +241,8 @@ class ImageClassifierTest(parameterized.TestCase):
|
|||
classifications = image_result.classifications
|
||||
|
||||
for classification in classifications:
|
||||
for entry in classification.entries:
|
||||
score = entry.categories[0].score
|
||||
for category in classification.categories:
|
||||
score = category.score
|
||||
self.assertGreaterEqual(
|
||||
score, _SCORE_THRESHOLD,
|
||||
f'Classification with score lower than threshold found. '
|
||||
|
@ -264,7 +257,7 @@ class ImageClassifierTest(parameterized.TestCase):
|
|||
with _ImageClassifier.create_from_options(options) as classifier:
|
||||
# Performs image classification on the input.
|
||||
image_result = classifier.classify(self.test_image)
|
||||
categories = image_result.classifications[0].entries[0].categories
|
||||
categories = image_result.classifications[0].categories
|
||||
|
||||
self.assertLessEqual(
|
||||
len(categories), _MAX_RESULTS, 'Too many results returned.')
|
||||
|
@ -281,8 +274,8 @@ class ImageClassifierTest(parameterized.TestCase):
|
|||
classifications = image_result.classifications
|
||||
|
||||
for classification in classifications:
|
||||
for entry in classification.entries:
|
||||
label = entry.categories[0].category_name
|
||||
for category in classification.categories:
|
||||
label = category.category_name
|
||||
self.assertIn(label, _ALLOW_LIST,
|
||||
f'Label {label} found but not in label allow list')
|
||||
|
||||
|
@ -297,8 +290,8 @@ class ImageClassifierTest(parameterized.TestCase):
|
|||
classifications = image_result.classifications
|
||||
|
||||
for classification in classifications:
|
||||
for entry in classification.entries:
|
||||
label = entry.categories[0].category_name
|
||||
for category in classification.categories:
|
||||
label = category.category_name
|
||||
self.assertNotIn(label, _DENY_LIST,
|
||||
f'Label {label} found but in deny list.')
|
||||
|
||||
|
@ -324,7 +317,7 @@ class ImageClassifierTest(parameterized.TestCase):
|
|||
with _ImageClassifier.create_from_options(options) as classifier:
|
||||
# Performs image classification on the input.
|
||||
image_result = classifier.classify(self.test_image)
|
||||
self.assertEmpty(image_result.classifications[0].entries[0].categories)
|
||||
self.assertEmpty(image_result.classifications[0].categories)
|
||||
|
||||
def test_missing_result_callback(self):
|
||||
options = _ImageClassifierOptions(
|
||||
|
@ -402,9 +395,8 @@ class ImageClassifierTest(parameterized.TestCase):
|
|||
for timestamp in range(0, 300, 30):
|
||||
classification_result = classifier.classify_for_video(
|
||||
self.test_image, timestamp)
|
||||
test_utils.assert_proto_equals(
|
||||
self, classification_result.to_pb2(),
|
||||
_generate_burger_results(timestamp).to_pb2())
|
||||
test_utils.assert_proto_equals(self, classification_result.to_pb2(),
|
||||
_generate_burger_results().to_pb2())
|
||||
|
||||
def test_classify_for_video_succeeds_with_region_of_interest(self):
|
||||
custom_classifier_options = _ClassifierOptions(max_results=1)
|
||||
|
@ -423,9 +415,8 @@ class ImageClassifierTest(parameterized.TestCase):
|
|||
for timestamp in range(0, 300, 30):
|
||||
classification_result = classifier.classify_for_video(
|
||||
test_image, timestamp, image_processing_options)
|
||||
test_utils.assert_proto_equals(
|
||||
self, classification_result.to_pb2(),
|
||||
_generate_soccer_ball_results(timestamp).to_pb2())
|
||||
test_utils.assert_proto_equals(self, classification_result.to_pb2(),
|
||||
_generate_soccer_ball_results().to_pb2())
|
||||
|
||||
def test_calling_classify_in_live_stream_mode(self):
|
||||
options = _ImageClassifierOptions(
|
||||
|
@ -460,15 +451,15 @@ class ImageClassifierTest(parameterized.TestCase):
|
|||
ValueError, r'Input timestamp must be monotonically increasing'):
|
||||
classifier.classify_async(self.test_image, 0)
|
||||
|
||||
@parameterized.parameters((0, _generate_burger_results),
|
||||
(1, _generate_empty_results))
|
||||
def test_classify_async_calls(self, threshold, expected_result_fn):
|
||||
@parameterized.parameters((0, _generate_burger_results()),
|
||||
(1, _generate_empty_results()))
|
||||
def test_classify_async_calls(self, threshold, expected_result):
|
||||
observed_timestamp_ms = -1
|
||||
|
||||
def check_result(result: _ClassificationResult, output_image: _Image,
|
||||
def check_result(result: ImageClassifierResult, output_image: _Image,
|
||||
timestamp_ms: int):
|
||||
test_utils.assert_proto_equals(self, result.to_pb2(),
|
||||
expected_result_fn(timestamp_ms).to_pb2())
|
||||
expected_result.to_pb2())
|
||||
self.assertTrue(
|
||||
np.array_equal(output_image.numpy_view(),
|
||||
self.test_image.numpy_view()))
|
||||
|
@ -496,11 +487,10 @@ class ImageClassifierTest(parameterized.TestCase):
|
|||
image_processing_options = _ImageProcessingOptions(roi)
|
||||
observed_timestamp_ms = -1
|
||||
|
||||
def check_result(result: _ClassificationResult, output_image: _Image,
|
||||
def check_result(result: ImageClassifierResult, output_image: _Image,
|
||||
timestamp_ms: int):
|
||||
test_utils.assert_proto_equals(
|
||||
self, result.to_pb2(),
|
||||
_generate_soccer_ball_results(timestamp_ms).to_pb2())
|
||||
test_utils.assert_proto_equals(self, result.to_pb2(),
|
||||
_generate_soccer_ball_results().to_pb2())
|
||||
self.assertEqual(output_image.width, test_image.width)
|
||||
self.assertEqual(output_image.height, test_image.height)
|
||||
self.assertLess(observed_timestamp_ms, timestamp_ms)
|
||||
|
|
|
@ -28,7 +28,7 @@ py_library(
|
|||
"//mediapipe/python:packet_getter",
|
||||
"//mediapipe/tasks/cc/components/containers/proto:classifications_py_pb2",
|
||||
"//mediapipe/tasks/cc/text/text_classifier/proto:text_classifier_graph_options_py_pb2",
|
||||
"//mediapipe/tasks/python/components/containers:classifications",
|
||||
"//mediapipe/tasks/python/components/containers:classification_result",
|
||||
"//mediapipe/tasks/python/components/processors:classifier_options",
|
||||
"//mediapipe/tasks/python/core:base_options",
|
||||
"//mediapipe/tasks/python/core:optional_dependencies",
|
||||
|
|
|
@ -19,21 +19,21 @@ from mediapipe.python import packet_creator
|
|||
from mediapipe.python import packet_getter
|
||||
from mediapipe.tasks.cc.components.containers.proto import classifications_pb2
|
||||
from mediapipe.tasks.cc.text.text_classifier.proto import text_classifier_graph_options_pb2
|
||||
from mediapipe.tasks.python.components.containers import classifications
|
||||
from mediapipe.tasks.python.components.containers import classification_result as classification_result_module
|
||||
from mediapipe.tasks.python.components.processors import classifier_options
|
||||
from mediapipe.tasks.python.core import base_options as base_options_module
|
||||
from mediapipe.tasks.python.core import task_info as task_info_module
|
||||
from mediapipe.tasks.python.core.optional_dependencies import doc_controls
|
||||
from mediapipe.tasks.python.text.core import base_text_task_api
|
||||
|
||||
TextClassifierResult = classifications.ClassificationResult
|
||||
TextClassifierResult = classification_result_module.ClassificationResult
|
||||
_BaseOptions = base_options_module.BaseOptions
|
||||
_TextClassifierGraphOptionsProto = text_classifier_graph_options_pb2.TextClassifierGraphOptions
|
||||
_ClassifierOptions = classifier_options.ClassifierOptions
|
||||
_TaskInfo = task_info_module.TaskInfo
|
||||
|
||||
_CLASSIFICATION_RESULT_OUT_STREAM_NAME = 'classification_result_out'
|
||||
_CLASSIFICATION_RESULT_TAG = 'CLASSIFICATION_RESULT'
|
||||
_CLASSIFICATIONS_STREAM_NAME = 'classifications_out'
|
||||
_CLASSIFICATIONS_TAG = 'CLASSIFICATIONS'
|
||||
_TEXT_IN_STREAM_NAME = 'text_in'
|
||||
_TEXT_TAG = 'TEXT'
|
||||
_TASK_GRAPH_NAME = 'mediapipe.tasks.text.text_classifier.TextClassifierGraph'
|
||||
|
@ -104,10 +104,7 @@ class TextClassifier(base_text_task_api.BaseTextTaskApi):
|
|||
task_graph=_TASK_GRAPH_NAME,
|
||||
input_streams=[':'.join([_TEXT_TAG, _TEXT_IN_STREAM_NAME])],
|
||||
output_streams=[
|
||||
':'.join([
|
||||
_CLASSIFICATION_RESULT_TAG,
|
||||
_CLASSIFICATION_RESULT_OUT_STREAM_NAME
|
||||
])
|
||||
':'.join([_CLASSIFICATIONS_TAG, _CLASSIFICATIONS_STREAM_NAME])
|
||||
],
|
||||
task_options=options)
|
||||
return cls(task_info.generate_graph_config())
|
||||
|
@ -131,10 +128,6 @@ class TextClassifier(base_text_task_api.BaseTextTaskApi):
|
|||
|
||||
classification_result_proto = classifications_pb2.ClassificationResult()
|
||||
classification_result_proto.CopyFrom(
|
||||
packet_getter.get_proto(
|
||||
output_packets[_CLASSIFICATION_RESULT_OUT_STREAM_NAME]))
|
||||
packet_getter.get_proto(output_packets[_CLASSIFICATIONS_STREAM_NAME]))
|
||||
|
||||
return TextClassifierResult([
|
||||
classifications.Classifications.create_from_pb2(classification)
|
||||
for classification in classification_result_proto.classifications
|
||||
])
|
||||
return TextClassifierResult.create_from_pb2(classification_result_proto)
|
||||
|
|
|
@ -48,7 +48,7 @@ py_library(
|
|||
"//mediapipe/python:packet_getter",
|
||||
"//mediapipe/tasks/cc/components/containers/proto:classifications_py_pb2",
|
||||
"//mediapipe/tasks/cc/vision/image_classifier/proto:image_classifier_graph_options_py_pb2",
|
||||
"//mediapipe/tasks/python/components/containers:classifications",
|
||||
"//mediapipe/tasks/python/components/containers:classification_result",
|
||||
"//mediapipe/tasks/python/components/containers:rect",
|
||||
"//mediapipe/tasks/python/components/processors:classifier_options",
|
||||
"//mediapipe/tasks/python/core:base_options",
|
||||
|
|
|
@ -18,12 +18,11 @@ from typing import Callable, Mapping, Optional
|
|||
|
||||
from mediapipe.python import packet_creator
|
||||
from mediapipe.python import packet_getter
|
||||
# TODO: Import MPImage directly one we have an alias
|
||||
from mediapipe.python._framework_bindings import image as image_module
|
||||
from mediapipe.python._framework_bindings import packet
|
||||
from mediapipe.tasks.cc.components.containers.proto import classifications_pb2
|
||||
from mediapipe.tasks.cc.vision.image_classifier.proto import image_classifier_graph_options_pb2
|
||||
from mediapipe.tasks.python.components.containers import classifications
|
||||
from mediapipe.tasks.python.components.containers import classification_result as classification_result_module
|
||||
from mediapipe.tasks.python.components.containers import rect
|
||||
from mediapipe.tasks.python.components.processors import classifier_options
|
||||
from mediapipe.tasks.python.core import base_options as base_options_module
|
||||
|
@ -33,6 +32,7 @@ from mediapipe.tasks.python.vision.core import base_vision_task_api
|
|||
from mediapipe.tasks.python.vision.core import image_processing_options as image_processing_options_module
|
||||
from mediapipe.tasks.python.vision.core import vision_task_running_mode
|
||||
|
||||
ImageClassifierResult = classification_result_module.ClassificationResult
|
||||
_NormalizedRect = rect.NormalizedRect
|
||||
_BaseOptions = base_options_module.BaseOptions
|
||||
_ImageClassifierGraphOptionsProto = image_classifier_graph_options_pb2.ImageClassifierGraphOptions
|
||||
|
@ -41,8 +41,8 @@ _RunningMode = vision_task_running_mode.VisionTaskRunningMode
|
|||
_ImageProcessingOptions = image_processing_options_module.ImageProcessingOptions
|
||||
_TaskInfo = task_info_module.TaskInfo
|
||||
|
||||
_CLASSIFICATION_RESULT_OUT_STREAM_NAME = 'classification_result_out'
|
||||
_CLASSIFICATION_RESULT_TAG = 'CLASSIFICATION_RESULT'
|
||||
_CLASSIFICATIONS_STREAM_NAME = 'classifications_out'
|
||||
_CLASSIFICATIONS_TAG = 'CLASSIFICATIONS'
|
||||
_IMAGE_IN_STREAM_NAME = 'image_in'
|
||||
_IMAGE_OUT_STREAM_NAME = 'image_out'
|
||||
_IMAGE_TAG = 'IMAGE'
|
||||
|
@ -71,9 +71,8 @@ class ImageClassifierOptions:
|
|||
base_options: _BaseOptions
|
||||
running_mode: _RunningMode = _RunningMode.IMAGE
|
||||
classifier_options: _ClassifierOptions = _ClassifierOptions()
|
||||
result_callback: Optional[
|
||||
Callable[[classifications.ClassificationResult, image_module.Image, int],
|
||||
None]] = None
|
||||
result_callback: Optional[Callable[
|
||||
[ImageClassifierResult, image_module.Image, int], None]] = None
|
||||
|
||||
@doc_controls.do_not_generate_docs
|
||||
def to_pb2(self) -> _ImageClassifierGraphOptionsProto:
|
||||
|
@ -137,17 +136,12 @@ class ImageClassifier(base_vision_task_api.BaseVisionTaskApi):
|
|||
|
||||
classification_result_proto = classifications_pb2.ClassificationResult()
|
||||
classification_result_proto.CopyFrom(
|
||||
packet_getter.get_proto(
|
||||
output_packets[_CLASSIFICATION_RESULT_OUT_STREAM_NAME]))
|
||||
|
||||
classification_result = classifications.ClassificationResult([
|
||||
classifications.Classifications.create_from_pb2(classification)
|
||||
for classification in classification_result_proto.classifications
|
||||
])
|
||||
packet_getter.get_proto(output_packets[_CLASSIFICATIONS_STREAM_NAME]))
|
||||
image = packet_getter.get_image(output_packets[_IMAGE_OUT_STREAM_NAME])
|
||||
timestamp = output_packets[_IMAGE_OUT_STREAM_NAME].timestamp
|
||||
options.result_callback(classification_result, image,
|
||||
timestamp.value // _MICRO_SECONDS_PER_MILLISECOND)
|
||||
options.result_callback(
|
||||
ImageClassifierResult.create_from_pb2(classification_result_proto),
|
||||
image, timestamp.value // _MICRO_SECONDS_PER_MILLISECOND)
|
||||
|
||||
task_info = _TaskInfo(
|
||||
task_graph=_TASK_GRAPH_NAME,
|
||||
|
@ -156,10 +150,8 @@ class ImageClassifier(base_vision_task_api.BaseVisionTaskApi):
|
|||
':'.join([_NORM_RECT_TAG, _NORM_RECT_STREAM_NAME]),
|
||||
],
|
||||
output_streams=[
|
||||
':'.join([
|
||||
_CLASSIFICATION_RESULT_TAG,
|
||||
_CLASSIFICATION_RESULT_OUT_STREAM_NAME
|
||||
]), ':'.join([_IMAGE_TAG, _IMAGE_OUT_STREAM_NAME])
|
||||
':'.join([_CLASSIFICATIONS_TAG, _CLASSIFICATIONS_STREAM_NAME]),
|
||||
':'.join([_IMAGE_TAG, _IMAGE_OUT_STREAM_NAME])
|
||||
],
|
||||
task_options=options)
|
||||
return cls(
|
||||
|
@ -172,7 +164,7 @@ class ImageClassifier(base_vision_task_api.BaseVisionTaskApi):
|
|||
self,
|
||||
image: image_module.Image,
|
||||
image_processing_options: Optional[_ImageProcessingOptions] = None
|
||||
) -> classifications.ClassificationResult:
|
||||
) -> ImageClassifierResult:
|
||||
"""Performs image classification on the provided MediaPipe Image.
|
||||
|
||||
Args:
|
||||
|
@ -196,20 +188,16 @@ class ImageClassifier(base_vision_task_api.BaseVisionTaskApi):
|
|||
|
||||
classification_result_proto = classifications_pb2.ClassificationResult()
|
||||
classification_result_proto.CopyFrom(
|
||||
packet_getter.get_proto(
|
||||
output_packets[_CLASSIFICATION_RESULT_OUT_STREAM_NAME]))
|
||||
packet_getter.get_proto(output_packets[_CLASSIFICATIONS_STREAM_NAME]))
|
||||
|
||||
return classifications.ClassificationResult([
|
||||
classifications.Classifications.create_from_pb2(classification)
|
||||
for classification in classification_result_proto.classifications
|
||||
])
|
||||
return ImageClassifierResult.create_from_pb2(classification_result_proto)
|
||||
|
||||
def classify_for_video(
|
||||
self,
|
||||
image: image_module.Image,
|
||||
timestamp_ms: int,
|
||||
image_processing_options: Optional[_ImageProcessingOptions] = None
|
||||
) -> classifications.ClassificationResult:
|
||||
) -> ImageClassifierResult:
|
||||
"""Performs image classification on the provided video frames.
|
||||
|
||||
Only use this method when the ImageClassifier is created with the video
|
||||
|
@ -241,13 +229,9 @@ class ImageClassifier(base_vision_task_api.BaseVisionTaskApi):
|
|||
|
||||
classification_result_proto = classifications_pb2.ClassificationResult()
|
||||
classification_result_proto.CopyFrom(
|
||||
packet_getter.get_proto(
|
||||
output_packets[_CLASSIFICATION_RESULT_OUT_STREAM_NAME]))
|
||||
packet_getter.get_proto(output_packets[_CLASSIFICATIONS_STREAM_NAME]))
|
||||
|
||||
return classifications.ClassificationResult([
|
||||
classifications.Classifications.create_from_pb2(classification)
|
||||
for classification in classification_result_proto.classifications
|
||||
])
|
||||
return ImageClassifierResult.create_from_pb2(classification_result_proto)
|
||||
|
||||
def classify_async(
|
||||
self,
|
||||
|
|
Loading…
Reference in New Issue
Block a user