diff --git a/mediapipe/tasks/python/components/containers/BUILD b/mediapipe/tasks/python/components/containers/BUILD index 9e0a90911..b02021acd 100644 --- a/mediapipe/tasks/python/components/containers/BUILD +++ b/mediapipe/tasks/python/components/containers/BUILD @@ -85,21 +85,12 @@ py_library( ], ) -py_library( - name = "classifications", - srcs = ["classifications.py"], - deps = [ - ":category", - "//mediapipe/tasks/cc/components/containers/proto:classifications_py_pb2", - "//mediapipe/tasks/python/core:optional_dependencies", - ], -) - py_library( name = "classification_result", srcs = ["classification_result.py"], deps = [ ":category", + "//mediapipe/framework/formats:classification_py_pb2", "//mediapipe/tasks/cc/components/containers/proto:classifications_py_pb2", "//mediapipe/tasks/python/core:optional_dependencies", ], diff --git a/mediapipe/tasks/python/components/containers/classification_result.py b/mediapipe/tasks/python/components/containers/classification_result.py index cc25fc708..20dc6991a 100644 --- a/mediapipe/tasks/python/components/containers/classification_result.py +++ b/mediapipe/tasks/python/components/containers/classification_result.py @@ -16,10 +16,13 @@ import dataclasses from typing import List, Optional +from mediapipe.framework.formats import classification_pb2 from mediapipe.tasks.cc.components.containers.proto import classifications_pb2 from mediapipe.tasks.python.components.containers import category as category_module from mediapipe.tasks.python.core.optional_dependencies import doc_controls +_ClassificationProto = classification_pb2.Classification +_ClassificationListProto = classification_pb2.ClassificationList _ClassificationsProto = classifications_pb2.Classifications _ClassificationResultProto = classifications_pb2.ClassificationResult @@ -41,6 +44,24 @@ class Classifications: head_index: int head_name: Optional[str] = None + @doc_controls.do_not_generate_docs + def to_pb2(self) -> _ClassificationsProto: + """Generates a Classifications protobuf object.""" + classification_list_proto = _ClassificationListProto() + classification_list_proto.classification.extend([ + _ClassificationProto( + index=category.index, + score=category.score, + label=category.category_name, + display_name=category.display_name + ) + for category in self.categories + ]) + return _ClassificationsProto( + classification_list=classification_list_proto, + head_index=self.head_index, + head_name=self.head_name) + @classmethod @doc_controls.do_not_generate_docs def create_from_pb2(cls, pb2_obj: _ClassificationsProto) -> 'Classifications': @@ -78,6 +99,16 @@ class ClassificationResult: classifications: List[Classifications] timestamp_ms: Optional[int] = None + @doc_controls.do_not_generate_docs + def to_pb2(self) -> _ClassificationResultProto: + """Generates a ClassificationResult protobuf object.""" + return _ClassificationResultProto( + classifications=[ + classification.to_pb2() + for classification in self.classifications + ], + timestamp_ms=self.timestamp_ms) + @classmethod @doc_controls.do_not_generate_docs def create_from_pb2( diff --git a/mediapipe/tasks/python/components/containers/classifications.py b/mediapipe/tasks/python/components/containers/classifications.py deleted file mode 100644 index 90ab22614..000000000 --- a/mediapipe/tasks/python/components/containers/classifications.py +++ /dev/null @@ -1,168 +0,0 @@ -# Copyright 2022 The TensorFlow Authors. All Rights Reserved. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -"""Classifications data class.""" - -import dataclasses -from typing import Any, List, Optional - -from mediapipe.tasks.cc.components.containers.proto import classifications_pb2 -from mediapipe.tasks.python.components.containers import category as category_module -from mediapipe.tasks.python.core.optional_dependencies import doc_controls - -_ClassificationEntryProto = classifications_pb2.ClassificationEntry -_ClassificationsProto = classifications_pb2.Classifications -_ClassificationResultProto = classifications_pb2.ClassificationResult - - -@dataclasses.dataclass -class ClassificationEntry: - """List of predicted classes (aka labels) for a given classifier head. - - Attributes: - categories: The array of predicted categories, usually sorted by descending - scores (e.g. from high to low probability). - timestamp_ms: The optional timestamp (in milliseconds) associated to the - classification entry. This is useful for time series use cases, e.g., - audio classification. - """ - - categories: List[category_module.Category] - timestamp_ms: Optional[int] = None - - @doc_controls.do_not_generate_docs - def to_pb2(self) -> _ClassificationEntryProto: - """Generates a ClassificationEntry protobuf object.""" - return _ClassificationEntryProto( - categories=[category.to_pb2() for category in self.categories], - timestamp_ms=self.timestamp_ms) - - @classmethod - @doc_controls.do_not_generate_docs - def create_from_pb2( - cls, pb2_obj: _ClassificationEntryProto) -> 'ClassificationEntry': - """Creates a `ClassificationEntry` object from the given protobuf object.""" - return ClassificationEntry( - categories=[ - category_module.Category.create_from_pb2(category) - for category in pb2_obj.categories - ], - timestamp_ms=pb2_obj.timestamp_ms) - - def __eq__(self, other: Any) -> bool: - """Checks if this object is equal to the given object. - - Args: - other: The object to be compared with. - - Returns: - True if the objects are equal. - """ - if not isinstance(other, ClassificationEntry): - return False - - return self.to_pb2().__eq__(other.to_pb2()) - - -@dataclasses.dataclass -class Classifications: - """Represents the classifications for a given classifier head. - - Attributes: - entries: A list of `ClassificationEntry` objects. - head_index: The index of the classifier head these categories refer to. This - is useful for multi-head models. - head_name: The name of the classifier head, which is the corresponding - tensor metadata name. - """ - - entries: List[ClassificationEntry] - head_index: int - head_name: str - - @doc_controls.do_not_generate_docs - def to_pb2(self) -> _ClassificationsProto: - """Generates a Classifications protobuf object.""" - return _ClassificationsProto( - entries=[entry.to_pb2() for entry in self.entries], - head_index=self.head_index, - head_name=self.head_name) - - @classmethod - @doc_controls.do_not_generate_docs - def create_from_pb2(cls, pb2_obj: _ClassificationsProto) -> 'Classifications': - """Creates a `Classifications` object from the given protobuf object.""" - return Classifications( - entries=[ - ClassificationEntry.create_from_pb2(entry) - for entry in pb2_obj.entries - ], - head_index=pb2_obj.head_index, - head_name=pb2_obj.head_name) - - def __eq__(self, other: Any) -> bool: - """Checks if this object is equal to the given object. - - Args: - other: The object to be compared with. - - Returns: - True if the objects are equal. - """ - if not isinstance(other, Classifications): - return False - - return self.to_pb2().__eq__(other.to_pb2()) - - -@dataclasses.dataclass -class ClassificationResult: - """Contains one set of results per classifier head. - - Attributes: - classifications: A list of `Classifications` objects. - """ - - classifications: List[Classifications] - - @doc_controls.do_not_generate_docs - def to_pb2(self) -> _ClassificationResultProto: - """Generates a ClassificationResult protobuf object.""" - return _ClassificationResultProto(classifications=[ - classification.to_pb2() for classification in self.classifications - ]) - - @classmethod - @doc_controls.do_not_generate_docs - def create_from_pb2( - cls, pb2_obj: _ClassificationResultProto) -> 'ClassificationResult': - """Creates a `ClassificationResult` object from the given protobuf object. - """ - return ClassificationResult(classifications=[ - Classifications.create_from_pb2(classification) - for classification in pb2_obj.classifications - ]) - - def __eq__(self, other: Any) -> bool: - """Checks if this object is equal to the given object. - - Args: - other: The object to be compared with. - - Returns: - True if the objects are equal. - """ - if not isinstance(other, ClassificationResult): - return False - - return self.to_pb2().__eq__(other.to_pb2()) diff --git a/mediapipe/tasks/python/test/text/BUILD b/mediapipe/tasks/python/test/text/BUILD index d7176b0a5..f12c20bc4 100644 --- a/mediapipe/tasks/python/test/text/BUILD +++ b/mediapipe/tasks/python/test/text/BUILD @@ -27,7 +27,7 @@ py_test( ], deps = [ "//mediapipe/tasks/python/components/containers:category", - "//mediapipe/tasks/python/components/containers:classifications", + "//mediapipe/tasks/python/components/containers:classification_result", "//mediapipe/tasks/python/components/processors:classifier_options", "//mediapipe/tasks/python/core:base_options", "//mediapipe/tasks/python/test:test_utils", diff --git a/mediapipe/tasks/python/test/text/text_classifier_test.py b/mediapipe/tasks/python/test/text/text_classifier_test.py index c93def48e..d9221e503 100644 --- a/mediapipe/tasks/python/test/text/text_classifier_test.py +++ b/mediapipe/tasks/python/test/text/text_classifier_test.py @@ -20,18 +20,17 @@ from absl.testing import absltest from absl.testing import parameterized from mediapipe.tasks.python.components.containers import category -from mediapipe.tasks.python.components.containers import classifications as classifications_module +from mediapipe.tasks.python.components.containers import classification_result as classification_result_module from mediapipe.tasks.python.components.processors import classifier_options from mediapipe.tasks.python.core import base_options as base_options_module from mediapipe.tasks.python.test import test_utils from mediapipe.tasks.python.text import text_classifier +TextClassifierResult = classification_result_module.ClassificationResult _BaseOptions = base_options_module.BaseOptions _ClassifierOptions = classifier_options.ClassifierOptions _Category = category.Category -_ClassificationEntry = classifications_module.ClassificationEntry -_Classifications = classifications_module.Classifications -_TextClassifierResult = classifications_module.ClassificationResult +_Classifications = classification_result_module.Classifications _TextClassifier = text_classifier.TextClassifier _TextClassifierOptions = text_classifier.TextClassifierOptions @@ -43,90 +42,75 @@ _NEGATIVE_TEXT = 'What a waste of my time.' _POSITIVE_TEXT = ('This is the best movie I’ve seen in recent years.' 'Strongly recommend it!') -_BERT_NEGATIVE_RESULTS = _TextClassifierResult(classifications=[ - _Classifications( - entries=[ - _ClassificationEntry( - categories=[ - _Category( - index=0, - score=0.999479, - display_name='', - category_name='negative'), - _Category( - index=1, - score=0.00052154, - display_name='', - category_name='positive') - ], - timestamp_ms=0) + +_BERT_NEGATIVE_RESULTS = TextClassifierResult(classifications=[ + _Classifications(categories=[ + _Category( + index=0, + score=0.999479, + display_name='', + category_name='negative'), + _Category( + index=1, + score=0.00052154, + display_name='', + category_name='positive') ], head_index=0, head_name='probability') -]) -_BERT_POSITIVE_RESULTS = _TextClassifierResult(classifications=[ - _Classifications( - entries=[ - _ClassificationEntry( - categories=[ - _Category( - index=1, - score=0.999466, - display_name='', - category_name='positive'), - _Category( - index=0, - score=0.000533596, - display_name='', - category_name='negative') - ], - timestamp_ms=0) + ], + timestamp_ms=0) +_BERT_POSITIVE_RESULTS = TextClassifierResult(classifications=[ + _Classifications(categories=[ + _Category( + index=1, + score=0.999466, + display_name='', + category_name='positive'), + _Category( + index=0, + score=0.000533596, + display_name='', + category_name='negative') ], head_index=0, head_name='probability') -]) -_REGEX_NEGATIVE_RESULTS = _TextClassifierResult(classifications=[ - _Classifications( - entries=[ - _ClassificationEntry( - categories=[ - _Category( - index=0, - score=0.81313, - display_name='', - category_name='Negative'), - _Category( - index=1, - score=0.1868704, - display_name='', - category_name='Positive') - ], - timestamp_ms=0) + ], + timestamp_ms=0) +_REGEX_NEGATIVE_RESULTS = TextClassifierResult(classifications=[ + _Classifications(categories=[ + _Category( + index=0, + score=0.81313, + display_name='', + category_name='Negative'), + _Category( + index=1, + score=0.1868704, + display_name='', + category_name='Positive') ], head_index=0, head_name='probability') -]) -_REGEX_POSITIVE_RESULTS = _TextClassifierResult(classifications=[ - _Classifications( - entries=[ - _ClassificationEntry( - categories=[ - _Category( - index=1, - score=0.5134273, - display_name='', - category_name='Positive'), - _Category( - index=0, - score=0.486573, - display_name='', - category_name='Negative') - ], - timestamp_ms=0) + ], + timestamp_ms=0) +_REGEX_POSITIVE_RESULTS = TextClassifierResult(classifications=[ + _Classifications(categories=[ + _Category( + index=1, + score=0.5134273, + display_name='', + category_name='Positive'), + _Category( + index=0, + score=0.486573, + display_name='', + category_name='Negative') ], head_index=0, head_name='probability') -]) + ], + timestamp_ms=0) class ModelFileType(enum.Enum): diff --git a/mediapipe/tasks/python/test/vision/BUILD b/mediapipe/tasks/python/test/vision/BUILD index c29648160..1ea5f5b7d 100644 --- a/mediapipe/tasks/python/test/vision/BUILD +++ b/mediapipe/tasks/python/test/vision/BUILD @@ -47,7 +47,7 @@ py_test( deps = [ "//mediapipe/python:_framework_bindings", "//mediapipe/tasks/python/components/containers:category", - "//mediapipe/tasks/python/components/containers:classifications", + "//mediapipe/tasks/python/components/containers:classification_result", "//mediapipe/tasks/python/components/containers:rect", "//mediapipe/tasks/python/components/processors:classifier_options", "//mediapipe/tasks/python/core:base_options", diff --git a/mediapipe/tasks/python/test/vision/image_classifier_test.py b/mediapipe/tasks/python/test/vision/image_classifier_test.py index 11941ce23..1f4d39ec7 100644 --- a/mediapipe/tasks/python/test/vision/image_classifier_test.py +++ b/mediapipe/tasks/python/test/vision/image_classifier_test.py @@ -23,8 +23,8 @@ from absl.testing import parameterized import numpy as np from mediapipe.python._framework_bindings import image -from mediapipe.tasks.python.components.containers import category -from mediapipe.tasks.python.components.containers import classifications as classifications_module +from mediapipe.tasks.python.components.containers import category as category_module +from mediapipe.tasks.python.components.containers import classification_result as classification_result_module from mediapipe.tasks.python.components.containers import rect from mediapipe.tasks.python.components.processors import classifier_options from mediapipe.tasks.python.core import base_options as base_options_module @@ -33,13 +33,12 @@ from mediapipe.tasks.python.vision import image_classifier from mediapipe.tasks.python.vision.core import image_processing_options as image_processing_options_module from mediapipe.tasks.python.vision.core import vision_task_running_mode +ImageClassifierResult = classification_result_module.ClassificationResult _Rect = rect.Rect _BaseOptions = base_options_module.BaseOptions _ClassifierOptions = classifier_options.ClassifierOptions -_Category = category.Category -_ClassificationEntry = classifications_module.ClassificationEntry -_Classifications = classifications_module.Classifications -_ClassificationResult = classifications_module.ClassificationResult +_Category = category_module.Category +_Classifications = classification_result_module.Classifications _Image = image.Image _ImageClassifier = image_classifier.ImageClassifier _ImageClassifierOptions = image_classifier.ImageClassifierOptions @@ -55,68 +54,56 @@ _MAX_RESULTS = 3 _TEST_DATA_DIR = 'mediapipe/tasks/testdata/vision' -def _generate_empty_results(timestamp_ms: int) -> _ClassificationResult: - return _ClassificationResult(classifications=[ - _Classifications( - entries=[ - _ClassificationEntry(categories=[], timestamp_ms=timestamp_ms) +def _generate_empty_results() -> ImageClassifierResult: + return ImageClassifierResult(classifications=[ + _Classifications(categories=[], head_index=0, head_name='probability') + ], + timestamp_ms=0) + + +def _generate_burger_results() -> ImageClassifierResult: + return ImageClassifierResult(classifications=[ + _Classifications(categories=[ + _Category( + index=934, + score=0.793959, + display_name='', + category_name='cheeseburger'), + _Category( + index=932, + score=0.0273929, + display_name='', + category_name='bagel'), + _Category( + index=925, + score=0.0193408, + display_name='', + category_name='guacamole'), + _Category( + index=963, + score=0.00632786, + display_name='', + category_name='meat loaf') ], head_index=0, head_name='probability') - ]) + ], + timestamp_ms=0) -def _generate_burger_results(timestamp_ms: int) -> _ClassificationResult: - return _ClassificationResult(classifications=[ - _Classifications( - entries=[ - _ClassificationEntry( - categories=[ - _Category( - index=934, - score=0.793959, - display_name='', - category_name='cheeseburger'), - _Category( - index=932, - score=0.0273929, - display_name='', - category_name='bagel'), - _Category( - index=925, - score=0.0193408, - display_name='', - category_name='guacamole'), - _Category( - index=963, - score=0.00632786, - display_name='', - category_name='meat loaf') - ], - timestamp_ms=timestamp_ms) +def _generate_soccer_ball_results() -> ImageClassifierResult: + return ImageClassifierResult(classifications=[ + _Classifications(categories=[ + _Category( + index=806, + score=0.996527, + display_name='', + category_name='soccer ball') ], head_index=0, head_name='probability') - ]) - - -def _generate_soccer_ball_results(timestamp_ms: int) -> _ClassificationResult: - return _ClassificationResult(classifications=[ - _Classifications( - entries=[ - _ClassificationEntry( - categories=[ - _Category( - index=806, - score=0.996527, - display_name='', - category_name='soccer ball') - ], - timestamp_ms=timestamp_ms) - ], - head_index=0, - head_name='probability') - ]) + ], + timestamp_ms=0) class ModelFileType(enum.Enum): @@ -165,8 +152,8 @@ class ImageClassifierTest(parameterized.TestCase): self.assertIsInstance(classifier, _ImageClassifier) @parameterized.parameters( - (ModelFileType.FILE_NAME, 4, _generate_burger_results(0)), - (ModelFileType.FILE_CONTENT, 4, _generate_burger_results(0))) + (ModelFileType.FILE_NAME, 4, _generate_burger_results()), + (ModelFileType.FILE_CONTENT, 4, _generate_burger_results())) def test_classify(self, model_file_type, max_results, expected_classification_result): # Creates classifier. @@ -195,8 +182,8 @@ class ImageClassifierTest(parameterized.TestCase): classifier.close() @parameterized.parameters( - (ModelFileType.FILE_NAME, 4, _generate_burger_results(0)), - (ModelFileType.FILE_CONTENT, 4, _generate_burger_results(0))) + (ModelFileType.FILE_NAME, 4, _generate_burger_results()), + (ModelFileType.FILE_CONTENT, 4, _generate_burger_results())) def test_classify_in_context(self, model_file_type, max_results, expected_classification_result): if model_file_type is ModelFileType.FILE_NAME: @@ -236,7 +223,7 @@ class ImageClassifierTest(parameterized.TestCase): image_result = classifier.classify(test_image, image_processing_options) # Comparing results. test_utils.assert_proto_equals(self, image_result.to_pb2(), - _generate_soccer_ball_results(0).to_pb2()) + _generate_soccer_ball_results().to_pb2()) def test_score_threshold_option(self): custom_classifier_options = _ClassifierOptions( @@ -250,8 +237,8 @@ class ImageClassifierTest(parameterized.TestCase): classifications = image_result.classifications for classification in classifications: - for entry in classification.entries: - score = entry.categories[0].score + for category in classification.categories: + score = category.score self.assertGreaterEqual( score, _SCORE_THRESHOLD, f'Classification with score lower than threshold found. ' @@ -266,7 +253,7 @@ class ImageClassifierTest(parameterized.TestCase): with _ImageClassifier.create_from_options(options) as classifier: # Performs image classification on the input. image_result = classifier.classify(self.test_image) - categories = image_result.classifications[0].entries[0].categories + categories = image_result.classifications[0].categories self.assertLessEqual( len(categories), _MAX_RESULTS, 'Too many results returned.') @@ -283,8 +270,8 @@ class ImageClassifierTest(parameterized.TestCase): classifications = image_result.classifications for classification in classifications: - for entry in classification.entries: - label = entry.categories[0].category_name + for category in classification.categories: + label = category.category_name self.assertIn(label, _ALLOW_LIST, f'Label {label} found but not in label allow list') @@ -299,8 +286,8 @@ class ImageClassifierTest(parameterized.TestCase): classifications = image_result.classifications for classification in classifications: - for entry in classification.entries: - label = entry.categories[0].category_name + for category in classification.categories: + label = category.category_name self.assertNotIn(label, _DENY_LIST, f'Label {label} found but in deny list.') @@ -326,7 +313,7 @@ class ImageClassifierTest(parameterized.TestCase): with _ImageClassifier.create_from_options(options) as classifier: # Performs image classification on the input. image_result = classifier.classify(self.test_image) - self.assertEmpty(image_result.classifications[0].entries[0].categories) + self.assertEmpty(image_result.classifications[0].categories) def test_missing_result_callback(self): options = _ImageClassifierOptions( @@ -406,7 +393,7 @@ class ImageClassifierTest(parameterized.TestCase): self.test_image, timestamp) test_utils.assert_proto_equals( self, classification_result.to_pb2(), - _generate_burger_results(timestamp).to_pb2()) + _generate_burger_results().to_pb2()) def test_classify_for_video_succeeds_with_region_of_interest(self): custom_classifier_options = _ClassifierOptions(max_results=1) @@ -427,7 +414,7 @@ class ImageClassifierTest(parameterized.TestCase): test_image, timestamp, image_processing_options) test_utils.assert_proto_equals( self, classification_result.to_pb2(), - _generate_soccer_ball_results(timestamp).to_pb2()) + _generate_soccer_ball_results().to_pb2()) def test_calling_classify_in_live_stream_mode(self): options = _ImageClassifierOptions( @@ -462,15 +449,15 @@ class ImageClassifierTest(parameterized.TestCase): ValueError, r'Input timestamp must be monotonically increasing'): classifier.classify_async(self.test_image, 0) - @parameterized.parameters((0, _generate_burger_results), - (1, _generate_empty_results)) - def test_classify_async_calls(self, threshold, expected_result_fn): + @parameterized.parameters((0, _generate_burger_results()), + (1, _generate_empty_results())) + def test_classify_async_calls(self, threshold, expected_result): observed_timestamp_ms = -1 - def check_result(result: _ClassificationResult, output_image: _Image, + def check_result(result: ImageClassifierResult, output_image: _Image, timestamp_ms: int): test_utils.assert_proto_equals(self, result.to_pb2(), - expected_result_fn(timestamp_ms).to_pb2()) + expected_result.to_pb2()) self.assertTrue( np.array_equal(output_image.numpy_view(), self.test_image.numpy_view())) @@ -498,11 +485,11 @@ class ImageClassifierTest(parameterized.TestCase): image_processing_options = _ImageProcessingOptions(roi) observed_timestamp_ms = -1 - def check_result(result: _ClassificationResult, output_image: _Image, + def check_result(result: ImageClassifierResult, output_image: _Image, timestamp_ms: int): test_utils.assert_proto_equals( self, result.to_pb2(), - _generate_soccer_ball_results(timestamp_ms).to_pb2()) + _generate_soccer_ball_results().to_pb2()) self.assertEqual(output_image.width, test_image.width) self.assertEqual(output_image.height, test_image.height) self.assertLess(observed_timestamp_ms, timestamp_ms) diff --git a/mediapipe/tasks/python/text/BUILD b/mediapipe/tasks/python/text/BUILD index fd5d701b4..c7278ecb8 100644 --- a/mediapipe/tasks/python/text/BUILD +++ b/mediapipe/tasks/python/text/BUILD @@ -28,7 +28,7 @@ py_library( "//mediapipe/python:packet_getter", "//mediapipe/tasks/cc/components/containers/proto:classifications_py_pb2", "//mediapipe/tasks/cc/text/text_classifier/proto:text_classifier_graph_options_py_pb2", - "//mediapipe/tasks/python/components/containers:classifications", + "//mediapipe/tasks/python/components/containers:classification_result", "//mediapipe/tasks/python/components/processors:classifier_options", "//mediapipe/tasks/python/core:base_options", "//mediapipe/tasks/python/core:optional_dependencies", diff --git a/mediapipe/tasks/python/text/text_classifier.py b/mediapipe/tasks/python/text/text_classifier.py index 1e230ee20..d4259d121 100644 --- a/mediapipe/tasks/python/text/text_classifier.py +++ b/mediapipe/tasks/python/text/text_classifier.py @@ -19,21 +19,21 @@ from mediapipe.python import packet_creator from mediapipe.python import packet_getter from mediapipe.tasks.cc.components.containers.proto import classifications_pb2 from mediapipe.tasks.cc.text.text_classifier.proto import text_classifier_graph_options_pb2 -from mediapipe.tasks.python.components.containers import classifications +from mediapipe.tasks.python.components.containers import classification_result as classification_result_module from mediapipe.tasks.python.components.processors import classifier_options from mediapipe.tasks.python.core import base_options as base_options_module from mediapipe.tasks.python.core import task_info as task_info_module from mediapipe.tasks.python.core.optional_dependencies import doc_controls from mediapipe.tasks.python.text.core import base_text_task_api -TextClassifierResult = classifications.ClassificationResult +TextClassifierResult = classification_result_module.ClassificationResult _BaseOptions = base_options_module.BaseOptions _TextClassifierGraphOptionsProto = text_classifier_graph_options_pb2.TextClassifierGraphOptions _ClassifierOptions = classifier_options.ClassifierOptions _TaskInfo = task_info_module.TaskInfo -_CLASSIFICATION_RESULT_OUT_STREAM_NAME = 'classification_result_out' -_CLASSIFICATION_RESULT_TAG = 'CLASSIFICATION_RESULT' +_CLASSIFICATIONS_STREAM_NAME = 'classifications_out' +_CLASSIFICATIONS_TAG = 'CLASSIFICATIONS' _TEXT_IN_STREAM_NAME = 'text_in' _TEXT_TAG = 'TEXT' _TASK_GRAPH_NAME = 'mediapipe.tasks.text.text_classifier.TextClassifierGraph' @@ -105,8 +105,8 @@ class TextClassifier(base_text_task_api.BaseTextTaskApi): input_streams=[':'.join([_TEXT_TAG, _TEXT_IN_STREAM_NAME])], output_streams=[ ':'.join([ - _CLASSIFICATION_RESULT_TAG, - _CLASSIFICATION_RESULT_OUT_STREAM_NAME + _CLASSIFICATIONS_TAG, + _CLASSIFICATIONS_STREAM_NAME ]) ], task_options=options) @@ -132,9 +132,6 @@ class TextClassifier(base_text_task_api.BaseTextTaskApi): classification_result_proto = classifications_pb2.ClassificationResult() classification_result_proto.CopyFrom( packet_getter.get_proto( - output_packets[_CLASSIFICATION_RESULT_OUT_STREAM_NAME])) + output_packets[_CLASSIFICATIONS_STREAM_NAME])) - return TextClassifierResult([ - classifications.Classifications.create_from_pb2(classification) - for classification in classification_result_proto.classifications - ]) + return TextClassifierResult.create_from_pb2(classification_result_proto) diff --git a/mediapipe/tasks/python/vision/BUILD b/mediapipe/tasks/python/vision/BUILD index 527c6d883..9b4e0ef1b 100644 --- a/mediapipe/tasks/python/vision/BUILD +++ b/mediapipe/tasks/python/vision/BUILD @@ -48,7 +48,7 @@ py_library( "//mediapipe/python:packet_getter", "//mediapipe/tasks/cc/components/containers/proto:classifications_py_pb2", "//mediapipe/tasks/cc/vision/image_classifier/proto:image_classifier_graph_options_py_pb2", - "//mediapipe/tasks/python/components/containers:classifications", + "//mediapipe/tasks/python/components/containers:classification_result", "//mediapipe/tasks/python/components/containers:rect", "//mediapipe/tasks/python/components/processors:classifier_options", "//mediapipe/tasks/python/core:base_options", diff --git a/mediapipe/tasks/python/vision/image_classifier.py b/mediapipe/tasks/python/vision/image_classifier.py index 02819ddf1..377d3fdcb 100644 --- a/mediapipe/tasks/python/vision/image_classifier.py +++ b/mediapipe/tasks/python/vision/image_classifier.py @@ -23,7 +23,7 @@ from mediapipe.python._framework_bindings import image as image_module from mediapipe.python._framework_bindings import packet from mediapipe.tasks.cc.components.containers.proto import classifications_pb2 from mediapipe.tasks.cc.vision.image_classifier.proto import image_classifier_graph_options_pb2 -from mediapipe.tasks.python.components.containers import classifications +from mediapipe.tasks.python.components.containers import classification_result as classification_result_module from mediapipe.tasks.python.components.containers import rect from mediapipe.tasks.python.components.processors import classifier_options from mediapipe.tasks.python.core import base_options as base_options_module @@ -33,6 +33,7 @@ from mediapipe.tasks.python.vision.core import base_vision_task_api from mediapipe.tasks.python.vision.core import image_processing_options as image_processing_options_module from mediapipe.tasks.python.vision.core import vision_task_running_mode +ImageClassifierResult = classification_result_module.ClassificationResult _NormalizedRect = rect.NormalizedRect _BaseOptions = base_options_module.BaseOptions _ImageClassifierGraphOptionsProto = image_classifier_graph_options_pb2.ImageClassifierGraphOptions @@ -41,8 +42,8 @@ _RunningMode = vision_task_running_mode.VisionTaskRunningMode _ImageProcessingOptions = image_processing_options_module.ImageProcessingOptions _TaskInfo = task_info_module.TaskInfo -_CLASSIFICATION_RESULT_OUT_STREAM_NAME = 'classification_result_out' -_CLASSIFICATION_RESULT_TAG = 'CLASSIFICATION_RESULT' +_CLASSIFICATIONS_STREAM_NAME = 'classifications_out' +_CLASSIFICATIONS_TAG = 'CLASSIFICATIONS' _IMAGE_IN_STREAM_NAME = 'image_in' _IMAGE_OUT_STREAM_NAME = 'image_out' _IMAGE_TAG = 'IMAGE' @@ -72,7 +73,7 @@ class ImageClassifierOptions: running_mode: _RunningMode = _RunningMode.IMAGE classifier_options: _ClassifierOptions = _ClassifierOptions() result_callback: Optional[ - Callable[[classifications.ClassificationResult, image_module.Image, int], + Callable[[ImageClassifierResult, image_module.Image, int], None]] = None @doc_controls.do_not_generate_docs @@ -137,17 +138,13 @@ class ImageClassifier(base_vision_task_api.BaseVisionTaskApi): classification_result_proto = classifications_pb2.ClassificationResult() classification_result_proto.CopyFrom( - packet_getter.get_proto( - output_packets[_CLASSIFICATION_RESULT_OUT_STREAM_NAME])) - - classification_result = classifications.ClassificationResult([ - classifications.Classifications.create_from_pb2(classification) - for classification in classification_result_proto.classifications - ]) + packet_getter.get_proto(output_packets[_CLASSIFICATIONS_STREAM_NAME])) image = packet_getter.get_image(output_packets[_IMAGE_OUT_STREAM_NAME]) timestamp = output_packets[_IMAGE_OUT_STREAM_NAME].timestamp - options.result_callback(classification_result, image, - timestamp.value // _MICRO_SECONDS_PER_MILLISECOND) + options.result_callback( + ImageClassifierResult.create_from_pb2(classification_result_proto), + image, + timestamp.value // _MICRO_SECONDS_PER_MILLISECOND) task_info = _TaskInfo( task_graph=_TASK_GRAPH_NAME, @@ -157,8 +154,8 @@ class ImageClassifier(base_vision_task_api.BaseVisionTaskApi): ], output_streams=[ ':'.join([ - _CLASSIFICATION_RESULT_TAG, - _CLASSIFICATION_RESULT_OUT_STREAM_NAME + _CLASSIFICATIONS_TAG, + _CLASSIFICATIONS_STREAM_NAME ]), ':'.join([_IMAGE_TAG, _IMAGE_OUT_STREAM_NAME]) ], task_options=options) @@ -172,7 +169,7 @@ class ImageClassifier(base_vision_task_api.BaseVisionTaskApi): self, image: image_module.Image, image_processing_options: Optional[_ImageProcessingOptions] = None - ) -> classifications.ClassificationResult: + ) -> ImageClassifierResult: """Performs image classification on the provided MediaPipe Image. Args: @@ -196,20 +193,16 @@ class ImageClassifier(base_vision_task_api.BaseVisionTaskApi): classification_result_proto = classifications_pb2.ClassificationResult() classification_result_proto.CopyFrom( - packet_getter.get_proto( - output_packets[_CLASSIFICATION_RESULT_OUT_STREAM_NAME])) + packet_getter.get_proto(output_packets[_CLASSIFICATIONS_STREAM_NAME])) - return classifications.ClassificationResult([ - classifications.Classifications.create_from_pb2(classification) - for classification in classification_result_proto.classifications - ]) + return ImageClassifierResult.create_from_pb2(classification_result_proto) def classify_for_video( self, image: image_module.Image, timestamp_ms: int, image_processing_options: Optional[_ImageProcessingOptions] = None - ) -> classifications.ClassificationResult: + ) -> ImageClassifierResult: """Performs image classification on the provided video frames. Only use this method when the ImageClassifier is created with the video @@ -241,13 +234,9 @@ class ImageClassifier(base_vision_task_api.BaseVisionTaskApi): classification_result_proto = classifications_pb2.ClassificationResult() classification_result_proto.CopyFrom( - packet_getter.get_proto( - output_packets[_CLASSIFICATION_RESULT_OUT_STREAM_NAME])) + packet_getter.get_proto(output_packets[_CLASSIFICATIONS_STREAM_NAME])) - return classifications.ClassificationResult([ - classifications.Classifications.create_from_pb2(classification) - for classification in classification_result_proto.classifications - ]) + return ImageClassifierResult.create_from_pb2(classification_result_proto) def classify_async( self,