From cb52432159ecfea69c6df54be2cb56fd569f275f Mon Sep 17 00:00:00 2001 From: kinaryml Date: Thu, 8 Sep 2022 06:23:03 -0700 Subject: [PATCH 01/23] Added image classification implementation files and associated tests --- .../tasks/python/components/containers/BUILD | 10 + .../components/containers/classifications.py | 169 ++++++++++ mediapipe/tasks/python/test/vision/BUILD | 38 ++- .../test/vision/image_classification_test.py | 301 ++++++++++++++++++ mediapipe/tasks/python/vision/BUILD | 20 ++ .../python/vision/image_classification.py | 227 +++++++++++++ 6 files changed, 764 insertions(+), 1 deletion(-) create mode 100644 mediapipe/tasks/python/components/containers/classifications.py create mode 100644 mediapipe/tasks/python/test/vision/image_classification_test.py create mode 100644 mediapipe/tasks/python/vision/image_classification.py diff --git a/mediapipe/tasks/python/components/containers/BUILD b/mediapipe/tasks/python/components/containers/BUILD index 2bc951220..eb3acdd97 100644 --- a/mediapipe/tasks/python/components/containers/BUILD +++ b/mediapipe/tasks/python/components/containers/BUILD @@ -47,3 +47,13 @@ py_library( "//mediapipe/tasks/python/core:optional_dependencies", ], ) + +py_library( + name = "classifications", + srcs = ["classifications.py"], + deps = [ + ":category", + "//mediapipe/tasks/cc/components/containers:classifications_py_pb2", + "//mediapipe/tasks/python/core:optional_dependencies", + ], +) diff --git a/mediapipe/tasks/python/components/containers/classifications.py b/mediapipe/tasks/python/components/containers/classifications.py new file mode 100644 index 000000000..19c5decde --- /dev/null +++ b/mediapipe/tasks/python/components/containers/classifications.py @@ -0,0 +1,169 @@ +# Copyright 2022 The TensorFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""Classifications data class.""" + +import dataclasses +from typing import Any, List, Optional + +from mediapipe.tasks.cc.components.containers import classifications_pb2 +from mediapipe.tasks.python.components.containers import category as category_module +from mediapipe.tasks.python.core.optional_dependencies import doc_controls + +_ClassificationEntryProto = classifications_pb2.ClassificationEntry +_ClassificationsProto = classifications_pb2.Classifications +_ClassificationResultProto = classifications_pb2.ClassificationResult + + +@dataclasses.dataclass +class ClassificationEntry: + """List of predicted classes (aka labels) for a given classifier head. + + Attributes: + categories: The array of predicted categories, usually sorted by descending + scores (e.g. from high to low probability). + timestamp_ms: The optional timestamp (in milliseconds) associated to the + classification entry. This is useful for time series use cases, e.g., + audio classification. + """ + + categories: List[category_module.Category] + timestamp_ms: Optional[int] = None + + @doc_controls.do_not_generate_docs + def to_pb2(self) -> _ClassificationEntryProto: + """Generates a ClassificationEntry protobuf object.""" + return _ClassificationEntryProto( + categories=[category.to_pb2() for category in self.categories], + timestamp_ms=self.timestamp_ms) + + @classmethod + @doc_controls.do_not_generate_docs + def create_from_pb2( + cls, pb2_obj: _ClassificationEntryProto) -> 'ClassificationEntry': + """Creates a `ClassificationEntry` object from the given protobuf object.""" + return ClassificationEntry( + categories=[ + category_module.Category.create_from_pb2(category) + for category in pb2_obj.categories + ], + timestamp_ms=pb2_obj.timestamp_ms) + + def __eq__(self, other: Any) -> bool: + """Checks if this object is equal to the given object. + + Args: + other: The object to be compared with. + + Returns: + True if the objects are equal. + """ + if not isinstance(other, ClassificationEntry): + return False + + return self.to_pb2().__eq__(other.to_pb2()) + + +@dataclasses.dataclass +class Classifications: + """Represents the classifications for a given classifier head. + + Attributes: + entries: A list of `ClassificationEntry` objects. + head_index: The index of the classifier head these categories refer to. + This is useful for multi-head models. + head_name: The name of the classifier head, which is the corresponding + tensor metadata name. + """ + + entries: List[ClassificationEntry] + head_index: int + head_name: str + + @doc_controls.do_not_generate_docs + def to_pb2(self) -> _ClassificationsProto: + """Generates a Classifications protobuf object.""" + return _ClassificationsProto( + entries=[entry.to_pb2() for entry in self.entries], + head_index=self.head_index, + head_name=self.head_name) + + @classmethod + @doc_controls.do_not_generate_docs + def create_from_pb2(cls, pb2_obj: _ClassificationsProto) -> 'Classifications': + """Creates a `Classifications` object from the given protobuf object.""" + return Classifications( + entries=[ + ClassificationEntry.create_from_pb2(entry) + for entry in pb2_obj.entries + ], + head_index=pb2_obj.head_index, + head_name=pb2_obj.head_name) + + def __eq__(self, other: Any) -> bool: + """Checks if this object is equal to the given object. + + Args: + other: The object to be compared with. + + Returns: + True if the objects are equal. + """ + if not isinstance(other, Classifications): + return False + + return self.to_pb2().__eq__(other.to_pb2()) + + +@dataclasses.dataclass +class ClassificationResult: + """Contains one set of results per classifier head. + + Attributes: + classifications: A list of `Classifications` objects. + """ + + classifications: List[Classifications] + + @doc_controls.do_not_generate_docs + def to_pb2(self) -> _ClassificationResultProto: + """Generates a ClassificationResult protobuf object.""" + return _ClassificationResultProto( + classifications=[ + classification.to_pb2() for classification in self.classifications + ]) + + @classmethod + @doc_controls.do_not_generate_docs + def create_from_pb2( + cls, pb2_obj: _ClassificationResultProto) -> 'ClassificationResult': + """Creates a `ClassificationResult` object from the given protobuf object.""" + return ClassificationResult( + classifications=[ + Classifications.create_from_pb2(classification) + for classification in pb2_obj.classifications + ]) + + def __eq__(self, other: Any) -> bool: + """Checks if this object is equal to the given object. + + Args: + other: The object to be compared with. + + Returns: + True if the objects are equal. + """ + if not isinstance(other, ClassificationResult): + return False + + return self.to_pb2().__eq__(other.to_pb2()) diff --git a/mediapipe/tasks/python/test/vision/BUILD b/mediapipe/tasks/python/test/vision/BUILD index bb495338d..a63c36b55 100644 --- a/mediapipe/tasks/python/test/vision/BUILD +++ b/mediapipe/tasks/python/test/vision/BUILD @@ -18,4 +18,40 @@ package(default_visibility = ["//mediapipe/tasks:internal"]) licenses(["notice"]) -# TODO: This test fails in OSS +py_test( + name = "object_detector_test", + srcs = ["object_detector_test.py"], + data = [ + "//mediapipe/tasks/testdata/vision:test_images", + "//mediapipe/tasks/testdata/vision:test_models", + ], + deps = [ + # build rule placeholder: numpy dep, + "//mediapipe/tasks/python/components/containers:bounding_box", + "//mediapipe/tasks/python/components/containers:category", + "//mediapipe/tasks/python/components/containers:detections", + "//mediapipe/tasks/python/core:base_options", + "//mediapipe/tasks/python/test:test_util", + "//mediapipe/tasks/python/vision:object_detector", + "//mediapipe/tasks/python/vision/core:vision_task_running_mode", + "@absl_py//absl/testing:parameterized", + ], +) + +py_test( + name = "image_classification_test", + srcs = ["image_classification_test.py"], + data = [ + "//mediapipe/tasks/testdata/vision:test_images", + "//mediapipe/tasks/testdata/vision:test_models", + ], + deps = [ + "//mediapipe/tasks/python/components/containers:category", + "//mediapipe/tasks/python/components/containers:classifications", + "//mediapipe/tasks/python/core:base_options", + "//mediapipe/tasks/python/test:test_util", + "//mediapipe/tasks/python/vision:image_classification", + "//mediapipe/tasks/python/vision/core:vision_task_running_mode", + "@absl_py//absl/testing:parameterized", + ], +) diff --git a/mediapipe/tasks/python/test/vision/image_classification_test.py b/mediapipe/tasks/python/test/vision/image_classification_test.py new file mode 100644 index 000000000..3650c547c --- /dev/null +++ b/mediapipe/tasks/python/test/vision/image_classification_test.py @@ -0,0 +1,301 @@ +# Copyright 2022 The MediaPipe Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""Tests for image classifier.""" + +import enum + +from absl.testing import absltest +from absl.testing import parameterized + +from mediapipe.python._framework_bindings import image as image_module +from mediapipe.tasks.python.components.containers import category as category_module +from mediapipe.tasks.python.components.containers import classifications as classifications_module +from mediapipe.tasks.python.core import base_options as base_options_module +from mediapipe.tasks.python.test import test_util +from mediapipe.tasks.python.vision import image_classification +from mediapipe.tasks.python.vision.core import vision_task_running_mode as running_mode_module + +_BaseOptions = base_options_module.BaseOptions +_Category = category_module.Category +_ClassificationEntry = classifications_module.ClassificationEntry +_Classifications = classifications_module.Classifications +_ClassificationResult = classifications_module.ClassificationResult +_Image = image_module.Image +_ImageClassifier = image_classification.ImageClassifier +_ImageClassifierOptions = image_classification.ImageClassifierOptions +_RUNNING_MODE = running_mode_module.VisionTaskRunningMode + +_MODEL_FILE = 'mobilenet_v2_1.0_224.tflite' +_IMAGE_FILE = 'burger.jpg' +_EXPECTED_CLASSIFICATION_RESULT = _ClassificationResult( + classifications=[ + _Classifications( + entries=[ + _ClassificationEntry( + categories=[ + _Category( + index=934, + score=0.7952049970626831, + display_name='', + category_name='cheeseburger'), + _Category( + index=932, + score=0.02732999622821808, + display_name='', + category_name='bagel'), + _Category( + index=925, + score=0.01933487318456173, + display_name='', + category_name='guacamole'), + _Category( + index=963, + score=0.006279350258409977, + display_name='', + category_name='meat loaf') + ], + timestamp_ms=0 + ) + ], + head_index=0, + head_name='probability') + ]) +_ALLOW_LIST = ['cheeseburger', 'guacamole'] +_DENY_LIST = ['cheeseburger'] +_SCORE_THRESHOLD = 0.5 +_MAX_RESULTS = 3 + + +class ModelFileType(enum.Enum): + FILE_CONTENT = 1 + FILE_NAME = 2 + + +class ImageClassifierTest(parameterized.TestCase): + + def setUp(self): + super().setUp() + self.test_image = test_util.read_test_image( + test_util.get_test_data_path(_IMAGE_FILE)) + self.model_path = test_util.get_test_data_path(_MODEL_FILE) + + def test_create_from_file_succeeds_with_valid_model_path(self): + # Creates with default option and valid model file successfully. + with _ImageClassifier.create_from_model_path(self.model_path) as classifier: + self.assertIsInstance(classifier, _ImageClassifier) + + def test_create_from_options_succeeds_with_valid_model_path(self): + # Creates with options containing model file successfully. + base_options = _BaseOptions(file_name=self.model_path) + options = _ImageClassifierOptions(base_options=base_options) + with _ImageClassifier.create_from_options(options) as classifier: + self.assertIsInstance(classifier, _ImageClassifier) + + def test_create_from_options_fails_with_invalid_model_path(self): + # Invalid empty model path. + with self.assertRaisesRegex( + ValueError, + r"ExternalFile must specify at least one of 'file_content', " + r"'file_name' or 'file_descriptor_meta'."): + base_options = _BaseOptions(file_name='') + options = _ImageClassifierOptions(base_options=base_options) + _ImageClassifier.create_from_options(options) + + def test_create_from_options_succeeds_with_valid_model_content(self): + # Creates with options containing model content successfully. + with open(self.model_path, 'rb') as f: + base_options = _BaseOptions(file_content=f.read()) + options = _ImageClassifierOptions(base_options=base_options) + classifier = _ImageClassifier.create_from_options(options) + self.assertIsInstance(classifier, _ImageClassifier) + + @parameterized.parameters( + (ModelFileType.FILE_NAME, 4, _EXPECTED_CLASSIFICATION_RESULT), + (ModelFileType.FILE_CONTENT, 4, _EXPECTED_CLASSIFICATION_RESULT)) + def test_classify(self, model_file_type, max_results, + expected_classification_result): + # Creates classifier. + if model_file_type is ModelFileType.FILE_NAME: + base_options = _BaseOptions(file_name=self.model_path) + elif model_file_type is ModelFileType.FILE_CONTENT: + with open(self.model_path, 'rb') as f: + model_content = f.read() + base_options = _BaseOptions(file_content=model_content) + else: + # Should never happen + raise ValueError('model_file_type is invalid.') + + options = _ImageClassifierOptions( + base_options=base_options, max_results=max_results) + classifier = _ImageClassifier.create_from_options(options) + + # Performs image classification on the input. + image_result = classifier.classify(self.test_image) + # Comparing results. + self.assertEqual(image_result, expected_classification_result) + # Closes the classifier explicitly when the classifier is not used in + # a context. + classifier.close() + + @parameterized.parameters( + (ModelFileType.FILE_NAME, 4, _EXPECTED_CLASSIFICATION_RESULT), + (ModelFileType.FILE_CONTENT, 4, _EXPECTED_CLASSIFICATION_RESULT)) + def test_classify_in_context(self, model_file_type, max_results, + expected_classification_result): + if model_file_type is ModelFileType.FILE_NAME: + base_options = _BaseOptions(file_name=self.model_path) + elif model_file_type is ModelFileType.FILE_CONTENT: + with open(self.model_path, 'rb') as f: + model_content = f.read() + base_options = _BaseOptions(file_content=model_content) + else: + # Should never happen + raise ValueError('model_file_type is invalid.') + + options = _ImageClassifierOptions( + base_options=base_options, max_results=max_results) + with _ImageClassifier.create_from_options(options) as classifier: + # Performs object detection on the input. + image_result = classifier.classify(self.test_image) + # Comparing results. + self.assertEqual(image_result, expected_classification_result) + + def test_score_threshold_option(self): + options = _ImageClassifierOptions( + base_options=_BaseOptions(file_name=self.model_path), + score_threshold=_SCORE_THRESHOLD) + with _ImageClassifier.create_from_options(options) as classifier: + # Performs image classification on the input. + image_result = classifier.classify(self.test_image) + classifications = image_result.classifications + + for classification in classifications: + for entry in classification.entries: + score = entry.categories[0].score + self.assertGreaterEqual( + score, _SCORE_THRESHOLD, + f'Classification with score lower than threshold found. ' + f'{classification}') + + def test_max_results_option(self): + options = _ImageClassifierOptions( + base_options=_BaseOptions(file_name=self.model_path), + max_results=_MAX_RESULTS) + with _ImageClassifier.create_from_options(options) as classifier: + # Performs image classification on the input. + image_result = classifier.classify(self.test_image) + categories = image_result.classifications[0].entries[0].categories + + self.assertLessEqual( + len(categories), _MAX_RESULTS, 'Too many results returned.') + + def test_allow_list_option(self): + options = _ImageClassifierOptions( + base_options=_BaseOptions(file_name=self.model_path), + category_allowlist=_ALLOW_LIST) + with _ImageClassifier.create_from_options(options) as classifier: + # Performs image classification on the input. + image_result = classifier.classify(self.test_image) + classifications = image_result.classifications + + for classification in classifications: + for entry in classification.entries: + label = entry.categories[0].category_name + self.assertIn(label, _ALLOW_LIST, + f'Label {label} found but not in label allow list') + + def test_deny_list_option(self): + options = _ImageClassifierOptions( + base_options=_BaseOptions(file_name=self.model_path), + category_denylist=_DENY_LIST) + with _ImageClassifier.create_from_options(options) as classifier: + # Performs image classification on the input. + image_result = classifier.classify(self.test_image) + classifications = image_result.classifications + + for classification in classifications: + for entry in classification.entries: + label = entry.categories[0].category_name + self.assertNotIn(label, _DENY_LIST, + f'Label {label} found but in deny list.') + + def test_combined_allowlist_and_denylist(self): + # Fails with combined allowlist and denylist + with self.assertRaisesRegex( + ValueError, + r'`category_allowlist` and `category_denylist` are mutually ' + r'exclusive options.'): + options = _ImageClassifierOptions( + base_options=_BaseOptions(file_name=self.model_path), + category_allowlist=['foo'], + category_denylist=['bar']) + with _ImageClassifier.create_from_options(options) as unused_classifier: + pass + + def test_empty_classification_outputs(self): + options = _ImageClassifierOptions( + base_options=_BaseOptions(file_name=self.model_path), score_threshold=1) + with _ImageClassifier.create_from_options(options) as classifier: + # Performs image classification on the input. + image_result = classifier.classify(self.test_image) + self.assertEmpty(image_result.classifications[0].entries[0].categories) + + def test_missing_result_callback(self): + options = _ImageClassifierOptions( + base_options=_BaseOptions(file_name=self.model_path), + running_mode=_RUNNING_MODE.LIVE_STREAM) + with self.assertRaisesRegex(ValueError, + r'result callback must be provided'): + with _ImageClassifier.create_from_options(options) as unused_classifier: + pass + + @parameterized.parameters((_RUNNING_MODE.IMAGE), (_RUNNING_MODE.VIDEO)) + def test_illegal_result_callback(self, running_mode): + + def pass_through(unused_result: _ClassificationResult): + pass + + options = _ImageClassifierOptions( + base_options=_BaseOptions(file_name=self.model_path), + running_mode=running_mode, + result_callback=pass_through) + with self.assertRaisesRegex(ValueError, + r'result callback should not be provided'): + with _ImageClassifier.create_from_options(options) as unused_classifier: + pass + + # @parameterized.parameters((0, _EXPECTED_CLASSIFICATION_RESULT), + # (1, _ClassificationResult(classifications=[]))) + # def test_classify_async_calls(self, threshold, expected_result): + # observed_timestamp_ms = -1 + # + # def check_result(result: _ClassificationResult, timestamp_ms: int): + # self.assertEqual(result, expected_result) + # self.assertLess(observed_timestamp_ms, timestamp_ms) + # self.observed_timestamp_ms = timestamp_ms + # + # options = _ImageClassifierOptions( + # base_options=_BaseOptions(file_name=self.model_path), + # running_mode=_RUNNING_MODE.LIVE_STREAM, + # max_results=4, + # score_threshold=threshold, + # result_callback=check_result) + # classifier = _ImageClassifier.create_from_options(options) + # for timestamp in range(0, 300, 30): + # classifier.classify_async(self.test_image, timestamp) + # classifier.close() + + +if __name__ == '__main__': + absltest.main() diff --git a/mediapipe/tasks/python/vision/BUILD b/mediapipe/tasks/python/vision/BUILD index 7ff818610..7a27da179 100644 --- a/mediapipe/tasks/python/vision/BUILD +++ b/mediapipe/tasks/python/vision/BUILD @@ -36,3 +36,23 @@ py_library( "//mediapipe/tasks/python/vision/core:vision_task_running_mode", ], ) + +py_library( + name = "image_classification", + srcs = [ + "image_classification.py", + ], + deps = [ + "//mediapipe/python:_framework_bindings", + "//mediapipe/python:packet_creator", + "//mediapipe/python:packet_getter", + "//mediapipe/tasks/cc/components:classifier_options_py_pb2", + "//mediapipe/tasks/cc/vision/image_classification:image_classifier_options_py_pb2", + "//mediapipe/tasks/python/components/containers:classifications", + "//mediapipe/tasks/python/core:base_options", + "//mediapipe/tasks/python/core:optional_dependencies", + "//mediapipe/tasks/python/core:task_info", + "//mediapipe/tasks/python/vision/core:base_vision_task_api", + "//mediapipe/tasks/python/vision/core:vision_task_running_mode", + ], +) diff --git a/mediapipe/tasks/python/vision/image_classification.py b/mediapipe/tasks/python/vision/image_classification.py new file mode 100644 index 000000000..efe6aa11d --- /dev/null +++ b/mediapipe/tasks/python/vision/image_classification.py @@ -0,0 +1,227 @@ +# Copyright 2022 The MediaPipe Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""MediaPipe image classifier task.""" + +import dataclasses +from typing import Callable, List, Mapping, Optional + +from mediapipe.python import packet_creator +from mediapipe.python import packet_getter +from mediapipe.python._framework_bindings import image as image_module +from mediapipe.python._framework_bindings import packet as packet_module +from mediapipe.python._framework_bindings import task_runner as task_runner_module +from mediapipe.tasks.cc.components import classifier_options_pb2 +from mediapipe.tasks.cc.vision.image_classification import image_classifier_options_pb2 +from mediapipe.tasks.python.components.containers import classifications as classifications_module +from mediapipe.tasks.python.core import base_options as base_options_module +from mediapipe.tasks.python.core import task_info as task_info_module +from mediapipe.tasks.python.core.optional_dependencies import doc_controls +from mediapipe.tasks.python.vision.core import base_vision_task_api +from mediapipe.tasks.python.vision.core import vision_task_running_mode as running_mode_module + +_BaseOptions = base_options_module.BaseOptions +_ClassifierOptionsProto = classifier_options_pb2.ClassifierOptions +_ImageClassifierOptionsProto = image_classifier_options_pb2.ImageClassifierOptions +_RunningMode = running_mode_module.VisionTaskRunningMode +_TaskInfo = task_info_module.TaskInfo +_TaskRunner = task_runner_module.TaskRunner + +_CLASSIFICATION_RESULT_OUT_STREAM_NAME = 'classification_result_out' +_CLASSIFICATION_RESULT_TAG = 'CLASSIFICATION_RESULT' +_IMAGE_IN_STREAM_NAME = 'image_in' +_IMAGE_TAG = 'IMAGE' +_TASK_GRAPH_NAME = 'mediapipe.tasks.vision.ImageClassifierGraph' + + +@dataclasses.dataclass +class ImageClassifierOptions: + """Options for the image classifier task. + + Attributes: + base_options: Base options for the image classifier task. + running_mode: The running mode of the task. Default to the image mode. + Image classifier task has three running modes: + 1) The image mode for classifying objects on single image inputs. + 2) The video mode for classifying objects on the decoded frames of a + video. + 3) The live stream mode for classifying objects on a live stream of input + data, such as from camera. + display_names_locale: The locale to use for display names specified through + the TFLite Model Metadata. + max_results: The maximum number of top-scored classification results to + return. + score_threshold: Overrides the ones provided in the model metadata. Results + below this value are rejected. + category_allowlist: Allowlist of category names. If non-empty, detection + results whose category name is not in this set will be filtered out. + Duplicate or unknown category names are ignored. Mutually exclusive with + `category_denylist`. + category_denylist: Denylist of category names. If non-empty, detection + results whose category name is in this set will be filtered out. Duplicate + or unknown category names are ignored. Mutually exclusive with + `category_allowlist`. + result_callback: The user-defined result callback for processing live stream + data. The result callback should only be specified when the running mode + is set to the live stream mode. + """ + base_options: _BaseOptions + running_mode: _RunningMode = _RunningMode.IMAGE + display_names_locale: Optional[str] = None + max_results: Optional[int] = None + score_threshold: Optional[float] = None + category_allowlist: Optional[List[str]] = None + category_denylist: Optional[List[str]] = None + result_callback: Optional[ + Callable[[classifications_module.ClassificationResult], None]] = None + + @doc_controls.do_not_generate_docs + def to_pb2(self) -> _ImageClassifierOptionsProto: + """Generates an ImageClassifierOptions protobuf object.""" + base_options_proto = self.base_options.to_pb2() + base_options_proto.use_stream_mode = False if self.running_mode == _RunningMode.IMAGE else True + + classifier_options_proto = _ClassifierOptionsProto( + display_names_locale=self.display_names_locale, + max_results=self.max_results, + score_threshold=self.score_threshold, + category_allowlist=self.category_allowlist, + category_denylist=self.category_denylist) + + return _ImageClassifierOptionsProto( + base_options=base_options_proto, + classifier_options=classifier_options_proto + ) + + +class ImageClassifier(base_vision_task_api.BaseVisionTaskApi): + """Class that performs image classification on images.""" + + @classmethod + def create_from_model_path(cls, model_path: str) -> 'ImageClassifier': + """Creates an `ImageClassifier` object from a TensorFlow Lite model and the default `ImageClassifierOptions`. + + Note that the created `ImageClassifier` instance is in image mode, for + detecting objects on single image inputs. + + Args: + model_path: Path to the model. + + Returns: + `ImageClassifier` object that's created from the model file and the default + `ImageClassifierOptions`. + + Raises: + ValueError: If failed to create `ImageClassifier` object from the provided + file such as invalid file path. + RuntimeError: If other types of error occurred. + """ + base_options = _BaseOptions(file_name=model_path) + options = ImageClassifierOptions( + base_options=base_options, running_mode=_RunningMode.IMAGE) + return cls.create_from_options(options) + + @classmethod + def create_from_options(cls, + options: ImageClassifierOptions) -> 'ImageClassifier': + """Creates the `ImageClassifier` object from image classifier options. + + Args: + options: Options for the image classifier task. + + Returns: + `ImageClassifier` object that's created from `options`. + + Raises: + ValueError: If failed to create `ImageClassifier` object from + `ImageClassifierOptions` such as missing the model. + RuntimeError: If other types of error occurred. + """ + + def packets_callback(output_packets: Mapping[str, packet_module.Packet]): + classification_result_proto = packet_getter.get_proto( + output_packets[_CLASSIFICATION_RESULT_OUT_STREAM_NAME]) + + classification_result = classifications_module.ClassificationResult([ + classifications_module.Classifications.create_from_pb2(classification) + for classification in classification_result_proto.classifications + ]) + options.result_callback(classification_result) + + task_info = _TaskInfo( + task_graph=_TASK_GRAPH_NAME, + input_streams=[':'.join([_IMAGE_TAG, _IMAGE_IN_STREAM_NAME])], + output_streams=[ + ':'.join([_CLASSIFICATION_RESULT_TAG, + _CLASSIFICATION_RESULT_OUT_STREAM_NAME]) + ], + task_options=options) + return cls( + task_info.generate_graph_config( + enable_flow_limiting=options.running_mode == + _RunningMode.LIVE_STREAM), options.running_mode, + packets_callback if options.result_callback else None) + + # TODO: Create an Image class for MediaPipe Tasks. + def classify( + self, + image: image_module.Image + ) -> classifications_module.ClassificationResult: + """Performs image classification on the provided MediaPipe Image. + + Args: + image: MediaPipe Image. + + Returns: + A classification result object that contains a list of classifications. + + Raises: + ValueError: If any of the input arguments is invalid. + RuntimeError: If image classification failed to run. + """ + output_packets = self._process_image_data( + {_IMAGE_IN_STREAM_NAME: packet_creator.create_image(image)}) + classification_result_proto = packet_getter.get_proto( + output_packets[_CLASSIFICATION_RESULT_OUT_STREAM_NAME]) + + return classifications_module.ClassificationResult([ + classifications_module.Classifications.create_from_pb2(classification) + for classification in classification_result_proto.classifications + ]) + + def classify_async(self, image: image_module.Image, timestamp_ms: int) -> None: + """Sends live image data (an Image with a unique timestamp) to perform image + classification. + + This method will return immediately after the input image is accepted. The + results will be available via the `result_callback` provided in the + `ImageClassifierOptions`. The `detect_async` method is designed to process + live stream data such as camera input. To lower the overall latency, image + classifier may drop the input images if needed. In other words, it's not + guaranteed to have output per input image. The `result_callback` provides: + - A classification result object that contains a list of classifications. + - The input image that the image classifier runs on. + - The input timestamp in milliseconds. + + Args: + image: MediaPipe Image. + timestamp_ms: The timestamp of the input image in milliseconds. + + Raises: + ValueError: If the current input timestamp is smaller than what the image + classifier has already processed. + """ + self._send_live_stream_data({ + _IMAGE_IN_STREAM_NAME: + packet_creator.create_image(image).at(timestamp_ms) + }) From ec0c5f43412eb521f2d943ee6eef0423c95baf3c Mon Sep 17 00:00:00 2001 From: kinaryml Date: Sun, 11 Sep 2022 14:00:49 -0700 Subject: [PATCH 02/23] Code cleanup --- .../test/vision/image_classification_test.py | 25 ++------------ .../python/vision/image_classification.py | 33 +++---------------- 2 files changed, 6 insertions(+), 52 deletions(-) diff --git a/mediapipe/tasks/python/test/vision/image_classification_test.py b/mediapipe/tasks/python/test/vision/image_classification_test.py index 3650c547c..a96eee6cb 100644 --- a/mediapipe/tasks/python/test/vision/image_classification_test.py +++ b/mediapipe/tasks/python/test/vision/image_classification_test.py @@ -124,7 +124,7 @@ class ImageClassifierTest(parameterized.TestCase): (ModelFileType.FILE_NAME, 4, _EXPECTED_CLASSIFICATION_RESULT), (ModelFileType.FILE_CONTENT, 4, _EXPECTED_CLASSIFICATION_RESULT)) def test_classify(self, model_file_type, max_results, - expected_classification_result): + expected_classification_result): # Creates classifier. if model_file_type is ModelFileType.FILE_NAME: base_options = _BaseOptions(file_name=self.model_path) @@ -152,7 +152,7 @@ class ImageClassifierTest(parameterized.TestCase): (ModelFileType.FILE_NAME, 4, _EXPECTED_CLASSIFICATION_RESULT), (ModelFileType.FILE_CONTENT, 4, _EXPECTED_CLASSIFICATION_RESULT)) def test_classify_in_context(self, model_file_type, max_results, - expected_classification_result): + expected_classification_result): if model_file_type is ModelFileType.FILE_NAME: base_options = _BaseOptions(file_name=self.model_path) elif model_file_type is ModelFileType.FILE_CONTENT: @@ -275,27 +275,6 @@ class ImageClassifierTest(parameterized.TestCase): with _ImageClassifier.create_from_options(options) as unused_classifier: pass - # @parameterized.parameters((0, _EXPECTED_CLASSIFICATION_RESULT), - # (1, _ClassificationResult(classifications=[]))) - # def test_classify_async_calls(self, threshold, expected_result): - # observed_timestamp_ms = -1 - # - # def check_result(result: _ClassificationResult, timestamp_ms: int): - # self.assertEqual(result, expected_result) - # self.assertLess(observed_timestamp_ms, timestamp_ms) - # self.observed_timestamp_ms = timestamp_ms - # - # options = _ImageClassifierOptions( - # base_options=_BaseOptions(file_name=self.model_path), - # running_mode=_RUNNING_MODE.LIVE_STREAM, - # max_results=4, - # score_threshold=threshold, - # result_callback=check_result) - # classifier = _ImageClassifier.create_from_options(options) - # for timestamp in range(0, 300, 30): - # classifier.classify_async(self.test_image, timestamp) - # classifier.close() - if __name__ == '__main__': absltest.main() diff --git a/mediapipe/tasks/python/vision/image_classification.py b/mediapipe/tasks/python/vision/image_classification.py index efe6aa11d..95381d78a 100644 --- a/mediapipe/tasks/python/vision/image_classification.py +++ b/mediapipe/tasks/python/vision/image_classification.py @@ -83,7 +83,8 @@ class ImageClassifierOptions: category_allowlist: Optional[List[str]] = None category_denylist: Optional[List[str]] = None result_callback: Optional[ - Callable[[classifications_module.ClassificationResult], None]] = None + Callable[[classifications_module.ClassificationResult], + None]] = None @doc_controls.do_not_generate_docs def to_pb2(self) -> _ImageClassifierOptionsProto: @@ -96,7 +97,8 @@ class ImageClassifierOptions: max_results=self.max_results, score_threshold=self.score_threshold, category_allowlist=self.category_allowlist, - category_denylist=self.category_denylist) + category_denylist=self.category_denylist + ) return _ImageClassifierOptionsProto( base_options=base_options_proto, @@ -198,30 +200,3 @@ class ImageClassifier(base_vision_task_api.BaseVisionTaskApi): classifications_module.Classifications.create_from_pb2(classification) for classification in classification_result_proto.classifications ]) - - def classify_async(self, image: image_module.Image, timestamp_ms: int) -> None: - """Sends live image data (an Image with a unique timestamp) to perform image - classification. - - This method will return immediately after the input image is accepted. The - results will be available via the `result_callback` provided in the - `ImageClassifierOptions`. The `detect_async` method is designed to process - live stream data such as camera input. To lower the overall latency, image - classifier may drop the input images if needed. In other words, it's not - guaranteed to have output per input image. The `result_callback` provides: - - A classification result object that contains a list of classifications. - - The input image that the image classifier runs on. - - The input timestamp in milliseconds. - - Args: - image: MediaPipe Image. - timestamp_ms: The timestamp of the input image in milliseconds. - - Raises: - ValueError: If the current input timestamp is smaller than what the image - classifier has already processed. - """ - self._send_live_stream_data({ - _IMAGE_IN_STREAM_NAME: - packet_creator.create_image(image).at(timestamp_ms) - }) From 7287e5a0ed860a29abb8f1ad9c88d82021d1c8e1 Mon Sep 17 00:00:00 2001 From: kinaryml Date: Wed, 21 Sep 2022 03:27:14 -0700 Subject: [PATCH 03/23] Added the image classifier task graph --- mediapipe/python/BUILD | 1 + 1 file changed, 1 insertion(+) diff --git a/mediapipe/python/BUILD b/mediapipe/python/BUILD index 3a4a90b44..9fe21ab9e 100644 --- a/mediapipe/python/BUILD +++ b/mediapipe/python/BUILD @@ -86,6 +86,7 @@ cc_library( name = "builtin_task_graphs", deps = [ "//mediapipe/tasks/cc/vision/object_detector:object_detector_graph", + "//mediapipe/tasks/cc/vision/image_classification:image_classifier_graph", ], ) From d8f7c5a43b311e166c337c451595348b5e1610d3 Mon Sep 17 00:00:00 2001 From: kinaryml Date: Wed, 21 Sep 2022 04:22:33 -0700 Subject: [PATCH 04/23] Moved ClassifierOptions to mediapipe/tasks/python/components to align with the C++ API --- mediapipe/tasks/python/components/BUILD | 28 ++++++ .../python/components/classifier_options.py | 92 +++++++++++++++++++ .../test/vision/image_classification_test.py | 31 +++++-- mediapipe/tasks/python/vision/BUILD | 2 +- .../python/vision/image_classification.py | 19 +--- 5 files changed, 146 insertions(+), 26 deletions(-) create mode 100644 mediapipe/tasks/python/components/BUILD create mode 100644 mediapipe/tasks/python/components/classifier_options.py diff --git a/mediapipe/tasks/python/components/BUILD b/mediapipe/tasks/python/components/BUILD new file mode 100644 index 000000000..4094b7f7f --- /dev/null +++ b/mediapipe/tasks/python/components/BUILD @@ -0,0 +1,28 @@ +# Copyright 2022 The MediaPipe Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +# Placeholder for internal Python strict library compatibility macro. + +package(default_visibility = ["//mediapipe/tasks:internal"]) + +licenses(["notice"]) + +py_library( + name = "classifier_options", + srcs = ["classifier_options.py"], + deps = [ + "//mediapipe/tasks/cc/components:classifier_options_py_pb2", + "//mediapipe/tasks/python/core:optional_dependencies", + ], +) diff --git a/mediapipe/tasks/python/components/classifier_options.py b/mediapipe/tasks/python/components/classifier_options.py new file mode 100644 index 000000000..f6e61e48c --- /dev/null +++ b/mediapipe/tasks/python/components/classifier_options.py @@ -0,0 +1,92 @@ +# Copyright 2022 The TensorFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""Classifier options data class.""" + +import dataclasses +from typing import Any, List, Optional + +from mediapipe.tasks.cc.components import classifier_options_pb2 +from mediapipe.tasks.python.core.optional_dependencies import doc_controls + +_ClassifierOptionsProto = classifier_options_pb2.ClassifierOptions + + +@dataclasses.dataclass +class ClassifierOptions: + """Options for classification processor. + + Attributes: + display_names_locale: The locale to use for display names specified through + the TFLite Model Metadata. + max_results: The maximum number of top-scored classification results to + return. + score_threshold: Overrides the ones provided in the model metadata. Results + below this value are rejected. + category_allowlist: Allowlist of category names. If non-empty, detection + results whose category name is not in this set will be filtered out. + Duplicate or unknown category names are ignored. Mutually exclusive with + `category_denylist`. + category_denylist: Denylist of category names. If non-empty, detection + results whose category name is in this set will be filtered out. Duplicate + or unknown category names are ignored. Mutually exclusive with + `category_allowlist`. + """ + + display_names_locale: Optional[str] = None + max_results: Optional[int] = None + score_threshold: Optional[float] = None + category_allowlist: Optional[List[str]] = None + category_denylist: Optional[List[str]] = None + + @doc_controls.do_not_generate_docs + def to_pb2(self) -> _ClassifierOptionsProto: + """Generates a ClassifierOptions protobuf object.""" + return _ClassifierOptionsProto( + score_threshold=self.score_threshold, + category_allowlist=self.category_allowlist, + category_denylist=self.category_denylist, + display_names_locale=self.display_names_locale, + max_results=self.max_results) + + @classmethod + @doc_controls.do_not_generate_docs + def create_from_pb2( + cls, + pb2_obj: _ClassifierOptionsProto + ) -> 'ClassifierOptions': + """Creates a `ClassifierOptions` object from the given protobuf object.""" + return ClassifierOptions( + score_threshold=pb2_obj.score_threshold, + category_allowlist=[ + str(name) for name in pb2_obj.class_name_allowlist + ], + category_denylist=[ + str(name) for name in pb2_obj.class_name_denylist + ], + display_names_locale=pb2_obj.display_names_locale, + max_results=pb2_obj.max_results) + + def __eq__(self, other: Any) -> bool: + """Checks if this object is equal to the given object. + + Args: + other: The object to be compared with. + + Returns: + True if the objects are equal. + """ + if not isinstance(other, ClassifierOptions): + return False + + return self.to_pb2().__eq__(other.to_pb2()) diff --git a/mediapipe/tasks/python/test/vision/image_classification_test.py b/mediapipe/tasks/python/test/vision/image_classification_test.py index a96eee6cb..51dcb1adf 100644 --- a/mediapipe/tasks/python/test/vision/image_classification_test.py +++ b/mediapipe/tasks/python/test/vision/image_classification_test.py @@ -19,6 +19,7 @@ from absl.testing import absltest from absl.testing import parameterized from mediapipe.python._framework_bindings import image as image_module +from mediapipe.tasks.python.components import classifier_options from mediapipe.tasks.python.components.containers import category as category_module from mediapipe.tasks.python.components.containers import classifications as classifications_module from mediapipe.tasks.python.core import base_options as base_options_module @@ -27,6 +28,7 @@ from mediapipe.tasks.python.vision import image_classification from mediapipe.tasks.python.vision.core import vision_task_running_mode as running_mode_module _BaseOptions = base_options_module.BaseOptions +_ClassifierOptions = classifier_options.ClassifierOptions _Category = category_module.Category _ClassificationEntry = classifications_module.ClassificationEntry _Classifications = classifications_module.Classifications @@ -136,8 +138,9 @@ class ImageClassifierTest(parameterized.TestCase): # Should never happen raise ValueError('model_file_type is invalid.') + classifier_options = _ClassifierOptions(max_results=max_results) options = _ImageClassifierOptions( - base_options=base_options, max_results=max_results) + base_options=base_options, classifier_options=classifier_options) classifier = _ImageClassifier.create_from_options(options) # Performs image classification on the input. @@ -163,8 +166,9 @@ class ImageClassifierTest(parameterized.TestCase): # Should never happen raise ValueError('model_file_type is invalid.') + classifier_options = _ClassifierOptions(max_results=max_results) options = _ImageClassifierOptions( - base_options=base_options, max_results=max_results) + base_options=base_options, classifier_options=classifier_options) with _ImageClassifier.create_from_options(options) as classifier: # Performs object detection on the input. image_result = classifier.classify(self.test_image) @@ -172,9 +176,10 @@ class ImageClassifierTest(parameterized.TestCase): self.assertEqual(image_result, expected_classification_result) def test_score_threshold_option(self): + classifier_options = _ClassifierOptions(score_threshold=_SCORE_THRESHOLD) options = _ImageClassifierOptions( base_options=_BaseOptions(file_name=self.model_path), - score_threshold=_SCORE_THRESHOLD) + classifier_options=classifier_options) with _ImageClassifier.create_from_options(options) as classifier: # Performs image classification on the input. image_result = classifier.classify(self.test_image) @@ -189,9 +194,10 @@ class ImageClassifierTest(parameterized.TestCase): f'{classification}') def test_max_results_option(self): + classifier_options = _ClassifierOptions(score_threshold=_SCORE_THRESHOLD) options = _ImageClassifierOptions( base_options=_BaseOptions(file_name=self.model_path), - max_results=_MAX_RESULTS) + classifier_options=classifier_options) with _ImageClassifier.create_from_options(options) as classifier: # Performs image classification on the input. image_result = classifier.classify(self.test_image) @@ -201,9 +207,10 @@ class ImageClassifierTest(parameterized.TestCase): len(categories), _MAX_RESULTS, 'Too many results returned.') def test_allow_list_option(self): + classifier_options = _ClassifierOptions(category_allowlist=_ALLOW_LIST) options = _ImageClassifierOptions( base_options=_BaseOptions(file_name=self.model_path), - category_allowlist=_ALLOW_LIST) + classifier_options=classifier_options) with _ImageClassifier.create_from_options(options) as classifier: # Performs image classification on the input. image_result = classifier.classify(self.test_image) @@ -216,9 +223,10 @@ class ImageClassifierTest(parameterized.TestCase): f'Label {label} found but not in label allow list') def test_deny_list_option(self): + classifier_options = _ClassifierOptions(category_denylist=_DENY_LIST) options = _ImageClassifierOptions( - base_options=_BaseOptions(file_name=self.model_path), - category_denylist=_DENY_LIST) + base_options=_BaseOptions(file_name=self.model_path), + classifier_options=classifier_options) with _ImageClassifier.create_from_options(options) as classifier: # Performs image classification on the input. image_result = classifier.classify(self.test_image) @@ -236,16 +244,19 @@ class ImageClassifierTest(parameterized.TestCase): ValueError, r'`category_allowlist` and `category_denylist` are mutually ' r'exclusive options.'): + classifier_options = _ClassifierOptions(category_allowlist=['foo'], + category_denylist=['bar']) options = _ImageClassifierOptions( base_options=_BaseOptions(file_name=self.model_path), - category_allowlist=['foo'], - category_denylist=['bar']) + classifier_options=classifier_options) with _ImageClassifier.create_from_options(options) as unused_classifier: pass def test_empty_classification_outputs(self): + classifier_options = _ClassifierOptions(score_threshold=1) options = _ImageClassifierOptions( - base_options=_BaseOptions(file_name=self.model_path), score_threshold=1) + base_options=_BaseOptions(file_name=self.model_path), + classifier_options=classifier_options) with _ImageClassifier.create_from_options(options) as classifier: # Performs image classification on the input. image_result = classifier.classify(self.test_image) diff --git a/mediapipe/tasks/python/vision/BUILD b/mediapipe/tasks/python/vision/BUILD index 7a27da179..40caf129f 100644 --- a/mediapipe/tasks/python/vision/BUILD +++ b/mediapipe/tasks/python/vision/BUILD @@ -46,8 +46,8 @@ py_library( "//mediapipe/python:_framework_bindings", "//mediapipe/python:packet_creator", "//mediapipe/python:packet_getter", - "//mediapipe/tasks/cc/components:classifier_options_py_pb2", "//mediapipe/tasks/cc/vision/image_classification:image_classifier_options_py_pb2", + "//mediapipe/tasks/python/components:classifier_options", "//mediapipe/tasks/python/components/containers:classifications", "//mediapipe/tasks/python/core:base_options", "//mediapipe/tasks/python/core:optional_dependencies", diff --git a/mediapipe/tasks/python/vision/image_classification.py b/mediapipe/tasks/python/vision/image_classification.py index 95381d78a..94176cdf8 100644 --- a/mediapipe/tasks/python/vision/image_classification.py +++ b/mediapipe/tasks/python/vision/image_classification.py @@ -21,8 +21,8 @@ from mediapipe.python import packet_getter from mediapipe.python._framework_bindings import image as image_module from mediapipe.python._framework_bindings import packet as packet_module from mediapipe.python._framework_bindings import task_runner as task_runner_module -from mediapipe.tasks.cc.components import classifier_options_pb2 from mediapipe.tasks.cc.vision.image_classification import image_classifier_options_pb2 +from mediapipe.tasks.python.components import classifier_options from mediapipe.tasks.python.components.containers import classifications as classifications_module from mediapipe.tasks.python.core import base_options as base_options_module from mediapipe.tasks.python.core import task_info as task_info_module @@ -31,8 +31,8 @@ from mediapipe.tasks.python.vision.core import base_vision_task_api from mediapipe.tasks.python.vision.core import vision_task_running_mode as running_mode_module _BaseOptions = base_options_module.BaseOptions -_ClassifierOptionsProto = classifier_options_pb2.ClassifierOptions _ImageClassifierOptionsProto = image_classifier_options_pb2.ImageClassifierOptions +_ClassifierOptions = classifier_options.ClassifierOptions _RunningMode = running_mode_module.VisionTaskRunningMode _TaskInfo = task_info_module.TaskInfo _TaskRunner = task_runner_module.TaskRunner @@ -77,11 +77,7 @@ class ImageClassifierOptions: """ base_options: _BaseOptions running_mode: _RunningMode = _RunningMode.IMAGE - display_names_locale: Optional[str] = None - max_results: Optional[int] = None - score_threshold: Optional[float] = None - category_allowlist: Optional[List[str]] = None - category_denylist: Optional[List[str]] = None + classifier_options: _ClassifierOptions = _ClassifierOptions() result_callback: Optional[ Callable[[classifications_module.ClassificationResult], None]] = None @@ -91,14 +87,7 @@ class ImageClassifierOptions: """Generates an ImageClassifierOptions protobuf object.""" base_options_proto = self.base_options.to_pb2() base_options_proto.use_stream_mode = False if self.running_mode == _RunningMode.IMAGE else True - - classifier_options_proto = _ClassifierOptionsProto( - display_names_locale=self.display_names_locale, - max_results=self.max_results, - score_threshold=self.score_threshold, - category_allowlist=self.category_allowlist, - category_denylist=self.category_denylist - ) + classifier_options_proto = self.classifier_options.to_pb2() return _ImageClassifierOptionsProto( base_options=base_options_proto, From bb750befd2e9ca066b76981404f09bbd91919a18 Mon Sep 17 00:00:00 2001 From: kinaryml Date: Wed, 21 Sep 2022 04:24:11 -0700 Subject: [PATCH 05/23] Updated ImageClassifierOptions docstring --- .../tasks/python/vision/image_classification.py | 15 +-------------- 1 file changed, 1 insertion(+), 14 deletions(-) diff --git a/mediapipe/tasks/python/vision/image_classification.py b/mediapipe/tasks/python/vision/image_classification.py index 94176cdf8..f08da9f2b 100644 --- a/mediapipe/tasks/python/vision/image_classification.py +++ b/mediapipe/tasks/python/vision/image_classification.py @@ -57,20 +57,7 @@ class ImageClassifierOptions: video. 3) The live stream mode for classifying objects on a live stream of input data, such as from camera. - display_names_locale: The locale to use for display names specified through - the TFLite Model Metadata. - max_results: The maximum number of top-scored classification results to - return. - score_threshold: Overrides the ones provided in the model metadata. Results - below this value are rejected. - category_allowlist: Allowlist of category names. If non-empty, detection - results whose category name is not in this set will be filtered out. - Duplicate or unknown category names are ignored. Mutually exclusive with - `category_denylist`. - category_denylist: Denylist of category names. If non-empty, detection - results whose category name is in this set will be filtered out. Duplicate - or unknown category names are ignored. Mutually exclusive with - `category_allowlist`. + classifier_options: Options for the image classification task. result_callback: The user-defined result callback for processing live stream data. The result callback should only be specified when the running mode is set to the live stream mode. From 72319ecbf57cb56502e8e9af912b7caa37e2b2eb Mon Sep 17 00:00:00 2001 From: kinaryml Date: Thu, 29 Sep 2022 02:14:42 -0700 Subject: [PATCH 06/23] Updated BUILD --- mediapipe/tasks/python/test/vision/BUILD | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/mediapipe/tasks/python/test/vision/BUILD b/mediapipe/tasks/python/test/vision/BUILD index a63c36b55..46efc402d 100644 --- a/mediapipe/tasks/python/test/vision/BUILD +++ b/mediapipe/tasks/python/test/vision/BUILD @@ -26,7 +26,7 @@ py_test( "//mediapipe/tasks/testdata/vision:test_models", ], deps = [ - # build rule placeholder: numpy dep, + "//mediapipe/python:_framework_bindings", "//mediapipe/tasks/python/components/containers:bounding_box", "//mediapipe/tasks/python/components/containers:category", "//mediapipe/tasks/python/components/containers:detections", @@ -34,7 +34,6 @@ py_test( "//mediapipe/tasks/python/test:test_util", "//mediapipe/tasks/python/vision:object_detector", "//mediapipe/tasks/python/vision/core:vision_task_running_mode", - "@absl_py//absl/testing:parameterized", ], ) @@ -46,6 +45,7 @@ py_test( "//mediapipe/tasks/testdata/vision:test_models", ], deps = [ + "//mediapipe/python:_framework_bindings", "//mediapipe/tasks/python/components/containers:category", "//mediapipe/tasks/python/components/containers:classifications", "//mediapipe/tasks/python/core:base_options", From 8ad591822922338aca7cabd25fddd3bacdc98d9d Mon Sep 17 00:00:00 2001 From: kinaryml Date: Thu, 29 Sep 2022 02:18:10 -0700 Subject: [PATCH 07/23] Removed some tests --- .../test/vision/image_classification_test.py | 165 ------------------ 1 file changed, 165 deletions(-) diff --git a/mediapipe/tasks/python/test/vision/image_classification_test.py b/mediapipe/tasks/python/test/vision/image_classification_test.py index 51dcb1adf..0d7f9e7f0 100644 --- a/mediapipe/tasks/python/test/vision/image_classification_test.py +++ b/mediapipe/tasks/python/test/vision/image_classification_test.py @@ -92,36 +92,6 @@ class ImageClassifierTest(parameterized.TestCase): test_util.get_test_data_path(_IMAGE_FILE)) self.model_path = test_util.get_test_data_path(_MODEL_FILE) - def test_create_from_file_succeeds_with_valid_model_path(self): - # Creates with default option and valid model file successfully. - with _ImageClassifier.create_from_model_path(self.model_path) as classifier: - self.assertIsInstance(classifier, _ImageClassifier) - - def test_create_from_options_succeeds_with_valid_model_path(self): - # Creates with options containing model file successfully. - base_options = _BaseOptions(file_name=self.model_path) - options = _ImageClassifierOptions(base_options=base_options) - with _ImageClassifier.create_from_options(options) as classifier: - self.assertIsInstance(classifier, _ImageClassifier) - - def test_create_from_options_fails_with_invalid_model_path(self): - # Invalid empty model path. - with self.assertRaisesRegex( - ValueError, - r"ExternalFile must specify at least one of 'file_content', " - r"'file_name' or 'file_descriptor_meta'."): - base_options = _BaseOptions(file_name='') - options = _ImageClassifierOptions(base_options=base_options) - _ImageClassifier.create_from_options(options) - - def test_create_from_options_succeeds_with_valid_model_content(self): - # Creates with options containing model content successfully. - with open(self.model_path, 'rb') as f: - base_options = _BaseOptions(file_content=f.read()) - options = _ImageClassifierOptions(base_options=base_options) - classifier = _ImageClassifier.create_from_options(options) - self.assertIsInstance(classifier, _ImageClassifier) - @parameterized.parameters( (ModelFileType.FILE_NAME, 4, _EXPECTED_CLASSIFICATION_RESULT), (ModelFileType.FILE_CONTENT, 4, _EXPECTED_CLASSIFICATION_RESULT)) @@ -151,141 +121,6 @@ class ImageClassifierTest(parameterized.TestCase): # a context. classifier.close() - @parameterized.parameters( - (ModelFileType.FILE_NAME, 4, _EXPECTED_CLASSIFICATION_RESULT), - (ModelFileType.FILE_CONTENT, 4, _EXPECTED_CLASSIFICATION_RESULT)) - def test_classify_in_context(self, model_file_type, max_results, - expected_classification_result): - if model_file_type is ModelFileType.FILE_NAME: - base_options = _BaseOptions(file_name=self.model_path) - elif model_file_type is ModelFileType.FILE_CONTENT: - with open(self.model_path, 'rb') as f: - model_content = f.read() - base_options = _BaseOptions(file_content=model_content) - else: - # Should never happen - raise ValueError('model_file_type is invalid.') - - classifier_options = _ClassifierOptions(max_results=max_results) - options = _ImageClassifierOptions( - base_options=base_options, classifier_options=classifier_options) - with _ImageClassifier.create_from_options(options) as classifier: - # Performs object detection on the input. - image_result = classifier.classify(self.test_image) - # Comparing results. - self.assertEqual(image_result, expected_classification_result) - - def test_score_threshold_option(self): - classifier_options = _ClassifierOptions(score_threshold=_SCORE_THRESHOLD) - options = _ImageClassifierOptions( - base_options=_BaseOptions(file_name=self.model_path), - classifier_options=classifier_options) - with _ImageClassifier.create_from_options(options) as classifier: - # Performs image classification on the input. - image_result = classifier.classify(self.test_image) - classifications = image_result.classifications - - for classification in classifications: - for entry in classification.entries: - score = entry.categories[0].score - self.assertGreaterEqual( - score, _SCORE_THRESHOLD, - f'Classification with score lower than threshold found. ' - f'{classification}') - - def test_max_results_option(self): - classifier_options = _ClassifierOptions(score_threshold=_SCORE_THRESHOLD) - options = _ImageClassifierOptions( - base_options=_BaseOptions(file_name=self.model_path), - classifier_options=classifier_options) - with _ImageClassifier.create_from_options(options) as classifier: - # Performs image classification on the input. - image_result = classifier.classify(self.test_image) - categories = image_result.classifications[0].entries[0].categories - - self.assertLessEqual( - len(categories), _MAX_RESULTS, 'Too many results returned.') - - def test_allow_list_option(self): - classifier_options = _ClassifierOptions(category_allowlist=_ALLOW_LIST) - options = _ImageClassifierOptions( - base_options=_BaseOptions(file_name=self.model_path), - classifier_options=classifier_options) - with _ImageClassifier.create_from_options(options) as classifier: - # Performs image classification on the input. - image_result = classifier.classify(self.test_image) - classifications = image_result.classifications - - for classification in classifications: - for entry in classification.entries: - label = entry.categories[0].category_name - self.assertIn(label, _ALLOW_LIST, - f'Label {label} found but not in label allow list') - - def test_deny_list_option(self): - classifier_options = _ClassifierOptions(category_denylist=_DENY_LIST) - options = _ImageClassifierOptions( - base_options=_BaseOptions(file_name=self.model_path), - classifier_options=classifier_options) - with _ImageClassifier.create_from_options(options) as classifier: - # Performs image classification on the input. - image_result = classifier.classify(self.test_image) - classifications = image_result.classifications - - for classification in classifications: - for entry in classification.entries: - label = entry.categories[0].category_name - self.assertNotIn(label, _DENY_LIST, - f'Label {label} found but in deny list.') - - def test_combined_allowlist_and_denylist(self): - # Fails with combined allowlist and denylist - with self.assertRaisesRegex( - ValueError, - r'`category_allowlist` and `category_denylist` are mutually ' - r'exclusive options.'): - classifier_options = _ClassifierOptions(category_allowlist=['foo'], - category_denylist=['bar']) - options = _ImageClassifierOptions( - base_options=_BaseOptions(file_name=self.model_path), - classifier_options=classifier_options) - with _ImageClassifier.create_from_options(options) as unused_classifier: - pass - - def test_empty_classification_outputs(self): - classifier_options = _ClassifierOptions(score_threshold=1) - options = _ImageClassifierOptions( - base_options=_BaseOptions(file_name=self.model_path), - classifier_options=classifier_options) - with _ImageClassifier.create_from_options(options) as classifier: - # Performs image classification on the input. - image_result = classifier.classify(self.test_image) - self.assertEmpty(image_result.classifications[0].entries[0].categories) - - def test_missing_result_callback(self): - options = _ImageClassifierOptions( - base_options=_BaseOptions(file_name=self.model_path), - running_mode=_RUNNING_MODE.LIVE_STREAM) - with self.assertRaisesRegex(ValueError, - r'result callback must be provided'): - with _ImageClassifier.create_from_options(options) as unused_classifier: - pass - - @parameterized.parameters((_RUNNING_MODE.IMAGE), (_RUNNING_MODE.VIDEO)) - def test_illegal_result_callback(self, running_mode): - - def pass_through(unused_result: _ClassificationResult): - pass - - options = _ImageClassifierOptions( - base_options=_BaseOptions(file_name=self.model_path), - running_mode=running_mode, - result_callback=pass_through) - with self.assertRaisesRegex(ValueError, - r'result callback should not be provided'): - with _ImageClassifier.create_from_options(options) as unused_classifier: - pass - if __name__ == '__main__': absltest.main() From b04af0cafa7b8ec0d48122d20c7bd2f7693b12f0 Mon Sep 17 00:00:00 2001 From: kinaryml Date: Thu, 29 Sep 2022 02:45:20 -0700 Subject: [PATCH 08/23] Updated implementation and tests --- mediapipe/python/BUILD | 2 +- mediapipe/tasks/python/components/proto/BUILD | 29 +++++++++++++++++++ .../{ => proto}/classifier_options.py | 0 .../test/vision/image_classification_test.py | 4 +-- mediapipe/tasks/python/vision/BUILD | 4 +-- .../python/vision/image_classification.py | 6 ++-- 6 files changed, 37 insertions(+), 8 deletions(-) create mode 100644 mediapipe/tasks/python/components/proto/BUILD rename mediapipe/tasks/python/components/{ => proto}/classifier_options.py (100%) diff --git a/mediapipe/python/BUILD b/mediapipe/python/BUILD index 9fe21ab9e..33667d18e 100644 --- a/mediapipe/python/BUILD +++ b/mediapipe/python/BUILD @@ -86,7 +86,7 @@ cc_library( name = "builtin_task_graphs", deps = [ "//mediapipe/tasks/cc/vision/object_detector:object_detector_graph", - "//mediapipe/tasks/cc/vision/image_classification:image_classifier_graph", + "//mediapipe/tasks/cc/vision/image_classifier:image_classifier_graph", ], ) diff --git a/mediapipe/tasks/python/components/proto/BUILD b/mediapipe/tasks/python/components/proto/BUILD new file mode 100644 index 000000000..7814ec675 --- /dev/null +++ b/mediapipe/tasks/python/components/proto/BUILD @@ -0,0 +1,29 @@ +# Copyright 2022 The MediaPipe Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +# Placeholder for internal Python strict library compatibility macro. + +package(default_visibility = ["//mediapipe/tasks:internal"]) + +licenses(["notice"]) + +py_library( + name = "classifier_options", + srcs = ["classifier_options.py"], + deps = [ + "//mediapipe/tasks/cc/components/proto:classifier_options_py_pb2", + "//mediapipe/tasks/python/core:optional_dependencies", + ], +) + diff --git a/mediapipe/tasks/python/components/classifier_options.py b/mediapipe/tasks/python/components/proto/classifier_options.py similarity index 100% rename from mediapipe/tasks/python/components/classifier_options.py rename to mediapipe/tasks/python/components/proto/classifier_options.py diff --git a/mediapipe/tasks/python/test/vision/image_classification_test.py b/mediapipe/tasks/python/test/vision/image_classification_test.py index 0d7f9e7f0..08d850dab 100644 --- a/mediapipe/tasks/python/test/vision/image_classification_test.py +++ b/mediapipe/tasks/python/test/vision/image_classification_test.py @@ -99,11 +99,11 @@ class ImageClassifierTest(parameterized.TestCase): expected_classification_result): # Creates classifier. if model_file_type is ModelFileType.FILE_NAME: - base_options = _BaseOptions(file_name=self.model_path) + base_options = _BaseOptions(model_asset_path=self.model_path) elif model_file_type is ModelFileType.FILE_CONTENT: with open(self.model_path, 'rb') as f: model_content = f.read() - base_options = _BaseOptions(file_content=model_content) + base_options = _BaseOptions(model_asset_buffer=model_content) else: # Should never happen raise ValueError('model_file_type is invalid.') diff --git a/mediapipe/tasks/python/vision/BUILD b/mediapipe/tasks/python/vision/BUILD index 40caf129f..5e5bfe3ba 100644 --- a/mediapipe/tasks/python/vision/BUILD +++ b/mediapipe/tasks/python/vision/BUILD @@ -46,8 +46,8 @@ py_library( "//mediapipe/python:_framework_bindings", "//mediapipe/python:packet_creator", "//mediapipe/python:packet_getter", - "//mediapipe/tasks/cc/vision/image_classification:image_classifier_options_py_pb2", - "//mediapipe/tasks/python/components:classifier_options", + "//mediapipe/tasks/cc/vision/image_classifier/proto:image_classifier_options_py_pb2", + "//mediapipe/tasks/python/components/proto:classifier_options", "//mediapipe/tasks/python/components/containers:classifications", "//mediapipe/tasks/python/core:base_options", "//mediapipe/tasks/python/core:optional_dependencies", diff --git a/mediapipe/tasks/python/vision/image_classification.py b/mediapipe/tasks/python/vision/image_classification.py index f08da9f2b..c240ffedf 100644 --- a/mediapipe/tasks/python/vision/image_classification.py +++ b/mediapipe/tasks/python/vision/image_classification.py @@ -21,8 +21,8 @@ from mediapipe.python import packet_getter from mediapipe.python._framework_bindings import image as image_module from mediapipe.python._framework_bindings import packet as packet_module from mediapipe.python._framework_bindings import task_runner as task_runner_module -from mediapipe.tasks.cc.vision.image_classification import image_classifier_options_pb2 -from mediapipe.tasks.python.components import classifier_options +from mediapipe.tasks.cc.vision.image_classifier.proto import image_classifier_options_pb2 +from mediapipe.tasks.python.components.proto import classifier_options from mediapipe.tasks.python.components.containers import classifications as classifications_module from mediapipe.tasks.python.core import base_options as base_options_module from mediapipe.tasks.python.core import task_info as task_info_module @@ -104,7 +104,7 @@ class ImageClassifier(base_vision_task_api.BaseVisionTaskApi): file such as invalid file path. RuntimeError: If other types of error occurred. """ - base_options = _BaseOptions(file_name=model_path) + base_options = _BaseOptions(model_asset_path=model_path) options = ImageClassifierOptions( base_options=base_options, running_mode=_RunningMode.IMAGE) return cls.create_from_options(options) From cba2a6035c55842c440cb8be78244794eca38cf4 Mon Sep 17 00:00:00 2001 From: kinaryml Date: Thu, 29 Sep 2022 03:15:17 -0700 Subject: [PATCH 09/23] Code cleanup --- .../tasks/python/components/proto/classifier_options.py | 2 +- mediapipe/tasks/python/test/vision/BUILD | 7 ++++--- ...ge_classification_test.py => image_classifier_test.py} | 8 ++++---- mediapipe/tasks/python/vision/BUILD | 4 ++-- .../{image_classification.py => image_classifier.py} | 0 5 files changed, 11 insertions(+), 10 deletions(-) rename mediapipe/tasks/python/test/vision/{image_classification_test.py => image_classifier_test.py} (94%) rename mediapipe/tasks/python/vision/{image_classification.py => image_classifier.py} (100%) diff --git a/mediapipe/tasks/python/components/proto/classifier_options.py b/mediapipe/tasks/python/components/proto/classifier_options.py index f6e61e48c..c73828c13 100644 --- a/mediapipe/tasks/python/components/proto/classifier_options.py +++ b/mediapipe/tasks/python/components/proto/classifier_options.py @@ -16,7 +16,7 @@ import dataclasses from typing import Any, List, Optional -from mediapipe.tasks.cc.components import classifier_options_pb2 +from mediapipe.tasks.cc.components.proto import classifier_options_pb2 from mediapipe.tasks.python.core.optional_dependencies import doc_controls _ClassifierOptionsProto = classifier_options_pb2.ClassifierOptions diff --git a/mediapipe/tasks/python/test/vision/BUILD b/mediapipe/tasks/python/test/vision/BUILD index c96be63f8..6b6b9e3e2 100644 --- a/mediapipe/tasks/python/test/vision/BUILD +++ b/mediapipe/tasks/python/test/vision/BUILD @@ -38,19 +38,20 @@ py_test( ) py_test( - name = "image_classification_test", - srcs = ["image_classification_test.py"], + name = "image_classifier_test", + srcs = ["image_classifier_test.py"], data = [ "//mediapipe/tasks/testdata/vision:test_images", "//mediapipe/tasks/testdata/vision:test_models", ], deps = [ "//mediapipe/python:_framework_bindings", + "//mediapipe/tasks/python/components/proto:classifier_options", "//mediapipe/tasks/python/components/containers:category", "//mediapipe/tasks/python/components/containers:classifications", "//mediapipe/tasks/python/core:base_options", "//mediapipe/tasks/python/test:test_util", - "//mediapipe/tasks/python/vision:image_classification", + "//mediapipe/tasks/python/vision:image_classifier", "//mediapipe/tasks/python/vision/core:vision_task_running_mode", ], ) diff --git a/mediapipe/tasks/python/test/vision/image_classification_test.py b/mediapipe/tasks/python/test/vision/image_classifier_test.py similarity index 94% rename from mediapipe/tasks/python/test/vision/image_classification_test.py rename to mediapipe/tasks/python/test/vision/image_classifier_test.py index 08d850dab..3d7d116df 100644 --- a/mediapipe/tasks/python/test/vision/image_classification_test.py +++ b/mediapipe/tasks/python/test/vision/image_classifier_test.py @@ -19,12 +19,12 @@ from absl.testing import absltest from absl.testing import parameterized from mediapipe.python._framework_bindings import image as image_module -from mediapipe.tasks.python.components import classifier_options +from mediapipe.tasks.python.components.proto import classifier_options from mediapipe.tasks.python.components.containers import category as category_module from mediapipe.tasks.python.components.containers import classifications as classifications_module from mediapipe.tasks.python.core import base_options as base_options_module from mediapipe.tasks.python.test import test_util -from mediapipe.tasks.python.vision import image_classification +from mediapipe.tasks.python.vision import image_classifier from mediapipe.tasks.python.vision.core import vision_task_running_mode as running_mode_module _BaseOptions = base_options_module.BaseOptions @@ -34,8 +34,8 @@ _ClassificationEntry = classifications_module.ClassificationEntry _Classifications = classifications_module.Classifications _ClassificationResult = classifications_module.ClassificationResult _Image = image_module.Image -_ImageClassifier = image_classification.ImageClassifier -_ImageClassifierOptions = image_classification.ImageClassifierOptions +_ImageClassifier = image_classifier.ImageClassifier +_ImageClassifierOptions = image_classifier.ImageClassifierOptions _RUNNING_MODE = running_mode_module.VisionTaskRunningMode _MODEL_FILE = 'mobilenet_v2_1.0_224.tflite' diff --git a/mediapipe/tasks/python/vision/BUILD b/mediapipe/tasks/python/vision/BUILD index 5e5bfe3ba..a724f149a 100644 --- a/mediapipe/tasks/python/vision/BUILD +++ b/mediapipe/tasks/python/vision/BUILD @@ -38,9 +38,9 @@ py_library( ) py_library( - name = "image_classification", + name = "image_classifier", srcs = [ - "image_classification.py", + "image_classifier.py", ], deps = [ "//mediapipe/python:_framework_bindings", diff --git a/mediapipe/tasks/python/vision/image_classification.py b/mediapipe/tasks/python/vision/image_classifier.py similarity index 100% rename from mediapipe/tasks/python/vision/image_classification.py rename to mediapipe/tasks/python/vision/image_classifier.py From 68fea17e30a5c316f30617e1bda2e53336b41108 Mon Sep 17 00:00:00 2001 From: kinaryml Date: Thu, 29 Sep 2022 13:34:10 -0700 Subject: [PATCH 10/23] Removed unused dependencies in BUILD --- mediapipe/tasks/python/components/BUILD | 9 --------- 1 file changed, 9 deletions(-) diff --git a/mediapipe/tasks/python/components/BUILD b/mediapipe/tasks/python/components/BUILD index 4094b7f7f..00fd4061c 100644 --- a/mediapipe/tasks/python/components/BUILD +++ b/mediapipe/tasks/python/components/BUILD @@ -17,12 +17,3 @@ package(default_visibility = ["//mediapipe/tasks:internal"]) licenses(["notice"]) - -py_library( - name = "classifier_options", - srcs = ["classifier_options.py"], - deps = [ - "//mediapipe/tasks/cc/components:classifier_options_py_pb2", - "//mediapipe/tasks/python/core:optional_dependencies", - ], -) From 85af7ac9bcbf024178d8f389c8251aed2c5da36d Mon Sep 17 00:00:00 2001 From: kinaryml Date: Thu, 29 Sep 2022 13:36:07 -0700 Subject: [PATCH 11/23] Removed unused BUILD --- mediapipe/tasks/python/components/BUILD | 19 ------------------- 1 file changed, 19 deletions(-) delete mode 100644 mediapipe/tasks/python/components/BUILD diff --git a/mediapipe/tasks/python/components/BUILD b/mediapipe/tasks/python/components/BUILD deleted file mode 100644 index 00fd4061c..000000000 --- a/mediapipe/tasks/python/components/BUILD +++ /dev/null @@ -1,19 +0,0 @@ -# Copyright 2022 The MediaPipe Authors. All Rights Reserved. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -# Placeholder for internal Python strict library compatibility macro. - -package(default_visibility = ["//mediapipe/tasks:internal"]) - -licenses(["notice"]) From f241630b56261ac5009ba6638c7f5b8db9e699e1 Mon Sep 17 00:00:00 2001 From: kinaryml Date: Mon, 3 Oct 2022 15:11:48 -0700 Subject: [PATCH 12/23] Revised implementation to align with recent changes --- .../components/containers/{ => proto}/BUILD | 2 +- .../containers/{ => proto}/__init__.py | 0 .../containers/{ => proto}/bounding_box.py | 0 .../containers/{ => proto}/category.py | 0 .../containers/{ => proto}/classifications.py | 4 ++-- .../containers/{ => proto}/detections.py | 0 .../components/{ => processors}/proto/BUILD | 3 +-- .../proto/classifier_options.py | 2 +- mediapipe/tasks/python/test/vision/BUILD | 8 ++++---- .../test/vision/image_classifier_test.py | 14 ++++++------- mediapipe/tasks/python/vision/BUILD | 6 +++--- .../tasks/python/vision/image_classifier.py | 20 ++++++++++--------- 12 files changed, 30 insertions(+), 29 deletions(-) rename mediapipe/tasks/python/components/containers/{ => proto}/BUILD (95%) rename mediapipe/tasks/python/components/containers/{ => proto}/__init__.py (100%) rename mediapipe/tasks/python/components/containers/{ => proto}/bounding_box.py (100%) rename mediapipe/tasks/python/components/containers/{ => proto}/category.py (100%) rename mediapipe/tasks/python/components/containers/{ => proto}/classifications.py (96%) rename mediapipe/tasks/python/components/containers/{ => proto}/detections.py (100%) rename mediapipe/tasks/python/components/{ => processors}/proto/BUILD (91%) rename mediapipe/tasks/python/components/{ => processors}/proto/classifier_options.py (97%) diff --git a/mediapipe/tasks/python/components/containers/BUILD b/mediapipe/tasks/python/components/containers/proto/BUILD similarity index 95% rename from mediapipe/tasks/python/components/containers/BUILD rename to mediapipe/tasks/python/components/containers/proto/BUILD index 450111161..d46df22da 100644 --- a/mediapipe/tasks/python/components/containers/BUILD +++ b/mediapipe/tasks/python/components/containers/proto/BUILD @@ -53,7 +53,7 @@ py_library( srcs = ["classifications.py"], deps = [ ":category", - "//mediapipe/tasks/cc/components/containers:classifications_py_pb2", + "//mediapipe/tasks/cc/components/containers/proto:classifications_py_pb2", "//mediapipe/tasks/python/core:optional_dependencies", ], ) diff --git a/mediapipe/tasks/python/components/containers/__init__.py b/mediapipe/tasks/python/components/containers/proto/__init__.py similarity index 100% rename from mediapipe/tasks/python/components/containers/__init__.py rename to mediapipe/tasks/python/components/containers/proto/__init__.py diff --git a/mediapipe/tasks/python/components/containers/bounding_box.py b/mediapipe/tasks/python/components/containers/proto/bounding_box.py similarity index 100% rename from mediapipe/tasks/python/components/containers/bounding_box.py rename to mediapipe/tasks/python/components/containers/proto/bounding_box.py diff --git a/mediapipe/tasks/python/components/containers/category.py b/mediapipe/tasks/python/components/containers/proto/category.py similarity index 100% rename from mediapipe/tasks/python/components/containers/category.py rename to mediapipe/tasks/python/components/containers/proto/category.py diff --git a/mediapipe/tasks/python/components/containers/classifications.py b/mediapipe/tasks/python/components/containers/proto/classifications.py similarity index 96% rename from mediapipe/tasks/python/components/containers/classifications.py rename to mediapipe/tasks/python/components/containers/proto/classifications.py index 19c5decde..2c43370b2 100644 --- a/mediapipe/tasks/python/components/containers/classifications.py +++ b/mediapipe/tasks/python/components/containers/proto/classifications.py @@ -16,8 +16,8 @@ import dataclasses from typing import Any, List, Optional -from mediapipe.tasks.cc.components.containers import classifications_pb2 -from mediapipe.tasks.python.components.containers import category as category_module +from mediapipe.tasks.cc.components.containers.proto import classifications_pb2 +from mediapipe.tasks.python.components.containers.proto import category as category_module from mediapipe.tasks.python.core.optional_dependencies import doc_controls _ClassificationEntryProto = classifications_pb2.ClassificationEntry diff --git a/mediapipe/tasks/python/components/containers/detections.py b/mediapipe/tasks/python/components/containers/proto/detections.py similarity index 100% rename from mediapipe/tasks/python/components/containers/detections.py rename to mediapipe/tasks/python/components/containers/proto/detections.py diff --git a/mediapipe/tasks/python/components/proto/BUILD b/mediapipe/tasks/python/components/processors/proto/BUILD similarity index 91% rename from mediapipe/tasks/python/components/proto/BUILD rename to mediapipe/tasks/python/components/processors/proto/BUILD index 7814ec675..814e15d1f 100644 --- a/mediapipe/tasks/python/components/proto/BUILD +++ b/mediapipe/tasks/python/components/processors/proto/BUILD @@ -22,8 +22,7 @@ py_library( name = "classifier_options", srcs = ["classifier_options.py"], deps = [ - "//mediapipe/tasks/cc/components/proto:classifier_options_py_pb2", + "//mediapipe/tasks/cc/components/processors/proto:classifier_options_py_pb2", "//mediapipe/tasks/python/core:optional_dependencies", ], ) - diff --git a/mediapipe/tasks/python/components/proto/classifier_options.py b/mediapipe/tasks/python/components/processors/proto/classifier_options.py similarity index 97% rename from mediapipe/tasks/python/components/proto/classifier_options.py rename to mediapipe/tasks/python/components/processors/proto/classifier_options.py index c73828c13..b4597e57a 100644 --- a/mediapipe/tasks/python/components/proto/classifier_options.py +++ b/mediapipe/tasks/python/components/processors/proto/classifier_options.py @@ -16,7 +16,7 @@ import dataclasses from typing import Any, List, Optional -from mediapipe.tasks.cc.components.proto import classifier_options_pb2 +from mediapipe.tasks.cc.components.processors.proto import classifier_options_pb2 from mediapipe.tasks.python.core.optional_dependencies import doc_controls _ClassifierOptionsProto = classifier_options_pb2.ClassifierOptions diff --git a/mediapipe/tasks/python/test/vision/BUILD b/mediapipe/tasks/python/test/vision/BUILD index df2e72f98..107a78a33 100644 --- a/mediapipe/tasks/python/test/vision/BUILD +++ b/mediapipe/tasks/python/test/vision/BUILD @@ -46,11 +46,11 @@ py_test( ], deps = [ "//mediapipe/python:_framework_bindings", - "//mediapipe/tasks/python/components/proto:classifier_options", - "//mediapipe/tasks/python/components/containers:category", - "//mediapipe/tasks/python/components/containers:classifications", + "//mediapipe/tasks/python/components/processors/proto:classifier_options", + "//mediapipe/tasks/python/components/containers/proto:category", + "//mediapipe/tasks/python/components/containers/proto:classifications", "//mediapipe/tasks/python/core:base_options", - "//mediapipe/tasks/python/test:test_util", + "//mediapipe/tasks/python/test:test_utils", "//mediapipe/tasks/python/vision:image_classifier", "//mediapipe/tasks/python/vision/core:vision_task_running_mode", ], diff --git a/mediapipe/tasks/python/test/vision/image_classifier_test.py b/mediapipe/tasks/python/test/vision/image_classifier_test.py index 3d7d116df..c16587cb5 100644 --- a/mediapipe/tasks/python/test/vision/image_classifier_test.py +++ b/mediapipe/tasks/python/test/vision/image_classifier_test.py @@ -19,11 +19,11 @@ from absl.testing import absltest from absl.testing import parameterized from mediapipe.python._framework_bindings import image as image_module -from mediapipe.tasks.python.components.proto import classifier_options -from mediapipe.tasks.python.components.containers import category as category_module -from mediapipe.tasks.python.components.containers import classifications as classifications_module +from mediapipe.tasks.python.components.processors.proto import classifier_options +from mediapipe.tasks.python.components.containers.proto import category as category_module +from mediapipe.tasks.python.components.containers.proto import classifications as classifications_module from mediapipe.tasks.python.core import base_options as base_options_module -from mediapipe.tasks.python.test import test_util +from mediapipe.tasks.python.test import test_utils from mediapipe.tasks.python.vision import image_classifier from mediapipe.tasks.python.vision.core import vision_task_running_mode as running_mode_module @@ -88,9 +88,9 @@ class ImageClassifierTest(parameterized.TestCase): def setUp(self): super().setUp() - self.test_image = test_util.read_test_image( - test_util.get_test_data_path(_IMAGE_FILE)) - self.model_path = test_util.get_test_data_path(_MODEL_FILE) + self.test_image = _Image.create_from_file( + test_utils.get_test_data_path(_IMAGE_FILE)) + self.model_path = test_utils.get_test_data_path(_MODEL_FILE) @parameterized.parameters( (ModelFileType.FILE_NAME, 4, _EXPECTED_CLASSIFICATION_RESULT), diff --git a/mediapipe/tasks/python/vision/BUILD b/mediapipe/tasks/python/vision/BUILD index a724f149a..69fd8f2a6 100644 --- a/mediapipe/tasks/python/vision/BUILD +++ b/mediapipe/tasks/python/vision/BUILD @@ -46,9 +46,9 @@ py_library( "//mediapipe/python:_framework_bindings", "//mediapipe/python:packet_creator", "//mediapipe/python:packet_getter", - "//mediapipe/tasks/cc/vision/image_classifier/proto:image_classifier_options_py_pb2", - "//mediapipe/tasks/python/components/proto:classifier_options", - "//mediapipe/tasks/python/components/containers:classifications", + "//mediapipe/tasks/cc/vision/image_classifier/proto:image_classifier_graph_options_py_pb2", + "//mediapipe/tasks/python/components/processors/proto:classifier_options", + "//mediapipe/tasks/python/components/containers/proto:classifications", "//mediapipe/tasks/python/core:base_options", "//mediapipe/tasks/python/core:optional_dependencies", "//mediapipe/tasks/python/core:task_info", diff --git a/mediapipe/tasks/python/vision/image_classifier.py b/mediapipe/tasks/python/vision/image_classifier.py index c240ffedf..809429c37 100644 --- a/mediapipe/tasks/python/vision/image_classifier.py +++ b/mediapipe/tasks/python/vision/image_classifier.py @@ -21,9 +21,9 @@ from mediapipe.python import packet_getter from mediapipe.python._framework_bindings import image as image_module from mediapipe.python._framework_bindings import packet as packet_module from mediapipe.python._framework_bindings import task_runner as task_runner_module -from mediapipe.tasks.cc.vision.image_classifier.proto import image_classifier_options_pb2 -from mediapipe.tasks.python.components.proto import classifier_options -from mediapipe.tasks.python.components.containers import classifications as classifications_module +from mediapipe.tasks.cc.vision.image_classifier.proto import image_classifier_graph_options_pb2 +from mediapipe.tasks.python.components.processors.proto import classifier_options +from mediapipe.tasks.python.components.containers.proto import classifications as classifications_module from mediapipe.tasks.python.core import base_options as base_options_module from mediapipe.tasks.python.core import task_info as task_info_module from mediapipe.tasks.python.core.optional_dependencies import doc_controls @@ -31,7 +31,7 @@ from mediapipe.tasks.python.vision.core import base_vision_task_api from mediapipe.tasks.python.vision.core import vision_task_running_mode as running_mode_module _BaseOptions = base_options_module.BaseOptions -_ImageClassifierOptionsProto = image_classifier_options_pb2.ImageClassifierOptions +_ImageClassifierGraphOptionsProto = image_classifier_graph_options_pb2.ImageClassifierGraphOptions _ClassifierOptions = classifier_options.ClassifierOptions _RunningMode = running_mode_module.VisionTaskRunningMode _TaskInfo = task_info_module.TaskInfo @@ -41,7 +41,7 @@ _CLASSIFICATION_RESULT_OUT_STREAM_NAME = 'classification_result_out' _CLASSIFICATION_RESULT_TAG = 'CLASSIFICATION_RESULT' _IMAGE_IN_STREAM_NAME = 'image_in' _IMAGE_TAG = 'IMAGE' -_TASK_GRAPH_NAME = 'mediapipe.tasks.vision.ImageClassifierGraph' +_TASK_GRAPH_NAME = 'mediapipe.tasks.vision.image_classifier.ImageClassifierGraph' @dataclasses.dataclass @@ -70,13 +70,13 @@ class ImageClassifierOptions: None]] = None @doc_controls.do_not_generate_docs - def to_pb2(self) -> _ImageClassifierOptionsProto: + def to_pb2(self) -> _ImageClassifierGraphOptionsProto: """Generates an ImageClassifierOptions protobuf object.""" base_options_proto = self.base_options.to_pb2() base_options_proto.use_stream_mode = False if self.running_mode == _RunningMode.IMAGE else True classifier_options_proto = self.classifier_options.to_pb2() - return _ImageClassifierOptionsProto( + return _ImageClassifierGraphOptionsProto( base_options=base_options_proto, classifier_options=classifier_options_proto ) @@ -138,7 +138,9 @@ class ImageClassifier(base_vision_task_api.BaseVisionTaskApi): task_info = _TaskInfo( task_graph=_TASK_GRAPH_NAME, - input_streams=[':'.join([_IMAGE_TAG, _IMAGE_IN_STREAM_NAME])], + input_streams=[ + ':'.join([_IMAGE_TAG, _IMAGE_IN_STREAM_NAME]), + ], output_streams=[ ':'.join([_CLASSIFICATION_RESULT_TAG, _CLASSIFICATION_RESULT_OUT_STREAM_NAME]) @@ -153,7 +155,7 @@ class ImageClassifier(base_vision_task_api.BaseVisionTaskApi): # TODO: Create an Image class for MediaPipe Tasks. def classify( self, - image: image_module.Image + image: image_module.Image, ) -> classifications_module.ClassificationResult: """Performs image classification on the provided MediaPipe Image. From 64d5c159c6f509b656b8e28ed91375c24db3e35b Mon Sep 17 00:00:00 2001 From: kinaryml Date: Mon, 3 Oct 2022 16:14:35 -0700 Subject: [PATCH 13/23] Fixed an auto formatting issue that caused classification_posprocessing_graph's registration to fail --- .../processors/classification_postprocessing_graph.cc | 4 ++-- .../tasks/python/test/vision/image_classifier_test.py | 8 ++++---- 2 files changed, 6 insertions(+), 6 deletions(-) diff --git a/mediapipe/tasks/cc/components/processors/classification_postprocessing_graph.cc b/mediapipe/tasks/cc/components/processors/classification_postprocessing_graph.cc index 35adab687..649ff2c11 100644 --- a/mediapipe/tasks/cc/components/processors/classification_postprocessing_graph.cc +++ b/mediapipe/tasks/cc/components/processors/classification_postprocessing_graph.cc @@ -507,8 +507,8 @@ class ClassificationPostprocessingGraph : public mediapipe::Subgraph { } }; -REGISTER_MEDIAPIPE_GRAPH(::mediapipe::tasks::components::processors:: - ClassificationPostprocessingGraph); // NOLINT +REGISTER_MEDIAPIPE_GRAPH( + ::mediapipe::tasks::components::processors::ClassificationPostprocessingGraph); // NOLINT } // namespace processors } // namespace components diff --git a/mediapipe/tasks/python/test/vision/image_classifier_test.py b/mediapipe/tasks/python/test/vision/image_classifier_test.py index c16587cb5..0ce70395e 100644 --- a/mediapipe/tasks/python/test/vision/image_classifier_test.py +++ b/mediapipe/tasks/python/test/vision/image_classifier_test.py @@ -48,22 +48,22 @@ _EXPECTED_CLASSIFICATION_RESULT = _ClassificationResult( categories=[ _Category( index=934, - score=0.7952049970626831, + score=0.7939587831497192, display_name='', category_name='cheeseburger'), _Category( index=932, - score=0.02732999622821808, + score=0.02739289402961731, display_name='', category_name='bagel'), _Category( index=925, - score=0.01933487318456173, + score=0.01934075355529785, display_name='', category_name='guacamole'), _Category( index=963, - score=0.006279350258409977, + score=0.006327860057353973, display_name='', category_name='meat loaf') ], From a22a5283d272771323186af02ce061c74a32c501 Mon Sep 17 00:00:00 2001 From: kinaryml Date: Wed, 5 Oct 2022 04:43:30 -0700 Subject: [PATCH 14/23] Adjusted namespaces --- .../tasks/python/components/containers/{proto => }/BUILD | 0 .../python/components/containers/{proto => }/__init__.py | 0 .../components/containers/{proto => }/bounding_box.py | 0 .../python/components/containers/{proto => }/category.py | 0 .../components/containers/{proto => }/classifications.py | 2 +- .../python/components/containers/{proto => }/detections.py | 0 .../tasks/python/components/processors/{proto => }/BUILD | 0 .../components/processors/{proto => }/classifier_options.py | 0 mediapipe/tasks/python/test/vision/BUILD | 6 +++--- mediapipe/tasks/python/test/vision/image_classifier_test.py | 6 +++--- mediapipe/tasks/python/vision/BUILD | 4 ++-- mediapipe/tasks/python/vision/image_classifier.py | 4 ++-- 12 files changed, 11 insertions(+), 11 deletions(-) rename mediapipe/tasks/python/components/containers/{proto => }/BUILD (100%) rename mediapipe/tasks/python/components/containers/{proto => }/__init__.py (100%) rename mediapipe/tasks/python/components/containers/{proto => }/bounding_box.py (100%) rename mediapipe/tasks/python/components/containers/{proto => }/category.py (100%) rename mediapipe/tasks/python/components/containers/{proto => }/classifications.py (98%) rename mediapipe/tasks/python/components/containers/{proto => }/detections.py (100%) rename mediapipe/tasks/python/components/processors/{proto => }/BUILD (100%) rename mediapipe/tasks/python/components/processors/{proto => }/classifier_options.py (100%) diff --git a/mediapipe/tasks/python/components/containers/proto/BUILD b/mediapipe/tasks/python/components/containers/BUILD similarity index 100% rename from mediapipe/tasks/python/components/containers/proto/BUILD rename to mediapipe/tasks/python/components/containers/BUILD diff --git a/mediapipe/tasks/python/components/containers/proto/__init__.py b/mediapipe/tasks/python/components/containers/__init__.py similarity index 100% rename from mediapipe/tasks/python/components/containers/proto/__init__.py rename to mediapipe/tasks/python/components/containers/__init__.py diff --git a/mediapipe/tasks/python/components/containers/proto/bounding_box.py b/mediapipe/tasks/python/components/containers/bounding_box.py similarity index 100% rename from mediapipe/tasks/python/components/containers/proto/bounding_box.py rename to mediapipe/tasks/python/components/containers/bounding_box.py diff --git a/mediapipe/tasks/python/components/containers/proto/category.py b/mediapipe/tasks/python/components/containers/category.py similarity index 100% rename from mediapipe/tasks/python/components/containers/proto/category.py rename to mediapipe/tasks/python/components/containers/category.py diff --git a/mediapipe/tasks/python/components/containers/proto/classifications.py b/mediapipe/tasks/python/components/containers/classifications.py similarity index 98% rename from mediapipe/tasks/python/components/containers/proto/classifications.py rename to mediapipe/tasks/python/components/containers/classifications.py index 2c43370b2..c51816e07 100644 --- a/mediapipe/tasks/python/components/containers/proto/classifications.py +++ b/mediapipe/tasks/python/components/containers/classifications.py @@ -17,7 +17,7 @@ import dataclasses from typing import Any, List, Optional from mediapipe.tasks.cc.components.containers.proto import classifications_pb2 -from mediapipe.tasks.python.components.containers.proto import category as category_module +from mediapipe.tasks.python.components.containers import category as category_module from mediapipe.tasks.python.core.optional_dependencies import doc_controls _ClassificationEntryProto = classifications_pb2.ClassificationEntry diff --git a/mediapipe/tasks/python/components/containers/proto/detections.py b/mediapipe/tasks/python/components/containers/detections.py similarity index 100% rename from mediapipe/tasks/python/components/containers/proto/detections.py rename to mediapipe/tasks/python/components/containers/detections.py diff --git a/mediapipe/tasks/python/components/processors/proto/BUILD b/mediapipe/tasks/python/components/processors/BUILD similarity index 100% rename from mediapipe/tasks/python/components/processors/proto/BUILD rename to mediapipe/tasks/python/components/processors/BUILD diff --git a/mediapipe/tasks/python/components/processors/proto/classifier_options.py b/mediapipe/tasks/python/components/processors/classifier_options.py similarity index 100% rename from mediapipe/tasks/python/components/processors/proto/classifier_options.py rename to mediapipe/tasks/python/components/processors/classifier_options.py diff --git a/mediapipe/tasks/python/test/vision/BUILD b/mediapipe/tasks/python/test/vision/BUILD index 107a78a33..09bbf1958 100644 --- a/mediapipe/tasks/python/test/vision/BUILD +++ b/mediapipe/tasks/python/test/vision/BUILD @@ -46,9 +46,9 @@ py_test( ], deps = [ "//mediapipe/python:_framework_bindings", - "//mediapipe/tasks/python/components/processors/proto:classifier_options", - "//mediapipe/tasks/python/components/containers/proto:category", - "//mediapipe/tasks/python/components/containers/proto:classifications", + "//mediapipe/tasks/python/components/processors:classifier_options", + "//mediapipe/tasks/python/components/containers:category", + "//mediapipe/tasks/python/components/containers:classifications", "//mediapipe/tasks/python/core:base_options", "//mediapipe/tasks/python/test:test_utils", "//mediapipe/tasks/python/vision:image_classifier", diff --git a/mediapipe/tasks/python/test/vision/image_classifier_test.py b/mediapipe/tasks/python/test/vision/image_classifier_test.py index 0ce70395e..7fbd96ddc 100644 --- a/mediapipe/tasks/python/test/vision/image_classifier_test.py +++ b/mediapipe/tasks/python/test/vision/image_classifier_test.py @@ -19,9 +19,9 @@ from absl.testing import absltest from absl.testing import parameterized from mediapipe.python._framework_bindings import image as image_module -from mediapipe.tasks.python.components.processors.proto import classifier_options -from mediapipe.tasks.python.components.containers.proto import category as category_module -from mediapipe.tasks.python.components.containers.proto import classifications as classifications_module +from mediapipe.tasks.python.components.processors import classifier_options +from mediapipe.tasks.python.components.containers import category as category_module +from mediapipe.tasks.python.components.containers import classifications as classifications_module from mediapipe.tasks.python.core import base_options as base_options_module from mediapipe.tasks.python.test import test_utils from mediapipe.tasks.python.vision import image_classifier diff --git a/mediapipe/tasks/python/vision/BUILD b/mediapipe/tasks/python/vision/BUILD index 69fd8f2a6..7d44e3326 100644 --- a/mediapipe/tasks/python/vision/BUILD +++ b/mediapipe/tasks/python/vision/BUILD @@ -47,8 +47,8 @@ py_library( "//mediapipe/python:packet_creator", "//mediapipe/python:packet_getter", "//mediapipe/tasks/cc/vision/image_classifier/proto:image_classifier_graph_options_py_pb2", - "//mediapipe/tasks/python/components/processors/proto:classifier_options", - "//mediapipe/tasks/python/components/containers/proto:classifications", + "//mediapipe/tasks/python/components/processors:classifier_options", + "//mediapipe/tasks/python/components/containers:classifications", "//mediapipe/tasks/python/core:base_options", "//mediapipe/tasks/python/core:optional_dependencies", "//mediapipe/tasks/python/core:task_info", diff --git a/mediapipe/tasks/python/vision/image_classifier.py b/mediapipe/tasks/python/vision/image_classifier.py index 809429c37..94dcd4d70 100644 --- a/mediapipe/tasks/python/vision/image_classifier.py +++ b/mediapipe/tasks/python/vision/image_classifier.py @@ -22,8 +22,8 @@ from mediapipe.python._framework_bindings import image as image_module from mediapipe.python._framework_bindings import packet as packet_module from mediapipe.python._framework_bindings import task_runner as task_runner_module from mediapipe.tasks.cc.vision.image_classifier.proto import image_classifier_graph_options_pb2 -from mediapipe.tasks.python.components.processors.proto import classifier_options -from mediapipe.tasks.python.components.containers.proto import classifications as classifications_module +from mediapipe.tasks.python.components.processors import classifier_options +from mediapipe.tasks.python.components.containers import classifications as classifications_module from mediapipe.tasks.python.core import base_options as base_options_module from mediapipe.tasks.python.core import task_info as task_info_module from mediapipe.tasks.python.core.optional_dependencies import doc_controls From e250c903f5b6e8c87b9a4cef175137df95802804 Mon Sep 17 00:00:00 2001 From: kinaryml Date: Wed, 5 Oct 2022 05:24:52 -0700 Subject: [PATCH 15/23] Added the ClassifyForVideo API --- .../test/vision/image_classifier_test.py | 13 +++++ .../tasks/python/vision/image_classifier.py | 52 ++++++++++++++++--- 2 files changed, 58 insertions(+), 7 deletions(-) diff --git a/mediapipe/tasks/python/test/vision/image_classifier_test.py b/mediapipe/tasks/python/test/vision/image_classifier_test.py index 7fbd96ddc..5143a28db 100644 --- a/mediapipe/tasks/python/test/vision/image_classifier_test.py +++ b/mediapipe/tasks/python/test/vision/image_classifier_test.py @@ -15,6 +15,7 @@ import enum +import numpy as np from absl.testing import absltest from absl.testing import parameterized @@ -121,6 +122,18 @@ class ImageClassifierTest(parameterized.TestCase): # a context. classifier.close() + def test_classify_for_video(self): + classifier_options = _ClassifierOptions(max_results=4) + options = _ImageClassifierOptions( + base_options=_BaseOptions(model_asset_path=self.model_path), + running_mode=_RUNNING_MODE.VIDEO, + classifier_options=classifier_options) + with _ImageClassifier.create_from_options(options) as classifier: + for timestamp in range(0, 300, 30): + classification_result = classifier.classify_for_video( + self.test_image, timestamp) + self.assertEqual(classification_result, _EXPECTED_CLASSIFICATION_RESULT) + if __name__ == '__main__': absltest.main() diff --git a/mediapipe/tasks/python/vision/image_classifier.py b/mediapipe/tasks/python/vision/image_classifier.py index 94dcd4d70..b3bafa113 100644 --- a/mediapipe/tasks/python/vision/image_classifier.py +++ b/mediapipe/tasks/python/vision/image_classifier.py @@ -14,7 +14,7 @@ """MediaPipe image classifier task.""" import dataclasses -from typing import Callable, List, Mapping, Optional +from typing import Callable, Mapping, Optional from mediapipe.python import packet_creator from mediapipe.python import packet_getter @@ -40,6 +40,7 @@ _TaskRunner = task_runner_module.TaskRunner _CLASSIFICATION_RESULT_OUT_STREAM_NAME = 'classification_result_out' _CLASSIFICATION_RESULT_TAG = 'CLASSIFICATION_RESULT' _IMAGE_IN_STREAM_NAME = 'image_in' +_IMAGE_OUT_STREAM_NAME = 'image_out' _IMAGE_TAG = 'IMAGE' _TASK_GRAPH_NAME = 'mediapipe.tasks.vision.image_classifier.ImageClassifierGraph' @@ -66,8 +67,8 @@ class ImageClassifierOptions: running_mode: _RunningMode = _RunningMode.IMAGE classifier_options: _ClassifierOptions = _ClassifierOptions() result_callback: Optional[ - Callable[[classifications_module.ClassificationResult], - None]] = None + Callable[[classifications_module.ClassificationResult, image_module.Image, + int], None]] = None @doc_controls.do_not_generate_docs def to_pb2(self) -> _ImageClassifierGraphOptionsProto: @@ -134,7 +135,9 @@ class ImageClassifier(base_vision_task_api.BaseVisionTaskApi): classifications_module.Classifications.create_from_pb2(classification) for classification in classification_result_proto.classifications ]) - options.result_callback(classification_result) + image = packet_getter.get_image(output_packets[_IMAGE_OUT_STREAM_NAME]) + timestamp = output_packets[_IMAGE_OUT_STREAM_NAME].timestamp + options.result_callback(classification_result, image, timestamp) task_info = _TaskInfo( task_graph=_TASK_GRAPH_NAME, @@ -143,7 +146,8 @@ class ImageClassifier(base_vision_task_api.BaseVisionTaskApi): ], output_streams=[ ':'.join([_CLASSIFICATION_RESULT_TAG, - _CLASSIFICATION_RESULT_OUT_STREAM_NAME]) + _CLASSIFICATION_RESULT_OUT_STREAM_NAME]), + ':'.join([_IMAGE_TAG, _IMAGE_OUT_STREAM_NAME]) ], task_options=options) return cls( @@ -175,6 +179,40 @@ class ImageClassifier(base_vision_task_api.BaseVisionTaskApi): output_packets[_CLASSIFICATION_RESULT_OUT_STREAM_NAME]) return classifications_module.ClassificationResult([ - classifications_module.Classifications.create_from_pb2(classification) - for classification in classification_result_proto.classifications + classifications_module.Classifications.create_from_pb2(classification) + for classification in classification_result_proto.classifications + ]) + + def classify_for_video( + self, image: image_module.Image, + timestamp_ms: int + ) -> classifications_module.ClassificationResult: + """Performs image classification on the provided video frames. + + Only use this method when the ImageClassifier is created with the video + running mode. It's required to provide the video frame's timestamp (in + milliseconds) along with the video frame. The input timestamps should be + monotonically increasing for adjacent calls of this method. + + Args: + image: MediaPipe Image. + timestamp_ms: The timestamp of the input video frame in milliseconds. + + Returns: + A classification result object that contains a list of classifications. + + Raises: + ValueError: If any of the input arguments is invalid. + RuntimeError: If image classification failed to run. + """ + output_packets = self._process_video_data({ + _IMAGE_IN_STREAM_NAME: + packet_creator.create_image(image).at(timestamp_ms) + }) + classification_result_proto = packet_getter.get_proto( + output_packets[_CLASSIFICATION_RESULT_OUT_STREAM_NAME]) + + return classifications_module.ClassificationResult([ + classifications_module.Classifications.create_from_pb2(classification) + for classification in classification_result_proto.classifications ]) From cb806071ba30ba66cfa1ec26d80b9b5c0b3e2501 Mon Sep 17 00:00:00 2001 From: kinaryml Date: Fri, 7 Oct 2022 22:26:49 -0700 Subject: [PATCH 16/23] Added more tests and updated the APIs to use a new constant --- .../test/vision/image_classifier_test.py | 338 ++++++++++++++++-- .../tasks/python/vision/image_classifier.py | 45 ++- 2 files changed, 355 insertions(+), 28 deletions(-) diff --git a/mediapipe/tasks/python/test/vision/image_classifier_test.py b/mediapipe/tasks/python/test/vision/image_classifier_test.py index 5143a28db..073674c3f 100644 --- a/mediapipe/tasks/python/test/vision/image_classifier_test.py +++ b/mediapipe/tasks/python/test/vision/image_classifier_test.py @@ -14,6 +14,7 @@ """Tests for image classifier.""" import enum +from unittest import mock import numpy as np from absl.testing import absltest @@ -41,33 +42,46 @@ _RUNNING_MODE = running_mode_module.VisionTaskRunningMode _MODEL_FILE = 'mobilenet_v2_1.0_224.tflite' _IMAGE_FILE = 'burger.jpg' +_EXPECTED_CATEGORIES = [ + _Category( + index=934, + score=0.7939587831497192, + display_name='', + category_name='cheeseburger'), + _Category( + index=932, + score=0.02739289402961731, + display_name='', + category_name='bagel'), + _Category( + index=925, + score=0.01934075355529785, + display_name='', + category_name='guacamole'), + _Category( + index=963, + score=0.006327860057353973, + display_name='', + category_name='meat loaf') +] _EXPECTED_CLASSIFICATION_RESULT = _ClassificationResult( classifications=[ _Classifications( entries=[ _ClassificationEntry( - categories=[ - _Category( - index=934, - score=0.7939587831497192, - display_name='', - category_name='cheeseburger'), - _Category( - index=932, - score=0.02739289402961731, - display_name='', - category_name='bagel'), - _Category( - index=925, - score=0.01934075355529785, - display_name='', - category_name='guacamole'), - _Category( - index=963, - score=0.006327860057353973, - display_name='', - category_name='meat loaf') - ], + categories=_EXPECTED_CATEGORIES, + timestamp_ms=0 + ) + ], + head_index=0, + head_name='probability') + ]) +_EMPTY_CLASSIFICATION_RESULT = _ClassificationResult( + classifications=[ + _Classifications( + entries=[ + _ClassificationEntry( + categories=[], timestamp_ms=0 ) ], @@ -93,6 +107,36 @@ class ImageClassifierTest(parameterized.TestCase): test_utils.get_test_data_path(_IMAGE_FILE)) self.model_path = test_utils.get_test_data_path(_MODEL_FILE) + def test_create_from_file_succeeds_with_valid_model_path(self): + # Creates with default option and valid model file successfully. + with _ImageClassifier.create_from_model_path(self.model_path) as classifier: + self.assertIsInstance(classifier, _ImageClassifier) + + def test_create_from_options_succeeds_with_valid_model_path(self): + # Creates with options containing model file successfully. + base_options = _BaseOptions(model_asset_path=self.model_path) + options = _ImageClassifierOptions(base_options=base_options) + with _ImageClassifier.create_from_options(options) as classifier: + self.assertIsInstance(classifier, _ImageClassifier) + + def test_create_from_options_fails_with_invalid_model_path(self): + # Invalid empty model path. + with self.assertRaisesRegex( + ValueError, + r"ExternalFile must specify at least one of 'file_content', " + r"'file_name' or 'file_descriptor_meta'."): + base_options = _BaseOptions(model_asset_path='') + options = _ImageClassifierOptions(base_options=base_options) + _ImageClassifier.create_from_options(options) + + def test_create_from_options_succeeds_with_valid_model_content(self): + # Creates with options containing model content successfully. + with open(self.model_path, 'rb') as f: + base_options = _BaseOptions(model_asset_buffer=f.read()) + options = _ImageClassifierOptions(base_options=base_options) + classifier = _ImageClassifier.create_from_options(options) + self.assertIsInstance(classifier, _ImageClassifier) + @parameterized.parameters( (ModelFileType.FILE_NAME, 4, _EXPECTED_CLASSIFICATION_RESULT), (ModelFileType.FILE_CONTENT, 4, _EXPECTED_CLASSIFICATION_RESULT)) @@ -122,6 +166,183 @@ class ImageClassifierTest(parameterized.TestCase): # a context. classifier.close() + @parameterized.parameters( + (ModelFileType.FILE_NAME, 4, _EXPECTED_CLASSIFICATION_RESULT), + (ModelFileType.FILE_CONTENT, 4, _EXPECTED_CLASSIFICATION_RESULT)) + def test_classify_in_context(self, model_file_type, max_results, + expected_classification_result): + if model_file_type is ModelFileType.FILE_NAME: + base_options = _BaseOptions(model_asset_path=self.model_path) + elif model_file_type is ModelFileType.FILE_CONTENT: + with open(self.model_path, 'rb') as f: + model_content = f.read() + base_options = _BaseOptions(model_asset_buffer=model_content) + else: + # Should never happen + raise ValueError('model_file_type is invalid.') + + classifier_options = _ClassifierOptions(max_results=max_results) + options = _ImageClassifierOptions( + base_options=base_options, classifier_options=classifier_options) + with _ImageClassifier.create_from_options(options) as classifier: + # Performs image classification on the input. + image_result = classifier.classify(self.test_image) + # Comparing results. + self.assertEqual(image_result, expected_classification_result) + + def test_score_threshold_option(self): + classifier_options = _ClassifierOptions(score_threshold=_SCORE_THRESHOLD) + options = _ImageClassifierOptions( + base_options=_BaseOptions(model_asset_path=self.model_path), + classifier_options=classifier_options) + with _ImageClassifier.create_from_options(options) as classifier: + # Performs image classification on the input. + image_result = classifier.classify(self.test_image) + classifications = image_result.classifications + + for classification in classifications: + for entry in classification.entries: + score = entry.categories[0].score + self.assertGreaterEqual( + score, _SCORE_THRESHOLD, + f'Classification with score lower than threshold found. ' + f'{classification}') + + def test_max_results_option(self): + classifier_options = _ClassifierOptions(score_threshold=_SCORE_THRESHOLD) + options = _ImageClassifierOptions( + base_options=_BaseOptions(model_asset_path=self.model_path), + classifier_options=classifier_options) + with _ImageClassifier.create_from_options(options) as classifier: + # Performs image classification on the input. + image_result = classifier.classify(self.test_image) + categories = image_result.classifications[0].entries[0].categories + + self.assertLessEqual( + len(categories), _MAX_RESULTS, 'Too many results returned.') + + def test_allow_list_option(self): + classifier_options = _ClassifierOptions(category_allowlist=_ALLOW_LIST) + options = _ImageClassifierOptions( + base_options=_BaseOptions(model_asset_path=self.model_path), + classifier_options=classifier_options) + with _ImageClassifier.create_from_options(options) as classifier: + # Performs image classification on the input. + image_result = classifier.classify(self.test_image) + classifications = image_result.classifications + + for classification in classifications: + for entry in classification.entries: + label = entry.categories[0].category_name + self.assertIn(label, _ALLOW_LIST, + f'Label {label} found but not in label allow list') + + def test_deny_list_option(self): + classifier_options = _ClassifierOptions(category_denylist=_DENY_LIST) + options = _ImageClassifierOptions( + base_options=_BaseOptions(model_asset_path=self.model_path), + classifier_options=classifier_options) + with _ImageClassifier.create_from_options(options) as classifier: + # Performs image classification on the input. + image_result = classifier.classify(self.test_image) + classifications = image_result.classifications + + for classification in classifications: + for entry in classification.entries: + label = entry.categories[0].category_name + self.assertNotIn(label, _DENY_LIST, + f'Label {label} found but in deny list.') + + def test_combined_allowlist_and_denylist(self): + # Fails with combined allowlist and denylist + with self.assertRaisesRegex( + ValueError, + r'`category_allowlist` and `category_denylist` are mutually ' + r'exclusive options.'): + classifier_options = _ClassifierOptions(category_allowlist=['foo'], + category_denylist=['bar']) + options = _ImageClassifierOptions( + base_options=_BaseOptions(model_asset_path=self.model_path), + classifier_options=classifier_options) + with _ImageClassifier.create_from_options(options) as unused_classifier: + pass + + def test_empty_classification_outputs(self): + classifier_options = _ClassifierOptions(score_threshold=1) + options = _ImageClassifierOptions( + base_options=_BaseOptions(model_asset_path=self.model_path), + classifier_options=classifier_options) + with _ImageClassifier.create_from_options(options) as classifier: + # Performs image classification on the input. + image_result = classifier.classify(self.test_image) + self.assertEmpty(image_result.classifications[0].entries[0].categories) + + def test_missing_result_callback(self): + options = _ImageClassifierOptions( + base_options=_BaseOptions(model_asset_path=self.model_path), + running_mode=_RUNNING_MODE.LIVE_STREAM) + with self.assertRaisesRegex(ValueError, + r'result callback must be provided'): + with _ImageClassifier.create_from_options(options) as unused_classifier: + pass + + @parameterized.parameters((_RUNNING_MODE.IMAGE), (_RUNNING_MODE.VIDEO)) + def test_illegal_result_callback(self, running_mode): + options = _ImageClassifierOptions( + base_options=_BaseOptions(model_asset_path=self.model_path), + running_mode=running_mode, + result_callback=mock.MagicMock()) + with self.assertRaisesRegex(ValueError, + r'result callback should not be provided'): + with _ImageClassifier.create_from_options(options) as unused_classifier: + pass + + def test_calling_classify_for_video_in_image_mode(self): + options = _ImageClassifierOptions( + base_options=_BaseOptions(model_asset_path=self.model_path), + running_mode=_RUNNING_MODE.IMAGE) + with _ImageClassifier.create_from_options(options) as classifier: + with self.assertRaisesRegex(ValueError, + r'not initialized with the video mode'): + classifier.classify_for_video(self.test_image, 0) + + def test_calling_classify_async_in_image_mode(self): + options = _ImageClassifierOptions( + base_options=_BaseOptions(model_asset_path=self.model_path), + running_mode=_RUNNING_MODE.IMAGE) + with _ImageClassifier.create_from_options(options) as classifier: + with self.assertRaisesRegex(ValueError, + r'not initialized with the live stream mode'): + classifier.classify_async(self.test_image, 0) + + def test_calling_classify_in_video_mode(self): + options = _ImageClassifierOptions( + base_options=_BaseOptions(model_asset_path=self.model_path), + running_mode=_RUNNING_MODE.VIDEO) + with _ImageClassifier.create_from_options(options) as classifier: + with self.assertRaisesRegex(ValueError, + r'not initialized with the image mode'): + classifier.classify(self.test_image) + + def test_calling_classify_async_in_video_mode(self): + options = _ImageClassifierOptions( + base_options=_BaseOptions(model_asset_path=self.model_path), + running_mode=_RUNNING_MODE.VIDEO) + with _ImageClassifier.create_from_options(options) as classifier: + with self.assertRaisesRegex(ValueError, + r'not initialized with the live stream mode'): + classifier.classify_async(self.test_image, 0) + + def test_classify_for_video_with_out_of_order_timestamp(self): + options = _ImageClassifierOptions( + base_options=_BaseOptions(model_asset_path=self.model_path), + running_mode=_RUNNING_MODE.VIDEO) + with _ImageClassifier.create_from_options(options) as classifier: + unused_result = classifier.classify_for_video(self.test_image, 1) + with self.assertRaisesRegex( + ValueError, r'Input timestamp must be monotonically increasing'): + classifier.classify_for_video(self.test_image, 0) + def test_classify_for_video(self): classifier_options = _ClassifierOptions(max_results=4) options = _ImageClassifierOptions( @@ -132,7 +353,78 @@ class ImageClassifierTest(parameterized.TestCase): for timestamp in range(0, 300, 30): classification_result = classifier.classify_for_video( self.test_image, timestamp) - self.assertEqual(classification_result, _EXPECTED_CLASSIFICATION_RESULT) + expected_classification_result = _ClassificationResult( + classifications=[ + _Classifications( + entries=[ + _ClassificationEntry( + categories=_EXPECTED_CATEGORIES, timestamp_ms=timestamp) + ], + head_index=0, head_name='probability') + ]) + self.assertEqual(classification_result, expected_classification_result) + + def test_calling_classify_in_live_stream_mode(self): + options = _ImageClassifierOptions( + base_options=_BaseOptions(model_asset_path=self.model_path), + running_mode=_RUNNING_MODE.LIVE_STREAM, + result_callback=mock.MagicMock()) + with _ImageClassifier.create_from_options(options) as classifier: + with self.assertRaisesRegex(ValueError, + r'not initialized with the image mode'): + classifier.classify(self.test_image) + + def test_calling_classify_for_video_in_live_stream_mode(self): + options = _ImageClassifierOptions( + base_options=_BaseOptions(model_asset_path=self.model_path), + running_mode=_RUNNING_MODE.LIVE_STREAM, + result_callback=mock.MagicMock()) + with _ImageClassifier.create_from_options(options) as classifier: + with self.assertRaisesRegex(ValueError, + r'not initialized with the video mode'): + classifier.classify_for_video(self.test_image, 0) + + def test_classify_async_calls_with_illegal_timestamp(self): + classifier_options = _ClassifierOptions(max_results=4) + options = _ImageClassifierOptions( + base_options=_BaseOptions(model_asset_path=self.model_path), + running_mode=_RUNNING_MODE.LIVE_STREAM, + classifier_options=classifier_options, + result_callback=mock.MagicMock()) + with _ImageClassifier.create_from_options(options) as classifier: + classifier.classify_async(self.test_image, 100) + with self.assertRaisesRegex( + ValueError, r'Input timestamp must be monotonically increasing'): + classifier.classify_async(self.test_image, 0) + + # TODO: Fix the packet is empty issue. + """ + @parameterized.parameters((0, _EXPECTED_CLASSIFICATION_RESULT), + (1, _EMPTY_CLASSIFICATION_RESULT)) + def test_classify_async_calls(self, threshold, expected_result): + observed_timestamp_ms = -1 + + def check_result(result: _ClassificationResult, output_image: _Image, + timestamp_ms: int): + self.assertEqual(result, expected_result) + self.assertTrue( + np.array_equal(output_image.numpy_view(), + self.test_image.numpy_view())) + self.assertLess(observed_timestamp_ms, timestamp_ms) + self.observed_timestamp_ms = timestamp_ms + + classifier_options = _ClassifierOptions( + max_results=4, score_threshold=threshold) + options = _ImageClassifierOptions( + base_options=_BaseOptions(model_asset_path=self.model_path), + running_mode=_RUNNING_MODE.LIVE_STREAM, + classifier_options=classifier_options, + result_callback=check_result) + classifier = _ImageClassifier.create_from_options(options) + for timestamp in range(0, 300, 30): + classifier.classify_async(self.test_image, timestamp) + classifier.close() + """ if __name__ == '__main__': diff --git a/mediapipe/tasks/python/vision/image_classifier.py b/mediapipe/tasks/python/vision/image_classifier.py index b3bafa113..36c5561c4 100644 --- a/mediapipe/tasks/python/vision/image_classifier.py +++ b/mediapipe/tasks/python/vision/image_classifier.py @@ -43,6 +43,7 @@ _IMAGE_IN_STREAM_NAME = 'image_in' _IMAGE_OUT_STREAM_NAME = 'image_out' _IMAGE_TAG = 'IMAGE' _TASK_GRAPH_NAME = 'mediapipe.tasks.vision.image_classifier.ImageClassifierGraph' +_MICRO_SECONDS_PER_MILLISECOND = 1000 @dataclasses.dataclass @@ -91,7 +92,7 @@ class ImageClassifier(base_vision_task_api.BaseVisionTaskApi): """Creates an `ImageClassifier` object from a TensorFlow Lite model and the default `ImageClassifierOptions`. Note that the created `ImageClassifier` instance is in image mode, for - detecting objects on single image inputs. + classifying objects on single image inputs. Args: model_path: Path to the model. @@ -137,7 +138,8 @@ class ImageClassifier(base_vision_task_api.BaseVisionTaskApi): ]) image = packet_getter.get_image(output_packets[_IMAGE_OUT_STREAM_NAME]) timestamp = output_packets[_IMAGE_OUT_STREAM_NAME].timestamp - options.result_callback(classification_result, image, timestamp) + options.result_callback(classification_result, image, + timestamp.value / _MICRO_SECONDS_PER_MILLISECOND) task_info = _TaskInfo( task_graph=_TASK_GRAPH_NAME, @@ -156,7 +158,6 @@ class ImageClassifier(base_vision_task_api.BaseVisionTaskApi): _RunningMode.LIVE_STREAM), options.running_mode, packets_callback if options.result_callback else None) - # TODO: Create an Image class for MediaPipe Tasks. def classify( self, image: image_module.Image, @@ -206,8 +207,9 @@ class ImageClassifier(base_vision_task_api.BaseVisionTaskApi): RuntimeError: If image classification failed to run. """ output_packets = self._process_video_data({ - _IMAGE_IN_STREAM_NAME: - packet_creator.create_image(image).at(timestamp_ms) + _IMAGE_IN_STREAM_NAME: + packet_creator.create_image(image).at( + timestamp_ms * _MICRO_SECONDS_PER_MILLISECOND) }) classification_result_proto = packet_getter.get_proto( output_packets[_CLASSIFICATION_RESULT_OUT_STREAM_NAME]) @@ -216,3 +218,36 @@ class ImageClassifier(base_vision_task_api.BaseVisionTaskApi): classifications_module.Classifications.create_from_pb2(classification) for classification in classification_result_proto.classifications ]) + + def classify_async(self, image: image_module.Image, timestamp_ms: int) -> None: + """Sends live image data (an Image with a unique timestamp) to perform + image classification. + + Only use this method when the ImageClassifier is created with the live + stream running mode. The input timestamps should be monotonically increasing + for adjacent calls of this method. This method will return immediately after + the input image is accepted. The results will be available via the + `result_callback` provided in the `ImageClassifierOptions`. The + `classify_async` method is designed to process live stream data such as + camera input. To lower the overall latency, image classifier may drop the + input images if needed. In other words, it's not guaranteed to have output + per input image. + + The `result_callback` provides: + - A classification result object that contains a list of classifications. + - The input image that the image classifier runs on. + - The input timestamp in milliseconds. + + Args: + image: MediaPipe Image. + timestamp_ms: The timestamp of the input image in milliseconds. + + Raises: + ValueError: If the current input timestamp is smaller than what the image + classifier has already processed. + """ + self._send_live_stream_data({ + _IMAGE_IN_STREAM_NAME: + packet_creator.create_image(image).at( + timestamp_ms * _MICRO_SECONDS_PER_MILLISECOND) + }) From 44e6f8e1a1edea42bc92c1e9c83d59ddce678fb8 Mon Sep 17 00:00:00 2001 From: kinaryml Date: Mon, 10 Oct 2022 08:15:40 -0700 Subject: [PATCH 17/23] Updated image classifier to use a region of interest parameter --- .../tasks/python/components/containers/BUILD | 9 ++ .../python/components/containers/rect.py | 136 ++++++++++++++++ mediapipe/tasks/python/test/vision/BUILD | 1 + .../test/vision/image_classifier_test.py | 146 ++++++++++++------ mediapipe/tasks/python/vision/BUILD | 1 + .../tasks/python/vision/image_classifier.py | 46 ++++-- 6 files changed, 281 insertions(+), 58 deletions(-) create mode 100644 mediapipe/tasks/python/components/containers/rect.py diff --git a/mediapipe/tasks/python/components/containers/BUILD b/mediapipe/tasks/python/components/containers/BUILD index d46df22da..723210f5f 100644 --- a/mediapipe/tasks/python/components/containers/BUILD +++ b/mediapipe/tasks/python/components/containers/BUILD @@ -27,6 +27,15 @@ py_library( ], ) +py_library( + name = "rect", + srcs = ["rect.py"], + deps = [ + "//mediapipe/framework/formats:rect_py_pb2", + "//mediapipe/tasks/python/core:optional_dependencies", + ], +) + py_library( name = "category", srcs = ["category.py"], diff --git a/mediapipe/tasks/python/components/containers/rect.py b/mediapipe/tasks/python/components/containers/rect.py new file mode 100644 index 000000000..e74be1b0e --- /dev/null +++ b/mediapipe/tasks/python/components/containers/rect.py @@ -0,0 +1,136 @@ +# Copyright 2022 The MediaPipe Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""Rect data class.""" + +import dataclasses +from typing import Any, Optional + +from mediapipe.framework.formats import rect_pb2 +from mediapipe.tasks.python.core.optional_dependencies import doc_controls + +_RectProto = rect_pb2.Rect +_NormalizedRectProto = rect_pb2.NormalizedRect + + +@dataclasses.dataclass +class Rect: + """A rectangle with rotation in image coordinates. + + Attributes: + x_center : The X coordinate of the top-left corner, in pixels. + y_center : The Y coordinate of the top-left corner, in pixels. + width: The width of the rectangle, in pixels. + height: The height of the rectangle, in pixels. + rotation: Rotation angle is clockwise in radians. + rect_id: Optional unique id to help associate different rectangles to each + other. + """ + + x_center: int + y_center: int + width: int + height: int + rotation: Optional[float] = 0.0 + rect_id: Optional[int] = None + + @doc_controls.do_not_generate_docs + def to_pb2(self) -> _RectProto: + """Generates a Rect protobuf object.""" + return _RectProto( + x_center=self.x_center, + y_center=self.y_center, + width=self.width, + height=self.height, + ) + + @classmethod + @doc_controls.do_not_generate_docs + def create_from_pb2(cls, pb2_obj: _RectProto) -> 'Rect': + """Creates a `Rect` object from the given protobuf object.""" + return Rect( + x_center=pb2_obj.x_center, + y_center=pb2_obj.y_center, + width=pb2_obj.width, + height=pb2_obj.height) + + def __eq__(self, other: Any) -> bool: + """Checks if this object is equal to the given object. + + Args: + other: The object to be compared with. + + Returns: + True if the objects are equal. + """ + if not isinstance(other, Rect): + return False + + return self.to_pb2().__eq__(other.to_pb2()) + + +@dataclasses.dataclass +class NormalizedRect: + """A rectangle with rotation in normalized coordinates. The values of box + center location and size are within [0, 1]. + + Attributes: + x_center : The X normalized coordinate of the top-left corner. + y_center : The Y normalized coordinate of the top-left corner. + width: The width of the rectangle. + height: The height of the rectangle. + rotation: Rotation angle is clockwise in radians. + rect_id: Optional unique id to help associate different rectangles to each + other. + """ + + x_center: float + y_center: float + width: float + height: float + rotation: Optional[float] = 0.0 + rect_id: Optional[int] = None + + @doc_controls.do_not_generate_docs + def to_pb2(self) -> _NormalizedRectProto: + """Generates a NormalizedRect protobuf object.""" + return _NormalizedRectProto( + x_center=self.x_center, + y_center=self.y_center, + width=self.width, + height=self.height, + ) + + @classmethod + @doc_controls.do_not_generate_docs + def create_from_pb2(cls, pb2_obj: _NormalizedRectProto) -> 'NormalizedRect': + """Creates a `NormalizedRect` object from the given protobuf object.""" + return NormalizedRect( + x_center=pb2_obj.x_center, + y_center=pb2_obj.y_center, + width=pb2_obj.width, + height=pb2_obj.height) + + def __eq__(self, other: Any) -> bool: + """Checks if this object is equal to the given object. + + Args: + other: The object to be compared with. + + Returns: + True if the objects are equal. + """ + if not isinstance(other, NormalizedRect): + return False + + return self.to_pb2().__eq__(other.to_pb2()) diff --git a/mediapipe/tasks/python/test/vision/BUILD b/mediapipe/tasks/python/test/vision/BUILD index 09bbf1958..e4d331784 100644 --- a/mediapipe/tasks/python/test/vision/BUILD +++ b/mediapipe/tasks/python/test/vision/BUILD @@ -49,6 +49,7 @@ py_test( "//mediapipe/tasks/python/components/processors:classifier_options", "//mediapipe/tasks/python/components/containers:category", "//mediapipe/tasks/python/components/containers:classifications", + "//mediapipe/tasks/python/components/containers:rect", "//mediapipe/tasks/python/core:base_options", "//mediapipe/tasks/python/test:test_utils", "//mediapipe/tasks/python/vision:image_classifier", diff --git a/mediapipe/tasks/python/test/vision/image_classifier_test.py b/mediapipe/tasks/python/test/vision/image_classifier_test.py index 073674c3f..b9098b55b 100644 --- a/mediapipe/tasks/python/test/vision/image_classifier_test.py +++ b/mediapipe/tasks/python/test/vision/image_classifier_test.py @@ -24,11 +24,13 @@ from mediapipe.python._framework_bindings import image as image_module from mediapipe.tasks.python.components.processors import classifier_options from mediapipe.tasks.python.components.containers import category as category_module from mediapipe.tasks.python.components.containers import classifications as classifications_module +from mediapipe.tasks.python.components.containers import rect as rect_module from mediapipe.tasks.python.core import base_options as base_options_module from mediapipe.tasks.python.test import test_utils from mediapipe.tasks.python.vision import image_classifier from mediapipe.tasks.python.vision.core import vision_task_running_mode as running_mode_module +_NormalizedRect = rect_module.NormalizedRect _BaseOptions = base_options_module.BaseOptions _ClassifierOptions = classifier_options.ClassifierOptions _Category = category_module.Category @@ -42,40 +44,6 @@ _RUNNING_MODE = running_mode_module.VisionTaskRunningMode _MODEL_FILE = 'mobilenet_v2_1.0_224.tflite' _IMAGE_FILE = 'burger.jpg' -_EXPECTED_CATEGORIES = [ - _Category( - index=934, - score=0.7939587831497192, - display_name='', - category_name='cheeseburger'), - _Category( - index=932, - score=0.02739289402961731, - display_name='', - category_name='bagel'), - _Category( - index=925, - score=0.01934075355529785, - display_name='', - category_name='guacamole'), - _Category( - index=963, - score=0.006327860057353973, - display_name='', - category_name='meat loaf') -] -_EXPECTED_CLASSIFICATION_RESULT = _ClassificationResult( - classifications=[ - _Classifications( - entries=[ - _ClassificationEntry( - categories=_EXPECTED_CATEGORIES, - timestamp_ms=0 - ) - ], - head_index=0, - head_name='probability') - ]) _EMPTY_CLASSIFICATION_RESULT = _ClassificationResult( classifications=[ _Classifications( @@ -94,6 +62,60 @@ _SCORE_THRESHOLD = 0.5 _MAX_RESULTS = 3 +def _generate_burger_results(timestamp_ms: int) -> _ClassificationResult: + return _ClassificationResult( + classifications=[ + _Classifications( + entries=[ + _ClassificationEntry( + categories=[ + _Category( + index=934, + score=0.7939587831497192, + display_name='', + category_name='cheeseburger'), + _Category( + index=932, + score=0.02739289402961731, + display_name='', + category_name='bagel'), + _Category( + index=925, + score=0.01934075355529785, + display_name='', + category_name='guacamole'), + _Category( + index=963, + score=0.006327860057353973, + display_name='', + category_name='meat loaf') + ], + timestamp_ms=timestamp_ms + ) + ], + head_index=0, + head_name='probability') + ]) + + +def _generate_soccer_ball_results(timestamp_ms: int) -> _ClassificationResult: + return _ClassificationResult( + classifications=[ + _Classifications( + entries=[ + _ClassificationEntry( + categories=[ + _Category(index=806, score=0.9965274930000305, display_name='', + category_name='soccer ball') + ], + timestamp_ms=timestamp_ms + ) + ], + head_index=0, + head_name='probability') + ]) + + class ModelFileType(enum.Enum): FILE_CONTENT = 1 FILE_NAME = 2 @@ -138,8 +160,8 @@ class ImageClassifierTest(parameterized.TestCase): self.assertIsInstance(classifier, _ImageClassifier) @parameterized.parameters( - (ModelFileType.FILE_NAME, 4, _EXPECTED_CLASSIFICATION_RESULT), - (ModelFileType.FILE_CONTENT, 4, _EXPECTED_CLASSIFICATION_RESULT)) + (ModelFileType.FILE_NAME, 4, _generate_burger_results(0)), + (ModelFileType.FILE_CONTENT, 4, _generate_burger_results(0))) def test_classify(self, model_file_type, max_results, expected_classification_result): # Creates classifier. @@ -167,8 +189,8 @@ class ImageClassifierTest(parameterized.TestCase): classifier.close() @parameterized.parameters( - (ModelFileType.FILE_NAME, 4, _EXPECTED_CLASSIFICATION_RESULT), - (ModelFileType.FILE_CONTENT, 4, _EXPECTED_CLASSIFICATION_RESULT)) + (ModelFileType.FILE_NAME, 4, _generate_burger_results(0)), + (ModelFileType.FILE_CONTENT, 4, _generate_burger_results(0))) def test_classify_in_context(self, model_file_type, max_results, expected_classification_result): if model_file_type is ModelFileType.FILE_NAME: @@ -190,6 +212,23 @@ class ImageClassifierTest(parameterized.TestCase): # Comparing results. self.assertEqual(image_result, expected_classification_result) + def test_classify_succeeds_with_region_of_interest(self): + base_options = _BaseOptions(model_asset_path=self.model_path) + classifier_options = _ClassifierOptions(max_results=1) + options = _ImageClassifierOptions( + base_options=base_options, classifier_options=classifier_options) + with _ImageClassifier.create_from_options(options) as classifier: + # Load the test image. + test_image = _Image.create_from_file( + test_utils.get_test_data_path('multi_objects.jpg')) + # NormalizedRect around the soccer ball. + roi = _NormalizedRect(x_center=0.532, y_center=0.521, width=0.164, + height=0.427) + # Performs image classification on the input. + image_result = classifier.classify(test_image, roi) + # Comparing results. + self.assertEqual(image_result, _generate_soccer_ball_results(0)) + def test_score_threshold_option(self): classifier_options = _ClassifierOptions(score_threshold=_SCORE_THRESHOLD) options = _ImageClassifierOptions( @@ -353,16 +392,27 @@ class ImageClassifierTest(parameterized.TestCase): for timestamp in range(0, 300, 30): classification_result = classifier.classify_for_video( self.test_image, timestamp) - expected_classification_result = _ClassificationResult( - classifications=[ - _Classifications( - entries=[ - _ClassificationEntry( - categories=_EXPECTED_CATEGORIES, timestamp_ms=timestamp) - ], - head_index=0, head_name='probability') - ]) - self.assertEqual(classification_result, expected_classification_result) + self.assertEqual(classification_result, + _generate_burger_results(timestamp)) + + def test_classify_for_video_succeeds_with_region_of_interest(self): + classifier_options = _ClassifierOptions(max_results=1) + options = _ImageClassifierOptions( + base_options=_BaseOptions(model_asset_path=self.model_path), + running_mode=_RUNNING_MODE.VIDEO, + classifier_options=classifier_options) + with _ImageClassifier.create_from_options(options) as classifier: + # Load the test image. + test_image = _Image.create_from_file( + test_utils.get_test_data_path('multi_objects.jpg')) + # NormalizedRect around the soccer ball. + roi = _NormalizedRect(x_center=0.532, y_center=0.521, width=0.164, + height=0.427) + for timestamp in range(0, 300, 30): + classification_result = classifier.classify_for_video( + test_image, timestamp, roi) + self.assertEqual(classification_result, + _generate_soccer_ball_results(timestamp)) def test_calling_classify_in_live_stream_mode(self): options = _ImageClassifierOptions( diff --git a/mediapipe/tasks/python/vision/BUILD b/mediapipe/tasks/python/vision/BUILD index 7d44e3326..273804f0a 100644 --- a/mediapipe/tasks/python/vision/BUILD +++ b/mediapipe/tasks/python/vision/BUILD @@ -49,6 +49,7 @@ py_library( "//mediapipe/tasks/cc/vision/image_classifier/proto:image_classifier_graph_options_py_pb2", "//mediapipe/tasks/python/components/processors:classifier_options", "//mediapipe/tasks/python/components/containers:classifications", + "//mediapipe/tasks/python/components/containers:rect", "//mediapipe/tasks/python/core:base_options", "//mediapipe/tasks/python/core:optional_dependencies", "//mediapipe/tasks/python/core:task_info", diff --git a/mediapipe/tasks/python/vision/image_classifier.py b/mediapipe/tasks/python/vision/image_classifier.py index 36c5561c4..348e848ab 100644 --- a/mediapipe/tasks/python/vision/image_classifier.py +++ b/mediapipe/tasks/python/vision/image_classifier.py @@ -24,12 +24,14 @@ from mediapipe.python._framework_bindings import task_runner as task_runner_modu from mediapipe.tasks.cc.vision.image_classifier.proto import image_classifier_graph_options_pb2 from mediapipe.tasks.python.components.processors import classifier_options from mediapipe.tasks.python.components.containers import classifications as classifications_module +from mediapipe.tasks.python.components.containers import rect as rect_module from mediapipe.tasks.python.core import base_options as base_options_module from mediapipe.tasks.python.core import task_info as task_info_module from mediapipe.tasks.python.core.optional_dependencies import doc_controls from mediapipe.tasks.python.vision.core import base_vision_task_api from mediapipe.tasks.python.vision.core import vision_task_running_mode as running_mode_module +_NormalizedRect = rect_module.NormalizedRect _BaseOptions = base_options_module.BaseOptions _ImageClassifierGraphOptionsProto = image_classifier_graph_options_pb2.ImageClassifierGraphOptions _ClassifierOptions = classifier_options.ClassifierOptions @@ -42,10 +44,17 @@ _CLASSIFICATION_RESULT_TAG = 'CLASSIFICATION_RESULT' _IMAGE_IN_STREAM_NAME = 'image_in' _IMAGE_OUT_STREAM_NAME = 'image_out' _IMAGE_TAG = 'IMAGE' +_NORM_RECT_NAME = 'norm_rect_in' +_NORM_RECT_TAG = 'NORM_RECT' _TASK_GRAPH_NAME = 'mediapipe.tasks.vision.image_classifier.ImageClassifierGraph' _MICRO_SECONDS_PER_MILLISECOND = 1000 +def _build_full_image_norm_rect() -> _NormalizedRect: + # Builds a NormalizedRect covering the entire image. + return _NormalizedRect(x_center=0.5, y_center=0.5, width=1, height=1) + + @dataclasses.dataclass class ImageClassifierOptions: """Options for the image classifier task. @@ -145,6 +154,7 @@ class ImageClassifier(base_vision_task_api.BaseVisionTaskApi): task_graph=_TASK_GRAPH_NAME, input_streams=[ ':'.join([_IMAGE_TAG, _IMAGE_IN_STREAM_NAME]), + ':'.join([_NORM_RECT_TAG, _NORM_RECT_NAME]), ], output_streams=[ ':'.join([_CLASSIFICATION_RESULT_TAG, @@ -161,11 +171,13 @@ class ImageClassifier(base_vision_task_api.BaseVisionTaskApi): def classify( self, image: image_module.Image, + roi: Optional[_NormalizedRect] = None ) -> classifications_module.ClassificationResult: """Performs image classification on the provided MediaPipe Image. Args: image: MediaPipe Image. + roi: The region of interest. Returns: A classification result object that contains a list of classifications. @@ -174,8 +186,10 @@ class ImageClassifier(base_vision_task_api.BaseVisionTaskApi): ValueError: If any of the input arguments is invalid. RuntimeError: If image classification failed to run. """ - output_packets = self._process_image_data( - {_IMAGE_IN_STREAM_NAME: packet_creator.create_image(image)}) + norm_rect = roi if roi is not None else _build_full_image_norm_rect() + output_packets = self._process_image_data({ + _IMAGE_IN_STREAM_NAME: packet_creator.create_image(image), + _NORM_RECT_NAME: packet_creator.create_proto(norm_rect.to_pb2())}) classification_result_proto = packet_getter.get_proto( output_packets[_CLASSIFICATION_RESULT_OUT_STREAM_NAME]) @@ -186,7 +200,8 @@ class ImageClassifier(base_vision_task_api.BaseVisionTaskApi): def classify_for_video( self, image: image_module.Image, - timestamp_ms: int + timestamp_ms: int, + roi: Optional[_NormalizedRect] = None ) -> classifications_module.ClassificationResult: """Performs image classification on the provided video frames. @@ -198,6 +213,7 @@ class ImageClassifier(base_vision_task_api.BaseVisionTaskApi): Args: image: MediaPipe Image. timestamp_ms: The timestamp of the input video frame in milliseconds. + roi: The region of interest. Returns: A classification result object that contains a list of classifications. @@ -206,10 +222,12 @@ class ImageClassifier(base_vision_task_api.BaseVisionTaskApi): ValueError: If any of the input arguments is invalid. RuntimeError: If image classification failed to run. """ + norm_rect = roi if roi is not None else _build_full_image_norm_rect() output_packets = self._process_video_data({ - _IMAGE_IN_STREAM_NAME: - packet_creator.create_image(image).at( - timestamp_ms * _MICRO_SECONDS_PER_MILLISECOND) + _IMAGE_IN_STREAM_NAME: packet_creator.create_image(image).at( + timestamp_ms * _MICRO_SECONDS_PER_MILLISECOND), + _NORM_RECT_NAME: packet_creator.create_proto(norm_rect.to_pb2()).at( + timestamp_ms * _MICRO_SECONDS_PER_MILLISECOND) }) classification_result_proto = packet_getter.get_proto( output_packets[_CLASSIFICATION_RESULT_OUT_STREAM_NAME]) @@ -219,7 +237,12 @@ class ImageClassifier(base_vision_task_api.BaseVisionTaskApi): for classification in classification_result_proto.classifications ]) - def classify_async(self, image: image_module.Image, timestamp_ms: int) -> None: + def classify_async( + self, + image: image_module.Image, + timestamp_ms: int, + roi: Optional[_NormalizedRect] = None + ) -> None: """Sends live image data (an Image with a unique timestamp) to perform image classification. @@ -241,13 +264,16 @@ class ImageClassifier(base_vision_task_api.BaseVisionTaskApi): Args: image: MediaPipe Image. timestamp_ms: The timestamp of the input image in milliseconds. + roi: The region of interest. Raises: ValueError: If the current input timestamp is smaller than what the image classifier has already processed. """ + norm_rect = roi if roi is not None else _build_full_image_norm_rect() self._send_live_stream_data({ - _IMAGE_IN_STREAM_NAME: - packet_creator.create_image(image).at( - timestamp_ms * _MICRO_SECONDS_PER_MILLISECOND) + _IMAGE_IN_STREAM_NAME: packet_creator.create_image(image).at( + timestamp_ms * _MICRO_SECONDS_PER_MILLISECOND), + _NORM_RECT_NAME: packet_creator.create_proto(norm_rect.to_pb2()).at( + timestamp_ms * _MICRO_SECONDS_PER_MILLISECOND) }) From c2672d040f9b79a77610efd6cd9f591f2b047f77 Mon Sep 17 00:00:00 2001 From: kinaryml Date: Mon, 10 Oct 2022 08:19:12 -0700 Subject: [PATCH 18/23] Updated error message for the invalid model path test case --- mediapipe/tasks/python/test/vision/image_classifier_test.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/mediapipe/tasks/python/test/vision/image_classifier_test.py b/mediapipe/tasks/python/test/vision/image_classifier_test.py index b9098b55b..336f1f306 100644 --- a/mediapipe/tasks/python/test/vision/image_classifier_test.py +++ b/mediapipe/tasks/python/test/vision/image_classifier_test.py @@ -146,7 +146,7 @@ class ImageClassifierTest(parameterized.TestCase): with self.assertRaisesRegex( ValueError, r"ExternalFile must specify at least one of 'file_content', " - r"'file_name' or 'file_descriptor_meta'."): + r"'file_name', 'file_pointer_meta' or 'file_descriptor_meta'."): base_options = _BaseOptions(model_asset_path='') options = _ImageClassifierOptions(base_options=base_options) _ImageClassifier.create_from_options(options) From 0a8dbc7576fd5381eefe98bf7e645f9f7b01ad04 Mon Sep 17 00:00:00 2001 From: kinaryml Date: Mon, 10 Oct 2022 13:58:37 -0700 Subject: [PATCH 19/23] Added remaining parameters to initialize the Rect data class --- mediapipe/tasks/python/components/containers/rect.py | 7 ++++++- 1 file changed, 6 insertions(+), 1 deletion(-) diff --git a/mediapipe/tasks/python/components/containers/rect.py b/mediapipe/tasks/python/components/containers/rect.py index e74be1b0e..aadb404db 100644 --- a/mediapipe/tasks/python/components/containers/rect.py +++ b/mediapipe/tasks/python/components/containers/rect.py @@ -109,6 +109,8 @@ class NormalizedRect: y_center=self.y_center, width=self.width, height=self.height, + rotation=self.rotation, + rect_id=self.rect_id ) @classmethod @@ -119,7 +121,10 @@ class NormalizedRect: x_center=pb2_obj.x_center, y_center=pb2_obj.y_center, width=pb2_obj.width, - height=pb2_obj.height) + height=pb2_obj.height, + rotation=pb2_obj.rotation, + rect_id=pb2_obj.rect_id + ) def __eq__(self, other: Any) -> bool: """Checks if this object is equal to the given object. From 8ea0018397fc98289e60d3498c90b98b82bb3e18 Mon Sep 17 00:00:00 2001 From: kinaryml Date: Tue, 11 Oct 2022 22:34:56 -0700 Subject: [PATCH 20/23] Added a check to see if the output packet is empty in the API and updated tests --- .../test/vision/image_classifier_test.py | 38 +++++++++---------- .../tasks/python/vision/image_classifier.py | 4 +- 2 files changed, 22 insertions(+), 20 deletions(-) diff --git a/mediapipe/tasks/python/test/vision/image_classifier_test.py b/mediapipe/tasks/python/test/vision/image_classifier_test.py index 336f1f306..5bb479d7a 100644 --- a/mediapipe/tasks/python/test/vision/image_classifier_test.py +++ b/mediapipe/tasks/python/test/vision/image_classifier_test.py @@ -44,24 +44,27 @@ _RUNNING_MODE = running_mode_module.VisionTaskRunningMode _MODEL_FILE = 'mobilenet_v2_1.0_224.tflite' _IMAGE_FILE = 'burger.jpg' -_EMPTY_CLASSIFICATION_RESULT = _ClassificationResult( - classifications=[ - _Classifications( - entries=[ - _ClassificationEntry( - categories=[], - timestamp_ms=0 - ) - ], - head_index=0, - head_name='probability') - ]) _ALLOW_LIST = ['cheeseburger', 'guacamole'] _DENY_LIST = ['cheeseburger'] _SCORE_THRESHOLD = 0.5 _MAX_RESULTS = 3 +def _generate_empty_results(timestamp_ms: int) -> _ClassificationResult: + return _ClassificationResult( + classifications=[ + _Classifications( + entries=[ + _ClassificationEntry( + categories=[], + timestamp_ms=timestamp_ms + ) + ], + head_index=0, + head_name='probability') + ]) + + def _generate_burger_results(timestamp_ms: int) -> _ClassificationResult: return _ClassificationResult( classifications=[ @@ -447,16 +450,14 @@ class ImageClassifierTest(parameterized.TestCase): ValueError, r'Input timestamp must be monotonically increasing'): classifier.classify_async(self.test_image, 0) - # TODO: Fix the packet is empty issue. - """ - @parameterized.parameters((0, _EXPECTED_CLASSIFICATION_RESULT), - (1, _EMPTY_CLASSIFICATION_RESULT)) - def test_classify_async_calls(self, threshold, expected_result): + @parameterized.parameters((0, _generate_burger_results), + (1, _generate_empty_results)) + def test_classify_async_calls(self, threshold, expected_result_fn): observed_timestamp_ms = -1 def check_result(result: _ClassificationResult, output_image: _Image, timestamp_ms: int): - self.assertEqual(result, expected_result) + self.assertEqual(result, expected_result_fn(timestamp_ms)) self.assertTrue( np.array_equal(output_image.numpy_view(), self.test_image.numpy_view())) @@ -474,7 +475,6 @@ class ImageClassifierTest(parameterized.TestCase): for timestamp in range(0, 300, 30): classifier.classify_async(self.test_image, timestamp) classifier.close() - """ if __name__ == '__main__': diff --git a/mediapipe/tasks/python/vision/image_classifier.py b/mediapipe/tasks/python/vision/image_classifier.py index 348e848ab..1d6b69778 100644 --- a/mediapipe/tasks/python/vision/image_classifier.py +++ b/mediapipe/tasks/python/vision/image_classifier.py @@ -138,6 +138,8 @@ class ImageClassifier(base_vision_task_api.BaseVisionTaskApi): """ def packets_callback(output_packets: Mapping[str, packet_module.Packet]): + if output_packets[_IMAGE_OUT_STREAM_NAME].is_empty(): + return classification_result_proto = packet_getter.get_proto( output_packets[_CLASSIFICATION_RESULT_OUT_STREAM_NAME]) @@ -148,7 +150,7 @@ class ImageClassifier(base_vision_task_api.BaseVisionTaskApi): image = packet_getter.get_image(output_packets[_IMAGE_OUT_STREAM_NAME]) timestamp = output_packets[_IMAGE_OUT_STREAM_NAME].timestamp options.result_callback(classification_result, image, - timestamp.value / _MICRO_SECONDS_PER_MILLISECOND) + timestamp.value // _MICRO_SECONDS_PER_MILLISECOND) task_info = _TaskInfo( task_graph=_TASK_GRAPH_NAME, From 7726205d851cab65bbe1e7508201704c57869ae1 Mon Sep 17 00:00:00 2001 From: kinaryml Date: Tue, 11 Oct 2022 22:42:50 -0700 Subject: [PATCH 21/23] Added a test to run classify_async in region of interest mode --- .../test/vision/image_classifier_test.py | 26 +++++++++++++++++++ 1 file changed, 26 insertions(+) diff --git a/mediapipe/tasks/python/test/vision/image_classifier_test.py b/mediapipe/tasks/python/test/vision/image_classifier_test.py index 5bb479d7a..d2140f1da 100644 --- a/mediapipe/tasks/python/test/vision/image_classifier_test.py +++ b/mediapipe/tasks/python/test/vision/image_classifier_test.py @@ -476,6 +476,32 @@ class ImageClassifierTest(parameterized.TestCase): classifier.classify_async(self.test_image, timestamp) classifier.close() + def test_classify_async_succeeds_with_region_of_interest(self): + observed_timestamp_ms = -1 + + def check_result(result: _ClassificationResult, unused_output_image: _Image, + timestamp_ms: int): + self.assertEqual(result, _generate_soccer_ball_results(timestamp_ms)) + self.assertLess(observed_timestamp_ms, timestamp_ms) + self.observed_timestamp_ms = timestamp_ms + + classifier_options = _ClassifierOptions(max_results=1) + options = _ImageClassifierOptions( + base_options=_BaseOptions(model_asset_path=self.model_path), + running_mode=_RUNNING_MODE.LIVE_STREAM, + classifier_options=classifier_options, + result_callback=check_result) + classifier = _ImageClassifier.create_from_options(options) + # Load the test image. + test_image = _Image.create_from_file( + test_utils.get_test_data_path('multi_objects.jpg')) + # NormalizedRect around the soccer ball. + roi = _NormalizedRect(x_center=0.532, y_center=0.521, width=0.164, + height=0.427) + for timestamp in range(0, 300, 30): + classifier.classify_async(test_image, timestamp, roi) + classifier.close() + if __name__ == '__main__': absltest.main() From 6771fe69e9803bdfe8925dcb53b4445c83a72454 Mon Sep 17 00:00:00 2001 From: kinaryml Date: Tue, 11 Oct 2022 22:52:18 -0700 Subject: [PATCH 22/23] Included checks for image sizes while running in async and roi mode --- .../python/test/vision/image_classifier_test.py | 16 +++++++++------- 1 file changed, 9 insertions(+), 7 deletions(-) diff --git a/mediapipe/tasks/python/test/vision/image_classifier_test.py b/mediapipe/tasks/python/test/vision/image_classifier_test.py index d2140f1da..783718d06 100644 --- a/mediapipe/tasks/python/test/vision/image_classifier_test.py +++ b/mediapipe/tasks/python/test/vision/image_classifier_test.py @@ -477,11 +477,19 @@ class ImageClassifierTest(parameterized.TestCase): classifier.close() def test_classify_async_succeeds_with_region_of_interest(self): + # Load the test image. + test_image = _Image.create_from_file( + test_utils.get_test_data_path('multi_objects.jpg')) + # NormalizedRect around the soccer ball. + roi = _NormalizedRect(x_center=0.532, y_center=0.521, width=0.164, + height=0.427) observed_timestamp_ms = -1 - def check_result(result: _ClassificationResult, unused_output_image: _Image, + def check_result(result: _ClassificationResult, output_image: _Image, timestamp_ms: int): self.assertEqual(result, _generate_soccer_ball_results(timestamp_ms)) + self.assertEqual(output_image.width, test_image.width) + self.assertEqual(output_image.height, test_image.height) self.assertLess(observed_timestamp_ms, timestamp_ms) self.observed_timestamp_ms = timestamp_ms @@ -492,12 +500,6 @@ class ImageClassifierTest(parameterized.TestCase): classifier_options=classifier_options, result_callback=check_result) classifier = _ImageClassifier.create_from_options(options) - # Load the test image. - test_image = _Image.create_from_file( - test_utils.get_test_data_path('multi_objects.jpg')) - # NormalizedRect around the soccer ball. - roi = _NormalizedRect(x_center=0.532, y_center=0.521, width=0.164, - height=0.427) for timestamp in range(0, 300, 30): classifier.classify_async(test_image, timestamp, roi) classifier.close() From 803210a86b460f53fa6eb6c4ef4d67131e70d457 Mon Sep 17 00:00:00 2001 From: kinaryml Date: Fri, 14 Oct 2022 03:00:29 -0700 Subject: [PATCH 23/23] Simplified async test cases to invoke the classifier in context --- .../python/test/vision/image_classifier_test.py | 14 ++++++-------- 1 file changed, 6 insertions(+), 8 deletions(-) diff --git a/mediapipe/tasks/python/test/vision/image_classifier_test.py b/mediapipe/tasks/python/test/vision/image_classifier_test.py index 783718d06..a90ddd53e 100644 --- a/mediapipe/tasks/python/test/vision/image_classifier_test.py +++ b/mediapipe/tasks/python/test/vision/image_classifier_test.py @@ -471,10 +471,9 @@ class ImageClassifierTest(parameterized.TestCase): running_mode=_RUNNING_MODE.LIVE_STREAM, classifier_options=classifier_options, result_callback=check_result) - classifier = _ImageClassifier.create_from_options(options) - for timestamp in range(0, 300, 30): - classifier.classify_async(self.test_image, timestamp) - classifier.close() + with _ImageClassifier.create_from_options(options) as classifier: + for timestamp in range(0, 300, 30): + classifier.classify_async(self.test_image, timestamp) def test_classify_async_succeeds_with_region_of_interest(self): # Load the test image. @@ -499,10 +498,9 @@ class ImageClassifierTest(parameterized.TestCase): running_mode=_RUNNING_MODE.LIVE_STREAM, classifier_options=classifier_options, result_callback=check_result) - classifier = _ImageClassifier.create_from_options(options) - for timestamp in range(0, 300, 30): - classifier.classify_async(test_image, timestamp, roi) - classifier.close() + with _ImageClassifier.create_from_options(options) as classifier: + for timestamp in range(0, 300, 30): + classifier.classify_async(test_image, timestamp, roi) if __name__ == '__main__':