Merge pull request #3832 from kinaryml:update-classification-tasks-output-stream
PiperOrigin-RevId: 487638077
This commit is contained in:
commit
bbf4ff0300
|
@ -85,21 +85,12 @@ py_library(
|
||||||
],
|
],
|
||||||
)
|
)
|
||||||
|
|
||||||
py_library(
|
|
||||||
name = "classifications",
|
|
||||||
srcs = ["classifications.py"],
|
|
||||||
deps = [
|
|
||||||
":category",
|
|
||||||
"//mediapipe/tasks/cc/components/containers/proto:classifications_py_pb2",
|
|
||||||
"//mediapipe/tasks/python/core:optional_dependencies",
|
|
||||||
],
|
|
||||||
)
|
|
||||||
|
|
||||||
py_library(
|
py_library(
|
||||||
name = "classification_result",
|
name = "classification_result",
|
||||||
srcs = ["classification_result.py"],
|
srcs = ["classification_result.py"],
|
||||||
deps = [
|
deps = [
|
||||||
":category",
|
":category",
|
||||||
|
"//mediapipe/framework/formats:classification_py_pb2",
|
||||||
"//mediapipe/tasks/cc/components/containers/proto:classifications_py_pb2",
|
"//mediapipe/tasks/cc/components/containers/proto:classifications_py_pb2",
|
||||||
"//mediapipe/tasks/python/core:optional_dependencies",
|
"//mediapipe/tasks/python/core:optional_dependencies",
|
||||||
],
|
],
|
||||||
|
|
|
@ -16,10 +16,13 @@
|
||||||
import dataclasses
|
import dataclasses
|
||||||
from typing import List, Optional
|
from typing import List, Optional
|
||||||
|
|
||||||
|
from mediapipe.framework.formats import classification_pb2
|
||||||
from mediapipe.tasks.cc.components.containers.proto import classifications_pb2
|
from mediapipe.tasks.cc.components.containers.proto import classifications_pb2
|
||||||
from mediapipe.tasks.python.components.containers import category as category_module
|
from mediapipe.tasks.python.components.containers import category as category_module
|
||||||
from mediapipe.tasks.python.core.optional_dependencies import doc_controls
|
from mediapipe.tasks.python.core.optional_dependencies import doc_controls
|
||||||
|
|
||||||
|
_ClassificationProto = classification_pb2.Classification
|
||||||
|
_ClassificationListProto = classification_pb2.ClassificationList
|
||||||
_ClassificationsProto = classifications_pb2.Classifications
|
_ClassificationsProto = classifications_pb2.Classifications
|
||||||
_ClassificationResultProto = classifications_pb2.ClassificationResult
|
_ClassificationResultProto = classifications_pb2.ClassificationResult
|
||||||
|
|
||||||
|
@ -41,6 +44,22 @@ class Classifications:
|
||||||
head_index: int
|
head_index: int
|
||||||
head_name: Optional[str] = None
|
head_name: Optional[str] = None
|
||||||
|
|
||||||
|
@doc_controls.do_not_generate_docs
|
||||||
|
def to_pb2(self) -> _ClassificationsProto:
|
||||||
|
"""Generates a Classifications protobuf object."""
|
||||||
|
classification_list_proto = _ClassificationListProto()
|
||||||
|
for category in self.categories:
|
||||||
|
classification_proto = _ClassificationProto(
|
||||||
|
index=category.index,
|
||||||
|
score=category.score,
|
||||||
|
label=category.category_name,
|
||||||
|
display_name=category.display_name)
|
||||||
|
classification_list_proto.classification.append(classification_proto)
|
||||||
|
return _ClassificationsProto(
|
||||||
|
classification_list=classification_list_proto,
|
||||||
|
head_index=self.head_index,
|
||||||
|
head_name=self.head_name)
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
@doc_controls.do_not_generate_docs
|
@doc_controls.do_not_generate_docs
|
||||||
def create_from_pb2(cls, pb2_obj: _ClassificationsProto) -> 'Classifications':
|
def create_from_pb2(cls, pb2_obj: _ClassificationsProto) -> 'Classifications':
|
||||||
|
@ -78,6 +97,15 @@ class ClassificationResult:
|
||||||
classifications: List[Classifications]
|
classifications: List[Classifications]
|
||||||
timestamp_ms: Optional[int] = None
|
timestamp_ms: Optional[int] = None
|
||||||
|
|
||||||
|
@doc_controls.do_not_generate_docs
|
||||||
|
def to_pb2(self) -> _ClassificationResultProto:
|
||||||
|
"""Generates a ClassificationResult protobuf object."""
|
||||||
|
return _ClassificationResultProto(
|
||||||
|
classifications=[
|
||||||
|
classification.to_pb2() for classification in self.classifications
|
||||||
|
],
|
||||||
|
timestamp_ms=self.timestamp_ms)
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
@doc_controls.do_not_generate_docs
|
@doc_controls.do_not_generate_docs
|
||||||
def create_from_pb2(
|
def create_from_pb2(
|
||||||
|
|
|
@ -1,168 +0,0 @@
|
||||||
# Copyright 2022 The TensorFlow Authors. All Rights Reserved.
|
|
||||||
#
|
|
||||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
||||||
# you may not use this file except in compliance with the License.
|
|
||||||
# You may obtain a copy of the License at
|
|
||||||
#
|
|
||||||
# http://www.apache.org/licenses/LICENSE-2.0
|
|
||||||
#
|
|
||||||
# Unless required by applicable law or agreed to in writing, software
|
|
||||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
||||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
||||||
# See the License for the specific language governing permissions and
|
|
||||||
# limitations under the License.
|
|
||||||
"""Classifications data class."""
|
|
||||||
|
|
||||||
import dataclasses
|
|
||||||
from typing import Any, List, Optional
|
|
||||||
|
|
||||||
from mediapipe.tasks.cc.components.containers.proto import classifications_pb2
|
|
||||||
from mediapipe.tasks.python.components.containers import category as category_module
|
|
||||||
from mediapipe.tasks.python.core.optional_dependencies import doc_controls
|
|
||||||
|
|
||||||
_ClassificationEntryProto = classifications_pb2.ClassificationEntry
|
|
||||||
_ClassificationsProto = classifications_pb2.Classifications
|
|
||||||
_ClassificationResultProto = classifications_pb2.ClassificationResult
|
|
||||||
|
|
||||||
|
|
||||||
@dataclasses.dataclass
|
|
||||||
class ClassificationEntry:
|
|
||||||
"""List of predicted classes (aka labels) for a given classifier head.
|
|
||||||
|
|
||||||
Attributes:
|
|
||||||
categories: The array of predicted categories, usually sorted by descending
|
|
||||||
scores (e.g. from high to low probability).
|
|
||||||
timestamp_ms: The optional timestamp (in milliseconds) associated to the
|
|
||||||
classification entry. This is useful for time series use cases, e.g.,
|
|
||||||
audio classification.
|
|
||||||
"""
|
|
||||||
|
|
||||||
categories: List[category_module.Category]
|
|
||||||
timestamp_ms: Optional[int] = None
|
|
||||||
|
|
||||||
@doc_controls.do_not_generate_docs
|
|
||||||
def to_pb2(self) -> _ClassificationEntryProto:
|
|
||||||
"""Generates a ClassificationEntry protobuf object."""
|
|
||||||
return _ClassificationEntryProto(
|
|
||||||
categories=[category.to_pb2() for category in self.categories],
|
|
||||||
timestamp_ms=self.timestamp_ms)
|
|
||||||
|
|
||||||
@classmethod
|
|
||||||
@doc_controls.do_not_generate_docs
|
|
||||||
def create_from_pb2(
|
|
||||||
cls, pb2_obj: _ClassificationEntryProto) -> 'ClassificationEntry':
|
|
||||||
"""Creates a `ClassificationEntry` object from the given protobuf object."""
|
|
||||||
return ClassificationEntry(
|
|
||||||
categories=[
|
|
||||||
category_module.Category.create_from_pb2(category)
|
|
||||||
for category in pb2_obj.categories
|
|
||||||
],
|
|
||||||
timestamp_ms=pb2_obj.timestamp_ms)
|
|
||||||
|
|
||||||
def __eq__(self, other: Any) -> bool:
|
|
||||||
"""Checks if this object is equal to the given object.
|
|
||||||
|
|
||||||
Args:
|
|
||||||
other: The object to be compared with.
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
True if the objects are equal.
|
|
||||||
"""
|
|
||||||
if not isinstance(other, ClassificationEntry):
|
|
||||||
return False
|
|
||||||
|
|
||||||
return self.to_pb2().__eq__(other.to_pb2())
|
|
||||||
|
|
||||||
|
|
||||||
@dataclasses.dataclass
|
|
||||||
class Classifications:
|
|
||||||
"""Represents the classifications for a given classifier head.
|
|
||||||
|
|
||||||
Attributes:
|
|
||||||
entries: A list of `ClassificationEntry` objects.
|
|
||||||
head_index: The index of the classifier head these categories refer to. This
|
|
||||||
is useful for multi-head models.
|
|
||||||
head_name: The name of the classifier head, which is the corresponding
|
|
||||||
tensor metadata name.
|
|
||||||
"""
|
|
||||||
|
|
||||||
entries: List[ClassificationEntry]
|
|
||||||
head_index: int
|
|
||||||
head_name: str
|
|
||||||
|
|
||||||
@doc_controls.do_not_generate_docs
|
|
||||||
def to_pb2(self) -> _ClassificationsProto:
|
|
||||||
"""Generates a Classifications protobuf object."""
|
|
||||||
return _ClassificationsProto(
|
|
||||||
entries=[entry.to_pb2() for entry in self.entries],
|
|
||||||
head_index=self.head_index,
|
|
||||||
head_name=self.head_name)
|
|
||||||
|
|
||||||
@classmethod
|
|
||||||
@doc_controls.do_not_generate_docs
|
|
||||||
def create_from_pb2(cls, pb2_obj: _ClassificationsProto) -> 'Classifications':
|
|
||||||
"""Creates a `Classifications` object from the given protobuf object."""
|
|
||||||
return Classifications(
|
|
||||||
entries=[
|
|
||||||
ClassificationEntry.create_from_pb2(entry)
|
|
||||||
for entry in pb2_obj.entries
|
|
||||||
],
|
|
||||||
head_index=pb2_obj.head_index,
|
|
||||||
head_name=pb2_obj.head_name)
|
|
||||||
|
|
||||||
def __eq__(self, other: Any) -> bool:
|
|
||||||
"""Checks if this object is equal to the given object.
|
|
||||||
|
|
||||||
Args:
|
|
||||||
other: The object to be compared with.
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
True if the objects are equal.
|
|
||||||
"""
|
|
||||||
if not isinstance(other, Classifications):
|
|
||||||
return False
|
|
||||||
|
|
||||||
return self.to_pb2().__eq__(other.to_pb2())
|
|
||||||
|
|
||||||
|
|
||||||
@dataclasses.dataclass
|
|
||||||
class ClassificationResult:
|
|
||||||
"""Contains one set of results per classifier head.
|
|
||||||
|
|
||||||
Attributes:
|
|
||||||
classifications: A list of `Classifications` objects.
|
|
||||||
"""
|
|
||||||
|
|
||||||
classifications: List[Classifications]
|
|
||||||
|
|
||||||
@doc_controls.do_not_generate_docs
|
|
||||||
def to_pb2(self) -> _ClassificationResultProto:
|
|
||||||
"""Generates a ClassificationResult protobuf object."""
|
|
||||||
return _ClassificationResultProto(classifications=[
|
|
||||||
classification.to_pb2() for classification in self.classifications
|
|
||||||
])
|
|
||||||
|
|
||||||
@classmethod
|
|
||||||
@doc_controls.do_not_generate_docs
|
|
||||||
def create_from_pb2(
|
|
||||||
cls, pb2_obj: _ClassificationResultProto) -> 'ClassificationResult':
|
|
||||||
"""Creates a `ClassificationResult` object from the given protobuf object.
|
|
||||||
"""
|
|
||||||
return ClassificationResult(classifications=[
|
|
||||||
Classifications.create_from_pb2(classification)
|
|
||||||
for classification in pb2_obj.classifications
|
|
||||||
])
|
|
||||||
|
|
||||||
def __eq__(self, other: Any) -> bool:
|
|
||||||
"""Checks if this object is equal to the given object.
|
|
||||||
|
|
||||||
Args:
|
|
||||||
other: The object to be compared with.
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
True if the objects are equal.
|
|
||||||
"""
|
|
||||||
if not isinstance(other, ClassificationResult):
|
|
||||||
return False
|
|
||||||
|
|
||||||
return self.to_pb2().__eq__(other.to_pb2())
|
|
|
@ -27,7 +27,7 @@ py_test(
|
||||||
],
|
],
|
||||||
deps = [
|
deps = [
|
||||||
"//mediapipe/tasks/python/components/containers:category",
|
"//mediapipe/tasks/python/components/containers:category",
|
||||||
"//mediapipe/tasks/python/components/containers:classifications",
|
"//mediapipe/tasks/python/components/containers:classification_result",
|
||||||
"//mediapipe/tasks/python/components/processors:classifier_options",
|
"//mediapipe/tasks/python/components/processors:classifier_options",
|
||||||
"//mediapipe/tasks/python/core:base_options",
|
"//mediapipe/tasks/python/core:base_options",
|
||||||
"//mediapipe/tasks/python/test:test_utils",
|
"//mediapipe/tasks/python/test:test_utils",
|
||||||
|
|
|
@ -20,18 +20,17 @@ from absl.testing import absltest
|
||||||
from absl.testing import parameterized
|
from absl.testing import parameterized
|
||||||
|
|
||||||
from mediapipe.tasks.python.components.containers import category
|
from mediapipe.tasks.python.components.containers import category
|
||||||
from mediapipe.tasks.python.components.containers import classifications as classifications_module
|
from mediapipe.tasks.python.components.containers import classification_result as classification_result_module
|
||||||
from mediapipe.tasks.python.components.processors import classifier_options
|
from mediapipe.tasks.python.components.processors import classifier_options
|
||||||
from mediapipe.tasks.python.core import base_options as base_options_module
|
from mediapipe.tasks.python.core import base_options as base_options_module
|
||||||
from mediapipe.tasks.python.test import test_utils
|
from mediapipe.tasks.python.test import test_utils
|
||||||
from mediapipe.tasks.python.text import text_classifier
|
from mediapipe.tasks.python.text import text_classifier
|
||||||
|
|
||||||
|
TextClassifierResult = classification_result_module.ClassificationResult
|
||||||
_BaseOptions = base_options_module.BaseOptions
|
_BaseOptions = base_options_module.BaseOptions
|
||||||
_ClassifierOptions = classifier_options.ClassifierOptions
|
_ClassifierOptions = classifier_options.ClassifierOptions
|
||||||
_Category = category.Category
|
_Category = category.Category
|
||||||
_ClassificationEntry = classifications_module.ClassificationEntry
|
_Classifications = classification_result_module.Classifications
|
||||||
_Classifications = classifications_module.Classifications
|
|
||||||
_TextClassifierResult = classifications_module.ClassificationResult
|
|
||||||
_TextClassifier = text_classifier.TextClassifier
|
_TextClassifier = text_classifier.TextClassifier
|
||||||
_TextClassifierOptions = text_classifier.TextClassifierOptions
|
_TextClassifierOptions = text_classifier.TextClassifierOptions
|
||||||
|
|
||||||
|
@ -43,10 +42,9 @@ _NEGATIVE_TEXT = 'What a waste of my time.'
|
||||||
_POSITIVE_TEXT = ('This is the best movie I’ve seen in recent years.'
|
_POSITIVE_TEXT = ('This is the best movie I’ve seen in recent years.'
|
||||||
'Strongly recommend it!')
|
'Strongly recommend it!')
|
||||||
|
|
||||||
_BERT_NEGATIVE_RESULTS = _TextClassifierResult(classifications=[
|
_BERT_NEGATIVE_RESULTS = TextClassifierResult(
|
||||||
|
classifications=[
|
||||||
_Classifications(
|
_Classifications(
|
||||||
entries=[
|
|
||||||
_ClassificationEntry(
|
|
||||||
categories=[
|
categories=[
|
||||||
_Category(
|
_Category(
|
||||||
index=0,
|
index=0,
|
||||||
|
@ -59,15 +57,13 @@ _BERT_NEGATIVE_RESULTS = _TextClassifierResult(classifications=[
|
||||||
display_name='',
|
display_name='',
|
||||||
category_name='positive')
|
category_name='positive')
|
||||||
],
|
],
|
||||||
timestamp_ms=0)
|
|
||||||
],
|
|
||||||
head_index=0,
|
head_index=0,
|
||||||
head_name='probability')
|
head_name='probability')
|
||||||
])
|
],
|
||||||
_BERT_POSITIVE_RESULTS = _TextClassifierResult(classifications=[
|
timestamp_ms=0)
|
||||||
|
_BERT_POSITIVE_RESULTS = TextClassifierResult(
|
||||||
|
classifications=[
|
||||||
_Classifications(
|
_Classifications(
|
||||||
entries=[
|
|
||||||
_ClassificationEntry(
|
|
||||||
categories=[
|
categories=[
|
||||||
_Category(
|
_Category(
|
||||||
index=1,
|
index=1,
|
||||||
|
@ -80,15 +76,13 @@ _BERT_POSITIVE_RESULTS = _TextClassifierResult(classifications=[
|
||||||
display_name='',
|
display_name='',
|
||||||
category_name='negative')
|
category_name='negative')
|
||||||
],
|
],
|
||||||
timestamp_ms=0)
|
|
||||||
],
|
|
||||||
head_index=0,
|
head_index=0,
|
||||||
head_name='probability')
|
head_name='probability')
|
||||||
])
|
],
|
||||||
_REGEX_NEGATIVE_RESULTS = _TextClassifierResult(classifications=[
|
timestamp_ms=0)
|
||||||
|
_REGEX_NEGATIVE_RESULTS = TextClassifierResult(
|
||||||
|
classifications=[
|
||||||
_Classifications(
|
_Classifications(
|
||||||
entries=[
|
|
||||||
_ClassificationEntry(
|
|
||||||
categories=[
|
categories=[
|
||||||
_Category(
|
_Category(
|
||||||
index=0,
|
index=0,
|
||||||
|
@ -101,15 +95,13 @@ _REGEX_NEGATIVE_RESULTS = _TextClassifierResult(classifications=[
|
||||||
display_name='',
|
display_name='',
|
||||||
category_name='Positive')
|
category_name='Positive')
|
||||||
],
|
],
|
||||||
timestamp_ms=0)
|
|
||||||
],
|
|
||||||
head_index=0,
|
head_index=0,
|
||||||
head_name='probability')
|
head_name='probability')
|
||||||
])
|
],
|
||||||
_REGEX_POSITIVE_RESULTS = _TextClassifierResult(classifications=[
|
timestamp_ms=0)
|
||||||
|
_REGEX_POSITIVE_RESULTS = TextClassifierResult(
|
||||||
|
classifications=[
|
||||||
_Classifications(
|
_Classifications(
|
||||||
entries=[
|
|
||||||
_ClassificationEntry(
|
|
||||||
categories=[
|
categories=[
|
||||||
_Category(
|
_Category(
|
||||||
index=1,
|
index=1,
|
||||||
|
@ -122,11 +114,10 @@ _REGEX_POSITIVE_RESULTS = _TextClassifierResult(classifications=[
|
||||||
display_name='',
|
display_name='',
|
||||||
category_name='Negative')
|
category_name='Negative')
|
||||||
],
|
],
|
||||||
timestamp_ms=0)
|
|
||||||
],
|
|
||||||
head_index=0,
|
head_index=0,
|
||||||
head_name='probability')
|
head_name='probability')
|
||||||
])
|
],
|
||||||
|
timestamp_ms=0)
|
||||||
|
|
||||||
|
|
||||||
class ModelFileType(enum.Enum):
|
class ModelFileType(enum.Enum):
|
||||||
|
|
|
@ -47,7 +47,7 @@ py_test(
|
||||||
deps = [
|
deps = [
|
||||||
"//mediapipe/python:_framework_bindings",
|
"//mediapipe/python:_framework_bindings",
|
||||||
"//mediapipe/tasks/python/components/containers:category",
|
"//mediapipe/tasks/python/components/containers:category",
|
||||||
"//mediapipe/tasks/python/components/containers:classifications",
|
"//mediapipe/tasks/python/components/containers:classification_result",
|
||||||
"//mediapipe/tasks/python/components/containers:rect",
|
"//mediapipe/tasks/python/components/containers:rect",
|
||||||
"//mediapipe/tasks/python/components/processors:classifier_options",
|
"//mediapipe/tasks/python/components/processors:classifier_options",
|
||||||
"//mediapipe/tasks/python/core:base_options",
|
"//mediapipe/tasks/python/core:base_options",
|
||||||
|
|
|
@ -23,8 +23,8 @@ from absl.testing import parameterized
|
||||||
import numpy as np
|
import numpy as np
|
||||||
|
|
||||||
from mediapipe.python._framework_bindings import image
|
from mediapipe.python._framework_bindings import image
|
||||||
from mediapipe.tasks.python.components.containers import category
|
from mediapipe.tasks.python.components.containers import category as category_module
|
||||||
from mediapipe.tasks.python.components.containers import classifications as classifications_module
|
from mediapipe.tasks.python.components.containers import classification_result as classification_result_module
|
||||||
from mediapipe.tasks.python.components.containers import rect
|
from mediapipe.tasks.python.components.containers import rect
|
||||||
from mediapipe.tasks.python.components.processors import classifier_options
|
from mediapipe.tasks.python.components.processors import classifier_options
|
||||||
from mediapipe.tasks.python.core import base_options as base_options_module
|
from mediapipe.tasks.python.core import base_options as base_options_module
|
||||||
|
@ -33,13 +33,12 @@ from mediapipe.tasks.python.vision import image_classifier
|
||||||
from mediapipe.tasks.python.vision.core import image_processing_options as image_processing_options_module
|
from mediapipe.tasks.python.vision.core import image_processing_options as image_processing_options_module
|
||||||
from mediapipe.tasks.python.vision.core import vision_task_running_mode
|
from mediapipe.tasks.python.vision.core import vision_task_running_mode
|
||||||
|
|
||||||
|
ImageClassifierResult = classification_result_module.ClassificationResult
|
||||||
_Rect = rect.Rect
|
_Rect = rect.Rect
|
||||||
_BaseOptions = base_options_module.BaseOptions
|
_BaseOptions = base_options_module.BaseOptions
|
||||||
_ClassifierOptions = classifier_options.ClassifierOptions
|
_ClassifierOptions = classifier_options.ClassifierOptions
|
||||||
_Category = category.Category
|
_Category = category_module.Category
|
||||||
_ClassificationEntry = classifications_module.ClassificationEntry
|
_Classifications = classification_result_module.Classifications
|
||||||
_Classifications = classifications_module.Classifications
|
|
||||||
_ClassificationResult = classifications_module.ClassificationResult
|
|
||||||
_Image = image.Image
|
_Image = image.Image
|
||||||
_ImageClassifier = image_classifier.ImageClassifier
|
_ImageClassifier = image_classifier.ImageClassifier
|
||||||
_ImageClassifierOptions = image_classifier.ImageClassifierOptions
|
_ImageClassifierOptions = image_classifier.ImageClassifierOptions
|
||||||
|
@ -55,22 +54,19 @@ _MAX_RESULTS = 3
|
||||||
_TEST_DATA_DIR = 'mediapipe/tasks/testdata/vision'
|
_TEST_DATA_DIR = 'mediapipe/tasks/testdata/vision'
|
||||||
|
|
||||||
|
|
||||||
def _generate_empty_results(timestamp_ms: int) -> _ClassificationResult:
|
def _generate_empty_results() -> ImageClassifierResult:
|
||||||
return _ClassificationResult(classifications=[
|
return ImageClassifierResult(
|
||||||
|
classifications=[
|
||||||
_Classifications(
|
_Classifications(
|
||||||
entries=[
|
categories=[], head_index=0, head_name='probability')
|
||||||
_ClassificationEntry(categories=[], timestamp_ms=timestamp_ms)
|
|
||||||
],
|
],
|
||||||
head_index=0,
|
timestamp_ms=0)
|
||||||
head_name='probability')
|
|
||||||
])
|
|
||||||
|
|
||||||
|
|
||||||
def _generate_burger_results(timestamp_ms: int) -> _ClassificationResult:
|
def _generate_burger_results() -> ImageClassifierResult:
|
||||||
return _ClassificationResult(classifications=[
|
return ImageClassifierResult(
|
||||||
|
classifications=[
|
||||||
_Classifications(
|
_Classifications(
|
||||||
entries=[
|
|
||||||
_ClassificationEntry(
|
|
||||||
categories=[
|
categories=[
|
||||||
_Category(
|
_Category(
|
||||||
index=934,
|
index=934,
|
||||||
|
@ -93,18 +89,16 @@ def _generate_burger_results(timestamp_ms: int) -> _ClassificationResult:
|
||||||
display_name='',
|
display_name='',
|
||||||
category_name='meat loaf')
|
category_name='meat loaf')
|
||||||
],
|
],
|
||||||
timestamp_ms=timestamp_ms)
|
|
||||||
],
|
|
||||||
head_index=0,
|
head_index=0,
|
||||||
head_name='probability')
|
head_name='probability')
|
||||||
])
|
],
|
||||||
|
timestamp_ms=0)
|
||||||
|
|
||||||
|
|
||||||
def _generate_soccer_ball_results(timestamp_ms: int) -> _ClassificationResult:
|
def _generate_soccer_ball_results() -> ImageClassifierResult:
|
||||||
return _ClassificationResult(classifications=[
|
return ImageClassifierResult(
|
||||||
|
classifications=[
|
||||||
_Classifications(
|
_Classifications(
|
||||||
entries=[
|
|
||||||
_ClassificationEntry(
|
|
||||||
categories=[
|
categories=[
|
||||||
_Category(
|
_Category(
|
||||||
index=806,
|
index=806,
|
||||||
|
@ -112,11 +106,10 @@ def _generate_soccer_ball_results(timestamp_ms: int) -> _ClassificationResult:
|
||||||
display_name='',
|
display_name='',
|
||||||
category_name='soccer ball')
|
category_name='soccer ball')
|
||||||
],
|
],
|
||||||
timestamp_ms=timestamp_ms)
|
|
||||||
],
|
|
||||||
head_index=0,
|
head_index=0,
|
||||||
head_name='probability')
|
head_name='probability')
|
||||||
])
|
],
|
||||||
|
timestamp_ms=0)
|
||||||
|
|
||||||
|
|
||||||
class ModelFileType(enum.Enum):
|
class ModelFileType(enum.Enum):
|
||||||
|
@ -163,8 +156,8 @@ class ImageClassifierTest(parameterized.TestCase):
|
||||||
self.assertIsInstance(classifier, _ImageClassifier)
|
self.assertIsInstance(classifier, _ImageClassifier)
|
||||||
|
|
||||||
@parameterized.parameters(
|
@parameterized.parameters(
|
||||||
(ModelFileType.FILE_NAME, 4, _generate_burger_results(0)),
|
(ModelFileType.FILE_NAME, 4, _generate_burger_results()),
|
||||||
(ModelFileType.FILE_CONTENT, 4, _generate_burger_results(0)))
|
(ModelFileType.FILE_CONTENT, 4, _generate_burger_results()))
|
||||||
def test_classify(self, model_file_type, max_results,
|
def test_classify(self, model_file_type, max_results,
|
||||||
expected_classification_result):
|
expected_classification_result):
|
||||||
# Creates classifier.
|
# Creates classifier.
|
||||||
|
@ -193,8 +186,8 @@ class ImageClassifierTest(parameterized.TestCase):
|
||||||
classifier.close()
|
classifier.close()
|
||||||
|
|
||||||
@parameterized.parameters(
|
@parameterized.parameters(
|
||||||
(ModelFileType.FILE_NAME, 4, _generate_burger_results(0)),
|
(ModelFileType.FILE_NAME, 4, _generate_burger_results()),
|
||||||
(ModelFileType.FILE_CONTENT, 4, _generate_burger_results(0)))
|
(ModelFileType.FILE_CONTENT, 4, _generate_burger_results()))
|
||||||
def test_classify_in_context(self, model_file_type, max_results,
|
def test_classify_in_context(self, model_file_type, max_results,
|
||||||
expected_classification_result):
|
expected_classification_result):
|
||||||
if model_file_type is ModelFileType.FILE_NAME:
|
if model_file_type is ModelFileType.FILE_NAME:
|
||||||
|
@ -234,7 +227,7 @@ class ImageClassifierTest(parameterized.TestCase):
|
||||||
image_result = classifier.classify(test_image, image_processing_options)
|
image_result = classifier.classify(test_image, image_processing_options)
|
||||||
# Comparing results.
|
# Comparing results.
|
||||||
test_utils.assert_proto_equals(self, image_result.to_pb2(),
|
test_utils.assert_proto_equals(self, image_result.to_pb2(),
|
||||||
_generate_soccer_ball_results(0).to_pb2())
|
_generate_soccer_ball_results().to_pb2())
|
||||||
|
|
||||||
def test_score_threshold_option(self):
|
def test_score_threshold_option(self):
|
||||||
custom_classifier_options = _ClassifierOptions(
|
custom_classifier_options = _ClassifierOptions(
|
||||||
|
@ -248,8 +241,8 @@ class ImageClassifierTest(parameterized.TestCase):
|
||||||
classifications = image_result.classifications
|
classifications = image_result.classifications
|
||||||
|
|
||||||
for classification in classifications:
|
for classification in classifications:
|
||||||
for entry in classification.entries:
|
for category in classification.categories:
|
||||||
score = entry.categories[0].score
|
score = category.score
|
||||||
self.assertGreaterEqual(
|
self.assertGreaterEqual(
|
||||||
score, _SCORE_THRESHOLD,
|
score, _SCORE_THRESHOLD,
|
||||||
f'Classification with score lower than threshold found. '
|
f'Classification with score lower than threshold found. '
|
||||||
|
@ -264,7 +257,7 @@ class ImageClassifierTest(parameterized.TestCase):
|
||||||
with _ImageClassifier.create_from_options(options) as classifier:
|
with _ImageClassifier.create_from_options(options) as classifier:
|
||||||
# Performs image classification on the input.
|
# Performs image classification on the input.
|
||||||
image_result = classifier.classify(self.test_image)
|
image_result = classifier.classify(self.test_image)
|
||||||
categories = image_result.classifications[0].entries[0].categories
|
categories = image_result.classifications[0].categories
|
||||||
|
|
||||||
self.assertLessEqual(
|
self.assertLessEqual(
|
||||||
len(categories), _MAX_RESULTS, 'Too many results returned.')
|
len(categories), _MAX_RESULTS, 'Too many results returned.')
|
||||||
|
@ -281,8 +274,8 @@ class ImageClassifierTest(parameterized.TestCase):
|
||||||
classifications = image_result.classifications
|
classifications = image_result.classifications
|
||||||
|
|
||||||
for classification in classifications:
|
for classification in classifications:
|
||||||
for entry in classification.entries:
|
for category in classification.categories:
|
||||||
label = entry.categories[0].category_name
|
label = category.category_name
|
||||||
self.assertIn(label, _ALLOW_LIST,
|
self.assertIn(label, _ALLOW_LIST,
|
||||||
f'Label {label} found but not in label allow list')
|
f'Label {label} found but not in label allow list')
|
||||||
|
|
||||||
|
@ -297,8 +290,8 @@ class ImageClassifierTest(parameterized.TestCase):
|
||||||
classifications = image_result.classifications
|
classifications = image_result.classifications
|
||||||
|
|
||||||
for classification in classifications:
|
for classification in classifications:
|
||||||
for entry in classification.entries:
|
for category in classification.categories:
|
||||||
label = entry.categories[0].category_name
|
label = category.category_name
|
||||||
self.assertNotIn(label, _DENY_LIST,
|
self.assertNotIn(label, _DENY_LIST,
|
||||||
f'Label {label} found but in deny list.')
|
f'Label {label} found but in deny list.')
|
||||||
|
|
||||||
|
@ -324,7 +317,7 @@ class ImageClassifierTest(parameterized.TestCase):
|
||||||
with _ImageClassifier.create_from_options(options) as classifier:
|
with _ImageClassifier.create_from_options(options) as classifier:
|
||||||
# Performs image classification on the input.
|
# Performs image classification on the input.
|
||||||
image_result = classifier.classify(self.test_image)
|
image_result = classifier.classify(self.test_image)
|
||||||
self.assertEmpty(image_result.classifications[0].entries[0].categories)
|
self.assertEmpty(image_result.classifications[0].categories)
|
||||||
|
|
||||||
def test_missing_result_callback(self):
|
def test_missing_result_callback(self):
|
||||||
options = _ImageClassifierOptions(
|
options = _ImageClassifierOptions(
|
||||||
|
@ -402,9 +395,8 @@ class ImageClassifierTest(parameterized.TestCase):
|
||||||
for timestamp in range(0, 300, 30):
|
for timestamp in range(0, 300, 30):
|
||||||
classification_result = classifier.classify_for_video(
|
classification_result = classifier.classify_for_video(
|
||||||
self.test_image, timestamp)
|
self.test_image, timestamp)
|
||||||
test_utils.assert_proto_equals(
|
test_utils.assert_proto_equals(self, classification_result.to_pb2(),
|
||||||
self, classification_result.to_pb2(),
|
_generate_burger_results().to_pb2())
|
||||||
_generate_burger_results(timestamp).to_pb2())
|
|
||||||
|
|
||||||
def test_classify_for_video_succeeds_with_region_of_interest(self):
|
def test_classify_for_video_succeeds_with_region_of_interest(self):
|
||||||
custom_classifier_options = _ClassifierOptions(max_results=1)
|
custom_classifier_options = _ClassifierOptions(max_results=1)
|
||||||
|
@ -423,9 +415,8 @@ class ImageClassifierTest(parameterized.TestCase):
|
||||||
for timestamp in range(0, 300, 30):
|
for timestamp in range(0, 300, 30):
|
||||||
classification_result = classifier.classify_for_video(
|
classification_result = classifier.classify_for_video(
|
||||||
test_image, timestamp, image_processing_options)
|
test_image, timestamp, image_processing_options)
|
||||||
test_utils.assert_proto_equals(
|
test_utils.assert_proto_equals(self, classification_result.to_pb2(),
|
||||||
self, classification_result.to_pb2(),
|
_generate_soccer_ball_results().to_pb2())
|
||||||
_generate_soccer_ball_results(timestamp).to_pb2())
|
|
||||||
|
|
||||||
def test_calling_classify_in_live_stream_mode(self):
|
def test_calling_classify_in_live_stream_mode(self):
|
||||||
options = _ImageClassifierOptions(
|
options = _ImageClassifierOptions(
|
||||||
|
@ -460,15 +451,15 @@ class ImageClassifierTest(parameterized.TestCase):
|
||||||
ValueError, r'Input timestamp must be monotonically increasing'):
|
ValueError, r'Input timestamp must be monotonically increasing'):
|
||||||
classifier.classify_async(self.test_image, 0)
|
classifier.classify_async(self.test_image, 0)
|
||||||
|
|
||||||
@parameterized.parameters((0, _generate_burger_results),
|
@parameterized.parameters((0, _generate_burger_results()),
|
||||||
(1, _generate_empty_results))
|
(1, _generate_empty_results()))
|
||||||
def test_classify_async_calls(self, threshold, expected_result_fn):
|
def test_classify_async_calls(self, threshold, expected_result):
|
||||||
observed_timestamp_ms = -1
|
observed_timestamp_ms = -1
|
||||||
|
|
||||||
def check_result(result: _ClassificationResult, output_image: _Image,
|
def check_result(result: ImageClassifierResult, output_image: _Image,
|
||||||
timestamp_ms: int):
|
timestamp_ms: int):
|
||||||
test_utils.assert_proto_equals(self, result.to_pb2(),
|
test_utils.assert_proto_equals(self, result.to_pb2(),
|
||||||
expected_result_fn(timestamp_ms).to_pb2())
|
expected_result.to_pb2())
|
||||||
self.assertTrue(
|
self.assertTrue(
|
||||||
np.array_equal(output_image.numpy_view(),
|
np.array_equal(output_image.numpy_view(),
|
||||||
self.test_image.numpy_view()))
|
self.test_image.numpy_view()))
|
||||||
|
@ -496,11 +487,10 @@ class ImageClassifierTest(parameterized.TestCase):
|
||||||
image_processing_options = _ImageProcessingOptions(roi)
|
image_processing_options = _ImageProcessingOptions(roi)
|
||||||
observed_timestamp_ms = -1
|
observed_timestamp_ms = -1
|
||||||
|
|
||||||
def check_result(result: _ClassificationResult, output_image: _Image,
|
def check_result(result: ImageClassifierResult, output_image: _Image,
|
||||||
timestamp_ms: int):
|
timestamp_ms: int):
|
||||||
test_utils.assert_proto_equals(
|
test_utils.assert_proto_equals(self, result.to_pb2(),
|
||||||
self, result.to_pb2(),
|
_generate_soccer_ball_results().to_pb2())
|
||||||
_generate_soccer_ball_results(timestamp_ms).to_pb2())
|
|
||||||
self.assertEqual(output_image.width, test_image.width)
|
self.assertEqual(output_image.width, test_image.width)
|
||||||
self.assertEqual(output_image.height, test_image.height)
|
self.assertEqual(output_image.height, test_image.height)
|
||||||
self.assertLess(observed_timestamp_ms, timestamp_ms)
|
self.assertLess(observed_timestamp_ms, timestamp_ms)
|
||||||
|
|
|
@ -28,7 +28,7 @@ py_library(
|
||||||
"//mediapipe/python:packet_getter",
|
"//mediapipe/python:packet_getter",
|
||||||
"//mediapipe/tasks/cc/components/containers/proto:classifications_py_pb2",
|
"//mediapipe/tasks/cc/components/containers/proto:classifications_py_pb2",
|
||||||
"//mediapipe/tasks/cc/text/text_classifier/proto:text_classifier_graph_options_py_pb2",
|
"//mediapipe/tasks/cc/text/text_classifier/proto:text_classifier_graph_options_py_pb2",
|
||||||
"//mediapipe/tasks/python/components/containers:classifications",
|
"//mediapipe/tasks/python/components/containers:classification_result",
|
||||||
"//mediapipe/tasks/python/components/processors:classifier_options",
|
"//mediapipe/tasks/python/components/processors:classifier_options",
|
||||||
"//mediapipe/tasks/python/core:base_options",
|
"//mediapipe/tasks/python/core:base_options",
|
||||||
"//mediapipe/tasks/python/core:optional_dependencies",
|
"//mediapipe/tasks/python/core:optional_dependencies",
|
||||||
|
|
|
@ -19,21 +19,21 @@ from mediapipe.python import packet_creator
|
||||||
from mediapipe.python import packet_getter
|
from mediapipe.python import packet_getter
|
||||||
from mediapipe.tasks.cc.components.containers.proto import classifications_pb2
|
from mediapipe.tasks.cc.components.containers.proto import classifications_pb2
|
||||||
from mediapipe.tasks.cc.text.text_classifier.proto import text_classifier_graph_options_pb2
|
from mediapipe.tasks.cc.text.text_classifier.proto import text_classifier_graph_options_pb2
|
||||||
from mediapipe.tasks.python.components.containers import classifications
|
from mediapipe.tasks.python.components.containers import classification_result as classification_result_module
|
||||||
from mediapipe.tasks.python.components.processors import classifier_options
|
from mediapipe.tasks.python.components.processors import classifier_options
|
||||||
from mediapipe.tasks.python.core import base_options as base_options_module
|
from mediapipe.tasks.python.core import base_options as base_options_module
|
||||||
from mediapipe.tasks.python.core import task_info as task_info_module
|
from mediapipe.tasks.python.core import task_info as task_info_module
|
||||||
from mediapipe.tasks.python.core.optional_dependencies import doc_controls
|
from mediapipe.tasks.python.core.optional_dependencies import doc_controls
|
||||||
from mediapipe.tasks.python.text.core import base_text_task_api
|
from mediapipe.tasks.python.text.core import base_text_task_api
|
||||||
|
|
||||||
TextClassifierResult = classifications.ClassificationResult
|
TextClassifierResult = classification_result_module.ClassificationResult
|
||||||
_BaseOptions = base_options_module.BaseOptions
|
_BaseOptions = base_options_module.BaseOptions
|
||||||
_TextClassifierGraphOptionsProto = text_classifier_graph_options_pb2.TextClassifierGraphOptions
|
_TextClassifierGraphOptionsProto = text_classifier_graph_options_pb2.TextClassifierGraphOptions
|
||||||
_ClassifierOptions = classifier_options.ClassifierOptions
|
_ClassifierOptions = classifier_options.ClassifierOptions
|
||||||
_TaskInfo = task_info_module.TaskInfo
|
_TaskInfo = task_info_module.TaskInfo
|
||||||
|
|
||||||
_CLASSIFICATION_RESULT_OUT_STREAM_NAME = 'classification_result_out'
|
_CLASSIFICATIONS_STREAM_NAME = 'classifications_out'
|
||||||
_CLASSIFICATION_RESULT_TAG = 'CLASSIFICATION_RESULT'
|
_CLASSIFICATIONS_TAG = 'CLASSIFICATIONS'
|
||||||
_TEXT_IN_STREAM_NAME = 'text_in'
|
_TEXT_IN_STREAM_NAME = 'text_in'
|
||||||
_TEXT_TAG = 'TEXT'
|
_TEXT_TAG = 'TEXT'
|
||||||
_TASK_GRAPH_NAME = 'mediapipe.tasks.text.text_classifier.TextClassifierGraph'
|
_TASK_GRAPH_NAME = 'mediapipe.tasks.text.text_classifier.TextClassifierGraph'
|
||||||
|
@ -104,10 +104,7 @@ class TextClassifier(base_text_task_api.BaseTextTaskApi):
|
||||||
task_graph=_TASK_GRAPH_NAME,
|
task_graph=_TASK_GRAPH_NAME,
|
||||||
input_streams=[':'.join([_TEXT_TAG, _TEXT_IN_STREAM_NAME])],
|
input_streams=[':'.join([_TEXT_TAG, _TEXT_IN_STREAM_NAME])],
|
||||||
output_streams=[
|
output_streams=[
|
||||||
':'.join([
|
':'.join([_CLASSIFICATIONS_TAG, _CLASSIFICATIONS_STREAM_NAME])
|
||||||
_CLASSIFICATION_RESULT_TAG,
|
|
||||||
_CLASSIFICATION_RESULT_OUT_STREAM_NAME
|
|
||||||
])
|
|
||||||
],
|
],
|
||||||
task_options=options)
|
task_options=options)
|
||||||
return cls(task_info.generate_graph_config())
|
return cls(task_info.generate_graph_config())
|
||||||
|
@ -131,10 +128,6 @@ class TextClassifier(base_text_task_api.BaseTextTaskApi):
|
||||||
|
|
||||||
classification_result_proto = classifications_pb2.ClassificationResult()
|
classification_result_proto = classifications_pb2.ClassificationResult()
|
||||||
classification_result_proto.CopyFrom(
|
classification_result_proto.CopyFrom(
|
||||||
packet_getter.get_proto(
|
packet_getter.get_proto(output_packets[_CLASSIFICATIONS_STREAM_NAME]))
|
||||||
output_packets[_CLASSIFICATION_RESULT_OUT_STREAM_NAME]))
|
|
||||||
|
|
||||||
return TextClassifierResult([
|
return TextClassifierResult.create_from_pb2(classification_result_proto)
|
||||||
classifications.Classifications.create_from_pb2(classification)
|
|
||||||
for classification in classification_result_proto.classifications
|
|
||||||
])
|
|
||||||
|
|
|
@ -48,7 +48,7 @@ py_library(
|
||||||
"//mediapipe/python:packet_getter",
|
"//mediapipe/python:packet_getter",
|
||||||
"//mediapipe/tasks/cc/components/containers/proto:classifications_py_pb2",
|
"//mediapipe/tasks/cc/components/containers/proto:classifications_py_pb2",
|
||||||
"//mediapipe/tasks/cc/vision/image_classifier/proto:image_classifier_graph_options_py_pb2",
|
"//mediapipe/tasks/cc/vision/image_classifier/proto:image_classifier_graph_options_py_pb2",
|
||||||
"//mediapipe/tasks/python/components/containers:classifications",
|
"//mediapipe/tasks/python/components/containers:classification_result",
|
||||||
"//mediapipe/tasks/python/components/containers:rect",
|
"//mediapipe/tasks/python/components/containers:rect",
|
||||||
"//mediapipe/tasks/python/components/processors:classifier_options",
|
"//mediapipe/tasks/python/components/processors:classifier_options",
|
||||||
"//mediapipe/tasks/python/core:base_options",
|
"//mediapipe/tasks/python/core:base_options",
|
||||||
|
|
|
@ -18,12 +18,11 @@ from typing import Callable, Mapping, Optional
|
||||||
|
|
||||||
from mediapipe.python import packet_creator
|
from mediapipe.python import packet_creator
|
||||||
from mediapipe.python import packet_getter
|
from mediapipe.python import packet_getter
|
||||||
# TODO: Import MPImage directly one we have an alias
|
|
||||||
from mediapipe.python._framework_bindings import image as image_module
|
from mediapipe.python._framework_bindings import image as image_module
|
||||||
from mediapipe.python._framework_bindings import packet
|
from mediapipe.python._framework_bindings import packet
|
||||||
from mediapipe.tasks.cc.components.containers.proto import classifications_pb2
|
from mediapipe.tasks.cc.components.containers.proto import classifications_pb2
|
||||||
from mediapipe.tasks.cc.vision.image_classifier.proto import image_classifier_graph_options_pb2
|
from mediapipe.tasks.cc.vision.image_classifier.proto import image_classifier_graph_options_pb2
|
||||||
from mediapipe.tasks.python.components.containers import classifications
|
from mediapipe.tasks.python.components.containers import classification_result as classification_result_module
|
||||||
from mediapipe.tasks.python.components.containers import rect
|
from mediapipe.tasks.python.components.containers import rect
|
||||||
from mediapipe.tasks.python.components.processors import classifier_options
|
from mediapipe.tasks.python.components.processors import classifier_options
|
||||||
from mediapipe.tasks.python.core import base_options as base_options_module
|
from mediapipe.tasks.python.core import base_options as base_options_module
|
||||||
|
@ -33,6 +32,7 @@ from mediapipe.tasks.python.vision.core import base_vision_task_api
|
||||||
from mediapipe.tasks.python.vision.core import image_processing_options as image_processing_options_module
|
from mediapipe.tasks.python.vision.core import image_processing_options as image_processing_options_module
|
||||||
from mediapipe.tasks.python.vision.core import vision_task_running_mode
|
from mediapipe.tasks.python.vision.core import vision_task_running_mode
|
||||||
|
|
||||||
|
ImageClassifierResult = classification_result_module.ClassificationResult
|
||||||
_NormalizedRect = rect.NormalizedRect
|
_NormalizedRect = rect.NormalizedRect
|
||||||
_BaseOptions = base_options_module.BaseOptions
|
_BaseOptions = base_options_module.BaseOptions
|
||||||
_ImageClassifierGraphOptionsProto = image_classifier_graph_options_pb2.ImageClassifierGraphOptions
|
_ImageClassifierGraphOptionsProto = image_classifier_graph_options_pb2.ImageClassifierGraphOptions
|
||||||
|
@ -41,8 +41,8 @@ _RunningMode = vision_task_running_mode.VisionTaskRunningMode
|
||||||
_ImageProcessingOptions = image_processing_options_module.ImageProcessingOptions
|
_ImageProcessingOptions = image_processing_options_module.ImageProcessingOptions
|
||||||
_TaskInfo = task_info_module.TaskInfo
|
_TaskInfo = task_info_module.TaskInfo
|
||||||
|
|
||||||
_CLASSIFICATION_RESULT_OUT_STREAM_NAME = 'classification_result_out'
|
_CLASSIFICATIONS_STREAM_NAME = 'classifications_out'
|
||||||
_CLASSIFICATION_RESULT_TAG = 'CLASSIFICATION_RESULT'
|
_CLASSIFICATIONS_TAG = 'CLASSIFICATIONS'
|
||||||
_IMAGE_IN_STREAM_NAME = 'image_in'
|
_IMAGE_IN_STREAM_NAME = 'image_in'
|
||||||
_IMAGE_OUT_STREAM_NAME = 'image_out'
|
_IMAGE_OUT_STREAM_NAME = 'image_out'
|
||||||
_IMAGE_TAG = 'IMAGE'
|
_IMAGE_TAG = 'IMAGE'
|
||||||
|
@ -71,9 +71,8 @@ class ImageClassifierOptions:
|
||||||
base_options: _BaseOptions
|
base_options: _BaseOptions
|
||||||
running_mode: _RunningMode = _RunningMode.IMAGE
|
running_mode: _RunningMode = _RunningMode.IMAGE
|
||||||
classifier_options: _ClassifierOptions = _ClassifierOptions()
|
classifier_options: _ClassifierOptions = _ClassifierOptions()
|
||||||
result_callback: Optional[
|
result_callback: Optional[Callable[
|
||||||
Callable[[classifications.ClassificationResult, image_module.Image, int],
|
[ImageClassifierResult, image_module.Image, int], None]] = None
|
||||||
None]] = None
|
|
||||||
|
|
||||||
@doc_controls.do_not_generate_docs
|
@doc_controls.do_not_generate_docs
|
||||||
def to_pb2(self) -> _ImageClassifierGraphOptionsProto:
|
def to_pb2(self) -> _ImageClassifierGraphOptionsProto:
|
||||||
|
@ -137,17 +136,12 @@ class ImageClassifier(base_vision_task_api.BaseVisionTaskApi):
|
||||||
|
|
||||||
classification_result_proto = classifications_pb2.ClassificationResult()
|
classification_result_proto = classifications_pb2.ClassificationResult()
|
||||||
classification_result_proto.CopyFrom(
|
classification_result_proto.CopyFrom(
|
||||||
packet_getter.get_proto(
|
packet_getter.get_proto(output_packets[_CLASSIFICATIONS_STREAM_NAME]))
|
||||||
output_packets[_CLASSIFICATION_RESULT_OUT_STREAM_NAME]))
|
|
||||||
|
|
||||||
classification_result = classifications.ClassificationResult([
|
|
||||||
classifications.Classifications.create_from_pb2(classification)
|
|
||||||
for classification in classification_result_proto.classifications
|
|
||||||
])
|
|
||||||
image = packet_getter.get_image(output_packets[_IMAGE_OUT_STREAM_NAME])
|
image = packet_getter.get_image(output_packets[_IMAGE_OUT_STREAM_NAME])
|
||||||
timestamp = output_packets[_IMAGE_OUT_STREAM_NAME].timestamp
|
timestamp = output_packets[_IMAGE_OUT_STREAM_NAME].timestamp
|
||||||
options.result_callback(classification_result, image,
|
options.result_callback(
|
||||||
timestamp.value // _MICRO_SECONDS_PER_MILLISECOND)
|
ImageClassifierResult.create_from_pb2(classification_result_proto),
|
||||||
|
image, timestamp.value // _MICRO_SECONDS_PER_MILLISECOND)
|
||||||
|
|
||||||
task_info = _TaskInfo(
|
task_info = _TaskInfo(
|
||||||
task_graph=_TASK_GRAPH_NAME,
|
task_graph=_TASK_GRAPH_NAME,
|
||||||
|
@ -156,10 +150,8 @@ class ImageClassifier(base_vision_task_api.BaseVisionTaskApi):
|
||||||
':'.join([_NORM_RECT_TAG, _NORM_RECT_STREAM_NAME]),
|
':'.join([_NORM_RECT_TAG, _NORM_RECT_STREAM_NAME]),
|
||||||
],
|
],
|
||||||
output_streams=[
|
output_streams=[
|
||||||
':'.join([
|
':'.join([_CLASSIFICATIONS_TAG, _CLASSIFICATIONS_STREAM_NAME]),
|
||||||
_CLASSIFICATION_RESULT_TAG,
|
':'.join([_IMAGE_TAG, _IMAGE_OUT_STREAM_NAME])
|
||||||
_CLASSIFICATION_RESULT_OUT_STREAM_NAME
|
|
||||||
]), ':'.join([_IMAGE_TAG, _IMAGE_OUT_STREAM_NAME])
|
|
||||||
],
|
],
|
||||||
task_options=options)
|
task_options=options)
|
||||||
return cls(
|
return cls(
|
||||||
|
@ -172,7 +164,7 @@ class ImageClassifier(base_vision_task_api.BaseVisionTaskApi):
|
||||||
self,
|
self,
|
||||||
image: image_module.Image,
|
image: image_module.Image,
|
||||||
image_processing_options: Optional[_ImageProcessingOptions] = None
|
image_processing_options: Optional[_ImageProcessingOptions] = None
|
||||||
) -> classifications.ClassificationResult:
|
) -> ImageClassifierResult:
|
||||||
"""Performs image classification on the provided MediaPipe Image.
|
"""Performs image classification on the provided MediaPipe Image.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
|
@ -196,20 +188,16 @@ class ImageClassifier(base_vision_task_api.BaseVisionTaskApi):
|
||||||
|
|
||||||
classification_result_proto = classifications_pb2.ClassificationResult()
|
classification_result_proto = classifications_pb2.ClassificationResult()
|
||||||
classification_result_proto.CopyFrom(
|
classification_result_proto.CopyFrom(
|
||||||
packet_getter.get_proto(
|
packet_getter.get_proto(output_packets[_CLASSIFICATIONS_STREAM_NAME]))
|
||||||
output_packets[_CLASSIFICATION_RESULT_OUT_STREAM_NAME]))
|
|
||||||
|
|
||||||
return classifications.ClassificationResult([
|
return ImageClassifierResult.create_from_pb2(classification_result_proto)
|
||||||
classifications.Classifications.create_from_pb2(classification)
|
|
||||||
for classification in classification_result_proto.classifications
|
|
||||||
])
|
|
||||||
|
|
||||||
def classify_for_video(
|
def classify_for_video(
|
||||||
self,
|
self,
|
||||||
image: image_module.Image,
|
image: image_module.Image,
|
||||||
timestamp_ms: int,
|
timestamp_ms: int,
|
||||||
image_processing_options: Optional[_ImageProcessingOptions] = None
|
image_processing_options: Optional[_ImageProcessingOptions] = None
|
||||||
) -> classifications.ClassificationResult:
|
) -> ImageClassifierResult:
|
||||||
"""Performs image classification on the provided video frames.
|
"""Performs image classification on the provided video frames.
|
||||||
|
|
||||||
Only use this method when the ImageClassifier is created with the video
|
Only use this method when the ImageClassifier is created with the video
|
||||||
|
@ -241,13 +229,9 @@ class ImageClassifier(base_vision_task_api.BaseVisionTaskApi):
|
||||||
|
|
||||||
classification_result_proto = classifications_pb2.ClassificationResult()
|
classification_result_proto = classifications_pb2.ClassificationResult()
|
||||||
classification_result_proto.CopyFrom(
|
classification_result_proto.CopyFrom(
|
||||||
packet_getter.get_proto(
|
packet_getter.get_proto(output_packets[_CLASSIFICATIONS_STREAM_NAME]))
|
||||||
output_packets[_CLASSIFICATION_RESULT_OUT_STREAM_NAME]))
|
|
||||||
|
|
||||||
return classifications.ClassificationResult([
|
return ImageClassifierResult.create_from_pb2(classification_result_proto)
|
||||||
classifications.Classifications.create_from_pb2(classification)
|
|
||||||
for classification in classification_result_proto.classifications
|
|
||||||
])
|
|
||||||
|
|
||||||
def classify_async(
|
def classify_async(
|
||||||
self,
|
self,
|
||||||
|
|
Loading…
Reference in New Issue
Block a user