Updated classifier python tasks to use the new classification_result dataclass

This commit is contained in:
kinaryml 2022-11-08 10:28:24 -08:00
parent 26066787b3
commit 0fd78c2ec6
11 changed files with 191 additions and 380 deletions

View File

@ -85,21 +85,12 @@ py_library(
], ],
) )
py_library(
name = "classifications",
srcs = ["classifications.py"],
deps = [
":category",
"//mediapipe/tasks/cc/components/containers/proto:classifications_py_pb2",
"//mediapipe/tasks/python/core:optional_dependencies",
],
)
py_library( py_library(
name = "classification_result", name = "classification_result",
srcs = ["classification_result.py"], srcs = ["classification_result.py"],
deps = [ deps = [
":category", ":category",
"//mediapipe/framework/formats:classification_py_pb2",
"//mediapipe/tasks/cc/components/containers/proto:classifications_py_pb2", "//mediapipe/tasks/cc/components/containers/proto:classifications_py_pb2",
"//mediapipe/tasks/python/core:optional_dependencies", "//mediapipe/tasks/python/core:optional_dependencies",
], ],

View File

@ -16,10 +16,13 @@
import dataclasses import dataclasses
from typing import List, Optional from typing import List, Optional
from mediapipe.framework.formats import classification_pb2
from mediapipe.tasks.cc.components.containers.proto import classifications_pb2 from mediapipe.tasks.cc.components.containers.proto import classifications_pb2
from mediapipe.tasks.python.components.containers import category as category_module from mediapipe.tasks.python.components.containers import category as category_module
from mediapipe.tasks.python.core.optional_dependencies import doc_controls from mediapipe.tasks.python.core.optional_dependencies import doc_controls
_ClassificationProto = classification_pb2.Classification
_ClassificationListProto = classification_pb2.ClassificationList
_ClassificationsProto = classifications_pb2.Classifications _ClassificationsProto = classifications_pb2.Classifications
_ClassificationResultProto = classifications_pb2.ClassificationResult _ClassificationResultProto = classifications_pb2.ClassificationResult
@ -41,6 +44,24 @@ class Classifications:
head_index: int head_index: int
head_name: Optional[str] = None head_name: Optional[str] = None
@doc_controls.do_not_generate_docs
def to_pb2(self) -> _ClassificationsProto:
"""Generates a Classifications protobuf object."""
classification_list_proto = _ClassificationListProto()
classification_list_proto.classification.extend([
_ClassificationProto(
index=category.index,
score=category.score,
label=category.category_name,
display_name=category.display_name
)
for category in self.categories
])
return _ClassificationsProto(
classification_list=classification_list_proto,
head_index=self.head_index,
head_name=self.head_name)
@classmethod @classmethod
@doc_controls.do_not_generate_docs @doc_controls.do_not_generate_docs
def create_from_pb2(cls, pb2_obj: _ClassificationsProto) -> 'Classifications': def create_from_pb2(cls, pb2_obj: _ClassificationsProto) -> 'Classifications':
@ -78,6 +99,16 @@ class ClassificationResult:
classifications: List[Classifications] classifications: List[Classifications]
timestamp_ms: Optional[int] = None timestamp_ms: Optional[int] = None
@doc_controls.do_not_generate_docs
def to_pb2(self) -> _ClassificationResultProto:
"""Generates a ClassificationResult protobuf object."""
return _ClassificationResultProto(
classifications=[
classification.to_pb2()
for classification in self.classifications
],
timestamp_ms=self.timestamp_ms)
@classmethod @classmethod
@doc_controls.do_not_generate_docs @doc_controls.do_not_generate_docs
def create_from_pb2( def create_from_pb2(

View File

@ -1,168 +0,0 @@
# Copyright 2022 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Classifications data class."""
import dataclasses
from typing import Any, List, Optional
from mediapipe.tasks.cc.components.containers.proto import classifications_pb2
from mediapipe.tasks.python.components.containers import category as category_module
from mediapipe.tasks.python.core.optional_dependencies import doc_controls
_ClassificationEntryProto = classifications_pb2.ClassificationEntry
_ClassificationsProto = classifications_pb2.Classifications
_ClassificationResultProto = classifications_pb2.ClassificationResult
@dataclasses.dataclass
class ClassificationEntry:
"""List of predicted classes (aka labels) for a given classifier head.
Attributes:
categories: The array of predicted categories, usually sorted by descending
scores (e.g. from high to low probability).
timestamp_ms: The optional timestamp (in milliseconds) associated to the
classification entry. This is useful for time series use cases, e.g.,
audio classification.
"""
categories: List[category_module.Category]
timestamp_ms: Optional[int] = None
@doc_controls.do_not_generate_docs
def to_pb2(self) -> _ClassificationEntryProto:
"""Generates a ClassificationEntry protobuf object."""
return _ClassificationEntryProto(
categories=[category.to_pb2() for category in self.categories],
timestamp_ms=self.timestamp_ms)
@classmethod
@doc_controls.do_not_generate_docs
def create_from_pb2(
cls, pb2_obj: _ClassificationEntryProto) -> 'ClassificationEntry':
"""Creates a `ClassificationEntry` object from the given protobuf object."""
return ClassificationEntry(
categories=[
category_module.Category.create_from_pb2(category)
for category in pb2_obj.categories
],
timestamp_ms=pb2_obj.timestamp_ms)
def __eq__(self, other: Any) -> bool:
"""Checks if this object is equal to the given object.
Args:
other: The object to be compared with.
Returns:
True if the objects are equal.
"""
if not isinstance(other, ClassificationEntry):
return False
return self.to_pb2().__eq__(other.to_pb2())
@dataclasses.dataclass
class Classifications:
"""Represents the classifications for a given classifier head.
Attributes:
entries: A list of `ClassificationEntry` objects.
head_index: The index of the classifier head these categories refer to. This
is useful for multi-head models.
head_name: The name of the classifier head, which is the corresponding
tensor metadata name.
"""
entries: List[ClassificationEntry]
head_index: int
head_name: str
@doc_controls.do_not_generate_docs
def to_pb2(self) -> _ClassificationsProto:
"""Generates a Classifications protobuf object."""
return _ClassificationsProto(
entries=[entry.to_pb2() for entry in self.entries],
head_index=self.head_index,
head_name=self.head_name)
@classmethod
@doc_controls.do_not_generate_docs
def create_from_pb2(cls, pb2_obj: _ClassificationsProto) -> 'Classifications':
"""Creates a `Classifications` object from the given protobuf object."""
return Classifications(
entries=[
ClassificationEntry.create_from_pb2(entry)
for entry in pb2_obj.entries
],
head_index=pb2_obj.head_index,
head_name=pb2_obj.head_name)
def __eq__(self, other: Any) -> bool:
"""Checks if this object is equal to the given object.
Args:
other: The object to be compared with.
Returns:
True if the objects are equal.
"""
if not isinstance(other, Classifications):
return False
return self.to_pb2().__eq__(other.to_pb2())
@dataclasses.dataclass
class ClassificationResult:
"""Contains one set of results per classifier head.
Attributes:
classifications: A list of `Classifications` objects.
"""
classifications: List[Classifications]
@doc_controls.do_not_generate_docs
def to_pb2(self) -> _ClassificationResultProto:
"""Generates a ClassificationResult protobuf object."""
return _ClassificationResultProto(classifications=[
classification.to_pb2() for classification in self.classifications
])
@classmethod
@doc_controls.do_not_generate_docs
def create_from_pb2(
cls, pb2_obj: _ClassificationResultProto) -> 'ClassificationResult':
"""Creates a `ClassificationResult` object from the given protobuf object.
"""
return ClassificationResult(classifications=[
Classifications.create_from_pb2(classification)
for classification in pb2_obj.classifications
])
def __eq__(self, other: Any) -> bool:
"""Checks if this object is equal to the given object.
Args:
other: The object to be compared with.
Returns:
True if the objects are equal.
"""
if not isinstance(other, ClassificationResult):
return False
return self.to_pb2().__eq__(other.to_pb2())

View File

@ -27,7 +27,7 @@ py_test(
], ],
deps = [ deps = [
"//mediapipe/tasks/python/components/containers:category", "//mediapipe/tasks/python/components/containers:category",
"//mediapipe/tasks/python/components/containers:classifications", "//mediapipe/tasks/python/components/containers:classification_result",
"//mediapipe/tasks/python/components/processors:classifier_options", "//mediapipe/tasks/python/components/processors:classifier_options",
"//mediapipe/tasks/python/core:base_options", "//mediapipe/tasks/python/core:base_options",
"//mediapipe/tasks/python/test:test_utils", "//mediapipe/tasks/python/test:test_utils",

View File

@ -20,18 +20,17 @@ from absl.testing import absltest
from absl.testing import parameterized from absl.testing import parameterized
from mediapipe.tasks.python.components.containers import category from mediapipe.tasks.python.components.containers import category
from mediapipe.tasks.python.components.containers import classifications as classifications_module from mediapipe.tasks.python.components.containers import classification_result as classification_result_module
from mediapipe.tasks.python.components.processors import classifier_options from mediapipe.tasks.python.components.processors import classifier_options
from mediapipe.tasks.python.core import base_options as base_options_module from mediapipe.tasks.python.core import base_options as base_options_module
from mediapipe.tasks.python.test import test_utils from mediapipe.tasks.python.test import test_utils
from mediapipe.tasks.python.text import text_classifier from mediapipe.tasks.python.text import text_classifier
TextClassifierResult = classification_result_module.ClassificationResult
_BaseOptions = base_options_module.BaseOptions _BaseOptions = base_options_module.BaseOptions
_ClassifierOptions = classifier_options.ClassifierOptions _ClassifierOptions = classifier_options.ClassifierOptions
_Category = category.Category _Category = category.Category
_ClassificationEntry = classifications_module.ClassificationEntry _Classifications = classification_result_module.Classifications
_Classifications = classifications_module.Classifications
_TextClassifierResult = classifications_module.ClassificationResult
_TextClassifier = text_classifier.TextClassifier _TextClassifier = text_classifier.TextClassifier
_TextClassifierOptions = text_classifier.TextClassifierOptions _TextClassifierOptions = text_classifier.TextClassifierOptions
@ -43,90 +42,75 @@ _NEGATIVE_TEXT = 'What a waste of my time.'
_POSITIVE_TEXT = ('This is the best movie Ive seen in recent years.' _POSITIVE_TEXT = ('This is the best movie Ive seen in recent years.'
'Strongly recommend it!') 'Strongly recommend it!')
_BERT_NEGATIVE_RESULTS = _TextClassifierResult(classifications=[
_Classifications( _BERT_NEGATIVE_RESULTS = TextClassifierResult(classifications=[
entries=[ _Classifications(categories=[
_ClassificationEntry( _Category(
categories=[ index=0,
_Category( score=0.999479,
index=0, display_name='',
score=0.999479, category_name='negative'),
display_name='', _Category(
category_name='negative'), index=1,
_Category( score=0.00052154,
index=1, display_name='',
score=0.00052154, category_name='positive')
display_name='',
category_name='positive')
],
timestamp_ms=0)
], ],
head_index=0, head_index=0,
head_name='probability') head_name='probability')
]) ],
_BERT_POSITIVE_RESULTS = _TextClassifierResult(classifications=[ timestamp_ms=0)
_Classifications( _BERT_POSITIVE_RESULTS = TextClassifierResult(classifications=[
entries=[ _Classifications(categories=[
_ClassificationEntry( _Category(
categories=[ index=1,
_Category( score=0.999466,
index=1, display_name='',
score=0.999466, category_name='positive'),
display_name='', _Category(
category_name='positive'), index=0,
_Category( score=0.000533596,
index=0, display_name='',
score=0.000533596, category_name='negative')
display_name='',
category_name='negative')
],
timestamp_ms=0)
], ],
head_index=0, head_index=0,
head_name='probability') head_name='probability')
]) ],
_REGEX_NEGATIVE_RESULTS = _TextClassifierResult(classifications=[ timestamp_ms=0)
_Classifications( _REGEX_NEGATIVE_RESULTS = TextClassifierResult(classifications=[
entries=[ _Classifications(categories=[
_ClassificationEntry( _Category(
categories=[ index=0,
_Category( score=0.81313,
index=0, display_name='',
score=0.81313, category_name='Negative'),
display_name='', _Category(
category_name='Negative'), index=1,
_Category( score=0.1868704,
index=1, display_name='',
score=0.1868704, category_name='Positive')
display_name='',
category_name='Positive')
],
timestamp_ms=0)
], ],
head_index=0, head_index=0,
head_name='probability') head_name='probability')
]) ],
_REGEX_POSITIVE_RESULTS = _TextClassifierResult(classifications=[ timestamp_ms=0)
_Classifications( _REGEX_POSITIVE_RESULTS = TextClassifierResult(classifications=[
entries=[ _Classifications(categories=[
_ClassificationEntry( _Category(
categories=[ index=1,
_Category( score=0.5134273,
index=1, display_name='',
score=0.5134273, category_name='Positive'),
display_name='', _Category(
category_name='Positive'), index=0,
_Category( score=0.486573,
index=0, display_name='',
score=0.486573, category_name='Negative')
display_name='',
category_name='Negative')
],
timestamp_ms=0)
], ],
head_index=0, head_index=0,
head_name='probability') head_name='probability')
]) ],
timestamp_ms=0)
class ModelFileType(enum.Enum): class ModelFileType(enum.Enum):

View File

@ -47,7 +47,7 @@ py_test(
deps = [ deps = [
"//mediapipe/python:_framework_bindings", "//mediapipe/python:_framework_bindings",
"//mediapipe/tasks/python/components/containers:category", "//mediapipe/tasks/python/components/containers:category",
"//mediapipe/tasks/python/components/containers:classifications", "//mediapipe/tasks/python/components/containers:classification_result",
"//mediapipe/tasks/python/components/containers:rect", "//mediapipe/tasks/python/components/containers:rect",
"//mediapipe/tasks/python/components/processors:classifier_options", "//mediapipe/tasks/python/components/processors:classifier_options",
"//mediapipe/tasks/python/core:base_options", "//mediapipe/tasks/python/core:base_options",

View File

@ -23,8 +23,8 @@ from absl.testing import parameterized
import numpy as np import numpy as np
from mediapipe.python._framework_bindings import image from mediapipe.python._framework_bindings import image
from mediapipe.tasks.python.components.containers import category from mediapipe.tasks.python.components.containers import category as category_module
from mediapipe.tasks.python.components.containers import classifications as classifications_module from mediapipe.tasks.python.components.containers import classification_result as classification_result_module
from mediapipe.tasks.python.components.containers import rect from mediapipe.tasks.python.components.containers import rect
from mediapipe.tasks.python.components.processors import classifier_options from mediapipe.tasks.python.components.processors import classifier_options
from mediapipe.tasks.python.core import base_options as base_options_module from mediapipe.tasks.python.core import base_options as base_options_module
@ -33,13 +33,12 @@ from mediapipe.tasks.python.vision import image_classifier
from mediapipe.tasks.python.vision.core import image_processing_options as image_processing_options_module from mediapipe.tasks.python.vision.core import image_processing_options as image_processing_options_module
from mediapipe.tasks.python.vision.core import vision_task_running_mode from mediapipe.tasks.python.vision.core import vision_task_running_mode
ImageClassifierResult = classification_result_module.ClassificationResult
_Rect = rect.Rect _Rect = rect.Rect
_BaseOptions = base_options_module.BaseOptions _BaseOptions = base_options_module.BaseOptions
_ClassifierOptions = classifier_options.ClassifierOptions _ClassifierOptions = classifier_options.ClassifierOptions
_Category = category.Category _Category = category_module.Category
_ClassificationEntry = classifications_module.ClassificationEntry _Classifications = classification_result_module.Classifications
_Classifications = classifications_module.Classifications
_ClassificationResult = classifications_module.ClassificationResult
_Image = image.Image _Image = image.Image
_ImageClassifier = image_classifier.ImageClassifier _ImageClassifier = image_classifier.ImageClassifier
_ImageClassifierOptions = image_classifier.ImageClassifierOptions _ImageClassifierOptions = image_classifier.ImageClassifierOptions
@ -55,68 +54,56 @@ _MAX_RESULTS = 3
_TEST_DATA_DIR = 'mediapipe/tasks/testdata/vision' _TEST_DATA_DIR = 'mediapipe/tasks/testdata/vision'
def _generate_empty_results(timestamp_ms: int) -> _ClassificationResult: def _generate_empty_results() -> ImageClassifierResult:
return _ClassificationResult(classifications=[ return ImageClassifierResult(classifications=[
_Classifications( _Classifications(categories=[], head_index=0, head_name='probability')
entries=[ ],
_ClassificationEntry(categories=[], timestamp_ms=timestamp_ms) timestamp_ms=0)
def _generate_burger_results() -> ImageClassifierResult:
return ImageClassifierResult(classifications=[
_Classifications(categories=[
_Category(
index=934,
score=0.793959,
display_name='',
category_name='cheeseburger'),
_Category(
index=932,
score=0.0273929,
display_name='',
category_name='bagel'),
_Category(
index=925,
score=0.0193408,
display_name='',
category_name='guacamole'),
_Category(
index=963,
score=0.00632786,
display_name='',
category_name='meat loaf')
], ],
head_index=0, head_index=0,
head_name='probability') head_name='probability')
]) ],
timestamp_ms=0)
def _generate_burger_results(timestamp_ms: int) -> _ClassificationResult: def _generate_soccer_ball_results() -> ImageClassifierResult:
return _ClassificationResult(classifications=[ return ImageClassifierResult(classifications=[
_Classifications( _Classifications(categories=[
entries=[ _Category(
_ClassificationEntry( index=806,
categories=[ score=0.996527,
_Category( display_name='',
index=934, category_name='soccer ball')
score=0.793959,
display_name='',
category_name='cheeseburger'),
_Category(
index=932,
score=0.0273929,
display_name='',
category_name='bagel'),
_Category(
index=925,
score=0.0193408,
display_name='',
category_name='guacamole'),
_Category(
index=963,
score=0.00632786,
display_name='',
category_name='meat loaf')
],
timestamp_ms=timestamp_ms)
], ],
head_index=0, head_index=0,
head_name='probability') head_name='probability')
]) ],
timestamp_ms=0)
def _generate_soccer_ball_results(timestamp_ms: int) -> _ClassificationResult:
return _ClassificationResult(classifications=[
_Classifications(
entries=[
_ClassificationEntry(
categories=[
_Category(
index=806,
score=0.996527,
display_name='',
category_name='soccer ball')
],
timestamp_ms=timestamp_ms)
],
head_index=0,
head_name='probability')
])
class ModelFileType(enum.Enum): class ModelFileType(enum.Enum):
@ -165,8 +152,8 @@ class ImageClassifierTest(parameterized.TestCase):
self.assertIsInstance(classifier, _ImageClassifier) self.assertIsInstance(classifier, _ImageClassifier)
@parameterized.parameters( @parameterized.parameters(
(ModelFileType.FILE_NAME, 4, _generate_burger_results(0)), (ModelFileType.FILE_NAME, 4, _generate_burger_results()),
(ModelFileType.FILE_CONTENT, 4, _generate_burger_results(0))) (ModelFileType.FILE_CONTENT, 4, _generate_burger_results()))
def test_classify(self, model_file_type, max_results, def test_classify(self, model_file_type, max_results,
expected_classification_result): expected_classification_result):
# Creates classifier. # Creates classifier.
@ -195,8 +182,8 @@ class ImageClassifierTest(parameterized.TestCase):
classifier.close() classifier.close()
@parameterized.parameters( @parameterized.parameters(
(ModelFileType.FILE_NAME, 4, _generate_burger_results(0)), (ModelFileType.FILE_NAME, 4, _generate_burger_results()),
(ModelFileType.FILE_CONTENT, 4, _generate_burger_results(0))) (ModelFileType.FILE_CONTENT, 4, _generate_burger_results()))
def test_classify_in_context(self, model_file_type, max_results, def test_classify_in_context(self, model_file_type, max_results,
expected_classification_result): expected_classification_result):
if model_file_type is ModelFileType.FILE_NAME: if model_file_type is ModelFileType.FILE_NAME:
@ -236,7 +223,7 @@ class ImageClassifierTest(parameterized.TestCase):
image_result = classifier.classify(test_image, image_processing_options) image_result = classifier.classify(test_image, image_processing_options)
# Comparing results. # Comparing results.
test_utils.assert_proto_equals(self, image_result.to_pb2(), test_utils.assert_proto_equals(self, image_result.to_pb2(),
_generate_soccer_ball_results(0).to_pb2()) _generate_soccer_ball_results().to_pb2())
def test_score_threshold_option(self): def test_score_threshold_option(self):
custom_classifier_options = _ClassifierOptions( custom_classifier_options = _ClassifierOptions(
@ -250,8 +237,8 @@ class ImageClassifierTest(parameterized.TestCase):
classifications = image_result.classifications classifications = image_result.classifications
for classification in classifications: for classification in classifications:
for entry in classification.entries: for category in classification.categories:
score = entry.categories[0].score score = category.score
self.assertGreaterEqual( self.assertGreaterEqual(
score, _SCORE_THRESHOLD, score, _SCORE_THRESHOLD,
f'Classification with score lower than threshold found. ' f'Classification with score lower than threshold found. '
@ -266,7 +253,7 @@ class ImageClassifierTest(parameterized.TestCase):
with _ImageClassifier.create_from_options(options) as classifier: with _ImageClassifier.create_from_options(options) as classifier:
# Performs image classification on the input. # Performs image classification on the input.
image_result = classifier.classify(self.test_image) image_result = classifier.classify(self.test_image)
categories = image_result.classifications[0].entries[0].categories categories = image_result.classifications[0].categories
self.assertLessEqual( self.assertLessEqual(
len(categories), _MAX_RESULTS, 'Too many results returned.') len(categories), _MAX_RESULTS, 'Too many results returned.')
@ -283,8 +270,8 @@ class ImageClassifierTest(parameterized.TestCase):
classifications = image_result.classifications classifications = image_result.classifications
for classification in classifications: for classification in classifications:
for entry in classification.entries: for category in classification.categories:
label = entry.categories[0].category_name label = category.category_name
self.assertIn(label, _ALLOW_LIST, self.assertIn(label, _ALLOW_LIST,
f'Label {label} found but not in label allow list') f'Label {label} found but not in label allow list')
@ -299,8 +286,8 @@ class ImageClassifierTest(parameterized.TestCase):
classifications = image_result.classifications classifications = image_result.classifications
for classification in classifications: for classification in classifications:
for entry in classification.entries: for category in classification.categories:
label = entry.categories[0].category_name label = category.category_name
self.assertNotIn(label, _DENY_LIST, self.assertNotIn(label, _DENY_LIST,
f'Label {label} found but in deny list.') f'Label {label} found but in deny list.')
@ -326,7 +313,7 @@ class ImageClassifierTest(parameterized.TestCase):
with _ImageClassifier.create_from_options(options) as classifier: with _ImageClassifier.create_from_options(options) as classifier:
# Performs image classification on the input. # Performs image classification on the input.
image_result = classifier.classify(self.test_image) image_result = classifier.classify(self.test_image)
self.assertEmpty(image_result.classifications[0].entries[0].categories) self.assertEmpty(image_result.classifications[0].categories)
def test_missing_result_callback(self): def test_missing_result_callback(self):
options = _ImageClassifierOptions( options = _ImageClassifierOptions(
@ -406,7 +393,7 @@ class ImageClassifierTest(parameterized.TestCase):
self.test_image, timestamp) self.test_image, timestamp)
test_utils.assert_proto_equals( test_utils.assert_proto_equals(
self, classification_result.to_pb2(), self, classification_result.to_pb2(),
_generate_burger_results(timestamp).to_pb2()) _generate_burger_results().to_pb2())
def test_classify_for_video_succeeds_with_region_of_interest(self): def test_classify_for_video_succeeds_with_region_of_interest(self):
custom_classifier_options = _ClassifierOptions(max_results=1) custom_classifier_options = _ClassifierOptions(max_results=1)
@ -427,7 +414,7 @@ class ImageClassifierTest(parameterized.TestCase):
test_image, timestamp, image_processing_options) test_image, timestamp, image_processing_options)
test_utils.assert_proto_equals( test_utils.assert_proto_equals(
self, classification_result.to_pb2(), self, classification_result.to_pb2(),
_generate_soccer_ball_results(timestamp).to_pb2()) _generate_soccer_ball_results().to_pb2())
def test_calling_classify_in_live_stream_mode(self): def test_calling_classify_in_live_stream_mode(self):
options = _ImageClassifierOptions( options = _ImageClassifierOptions(
@ -462,15 +449,15 @@ class ImageClassifierTest(parameterized.TestCase):
ValueError, r'Input timestamp must be monotonically increasing'): ValueError, r'Input timestamp must be monotonically increasing'):
classifier.classify_async(self.test_image, 0) classifier.classify_async(self.test_image, 0)
@parameterized.parameters((0, _generate_burger_results), @parameterized.parameters((0, _generate_burger_results()),
(1, _generate_empty_results)) (1, _generate_empty_results()))
def test_classify_async_calls(self, threshold, expected_result_fn): def test_classify_async_calls(self, threshold, expected_result):
observed_timestamp_ms = -1 observed_timestamp_ms = -1
def check_result(result: _ClassificationResult, output_image: _Image, def check_result(result: ImageClassifierResult, output_image: _Image,
timestamp_ms: int): timestamp_ms: int):
test_utils.assert_proto_equals(self, result.to_pb2(), test_utils.assert_proto_equals(self, result.to_pb2(),
expected_result_fn(timestamp_ms).to_pb2()) expected_result.to_pb2())
self.assertTrue( self.assertTrue(
np.array_equal(output_image.numpy_view(), np.array_equal(output_image.numpy_view(),
self.test_image.numpy_view())) self.test_image.numpy_view()))
@ -498,11 +485,11 @@ class ImageClassifierTest(parameterized.TestCase):
image_processing_options = _ImageProcessingOptions(roi) image_processing_options = _ImageProcessingOptions(roi)
observed_timestamp_ms = -1 observed_timestamp_ms = -1
def check_result(result: _ClassificationResult, output_image: _Image, def check_result(result: ImageClassifierResult, output_image: _Image,
timestamp_ms: int): timestamp_ms: int):
test_utils.assert_proto_equals( test_utils.assert_proto_equals(
self, result.to_pb2(), self, result.to_pb2(),
_generate_soccer_ball_results(timestamp_ms).to_pb2()) _generate_soccer_ball_results().to_pb2())
self.assertEqual(output_image.width, test_image.width) self.assertEqual(output_image.width, test_image.width)
self.assertEqual(output_image.height, test_image.height) self.assertEqual(output_image.height, test_image.height)
self.assertLess(observed_timestamp_ms, timestamp_ms) self.assertLess(observed_timestamp_ms, timestamp_ms)

View File

@ -28,7 +28,7 @@ py_library(
"//mediapipe/python:packet_getter", "//mediapipe/python:packet_getter",
"//mediapipe/tasks/cc/components/containers/proto:classifications_py_pb2", "//mediapipe/tasks/cc/components/containers/proto:classifications_py_pb2",
"//mediapipe/tasks/cc/text/text_classifier/proto:text_classifier_graph_options_py_pb2", "//mediapipe/tasks/cc/text/text_classifier/proto:text_classifier_graph_options_py_pb2",
"//mediapipe/tasks/python/components/containers:classifications", "//mediapipe/tasks/python/components/containers:classification_result",
"//mediapipe/tasks/python/components/processors:classifier_options", "//mediapipe/tasks/python/components/processors:classifier_options",
"//mediapipe/tasks/python/core:base_options", "//mediapipe/tasks/python/core:base_options",
"//mediapipe/tasks/python/core:optional_dependencies", "//mediapipe/tasks/python/core:optional_dependencies",

View File

@ -19,21 +19,21 @@ from mediapipe.python import packet_creator
from mediapipe.python import packet_getter from mediapipe.python import packet_getter
from mediapipe.tasks.cc.components.containers.proto import classifications_pb2 from mediapipe.tasks.cc.components.containers.proto import classifications_pb2
from mediapipe.tasks.cc.text.text_classifier.proto import text_classifier_graph_options_pb2 from mediapipe.tasks.cc.text.text_classifier.proto import text_classifier_graph_options_pb2
from mediapipe.tasks.python.components.containers import classifications from mediapipe.tasks.python.components.containers import classification_result as classification_result_module
from mediapipe.tasks.python.components.processors import classifier_options from mediapipe.tasks.python.components.processors import classifier_options
from mediapipe.tasks.python.core import base_options as base_options_module from mediapipe.tasks.python.core import base_options as base_options_module
from mediapipe.tasks.python.core import task_info as task_info_module from mediapipe.tasks.python.core import task_info as task_info_module
from mediapipe.tasks.python.core.optional_dependencies import doc_controls from mediapipe.tasks.python.core.optional_dependencies import doc_controls
from mediapipe.tasks.python.text.core import base_text_task_api from mediapipe.tasks.python.text.core import base_text_task_api
TextClassifierResult = classifications.ClassificationResult TextClassifierResult = classification_result_module.ClassificationResult
_BaseOptions = base_options_module.BaseOptions _BaseOptions = base_options_module.BaseOptions
_TextClassifierGraphOptionsProto = text_classifier_graph_options_pb2.TextClassifierGraphOptions _TextClassifierGraphOptionsProto = text_classifier_graph_options_pb2.TextClassifierGraphOptions
_ClassifierOptions = classifier_options.ClassifierOptions _ClassifierOptions = classifier_options.ClassifierOptions
_TaskInfo = task_info_module.TaskInfo _TaskInfo = task_info_module.TaskInfo
_CLASSIFICATION_RESULT_OUT_STREAM_NAME = 'classification_result_out' _CLASSIFICATIONS_STREAM_NAME = 'classifications_out'
_CLASSIFICATION_RESULT_TAG = 'CLASSIFICATION_RESULT' _CLASSIFICATIONS_TAG = 'CLASSIFICATIONS'
_TEXT_IN_STREAM_NAME = 'text_in' _TEXT_IN_STREAM_NAME = 'text_in'
_TEXT_TAG = 'TEXT' _TEXT_TAG = 'TEXT'
_TASK_GRAPH_NAME = 'mediapipe.tasks.text.text_classifier.TextClassifierGraph' _TASK_GRAPH_NAME = 'mediapipe.tasks.text.text_classifier.TextClassifierGraph'
@ -105,8 +105,8 @@ class TextClassifier(base_text_task_api.BaseTextTaskApi):
input_streams=[':'.join([_TEXT_TAG, _TEXT_IN_STREAM_NAME])], input_streams=[':'.join([_TEXT_TAG, _TEXT_IN_STREAM_NAME])],
output_streams=[ output_streams=[
':'.join([ ':'.join([
_CLASSIFICATION_RESULT_TAG, _CLASSIFICATIONS_TAG,
_CLASSIFICATION_RESULT_OUT_STREAM_NAME _CLASSIFICATIONS_STREAM_NAME
]) ])
], ],
task_options=options) task_options=options)
@ -132,9 +132,6 @@ class TextClassifier(base_text_task_api.BaseTextTaskApi):
classification_result_proto = classifications_pb2.ClassificationResult() classification_result_proto = classifications_pb2.ClassificationResult()
classification_result_proto.CopyFrom( classification_result_proto.CopyFrom(
packet_getter.get_proto( packet_getter.get_proto(
output_packets[_CLASSIFICATION_RESULT_OUT_STREAM_NAME])) output_packets[_CLASSIFICATIONS_STREAM_NAME]))
return TextClassifierResult([ return TextClassifierResult.create_from_pb2(classification_result_proto)
classifications.Classifications.create_from_pb2(classification)
for classification in classification_result_proto.classifications
])

View File

@ -48,7 +48,7 @@ py_library(
"//mediapipe/python:packet_getter", "//mediapipe/python:packet_getter",
"//mediapipe/tasks/cc/components/containers/proto:classifications_py_pb2", "//mediapipe/tasks/cc/components/containers/proto:classifications_py_pb2",
"//mediapipe/tasks/cc/vision/image_classifier/proto:image_classifier_graph_options_py_pb2", "//mediapipe/tasks/cc/vision/image_classifier/proto:image_classifier_graph_options_py_pb2",
"//mediapipe/tasks/python/components/containers:classifications", "//mediapipe/tasks/python/components/containers:classification_result",
"//mediapipe/tasks/python/components/containers:rect", "//mediapipe/tasks/python/components/containers:rect",
"//mediapipe/tasks/python/components/processors:classifier_options", "//mediapipe/tasks/python/components/processors:classifier_options",
"//mediapipe/tasks/python/core:base_options", "//mediapipe/tasks/python/core:base_options",

View File

@ -23,7 +23,7 @@ from mediapipe.python._framework_bindings import image as image_module
from mediapipe.python._framework_bindings import packet from mediapipe.python._framework_bindings import packet
from mediapipe.tasks.cc.components.containers.proto import classifications_pb2 from mediapipe.tasks.cc.components.containers.proto import classifications_pb2
from mediapipe.tasks.cc.vision.image_classifier.proto import image_classifier_graph_options_pb2 from mediapipe.tasks.cc.vision.image_classifier.proto import image_classifier_graph_options_pb2
from mediapipe.tasks.python.components.containers import classifications from mediapipe.tasks.python.components.containers import classification_result as classification_result_module
from mediapipe.tasks.python.components.containers import rect from mediapipe.tasks.python.components.containers import rect
from mediapipe.tasks.python.components.processors import classifier_options from mediapipe.tasks.python.components.processors import classifier_options
from mediapipe.tasks.python.core import base_options as base_options_module from mediapipe.tasks.python.core import base_options as base_options_module
@ -33,6 +33,7 @@ from mediapipe.tasks.python.vision.core import base_vision_task_api
from mediapipe.tasks.python.vision.core import image_processing_options as image_processing_options_module from mediapipe.tasks.python.vision.core import image_processing_options as image_processing_options_module
from mediapipe.tasks.python.vision.core import vision_task_running_mode from mediapipe.tasks.python.vision.core import vision_task_running_mode
ImageClassifierResult = classification_result_module.ClassificationResult
_NormalizedRect = rect.NormalizedRect _NormalizedRect = rect.NormalizedRect
_BaseOptions = base_options_module.BaseOptions _BaseOptions = base_options_module.BaseOptions
_ImageClassifierGraphOptionsProto = image_classifier_graph_options_pb2.ImageClassifierGraphOptions _ImageClassifierGraphOptionsProto = image_classifier_graph_options_pb2.ImageClassifierGraphOptions
@ -41,8 +42,8 @@ _RunningMode = vision_task_running_mode.VisionTaskRunningMode
_ImageProcessingOptions = image_processing_options_module.ImageProcessingOptions _ImageProcessingOptions = image_processing_options_module.ImageProcessingOptions
_TaskInfo = task_info_module.TaskInfo _TaskInfo = task_info_module.TaskInfo
_CLASSIFICATION_RESULT_OUT_STREAM_NAME = 'classification_result_out' _CLASSIFICATIONS_STREAM_NAME = 'classifications_out'
_CLASSIFICATION_RESULT_TAG = 'CLASSIFICATION_RESULT' _CLASSIFICATIONS_TAG = 'CLASSIFICATIONS'
_IMAGE_IN_STREAM_NAME = 'image_in' _IMAGE_IN_STREAM_NAME = 'image_in'
_IMAGE_OUT_STREAM_NAME = 'image_out' _IMAGE_OUT_STREAM_NAME = 'image_out'
_IMAGE_TAG = 'IMAGE' _IMAGE_TAG = 'IMAGE'
@ -72,7 +73,7 @@ class ImageClassifierOptions:
running_mode: _RunningMode = _RunningMode.IMAGE running_mode: _RunningMode = _RunningMode.IMAGE
classifier_options: _ClassifierOptions = _ClassifierOptions() classifier_options: _ClassifierOptions = _ClassifierOptions()
result_callback: Optional[ result_callback: Optional[
Callable[[classifications.ClassificationResult, image_module.Image, int], Callable[[ImageClassifierResult, image_module.Image, int],
None]] = None None]] = None
@doc_controls.do_not_generate_docs @doc_controls.do_not_generate_docs
@ -137,17 +138,13 @@ class ImageClassifier(base_vision_task_api.BaseVisionTaskApi):
classification_result_proto = classifications_pb2.ClassificationResult() classification_result_proto = classifications_pb2.ClassificationResult()
classification_result_proto.CopyFrom( classification_result_proto.CopyFrom(
packet_getter.get_proto( packet_getter.get_proto(output_packets[_CLASSIFICATIONS_STREAM_NAME]))
output_packets[_CLASSIFICATION_RESULT_OUT_STREAM_NAME]))
classification_result = classifications.ClassificationResult([
classifications.Classifications.create_from_pb2(classification)
for classification in classification_result_proto.classifications
])
image = packet_getter.get_image(output_packets[_IMAGE_OUT_STREAM_NAME]) image = packet_getter.get_image(output_packets[_IMAGE_OUT_STREAM_NAME])
timestamp = output_packets[_IMAGE_OUT_STREAM_NAME].timestamp timestamp = output_packets[_IMAGE_OUT_STREAM_NAME].timestamp
options.result_callback(classification_result, image, options.result_callback(
timestamp.value // _MICRO_SECONDS_PER_MILLISECOND) ImageClassifierResult.create_from_pb2(classification_result_proto),
image,
timestamp.value // _MICRO_SECONDS_PER_MILLISECOND)
task_info = _TaskInfo( task_info = _TaskInfo(
task_graph=_TASK_GRAPH_NAME, task_graph=_TASK_GRAPH_NAME,
@ -157,8 +154,8 @@ class ImageClassifier(base_vision_task_api.BaseVisionTaskApi):
], ],
output_streams=[ output_streams=[
':'.join([ ':'.join([
_CLASSIFICATION_RESULT_TAG, _CLASSIFICATIONS_TAG,
_CLASSIFICATION_RESULT_OUT_STREAM_NAME _CLASSIFICATIONS_STREAM_NAME
]), ':'.join([_IMAGE_TAG, _IMAGE_OUT_STREAM_NAME]) ]), ':'.join([_IMAGE_TAG, _IMAGE_OUT_STREAM_NAME])
], ],
task_options=options) task_options=options)
@ -172,7 +169,7 @@ class ImageClassifier(base_vision_task_api.BaseVisionTaskApi):
self, self,
image: image_module.Image, image: image_module.Image,
image_processing_options: Optional[_ImageProcessingOptions] = None image_processing_options: Optional[_ImageProcessingOptions] = None
) -> classifications.ClassificationResult: ) -> ImageClassifierResult:
"""Performs image classification on the provided MediaPipe Image. """Performs image classification on the provided MediaPipe Image.
Args: Args:
@ -196,20 +193,16 @@ class ImageClassifier(base_vision_task_api.BaseVisionTaskApi):
classification_result_proto = classifications_pb2.ClassificationResult() classification_result_proto = classifications_pb2.ClassificationResult()
classification_result_proto.CopyFrom( classification_result_proto.CopyFrom(
packet_getter.get_proto( packet_getter.get_proto(output_packets[_CLASSIFICATIONS_STREAM_NAME]))
output_packets[_CLASSIFICATION_RESULT_OUT_STREAM_NAME]))
return classifications.ClassificationResult([ return ImageClassifierResult.create_from_pb2(classification_result_proto)
classifications.Classifications.create_from_pb2(classification)
for classification in classification_result_proto.classifications
])
def classify_for_video( def classify_for_video(
self, self,
image: image_module.Image, image: image_module.Image,
timestamp_ms: int, timestamp_ms: int,
image_processing_options: Optional[_ImageProcessingOptions] = None image_processing_options: Optional[_ImageProcessingOptions] = None
) -> classifications.ClassificationResult: ) -> ImageClassifierResult:
"""Performs image classification on the provided video frames. """Performs image classification on the provided video frames.
Only use this method when the ImageClassifier is created with the video Only use this method when the ImageClassifier is created with the video
@ -241,13 +234,9 @@ class ImageClassifier(base_vision_task_api.BaseVisionTaskApi):
classification_result_proto = classifications_pb2.ClassificationResult() classification_result_proto = classifications_pb2.ClassificationResult()
classification_result_proto.CopyFrom( classification_result_proto.CopyFrom(
packet_getter.get_proto( packet_getter.get_proto(output_packets[_CLASSIFICATIONS_STREAM_NAME]))
output_packets[_CLASSIFICATION_RESULT_OUT_STREAM_NAME]))
return classifications.ClassificationResult([ return ImageClassifierResult.create_from_pb2(classification_result_proto)
classifications.Classifications.create_from_pb2(classification)
for classification in classification_result_proto.classifications
])
def classify_async( def classify_async(
self, self,