From cb52432159ecfea69c6df54be2cb56fd569f275f Mon Sep 17 00:00:00 2001 From: kinaryml Date: Thu, 8 Sep 2022 06:23:03 -0700 Subject: [PATCH 01/55] Added image classification implementation files and associated tests --- .../tasks/python/components/containers/BUILD | 10 + .../components/containers/classifications.py | 169 ++++++++++ mediapipe/tasks/python/test/vision/BUILD | 38 ++- .../test/vision/image_classification_test.py | 301 ++++++++++++++++++ mediapipe/tasks/python/vision/BUILD | 20 ++ .../python/vision/image_classification.py | 227 +++++++++++++ 6 files changed, 764 insertions(+), 1 deletion(-) create mode 100644 mediapipe/tasks/python/components/containers/classifications.py create mode 100644 mediapipe/tasks/python/test/vision/image_classification_test.py create mode 100644 mediapipe/tasks/python/vision/image_classification.py diff --git a/mediapipe/tasks/python/components/containers/BUILD b/mediapipe/tasks/python/components/containers/BUILD index 2bc951220..eb3acdd97 100644 --- a/mediapipe/tasks/python/components/containers/BUILD +++ b/mediapipe/tasks/python/components/containers/BUILD @@ -47,3 +47,13 @@ py_library( "//mediapipe/tasks/python/core:optional_dependencies", ], ) + +py_library( + name = "classifications", + srcs = ["classifications.py"], + deps = [ + ":category", + "//mediapipe/tasks/cc/components/containers:classifications_py_pb2", + "//mediapipe/tasks/python/core:optional_dependencies", + ], +) diff --git a/mediapipe/tasks/python/components/containers/classifications.py b/mediapipe/tasks/python/components/containers/classifications.py new file mode 100644 index 000000000..19c5decde --- /dev/null +++ b/mediapipe/tasks/python/components/containers/classifications.py @@ -0,0 +1,169 @@ +# Copyright 2022 The TensorFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""Classifications data class.""" + +import dataclasses +from typing import Any, List, Optional + +from mediapipe.tasks.cc.components.containers import classifications_pb2 +from mediapipe.tasks.python.components.containers import category as category_module +from mediapipe.tasks.python.core.optional_dependencies import doc_controls + +_ClassificationEntryProto = classifications_pb2.ClassificationEntry +_ClassificationsProto = classifications_pb2.Classifications +_ClassificationResultProto = classifications_pb2.ClassificationResult + + +@dataclasses.dataclass +class ClassificationEntry: + """List of predicted classes (aka labels) for a given classifier head. + + Attributes: + categories: The array of predicted categories, usually sorted by descending + scores (e.g. from high to low probability). + timestamp_ms: The optional timestamp (in milliseconds) associated to the + classification entry. This is useful for time series use cases, e.g., + audio classification. + """ + + categories: List[category_module.Category] + timestamp_ms: Optional[int] = None + + @doc_controls.do_not_generate_docs + def to_pb2(self) -> _ClassificationEntryProto: + """Generates a ClassificationEntry protobuf object.""" + return _ClassificationEntryProto( + categories=[category.to_pb2() for category in self.categories], + timestamp_ms=self.timestamp_ms) + + @classmethod + @doc_controls.do_not_generate_docs + def create_from_pb2( + cls, pb2_obj: _ClassificationEntryProto) -> 'ClassificationEntry': + """Creates a `ClassificationEntry` object from the given protobuf object.""" + return ClassificationEntry( + categories=[ + category_module.Category.create_from_pb2(category) + for category in pb2_obj.categories + ], + timestamp_ms=pb2_obj.timestamp_ms) + + def __eq__(self, other: Any) -> bool: + """Checks if this object is equal to the given object. + + Args: + other: The object to be compared with. + + Returns: + True if the objects are equal. + """ + if not isinstance(other, ClassificationEntry): + return False + + return self.to_pb2().__eq__(other.to_pb2()) + + +@dataclasses.dataclass +class Classifications: + """Represents the classifications for a given classifier head. + + Attributes: + entries: A list of `ClassificationEntry` objects. + head_index: The index of the classifier head these categories refer to. + This is useful for multi-head models. + head_name: The name of the classifier head, which is the corresponding + tensor metadata name. + """ + + entries: List[ClassificationEntry] + head_index: int + head_name: str + + @doc_controls.do_not_generate_docs + def to_pb2(self) -> _ClassificationsProto: + """Generates a Classifications protobuf object.""" + return _ClassificationsProto( + entries=[entry.to_pb2() for entry in self.entries], + head_index=self.head_index, + head_name=self.head_name) + + @classmethod + @doc_controls.do_not_generate_docs + def create_from_pb2(cls, pb2_obj: _ClassificationsProto) -> 'Classifications': + """Creates a `Classifications` object from the given protobuf object.""" + return Classifications( + entries=[ + ClassificationEntry.create_from_pb2(entry) + for entry in pb2_obj.entries + ], + head_index=pb2_obj.head_index, + head_name=pb2_obj.head_name) + + def __eq__(self, other: Any) -> bool: + """Checks if this object is equal to the given object. + + Args: + other: The object to be compared with. + + Returns: + True if the objects are equal. + """ + if not isinstance(other, Classifications): + return False + + return self.to_pb2().__eq__(other.to_pb2()) + + +@dataclasses.dataclass +class ClassificationResult: + """Contains one set of results per classifier head. + + Attributes: + classifications: A list of `Classifications` objects. + """ + + classifications: List[Classifications] + + @doc_controls.do_not_generate_docs + def to_pb2(self) -> _ClassificationResultProto: + """Generates a ClassificationResult protobuf object.""" + return _ClassificationResultProto( + classifications=[ + classification.to_pb2() for classification in self.classifications + ]) + + @classmethod + @doc_controls.do_not_generate_docs + def create_from_pb2( + cls, pb2_obj: _ClassificationResultProto) -> 'ClassificationResult': + """Creates a `ClassificationResult` object from the given protobuf object.""" + return ClassificationResult( + classifications=[ + Classifications.create_from_pb2(classification) + for classification in pb2_obj.classifications + ]) + + def __eq__(self, other: Any) -> bool: + """Checks if this object is equal to the given object. + + Args: + other: The object to be compared with. + + Returns: + True if the objects are equal. + """ + if not isinstance(other, ClassificationResult): + return False + + return self.to_pb2().__eq__(other.to_pb2()) diff --git a/mediapipe/tasks/python/test/vision/BUILD b/mediapipe/tasks/python/test/vision/BUILD index bb495338d..a63c36b55 100644 --- a/mediapipe/tasks/python/test/vision/BUILD +++ b/mediapipe/tasks/python/test/vision/BUILD @@ -18,4 +18,40 @@ package(default_visibility = ["//mediapipe/tasks:internal"]) licenses(["notice"]) -# TODO: This test fails in OSS +py_test( + name = "object_detector_test", + srcs = ["object_detector_test.py"], + data = [ + "//mediapipe/tasks/testdata/vision:test_images", + "//mediapipe/tasks/testdata/vision:test_models", + ], + deps = [ + # build rule placeholder: numpy dep, + "//mediapipe/tasks/python/components/containers:bounding_box", + "//mediapipe/tasks/python/components/containers:category", + "//mediapipe/tasks/python/components/containers:detections", + "//mediapipe/tasks/python/core:base_options", + "//mediapipe/tasks/python/test:test_util", + "//mediapipe/tasks/python/vision:object_detector", + "//mediapipe/tasks/python/vision/core:vision_task_running_mode", + "@absl_py//absl/testing:parameterized", + ], +) + +py_test( + name = "image_classification_test", + srcs = ["image_classification_test.py"], + data = [ + "//mediapipe/tasks/testdata/vision:test_images", + "//mediapipe/tasks/testdata/vision:test_models", + ], + deps = [ + "//mediapipe/tasks/python/components/containers:category", + "//mediapipe/tasks/python/components/containers:classifications", + "//mediapipe/tasks/python/core:base_options", + "//mediapipe/tasks/python/test:test_util", + "//mediapipe/tasks/python/vision:image_classification", + "//mediapipe/tasks/python/vision/core:vision_task_running_mode", + "@absl_py//absl/testing:parameterized", + ], +) diff --git a/mediapipe/tasks/python/test/vision/image_classification_test.py b/mediapipe/tasks/python/test/vision/image_classification_test.py new file mode 100644 index 000000000..3650c547c --- /dev/null +++ b/mediapipe/tasks/python/test/vision/image_classification_test.py @@ -0,0 +1,301 @@ +# Copyright 2022 The MediaPipe Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""Tests for image classifier.""" + +import enum + +from absl.testing import absltest +from absl.testing import parameterized + +from mediapipe.python._framework_bindings import image as image_module +from mediapipe.tasks.python.components.containers import category as category_module +from mediapipe.tasks.python.components.containers import classifications as classifications_module +from mediapipe.tasks.python.core import base_options as base_options_module +from mediapipe.tasks.python.test import test_util +from mediapipe.tasks.python.vision import image_classification +from mediapipe.tasks.python.vision.core import vision_task_running_mode as running_mode_module + +_BaseOptions = base_options_module.BaseOptions +_Category = category_module.Category +_ClassificationEntry = classifications_module.ClassificationEntry +_Classifications = classifications_module.Classifications +_ClassificationResult = classifications_module.ClassificationResult +_Image = image_module.Image +_ImageClassifier = image_classification.ImageClassifier +_ImageClassifierOptions = image_classification.ImageClassifierOptions +_RUNNING_MODE = running_mode_module.VisionTaskRunningMode + +_MODEL_FILE = 'mobilenet_v2_1.0_224.tflite' +_IMAGE_FILE = 'burger.jpg' +_EXPECTED_CLASSIFICATION_RESULT = _ClassificationResult( + classifications=[ + _Classifications( + entries=[ + _ClassificationEntry( + categories=[ + _Category( + index=934, + score=0.7952049970626831, + display_name='', + category_name='cheeseburger'), + _Category( + index=932, + score=0.02732999622821808, + display_name='', + category_name='bagel'), + _Category( + index=925, + score=0.01933487318456173, + display_name='', + category_name='guacamole'), + _Category( + index=963, + score=0.006279350258409977, + display_name='', + category_name='meat loaf') + ], + timestamp_ms=0 + ) + ], + head_index=0, + head_name='probability') + ]) +_ALLOW_LIST = ['cheeseburger', 'guacamole'] +_DENY_LIST = ['cheeseburger'] +_SCORE_THRESHOLD = 0.5 +_MAX_RESULTS = 3 + + +class ModelFileType(enum.Enum): + FILE_CONTENT = 1 + FILE_NAME = 2 + + +class ImageClassifierTest(parameterized.TestCase): + + def setUp(self): + super().setUp() + self.test_image = test_util.read_test_image( + test_util.get_test_data_path(_IMAGE_FILE)) + self.model_path = test_util.get_test_data_path(_MODEL_FILE) + + def test_create_from_file_succeeds_with_valid_model_path(self): + # Creates with default option and valid model file successfully. + with _ImageClassifier.create_from_model_path(self.model_path) as classifier: + self.assertIsInstance(classifier, _ImageClassifier) + + def test_create_from_options_succeeds_with_valid_model_path(self): + # Creates with options containing model file successfully. + base_options = _BaseOptions(file_name=self.model_path) + options = _ImageClassifierOptions(base_options=base_options) + with _ImageClassifier.create_from_options(options) as classifier: + self.assertIsInstance(classifier, _ImageClassifier) + + def test_create_from_options_fails_with_invalid_model_path(self): + # Invalid empty model path. + with self.assertRaisesRegex( + ValueError, + r"ExternalFile must specify at least one of 'file_content', " + r"'file_name' or 'file_descriptor_meta'."): + base_options = _BaseOptions(file_name='') + options = _ImageClassifierOptions(base_options=base_options) + _ImageClassifier.create_from_options(options) + + def test_create_from_options_succeeds_with_valid_model_content(self): + # Creates with options containing model content successfully. + with open(self.model_path, 'rb') as f: + base_options = _BaseOptions(file_content=f.read()) + options = _ImageClassifierOptions(base_options=base_options) + classifier = _ImageClassifier.create_from_options(options) + self.assertIsInstance(classifier, _ImageClassifier) + + @parameterized.parameters( + (ModelFileType.FILE_NAME, 4, _EXPECTED_CLASSIFICATION_RESULT), + (ModelFileType.FILE_CONTENT, 4, _EXPECTED_CLASSIFICATION_RESULT)) + def test_classify(self, model_file_type, max_results, + expected_classification_result): + # Creates classifier. + if model_file_type is ModelFileType.FILE_NAME: + base_options = _BaseOptions(file_name=self.model_path) + elif model_file_type is ModelFileType.FILE_CONTENT: + with open(self.model_path, 'rb') as f: + model_content = f.read() + base_options = _BaseOptions(file_content=model_content) + else: + # Should never happen + raise ValueError('model_file_type is invalid.') + + options = _ImageClassifierOptions( + base_options=base_options, max_results=max_results) + classifier = _ImageClassifier.create_from_options(options) + + # Performs image classification on the input. + image_result = classifier.classify(self.test_image) + # Comparing results. + self.assertEqual(image_result, expected_classification_result) + # Closes the classifier explicitly when the classifier is not used in + # a context. + classifier.close() + + @parameterized.parameters( + (ModelFileType.FILE_NAME, 4, _EXPECTED_CLASSIFICATION_RESULT), + (ModelFileType.FILE_CONTENT, 4, _EXPECTED_CLASSIFICATION_RESULT)) + def test_classify_in_context(self, model_file_type, max_results, + expected_classification_result): + if model_file_type is ModelFileType.FILE_NAME: + base_options = _BaseOptions(file_name=self.model_path) + elif model_file_type is ModelFileType.FILE_CONTENT: + with open(self.model_path, 'rb') as f: + model_content = f.read() + base_options = _BaseOptions(file_content=model_content) + else: + # Should never happen + raise ValueError('model_file_type is invalid.') + + options = _ImageClassifierOptions( + base_options=base_options, max_results=max_results) + with _ImageClassifier.create_from_options(options) as classifier: + # Performs object detection on the input. + image_result = classifier.classify(self.test_image) + # Comparing results. + self.assertEqual(image_result, expected_classification_result) + + def test_score_threshold_option(self): + options = _ImageClassifierOptions( + base_options=_BaseOptions(file_name=self.model_path), + score_threshold=_SCORE_THRESHOLD) + with _ImageClassifier.create_from_options(options) as classifier: + # Performs image classification on the input. + image_result = classifier.classify(self.test_image) + classifications = image_result.classifications + + for classification in classifications: + for entry in classification.entries: + score = entry.categories[0].score + self.assertGreaterEqual( + score, _SCORE_THRESHOLD, + f'Classification with score lower than threshold found. ' + f'{classification}') + + def test_max_results_option(self): + options = _ImageClassifierOptions( + base_options=_BaseOptions(file_name=self.model_path), + max_results=_MAX_RESULTS) + with _ImageClassifier.create_from_options(options) as classifier: + # Performs image classification on the input. + image_result = classifier.classify(self.test_image) + categories = image_result.classifications[0].entries[0].categories + + self.assertLessEqual( + len(categories), _MAX_RESULTS, 'Too many results returned.') + + def test_allow_list_option(self): + options = _ImageClassifierOptions( + base_options=_BaseOptions(file_name=self.model_path), + category_allowlist=_ALLOW_LIST) + with _ImageClassifier.create_from_options(options) as classifier: + # Performs image classification on the input. + image_result = classifier.classify(self.test_image) + classifications = image_result.classifications + + for classification in classifications: + for entry in classification.entries: + label = entry.categories[0].category_name + self.assertIn(label, _ALLOW_LIST, + f'Label {label} found but not in label allow list') + + def test_deny_list_option(self): + options = _ImageClassifierOptions( + base_options=_BaseOptions(file_name=self.model_path), + category_denylist=_DENY_LIST) + with _ImageClassifier.create_from_options(options) as classifier: + # Performs image classification on the input. + image_result = classifier.classify(self.test_image) + classifications = image_result.classifications + + for classification in classifications: + for entry in classification.entries: + label = entry.categories[0].category_name + self.assertNotIn(label, _DENY_LIST, + f'Label {label} found but in deny list.') + + def test_combined_allowlist_and_denylist(self): + # Fails with combined allowlist and denylist + with self.assertRaisesRegex( + ValueError, + r'`category_allowlist` and `category_denylist` are mutually ' + r'exclusive options.'): + options = _ImageClassifierOptions( + base_options=_BaseOptions(file_name=self.model_path), + category_allowlist=['foo'], + category_denylist=['bar']) + with _ImageClassifier.create_from_options(options) as unused_classifier: + pass + + def test_empty_classification_outputs(self): + options = _ImageClassifierOptions( + base_options=_BaseOptions(file_name=self.model_path), score_threshold=1) + with _ImageClassifier.create_from_options(options) as classifier: + # Performs image classification on the input. + image_result = classifier.classify(self.test_image) + self.assertEmpty(image_result.classifications[0].entries[0].categories) + + def test_missing_result_callback(self): + options = _ImageClassifierOptions( + base_options=_BaseOptions(file_name=self.model_path), + running_mode=_RUNNING_MODE.LIVE_STREAM) + with self.assertRaisesRegex(ValueError, + r'result callback must be provided'): + with _ImageClassifier.create_from_options(options) as unused_classifier: + pass + + @parameterized.parameters((_RUNNING_MODE.IMAGE), (_RUNNING_MODE.VIDEO)) + def test_illegal_result_callback(self, running_mode): + + def pass_through(unused_result: _ClassificationResult): + pass + + options = _ImageClassifierOptions( + base_options=_BaseOptions(file_name=self.model_path), + running_mode=running_mode, + result_callback=pass_through) + with self.assertRaisesRegex(ValueError, + r'result callback should not be provided'): + with _ImageClassifier.create_from_options(options) as unused_classifier: + pass + + # @parameterized.parameters((0, _EXPECTED_CLASSIFICATION_RESULT), + # (1, _ClassificationResult(classifications=[]))) + # def test_classify_async_calls(self, threshold, expected_result): + # observed_timestamp_ms = -1 + # + # def check_result(result: _ClassificationResult, timestamp_ms: int): + # self.assertEqual(result, expected_result) + # self.assertLess(observed_timestamp_ms, timestamp_ms) + # self.observed_timestamp_ms = timestamp_ms + # + # options = _ImageClassifierOptions( + # base_options=_BaseOptions(file_name=self.model_path), + # running_mode=_RUNNING_MODE.LIVE_STREAM, + # max_results=4, + # score_threshold=threshold, + # result_callback=check_result) + # classifier = _ImageClassifier.create_from_options(options) + # for timestamp in range(0, 300, 30): + # classifier.classify_async(self.test_image, timestamp) + # classifier.close() + + +if __name__ == '__main__': + absltest.main() diff --git a/mediapipe/tasks/python/vision/BUILD b/mediapipe/tasks/python/vision/BUILD index 7ff818610..7a27da179 100644 --- a/mediapipe/tasks/python/vision/BUILD +++ b/mediapipe/tasks/python/vision/BUILD @@ -36,3 +36,23 @@ py_library( "//mediapipe/tasks/python/vision/core:vision_task_running_mode", ], ) + +py_library( + name = "image_classification", + srcs = [ + "image_classification.py", + ], + deps = [ + "//mediapipe/python:_framework_bindings", + "//mediapipe/python:packet_creator", + "//mediapipe/python:packet_getter", + "//mediapipe/tasks/cc/components:classifier_options_py_pb2", + "//mediapipe/tasks/cc/vision/image_classification:image_classifier_options_py_pb2", + "//mediapipe/tasks/python/components/containers:classifications", + "//mediapipe/tasks/python/core:base_options", + "//mediapipe/tasks/python/core:optional_dependencies", + "//mediapipe/tasks/python/core:task_info", + "//mediapipe/tasks/python/vision/core:base_vision_task_api", + "//mediapipe/tasks/python/vision/core:vision_task_running_mode", + ], +) diff --git a/mediapipe/tasks/python/vision/image_classification.py b/mediapipe/tasks/python/vision/image_classification.py new file mode 100644 index 000000000..efe6aa11d --- /dev/null +++ b/mediapipe/tasks/python/vision/image_classification.py @@ -0,0 +1,227 @@ +# Copyright 2022 The MediaPipe Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""MediaPipe image classifier task.""" + +import dataclasses +from typing import Callable, List, Mapping, Optional + +from mediapipe.python import packet_creator +from mediapipe.python import packet_getter +from mediapipe.python._framework_bindings import image as image_module +from mediapipe.python._framework_bindings import packet as packet_module +from mediapipe.python._framework_bindings import task_runner as task_runner_module +from mediapipe.tasks.cc.components import classifier_options_pb2 +from mediapipe.tasks.cc.vision.image_classification import image_classifier_options_pb2 +from mediapipe.tasks.python.components.containers import classifications as classifications_module +from mediapipe.tasks.python.core import base_options as base_options_module +from mediapipe.tasks.python.core import task_info as task_info_module +from mediapipe.tasks.python.core.optional_dependencies import doc_controls +from mediapipe.tasks.python.vision.core import base_vision_task_api +from mediapipe.tasks.python.vision.core import vision_task_running_mode as running_mode_module + +_BaseOptions = base_options_module.BaseOptions +_ClassifierOptionsProto = classifier_options_pb2.ClassifierOptions +_ImageClassifierOptionsProto = image_classifier_options_pb2.ImageClassifierOptions +_RunningMode = running_mode_module.VisionTaskRunningMode +_TaskInfo = task_info_module.TaskInfo +_TaskRunner = task_runner_module.TaskRunner + +_CLASSIFICATION_RESULT_OUT_STREAM_NAME = 'classification_result_out' +_CLASSIFICATION_RESULT_TAG = 'CLASSIFICATION_RESULT' +_IMAGE_IN_STREAM_NAME = 'image_in' +_IMAGE_TAG = 'IMAGE' +_TASK_GRAPH_NAME = 'mediapipe.tasks.vision.ImageClassifierGraph' + + +@dataclasses.dataclass +class ImageClassifierOptions: + """Options for the image classifier task. + + Attributes: + base_options: Base options for the image classifier task. + running_mode: The running mode of the task. Default to the image mode. + Image classifier task has three running modes: + 1) The image mode for classifying objects on single image inputs. + 2) The video mode for classifying objects on the decoded frames of a + video. + 3) The live stream mode for classifying objects on a live stream of input + data, such as from camera. + display_names_locale: The locale to use for display names specified through + the TFLite Model Metadata. + max_results: The maximum number of top-scored classification results to + return. + score_threshold: Overrides the ones provided in the model metadata. Results + below this value are rejected. + category_allowlist: Allowlist of category names. If non-empty, detection + results whose category name is not in this set will be filtered out. + Duplicate or unknown category names are ignored. Mutually exclusive with + `category_denylist`. + category_denylist: Denylist of category names. If non-empty, detection + results whose category name is in this set will be filtered out. Duplicate + or unknown category names are ignored. Mutually exclusive with + `category_allowlist`. + result_callback: The user-defined result callback for processing live stream + data. The result callback should only be specified when the running mode + is set to the live stream mode. + """ + base_options: _BaseOptions + running_mode: _RunningMode = _RunningMode.IMAGE + display_names_locale: Optional[str] = None + max_results: Optional[int] = None + score_threshold: Optional[float] = None + category_allowlist: Optional[List[str]] = None + category_denylist: Optional[List[str]] = None + result_callback: Optional[ + Callable[[classifications_module.ClassificationResult], None]] = None + + @doc_controls.do_not_generate_docs + def to_pb2(self) -> _ImageClassifierOptionsProto: + """Generates an ImageClassifierOptions protobuf object.""" + base_options_proto = self.base_options.to_pb2() + base_options_proto.use_stream_mode = False if self.running_mode == _RunningMode.IMAGE else True + + classifier_options_proto = _ClassifierOptionsProto( + display_names_locale=self.display_names_locale, + max_results=self.max_results, + score_threshold=self.score_threshold, + category_allowlist=self.category_allowlist, + category_denylist=self.category_denylist) + + return _ImageClassifierOptionsProto( + base_options=base_options_proto, + classifier_options=classifier_options_proto + ) + + +class ImageClassifier(base_vision_task_api.BaseVisionTaskApi): + """Class that performs image classification on images.""" + + @classmethod + def create_from_model_path(cls, model_path: str) -> 'ImageClassifier': + """Creates an `ImageClassifier` object from a TensorFlow Lite model and the default `ImageClassifierOptions`. + + Note that the created `ImageClassifier` instance is in image mode, for + detecting objects on single image inputs. + + Args: + model_path: Path to the model. + + Returns: + `ImageClassifier` object that's created from the model file and the default + `ImageClassifierOptions`. + + Raises: + ValueError: If failed to create `ImageClassifier` object from the provided + file such as invalid file path. + RuntimeError: If other types of error occurred. + """ + base_options = _BaseOptions(file_name=model_path) + options = ImageClassifierOptions( + base_options=base_options, running_mode=_RunningMode.IMAGE) + return cls.create_from_options(options) + + @classmethod + def create_from_options(cls, + options: ImageClassifierOptions) -> 'ImageClassifier': + """Creates the `ImageClassifier` object from image classifier options. + + Args: + options: Options for the image classifier task. + + Returns: + `ImageClassifier` object that's created from `options`. + + Raises: + ValueError: If failed to create `ImageClassifier` object from + `ImageClassifierOptions` such as missing the model. + RuntimeError: If other types of error occurred. + """ + + def packets_callback(output_packets: Mapping[str, packet_module.Packet]): + classification_result_proto = packet_getter.get_proto( + output_packets[_CLASSIFICATION_RESULT_OUT_STREAM_NAME]) + + classification_result = classifications_module.ClassificationResult([ + classifications_module.Classifications.create_from_pb2(classification) + for classification in classification_result_proto.classifications + ]) + options.result_callback(classification_result) + + task_info = _TaskInfo( + task_graph=_TASK_GRAPH_NAME, + input_streams=[':'.join([_IMAGE_TAG, _IMAGE_IN_STREAM_NAME])], + output_streams=[ + ':'.join([_CLASSIFICATION_RESULT_TAG, + _CLASSIFICATION_RESULT_OUT_STREAM_NAME]) + ], + task_options=options) + return cls( + task_info.generate_graph_config( + enable_flow_limiting=options.running_mode == + _RunningMode.LIVE_STREAM), options.running_mode, + packets_callback if options.result_callback else None) + + # TODO: Create an Image class for MediaPipe Tasks. + def classify( + self, + image: image_module.Image + ) -> classifications_module.ClassificationResult: + """Performs image classification on the provided MediaPipe Image. + + Args: + image: MediaPipe Image. + + Returns: + A classification result object that contains a list of classifications. + + Raises: + ValueError: If any of the input arguments is invalid. + RuntimeError: If image classification failed to run. + """ + output_packets = self._process_image_data( + {_IMAGE_IN_STREAM_NAME: packet_creator.create_image(image)}) + classification_result_proto = packet_getter.get_proto( + output_packets[_CLASSIFICATION_RESULT_OUT_STREAM_NAME]) + + return classifications_module.ClassificationResult([ + classifications_module.Classifications.create_from_pb2(classification) + for classification in classification_result_proto.classifications + ]) + + def classify_async(self, image: image_module.Image, timestamp_ms: int) -> None: + """Sends live image data (an Image with a unique timestamp) to perform image + classification. + + This method will return immediately after the input image is accepted. The + results will be available via the `result_callback` provided in the + `ImageClassifierOptions`. The `detect_async` method is designed to process + live stream data such as camera input. To lower the overall latency, image + classifier may drop the input images if needed. In other words, it's not + guaranteed to have output per input image. The `result_callback` provides: + - A classification result object that contains a list of classifications. + - The input image that the image classifier runs on. + - The input timestamp in milliseconds. + + Args: + image: MediaPipe Image. + timestamp_ms: The timestamp of the input image in milliseconds. + + Raises: + ValueError: If the current input timestamp is smaller than what the image + classifier has already processed. + """ + self._send_live_stream_data({ + _IMAGE_IN_STREAM_NAME: + packet_creator.create_image(image).at(timestamp_ms) + }) From ec0c5f43412eb521f2d943ee6eef0423c95baf3c Mon Sep 17 00:00:00 2001 From: kinaryml Date: Sun, 11 Sep 2022 14:00:49 -0700 Subject: [PATCH 02/55] Code cleanup --- .../test/vision/image_classification_test.py | 25 ++------------ .../python/vision/image_classification.py | 33 +++---------------- 2 files changed, 6 insertions(+), 52 deletions(-) diff --git a/mediapipe/tasks/python/test/vision/image_classification_test.py b/mediapipe/tasks/python/test/vision/image_classification_test.py index 3650c547c..a96eee6cb 100644 --- a/mediapipe/tasks/python/test/vision/image_classification_test.py +++ b/mediapipe/tasks/python/test/vision/image_classification_test.py @@ -124,7 +124,7 @@ class ImageClassifierTest(parameterized.TestCase): (ModelFileType.FILE_NAME, 4, _EXPECTED_CLASSIFICATION_RESULT), (ModelFileType.FILE_CONTENT, 4, _EXPECTED_CLASSIFICATION_RESULT)) def test_classify(self, model_file_type, max_results, - expected_classification_result): + expected_classification_result): # Creates classifier. if model_file_type is ModelFileType.FILE_NAME: base_options = _BaseOptions(file_name=self.model_path) @@ -152,7 +152,7 @@ class ImageClassifierTest(parameterized.TestCase): (ModelFileType.FILE_NAME, 4, _EXPECTED_CLASSIFICATION_RESULT), (ModelFileType.FILE_CONTENT, 4, _EXPECTED_CLASSIFICATION_RESULT)) def test_classify_in_context(self, model_file_type, max_results, - expected_classification_result): + expected_classification_result): if model_file_type is ModelFileType.FILE_NAME: base_options = _BaseOptions(file_name=self.model_path) elif model_file_type is ModelFileType.FILE_CONTENT: @@ -275,27 +275,6 @@ class ImageClassifierTest(parameterized.TestCase): with _ImageClassifier.create_from_options(options) as unused_classifier: pass - # @parameterized.parameters((0, _EXPECTED_CLASSIFICATION_RESULT), - # (1, _ClassificationResult(classifications=[]))) - # def test_classify_async_calls(self, threshold, expected_result): - # observed_timestamp_ms = -1 - # - # def check_result(result: _ClassificationResult, timestamp_ms: int): - # self.assertEqual(result, expected_result) - # self.assertLess(observed_timestamp_ms, timestamp_ms) - # self.observed_timestamp_ms = timestamp_ms - # - # options = _ImageClassifierOptions( - # base_options=_BaseOptions(file_name=self.model_path), - # running_mode=_RUNNING_MODE.LIVE_STREAM, - # max_results=4, - # score_threshold=threshold, - # result_callback=check_result) - # classifier = _ImageClassifier.create_from_options(options) - # for timestamp in range(0, 300, 30): - # classifier.classify_async(self.test_image, timestamp) - # classifier.close() - if __name__ == '__main__': absltest.main() diff --git a/mediapipe/tasks/python/vision/image_classification.py b/mediapipe/tasks/python/vision/image_classification.py index efe6aa11d..95381d78a 100644 --- a/mediapipe/tasks/python/vision/image_classification.py +++ b/mediapipe/tasks/python/vision/image_classification.py @@ -83,7 +83,8 @@ class ImageClassifierOptions: category_allowlist: Optional[List[str]] = None category_denylist: Optional[List[str]] = None result_callback: Optional[ - Callable[[classifications_module.ClassificationResult], None]] = None + Callable[[classifications_module.ClassificationResult], + None]] = None @doc_controls.do_not_generate_docs def to_pb2(self) -> _ImageClassifierOptionsProto: @@ -96,7 +97,8 @@ class ImageClassifierOptions: max_results=self.max_results, score_threshold=self.score_threshold, category_allowlist=self.category_allowlist, - category_denylist=self.category_denylist) + category_denylist=self.category_denylist + ) return _ImageClassifierOptionsProto( base_options=base_options_proto, @@ -198,30 +200,3 @@ class ImageClassifier(base_vision_task_api.BaseVisionTaskApi): classifications_module.Classifications.create_from_pb2(classification) for classification in classification_result_proto.classifications ]) - - def classify_async(self, image: image_module.Image, timestamp_ms: int) -> None: - """Sends live image data (an Image with a unique timestamp) to perform image - classification. - - This method will return immediately after the input image is accepted. The - results will be available via the `result_callback` provided in the - `ImageClassifierOptions`. The `detect_async` method is designed to process - live stream data such as camera input. To lower the overall latency, image - classifier may drop the input images if needed. In other words, it's not - guaranteed to have output per input image. The `result_callback` provides: - - A classification result object that contains a list of classifications. - - The input image that the image classifier runs on. - - The input timestamp in milliseconds. - - Args: - image: MediaPipe Image. - timestamp_ms: The timestamp of the input image in milliseconds. - - Raises: - ValueError: If the current input timestamp is smaller than what the image - classifier has already processed. - """ - self._send_live_stream_data({ - _IMAGE_IN_STREAM_NAME: - packet_creator.create_image(image).at(timestamp_ms) - }) From 7287e5a0ed860a29abb8f1ad9c88d82021d1c8e1 Mon Sep 17 00:00:00 2001 From: kinaryml Date: Wed, 21 Sep 2022 03:27:14 -0700 Subject: [PATCH 03/55] Added the image classifier task graph --- mediapipe/python/BUILD | 1 + 1 file changed, 1 insertion(+) diff --git a/mediapipe/python/BUILD b/mediapipe/python/BUILD index 3a4a90b44..9fe21ab9e 100644 --- a/mediapipe/python/BUILD +++ b/mediapipe/python/BUILD @@ -86,6 +86,7 @@ cc_library( name = "builtin_task_graphs", deps = [ "//mediapipe/tasks/cc/vision/object_detector:object_detector_graph", + "//mediapipe/tasks/cc/vision/image_classification:image_classifier_graph", ], ) From d8f7c5a43b311e166c337c451595348b5e1610d3 Mon Sep 17 00:00:00 2001 From: kinaryml Date: Wed, 21 Sep 2022 04:22:33 -0700 Subject: [PATCH 04/55] Moved ClassifierOptions to mediapipe/tasks/python/components to align with the C++ API --- mediapipe/tasks/python/components/BUILD | 28 ++++++ .../python/components/classifier_options.py | 92 +++++++++++++++++++ .../test/vision/image_classification_test.py | 31 +++++-- mediapipe/tasks/python/vision/BUILD | 2 +- .../python/vision/image_classification.py | 19 +--- 5 files changed, 146 insertions(+), 26 deletions(-) create mode 100644 mediapipe/tasks/python/components/BUILD create mode 100644 mediapipe/tasks/python/components/classifier_options.py diff --git a/mediapipe/tasks/python/components/BUILD b/mediapipe/tasks/python/components/BUILD new file mode 100644 index 000000000..4094b7f7f --- /dev/null +++ b/mediapipe/tasks/python/components/BUILD @@ -0,0 +1,28 @@ +# Copyright 2022 The MediaPipe Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +# Placeholder for internal Python strict library compatibility macro. + +package(default_visibility = ["//mediapipe/tasks:internal"]) + +licenses(["notice"]) + +py_library( + name = "classifier_options", + srcs = ["classifier_options.py"], + deps = [ + "//mediapipe/tasks/cc/components:classifier_options_py_pb2", + "//mediapipe/tasks/python/core:optional_dependencies", + ], +) diff --git a/mediapipe/tasks/python/components/classifier_options.py b/mediapipe/tasks/python/components/classifier_options.py new file mode 100644 index 000000000..f6e61e48c --- /dev/null +++ b/mediapipe/tasks/python/components/classifier_options.py @@ -0,0 +1,92 @@ +# Copyright 2022 The TensorFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""Classifier options data class.""" + +import dataclasses +from typing import Any, List, Optional + +from mediapipe.tasks.cc.components import classifier_options_pb2 +from mediapipe.tasks.python.core.optional_dependencies import doc_controls + +_ClassifierOptionsProto = classifier_options_pb2.ClassifierOptions + + +@dataclasses.dataclass +class ClassifierOptions: + """Options for classification processor. + + Attributes: + display_names_locale: The locale to use for display names specified through + the TFLite Model Metadata. + max_results: The maximum number of top-scored classification results to + return. + score_threshold: Overrides the ones provided in the model metadata. Results + below this value are rejected. + category_allowlist: Allowlist of category names. If non-empty, detection + results whose category name is not in this set will be filtered out. + Duplicate or unknown category names are ignored. Mutually exclusive with + `category_denylist`. + category_denylist: Denylist of category names. If non-empty, detection + results whose category name is in this set will be filtered out. Duplicate + or unknown category names are ignored. Mutually exclusive with + `category_allowlist`. + """ + + display_names_locale: Optional[str] = None + max_results: Optional[int] = None + score_threshold: Optional[float] = None + category_allowlist: Optional[List[str]] = None + category_denylist: Optional[List[str]] = None + + @doc_controls.do_not_generate_docs + def to_pb2(self) -> _ClassifierOptionsProto: + """Generates a ClassifierOptions protobuf object.""" + return _ClassifierOptionsProto( + score_threshold=self.score_threshold, + category_allowlist=self.category_allowlist, + category_denylist=self.category_denylist, + display_names_locale=self.display_names_locale, + max_results=self.max_results) + + @classmethod + @doc_controls.do_not_generate_docs + def create_from_pb2( + cls, + pb2_obj: _ClassifierOptionsProto + ) -> 'ClassifierOptions': + """Creates a `ClassifierOptions` object from the given protobuf object.""" + return ClassifierOptions( + score_threshold=pb2_obj.score_threshold, + category_allowlist=[ + str(name) for name in pb2_obj.class_name_allowlist + ], + category_denylist=[ + str(name) for name in pb2_obj.class_name_denylist + ], + display_names_locale=pb2_obj.display_names_locale, + max_results=pb2_obj.max_results) + + def __eq__(self, other: Any) -> bool: + """Checks if this object is equal to the given object. + + Args: + other: The object to be compared with. + + Returns: + True if the objects are equal. + """ + if not isinstance(other, ClassifierOptions): + return False + + return self.to_pb2().__eq__(other.to_pb2()) diff --git a/mediapipe/tasks/python/test/vision/image_classification_test.py b/mediapipe/tasks/python/test/vision/image_classification_test.py index a96eee6cb..51dcb1adf 100644 --- a/mediapipe/tasks/python/test/vision/image_classification_test.py +++ b/mediapipe/tasks/python/test/vision/image_classification_test.py @@ -19,6 +19,7 @@ from absl.testing import absltest from absl.testing import parameterized from mediapipe.python._framework_bindings import image as image_module +from mediapipe.tasks.python.components import classifier_options from mediapipe.tasks.python.components.containers import category as category_module from mediapipe.tasks.python.components.containers import classifications as classifications_module from mediapipe.tasks.python.core import base_options as base_options_module @@ -27,6 +28,7 @@ from mediapipe.tasks.python.vision import image_classification from mediapipe.tasks.python.vision.core import vision_task_running_mode as running_mode_module _BaseOptions = base_options_module.BaseOptions +_ClassifierOptions = classifier_options.ClassifierOptions _Category = category_module.Category _ClassificationEntry = classifications_module.ClassificationEntry _Classifications = classifications_module.Classifications @@ -136,8 +138,9 @@ class ImageClassifierTest(parameterized.TestCase): # Should never happen raise ValueError('model_file_type is invalid.') + classifier_options = _ClassifierOptions(max_results=max_results) options = _ImageClassifierOptions( - base_options=base_options, max_results=max_results) + base_options=base_options, classifier_options=classifier_options) classifier = _ImageClassifier.create_from_options(options) # Performs image classification on the input. @@ -163,8 +166,9 @@ class ImageClassifierTest(parameterized.TestCase): # Should never happen raise ValueError('model_file_type is invalid.') + classifier_options = _ClassifierOptions(max_results=max_results) options = _ImageClassifierOptions( - base_options=base_options, max_results=max_results) + base_options=base_options, classifier_options=classifier_options) with _ImageClassifier.create_from_options(options) as classifier: # Performs object detection on the input. image_result = classifier.classify(self.test_image) @@ -172,9 +176,10 @@ class ImageClassifierTest(parameterized.TestCase): self.assertEqual(image_result, expected_classification_result) def test_score_threshold_option(self): + classifier_options = _ClassifierOptions(score_threshold=_SCORE_THRESHOLD) options = _ImageClassifierOptions( base_options=_BaseOptions(file_name=self.model_path), - score_threshold=_SCORE_THRESHOLD) + classifier_options=classifier_options) with _ImageClassifier.create_from_options(options) as classifier: # Performs image classification on the input. image_result = classifier.classify(self.test_image) @@ -189,9 +194,10 @@ class ImageClassifierTest(parameterized.TestCase): f'{classification}') def test_max_results_option(self): + classifier_options = _ClassifierOptions(score_threshold=_SCORE_THRESHOLD) options = _ImageClassifierOptions( base_options=_BaseOptions(file_name=self.model_path), - max_results=_MAX_RESULTS) + classifier_options=classifier_options) with _ImageClassifier.create_from_options(options) as classifier: # Performs image classification on the input. image_result = classifier.classify(self.test_image) @@ -201,9 +207,10 @@ class ImageClassifierTest(parameterized.TestCase): len(categories), _MAX_RESULTS, 'Too many results returned.') def test_allow_list_option(self): + classifier_options = _ClassifierOptions(category_allowlist=_ALLOW_LIST) options = _ImageClassifierOptions( base_options=_BaseOptions(file_name=self.model_path), - category_allowlist=_ALLOW_LIST) + classifier_options=classifier_options) with _ImageClassifier.create_from_options(options) as classifier: # Performs image classification on the input. image_result = classifier.classify(self.test_image) @@ -216,9 +223,10 @@ class ImageClassifierTest(parameterized.TestCase): f'Label {label} found but not in label allow list') def test_deny_list_option(self): + classifier_options = _ClassifierOptions(category_denylist=_DENY_LIST) options = _ImageClassifierOptions( - base_options=_BaseOptions(file_name=self.model_path), - category_denylist=_DENY_LIST) + base_options=_BaseOptions(file_name=self.model_path), + classifier_options=classifier_options) with _ImageClassifier.create_from_options(options) as classifier: # Performs image classification on the input. image_result = classifier.classify(self.test_image) @@ -236,16 +244,19 @@ class ImageClassifierTest(parameterized.TestCase): ValueError, r'`category_allowlist` and `category_denylist` are mutually ' r'exclusive options.'): + classifier_options = _ClassifierOptions(category_allowlist=['foo'], + category_denylist=['bar']) options = _ImageClassifierOptions( base_options=_BaseOptions(file_name=self.model_path), - category_allowlist=['foo'], - category_denylist=['bar']) + classifier_options=classifier_options) with _ImageClassifier.create_from_options(options) as unused_classifier: pass def test_empty_classification_outputs(self): + classifier_options = _ClassifierOptions(score_threshold=1) options = _ImageClassifierOptions( - base_options=_BaseOptions(file_name=self.model_path), score_threshold=1) + base_options=_BaseOptions(file_name=self.model_path), + classifier_options=classifier_options) with _ImageClassifier.create_from_options(options) as classifier: # Performs image classification on the input. image_result = classifier.classify(self.test_image) diff --git a/mediapipe/tasks/python/vision/BUILD b/mediapipe/tasks/python/vision/BUILD index 7a27da179..40caf129f 100644 --- a/mediapipe/tasks/python/vision/BUILD +++ b/mediapipe/tasks/python/vision/BUILD @@ -46,8 +46,8 @@ py_library( "//mediapipe/python:_framework_bindings", "//mediapipe/python:packet_creator", "//mediapipe/python:packet_getter", - "//mediapipe/tasks/cc/components:classifier_options_py_pb2", "//mediapipe/tasks/cc/vision/image_classification:image_classifier_options_py_pb2", + "//mediapipe/tasks/python/components:classifier_options", "//mediapipe/tasks/python/components/containers:classifications", "//mediapipe/tasks/python/core:base_options", "//mediapipe/tasks/python/core:optional_dependencies", diff --git a/mediapipe/tasks/python/vision/image_classification.py b/mediapipe/tasks/python/vision/image_classification.py index 95381d78a..94176cdf8 100644 --- a/mediapipe/tasks/python/vision/image_classification.py +++ b/mediapipe/tasks/python/vision/image_classification.py @@ -21,8 +21,8 @@ from mediapipe.python import packet_getter from mediapipe.python._framework_bindings import image as image_module from mediapipe.python._framework_bindings import packet as packet_module from mediapipe.python._framework_bindings import task_runner as task_runner_module -from mediapipe.tasks.cc.components import classifier_options_pb2 from mediapipe.tasks.cc.vision.image_classification import image_classifier_options_pb2 +from mediapipe.tasks.python.components import classifier_options from mediapipe.tasks.python.components.containers import classifications as classifications_module from mediapipe.tasks.python.core import base_options as base_options_module from mediapipe.tasks.python.core import task_info as task_info_module @@ -31,8 +31,8 @@ from mediapipe.tasks.python.vision.core import base_vision_task_api from mediapipe.tasks.python.vision.core import vision_task_running_mode as running_mode_module _BaseOptions = base_options_module.BaseOptions -_ClassifierOptionsProto = classifier_options_pb2.ClassifierOptions _ImageClassifierOptionsProto = image_classifier_options_pb2.ImageClassifierOptions +_ClassifierOptions = classifier_options.ClassifierOptions _RunningMode = running_mode_module.VisionTaskRunningMode _TaskInfo = task_info_module.TaskInfo _TaskRunner = task_runner_module.TaskRunner @@ -77,11 +77,7 @@ class ImageClassifierOptions: """ base_options: _BaseOptions running_mode: _RunningMode = _RunningMode.IMAGE - display_names_locale: Optional[str] = None - max_results: Optional[int] = None - score_threshold: Optional[float] = None - category_allowlist: Optional[List[str]] = None - category_denylist: Optional[List[str]] = None + classifier_options: _ClassifierOptions = _ClassifierOptions() result_callback: Optional[ Callable[[classifications_module.ClassificationResult], None]] = None @@ -91,14 +87,7 @@ class ImageClassifierOptions: """Generates an ImageClassifierOptions protobuf object.""" base_options_proto = self.base_options.to_pb2() base_options_proto.use_stream_mode = False if self.running_mode == _RunningMode.IMAGE else True - - classifier_options_proto = _ClassifierOptionsProto( - display_names_locale=self.display_names_locale, - max_results=self.max_results, - score_threshold=self.score_threshold, - category_allowlist=self.category_allowlist, - category_denylist=self.category_denylist - ) + classifier_options_proto = self.classifier_options.to_pb2() return _ImageClassifierOptionsProto( base_options=base_options_proto, From bb750befd2e9ca066b76981404f09bbd91919a18 Mon Sep 17 00:00:00 2001 From: kinaryml Date: Wed, 21 Sep 2022 04:24:11 -0700 Subject: [PATCH 05/55] Updated ImageClassifierOptions docstring --- .../tasks/python/vision/image_classification.py | 15 +-------------- 1 file changed, 1 insertion(+), 14 deletions(-) diff --git a/mediapipe/tasks/python/vision/image_classification.py b/mediapipe/tasks/python/vision/image_classification.py index 94176cdf8..f08da9f2b 100644 --- a/mediapipe/tasks/python/vision/image_classification.py +++ b/mediapipe/tasks/python/vision/image_classification.py @@ -57,20 +57,7 @@ class ImageClassifierOptions: video. 3) The live stream mode for classifying objects on a live stream of input data, such as from camera. - display_names_locale: The locale to use for display names specified through - the TFLite Model Metadata. - max_results: The maximum number of top-scored classification results to - return. - score_threshold: Overrides the ones provided in the model metadata. Results - below this value are rejected. - category_allowlist: Allowlist of category names. If non-empty, detection - results whose category name is not in this set will be filtered out. - Duplicate or unknown category names are ignored. Mutually exclusive with - `category_denylist`. - category_denylist: Denylist of category names. If non-empty, detection - results whose category name is in this set will be filtered out. Duplicate - or unknown category names are ignored. Mutually exclusive with - `category_allowlist`. + classifier_options: Options for the image classification task. result_callback: The user-defined result callback for processing live stream data. The result callback should only be specified when the running mode is set to the live stream mode. From 72319ecbf57cb56502e8e9af912b7caa37e2b2eb Mon Sep 17 00:00:00 2001 From: kinaryml Date: Thu, 29 Sep 2022 02:14:42 -0700 Subject: [PATCH 06/55] Updated BUILD --- mediapipe/tasks/python/test/vision/BUILD | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/mediapipe/tasks/python/test/vision/BUILD b/mediapipe/tasks/python/test/vision/BUILD index a63c36b55..46efc402d 100644 --- a/mediapipe/tasks/python/test/vision/BUILD +++ b/mediapipe/tasks/python/test/vision/BUILD @@ -26,7 +26,7 @@ py_test( "//mediapipe/tasks/testdata/vision:test_models", ], deps = [ - # build rule placeholder: numpy dep, + "//mediapipe/python:_framework_bindings", "//mediapipe/tasks/python/components/containers:bounding_box", "//mediapipe/tasks/python/components/containers:category", "//mediapipe/tasks/python/components/containers:detections", @@ -34,7 +34,6 @@ py_test( "//mediapipe/tasks/python/test:test_util", "//mediapipe/tasks/python/vision:object_detector", "//mediapipe/tasks/python/vision/core:vision_task_running_mode", - "@absl_py//absl/testing:parameterized", ], ) @@ -46,6 +45,7 @@ py_test( "//mediapipe/tasks/testdata/vision:test_models", ], deps = [ + "//mediapipe/python:_framework_bindings", "//mediapipe/tasks/python/components/containers:category", "//mediapipe/tasks/python/components/containers:classifications", "//mediapipe/tasks/python/core:base_options", From 8ad591822922338aca7cabd25fddd3bacdc98d9d Mon Sep 17 00:00:00 2001 From: kinaryml Date: Thu, 29 Sep 2022 02:18:10 -0700 Subject: [PATCH 07/55] Removed some tests --- .../test/vision/image_classification_test.py | 165 ------------------ 1 file changed, 165 deletions(-) diff --git a/mediapipe/tasks/python/test/vision/image_classification_test.py b/mediapipe/tasks/python/test/vision/image_classification_test.py index 51dcb1adf..0d7f9e7f0 100644 --- a/mediapipe/tasks/python/test/vision/image_classification_test.py +++ b/mediapipe/tasks/python/test/vision/image_classification_test.py @@ -92,36 +92,6 @@ class ImageClassifierTest(parameterized.TestCase): test_util.get_test_data_path(_IMAGE_FILE)) self.model_path = test_util.get_test_data_path(_MODEL_FILE) - def test_create_from_file_succeeds_with_valid_model_path(self): - # Creates with default option and valid model file successfully. - with _ImageClassifier.create_from_model_path(self.model_path) as classifier: - self.assertIsInstance(classifier, _ImageClassifier) - - def test_create_from_options_succeeds_with_valid_model_path(self): - # Creates with options containing model file successfully. - base_options = _BaseOptions(file_name=self.model_path) - options = _ImageClassifierOptions(base_options=base_options) - with _ImageClassifier.create_from_options(options) as classifier: - self.assertIsInstance(classifier, _ImageClassifier) - - def test_create_from_options_fails_with_invalid_model_path(self): - # Invalid empty model path. - with self.assertRaisesRegex( - ValueError, - r"ExternalFile must specify at least one of 'file_content', " - r"'file_name' or 'file_descriptor_meta'."): - base_options = _BaseOptions(file_name='') - options = _ImageClassifierOptions(base_options=base_options) - _ImageClassifier.create_from_options(options) - - def test_create_from_options_succeeds_with_valid_model_content(self): - # Creates with options containing model content successfully. - with open(self.model_path, 'rb') as f: - base_options = _BaseOptions(file_content=f.read()) - options = _ImageClassifierOptions(base_options=base_options) - classifier = _ImageClassifier.create_from_options(options) - self.assertIsInstance(classifier, _ImageClassifier) - @parameterized.parameters( (ModelFileType.FILE_NAME, 4, _EXPECTED_CLASSIFICATION_RESULT), (ModelFileType.FILE_CONTENT, 4, _EXPECTED_CLASSIFICATION_RESULT)) @@ -151,141 +121,6 @@ class ImageClassifierTest(parameterized.TestCase): # a context. classifier.close() - @parameterized.parameters( - (ModelFileType.FILE_NAME, 4, _EXPECTED_CLASSIFICATION_RESULT), - (ModelFileType.FILE_CONTENT, 4, _EXPECTED_CLASSIFICATION_RESULT)) - def test_classify_in_context(self, model_file_type, max_results, - expected_classification_result): - if model_file_type is ModelFileType.FILE_NAME: - base_options = _BaseOptions(file_name=self.model_path) - elif model_file_type is ModelFileType.FILE_CONTENT: - with open(self.model_path, 'rb') as f: - model_content = f.read() - base_options = _BaseOptions(file_content=model_content) - else: - # Should never happen - raise ValueError('model_file_type is invalid.') - - classifier_options = _ClassifierOptions(max_results=max_results) - options = _ImageClassifierOptions( - base_options=base_options, classifier_options=classifier_options) - with _ImageClassifier.create_from_options(options) as classifier: - # Performs object detection on the input. - image_result = classifier.classify(self.test_image) - # Comparing results. - self.assertEqual(image_result, expected_classification_result) - - def test_score_threshold_option(self): - classifier_options = _ClassifierOptions(score_threshold=_SCORE_THRESHOLD) - options = _ImageClassifierOptions( - base_options=_BaseOptions(file_name=self.model_path), - classifier_options=classifier_options) - with _ImageClassifier.create_from_options(options) as classifier: - # Performs image classification on the input. - image_result = classifier.classify(self.test_image) - classifications = image_result.classifications - - for classification in classifications: - for entry in classification.entries: - score = entry.categories[0].score - self.assertGreaterEqual( - score, _SCORE_THRESHOLD, - f'Classification with score lower than threshold found. ' - f'{classification}') - - def test_max_results_option(self): - classifier_options = _ClassifierOptions(score_threshold=_SCORE_THRESHOLD) - options = _ImageClassifierOptions( - base_options=_BaseOptions(file_name=self.model_path), - classifier_options=classifier_options) - with _ImageClassifier.create_from_options(options) as classifier: - # Performs image classification on the input. - image_result = classifier.classify(self.test_image) - categories = image_result.classifications[0].entries[0].categories - - self.assertLessEqual( - len(categories), _MAX_RESULTS, 'Too many results returned.') - - def test_allow_list_option(self): - classifier_options = _ClassifierOptions(category_allowlist=_ALLOW_LIST) - options = _ImageClassifierOptions( - base_options=_BaseOptions(file_name=self.model_path), - classifier_options=classifier_options) - with _ImageClassifier.create_from_options(options) as classifier: - # Performs image classification on the input. - image_result = classifier.classify(self.test_image) - classifications = image_result.classifications - - for classification in classifications: - for entry in classification.entries: - label = entry.categories[0].category_name - self.assertIn(label, _ALLOW_LIST, - f'Label {label} found but not in label allow list') - - def test_deny_list_option(self): - classifier_options = _ClassifierOptions(category_denylist=_DENY_LIST) - options = _ImageClassifierOptions( - base_options=_BaseOptions(file_name=self.model_path), - classifier_options=classifier_options) - with _ImageClassifier.create_from_options(options) as classifier: - # Performs image classification on the input. - image_result = classifier.classify(self.test_image) - classifications = image_result.classifications - - for classification in classifications: - for entry in classification.entries: - label = entry.categories[0].category_name - self.assertNotIn(label, _DENY_LIST, - f'Label {label} found but in deny list.') - - def test_combined_allowlist_and_denylist(self): - # Fails with combined allowlist and denylist - with self.assertRaisesRegex( - ValueError, - r'`category_allowlist` and `category_denylist` are mutually ' - r'exclusive options.'): - classifier_options = _ClassifierOptions(category_allowlist=['foo'], - category_denylist=['bar']) - options = _ImageClassifierOptions( - base_options=_BaseOptions(file_name=self.model_path), - classifier_options=classifier_options) - with _ImageClassifier.create_from_options(options) as unused_classifier: - pass - - def test_empty_classification_outputs(self): - classifier_options = _ClassifierOptions(score_threshold=1) - options = _ImageClassifierOptions( - base_options=_BaseOptions(file_name=self.model_path), - classifier_options=classifier_options) - with _ImageClassifier.create_from_options(options) as classifier: - # Performs image classification on the input. - image_result = classifier.classify(self.test_image) - self.assertEmpty(image_result.classifications[0].entries[0].categories) - - def test_missing_result_callback(self): - options = _ImageClassifierOptions( - base_options=_BaseOptions(file_name=self.model_path), - running_mode=_RUNNING_MODE.LIVE_STREAM) - with self.assertRaisesRegex(ValueError, - r'result callback must be provided'): - with _ImageClassifier.create_from_options(options) as unused_classifier: - pass - - @parameterized.parameters((_RUNNING_MODE.IMAGE), (_RUNNING_MODE.VIDEO)) - def test_illegal_result_callback(self, running_mode): - - def pass_through(unused_result: _ClassificationResult): - pass - - options = _ImageClassifierOptions( - base_options=_BaseOptions(file_name=self.model_path), - running_mode=running_mode, - result_callback=pass_through) - with self.assertRaisesRegex(ValueError, - r'result callback should not be provided'): - with _ImageClassifier.create_from_options(options) as unused_classifier: - pass - if __name__ == '__main__': absltest.main() From b04af0cafa7b8ec0d48122d20c7bd2f7693b12f0 Mon Sep 17 00:00:00 2001 From: kinaryml Date: Thu, 29 Sep 2022 02:45:20 -0700 Subject: [PATCH 08/55] Updated implementation and tests --- mediapipe/python/BUILD | 2 +- mediapipe/tasks/python/components/proto/BUILD | 29 +++++++++++++++++++ .../{ => proto}/classifier_options.py | 0 .../test/vision/image_classification_test.py | 4 +-- mediapipe/tasks/python/vision/BUILD | 4 +-- .../python/vision/image_classification.py | 6 ++-- 6 files changed, 37 insertions(+), 8 deletions(-) create mode 100644 mediapipe/tasks/python/components/proto/BUILD rename mediapipe/tasks/python/components/{ => proto}/classifier_options.py (100%) diff --git a/mediapipe/python/BUILD b/mediapipe/python/BUILD index 9fe21ab9e..33667d18e 100644 --- a/mediapipe/python/BUILD +++ b/mediapipe/python/BUILD @@ -86,7 +86,7 @@ cc_library( name = "builtin_task_graphs", deps = [ "//mediapipe/tasks/cc/vision/object_detector:object_detector_graph", - "//mediapipe/tasks/cc/vision/image_classification:image_classifier_graph", + "//mediapipe/tasks/cc/vision/image_classifier:image_classifier_graph", ], ) diff --git a/mediapipe/tasks/python/components/proto/BUILD b/mediapipe/tasks/python/components/proto/BUILD new file mode 100644 index 000000000..7814ec675 --- /dev/null +++ b/mediapipe/tasks/python/components/proto/BUILD @@ -0,0 +1,29 @@ +# Copyright 2022 The MediaPipe Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +# Placeholder for internal Python strict library compatibility macro. + +package(default_visibility = ["//mediapipe/tasks:internal"]) + +licenses(["notice"]) + +py_library( + name = "classifier_options", + srcs = ["classifier_options.py"], + deps = [ + "//mediapipe/tasks/cc/components/proto:classifier_options_py_pb2", + "//mediapipe/tasks/python/core:optional_dependencies", + ], +) + diff --git a/mediapipe/tasks/python/components/classifier_options.py b/mediapipe/tasks/python/components/proto/classifier_options.py similarity index 100% rename from mediapipe/tasks/python/components/classifier_options.py rename to mediapipe/tasks/python/components/proto/classifier_options.py diff --git a/mediapipe/tasks/python/test/vision/image_classification_test.py b/mediapipe/tasks/python/test/vision/image_classification_test.py index 0d7f9e7f0..08d850dab 100644 --- a/mediapipe/tasks/python/test/vision/image_classification_test.py +++ b/mediapipe/tasks/python/test/vision/image_classification_test.py @@ -99,11 +99,11 @@ class ImageClassifierTest(parameterized.TestCase): expected_classification_result): # Creates classifier. if model_file_type is ModelFileType.FILE_NAME: - base_options = _BaseOptions(file_name=self.model_path) + base_options = _BaseOptions(model_asset_path=self.model_path) elif model_file_type is ModelFileType.FILE_CONTENT: with open(self.model_path, 'rb') as f: model_content = f.read() - base_options = _BaseOptions(file_content=model_content) + base_options = _BaseOptions(model_asset_buffer=model_content) else: # Should never happen raise ValueError('model_file_type is invalid.') diff --git a/mediapipe/tasks/python/vision/BUILD b/mediapipe/tasks/python/vision/BUILD index 40caf129f..5e5bfe3ba 100644 --- a/mediapipe/tasks/python/vision/BUILD +++ b/mediapipe/tasks/python/vision/BUILD @@ -46,8 +46,8 @@ py_library( "//mediapipe/python:_framework_bindings", "//mediapipe/python:packet_creator", "//mediapipe/python:packet_getter", - "//mediapipe/tasks/cc/vision/image_classification:image_classifier_options_py_pb2", - "//mediapipe/tasks/python/components:classifier_options", + "//mediapipe/tasks/cc/vision/image_classifier/proto:image_classifier_options_py_pb2", + "//mediapipe/tasks/python/components/proto:classifier_options", "//mediapipe/tasks/python/components/containers:classifications", "//mediapipe/tasks/python/core:base_options", "//mediapipe/tasks/python/core:optional_dependencies", diff --git a/mediapipe/tasks/python/vision/image_classification.py b/mediapipe/tasks/python/vision/image_classification.py index f08da9f2b..c240ffedf 100644 --- a/mediapipe/tasks/python/vision/image_classification.py +++ b/mediapipe/tasks/python/vision/image_classification.py @@ -21,8 +21,8 @@ from mediapipe.python import packet_getter from mediapipe.python._framework_bindings import image as image_module from mediapipe.python._framework_bindings import packet as packet_module from mediapipe.python._framework_bindings import task_runner as task_runner_module -from mediapipe.tasks.cc.vision.image_classification import image_classifier_options_pb2 -from mediapipe.tasks.python.components import classifier_options +from mediapipe.tasks.cc.vision.image_classifier.proto import image_classifier_options_pb2 +from mediapipe.tasks.python.components.proto import classifier_options from mediapipe.tasks.python.components.containers import classifications as classifications_module from mediapipe.tasks.python.core import base_options as base_options_module from mediapipe.tasks.python.core import task_info as task_info_module @@ -104,7 +104,7 @@ class ImageClassifier(base_vision_task_api.BaseVisionTaskApi): file such as invalid file path. RuntimeError: If other types of error occurred. """ - base_options = _BaseOptions(file_name=model_path) + base_options = _BaseOptions(model_asset_path=model_path) options = ImageClassifierOptions( base_options=base_options, running_mode=_RunningMode.IMAGE) return cls.create_from_options(options) From cba2a6035c55842c440cb8be78244794eca38cf4 Mon Sep 17 00:00:00 2001 From: kinaryml Date: Thu, 29 Sep 2022 03:15:17 -0700 Subject: [PATCH 09/55] Code cleanup --- .../tasks/python/components/proto/classifier_options.py | 2 +- mediapipe/tasks/python/test/vision/BUILD | 7 ++++--- ...ge_classification_test.py => image_classifier_test.py} | 8 ++++---- mediapipe/tasks/python/vision/BUILD | 4 ++-- .../{image_classification.py => image_classifier.py} | 0 5 files changed, 11 insertions(+), 10 deletions(-) rename mediapipe/tasks/python/test/vision/{image_classification_test.py => image_classifier_test.py} (94%) rename mediapipe/tasks/python/vision/{image_classification.py => image_classifier.py} (100%) diff --git a/mediapipe/tasks/python/components/proto/classifier_options.py b/mediapipe/tasks/python/components/proto/classifier_options.py index f6e61e48c..c73828c13 100644 --- a/mediapipe/tasks/python/components/proto/classifier_options.py +++ b/mediapipe/tasks/python/components/proto/classifier_options.py @@ -16,7 +16,7 @@ import dataclasses from typing import Any, List, Optional -from mediapipe.tasks.cc.components import classifier_options_pb2 +from mediapipe.tasks.cc.components.proto import classifier_options_pb2 from mediapipe.tasks.python.core.optional_dependencies import doc_controls _ClassifierOptionsProto = classifier_options_pb2.ClassifierOptions diff --git a/mediapipe/tasks/python/test/vision/BUILD b/mediapipe/tasks/python/test/vision/BUILD index c96be63f8..6b6b9e3e2 100644 --- a/mediapipe/tasks/python/test/vision/BUILD +++ b/mediapipe/tasks/python/test/vision/BUILD @@ -38,19 +38,20 @@ py_test( ) py_test( - name = "image_classification_test", - srcs = ["image_classification_test.py"], + name = "image_classifier_test", + srcs = ["image_classifier_test.py"], data = [ "//mediapipe/tasks/testdata/vision:test_images", "//mediapipe/tasks/testdata/vision:test_models", ], deps = [ "//mediapipe/python:_framework_bindings", + "//mediapipe/tasks/python/components/proto:classifier_options", "//mediapipe/tasks/python/components/containers:category", "//mediapipe/tasks/python/components/containers:classifications", "//mediapipe/tasks/python/core:base_options", "//mediapipe/tasks/python/test:test_util", - "//mediapipe/tasks/python/vision:image_classification", + "//mediapipe/tasks/python/vision:image_classifier", "//mediapipe/tasks/python/vision/core:vision_task_running_mode", ], ) diff --git a/mediapipe/tasks/python/test/vision/image_classification_test.py b/mediapipe/tasks/python/test/vision/image_classifier_test.py similarity index 94% rename from mediapipe/tasks/python/test/vision/image_classification_test.py rename to mediapipe/tasks/python/test/vision/image_classifier_test.py index 08d850dab..3d7d116df 100644 --- a/mediapipe/tasks/python/test/vision/image_classification_test.py +++ b/mediapipe/tasks/python/test/vision/image_classifier_test.py @@ -19,12 +19,12 @@ from absl.testing import absltest from absl.testing import parameterized from mediapipe.python._framework_bindings import image as image_module -from mediapipe.tasks.python.components import classifier_options +from mediapipe.tasks.python.components.proto import classifier_options from mediapipe.tasks.python.components.containers import category as category_module from mediapipe.tasks.python.components.containers import classifications as classifications_module from mediapipe.tasks.python.core import base_options as base_options_module from mediapipe.tasks.python.test import test_util -from mediapipe.tasks.python.vision import image_classification +from mediapipe.tasks.python.vision import image_classifier from mediapipe.tasks.python.vision.core import vision_task_running_mode as running_mode_module _BaseOptions = base_options_module.BaseOptions @@ -34,8 +34,8 @@ _ClassificationEntry = classifications_module.ClassificationEntry _Classifications = classifications_module.Classifications _ClassificationResult = classifications_module.ClassificationResult _Image = image_module.Image -_ImageClassifier = image_classification.ImageClassifier -_ImageClassifierOptions = image_classification.ImageClassifierOptions +_ImageClassifier = image_classifier.ImageClassifier +_ImageClassifierOptions = image_classifier.ImageClassifierOptions _RUNNING_MODE = running_mode_module.VisionTaskRunningMode _MODEL_FILE = 'mobilenet_v2_1.0_224.tflite' diff --git a/mediapipe/tasks/python/vision/BUILD b/mediapipe/tasks/python/vision/BUILD index 5e5bfe3ba..a724f149a 100644 --- a/mediapipe/tasks/python/vision/BUILD +++ b/mediapipe/tasks/python/vision/BUILD @@ -38,9 +38,9 @@ py_library( ) py_library( - name = "image_classification", + name = "image_classifier", srcs = [ - "image_classification.py", + "image_classifier.py", ], deps = [ "//mediapipe/python:_framework_bindings", diff --git a/mediapipe/tasks/python/vision/image_classification.py b/mediapipe/tasks/python/vision/image_classifier.py similarity index 100% rename from mediapipe/tasks/python/vision/image_classification.py rename to mediapipe/tasks/python/vision/image_classifier.py From 68fea17e30a5c316f30617e1bda2e53336b41108 Mon Sep 17 00:00:00 2001 From: kinaryml Date: Thu, 29 Sep 2022 13:34:10 -0700 Subject: [PATCH 10/55] Removed unused dependencies in BUILD --- mediapipe/tasks/python/components/BUILD | 9 --------- 1 file changed, 9 deletions(-) diff --git a/mediapipe/tasks/python/components/BUILD b/mediapipe/tasks/python/components/BUILD index 4094b7f7f..00fd4061c 100644 --- a/mediapipe/tasks/python/components/BUILD +++ b/mediapipe/tasks/python/components/BUILD @@ -17,12 +17,3 @@ package(default_visibility = ["//mediapipe/tasks:internal"]) licenses(["notice"]) - -py_library( - name = "classifier_options", - srcs = ["classifier_options.py"], - deps = [ - "//mediapipe/tasks/cc/components:classifier_options_py_pb2", - "//mediapipe/tasks/python/core:optional_dependencies", - ], -) From 85af7ac9bcbf024178d8f389c8251aed2c5da36d Mon Sep 17 00:00:00 2001 From: kinaryml Date: Thu, 29 Sep 2022 13:36:07 -0700 Subject: [PATCH 11/55] Removed unused BUILD --- mediapipe/tasks/python/components/BUILD | 19 ------------------- 1 file changed, 19 deletions(-) delete mode 100644 mediapipe/tasks/python/components/BUILD diff --git a/mediapipe/tasks/python/components/BUILD b/mediapipe/tasks/python/components/BUILD deleted file mode 100644 index 00fd4061c..000000000 --- a/mediapipe/tasks/python/components/BUILD +++ /dev/null @@ -1,19 +0,0 @@ -# Copyright 2022 The MediaPipe Authors. All Rights Reserved. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -# Placeholder for internal Python strict library compatibility macro. - -package(default_visibility = ["//mediapipe/tasks:internal"]) - -licenses(["notice"]) From f241630b56261ac5009ba6638c7f5b8db9e699e1 Mon Sep 17 00:00:00 2001 From: kinaryml Date: Mon, 3 Oct 2022 15:11:48 -0700 Subject: [PATCH 12/55] Revised implementation to align with recent changes --- .../components/containers/{ => proto}/BUILD | 2 +- .../containers/{ => proto}/__init__.py | 0 .../containers/{ => proto}/bounding_box.py | 0 .../containers/{ => proto}/category.py | 0 .../containers/{ => proto}/classifications.py | 4 ++-- .../containers/{ => proto}/detections.py | 0 .../components/{ => processors}/proto/BUILD | 3 +-- .../proto/classifier_options.py | 2 +- mediapipe/tasks/python/test/vision/BUILD | 8 ++++---- .../test/vision/image_classifier_test.py | 14 ++++++------- mediapipe/tasks/python/vision/BUILD | 6 +++--- .../tasks/python/vision/image_classifier.py | 20 ++++++++++--------- 12 files changed, 30 insertions(+), 29 deletions(-) rename mediapipe/tasks/python/components/containers/{ => proto}/BUILD (95%) rename mediapipe/tasks/python/components/containers/{ => proto}/__init__.py (100%) rename mediapipe/tasks/python/components/containers/{ => proto}/bounding_box.py (100%) rename mediapipe/tasks/python/components/containers/{ => proto}/category.py (100%) rename mediapipe/tasks/python/components/containers/{ => proto}/classifications.py (96%) rename mediapipe/tasks/python/components/containers/{ => proto}/detections.py (100%) rename mediapipe/tasks/python/components/{ => processors}/proto/BUILD (91%) rename mediapipe/tasks/python/components/{ => processors}/proto/classifier_options.py (97%) diff --git a/mediapipe/tasks/python/components/containers/BUILD b/mediapipe/tasks/python/components/containers/proto/BUILD similarity index 95% rename from mediapipe/tasks/python/components/containers/BUILD rename to mediapipe/tasks/python/components/containers/proto/BUILD index 450111161..d46df22da 100644 --- a/mediapipe/tasks/python/components/containers/BUILD +++ b/mediapipe/tasks/python/components/containers/proto/BUILD @@ -53,7 +53,7 @@ py_library( srcs = ["classifications.py"], deps = [ ":category", - "//mediapipe/tasks/cc/components/containers:classifications_py_pb2", + "//mediapipe/tasks/cc/components/containers/proto:classifications_py_pb2", "//mediapipe/tasks/python/core:optional_dependencies", ], ) diff --git a/mediapipe/tasks/python/components/containers/__init__.py b/mediapipe/tasks/python/components/containers/proto/__init__.py similarity index 100% rename from mediapipe/tasks/python/components/containers/__init__.py rename to mediapipe/tasks/python/components/containers/proto/__init__.py diff --git a/mediapipe/tasks/python/components/containers/bounding_box.py b/mediapipe/tasks/python/components/containers/proto/bounding_box.py similarity index 100% rename from mediapipe/tasks/python/components/containers/bounding_box.py rename to mediapipe/tasks/python/components/containers/proto/bounding_box.py diff --git a/mediapipe/tasks/python/components/containers/category.py b/mediapipe/tasks/python/components/containers/proto/category.py similarity index 100% rename from mediapipe/tasks/python/components/containers/category.py rename to mediapipe/tasks/python/components/containers/proto/category.py diff --git a/mediapipe/tasks/python/components/containers/classifications.py b/mediapipe/tasks/python/components/containers/proto/classifications.py similarity index 96% rename from mediapipe/tasks/python/components/containers/classifications.py rename to mediapipe/tasks/python/components/containers/proto/classifications.py index 19c5decde..2c43370b2 100644 --- a/mediapipe/tasks/python/components/containers/classifications.py +++ b/mediapipe/tasks/python/components/containers/proto/classifications.py @@ -16,8 +16,8 @@ import dataclasses from typing import Any, List, Optional -from mediapipe.tasks.cc.components.containers import classifications_pb2 -from mediapipe.tasks.python.components.containers import category as category_module +from mediapipe.tasks.cc.components.containers.proto import classifications_pb2 +from mediapipe.tasks.python.components.containers.proto import category as category_module from mediapipe.tasks.python.core.optional_dependencies import doc_controls _ClassificationEntryProto = classifications_pb2.ClassificationEntry diff --git a/mediapipe/tasks/python/components/containers/detections.py b/mediapipe/tasks/python/components/containers/proto/detections.py similarity index 100% rename from mediapipe/tasks/python/components/containers/detections.py rename to mediapipe/tasks/python/components/containers/proto/detections.py diff --git a/mediapipe/tasks/python/components/proto/BUILD b/mediapipe/tasks/python/components/processors/proto/BUILD similarity index 91% rename from mediapipe/tasks/python/components/proto/BUILD rename to mediapipe/tasks/python/components/processors/proto/BUILD index 7814ec675..814e15d1f 100644 --- a/mediapipe/tasks/python/components/proto/BUILD +++ b/mediapipe/tasks/python/components/processors/proto/BUILD @@ -22,8 +22,7 @@ py_library( name = "classifier_options", srcs = ["classifier_options.py"], deps = [ - "//mediapipe/tasks/cc/components/proto:classifier_options_py_pb2", + "//mediapipe/tasks/cc/components/processors/proto:classifier_options_py_pb2", "//mediapipe/tasks/python/core:optional_dependencies", ], ) - diff --git a/mediapipe/tasks/python/components/proto/classifier_options.py b/mediapipe/tasks/python/components/processors/proto/classifier_options.py similarity index 97% rename from mediapipe/tasks/python/components/proto/classifier_options.py rename to mediapipe/tasks/python/components/processors/proto/classifier_options.py index c73828c13..b4597e57a 100644 --- a/mediapipe/tasks/python/components/proto/classifier_options.py +++ b/mediapipe/tasks/python/components/processors/proto/classifier_options.py @@ -16,7 +16,7 @@ import dataclasses from typing import Any, List, Optional -from mediapipe.tasks.cc.components.proto import classifier_options_pb2 +from mediapipe.tasks.cc.components.processors.proto import classifier_options_pb2 from mediapipe.tasks.python.core.optional_dependencies import doc_controls _ClassifierOptionsProto = classifier_options_pb2.ClassifierOptions diff --git a/mediapipe/tasks/python/test/vision/BUILD b/mediapipe/tasks/python/test/vision/BUILD index df2e72f98..107a78a33 100644 --- a/mediapipe/tasks/python/test/vision/BUILD +++ b/mediapipe/tasks/python/test/vision/BUILD @@ -46,11 +46,11 @@ py_test( ], deps = [ "//mediapipe/python:_framework_bindings", - "//mediapipe/tasks/python/components/proto:classifier_options", - "//mediapipe/tasks/python/components/containers:category", - "//mediapipe/tasks/python/components/containers:classifications", + "//mediapipe/tasks/python/components/processors/proto:classifier_options", + "//mediapipe/tasks/python/components/containers/proto:category", + "//mediapipe/tasks/python/components/containers/proto:classifications", "//mediapipe/tasks/python/core:base_options", - "//mediapipe/tasks/python/test:test_util", + "//mediapipe/tasks/python/test:test_utils", "//mediapipe/tasks/python/vision:image_classifier", "//mediapipe/tasks/python/vision/core:vision_task_running_mode", ], diff --git a/mediapipe/tasks/python/test/vision/image_classifier_test.py b/mediapipe/tasks/python/test/vision/image_classifier_test.py index 3d7d116df..c16587cb5 100644 --- a/mediapipe/tasks/python/test/vision/image_classifier_test.py +++ b/mediapipe/tasks/python/test/vision/image_classifier_test.py @@ -19,11 +19,11 @@ from absl.testing import absltest from absl.testing import parameterized from mediapipe.python._framework_bindings import image as image_module -from mediapipe.tasks.python.components.proto import classifier_options -from mediapipe.tasks.python.components.containers import category as category_module -from mediapipe.tasks.python.components.containers import classifications as classifications_module +from mediapipe.tasks.python.components.processors.proto import classifier_options +from mediapipe.tasks.python.components.containers.proto import category as category_module +from mediapipe.tasks.python.components.containers.proto import classifications as classifications_module from mediapipe.tasks.python.core import base_options as base_options_module -from mediapipe.tasks.python.test import test_util +from mediapipe.tasks.python.test import test_utils from mediapipe.tasks.python.vision import image_classifier from mediapipe.tasks.python.vision.core import vision_task_running_mode as running_mode_module @@ -88,9 +88,9 @@ class ImageClassifierTest(parameterized.TestCase): def setUp(self): super().setUp() - self.test_image = test_util.read_test_image( - test_util.get_test_data_path(_IMAGE_FILE)) - self.model_path = test_util.get_test_data_path(_MODEL_FILE) + self.test_image = _Image.create_from_file( + test_utils.get_test_data_path(_IMAGE_FILE)) + self.model_path = test_utils.get_test_data_path(_MODEL_FILE) @parameterized.parameters( (ModelFileType.FILE_NAME, 4, _EXPECTED_CLASSIFICATION_RESULT), diff --git a/mediapipe/tasks/python/vision/BUILD b/mediapipe/tasks/python/vision/BUILD index a724f149a..69fd8f2a6 100644 --- a/mediapipe/tasks/python/vision/BUILD +++ b/mediapipe/tasks/python/vision/BUILD @@ -46,9 +46,9 @@ py_library( "//mediapipe/python:_framework_bindings", "//mediapipe/python:packet_creator", "//mediapipe/python:packet_getter", - "//mediapipe/tasks/cc/vision/image_classifier/proto:image_classifier_options_py_pb2", - "//mediapipe/tasks/python/components/proto:classifier_options", - "//mediapipe/tasks/python/components/containers:classifications", + "//mediapipe/tasks/cc/vision/image_classifier/proto:image_classifier_graph_options_py_pb2", + "//mediapipe/tasks/python/components/processors/proto:classifier_options", + "//mediapipe/tasks/python/components/containers/proto:classifications", "//mediapipe/tasks/python/core:base_options", "//mediapipe/tasks/python/core:optional_dependencies", "//mediapipe/tasks/python/core:task_info", diff --git a/mediapipe/tasks/python/vision/image_classifier.py b/mediapipe/tasks/python/vision/image_classifier.py index c240ffedf..809429c37 100644 --- a/mediapipe/tasks/python/vision/image_classifier.py +++ b/mediapipe/tasks/python/vision/image_classifier.py @@ -21,9 +21,9 @@ from mediapipe.python import packet_getter from mediapipe.python._framework_bindings import image as image_module from mediapipe.python._framework_bindings import packet as packet_module from mediapipe.python._framework_bindings import task_runner as task_runner_module -from mediapipe.tasks.cc.vision.image_classifier.proto import image_classifier_options_pb2 -from mediapipe.tasks.python.components.proto import classifier_options -from mediapipe.tasks.python.components.containers import classifications as classifications_module +from mediapipe.tasks.cc.vision.image_classifier.proto import image_classifier_graph_options_pb2 +from mediapipe.tasks.python.components.processors.proto import classifier_options +from mediapipe.tasks.python.components.containers.proto import classifications as classifications_module from mediapipe.tasks.python.core import base_options as base_options_module from mediapipe.tasks.python.core import task_info as task_info_module from mediapipe.tasks.python.core.optional_dependencies import doc_controls @@ -31,7 +31,7 @@ from mediapipe.tasks.python.vision.core import base_vision_task_api from mediapipe.tasks.python.vision.core import vision_task_running_mode as running_mode_module _BaseOptions = base_options_module.BaseOptions -_ImageClassifierOptionsProto = image_classifier_options_pb2.ImageClassifierOptions +_ImageClassifierGraphOptionsProto = image_classifier_graph_options_pb2.ImageClassifierGraphOptions _ClassifierOptions = classifier_options.ClassifierOptions _RunningMode = running_mode_module.VisionTaskRunningMode _TaskInfo = task_info_module.TaskInfo @@ -41,7 +41,7 @@ _CLASSIFICATION_RESULT_OUT_STREAM_NAME = 'classification_result_out' _CLASSIFICATION_RESULT_TAG = 'CLASSIFICATION_RESULT' _IMAGE_IN_STREAM_NAME = 'image_in' _IMAGE_TAG = 'IMAGE' -_TASK_GRAPH_NAME = 'mediapipe.tasks.vision.ImageClassifierGraph' +_TASK_GRAPH_NAME = 'mediapipe.tasks.vision.image_classifier.ImageClassifierGraph' @dataclasses.dataclass @@ -70,13 +70,13 @@ class ImageClassifierOptions: None]] = None @doc_controls.do_not_generate_docs - def to_pb2(self) -> _ImageClassifierOptionsProto: + def to_pb2(self) -> _ImageClassifierGraphOptionsProto: """Generates an ImageClassifierOptions protobuf object.""" base_options_proto = self.base_options.to_pb2() base_options_proto.use_stream_mode = False if self.running_mode == _RunningMode.IMAGE else True classifier_options_proto = self.classifier_options.to_pb2() - return _ImageClassifierOptionsProto( + return _ImageClassifierGraphOptionsProto( base_options=base_options_proto, classifier_options=classifier_options_proto ) @@ -138,7 +138,9 @@ class ImageClassifier(base_vision_task_api.BaseVisionTaskApi): task_info = _TaskInfo( task_graph=_TASK_GRAPH_NAME, - input_streams=[':'.join([_IMAGE_TAG, _IMAGE_IN_STREAM_NAME])], + input_streams=[ + ':'.join([_IMAGE_TAG, _IMAGE_IN_STREAM_NAME]), + ], output_streams=[ ':'.join([_CLASSIFICATION_RESULT_TAG, _CLASSIFICATION_RESULT_OUT_STREAM_NAME]) @@ -153,7 +155,7 @@ class ImageClassifier(base_vision_task_api.BaseVisionTaskApi): # TODO: Create an Image class for MediaPipe Tasks. def classify( self, - image: image_module.Image + image: image_module.Image, ) -> classifications_module.ClassificationResult: """Performs image classification on the provided MediaPipe Image. From 64d5c159c6f509b656b8e28ed91375c24db3e35b Mon Sep 17 00:00:00 2001 From: kinaryml Date: Mon, 3 Oct 2022 16:14:35 -0700 Subject: [PATCH 13/55] Fixed an auto formatting issue that caused classification_posprocessing_graph's registration to fail --- .../processors/classification_postprocessing_graph.cc | 4 ++-- .../tasks/python/test/vision/image_classifier_test.py | 8 ++++---- 2 files changed, 6 insertions(+), 6 deletions(-) diff --git a/mediapipe/tasks/cc/components/processors/classification_postprocessing_graph.cc b/mediapipe/tasks/cc/components/processors/classification_postprocessing_graph.cc index 35adab687..649ff2c11 100644 --- a/mediapipe/tasks/cc/components/processors/classification_postprocessing_graph.cc +++ b/mediapipe/tasks/cc/components/processors/classification_postprocessing_graph.cc @@ -507,8 +507,8 @@ class ClassificationPostprocessingGraph : public mediapipe::Subgraph { } }; -REGISTER_MEDIAPIPE_GRAPH(::mediapipe::tasks::components::processors:: - ClassificationPostprocessingGraph); // NOLINT +REGISTER_MEDIAPIPE_GRAPH( + ::mediapipe::tasks::components::processors::ClassificationPostprocessingGraph); // NOLINT } // namespace processors } // namespace components diff --git a/mediapipe/tasks/python/test/vision/image_classifier_test.py b/mediapipe/tasks/python/test/vision/image_classifier_test.py index c16587cb5..0ce70395e 100644 --- a/mediapipe/tasks/python/test/vision/image_classifier_test.py +++ b/mediapipe/tasks/python/test/vision/image_classifier_test.py @@ -48,22 +48,22 @@ _EXPECTED_CLASSIFICATION_RESULT = _ClassificationResult( categories=[ _Category( index=934, - score=0.7952049970626831, + score=0.7939587831497192, display_name='', category_name='cheeseburger'), _Category( index=932, - score=0.02732999622821808, + score=0.02739289402961731, display_name='', category_name='bagel'), _Category( index=925, - score=0.01933487318456173, + score=0.01934075355529785, display_name='', category_name='guacamole'), _Category( index=963, - score=0.006279350258409977, + score=0.006327860057353973, display_name='', category_name='meat loaf') ], From a22a5283d272771323186af02ce061c74a32c501 Mon Sep 17 00:00:00 2001 From: kinaryml Date: Wed, 5 Oct 2022 04:43:30 -0700 Subject: [PATCH 14/55] Adjusted namespaces --- .../tasks/python/components/containers/{proto => }/BUILD | 0 .../python/components/containers/{proto => }/__init__.py | 0 .../components/containers/{proto => }/bounding_box.py | 0 .../python/components/containers/{proto => }/category.py | 0 .../components/containers/{proto => }/classifications.py | 2 +- .../python/components/containers/{proto => }/detections.py | 0 .../tasks/python/components/processors/{proto => }/BUILD | 0 .../components/processors/{proto => }/classifier_options.py | 0 mediapipe/tasks/python/test/vision/BUILD | 6 +++--- mediapipe/tasks/python/test/vision/image_classifier_test.py | 6 +++--- mediapipe/tasks/python/vision/BUILD | 4 ++-- mediapipe/tasks/python/vision/image_classifier.py | 4 ++-- 12 files changed, 11 insertions(+), 11 deletions(-) rename mediapipe/tasks/python/components/containers/{proto => }/BUILD (100%) rename mediapipe/tasks/python/components/containers/{proto => }/__init__.py (100%) rename mediapipe/tasks/python/components/containers/{proto => }/bounding_box.py (100%) rename mediapipe/tasks/python/components/containers/{proto => }/category.py (100%) rename mediapipe/tasks/python/components/containers/{proto => }/classifications.py (98%) rename mediapipe/tasks/python/components/containers/{proto => }/detections.py (100%) rename mediapipe/tasks/python/components/processors/{proto => }/BUILD (100%) rename mediapipe/tasks/python/components/processors/{proto => }/classifier_options.py (100%) diff --git a/mediapipe/tasks/python/components/containers/proto/BUILD b/mediapipe/tasks/python/components/containers/BUILD similarity index 100% rename from mediapipe/tasks/python/components/containers/proto/BUILD rename to mediapipe/tasks/python/components/containers/BUILD diff --git a/mediapipe/tasks/python/components/containers/proto/__init__.py b/mediapipe/tasks/python/components/containers/__init__.py similarity index 100% rename from mediapipe/tasks/python/components/containers/proto/__init__.py rename to mediapipe/tasks/python/components/containers/__init__.py diff --git a/mediapipe/tasks/python/components/containers/proto/bounding_box.py b/mediapipe/tasks/python/components/containers/bounding_box.py similarity index 100% rename from mediapipe/tasks/python/components/containers/proto/bounding_box.py rename to mediapipe/tasks/python/components/containers/bounding_box.py diff --git a/mediapipe/tasks/python/components/containers/proto/category.py b/mediapipe/tasks/python/components/containers/category.py similarity index 100% rename from mediapipe/tasks/python/components/containers/proto/category.py rename to mediapipe/tasks/python/components/containers/category.py diff --git a/mediapipe/tasks/python/components/containers/proto/classifications.py b/mediapipe/tasks/python/components/containers/classifications.py similarity index 98% rename from mediapipe/tasks/python/components/containers/proto/classifications.py rename to mediapipe/tasks/python/components/containers/classifications.py index 2c43370b2..c51816e07 100644 --- a/mediapipe/tasks/python/components/containers/proto/classifications.py +++ b/mediapipe/tasks/python/components/containers/classifications.py @@ -17,7 +17,7 @@ import dataclasses from typing import Any, List, Optional from mediapipe.tasks.cc.components.containers.proto import classifications_pb2 -from mediapipe.tasks.python.components.containers.proto import category as category_module +from mediapipe.tasks.python.components.containers import category as category_module from mediapipe.tasks.python.core.optional_dependencies import doc_controls _ClassificationEntryProto = classifications_pb2.ClassificationEntry diff --git a/mediapipe/tasks/python/components/containers/proto/detections.py b/mediapipe/tasks/python/components/containers/detections.py similarity index 100% rename from mediapipe/tasks/python/components/containers/proto/detections.py rename to mediapipe/tasks/python/components/containers/detections.py diff --git a/mediapipe/tasks/python/components/processors/proto/BUILD b/mediapipe/tasks/python/components/processors/BUILD similarity index 100% rename from mediapipe/tasks/python/components/processors/proto/BUILD rename to mediapipe/tasks/python/components/processors/BUILD diff --git a/mediapipe/tasks/python/components/processors/proto/classifier_options.py b/mediapipe/tasks/python/components/processors/classifier_options.py similarity index 100% rename from mediapipe/tasks/python/components/processors/proto/classifier_options.py rename to mediapipe/tasks/python/components/processors/classifier_options.py diff --git a/mediapipe/tasks/python/test/vision/BUILD b/mediapipe/tasks/python/test/vision/BUILD index 107a78a33..09bbf1958 100644 --- a/mediapipe/tasks/python/test/vision/BUILD +++ b/mediapipe/tasks/python/test/vision/BUILD @@ -46,9 +46,9 @@ py_test( ], deps = [ "//mediapipe/python:_framework_bindings", - "//mediapipe/tasks/python/components/processors/proto:classifier_options", - "//mediapipe/tasks/python/components/containers/proto:category", - "//mediapipe/tasks/python/components/containers/proto:classifications", + "//mediapipe/tasks/python/components/processors:classifier_options", + "//mediapipe/tasks/python/components/containers:category", + "//mediapipe/tasks/python/components/containers:classifications", "//mediapipe/tasks/python/core:base_options", "//mediapipe/tasks/python/test:test_utils", "//mediapipe/tasks/python/vision:image_classifier", diff --git a/mediapipe/tasks/python/test/vision/image_classifier_test.py b/mediapipe/tasks/python/test/vision/image_classifier_test.py index 0ce70395e..7fbd96ddc 100644 --- a/mediapipe/tasks/python/test/vision/image_classifier_test.py +++ b/mediapipe/tasks/python/test/vision/image_classifier_test.py @@ -19,9 +19,9 @@ from absl.testing import absltest from absl.testing import parameterized from mediapipe.python._framework_bindings import image as image_module -from mediapipe.tasks.python.components.processors.proto import classifier_options -from mediapipe.tasks.python.components.containers.proto import category as category_module -from mediapipe.tasks.python.components.containers.proto import classifications as classifications_module +from mediapipe.tasks.python.components.processors import classifier_options +from mediapipe.tasks.python.components.containers import category as category_module +from mediapipe.tasks.python.components.containers import classifications as classifications_module from mediapipe.tasks.python.core import base_options as base_options_module from mediapipe.tasks.python.test import test_utils from mediapipe.tasks.python.vision import image_classifier diff --git a/mediapipe/tasks/python/vision/BUILD b/mediapipe/tasks/python/vision/BUILD index 69fd8f2a6..7d44e3326 100644 --- a/mediapipe/tasks/python/vision/BUILD +++ b/mediapipe/tasks/python/vision/BUILD @@ -47,8 +47,8 @@ py_library( "//mediapipe/python:packet_creator", "//mediapipe/python:packet_getter", "//mediapipe/tasks/cc/vision/image_classifier/proto:image_classifier_graph_options_py_pb2", - "//mediapipe/tasks/python/components/processors/proto:classifier_options", - "//mediapipe/tasks/python/components/containers/proto:classifications", + "//mediapipe/tasks/python/components/processors:classifier_options", + "//mediapipe/tasks/python/components/containers:classifications", "//mediapipe/tasks/python/core:base_options", "//mediapipe/tasks/python/core:optional_dependencies", "//mediapipe/tasks/python/core:task_info", diff --git a/mediapipe/tasks/python/vision/image_classifier.py b/mediapipe/tasks/python/vision/image_classifier.py index 809429c37..94dcd4d70 100644 --- a/mediapipe/tasks/python/vision/image_classifier.py +++ b/mediapipe/tasks/python/vision/image_classifier.py @@ -22,8 +22,8 @@ from mediapipe.python._framework_bindings import image as image_module from mediapipe.python._framework_bindings import packet as packet_module from mediapipe.python._framework_bindings import task_runner as task_runner_module from mediapipe.tasks.cc.vision.image_classifier.proto import image_classifier_graph_options_pb2 -from mediapipe.tasks.python.components.processors.proto import classifier_options -from mediapipe.tasks.python.components.containers.proto import classifications as classifications_module +from mediapipe.tasks.python.components.processors import classifier_options +from mediapipe.tasks.python.components.containers import classifications as classifications_module from mediapipe.tasks.python.core import base_options as base_options_module from mediapipe.tasks.python.core import task_info as task_info_module from mediapipe.tasks.python.core.optional_dependencies import doc_controls From e250c903f5b6e8c87b9a4cef175137df95802804 Mon Sep 17 00:00:00 2001 From: kinaryml Date: Wed, 5 Oct 2022 05:24:52 -0700 Subject: [PATCH 15/55] Added the ClassifyForVideo API --- .../test/vision/image_classifier_test.py | 13 +++++ .../tasks/python/vision/image_classifier.py | 52 ++++++++++++++++--- 2 files changed, 58 insertions(+), 7 deletions(-) diff --git a/mediapipe/tasks/python/test/vision/image_classifier_test.py b/mediapipe/tasks/python/test/vision/image_classifier_test.py index 7fbd96ddc..5143a28db 100644 --- a/mediapipe/tasks/python/test/vision/image_classifier_test.py +++ b/mediapipe/tasks/python/test/vision/image_classifier_test.py @@ -15,6 +15,7 @@ import enum +import numpy as np from absl.testing import absltest from absl.testing import parameterized @@ -121,6 +122,18 @@ class ImageClassifierTest(parameterized.TestCase): # a context. classifier.close() + def test_classify_for_video(self): + classifier_options = _ClassifierOptions(max_results=4) + options = _ImageClassifierOptions( + base_options=_BaseOptions(model_asset_path=self.model_path), + running_mode=_RUNNING_MODE.VIDEO, + classifier_options=classifier_options) + with _ImageClassifier.create_from_options(options) as classifier: + for timestamp in range(0, 300, 30): + classification_result = classifier.classify_for_video( + self.test_image, timestamp) + self.assertEqual(classification_result, _EXPECTED_CLASSIFICATION_RESULT) + if __name__ == '__main__': absltest.main() diff --git a/mediapipe/tasks/python/vision/image_classifier.py b/mediapipe/tasks/python/vision/image_classifier.py index 94dcd4d70..b3bafa113 100644 --- a/mediapipe/tasks/python/vision/image_classifier.py +++ b/mediapipe/tasks/python/vision/image_classifier.py @@ -14,7 +14,7 @@ """MediaPipe image classifier task.""" import dataclasses -from typing import Callable, List, Mapping, Optional +from typing import Callable, Mapping, Optional from mediapipe.python import packet_creator from mediapipe.python import packet_getter @@ -40,6 +40,7 @@ _TaskRunner = task_runner_module.TaskRunner _CLASSIFICATION_RESULT_OUT_STREAM_NAME = 'classification_result_out' _CLASSIFICATION_RESULT_TAG = 'CLASSIFICATION_RESULT' _IMAGE_IN_STREAM_NAME = 'image_in' +_IMAGE_OUT_STREAM_NAME = 'image_out' _IMAGE_TAG = 'IMAGE' _TASK_GRAPH_NAME = 'mediapipe.tasks.vision.image_classifier.ImageClassifierGraph' @@ -66,8 +67,8 @@ class ImageClassifierOptions: running_mode: _RunningMode = _RunningMode.IMAGE classifier_options: _ClassifierOptions = _ClassifierOptions() result_callback: Optional[ - Callable[[classifications_module.ClassificationResult], - None]] = None + Callable[[classifications_module.ClassificationResult, image_module.Image, + int], None]] = None @doc_controls.do_not_generate_docs def to_pb2(self) -> _ImageClassifierGraphOptionsProto: @@ -134,7 +135,9 @@ class ImageClassifier(base_vision_task_api.BaseVisionTaskApi): classifications_module.Classifications.create_from_pb2(classification) for classification in classification_result_proto.classifications ]) - options.result_callback(classification_result) + image = packet_getter.get_image(output_packets[_IMAGE_OUT_STREAM_NAME]) + timestamp = output_packets[_IMAGE_OUT_STREAM_NAME].timestamp + options.result_callback(classification_result, image, timestamp) task_info = _TaskInfo( task_graph=_TASK_GRAPH_NAME, @@ -143,7 +146,8 @@ class ImageClassifier(base_vision_task_api.BaseVisionTaskApi): ], output_streams=[ ':'.join([_CLASSIFICATION_RESULT_TAG, - _CLASSIFICATION_RESULT_OUT_STREAM_NAME]) + _CLASSIFICATION_RESULT_OUT_STREAM_NAME]), + ':'.join([_IMAGE_TAG, _IMAGE_OUT_STREAM_NAME]) ], task_options=options) return cls( @@ -175,6 +179,40 @@ class ImageClassifier(base_vision_task_api.BaseVisionTaskApi): output_packets[_CLASSIFICATION_RESULT_OUT_STREAM_NAME]) return classifications_module.ClassificationResult([ - classifications_module.Classifications.create_from_pb2(classification) - for classification in classification_result_proto.classifications + classifications_module.Classifications.create_from_pb2(classification) + for classification in classification_result_proto.classifications + ]) + + def classify_for_video( + self, image: image_module.Image, + timestamp_ms: int + ) -> classifications_module.ClassificationResult: + """Performs image classification on the provided video frames. + + Only use this method when the ImageClassifier is created with the video + running mode. It's required to provide the video frame's timestamp (in + milliseconds) along with the video frame. The input timestamps should be + monotonically increasing for adjacent calls of this method. + + Args: + image: MediaPipe Image. + timestamp_ms: The timestamp of the input video frame in milliseconds. + + Returns: + A classification result object that contains a list of classifications. + + Raises: + ValueError: If any of the input arguments is invalid. + RuntimeError: If image classification failed to run. + """ + output_packets = self._process_video_data({ + _IMAGE_IN_STREAM_NAME: + packet_creator.create_image(image).at(timestamp_ms) + }) + classification_result_proto = packet_getter.get_proto( + output_packets[_CLASSIFICATION_RESULT_OUT_STREAM_NAME]) + + return classifications_module.ClassificationResult([ + classifications_module.Classifications.create_from_pb2(classification) + for classification in classification_result_proto.classifications ]) From cb806071ba30ba66cfa1ec26d80b9b5c0b3e2501 Mon Sep 17 00:00:00 2001 From: kinaryml Date: Fri, 7 Oct 2022 22:26:49 -0700 Subject: [PATCH 16/55] Added more tests and updated the APIs to use a new constant --- .../test/vision/image_classifier_test.py | 338 ++++++++++++++++-- .../tasks/python/vision/image_classifier.py | 45 ++- 2 files changed, 355 insertions(+), 28 deletions(-) diff --git a/mediapipe/tasks/python/test/vision/image_classifier_test.py b/mediapipe/tasks/python/test/vision/image_classifier_test.py index 5143a28db..073674c3f 100644 --- a/mediapipe/tasks/python/test/vision/image_classifier_test.py +++ b/mediapipe/tasks/python/test/vision/image_classifier_test.py @@ -14,6 +14,7 @@ """Tests for image classifier.""" import enum +from unittest import mock import numpy as np from absl.testing import absltest @@ -41,33 +42,46 @@ _RUNNING_MODE = running_mode_module.VisionTaskRunningMode _MODEL_FILE = 'mobilenet_v2_1.0_224.tflite' _IMAGE_FILE = 'burger.jpg' +_EXPECTED_CATEGORIES = [ + _Category( + index=934, + score=0.7939587831497192, + display_name='', + category_name='cheeseburger'), + _Category( + index=932, + score=0.02739289402961731, + display_name='', + category_name='bagel'), + _Category( + index=925, + score=0.01934075355529785, + display_name='', + category_name='guacamole'), + _Category( + index=963, + score=0.006327860057353973, + display_name='', + category_name='meat loaf') +] _EXPECTED_CLASSIFICATION_RESULT = _ClassificationResult( classifications=[ _Classifications( entries=[ _ClassificationEntry( - categories=[ - _Category( - index=934, - score=0.7939587831497192, - display_name='', - category_name='cheeseburger'), - _Category( - index=932, - score=0.02739289402961731, - display_name='', - category_name='bagel'), - _Category( - index=925, - score=0.01934075355529785, - display_name='', - category_name='guacamole'), - _Category( - index=963, - score=0.006327860057353973, - display_name='', - category_name='meat loaf') - ], + categories=_EXPECTED_CATEGORIES, + timestamp_ms=0 + ) + ], + head_index=0, + head_name='probability') + ]) +_EMPTY_CLASSIFICATION_RESULT = _ClassificationResult( + classifications=[ + _Classifications( + entries=[ + _ClassificationEntry( + categories=[], timestamp_ms=0 ) ], @@ -93,6 +107,36 @@ class ImageClassifierTest(parameterized.TestCase): test_utils.get_test_data_path(_IMAGE_FILE)) self.model_path = test_utils.get_test_data_path(_MODEL_FILE) + def test_create_from_file_succeeds_with_valid_model_path(self): + # Creates with default option and valid model file successfully. + with _ImageClassifier.create_from_model_path(self.model_path) as classifier: + self.assertIsInstance(classifier, _ImageClassifier) + + def test_create_from_options_succeeds_with_valid_model_path(self): + # Creates with options containing model file successfully. + base_options = _BaseOptions(model_asset_path=self.model_path) + options = _ImageClassifierOptions(base_options=base_options) + with _ImageClassifier.create_from_options(options) as classifier: + self.assertIsInstance(classifier, _ImageClassifier) + + def test_create_from_options_fails_with_invalid_model_path(self): + # Invalid empty model path. + with self.assertRaisesRegex( + ValueError, + r"ExternalFile must specify at least one of 'file_content', " + r"'file_name' or 'file_descriptor_meta'."): + base_options = _BaseOptions(model_asset_path='') + options = _ImageClassifierOptions(base_options=base_options) + _ImageClassifier.create_from_options(options) + + def test_create_from_options_succeeds_with_valid_model_content(self): + # Creates with options containing model content successfully. + with open(self.model_path, 'rb') as f: + base_options = _BaseOptions(model_asset_buffer=f.read()) + options = _ImageClassifierOptions(base_options=base_options) + classifier = _ImageClassifier.create_from_options(options) + self.assertIsInstance(classifier, _ImageClassifier) + @parameterized.parameters( (ModelFileType.FILE_NAME, 4, _EXPECTED_CLASSIFICATION_RESULT), (ModelFileType.FILE_CONTENT, 4, _EXPECTED_CLASSIFICATION_RESULT)) @@ -122,6 +166,183 @@ class ImageClassifierTest(parameterized.TestCase): # a context. classifier.close() + @parameterized.parameters( + (ModelFileType.FILE_NAME, 4, _EXPECTED_CLASSIFICATION_RESULT), + (ModelFileType.FILE_CONTENT, 4, _EXPECTED_CLASSIFICATION_RESULT)) + def test_classify_in_context(self, model_file_type, max_results, + expected_classification_result): + if model_file_type is ModelFileType.FILE_NAME: + base_options = _BaseOptions(model_asset_path=self.model_path) + elif model_file_type is ModelFileType.FILE_CONTENT: + with open(self.model_path, 'rb') as f: + model_content = f.read() + base_options = _BaseOptions(model_asset_buffer=model_content) + else: + # Should never happen + raise ValueError('model_file_type is invalid.') + + classifier_options = _ClassifierOptions(max_results=max_results) + options = _ImageClassifierOptions( + base_options=base_options, classifier_options=classifier_options) + with _ImageClassifier.create_from_options(options) as classifier: + # Performs image classification on the input. + image_result = classifier.classify(self.test_image) + # Comparing results. + self.assertEqual(image_result, expected_classification_result) + + def test_score_threshold_option(self): + classifier_options = _ClassifierOptions(score_threshold=_SCORE_THRESHOLD) + options = _ImageClassifierOptions( + base_options=_BaseOptions(model_asset_path=self.model_path), + classifier_options=classifier_options) + with _ImageClassifier.create_from_options(options) as classifier: + # Performs image classification on the input. + image_result = classifier.classify(self.test_image) + classifications = image_result.classifications + + for classification in classifications: + for entry in classification.entries: + score = entry.categories[0].score + self.assertGreaterEqual( + score, _SCORE_THRESHOLD, + f'Classification with score lower than threshold found. ' + f'{classification}') + + def test_max_results_option(self): + classifier_options = _ClassifierOptions(score_threshold=_SCORE_THRESHOLD) + options = _ImageClassifierOptions( + base_options=_BaseOptions(model_asset_path=self.model_path), + classifier_options=classifier_options) + with _ImageClassifier.create_from_options(options) as classifier: + # Performs image classification on the input. + image_result = classifier.classify(self.test_image) + categories = image_result.classifications[0].entries[0].categories + + self.assertLessEqual( + len(categories), _MAX_RESULTS, 'Too many results returned.') + + def test_allow_list_option(self): + classifier_options = _ClassifierOptions(category_allowlist=_ALLOW_LIST) + options = _ImageClassifierOptions( + base_options=_BaseOptions(model_asset_path=self.model_path), + classifier_options=classifier_options) + with _ImageClassifier.create_from_options(options) as classifier: + # Performs image classification on the input. + image_result = classifier.classify(self.test_image) + classifications = image_result.classifications + + for classification in classifications: + for entry in classification.entries: + label = entry.categories[0].category_name + self.assertIn(label, _ALLOW_LIST, + f'Label {label} found but not in label allow list') + + def test_deny_list_option(self): + classifier_options = _ClassifierOptions(category_denylist=_DENY_LIST) + options = _ImageClassifierOptions( + base_options=_BaseOptions(model_asset_path=self.model_path), + classifier_options=classifier_options) + with _ImageClassifier.create_from_options(options) as classifier: + # Performs image classification on the input. + image_result = classifier.classify(self.test_image) + classifications = image_result.classifications + + for classification in classifications: + for entry in classification.entries: + label = entry.categories[0].category_name + self.assertNotIn(label, _DENY_LIST, + f'Label {label} found but in deny list.') + + def test_combined_allowlist_and_denylist(self): + # Fails with combined allowlist and denylist + with self.assertRaisesRegex( + ValueError, + r'`category_allowlist` and `category_denylist` are mutually ' + r'exclusive options.'): + classifier_options = _ClassifierOptions(category_allowlist=['foo'], + category_denylist=['bar']) + options = _ImageClassifierOptions( + base_options=_BaseOptions(model_asset_path=self.model_path), + classifier_options=classifier_options) + with _ImageClassifier.create_from_options(options) as unused_classifier: + pass + + def test_empty_classification_outputs(self): + classifier_options = _ClassifierOptions(score_threshold=1) + options = _ImageClassifierOptions( + base_options=_BaseOptions(model_asset_path=self.model_path), + classifier_options=classifier_options) + with _ImageClassifier.create_from_options(options) as classifier: + # Performs image classification on the input. + image_result = classifier.classify(self.test_image) + self.assertEmpty(image_result.classifications[0].entries[0].categories) + + def test_missing_result_callback(self): + options = _ImageClassifierOptions( + base_options=_BaseOptions(model_asset_path=self.model_path), + running_mode=_RUNNING_MODE.LIVE_STREAM) + with self.assertRaisesRegex(ValueError, + r'result callback must be provided'): + with _ImageClassifier.create_from_options(options) as unused_classifier: + pass + + @parameterized.parameters((_RUNNING_MODE.IMAGE), (_RUNNING_MODE.VIDEO)) + def test_illegal_result_callback(self, running_mode): + options = _ImageClassifierOptions( + base_options=_BaseOptions(model_asset_path=self.model_path), + running_mode=running_mode, + result_callback=mock.MagicMock()) + with self.assertRaisesRegex(ValueError, + r'result callback should not be provided'): + with _ImageClassifier.create_from_options(options) as unused_classifier: + pass + + def test_calling_classify_for_video_in_image_mode(self): + options = _ImageClassifierOptions( + base_options=_BaseOptions(model_asset_path=self.model_path), + running_mode=_RUNNING_MODE.IMAGE) + with _ImageClassifier.create_from_options(options) as classifier: + with self.assertRaisesRegex(ValueError, + r'not initialized with the video mode'): + classifier.classify_for_video(self.test_image, 0) + + def test_calling_classify_async_in_image_mode(self): + options = _ImageClassifierOptions( + base_options=_BaseOptions(model_asset_path=self.model_path), + running_mode=_RUNNING_MODE.IMAGE) + with _ImageClassifier.create_from_options(options) as classifier: + with self.assertRaisesRegex(ValueError, + r'not initialized with the live stream mode'): + classifier.classify_async(self.test_image, 0) + + def test_calling_classify_in_video_mode(self): + options = _ImageClassifierOptions( + base_options=_BaseOptions(model_asset_path=self.model_path), + running_mode=_RUNNING_MODE.VIDEO) + with _ImageClassifier.create_from_options(options) as classifier: + with self.assertRaisesRegex(ValueError, + r'not initialized with the image mode'): + classifier.classify(self.test_image) + + def test_calling_classify_async_in_video_mode(self): + options = _ImageClassifierOptions( + base_options=_BaseOptions(model_asset_path=self.model_path), + running_mode=_RUNNING_MODE.VIDEO) + with _ImageClassifier.create_from_options(options) as classifier: + with self.assertRaisesRegex(ValueError, + r'not initialized with the live stream mode'): + classifier.classify_async(self.test_image, 0) + + def test_classify_for_video_with_out_of_order_timestamp(self): + options = _ImageClassifierOptions( + base_options=_BaseOptions(model_asset_path=self.model_path), + running_mode=_RUNNING_MODE.VIDEO) + with _ImageClassifier.create_from_options(options) as classifier: + unused_result = classifier.classify_for_video(self.test_image, 1) + with self.assertRaisesRegex( + ValueError, r'Input timestamp must be monotonically increasing'): + classifier.classify_for_video(self.test_image, 0) + def test_classify_for_video(self): classifier_options = _ClassifierOptions(max_results=4) options = _ImageClassifierOptions( @@ -132,7 +353,78 @@ class ImageClassifierTest(parameterized.TestCase): for timestamp in range(0, 300, 30): classification_result = classifier.classify_for_video( self.test_image, timestamp) - self.assertEqual(classification_result, _EXPECTED_CLASSIFICATION_RESULT) + expected_classification_result = _ClassificationResult( + classifications=[ + _Classifications( + entries=[ + _ClassificationEntry( + categories=_EXPECTED_CATEGORIES, timestamp_ms=timestamp) + ], + head_index=0, head_name='probability') + ]) + self.assertEqual(classification_result, expected_classification_result) + + def test_calling_classify_in_live_stream_mode(self): + options = _ImageClassifierOptions( + base_options=_BaseOptions(model_asset_path=self.model_path), + running_mode=_RUNNING_MODE.LIVE_STREAM, + result_callback=mock.MagicMock()) + with _ImageClassifier.create_from_options(options) as classifier: + with self.assertRaisesRegex(ValueError, + r'not initialized with the image mode'): + classifier.classify(self.test_image) + + def test_calling_classify_for_video_in_live_stream_mode(self): + options = _ImageClassifierOptions( + base_options=_BaseOptions(model_asset_path=self.model_path), + running_mode=_RUNNING_MODE.LIVE_STREAM, + result_callback=mock.MagicMock()) + with _ImageClassifier.create_from_options(options) as classifier: + with self.assertRaisesRegex(ValueError, + r'not initialized with the video mode'): + classifier.classify_for_video(self.test_image, 0) + + def test_classify_async_calls_with_illegal_timestamp(self): + classifier_options = _ClassifierOptions(max_results=4) + options = _ImageClassifierOptions( + base_options=_BaseOptions(model_asset_path=self.model_path), + running_mode=_RUNNING_MODE.LIVE_STREAM, + classifier_options=classifier_options, + result_callback=mock.MagicMock()) + with _ImageClassifier.create_from_options(options) as classifier: + classifier.classify_async(self.test_image, 100) + with self.assertRaisesRegex( + ValueError, r'Input timestamp must be monotonically increasing'): + classifier.classify_async(self.test_image, 0) + + # TODO: Fix the packet is empty issue. + """ + @parameterized.parameters((0, _EXPECTED_CLASSIFICATION_RESULT), + (1, _EMPTY_CLASSIFICATION_RESULT)) + def test_classify_async_calls(self, threshold, expected_result): + observed_timestamp_ms = -1 + + def check_result(result: _ClassificationResult, output_image: _Image, + timestamp_ms: int): + self.assertEqual(result, expected_result) + self.assertTrue( + np.array_equal(output_image.numpy_view(), + self.test_image.numpy_view())) + self.assertLess(observed_timestamp_ms, timestamp_ms) + self.observed_timestamp_ms = timestamp_ms + + classifier_options = _ClassifierOptions( + max_results=4, score_threshold=threshold) + options = _ImageClassifierOptions( + base_options=_BaseOptions(model_asset_path=self.model_path), + running_mode=_RUNNING_MODE.LIVE_STREAM, + classifier_options=classifier_options, + result_callback=check_result) + classifier = _ImageClassifier.create_from_options(options) + for timestamp in range(0, 300, 30): + classifier.classify_async(self.test_image, timestamp) + classifier.close() + """ if __name__ == '__main__': diff --git a/mediapipe/tasks/python/vision/image_classifier.py b/mediapipe/tasks/python/vision/image_classifier.py index b3bafa113..36c5561c4 100644 --- a/mediapipe/tasks/python/vision/image_classifier.py +++ b/mediapipe/tasks/python/vision/image_classifier.py @@ -43,6 +43,7 @@ _IMAGE_IN_STREAM_NAME = 'image_in' _IMAGE_OUT_STREAM_NAME = 'image_out' _IMAGE_TAG = 'IMAGE' _TASK_GRAPH_NAME = 'mediapipe.tasks.vision.image_classifier.ImageClassifierGraph' +_MICRO_SECONDS_PER_MILLISECOND = 1000 @dataclasses.dataclass @@ -91,7 +92,7 @@ class ImageClassifier(base_vision_task_api.BaseVisionTaskApi): """Creates an `ImageClassifier` object from a TensorFlow Lite model and the default `ImageClassifierOptions`. Note that the created `ImageClassifier` instance is in image mode, for - detecting objects on single image inputs. + classifying objects on single image inputs. Args: model_path: Path to the model. @@ -137,7 +138,8 @@ class ImageClassifier(base_vision_task_api.BaseVisionTaskApi): ]) image = packet_getter.get_image(output_packets[_IMAGE_OUT_STREAM_NAME]) timestamp = output_packets[_IMAGE_OUT_STREAM_NAME].timestamp - options.result_callback(classification_result, image, timestamp) + options.result_callback(classification_result, image, + timestamp.value / _MICRO_SECONDS_PER_MILLISECOND) task_info = _TaskInfo( task_graph=_TASK_GRAPH_NAME, @@ -156,7 +158,6 @@ class ImageClassifier(base_vision_task_api.BaseVisionTaskApi): _RunningMode.LIVE_STREAM), options.running_mode, packets_callback if options.result_callback else None) - # TODO: Create an Image class for MediaPipe Tasks. def classify( self, image: image_module.Image, @@ -206,8 +207,9 @@ class ImageClassifier(base_vision_task_api.BaseVisionTaskApi): RuntimeError: If image classification failed to run. """ output_packets = self._process_video_data({ - _IMAGE_IN_STREAM_NAME: - packet_creator.create_image(image).at(timestamp_ms) + _IMAGE_IN_STREAM_NAME: + packet_creator.create_image(image).at( + timestamp_ms * _MICRO_SECONDS_PER_MILLISECOND) }) classification_result_proto = packet_getter.get_proto( output_packets[_CLASSIFICATION_RESULT_OUT_STREAM_NAME]) @@ -216,3 +218,36 @@ class ImageClassifier(base_vision_task_api.BaseVisionTaskApi): classifications_module.Classifications.create_from_pb2(classification) for classification in classification_result_proto.classifications ]) + + def classify_async(self, image: image_module.Image, timestamp_ms: int) -> None: + """Sends live image data (an Image with a unique timestamp) to perform + image classification. + + Only use this method when the ImageClassifier is created with the live + stream running mode. The input timestamps should be monotonically increasing + for adjacent calls of this method. This method will return immediately after + the input image is accepted. The results will be available via the + `result_callback` provided in the `ImageClassifierOptions`. The + `classify_async` method is designed to process live stream data such as + camera input. To lower the overall latency, image classifier may drop the + input images if needed. In other words, it's not guaranteed to have output + per input image. + + The `result_callback` provides: + - A classification result object that contains a list of classifications. + - The input image that the image classifier runs on. + - The input timestamp in milliseconds. + + Args: + image: MediaPipe Image. + timestamp_ms: The timestamp of the input image in milliseconds. + + Raises: + ValueError: If the current input timestamp is smaller than what the image + classifier has already processed. + """ + self._send_live_stream_data({ + _IMAGE_IN_STREAM_NAME: + packet_creator.create_image(image).at( + timestamp_ms * _MICRO_SECONDS_PER_MILLISECOND) + }) From 44e6f8e1a1edea42bc92c1e9c83d59ddce678fb8 Mon Sep 17 00:00:00 2001 From: kinaryml Date: Mon, 10 Oct 2022 08:15:40 -0700 Subject: [PATCH 17/55] Updated image classifier to use a region of interest parameter --- .../tasks/python/components/containers/BUILD | 9 ++ .../python/components/containers/rect.py | 136 ++++++++++++++++ mediapipe/tasks/python/test/vision/BUILD | 1 + .../test/vision/image_classifier_test.py | 146 ++++++++++++------ mediapipe/tasks/python/vision/BUILD | 1 + .../tasks/python/vision/image_classifier.py | 46 ++++-- 6 files changed, 281 insertions(+), 58 deletions(-) create mode 100644 mediapipe/tasks/python/components/containers/rect.py diff --git a/mediapipe/tasks/python/components/containers/BUILD b/mediapipe/tasks/python/components/containers/BUILD index d46df22da..723210f5f 100644 --- a/mediapipe/tasks/python/components/containers/BUILD +++ b/mediapipe/tasks/python/components/containers/BUILD @@ -27,6 +27,15 @@ py_library( ], ) +py_library( + name = "rect", + srcs = ["rect.py"], + deps = [ + "//mediapipe/framework/formats:rect_py_pb2", + "//mediapipe/tasks/python/core:optional_dependencies", + ], +) + py_library( name = "category", srcs = ["category.py"], diff --git a/mediapipe/tasks/python/components/containers/rect.py b/mediapipe/tasks/python/components/containers/rect.py new file mode 100644 index 000000000..e74be1b0e --- /dev/null +++ b/mediapipe/tasks/python/components/containers/rect.py @@ -0,0 +1,136 @@ +# Copyright 2022 The MediaPipe Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""Rect data class.""" + +import dataclasses +from typing import Any, Optional + +from mediapipe.framework.formats import rect_pb2 +from mediapipe.tasks.python.core.optional_dependencies import doc_controls + +_RectProto = rect_pb2.Rect +_NormalizedRectProto = rect_pb2.NormalizedRect + + +@dataclasses.dataclass +class Rect: + """A rectangle with rotation in image coordinates. + + Attributes: + x_center : The X coordinate of the top-left corner, in pixels. + y_center : The Y coordinate of the top-left corner, in pixels. + width: The width of the rectangle, in pixels. + height: The height of the rectangle, in pixels. + rotation: Rotation angle is clockwise in radians. + rect_id: Optional unique id to help associate different rectangles to each + other. + """ + + x_center: int + y_center: int + width: int + height: int + rotation: Optional[float] = 0.0 + rect_id: Optional[int] = None + + @doc_controls.do_not_generate_docs + def to_pb2(self) -> _RectProto: + """Generates a Rect protobuf object.""" + return _RectProto( + x_center=self.x_center, + y_center=self.y_center, + width=self.width, + height=self.height, + ) + + @classmethod + @doc_controls.do_not_generate_docs + def create_from_pb2(cls, pb2_obj: _RectProto) -> 'Rect': + """Creates a `Rect` object from the given protobuf object.""" + return Rect( + x_center=pb2_obj.x_center, + y_center=pb2_obj.y_center, + width=pb2_obj.width, + height=pb2_obj.height) + + def __eq__(self, other: Any) -> bool: + """Checks if this object is equal to the given object. + + Args: + other: The object to be compared with. + + Returns: + True if the objects are equal. + """ + if not isinstance(other, Rect): + return False + + return self.to_pb2().__eq__(other.to_pb2()) + + +@dataclasses.dataclass +class NormalizedRect: + """A rectangle with rotation in normalized coordinates. The values of box + center location and size are within [0, 1]. + + Attributes: + x_center : The X normalized coordinate of the top-left corner. + y_center : The Y normalized coordinate of the top-left corner. + width: The width of the rectangle. + height: The height of the rectangle. + rotation: Rotation angle is clockwise in radians. + rect_id: Optional unique id to help associate different rectangles to each + other. + """ + + x_center: float + y_center: float + width: float + height: float + rotation: Optional[float] = 0.0 + rect_id: Optional[int] = None + + @doc_controls.do_not_generate_docs + def to_pb2(self) -> _NormalizedRectProto: + """Generates a NormalizedRect protobuf object.""" + return _NormalizedRectProto( + x_center=self.x_center, + y_center=self.y_center, + width=self.width, + height=self.height, + ) + + @classmethod + @doc_controls.do_not_generate_docs + def create_from_pb2(cls, pb2_obj: _NormalizedRectProto) -> 'NormalizedRect': + """Creates a `NormalizedRect` object from the given protobuf object.""" + return NormalizedRect( + x_center=pb2_obj.x_center, + y_center=pb2_obj.y_center, + width=pb2_obj.width, + height=pb2_obj.height) + + def __eq__(self, other: Any) -> bool: + """Checks if this object is equal to the given object. + + Args: + other: The object to be compared with. + + Returns: + True if the objects are equal. + """ + if not isinstance(other, NormalizedRect): + return False + + return self.to_pb2().__eq__(other.to_pb2()) diff --git a/mediapipe/tasks/python/test/vision/BUILD b/mediapipe/tasks/python/test/vision/BUILD index 09bbf1958..e4d331784 100644 --- a/mediapipe/tasks/python/test/vision/BUILD +++ b/mediapipe/tasks/python/test/vision/BUILD @@ -49,6 +49,7 @@ py_test( "//mediapipe/tasks/python/components/processors:classifier_options", "//mediapipe/tasks/python/components/containers:category", "//mediapipe/tasks/python/components/containers:classifications", + "//mediapipe/tasks/python/components/containers:rect", "//mediapipe/tasks/python/core:base_options", "//mediapipe/tasks/python/test:test_utils", "//mediapipe/tasks/python/vision:image_classifier", diff --git a/mediapipe/tasks/python/test/vision/image_classifier_test.py b/mediapipe/tasks/python/test/vision/image_classifier_test.py index 073674c3f..b9098b55b 100644 --- a/mediapipe/tasks/python/test/vision/image_classifier_test.py +++ b/mediapipe/tasks/python/test/vision/image_classifier_test.py @@ -24,11 +24,13 @@ from mediapipe.python._framework_bindings import image as image_module from mediapipe.tasks.python.components.processors import classifier_options from mediapipe.tasks.python.components.containers import category as category_module from mediapipe.tasks.python.components.containers import classifications as classifications_module +from mediapipe.tasks.python.components.containers import rect as rect_module from mediapipe.tasks.python.core import base_options as base_options_module from mediapipe.tasks.python.test import test_utils from mediapipe.tasks.python.vision import image_classifier from mediapipe.tasks.python.vision.core import vision_task_running_mode as running_mode_module +_NormalizedRect = rect_module.NormalizedRect _BaseOptions = base_options_module.BaseOptions _ClassifierOptions = classifier_options.ClassifierOptions _Category = category_module.Category @@ -42,40 +44,6 @@ _RUNNING_MODE = running_mode_module.VisionTaskRunningMode _MODEL_FILE = 'mobilenet_v2_1.0_224.tflite' _IMAGE_FILE = 'burger.jpg' -_EXPECTED_CATEGORIES = [ - _Category( - index=934, - score=0.7939587831497192, - display_name='', - category_name='cheeseburger'), - _Category( - index=932, - score=0.02739289402961731, - display_name='', - category_name='bagel'), - _Category( - index=925, - score=0.01934075355529785, - display_name='', - category_name='guacamole'), - _Category( - index=963, - score=0.006327860057353973, - display_name='', - category_name='meat loaf') -] -_EXPECTED_CLASSIFICATION_RESULT = _ClassificationResult( - classifications=[ - _Classifications( - entries=[ - _ClassificationEntry( - categories=_EXPECTED_CATEGORIES, - timestamp_ms=0 - ) - ], - head_index=0, - head_name='probability') - ]) _EMPTY_CLASSIFICATION_RESULT = _ClassificationResult( classifications=[ _Classifications( @@ -94,6 +62,60 @@ _SCORE_THRESHOLD = 0.5 _MAX_RESULTS = 3 +def _generate_burger_results(timestamp_ms: int) -> _ClassificationResult: + return _ClassificationResult( + classifications=[ + _Classifications( + entries=[ + _ClassificationEntry( + categories=[ + _Category( + index=934, + score=0.7939587831497192, + display_name='', + category_name='cheeseburger'), + _Category( + index=932, + score=0.02739289402961731, + display_name='', + category_name='bagel'), + _Category( + index=925, + score=0.01934075355529785, + display_name='', + category_name='guacamole'), + _Category( + index=963, + score=0.006327860057353973, + display_name='', + category_name='meat loaf') + ], + timestamp_ms=timestamp_ms + ) + ], + head_index=0, + head_name='probability') + ]) + + +def _generate_soccer_ball_results(timestamp_ms: int) -> _ClassificationResult: + return _ClassificationResult( + classifications=[ + _Classifications( + entries=[ + _ClassificationEntry( + categories=[ + _Category(index=806, score=0.9965274930000305, display_name='', + category_name='soccer ball') + ], + timestamp_ms=timestamp_ms + ) + ], + head_index=0, + head_name='probability') + ]) + + class ModelFileType(enum.Enum): FILE_CONTENT = 1 FILE_NAME = 2 @@ -138,8 +160,8 @@ class ImageClassifierTest(parameterized.TestCase): self.assertIsInstance(classifier, _ImageClassifier) @parameterized.parameters( - (ModelFileType.FILE_NAME, 4, _EXPECTED_CLASSIFICATION_RESULT), - (ModelFileType.FILE_CONTENT, 4, _EXPECTED_CLASSIFICATION_RESULT)) + (ModelFileType.FILE_NAME, 4, _generate_burger_results(0)), + (ModelFileType.FILE_CONTENT, 4, _generate_burger_results(0))) def test_classify(self, model_file_type, max_results, expected_classification_result): # Creates classifier. @@ -167,8 +189,8 @@ class ImageClassifierTest(parameterized.TestCase): classifier.close() @parameterized.parameters( - (ModelFileType.FILE_NAME, 4, _EXPECTED_CLASSIFICATION_RESULT), - (ModelFileType.FILE_CONTENT, 4, _EXPECTED_CLASSIFICATION_RESULT)) + (ModelFileType.FILE_NAME, 4, _generate_burger_results(0)), + (ModelFileType.FILE_CONTENT, 4, _generate_burger_results(0))) def test_classify_in_context(self, model_file_type, max_results, expected_classification_result): if model_file_type is ModelFileType.FILE_NAME: @@ -190,6 +212,23 @@ class ImageClassifierTest(parameterized.TestCase): # Comparing results. self.assertEqual(image_result, expected_classification_result) + def test_classify_succeeds_with_region_of_interest(self): + base_options = _BaseOptions(model_asset_path=self.model_path) + classifier_options = _ClassifierOptions(max_results=1) + options = _ImageClassifierOptions( + base_options=base_options, classifier_options=classifier_options) + with _ImageClassifier.create_from_options(options) as classifier: + # Load the test image. + test_image = _Image.create_from_file( + test_utils.get_test_data_path('multi_objects.jpg')) + # NormalizedRect around the soccer ball. + roi = _NormalizedRect(x_center=0.532, y_center=0.521, width=0.164, + height=0.427) + # Performs image classification on the input. + image_result = classifier.classify(test_image, roi) + # Comparing results. + self.assertEqual(image_result, _generate_soccer_ball_results(0)) + def test_score_threshold_option(self): classifier_options = _ClassifierOptions(score_threshold=_SCORE_THRESHOLD) options = _ImageClassifierOptions( @@ -353,16 +392,27 @@ class ImageClassifierTest(parameterized.TestCase): for timestamp in range(0, 300, 30): classification_result = classifier.classify_for_video( self.test_image, timestamp) - expected_classification_result = _ClassificationResult( - classifications=[ - _Classifications( - entries=[ - _ClassificationEntry( - categories=_EXPECTED_CATEGORIES, timestamp_ms=timestamp) - ], - head_index=0, head_name='probability') - ]) - self.assertEqual(classification_result, expected_classification_result) + self.assertEqual(classification_result, + _generate_burger_results(timestamp)) + + def test_classify_for_video_succeeds_with_region_of_interest(self): + classifier_options = _ClassifierOptions(max_results=1) + options = _ImageClassifierOptions( + base_options=_BaseOptions(model_asset_path=self.model_path), + running_mode=_RUNNING_MODE.VIDEO, + classifier_options=classifier_options) + with _ImageClassifier.create_from_options(options) as classifier: + # Load the test image. + test_image = _Image.create_from_file( + test_utils.get_test_data_path('multi_objects.jpg')) + # NormalizedRect around the soccer ball. + roi = _NormalizedRect(x_center=0.532, y_center=0.521, width=0.164, + height=0.427) + for timestamp in range(0, 300, 30): + classification_result = classifier.classify_for_video( + test_image, timestamp, roi) + self.assertEqual(classification_result, + _generate_soccer_ball_results(timestamp)) def test_calling_classify_in_live_stream_mode(self): options = _ImageClassifierOptions( diff --git a/mediapipe/tasks/python/vision/BUILD b/mediapipe/tasks/python/vision/BUILD index 7d44e3326..273804f0a 100644 --- a/mediapipe/tasks/python/vision/BUILD +++ b/mediapipe/tasks/python/vision/BUILD @@ -49,6 +49,7 @@ py_library( "//mediapipe/tasks/cc/vision/image_classifier/proto:image_classifier_graph_options_py_pb2", "//mediapipe/tasks/python/components/processors:classifier_options", "//mediapipe/tasks/python/components/containers:classifications", + "//mediapipe/tasks/python/components/containers:rect", "//mediapipe/tasks/python/core:base_options", "//mediapipe/tasks/python/core:optional_dependencies", "//mediapipe/tasks/python/core:task_info", diff --git a/mediapipe/tasks/python/vision/image_classifier.py b/mediapipe/tasks/python/vision/image_classifier.py index 36c5561c4..348e848ab 100644 --- a/mediapipe/tasks/python/vision/image_classifier.py +++ b/mediapipe/tasks/python/vision/image_classifier.py @@ -24,12 +24,14 @@ from mediapipe.python._framework_bindings import task_runner as task_runner_modu from mediapipe.tasks.cc.vision.image_classifier.proto import image_classifier_graph_options_pb2 from mediapipe.tasks.python.components.processors import classifier_options from mediapipe.tasks.python.components.containers import classifications as classifications_module +from mediapipe.tasks.python.components.containers import rect as rect_module from mediapipe.tasks.python.core import base_options as base_options_module from mediapipe.tasks.python.core import task_info as task_info_module from mediapipe.tasks.python.core.optional_dependencies import doc_controls from mediapipe.tasks.python.vision.core import base_vision_task_api from mediapipe.tasks.python.vision.core import vision_task_running_mode as running_mode_module +_NormalizedRect = rect_module.NormalizedRect _BaseOptions = base_options_module.BaseOptions _ImageClassifierGraphOptionsProto = image_classifier_graph_options_pb2.ImageClassifierGraphOptions _ClassifierOptions = classifier_options.ClassifierOptions @@ -42,10 +44,17 @@ _CLASSIFICATION_RESULT_TAG = 'CLASSIFICATION_RESULT' _IMAGE_IN_STREAM_NAME = 'image_in' _IMAGE_OUT_STREAM_NAME = 'image_out' _IMAGE_TAG = 'IMAGE' +_NORM_RECT_NAME = 'norm_rect_in' +_NORM_RECT_TAG = 'NORM_RECT' _TASK_GRAPH_NAME = 'mediapipe.tasks.vision.image_classifier.ImageClassifierGraph' _MICRO_SECONDS_PER_MILLISECOND = 1000 +def _build_full_image_norm_rect() -> _NormalizedRect: + # Builds a NormalizedRect covering the entire image. + return _NormalizedRect(x_center=0.5, y_center=0.5, width=1, height=1) + + @dataclasses.dataclass class ImageClassifierOptions: """Options for the image classifier task. @@ -145,6 +154,7 @@ class ImageClassifier(base_vision_task_api.BaseVisionTaskApi): task_graph=_TASK_GRAPH_NAME, input_streams=[ ':'.join([_IMAGE_TAG, _IMAGE_IN_STREAM_NAME]), + ':'.join([_NORM_RECT_TAG, _NORM_RECT_NAME]), ], output_streams=[ ':'.join([_CLASSIFICATION_RESULT_TAG, @@ -161,11 +171,13 @@ class ImageClassifier(base_vision_task_api.BaseVisionTaskApi): def classify( self, image: image_module.Image, + roi: Optional[_NormalizedRect] = None ) -> classifications_module.ClassificationResult: """Performs image classification on the provided MediaPipe Image. Args: image: MediaPipe Image. + roi: The region of interest. Returns: A classification result object that contains a list of classifications. @@ -174,8 +186,10 @@ class ImageClassifier(base_vision_task_api.BaseVisionTaskApi): ValueError: If any of the input arguments is invalid. RuntimeError: If image classification failed to run. """ - output_packets = self._process_image_data( - {_IMAGE_IN_STREAM_NAME: packet_creator.create_image(image)}) + norm_rect = roi if roi is not None else _build_full_image_norm_rect() + output_packets = self._process_image_data({ + _IMAGE_IN_STREAM_NAME: packet_creator.create_image(image), + _NORM_RECT_NAME: packet_creator.create_proto(norm_rect.to_pb2())}) classification_result_proto = packet_getter.get_proto( output_packets[_CLASSIFICATION_RESULT_OUT_STREAM_NAME]) @@ -186,7 +200,8 @@ class ImageClassifier(base_vision_task_api.BaseVisionTaskApi): def classify_for_video( self, image: image_module.Image, - timestamp_ms: int + timestamp_ms: int, + roi: Optional[_NormalizedRect] = None ) -> classifications_module.ClassificationResult: """Performs image classification on the provided video frames. @@ -198,6 +213,7 @@ class ImageClassifier(base_vision_task_api.BaseVisionTaskApi): Args: image: MediaPipe Image. timestamp_ms: The timestamp of the input video frame in milliseconds. + roi: The region of interest. Returns: A classification result object that contains a list of classifications. @@ -206,10 +222,12 @@ class ImageClassifier(base_vision_task_api.BaseVisionTaskApi): ValueError: If any of the input arguments is invalid. RuntimeError: If image classification failed to run. """ + norm_rect = roi if roi is not None else _build_full_image_norm_rect() output_packets = self._process_video_data({ - _IMAGE_IN_STREAM_NAME: - packet_creator.create_image(image).at( - timestamp_ms * _MICRO_SECONDS_PER_MILLISECOND) + _IMAGE_IN_STREAM_NAME: packet_creator.create_image(image).at( + timestamp_ms * _MICRO_SECONDS_PER_MILLISECOND), + _NORM_RECT_NAME: packet_creator.create_proto(norm_rect.to_pb2()).at( + timestamp_ms * _MICRO_SECONDS_PER_MILLISECOND) }) classification_result_proto = packet_getter.get_proto( output_packets[_CLASSIFICATION_RESULT_OUT_STREAM_NAME]) @@ -219,7 +237,12 @@ class ImageClassifier(base_vision_task_api.BaseVisionTaskApi): for classification in classification_result_proto.classifications ]) - def classify_async(self, image: image_module.Image, timestamp_ms: int) -> None: + def classify_async( + self, + image: image_module.Image, + timestamp_ms: int, + roi: Optional[_NormalizedRect] = None + ) -> None: """Sends live image data (an Image with a unique timestamp) to perform image classification. @@ -241,13 +264,16 @@ class ImageClassifier(base_vision_task_api.BaseVisionTaskApi): Args: image: MediaPipe Image. timestamp_ms: The timestamp of the input image in milliseconds. + roi: The region of interest. Raises: ValueError: If the current input timestamp is smaller than what the image classifier has already processed. """ + norm_rect = roi if roi is not None else _build_full_image_norm_rect() self._send_live_stream_data({ - _IMAGE_IN_STREAM_NAME: - packet_creator.create_image(image).at( - timestamp_ms * _MICRO_SECONDS_PER_MILLISECOND) + _IMAGE_IN_STREAM_NAME: packet_creator.create_image(image).at( + timestamp_ms * _MICRO_SECONDS_PER_MILLISECOND), + _NORM_RECT_NAME: packet_creator.create_proto(norm_rect.to_pb2()).at( + timestamp_ms * _MICRO_SECONDS_PER_MILLISECOND) }) From c2672d040f9b79a77610efd6cd9f591f2b047f77 Mon Sep 17 00:00:00 2001 From: kinaryml Date: Mon, 10 Oct 2022 08:19:12 -0700 Subject: [PATCH 18/55] Updated error message for the invalid model path test case --- mediapipe/tasks/python/test/vision/image_classifier_test.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/mediapipe/tasks/python/test/vision/image_classifier_test.py b/mediapipe/tasks/python/test/vision/image_classifier_test.py index b9098b55b..336f1f306 100644 --- a/mediapipe/tasks/python/test/vision/image_classifier_test.py +++ b/mediapipe/tasks/python/test/vision/image_classifier_test.py @@ -146,7 +146,7 @@ class ImageClassifierTest(parameterized.TestCase): with self.assertRaisesRegex( ValueError, r"ExternalFile must specify at least one of 'file_content', " - r"'file_name' or 'file_descriptor_meta'."): + r"'file_name', 'file_pointer_meta' or 'file_descriptor_meta'."): base_options = _BaseOptions(model_asset_path='') options = _ImageClassifierOptions(base_options=base_options) _ImageClassifier.create_from_options(options) From 0a8dbc7576fd5381eefe98bf7e645f9f7b01ad04 Mon Sep 17 00:00:00 2001 From: kinaryml Date: Mon, 10 Oct 2022 13:58:37 -0700 Subject: [PATCH 19/55] Added remaining parameters to initialize the Rect data class --- mediapipe/tasks/python/components/containers/rect.py | 7 ++++++- 1 file changed, 6 insertions(+), 1 deletion(-) diff --git a/mediapipe/tasks/python/components/containers/rect.py b/mediapipe/tasks/python/components/containers/rect.py index e74be1b0e..aadb404db 100644 --- a/mediapipe/tasks/python/components/containers/rect.py +++ b/mediapipe/tasks/python/components/containers/rect.py @@ -109,6 +109,8 @@ class NormalizedRect: y_center=self.y_center, width=self.width, height=self.height, + rotation=self.rotation, + rect_id=self.rect_id ) @classmethod @@ -119,7 +121,10 @@ class NormalizedRect: x_center=pb2_obj.x_center, y_center=pb2_obj.y_center, width=pb2_obj.width, - height=pb2_obj.height) + height=pb2_obj.height, + rotation=pb2_obj.rotation, + rect_id=pb2_obj.rect_id + ) def __eq__(self, other: Any) -> bool: """Checks if this object is equal to the given object. From 8ea0018397fc98289e60d3498c90b98b82bb3e18 Mon Sep 17 00:00:00 2001 From: kinaryml Date: Tue, 11 Oct 2022 22:34:56 -0700 Subject: [PATCH 20/55] Added a check to see if the output packet is empty in the API and updated tests --- .../test/vision/image_classifier_test.py | 38 +++++++++---------- .../tasks/python/vision/image_classifier.py | 4 +- 2 files changed, 22 insertions(+), 20 deletions(-) diff --git a/mediapipe/tasks/python/test/vision/image_classifier_test.py b/mediapipe/tasks/python/test/vision/image_classifier_test.py index 336f1f306..5bb479d7a 100644 --- a/mediapipe/tasks/python/test/vision/image_classifier_test.py +++ b/mediapipe/tasks/python/test/vision/image_classifier_test.py @@ -44,24 +44,27 @@ _RUNNING_MODE = running_mode_module.VisionTaskRunningMode _MODEL_FILE = 'mobilenet_v2_1.0_224.tflite' _IMAGE_FILE = 'burger.jpg' -_EMPTY_CLASSIFICATION_RESULT = _ClassificationResult( - classifications=[ - _Classifications( - entries=[ - _ClassificationEntry( - categories=[], - timestamp_ms=0 - ) - ], - head_index=0, - head_name='probability') - ]) _ALLOW_LIST = ['cheeseburger', 'guacamole'] _DENY_LIST = ['cheeseburger'] _SCORE_THRESHOLD = 0.5 _MAX_RESULTS = 3 +def _generate_empty_results(timestamp_ms: int) -> _ClassificationResult: + return _ClassificationResult( + classifications=[ + _Classifications( + entries=[ + _ClassificationEntry( + categories=[], + timestamp_ms=timestamp_ms + ) + ], + head_index=0, + head_name='probability') + ]) + + def _generate_burger_results(timestamp_ms: int) -> _ClassificationResult: return _ClassificationResult( classifications=[ @@ -447,16 +450,14 @@ class ImageClassifierTest(parameterized.TestCase): ValueError, r'Input timestamp must be monotonically increasing'): classifier.classify_async(self.test_image, 0) - # TODO: Fix the packet is empty issue. - """ - @parameterized.parameters((0, _EXPECTED_CLASSIFICATION_RESULT), - (1, _EMPTY_CLASSIFICATION_RESULT)) - def test_classify_async_calls(self, threshold, expected_result): + @parameterized.parameters((0, _generate_burger_results), + (1, _generate_empty_results)) + def test_classify_async_calls(self, threshold, expected_result_fn): observed_timestamp_ms = -1 def check_result(result: _ClassificationResult, output_image: _Image, timestamp_ms: int): - self.assertEqual(result, expected_result) + self.assertEqual(result, expected_result_fn(timestamp_ms)) self.assertTrue( np.array_equal(output_image.numpy_view(), self.test_image.numpy_view())) @@ -474,7 +475,6 @@ class ImageClassifierTest(parameterized.TestCase): for timestamp in range(0, 300, 30): classifier.classify_async(self.test_image, timestamp) classifier.close() - """ if __name__ == '__main__': diff --git a/mediapipe/tasks/python/vision/image_classifier.py b/mediapipe/tasks/python/vision/image_classifier.py index 348e848ab..1d6b69778 100644 --- a/mediapipe/tasks/python/vision/image_classifier.py +++ b/mediapipe/tasks/python/vision/image_classifier.py @@ -138,6 +138,8 @@ class ImageClassifier(base_vision_task_api.BaseVisionTaskApi): """ def packets_callback(output_packets: Mapping[str, packet_module.Packet]): + if output_packets[_IMAGE_OUT_STREAM_NAME].is_empty(): + return classification_result_proto = packet_getter.get_proto( output_packets[_CLASSIFICATION_RESULT_OUT_STREAM_NAME]) @@ -148,7 +150,7 @@ class ImageClassifier(base_vision_task_api.BaseVisionTaskApi): image = packet_getter.get_image(output_packets[_IMAGE_OUT_STREAM_NAME]) timestamp = output_packets[_IMAGE_OUT_STREAM_NAME].timestamp options.result_callback(classification_result, image, - timestamp.value / _MICRO_SECONDS_PER_MILLISECOND) + timestamp.value // _MICRO_SECONDS_PER_MILLISECOND) task_info = _TaskInfo( task_graph=_TASK_GRAPH_NAME, From 7726205d851cab65bbe1e7508201704c57869ae1 Mon Sep 17 00:00:00 2001 From: kinaryml Date: Tue, 11 Oct 2022 22:42:50 -0700 Subject: [PATCH 21/55] Added a test to run classify_async in region of interest mode --- .../test/vision/image_classifier_test.py | 26 +++++++++++++++++++ 1 file changed, 26 insertions(+) diff --git a/mediapipe/tasks/python/test/vision/image_classifier_test.py b/mediapipe/tasks/python/test/vision/image_classifier_test.py index 5bb479d7a..d2140f1da 100644 --- a/mediapipe/tasks/python/test/vision/image_classifier_test.py +++ b/mediapipe/tasks/python/test/vision/image_classifier_test.py @@ -476,6 +476,32 @@ class ImageClassifierTest(parameterized.TestCase): classifier.classify_async(self.test_image, timestamp) classifier.close() + def test_classify_async_succeeds_with_region_of_interest(self): + observed_timestamp_ms = -1 + + def check_result(result: _ClassificationResult, unused_output_image: _Image, + timestamp_ms: int): + self.assertEqual(result, _generate_soccer_ball_results(timestamp_ms)) + self.assertLess(observed_timestamp_ms, timestamp_ms) + self.observed_timestamp_ms = timestamp_ms + + classifier_options = _ClassifierOptions(max_results=1) + options = _ImageClassifierOptions( + base_options=_BaseOptions(model_asset_path=self.model_path), + running_mode=_RUNNING_MODE.LIVE_STREAM, + classifier_options=classifier_options, + result_callback=check_result) + classifier = _ImageClassifier.create_from_options(options) + # Load the test image. + test_image = _Image.create_from_file( + test_utils.get_test_data_path('multi_objects.jpg')) + # NormalizedRect around the soccer ball. + roi = _NormalizedRect(x_center=0.532, y_center=0.521, width=0.164, + height=0.427) + for timestamp in range(0, 300, 30): + classifier.classify_async(test_image, timestamp, roi) + classifier.close() + if __name__ == '__main__': absltest.main() From 6771fe69e9803bdfe8925dcb53b4445c83a72454 Mon Sep 17 00:00:00 2001 From: kinaryml Date: Tue, 11 Oct 2022 22:52:18 -0700 Subject: [PATCH 22/55] Included checks for image sizes while running in async and roi mode --- .../python/test/vision/image_classifier_test.py | 16 +++++++++------- 1 file changed, 9 insertions(+), 7 deletions(-) diff --git a/mediapipe/tasks/python/test/vision/image_classifier_test.py b/mediapipe/tasks/python/test/vision/image_classifier_test.py index d2140f1da..783718d06 100644 --- a/mediapipe/tasks/python/test/vision/image_classifier_test.py +++ b/mediapipe/tasks/python/test/vision/image_classifier_test.py @@ -477,11 +477,19 @@ class ImageClassifierTest(parameterized.TestCase): classifier.close() def test_classify_async_succeeds_with_region_of_interest(self): + # Load the test image. + test_image = _Image.create_from_file( + test_utils.get_test_data_path('multi_objects.jpg')) + # NormalizedRect around the soccer ball. + roi = _NormalizedRect(x_center=0.532, y_center=0.521, width=0.164, + height=0.427) observed_timestamp_ms = -1 - def check_result(result: _ClassificationResult, unused_output_image: _Image, + def check_result(result: _ClassificationResult, output_image: _Image, timestamp_ms: int): self.assertEqual(result, _generate_soccer_ball_results(timestamp_ms)) + self.assertEqual(output_image.width, test_image.width) + self.assertEqual(output_image.height, test_image.height) self.assertLess(observed_timestamp_ms, timestamp_ms) self.observed_timestamp_ms = timestamp_ms @@ -492,12 +500,6 @@ class ImageClassifierTest(parameterized.TestCase): classifier_options=classifier_options, result_callback=check_result) classifier = _ImageClassifier.create_from_options(options) - # Load the test image. - test_image = _Image.create_from_file( - test_utils.get_test_data_path('multi_objects.jpg')) - # NormalizedRect around the soccer ball. - roi = _NormalizedRect(x_center=0.532, y_center=0.521, width=0.164, - height=0.427) for timestamp in range(0, 300, 30): classifier.classify_async(test_image, timestamp, roi) classifier.close() From 803210a86b460f53fa6eb6c4ef4d67131e70d457 Mon Sep 17 00:00:00 2001 From: kinaryml Date: Fri, 14 Oct 2022 03:00:29 -0700 Subject: [PATCH 23/55] Simplified async test cases to invoke the classifier in context --- .../python/test/vision/image_classifier_test.py | 14 ++++++-------- 1 file changed, 6 insertions(+), 8 deletions(-) diff --git a/mediapipe/tasks/python/test/vision/image_classifier_test.py b/mediapipe/tasks/python/test/vision/image_classifier_test.py index 783718d06..a90ddd53e 100644 --- a/mediapipe/tasks/python/test/vision/image_classifier_test.py +++ b/mediapipe/tasks/python/test/vision/image_classifier_test.py @@ -471,10 +471,9 @@ class ImageClassifierTest(parameterized.TestCase): running_mode=_RUNNING_MODE.LIVE_STREAM, classifier_options=classifier_options, result_callback=check_result) - classifier = _ImageClassifier.create_from_options(options) - for timestamp in range(0, 300, 30): - classifier.classify_async(self.test_image, timestamp) - classifier.close() + with _ImageClassifier.create_from_options(options) as classifier: + for timestamp in range(0, 300, 30): + classifier.classify_async(self.test_image, timestamp) def test_classify_async_succeeds_with_region_of_interest(self): # Load the test image. @@ -499,10 +498,9 @@ class ImageClassifierTest(parameterized.TestCase): running_mode=_RUNNING_MODE.LIVE_STREAM, classifier_options=classifier_options, result_callback=check_result) - classifier = _ImageClassifier.create_from_options(options) - for timestamp in range(0, 300, 30): - classifier.classify_async(test_image, timestamp, roi) - classifier.close() + with _ImageClassifier.create_from_options(options) as classifier: + for timestamp in range(0, 300, 30): + classifier.classify_async(test_image, timestamp, roi) if __name__ == '__main__': From 36d69971a7b645273374410a589f4afdb186272d Mon Sep 17 00:00:00 2001 From: MediaPipe Team Date: Fri, 21 Oct 2022 13:33:22 -0700 Subject: [PATCH 24/55] Internal change PiperOrigin-RevId: 482875698 --- mediapipe/calculators/tensor/BUILD | 21 +++++++++++++++++++++ 1 file changed, 21 insertions(+) diff --git a/mediapipe/calculators/tensor/BUILD b/mediapipe/calculators/tensor/BUILD index ae8a0cbf0..e953342da 100644 --- a/mediapipe/calculators/tensor/BUILD +++ b/mediapipe/calculators/tensor/BUILD @@ -307,6 +307,27 @@ cc_library( alwayslink = 1, ) +cc_test( + name = "universal_sentence_encoder_preprocessor_calculator_test", + srcs = ["universal_sentence_encoder_preprocessor_calculator_test.cc"], + data = ["//mediapipe/tasks/testdata/text:universal_sentence_encoder_qa"], + deps = [ + ":universal_sentence_encoder_preprocessor_calculator", + "//mediapipe/framework:calculator_framework", + "//mediapipe/framework:packet", + "//mediapipe/framework/formats:tensor", + "//mediapipe/framework/port:gtest_main", + "//mediapipe/framework/port:parse_text_proto", + "//mediapipe/framework/port:status", + "//mediapipe/framework/tool:options_map", + "//mediapipe/tasks/cc/core:utils", + "//mediapipe/tasks/cc/metadata:metadata_extractor", + "@com_google_absl//absl/status", + "@com_google_absl//absl/status:statusor", + "@com_google_absl//absl/strings", + ], +) + mediapipe_proto_library( name = "inference_calculator_proto", srcs = ["inference_calculator.proto"], From d0437b7f9152a723a48835916872ea7b051ac96d Mon Sep 17 00:00:00 2001 From: MediaPipe Team Date: Fri, 21 Oct 2022 15:25:38 -0700 Subject: [PATCH 25/55] Add tensor_index and tensor_name fields to ClassificationList PiperOrigin-RevId: 482901854 --- .../tensors_to_classification_calculator.cc | 7 +++++ ...tensors_to_classification_calculator.proto | 5 ++++ ...nsors_to_classification_calculator_test.cc | 30 +++++++++++++++++++ .../framework/formats/classification.proto | 4 +++ 4 files changed, 46 insertions(+) diff --git a/mediapipe/calculators/tensor/tensors_to_classification_calculator.cc b/mediapipe/calculators/tensor/tensors_to_classification_calculator.cc index 5bfc00ed7..76d2869e8 100644 --- a/mediapipe/calculators/tensor/tensors_to_classification_calculator.cc +++ b/mediapipe/calculators/tensor/tensors_to_classification_calculator.cc @@ -163,6 +163,7 @@ absl::Status TensorsToClassificationCalculator::Open(CalculatorContext* cc) { } absl::Status TensorsToClassificationCalculator::Process(CalculatorContext* cc) { + const auto& options = cc->Options(); const auto& input_tensors = *kInTensors(cc); RET_CHECK_EQ(input_tensors.size(), 1); RET_CHECK(input_tensors[0].element_type() == Tensor::ElementType::kFloat32); @@ -181,6 +182,12 @@ absl::Status TensorsToClassificationCalculator::Process(CalculatorContext* cc) { auto raw_scores = view.buffer(); auto classification_list = absl::make_unique(); + if (options.has_tensor_index()) { + classification_list->set_tensor_index(options.tensor_index()); + } + if (options.has_tensor_name()) { + classification_list->set_tensor_name(options.tensor_name()); + } if (is_binary_classification_) { Classification* class_first = classification_list->add_classification(); Classification* class_second = classification_list->add_classification(); diff --git a/mediapipe/calculators/tensor/tensors_to_classification_calculator.proto b/mediapipe/calculators/tensor/tensors_to_classification_calculator.proto index 32bc4b63a..f0f7727ba 100644 --- a/mediapipe/calculators/tensor/tensors_to_classification_calculator.proto +++ b/mediapipe/calculators/tensor/tensors_to_classification_calculator.proto @@ -72,4 +72,9 @@ message TensorsToClassificationCalculatorOptions { // that are not in the `allow_classes` field will be completely ignored. // `ignore_classes` and `allow_classes` are mutually exclusive. repeated int32 allow_classes = 8 [packed = true]; + + // The optional index of the tensor these classifications originate from. + optional int32 tensor_index = 10; + // The optional name of the tensor these classifications originate from. + optional string tensor_name = 11; } diff --git a/mediapipe/calculators/tensor/tensors_to_classification_calculator_test.cc b/mediapipe/calculators/tensor/tensors_to_classification_calculator_test.cc index 9634635f0..b20f2768c 100644 --- a/mediapipe/calculators/tensor/tensors_to_classification_calculator_test.cc +++ b/mediapipe/calculators/tensor/tensors_to_classification_calculator_test.cc @@ -240,6 +240,36 @@ TEST_F(TensorsToClassificationCalculatorTest, } } +TEST_F(TensorsToClassificationCalculatorTest, + CorrectOutputWithTensorNameAndIndex) { + mediapipe::CalculatorRunner runner(ParseTextProtoOrDie(R"pb( + calculator: "TensorsToClassificationCalculator" + input_stream: "TENSORS:tensors" + output_stream: "CLASSIFICATIONS:classifications" + options { + [mediapipe.TensorsToClassificationCalculatorOptions.ext] { + tensor_index: 1 + tensor_name: "foo" + } + } + )pb")); + + BuildGraph(&runner, {0, 0.5, 1}); + MP_ASSERT_OK(runner.Run()); + + const auto& output_packets_ = runner.Outputs().Tag("CLASSIFICATIONS").packets; + + EXPECT_EQ(1, output_packets_.size()); + + const auto& classification_list = + output_packets_[0].Get(); + EXPECT_EQ(3, classification_list.classification_size()); + + // Verify that the tensor_index and tensor_name fields are correctly set. + EXPECT_EQ(classification_list.tensor_index(), 1); + EXPECT_EQ(classification_list.tensor_name(), "foo"); +} + TEST_F(TensorsToClassificationCalculatorTest, ClassNameAllowlistWithLabelItems) { mediapipe::CalculatorRunner runner(ParseTextProtoOrDie(R"pb( diff --git a/mediapipe/framework/formats/classification.proto b/mediapipe/framework/formats/classification.proto index 7efd9074d..c3eea07ff 100644 --- a/mediapipe/framework/formats/classification.proto +++ b/mediapipe/framework/formats/classification.proto @@ -37,6 +37,10 @@ message Classification { // Group of Classification protos. message ClassificationList { repeated Classification classification = 1; + // Optional index of the tensor that produced these classifications. + optional int32 tensor_index = 2; + // Optional name of the tensor that produced these classifications. + optional string tensor_name = 3; } // Group of ClassificationList protos. From 4a6c23a76a70369ba5a1a65789fcfc2d6497cc82 Mon Sep 17 00:00:00 2001 From: MediaPipe Team Date: Fri, 21 Oct 2022 15:47:13 -0700 Subject: [PATCH 26/55] Internal change PiperOrigin-RevId: 482906478 --- mediapipe/tasks/cc/vision/core/BUILD | 11 ++ .../cc/vision/core/base_vision_task_api.h | 59 ++++++++++ .../cc/vision/core/image_processing_options.h | 52 +++++++++ .../tasks/cc/vision/gesture_recognizer/BUILD | 1 + .../gesture_recognizer/gesture_recognizer.cc | 47 +++----- .../gesture_recognizer/gesture_recognizer.h | 36 +++--- .../tasks/cc/vision/image_classifier/BUILD | 1 + .../image_classifier/image_classifier.cc | 37 ++---- .../image_classifier/image_classifier.h | 31 ++--- .../image_classifier/image_classifier_test.cc | 109 +++++++++++++----- .../tasks/cc/vision/object_detector/BUILD | 1 + .../vision/object_detector/object_detector.cc | 47 +++----- .../vision/object_detector/object_detector.h | 29 +++-- .../object_detector/object_detector_test.cc | 22 ++-- 14 files changed, 302 insertions(+), 181 deletions(-) create mode 100644 mediapipe/tasks/cc/vision/core/image_processing_options.h diff --git a/mediapipe/tasks/cc/vision/core/BUILD b/mediapipe/tasks/cc/vision/core/BUILD index 12d789901..e8e197a1d 100644 --- a/mediapipe/tasks/cc/vision/core/BUILD +++ b/mediapipe/tasks/cc/vision/core/BUILD @@ -21,12 +21,23 @@ cc_library( hdrs = ["running_mode.h"], ) +cc_library( + name = "image_processing_options", + hdrs = ["image_processing_options.h"], + deps = [ + "//mediapipe/tasks/cc/components/containers:rect", + ], +) + cc_library( name = "base_vision_task_api", hdrs = ["base_vision_task_api.h"], deps = [ + ":image_processing_options", ":running_mode", "//mediapipe/calculators/core:flow_limiter_calculator", + "//mediapipe/framework/formats:rect_cc_proto", + "//mediapipe/tasks/cc/components/containers:rect", "//mediapipe/tasks/cc/core:base_task_api", "//mediapipe/tasks/cc/core:task_runner", "@com_google_absl//absl/status", diff --git a/mediapipe/tasks/cc/vision/core/base_vision_task_api.h b/mediapipe/tasks/cc/vision/core/base_vision_task_api.h index 4586cbbdd..c3c0a0261 100644 --- a/mediapipe/tasks/cc/vision/core/base_vision_task_api.h +++ b/mediapipe/tasks/cc/vision/core/base_vision_task_api.h @@ -16,15 +16,20 @@ limitations under the License. #ifndef MEDIAPIPE_TASKS_CC_VISION_CORE_BASE_VISION_TASK_API_H_ #define MEDIAPIPE_TASKS_CC_VISION_CORE_BASE_VISION_TASK_API_H_ +#include #include +#include #include #include #include "absl/status/status.h" #include "absl/status/statusor.h" #include "absl/strings/str_cat.h" +#include "mediapipe/framework/formats/rect.pb.h" +#include "mediapipe/tasks/cc/components/containers/rect.h" #include "mediapipe/tasks/cc/core/base_task_api.h" #include "mediapipe/tasks/cc/core/task_runner.h" +#include "mediapipe/tasks/cc/vision/core/image_processing_options.h" #include "mediapipe/tasks/cc/vision/core/running_mode.h" namespace mediapipe { @@ -87,6 +92,60 @@ class BaseVisionTaskApi : public tasks::core::BaseTaskApi { return runner_->Send(std::move(inputs)); } + // Convert from ImageProcessingOptions to NormalizedRect, performing sanity + // checks on-the-fly. If the input ImageProcessingOptions is not present, + // returns a default NormalizedRect covering the whole image with rotation set + // to 0. If 'roi_allowed' is false, an error will be returned if the input + // ImageProcessingOptions has its 'region_or_interest' field set. + static absl::StatusOr ConvertToNormalizedRect( + std::optional options, bool roi_allowed = true) { + mediapipe::NormalizedRect normalized_rect; + normalized_rect.set_rotation(0); + normalized_rect.set_x_center(0.5); + normalized_rect.set_y_center(0.5); + normalized_rect.set_width(1.0); + normalized_rect.set_height(1.0); + if (!options.has_value()) { + return normalized_rect; + } + + if (options->rotation_degrees % 90 != 0) { + return CreateStatusWithPayload( + absl::StatusCode::kInvalidArgument, + "Expected rotation to be a multiple of 90°.", + MediaPipeTasksStatus::kImageProcessingInvalidArgumentError); + } + // Convert to radians counter-clockwise. + normalized_rect.set_rotation(-options->rotation_degrees * M_PI / 180.0); + + if (options->region_of_interest.has_value()) { + if (!roi_allowed) { + return CreateStatusWithPayload( + absl::StatusCode::kInvalidArgument, + "This task doesn't support region-of-interest.", + MediaPipeTasksStatus::kImageProcessingInvalidArgumentError); + } + auto& roi = *options->region_of_interest; + if (roi.left >= roi.right || roi.top >= roi.bottom) { + return CreateStatusWithPayload( + absl::StatusCode::kInvalidArgument, + "Expected Rect with left < right and top < bottom.", + MediaPipeTasksStatus::kImageProcessingInvalidArgumentError); + } + if (roi.left < 0 || roi.top < 0 || roi.right > 1 || roi.bottom > 1) { + return CreateStatusWithPayload( + absl::StatusCode::kInvalidArgument, + "Expected Rect values to be in [0,1].", + MediaPipeTasksStatus::kImageProcessingInvalidArgumentError); + } + normalized_rect.set_x_center((roi.left + roi.right) / 2.0); + normalized_rect.set_y_center((roi.top + roi.bottom) / 2.0); + normalized_rect.set_width(roi.right - roi.left); + normalized_rect.set_height(roi.bottom - roi.top); + } + return normalized_rect; + } + private: RunningMode running_mode_; }; diff --git a/mediapipe/tasks/cc/vision/core/image_processing_options.h b/mediapipe/tasks/cc/vision/core/image_processing_options.h new file mode 100644 index 000000000..7e764c1fe --- /dev/null +++ b/mediapipe/tasks/cc/vision/core/image_processing_options.h @@ -0,0 +1,52 @@ +/* Copyright 2022 The MediaPipe Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef MEDIAPIPE_TASKS_CC_VISION_CORE_IMAGE_PROCESSING_OPTIONS_H_ +#define MEDIAPIPE_TASKS_CC_VISION_CORE_IMAGE_PROCESSING_OPTIONS_H_ + +#include + +#include "mediapipe/tasks/cc/components/containers/rect.h" + +namespace mediapipe { +namespace tasks { +namespace vision { +namespace core { + +// Options for image processing. +// +// If both region-or-interest and rotation are specified, the crop around the +// region-of-interest is extracted first, the the specified rotation is applied +// to the crop. +struct ImageProcessingOptions { + // The optional region-of-interest to crop from the image. If not specified, + // the full image is used. + // + // Coordinates must be in [0,1] with 'left' < 'right' and 'top' < bottom. + std::optional region_of_interest = std::nullopt; + + // The rotation to apply to the image (or cropped region-of-interest), in + // degrees clockwise. + // + // The rotation must be a multiple (positive or negative) of 90°. + int rotation_degrees = 0; +}; + +} // namespace core +} // namespace vision +} // namespace tasks +} // namespace mediapipe + +#endif // MEDIAPIPE_TASKS_CC_VISION_CORE_IMAGE_PROCESSING_OPTIONS_H_ diff --git a/mediapipe/tasks/cc/vision/gesture_recognizer/BUILD b/mediapipe/tasks/cc/vision/gesture_recognizer/BUILD index e5b1f0479..a766c6b3f 100644 --- a/mediapipe/tasks/cc/vision/gesture_recognizer/BUILD +++ b/mediapipe/tasks/cc/vision/gesture_recognizer/BUILD @@ -137,6 +137,7 @@ cc_library( "//mediapipe/tasks/cc/core:utils", "//mediapipe/tasks/cc/core/proto:inference_subgraph_cc_proto", "//mediapipe/tasks/cc/vision/core:base_vision_task_api", + "//mediapipe/tasks/cc/vision/core:image_processing_options", "//mediapipe/tasks/cc/vision/core:running_mode", "//mediapipe/tasks/cc/vision/core:vision_task_api_factory", "//mediapipe/tasks/cc/vision/gesture_recognizer/proto:gesture_recognizer_graph_options_cc_proto", diff --git a/mediapipe/tasks/cc/vision/gesture_recognizer/gesture_recognizer.cc b/mediapipe/tasks/cc/vision/gesture_recognizer/gesture_recognizer.cc index 333edb6fb..000a2e141 100644 --- a/mediapipe/tasks/cc/vision/gesture_recognizer/gesture_recognizer.cc +++ b/mediapipe/tasks/cc/vision/gesture_recognizer/gesture_recognizer.cc @@ -39,6 +39,7 @@ limitations under the License. #include "mediapipe/tasks/cc/core/task_runner.h" #include "mediapipe/tasks/cc/core/utils.h" #include "mediapipe/tasks/cc/vision/core/base_vision_task_api.h" +#include "mediapipe/tasks/cc/vision/core/image_processing_options.h" #include "mediapipe/tasks/cc/vision/core/vision_task_api_factory.h" #include "mediapipe/tasks/cc/vision/gesture_recognizer/proto/gesture_recognizer_graph_options.pb.h" #include "mediapipe/tasks/cc/vision/gesture_recognizer/proto/hand_gesture_recognizer_graph_options.pb.h" @@ -76,31 +77,6 @@ constexpr char kHandWorldLandmarksTag[] = "WORLD_LANDMARKS"; constexpr char kHandWorldLandmarksStreamName[] = "world_landmarks"; constexpr int kMicroSecondsPerMilliSecond = 1000; -// Returns a NormalizedRect filling the whole image. If input is present, its -// rotation is set in the returned NormalizedRect and a check is performed to -// make sure no region-of-interest was provided. Otherwise, rotation is set to -// 0. -absl::StatusOr FillNormalizedRect( - std::optional normalized_rect) { - NormalizedRect result; - if (normalized_rect.has_value()) { - result = *normalized_rect; - } - bool has_coordinates = result.has_x_center() || result.has_y_center() || - result.has_width() || result.has_height(); - if (has_coordinates) { - return CreateStatusWithPayload( - absl::StatusCode::kInvalidArgument, - "GestureRecognizer does not support region-of-interest.", - MediaPipeTasksStatus::kInvalidArgumentError); - } - result.set_x_center(0.5); - result.set_y_center(0.5); - result.set_width(1); - result.set_height(1); - return result; -} - // Creates a MediaPipe graph config that contains a subgraph node of // "mediapipe.tasks.vision.GestureRecognizerGraph". If the task is running // in the live stream mode, a "FlowLimiterCalculator" will be added to limit the @@ -248,15 +224,16 @@ absl::StatusOr> GestureRecognizer::Create( absl::StatusOr GestureRecognizer::Recognize( mediapipe::Image image, - std::optional image_processing_options) { + std::optional image_processing_options) { if (image.UsesGpu()) { return CreateStatusWithPayload( absl::StatusCode::kInvalidArgument, "GPU input images are currently not supported.", MediaPipeTasksStatus::kRunnerUnexpectedInputError); } - ASSIGN_OR_RETURN(NormalizedRect norm_rect, - FillNormalizedRect(image_processing_options)); + ASSIGN_OR_RETURN( + NormalizedRect norm_rect, + ConvertToNormalizedRect(image_processing_options, /*roi_allowed=*/false)); ASSIGN_OR_RETURN( auto output_packets, ProcessImageData( @@ -283,15 +260,16 @@ absl::StatusOr GestureRecognizer::Recognize( absl::StatusOr GestureRecognizer::RecognizeForVideo( mediapipe::Image image, int64 timestamp_ms, - std::optional image_processing_options) { + std::optional image_processing_options) { if (image.UsesGpu()) { return CreateStatusWithPayload( absl::StatusCode::kInvalidArgument, absl::StrCat("GPU input images are currently not supported."), MediaPipeTasksStatus::kRunnerUnexpectedInputError); } - ASSIGN_OR_RETURN(NormalizedRect norm_rect, - FillNormalizedRect(image_processing_options)); + ASSIGN_OR_RETURN( + NormalizedRect norm_rect, + ConvertToNormalizedRect(image_processing_options, /*roi_allowed=*/false)); ASSIGN_OR_RETURN( auto output_packets, ProcessVideoData( @@ -321,15 +299,16 @@ absl::StatusOr GestureRecognizer::RecognizeForVideo( absl::Status GestureRecognizer::RecognizeAsync( mediapipe::Image image, int64 timestamp_ms, - std::optional image_processing_options) { + std::optional image_processing_options) { if (image.UsesGpu()) { return CreateStatusWithPayload( absl::StatusCode::kInvalidArgument, absl::StrCat("GPU input images are currently not supported."), MediaPipeTasksStatus::kRunnerUnexpectedInputError); } - ASSIGN_OR_RETURN(NormalizedRect norm_rect, - FillNormalizedRect(image_processing_options)); + ASSIGN_OR_RETURN( + NormalizedRect norm_rect, + ConvertToNormalizedRect(image_processing_options, /*roi_allowed=*/false)); return SendLiveStreamData( {{kImageInStreamName, MakePacket(std::move(image)) diff --git a/mediapipe/tasks/cc/vision/gesture_recognizer/gesture_recognizer.h b/mediapipe/tasks/cc/vision/gesture_recognizer/gesture_recognizer.h index 750a99797..29c8bea7b 100644 --- a/mediapipe/tasks/cc/vision/gesture_recognizer/gesture_recognizer.h +++ b/mediapipe/tasks/cc/vision/gesture_recognizer/gesture_recognizer.h @@ -23,10 +23,10 @@ limitations under the License. #include "mediapipe/framework/formats/classification.pb.h" #include "mediapipe/framework/formats/image.h" #include "mediapipe/framework/formats/landmark.pb.h" -#include "mediapipe/framework/formats/rect.pb.h" #include "mediapipe/tasks/cc/components/containers/gesture_recognition_result.h" #include "mediapipe/tasks/cc/core/base_options.h" #include "mediapipe/tasks/cc/vision/core/base_vision_task_api.h" +#include "mediapipe/tasks/cc/vision/core/image_processing_options.h" #include "mediapipe/tasks/cc/vision/core/running_mode.h" namespace mediapipe { @@ -129,36 +129,36 @@ class GestureRecognizer : tasks::vision::core::BaseVisionTaskApi { // Only use this method when the GestureRecognizer is created with the image // running mode. // - // image - mediapipe::Image - // Image to perform hand gesture recognition on. - // imageProcessingOptions - std::optional - // If provided, can be used to specify the rotation to apply to the image - // before performing classification, by setting its 'rotation' field in - // radians (e.g. 'M_PI / 2' for a 90° anti-clockwise rotation). Note that - // specifying a region-of-interest using the 'x_center', 'y_center', 'width' - // and 'height' fields is NOT supported and will result in an invalid - // argument error being returned. + // The optional 'image_processing_options' parameter can be used to specify + // the rotation to apply to the image before performing recognition, by + // setting its 'rotation_degrees' field. Note that specifying a + // region-of-interest using the 'region_of_interest' field is NOT supported + // and will result in an invalid argument error being returned. // // The image can be of any size with format RGB or RGBA. // TODO: Describes how the input image will be preprocessed // after the yuv support is implemented. - // TODO: use an ImageProcessingOptions struct instead of - // NormalizedRect. absl::StatusOr Recognize( Image image, - std::optional image_processing_options = + std::optional image_processing_options = std::nullopt); // Performs gesture recognition on the provided video frame. // Only use this method when the GestureRecognizer is created with the video // running mode. // + // The optional 'image_processing_options' parameter can be used to specify + // the rotation to apply to the image before performing recognition, by + // setting its 'rotation_degrees' field. Note that specifying a + // region-of-interest using the 'region_of_interest' field is NOT supported + // and will result in an invalid argument error being returned. + // // The image can be of any size with format RGB or RGBA. It's required to // provide the video frame's timestamp (in milliseconds). The input timestamps // must be monotonically increasing. absl::StatusOr RecognizeForVideo(Image image, int64 timestamp_ms, - std::optional + std::optional image_processing_options = std::nullopt); // Sends live image data to perform gesture recognition, and the results will @@ -171,6 +171,12 @@ class GestureRecognizer : tasks::vision::core::BaseVisionTaskApi { // sent to the gesture recognizer. The input timestamps must be monotonically // increasing. // + // The optional 'image_processing_options' parameter can be used to specify + // the rotation to apply to the image before performing recognition, by + // setting its 'rotation_degrees' field. Note that specifying a + // region-of-interest using the 'region_of_interest' field is NOT supported + // and will result in an invalid argument error being returned. + // // The "result_callback" provides // - A vector of GestureRecognitionResult, each is the recognized results // for a input frame. @@ -180,7 +186,7 @@ class GestureRecognizer : tasks::vision::core::BaseVisionTaskApi { // outside of the callback, callers need to make a copy of the image. // - The input timestamp in milliseconds. absl::Status RecognizeAsync(Image image, int64 timestamp_ms, - std::optional + std::optional image_processing_options = std::nullopt); // Shuts down the GestureRecognizer when all works are done. diff --git a/mediapipe/tasks/cc/vision/image_classifier/BUILD b/mediapipe/tasks/cc/vision/image_classifier/BUILD index dfa77cb96..3d655cd50 100644 --- a/mediapipe/tasks/cc/vision/image_classifier/BUILD +++ b/mediapipe/tasks/cc/vision/image_classifier/BUILD @@ -59,6 +59,7 @@ cc_library( "//mediapipe/tasks/cc/core/proto:base_options_cc_proto", "//mediapipe/tasks/cc/core/proto:inference_subgraph_cc_proto", "//mediapipe/tasks/cc/vision/core:base_vision_task_api", + "//mediapipe/tasks/cc/vision/core:image_processing_options", "//mediapipe/tasks/cc/vision/core:running_mode", "//mediapipe/tasks/cc/vision/core:vision_task_api_factory", "//mediapipe/tasks/cc/vision/image_classifier/proto:image_classifier_graph_options_cc_proto", diff --git a/mediapipe/tasks/cc/vision/image_classifier/image_classifier.cc b/mediapipe/tasks/cc/vision/image_classifier/image_classifier.cc index f3dcdd07d..8a32758f4 100644 --- a/mediapipe/tasks/cc/vision/image_classifier/image_classifier.cc +++ b/mediapipe/tasks/cc/vision/image_classifier/image_classifier.cc @@ -34,6 +34,7 @@ limitations under the License. #include "mediapipe/tasks/cc/core/proto/inference_subgraph.pb.h" #include "mediapipe/tasks/cc/core/task_runner.h" #include "mediapipe/tasks/cc/core/utils.h" +#include "mediapipe/tasks/cc/vision/core/image_processing_options.h" #include "mediapipe/tasks/cc/vision/core/running_mode.h" #include "mediapipe/tasks/cc/vision/core/vision_task_api_factory.h" #include "mediapipe/tasks/cc/vision/image_classifier/proto/image_classifier_graph_options.pb.h" @@ -59,26 +60,6 @@ constexpr int kMicroSecondsPerMilliSecond = 1000; using ::mediapipe::tasks::components::containers::proto::ClassificationResult; using ::mediapipe::tasks::core::PacketMap; -// Returns a NormalizedRect covering the full image if input is not present. -// Otherwise, makes sure the x_center, y_center, width and height are set in -// case only a rotation was provided in the input. -NormalizedRect FillNormalizedRect( - std::optional normalized_rect) { - NormalizedRect result; - if (normalized_rect.has_value()) { - result = *normalized_rect; - } - bool has_coordinates = result.has_x_center() || result.has_y_center() || - result.has_width() || result.has_height(); - if (!has_coordinates) { - result.set_x_center(0.5); - result.set_y_center(0.5); - result.set_width(1); - result.set_height(1); - } - return result; -} - // Creates a MediaPipe graph config that contains a subgraph node of // type "ImageClassifierGraph". If the task is running in the live stream mode, // a "FlowLimiterCalculator" will be added to limit the number of frames in @@ -164,14 +145,16 @@ absl::StatusOr> ImageClassifier::Create( } absl::StatusOr ImageClassifier::Classify( - Image image, std::optional image_processing_options) { + Image image, + std::optional image_processing_options) { if (image.UsesGpu()) { return CreateStatusWithPayload( absl::StatusCode::kInvalidArgument, "GPU input images are currently not supported.", MediaPipeTasksStatus::kRunnerUnexpectedInputError); } - NormalizedRect norm_rect = FillNormalizedRect(image_processing_options); + ASSIGN_OR_RETURN(NormalizedRect norm_rect, + ConvertToNormalizedRect(image_processing_options)); ASSIGN_OR_RETURN( auto output_packets, ProcessImageData( @@ -183,14 +166,15 @@ absl::StatusOr ImageClassifier::Classify( absl::StatusOr ImageClassifier::ClassifyForVideo( Image image, int64 timestamp_ms, - std::optional image_processing_options) { + std::optional image_processing_options) { if (image.UsesGpu()) { return CreateStatusWithPayload( absl::StatusCode::kInvalidArgument, "GPU input images are currently not supported.", MediaPipeTasksStatus::kRunnerUnexpectedInputError); } - NormalizedRect norm_rect = FillNormalizedRect(image_processing_options); + ASSIGN_OR_RETURN(NormalizedRect norm_rect, + ConvertToNormalizedRect(image_processing_options)); ASSIGN_OR_RETURN( auto output_packets, ProcessVideoData( @@ -206,14 +190,15 @@ absl::StatusOr ImageClassifier::ClassifyForVideo( absl::Status ImageClassifier::ClassifyAsync( Image image, int64 timestamp_ms, - std::optional image_processing_options) { + std::optional image_processing_options) { if (image.UsesGpu()) { return CreateStatusWithPayload( absl::StatusCode::kInvalidArgument, "GPU input images are currently not supported.", MediaPipeTasksStatus::kRunnerUnexpectedInputError); } - NormalizedRect norm_rect = FillNormalizedRect(image_processing_options); + ASSIGN_OR_RETURN(NormalizedRect norm_rect, + ConvertToNormalizedRect(image_processing_options)); return SendLiveStreamData( {{kImageInStreamName, MakePacket(std::move(image)) diff --git a/mediapipe/tasks/cc/vision/image_classifier/image_classifier.h b/mediapipe/tasks/cc/vision/image_classifier/image_classifier.h index 5dff06cc7..de69b7994 100644 --- a/mediapipe/tasks/cc/vision/image_classifier/image_classifier.h +++ b/mediapipe/tasks/cc/vision/image_classifier/image_classifier.h @@ -22,11 +22,11 @@ limitations under the License. #include "absl/status/statusor.h" #include "mediapipe/framework/formats/image.h" -#include "mediapipe/framework/formats/rect.pb.h" #include "mediapipe/tasks/cc/components/containers/proto/classifications.pb.h" #include "mediapipe/tasks/cc/components/processors/classifier_options.h" #include "mediapipe/tasks/cc/core/base_options.h" #include "mediapipe/tasks/cc/vision/core/base_vision_task_api.h" +#include "mediapipe/tasks/cc/vision/core/image_processing_options.h" #include "mediapipe/tasks/cc/vision/core/running_mode.h" namespace mediapipe { @@ -109,12 +109,10 @@ class ImageClassifier : tasks::vision::core::BaseVisionTaskApi { // // The optional 'image_processing_options' parameter can be used to specify: // - the rotation to apply to the image before performing classification, by - // setting its 'rotation' field in radians (e.g. 'M_PI / 2' for a 90° - // anti-clockwise rotation). + // setting its 'rotation_degrees' field. // and/or // - the region-of-interest on which to perform classification, by setting its - // 'x_center', 'y_center', 'width' and 'height' fields. If none of these is - // set, they will automatically be set to cover the full image. + // 'region_of_interest' field. If not specified, the full image is used. // If both are specified, the crop around the region-of-interest is extracted // first, then the specified rotation is applied to the crop. // @@ -126,19 +124,17 @@ class ImageClassifier : tasks::vision::core::BaseVisionTaskApi { // YUVToImageCalculator is integrated. absl::StatusOr Classify( mediapipe::Image image, - std::optional image_processing_options = + std::optional image_processing_options = std::nullopt); // Performs image classification on the provided video frame. // // The optional 'image_processing_options' parameter can be used to specify: // - the rotation to apply to the image before performing classification, by - // setting its 'rotation' field in radians (e.g. 'M_PI / 2' for a 90° - // anti-clockwise rotation). + // setting its 'rotation_degrees' field. // and/or // - the region-of-interest on which to perform classification, by setting its - // 'x_center', 'y_center', 'width' and 'height' fields. If none of these is - // set, they will automatically be set to cover the full image. + // 'region_of_interest' field. If not specified, the full image is used. // If both are specified, the crop around the region-of-interest is extracted // first, then the specified rotation is applied to the crop. // @@ -150,7 +146,7 @@ class ImageClassifier : tasks::vision::core::BaseVisionTaskApi { // must be monotonically increasing. absl::StatusOr ClassifyForVideo(mediapipe::Image image, int64 timestamp_ms, - std::optional + std::optional image_processing_options = std::nullopt); // Sends live image data to image classification, and the results will be @@ -158,12 +154,10 @@ class ImageClassifier : tasks::vision::core::BaseVisionTaskApi { // // The optional 'image_processing_options' parameter can be used to specify: // - the rotation to apply to the image before performing classification, by - // setting its 'rotation' field in radians (e.g. 'M_PI / 2' for a 90° - // anti-clockwise rotation). + // setting its 'rotation_degrees' field. // and/or // - the region-of-interest on which to perform classification, by setting its - // 'x_center', 'y_center', 'width' and 'height' fields. If none of these is - // set, they will automatically be set to cover the full image. + // 'region_of_interest' field. If not specified, the full image is used. // If both are specified, the crop around the region-of-interest is extracted // first, then the specified rotation is applied to the crop. // @@ -175,7 +169,7 @@ class ImageClassifier : tasks::vision::core::BaseVisionTaskApi { // sent to the object detector. The input timestamps must be monotonically // increasing. // - // The "result_callback" prvoides + // The "result_callback" provides: // - The classification results as a ClassificationResult object. // - The const reference to the corresponding input image that the image // classifier runs on. Note that the const reference to the image will no @@ -183,12 +177,9 @@ class ImageClassifier : tasks::vision::core::BaseVisionTaskApi { // outside of the callback, callers need to make a copy of the image. // - The input timestamp in milliseconds. absl::Status ClassifyAsync(mediapipe::Image image, int64 timestamp_ms, - std::optional + std::optional image_processing_options = std::nullopt); - // TODO: add Classify() variants taking a region of interest as - // additional argument. - // Shuts down the ImageClassifier when all works are done. absl::Status Close() { return runner_->Close(); } }; diff --git a/mediapipe/tasks/cc/vision/image_classifier/image_classifier_test.cc b/mediapipe/tasks/cc/vision/image_classifier/image_classifier_test.cc index 55830e520..0c45122c0 100644 --- a/mediapipe/tasks/cc/vision/image_classifier/image_classifier_test.cc +++ b/mediapipe/tasks/cc/vision/image_classifier/image_classifier_test.cc @@ -27,7 +27,6 @@ limitations under the License. #include "absl/strings/str_format.h" #include "mediapipe/framework/deps/file_path.h" #include "mediapipe/framework/formats/image.h" -#include "mediapipe/framework/formats/rect.pb.h" #include "mediapipe/framework/port/gmock.h" #include "mediapipe/framework/port/gtest.h" #include "mediapipe/framework/port/parse_text_proto.h" @@ -35,6 +34,8 @@ limitations under the License. #include "mediapipe/tasks/cc/common.h" #include "mediapipe/tasks/cc/components/containers/proto/category.pb.h" #include "mediapipe/tasks/cc/components/containers/proto/classifications.pb.h" +#include "mediapipe/tasks/cc/components/containers/rect.h" +#include "mediapipe/tasks/cc/vision/core/image_processing_options.h" #include "mediapipe/tasks/cc/vision/core/running_mode.h" #include "mediapipe/tasks/cc/vision/utils/image_utils.h" #include "tensorflow/lite/core/api/op_resolver.h" @@ -49,9 +50,11 @@ namespace image_classifier { namespace { using ::mediapipe::file::JoinPath; +using ::mediapipe::tasks::components::containers::Rect; using ::mediapipe::tasks::components::containers::proto::ClassificationEntry; using ::mediapipe::tasks::components::containers::proto::ClassificationResult; using ::mediapipe::tasks::components::containers::proto::Classifications; +using ::mediapipe::tasks::vision::core::ImageProcessingOptions; using ::testing::HasSubstr; using ::testing::Optional; @@ -547,12 +550,9 @@ TEST_F(ImageModeTest, SucceedsWithRegionOfInterest) { options->classifier_options.max_results = 1; MP_ASSERT_OK_AND_ASSIGN(std::unique_ptr image_classifier, ImageClassifier::Create(std::move(options))); - // Crop around the soccer ball. - NormalizedRect image_processing_options; - image_processing_options.set_x_center(0.532); - image_processing_options.set_y_center(0.521); - image_processing_options.set_width(0.164); - image_processing_options.set_height(0.427); + // Region-of-interest around the soccer ball. + Rect roi{/*left=*/0.45, /*top=*/0.3075, /*right=*/0.614, /*bottom=*/0.7345}; + ImageProcessingOptions image_processing_options{roi, /*rotation_degrees=*/0}; MP_ASSERT_OK_AND_ASSIGN(auto results, image_classifier->Classify( image, image_processing_options)); @@ -572,8 +572,8 @@ TEST_F(ImageModeTest, SucceedsWithRotation) { ImageClassifier::Create(std::move(options))); // Specify a 90° anti-clockwise rotation. - NormalizedRect image_processing_options; - image_processing_options.set_rotation(M_PI / 2.0); + ImageProcessingOptions image_processing_options; + image_processing_options.rotation_degrees = -90; MP_ASSERT_OK_AND_ASSIGN(auto results, image_classifier->Classify( image, image_processing_options)); @@ -616,13 +616,10 @@ TEST_F(ImageModeTest, SucceedsWithRegionOfInterestAndRotation) { options->classifier_options.max_results = 1; MP_ASSERT_OK_AND_ASSIGN(std::unique_ptr image_classifier, ImageClassifier::Create(std::move(options))); - // Crop around the chair, with 90° anti-clockwise rotation. - NormalizedRect image_processing_options; - image_processing_options.set_x_center(0.2821); - image_processing_options.set_y_center(0.2406); - image_processing_options.set_width(0.5642); - image_processing_options.set_height(0.1286); - image_processing_options.set_rotation(M_PI / 2.0); + // Region-of-interest around the chair, with 90° anti-clockwise rotation. + Rect roi{/*left=*/0.006, /*top=*/0.1763, /*right=*/0.5702, /*bottom=*/0.3049}; + ImageProcessingOptions image_processing_options{roi, + /*rotation_degrees=*/-90}; MP_ASSERT_OK_AND_ASSIGN(auto results, image_classifier->Classify( image, image_processing_options)); @@ -633,7 +630,7 @@ TEST_F(ImageModeTest, SucceedsWithRegionOfInterestAndRotation) { entries { categories { index: 560 - score: 0.6800408 + score: 0.6522213 category_name: "folding chair" } timestamp_ms: 0 @@ -643,6 +640,69 @@ TEST_F(ImageModeTest, SucceedsWithRegionOfInterestAndRotation) { })pb")); } +// Testing all these once with ImageClassifier. +TEST_F(ImageModeTest, FailsWithInvalidImageProcessingOptions) { + MP_ASSERT_OK_AND_ASSIGN(Image image, + DecodeImageFromFile(JoinPath("./", kTestDataDirectory, + "multi_objects.jpg"))); + auto options = std::make_unique(); + options->base_options.model_asset_path = + JoinPath("./", kTestDataDirectory, kMobileNetFloatWithMetadata); + MP_ASSERT_OK_AND_ASSIGN(std::unique_ptr image_classifier, + ImageClassifier::Create(std::move(options))); + + // Invalid: left > right. + Rect roi{/*left=*/0.9, /*top=*/0, /*right=*/0.1, /*bottom=*/1}; + ImageProcessingOptions image_processing_options{roi, + /*rotation_degrees=*/0}; + auto results = image_classifier->Classify(image, image_processing_options); + EXPECT_EQ(results.status().code(), absl::StatusCode::kInvalidArgument); + EXPECT_THAT(results.status().message(), + HasSubstr("Expected Rect with left < right and top < bottom")); + EXPECT_THAT( + results.status().GetPayload(kMediaPipeTasksPayload), + Optional(absl::Cord(absl::StrCat( + MediaPipeTasksStatus::kImageProcessingInvalidArgumentError)))); + + // Invalid: top > bottom. + roi = {/*left=*/0, /*top=*/0.9, /*right=*/1, /*bottom=*/0.1}; + image_processing_options = {roi, + /*rotation_degrees=*/0}; + results = image_classifier->Classify(image, image_processing_options); + EXPECT_EQ(results.status().code(), absl::StatusCode::kInvalidArgument); + EXPECT_THAT(results.status().message(), + HasSubstr("Expected Rect with left < right and top < bottom")); + EXPECT_THAT( + results.status().GetPayload(kMediaPipeTasksPayload), + Optional(absl::Cord(absl::StrCat( + MediaPipeTasksStatus::kImageProcessingInvalidArgumentError)))); + + // Invalid: coordinates out of [0,1] range. + roi = {/*left=*/-0.1, /*top=*/0, /*right=*/1, /*bottom=*/1}; + image_processing_options = {roi, + /*rotation_degrees=*/0}; + results = image_classifier->Classify(image, image_processing_options); + EXPECT_EQ(results.status().code(), absl::StatusCode::kInvalidArgument); + EXPECT_THAT(results.status().message(), + HasSubstr("Expected Rect values to be in [0,1]")); + EXPECT_THAT( + results.status().GetPayload(kMediaPipeTasksPayload), + Optional(absl::Cord(absl::StrCat( + MediaPipeTasksStatus::kImageProcessingInvalidArgumentError)))); + + // Invalid: rotation not a multiple of 90°. + image_processing_options = {/*region_of_interest=*/std::nullopt, + /*rotation_degrees=*/1}; + results = image_classifier->Classify(image, image_processing_options); + EXPECT_EQ(results.status().code(), absl::StatusCode::kInvalidArgument); + EXPECT_THAT(results.status().message(), + HasSubstr("Expected rotation to be a multiple of 90°")); + EXPECT_THAT( + results.status().GetPayload(kMediaPipeTasksPayload), + Optional(absl::Cord(absl::StrCat( + MediaPipeTasksStatus::kImageProcessingInvalidArgumentError)))); +} + class VideoModeTest : public tflite_shims::testing::Test {}; TEST_F(VideoModeTest, FailsWithCallingWrongMethod) { @@ -732,11 +792,9 @@ TEST_F(VideoModeTest, SucceedsWithRegionOfInterest) { MP_ASSERT_OK_AND_ASSIGN(std::unique_ptr image_classifier, ImageClassifier::Create(std::move(options))); // Crop around the soccer ball. - NormalizedRect image_processing_options; - image_processing_options.set_x_center(0.532); - image_processing_options.set_y_center(0.521); - image_processing_options.set_width(0.164); - image_processing_options.set_height(0.427); + // Region-of-interest around the soccer ball. + Rect roi{/*left=*/0.45, /*top=*/0.3075, /*right=*/0.614, /*bottom=*/0.7345}; + ImageProcessingOptions image_processing_options{roi, /*rotation_degrees=*/0}; for (int i = 0; i < iterations; ++i) { MP_ASSERT_OK_AND_ASSIGN( @@ -877,11 +935,8 @@ TEST_F(LiveStreamModeTest, SucceedsWithRegionOfInterest) { MP_ASSERT_OK_AND_ASSIGN(std::unique_ptr image_classifier, ImageClassifier::Create(std::move(options))); // Crop around the soccer ball. - NormalizedRect image_processing_options; - image_processing_options.set_x_center(0.532); - image_processing_options.set_y_center(0.521); - image_processing_options.set_width(0.164); - image_processing_options.set_height(0.427); + Rect roi{/*left=*/0.45, /*top=*/0.3075, /*right=*/0.614, /*bottom=*/0.7345}; + ImageProcessingOptions image_processing_options{roi, /*rotation_degrees=*/0}; for (int i = 0; i < iterations; ++i) { MP_ASSERT_OK( diff --git a/mediapipe/tasks/cc/vision/object_detector/BUILD b/mediapipe/tasks/cc/vision/object_detector/BUILD index 186909509..8220d8b7f 100644 --- a/mediapipe/tasks/cc/vision/object_detector/BUILD +++ b/mediapipe/tasks/cc/vision/object_detector/BUILD @@ -75,6 +75,7 @@ cc_library( "//mediapipe/tasks/cc/core/proto:base_options_cc_proto", "//mediapipe/tasks/cc/core/proto:inference_subgraph_cc_proto", "//mediapipe/tasks/cc/vision/core:base_vision_task_api", + "//mediapipe/tasks/cc/vision/core:image_processing_options", "//mediapipe/tasks/cc/vision/core:running_mode", "//mediapipe/tasks/cc/vision/core:vision_task_api_factory", "//mediapipe/tasks/cc/vision/object_detector/proto:object_detector_options_cc_proto", diff --git a/mediapipe/tasks/cc/vision/object_detector/object_detector.cc b/mediapipe/tasks/cc/vision/object_detector/object_detector.cc index 9149a3cbe..dd19237ff 100644 --- a/mediapipe/tasks/cc/vision/object_detector/object_detector.cc +++ b/mediapipe/tasks/cc/vision/object_detector/object_detector.cc @@ -34,6 +34,7 @@ limitations under the License. #include "mediapipe/tasks/cc/core/proto/base_options.pb.h" #include "mediapipe/tasks/cc/core/proto/inference_subgraph.pb.h" #include "mediapipe/tasks/cc/core/utils.h" +#include "mediapipe/tasks/cc/vision/core/image_processing_options.h" #include "mediapipe/tasks/cc/vision/core/running_mode.h" #include "mediapipe/tasks/cc/vision/core/vision_task_api_factory.h" #include "mediapipe/tasks/cc/vision/object_detector/proto/object_detector_options.pb.h" @@ -58,31 +59,6 @@ constexpr int kMicroSecondsPerMilliSecond = 1000; using ObjectDetectorOptionsProto = object_detector::proto::ObjectDetectorOptions; -// Returns a NormalizedRect filling the whole image. If input is present, its -// rotation is set in the returned NormalizedRect and a check is performed to -// make sure no region-of-interest was provided. Otherwise, rotation is set to -// 0. -absl::StatusOr FillNormalizedRect( - std::optional normalized_rect) { - NormalizedRect result; - if (normalized_rect.has_value()) { - result = *normalized_rect; - } - bool has_coordinates = result.has_x_center() || result.has_y_center() || - result.has_width() || result.has_height(); - if (has_coordinates) { - return CreateStatusWithPayload( - absl::StatusCode::kInvalidArgument, - "ObjectDetector does not support region-of-interest.", - MediaPipeTasksStatus::kInvalidArgumentError); - } - result.set_x_center(0.5); - result.set_y_center(0.5); - result.set_width(1); - result.set_height(1); - return result; -} - // Creates a MediaPipe graph config that contains a subgraph node of // "mediapipe.tasks.vision.ObjectDetectorGraph". If the task is running in the // live stream mode, a "FlowLimiterCalculator" will be added to limit the @@ -170,15 +146,16 @@ absl::StatusOr> ObjectDetector::Create( absl::StatusOr> ObjectDetector::Detect( mediapipe::Image image, - std::optional image_processing_options) { + std::optional image_processing_options) { if (image.UsesGpu()) { return CreateStatusWithPayload( absl::StatusCode::kInvalidArgument, absl::StrCat("GPU input images are currently not supported."), MediaPipeTasksStatus::kRunnerUnexpectedInputError); } - ASSIGN_OR_RETURN(NormalizedRect norm_rect, - FillNormalizedRect(image_processing_options)); + ASSIGN_OR_RETURN( + NormalizedRect norm_rect, + ConvertToNormalizedRect(image_processing_options, /*roi_allowed=*/false)); ASSIGN_OR_RETURN( auto output_packets, ProcessImageData( @@ -189,15 +166,16 @@ absl::StatusOr> ObjectDetector::Detect( absl::StatusOr> ObjectDetector::DetectForVideo( mediapipe::Image image, int64 timestamp_ms, - std::optional image_processing_options) { + std::optional image_processing_options) { if (image.UsesGpu()) { return CreateStatusWithPayload( absl::StatusCode::kInvalidArgument, absl::StrCat("GPU input images are currently not supported."), MediaPipeTasksStatus::kRunnerUnexpectedInputError); } - ASSIGN_OR_RETURN(NormalizedRect norm_rect, - FillNormalizedRect(image_processing_options)); + ASSIGN_OR_RETURN( + NormalizedRect norm_rect, + ConvertToNormalizedRect(image_processing_options, /*roi_allowed=*/false)); ASSIGN_OR_RETURN( auto output_packets, ProcessVideoData( @@ -212,15 +190,16 @@ absl::StatusOr> ObjectDetector::DetectForVideo( absl::Status ObjectDetector::DetectAsync( Image image, int64 timestamp_ms, - std::optional image_processing_options) { + std::optional image_processing_options) { if (image.UsesGpu()) { return CreateStatusWithPayload( absl::StatusCode::kInvalidArgument, absl::StrCat("GPU input images are currently not supported."), MediaPipeTasksStatus::kRunnerUnexpectedInputError); } - ASSIGN_OR_RETURN(NormalizedRect norm_rect, - FillNormalizedRect(image_processing_options)); + ASSIGN_OR_RETURN( + NormalizedRect norm_rect, + ConvertToNormalizedRect(image_processing_options, /*roi_allowed=*/false)); return SendLiveStreamData( {{kImageInStreamName, MakePacket(std::move(image)) diff --git a/mediapipe/tasks/cc/vision/object_detector/object_detector.h b/mediapipe/tasks/cc/vision/object_detector/object_detector.h index 2e5ed7b8d..44ce68ed9 100644 --- a/mediapipe/tasks/cc/vision/object_detector/object_detector.h +++ b/mediapipe/tasks/cc/vision/object_detector/object_detector.h @@ -27,9 +27,9 @@ limitations under the License. #include "absl/status/statusor.h" #include "mediapipe/framework/formats/detection.pb.h" #include "mediapipe/framework/formats/image.h" -#include "mediapipe/framework/formats/rect.pb.h" #include "mediapipe/tasks/cc/core/base_options.h" #include "mediapipe/tasks/cc/vision/core/base_vision_task_api.h" +#include "mediapipe/tasks/cc/vision/core/image_processing_options.h" #include "mediapipe/tasks/cc/vision/core/running_mode.h" namespace mediapipe { @@ -154,10 +154,9 @@ class ObjectDetector : tasks::vision::core::BaseVisionTaskApi { // after the yuv support is implemented. // // The optional 'image_processing_options' parameter can be used to specify - // the rotation to apply to the image before performing classification, by - // setting its 'rotation' field in radians (e.g. 'M_PI / 2' for a 90° - // anti-clockwise rotation). Note that specifying a region-of-interest using - // the 'x_center', 'y_center', 'width' and 'height' fields is NOT supported + // the rotation to apply to the image before performing detection, by + // setting its 'rotation_degrees' field. Note that specifying a + // region-of-interest using the 'region_of_interest' field is NOT supported // and will result in an invalid argument error being returned. // // For CPU images, the returned bounding boxes are expressed in the @@ -168,7 +167,7 @@ class ObjectDetector : tasks::vision::core::BaseVisionTaskApi { // images after enabling the gpu support in MediaPipe Tasks. absl::StatusOr> Detect( mediapipe::Image image, - std::optional image_processing_options = + std::optional image_processing_options = std::nullopt); // Performs object detection on the provided video frame. @@ -180,10 +179,9 @@ class ObjectDetector : tasks::vision::core::BaseVisionTaskApi { // must be monotonically increasing. // // The optional 'image_processing_options' parameter can be used to specify - // the rotation to apply to the image before performing classification, by - // setting its 'rotation' field in radians (e.g. 'M_PI / 2' for a 90° - // anti-clockwise rotation). Note that specifying a region-of-interest using - // the 'x_center', 'y_center', 'width' and 'height' fields is NOT supported + // the rotation to apply to the image before performing detection, by + // setting its 'rotation_degrees' field. Note that specifying a + // region-of-interest using the 'region_of_interest' field is NOT supported // and will result in an invalid argument error being returned. // // For CPU images, the returned bounding boxes are expressed in the @@ -192,7 +190,7 @@ class ObjectDetector : tasks::vision::core::BaseVisionTaskApi { // underlying image data. absl::StatusOr> DetectForVideo( mediapipe::Image image, int64 timestamp_ms, - std::optional image_processing_options = + std::optional image_processing_options = std::nullopt); // Sends live image data to perform object detection, and the results will be @@ -206,10 +204,9 @@ class ObjectDetector : tasks::vision::core::BaseVisionTaskApi { // increasing. // // The optional 'image_processing_options' parameter can be used to specify - // the rotation to apply to the image before performing classification, by - // setting its 'rotation' field in radians (e.g. 'M_PI / 2' for a 90° - // anti-clockwise rotation). Note that specifying a region-of-interest using - // the 'x_center', 'y_center', 'width' and 'height' fields is NOT supported + // the rotation to apply to the image before performing detection, by + // setting its 'rotation_degrees' field. Note that specifying a + // region-of-interest using the 'region_of_interest' field is NOT supported // and will result in an invalid argument error being returned. // // The "result_callback" provides @@ -223,7 +220,7 @@ class ObjectDetector : tasks::vision::core::BaseVisionTaskApi { // outside of the callback, callers need to make a copy of the image. // - The input timestamp in milliseconds. absl::Status DetectAsync(mediapipe::Image image, int64 timestamp_ms, - std::optional + std::optional image_processing_options = std::nullopt); // Shuts down the ObjectDetector when all works are done. diff --git a/mediapipe/tasks/cc/vision/object_detector/object_detector_test.cc b/mediapipe/tasks/cc/vision/object_detector/object_detector_test.cc index 8db3fa767..1747685dd 100644 --- a/mediapipe/tasks/cc/vision/object_detector/object_detector_test.cc +++ b/mediapipe/tasks/cc/vision/object_detector/object_detector_test.cc @@ -31,11 +31,12 @@ limitations under the License. #include "mediapipe/framework/deps/file_path.h" #include "mediapipe/framework/formats/image.h" #include "mediapipe/framework/formats/location_data.pb.h" -#include "mediapipe/framework/formats/rect.pb.h" #include "mediapipe/framework/port/gmock.h" #include "mediapipe/framework/port/gtest.h" #include "mediapipe/framework/port/parse_text_proto.h" #include "mediapipe/framework/port/status_matchers.h" +#include "mediapipe/tasks/cc/components/containers/rect.h" +#include "mediapipe/tasks/cc/vision/core/image_processing_options.h" #include "mediapipe/tasks/cc/vision/core/running_mode.h" #include "mediapipe/tasks/cc/vision/utils/image_utils.h" #include "tensorflow/lite/c/common.h" @@ -64,6 +65,8 @@ namespace vision { namespace { using ::mediapipe::file::JoinPath; +using ::mediapipe::tasks::components::containers::Rect; +using ::mediapipe::tasks::vision::core::ImageProcessingOptions; using ::testing::HasSubstr; using ::testing::Optional; @@ -532,8 +535,8 @@ TEST_F(ImageModeTest, SucceedsWithRotation) { JoinPath("./", kTestDataDirectory, kMobileSsdWithMetadata); MP_ASSERT_OK_AND_ASSIGN(std::unique_ptr object_detector, ObjectDetector::Create(std::move(options))); - NormalizedRect image_processing_options; - image_processing_options.set_rotation(M_PI / 2.0); + ImageProcessingOptions image_processing_options; + image_processing_options.rotation_degrees = -90; MP_ASSERT_OK_AND_ASSIGN( auto results, object_detector->Detect(image, image_processing_options)); MP_ASSERT_OK(object_detector->Close()); @@ -557,16 +560,17 @@ TEST_F(ImageModeTest, FailsWithRegionOfInterest) { JoinPath("./", kTestDataDirectory, kMobileSsdWithMetadata); MP_ASSERT_OK_AND_ASSIGN(std::unique_ptr object_detector, ObjectDetector::Create(std::move(options))); - NormalizedRect image_processing_options; - image_processing_options.set_x_center(0.5); - image_processing_options.set_y_center(0.5); - image_processing_options.set_width(1.0); - image_processing_options.set_height(1.0); + Rect roi{/*left=*/0.1, /*top=*/0, /*right=*/0.9, /*bottom=*/1}; + ImageProcessingOptions image_processing_options{roi, /*rotation_degrees=*/0}; auto results = object_detector->Detect(image, image_processing_options); EXPECT_EQ(results.status().code(), absl::StatusCode::kInvalidArgument); EXPECT_THAT(results.status().message(), - HasSubstr("ObjectDetector does not support region-of-interest")); + HasSubstr("This task doesn't support region-of-interest")); + EXPECT_THAT( + results.status().GetPayload(kMediaPipeTasksPayload), + Optional(absl::Cord(absl::StrCat( + MediaPipeTasksStatus::kImageProcessingInvalidArgumentError)))); } class VideoModeTest : public tflite_shims::testing::Test {}; From ea1d85d811f81cc10496bbef0a7f57703cc8e7b2 Mon Sep 17 00:00:00 2001 From: MediaPipe Team Date: Fri, 21 Oct 2022 16:41:07 -0700 Subject: [PATCH 27/55] Update model_task_graph to support multiple local model resources. PiperOrigin-RevId: 482917453 --- mediapipe/tasks/cc/core/model_task_graph.cc | 24 +++++++++++++-------- mediapipe/tasks/cc/core/model_task_graph.h | 23 ++++++++++++++------ 2 files changed, 32 insertions(+), 15 deletions(-) diff --git a/mediapipe/tasks/cc/core/model_task_graph.cc b/mediapipe/tasks/cc/core/model_task_graph.cc index 47334b673..66434483b 100644 --- a/mediapipe/tasks/cc/core/model_task_graph.cc +++ b/mediapipe/tasks/cc/core/model_task_graph.cc @@ -156,21 +156,24 @@ absl::StatusOr ModelTaskGraph::GetConfig( } absl::StatusOr ModelTaskGraph::CreateModelResources( - SubgraphContext* sc, std::unique_ptr external_file) { + SubgraphContext* sc, std::unique_ptr external_file, + const std::string tag_suffix) { auto model_resources_cache_service = sc->Service(kModelResourcesCacheService); if (!model_resources_cache_service.IsAvailable()) { - ASSIGN_OR_RETURN(local_model_resources_, + ASSIGN_OR_RETURN(auto local_model_resource, ModelResources::Create("", std::move(external_file))); LOG(WARNING) << "A local ModelResources object is created. Please consider using " "ModelResourcesCacheService to cache the created ModelResources " "object in the CalculatorGraph."; - return local_model_resources_.get(); + local_model_resources_.push_back(std::move(local_model_resource)); + return local_model_resources_.back().get(); } ASSIGN_OR_RETURN( auto op_resolver_packet, model_resources_cache_service.GetObject().GetGraphOpResolverPacket()); - const std::string tag = CreateModelResourcesTag(sc->OriginalNode()); + const std::string tag = + absl::StrCat(CreateModelResourcesTag(sc->OriginalNode()), tag_suffix); ASSIGN_OR_RETURN(auto model_resources, ModelResources::Create(tag, std::move(external_file), op_resolver_packet)); @@ -182,7 +185,8 @@ absl::StatusOr ModelTaskGraph::CreateModelResources( absl::StatusOr ModelTaskGraph::CreateModelAssetBundleResources( - SubgraphContext* sc, std::unique_ptr external_file) { + SubgraphContext* sc, std::unique_ptr external_file, + const std::string tag_suffix) { auto model_resources_cache_service = sc->Service(kModelResourcesCacheService); bool has_file_pointer_meta = external_file->has_file_pointer_meta(); // if external file is set by file pointer, no need to add the model asset @@ -190,7 +194,7 @@ ModelTaskGraph::CreateModelAssetBundleResources( // not owned by this model asset bundle resources. if (!model_resources_cache_service.IsAvailable() || has_file_pointer_meta) { ASSIGN_OR_RETURN( - local_model_asset_bundle_resources_, + auto local_model_asset_bundle_resource, ModelAssetBundleResources::Create("", std::move(external_file))); if (!has_file_pointer_meta) { LOG(WARNING) @@ -198,10 +202,12 @@ ModelTaskGraph::CreateModelAssetBundleResources( "ModelResourcesCacheService to cache the created ModelResources " "object in the CalculatorGraph."; } - return local_model_asset_bundle_resources_.get(); + local_model_asset_bundle_resources_.push_back( + std::move(local_model_asset_bundle_resource)); + return local_model_asset_bundle_resources_.back().get(); } - const std::string tag = - CreateModelAssetBundleResourcesTag(sc->OriginalNode()); + const std::string tag = absl::StrCat( + CreateModelAssetBundleResourcesTag(sc->OriginalNode()), tag_suffix); ASSIGN_OR_RETURN( auto model_bundle_resources, ModelAssetBundleResources::Create(tag, std::move(external_file))); diff --git a/mediapipe/tasks/cc/core/model_task_graph.h b/mediapipe/tasks/cc/core/model_task_graph.h index 5ee70e8f3..50dcc903b 100644 --- a/mediapipe/tasks/cc/core/model_task_graph.h +++ b/mediapipe/tasks/cc/core/model_task_graph.h @@ -19,6 +19,7 @@ limitations under the License. #include #include #include +#include #include "absl/status/statusor.h" #include "absl/strings/str_format.h" @@ -75,9 +76,14 @@ class ModelTaskGraph : public Subgraph { // construction stage. Note that the external file contents will be moved // into the model resources object on creation. The returned model resources // pointer will provide graph authors with the access to the metadata - // extractor and the tflite model. + // extractor and the tflite model. When the model resources graph service is + // available, a tag is generated internally asscoiated with the created model + // resource. If more than one model resources are created in a graph, the + // model resources graph service add the tag_suffix to support multiple + // resources. absl::StatusOr CreateModelResources( - SubgraphContext* sc, std::unique_ptr external_file); + SubgraphContext* sc, std::unique_ptr external_file, + const std::string tag_suffix = ""); // If the model resources graph service is available, creates a model asset // bundle resources object from the subgraph context, and caches the created @@ -103,10 +109,15 @@ class ModelTaskGraph : public Subgraph { // that can only be used in the graph construction stage. Note that the // external file contents will be moved into the model asset bundle resources // object on creation. The returned model asset bundle resources pointer will - // provide graph authors with the access to extracted model files. + // provide graph authors with the access to extracted model files. When the + // model resources graph service is available, a tag is generated internally + // asscoiated with the created model asset bundle resource. If more than one + // model asset bundle resources are created in a graph, the model resources + // graph service add the tag_suffix to support multiple resources. absl::StatusOr CreateModelAssetBundleResources( - SubgraphContext* sc, std::unique_ptr external_file); + SubgraphContext* sc, std::unique_ptr external_file, + const std::string tag_suffix = ""); // Inserts a mediapipe task inference subgraph into the provided // GraphBuilder. The returned node provides the following interfaces to the @@ -124,9 +135,9 @@ class ModelTaskGraph : public Subgraph { api2::builder::Graph& graph) const; private: - std::unique_ptr local_model_resources_; + std::vector> local_model_resources_; - std::unique_ptr + std::vector> local_model_asset_bundle_resources_; }; From 7196db275efae1738bc31f18fb2ed366f1b41b1d Mon Sep 17 00:00:00 2001 From: MediaPipe Team Date: Fri, 21 Oct 2022 17:26:48 -0700 Subject: [PATCH 28/55] Internal change PiperOrigin-RevId: 482925717 --- mediapipe/calculators/tensor/BUILD | 20 +++ mediapipe/tasks/cc/text/text_classifier/BUILD | 23 ++++ .../text_classifier/text_classifier_test.cc | 124 ------------------ mediapipe/tasks/cc/text/tokenizers/BUILD | 4 - mediapipe/tasks/testdata/text/BUILD | 7 +- 5 files changed, 47 insertions(+), 131 deletions(-) diff --git a/mediapipe/calculators/tensor/BUILD b/mediapipe/calculators/tensor/BUILD index e953342da..99b5b3e91 100644 --- a/mediapipe/calculators/tensor/BUILD +++ b/mediapipe/calculators/tensor/BUILD @@ -253,6 +253,26 @@ cc_library( alwayslink = 1, ) +cc_test( + name = "regex_preprocessor_calculator_test", + srcs = ["regex_preprocessor_calculator_test.cc"], + data = ["//mediapipe/tasks/testdata/text:text_classifier_models"], + linkopts = ["-ldl"], + deps = [ + ":regex_preprocessor_calculator", + "//mediapipe/framework:calculator_framework", + "//mediapipe/framework/formats:tensor", + "//mediapipe/framework/port:gtest_main", + "//mediapipe/framework/port:parse_text_proto", + "//mediapipe/framework/tool:sink", + "//mediapipe/tasks/cc/core:utils", + "//mediapipe/tasks/cc/metadata:metadata_extractor", + "@com_google_absl//absl/status", + "@com_google_absl//absl/status:statusor", + "@com_google_absl//absl/strings", + ], +) + cc_library( name = "text_to_tensor_calculator", srcs = ["text_to_tensor_calculator.cc"], diff --git a/mediapipe/tasks/cc/text/text_classifier/BUILD b/mediapipe/tasks/cc/text/text_classifier/BUILD index a85538631..336b1bb45 100644 --- a/mediapipe/tasks/cc/text/text_classifier/BUILD +++ b/mediapipe/tasks/cc/text/text_classifier/BUILD @@ -63,6 +63,29 @@ cc_library( ], ) +cc_test( + name = "text_classifier_test", + srcs = ["text_classifier_test.cc"], + data = [ + "//mediapipe/tasks/testdata/text:bert_text_classifier_models", + "//mediapipe/tasks/testdata/text:text_classifier_models", + ], + deps = [ + ":text_classifier", + ":text_classifier_test_utils", + "//mediapipe/framework/deps:file_path", + "//mediapipe/framework/port:gtest_main", + "//mediapipe/tasks/cc:common", + "//mediapipe/tasks/cc/components/containers/proto:classifications_cc_proto", + "@com_google_absl//absl/flags:flag", + "@com_google_absl//absl/status", + "@com_google_absl//absl/status:statusor", + "@com_google_absl//absl/strings", + "@com_google_absl//absl/strings:cord", + "@org_tensorflow//tensorflow/lite/core/shims:cc_shims_test_util", + ], +) + cc_library( name = "text_classifier_test_utils", srcs = ["text_classifier_test_utils.cc"], diff --git a/mediapipe/tasks/cc/text/text_classifier/text_classifier_test.cc b/mediapipe/tasks/cc/text/text_classifier/text_classifier_test.cc index 5b33f6606..62837be8c 100644 --- a/mediapipe/tasks/cc/text/text_classifier/text_classifier_test.cc +++ b/mediapipe/tasks/cc/text/text_classifier/text_classifier_test.cc @@ -49,9 +49,6 @@ using ::mediapipe::tasks::kMediaPipeTasksPayload; using ::mediapipe::tasks::components::containers::proto::ClassificationResult; using ::testing::HasSubstr; using ::testing::Optional; -using ::testing::proto::Approximately; -using ::testing::proto::IgnoringRepeatedFieldOrdering; -using ::testing::proto::Partially; constexpr float kEpsilon = 0.001; constexpr int kMaxSeqLen = 128; @@ -110,127 +107,6 @@ TEST_F(TextClassifierTest, CreateSucceedsWithRegexModel) { MP_ASSERT_OK(TextClassifier::Create(std::move(options))); } -TEST_F(TextClassifierTest, TextClassifierWithBert) { - auto options = std::make_unique(); - options->base_options.model_asset_path = GetFullPath(kTestBertModelPath); - MP_ASSERT_OK_AND_ASSIGN(std::unique_ptr classifier, - TextClassifier::Create(std::move(options))); - MP_ASSERT_OK_AND_ASSIGN( - ClassificationResult negative_result, - classifier->Classify("unflinchingly bleak and desperate")); - ASSERT_THAT(negative_result, - Partially(IgnoringRepeatedFieldOrdering(Approximately( - EqualsProto(R"pb( - classifications { - entries { - categories { category_name: "negative" score: 0.956 } - categories { category_name: "positive" score: 0.044 } - } - } - )pb"), - kEpsilon)))); - - MP_ASSERT_OK_AND_ASSIGN( - ClassificationResult positive_result, - classifier->Classify("it's a charming and often affecting journey")); - ASSERT_THAT(positive_result, - Partially(IgnoringRepeatedFieldOrdering(Approximately( - EqualsProto(R"pb( - classifications { - entries { - categories { category_name: "negative" score: 0.0 } - categories { category_name: "positive" score: 1.0 } - } - } - )pb"), - kEpsilon)))); - MP_ASSERT_OK(classifier->Close()); -} - -TEST_F(TextClassifierTest, TextClassifierWithIntInputs) { - auto options = std::make_unique(); - options->base_options.model_asset_path = GetFullPath(kTestRegexModelPath); - MP_ASSERT_OK_AND_ASSIGN(std::unique_ptr classifier, - TextClassifier::Create(std::move(options))); - MP_ASSERT_OK_AND_ASSIGN(ClassificationResult negative_result, - classifier->Classify("What a waste of my time.")); - ASSERT_THAT(negative_result, - Partially(IgnoringRepeatedFieldOrdering(Approximately( - EqualsProto(R"pb( - classifications { - entries { - categories { category_name: "Negative" score: 0.813 } - categories { category_name: "Positive" score: 0.187 } - } - } - )pb"), - kEpsilon)))); - - MP_ASSERT_OK_AND_ASSIGN( - ClassificationResult positive_result, - classifier->Classify("This is the best movie I’ve seen in recent years. " - "Strongly recommend it!")); - ASSERT_THAT(positive_result, - Partially(IgnoringRepeatedFieldOrdering(Approximately( - EqualsProto(R"pb( - classifications { - entries { - categories { category_name: "Negative" score: 0.487 } - categories { category_name: "Positive" score: 0.513 } - } - } - )pb"), - kEpsilon)))); - MP_ASSERT_OK(classifier->Close()); -} - -TEST_F(TextClassifierTest, TextClassifierWithStringToBool) { - auto options = std::make_unique(); - options->base_options.model_asset_path = GetFullPath(kStringToBoolModelPath); - options->base_options.op_resolver = CreateCustomResolver(); - MP_ASSERT_OK_AND_ASSIGN(std::unique_ptr classifier, - TextClassifier::Create(std::move(options))); - MP_ASSERT_OK_AND_ASSIGN(ClassificationResult result, - classifier->Classify("hello")); - ASSERT_THAT(result, Partially(IgnoringRepeatedFieldOrdering(EqualsProto(R"pb( - classifications { - entries { - categories { index: 1 score: 1 } - categories { index: 0 score: 1 } - categories { index: 2 score: 0 } - } - } - )pb")))); -} - -TEST_F(TextClassifierTest, BertLongPositive) { - std::stringstream ss_for_positive_review; - ss_for_positive_review - << "it's a charming and often affecting journey and this is a long"; - for (int i = 0; i < kMaxSeqLen; ++i) { - ss_for_positive_review << " long"; - } - ss_for_positive_review << " movie review"; - auto options = std::make_unique(); - options->base_options.model_asset_path = GetFullPath(kTestBertModelPath); - MP_ASSERT_OK_AND_ASSIGN(std::unique_ptr classifier, - TextClassifier::Create(std::move(options))); - MP_ASSERT_OK_AND_ASSIGN(ClassificationResult result, - classifier->Classify(ss_for_positive_review.str())); - ASSERT_THAT(result, - Partially(IgnoringRepeatedFieldOrdering(Approximately( - EqualsProto(R"pb( - classifications { - entries { - categories { category_name: "negative" score: 0.014 } - categories { category_name: "positive" score: 0.986 } - } - } - )pb"), - kEpsilon)))); - MP_ASSERT_OK(classifier->Close()); -} - } // namespace } // namespace text_classifier } // namespace text diff --git a/mediapipe/tasks/cc/text/tokenizers/BUILD b/mediapipe/tasks/cc/text/tokenizers/BUILD index 048c7021d..e76d943c5 100644 --- a/mediapipe/tasks/cc/text/tokenizers/BUILD +++ b/mediapipe/tasks/cc/text/tokenizers/BUILD @@ -73,8 +73,6 @@ cc_library( ], ) -# TODO: This test fails in OSS - cc_library( name = "tokenizer_utils", srcs = ["tokenizer_utils.cc"], @@ -97,8 +95,6 @@ cc_library( ], ) -# TODO: This test fails in OSS - cc_library( name = "regex_tokenizer", srcs = [ diff --git a/mediapipe/tasks/testdata/text/BUILD b/mediapipe/tasks/testdata/text/BUILD index 6cce5ae41..14999a03e 100644 --- a/mediapipe/tasks/testdata/text/BUILD +++ b/mediapipe/tasks/testdata/text/BUILD @@ -76,9 +76,10 @@ filegroup( filegroup( name = "text_classifier_models", - srcs = glob([ - "test_model_text_classifier*.tflite", - ]), + srcs = [ + "test_model_text_classifier_bool_output.tflite", + "test_model_text_classifier_with_regex_tokenizer.tflite", + ], ) filegroup( From abed54ea30f9d14bbdf2fb2da1544f685d2346e4 Mon Sep 17 00:00:00 2001 From: Jiuqiang Tang Date: Fri, 21 Oct 2022 18:15:22 -0700 Subject: [PATCH 29/55] Rename the mediapipe java image container from Image to MPImage. PiperOrigin-RevId: 482933122 --- .../framework/AndroidPacketCreator.java | 20 ++-- .../framework/image/BitmapExtractor.java | 14 +-- .../framework/image/BitmapImageBuilder.java | 12 +-- .../framework/image/BitmapImageContainer.java | 20 ++-- .../framework/image/ByteBufferExtractor.java | 97 ++++++++++--------- .../image/ByteBufferImageBuilder.java | 16 +-- .../image/ByteBufferImageContainer.java | 22 ++--- .../image/{Image.java => MPImage.java} | 58 +++++------ ...mageConsumer.java => MPImageConsumer.java} | 8 +- ...geContainer.java => MPImageContainer.java} | 4 +- ...mageProducer.java => MPImageProducer.java} | 8 +- ...Properties.java => MPImageProperties.java} | 36 +++---- .../framework/image/MediaImageBuilder.java | 15 +-- .../framework/image/MediaImageContainer.java | 31 +++--- .../framework/image/MediaImageExtractor.java | 19 ++-- .../examples/objectdetector/MainActivity.java | 7 +- .../ObjectDetectionResultImageView.java | 8 +- .../tasks/vision/core/BaseVisionTaskApi.java | 26 ++--- .../gesturerecognizer/GestureRecognizer.java | 26 ++--- .../imageclassifier/ImageClassifier.java | 36 +++---- .../vision/objectdetector/ObjectDetector.java | 25 ++--- .../GestureRecognizerTest.java | 10 +- .../imageclassifier/ImageClassifierTest.java | 12 +-- .../objectdetector/ObjectDetectorTest.java | 10 +- 24 files changed, 273 insertions(+), 267 deletions(-) rename mediapipe/java/com/google/mediapipe/framework/image/{Image.java => MPImage.java} (76%) rename mediapipe/java/com/google/mediapipe/framework/image/{ImageConsumer.java => MPImageConsumer.java} (87%) rename mediapipe/java/com/google/mediapipe/framework/image/{ImageContainer.java => MPImageContainer.java} (93%) rename mediapipe/java/com/google/mediapipe/framework/image/{ImageProducer.java => MPImageProducer.java} (75%) rename mediapipe/java/com/google/mediapipe/framework/image/{ImageProperties.java => MPImageProperties.java} (63%) diff --git a/mediapipe/java/com/google/mediapipe/framework/AndroidPacketCreator.java b/mediapipe/java/com/google/mediapipe/framework/AndroidPacketCreator.java index 4af9dae78..e3a878f91 100644 --- a/mediapipe/java/com/google/mediapipe/framework/AndroidPacketCreator.java +++ b/mediapipe/java/com/google/mediapipe/framework/AndroidPacketCreator.java @@ -17,8 +17,8 @@ package com.google.mediapipe.framework; import android.graphics.Bitmap; import com.google.mediapipe.framework.image.BitmapExtractor; import com.google.mediapipe.framework.image.ByteBufferExtractor; -import com.google.mediapipe.framework.image.Image; -import com.google.mediapipe.framework.image.ImageProperties; +import com.google.mediapipe.framework.image.MPImage; +import com.google.mediapipe.framework.image.MPImageProperties; import java.nio.ByteBuffer; // TODO: use Preconditions in this file. @@ -60,24 +60,24 @@ public class AndroidPacketCreator extends PacketCreator { } /** - * Creates an Image packet from an {@link Image}. + * Creates a MediaPipe Image packet from a {@link MPImage}. * *

The ImageContainerType must be IMAGE_CONTAINER_BYTEBUFFER or IMAGE_CONTAINER_BITMAP. */ - public Packet createImage(Image image) { + public Packet createImage(MPImage image) { // TODO: Choose the best storage from multiple containers. - ImageProperties properties = image.getContainedImageProperties().get(0); - if (properties.getStorageType() == Image.STORAGE_TYPE_BYTEBUFFER) { + MPImageProperties properties = image.getContainedImageProperties().get(0); + if (properties.getStorageType() == MPImage.STORAGE_TYPE_BYTEBUFFER) { ByteBuffer buffer = ByteBufferExtractor.extract(image); int numChannels = 0; switch (properties.getImageFormat()) { - case Image.IMAGE_FORMAT_RGBA: + case MPImage.IMAGE_FORMAT_RGBA: numChannels = 4; break; - case Image.IMAGE_FORMAT_RGB: + case MPImage.IMAGE_FORMAT_RGB: numChannels = 3; break; - case Image.IMAGE_FORMAT_ALPHA: + case MPImage.IMAGE_FORMAT_ALPHA: numChannels = 1; break; default: // fall out @@ -90,7 +90,7 @@ public class AndroidPacketCreator extends PacketCreator { int height = image.getHeight(); return createImage(buffer, width, height, numChannels); } - if (properties.getImageFormat() == Image.STORAGE_TYPE_BITMAP) { + if (properties.getImageFormat() == MPImage.STORAGE_TYPE_BITMAP) { Bitmap bitmap = BitmapExtractor.extract(image); if (bitmap.getConfig() != Bitmap.Config.ARGB_8888) { throw new UnsupportedOperationException("bitmap must use ARGB_8888 config."); diff --git a/mediapipe/java/com/google/mediapipe/framework/image/BitmapExtractor.java b/mediapipe/java/com/google/mediapipe/framework/image/BitmapExtractor.java index 4c6cebd4d..d6f50bf30 100644 --- a/mediapipe/java/com/google/mediapipe/framework/image/BitmapExtractor.java +++ b/mediapipe/java/com/google/mediapipe/framework/image/BitmapExtractor.java @@ -18,29 +18,29 @@ package com.google.mediapipe.framework.image; import android.graphics.Bitmap; /** - * Utility for extracting {@link android.graphics.Bitmap} from {@link Image}. + * Utility for extracting {@link android.graphics.Bitmap} from {@link MPImage}. * - *

Currently it only supports {@link Image} with {@link Image#STORAGE_TYPE_BITMAP}, otherwise + *

Currently it only supports {@link MPImage} with {@link MPImage#STORAGE_TYPE_BITMAP}, otherwise * {@link IllegalArgumentException} will be thrown. */ public final class BitmapExtractor { /** - * Extracts a {@link android.graphics.Bitmap} from an {@link Image}. + * Extracts a {@link android.graphics.Bitmap} from a {@link MPImage}. * * @param image the image to extract {@link android.graphics.Bitmap} from. - * @return the {@link android.graphics.Bitmap} stored in {@link Image} + * @return the {@link android.graphics.Bitmap} stored in {@link MPImage} * @throws IllegalArgumentException when the extraction requires unsupported format or data type * conversions. */ - public static Bitmap extract(Image image) { - ImageContainer imageContainer = image.getContainer(Image.STORAGE_TYPE_BITMAP); + public static Bitmap extract(MPImage image) { + MPImageContainer imageContainer = image.getContainer(MPImage.STORAGE_TYPE_BITMAP); if (imageContainer != null) { return ((BitmapImageContainer) imageContainer).getBitmap(); } else { // TODO: Support ByteBuffer -> Bitmap conversion. throw new IllegalArgumentException( - "Extracting Bitmap from an Image created by objects other than Bitmap is not" + "Extracting Bitmap from a MPImage created by objects other than Bitmap is not" + " supported"); } } diff --git a/mediapipe/java/com/google/mediapipe/framework/image/BitmapImageBuilder.java b/mediapipe/java/com/google/mediapipe/framework/image/BitmapImageBuilder.java index ea2ca6b1f..988cdf542 100644 --- a/mediapipe/java/com/google/mediapipe/framework/image/BitmapImageBuilder.java +++ b/mediapipe/java/com/google/mediapipe/framework/image/BitmapImageBuilder.java @@ -22,7 +22,7 @@ import android.provider.MediaStore; import java.io.IOException; /** - * Builds {@link Image} from {@link android.graphics.Bitmap}. + * Builds {@link MPImage} from {@link android.graphics.Bitmap}. * *

You can pass in either mutable or immutable {@link android.graphics.Bitmap}. However once * {@link android.graphics.Bitmap} is passed in, to keep data integrity you shouldn't modify content @@ -49,7 +49,7 @@ public class BitmapImageBuilder { } /** - * Creates the builder to build {@link Image} from a file. + * Creates the builder to build {@link MPImage} from a file. * * @param context the application context. * @param uri the path to the resource file. @@ -58,15 +58,15 @@ public class BitmapImageBuilder { this(MediaStore.Images.Media.getBitmap(context.getContentResolver(), uri)); } - /** Sets value for {@link Image#getTimestamp()}. */ + /** Sets value for {@link MPImage#getTimestamp()}. */ BitmapImageBuilder setTimestamp(long timestamp) { this.timestamp = timestamp; return this; } - /** Builds an {@link Image} instance. */ - public Image build() { - return new Image( + /** Builds a {@link MPImage} instance. */ + public MPImage build() { + return new MPImage( new BitmapImageContainer(bitmap), timestamp, bitmap.getWidth(), bitmap.getHeight()); } } diff --git a/mediapipe/java/com/google/mediapipe/framework/image/BitmapImageContainer.java b/mediapipe/java/com/google/mediapipe/framework/image/BitmapImageContainer.java index 0457e1e9b..6fbcac214 100644 --- a/mediapipe/java/com/google/mediapipe/framework/image/BitmapImageContainer.java +++ b/mediapipe/java/com/google/mediapipe/framework/image/BitmapImageContainer.java @@ -16,19 +16,19 @@ limitations under the License. package com.google.mediapipe.framework.image; import android.graphics.Bitmap; -import com.google.mediapipe.framework.image.Image.ImageFormat; +import com.google.mediapipe.framework.image.MPImage.MPImageFormat; -class BitmapImageContainer implements ImageContainer { +class BitmapImageContainer implements MPImageContainer { private final Bitmap bitmap; - private final ImageProperties properties; + private final MPImageProperties properties; public BitmapImageContainer(Bitmap bitmap) { this.bitmap = bitmap; this.properties = - ImageProperties.builder() + MPImageProperties.builder() .setImageFormat(convertFormatCode(bitmap.getConfig())) - .setStorageType(Image.STORAGE_TYPE_BITMAP) + .setStorageType(MPImage.STORAGE_TYPE_BITMAP) .build(); } @@ -37,7 +37,7 @@ class BitmapImageContainer implements ImageContainer { } @Override - public ImageProperties getImageProperties() { + public MPImageProperties getImageProperties() { return properties; } @@ -46,15 +46,15 @@ class BitmapImageContainer implements ImageContainer { bitmap.recycle(); } - @ImageFormat + @MPImageFormat static int convertFormatCode(Bitmap.Config config) { switch (config) { case ALPHA_8: - return Image.IMAGE_FORMAT_ALPHA; + return MPImage.IMAGE_FORMAT_ALPHA; case ARGB_8888: - return Image.IMAGE_FORMAT_RGBA; + return MPImage.IMAGE_FORMAT_RGBA; default: - return Image.IMAGE_FORMAT_UNKNOWN; + return MPImage.IMAGE_FORMAT_UNKNOWN; } } } diff --git a/mediapipe/java/com/google/mediapipe/framework/image/ByteBufferExtractor.java b/mediapipe/java/com/google/mediapipe/framework/image/ByteBufferExtractor.java index a0e8c3dff..748a10667 100644 --- a/mediapipe/java/com/google/mediapipe/framework/image/ByteBufferExtractor.java +++ b/mediapipe/java/com/google/mediapipe/framework/image/ByteBufferExtractor.java @@ -21,45 +21,45 @@ import android.graphics.Bitmap.Config; import android.os.Build.VERSION; import android.os.Build.VERSION_CODES; import com.google.auto.value.AutoValue; -import com.google.mediapipe.framework.image.Image.ImageFormat; +import com.google.mediapipe.framework.image.MPImage.MPImageFormat; import java.nio.ByteBuffer; import java.nio.ByteOrder; import java.util.Locale; /** - * Utility for extracting {@link ByteBuffer} from {@link Image}. + * Utility for extracting {@link ByteBuffer} from {@link MPImage}. * - *

Currently it only supports {@link Image} with {@link Image#STORAGE_TYPE_BYTEBUFFER}, otherwise - * {@link IllegalArgumentException} will be thrown. + *

Currently it only supports {@link MPImage} with {@link MPImage#STORAGE_TYPE_BYTEBUFFER}, + * otherwise {@link IllegalArgumentException} will be thrown. */ public class ByteBufferExtractor { /** - * Extracts a {@link ByteBuffer} from an {@link Image}. + * Extracts a {@link ByteBuffer} from a {@link MPImage}. * *

The returned {@link ByteBuffer} is a read-only view, with the first available {@link - * ImageProperties} whose storage type is {@code Image.STORAGE_TYPE_BYTEBUFFER}. + * MPImageProperties} whose storage type is {@code MPImage.STORAGE_TYPE_BYTEBUFFER}. * - * @see Image#getContainedImageProperties() + * @see MPImage#getContainedImageProperties() * @return A read-only {@link ByteBuffer}. * @throws IllegalArgumentException when the image doesn't contain a {@link ByteBuffer} storage. */ @SuppressLint("SwitchIntDef") - public static ByteBuffer extract(Image image) { - ImageContainer container = image.getContainer(); + public static ByteBuffer extract(MPImage image) { + MPImageContainer container = image.getContainer(); switch (container.getImageProperties().getStorageType()) { - case Image.STORAGE_TYPE_BYTEBUFFER: + case MPImage.STORAGE_TYPE_BYTEBUFFER: ByteBufferImageContainer byteBufferImageContainer = (ByteBufferImageContainer) container; return byteBufferImageContainer.getByteBuffer().asReadOnlyBuffer(); default: throw new IllegalArgumentException( - "Extract ByteBuffer from an Image created by objects other than Bytebuffer is not" + "Extract ByteBuffer from a MPImage created by objects other than Bytebuffer is not" + " supported"); } } /** - * Extracts a readonly {@link ByteBuffer} in given {@code targetFormat} from an {@link Image}. + * Extracts a readonly {@link ByteBuffer} in given {@code targetFormat} from a {@link MPImage}. * *

Format conversion spec: * @@ -70,26 +70,26 @@ public class ByteBufferExtractor { * * @param image the image to extract buffer from. * @param targetFormat the image format of the result bytebuffer. - * @return the readonly {@link ByteBuffer} stored in {@link Image} + * @return the readonly {@link ByteBuffer} stored in {@link MPImage} * @throws IllegalArgumentException when the extraction requires unsupported format or data type * conversions. */ - static ByteBuffer extract(Image image, @ImageFormat int targetFormat) { - ImageContainer container; - ImageProperties byteBufferProperties = - ImageProperties.builder() - .setStorageType(Image.STORAGE_TYPE_BYTEBUFFER) + static ByteBuffer extract(MPImage image, @MPImageFormat int targetFormat) { + MPImageContainer container; + MPImageProperties byteBufferProperties = + MPImageProperties.builder() + .setStorageType(MPImage.STORAGE_TYPE_BYTEBUFFER) .setImageFormat(targetFormat) .build(); if ((container = image.getContainer(byteBufferProperties)) != null) { ByteBufferImageContainer byteBufferImageContainer = (ByteBufferImageContainer) container; return byteBufferImageContainer.getByteBuffer().asReadOnlyBuffer(); - } else if ((container = image.getContainer(Image.STORAGE_TYPE_BYTEBUFFER)) != null) { + } else if ((container = image.getContainer(MPImage.STORAGE_TYPE_BYTEBUFFER)) != null) { ByteBufferImageContainer byteBufferImageContainer = (ByteBufferImageContainer) container; - @ImageFormat int sourceFormat = byteBufferImageContainer.getImageFormat(); + @MPImageFormat int sourceFormat = byteBufferImageContainer.getImageFormat(); return convertByteBuffer(byteBufferImageContainer.getByteBuffer(), sourceFormat, targetFormat) .asReadOnlyBuffer(); - } else if ((container = image.getContainer(Image.STORAGE_TYPE_BITMAP)) != null) { + } else if ((container = image.getContainer(MPImage.STORAGE_TYPE_BITMAP)) != null) { BitmapImageContainer bitmapImageContainer = (BitmapImageContainer) container; ByteBuffer byteBuffer = extractByteBufferFromBitmap(bitmapImageContainer.getBitmap(), targetFormat) @@ -98,85 +98,89 @@ public class ByteBufferExtractor { return byteBuffer; } else { throw new IllegalArgumentException( - "Extracting ByteBuffer from an Image created by objects other than Bitmap or" + "Extracting ByteBuffer from a MPImage created by objects other than Bitmap or" + " Bytebuffer is not supported"); } } - /** A wrapper for a {@link ByteBuffer} and its {@link ImageFormat}. */ + /** A wrapper for a {@link ByteBuffer} and its {@link MPImageFormat}. */ @AutoValue abstract static class Result { - /** Gets the {@link ByteBuffer} in the result of {@link ByteBufferExtractor#extract(Image)}. */ + /** + * Gets the {@link ByteBuffer} in the result of {@link ByteBufferExtractor#extract(MPImage)}. + */ public abstract ByteBuffer buffer(); - /** Gets the {@link ImageFormat} in the result of {@link ByteBufferExtractor#extract(Image)}. */ - @ImageFormat + /** + * Gets the {@link MPImageFormat} in the result of {@link ByteBufferExtractor#extract(MPImage)}. + */ + @MPImageFormat public abstract int format(); - static Result create(ByteBuffer buffer, @ImageFormat int imageFormat) { + static Result create(ByteBuffer buffer, @MPImageFormat int imageFormat) { return new AutoValue_ByteBufferExtractor_Result(buffer, imageFormat); } } /** - * Extracts a {@link ByteBuffer} in any available {@code imageFormat} from an {@link Image}. + * Extracts a {@link ByteBuffer} in any available {@code imageFormat} from a {@link MPImage}. * *

It will make the best effort to return an already existed {@link ByteBuffer} to avoid copy. * - * @return the readonly {@link ByteBuffer} stored in {@link Image} + * @return the readonly {@link ByteBuffer} stored in {@link MPImage} * @throws IllegalArgumentException when {@code image} doesn't contain {@link ByteBuffer} with * given {@code imageFormat} */ - static Result extractInRecommendedFormat(Image image) { - ImageContainer container; - if ((container = image.getContainer(Image.STORAGE_TYPE_BITMAP)) != null) { + static Result extractInRecommendedFormat(MPImage image) { + MPImageContainer container; + if ((container = image.getContainer(MPImage.STORAGE_TYPE_BITMAP)) != null) { Bitmap bitmap = ((BitmapImageContainer) container).getBitmap(); - @ImageFormat int format = adviseImageFormat(bitmap); + @MPImageFormat int format = adviseImageFormat(bitmap); Result result = Result.create(extractByteBufferFromBitmap(bitmap, format).asReadOnlyBuffer(), format); boolean unused = image.addContainer(new ByteBufferImageContainer(result.buffer(), result.format())); return result; - } else if ((container = image.getContainer(Image.STORAGE_TYPE_BYTEBUFFER)) != null) { + } else if ((container = image.getContainer(MPImage.STORAGE_TYPE_BYTEBUFFER)) != null) { ByteBufferImageContainer byteBufferImageContainer = (ByteBufferImageContainer) container; return Result.create( byteBufferImageContainer.getByteBuffer().asReadOnlyBuffer(), byteBufferImageContainer.getImageFormat()); } else { throw new IllegalArgumentException( - "Extract ByteBuffer from an Image created by objects other than Bitmap or Bytebuffer" + "Extract ByteBuffer from a MPImage created by objects other than Bitmap or Bytebuffer" + " is not supported"); } } - @ImageFormat + @MPImageFormat private static int adviseImageFormat(Bitmap bitmap) { if (bitmap.getConfig() == Config.ARGB_8888) { - return Image.IMAGE_FORMAT_RGBA; + return MPImage.IMAGE_FORMAT_RGBA; } else { throw new IllegalArgumentException( String.format( - "Extracting ByteBuffer from an Image created by a Bitmap in config %s is not" + "Extracting ByteBuffer from a MPImage created by a Bitmap in config %s is not" + " supported", bitmap.getConfig())); } } private static ByteBuffer extractByteBufferFromBitmap( - Bitmap bitmap, @ImageFormat int imageFormat) { + Bitmap bitmap, @MPImageFormat int imageFormat) { if (VERSION.SDK_INT >= VERSION_CODES.JELLY_BEAN_MR1 && bitmap.isPremultiplied()) { throw new IllegalArgumentException( - "Extracting ByteBuffer from an Image created by a premultiplied Bitmap is not" + "Extracting ByteBuffer from a MPImage created by a premultiplied Bitmap is not" + " supported"); } if (bitmap.getConfig() == Config.ARGB_8888) { - if (imageFormat == Image.IMAGE_FORMAT_RGBA) { + if (imageFormat == MPImage.IMAGE_FORMAT_RGBA) { ByteBuffer buffer = ByteBuffer.allocateDirect(bitmap.getByteCount()); bitmap.copyPixelsToBuffer(buffer); buffer.rewind(); return buffer; - } else if (imageFormat == Image.IMAGE_FORMAT_RGB) { + } else if (imageFormat == MPImage.IMAGE_FORMAT_RGB) { // TODO: Try Use RGBA buffer to create RGB buffer which might be faster. int w = bitmap.getWidth(); int h = bitmap.getHeight(); @@ -196,14 +200,14 @@ public class ByteBufferExtractor { } throw new IllegalArgumentException( String.format( - "Extracting ByteBuffer from an Image created by Bitmap and convert from %s to format" + "Extracting ByteBuffer from a MPImage created by Bitmap and convert from %s to format" + " %d is not supported", bitmap.getConfig(), imageFormat)); } private static ByteBuffer convertByteBuffer( - ByteBuffer source, @ImageFormat int sourceFormat, @ImageFormat int targetFormat) { - if (sourceFormat == Image.IMAGE_FORMAT_RGB && targetFormat == Image.IMAGE_FORMAT_RGBA) { + ByteBuffer source, @MPImageFormat int sourceFormat, @MPImageFormat int targetFormat) { + if (sourceFormat == MPImage.IMAGE_FORMAT_RGB && targetFormat == MPImage.IMAGE_FORMAT_RGBA) { ByteBuffer target = ByteBuffer.allocateDirect(source.capacity() / 3 * 4); // Extend the buffer when the target is longer than the source. Use two cursors and sweep the // array reversely to convert in-place. @@ -221,7 +225,8 @@ public class ByteBufferExtractor { target.put(array, 0, target.capacity()); target.rewind(); return target; - } else if (sourceFormat == Image.IMAGE_FORMAT_RGBA && targetFormat == Image.IMAGE_FORMAT_RGB) { + } else if (sourceFormat == MPImage.IMAGE_FORMAT_RGBA + && targetFormat == MPImage.IMAGE_FORMAT_RGB) { ByteBuffer target = ByteBuffer.allocateDirect(source.capacity() / 4 * 3); // Shrink the buffer when the target is shorter than the source. Use two cursors and sweep the // array to convert in-place. diff --git a/mediapipe/java/com/google/mediapipe/framework/image/ByteBufferImageBuilder.java b/mediapipe/java/com/google/mediapipe/framework/image/ByteBufferImageBuilder.java index 07871da38..a650e4c33 100644 --- a/mediapipe/java/com/google/mediapipe/framework/image/ByteBufferImageBuilder.java +++ b/mediapipe/java/com/google/mediapipe/framework/image/ByteBufferImageBuilder.java @@ -15,11 +15,11 @@ limitations under the License. package com.google.mediapipe.framework.image; -import com.google.mediapipe.framework.image.Image.ImageFormat; +import com.google.mediapipe.framework.image.MPImage.MPImageFormat; import java.nio.ByteBuffer; /** - * Builds a {@link Image} from a {@link ByteBuffer}. + * Builds a {@link MPImage} from a {@link ByteBuffer}. * *

You can pass in either mutable or immutable {@link ByteBuffer}. However once {@link * ByteBuffer} is passed in, to keep data integrity you shouldn't modify content in it. @@ -32,7 +32,7 @@ public class ByteBufferImageBuilder { private final ByteBuffer buffer; private final int width; private final int height; - @ImageFormat private final int imageFormat; + @MPImageFormat private final int imageFormat; // Optional fields. private long timestamp; @@ -49,7 +49,7 @@ public class ByteBufferImageBuilder { * @param imageFormat how the data encode the image. */ public ByteBufferImageBuilder( - ByteBuffer byteBuffer, int width, int height, @ImageFormat int imageFormat) { + ByteBuffer byteBuffer, int width, int height, @MPImageFormat int imageFormat) { this.buffer = byteBuffer; this.width = width; this.height = height; @@ -58,14 +58,14 @@ public class ByteBufferImageBuilder { this.timestamp = 0; } - /** Sets value for {@link Image#getTimestamp()}. */ + /** Sets value for {@link MPImage#getTimestamp()}. */ ByteBufferImageBuilder setTimestamp(long timestamp) { this.timestamp = timestamp; return this; } - /** Builds an {@link Image} instance. */ - public Image build() { - return new Image(new ByteBufferImageContainer(buffer, imageFormat), timestamp, width, height); + /** Builds a {@link MPImage} instance. */ + public MPImage build() { + return new MPImage(new ByteBufferImageContainer(buffer, imageFormat), timestamp, width, height); } } diff --git a/mediapipe/java/com/google/mediapipe/framework/image/ByteBufferImageContainer.java b/mediapipe/java/com/google/mediapipe/framework/image/ByteBufferImageContainer.java index 1c24c1dfd..82dbe32ca 100644 --- a/mediapipe/java/com/google/mediapipe/framework/image/ByteBufferImageContainer.java +++ b/mediapipe/java/com/google/mediapipe/framework/image/ByteBufferImageContainer.java @@ -15,21 +15,19 @@ limitations under the License. package com.google.mediapipe.framework.image; -import com.google.mediapipe.framework.image.Image.ImageFormat; +import com.google.mediapipe.framework.image.MPImage.MPImageFormat; import java.nio.ByteBuffer; -class ByteBufferImageContainer implements ImageContainer { +class ByteBufferImageContainer implements MPImageContainer { private final ByteBuffer buffer; - private final ImageProperties properties; + private final MPImageProperties properties; - public ByteBufferImageContainer( - ByteBuffer buffer, - @ImageFormat int imageFormat) { + public ByteBufferImageContainer(ByteBuffer buffer, @MPImageFormat int imageFormat) { this.buffer = buffer; this.properties = - ImageProperties.builder() - .setStorageType(Image.STORAGE_TYPE_BYTEBUFFER) + MPImageProperties.builder() + .setStorageType(MPImage.STORAGE_TYPE_BYTEBUFFER) .setImageFormat(imageFormat) .build(); } @@ -39,14 +37,12 @@ class ByteBufferImageContainer implements ImageContainer { } @Override - public ImageProperties getImageProperties() { + public MPImageProperties getImageProperties() { return properties; } - /** - * Returns the image format. - */ - @ImageFormat + /** Returns the image format. */ + @MPImageFormat public int getImageFormat() { return properties.getImageFormat(); } diff --git a/mediapipe/java/com/google/mediapipe/framework/image/Image.java b/mediapipe/java/com/google/mediapipe/framework/image/MPImage.java similarity index 76% rename from mediapipe/java/com/google/mediapipe/framework/image/Image.java rename to mediapipe/java/com/google/mediapipe/framework/image/MPImage.java index 49e63bcc0..e17cc4d30 100644 --- a/mediapipe/java/com/google/mediapipe/framework/image/Image.java +++ b/mediapipe/java/com/google/mediapipe/framework/image/MPImage.java @@ -29,10 +29,10 @@ import java.util.Map.Entry; /** * The wrapper class for image objects. * - *

{@link Image} is designed to be an immutable image container, which could be shared + *

{@link MPImage} is designed to be an immutable image container, which could be shared * cross-platforms. * - *

To construct an {@link Image}, use the provided builders: + *

To construct a {@link MPImage}, use the provided builders: * *

    *
  • {@link ByteBufferImageBuilder} @@ -40,7 +40,7 @@ import java.util.Map.Entry; *
  • {@link MediaImageBuilder} *
* - *

{@link Image} uses reference counting to maintain internal storage. When it is created the + *

{@link MPImage} uses reference counting to maintain internal storage. When it is created the * reference count is 1. Developer can call {@link #close()} to reduce reference count to release * internal storage earlier, otherwise Java garbage collection will release the storage eventually. * @@ -53,7 +53,7 @@ import java.util.Map.Entry; *

  • {@link MediaImageExtractor} * */ -public class Image implements Closeable { +public class MPImage implements Closeable { /** Specifies the image format of an image. */ @IntDef({ @@ -69,7 +69,7 @@ public class Image implements Closeable { IMAGE_FORMAT_JPEG, }) @Retention(RetentionPolicy.SOURCE) - public @interface ImageFormat {} + public @interface MPImageFormat {} public static final int IMAGE_FORMAT_UNKNOWN = 0; public static final int IMAGE_FORMAT_RGBA = 1; @@ -98,14 +98,14 @@ public class Image implements Closeable { public static final int STORAGE_TYPE_IMAGE_PROXY = 4; /** - * Returns a list of supported image properties for this {@link Image}. + * Returns a list of supported image properties for this {@link MPImage}. * - *

    Currently {@link Image} only support single storage type so the size of return list will + *

    Currently {@link MPImage} only support single storage type so the size of return list will * always be 1. * - * @see ImageProperties + * @see MPImageProperties */ - public List getContainedImageProperties() { + public List getContainedImageProperties() { return Collections.singletonList(getContainer().getImageProperties()); } @@ -124,7 +124,7 @@ public class Image implements Closeable { return height; } - /** Acquires a reference on this {@link Image}. This will increase the reference count by 1. */ + /** Acquires a reference on this {@link MPImage}. This will increase the reference count by 1. */ private synchronized void acquire() { referenceCount += 1; } @@ -132,7 +132,7 @@ public class Image implements Closeable { /** * Removes a reference that was previously acquired or init. * - *

    When {@link Image} is created, it has 1 reference count. + *

    When {@link MPImage} is created, it has 1 reference count. * *

    When the reference count becomes 0, it will release the resource under the hood. */ @@ -141,24 +141,24 @@ public class Image implements Closeable { public synchronized void close() { referenceCount -= 1; if (referenceCount == 0) { - for (ImageContainer imageContainer : containerMap.values()) { + for (MPImageContainer imageContainer : containerMap.values()) { imageContainer.close(); } } } - /** Advanced API access for {@link Image}. */ + /** Advanced API access for {@link MPImage}. */ static final class Internal { /** - * Acquires a reference on this {@link Image}. This will increase the reference count by 1. + * Acquires a reference on this {@link MPImage}. This will increase the reference count by 1. * *

    This method is more useful for image consumer to acquire a reference so image resource * will not be closed accidentally. As image creator, normal developer doesn't need to call this * method. * - *

    The reference count is 1 when {@link Image} is created. Developer can call {@link - * #close()} to indicate it doesn't need this {@link Image} anymore. + *

    The reference count is 1 when {@link MPImage} is created. Developer can call {@link + * #close()} to indicate it doesn't need this {@link MPImage} anymore. * * @see #close() */ @@ -166,10 +166,10 @@ public class Image implements Closeable { image.acquire(); } - private final Image image; + private final MPImage image; - // Only Image creates the internal helper. - private Internal(Image image) { + // Only MPImage creates the internal helper. + private Internal(MPImage image) { this.image = image; } } @@ -179,15 +179,15 @@ public class Image implements Closeable { return new Internal(this); } - private final Map containerMap; + private final Map containerMap; private final long timestamp; private final int width; private final int height; private int referenceCount; - /** Constructs an {@link Image} with a built container. */ - Image(ImageContainer container, long timestamp, int width, int height) { + /** Constructs a {@link MPImage} with a built container. */ + MPImage(MPImageContainer container, long timestamp, int width, int height) { this.containerMap = new HashMap<>(); containerMap.put(container.getImageProperties(), container); this.timestamp = timestamp; @@ -201,10 +201,10 @@ public class Image implements Closeable { * * @return the current container. */ - ImageContainer getContainer() { + MPImageContainer getContainer() { // According to the design, in the future we will support multiple containers in one image. // Currently just return the original container. - // TODO: Cache multiple containers in Image. + // TODO: Cache multiple containers in MPImage. return containerMap.values().iterator().next(); } @@ -214,8 +214,8 @@ public class Image implements Closeable { *

    If there are multiple containers with required {@code storageType}, returns the first one. */ @Nullable - ImageContainer getContainer(@StorageType int storageType) { - for (Entry entry : containerMap.entrySet()) { + MPImageContainer getContainer(@StorageType int storageType) { + for (Entry entry : containerMap.entrySet()) { if (entry.getKey().getStorageType() == storageType) { return entry.getValue(); } @@ -225,13 +225,13 @@ public class Image implements Closeable { /** Gets container from required {@code imageProperties}. Returns {@code null} if non existed. */ @Nullable - ImageContainer getContainer(ImageProperties imageProperties) { + MPImageContainer getContainer(MPImageProperties imageProperties) { return containerMap.get(imageProperties); } /** Adds a new container if it doesn't exist. Returns {@code true} if it succeeds. */ - boolean addContainer(ImageContainer container) { - ImageProperties imageProperties = container.getImageProperties(); + boolean addContainer(MPImageContainer container) { + MPImageProperties imageProperties = container.getImageProperties(); if (containerMap.containsKey(imageProperties)) { return false; } diff --git a/mediapipe/java/com/google/mediapipe/framework/image/ImageConsumer.java b/mediapipe/java/com/google/mediapipe/framework/image/MPImageConsumer.java similarity index 87% rename from mediapipe/java/com/google/mediapipe/framework/image/ImageConsumer.java rename to mediapipe/java/com/google/mediapipe/framework/image/MPImageConsumer.java index 18eed68c6..f9f343e93 100644 --- a/mediapipe/java/com/google/mediapipe/framework/image/ImageConsumer.java +++ b/mediapipe/java/com/google/mediapipe/framework/image/MPImageConsumer.java @@ -14,14 +14,14 @@ limitations under the License. ==============================================================================*/ package com.google.mediapipe.framework.image; -/** Lightweight abstraction for an object that can receive {@link Image} */ -public interface ImageConsumer { +/** Lightweight abstraction for an object that can receive {@link MPImage} */ +public interface MPImageConsumer { /** - * Called when an {@link Image} is available. + * Called when a {@link MPImage} is available. * *

    The argument is only guaranteed to be available until this method returns. if you need to * extend its life time, acquire it, then release it when done. */ - void onNewImage(Image image); + void onNewMPImage(MPImage image); } diff --git a/mediapipe/java/com/google/mediapipe/framework/image/ImageContainer.java b/mediapipe/java/com/google/mediapipe/framework/image/MPImageContainer.java similarity index 93% rename from mediapipe/java/com/google/mediapipe/framework/image/ImageContainer.java rename to mediapipe/java/com/google/mediapipe/framework/image/MPImageContainer.java index 727ec0893..674073b5b 100644 --- a/mediapipe/java/com/google/mediapipe/framework/image/ImageContainer.java +++ b/mediapipe/java/com/google/mediapipe/framework/image/MPImageContainer.java @@ -16,9 +16,9 @@ limitations under the License. package com.google.mediapipe.framework.image; /** Manages internal image data storage. The interface is package-private. */ -interface ImageContainer { +interface MPImageContainer { /** Returns the properties of the contained image. */ - ImageProperties getImageProperties(); + MPImageProperties getImageProperties(); /** Close the image container and releases the image resource inside. */ void close(); diff --git a/mediapipe/java/com/google/mediapipe/framework/image/ImageProducer.java b/mediapipe/java/com/google/mediapipe/framework/image/MPImageProducer.java similarity index 75% rename from mediapipe/java/com/google/mediapipe/framework/image/ImageProducer.java rename to mediapipe/java/com/google/mediapipe/framework/image/MPImageProducer.java index 4f3641d6f..9783935d4 100644 --- a/mediapipe/java/com/google/mediapipe/framework/image/ImageProducer.java +++ b/mediapipe/java/com/google/mediapipe/framework/image/MPImageProducer.java @@ -14,9 +14,9 @@ limitations under the License. ==============================================================================*/ package com.google.mediapipe.framework.image; -/** Lightweight abstraction for an object that produce {@link Image} */ -public interface ImageProducer { +/** Lightweight abstraction for an object that produce {@link MPImage} */ +public interface MPImageProducer { - /** Sets the consumer that receives the {@link Image}. */ - void setImageConsumer(ImageConsumer imageConsumer); + /** Sets the consumer that receives the {@link MPImage}. */ + void setMPImageConsumer(MPImageConsumer imageConsumer); } diff --git a/mediapipe/java/com/google/mediapipe/framework/image/ImageProperties.java b/mediapipe/java/com/google/mediapipe/framework/image/MPImageProperties.java similarity index 63% rename from mediapipe/java/com/google/mediapipe/framework/image/ImageProperties.java rename to mediapipe/java/com/google/mediapipe/framework/image/MPImageProperties.java index e33b33e7f..6005ce77b 100644 --- a/mediapipe/java/com/google/mediapipe/framework/image/ImageProperties.java +++ b/mediapipe/java/com/google/mediapipe/framework/image/MPImageProperties.java @@ -17,25 +17,25 @@ package com.google.mediapipe.framework.image; import com.google.auto.value.AutoValue; import com.google.auto.value.extension.memoized.Memoized; -import com.google.mediapipe.framework.image.Image.ImageFormat; -import com.google.mediapipe.framework.image.Image.StorageType; +import com.google.mediapipe.framework.image.MPImage.MPImageFormat; +import com.google.mediapipe.framework.image.MPImage.StorageType; /** Groups a set of properties to describe how an image is stored. */ @AutoValue -public abstract class ImageProperties { +public abstract class MPImageProperties { /** * Gets the pixel format of the image. * - * @see Image.ImageFormat + * @see MPImage.MPImageFormat */ - @ImageFormat + @MPImageFormat public abstract int getImageFormat(); /** * Gets the storage type of the image. * - * @see Image.StorageType + * @see MPImage.StorageType */ @StorageType public abstract int getStorageType(); @@ -45,36 +45,36 @@ public abstract class ImageProperties { public abstract int hashCode(); /** - * Creates a builder of {@link ImageProperties}. + * Creates a builder of {@link MPImageProperties}. * - * @see ImageProperties.Builder + * @see MPImageProperties.Builder */ static Builder builder() { - return new AutoValue_ImageProperties.Builder(); + return new AutoValue_MPImageProperties.Builder(); } - /** Builds a {@link ImageProperties}. */ + /** Builds a {@link MPImageProperties}. */ @AutoValue.Builder abstract static class Builder { /** - * Sets the {@link Image.ImageFormat}. + * Sets the {@link MPImage.MPImageFormat}. * - * @see ImageProperties#getImageFormat + * @see MPImageProperties#getImageFormat */ - abstract Builder setImageFormat(@ImageFormat int value); + abstract Builder setImageFormat(@MPImageFormat int value); /** - * Sets the {@link Image.StorageType}. + * Sets the {@link MPImage.StorageType}. * - * @see ImageProperties#getStorageType + * @see MPImageProperties#getStorageType */ abstract Builder setStorageType(@StorageType int value); - /** Builds the {@link ImageProperties}. */ - abstract ImageProperties build(); + /** Builds the {@link MPImageProperties}. */ + abstract MPImageProperties build(); } // Hide the constructor. - ImageProperties() {} + MPImageProperties() {} } diff --git a/mediapipe/java/com/google/mediapipe/framework/image/MediaImageBuilder.java b/mediapipe/java/com/google/mediapipe/framework/image/MediaImageBuilder.java index e351a87fd..9e719715d 100644 --- a/mediapipe/java/com/google/mediapipe/framework/image/MediaImageBuilder.java +++ b/mediapipe/java/com/google/mediapipe/framework/image/MediaImageBuilder.java @@ -15,11 +15,12 @@ limitations under the License. package com.google.mediapipe.framework.image; +import android.media.Image; import android.os.Build.VERSION_CODES; import androidx.annotation.RequiresApi; /** - * Builds {@link Image} from {@link android.media.Image}. + * Builds {@link MPImage} from {@link android.media.Image}. * *

    Once {@link android.media.Image} is passed in, to keep data integrity you shouldn't modify * content in it. @@ -30,7 +31,7 @@ import androidx.annotation.RequiresApi; public class MediaImageBuilder { // Mandatory fields. - private final android.media.Image mediaImage; + private final Image mediaImage; // Optional fields. private long timestamp; @@ -40,20 +41,20 @@ public class MediaImageBuilder { * * @param mediaImage image data object. */ - public MediaImageBuilder(android.media.Image mediaImage) { + public MediaImageBuilder(Image mediaImage) { this.mediaImage = mediaImage; this.timestamp = 0; } - /** Sets value for {@link Image#getTimestamp()}. */ + /** Sets value for {@link MPImage#getTimestamp()}. */ MediaImageBuilder setTimestamp(long timestamp) { this.timestamp = timestamp; return this; } - /** Builds an {@link Image} instance. */ - public Image build() { - return new Image( + /** Builds a {@link MPImage} instance. */ + public MPImage build() { + return new MPImage( new MediaImageContainer(mediaImage), timestamp, mediaImage.getWidth(), diff --git a/mediapipe/java/com/google/mediapipe/framework/image/MediaImageContainer.java b/mediapipe/java/com/google/mediapipe/framework/image/MediaImageContainer.java index 144b64def..864c76df2 100644 --- a/mediapipe/java/com/google/mediapipe/framework/image/MediaImageContainer.java +++ b/mediapipe/java/com/google/mediapipe/framework/image/MediaImageContainer.java @@ -15,33 +15,34 @@ limitations under the License. package com.google.mediapipe.framework.image; +import android.media.Image; import android.os.Build; import android.os.Build.VERSION; import android.os.Build.VERSION_CODES; import androidx.annotation.RequiresApi; -import com.google.mediapipe.framework.image.Image.ImageFormat; +import com.google.mediapipe.framework.image.MPImage.MPImageFormat; @RequiresApi(VERSION_CODES.KITKAT) -class MediaImageContainer implements ImageContainer { +class MediaImageContainer implements MPImageContainer { - private final android.media.Image mediaImage; - private final ImageProperties properties; + private final Image mediaImage; + private final MPImageProperties properties; - public MediaImageContainer(android.media.Image mediaImage) { + public MediaImageContainer(Image mediaImage) { this.mediaImage = mediaImage; this.properties = - ImageProperties.builder() - .setStorageType(Image.STORAGE_TYPE_MEDIA_IMAGE) + MPImageProperties.builder() + .setStorageType(MPImage.STORAGE_TYPE_MEDIA_IMAGE) .setImageFormat(convertFormatCode(mediaImage.getFormat())) .build(); } - public android.media.Image getImage() { + public Image getImage() { return mediaImage; } @Override - public ImageProperties getImageProperties() { + public MPImageProperties getImageProperties() { return properties; } @@ -50,24 +51,24 @@ class MediaImageContainer implements ImageContainer { mediaImage.close(); } - @ImageFormat + @MPImageFormat static int convertFormatCode(int graphicsFormat) { // We only cover the format mentioned in // https://developer.android.com/reference/android/media/Image#getFormat() if (VERSION.SDK_INT >= Build.VERSION_CODES.M) { if (graphicsFormat == android.graphics.ImageFormat.FLEX_RGBA_8888) { - return Image.IMAGE_FORMAT_RGBA; + return MPImage.IMAGE_FORMAT_RGBA; } else if (graphicsFormat == android.graphics.ImageFormat.FLEX_RGB_888) { - return Image.IMAGE_FORMAT_RGB; + return MPImage.IMAGE_FORMAT_RGB; } } switch (graphicsFormat) { case android.graphics.ImageFormat.JPEG: - return Image.IMAGE_FORMAT_JPEG; + return MPImage.IMAGE_FORMAT_JPEG; case android.graphics.ImageFormat.YUV_420_888: - return Image.IMAGE_FORMAT_YUV_420_888; + return MPImage.IMAGE_FORMAT_YUV_420_888; default: - return Image.IMAGE_FORMAT_UNKNOWN; + return MPImage.IMAGE_FORMAT_UNKNOWN; } } } diff --git a/mediapipe/java/com/google/mediapipe/framework/image/MediaImageExtractor.java b/mediapipe/java/com/google/mediapipe/framework/image/MediaImageExtractor.java index 718cb471f..76bb5a5ec 100644 --- a/mediapipe/java/com/google/mediapipe/framework/image/MediaImageExtractor.java +++ b/mediapipe/java/com/google/mediapipe/framework/image/MediaImageExtractor.java @@ -15,13 +15,14 @@ limitations under the License. package com.google.mediapipe.framework.image; +import android.media.Image; import android.os.Build.VERSION_CODES; import androidx.annotation.RequiresApi; /** - * Utility for extracting {@link android.media.Image} from {@link Image}. + * Utility for extracting {@link android.media.Image} from {@link MPImage}. * - *

    Currently it only supports {@link Image} with {@link Image#STORAGE_TYPE_MEDIA_IMAGE}, + *

    Currently it only supports {@link MPImage} with {@link MPImage#STORAGE_TYPE_MEDIA_IMAGE}, * otherwise {@link IllegalArgumentException} will be thrown. */ @RequiresApi(VERSION_CODES.KITKAT) @@ -30,20 +31,20 @@ public class MediaImageExtractor { private MediaImageExtractor() {} /** - * Extracts a {@link android.media.Image} from an {@link Image}. Currently it only works for - * {@link Image} that built from {@link MediaImageBuilder}. + * Extracts a {@link android.media.Image} from a {@link MPImage}. Currently it only works for + * {@link MPImage} that built from {@link MediaImageBuilder}. * * @param image the image to extract {@link android.media.Image} from. - * @return {@link android.media.Image} that stored in {@link Image}. + * @return {@link android.media.Image} that stored in {@link MPImage}. * @throws IllegalArgumentException if the extraction failed. */ - public static android.media.Image extract(Image image) { - ImageContainer container; - if ((container = image.getContainer(Image.STORAGE_TYPE_MEDIA_IMAGE)) != null) { + public static Image extract(MPImage image) { + MPImageContainer container; + if ((container = image.getContainer(MPImage.STORAGE_TYPE_MEDIA_IMAGE)) != null) { return ((MediaImageContainer) container).getImage(); } throw new IllegalArgumentException( - "Extract Media Image from an Image created by objects other than Media Image" + "Extract Media Image from a MPImage created by objects other than Media Image" + " is not supported"); } } diff --git a/mediapipe/tasks/examples/android/objectdetector/src/main/java/com/google/mediapipe/tasks/examples/objectdetector/MainActivity.java b/mediapipe/tasks/examples/android/objectdetector/src/main/java/com/google/mediapipe/tasks/examples/objectdetector/MainActivity.java index 7f7ec1389..11c8c1837 100644 --- a/mediapipe/tasks/examples/android/objectdetector/src/main/java/com/google/mediapipe/tasks/examples/objectdetector/MainActivity.java +++ b/mediapipe/tasks/examples/android/objectdetector/src/main/java/com/google/mediapipe/tasks/examples/objectdetector/MainActivity.java @@ -30,7 +30,7 @@ import androidx.activity.result.contract.ActivityResultContracts; import androidx.exifinterface.media.ExifInterface; // ContentResolver dependency import com.google.mediapipe.framework.image.BitmapImageBuilder; -import com.google.mediapipe.framework.image.Image; +import com.google.mediapipe.framework.image.MPImage; import com.google.mediapipe.tasks.core.BaseOptions; import com.google.mediapipe.tasks.vision.core.RunningMode; import com.google.mediapipe.tasks.vision.objectdetector.ObjectDetectionResult; @@ -98,7 +98,7 @@ public class MainActivity extends AppCompatActivity { Log.e(TAG, "Bitmap rotation error:" + e); } if (bitmap != null) { - Image image = new BitmapImageBuilder(bitmap).build(); + MPImage image = new BitmapImageBuilder(bitmap).build(); ObjectDetectionResult detectionResult = objectDetector.detect(image); imageView.setData(image, detectionResult); runOnUiThread(() -> imageView.update()); @@ -144,7 +144,8 @@ public class MainActivity extends AppCompatActivity { MediaMetadataRetriever.METADATA_KEY_VIDEO_FRAME_COUNT)); long frameIntervalMs = duration / numFrames; for (int i = 0; i < numFrames; ++i) { - Image image = new BitmapImageBuilder(metaRetriever.getFrameAtIndex(i)).build(); + MPImage image = + new BitmapImageBuilder(metaRetriever.getFrameAtIndex(i)).build(); ObjectDetectionResult detectionResult = objectDetector.detectForVideo(image, frameIntervalMs * i); // Currently only annotates the detection result on the first video frame and diff --git a/mediapipe/tasks/examples/android/objectdetector/src/main/java/com/google/mediapipe/tasks/examples/objectdetector/ObjectDetectionResultImageView.java b/mediapipe/tasks/examples/android/objectdetector/src/main/java/com/google/mediapipe/tasks/examples/objectdetector/ObjectDetectionResultImageView.java index 94a4a90dc..283e48857 100644 --- a/mediapipe/tasks/examples/android/objectdetector/src/main/java/com/google/mediapipe/tasks/examples/objectdetector/ObjectDetectionResultImageView.java +++ b/mediapipe/tasks/examples/android/objectdetector/src/main/java/com/google/mediapipe/tasks/examples/objectdetector/ObjectDetectionResultImageView.java @@ -22,7 +22,7 @@ import android.graphics.Matrix; import android.graphics.Paint; import androidx.appcompat.widget.AppCompatImageView; import com.google.mediapipe.framework.image.BitmapExtractor; -import com.google.mediapipe.framework.image.Image; +import com.google.mediapipe.framework.image.MPImage; import com.google.mediapipe.tasks.components.containers.Detection; import com.google.mediapipe.tasks.vision.objectdetector.ObjectDetectionResult; @@ -40,12 +40,12 @@ public class ObjectDetectionResultImageView extends AppCompatImageView { } /** - * Sets an {@link Image} and an {@link ObjectDetectionResult} to render. + * Sets a {@link MPImage} and an {@link ObjectDetectionResult} to render. * - * @param image an {@link Image} object for annotation. + * @param image a {@link MPImage} object for annotation. * @param result an {@link ObjectDetectionResult} object that contains the detection result. */ - public void setData(Image image, ObjectDetectionResult result) { + public void setData(MPImage image, ObjectDetectionResult result) { if (image == null || result == null) { return; } diff --git a/mediapipe/tasks/java/com/google/mediapipe/tasks/vision/core/BaseVisionTaskApi.java b/mediapipe/tasks/java/com/google/mediapipe/tasks/vision/core/BaseVisionTaskApi.java index 7ab8e75a1..49dab408c 100644 --- a/mediapipe/tasks/java/com/google/mediapipe/tasks/vision/core/BaseVisionTaskApi.java +++ b/mediapipe/tasks/java/com/google/mediapipe/tasks/vision/core/BaseVisionTaskApi.java @@ -19,7 +19,7 @@ import com.google.mediapipe.formats.proto.RectProto.NormalizedRect; import com.google.mediapipe.framework.MediaPipeException; import com.google.mediapipe.framework.Packet; import com.google.mediapipe.framework.ProtoUtil; -import com.google.mediapipe.framework.image.Image; +import com.google.mediapipe.framework.image.MPImage; import com.google.mediapipe.tasks.core.TaskResult; import com.google.mediapipe.tasks.core.TaskRunner; import java.util.HashMap; @@ -77,11 +77,11 @@ public class BaseVisionTaskApi implements AutoCloseable { * A synchronous method to process single image inputs. The call blocks the current thread until a * failure status or a successful result is returned. * - * @param image a MediaPipe {@link Image} object for processing. + * @param image a MediaPipe {@link MPImage} object for processing. * @throws MediaPipeException if the task is not in the image mode or requires a normalized rect * input. */ - protected TaskResult processImageData(Image image) { + protected TaskResult processImageData(MPImage image) { if (runningMode != RunningMode.IMAGE) { throw new MediaPipeException( MediaPipeException.StatusCode.FAILED_PRECONDITION.ordinal(), @@ -102,13 +102,13 @@ public class BaseVisionTaskApi implements AutoCloseable { * A synchronous method to process single image inputs. The call blocks the current thread until a * failure status or a successful result is returned. * - * @param image a MediaPipe {@link Image} object for processing. + * @param image a MediaPipe {@link MPImage} object for processing. * @param roi a {@link RectF} defining the region-of-interest to process in the image. Coordinates * are expected to be specified as normalized values in [0,1]. * @throws MediaPipeException if the task is not in the image mode or doesn't require a normalized * rect. */ - protected TaskResult processImageData(Image image, RectF roi) { + protected TaskResult processImageData(MPImage image, RectF roi) { if (runningMode != RunningMode.IMAGE) { throw new MediaPipeException( MediaPipeException.StatusCode.FAILED_PRECONDITION.ordinal(), @@ -132,12 +132,12 @@ public class BaseVisionTaskApi implements AutoCloseable { * A synchronous method to process continuous video frames. The call blocks the current thread * until a failure status or a successful result is returned. * - * @param image a MediaPipe {@link Image} object for processing. + * @param image a MediaPipe {@link MPImage} object for processing. * @param timestampMs the corresponding timestamp of the input image in milliseconds. * @throws MediaPipeException if the task is not in the video mode or requires a normalized rect * input. */ - protected TaskResult processVideoData(Image image, long timestampMs) { + protected TaskResult processVideoData(MPImage image, long timestampMs) { if (runningMode != RunningMode.VIDEO) { throw new MediaPipeException( MediaPipeException.StatusCode.FAILED_PRECONDITION.ordinal(), @@ -158,14 +158,14 @@ public class BaseVisionTaskApi implements AutoCloseable { * A synchronous method to process continuous video frames. The call blocks the current thread * until a failure status or a successful result is returned. * - * @param image a MediaPipe {@link Image} object for processing. + * @param image a MediaPipe {@link MPImage} object for processing. * @param roi a {@link RectF} defining the region-of-interest to process in the image. Coordinates * are expected to be specified as normalized values in [0,1]. * @param timestampMs the corresponding timestamp of the input image in milliseconds. * @throws MediaPipeException if the task is not in the video mode or doesn't require a normalized * rect. */ - protected TaskResult processVideoData(Image image, RectF roi, long timestampMs) { + protected TaskResult processVideoData(MPImage image, RectF roi, long timestampMs) { if (runningMode != RunningMode.VIDEO) { throw new MediaPipeException( MediaPipeException.StatusCode.FAILED_PRECONDITION.ordinal(), @@ -189,12 +189,12 @@ public class BaseVisionTaskApi implements AutoCloseable { * An asynchronous method to send live stream data to the {@link TaskRunner}. The results will be * available in the user-defined result listener. * - * @param image a MediaPipe {@link Image} object for processing. + * @param image a MediaPipe {@link MPImage} object for processing. * @param timestampMs the corresponding timestamp of the input image in milliseconds. * @throws MediaPipeException if the task is not in the video mode or requires a normalized rect * input. */ - protected void sendLiveStreamData(Image image, long timestampMs) { + protected void sendLiveStreamData(MPImage image, long timestampMs) { if (runningMode != RunningMode.LIVE_STREAM) { throw new MediaPipeException( MediaPipeException.StatusCode.FAILED_PRECONDITION.ordinal(), @@ -215,14 +215,14 @@ public class BaseVisionTaskApi implements AutoCloseable { * An asynchronous method to send live stream data to the {@link TaskRunner}. The results will be * available in the user-defined result listener. * - * @param image a MediaPipe {@link Image} object for processing. + * @param image a MediaPipe {@link MPImage} object for processing. * @param roi a {@link RectF} defining the region-of-interest to process in the image. Coordinates * are expected to be specified as normalized values in [0,1]. * @param timestampMs the corresponding timestamp of the input image in milliseconds. * @throws MediaPipeException if the task is not in the video mode or doesn't require a normalized * rect. */ - protected void sendLiveStreamData(Image image, RectF roi, long timestampMs) { + protected void sendLiveStreamData(MPImage image, RectF roi, long timestampMs) { if (runningMode != RunningMode.LIVE_STREAM) { throw new MediaPipeException( MediaPipeException.StatusCode.FAILED_PRECONDITION.ordinal(), diff --git a/mediapipe/tasks/java/com/google/mediapipe/tasks/vision/gesturerecognizer/GestureRecognizer.java b/mediapipe/tasks/java/com/google/mediapipe/tasks/vision/gesturerecognizer/GestureRecognizer.java index 660645d9c..560508903 100644 --- a/mediapipe/tasks/java/com/google/mediapipe/tasks/vision/gesturerecognizer/GestureRecognizer.java +++ b/mediapipe/tasks/java/com/google/mediapipe/tasks/vision/gesturerecognizer/GestureRecognizer.java @@ -26,7 +26,7 @@ import com.google.mediapipe.framework.AndroidPacketGetter; import com.google.mediapipe.framework.Packet; import com.google.mediapipe.framework.PacketGetter; import com.google.mediapipe.framework.image.BitmapImageBuilder; -import com.google.mediapipe.framework.image.Image; +import com.google.mediapipe.framework.image.MPImage; import com.google.mediapipe.tasks.components.processors.proto.ClassifierOptionsProto; import com.google.mediapipe.tasks.core.BaseOptions; import com.google.mediapipe.tasks.core.ErrorListener; @@ -59,7 +59,7 @@ import java.util.Optional; * Model Maker. See . * *

      - *
    • Input image {@link Image} + *
    • Input image {@link MPImage} *
        *
      • The image that gesture recognition runs on. *
      @@ -151,9 +151,9 @@ public final class GestureRecognizer extends BaseVisionTaskApi { public static GestureRecognizer createFromOptions( Context context, GestureRecognizerOptions recognizerOptions) { // TODO: Consolidate OutputHandler and TaskRunner. - OutputHandler handler = new OutputHandler<>(); + OutputHandler handler = new OutputHandler<>(); handler.setOutputPacketConverter( - new OutputHandler.OutputPacketConverter() { + new OutputHandler.OutputPacketConverter() { @Override public GestureRecognitionResult convertToTaskResult(List packets) { // If there is no hands detected in the image, just returns empty lists. @@ -178,7 +178,7 @@ public final class GestureRecognizer extends BaseVisionTaskApi { } @Override - public Image convertToTaskInput(List packets) { + public MPImage convertToTaskInput(List packets) { return new BitmapImageBuilder( AndroidPacketGetter.getBitmapFromRgb(packets.get(IMAGE_OUT_STREAM_INDEX))) .build(); @@ -222,10 +222,10 @@ public final class GestureRecognizer extends BaseVisionTaskApi { *
    • {@link Bitmap.Config.ARGB_8888} *
    * - * @param inputImage a MediaPipe {@link Image} object for processing. + * @param inputImage a MediaPipe {@link MPImage} object for processing. * @throws MediaPipeException if there is an internal error. */ - public GestureRecognitionResult recognize(Image inputImage) { + public GestureRecognitionResult recognize(MPImage inputImage) { // TODO: add proper support for rotations. return (GestureRecognitionResult) processImageData(inputImage, buildFullImageRectF()); } @@ -243,11 +243,11 @@ public final class GestureRecognizer extends BaseVisionTaskApi { *
  • {@link Bitmap.Config.ARGB_8888} * * - * @param inputImage a MediaPipe {@link Image} object for processing. + * @param inputImage a MediaPipe {@link MPImage} object for processing. * @param inputTimestampMs the input timestamp (in milliseconds). * @throws MediaPipeException if there is an internal error. */ - public GestureRecognitionResult recognizeForVideo(Image inputImage, long inputTimestampMs) { + public GestureRecognitionResult recognizeForVideo(MPImage inputImage, long inputTimestampMs) { // TODO: add proper support for rotations. return (GestureRecognitionResult) processVideoData(inputImage, buildFullImageRectF(), inputTimestampMs); @@ -267,11 +267,11 @@ public final class GestureRecognizer extends BaseVisionTaskApi { *
  • {@link Bitmap.Config.ARGB_8888} * * - * @param inputImage a MediaPipe {@link Image} object for processing. + * @param inputImage a MediaPipe {@link MPImage} object for processing. * @param inputTimestampMs the input timestamp (in milliseconds). * @throws MediaPipeException if there is an internal error. */ - public void recognizeAsync(Image inputImage, long inputTimestampMs) { + public void recognizeAsync(MPImage inputImage, long inputTimestampMs) { // TODO: add proper support for rotations. sendLiveStreamData(inputImage, buildFullImageRectF(), inputTimestampMs); } @@ -333,7 +333,7 @@ public final class GestureRecognizer extends BaseVisionTaskApi { * recognizer is in the live stream mode. */ public abstract Builder setResultListener( - ResultListener value); + ResultListener value); /** Sets an optional error listener. */ public abstract Builder setErrorListener(ErrorListener value); @@ -386,7 +386,7 @@ public final class GestureRecognizer extends BaseVisionTaskApi { // TODO update gesture confidence options after score merging calculator is ready. abstract Optional minGestureConfidence(); - abstract Optional> resultListener(); + abstract Optional> resultListener(); abstract Optional errorListener(); diff --git a/mediapipe/tasks/java/com/google/mediapipe/tasks/vision/imageclassifier/ImageClassifier.java b/mediapipe/tasks/java/com/google/mediapipe/tasks/vision/imageclassifier/ImageClassifier.java index e8e263b71..e7d9e4ea1 100644 --- a/mediapipe/tasks/java/com/google/mediapipe/tasks/vision/imageclassifier/ImageClassifier.java +++ b/mediapipe/tasks/java/com/google/mediapipe/tasks/vision/imageclassifier/ImageClassifier.java @@ -25,7 +25,7 @@ import com.google.mediapipe.framework.Packet; import com.google.mediapipe.framework.PacketGetter; import com.google.mediapipe.framework.ProtoUtil; import com.google.mediapipe.framework.image.BitmapImageBuilder; -import com.google.mediapipe.framework.image.Image; +import com.google.mediapipe.framework.image.MPImage; import com.google.mediapipe.tasks.components.container.proto.ClassificationsProto; import com.google.mediapipe.tasks.components.processors.ClassifierOptions; import com.google.mediapipe.tasks.core.BaseOptions; @@ -164,9 +164,9 @@ public final class ImageClassifier extends BaseVisionTaskApi { * @throws MediaPipeException if there is an error during {@link ImageClassifier} creation. */ public static ImageClassifier createFromOptions(Context context, ImageClassifierOptions options) { - OutputHandler handler = new OutputHandler<>(); + OutputHandler handler = new OutputHandler<>(); handler.setOutputPacketConverter( - new OutputHandler.OutputPacketConverter() { + new OutputHandler.OutputPacketConverter() { @Override public ImageClassificationResult convertToTaskResult(List packets) { try { @@ -182,7 +182,7 @@ public final class ImageClassifier extends BaseVisionTaskApi { } @Override - public Image convertToTaskInput(List packets) { + public MPImage convertToTaskInput(List packets) { return new BitmapImageBuilder( AndroidPacketGetter.getBitmapFromRgb(packets.get(IMAGE_OUT_STREAM_INDEX))) .build(); @@ -225,10 +225,10 @@ public final class ImageClassifier extends BaseVisionTaskApi { *
  • {@link Bitmap.Config.ARGB_8888} * * - * @param inputImage a MediaPipe {@link Image} object for processing. + * @param inputImage a MediaPipe {@link MPImage} object for processing. * @throws MediaPipeException if there is an internal error. */ - public ImageClassificationResult classify(Image inputImage) { + public ImageClassificationResult classify(MPImage inputImage) { return (ImageClassificationResult) processImageData(inputImage, buildFullImageRectF()); } @@ -242,12 +242,12 @@ public final class ImageClassifier extends BaseVisionTaskApi { *
  • {@link Bitmap.Config.ARGB_8888} * * - * @param inputImage a MediaPipe {@link Image} object for processing. + * @param inputImage a MediaPipe {@link MPImage} object for processing. * @param roi a {@link RectF} specifying the region of interest on which to perform * classification. Coordinates are expected to be specified as normalized values in [0,1]. * @throws MediaPipeException if there is an internal error. */ - public ImageClassificationResult classify(Image inputImage, RectF roi) { + public ImageClassificationResult classify(MPImage inputImage, RectF roi) { return (ImageClassificationResult) processImageData(inputImage, roi); } @@ -264,11 +264,11 @@ public final class ImageClassifier extends BaseVisionTaskApi { *
  • {@link Bitmap.Config.ARGB_8888} * * - * @param inputImage a MediaPipe {@link Image} object for processing. + * @param inputImage a MediaPipe {@link MPImage} object for processing. * @param inputTimestampMs the input timestamp (in milliseconds). * @throws MediaPipeException if there is an internal error. */ - public ImageClassificationResult classifyForVideo(Image inputImage, long inputTimestampMs) { + public ImageClassificationResult classifyForVideo(MPImage inputImage, long inputTimestampMs) { return (ImageClassificationResult) processVideoData(inputImage, buildFullImageRectF(), inputTimestampMs); } @@ -286,14 +286,14 @@ public final class ImageClassifier extends BaseVisionTaskApi { *
  • {@link Bitmap.Config.ARGB_8888} * * - * @param inputImage a MediaPipe {@link Image} object for processing. + * @param inputImage a MediaPipe {@link MPImage} object for processing. * @param roi a {@link RectF} specifying the region of interest on which to perform * classification. Coordinates are expected to be specified as normalized values in [0,1]. * @param inputTimestampMs the input timestamp (in milliseconds). * @throws MediaPipeException if there is an internal error. */ public ImageClassificationResult classifyForVideo( - Image inputImage, RectF roi, long inputTimestampMs) { + MPImage inputImage, RectF roi, long inputTimestampMs) { return (ImageClassificationResult) processVideoData(inputImage, roi, inputTimestampMs); } @@ -311,11 +311,11 @@ public final class ImageClassifier extends BaseVisionTaskApi { *
  • {@link Bitmap.Config.ARGB_8888} * * - * @param inputImage a MediaPipe {@link Image} object for processing. + * @param inputImage a MediaPipe {@link MPImage} object for processing. * @param inputTimestampMs the input timestamp (in milliseconds). * @throws MediaPipeException if there is an internal error. */ - public void classifyAsync(Image inputImage, long inputTimestampMs) { + public void classifyAsync(MPImage inputImage, long inputTimestampMs) { sendLiveStreamData(inputImage, buildFullImageRectF(), inputTimestampMs); } @@ -334,13 +334,13 @@ public final class ImageClassifier extends BaseVisionTaskApi { *
  • {@link Bitmap.Config.ARGB_8888} * * - * @param inputImage a MediaPipe {@link Image} object for processing. + * @param inputImage a MediaPipe {@link MPImage} object for processing. * @param roi a {@link RectF} specifying the region of interest on which to perform * classification. Coordinates are expected to be specified as normalized values in [0,1]. * @param inputTimestampMs the input timestamp (in milliseconds). * @throws MediaPipeException if there is an internal error. */ - public void classifyAsync(Image inputImage, RectF roi, long inputTimestampMs) { + public void classifyAsync(MPImage inputImage, RectF roi, long inputTimestampMs) { sendLiveStreamData(inputImage, roi, inputTimestampMs); } @@ -379,7 +379,7 @@ public final class ImageClassifier extends BaseVisionTaskApi { * the image classifier is in the live stream mode. */ public abstract Builder setResultListener( - ResultListener resultListener); + ResultListener resultListener); /** Sets an optional {@link ErrorListener}. */ public abstract Builder setErrorListener(ErrorListener errorListener); @@ -416,7 +416,7 @@ public final class ImageClassifier extends BaseVisionTaskApi { abstract Optional classifierOptions(); - abstract Optional> resultListener(); + abstract Optional> resultListener(); abstract Optional errorListener(); diff --git a/mediapipe/tasks/java/com/google/mediapipe/tasks/vision/objectdetector/ObjectDetector.java b/mediapipe/tasks/java/com/google/mediapipe/tasks/vision/objectdetector/ObjectDetector.java index bfce62791..0f2e7b540 100644 --- a/mediapipe/tasks/java/com/google/mediapipe/tasks/vision/objectdetector/ObjectDetector.java +++ b/mediapipe/tasks/java/com/google/mediapipe/tasks/vision/objectdetector/ObjectDetector.java @@ -22,7 +22,7 @@ import com.google.mediapipe.framework.AndroidPacketGetter; import com.google.mediapipe.framework.Packet; import com.google.mediapipe.framework.PacketGetter; import com.google.mediapipe.framework.image.BitmapImageBuilder; -import com.google.mediapipe.framework.image.Image; +import com.google.mediapipe.framework.image.MPImage; import com.google.mediapipe.tasks.core.BaseOptions; import com.google.mediapipe.tasks.core.ErrorListener; import com.google.mediapipe.tasks.core.OutputHandler; @@ -162,9 +162,9 @@ public final class ObjectDetector extends BaseVisionTaskApi { public static ObjectDetector createFromOptions( Context context, ObjectDetectorOptions detectorOptions) { // TODO: Consolidate OutputHandler and TaskRunner. - OutputHandler handler = new OutputHandler<>(); + OutputHandler handler = new OutputHandler<>(); handler.setOutputPacketConverter( - new OutputHandler.OutputPacketConverter() { + new OutputHandler.OutputPacketConverter() { @Override public ObjectDetectionResult convertToTaskResult(List packets) { return ObjectDetectionResult.create( @@ -174,7 +174,7 @@ public final class ObjectDetector extends BaseVisionTaskApi { } @Override - public Image convertToTaskInput(List packets) { + public MPImage convertToTaskInput(List packets) { return new BitmapImageBuilder( AndroidPacketGetter.getBitmapFromRgb(packets.get(IMAGE_OUT_STREAM_INDEX))) .build(); @@ -217,10 +217,10 @@ public final class ObjectDetector extends BaseVisionTaskApi { *
  • {@link Bitmap.Config.ARGB_8888} * * - * @param inputImage a MediaPipe {@link Image} object for processing. + * @param inputImage a MediaPipe {@link MPImage} object for processing. * @throws MediaPipeException if there is an internal error. */ - public ObjectDetectionResult detect(Image inputImage) { + public ObjectDetectionResult detect(MPImage inputImage) { return (ObjectDetectionResult) processImageData(inputImage); } @@ -237,11 +237,11 @@ public final class ObjectDetector extends BaseVisionTaskApi { *
  • {@link Bitmap.Config.ARGB_8888} * * - * @param inputImage a MediaPipe {@link Image} object for processing. + * @param inputImage a MediaPipe {@link MPImage} object for processing. * @param inputTimestampMs the input timestamp (in milliseconds). * @throws MediaPipeException if there is an internal error. */ - public ObjectDetectionResult detectForVideo(Image inputImage, long inputTimestampMs) { + public ObjectDetectionResult detectForVideo(MPImage inputImage, long inputTimestampMs) { return (ObjectDetectionResult) processVideoData(inputImage, inputTimestampMs); } @@ -259,11 +259,11 @@ public final class ObjectDetector extends BaseVisionTaskApi { *
  • {@link Bitmap.Config.ARGB_8888} * * - * @param inputImage a MediaPipe {@link Image} object for processing. + * @param inputImage a MediaPipe {@link MPImage} object for processing. * @param inputTimestampMs the input timestamp (in milliseconds). * @throws MediaPipeException if there is an internal error. */ - public void detectAsync(Image inputImage, long inputTimestampMs) { + public void detectAsync(MPImage inputImage, long inputTimestampMs) { sendLiveStreamData(inputImage, inputTimestampMs); } @@ -333,7 +333,8 @@ public final class ObjectDetector extends BaseVisionTaskApi { * Sets the {@link ResultListener} to receive the detection results asynchronously when the * object detector is in the live stream mode. */ - public abstract Builder setResultListener(ResultListener value); + public abstract Builder setResultListener( + ResultListener value); /** Sets an optional {@link ErrorListener}}. */ public abstract Builder setErrorListener(ErrorListener value); @@ -378,7 +379,7 @@ public final class ObjectDetector extends BaseVisionTaskApi { abstract List categoryDenylist(); - abstract Optional> resultListener(); + abstract Optional> resultListener(); abstract Optional errorListener(); diff --git a/mediapipe/tasks/javatests/com/google/mediapipe/tasks/vision/gesturerecognizer/GestureRecognizerTest.java b/mediapipe/tasks/javatests/com/google/mediapipe/tasks/vision/gesturerecognizer/GestureRecognizerTest.java index efec02b2a..8beea96ac 100644 --- a/mediapipe/tasks/javatests/com/google/mediapipe/tasks/vision/gesturerecognizer/GestureRecognizerTest.java +++ b/mediapipe/tasks/javatests/com/google/mediapipe/tasks/vision/gesturerecognizer/GestureRecognizerTest.java @@ -25,7 +25,7 @@ import com.google.common.truth.Correspondence; import com.google.mediapipe.formats.proto.ClassificationProto; import com.google.mediapipe.framework.MediaPipeException; import com.google.mediapipe.framework.image.BitmapImageBuilder; -import com.google.mediapipe.framework.image.Image; +import com.google.mediapipe.framework.image.MPImage; import com.google.mediapipe.tasks.components.containers.Category; import com.google.mediapipe.tasks.components.containers.Landmark; import com.google.mediapipe.tasks.components.containers.proto.LandmarksDetectionResultProto.LandmarksDetectionResult; @@ -357,7 +357,7 @@ public class GestureRecognizerTest { @Test public void recognize_failsWithOutOfOrderInputTimestamps() throws Exception { - Image image = getImageFromAsset(THUMB_UP_IMAGE); + MPImage image = getImageFromAsset(THUMB_UP_IMAGE); GestureRecognitionResult expectedResult = getExpectedGestureRecognitionResult(THUMB_UP_LANDMARKS, THUMB_UP_LABEL, THUMB_UP_INDEX); GestureRecognizerOptions options = @@ -391,7 +391,7 @@ public class GestureRecognizerTest { @Test public void recognize_successWithLiveSteamMode() throws Exception { - Image image = getImageFromAsset(THUMB_UP_IMAGE); + MPImage image = getImageFromAsset(THUMB_UP_IMAGE); GestureRecognitionResult expectedResult = getExpectedGestureRecognitionResult(THUMB_UP_LANDMARKS, THUMB_UP_LABEL, THUMB_UP_INDEX); GestureRecognizerOptions options = @@ -420,7 +420,7 @@ public class GestureRecognizerTest { } } - private static Image getImageFromAsset(String filePath) throws Exception { + private static MPImage getImageFromAsset(String filePath) throws Exception { AssetManager assetManager = ApplicationProvider.getApplicationContext().getAssets(); InputStream istr = assetManager.open(filePath); return new BitmapImageBuilder(BitmapFactory.decodeStream(istr)).build(); @@ -487,7 +487,7 @@ public class GestureRecognizerTest { assertThat(expectedGesture.categoryName()).isEqualTo(expectedGesture.categoryName()); } - private static void assertImageSizeIsExpected(Image inputImage) { + private static void assertImageSizeIsExpected(MPImage inputImage) { assertThat(inputImage).isNotNull(); assertThat(inputImage.getWidth()).isEqualTo(IMAGE_WIDTH); assertThat(inputImage.getHeight()).isEqualTo(IMAGE_HEIGHT); diff --git a/mediapipe/tasks/javatests/com/google/mediapipe/tasks/vision/imageclassifier/ImageClassifierTest.java b/mediapipe/tasks/javatests/com/google/mediapipe/tasks/vision/imageclassifier/ImageClassifierTest.java index e02e8ebe7..966e4ff4a 100644 --- a/mediapipe/tasks/javatests/com/google/mediapipe/tasks/vision/imageclassifier/ImageClassifierTest.java +++ b/mediapipe/tasks/javatests/com/google/mediapipe/tasks/vision/imageclassifier/ImageClassifierTest.java @@ -24,7 +24,7 @@ import androidx.test.core.app.ApplicationProvider; import androidx.test.ext.junit.runners.AndroidJUnit4; import com.google.mediapipe.framework.MediaPipeException; import com.google.mediapipe.framework.image.BitmapImageBuilder; -import com.google.mediapipe.framework.image.Image; +import com.google.mediapipe.framework.image.MPImage; import com.google.mediapipe.tasks.components.containers.Category; import com.google.mediapipe.tasks.components.processors.ClassifierOptions; import com.google.mediapipe.tasks.core.BaseOptions; @@ -342,7 +342,7 @@ public class ImageClassifierTest { @Test public void classify_succeedsWithVideoMode() throws Exception { - Image image = getImageFromAsset(BURGER_IMAGE); + MPImage image = getImageFromAsset(BURGER_IMAGE); ImageClassifierOptions options = ImageClassifierOptions.builder() .setBaseOptions(BaseOptions.builder().setModelAssetPath(FLOAT_MODEL_FILE).build()) @@ -361,7 +361,7 @@ public class ImageClassifierTest { @Test public void classify_failsWithOutOfOrderInputTimestamps() throws Exception { - Image image = getImageFromAsset(BURGER_IMAGE); + MPImage image = getImageFromAsset(BURGER_IMAGE); ImageClassifierOptions options = ImageClassifierOptions.builder() .setBaseOptions(BaseOptions.builder().setModelAssetPath(FLOAT_MODEL_FILE).build()) @@ -388,7 +388,7 @@ public class ImageClassifierTest { @Test public void classify_succeedsWithLiveStreamMode() throws Exception { - Image image = getImageFromAsset(BURGER_IMAGE); + MPImage image = getImageFromAsset(BURGER_IMAGE); ImageClassifierOptions options = ImageClassifierOptions.builder() .setBaseOptions(BaseOptions.builder().setModelAssetPath(FLOAT_MODEL_FILE).build()) @@ -411,7 +411,7 @@ public class ImageClassifierTest { } } - private static Image getImageFromAsset(String filePath) throws Exception { + private static MPImage getImageFromAsset(String filePath) throws Exception { AssetManager assetManager = ApplicationProvider.getApplicationContext().getAssets(); InputStream istr = assetManager.open(filePath); return new BitmapImageBuilder(BitmapFactory.decodeStream(istr)).build(); @@ -437,7 +437,7 @@ public class ImageClassifierTest { } } - private static void assertImageSizeIsExpected(Image inputImage) { + private static void assertImageSizeIsExpected(MPImage inputImage) { assertThat(inputImage).isNotNull(); assertThat(inputImage.getWidth()).isEqualTo(480); assertThat(inputImage.getHeight()).isEqualTo(325); diff --git a/mediapipe/tasks/javatests/com/google/mediapipe/tasks/vision/objectdetector/ObjectDetectorTest.java b/mediapipe/tasks/javatests/com/google/mediapipe/tasks/vision/objectdetector/ObjectDetectorTest.java index cdec57d76..91ffa9273 100644 --- a/mediapipe/tasks/javatests/com/google/mediapipe/tasks/vision/objectdetector/ObjectDetectorTest.java +++ b/mediapipe/tasks/javatests/com/google/mediapipe/tasks/vision/objectdetector/ObjectDetectorTest.java @@ -24,7 +24,7 @@ import androidx.test.core.app.ApplicationProvider; import androidx.test.ext.junit.runners.AndroidJUnit4; import com.google.mediapipe.framework.MediaPipeException; import com.google.mediapipe.framework.image.BitmapImageBuilder; -import com.google.mediapipe.framework.image.Image; +import com.google.mediapipe.framework.image.MPImage; import com.google.mediapipe.tasks.components.containers.Category; import com.google.mediapipe.tasks.components.containers.Detection; import com.google.mediapipe.tasks.core.BaseOptions; @@ -370,7 +370,7 @@ public class ObjectDetectorTest { @Test public void detect_failsWithOutOfOrderInputTimestamps() throws Exception { - Image image = getImageFromAsset(CAT_AND_DOG_IMAGE); + MPImage image = getImageFromAsset(CAT_AND_DOG_IMAGE); ObjectDetectorOptions options = ObjectDetectorOptions.builder() .setBaseOptions(BaseOptions.builder().setModelAssetPath(MODEL_FILE).build()) @@ -395,7 +395,7 @@ public class ObjectDetectorTest { @Test public void detect_successWithLiveSteamMode() throws Exception { - Image image = getImageFromAsset(CAT_AND_DOG_IMAGE); + MPImage image = getImageFromAsset(CAT_AND_DOG_IMAGE); ObjectDetectorOptions options = ObjectDetectorOptions.builder() .setBaseOptions(BaseOptions.builder().setModelAssetPath(MODEL_FILE).build()) @@ -416,7 +416,7 @@ public class ObjectDetectorTest { } } - private static Image getImageFromAsset(String filePath) throws Exception { + private static MPImage getImageFromAsset(String filePath) throws Exception { AssetManager assetManager = ApplicationProvider.getApplicationContext().getAssets(); InputStream istr = assetManager.open(filePath); return new BitmapImageBuilder(BitmapFactory.decodeStream(istr)).build(); @@ -448,7 +448,7 @@ public class ObjectDetectorTest { assertThat(boundingBox1.bottom).isWithin(PIXEL_DIFF_TOLERANCE).of(boundingBox2.bottom); } - private static void assertImageSizeIsExpected(Image inputImage) { + private static void assertImageSizeIsExpected(MPImage inputImage) { assertThat(inputImage).isNotNull(); assertThat(inputImage.getWidth()).isEqualTo(IMAGE_WIDTH); assertThat(inputImage.getHeight()).isEqualTo(IMAGE_HEIGHT); From 404323f631048700412cb680f983e0ab646b55ee Mon Sep 17 00:00:00 2001 From: MediaPipe Team Date: Fri, 21 Oct 2022 18:31:47 -0700 Subject: [PATCH 30/55] Add Mediapipe Tasks Gesture Recognizer benchmarks PiperOrigin-RevId: 482935780 --- mediapipe/tasks/testdata/vision/BUILD | 1 + 1 file changed, 1 insertion(+) diff --git a/mediapipe/tasks/testdata/vision/BUILD b/mediapipe/tasks/testdata/vision/BUILD index ffb4760d9..f899be8ef 100644 --- a/mediapipe/tasks/testdata/vision/BUILD +++ b/mediapipe/tasks/testdata/vision/BUILD @@ -67,6 +67,7 @@ mediapipe_files(srcs = [ exports_files( srcs = [ + "cg_classifier_screen3d_landmark_features_nn_2022_08_04_base_simple_model.tflite", "expected_left_down_hand_landmarks.prototxt", "expected_left_down_hand_rotated_landmarks.prototxt", "expected_left_up_hand_landmarks.prototxt", From d8006a2f87a3dc1bac49244683f68ad09ed841f6 Mon Sep 17 00:00:00 2001 From: MediaPipe Team Date: Fri, 21 Oct 2022 21:52:40 -0700 Subject: [PATCH 31/55] Use model bundle for gesture recognizer. PiperOrigin-RevId: 482960305 --- .../tasks/cc/vision/gesture_recognizer/BUILD | 11 ++ .../gesture_recognizer/gesture_recognizer.cc | 46 ++---- .../gesture_recognizer/gesture_recognizer.h | 6 - .../gesture_recognizer_graph.cc | 71 +++++++++ .../hand_gesture_recognizer_graph.cc | 150 +++++++++++++++--- .../cc/vision/gesture_recognizer/proto/BUILD | 1 - ...and_gesture_recognizer_graph_options.proto | 7 +- .../hand_landmarker/hand_landmarker_graph.cc | 20 ++- .../hand_landmarker_graph_test.cc | 2 +- .../com/google/mediapipe/tasks/vision/BUILD | 1 + .../gesturerecognizer/GestureRecognizer.java | 53 ++----- .../GestureRecognizerTest.java | 141 +++++----------- mediapipe/tasks/testdata/vision/BUILD | 5 +- ...and_landmark.task => hand_landmarker.task} | Bin 7819037 -> 7819037 bytes third_party/external_files.bzl | 24 +-- 15 files changed, 309 insertions(+), 229 deletions(-) rename mediapipe/tasks/testdata/vision/{hand_landmark.task => hand_landmarker.task} (99%) diff --git a/mediapipe/tasks/cc/vision/gesture_recognizer/BUILD b/mediapipe/tasks/cc/vision/gesture_recognizer/BUILD index a766c6b3f..6296017d4 100644 --- a/mediapipe/tasks/cc/vision/gesture_recognizer/BUILD +++ b/mediapipe/tasks/cc/vision/gesture_recognizer/BUILD @@ -62,13 +62,19 @@ cc_library( "//mediapipe/tasks/cc/components:image_preprocessing", "//mediapipe/tasks/cc/components/processors:classification_postprocessing_graph", "//mediapipe/tasks/cc/components/processors/proto:classification_postprocessing_graph_options_cc_proto", + "//mediapipe/tasks/cc/core:model_asset_bundle_resources", "//mediapipe/tasks/cc/core:model_resources", + "//mediapipe/tasks/cc/core:model_resources_cache", "//mediapipe/tasks/cc/core:model_task_graph", "//mediapipe/tasks/cc/core:utils", + "//mediapipe/tasks/cc/core/proto:external_file_cc_proto", "//mediapipe/tasks/cc/core/proto:inference_subgraph_cc_proto", + "//mediapipe/tasks/cc/metadata/utils:zip_utils", "//mediapipe/tasks/cc/vision/gesture_recognizer/calculators:handedness_to_matrix_calculator", "//mediapipe/tasks/cc/vision/gesture_recognizer/calculators:landmarks_to_matrix_calculator", "//mediapipe/tasks/cc/vision/gesture_recognizer/calculators:landmarks_to_matrix_calculator_cc_proto", + "//mediapipe/tasks/cc/vision/gesture_recognizer/proto:gesture_classifier_graph_options_cc_proto", + "//mediapipe/tasks/cc/vision/gesture_recognizer/proto:gesture_embedder_graph_options_cc_proto", "//mediapipe/tasks/cc/vision/gesture_recognizer/proto:hand_gesture_recognizer_graph_options_cc_proto", "//mediapipe/tasks/cc/vision/hand_landmarker:hand_landmarks_detector_graph", "//mediapipe/tasks/cc/vision/hand_landmarker/proto:hand_landmarker_graph_options_cc_proto", @@ -93,10 +99,14 @@ cc_library( "//mediapipe/framework/formats:image", "//mediapipe/framework/formats:landmark_cc_proto", "//mediapipe/framework/formats:rect_cc_proto", + "//mediapipe/framework/port:status", "//mediapipe/tasks/cc:common", "//mediapipe/tasks/cc/components/processors/proto:classifier_options_cc_proto", + "//mediapipe/tasks/cc/core:model_asset_bundle_resources", + "//mediapipe/tasks/cc/core:model_resources_cache", "//mediapipe/tasks/cc/core:model_task_graph", "//mediapipe/tasks/cc/core:utils", + "//mediapipe/tasks/cc/metadata/utils:zip_utils", "//mediapipe/tasks/cc/vision/gesture_recognizer/proto:gesture_recognizer_graph_options_cc_proto", "//mediapipe/tasks/cc/vision/gesture_recognizer/proto:hand_gesture_recognizer_graph_options_cc_proto", "//mediapipe/tasks/cc/vision/hand_detector:hand_detector_graph", @@ -140,6 +150,7 @@ cc_library( "//mediapipe/tasks/cc/vision/core:image_processing_options", "//mediapipe/tasks/cc/vision/core:running_mode", "//mediapipe/tasks/cc/vision/core:vision_task_api_factory", + "//mediapipe/tasks/cc/vision/gesture_recognizer/proto:gesture_classifier_graph_options_cc_proto", "//mediapipe/tasks/cc/vision/gesture_recognizer/proto:gesture_recognizer_graph_options_cc_proto", "//mediapipe/tasks/cc/vision/gesture_recognizer/proto:hand_gesture_recognizer_graph_options_cc_proto", "//mediapipe/tasks/cc/vision/hand_detector/proto:hand_detector_graph_options_cc_proto", diff --git a/mediapipe/tasks/cc/vision/gesture_recognizer/gesture_recognizer.cc b/mediapipe/tasks/cc/vision/gesture_recognizer/gesture_recognizer.cc index 000a2e141..d4ab16ac8 100644 --- a/mediapipe/tasks/cc/vision/gesture_recognizer/gesture_recognizer.cc +++ b/mediapipe/tasks/cc/vision/gesture_recognizer/gesture_recognizer.cc @@ -41,6 +41,7 @@ limitations under the License. #include "mediapipe/tasks/cc/vision/core/base_vision_task_api.h" #include "mediapipe/tasks/cc/vision/core/image_processing_options.h" #include "mediapipe/tasks/cc/vision/core/vision_task_api_factory.h" +#include "mediapipe/tasks/cc/vision/gesture_recognizer/proto/gesture_classifier_graph_options.pb.h" #include "mediapipe/tasks/cc/vision/gesture_recognizer/proto/gesture_recognizer_graph_options.pb.h" #include "mediapipe/tasks/cc/vision/gesture_recognizer/proto/hand_gesture_recognizer_graph_options.pb.h" #include "mediapipe/tasks/cc/vision/hand_detector/proto/hand_detector_graph_options.pb.h" @@ -112,57 +113,38 @@ CalculatorGraphConfig CreateGraphConfig( std::unique_ptr ConvertGestureRecognizerGraphOptionsProto(GestureRecognizerOptions* options) { auto options_proto = std::make_unique(); + auto base_options_proto = std::make_unique( + tasks::core::ConvertBaseOptionsToProto(&(options->base_options))); + options_proto->mutable_base_options()->Swap(base_options_proto.get()); + options_proto->mutable_base_options()->set_use_stream_mode( + options->running_mode != core::RunningMode::IMAGE); - bool use_stream_mode = options->running_mode != core::RunningMode::IMAGE; - - // TODO remove these workarounds for base options of subgraphs. // Configure hand detector options. - auto base_options_proto_for_hand_detector = - std::make_unique( - tasks::core::ConvertBaseOptionsToProto( - &(options->base_options_for_hand_detector))); - base_options_proto_for_hand_detector->set_use_stream_mode(use_stream_mode); auto* hand_detector_graph_options = options_proto->mutable_hand_landmarker_graph_options() ->mutable_hand_detector_graph_options(); - hand_detector_graph_options->mutable_base_options()->Swap( - base_options_proto_for_hand_detector.get()); hand_detector_graph_options->set_num_hands(options->num_hands); hand_detector_graph_options->set_min_detection_confidence( options->min_hand_detection_confidence); // Configure hand landmark detector options. - auto base_options_proto_for_hand_landmarker = - std::make_unique( - tasks::core::ConvertBaseOptionsToProto( - &(options->base_options_for_hand_landmarker))); - base_options_proto_for_hand_landmarker->set_use_stream_mode(use_stream_mode); - auto* hand_landmarks_detector_graph_options = - options_proto->mutable_hand_landmarker_graph_options() - ->mutable_hand_landmarks_detector_graph_options(); - hand_landmarks_detector_graph_options->mutable_base_options()->Swap( - base_options_proto_for_hand_landmarker.get()); - hand_landmarks_detector_graph_options->set_min_detection_confidence( - options->min_hand_presence_confidence); - auto* hand_landmarker_graph_options = options_proto->mutable_hand_landmarker_graph_options(); hand_landmarker_graph_options->set_min_tracking_confidence( options->min_tracking_confidence); + auto* hand_landmarks_detector_graph_options = + hand_landmarker_graph_options + ->mutable_hand_landmarks_detector_graph_options(); + hand_landmarks_detector_graph_options->set_min_detection_confidence( + options->min_hand_presence_confidence); // Configure hand gesture recognizer options. - auto base_options_proto_for_gesture_recognizer = - std::make_unique( - tasks::core::ConvertBaseOptionsToProto( - &(options->base_options_for_gesture_recognizer))); - base_options_proto_for_gesture_recognizer->set_use_stream_mode( - use_stream_mode); auto* hand_gesture_recognizer_graph_options = options_proto->mutable_hand_gesture_recognizer_graph_options(); - hand_gesture_recognizer_graph_options->mutable_base_options()->Swap( - base_options_proto_for_gesture_recognizer.get()); if (options->min_gesture_confidence >= 0) { - hand_gesture_recognizer_graph_options->mutable_classifier_options() + hand_gesture_recognizer_graph_options + ->mutable_canned_gesture_classifier_graph_options() + ->mutable_classifier_options() ->set_score_threshold(options->min_gesture_confidence); } return options_proto; diff --git a/mediapipe/tasks/cc/vision/gesture_recognizer/gesture_recognizer.h b/mediapipe/tasks/cc/vision/gesture_recognizer/gesture_recognizer.h index 29c8bea7b..3e281b26e 100644 --- a/mediapipe/tasks/cc/vision/gesture_recognizer/gesture_recognizer.h +++ b/mediapipe/tasks/cc/vision/gesture_recognizer/gesture_recognizer.h @@ -39,12 +39,6 @@ struct GestureRecognizerOptions { // model file with metadata, accelerator options, op resolver, etc. tasks::core::BaseOptions base_options; - // TODO: remove these. Temporary solutions before bundle asset is - // ready. - tasks::core::BaseOptions base_options_for_hand_landmarker; - tasks::core::BaseOptions base_options_for_hand_detector; - tasks::core::BaseOptions base_options_for_gesture_recognizer; - // The running mode of the task. Default to the image mode. // GestureRecognizer has three running modes: // 1) The image mode for recognizing hand gestures on single image inputs. diff --git a/mediapipe/tasks/cc/vision/gesture_recognizer/gesture_recognizer_graph.cc b/mediapipe/tasks/cc/vision/gesture_recognizer/gesture_recognizer_graph.cc index e02eadde8..7ab4847dd 100644 --- a/mediapipe/tasks/cc/vision/gesture_recognizer/gesture_recognizer_graph.cc +++ b/mediapipe/tasks/cc/vision/gesture_recognizer/gesture_recognizer_graph.cc @@ -25,9 +25,13 @@ limitations under the License. #include "mediapipe/framework/formats/image.h" #include "mediapipe/framework/formats/landmark.pb.h" #include "mediapipe/framework/formats/rect.pb.h" +#include "mediapipe/framework/port/status_macros.h" #include "mediapipe/tasks/cc/common.h" +#include "mediapipe/tasks/cc/core/model_asset_bundle_resources.h" +#include "mediapipe/tasks/cc/core/model_resources_cache.h" #include "mediapipe/tasks/cc/core/model_task_graph.h" #include "mediapipe/tasks/cc/core/utils.h" +#include "mediapipe/tasks/cc/metadata/utils/zip_utils.h" #include "mediapipe/tasks/cc/vision/gesture_recognizer/proto/gesture_recognizer_graph_options.pb.h" #include "mediapipe/tasks/cc/vision/gesture_recognizer/proto/hand_gesture_recognizer_graph_options.pb.h" #include "mediapipe/tasks/cc/vision/hand_detector/proto/hand_detector_graph_options.pb.h" @@ -46,6 +50,8 @@ using ::mediapipe::api2::Input; using ::mediapipe::api2::Output; using ::mediapipe::api2::builder::Graph; using ::mediapipe::api2::builder::Source; +using ::mediapipe::tasks::core::ModelAssetBundleResources; +using ::mediapipe::tasks::metadata::SetExternalFile; using ::mediapipe::tasks::vision::gesture_recognizer::proto:: GestureRecognizerGraphOptions; using ::mediapipe::tasks::vision::gesture_recognizer::proto:: @@ -61,6 +67,9 @@ constexpr char kHandednessTag[] = "HANDEDNESS"; constexpr char kImageSizeTag[] = "IMAGE_SIZE"; constexpr char kHandGesturesTag[] = "HAND_GESTURES"; constexpr char kHandTrackingIdsTag[] = "HAND_TRACKING_IDS"; +constexpr char kHandLandmarkerBundleAssetName[] = "hand_landmarker.task"; +constexpr char kHandGestureRecognizerBundleAssetName[] = + "hand_gesture_recognizer.task"; struct GestureRecognizerOutputs { Source> gesture; @@ -70,6 +79,53 @@ struct GestureRecognizerOutputs { Source image; }; +// Sets the base options in the sub tasks. +absl::Status SetSubTaskBaseOptions(const ModelAssetBundleResources& resources, + GestureRecognizerGraphOptions* options, + bool is_copy) { + ASSIGN_OR_RETURN(const auto hand_landmarker_file, + resources.GetModelFile(kHandLandmarkerBundleAssetName)); + auto* hand_landmarker_graph_options = + options->mutable_hand_landmarker_graph_options(); + SetExternalFile(hand_landmarker_file, + hand_landmarker_graph_options->mutable_base_options() + ->mutable_model_asset(), + is_copy); + hand_landmarker_graph_options->mutable_base_options() + ->mutable_acceleration() + ->CopyFrom(options->base_options().acceleration()); + hand_landmarker_graph_options->mutable_base_options()->set_use_stream_mode( + options->base_options().use_stream_mode()); + + ASSIGN_OR_RETURN( + const auto hand_gesture_recognizer_file, + resources.GetModelFile(kHandGestureRecognizerBundleAssetName)); + auto* hand_gesture_recognizer_graph_options = + options->mutable_hand_gesture_recognizer_graph_options(); + SetExternalFile(hand_gesture_recognizer_file, + hand_gesture_recognizer_graph_options->mutable_base_options() + ->mutable_model_asset(), + is_copy); + hand_gesture_recognizer_graph_options->mutable_base_options() + ->mutable_acceleration() + ->CopyFrom(options->base_options().acceleration()); + if (!hand_gesture_recognizer_graph_options->base_options() + .acceleration() + .has_xnnpack() && + !hand_gesture_recognizer_graph_options->base_options() + .acceleration() + .has_tflite()) { + hand_gesture_recognizer_graph_options->mutable_base_options() + ->mutable_acceleration() + ->mutable_xnnpack(); + LOG(WARNING) << "Hand Gesture Recognizer contains CPU only ops. Sets " + << "HandGestureRecognizerGraph acceleartion to Xnnpack."; + } + hand_gesture_recognizer_graph_options->mutable_base_options() + ->set_use_stream_mode(options->base_options().use_stream_mode()); + return absl::OkStatus(); +} + } // namespace // A "mediapipe.tasks.vision.gesture_recognizer.GestureRecognizerGraph" performs @@ -136,6 +192,21 @@ class GestureRecognizerGraph : public core::ModelTaskGraph { absl::StatusOr GetConfig( SubgraphContext* sc) override { Graph graph; + if (sc->Options() + .base_options() + .has_model_asset()) { + ASSIGN_OR_RETURN( + const auto* model_asset_bundle_resources, + CreateModelAssetBundleResources(sc)); + // When the model resources cache service is available, filling in + // the file pointer meta in the subtasks' base options. Otherwise, + // providing the file contents instead. + MP_RETURN_IF_ERROR(SetSubTaskBaseOptions( + *model_asset_bundle_resources, + sc->MutableOptions(), + !sc->Service(::mediapipe::tasks::core::kModelResourcesCacheService) + .IsAvailable())); + } ASSIGN_OR_RETURN(auto hand_gesture_recognition_output, BuildGestureRecognizerGraph( *sc->MutableOptions(), diff --git a/mediapipe/tasks/cc/vision/gesture_recognizer/hand_gesture_recognizer_graph.cc b/mediapipe/tasks/cc/vision/gesture_recognizer/hand_gesture_recognizer_graph.cc index 4bbe94974..7b7746956 100644 --- a/mediapipe/tasks/cc/vision/gesture_recognizer/hand_gesture_recognizer_graph.cc +++ b/mediapipe/tasks/cc/vision/gesture_recognizer/hand_gesture_recognizer_graph.cc @@ -30,11 +30,17 @@ limitations under the License. #include "mediapipe/tasks/cc/common.h" #include "mediapipe/tasks/cc/components/processors/classification_postprocessing_graph.h" #include "mediapipe/tasks/cc/components/processors/proto/classification_postprocessing_graph_options.pb.h" +#include "mediapipe/tasks/cc/core/model_asset_bundle_resources.h" #include "mediapipe/tasks/cc/core/model_resources.h" +#include "mediapipe/tasks/cc/core/model_resources_cache.h" #include "mediapipe/tasks/cc/core/model_task_graph.h" +#include "mediapipe/tasks/cc/core/proto/external_file.pb.h" #include "mediapipe/tasks/cc/core/proto/inference_subgraph.pb.h" #include "mediapipe/tasks/cc/core/utils.h" +#include "mediapipe/tasks/cc/metadata/utils/zip_utils.h" #include "mediapipe/tasks/cc/vision/gesture_recognizer/calculators/landmarks_to_matrix_calculator.pb.h" +#include "mediapipe/tasks/cc/vision/gesture_recognizer/proto/gesture_classifier_graph_options.pb.h" +#include "mediapipe/tasks/cc/vision/gesture_recognizer/proto/gesture_embedder_graph_options.pb.h" #include "mediapipe/tasks/cc/vision/gesture_recognizer/proto/hand_gesture_recognizer_graph_options.pb.h" #include "mediapipe/tasks/metadata/metadata_schema_generated.h" @@ -51,6 +57,8 @@ using ::mediapipe::api2::builder::Graph; using ::mediapipe::api2::builder::Source; using ::mediapipe::tasks::components::processors:: ConfigureTensorsToClassificationCalculator; +using ::mediapipe::tasks::core::ModelAssetBundleResources; +using ::mediapipe::tasks::metadata::SetExternalFile; using ::mediapipe::tasks::vision::gesture_recognizer::proto:: HandGestureRecognizerGraphOptions; @@ -70,6 +78,14 @@ constexpr char kVectorTag[] = "VECTOR"; constexpr char kIndexTag[] = "INDEX"; constexpr char kIterableTag[] = "ITERABLE"; constexpr char kBatchEndTag[] = "BATCH_END"; +constexpr char kGestureEmbedderTFLiteName[] = "gesture_embedder.tflite"; +constexpr char kCannedGestureClassifierTFLiteName[] = + "canned_gesture_classifier.tflite"; + +struct SubTaskModelResources { + const core::ModelResources* gesture_embedder_model_resource; + const core::ModelResources* canned_gesture_classifier_model_resource; +}; Source> ConvertMatrixToTensor(Source matrix, Graph& graph) { @@ -78,6 +94,41 @@ Source> ConvertMatrixToTensor(Source matrix, return node[Output>{"TENSORS"}]; } +// Sets the base options in the sub tasks. +absl::Status SetSubTaskBaseOptions(const ModelAssetBundleResources& resources, + HandGestureRecognizerGraphOptions* options, + bool is_copy) { + ASSIGN_OR_RETURN(const auto gesture_embedder_file, + resources.GetModelFile(kGestureEmbedderTFLiteName)); + auto* gesture_embedder_graph_options = + options->mutable_gesture_embedder_graph_options(); + SetExternalFile(gesture_embedder_file, + gesture_embedder_graph_options->mutable_base_options() + ->mutable_model_asset(), + is_copy); + gesture_embedder_graph_options->mutable_base_options() + ->mutable_acceleration() + ->CopyFrom(options->base_options().acceleration()); + gesture_embedder_graph_options->mutable_base_options()->set_use_stream_mode( + options->base_options().use_stream_mode()); + + ASSIGN_OR_RETURN(const auto canned_gesture_classifier_file, + resources.GetModelFile(kCannedGestureClassifierTFLiteName)); + auto* canned_gesture_classifier_graph_options = + options->mutable_canned_gesture_classifier_graph_options(); + SetExternalFile( + canned_gesture_classifier_file, + canned_gesture_classifier_graph_options->mutable_base_options() + ->mutable_model_asset(), + is_copy); + canned_gesture_classifier_graph_options->mutable_base_options() + ->mutable_acceleration() + ->CopyFrom(options->base_options().acceleration()); + canned_gesture_classifier_graph_options->mutable_base_options() + ->set_use_stream_mode(options->base_options().use_stream_mode()); + return absl::OkStatus(); +} + } // namespace // A @@ -128,27 +179,70 @@ class SingleHandGestureRecognizerGraph : public core::ModelTaskGraph { public: absl::StatusOr GetConfig( SubgraphContext* sc) override { - ASSIGN_OR_RETURN( - const auto* model_resources, - CreateModelResources(sc)); + if (sc->Options() + .base_options() + .has_model_asset()) { + ASSIGN_OR_RETURN( + const auto* model_asset_bundle_resources, + CreateModelAssetBundleResources( + sc)); + // When the model resources cache service is available, filling in + // the file pointer meta in the subtasks' base options. Otherwise, + // providing the file contents instead. + MP_RETURN_IF_ERROR(SetSubTaskBaseOptions( + *model_asset_bundle_resources, + sc->MutableOptions(), + !sc->Service(::mediapipe::tasks::core::kModelResourcesCacheService) + .IsAvailable())); + } + ASSIGN_OR_RETURN(const auto sub_task_model_resources, + CreateSubTaskModelResources(sc)); Graph graph; - ASSIGN_OR_RETURN( - auto hand_gestures, - BuildGestureRecognizerGraph( - sc->Options(), *model_resources, - graph[Input(kHandednessTag)], - graph[Input(kLandmarksTag)], - graph[Input(kWorldLandmarksTag)], - graph[Input>(kImageSizeTag)], - graph[Input(kNormRectTag)], graph)); + ASSIGN_OR_RETURN(auto hand_gestures, + BuildGestureRecognizerGraph( + sc->Options(), + sub_task_model_resources, + graph[Input(kHandednessTag)], + graph[Input(kLandmarksTag)], + graph[Input(kWorldLandmarksTag)], + graph[Input>(kImageSizeTag)], + graph[Input(kNormRectTag)], graph)); hand_gestures >> graph[Output(kHandGesturesTag)]; return graph.GetConfig(); } private: + absl::StatusOr CreateSubTaskModelResources( + SubgraphContext* sc) { + auto* options = sc->MutableOptions(); + SubTaskModelResources sub_task_model_resources; + auto& gesture_embedder_model_asset = + *options->mutable_gesture_embedder_graph_options() + ->mutable_base_options() + ->mutable_model_asset(); + ASSIGN_OR_RETURN( + sub_task_model_resources.gesture_embedder_model_resource, + CreateModelResources(sc, + std::make_unique( + std::move(gesture_embedder_model_asset)), + "_gesture_embedder")); + auto& canned_gesture_classifier_model_asset = + *options->mutable_canned_gesture_classifier_graph_options() + ->mutable_base_options() + ->mutable_model_asset(); + ASSIGN_OR_RETURN( + sub_task_model_resources.canned_gesture_classifier_model_resource, + CreateModelResources( + sc, + std::make_unique( + std::move(canned_gesture_classifier_model_asset)), + "_canned_gesture_classifier")); + return sub_task_model_resources; + } + absl::StatusOr> BuildGestureRecognizerGraph( const HandGestureRecognizerGraphOptions& graph_options, - const core::ModelResources& model_resources, + const SubTaskModelResources& sub_task_model_resources, Source handedness, Source hand_landmarks, Source hand_world_landmarks, @@ -209,17 +303,33 @@ class SingleHandGestureRecognizerGraph : public core::ModelTaskGraph { auto concatenated_tensors = concatenate_tensor_vector.Out(""); // Inference for static hand gesture recognition. - // TODO add embedding step. - auto& inference = AddInference( - model_resources, graph_options.base_options().acceleration(), graph); - concatenated_tensors >> inference.In(kTensorsTag); - auto inference_output_tensors = inference.Out(kTensorsTag); + auto& gesture_embedder_inference = + AddInference(*sub_task_model_resources.gesture_embedder_model_resource, + graph_options.gesture_embedder_graph_options() + .base_options() + .acceleration(), + graph); + concatenated_tensors >> gesture_embedder_inference.In(kTensorsTag); + auto embedding_tensors = gesture_embedder_inference.Out(kTensorsTag); + + auto& canned_gesture_classifier_inference = AddInference( + *sub_task_model_resources.canned_gesture_classifier_model_resource, + graph_options.canned_gesture_classifier_graph_options() + .base_options() + .acceleration(), + graph); + embedding_tensors >> canned_gesture_classifier_inference.In(kTensorsTag); + auto inference_output_tensors = + canned_gesture_classifier_inference.Out(kTensorsTag); auto& tensors_to_classification = graph.AddNode("TensorsToClassificationCalculator"); MP_RETURN_IF_ERROR(ConfigureTensorsToClassificationCalculator( - graph_options.classifier_options(), - *model_resources.GetMetadataExtractor(), 0, + graph_options.canned_gesture_classifier_graph_options() + .classifier_options(), + *sub_task_model_resources.canned_gesture_classifier_model_resource + ->GetMetadataExtractor(), + 0, &tensors_to_classification.GetOptions< mediapipe::TensorsToClassificationCalculatorOptions>())); inference_output_tensors >> tensors_to_classification.In(kTensorsTag); diff --git a/mediapipe/tasks/cc/vision/gesture_recognizer/proto/BUILD b/mediapipe/tasks/cc/vision/gesture_recognizer/proto/BUILD index 3b73bf2b0..0db47da7a 100644 --- a/mediapipe/tasks/cc/vision/gesture_recognizer/proto/BUILD +++ b/mediapipe/tasks/cc/vision/gesture_recognizer/proto/BUILD @@ -49,7 +49,6 @@ mediapipe_proto_library( ":gesture_embedder_graph_options_proto", "//mediapipe/framework:calculator_options_proto", "//mediapipe/framework:calculator_proto", - "//mediapipe/tasks/cc/components/processors/proto:classifier_options_proto", "//mediapipe/tasks/cc/core/proto:base_options_proto", ], ) diff --git a/mediapipe/tasks/cc/vision/gesture_recognizer/proto/hand_gesture_recognizer_graph_options.proto b/mediapipe/tasks/cc/vision/gesture_recognizer/proto/hand_gesture_recognizer_graph_options.proto index a3281702a..7df2fed37 100644 --- a/mediapipe/tasks/cc/vision/gesture_recognizer/proto/hand_gesture_recognizer_graph_options.proto +++ b/mediapipe/tasks/cc/vision/gesture_recognizer/proto/hand_gesture_recognizer_graph_options.proto @@ -18,7 +18,6 @@ syntax = "proto2"; package mediapipe.tasks.vision.gesture_recognizer.proto; import "mediapipe/framework/calculator.proto"; -import "mediapipe/tasks/cc/components/processors/proto/classifier_options.proto"; import "mediapipe/tasks/cc/core/proto/base_options.proto"; import "mediapipe/tasks/cc/vision/gesture_recognizer/proto/gesture_classifier_graph_options.proto"; import "mediapipe/tasks/cc/vision/gesture_recognizer/proto/gesture_embedder_graph_options.proto"; @@ -37,15 +36,11 @@ message HandGestureRecognizerGraphOptions { // Options for GestureEmbedder. optional GestureEmbedderGraphOptions gesture_embedder_graph_options = 2; - // Options for GestureClassifier of default gestures. + // Options for GestureClassifier of canned gestures. optional GestureClassifierGraphOptions canned_gesture_classifier_graph_options = 3; // Options for GestureClassifier of custom gestures. optional GestureClassifierGraphOptions custom_gesture_classifier_graph_options = 4; - - // TODO: remove these. Temporary solutions before bundle asset is - // ready. - optional components.processors.proto.ClassifierOptions classifier_options = 5; } diff --git a/mediapipe/tasks/cc/vision/hand_landmarker/hand_landmarker_graph.cc b/mediapipe/tasks/cc/vision/hand_landmarker/hand_landmarker_graph.cc index 3fbe38c1c..e610a412e 100644 --- a/mediapipe/tasks/cc/vision/hand_landmarker/hand_landmarker_graph.cc +++ b/mediapipe/tasks/cc/vision/hand_landmarker/hand_landmarker_graph.cc @@ -92,18 +92,30 @@ absl::Status SetSubTaskBaseOptions(const ModelAssetBundleResources& resources, bool is_copy) { ASSIGN_OR_RETURN(const auto hand_detector_file, resources.GetModelFile(kHandDetectorTFLiteName)); + auto* hand_detector_graph_options = + options->mutable_hand_detector_graph_options(); SetExternalFile(hand_detector_file, - options->mutable_hand_detector_graph_options() - ->mutable_base_options() + hand_detector_graph_options->mutable_base_options() ->mutable_model_asset(), is_copy); + hand_detector_graph_options->mutable_base_options() + ->mutable_acceleration() + ->CopyFrom(options->base_options().acceleration()); + hand_detector_graph_options->mutable_base_options()->set_use_stream_mode( + options->base_options().use_stream_mode()); ASSIGN_OR_RETURN(const auto hand_landmarks_detector_file, resources.GetModelFile(kHandLandmarksDetectorTFLiteName)); + auto* hand_landmarks_detector_graph_options = + options->mutable_hand_landmarks_detector_graph_options(); SetExternalFile(hand_landmarks_detector_file, - options->mutable_hand_landmarks_detector_graph_options() - ->mutable_base_options() + hand_landmarks_detector_graph_options->mutable_base_options() ->mutable_model_asset(), is_copy); + hand_landmarks_detector_graph_options->mutable_base_options() + ->mutable_acceleration() + ->CopyFrom(options->base_options().acceleration()); + hand_landmarks_detector_graph_options->mutable_base_options() + ->set_use_stream_mode(options->base_options().use_stream_mode()); return absl::OkStatus(); } diff --git a/mediapipe/tasks/cc/vision/hand_landmarker/hand_landmarker_graph_test.cc b/mediapipe/tasks/cc/vision/hand_landmarker/hand_landmarker_graph_test.cc index 08beb1a1b..f275486f5 100644 --- a/mediapipe/tasks/cc/vision/hand_landmarker/hand_landmarker_graph_test.cc +++ b/mediapipe/tasks/cc/vision/hand_landmarker/hand_landmarker_graph_test.cc @@ -67,7 +67,7 @@ using ::testing::proto::Approximately; using ::testing::proto::Partially; constexpr char kTestDataDirectory[] = "/mediapipe/tasks/testdata/vision/"; -constexpr char kHandLandmarkerModelBundle[] = "hand_landmark.task"; +constexpr char kHandLandmarkerModelBundle[] = "hand_landmarker.task"; constexpr char kLeftHandsImage[] = "left_hands.jpg"; constexpr char kLeftHandsRotatedImage[] = "left_hands_rotated.jpg"; diff --git a/mediapipe/tasks/java/com/google/mediapipe/tasks/vision/BUILD b/mediapipe/tasks/java/com/google/mediapipe/tasks/vision/BUILD index dcf3b3542..2bdcc2522 100644 --- a/mediapipe/tasks/java/com/google/mediapipe/tasks/vision/BUILD +++ b/mediapipe/tasks/java/com/google/mediapipe/tasks/vision/BUILD @@ -128,6 +128,7 @@ android_library( "//mediapipe/java/com/google/mediapipe/framework/image", "//mediapipe/tasks/cc/components/processors/proto:classifier_options_java_proto_lite", "//mediapipe/tasks/cc/core/proto:base_options_java_proto_lite", + "//mediapipe/tasks/cc/vision/gesture_recognizer/proto:gesture_classifier_graph_options_java_proto_lite", "//mediapipe/tasks/cc/vision/gesture_recognizer/proto:gesture_recognizer_graph_options_java_proto_lite", "//mediapipe/tasks/cc/vision/gesture_recognizer/proto:hand_gesture_recognizer_graph_options_java_proto_lite", "//mediapipe/tasks/cc/vision/hand_detector/proto:hand_detector_graph_options_java_proto_lite", diff --git a/mediapipe/tasks/java/com/google/mediapipe/tasks/vision/gesturerecognizer/GestureRecognizer.java b/mediapipe/tasks/java/com/google/mediapipe/tasks/vision/gesturerecognizer/GestureRecognizer.java index 560508903..55cf275e9 100644 --- a/mediapipe/tasks/java/com/google/mediapipe/tasks/vision/gesturerecognizer/GestureRecognizer.java +++ b/mediapipe/tasks/java/com/google/mediapipe/tasks/vision/gesturerecognizer/GestureRecognizer.java @@ -38,6 +38,7 @@ import com.google.mediapipe.tasks.core.TaskRunner; import com.google.mediapipe.tasks.core.proto.BaseOptionsProto; import com.google.mediapipe.tasks.vision.core.BaseVisionTaskApi; import com.google.mediapipe.tasks.vision.core.RunningMode; +import com.google.mediapipe.tasks.vision.gesturerecognizer.proto.GestureClassifierGraphOptionsProto; import com.google.mediapipe.tasks.vision.gesturerecognizer.proto.GestureRecognizerGraphOptionsProto; import com.google.mediapipe.tasks.vision.gesturerecognizer.proto.HandGestureRecognizerGraphOptionsProto; import com.google.mediapipe.tasks.vision.handdetector.proto.HandDetectorGraphOptionsProto; @@ -300,13 +301,6 @@ public final class GestureRecognizer extends BaseVisionTaskApi { */ public abstract Builder setRunningMode(RunningMode value); - // TODO: remove these. Temporary solutions before bundle asset is ready. - public abstract Builder setBaseOptionsHandDetector(BaseOptions value); - - public abstract Builder setBaseOptionsHandLandmarker(BaseOptions value); - - public abstract Builder setBaseOptionsGestureRecognizer(BaseOptions value); - /** Sets the maximum number of hands can be detected by the GestureRecognizer. */ public abstract Builder setNumHands(Integer value); @@ -366,13 +360,6 @@ public final class GestureRecognizer extends BaseVisionTaskApi { abstract BaseOptions baseOptions(); - // TODO: remove these. Temporary solutions before bundle asset is ready. - abstract BaseOptions baseOptionsHandDetector(); - - abstract BaseOptions baseOptionsHandLandmarker(); - - abstract BaseOptions baseOptionsGestureRecognizer(); - abstract RunningMode runningMode(); abstract Optional numHands(); @@ -405,22 +392,18 @@ public final class GestureRecognizer extends BaseVisionTaskApi { */ @Override public CalculatorOptions convertToCalculatorOptionsProto() { - BaseOptionsProto.BaseOptions.Builder baseOptionsBuilder = - BaseOptionsProto.BaseOptions.newBuilder() - .setUseStreamMode(runningMode() != RunningMode.IMAGE) - .mergeFrom(convertBaseOptionsToProto(baseOptions())); GestureRecognizerGraphOptionsProto.GestureRecognizerGraphOptions.Builder taskOptionsBuilder = GestureRecognizerGraphOptionsProto.GestureRecognizerGraphOptions.newBuilder() - .setBaseOptions(baseOptionsBuilder); + .setBaseOptions( + BaseOptionsProto.BaseOptions.newBuilder() + .setUseStreamMode(runningMode() != RunningMode.IMAGE) + .mergeFrom(convertBaseOptionsToProto(baseOptions())) + .build()); // Setup HandDetectorGraphOptions. HandDetectorGraphOptionsProto.HandDetectorGraphOptions.Builder handDetectorGraphOptionsBuilder = - HandDetectorGraphOptionsProto.HandDetectorGraphOptions.newBuilder() - .setBaseOptions( - BaseOptionsProto.BaseOptions.newBuilder() - .setUseStreamMode(runningMode() != RunningMode.IMAGE) - .mergeFrom(convertBaseOptionsToProto(baseOptionsHandDetector()))); + HandDetectorGraphOptionsProto.HandDetectorGraphOptions.newBuilder(); numHands().ifPresent(handDetectorGraphOptionsBuilder::setNumHands); minHandDetectionConfidence() .ifPresent(handDetectorGraphOptionsBuilder::setMinDetectionConfidence); @@ -428,19 +411,12 @@ public final class GestureRecognizer extends BaseVisionTaskApi { // Setup HandLandmarkerGraphOptions. HandLandmarksDetectorGraphOptionsProto.HandLandmarksDetectorGraphOptions.Builder handLandmarksDetectorGraphOptionsBuilder = - HandLandmarksDetectorGraphOptionsProto.HandLandmarksDetectorGraphOptions.newBuilder() - .setBaseOptions( - BaseOptionsProto.BaseOptions.newBuilder() - .setUseStreamMode(runningMode() != RunningMode.IMAGE) - .mergeFrom(convertBaseOptionsToProto(baseOptionsHandLandmarker()))); + HandLandmarksDetectorGraphOptionsProto.HandLandmarksDetectorGraphOptions.newBuilder(); minHandPresenceConfidence() .ifPresent(handLandmarksDetectorGraphOptionsBuilder::setMinDetectionConfidence); HandLandmarkerGraphOptionsProto.HandLandmarkerGraphOptions.Builder handLandmarkerGraphOptionsBuilder = - HandLandmarkerGraphOptionsProto.HandLandmarkerGraphOptions.newBuilder() - .setBaseOptions( - BaseOptionsProto.BaseOptions.newBuilder() - .setUseStreamMode(runningMode() != RunningMode.IMAGE)); + HandLandmarkerGraphOptionsProto.HandLandmarkerGraphOptions.newBuilder(); minTrackingConfidence() .ifPresent(handLandmarkerGraphOptionsBuilder::setMinTrackingConfidence); handLandmarkerGraphOptionsBuilder @@ -450,16 +426,13 @@ public final class GestureRecognizer extends BaseVisionTaskApi { // Setup HandGestureRecognizerGraphOptions. HandGestureRecognizerGraphOptionsProto.HandGestureRecognizerGraphOptions.Builder handGestureRecognizerGraphOptionsBuilder = - HandGestureRecognizerGraphOptionsProto.HandGestureRecognizerGraphOptions.newBuilder() - .setBaseOptions( - BaseOptionsProto.BaseOptions.newBuilder() - .setUseStreamMode(runningMode() != RunningMode.IMAGE) - .mergeFrom(convertBaseOptionsToProto(baseOptionsGestureRecognizer()))); + HandGestureRecognizerGraphOptionsProto.HandGestureRecognizerGraphOptions.newBuilder(); ClassifierOptionsProto.ClassifierOptions.Builder classifierOptionsBuilder = ClassifierOptionsProto.ClassifierOptions.newBuilder(); minGestureConfidence().ifPresent(classifierOptionsBuilder::setScoreThreshold); - handGestureRecognizerGraphOptionsBuilder.setClassifierOptions( - classifierOptionsBuilder.build()); + handGestureRecognizerGraphOptionsBuilder.setCannedGestureClassifierGraphOptions( + GestureClassifierGraphOptionsProto.GestureClassifierGraphOptions.newBuilder() + .setClassifierOptions(classifierOptionsBuilder.build())); taskOptionsBuilder .setHandLandmarkerGraphOptions(handLandmarkerGraphOptionsBuilder.build()) diff --git a/mediapipe/tasks/javatests/com/google/mediapipe/tasks/vision/gesturerecognizer/GestureRecognizerTest.java b/mediapipe/tasks/javatests/com/google/mediapipe/tasks/vision/gesturerecognizer/GestureRecognizerTest.java index 8beea96ac..31e59a259 100644 --- a/mediapipe/tasks/javatests/com/google/mediapipe/tasks/vision/gesturerecognizer/GestureRecognizerTest.java +++ b/mediapipe/tasks/javatests/com/google/mediapipe/tasks/vision/gesturerecognizer/GestureRecognizerTest.java @@ -43,10 +43,7 @@ import org.junit.runners.Suite.SuiteClasses; @RunWith(Suite.class) @SuiteClasses({GestureRecognizerTest.General.class, GestureRecognizerTest.RunningModeTest.class}) public class GestureRecognizerTest { - private static final String HAND_DETECTOR_MODEL_FILE = "palm_detection_full.tflite"; - private static final String HAND_LANDMARKER_MODEL_FILE = "hand_landmark_full.tflite"; - private static final String GESTURE_RECOGNIZER_MODEL_FILE = - "cg_classifier_screen3d_landmark_features_nn_2022_08_04_base_simple_model.tflite"; + private static final String GESTURE_RECOGNIZER_BUNDLE_ASSET_FILE = "gesture_recognizer.task"; private static final String TWO_HANDS_IMAGE = "right_hands.jpg"; private static final String THUMB_UP_IMAGE = "thumb_up.jpg"; private static final String NO_HANDS_IMAGE = "cats_and_dogs.jpg"; @@ -66,13 +63,9 @@ public class GestureRecognizerTest { GestureRecognizerOptions options = GestureRecognizerOptions.builder() .setBaseOptions( - BaseOptions.builder().setModelAssetPath(GESTURE_RECOGNIZER_MODEL_FILE).build()) - .setBaseOptionsHandDetector( - BaseOptions.builder().setModelAssetPath(HAND_DETECTOR_MODEL_FILE).build()) - .setBaseOptionsHandLandmarker( - BaseOptions.builder().setModelAssetPath(HAND_LANDMARKER_MODEL_FILE).build()) - .setBaseOptionsGestureRecognizer( - BaseOptions.builder().setModelAssetPath(GESTURE_RECOGNIZER_MODEL_FILE).build()) + BaseOptions.builder() + .setModelAssetPath(GESTURE_RECOGNIZER_BUNDLE_ASSET_FILE) + .build()) .build(); GestureRecognizer gestureRecognizer = GestureRecognizer.createFromOptions(ApplicationProvider.getApplicationContext(), options); @@ -88,13 +81,9 @@ public class GestureRecognizerTest { GestureRecognizerOptions options = GestureRecognizerOptions.builder() .setBaseOptions( - BaseOptions.builder().setModelAssetPath(GESTURE_RECOGNIZER_MODEL_FILE).build()) - .setBaseOptionsHandDetector( - BaseOptions.builder().setModelAssetPath(HAND_DETECTOR_MODEL_FILE).build()) - .setBaseOptionsHandLandmarker( - BaseOptions.builder().setModelAssetPath(HAND_LANDMARKER_MODEL_FILE).build()) - .setBaseOptionsGestureRecognizer( - BaseOptions.builder().setModelAssetPath(GESTURE_RECOGNIZER_MODEL_FILE).build()) + BaseOptions.builder() + .setModelAssetPath(GESTURE_RECOGNIZER_BUNDLE_ASSET_FILE) + .build()) .build(); GestureRecognizer gestureRecognizer = GestureRecognizer.createFromOptions(ApplicationProvider.getApplicationContext(), options); @@ -111,16 +100,12 @@ public class GestureRecognizerTest { GestureRecognizerOptions options = GestureRecognizerOptions.builder() .setBaseOptions( - BaseOptions.builder().setModelAssetPath(GESTURE_RECOGNIZER_MODEL_FILE).build()) - .setBaseOptionsHandDetector( - BaseOptions.builder().setModelAssetPath(HAND_DETECTOR_MODEL_FILE).build()) - .setBaseOptionsHandLandmarker( - BaseOptions.builder().setModelAssetPath(HAND_LANDMARKER_MODEL_FILE).build()) - .setBaseOptionsGestureRecognizer( - BaseOptions.builder().setModelAssetPath(GESTURE_RECOGNIZER_MODEL_FILE).build()) + BaseOptions.builder() + .setModelAssetPath(GESTURE_RECOGNIZER_BUNDLE_ASSET_FILE) + .build()) // TODO update the confidence to be in range [0,1] after embedding model // and scoring calculator is integrated. - .setMinGestureConfidence(3.0f) + .setMinGestureConfidence(2.0f) .build(); GestureRecognizer gestureRecognizer = GestureRecognizer.createFromOptions(ApplicationProvider.getApplicationContext(), options); @@ -139,13 +124,9 @@ public class GestureRecognizerTest { GestureRecognizerOptions options = GestureRecognizerOptions.builder() .setBaseOptions( - BaseOptions.builder().setModelAssetPath(GESTURE_RECOGNIZER_MODEL_FILE).build()) - .setBaseOptionsHandDetector( - BaseOptions.builder().setModelAssetPath(HAND_DETECTOR_MODEL_FILE).build()) - .setBaseOptionsHandLandmarker( - BaseOptions.builder().setModelAssetPath(HAND_LANDMARKER_MODEL_FILE).build()) - .setBaseOptionsGestureRecognizer( - BaseOptions.builder().setModelAssetPath(GESTURE_RECOGNIZER_MODEL_FILE).build()) + BaseOptions.builder() + .setModelAssetPath(GESTURE_RECOGNIZER_BUNDLE_ASSET_FILE) + .build()) .setNumHands(2) .build(); GestureRecognizer gestureRecognizer = @@ -168,19 +149,7 @@ public class GestureRecognizerTest { GestureRecognizerOptions.builder() .setBaseOptions( BaseOptions.builder() - .setModelAssetPath(GESTURE_RECOGNIZER_MODEL_FILE) - .build()) - .setBaseOptionsHandDetector( - BaseOptions.builder() - .setModelAssetPath(HAND_DETECTOR_MODEL_FILE) - .build()) - .setBaseOptionsHandLandmarker( - BaseOptions.builder() - .setModelAssetPath(HAND_LANDMARKER_MODEL_FILE) - .build()) - .setBaseOptionsGestureRecognizer( - BaseOptions.builder() - .setModelAssetPath(GESTURE_RECOGNIZER_MODEL_FILE) + .setModelAssetPath(GESTURE_RECOGNIZER_BUNDLE_ASSET_FILE) .build()) .setRunningMode(mode) .setResultListener((gestureRecognitionResult, inputImage) -> {}) @@ -201,15 +170,7 @@ public class GestureRecognizerTest { GestureRecognizerOptions.builder() .setBaseOptions( BaseOptions.builder() - .setModelAssetPath(GESTURE_RECOGNIZER_MODEL_FILE) - .build()) - .setBaseOptionsHandDetector( - BaseOptions.builder().setModelAssetPath(HAND_DETECTOR_MODEL_FILE).build()) - .setBaseOptionsHandLandmarker( - BaseOptions.builder().setModelAssetPath(HAND_LANDMARKER_MODEL_FILE).build()) - .setBaseOptionsGestureRecognizer( - BaseOptions.builder() - .setModelAssetPath(GESTURE_RECOGNIZER_MODEL_FILE) + .setModelAssetPath(GESTURE_RECOGNIZER_BUNDLE_ASSET_FILE) .build()) .setRunningMode(RunningMode.LIVE_STREAM) .build()); @@ -223,13 +184,9 @@ public class GestureRecognizerTest { GestureRecognizerOptions options = GestureRecognizerOptions.builder() .setBaseOptions( - BaseOptions.builder().setModelAssetPath(GESTURE_RECOGNIZER_MODEL_FILE).build()) - .setBaseOptionsHandDetector( - BaseOptions.builder().setModelAssetPath(HAND_DETECTOR_MODEL_FILE).build()) - .setBaseOptionsHandLandmarker( - BaseOptions.builder().setModelAssetPath(HAND_LANDMARKER_MODEL_FILE).build()) - .setBaseOptionsGestureRecognizer( - BaseOptions.builder().setModelAssetPath(GESTURE_RECOGNIZER_MODEL_FILE).build()) + BaseOptions.builder() + .setModelAssetPath(GESTURE_RECOGNIZER_BUNDLE_ASSET_FILE) + .build()) .setRunningMode(RunningMode.IMAGE) .build(); @@ -252,13 +209,9 @@ public class GestureRecognizerTest { GestureRecognizerOptions options = GestureRecognizerOptions.builder() .setBaseOptions( - BaseOptions.builder().setModelAssetPath(GESTURE_RECOGNIZER_MODEL_FILE).build()) - .setBaseOptionsHandDetector( - BaseOptions.builder().setModelAssetPath(HAND_DETECTOR_MODEL_FILE).build()) - .setBaseOptionsHandLandmarker( - BaseOptions.builder().setModelAssetPath(HAND_LANDMARKER_MODEL_FILE).build()) - .setBaseOptionsGestureRecognizer( - BaseOptions.builder().setModelAssetPath(GESTURE_RECOGNIZER_MODEL_FILE).build()) + BaseOptions.builder() + .setModelAssetPath(GESTURE_RECOGNIZER_BUNDLE_ASSET_FILE) + .build()) .setRunningMode(RunningMode.VIDEO) .build(); @@ -281,13 +234,9 @@ public class GestureRecognizerTest { GestureRecognizerOptions options = GestureRecognizerOptions.builder() .setBaseOptions( - BaseOptions.builder().setModelAssetPath(GESTURE_RECOGNIZER_MODEL_FILE).build()) - .setBaseOptionsHandDetector( - BaseOptions.builder().setModelAssetPath(HAND_DETECTOR_MODEL_FILE).build()) - .setBaseOptionsHandLandmarker( - BaseOptions.builder().setModelAssetPath(HAND_LANDMARKER_MODEL_FILE).build()) - .setBaseOptionsGestureRecognizer( - BaseOptions.builder().setModelAssetPath(GESTURE_RECOGNIZER_MODEL_FILE).build()) + BaseOptions.builder() + .setModelAssetPath(GESTURE_RECOGNIZER_BUNDLE_ASSET_FILE) + .build()) .setRunningMode(RunningMode.LIVE_STREAM) .setResultListener((gestureRecognitionResult, inputImage) -> {}) .build(); @@ -311,13 +260,9 @@ public class GestureRecognizerTest { GestureRecognizerOptions options = GestureRecognizerOptions.builder() .setBaseOptions( - BaseOptions.builder().setModelAssetPath(GESTURE_RECOGNIZER_MODEL_FILE).build()) - .setBaseOptionsHandDetector( - BaseOptions.builder().setModelAssetPath(HAND_DETECTOR_MODEL_FILE).build()) - .setBaseOptionsHandLandmarker( - BaseOptions.builder().setModelAssetPath(HAND_LANDMARKER_MODEL_FILE).build()) - .setBaseOptionsGestureRecognizer( - BaseOptions.builder().setModelAssetPath(GESTURE_RECOGNIZER_MODEL_FILE).build()) + BaseOptions.builder() + .setModelAssetPath(GESTURE_RECOGNIZER_BUNDLE_ASSET_FILE) + .build()) .setRunningMode(RunningMode.IMAGE) .build(); @@ -335,13 +280,9 @@ public class GestureRecognizerTest { GestureRecognizerOptions options = GestureRecognizerOptions.builder() .setBaseOptions( - BaseOptions.builder().setModelAssetPath(GESTURE_RECOGNIZER_MODEL_FILE).build()) - .setBaseOptionsHandDetector( - BaseOptions.builder().setModelAssetPath(HAND_DETECTOR_MODEL_FILE).build()) - .setBaseOptionsHandLandmarker( - BaseOptions.builder().setModelAssetPath(HAND_LANDMARKER_MODEL_FILE).build()) - .setBaseOptionsGestureRecognizer( - BaseOptions.builder().setModelAssetPath(GESTURE_RECOGNIZER_MODEL_FILE).build()) + BaseOptions.builder() + .setModelAssetPath(GESTURE_RECOGNIZER_BUNDLE_ASSET_FILE) + .build()) .setRunningMode(RunningMode.VIDEO) .build(); GestureRecognizer gestureRecognizer = @@ -363,13 +304,9 @@ public class GestureRecognizerTest { GestureRecognizerOptions options = GestureRecognizerOptions.builder() .setBaseOptions( - BaseOptions.builder().setModelAssetPath(GESTURE_RECOGNIZER_MODEL_FILE).build()) - .setBaseOptionsHandDetector( - BaseOptions.builder().setModelAssetPath(HAND_DETECTOR_MODEL_FILE).build()) - .setBaseOptionsHandLandmarker( - BaseOptions.builder().setModelAssetPath(HAND_LANDMARKER_MODEL_FILE).build()) - .setBaseOptionsGestureRecognizer( - BaseOptions.builder().setModelAssetPath(GESTURE_RECOGNIZER_MODEL_FILE).build()) + BaseOptions.builder() + .setModelAssetPath(GESTURE_RECOGNIZER_BUNDLE_ASSET_FILE) + .build()) .setRunningMode(RunningMode.LIVE_STREAM) .setResultListener( (actualResult, inputImage) -> { @@ -397,13 +334,9 @@ public class GestureRecognizerTest { GestureRecognizerOptions options = GestureRecognizerOptions.builder() .setBaseOptions( - BaseOptions.builder().setModelAssetPath(GESTURE_RECOGNIZER_MODEL_FILE).build()) - .setBaseOptionsHandDetector( - BaseOptions.builder().setModelAssetPath(HAND_DETECTOR_MODEL_FILE).build()) - .setBaseOptionsHandLandmarker( - BaseOptions.builder().setModelAssetPath(HAND_LANDMARKER_MODEL_FILE).build()) - .setBaseOptionsGestureRecognizer( - BaseOptions.builder().setModelAssetPath(GESTURE_RECOGNIZER_MODEL_FILE).build()) + BaseOptions.builder() + .setModelAssetPath(GESTURE_RECOGNIZER_BUNDLE_ASSET_FILE) + .build()) .setRunningMode(RunningMode.LIVE_STREAM) .setResultListener( (actualResult, inputImage) -> { diff --git a/mediapipe/tasks/testdata/vision/BUILD b/mediapipe/tasks/testdata/vision/BUILD index f899be8ef..ebb8f05a6 100644 --- a/mediapipe/tasks/testdata/vision/BUILD +++ b/mediapipe/tasks/testdata/vision/BUILD @@ -35,7 +35,6 @@ mediapipe_files(srcs = [ "coco_ssd_mobilenet_v1_1.0_quant_2018_06_29.tflite", "coco_ssd_mobilenet_v1_1.0_quant_2018_06_29_with_dummy_score_calibration.tflite", "deeplabv3.tflite", - "hand_landmark.task", "hand_landmark_full.tflite", "hand_landmark_lite.tflite", "left_hands.jpg", @@ -67,13 +66,13 @@ mediapipe_files(srcs = [ exports_files( srcs = [ - "cg_classifier_screen3d_landmark_features_nn_2022_08_04_base_simple_model.tflite", "expected_left_down_hand_landmarks.prototxt", "expected_left_down_hand_rotated_landmarks.prototxt", "expected_left_up_hand_landmarks.prototxt", "expected_left_up_hand_rotated_landmarks.prototxt", "expected_right_down_hand_landmarks.prototxt", "expected_right_up_hand_landmarks.prototxt", + "gesture_recognizer.task", ], ) @@ -119,9 +118,9 @@ filegroup( "coco_ssd_mobilenet_v1_1.0_quant_2018_06_29.tflite", "coco_ssd_mobilenet_v1_1.0_quant_2018_06_29_with_dummy_score_calibration.tflite", "deeplabv3.tflite", - "hand_landmark.task", "hand_landmark_full.tflite", "hand_landmark_lite.tflite", + "hand_landmarker.task", "mobilenet_v1_0.25_192_quantized_1_default_1.tflite", "mobilenet_v1_0.25_224_1_default_1.tflite", "mobilenet_v1_0.25_224_1_metadata_1.tflite", diff --git a/mediapipe/tasks/testdata/vision/hand_landmark.task b/mediapipe/tasks/testdata/vision/hand_landmarker.task similarity index 99% rename from mediapipe/tasks/testdata/vision/hand_landmark.task rename to mediapipe/tasks/testdata/vision/hand_landmarker.task index b6eedf3248f860d736929631c3be3d6b0f9d4cdc..1ae9f7f6b33c52e16493b566cc76d014655c4d25 100644 GIT binary patch delta 491 zcmW;I*)ALa7{zgns?t-%sHL#3D#Wm(x)(8{r%S-Do8m2VYT&#gkM$a-PDw2G|~tMpo-rdWW3f5Nd4%2199 zRH6#6@EULM7S(u%8q}f=^=LpNn$V0Ew4x2|c#jTr;sd(SjgRO-FZvKhKL#*}Pxy=> z3}XZljA9Jqi1&};3nnpzY0Tg&zF`(~n8yMZv4mx;U=?dv#|AdBg>CHMJAPmnd)UVT e4snE^IL5Dk;n?r)^Y~ZK{zgu^?!!&wB>ErZ3Beiw delta 491 zcmW;I>oODo7{+l~FPJ|qGrA5h7az5k~5;@z?K0L!6m}zFaiZ$JXd+1*L z*E{o@c|N=|?=x@w@Bcs?nj8zH6{RF6w{iSJ)&ncm3Rr1Yx|Lx)v@)$ME8EJkaxK>K ztbFT{_1G$~3auim*ebC~tupJ0Rc=*SPpwL;%Br@WS~E_G@=R3XhAF5@EYxSgAR1!ExOQ+9=t;@`p}O7gb>Cc-eU+K5Wz4;Fp7`(gfV Date: Sat, 22 Oct 2022 15:54:34 -0700 Subject: [PATCH 32/55] Internal change PiperOrigin-RevId: 483078695 --- .../calculators/tensor/inference_calculator_gl.cc | 9 +++++++-- .../tensor/inference_calculator_gl_advanced.cc | 15 ++++++++++----- mediapipe/framework/calculator_profile.proto | 5 +++++ mediapipe/framework/profiler/trace_buffer.h | 5 +++++ 4 files changed, 27 insertions(+), 7 deletions(-) diff --git a/mediapipe/calculators/tensor/inference_calculator_gl.cc b/mediapipe/calculators/tensor/inference_calculator_gl.cc index 1f3768ee0..bd8eb3eed 100644 --- a/mediapipe/calculators/tensor/inference_calculator_gl.cc +++ b/mediapipe/calculators/tensor/inference_calculator_gl.cc @@ -26,6 +26,8 @@ #include "mediapipe/gpu/gl_calculator_helper.h" #include "tensorflow/lite/delegates/gpu/gl_delegate.h" +#define PERFETTO_TRACK_EVENT_NAMESPACE mediapipe + namespace mediapipe { namespace api2 { @@ -191,7 +193,7 @@ absl::Status InferenceCalculatorGlImpl::GpuInferenceRunner::Process( CalculatorContext* cc, const std::vector& input_tensors, std::vector& output_tensors) { return gpu_helper_.RunInGlContext( - [this, &input_tensors, &output_tensors]() -> absl::Status { + [this, cc, &input_tensors, &output_tensors]() -> absl::Status { // Explicitly copy input. for (int i = 0; i < input_tensors.size(); ++i) { glBindBuffer(GL_COPY_READ_BUFFER, @@ -203,7 +205,10 @@ absl::Status InferenceCalculatorGlImpl::GpuInferenceRunner::Process( } // Run inference. - RET_CHECK_EQ(interpreter_->Invoke(), kTfLiteOk); + { + MEDIAPIPE_PROFILING(GPU_TASK_INVOKE, cc); + RET_CHECK_EQ(interpreter_->Invoke(), kTfLiteOk); + } output_tensors.reserve(output_size_); for (int i = 0; i < output_size_; ++i) { diff --git a/mediapipe/calculators/tensor/inference_calculator_gl_advanced.cc b/mediapipe/calculators/tensor/inference_calculator_gl_advanced.cc index 7e11ee072..52359f7f5 100644 --- a/mediapipe/calculators/tensor/inference_calculator_gl_advanced.cc +++ b/mediapipe/calculators/tensor/inference_calculator_gl_advanced.cc @@ -32,6 +32,8 @@ #include "mediapipe/util/android/file/base/helpers.h" #endif // MEDIAPIPE_ANDROID +#define PERFETTO_TRACK_EVENT_NAMESPACE mediapipe + namespace mediapipe { namespace api2 { @@ -83,7 +85,7 @@ class InferenceCalculatorGlAdvancedImpl const mediapipe::InferenceCalculatorOptions::Delegate& delegate); absl::StatusOr> Process( - const std::vector& input_tensors); + CalculatorContext* cc, const std::vector& input_tensors); absl::Status Close(); @@ -121,11 +123,11 @@ absl::Status InferenceCalculatorGlAdvancedImpl::GpuInferenceRunner::Init( absl::StatusOr> InferenceCalculatorGlAdvancedImpl::GpuInferenceRunner::Process( - const std::vector& input_tensors) { + CalculatorContext* cc, const std::vector& input_tensors) { std::vector output_tensors; MP_RETURN_IF_ERROR(gpu_helper_.RunInGlContext( - [this, &input_tensors, &output_tensors]() -> absl::Status { + [this, cc, &input_tensors, &output_tensors]() -> absl::Status { for (int i = 0; i < input_tensors.size(); ++i) { MP_RETURN_IF_ERROR(tflite_gpu_runner_->BindSSBOToInputTensor( input_tensors[i].GetOpenGlBufferReadView().name(), i)); @@ -138,7 +140,10 @@ InferenceCalculatorGlAdvancedImpl::GpuInferenceRunner::Process( output_tensors.back().GetOpenGlBufferWriteView().name(), i)); } // Run inference. - return tflite_gpu_runner_->Invoke(); + { + MEDIAPIPE_PROFILING(GPU_TASK_INVOKE, cc); + return tflite_gpu_runner_->Invoke(); + } })); return output_tensors; @@ -354,7 +359,7 @@ absl::Status InferenceCalculatorGlAdvancedImpl::Process(CalculatorContext* cc) { auto output_tensors = absl::make_unique>(); ASSIGN_OR_RETURN(*output_tensors, - gpu_inference_runner_->Process(input_tensors)); + gpu_inference_runner_->Process(cc, input_tensors)); kOutTensors(cc).Send(std::move(output_tensors)); return absl::OkStatus(); diff --git a/mediapipe/framework/calculator_profile.proto b/mediapipe/framework/calculator_profile.proto index 06ec678a9..1512da6af 100644 --- a/mediapipe/framework/calculator_profile.proto +++ b/mediapipe/framework/calculator_profile.proto @@ -133,7 +133,12 @@ message GraphTrace { TPU_TASK = 13; GPU_CALIBRATION = 14; PACKET_QUEUED = 15; + GPU_TASK_INVOKE = 16; + TPU_TASK_INVOKE = 17; } + // //depot/mediapipe/framework/mediapipe_profiling.h:profiler_census_tags, + // //depot/mediapipe/framework/profiler/trace_buffer.h:event_type_list, + // ) // The timing for one packet set being processed at one caclulator node. message CalculatorTrace { diff --git a/mediapipe/framework/profiler/trace_buffer.h b/mediapipe/framework/profiler/trace_buffer.h index 069f09610..60352c705 100644 --- a/mediapipe/framework/profiler/trace_buffer.h +++ b/mediapipe/framework/profiler/trace_buffer.h @@ -109,6 +109,11 @@ struct TraceEvent { static constexpr EventType TPU_TASK = GraphTrace::TPU_TASK; static constexpr EventType GPU_CALIBRATION = GraphTrace::GPU_CALIBRATION; static constexpr EventType PACKET_QUEUED = GraphTrace::PACKET_QUEUED; + static constexpr EventType GPU_TASK_INVOKE = GraphTrace::GPU_TASK_INVOKE; + static constexpr EventType TPU_TASK_INVOKE = GraphTrace::TPU_TASK_INVOKE; + // //depot/mediapipe/framework/mediapipe_profiling.h:profiler_census_tags, + // //depot/mediapipe/framework/calculator_profile.proto:event_type, + // ) }; // Packet trace log buffer. From ec2a34d2a43639d8a9a2169f3377452a8a37b40f Mon Sep 17 00:00:00 2001 From: MediaPipe Team Date: Sun, 23 Oct 2022 16:42:49 -0700 Subject: [PATCH 33/55] Replace pytype_struct_contrib_test by py_strict_test. Also remove unnecessary BUILD attributes. PiperOrigin-RevId: 483237371 --- mediapipe/model_maker/python/core/BUILD | 2 +- mediapipe/model_maker/python/core/data/BUILD | 9 +-------- mediapipe/model_maker/python/core/tasks/BUILD | 7 +------ mediapipe/model_maker/python/core/utils/BUILD | 9 +-------- mediapipe/model_maker/python/vision/core/BUILD | 2 +- .../model_maker/python/vision/image_classifier/BUILD | 2 +- mediapipe/tasks/python/components/containers/BUILD | 2 +- mediapipe/tasks/python/core/BUILD | 2 +- mediapipe/tasks/python/metadata/metadata_writers/BUILD | 2 +- mediapipe/tasks/python/test/BUILD | 2 +- mediapipe/tasks/python/vision/BUILD | 2 +- mediapipe/tasks/python/vision/core/BUILD | 2 +- 12 files changed, 12 insertions(+), 31 deletions(-) diff --git a/mediapipe/model_maker/python/core/BUILD b/mediapipe/model_maker/python/core/BUILD index 9f205bb11..636a1a720 100644 --- a/mediapipe/model_maker/python/core/BUILD +++ b/mediapipe/model_maker/python/core/BUILD @@ -12,7 +12,7 @@ # See the License for the specific language governing permissions and # limitations under the License. -# Placeholder for internal Python strict library compatibility macro. +# Placeholder for internal Python strict library and test compatibility macro. package( default_visibility = ["//mediapipe:__subpackages__"], diff --git a/mediapipe/model_maker/python/core/data/BUILD b/mediapipe/model_maker/python/core/data/BUILD index c4c659d56..70a62e8f7 100644 --- a/mediapipe/model_maker/python/core/data/BUILD +++ b/mediapipe/model_maker/python/core/data/BUILD @@ -13,6 +13,7 @@ # limitations under the License. # Placeholder for internal Python strict library and test compatibility macro. +# Placeholder for internal Python strict test compatibility macro. licenses(["notice"]) @@ -23,15 +24,12 @@ package( py_library( name = "data_util", srcs = ["data_util.py"], - srcs_version = "PY3", ) py_test( name = "data_util_test", srcs = ["data_util_test.py"], data = ["//mediapipe/model_maker/python/core/data/testdata"], - python_version = "PY3", - srcs_version = "PY3", deps = [":data_util"], ) @@ -44,8 +42,6 @@ py_library( py_test( name = "dataset_test", srcs = ["dataset_test.py"], - python_version = "PY3", - srcs_version = "PY3", deps = [ ":dataset", "//mediapipe/model_maker/python/core/utils:test_util", @@ -55,14 +51,11 @@ py_test( py_library( name = "classification_dataset", srcs = ["classification_dataset.py"], - srcs_version = "PY3", deps = [":dataset"], ) py_test( name = "classification_dataset_test", srcs = ["classification_dataset_test.py"], - python_version = "PY3", - srcs_version = "PY3", deps = [":classification_dataset"], ) diff --git a/mediapipe/model_maker/python/core/tasks/BUILD b/mediapipe/model_maker/python/core/tasks/BUILD index b3588f0be..124de621a 100644 --- a/mediapipe/model_maker/python/core/tasks/BUILD +++ b/mediapipe/model_maker/python/core/tasks/BUILD @@ -13,6 +13,7 @@ # limitations under the License. # Placeholder for internal Python strict library and test compatibility macro. +# Placeholder for internal Python strict test compatibility macro. package( default_visibility = ["//mediapipe:__subpackages__"], @@ -23,7 +24,6 @@ licenses(["notice"]) py_library( name = "custom_model", srcs = ["custom_model.py"], - srcs_version = "PY3", deps = [ "//mediapipe/model_maker/python/core/data:dataset", "//mediapipe/model_maker/python/core/utils:model_util", @@ -34,8 +34,6 @@ py_library( py_test( name = "custom_model_test", srcs = ["custom_model_test.py"], - python_version = "PY3", - srcs_version = "PY3", deps = [ ":custom_model", "//mediapipe/model_maker/python/core/utils:test_util", @@ -45,7 +43,6 @@ py_test( py_library( name = "classifier", srcs = ["classifier.py"], - srcs_version = "PY3", deps = [ ":custom_model", "//mediapipe/model_maker/python/core/data:dataset", @@ -55,8 +52,6 @@ py_library( py_test( name = "classifier_test", srcs = ["classifier_test.py"], - python_version = "PY3", - srcs_version = "PY3", deps = [ ":classifier", "//mediapipe/model_maker/python/core/utils:test_util", diff --git a/mediapipe/model_maker/python/core/utils/BUILD b/mediapipe/model_maker/python/core/utils/BUILD index 2538ec8fa..a2ec52044 100644 --- a/mediapipe/model_maker/python/core/utils/BUILD +++ b/mediapipe/model_maker/python/core/utils/BUILD @@ -13,6 +13,7 @@ # limitations under the License. # Placeholder for internal Python strict library and test compatibility macro. +# Placeholder for internal Python strict test compatibility macro. licenses(["notice"]) @@ -24,7 +25,6 @@ py_library( name = "test_util", testonly = 1, srcs = ["test_util.py"], - srcs_version = "PY3", deps = [ ":model_util", "//mediapipe/model_maker/python/core/data:dataset", @@ -34,7 +34,6 @@ py_library( py_library( name = "model_util", srcs = ["model_util.py"], - srcs_version = "PY3", deps = [ ":quantization", "//mediapipe/model_maker/python/core/data:dataset", @@ -44,8 +43,6 @@ py_library( py_test( name = "model_util_test", srcs = ["model_util_test.py"], - python_version = "PY3", - srcs_version = "PY3", deps = [ ":model_util", ":quantization", @@ -62,8 +59,6 @@ py_library( py_test( name = "loss_functions_test", srcs = ["loss_functions_test.py"], - python_version = "PY3", - srcs_version = "PY3", deps = [":loss_functions"], ) @@ -77,8 +72,6 @@ py_library( py_test( name = "quantization_test", srcs = ["quantization_test.py"], - python_version = "PY3", - srcs_version = "PY3", deps = [ ":quantization", ":test_util", diff --git a/mediapipe/model_maker/python/vision/core/BUILD b/mediapipe/model_maker/python/vision/core/BUILD index 2658841ae..0b15a0276 100644 --- a/mediapipe/model_maker/python/vision/core/BUILD +++ b/mediapipe/model_maker/python/vision/core/BUILD @@ -12,7 +12,7 @@ # See the License for the specific language governing permissions and # limitations under the License. -# Placeholder for internal Python strict library compatibility macro. +# Placeholder for internal Python strict library and test compatibility macro. # Placeholder for internal Python strict test compatibility macro. licenses(["notice"]) diff --git a/mediapipe/model_maker/python/vision/image_classifier/BUILD b/mediapipe/model_maker/python/vision/image_classifier/BUILD index 5b4ec2bd1..a2268059f 100644 --- a/mediapipe/model_maker/python/vision/image_classifier/BUILD +++ b/mediapipe/model_maker/python/vision/image_classifier/BUILD @@ -12,8 +12,8 @@ # See the License for the specific language governing permissions and # limitations under the License. -# Placeholder for internal Python library rule. # Placeholder for internal Python strict library and test compatibility macro. +# Placeholder for internal Python library rule. licenses(["notice"]) diff --git a/mediapipe/tasks/python/components/containers/BUILD b/mediapipe/tasks/python/components/containers/BUILD index 8dd9fcd60..fd25401f7 100644 --- a/mediapipe/tasks/python/components/containers/BUILD +++ b/mediapipe/tasks/python/components/containers/BUILD @@ -12,7 +12,7 @@ # See the License for the specific language governing permissions and # limitations under the License. -# Placeholder for internal Python strict library compatibility macro. +# Placeholder for internal Python strict library and test compatibility macro. package(default_visibility = ["//mediapipe/tasks:internal"]) diff --git a/mediapipe/tasks/python/core/BUILD b/mediapipe/tasks/python/core/BUILD index d5cdeecda..76e2f4f4a 100644 --- a/mediapipe/tasks/python/core/BUILD +++ b/mediapipe/tasks/python/core/BUILD @@ -12,7 +12,7 @@ # See the License for the specific language governing permissions and # limitations under the License. -# Placeholder for internal Python strict library compatibility macro. +# Placeholder for internal Python strict library and test compatibility macro. package(default_visibility = ["//mediapipe/tasks:internal"]) diff --git a/mediapipe/tasks/python/metadata/metadata_writers/BUILD b/mediapipe/tasks/python/metadata/metadata_writers/BUILD index 3e44218ac..2a0c29dec 100644 --- a/mediapipe/tasks/python/metadata/metadata_writers/BUILD +++ b/mediapipe/tasks/python/metadata/metadata_writers/BUILD @@ -1,4 +1,4 @@ -# Placeholder for internal Python strict library compatibility macro. +# Placeholder for internal Python strict library and test compatibility macro. package( default_visibility = [ diff --git a/mediapipe/tasks/python/test/BUILD b/mediapipe/tasks/python/test/BUILD index 8e5b91cf9..92c5f4038 100644 --- a/mediapipe/tasks/python/test/BUILD +++ b/mediapipe/tasks/python/test/BUILD @@ -12,7 +12,7 @@ # See the License for the specific language governing permissions and # limitations under the License. -# Placeholder for internal Python strict library compatibility macro. +# Placeholder for internal Python strict library and test compatibility macro. package(default_visibility = ["//mediapipe/tasks:internal"]) diff --git a/mediapipe/tasks/python/vision/BUILD b/mediapipe/tasks/python/vision/BUILD index 7ff818610..e7be51c8d 100644 --- a/mediapipe/tasks/python/vision/BUILD +++ b/mediapipe/tasks/python/vision/BUILD @@ -12,7 +12,7 @@ # See the License for the specific language governing permissions and # limitations under the License. -# Placeholder for internal Python strict library compatibility macro. +# Placeholder for internal Python strict library and test compatibility macro. package(default_visibility = ["//mediapipe/tasks:internal"]) diff --git a/mediapipe/tasks/python/vision/core/BUILD b/mediapipe/tasks/python/vision/core/BUILD index c7422969a..df1b06f4c 100644 --- a/mediapipe/tasks/python/vision/core/BUILD +++ b/mediapipe/tasks/python/vision/core/BUILD @@ -12,7 +12,7 @@ # See the License for the specific language governing permissions and # limitations under the License. -# Placeholder for internal Python strict library compatibility macro. +# Placeholder for internal Python strict library and test compatibility macro. package(default_visibility = ["//mediapipe/tasks:internal"]) From ab17be92947d539410c2dc1b111321de9562de04 Mon Sep 17 00:00:00 2001 From: Yuqi Li Date: Sun, 23 Oct 2022 23:03:44 -0700 Subject: [PATCH 34/55] Metadata Writer: Add Metadata Writer for image classifier. PiperOrigin-RevId: 483282627 --- .../python/metadata/metadata_writers/BUILD | 6 + .../metadata_writers/image_classifier.py | 71 ++ .../metadata_writers/metadata_writer.py | 130 ++- .../test/metadata/metadata_writers/BUILD | 21 +- .../metadata_writers/image_classifier_test.py | 76 ++ .../metadata_writers/metadata_writer_test.py | 56 +- mediapipe/tasks/testdata/metadata/BUILD | 10 + mediapipe/tasks/testdata/metadata/labels.txt | 1001 +++++++++++++++++ .../metadata/mobilenet_v2_1.0_224.json | 82 ++ .../metadata/mobilenet_v2_1.0_224_quant.json | 82 ++ third_party/external_files.bzl | 30 + 11 files changed, 1537 insertions(+), 28 deletions(-) create mode 100644 mediapipe/tasks/python/metadata/metadata_writers/image_classifier.py create mode 100644 mediapipe/tasks/python/test/metadata/metadata_writers/image_classifier_test.py create mode 100644 mediapipe/tasks/testdata/metadata/labels.txt create mode 100644 mediapipe/tasks/testdata/metadata/mobilenet_v2_1.0_224.json create mode 100644 mediapipe/tasks/testdata/metadata/mobilenet_v2_1.0_224_quant.json diff --git a/mediapipe/tasks/python/metadata/metadata_writers/BUILD b/mediapipe/tasks/python/metadata/metadata_writers/BUILD index 2a0c29dec..d2b55d47d 100644 --- a/mediapipe/tasks/python/metadata/metadata_writers/BUILD +++ b/mediapipe/tasks/python/metadata/metadata_writers/BUILD @@ -37,3 +37,9 @@ py_library( srcs = ["writer_utils.py"], deps = ["//mediapipe/tasks/metadata:schema_py"], ) + +py_library( + name = "image_classifier", + srcs = ["image_classifier.py"], + deps = [":metadata_writer"], +) diff --git a/mediapipe/tasks/python/metadata/metadata_writers/image_classifier.py b/mediapipe/tasks/python/metadata/metadata_writers/image_classifier.py new file mode 100644 index 000000000..c516a342d --- /dev/null +++ b/mediapipe/tasks/python/metadata/metadata_writers/image_classifier.py @@ -0,0 +1,71 @@ +# Copyright 2022 The MediaPipe Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +"""Writes metadata and label file to the image classifier models.""" + +from typing import List, Optional + +from mediapipe.tasks.python.metadata.metadata_writers import metadata_writer + +_MODEL_NAME = "ImageClassifier" +_MODEL_DESCRIPTION = ("Identify the most prominent object in the image from a " + "known set of categories.") + + +class MetadataWriter(metadata_writer.MetadataWriterBase): + """MetadataWriter to write the metadata for image classifier.""" + + @classmethod + def create( + cls, + model_buffer: bytearray, + input_norm_mean: List[float], + input_norm_std: List[float], + labels: metadata_writer.Labels, + score_calibration: Optional[metadata_writer.ScoreCalibration] = None + ) -> "MetadataWriter": + """Creates MetadataWriter to write the metadata for image classifier. + + The parameters required in this method are mandatory when using MediaPipe + Tasks. + + Note that only the output TFLite is used for deployment. The output JSON + content is used to interpret the metadata content. + + Args: + model_buffer: A valid flatbuffer loaded from the TFLite model file. + input_norm_mean: the mean value used in the input tensor normalization + [1]. + input_norm_std: the std value used in the input tensor normalizarion [1]. + labels: an instance of Labels helper class used in the output + classification tensor [2]. + score_calibration: A container of the score calibration operation [3] in + the classification tensor. Optional if the model does not use score + calibration. + + [1]: + https://github.com/google/mediapipe/blob/f8af41b1eb49ff4bdad756ff19d1d36f486be614/mediapipe/tasks/metadata/metadata_schema.fbs#L389 + [2]: + https://github.com/google/mediapipe/blob/f8af41b1eb49ff4bdad756ff19d1d36f486be614/mediapipe/tasks/metadata/metadata_schema.fbs#L99 + [3]: + https://github.com/google/mediapipe/blob/f8af41b1eb49ff4bdad756ff19d1d36f486be614/mediapipe/tasks/metadata/metadata_schema.fbs#L456 + + Returns: + An MetadataWrite object. + """ + writer = metadata_writer.MetadataWriter(model_buffer) + writer.add_genernal_info(_MODEL_NAME, _MODEL_DESCRIPTION) + writer.add_image_input(input_norm_mean, input_norm_std) + writer.add_classification_output(labels, score_calibration) + return cls(writer) diff --git a/mediapipe/tasks/python/metadata/metadata_writers/metadata_writer.py b/mediapipe/tasks/python/metadata/metadata_writers/metadata_writer.py index e69efd015..5a2eaba07 100644 --- a/mediapipe/tasks/python/metadata/metadata_writers/metadata_writer.py +++ b/mediapipe/tasks/python/metadata/metadata_writers/metadata_writer.py @@ -15,19 +15,22 @@ """Generic metadata writer.""" import collections +import csv import dataclasses import os import tempfile from typing import List, Optional, Tuple import flatbuffers -from mediapipe.tasks.metadata import metadata_schema_py_generated as _metadata_fb -from mediapipe.tasks.python.metadata import metadata as _metadata +from mediapipe.tasks.metadata import metadata_schema_py_generated as metadata_fb +from mediapipe.tasks.python.metadata import metadata from mediapipe.tasks.python.metadata.metadata_writers import metadata_info from mediapipe.tasks.python.metadata.metadata_writers import writer_utils _INPUT_IMAGE_NAME = 'image' _INPUT_IMAGE_DESCRIPTION = 'Input image to be processed.' +_OUTPUT_CLASSIFICATION_NAME = 'score' +_OUTPUT_CLASSIFICATION_DESCRIPTION = 'Score of the labels respectively.' @dataclasses.dataclass @@ -140,26 +143,85 @@ class Labels(object): class ScoreCalibration: """Simple container holding score calibration related parameters.""" - # A shortcut to avoid client side code importing _metadata_fb - transformation_types = _metadata_fb.ScoreTransformationType + # A shortcut to avoid client side code importing metadata_fb + transformation_types = metadata_fb.ScoreTransformationType def __init__(self, - transformation_type: _metadata_fb.ScoreTransformationType, - parameters: List[CalibrationParameter], + transformation_type: metadata_fb.ScoreTransformationType, + parameters: List[Optional[CalibrationParameter]], default_score: int = 0): self.transformation_type = transformation_type self.parameters = parameters self.default_score = default_score + @classmethod + def create_from_file(cls, + transformation_type: metadata_fb.ScoreTransformationType, + file_path: str, + default_score: int = 0) -> 'ScoreCalibration': + """Creates ScoreCalibration from the file. + + Args: + transformation_type: type of the function used for transforming the + uncalibrated score before applying score calibration. + file_path: file_path of the score calibration file [1]. Contains + sigmoid-based score calibration parameters, formatted as CSV. Lines + contain for each index of an output tensor the scale, slope, offset and + (optional) min_score parameters to be used for sigmoid fitting (in this + order and in `strtof`-compatible [2] format). Scale should be a + non-negative value. A line may be left empty to default calibrated + scores for this index to default_score. In summary, each line should + thus contain 0, 3 or 4 comma-separated values. + default_score: the default calibrated score to apply if the uncalibrated + score is below min_score or if no parameters were specified for a given + index. + [1]: + https://github.com/google/mediapipe/blob/f8af41b1eb49ff4bdad756ff19d1d36f486be614/mediapipe/tasks/metadata/metadata_schema.fbs#L133 + [2]: + https://en.cppreference.com/w/c/string/byte/strtof + + Returns: + A ScoreCalibration object. + Raises: + ValueError: if the score_calibration file is malformed. + """ + with open(file_path, 'r') as calibration_file: + csv_reader = csv.reader(calibration_file, delimiter=',') + parameters = [] + for row in csv_reader: + if not row: + parameters.append(None) + continue + + if len(row) != 3 and len(row) != 4: + raise ValueError( + f'Expected empty lines or 3 or 4 parameters per line in score' + f' calibration file, but got {len(row)}.') + + if float(row[0]) < 0: + raise ValueError( + f'Expected scale to be a non-negative value, but got ' + f'{float(row[0])}.') + + parameters.append( + CalibrationParameter( + scale=float(row[0]), + slope=float(row[1]), + offset=float(row[2]), + min_score=None if len(row) == 3 else float(row[3]))) + + return cls(transformation_type, parameters, default_score) + def _fill_default_tensor_names( - tensor_metadata: List[_metadata_fb.TensorMetadataT], + tensor_metadata_list: List[metadata_fb.TensorMetadataT], tensor_names_from_model: List[str]): """Fills the default tensor names.""" # If tensor name in metadata is empty, default to the tensor name saved in # the model. - for metadata, name in zip(tensor_metadata, tensor_names_from_model): - metadata.name = metadata.name or name + for tensor_metadata, name in zip(tensor_metadata_list, + tensor_names_from_model): + tensor_metadata.name = tensor_metadata.name or name def _pair_tensor_metadata( @@ -212,7 +274,7 @@ def _create_metadata_buffer( input_metadata = [m.create_metadata() for m in input_md] else: num_input_tensors = writer_utils.get_subgraph(model_buffer).InputsLength() - input_metadata = [_metadata_fb.TensorMetadataT()] * num_input_tensors + input_metadata = [metadata_fb.TensorMetadataT()] * num_input_tensors _fill_default_tensor_names(input_metadata, writer_utils.get_input_tensor_names(model_buffer)) @@ -224,12 +286,12 @@ def _create_metadata_buffer( output_metadata = [m.create_metadata() for m in output_md] else: num_output_tensors = writer_utils.get_subgraph(model_buffer).OutputsLength() - output_metadata = [_metadata_fb.TensorMetadataT()] * num_output_tensors + output_metadata = [metadata_fb.TensorMetadataT()] * num_output_tensors _fill_default_tensor_names(output_metadata, writer_utils.get_output_tensor_names(model_buffer)) # Create the subgraph metadata. - subgraph_metadata = _metadata_fb.SubGraphMetadataT() + subgraph_metadata = metadata_fb.SubGraphMetadataT() subgraph_metadata.inputTensorMetadata = input_metadata subgraph_metadata.outputTensorMetadata = output_metadata @@ -243,7 +305,7 @@ def _create_metadata_buffer( b = flatbuffers.Builder(0) b.Finish( model_metadata.Pack(b), - _metadata.MetadataPopulator.METADATA_FILE_IDENTIFIER) + metadata.MetadataPopulator.METADATA_FILE_IDENTIFIER) return b.Output() @@ -291,7 +353,7 @@ class MetadataWriter(object): name=model_name, description=model_description) return self - color_space_types = _metadata_fb.ColorSpaceType + color_space_types = metadata_fb.ColorSpaceType def add_feature_input(self, name: Optional[str] = None, @@ -305,7 +367,7 @@ class MetadataWriter(object): self, norm_mean: List[float], norm_std: List[float], - color_space_type: Optional[int] = _metadata_fb.ColorSpaceType.RGB, + color_space_type: Optional[int] = metadata_fb.ColorSpaceType.RGB, name: str = _INPUT_IMAGE_NAME, description: str = _INPUT_IMAGE_DESCRIPTION) -> 'MetadataWriter': """Adds an input image metadata for the image input. @@ -341,9 +403,6 @@ class MetadataWriter(object): self._input_mds.append(input_md) return self - _OUTPUT_CLASSIFICATION_NAME = 'score' - _OUTPUT_CLASSIFICATION_DESCRIPTION = 'Score of the labels respectively' - def add_classification_output( self, labels: Optional[Labels] = None, @@ -416,8 +475,7 @@ class MetadataWriter(object): A tuple of (model_with_metadata_in_bytes, metdata_json_content) """ # Populates metadata and associated files into TFLite model buffer. - populator = _metadata.MetadataPopulator.with_model_buffer( - self._model_buffer) + populator = metadata.MetadataPopulator.with_model_buffer(self._model_buffer) metadata_buffer = _create_metadata_buffer( model_buffer=self._model_buffer, general_md=self._general_md, @@ -429,7 +487,7 @@ class MetadataWriter(object): populator.populate() tflite_content = populator.get_model_buffer() - displayer = _metadata.MetadataDisplayer.with_model_buffer(tflite_content) + displayer = metadata.MetadataDisplayer.with_model_buffer(tflite_content) metadata_json_content = displayer.get_metadata_json() return tflite_content, metadata_json_content @@ -452,9 +510,7 @@ class MetadataWriter(object): """Stores calibration parameters in a csv file.""" filepath = os.path.join(self._temp_folder.name, filename) with open(filepath, 'w') as f: - for idx, item in enumerate(calibrations): - if idx != 0: - f.write('\n') + for item in calibrations: if item: if item.scale is None or item.slope is None or item.offset is None: raise ValueError('scale, slope and offset values can not be set to ' @@ -463,6 +519,30 @@ class MetadataWriter(object): f.write(f'{item.scale},{item.slope},{item.offset},{item.min_score}') else: f.write(f'{item.scale},{item.slope},{item.offset}') + f.write('\n') - self._associated_files.append(filepath) + self._associated_files.append(filepath) return filepath + + +class MetadataWriterBase: + """Base MetadataWriter class which contains the apis exposed to users. + + MetadataWriter for Tasks e.g. image classifier / object detector will inherit + this class for their own usage. + """ + + def __init__(self, writer: MetadataWriter) -> None: + self.writer = writer + + def populate(self) -> Tuple[bytearray, str]: + """Populates metadata into the TFLite file. + + Note that only the output tflite is used for deployment. The output JSON + content is used to interpret the metadata content. + + Returns: + A tuple of (model_with_metadata_in_bytes, metdata_json_content) + """ + return self.writer.populate() + diff --git a/mediapipe/tasks/python/test/metadata/metadata_writers/BUILD b/mediapipe/tasks/python/test/metadata/metadata_writers/BUILD index 8779d2fb6..a7bfd297d 100644 --- a/mediapipe/tasks/python/test/metadata/metadata_writers/BUILD +++ b/mediapipe/tasks/python/test/metadata/metadata_writers/BUILD @@ -28,9 +28,28 @@ py_test( py_test( name = "metadata_writer_test", srcs = ["metadata_writer_test.py"], - data = ["//mediapipe/tasks/testdata/metadata:model_files"], + data = [ + "//mediapipe/tasks/testdata/metadata:data_files", + "//mediapipe/tasks/testdata/metadata:model_files", + ], deps = [ "//mediapipe/tasks/python/metadata/metadata_writers:metadata_writer", "//mediapipe/tasks/python/test:test_utils", ], ) + +py_test( + name = "image_classifier_test", + srcs = ["image_classifier_test.py"], + data = [ + "//mediapipe/tasks/testdata/metadata:data_files", + "//mediapipe/tasks/testdata/metadata:model_files", + ], + deps = [ + "//mediapipe/tasks/metadata:metadata_schema_py", + "//mediapipe/tasks/python/metadata", + "//mediapipe/tasks/python/metadata/metadata_writers:image_classifier", + "//mediapipe/tasks/python/metadata/metadata_writers:metadata_writer", + "//mediapipe/tasks/python/test:test_utils", + ], +) diff --git a/mediapipe/tasks/python/test/metadata/metadata_writers/image_classifier_test.py b/mediapipe/tasks/python/test/metadata/metadata_writers/image_classifier_test.py new file mode 100644 index 000000000..4bbd91667 --- /dev/null +++ b/mediapipe/tasks/python/test/metadata/metadata_writers/image_classifier_test.py @@ -0,0 +1,76 @@ +# Copyright 2022 The MediaPipe Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +"""Tests for metadata_writer.image_classifier.""" + +from absl.testing import absltest +from absl.testing import parameterized + +from mediapipe.tasks.metadata import metadata_schema_py_generated as metadata_fb +from mediapipe.tasks.python.metadata import metadata +from mediapipe.tasks.python.metadata.metadata_writers import image_classifier +from mediapipe.tasks.python.metadata.metadata_writers import metadata_writer +from mediapipe.tasks.python.test import test_utils + +_FLOAT_MODEL = test_utils.get_test_data_path( + "mobilenet_v2_1.0_224_without_metadata.tflite") +_QUANT_MODEL = test_utils.get_test_data_path( + "mobilenet_v2_1.0_224_quant_without_metadata.tflite") +_LABEL_FILE = test_utils.get_test_data_path("labels.txt") +_SCORE_CALIBRATION_FILE = test_utils.get_test_data_path("score_calibration.txt") +_SCORE_CALIBRATION_FILENAME = "score_calibration.txt" +_DEFAULT_SCORE_CALIBRATION_VALUE = 0.2 +_NORM_MEAN = 127.5 +_NORM_STD = 127.5 +_FLOAT_JSON = test_utils.get_test_data_path("mobilenet_v2_1.0_224.json") +_QUANT_JSON = test_utils.get_test_data_path("mobilenet_v2_1.0_224_quant.json") + + +class ImageClassifierTest(parameterized.TestCase): + + @parameterized.named_parameters( + { + "testcase_name": "test_float_model", + "model_file": _FLOAT_MODEL, + "golden_json": _FLOAT_JSON + }, { + "testcase_name": "test_quant_model", + "model_file": _QUANT_MODEL, + "golden_json": _QUANT_JSON + }) + def test_write_metadata(self, model_file: str, golden_json: str): + with open(model_file, "rb") as f: + model_buffer = f.read() + writer = image_classifier.MetadataWriter.create( + model_buffer, [_NORM_MEAN], [_NORM_STD], + labels=metadata_writer.Labels().add_from_file(_LABEL_FILE), + score_calibration=metadata_writer.ScoreCalibration.create_from_file( + metadata_fb.ScoreTransformationType.LOG, _SCORE_CALIBRATION_FILE, + _DEFAULT_SCORE_CALIBRATION_VALUE)) + tflite_content, metadata_json = writer.populate() + + with open(golden_json, "r") as f: + expected_json = f.read() + self.assertEqual(metadata_json, expected_json) + + displayer = metadata.MetadataDisplayer.with_model_buffer(tflite_content) + file_buffer = displayer.get_associated_file_buffer( + _SCORE_CALIBRATION_FILENAME) + with open(_SCORE_CALIBRATION_FILE, "rb") as f: + expected_file_buffer = f.read() + self.assertEqual(file_buffer, expected_file_buffer) + + +if __name__ == "__main__": + absltest.main() diff --git a/mediapipe/tasks/python/test/metadata/metadata_writers/metadata_writer_test.py b/mediapipe/tasks/python/test/metadata/metadata_writers/metadata_writer_test.py index c39b4a555..51b043c7d 100644 --- a/mediapipe/tasks/python/test/metadata/metadata_writers/metadata_writer_test.py +++ b/mediapipe/tasks/python/test/metadata/metadata_writers/metadata_writer_test.py @@ -13,6 +13,9 @@ # limitations under the License. # ============================================================================== """Tests for metadata writer classes.""" +import os +import tempfile + from absl.testing import absltest from mediapipe.tasks.python.metadata.metadata_writers import metadata_writer @@ -20,6 +23,7 @@ from mediapipe.tasks.python.test import test_utils _IMAGE_CLASSIFIER_MODEL = test_utils.get_test_data_path( 'mobilenet_v1_0.25_224_1_default_1.tflite') +_SCORE_CALIBRATION_FILE = test_utils.get_test_data_path('score_calibration.txt') class LabelsTest(absltest.TestCase): @@ -49,6 +53,54 @@ class LabelsTest(absltest.TestCase): ]) +class ScoreCalibrationTest(absltest.TestCase): + + def test_create_from_file_successful(self): + score_calibration = metadata_writer.ScoreCalibration.create_from_file( + metadata_writer.ScoreCalibration.transformation_types.LOG, + _SCORE_CALIBRATION_FILE) + self.assertLen(score_calibration.parameters, 511) + self.assertIsNone(score_calibration.parameters[0]) + self.assertEqual( + score_calibration.parameters[1], + metadata_writer.CalibrationParameter( + scale=0.9876328110694885, + slope=0.36622241139411926, + offset=0.5352765321731567, + min_score=0.71484375)) + self.assertEqual( + score_calibration.parameters[510], + metadata_writer.CalibrationParameter( + scale=0.9901729226112366, + slope=0.8561913371086121, + offset=0.8783953189849854, + min_score=0.5859375)) + + def test_create_from_file_fail(self): + with tempfile.TemporaryDirectory() as temp_dir: + test_file = os.path.join(temp_dir, 'score_calibration.csv') + with open(test_file, 'w') as f: + f.write('0.98,0.5\n') + + with self.assertRaisesRegex( + ValueError, + 'Expected empty lines or 3 or 4 parameters per line in score ' + 'calibration file, but got 2.' + ): + metadata_writer.ScoreCalibration.create_from_file( + metadata_writer.ScoreCalibration.transformation_types.LOG, + test_file) + + with open(test_file, 'w') as f: + f.write('-0.98,0.5,0.34\n') + with self.assertRaisesRegex( + ValueError, + 'Expected scale to be a non-negative value, but got -0.98.'): + metadata_writer.ScoreCalibration.create_from_file( + metadata_writer.ScoreCalibration.transformation_types.LOG, + test_file) + + class MetadataWriterForTaskTest(absltest.TestCase): def setUp(self): @@ -197,7 +249,7 @@ class MetadataWriterForTaskTest(absltest.TestCase): "output_tensor_metadata": [ { "name": "score", - "description": "Score of the labels respectively", + "description": "Score of the labels respectively.", "content": { "content_properties_type": "FeatureProperties", "content_properties": { @@ -298,7 +350,7 @@ class MetadataWriterForTaskTest(absltest.TestCase): "output_tensor_metadata": [ { "name": "score", - "description": "Score of the labels respectively", + "description": "Score of the labels respectively.", "content": { "content_properties_type": "FeatureProperties", "content_properties": { diff --git a/mediapipe/tasks/testdata/metadata/BUILD b/mediapipe/tasks/testdata/metadata/BUILD index 9f50368b8..6d7bbab6a 100644 --- a/mediapipe/tasks/testdata/metadata/BUILD +++ b/mediapipe/tasks/testdata/metadata/BUILD @@ -29,6 +29,8 @@ mediapipe_files(srcs = [ "mobile_object_classifier_v0_2_3-metadata-no-name.tflite", "mobilenet_v1_0.25_224_1_default_1.tflite", "mobilenet_v2_1.0_224_quant.tflite", + "mobilenet_v2_1.0_224_quant_without_metadata.tflite", + "mobilenet_v2_1.0_224_without_metadata.tflite", ]) exports_files([ @@ -48,6 +50,9 @@ exports_files([ "score_calibration.txt", "score_calibration_file_meta.json", "score_calibration_tensor_meta.json", + "labels.txt", + "mobilenet_v2_1.0_224.json", + "mobilenet_v2_1.0_224_quant.json", ]) filegroup( @@ -59,6 +64,8 @@ filegroup( "mobile_object_classifier_v0_2_3-metadata-no-name.tflite", "mobilenet_v1_0.25_224_1_default_1.tflite", "mobilenet_v2_1.0_224_quant.tflite", + "mobilenet_v2_1.0_224_quant_without_metadata.tflite", + "mobilenet_v2_1.0_224_without_metadata.tflite", ], ) @@ -78,6 +85,9 @@ filegroup( "input_image_tensor_float_meta.json", "input_image_tensor_uint8_meta.json", "input_image_tensor_unsupported_meta.json", + "labels.txt", + "mobilenet_v2_1.0_224.json", + "mobilenet_v2_1.0_224_quant.json", "score_calibration.txt", "score_calibration_file_meta.json", "score_calibration_tensor_meta.json", diff --git a/mediapipe/tasks/testdata/metadata/labels.txt b/mediapipe/tasks/testdata/metadata/labels.txt new file mode 100644 index 000000000..fe811239d --- /dev/null +++ b/mediapipe/tasks/testdata/metadata/labels.txt @@ -0,0 +1,1001 @@ +background +tench +goldfish +great white shark +tiger shark +hammerhead +electric ray +stingray +cock +hen +ostrich +brambling +goldfinch +house finch +junco +indigo bunting +robin +bulbul +jay +magpie +chickadee +water ouzel +kite +bald eagle +vulture +great grey owl +European fire salamander +common newt +eft +spotted salamander +axolotl +bullfrog +tree frog +tailed frog +loggerhead +leatherback turtle +mud turtle +terrapin +box turtle +banded gecko +common iguana +American chameleon +whiptail +agama +frilled lizard +alligator lizard +Gila monster +green lizard +African chameleon +Komodo dragon +African crocodile +American alligator +triceratops +thunder snake +ringneck snake +hognose snake +green snake +king snake +garter snake +water snake +vine snake +night snake +boa constrictor +rock python +Indian cobra +green mamba +sea snake +horned viper +diamondback +sidewinder +trilobite +harvestman +scorpion +black and gold garden spider +barn spider +garden spider +black widow +tarantula +wolf spider +tick +centipede +black grouse +ptarmigan +ruffed grouse +prairie chicken +peacock +quail +partridge +African grey +macaw +sulphur-crested cockatoo +lorikeet +coucal +bee eater +hornbill +hummingbird +jacamar +toucan +drake +red-breasted merganser +goose +black swan +tusker +echidna +platypus +wallaby +koala +wombat +jellyfish +sea anemone +brain coral +flatworm +nematode +conch +snail +slug +sea slug +chiton +chambered nautilus +Dungeness crab +rock crab +fiddler crab +king crab +American lobster +spiny lobster +crayfish +hermit crab +isopod +white stork +black stork +spoonbill +flamingo +little blue heron +American egret +bittern +crane +limpkin +European gallinule +American coot +bustard +ruddy turnstone +red-backed sandpiper +redshank +dowitcher +oystercatcher +pelican +king penguin +albatross +grey whale +killer whale +dugong +sea lion +Chihuahua +Japanese spaniel +Maltese dog +Pekinese +Shih-Tzu +Blenheim spaniel +papillon +toy terrier +Rhodesian ridgeback +Afghan hound +basset +beagle +bloodhound +bluetick +black-and-tan coonhound +Walker hound +English foxhound +redbone +borzoi +Irish wolfhound +Italian greyhound +whippet +Ibizan hound +Norwegian elkhound +otterhound +Saluki +Scottish deerhound +Weimaraner +Staffordshire bullterrier +American Staffordshire terrier +Bedlington terrier +Border terrier +Kerry blue terrier +Irish terrier +Norfolk terrier +Norwich terrier +Yorkshire terrier +wire-haired fox terrier +Lakeland terrier +Sealyham terrier +Airedale +cairn +Australian terrier +Dandie Dinmont +Boston bull +miniature schnauzer +giant schnauzer +standard schnauzer +Scotch terrier +Tibetan terrier +silky terrier +soft-coated wheaten terrier +West Highland white terrier +Lhasa +flat-coated retriever +curly-coated retriever +golden retriever +Labrador retriever +Chesapeake Bay retriever +German short-haired pointer +vizsla +English setter +Irish setter +Gordon setter +Brittany spaniel +clumber +English springer +Welsh springer spaniel +cocker spaniel +Sussex spaniel +Irish water spaniel +kuvasz +schipperke +groenendael +malinois +briard +kelpie +komondor +Old English sheepdog +Shetland sheepdog +collie +Border collie +Bouvier des Flandres +Rottweiler +German shepherd +Doberman +miniature pinscher +Greater Swiss Mountain dog +Bernese mountain dog +Appenzeller +EntleBucher +boxer +bull mastiff +Tibetan mastiff +French bulldog +Great Dane +Saint Bernard +Eskimo dog +malamute +Siberian husky +dalmatian +affenpinscher +basenji +pug +Leonberg +Newfoundland +Great Pyrenees +Samoyed +Pomeranian +chow +keeshond +Brabancon griffon +Pembroke +Cardigan +toy poodle +miniature poodle +standard poodle +Mexican hairless +timber wolf +white wolf +red wolf +coyote +dingo +dhole +African hunting dog +hyena +red fox +kit fox +Arctic fox +grey fox +tabby +tiger cat +Persian cat +Siamese cat +Egyptian cat +cougar +lynx +leopard +snow leopard +jaguar +lion +tiger +cheetah +brown bear +American black bear +ice bear +sloth bear +mongoose +meerkat +tiger beetle +ladybug +ground beetle +long-horned beetle +leaf beetle +dung beetle +rhinoceros beetle +weevil +fly +bee +ant +grasshopper +cricket +walking stick +cockroach +mantis +cicada +leafhopper +lacewing +dragonfly +damselfly +admiral +ringlet +monarch +cabbage butterfly +sulphur butterfly +lycaenid +starfish +sea urchin +sea cucumber +wood rabbit +hare +Angora +hamster +porcupine +fox squirrel +marmot +beaver +guinea pig +sorrel +zebra +hog +wild boar +warthog +hippopotamus +ox +water buffalo +bison +ram +bighorn +ibex +hartebeest +impala +gazelle +Arabian camel +llama +weasel +mink +polecat +black-footed ferret +otter +skunk +badger +armadillo +three-toed sloth +orangutan +gorilla +chimpanzee +gibbon +siamang +guenon +patas +baboon +macaque +langur +colobus +proboscis monkey +marmoset +capuchin +howler monkey +titi +spider monkey +squirrel monkey +Madagascar cat +indri +Indian elephant +African elephant +lesser panda +giant panda +barracouta +eel +coho +rock beauty +anemone fish +sturgeon +gar +lionfish +puffer +abacus +abaya +academic gown +accordion +acoustic guitar +aircraft carrier +airliner +airship +altar +ambulance +amphibian +analog clock +apiary +apron +ashcan +assault rifle +backpack +bakery +balance beam +balloon +ballpoint +Band Aid +banjo +bannister +barbell +barber chair +barbershop +barn +barometer +barrel +barrow +baseball +basketball +bassinet +bassoon +bathing cap +bath towel +bathtub +beach wagon +beacon +beaker +bearskin +beer bottle +beer glass +bell cote +bib +bicycle-built-for-two +bikini +binder +binoculars +birdhouse +boathouse +bobsled +bolo tie +bonnet +bookcase +bookshop +bottlecap +bow +bow tie +brass +brassiere +breakwater +breastplate +broom +bucket +buckle +bulletproof vest +bullet train +butcher shop +cab +caldron +candle +cannon +canoe +can opener +cardigan +car mirror +carousel +carpenter's kit +carton +car wheel +cash machine +cassette +cassette player +castle +catamaran +CD player +cello +cellular telephone +chain +chainlink fence +chain mail +chain saw +chest +chiffonier +chime +china cabinet +Christmas stocking +church +cinema +cleaver +cliff dwelling +cloak +clog +cocktail shaker +coffee mug +coffeepot +coil +combination lock +computer keyboard +confectionery +container ship +convertible +corkscrew +cornet +cowboy boot +cowboy hat +cradle +crane +crash helmet +crate +crib +Crock Pot +croquet ball +crutch +cuirass +dam +desk +desktop computer +dial telephone +diaper +digital clock +digital watch +dining table +dishrag +dishwasher +disk brake +dock +dogsled +dome +doormat +drilling platform +drum +drumstick +dumbbell +Dutch oven +electric fan +electric guitar +electric locomotive +entertainment center +envelope +espresso maker +face powder +feather boa +file +fireboat +fire engine +fire screen +flagpole +flute +folding chair +football helmet +forklift +fountain +fountain pen +four-poster +freight car +French horn +frying pan +fur coat +garbage truck +gasmask +gas pump +goblet +go-kart +golf ball +golfcart +gondola +gong +gown +grand piano +greenhouse +grille +grocery store +guillotine +hair slide +hair spray +half track +hammer +hamper +hand blower +hand-held computer +handkerchief +hard disc +harmonica +harp +harvester +hatchet +holster +home theater +honeycomb +hook +hoopskirt +horizontal bar +horse cart +hourglass +iPod +iron +jack-o'-lantern +jean +jeep +jersey +jigsaw puzzle +jinrikisha +joystick +kimono +knee pad +knot +lab coat +ladle +lampshade +laptop +lawn mower +lens cap +letter opener +library +lifeboat +lighter +limousine +liner +lipstick +Loafer +lotion +loudspeaker +loupe +lumbermill +magnetic compass +mailbag +mailbox +maillot +maillot +manhole cover +maraca +marimba +mask +matchstick +maypole +maze +measuring cup +medicine chest +megalith +microphone +microwave +military uniform +milk can +minibus +miniskirt +minivan +missile +mitten +mixing bowl +mobile home +Model T +modem +monastery +monitor +moped +mortar +mortarboard +mosque +mosquito net +motor scooter +mountain bike +mountain tent +mouse +mousetrap +moving van +muzzle +nail +neck brace +necklace +nipple +notebook +obelisk +oboe +ocarina +odometer +oil filter +organ +oscilloscope +overskirt +oxcart +oxygen mask +packet +paddle +paddlewheel +padlock +paintbrush +pajama +palace +panpipe +paper towel +parachute +parallel bars +park bench +parking meter +passenger car +patio +pay-phone +pedestal +pencil box +pencil sharpener +perfume +Petri dish +photocopier +pick +pickelhaube +picket fence +pickup +pier +piggy bank +pill bottle +pillow +ping-pong ball +pinwheel +pirate +pitcher +plane +planetarium +plastic bag +plate rack +plow +plunger +Polaroid camera +pole +police van +poncho +pool table +pop bottle +pot +potter's wheel +power drill +prayer rug +printer +prison +projectile +projector +puck +punching bag +purse +quill +quilt +racer +racket +radiator +radio +radio telescope +rain barrel +recreational vehicle +reel +reflex camera +refrigerator +remote control +restaurant +revolver +rifle +rocking chair +rotisserie +rubber eraser +rugby ball +rule +running shoe +safe +safety pin +saltshaker +sandal +sarong +sax +scabbard +scale +school bus +schooner +scoreboard +screen +screw +screwdriver +seat belt +sewing machine +shield +shoe shop +shoji +shopping basket +shopping cart +shovel +shower cap +shower curtain +ski +ski mask +sleeping bag +slide rule +sliding door +slot +snorkel +snowmobile +snowplow +soap dispenser +soccer ball +sock +solar dish +sombrero +soup bowl +space bar +space heater +space shuttle +spatula +speedboat +spider web +spindle +sports car +spotlight +stage +steam locomotive +steel arch bridge +steel drum +stethoscope +stole +stone wall +stopwatch +stove +strainer +streetcar +stretcher +studio couch +stupa +submarine +suit +sundial +sunglass +sunglasses +sunscreen +suspension bridge +swab +sweatshirt +swimming trunks +swing +switch +syringe +table lamp +tank +tape player +teapot +teddy +television +tennis ball +thatch +theater curtain +thimble +thresher +throne +tile roof +toaster +tobacco shop +toilet seat +torch +totem pole +tow truck +toyshop +tractor +trailer truck +tray +trench coat +tricycle +trimaran +tripod +triumphal arch +trolleybus +trombone +tub +turnstile +typewriter keyboard +umbrella +unicycle +upright +vacuum +vase +vault +velvet +vending machine +vestment +viaduct +violin +volleyball +waffle iron +wall clock +wallet +wardrobe +warplane +washbasin +washer +water bottle +water jug +water tower +whiskey jug +whistle +wig +window screen +window shade +Windsor tie +wine bottle +wing +wok +wooden spoon +wool +worm fence +wreck +yawl +yurt +web site +comic book +crossword puzzle +street sign +traffic light +book jacket +menu +plate +guacamole +consomme +hot pot +trifle +ice cream +ice lolly +French loaf +bagel +pretzel +cheeseburger +hotdog +mashed potato +head cabbage +broccoli +cauliflower +zucchini +spaghetti squash +acorn squash +butternut squash +cucumber +artichoke +bell pepper +cardoon +mushroom +Granny Smith +strawberry +orange +lemon +fig +pineapple +banana +jackfruit +custard apple +pomegranate +hay +carbonara +chocolate sauce +dough +meat loaf +pizza +potpie +burrito +red wine +espresso +cup +eggnog +alp +bubble +cliff +coral reef +geyser +lakeside +promontory +sandbar +seashore +valley +volcano +ballplayer +groom +scuba diver +rapeseed +daisy +yellow lady's slipper +corn +acorn +hip +buckeye +coral fungus +agaric +gyromitra +stinkhorn +earthstar +hen-of-the-woods +bolete +ear +toilet tissue diff --git a/mediapipe/tasks/testdata/metadata/mobilenet_v2_1.0_224.json b/mediapipe/tasks/testdata/metadata/mobilenet_v2_1.0_224.json new file mode 100644 index 000000000..6f01f9f09 --- /dev/null +++ b/mediapipe/tasks/testdata/metadata/mobilenet_v2_1.0_224.json @@ -0,0 +1,82 @@ +{ + "name": "ImageClassifier", + "description": "Identify the most prominent object in the image from a known set of categories.", + "subgraph_metadata": [ + { + "input_tensor_metadata": [ + { + "name": "image", + "description": "Input image to be processed.", + "content": { + "content_properties_type": "ImageProperties", + "content_properties": { + "color_space": "RGB" + } + }, + "process_units": [ + { + "options_type": "NormalizationOptions", + "options": { + "mean": [ + 127.5 + ], + "std": [ + 127.5 + ] + } + } + ], + "stats": { + "max": [ + 1.0 + ], + "min": [ + -1.0 + ] + } + } + ], + "output_tensor_metadata": [ + { + "name": "score", + "description": "Score of the labels respectively.", + "content": { + "content_properties_type": "FeatureProperties", + "content_properties": { + } + }, + "process_units": [ + { + "options_type": "ScoreCalibrationOptions", + "options": { + "score_transformation": "LOG", + "default_score": 0.2 + } + } + ], + "stats": { + "max": [ + 1.0 + ], + "min": [ + 0.0 + ] + }, + "associated_files": [ + { + "name": "labels.txt", + "description": "Labels for categories that the model can recognize.", + "type": "TENSOR_AXIS_LABELS" + }, + { + "name": "score_calibration.txt", + "description": "Contains sigmoid-based score calibration parameters. The main purposes of score calibration is to make scores across classes comparable, so that a common threshold can be used for all output classes.", + "type": "TENSOR_AXIS_SCORE_CALIBRATION" + } + ] + } + ] + } + ], + "min_parser_version": "1.0.0" +} diff --git a/mediapipe/tasks/testdata/metadata/mobilenet_v2_1.0_224_quant.json b/mediapipe/tasks/testdata/metadata/mobilenet_v2_1.0_224_quant.json new file mode 100644 index 000000000..e2ba42e3b --- /dev/null +++ b/mediapipe/tasks/testdata/metadata/mobilenet_v2_1.0_224_quant.json @@ -0,0 +1,82 @@ +{ + "name": "ImageClassifier", + "description": "Identify the most prominent object in the image from a known set of categories.", + "subgraph_metadata": [ + { + "input_tensor_metadata": [ + { + "name": "image", + "description": "Input image to be processed.", + "content": { + "content_properties_type": "ImageProperties", + "content_properties": { + "color_space": "RGB" + } + }, + "process_units": [ + { + "options_type": "NormalizationOptions", + "options": { + "mean": [ + 127.5 + ], + "std": [ + 127.5 + ] + } + } + ], + "stats": { + "max": [ + 255.0 + ], + "min": [ + 0.0 + ] + } + } + ], + "output_tensor_metadata": [ + { + "name": "score", + "description": "Score of the labels respectively.", + "content": { + "content_properties_type": "FeatureProperties", + "content_properties": { + } + }, + "process_units": [ + { + "options_type": "ScoreCalibrationOptions", + "options": { + "score_transformation": "LOG", + "default_score": 0.2 + } + } + ], + "stats": { + "max": [ + 255.0 + ], + "min": [ + 0.0 + ] + }, + "associated_files": [ + { + "name": "labels.txt", + "description": "Labels for categories that the model can recognize.", + "type": "TENSOR_AXIS_LABELS" + }, + { + "name": "score_calibration.txt", + "description": "Contains sigmoid-based score calibration parameters. The main purposes of score calibration is to make scores across classes comparable, so that a common threshold can be used for all output classes.", + "type": "TENSOR_AXIS_SCORE_CALIBRATION" + } + ] + } + ] + } + ], + "min_parser_version": "1.0.0" +} diff --git a/third_party/external_files.bzl b/third_party/external_files.bzl index 54c0dde0a..84b354c99 100644 --- a/third_party/external_files.bzl +++ b/third_party/external_files.bzl @@ -364,6 +364,12 @@ def external_files(): urls = ["https://storage.googleapis.com/mediapipe-assets/knift_labelmap.txt?generation=1661875792821628"], ) + http_file( + name = "com_google_mediapipe_labels_txt", + sha256 = "536feacc519de3d418de26b2effb4d75694a8c4c0063e36499a46fa8061e2da9", + urls = ["https://storage.googleapis.com/mediapipe-assets/labels.txt?generation=1665988394538324"], + ) + http_file( name = "com_google_mediapipe_left_hands_jpg", sha256 = "4b5134daa4cb60465535239535f9f74c2842aba3aa5fd30bf04ef5678f93d87f", @@ -448,18 +454,42 @@ def external_files(): urls = ["https://storage.googleapis.com/mediapipe-assets/mobilenet_v1_0.25_224_quant_without_subgraph_metadata.tflite?generation=1661875836078124"], ) + http_file( + name = "com_google_mediapipe_mobilenet_v2_1_0_224_json", + sha256 = "0eb285a857b4bb1815736d0902ace0af45ea62e90c1dac98844b9ca797cd0d7b", + urls = ["https://storage.googleapis.com/mediapipe-assets/mobilenet_v2_1.0_224.json?generation=1665988398778178"], + ) + + http_file( + name = "com_google_mediapipe_mobilenet_v2_1_0_224_quant_json", + sha256 = "932f345ebe3d98daf0dc4c88b0f9e694e450390fb394fc217e851338dfec43e6", + urls = ["https://storage.googleapis.com/mediapipe-assets/mobilenet_v2_1.0_224_quant.json?generation=1665988401522527"], + ) + http_file( name = "com_google_mediapipe_mobilenet_v2_1_0_224_quant_tflite", sha256 = "f08d447cde49b4e0446428aa921aff0a14ea589fa9c5817b31f83128e9a43c1d", urls = ["https://storage.googleapis.com/mediapipe-assets/mobilenet_v2_1.0_224_quant.tflite?generation=1664340173966530"], ) + http_file( + name = "com_google_mediapipe_mobilenet_v2_1_0_224_quant_without_metadata_tflite", + sha256 = "f08d447cde49b4e0446428aa921aff0a14ea589fa9c5817b31f83128e9a43c1d", + urls = ["https://storage.googleapis.com/mediapipe-assets/mobilenet_v2_1.0_224_quant_without_metadata.tflite?generation=1665988405130772"], + ) + http_file( name = "com_google_mediapipe_mobilenet_v2_1_0_224_tflite", sha256 = "ff5cb7f9e62c92ebdad971f8a98aa6b3106d82a64587a7787c6a385c9e791339", urls = ["https://storage.googleapis.com/mediapipe-assets/mobilenet_v2_1.0_224.tflite?generation=1661875840611150"], ) + http_file( + name = "com_google_mediapipe_mobilenet_v2_1_0_224_without_metadata_tflite", + sha256 = "9f3bc29e38e90842a852bfed957dbf5e36f2d97a91dd17736b1e5c0aca8d3303", + urls = ["https://storage.googleapis.com/mediapipe-assets/mobilenet_v2_1.0_224_without_metadata.tflite?generation=1665988408360823"], + ) + http_file( name = "com_google_mediapipe_mobilenet_v3_small_100_224_embedder_tflite", sha256 = "f7b9a563cb803bdcba76e8c7e82abde06f5c7a8e67b5e54e43e23095dfe79a78", From de93d06f8795c46e380f2adca15873b552597a93 Mon Sep 17 00:00:00 2001 From: Jiuqiang Tang Date: Mon, 24 Oct 2022 00:43:45 -0700 Subject: [PATCH 35/55] Implement build rules and targets to create MediaPipe Tasks AARs PiperOrigin-RevId: 483297891 --- .../com/google/mediapipe/mediapipe_aar.bzl | 31 +-- mediapipe/tasks/java/BUILD | 16 +- .../java/com/google/mediapipe/tasks/BUILD | 15 ++ .../com/google/mediapipe/tasks/core/BUILD | 12 + .../mediapipe/tasks/mediapipe_tasks_aar.bzl | 211 ++++++++++++++++++ .../com/google/mediapipe/tasks/vision/BUILD | 8 + 6 files changed, 279 insertions(+), 14 deletions(-) create mode 100644 mediapipe/tasks/java/com/google/mediapipe/tasks/BUILD create mode 100644 mediapipe/tasks/java/com/google/mediapipe/tasks/mediapipe_tasks_aar.bzl diff --git a/mediapipe/java/com/google/mediapipe/mediapipe_aar.bzl b/mediapipe/java/com/google/mediapipe/mediapipe_aar.bzl index 9b01e2f0b..645e8b722 100644 --- a/mediapipe/java/com/google/mediapipe/mediapipe_aar.bzl +++ b/mediapipe/java/com/google/mediapipe/mediapipe_aar.bzl @@ -1,4 +1,4 @@ -# Copyright 2019-2020 The MediaPipe Authors. +# Copyright 2019-2022 The MediaPipe Authors. All Rights Reserved. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. @@ -209,9 +209,9 @@ def _mediapipe_jni(name, gen_libmediapipe, calculators = []): def mediapipe_build_aar_with_jni(name, android_library): """Builds MediaPipe AAR with jni. - Args: - name: The bazel target name. - android_library: the android library that contains jni. + Args: + name: The bazel target name. + android_library: the android library that contains jni. """ # Generates dummy AndroidManifest.xml for dummy apk usage @@ -328,19 +328,14 @@ def mediapipe_java_proto_srcs(name = ""): src_out = "com/google/mediapipe/proto/MediaPipeOptionsProto.java", )) - proto_src_list.append(mediapipe_java_proto_src_extractor( - target = "//mediapipe/framework/formats:landmark_java_proto_lite", - src_out = "com/google/mediapipe/formats/proto/LandmarkProto.java", - )) - proto_src_list.append(mediapipe_java_proto_src_extractor( target = "//mediapipe/framework/formats/annotation:rasterization_java_proto_lite", src_out = "com/google/mediapipe/formats/annotation/proto/RasterizationProto.java", )) proto_src_list.append(mediapipe_java_proto_src_extractor( - target = "//mediapipe/framework/formats:location_data_java_proto_lite", - src_out = "com/google/mediapipe/formats/proto/LocationDataProto.java", + target = "//mediapipe/framework/formats:classification_java_proto_lite", + src_out = "com/google/mediapipe/formats/proto/ClassificationProto.java", )) proto_src_list.append(mediapipe_java_proto_src_extractor( @@ -349,8 +344,18 @@ def mediapipe_java_proto_srcs(name = ""): )) proto_src_list.append(mediapipe_java_proto_src_extractor( - target = "//mediapipe/framework/formats:classification_java_proto_lite", - src_out = "com/google/mediapipe/formats/proto/ClassificationProto.java", + target = "//mediapipe/framework/formats:landmark_java_proto_lite", + src_out = "com/google/mediapipe/formats/proto/LandmarkProto.java", + )) + + proto_src_list.append(mediapipe_java_proto_src_extractor( + target = "//mediapipe/framework/formats:location_data_java_proto_lite", + src_out = "com/google/mediapipe/formats/proto/LocationDataProto.java", + )) + + proto_src_list.append(mediapipe_java_proto_src_extractor( + target = "//mediapipe/framework/formats:rect_java_proto_lite", + src_out = "com/google/mediapipe/formats/proto/RectProto.java", )) return proto_src_list diff --git a/mediapipe/tasks/java/BUILD b/mediapipe/tasks/java/BUILD index 024510737..7e6283261 100644 --- a/mediapipe/tasks/java/BUILD +++ b/mediapipe/tasks/java/BUILD @@ -1 +1,15 @@ -# dummy file for tap test to find the pattern +# Copyright 2022 The MediaPipe Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +licenses(["notice"]) diff --git a/mediapipe/tasks/java/com/google/mediapipe/tasks/BUILD b/mediapipe/tasks/java/com/google/mediapipe/tasks/BUILD new file mode 100644 index 000000000..7e6283261 --- /dev/null +++ b/mediapipe/tasks/java/com/google/mediapipe/tasks/BUILD @@ -0,0 +1,15 @@ +# Copyright 2022 The MediaPipe Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +licenses(["notice"]) diff --git a/mediapipe/tasks/java/com/google/mediapipe/tasks/core/BUILD b/mediapipe/tasks/java/com/google/mediapipe/tasks/core/BUILD index b4ebfe8cc..cb9d67424 100644 --- a/mediapipe/tasks/java/com/google/mediapipe/tasks/core/BUILD +++ b/mediapipe/tasks/java/com/google/mediapipe/tasks/core/BUILD @@ -36,3 +36,15 @@ android_library( "@maven//:com_google_guava_guava", ], ) + +load("//mediapipe/tasks/java/com/google/mediapipe/tasks:mediapipe_tasks_aar.bzl", "mediapipe_tasks_core_aar") + +mediapipe_tasks_core_aar( + name = "tasks_core", + srcs = glob(["*.java"]) + [ + "//mediapipe/tasks/java/com/google/mediapipe/tasks/components/containers:java_src", + "//mediapipe/tasks/java/com/google/mediapipe/tasks/components/processors:java_src", + "//mediapipe/java/com/google/mediapipe/framework/image:java_src", + ], + manifest = "AndroidManifest.xml", +) diff --git a/mediapipe/tasks/java/com/google/mediapipe/tasks/mediapipe_tasks_aar.bzl b/mediapipe/tasks/java/com/google/mediapipe/tasks/mediapipe_tasks_aar.bzl new file mode 100644 index 000000000..e0b9c79ed --- /dev/null +++ b/mediapipe/tasks/java/com/google/mediapipe/tasks/mediapipe_tasks_aar.bzl @@ -0,0 +1,211 @@ +# Copyright 2022 The MediaPipe Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Building MediaPipe Tasks AARs.""" + +load("//mediapipe/java/com/google/mediapipe:mediapipe_aar.bzl", "mediapipe_build_aar_with_jni", "mediapipe_java_proto_src_extractor", "mediapipe_java_proto_srcs") +load("@build_bazel_rules_android//android:rules.bzl", "android_library") + +_CORE_TASKS_JAVA_PROTO_LITE_TARGETS = [ + "//mediapipe/tasks/cc/components/containers/proto:category_java_proto_lite", + "//mediapipe/tasks/cc/components/containers/proto:classifications_java_proto_lite", + "//mediapipe/tasks/cc/components/containers/proto:embeddings_java_proto_lite", + "//mediapipe/tasks/cc/components/containers/proto:landmarks_detection_result_java_proto_lite", + "//mediapipe/tasks/cc/components/processors/proto:classifier_options_java_proto_lite", + "//mediapipe/tasks/cc/core/proto:acceleration_java_proto_lite", + "//mediapipe/tasks/cc/core/proto:base_options_java_proto_lite", + "//mediapipe/tasks/cc/core/proto:external_file_java_proto_lite", +] + +_VISION_TASKS_JAVA_PROTO_LITE_TARGETS = [ + "//mediapipe/tasks/cc/vision/object_detector/proto:object_detector_options_java_proto_lite", + "//mediapipe/tasks/cc/vision/image_classifier/proto:image_classifier_graph_options_java_proto_lite", + "//mediapipe/tasks/cc/vision/gesture_recognizer/proto:gesture_classifier_graph_options_java_proto_lite", + "//mediapipe/tasks/cc/vision/gesture_recognizer/proto:gesture_embedder_graph_options_java_proto_lite", + "//mediapipe/tasks/cc/vision/gesture_recognizer/proto:gesture_recognizer_graph_options_java_proto_lite", + "//mediapipe/tasks/cc/vision/gesture_recognizer/proto:hand_gesture_recognizer_graph_options_java_proto_lite", + "//mediapipe/tasks/cc/vision/hand_detector/proto:hand_detector_graph_options_java_proto_lite", + "//mediapipe/tasks/cc/vision/hand_landmarker/proto:hand_landmarker_graph_options_java_proto_lite", + "//mediapipe/tasks/cc/vision/hand_landmarker/proto:hand_landmarks_detector_graph_options_java_proto_lite", +] + +def mediapipe_tasks_core_aar(name, srcs, manifest): + """Builds medaipipe tasks core AAR. + + Args: + name: The bazel target name. + srcs: MediaPipe Tasks' core layer source files. + manifest: The Android manifest. + """ + + mediapipe_tasks_java_proto_srcs = [] + for target in _CORE_TASKS_JAVA_PROTO_LITE_TARGETS: + mediapipe_tasks_java_proto_srcs.append( + _mediapipe_tasks_java_proto_src_extractor(target = target), + ) + + for target in _VISION_TASKS_JAVA_PROTO_LITE_TARGETS: + mediapipe_tasks_java_proto_srcs.append( + _mediapipe_tasks_java_proto_src_extractor(target = target), + ) + + mediapipe_tasks_java_proto_srcs.append(mediapipe_java_proto_src_extractor( + target = "//mediapipe/calculators/core:flow_limiter_calculator_java_proto_lite", + src_out = "com/google/mediapipe/calculator/proto/FlowLimiterCalculatorProto.java", + )) + + mediapipe_tasks_java_proto_srcs.append(mediapipe_java_proto_src_extractor( + target = "//mediapipe/calculators/tensor:inference_calculator_java_proto_lite", + src_out = "com/google/mediapipe/calculator/proto/InferenceCalculatorProto.java", + )) + + android_library( + name = name, + srcs = srcs + [ + "//mediapipe/java/com/google/mediapipe/framework:java_src", + ] + mediapipe_java_proto_srcs() + + mediapipe_tasks_java_proto_srcs, + javacopts = [ + "-Xep:AndroidJdkLibsChecker:OFF", + ], + manifest = manifest, + deps = [ + "//mediapipe/calculators/core:flow_limiter_calculator_java_proto_lite", + "//mediapipe/calculators/tensor:inference_calculator_java_proto_lite", + "//mediapipe/framework:calculator_java_proto_lite", + "//mediapipe/framework:calculator_profile_java_proto_lite", + "//mediapipe/framework:calculator_options_java_proto_lite", + "//mediapipe/framework:mediapipe_options_java_proto_lite", + "//mediapipe/framework:packet_factory_java_proto_lite", + "//mediapipe/framework:packet_generator_java_proto_lite", + "//mediapipe/framework:status_handler_java_proto_lite", + "//mediapipe/framework:stream_handler_java_proto_lite", + "//mediapipe/framework/formats:classification_java_proto_lite", + "//mediapipe/framework/formats:detection_java_proto_lite", + "//mediapipe/framework/formats:landmark_java_proto_lite", + "//mediapipe/framework/formats:location_data_java_proto_lite", + "//mediapipe/framework/formats:rect_java_proto_lite", + "//mediapipe/java/com/google/mediapipe/framework:android_framework", + "//mediapipe/java/com/google/mediapipe/framework/image", + "//mediapipe/tasks/java/com/google/mediapipe/tasks/core/jni:model_resources_cache_jni", + "//third_party:androidx_annotation", + "//third_party:autovalue", + "@com_google_protobuf//:protobuf_javalite", + "@maven//:com_google_guava_guava", + "@maven//:com_google_flogger_flogger", + "@maven//:com_google_flogger_flogger_system_backend", + "@maven//:com_google_code_findbugs_jsr305", + ] + _CORE_TASKS_JAVA_PROTO_LITE_TARGETS + _VISION_TASKS_JAVA_PROTO_LITE_TARGETS, + ) + +def mediapipe_tasks_vision_aar(name, srcs, native_library): + """Builds medaipipe tasks vision AAR. + + Args: + name: The bazel target name. + srcs: MediaPipe Vision Tasks' source files. + native_library: The native library that contains vision tasks' graph and calculators. + """ + + native.genrule( + name = name + "tasks_manifest_generator", + outs = ["AndroidManifest.xml"], + cmd = """ +cat > $(OUTS) < + + + +EOF +""", + ) + + _mediapipe_tasks_aar( + name = name, + srcs = srcs, + manifest = "AndroidManifest.xml", + java_proto_lite_targets = _CORE_TASKS_JAVA_PROTO_LITE_TARGETS + _VISION_TASKS_JAVA_PROTO_LITE_TARGETS, + native_library = native_library, + ) + +def _mediapipe_tasks_aar(name, srcs, manifest, java_proto_lite_targets, native_library): + """Builds medaipipe tasks AAR.""" + + # When "--define EXCLUDE_OPENCV_SO_LIB=1" is set in the build command, + # the OpenCV so libraries will be excluded from the AAR package to + # save the package size. + native.config_setting( + name = "exclude_opencv_so_lib", + define_values = { + "EXCLUDE_OPENCV_SO_LIB": "1", + }, + visibility = ["//visibility:public"], + ) + + native.cc_library( + name = name + "_jni_opencv_cc_lib", + srcs = select({ + "//mediapipe:android_arm64": ["@android_opencv//:libopencv_java3_so_arm64-v8a"], + "//mediapipe:android_armeabi": ["@android_opencv//:libopencv_java3_so_armeabi-v7a"], + "//mediapipe:android_arm": ["@android_opencv//:libopencv_java3_so_armeabi-v7a"], + "//mediapipe:android_x86": ["@android_opencv//:libopencv_java3_so_x86"], + "//mediapipe:android_x86_64": ["@android_opencv//:libopencv_java3_so_x86_64"], + "//conditions:default": [], + }), + alwayslink = 1, + ) + + android_library( + name = name + "_android_lib", + srcs = srcs, + manifest = manifest, + proguard_specs = ["//mediapipe/java/com/google/mediapipe/framework:proguard.pgcfg"], + deps = java_proto_lite_targets + [native_library] + [ + "//mediapipe/java/com/google/mediapipe/framework:android_framework", + "//mediapipe/java/com/google/mediapipe/framework/image", + "//mediapipe/framework:calculator_options_java_proto_lite", + "//mediapipe/framework:calculator_java_proto_lite", + "//mediapipe/framework/formats:classification_java_proto_lite", + "//mediapipe/framework/formats:detection_java_proto_lite", + "//mediapipe/framework/formats:landmark_java_proto_lite", + "//mediapipe/framework/formats:location_data_java_proto_lite", + "//mediapipe/framework/formats:rect_java_proto_lite", + "//mediapipe/tasks/java/com/google/mediapipe/tasks/components/containers:detection", + "//mediapipe/tasks/java/com/google/mediapipe/tasks/components/containers:category", + "//mediapipe/tasks/java/com/google/mediapipe/tasks/components/containers:classification_entry", + "//mediapipe/tasks/java/com/google/mediapipe/tasks/components/containers:classifications", + "//mediapipe/tasks/java/com/google/mediapipe/tasks/components/containers:landmark", + "//mediapipe/tasks/java/com/google/mediapipe/tasks/components/processors:classifieroptions", + "//mediapipe/tasks/java/com/google/mediapipe/tasks/core", + "//third_party:autovalue", + "@maven//:com_google_guava_guava", + ] + select({ + "//conditions:default": [":" + name + "_jni_opencv_cc_lib"], + "//mediapipe/framework/port:disable_opencv": [], + "exclude_opencv_so_lib": [], + }), + ) + + mediapipe_build_aar_with_jni(name, name + "_android_lib") + +def _mediapipe_tasks_java_proto_src_extractor(target): + proto_path = "com/google/" + target.split(":")[0].replace("cc/", "").replace("//", "").replace("_", "") + "/" + proto_name = target.split(":")[-1].replace("_java_proto_lite", "").replace("_", " ").title().replace(" ", "") + "Proto.java" + return mediapipe_java_proto_src_extractor( + target = target, + src_out = proto_path + proto_name, + ) diff --git a/mediapipe/tasks/java/com/google/mediapipe/tasks/vision/BUILD b/mediapipe/tasks/java/com/google/mediapipe/tasks/vision/BUILD index 2bdcc2522..5ea465d47 100644 --- a/mediapipe/tasks/java/com/google/mediapipe/tasks/vision/BUILD +++ b/mediapipe/tasks/java/com/google/mediapipe/tasks/vision/BUILD @@ -141,3 +141,11 @@ android_library( "@maven//:com_google_guava_guava", ], ) + +load("//mediapipe/tasks/java/com/google/mediapipe/tasks:mediapipe_tasks_aar.bzl", "mediapipe_tasks_vision_aar") + +mediapipe_tasks_vision_aar( + name = "tasks_vision", + srcs = glob(["**/*.java"]), + native_library = ":libmediapipe_tasks_vision_jni_lib", +) From af051dcb628782ec33188b425b64c5b516c33d59 Mon Sep 17 00:00:00 2001 From: MediaPipe Team Date: Mon, 24 Oct 2022 01:37:49 -0700 Subject: [PATCH 36/55] internal change PiperOrigin-RevId: 483308781 --- mediapipe/framework/calculator_base.h | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/mediapipe/framework/calculator_base.h b/mediapipe/framework/calculator_base.h index f9f0d7a8a..19f37f9de 100644 --- a/mediapipe/framework/calculator_base.h +++ b/mediapipe/framework/calculator_base.h @@ -185,7 +185,7 @@ class CalculatorBaseFactory { // Functions for checking that the calculator has the required GetContract. template constexpr bool CalculatorHasGetContract(decltype(&T::GetContract) /*unused*/) { - typedef absl::Status (*GetContractType)(CalculatorContract * cc); + typedef absl::Status (*GetContractType)(CalculatorContract* cc); return std::is_same::value; } template From 0fd69e8d838d71e364e019f3eb29eb4389dbec7b Mon Sep 17 00:00:00 2001 From: MediaPipe Team Date: Mon, 24 Oct 2022 09:09:42 -0700 Subject: [PATCH 37/55] Open-source some tokenizer unit tests. PiperOrigin-RevId: 483399326 --- mediapipe/tasks/cc/text/tokenizers/BUILD | 40 ++++++++++++++++++++++++ 1 file changed, 40 insertions(+) diff --git a/mediapipe/tasks/cc/text/tokenizers/BUILD b/mediapipe/tasks/cc/text/tokenizers/BUILD index e76d943c5..5ce08b2d7 100644 --- a/mediapipe/tasks/cc/text/tokenizers/BUILD +++ b/mediapipe/tasks/cc/text/tokenizers/BUILD @@ -73,6 +73,19 @@ cc_library( ], ) +cc_test( + name = "sentencepiece_tokenizer_test", + srcs = ["sentencepiece_tokenizer_test.cc"], + data = [ + "//mediapipe/tasks/testdata/text:albert_model", + ], + deps = [ + ":sentencepiece_tokenizer", + "//mediapipe/framework/port:gtest_main", + "//mediapipe/tasks/cc/core:utils", + ], +) + cc_library( name = "tokenizer_utils", srcs = ["tokenizer_utils.cc"], @@ -95,6 +108,33 @@ cc_library( ], ) +cc_test( + name = "tokenizer_utils_test", + srcs = ["tokenizer_utils_test.cc"], + data = [ + "//mediapipe/tasks/testdata/text:albert_model", + "//mediapipe/tasks/testdata/text:mobile_bert_model", + "//mediapipe/tasks/testdata/text:text_classifier_models", + ], + linkopts = ["-ldl"], + deps = [ + ":bert_tokenizer", + ":regex_tokenizer", + ":sentencepiece_tokenizer", + ":tokenizer_utils", + "//mediapipe/framework/port:gtest_main", + "//mediapipe/framework/port:status", + "//mediapipe/tasks/cc:common", + "//mediapipe/tasks/cc/core:utils", + "//mediapipe/tasks/cc/metadata:metadata_extractor", + "//mediapipe/tasks/metadata:metadata_schema_cc", + "@com_google_absl//absl/status", + "@com_google_absl//absl/status:statusor", + "@com_google_absl//absl/strings", + "@com_google_absl//absl/strings:cord", + ], +) + cc_library( name = "regex_tokenizer", srcs = [ From 2f2baeff6858bb8c5195910e41ca070bd8cda10f Mon Sep 17 00:00:00 2001 From: MediaPipe Team Date: Mon, 24 Oct 2022 10:12:41 -0700 Subject: [PATCH 38/55] Add support for rotation in ImageEmbedder & ImageSegmenter C++ APIs PiperOrigin-RevId: 483416498 --- .../tasks/cc/vision/image_embedder/BUILD | 1 + .../vision/image_embedder/image_embedder.cc | 34 +++---- .../cc/vision/image_embedder/image_embedder.h | 56 ++++++++---- .../image_embedder/image_embedder_test.cc | 88 +++++++++++++++++-- .../tasks/cc/vision/image_segmenter/BUILD | 3 + .../vision/image_segmenter/image_segmenter.cc | 42 +++++++-- .../vision/image_segmenter/image_segmenter.h | 34 +++++-- .../image_segmenter/image_segmenter_graph.cc | 19 ++-- .../image_segmenter/image_segmenter_test.cc | 60 ++++++++++++- mediapipe/tasks/testdata/vision/BUILD | 4 + third_party/external_files.bzl | 32 ++++--- 11 files changed, 301 insertions(+), 72 deletions(-) diff --git a/mediapipe/tasks/cc/vision/image_embedder/BUILD b/mediapipe/tasks/cc/vision/image_embedder/BUILD index e619b8d1b..0f63f87e4 100644 --- a/mediapipe/tasks/cc/vision/image_embedder/BUILD +++ b/mediapipe/tasks/cc/vision/image_embedder/BUILD @@ -58,6 +58,7 @@ cc_library( "//mediapipe/tasks/cc/core:utils", "//mediapipe/tasks/cc/core/proto:base_options_cc_proto", "//mediapipe/tasks/cc/vision/core:base_vision_task_api", + "//mediapipe/tasks/cc/vision/core:image_processing_options", "//mediapipe/tasks/cc/vision/core:running_mode", "//mediapipe/tasks/cc/vision/core:vision_task_api_factory", "//mediapipe/tasks/cc/vision/image_embedder/proto:image_embedder_graph_options_cc_proto", diff --git a/mediapipe/tasks/cc/vision/image_embedder/image_embedder.cc b/mediapipe/tasks/cc/vision/image_embedder/image_embedder.cc index 24fd2862c..1dc316305 100644 --- a/mediapipe/tasks/cc/vision/image_embedder/image_embedder.cc +++ b/mediapipe/tasks/cc/vision/image_embedder/image_embedder.cc @@ -29,6 +29,7 @@ limitations under the License. #include "mediapipe/tasks/cc/core/proto/base_options.pb.h" #include "mediapipe/tasks/cc/core/task_runner.h" #include "mediapipe/tasks/cc/core/utils.h" +#include "mediapipe/tasks/cc/vision/core/image_processing_options.h" #include "mediapipe/tasks/cc/vision/core/running_mode.h" #include "mediapipe/tasks/cc/vision/core/vision_task_api_factory.h" #include "mediapipe/tasks/cc/vision/image_embedder/proto/image_embedder_graph_options.pb.h" @@ -58,16 +59,6 @@ using ::mediapipe::tasks::core::PacketMap; using ::mediapipe::tasks::vision::image_embedder::proto:: ImageEmbedderGraphOptions; -// Builds a NormalizedRect covering the entire image. -NormalizedRect BuildFullImageNormRect() { - NormalizedRect norm_rect; - norm_rect.set_x_center(0.5); - norm_rect.set_y_center(0.5); - norm_rect.set_width(1); - norm_rect.set_height(1); - return norm_rect; -} - // Creates a MediaPipe graph config that contains a single node of type // "mediapipe.tasks.vision.image_embedder.ImageEmbedderGraph". If the task is // running in the live stream mode, a "FlowLimiterCalculator" will be added to @@ -148,15 +139,16 @@ absl::StatusOr> ImageEmbedder::Create( } absl::StatusOr ImageEmbedder::Embed( - Image image, std::optional roi) { + Image image, + std::optional image_processing_options) { if (image.UsesGpu()) { return CreateStatusWithPayload( absl::StatusCode::kInvalidArgument, "GPU input images are currently not supported.", MediaPipeTasksStatus::kRunnerUnexpectedInputError); } - NormalizedRect norm_rect = - roi.has_value() ? roi.value() : BuildFullImageNormRect(); + ASSIGN_OR_RETURN(NormalizedRect norm_rect, + ConvertToNormalizedRect(image_processing_options)); ASSIGN_OR_RETURN( auto output_packets, ProcessImageData( @@ -167,15 +159,16 @@ absl::StatusOr ImageEmbedder::Embed( } absl::StatusOr ImageEmbedder::EmbedForVideo( - Image image, int64 timestamp_ms, std::optional roi) { + Image image, int64 timestamp_ms, + std::optional image_processing_options) { if (image.UsesGpu()) { return CreateStatusWithPayload( absl::StatusCode::kInvalidArgument, "GPU input images are currently not supported.", MediaPipeTasksStatus::kRunnerUnexpectedInputError); } - NormalizedRect norm_rect = - roi.has_value() ? roi.value() : BuildFullImageNormRect(); + ASSIGN_OR_RETURN(NormalizedRect norm_rect, + ConvertToNormalizedRect(image_processing_options)); ASSIGN_OR_RETURN( auto output_packets, ProcessVideoData( @@ -188,16 +181,17 @@ absl::StatusOr ImageEmbedder::EmbedForVideo( return output_packets[kEmbeddingResultStreamName].Get(); } -absl::Status ImageEmbedder::EmbedAsync(Image image, int64 timestamp_ms, - std::optional roi) { +absl::Status ImageEmbedder::EmbedAsync( + Image image, int64 timestamp_ms, + std::optional image_processing_options) { if (image.UsesGpu()) { return CreateStatusWithPayload( absl::StatusCode::kInvalidArgument, "GPU input images are currently not supported.", MediaPipeTasksStatus::kRunnerUnexpectedInputError); } - NormalizedRect norm_rect = - roi.has_value() ? roi.value() : BuildFullImageNormRect(); + ASSIGN_OR_RETURN(NormalizedRect norm_rect, + ConvertToNormalizedRect(image_processing_options)); return SendLiveStreamData( {{kImageInStreamName, MakePacket(std::move(image)) diff --git a/mediapipe/tasks/cc/vision/image_embedder/image_embedder.h b/mediapipe/tasks/cc/vision/image_embedder/image_embedder.h index 13f4702d1..3a2a1dbee 100644 --- a/mediapipe/tasks/cc/vision/image_embedder/image_embedder.h +++ b/mediapipe/tasks/cc/vision/image_embedder/image_embedder.h @@ -21,11 +21,11 @@ limitations under the License. #include "absl/status/statusor.h" #include "mediapipe/framework/formats/image.h" -#include "mediapipe/framework/formats/rect.pb.h" #include "mediapipe/tasks/cc/components/containers/proto/embeddings.pb.h" #include "mediapipe/tasks/cc/components/embedder_options.h" #include "mediapipe/tasks/cc/core/base_options.h" #include "mediapipe/tasks/cc/vision/core/base_vision_task_api.h" +#include "mediapipe/tasks/cc/vision/core/image_processing_options.h" #include "mediapipe/tasks/cc/vision/core/running_mode.h" namespace mediapipe { @@ -88,9 +88,17 @@ class ImageEmbedder : core::BaseVisionTaskApi { static absl::StatusOr> Create( std::unique_ptr options); - // Performs embedding extraction on the provided single image. Extraction - // is performed on the region of interest specified by the `roi` argument if - // provided, or on the entire image otherwise. + // Performs embedding extraction on the provided single image. + // + // The optional 'image_processing_options' parameter can be used to specify: + // - the rotation to apply to the image before performing embedding + // extraction, by setting its 'rotation_degrees' field. + // and/or + // - the region-of-interest on which to perform embedding extraction, by + // setting its 'region_of_interest' field. If not specified, the full image + // is used. + // If both are specified, the crop around the region-of-interest is extracted + // first, then the specified rotation is applied to the crop. // // Only use this method when the ImageEmbedder is created with the image // running mode. @@ -98,11 +106,20 @@ class ImageEmbedder : core::BaseVisionTaskApi { // The image can be of any size with format RGB or RGBA. absl::StatusOr Embed( mediapipe::Image image, - std::optional roi = std::nullopt); + std::optional image_processing_options = + std::nullopt); - // Performs embedding extraction on the provided video frame. Extraction - // is performed on the region of interested specified by the `roi` argument if - // provided, or on the entire image otherwise. + // Performs embedding extraction on the provided video frame. + // + // The optional 'image_processing_options' parameter can be used to specify: + // - the rotation to apply to the image before performing embedding + // extraction, by setting its 'rotation_degrees' field. + // and/or + // - the region-of-interest on which to perform embedding extraction, by + // setting its 'region_of_interest' field. If not specified, the full image + // is used. + // If both are specified, the crop around the region-of-interest is extracted + // first, then the specified rotation is applied to the crop. // // Only use this method when the ImageEmbedder is created with the video // running mode. @@ -112,12 +129,21 @@ class ImageEmbedder : core::BaseVisionTaskApi { // must be monotonically increasing. absl::StatusOr EmbedForVideo( mediapipe::Image image, int64 timestamp_ms, - std::optional roi = std::nullopt); + std::optional image_processing_options = + std::nullopt); // Sends live image data to embedder, and the results will be available via - // the "result_callback" provided in the ImageEmbedderOptions. Embedding - // extraction is performed on the region of interested specified by the `roi` - // argument if provided, or on the entire image otherwise. + // the "result_callback" provided in the ImageEmbedderOptions. + // + // The optional 'image_processing_options' parameter can be used to specify: + // - the rotation to apply to the image before performing embedding + // extraction, by setting its 'rotation_degrees' field. + // and/or + // - the region-of-interest on which to perform embedding extraction, by + // setting its 'region_of_interest' field. If not specified, the full image + // is used. + // If both are specified, the crop around the region-of-interest is extracted + // first, then the specified rotation is applied to the crop. // // Only use this method when the ImageEmbedder is created with the live // stream running mode. @@ -135,9 +161,9 @@ class ImageEmbedder : core::BaseVisionTaskApi { // longer be valid when the callback returns. To access the image data // outside of the callback, callers need to make a copy of the image. // - The input timestamp in milliseconds. - absl::Status EmbedAsync( - mediapipe::Image image, int64 timestamp_ms, - std::optional roi = std::nullopt); + absl::Status EmbedAsync(mediapipe::Image image, int64 timestamp_ms, + std::optional + image_processing_options = std::nullopt); // Shuts down the ImageEmbedder when all works are done. absl::Status Close() { return runner_->Close(); } diff --git a/mediapipe/tasks/cc/vision/image_embedder/image_embedder_test.cc b/mediapipe/tasks/cc/vision/image_embedder/image_embedder_test.cc index db1019b33..386b6c8eb 100644 --- a/mediapipe/tasks/cc/vision/image_embedder/image_embedder_test.cc +++ b/mediapipe/tasks/cc/vision/image_embedder/image_embedder_test.cc @@ -23,7 +23,6 @@ limitations under the License. #include "absl/status/statusor.h" #include "mediapipe/framework/deps/file_path.h" #include "mediapipe/framework/formats/image.h" -#include "mediapipe/framework/formats/rect.pb.h" #include "mediapipe/framework/port/gmock.h" #include "mediapipe/framework/port/gtest.h" #include "mediapipe/framework/port/status_matchers.h" @@ -42,7 +41,9 @@ namespace image_embedder { namespace { using ::mediapipe::file::JoinPath; +using ::mediapipe::tasks::components::containers::Rect; using ::mediapipe::tasks::components::containers::proto::EmbeddingResult; +using ::mediapipe::tasks::vision::core::ImageProcessingOptions; using ::testing::HasSubstr; using ::testing::Optional; @@ -326,16 +327,14 @@ TEST_F(ImageModeTest, SucceedsWithRegionOfInterest) { MP_ASSERT_OK_AND_ASSIGN( Image crop, DecodeImageFromFile( JoinPath("./", kTestDataDirectory, "burger_crop.jpg"))); - // Bounding box in "burger.jpg" corresponding to "burger_crop.jpg". - NormalizedRect roi; - roi.set_x_center(200.0 / 480); - roi.set_y_center(0.5); - roi.set_width(400.0 / 480); - roi.set_height(1.0f); + // Region-of-interest in "burger.jpg" corresponding to "burger_crop.jpg". + Rect roi{/*left=*/0, /*top=*/0, /*right=*/0.833333, /*bottom=*/1}; + ImageProcessingOptions image_processing_options{roi, /*rotation_degrees=*/0}; // Extract both embeddings. - MP_ASSERT_OK_AND_ASSIGN(const EmbeddingResult& image_result, - image_embedder->Embed(image, roi)); + MP_ASSERT_OK_AND_ASSIGN( + const EmbeddingResult& image_result, + image_embedder->Embed(image, image_processing_options)); MP_ASSERT_OK_AND_ASSIGN(const EmbeddingResult& crop_result, image_embedder->Embed(crop)); @@ -351,6 +350,77 @@ TEST_F(ImageModeTest, SucceedsWithRegionOfInterest) { EXPECT_LE(abs(similarity - expected_similarity), kSimilarityTolerancy); } +TEST_F(ImageModeTest, SucceedsWithRotation) { + auto options = std::make_unique(); + options->base_options.model_asset_path = + JoinPath("./", kTestDataDirectory, kMobileNetV3Embedder); + MP_ASSERT_OK_AND_ASSIGN(std::unique_ptr image_embedder, + ImageEmbedder::Create(std::move(options))); + // Load images: one is a rotated version of the other. + MP_ASSERT_OK_AND_ASSIGN( + Image image, + DecodeImageFromFile(JoinPath("./", kTestDataDirectory, "burger.jpg"))); + MP_ASSERT_OK_AND_ASSIGN(Image rotated, + DecodeImageFromFile(JoinPath("./", kTestDataDirectory, + "burger_rotated.jpg"))); + ImageProcessingOptions image_processing_options; + image_processing_options.rotation_degrees = -90; + + // Extract both embeddings. + MP_ASSERT_OK_AND_ASSIGN(const EmbeddingResult& image_result, + image_embedder->Embed(image)); + MP_ASSERT_OK_AND_ASSIGN( + const EmbeddingResult& rotated_result, + image_embedder->Embed(rotated, image_processing_options)); + + // Check results. + CheckMobileNetV3Result(image_result, false); + CheckMobileNetV3Result(rotated_result, false); + // CheckCosineSimilarity. + MP_ASSERT_OK_AND_ASSIGN( + double similarity, + ImageEmbedder::CosineSimilarity(image_result.embeddings(0).entries(0), + rotated_result.embeddings(0).entries(0))); + double expected_similarity = 0.572265; + EXPECT_LE(abs(similarity - expected_similarity), kSimilarityTolerancy); +} + +TEST_F(ImageModeTest, SucceedsWithRegionOfInterestAndRotation) { + auto options = std::make_unique(); + options->base_options.model_asset_path = + JoinPath("./", kTestDataDirectory, kMobileNetV3Embedder); + MP_ASSERT_OK_AND_ASSIGN(std::unique_ptr image_embedder, + ImageEmbedder::Create(std::move(options))); + MP_ASSERT_OK_AND_ASSIGN( + Image crop, DecodeImageFromFile( + JoinPath("./", kTestDataDirectory, "burger_crop.jpg"))); + MP_ASSERT_OK_AND_ASSIGN(Image rotated, + DecodeImageFromFile(JoinPath("./", kTestDataDirectory, + "burger_rotated.jpg"))); + // Region-of-interest corresponding to burger_crop.jpg. + Rect roi{/*left=*/0, /*top=*/0, /*right=*/1, /*bottom=*/0.8333333}; + ImageProcessingOptions image_processing_options{roi, + /*rotation_degrees=*/-90}; + + // Extract both embeddings. + MP_ASSERT_OK_AND_ASSIGN(const EmbeddingResult& crop_result, + image_embedder->Embed(crop)); + MP_ASSERT_OK_AND_ASSIGN( + const EmbeddingResult& rotated_result, + image_embedder->Embed(rotated, image_processing_options)); + + // Check results. + CheckMobileNetV3Result(crop_result, false); + CheckMobileNetV3Result(rotated_result, false); + // CheckCosineSimilarity. + MP_ASSERT_OK_AND_ASSIGN( + double similarity, + ImageEmbedder::CosineSimilarity(crop_result.embeddings(0).entries(0), + rotated_result.embeddings(0).entries(0))); + double expected_similarity = 0.62838; + EXPECT_LE(abs(similarity - expected_similarity), kSimilarityTolerancy); +} + class VideoModeTest : public tflite_shims::testing::Test {}; TEST_F(VideoModeTest, FailsWithCallingWrongMethod) { diff --git a/mediapipe/tasks/cc/vision/image_segmenter/BUILD b/mediapipe/tasks/cc/vision/image_segmenter/BUILD index 6bdbf41da..81cd43e34 100644 --- a/mediapipe/tasks/cc/vision/image_segmenter/BUILD +++ b/mediapipe/tasks/cc/vision/image_segmenter/BUILD @@ -24,10 +24,12 @@ cc_library( ":image_segmenter_graph", "//mediapipe/framework/api2:builder", "//mediapipe/framework/formats:image", + "//mediapipe/framework/formats:rect_cc_proto", "//mediapipe/tasks/cc/components/proto:segmenter_options_cc_proto", "//mediapipe/tasks/cc/core:base_options", "//mediapipe/tasks/cc/core:utils", "//mediapipe/tasks/cc/vision/core:base_vision_task_api", + "//mediapipe/tasks/cc/vision/core:image_processing_options", "//mediapipe/tasks/cc/vision/core:running_mode", "//mediapipe/tasks/cc/vision/core:vision_task_api_factory", "//mediapipe/tasks/cc/vision/image_segmenter/proto:image_segmenter_options_cc_proto", @@ -48,6 +50,7 @@ cc_library( "//mediapipe/framework/api2:builder", "//mediapipe/framework/api2:port", "//mediapipe/framework/formats:image", + "//mediapipe/framework/formats:rect_cc_proto", "//mediapipe/framework/port:status", "//mediapipe/tasks/cc:common", "//mediapipe/tasks/cc/components:image_preprocessing", diff --git a/mediapipe/tasks/cc/vision/image_segmenter/image_segmenter.cc b/mediapipe/tasks/cc/vision/image_segmenter/image_segmenter.cc index 84ceea88a..209ee0df3 100644 --- a/mediapipe/tasks/cc/vision/image_segmenter/image_segmenter.cc +++ b/mediapipe/tasks/cc/vision/image_segmenter/image_segmenter.cc @@ -17,8 +17,10 @@ limitations under the License. #include "mediapipe/framework/api2/builder.h" #include "mediapipe/framework/formats/image.h" +#include "mediapipe/framework/formats/rect.pb.h" #include "mediapipe/tasks/cc/components/proto/segmenter_options.pb.h" #include "mediapipe/tasks/cc/core/utils.h" +#include "mediapipe/tasks/cc/vision/core/image_processing_options.h" #include "mediapipe/tasks/cc/vision/core/running_mode.h" #include "mediapipe/tasks/cc/vision/core/vision_task_api_factory.h" @@ -32,6 +34,8 @@ constexpr char kGroupedSegmentationTag[] = "GROUPED_SEGMENTATION"; constexpr char kImageInStreamName[] = "image_in"; constexpr char kImageOutStreamName[] = "image_out"; constexpr char kImageTag[] = "IMAGE"; +constexpr char kNormRectStreamName[] = "norm_rect_in"; +constexpr char kNormRectTag[] = "NORM_RECT"; constexpr char kSubgraphTypeName[] = "mediapipe.tasks.vision.ImageSegmenterGraph"; constexpr int kMicroSecondsPerMilliSecond = 1000; @@ -51,15 +55,18 @@ CalculatorGraphConfig CreateGraphConfig( auto& task_subgraph = graph.AddNode(kSubgraphTypeName); task_subgraph.GetOptions().Swap(options.get()); graph.In(kImageTag).SetName(kImageInStreamName); + graph.In(kNormRectTag).SetName(kNormRectStreamName); task_subgraph.Out(kGroupedSegmentationTag).SetName(kSegmentationStreamName) >> graph.Out(kGroupedSegmentationTag); task_subgraph.Out(kImageTag).SetName(kImageOutStreamName) >> graph.Out(kImageTag); if (enable_flow_limiting) { - return tasks::core::AddFlowLimiterCalculator( - graph, task_subgraph, {kImageTag}, kGroupedSegmentationTag); + return tasks::core::AddFlowLimiterCalculator(graph, task_subgraph, + {kImageTag, kNormRectTag}, + kGroupedSegmentationTag); } graph.In(kImageTag) >> task_subgraph.In(kImageTag); + graph.In(kNormRectTag) >> task_subgraph.In(kNormRectTag); return graph.GetConfig(); } @@ -139,47 +146,68 @@ absl::StatusOr> ImageSegmenter::Create( } absl::StatusOr> ImageSegmenter::Segment( - mediapipe::Image image) { + mediapipe::Image image, + std::optional image_processing_options) { if (image.UsesGpu()) { return CreateStatusWithPayload( absl::StatusCode::kInvalidArgument, absl::StrCat("GPU input images are currently not supported."), MediaPipeTasksStatus::kRunnerUnexpectedInputError); } + ASSIGN_OR_RETURN( + NormalizedRect norm_rect, + ConvertToNormalizedRect(image_processing_options, /*roi_allowed=*/false)); ASSIGN_OR_RETURN( auto output_packets, - ProcessImageData({{kImageInStreamName, - mediapipe::MakePacket(std::move(image))}})); + ProcessImageData( + {{kImageInStreamName, mediapipe::MakePacket(std::move(image))}, + {kNormRectStreamName, + MakePacket(std::move(norm_rect))}})); return output_packets[kSegmentationStreamName].Get>(); } absl::StatusOr> ImageSegmenter::SegmentForVideo( - mediapipe::Image image, int64 timestamp_ms) { + mediapipe::Image image, int64 timestamp_ms, + std::optional image_processing_options) { if (image.UsesGpu()) { return CreateStatusWithPayload( absl::StatusCode::kInvalidArgument, absl::StrCat("GPU input images are currently not supported."), MediaPipeTasksStatus::kRunnerUnexpectedInputError); } + ASSIGN_OR_RETURN( + NormalizedRect norm_rect, + ConvertToNormalizedRect(image_processing_options, /*roi_allowed=*/false)); ASSIGN_OR_RETURN( auto output_packets, ProcessVideoData( {{kImageInStreamName, MakePacket(std::move(image)) + .At(Timestamp(timestamp_ms * kMicroSecondsPerMilliSecond))}, + {kNormRectStreamName, + MakePacket(std::move(norm_rect)) .At(Timestamp(timestamp_ms * kMicroSecondsPerMilliSecond))}})); return output_packets[kSegmentationStreamName].Get>(); } -absl::Status ImageSegmenter::SegmentAsync(Image image, int64 timestamp_ms) { +absl::Status ImageSegmenter::SegmentAsync( + Image image, int64 timestamp_ms, + std::optional image_processing_options) { if (image.UsesGpu()) { return CreateStatusWithPayload( absl::StatusCode::kInvalidArgument, absl::StrCat("GPU input images are currently not supported."), MediaPipeTasksStatus::kRunnerUnexpectedInputError); } + ASSIGN_OR_RETURN( + NormalizedRect norm_rect, + ConvertToNormalizedRect(image_processing_options, /*roi_allowed=*/false)); return SendLiveStreamData( {{kImageInStreamName, MakePacket(std::move(image)) + .At(Timestamp(timestamp_ms * kMicroSecondsPerMilliSecond))}, + {kNormRectStreamName, + MakePacket(std::move(norm_rect)) .At(Timestamp(timestamp_ms * kMicroSecondsPerMilliSecond))}}); } diff --git a/mediapipe/tasks/cc/vision/image_segmenter/image_segmenter.h b/mediapipe/tasks/cc/vision/image_segmenter/image_segmenter.h index e2734c4e4..54269ec0e 100644 --- a/mediapipe/tasks/cc/vision/image_segmenter/image_segmenter.h +++ b/mediapipe/tasks/cc/vision/image_segmenter/image_segmenter.h @@ -25,6 +25,7 @@ limitations under the License. #include "mediapipe/framework/formats/image.h" #include "mediapipe/tasks/cc/core/base_options.h" #include "mediapipe/tasks/cc/vision/core/base_vision_task_api.h" +#include "mediapipe/tasks/cc/vision/core/image_processing_options.h" #include "mediapipe/tasks/cc/vision/image_segmenter/proto/image_segmenter_options.pb.h" #include "tensorflow/lite/kernels/register.h" @@ -116,14 +117,21 @@ class ImageSegmenter : tasks::vision::core::BaseVisionTaskApi { // running mode. // // The image can be of any size with format RGB or RGBA. - // TODO: Describes how the input image will be preprocessed - // after the yuv support is implemented. + // + // The optional 'image_processing_options' parameter can be used to specify + // the rotation to apply to the image before performing segmentation, by + // setting its 'rotation_degrees' field. Note that specifying a + // region-of-interest using the 'region_of_interest' field is NOT supported + // and will result in an invalid argument error being returned. // // If the output_type is CATEGORY_MASK, the returned vector of images is // per-category segmented image mask. // If the output_type is CONFIDENCE_MASK, the returned vector of images // contains only one confidence image mask. - absl::StatusOr> Segment(mediapipe::Image image); + absl::StatusOr> Segment( + mediapipe::Image image, + std::optional image_processing_options = + std::nullopt); // Performs image segmentation on the provided video frame. // Only use this method when the ImageSegmenter is created with the video @@ -133,12 +141,20 @@ class ImageSegmenter : tasks::vision::core::BaseVisionTaskApi { // provide the video frame's timestamp (in milliseconds). The input timestamps // must be monotonically increasing. // + // The optional 'image_processing_options' parameter can be used to specify + // the rotation to apply to the image before performing segmentation, by + // setting its 'rotation_degrees' field. Note that specifying a + // region-of-interest using the 'region_of_interest' field is NOT supported + // and will result in an invalid argument error being returned. + // // If the output_type is CATEGORY_MASK, the returned vector of images is // per-category segmented image mask. // If the output_type is CONFIDENCE_MASK, the returned vector of images // contains only one confidence image mask. absl::StatusOr> SegmentForVideo( - mediapipe::Image image, int64 timestamp_ms); + mediapipe::Image image, int64 timestamp_ms, + std::optional image_processing_options = + std::nullopt); // Sends live image data to perform image segmentation, and the results will // be available via the "result_callback" provided in the @@ -150,6 +166,12 @@ class ImageSegmenter : tasks::vision::core::BaseVisionTaskApi { // sent to the image segmenter. The input timestamps must be monotonically // increasing. // + // The optional 'image_processing_options' parameter can be used to specify + // the rotation to apply to the image before performing segmentation, by + // setting its 'rotation_degrees' field. Note that specifying a + // region-of-interest using the 'region_of_interest' field is NOT supported + // and will result in an invalid argument error being returned. + // // The "result_callback" prvoides // - A vector of segmented image masks. // If the output_type is CATEGORY_MASK, the returned vector of images is @@ -161,7 +183,9 @@ class ImageSegmenter : tasks::vision::core::BaseVisionTaskApi { // no longer be valid when the callback returns. To access the image data // outside of the callback, callers need to make a copy of the image. // - The input timestamp in milliseconds. - absl::Status SegmentAsync(mediapipe::Image image, int64 timestamp_ms); + absl::Status SegmentAsync(mediapipe::Image image, int64 timestamp_ms, + std::optional + image_processing_options = std::nullopt); // Shuts down the ImageSegmenter when all works are done. absl::Status Close() { return runner_->Close(); } diff --git a/mediapipe/tasks/cc/vision/image_segmenter/image_segmenter_graph.cc b/mediapipe/tasks/cc/vision/image_segmenter/image_segmenter_graph.cc index 1678dd083..629b940aa 100644 --- a/mediapipe/tasks/cc/vision/image_segmenter/image_segmenter_graph.cc +++ b/mediapipe/tasks/cc/vision/image_segmenter/image_segmenter_graph.cc @@ -23,6 +23,7 @@ limitations under the License. #include "mediapipe/framework/api2/builder.h" #include "mediapipe/framework/api2/port.h" #include "mediapipe/framework/formats/image.h" +#include "mediapipe/framework/formats/rect.pb.h" #include "mediapipe/framework/port/status_macros.h" #include "mediapipe/tasks/cc/common.h" #include "mediapipe/tasks/cc/components/calculators/tensor/tensors_to_segmentation_calculator.pb.h" @@ -62,6 +63,7 @@ using LabelItems = mediapipe::proto_ns::Map; constexpr char kSegmentationTag[] = "SEGMENTATION"; constexpr char kGroupedSegmentationTag[] = "GROUPED_SEGMENTATION"; constexpr char kImageTag[] = "IMAGE"; +constexpr char kNormRectTag[] = "NORM_RECT"; constexpr char kTensorsTag[] = "TENSORS"; constexpr char kOutputSizeTag[] = "OUTPUT_SIZE"; @@ -159,6 +161,10 @@ absl::StatusOr GetOutputTensor( // Inputs: // IMAGE - Image // Image to perform segmentation on. +// NORM_RECT - NormalizedRect @Optional +// Describes image rotation and region of image to perform detection +// on. +// @Optional: rect covering the whole image is used if not specified. // // Outputs: // SEGMENTATION - mediapipe::Image @Multiple @@ -196,10 +202,12 @@ class ImageSegmenterGraph : public core::ModelTaskGraph { ASSIGN_OR_RETURN(const auto* model_resources, CreateModelResources(sc)); Graph graph; - ASSIGN_OR_RETURN(auto output_streams, - BuildSegmentationTask( - sc->Options(), *model_resources, - graph[Input(kImageTag)], graph)); + ASSIGN_OR_RETURN( + auto output_streams, + BuildSegmentationTask( + sc->Options(), *model_resources, + graph[Input(kImageTag)], + graph[Input::Optional(kNormRectTag)], graph)); auto& merge_images_to_vector = graph.AddNode("MergeImagesToVectorCalculator"); @@ -228,7 +236,7 @@ class ImageSegmenterGraph : public core::ModelTaskGraph { absl::StatusOr BuildSegmentationTask( const ImageSegmenterOptions& task_options, const core::ModelResources& model_resources, Source image_in, - Graph& graph) { + Source norm_rect_in, Graph& graph) { MP_RETURN_IF_ERROR(SanityCheckOptions(task_options)); // Adds preprocessing calculators and connects them to the graph input image @@ -240,6 +248,7 @@ class ImageSegmenterGraph : public core::ModelTaskGraph { &preprocessing .GetOptions())); image_in >> preprocessing.In(kImageTag); + norm_rect_in >> preprocessing.In(kNormRectTag); // Adds inference subgraph and connects its input stream to the output // tensors produced by the ImageToTensorCalculator. diff --git a/mediapipe/tasks/cc/vision/image_segmenter/image_segmenter_test.cc b/mediapipe/tasks/cc/vision/image_segmenter/image_segmenter_test.cc index ab23a725c..07235563b 100644 --- a/mediapipe/tasks/cc/vision/image_segmenter/image_segmenter_test.cc +++ b/mediapipe/tasks/cc/vision/image_segmenter/image_segmenter_test.cc @@ -29,8 +29,10 @@ limitations under the License. #include "mediapipe/framework/port/opencv_imgcodecs_inc.h" #include "mediapipe/framework/port/status_matchers.h" #include "mediapipe/tasks/cc/components/calculators/tensor/tensors_to_segmentation_calculator.pb.h" +#include "mediapipe/tasks/cc/components/containers/rect.h" #include "mediapipe/tasks/cc/core/proto/base_options.pb.h" #include "mediapipe/tasks/cc/core/proto/external_file.pb.h" +#include "mediapipe/tasks/cc/vision/core/image_processing_options.h" #include "mediapipe/tasks/cc/vision/image_segmenter/proto/image_segmenter_options.pb.h" #include "mediapipe/tasks/cc/vision/utils/image_utils.h" #include "tensorflow/lite/core/shims/cc/shims_test_util.h" @@ -44,6 +46,8 @@ namespace { using ::mediapipe::Image; using ::mediapipe::file::JoinPath; +using ::mediapipe::tasks::components::containers::Rect; +using ::mediapipe::tasks::vision::core::ImageProcessingOptions; using ::testing::HasSubstr; using ::testing::Optional; @@ -237,7 +241,6 @@ TEST_F(ImageModeTest, SucceedsWithConfidenceMask) { MP_ASSERT_OK_AND_ASSIGN(std::unique_ptr segmenter, ImageSegmenter::Create(std::move(options))); - MP_ASSERT_OK_AND_ASSIGN(auto results, segmenter->Segment(image)); MP_ASSERT_OK_AND_ASSIGN(auto confidence_masks, segmenter->Segment(image)); EXPECT_EQ(confidence_masks.size(), 21); @@ -253,6 +256,61 @@ TEST_F(ImageModeTest, SucceedsWithConfidenceMask) { SimilarToFloatMask(expected_mask_float, kGoldenMaskSimilarity)); } +TEST_F(ImageModeTest, SucceedsWithRotation) { + MP_ASSERT_OK_AND_ASSIGN( + Image image, DecodeImageFromFile( + JoinPath("./", kTestDataDirectory, "cat_rotated.jpg"))); + auto options = std::make_unique(); + options->base_options.model_asset_path = + JoinPath("./", kTestDataDirectory, kDeeplabV3WithMetadata); + options->output_type = ImageSegmenterOptions::OutputType::CONFIDENCE_MASK; + options->activation = ImageSegmenterOptions::Activation::SOFTMAX; + + MP_ASSERT_OK_AND_ASSIGN(std::unique_ptr segmenter, + ImageSegmenter::Create(std::move(options))); + ImageProcessingOptions image_processing_options; + image_processing_options.rotation_degrees = -90; + MP_ASSERT_OK_AND_ASSIGN(auto confidence_masks, segmenter->Segment(image)); + EXPECT_EQ(confidence_masks.size(), 21); + + cv::Mat expected_mask = + cv::imread(JoinPath("./", kTestDataDirectory, "cat_rotated_mask.jpg"), + cv::IMREAD_GRAYSCALE); + cv::Mat expected_mask_float; + expected_mask.convertTo(expected_mask_float, CV_32FC1, 1 / 255.f); + + // Cat category index 8. + cv::Mat cat_mask = mediapipe::formats::MatView( + confidence_masks[8].GetImageFrameSharedPtr().get()); + EXPECT_THAT(cat_mask, + SimilarToFloatMask(expected_mask_float, kGoldenMaskSimilarity)); +} + +TEST_F(ImageModeTest, FailsWithRegionOfInterest) { + MP_ASSERT_OK_AND_ASSIGN( + Image image, + DecodeImageFromFile(JoinPath("./", kTestDataDirectory, "cat.jpg"))); + auto options = std::make_unique(); + options->base_options.model_asset_path = + JoinPath("./", kTestDataDirectory, kDeeplabV3WithMetadata); + options->output_type = ImageSegmenterOptions::OutputType::CONFIDENCE_MASK; + options->activation = ImageSegmenterOptions::Activation::SOFTMAX; + + MP_ASSERT_OK_AND_ASSIGN(std::unique_ptr segmenter, + ImageSegmenter::Create(std::move(options))); + Rect roi{/*left=*/0.1, /*top=*/0, /*right=*/0.9, /*bottom=*/1}; + ImageProcessingOptions image_processing_options{roi, /*rotation_degrees=*/0}; + + auto results = segmenter->Segment(image, image_processing_options); + EXPECT_EQ(results.status().code(), absl::StatusCode::kInvalidArgument); + EXPECT_THAT(results.status().message(), + HasSubstr("This task doesn't support region-of-interest")); + EXPECT_THAT( + results.status().GetPayload(kMediaPipeTasksPayload), + Optional(absl::Cord(absl::StrCat( + MediaPipeTasksStatus::kImageProcessingInvalidArgumentError)))); +} + TEST_F(ImageModeTest, SucceedsSelfie128x128Segmentation) { Image image = GetSRGBImage(JoinPath("./", kTestDataDirectory, "mozart_square.jpg")); diff --git a/mediapipe/tasks/testdata/vision/BUILD b/mediapipe/tasks/testdata/vision/BUILD index ebb8f05a6..c45cc6e69 100644 --- a/mediapipe/tasks/testdata/vision/BUILD +++ b/mediapipe/tasks/testdata/vision/BUILD @@ -28,6 +28,8 @@ mediapipe_files(srcs = [ "burger_rotated.jpg", "cat.jpg", "cat_mask.jpg", + "cat_rotated.jpg", + "cat_rotated_mask.jpg", "cats_and_dogs.jpg", "cats_and_dogs_no_resizing.jpg", "cats_and_dogs_rotated.jpg", @@ -84,6 +86,8 @@ filegroup( "burger_rotated.jpg", "cat.jpg", "cat_mask.jpg", + "cat_rotated.jpg", + "cat_rotated_mask.jpg", "cats_and_dogs.jpg", "cats_and_dogs_no_resizing.jpg", "cats_and_dogs_rotated.jpg", diff --git a/third_party/external_files.bzl b/third_party/external_files.bzl index 84b354c99..d460387aa 100644 --- a/third_party/external_files.bzl +++ b/third_party/external_files.bzl @@ -76,6 +76,18 @@ def external_files(): urls = ["https://storage.googleapis.com/mediapipe-assets/cat_mask.jpg?generation=1661875677203533"], ) + http_file( + name = "com_google_mediapipe_cat_rotated_jpg", + sha256 = "b78cee5ad14c9f36b1c25d103db371d81ca74d99030063c46a38e80bb8f38649", + urls = ["https://storage.googleapis.com/mediapipe-assets/cat_rotated.jpg?generation=1666304165042123"], + ) + + http_file( + name = "com_google_mediapipe_cat_rotated_mask_jpg", + sha256 = "f336973e7621d602f2ebc9a6ab1c62d8502272d391713f369d3b99541afda861", + urls = ["https://storage.googleapis.com/mediapipe-assets/cat_rotated_mask.jpg?generation=1666304167148173"], + ) + http_file( name = "com_google_mediapipe_cats_and_dogs_jpg", sha256 = "a2eaa7ad3a1aae4e623dd362a5f737e8a88d122597ecd1a02b3e1444db56df9c", @@ -162,8 +174,8 @@ def external_files(): http_file( name = "com_google_mediapipe_expected_left_down_hand_rotated_landmarks_prototxt", - sha256 = "a16d6cb8dd07d60f0678ddeb6a7447b73b9b03d4ddde365c8770b472205bb6cf", - urls = ["https://storage.googleapis.com/mediapipe-assets/expected_left_down_hand_rotated_landmarks.prototxt?generation=1666037061297507"], + sha256 = "c4dfdcc2e4cd366eb5f8ad227be94049eb593e3a528564611094687912463687", + urls = ["https://storage.googleapis.com/mediapipe-assets/expected_left_down_hand_rotated_landmarks.prototxt?generation=1666304169636598"], ) http_file( @@ -174,8 +186,8 @@ def external_files(): http_file( name = "com_google_mediapipe_expected_left_up_hand_rotated_landmarks_prototxt", - sha256 = "a9b9789c274d48a7cb9cc10af7bc644eb2512bb934529790d0a5404726daa86a", - urls = ["https://storage.googleapis.com/mediapipe-assets/expected_left_up_hand_rotated_landmarks.prototxt?generation=1666037063443676"], + sha256 = "7fb2d33cf69d2da50952a45bad0c0618f30859e608958fee95948a6e0de63ccb", + urls = ["https://storage.googleapis.com/mediapipe-assets/expected_left_up_hand_rotated_landmarks.prototxt?generation=1666304171758037"], ) http_file( @@ -258,8 +270,8 @@ def external_files(): http_file( name = "com_google_mediapipe_hand_detector_result_one_hand_rotated_pbtxt", - sha256 = "ff5ca0654028d78a3380df90054273cae79abe1b7369b164063fd1d5758ec370", - urls = ["https://storage.googleapis.com/mediapipe-assets/hand_detector_result_one_hand_rotated.pbtxt?generation=1666037065601724"], + sha256 = "555079c274ea91699757a0b9888c9993a8ab450069103b1bcd4ebb805a8e023c", + urls = ["https://storage.googleapis.com/mediapipe-assets/hand_detector_result_one_hand_rotated.pbtxt?generation=1666304174234283"], ) http_file( @@ -606,8 +618,8 @@ def external_files(): http_file( name = "com_google_mediapipe_pointing_up_rotated_landmarks_pbtxt", - sha256 = "ccf67e5867094ffb6c465a4dfbf2ef1eb3f9db2465803fc25a0b84c958e050de", - urls = ["https://storage.googleapis.com/mediapipe-assets/pointing_up_rotated_landmarks.pbtxt?generation=1666037074376515"], + sha256 = "5ec37218d8b613436f5c10121dc689bf9ee69af0656a6ccf8c2e3e8b652e2ad6", + urls = ["https://storage.googleapis.com/mediapipe-assets/pointing_up_rotated_landmarks.pbtxt?generation=1666304178388806"], ) http_file( @@ -798,8 +810,8 @@ def external_files(): http_file( name = "com_google_mediapipe_thumb_up_rotated_landmarks_pbtxt", - sha256 = "5d0a465959cacbd201ac8dd8fc8a66c5997a172b71809b12d27296db6a28a102", - urls = ["https://storage.googleapis.com/mediapipe-assets/thumb_up_rotated_landmarks.pbtxt?generation=1666037079490527"], + sha256 = "6645bbd98ea7f90b3e1ba297e16ea5280847fc5bf5400726d98c282f6c597257", + urls = ["https://storage.googleapis.com/mediapipe-assets/thumb_up_rotated_landmarks.pbtxt?generation=1666304181397432"], ) http_file( From 2cf9523468e280d736a92a1b681f2c96bb61c8f3 Mon Sep 17 00:00:00 2001 From: Jiuqiang Tang Date: Mon, 24 Oct 2022 10:53:35 -0700 Subject: [PATCH 39/55] Fix the java_package name. PiperOrigin-RevId: 483428848 --- mediapipe/tasks/cc/components/containers/proto/category.proto | 2 +- .../cc/components/containers/proto/classifications.proto | 2 +- .../tasks/cc/components/containers/proto/embeddings.proto | 3 +++ .../tasks/text/textclassifier/TextClassificationResult.java | 4 ++-- .../mediapipe/tasks/text/textclassifier/TextClassifier.java | 2 +- .../vision/imageclassifier/ImageClassificationResult.java | 4 ++-- .../tasks/vision/imageclassifier/ImageClassifier.java | 2 +- 7 files changed, 11 insertions(+), 8 deletions(-) diff --git a/mediapipe/tasks/cc/components/containers/proto/category.proto b/mediapipe/tasks/cc/components/containers/proto/category.proto index a154e5f4e..2ba760e99 100644 --- a/mediapipe/tasks/cc/components/containers/proto/category.proto +++ b/mediapipe/tasks/cc/components/containers/proto/category.proto @@ -17,7 +17,7 @@ syntax = "proto2"; package mediapipe.tasks.components.containers.proto; -option java_package = "com.google.mediapipe.tasks.components.container.proto"; +option java_package = "com.google.mediapipe.tasks.components.containers.proto"; option java_outer_classname = "CategoryProto"; // A single classification result. diff --git a/mediapipe/tasks/cc/components/containers/proto/classifications.proto b/mediapipe/tasks/cc/components/containers/proto/classifications.proto index 0f5086b95..712607fa6 100644 --- a/mediapipe/tasks/cc/components/containers/proto/classifications.proto +++ b/mediapipe/tasks/cc/components/containers/proto/classifications.proto @@ -19,7 +19,7 @@ package mediapipe.tasks.components.containers.proto; import "mediapipe/tasks/cc/components/containers/proto/category.proto"; -option java_package = "com.google.mediapipe.tasks.components.container.proto"; +option java_package = "com.google.mediapipe.tasks.components.containers.proto"; option java_outer_classname = "ClassificationsProto"; // List of predicted categories with an optional timestamp. diff --git a/mediapipe/tasks/cc/components/containers/proto/embeddings.proto b/mediapipe/tasks/cc/components/containers/proto/embeddings.proto index d57b08b53..39811e6c0 100644 --- a/mediapipe/tasks/cc/components/containers/proto/embeddings.proto +++ b/mediapipe/tasks/cc/components/containers/proto/embeddings.proto @@ -17,6 +17,9 @@ syntax = "proto2"; package mediapipe.tasks.components.containers.proto; +option java_package = "com.google.mediapipe.tasks.components.containers.proto"; +option java_outer_classname = "EmbeddingsProto"; + // Defines a dense floating-point embedding. message FloatEmbedding { repeated float values = 1 [packed = true]; diff --git a/mediapipe/tasks/java/com/google/mediapipe/tasks/text/textclassifier/TextClassificationResult.java b/mediapipe/tasks/java/com/google/mediapipe/tasks/text/textclassifier/TextClassificationResult.java index dd9b9a1b3..c1e2446cd 100644 --- a/mediapipe/tasks/java/com/google/mediapipe/tasks/text/textclassifier/TextClassificationResult.java +++ b/mediapipe/tasks/java/com/google/mediapipe/tasks/text/textclassifier/TextClassificationResult.java @@ -15,11 +15,11 @@ package com.google.mediapipe.tasks.text.textclassifier; import com.google.auto.value.AutoValue; -import com.google.mediapipe.tasks.components.container.proto.CategoryProto; -import com.google.mediapipe.tasks.components.container.proto.ClassificationsProto; import com.google.mediapipe.tasks.components.containers.Category; import com.google.mediapipe.tasks.components.containers.ClassificationEntry; import com.google.mediapipe.tasks.components.containers.Classifications; +import com.google.mediapipe.tasks.components.containers.proto.CategoryProto; +import com.google.mediapipe.tasks.components.containers.proto.ClassificationsProto; import com.google.mediapipe.tasks.core.TaskResult; import java.util.ArrayList; import java.util.Collections; diff --git a/mediapipe/tasks/java/com/google/mediapipe/tasks/text/textclassifier/TextClassifier.java b/mediapipe/tasks/java/com/google/mediapipe/tasks/text/textclassifier/TextClassifier.java index 76117d2e4..07a4fa48f 100644 --- a/mediapipe/tasks/java/com/google/mediapipe/tasks/text/textclassifier/TextClassifier.java +++ b/mediapipe/tasks/java/com/google/mediapipe/tasks/text/textclassifier/TextClassifier.java @@ -22,7 +22,7 @@ import com.google.mediapipe.framework.MediaPipeException; import com.google.mediapipe.framework.Packet; import com.google.mediapipe.framework.PacketGetter; import com.google.mediapipe.framework.ProtoUtil; -import com.google.mediapipe.tasks.components.container.proto.ClassificationsProto; +import com.google.mediapipe.tasks.components.containers.proto.ClassificationsProto; import com.google.mediapipe.tasks.components.processors.ClassifierOptions; import com.google.mediapipe.tasks.core.BaseOptions; import com.google.mediapipe.tasks.core.OutputHandler; diff --git a/mediapipe/tasks/java/com/google/mediapipe/tasks/vision/imageclassifier/ImageClassificationResult.java b/mediapipe/tasks/java/com/google/mediapipe/tasks/vision/imageclassifier/ImageClassificationResult.java index 09f854caa..d82a47b86 100644 --- a/mediapipe/tasks/java/com/google/mediapipe/tasks/vision/imageclassifier/ImageClassificationResult.java +++ b/mediapipe/tasks/java/com/google/mediapipe/tasks/vision/imageclassifier/ImageClassificationResult.java @@ -15,11 +15,11 @@ package com.google.mediapipe.tasks.vision.imageclassifier; import com.google.auto.value.AutoValue; -import com.google.mediapipe.tasks.components.container.proto.CategoryProto; -import com.google.mediapipe.tasks.components.container.proto.ClassificationsProto; import com.google.mediapipe.tasks.components.containers.Category; import com.google.mediapipe.tasks.components.containers.ClassificationEntry; import com.google.mediapipe.tasks.components.containers.Classifications; +import com.google.mediapipe.tasks.components.containers.proto.CategoryProto; +import com.google.mediapipe.tasks.components.containers.proto.ClassificationsProto; import com.google.mediapipe.tasks.core.TaskResult; import java.util.ArrayList; import java.util.Collections; diff --git a/mediapipe/tasks/java/com/google/mediapipe/tasks/vision/imageclassifier/ImageClassifier.java b/mediapipe/tasks/java/com/google/mediapipe/tasks/vision/imageclassifier/ImageClassifier.java index e7d9e4ea1..75e2de13a 100644 --- a/mediapipe/tasks/java/com/google/mediapipe/tasks/vision/imageclassifier/ImageClassifier.java +++ b/mediapipe/tasks/java/com/google/mediapipe/tasks/vision/imageclassifier/ImageClassifier.java @@ -26,7 +26,7 @@ import com.google.mediapipe.framework.PacketGetter; import com.google.mediapipe.framework.ProtoUtil; import com.google.mediapipe.framework.image.BitmapImageBuilder; import com.google.mediapipe.framework.image.MPImage; -import com.google.mediapipe.tasks.components.container.proto.ClassificationsProto; +import com.google.mediapipe.tasks.components.containers.proto.ClassificationsProto; import com.google.mediapipe.tasks.components.processors.ClassifierOptions; import com.google.mediapipe.tasks.core.BaseOptions; import com.google.mediapipe.tasks.core.ErrorListener; From 3d88b1797a357e80ffd3a6e85a227d61795c8935 Mon Sep 17 00:00:00 2001 From: MediaPipe Team Date: Mon, 24 Oct 2022 11:25:18 -0700 Subject: [PATCH 40/55] Switch CHECK to a status in resource handling code. Expand error message. PiperOrigin-RevId: 483438131 --- mediapipe/util/resource_util_android.cc | 3 ++- mediapipe/util/resource_util_apple.cc | 3 ++- 2 files changed, 4 insertions(+), 2 deletions(-) diff --git a/mediapipe/util/resource_util_android.cc b/mediapipe/util/resource_util_android.cc index b18354d5f..1e970f212 100644 --- a/mediapipe/util/resource_util_android.cc +++ b/mediapipe/util/resource_util_android.cc @@ -82,7 +82,8 @@ absl::StatusOr PathToResourceAsFile(const std::string& path) { // If that fails, assume it was a relative path, and try just the base name. { const size_t last_slash_idx = path.find_last_of("\\/"); - CHECK_NE(last_slash_idx, std::string::npos); // Make sure it's a path. + RET_CHECK(last_slash_idx != std::string::npos) + << path << " doesn't have a slash in it"; // Make sure it's a path. auto base_name = path.substr(last_slash_idx + 1); auto status_or_path = PathToResourceAsFileInternal(base_name); if (status_or_path.ok()) { diff --git a/mediapipe/util/resource_util_apple.cc b/mediapipe/util/resource_util_apple.cc index c812dcb57..f64718348 100644 --- a/mediapipe/util/resource_util_apple.cc +++ b/mediapipe/util/resource_util_apple.cc @@ -71,7 +71,8 @@ absl::StatusOr PathToResourceAsFile(const std::string& path) { // If that fails, assume it was a relative path, and try just the base name. { const size_t last_slash_idx = path.find_last_of("\\/"); - CHECK_NE(last_slash_idx, std::string::npos); // Make sure it's a path. + RET_CHECK(last_slash_idx != std::string::npos) + << path << " doesn't have a slash in it"; // Make sure it's a path. auto base_name = path.substr(last_slash_idx + 1); auto status_or_path = PathToResourceAsFileInternal(base_name); if (status_or_path.ok()) { From b8502decff8f798b34d3b981b7dc8c3676cfc6f7 Mon Sep 17 00:00:00 2001 From: Sebastian Schmidt Date: Mon, 24 Oct 2022 12:13:11 -0700 Subject: [PATCH 41/55] Update model file upload script to verify full URL PiperOrigin-RevId: 483451219 --- third_party/external_files.bzl | 18 +++++++++--------- 1 file changed, 9 insertions(+), 9 deletions(-) diff --git a/third_party/external_files.bzl b/third_party/external_files.bzl index d460387aa..4b7309eef 100644 --- a/third_party/external_files.bzl +++ b/third_party/external_files.bzl @@ -175,7 +175,7 @@ def external_files(): http_file( name = "com_google_mediapipe_expected_left_down_hand_rotated_landmarks_prototxt", sha256 = "c4dfdcc2e4cd366eb5f8ad227be94049eb593e3a528564611094687912463687", - urls = ["https://storage.googleapis.com/mediapipe-assets/expected_left_down_hand_rotated_landmarks.prototxt?generation=1666304169636598"], + urls = ["https://storage.googleapis.com/mediapipe-assets/expected_left_down_hand_rotated_landmarks.prototxt?generation=1666629474155924"], ) http_file( @@ -187,7 +187,7 @@ def external_files(): http_file( name = "com_google_mediapipe_expected_left_up_hand_rotated_landmarks_prototxt", sha256 = "7fb2d33cf69d2da50952a45bad0c0618f30859e608958fee95948a6e0de63ccb", - urls = ["https://storage.googleapis.com/mediapipe-assets/expected_left_up_hand_rotated_landmarks.prototxt?generation=1666304171758037"], + urls = ["https://storage.googleapis.com/mediapipe-assets/expected_left_up_hand_rotated_landmarks.prototxt?generation=1666629476401757"], ) http_file( @@ -271,7 +271,7 @@ def external_files(): http_file( name = "com_google_mediapipe_hand_detector_result_one_hand_rotated_pbtxt", sha256 = "555079c274ea91699757a0b9888c9993a8ab450069103b1bcd4ebb805a8e023c", - urls = ["https://storage.googleapis.com/mediapipe-assets/hand_detector_result_one_hand_rotated.pbtxt?generation=1666304174234283"], + urls = ["https://storage.googleapis.com/mediapipe-assets/hand_detector_result_one_hand_rotated.pbtxt?generation=1666629478777955"], ) http_file( @@ -468,14 +468,14 @@ def external_files(): http_file( name = "com_google_mediapipe_mobilenet_v2_1_0_224_json", - sha256 = "0eb285a857b4bb1815736d0902ace0af45ea62e90c1dac98844b9ca797cd0d7b", - urls = ["https://storage.googleapis.com/mediapipe-assets/mobilenet_v2_1.0_224.json?generation=1665988398778178"], + sha256 = "94613ea9539a20a3352604004be6d4d64d4d76250bc9042fcd8685c9a8498517", + urls = ["https://storage.googleapis.com/mediapipe-assets/mobilenet_v2_1.0_224.json?generation=1666633416316646"], ) http_file( name = "com_google_mediapipe_mobilenet_v2_1_0_224_quant_json", - sha256 = "932f345ebe3d98daf0dc4c88b0f9e694e450390fb394fc217e851338dfec43e6", - urls = ["https://storage.googleapis.com/mediapipe-assets/mobilenet_v2_1.0_224_quant.json?generation=1665988401522527"], + sha256 = "3703eadcf838b65bbc2b2aa11dbb1f1bc654c7a09a7aba5ca75a26096484a8ac", + urls = ["https://storage.googleapis.com/mediapipe-assets/mobilenet_v2_1.0_224_quant.json?generation=1666633418665507"], ) http_file( @@ -619,7 +619,7 @@ def external_files(): http_file( name = "com_google_mediapipe_pointing_up_rotated_landmarks_pbtxt", sha256 = "5ec37218d8b613436f5c10121dc689bf9ee69af0656a6ccf8c2e3e8b652e2ad6", - urls = ["https://storage.googleapis.com/mediapipe-assets/pointing_up_rotated_landmarks.pbtxt?generation=1666304178388806"], + urls = ["https://storage.googleapis.com/mediapipe-assets/pointing_up_rotated_landmarks.pbtxt?generation=1666629486774022"], ) http_file( @@ -811,7 +811,7 @@ def external_files(): http_file( name = "com_google_mediapipe_thumb_up_rotated_landmarks_pbtxt", sha256 = "6645bbd98ea7f90b3e1ba297e16ea5280847fc5bf5400726d98c282f6c597257", - urls = ["https://storage.googleapis.com/mediapipe-assets/thumb_up_rotated_landmarks.pbtxt?generation=1666304181397432"], + urls = ["https://storage.googleapis.com/mediapipe-assets/thumb_up_rotated_landmarks.pbtxt?generation=1666629489421733"], ) http_file( From 94cd1348096bae81f9cf6bdeb5ed5b5de96b66b3 Mon Sep 17 00:00:00 2001 From: MediaPipe Team Date: Mon, 24 Oct 2022 14:59:09 -0700 Subject: [PATCH 42/55] Add support for image rotation in Java vision tasks. PiperOrigin-RevId: 483493729 --- .../android/objectdetector/src/main/BUILD | 1 + .../examples/objectdetector/MainActivity.java | 38 ++-- .../com/google/mediapipe/tasks/vision/BUILD | 1 + .../tasks/vision/core/BaseVisionTaskApi.java | 181 ++++-------------- .../vision/core/ImageProcessingOptions.java | 92 +++++++++ .../gesturerecognizer/GestureRecognizer.java | 128 +++++++++++-- .../imageclassifier/ImageClassifier.java | 121 ++++++------ .../vision/objectdetector/ObjectDetector.java | 127 ++++++++++-- .../tasks/vision/core/AndroidManifest.xml | 24 +++ .../google/mediapipe/tasks/vision/core/BUILD | 19 ++ .../core/ImageProcessingOptionsTest.java | 70 +++++++ .../GestureRecognizerTest.java | 79 +++++++- .../imageclassifier/ImageClassifierTest.java | 79 +++++++- .../objectdetector/ObjectDetectorTest.java | 85 ++++++-- 14 files changed, 762 insertions(+), 283 deletions(-) create mode 100644 mediapipe/tasks/java/com/google/mediapipe/tasks/vision/core/ImageProcessingOptions.java create mode 100644 mediapipe/tasks/javatests/com/google/mediapipe/tasks/vision/core/AndroidManifest.xml create mode 100644 mediapipe/tasks/javatests/com/google/mediapipe/tasks/vision/core/BUILD create mode 100644 mediapipe/tasks/javatests/com/google/mediapipe/tasks/vision/core/ImageProcessingOptionsTest.java diff --git a/mediapipe/tasks/examples/android/objectdetector/src/main/BUILD b/mediapipe/tasks/examples/android/objectdetector/src/main/BUILD index acbdbd6eb..89c1edcb3 100644 --- a/mediapipe/tasks/examples/android/objectdetector/src/main/BUILD +++ b/mediapipe/tasks/examples/android/objectdetector/src/main/BUILD @@ -31,6 +31,7 @@ android_binary( multidex = "native", resource_files = ["//mediapipe/tasks/examples/android:resource_files"], deps = [ + "//mediapipe/java/com/google/mediapipe/framework:android_framework", "//mediapipe/java/com/google/mediapipe/framework/image", "//mediapipe/tasks/java/com/google/mediapipe/tasks/components/containers:detection", "//mediapipe/tasks/java/com/google/mediapipe/tasks/core", diff --git a/mediapipe/tasks/examples/android/objectdetector/src/main/java/com/google/mediapipe/tasks/examples/objectdetector/MainActivity.java b/mediapipe/tasks/examples/android/objectdetector/src/main/java/com/google/mediapipe/tasks/examples/objectdetector/MainActivity.java index 11c8c1837..18c010a00 100644 --- a/mediapipe/tasks/examples/android/objectdetector/src/main/java/com/google/mediapipe/tasks/examples/objectdetector/MainActivity.java +++ b/mediapipe/tasks/examples/android/objectdetector/src/main/java/com/google/mediapipe/tasks/examples/objectdetector/MainActivity.java @@ -16,7 +16,6 @@ package com.google.mediapipe.tasks.examples.objectdetector; import android.content.Intent; import android.graphics.Bitmap; -import android.graphics.Matrix; import android.media.MediaMetadataRetriever; import android.os.Bundle; import android.provider.MediaStore; @@ -29,9 +28,11 @@ import androidx.activity.result.ActivityResultLauncher; import androidx.activity.result.contract.ActivityResultContracts; import androidx.exifinterface.media.ExifInterface; // ContentResolver dependency +import com.google.mediapipe.framework.MediaPipeException; import com.google.mediapipe.framework.image.BitmapImageBuilder; import com.google.mediapipe.framework.image.MPImage; import com.google.mediapipe.tasks.core.BaseOptions; +import com.google.mediapipe.tasks.vision.core.ImageProcessingOptions; import com.google.mediapipe.tasks.vision.core.RunningMode; import com.google.mediapipe.tasks.vision.objectdetector.ObjectDetectionResult; import com.google.mediapipe.tasks.vision.objectdetector.ObjectDetector; @@ -82,6 +83,7 @@ public class MainActivity extends AppCompatActivity { if (resultIntent != null) { if (result.getResultCode() == RESULT_OK) { Bitmap bitmap = null; + int rotation = 0; try { bitmap = downscaleBitmap( @@ -93,13 +95,16 @@ public class MainActivity extends AppCompatActivity { try { InputStream imageData = this.getContentResolver().openInputStream(resultIntent.getData()); - bitmap = rotateBitmap(bitmap, imageData); - } catch (IOException e) { + rotation = getImageRotation(imageData); + } catch (IOException | MediaPipeException e) { Log.e(TAG, "Bitmap rotation error:" + e); } if (bitmap != null) { MPImage image = new BitmapImageBuilder(bitmap).build(); - ObjectDetectionResult detectionResult = objectDetector.detect(image); + ObjectDetectionResult detectionResult = + objectDetector.detect( + image, + ImageProcessingOptions.builder().setRotationDegrees(rotation).build()); imageView.setData(image, detectionResult); runOnUiThread(() -> imageView.update()); } @@ -210,28 +215,25 @@ public class MainActivity extends AppCompatActivity { return Bitmap.createScaledBitmap(originalBitmap, width, height, false); } - private Bitmap rotateBitmap(Bitmap inputBitmap, InputStream imageData) throws IOException { + private int getImageRotation(InputStream imageData) throws IOException, MediaPipeException { int orientation = new ExifInterface(imageData) .getAttributeInt(ExifInterface.TAG_ORIENTATION, ExifInterface.ORIENTATION_NORMAL); - if (orientation == ExifInterface.ORIENTATION_NORMAL) { - return inputBitmap; - } - Matrix matrix = new Matrix(); switch (orientation) { + case ExifInterface.ORIENTATION_NORMAL: + return 0; case ExifInterface.ORIENTATION_ROTATE_90: - matrix.postRotate(90); - break; + return 90; case ExifInterface.ORIENTATION_ROTATE_180: - matrix.postRotate(180); - break; + return 180; case ExifInterface.ORIENTATION_ROTATE_270: - matrix.postRotate(270); - break; + return 270; default: - matrix.postRotate(0); + // TODO: use getRotationDegrees() and isFlipped() instead of switch once flip + // is supported. + throw new MediaPipeException( + MediaPipeException.StatusCode.UNIMPLEMENTED.ordinal(), + "Flipped images are not supported yet."); } - return Bitmap.createBitmap( - inputBitmap, 0, 0, inputBitmap.getWidth(), inputBitmap.getHeight(), matrix, true); } } diff --git a/mediapipe/tasks/java/com/google/mediapipe/tasks/vision/BUILD b/mediapipe/tasks/java/com/google/mediapipe/tasks/vision/BUILD index 5ea465d47..ed65fbcac 100644 --- a/mediapipe/tasks/java/com/google/mediapipe/tasks/vision/BUILD +++ b/mediapipe/tasks/java/com/google/mediapipe/tasks/vision/BUILD @@ -28,6 +28,7 @@ android_library( "//mediapipe/java/com/google/mediapipe/framework:android_framework_no_mff", "//mediapipe/java/com/google/mediapipe/framework/image", "//mediapipe/tasks/java/com/google/mediapipe/tasks/core", + "//third_party:autovalue", "@maven//:com_google_guava_guava", ], ) diff --git a/mediapipe/tasks/java/com/google/mediapipe/tasks/vision/core/BaseVisionTaskApi.java b/mediapipe/tasks/java/com/google/mediapipe/tasks/vision/core/BaseVisionTaskApi.java index 49dab408c..0774b69a2 100644 --- a/mediapipe/tasks/java/com/google/mediapipe/tasks/vision/core/BaseVisionTaskApi.java +++ b/mediapipe/tasks/java/com/google/mediapipe/tasks/vision/core/BaseVisionTaskApi.java @@ -24,7 +24,6 @@ import com.google.mediapipe.tasks.core.TaskResult; import com.google.mediapipe.tasks.core.TaskRunner; import java.util.HashMap; import java.util.Map; -import java.util.Optional; /** The base class of MediaPipe vision tasks. */ public class BaseVisionTaskApi implements AutoCloseable { @@ -32,7 +31,7 @@ public class BaseVisionTaskApi implements AutoCloseable { private final TaskRunner runner; private final RunningMode runningMode; private final String imageStreamName; - private final Optional normRectStreamName; + private final String normRectStreamName; static { System.loadLibrary("mediapipe_tasks_vision_jni"); @@ -40,27 +39,13 @@ public class BaseVisionTaskApi implements AutoCloseable { } /** - * Constructor to initialize a {@link BaseVisionTaskApi} only taking images as input. + * Constructor to initialize a {@link BaseVisionTaskApi}. * * @param runner a {@link TaskRunner}. * @param runningMode a mediapipe vision task {@link RunningMode}. * @param imageStreamName the name of the input image stream. - */ - public BaseVisionTaskApi(TaskRunner runner, RunningMode runningMode, String imageStreamName) { - this.runner = runner; - this.runningMode = runningMode; - this.imageStreamName = imageStreamName; - this.normRectStreamName = Optional.empty(); - } - - /** - * Constructor to initialize a {@link BaseVisionTaskApi} taking images and normalized rects as - * input. - * - * @param runner a {@link TaskRunner}. - * @param runningMode a mediapipe vision task {@link RunningMode}. - * @param imageStreamName the name of the input image stream. - * @param normRectStreamName the name of the input normalized rect image stream. + * @param normRectStreamName the name of the input normalized rect image stream used to provide + * (mandatory) rotation and (optional) region-of-interest. */ public BaseVisionTaskApi( TaskRunner runner, @@ -70,7 +55,7 @@ public class BaseVisionTaskApi implements AutoCloseable { this.runner = runner; this.runningMode = runningMode; this.imageStreamName = imageStreamName; - this.normRectStreamName = Optional.of(normRectStreamName); + this.normRectStreamName = normRectStreamName; } /** @@ -78,53 +63,23 @@ public class BaseVisionTaskApi implements AutoCloseable { * failure status or a successful result is returned. * * @param image a MediaPipe {@link MPImage} object for processing. - * @throws MediaPipeException if the task is not in the image mode or requires a normalized rect - * input. + * @param imageProcessingOptions the {@link ImageProcessingOptions} specifying how to process the + * input image before running inference. + * @throws MediaPipeException if the task is not in the image mode. */ - protected TaskResult processImageData(MPImage image) { + protected TaskResult processImageData( + MPImage image, ImageProcessingOptions imageProcessingOptions) { if (runningMode != RunningMode.IMAGE) { throw new MediaPipeException( MediaPipeException.StatusCode.FAILED_PRECONDITION.ordinal(), "Task is not initialized with the image mode. Current running mode:" + runningMode.name()); } - if (normRectStreamName.isPresent()) { - throw new MediaPipeException( - MediaPipeException.StatusCode.FAILED_PRECONDITION.ordinal(), - "Task expects a normalized rect as input."); - } - Map inputPackets = new HashMap<>(); - inputPackets.put(imageStreamName, runner.getPacketCreator().createImage(image)); - return runner.process(inputPackets); - } - - /** - * A synchronous method to process single image inputs. The call blocks the current thread until a - * failure status or a successful result is returned. - * - * @param image a MediaPipe {@link MPImage} object for processing. - * @param roi a {@link RectF} defining the region-of-interest to process in the image. Coordinates - * are expected to be specified as normalized values in [0,1]. - * @throws MediaPipeException if the task is not in the image mode or doesn't require a normalized - * rect. - */ - protected TaskResult processImageData(MPImage image, RectF roi) { - if (runningMode != RunningMode.IMAGE) { - throw new MediaPipeException( - MediaPipeException.StatusCode.FAILED_PRECONDITION.ordinal(), - "Task is not initialized with the image mode. Current running mode:" - + runningMode.name()); - } - if (!normRectStreamName.isPresent()) { - throw new MediaPipeException( - MediaPipeException.StatusCode.FAILED_PRECONDITION.ordinal(), - "Task doesn't expect a normalized rect as input."); - } Map inputPackets = new HashMap<>(); inputPackets.put(imageStreamName, runner.getPacketCreator().createImage(image)); inputPackets.put( - normRectStreamName.get(), - runner.getPacketCreator().createProto(convertToNormalizedRect(roi))); + normRectStreamName, + runner.getPacketCreator().createProto(convertToNormalizedRect(imageProcessingOptions))); return runner.process(inputPackets); } @@ -133,55 +88,24 @@ public class BaseVisionTaskApi implements AutoCloseable { * until a failure status or a successful result is returned. * * @param image a MediaPipe {@link MPImage} object for processing. + * @param imageProcessingOptions the {@link ImageProcessingOptions} specifying how to process the + * input image before running inference. * @param timestampMs the corresponding timestamp of the input image in milliseconds. - * @throws MediaPipeException if the task is not in the video mode or requires a normalized rect - * input. + * @throws MediaPipeException if the task is not in the video mode. */ - protected TaskResult processVideoData(MPImage image, long timestampMs) { + protected TaskResult processVideoData( + MPImage image, ImageProcessingOptions imageProcessingOptions, long timestampMs) { if (runningMode != RunningMode.VIDEO) { throw new MediaPipeException( MediaPipeException.StatusCode.FAILED_PRECONDITION.ordinal(), "Task is not initialized with the video mode. Current running mode:" + runningMode.name()); } - if (normRectStreamName.isPresent()) { - throw new MediaPipeException( - MediaPipeException.StatusCode.FAILED_PRECONDITION.ordinal(), - "Task expects a normalized rect as input."); - } - Map inputPackets = new HashMap<>(); - inputPackets.put(imageStreamName, runner.getPacketCreator().createImage(image)); - return runner.process(inputPackets, timestampMs * MICROSECONDS_PER_MILLISECOND); - } - - /** - * A synchronous method to process continuous video frames. The call blocks the current thread - * until a failure status or a successful result is returned. - * - * @param image a MediaPipe {@link MPImage} object for processing. - * @param roi a {@link RectF} defining the region-of-interest to process in the image. Coordinates - * are expected to be specified as normalized values in [0,1]. - * @param timestampMs the corresponding timestamp of the input image in milliseconds. - * @throws MediaPipeException if the task is not in the video mode or doesn't require a normalized - * rect. - */ - protected TaskResult processVideoData(MPImage image, RectF roi, long timestampMs) { - if (runningMode != RunningMode.VIDEO) { - throw new MediaPipeException( - MediaPipeException.StatusCode.FAILED_PRECONDITION.ordinal(), - "Task is not initialized with the video mode. Current running mode:" - + runningMode.name()); - } - if (!normRectStreamName.isPresent()) { - throw new MediaPipeException( - MediaPipeException.StatusCode.FAILED_PRECONDITION.ordinal(), - "Task doesn't expect a normalized rect as input."); - } Map inputPackets = new HashMap<>(); inputPackets.put(imageStreamName, runner.getPacketCreator().createImage(image)); inputPackets.put( - normRectStreamName.get(), - runner.getPacketCreator().createProto(convertToNormalizedRect(roi))); + normRectStreamName, + runner.getPacketCreator().createProto(convertToNormalizedRect(imageProcessingOptions))); return runner.process(inputPackets, timestampMs * MICROSECONDS_PER_MILLISECOND); } @@ -190,55 +114,24 @@ public class BaseVisionTaskApi implements AutoCloseable { * available in the user-defined result listener. * * @param image a MediaPipe {@link MPImage} object for processing. + * @param imageProcessingOptions the {@link ImageProcessingOptions} specifying how to process the + * input image before running inference. * @param timestampMs the corresponding timestamp of the input image in milliseconds. - * @throws MediaPipeException if the task is not in the video mode or requires a normalized rect - * input. + * @throws MediaPipeException if the task is not in the stream mode. */ - protected void sendLiveStreamData(MPImage image, long timestampMs) { + protected void sendLiveStreamData( + MPImage image, ImageProcessingOptions imageProcessingOptions, long timestampMs) { if (runningMode != RunningMode.LIVE_STREAM) { throw new MediaPipeException( MediaPipeException.StatusCode.FAILED_PRECONDITION.ordinal(), "Task is not initialized with the live stream mode. Current running mode:" + runningMode.name()); } - if (normRectStreamName.isPresent()) { - throw new MediaPipeException( - MediaPipeException.StatusCode.FAILED_PRECONDITION.ordinal(), - "Task expects a normalized rect as input."); - } - Map inputPackets = new HashMap<>(); - inputPackets.put(imageStreamName, runner.getPacketCreator().createImage(image)); - runner.send(inputPackets, timestampMs * MICROSECONDS_PER_MILLISECOND); - } - - /** - * An asynchronous method to send live stream data to the {@link TaskRunner}. The results will be - * available in the user-defined result listener. - * - * @param image a MediaPipe {@link MPImage} object for processing. - * @param roi a {@link RectF} defining the region-of-interest to process in the image. Coordinates - * are expected to be specified as normalized values in [0,1]. - * @param timestampMs the corresponding timestamp of the input image in milliseconds. - * @throws MediaPipeException if the task is not in the video mode or doesn't require a normalized - * rect. - */ - protected void sendLiveStreamData(MPImage image, RectF roi, long timestampMs) { - if (runningMode != RunningMode.LIVE_STREAM) { - throw new MediaPipeException( - MediaPipeException.StatusCode.FAILED_PRECONDITION.ordinal(), - "Task is not initialized with the live stream mode. Current running mode:" - + runningMode.name()); - } - if (!normRectStreamName.isPresent()) { - throw new MediaPipeException( - MediaPipeException.StatusCode.FAILED_PRECONDITION.ordinal(), - "Task doesn't expect a normalized rect as input."); - } Map inputPackets = new HashMap<>(); inputPackets.put(imageStreamName, runner.getPacketCreator().createImage(image)); inputPackets.put( - normRectStreamName.get(), - runner.getPacketCreator().createProto(convertToNormalizedRect(roi))); + normRectStreamName, + runner.getPacketCreator().createProto(convertToNormalizedRect(imageProcessingOptions))); runner.send(inputPackets, timestampMs * MICROSECONDS_PER_MILLISECOND); } @@ -248,13 +141,23 @@ public class BaseVisionTaskApi implements AutoCloseable { runner.close(); } - /** Converts a {@link RectF} object into a {@link NormalizedRect} protobuf message. */ - private static NormalizedRect convertToNormalizedRect(RectF rect) { + /** + * Converts an {@link ImageProcessingOptions} instance into a {@link NormalizedRect} protobuf + * message. + */ + private static NormalizedRect convertToNormalizedRect( + ImageProcessingOptions imageProcessingOptions) { + RectF regionOfInterest = + imageProcessingOptions.regionOfInterest().isPresent() + ? imageProcessingOptions.regionOfInterest().get() + : new RectF(0, 0, 1, 1); return NormalizedRect.newBuilder() - .setXCenter(rect.centerX()) - .setYCenter(rect.centerY()) - .setWidth(rect.width()) - .setHeight(rect.height()) + .setXCenter(regionOfInterest.centerX()) + .setYCenter(regionOfInterest.centerY()) + .setWidth(regionOfInterest.width()) + .setHeight(regionOfInterest.height()) + // Convert to radians anti-clockwise. + .setRotation(-(float) Math.PI * imageProcessingOptions.rotationDegrees() / 180.0f) .build(); } } diff --git a/mediapipe/tasks/java/com/google/mediapipe/tasks/vision/core/ImageProcessingOptions.java b/mediapipe/tasks/java/com/google/mediapipe/tasks/vision/core/ImageProcessingOptions.java new file mode 100644 index 000000000..a34a9787d --- /dev/null +++ b/mediapipe/tasks/java/com/google/mediapipe/tasks/vision/core/ImageProcessingOptions.java @@ -0,0 +1,92 @@ +// Copyright 2022 The MediaPipe Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package com.google.mediapipe.tasks.vision.core; + +import android.graphics.RectF; +import com.google.auto.value.AutoValue; +import java.util.Optional; + +// TODO: add support for image flipping. +/** Options for image processing. */ +@AutoValue +public abstract class ImageProcessingOptions { + + /** + * Builder for {@link ImageProcessingOptions}. + * + *

    If both region-of-interest and rotation are specified, the crop around the + * region-of-interest is extracted first, then the specified rotation is applied to the crop. + */ + @AutoValue.Builder + public abstract static class Builder { + /** + * Sets the optional region-of-interest to crop from the image. If not specified, the full image + * is used. + * + *

    Coordinates must be in [0,1], {@code left} must be < {@code right} and {@code top} must be + * < {@code bottom}, otherwise an IllegalArgumentException will be thrown when {@link #build()} + * is called. + */ + public abstract Builder setRegionOfInterest(RectF value); + + /** + * Sets the rotation to apply to the image (or cropped region-of-interest), in degrees + * clockwise. Defaults to 0. + * + *

    The rotation must be a multiple (positive or negative) of 90°, otherwise an + * IllegalArgumentException will be thrown when {@link #build()} is called. + */ + public abstract Builder setRotationDegrees(int value); + + abstract ImageProcessingOptions autoBuild(); + + /** + * Validates and builds the {@link ImageProcessingOptions} instance. + * + * @throws IllegalArgumentException if some of the provided values do not meet their + * requirements. + */ + public final ImageProcessingOptions build() { + ImageProcessingOptions options = autoBuild(); + if (options.regionOfInterest().isPresent()) { + RectF roi = options.regionOfInterest().get(); + if (roi.left >= roi.right || roi.top >= roi.bottom) { + throw new IllegalArgumentException( + String.format( + "Expected left < right and top < bottom, found: %s.", roi.toShortString())); + } + if (roi.left < 0 || roi.right > 1 || roi.top < 0 || roi.bottom > 1) { + throw new IllegalArgumentException( + String.format("Expected RectF values in [0,1], found: %s.", roi.toShortString())); + } + } + if (options.rotationDegrees() % 90 != 0) { + throw new IllegalArgumentException( + String.format( + "Expected rotation to be a multiple of 90°, found: %d.", + options.rotationDegrees())); + } + return options; + } + } + + public abstract Optional regionOfInterest(); + + public abstract int rotationDegrees(); + + public static Builder builder() { + return new AutoValue_ImageProcessingOptions.Builder().setRotationDegrees(0); + } +} diff --git a/mediapipe/tasks/java/com/google/mediapipe/tasks/vision/gesturerecognizer/GestureRecognizer.java b/mediapipe/tasks/java/com/google/mediapipe/tasks/vision/gesturerecognizer/GestureRecognizer.java index 55cf275e9..8e5a30eab 100644 --- a/mediapipe/tasks/java/com/google/mediapipe/tasks/vision/gesturerecognizer/GestureRecognizer.java +++ b/mediapipe/tasks/java/com/google/mediapipe/tasks/vision/gesturerecognizer/GestureRecognizer.java @@ -15,7 +15,6 @@ package com.google.mediapipe.tasks.vision.gesturerecognizer; import android.content.Context; -import android.graphics.RectF; import android.os.ParcelFileDescriptor; import com.google.auto.value.AutoValue; import com.google.mediapipe.formats.proto.LandmarkProto.LandmarkList; @@ -37,6 +36,7 @@ import com.google.mediapipe.tasks.core.TaskOptions; import com.google.mediapipe.tasks.core.TaskRunner; import com.google.mediapipe.tasks.core.proto.BaseOptionsProto; import com.google.mediapipe.tasks.vision.core.BaseVisionTaskApi; +import com.google.mediapipe.tasks.vision.core.ImageProcessingOptions; import com.google.mediapipe.tasks.vision.core.RunningMode; import com.google.mediapipe.tasks.vision.gesturerecognizer.proto.GestureClassifierGraphOptionsProto; import com.google.mediapipe.tasks.vision.gesturerecognizer.proto.GestureRecognizerGraphOptionsProto; @@ -212,6 +212,25 @@ public final class GestureRecognizer extends BaseVisionTaskApi { super(taskRunner, runningMode, IMAGE_IN_STREAM_NAME, NORM_RECT_IN_STREAM_NAME); } + /** + * Performs gesture recognition on the provided single image with default image processing + * options, i.e. without any rotation applied. Only use this method when the {@link + * GestureRecognizer} is created with {@link RunningMode.IMAGE}. TODO update java doc + * for input image format. + * + *

    {@link GestureRecognizer} supports the following color space types: + * + *

      + *
    • {@link Bitmap.Config.ARGB_8888} + *
    + * + * @param image a MediaPipe {@link MPImage} object for processing. + * @throws MediaPipeException if there is an internal error. + */ + public GestureRecognitionResult recognize(MPImage image) { + return recognize(image, ImageProcessingOptions.builder().build()); + } + /** * Performs gesture recognition on the provided single image. Only use this method when the {@link * GestureRecognizer} is created with {@link RunningMode.IMAGE}. TODO update java doc @@ -223,12 +242,41 @@ public final class GestureRecognizer extends BaseVisionTaskApi { *
  • {@link Bitmap.Config.ARGB_8888} * * - * @param inputImage a MediaPipe {@link MPImage} object for processing. + * @param image a MediaPipe {@link MPImage} object for processing. + * @param imageProcessingOptions the {@link ImageProcessingOptions} specifying how to process the + * input image before running inference. Note that region-of-interest is not supported + * by this task: specifying {@link ImageProcessingOptions#regionOfInterest()} will result in + * this method throwing an IllegalArgumentException. + * @throws IllegalArgumentException if the {@link ImageProcessingOptions} specify a + * region-of-interest. * @throws MediaPipeException if there is an internal error. */ - public GestureRecognitionResult recognize(MPImage inputImage) { - // TODO: add proper support for rotations. - return (GestureRecognitionResult) processImageData(inputImage, buildFullImageRectF()); + public GestureRecognitionResult recognize( + MPImage image, ImageProcessingOptions imageProcessingOptions) { + validateImageProcessingOptions(imageProcessingOptions); + return (GestureRecognitionResult) processImageData(image, imageProcessingOptions); + } + + /** + * Performs gesture recognition on the provided video frame with default image processing options, + * i.e. without any rotation applied. Only use this method when the {@link GestureRecognizer} is + * created with {@link RunningMode.VIDEO}. + * + *

    It's required to provide the video frame's timestamp (in milliseconds). The input timestamps + * must be monotonically increasing. + * + *

    {@link GestureRecognizer} supports the following color space types: + * + *

      + *
    • {@link Bitmap.Config.ARGB_8888} + *
    + * + * @param image a MediaPipe {@link MPImage} object for processing. + * @param timestampMs the input timestamp (in milliseconds). + * @throws MediaPipeException if there is an internal error. + */ + public GestureRecognitionResult recognizeForVideo(MPImage image, long timestampMs) { + return recognizeForVideo(image, ImageProcessingOptions.builder().build(), timestampMs); } /** @@ -244,14 +292,43 @@ public final class GestureRecognizer extends BaseVisionTaskApi { *
  • {@link Bitmap.Config.ARGB_8888} * * - * @param inputImage a MediaPipe {@link MPImage} object for processing. - * @param inputTimestampMs the input timestamp (in milliseconds). + * @param image a MediaPipe {@link MPImage} object for processing. + * @param imageProcessingOptions the {@link ImageProcessingOptions} specifying how to process the + * input image before running inference. Note that region-of-interest is not supported + * by this task: specifying {@link ImageProcessingOptions#regionOfInterest()} will result in + * this method throwing an IllegalArgumentException. + * @param timestampMs the input timestamp (in milliseconds). + * @throws IllegalArgumentException if the {@link ImageProcessingOptions} specify a + * region-of-interest. * @throws MediaPipeException if there is an internal error. */ - public GestureRecognitionResult recognizeForVideo(MPImage inputImage, long inputTimestampMs) { - // TODO: add proper support for rotations. - return (GestureRecognitionResult) - processVideoData(inputImage, buildFullImageRectF(), inputTimestampMs); + public GestureRecognitionResult recognizeForVideo( + MPImage image, ImageProcessingOptions imageProcessingOptions, long timestampMs) { + validateImageProcessingOptions(imageProcessingOptions); + return (GestureRecognitionResult) processVideoData(image, imageProcessingOptions, timestampMs); + } + + /** + * Sends live image data to perform gesture recognition with default image processing options, + * i.e. without any rotation applied, and the results will be available via the {@link + * ResultListener} provided in the {@link GestureRecognizerOptions}. Only use this method when the + * {@link GestureRecognition} is created with {@link RunningMode.LIVE_STREAM}. + * + *

    It's required to provide a timestamp (in milliseconds) to indicate when the input image is + * sent to the gesture recognizer. The input timestamps must be monotonically increasing. + * + *

    {@link GestureRecognizer} supports the following color space types: + * + *

      + *
    • {@link Bitmap.Config.ARGB_8888} + *
    + * + * @param image a MediaPipe {@link MPImage} object for processing. + * @param timestampMs the input timestamp (in milliseconds). + * @throws MediaPipeException if there is an internal error. + */ + public void recognizeAsync(MPImage image, long timestampMs) { + recognizeAsync(image, ImageProcessingOptions.builder().build(), timestampMs); } /** @@ -268,13 +345,20 @@ public final class GestureRecognizer extends BaseVisionTaskApi { *
  • {@link Bitmap.Config.ARGB_8888} * * - * @param inputImage a MediaPipe {@link MPImage} object for processing. - * @param inputTimestampMs the input timestamp (in milliseconds). + * @param image a MediaPipe {@link MPImage} object for processing. + * @param imageProcessingOptions the {@link ImageProcessingOptions} specifying how to process the + * input image before running inference. Note that region-of-interest is not supported + * by this task: specifying {@link ImageProcessingOptions#regionOfInterest()} will result in + * this method throwing an IllegalArgumentException. + * @param timestampMs the input timestamp (in milliseconds). + * @throws IllegalArgumentException if the {@link ImageProcessingOptions} specify a + * region-of-interest. * @throws MediaPipeException if there is an internal error. */ - public void recognizeAsync(MPImage inputImage, long inputTimestampMs) { - // TODO: add proper support for rotations. - sendLiveStreamData(inputImage, buildFullImageRectF(), inputTimestampMs); + public void recognizeAsync( + MPImage image, ImageProcessingOptions imageProcessingOptions, long timestampMs) { + validateImageProcessingOptions(imageProcessingOptions); + sendLiveStreamData(image, imageProcessingOptions, timestampMs); } /** Options for setting up an {@link GestureRecognizer}. */ @@ -445,8 +529,14 @@ public final class GestureRecognizer extends BaseVisionTaskApi { } } - /** Creates a RectF covering the full image. */ - private static RectF buildFullImageRectF() { - return new RectF(0, 0, 1, 1); + /** + * Validates that the provided {@link ImageProcessingOptions} doesn't contain a + * region-of-interest. + */ + private static void validateImageProcessingOptions( + ImageProcessingOptions imageProcessingOptions) { + if (imageProcessingOptions.regionOfInterest().isPresent()) { + throw new IllegalArgumentException("GestureRecognizer doesn't support region-of-interest."); + } } } diff --git a/mediapipe/tasks/java/com/google/mediapipe/tasks/vision/imageclassifier/ImageClassifier.java b/mediapipe/tasks/java/com/google/mediapipe/tasks/vision/imageclassifier/ImageClassifier.java index 75e2de13a..3863b6fe0 100644 --- a/mediapipe/tasks/java/com/google/mediapipe/tasks/vision/imageclassifier/ImageClassifier.java +++ b/mediapipe/tasks/java/com/google/mediapipe/tasks/vision/imageclassifier/ImageClassifier.java @@ -15,7 +15,6 @@ package com.google.mediapipe.tasks.vision.imageclassifier; import android.content.Context; -import android.graphics.RectF; import android.os.ParcelFileDescriptor; import com.google.auto.value.AutoValue; import com.google.mediapipe.proto.CalculatorOptionsProto.CalculatorOptions; @@ -37,6 +36,7 @@ import com.google.mediapipe.tasks.core.TaskOptions; import com.google.mediapipe.tasks.core.TaskRunner; import com.google.mediapipe.tasks.core.proto.BaseOptionsProto; import com.google.mediapipe.tasks.vision.core.BaseVisionTaskApi; +import com.google.mediapipe.tasks.vision.core.ImageProcessingOptions; import com.google.mediapipe.tasks.vision.core.RunningMode; import com.google.mediapipe.tasks.vision.imageclassifier.proto.ImageClassifierGraphOptionsProto; import java.io.File; @@ -215,6 +215,24 @@ public final class ImageClassifier extends BaseVisionTaskApi { super(taskRunner, runningMode, IMAGE_IN_STREAM_NAME, NORM_RECT_IN_STREAM_NAME); } + /** + * Performs classification on the provided single image with default image processing options, + * i.e. using the whole image as region-of-interest and without any rotation applied. Only use + * this method when the {@link ImageClassifier} is created with {@link RunningMode.IMAGE}. + * + *

    {@link ImageClassifier} supports the following color space types: + * + *

      + *
    • {@link Bitmap.Config.ARGB_8888} + *
    + * + * @param image a MediaPipe {@link MPImage} object for processing. + * @throws MediaPipeException if there is an internal error. + */ + public ImageClassificationResult classify(MPImage image) { + return classify(image, ImageProcessingOptions.builder().build()); + } + /** * Performs classification on the provided single image. Only use this method when the {@link * ImageClassifier} is created with {@link RunningMode.IMAGE}. @@ -225,16 +243,23 @@ public final class ImageClassifier extends BaseVisionTaskApi { *
  • {@link Bitmap.Config.ARGB_8888} * * - * @param inputImage a MediaPipe {@link MPImage} object for processing. + * @param image a MediaPipe {@link MPImage} object for processing. + * @param imageProcessingOptions the {@link ImageProcessingOptions} specifying how to process the + * input image before running inference. * @throws MediaPipeException if there is an internal error. */ - public ImageClassificationResult classify(MPImage inputImage) { - return (ImageClassificationResult) processImageData(inputImage, buildFullImageRectF()); + public ImageClassificationResult classify( + MPImage image, ImageProcessingOptions imageProcessingOptions) { + return (ImageClassificationResult) processImageData(image, imageProcessingOptions); } /** - * Performs classification on the provided single image and region-of-interest. Only use this - * method when the {@link ImageClassifier} is created with {@link RunningMode.IMAGE}. + * Performs classification on the provided video frame with default image processing options, i.e. + * using the whole image as region-of-interest and without any rotation applied. Only use this + * method when the {@link ImageClassifier} is created with {@link RunningMode.VIDEO}. + * + *

    It's required to provide the video frame's timestamp (in milliseconds). The input timestamps + * must be monotonically increasing. * *

    {@link ImageClassifier} supports the following color space types: * @@ -242,13 +267,12 @@ public final class ImageClassifier extends BaseVisionTaskApi { *

  • {@link Bitmap.Config.ARGB_8888} * * - * @param inputImage a MediaPipe {@link MPImage} object for processing. - * @param roi a {@link RectF} specifying the region of interest on which to perform - * classification. Coordinates are expected to be specified as normalized values in [0,1]. + * @param image a MediaPipe {@link MPImage} object for processing. + * @param timestampMs the input timestamp (in milliseconds). * @throws MediaPipeException if there is an internal error. */ - public ImageClassificationResult classify(MPImage inputImage, RectF roi) { - return (ImageClassificationResult) processImageData(inputImage, roi); + public ImageClassificationResult classifyForVideo(MPImage image, long timestampMs) { + return classifyForVideo(image, ImageProcessingOptions.builder().build(), timestampMs); } /** @@ -264,21 +288,26 @@ public final class ImageClassifier extends BaseVisionTaskApi { *
  • {@link Bitmap.Config.ARGB_8888} * * - * @param inputImage a MediaPipe {@link MPImage} object for processing. - * @param inputTimestampMs the input timestamp (in milliseconds). + * @param image a MediaPipe {@link MPImage} object for processing. + * @param imageProcessingOptions the {@link ImageProcessingOptions} specifying how to process the + * input image before running inference. + * @param timestampMs the input timestamp (in milliseconds). * @throws MediaPipeException if there is an internal error. */ - public ImageClassificationResult classifyForVideo(MPImage inputImage, long inputTimestampMs) { - return (ImageClassificationResult) - processVideoData(inputImage, buildFullImageRectF(), inputTimestampMs); + public ImageClassificationResult classifyForVideo( + MPImage image, ImageProcessingOptions imageProcessingOptions, long timestampMs) { + return (ImageClassificationResult) processVideoData(image, imageProcessingOptions, timestampMs); } /** - * Performs classification on the provided video frame with additional region-of-interest. Only - * use this method when the {@link ImageClassifier} is created with {@link RunningMode.VIDEO}. + * Sends live image data to perform classification with default image processing options, i.e. + * using the whole image as region-of-interest and without any rotation applied, and the results + * will be available via the {@link ResultListener} provided in the {@link + * ImageClassifierOptions}. Only use this method when the {@link ImageClassifier} is created with + * {@link RunningMode.LIVE_STREAM}. * - *

    It's required to provide the video frame's timestamp (in milliseconds). The input timestamps - * must be monotonically increasing. + *

    It's required to provide a timestamp (in milliseconds) to indicate when the input image is + * sent to the object detector. The input timestamps must be monotonically increasing. * *

    {@link ImageClassifier} supports the following color space types: * @@ -286,15 +315,12 @@ public final class ImageClassifier extends BaseVisionTaskApi { *

  • {@link Bitmap.Config.ARGB_8888} * * - * @param inputImage a MediaPipe {@link MPImage} object for processing. - * @param roi a {@link RectF} specifying the region of interest on which to perform - * classification. Coordinates are expected to be specified as normalized values in [0,1]. - * @param inputTimestampMs the input timestamp (in milliseconds). + * @param image a MediaPipe {@link MPImage} object for processing. + * @param timestampMs the input timestamp (in milliseconds). * @throws MediaPipeException if there is an internal error. */ - public ImageClassificationResult classifyForVideo( - MPImage inputImage, RectF roi, long inputTimestampMs) { - return (ImageClassificationResult) processVideoData(inputImage, roi, inputTimestampMs); + public void classifyAsync(MPImage image, long timestampMs) { + classifyAsync(image, ImageProcessingOptions.builder().build(), timestampMs); } /** @@ -311,37 +337,15 @@ public final class ImageClassifier extends BaseVisionTaskApi { *
  • {@link Bitmap.Config.ARGB_8888} * * - * @param inputImage a MediaPipe {@link MPImage} object for processing. - * @param inputTimestampMs the input timestamp (in milliseconds). + * @param image a MediaPipe {@link MPImage} object for processing. + * @param imageProcessingOptions the {@link ImageProcessingOptions} specifying how to process the + * input image before running inference. + * @param timestampMs the input timestamp (in milliseconds). * @throws MediaPipeException if there is an internal error. */ - public void classifyAsync(MPImage inputImage, long inputTimestampMs) { - sendLiveStreamData(inputImage, buildFullImageRectF(), inputTimestampMs); - } - - /** - * Sends live image data and additional region-of-interest to perform classification, and the - * results will be available via the {@link ResultListener} provided in the {@link - * ImageClassifierOptions}. Only use this method when the {@link ImageClassifier} is created with - * {@link RunningMode.LIVE_STREAM}. - * - *

    It's required to provide a timestamp (in milliseconds) to indicate when the input image is - * sent to the object detector. The input timestamps must be monotonically increasing. - * - *

    {@link ImageClassifier} supports the following color space types: - * - *

      - *
    • {@link Bitmap.Config.ARGB_8888} - *
    - * - * @param inputImage a MediaPipe {@link MPImage} object for processing. - * @param roi a {@link RectF} specifying the region of interest on which to perform - * classification. Coordinates are expected to be specified as normalized values in [0,1]. - * @param inputTimestampMs the input timestamp (in milliseconds). - * @throws MediaPipeException if there is an internal error. - */ - public void classifyAsync(MPImage inputImage, RectF roi, long inputTimestampMs) { - sendLiveStreamData(inputImage, roi, inputTimestampMs); + public void classifyAsync( + MPImage image, ImageProcessingOptions imageProcessingOptions, long timestampMs) { + sendLiveStreamData(image, imageProcessingOptions, timestampMs); } /** Options for setting up and {@link ImageClassifier}. */ @@ -447,9 +451,4 @@ public final class ImageClassifier extends BaseVisionTaskApi { .build(); } } - - /** Creates a RectF covering the full image. */ - private static RectF buildFullImageRectF() { - return new RectF(0, 0, 1, 1); - } } diff --git a/mediapipe/tasks/java/com/google/mediapipe/tasks/vision/objectdetector/ObjectDetector.java b/mediapipe/tasks/java/com/google/mediapipe/tasks/vision/objectdetector/ObjectDetector.java index 0f2e7b540..3f944eaee 100644 --- a/mediapipe/tasks/java/com/google/mediapipe/tasks/vision/objectdetector/ObjectDetector.java +++ b/mediapipe/tasks/java/com/google/mediapipe/tasks/vision/objectdetector/ObjectDetector.java @@ -32,6 +32,7 @@ import com.google.mediapipe.tasks.core.TaskOptions; import com.google.mediapipe.tasks.core.TaskRunner; import com.google.mediapipe.tasks.core.proto.BaseOptionsProto; import com.google.mediapipe.tasks.vision.core.BaseVisionTaskApi; +import com.google.mediapipe.tasks.vision.core.ImageProcessingOptions; import com.google.mediapipe.tasks.vision.core.RunningMode; import com.google.mediapipe.tasks.vision.objectdetector.proto.ObjectDetectorOptionsProto; import com.google.mediapipe.formats.proto.DetectionProto.Detection; @@ -96,8 +97,10 @@ import java.util.Optional; public final class ObjectDetector extends BaseVisionTaskApi { private static final String TAG = ObjectDetector.class.getSimpleName(); private static final String IMAGE_IN_STREAM_NAME = "image_in"; + private static final String NORM_RECT_IN_STREAM_NAME = "norm_rect_in"; private static final List INPUT_STREAMS = - Collections.unmodifiableList(Arrays.asList("IMAGE:" + IMAGE_IN_STREAM_NAME)); + Collections.unmodifiableList( + Arrays.asList("IMAGE:" + IMAGE_IN_STREAM_NAME, "NORM_RECT:" + NORM_RECT_IN_STREAM_NAME)); private static final List OUTPUT_STREAMS = Collections.unmodifiableList(Arrays.asList("DETECTIONS:detections_out", "IMAGE:image_out")); private static final int DETECTIONS_OUT_STREAM_INDEX = 0; @@ -204,7 +207,25 @@ public final class ObjectDetector extends BaseVisionTaskApi { * @param runningMode a mediapipe vision task {@link RunningMode}. */ private ObjectDetector(TaskRunner taskRunner, RunningMode runningMode) { - super(taskRunner, runningMode, IMAGE_IN_STREAM_NAME); + super(taskRunner, runningMode, IMAGE_IN_STREAM_NAME, NORM_RECT_IN_STREAM_NAME); + } + + /** + * Performs object detection on the provided single image with default image processing options, + * i.e. without any rotation applied. Only use this method when the {@link ObjectDetector} is + * created with {@link RunningMode.IMAGE}. + * + *

    {@link ObjectDetector} supports the following color space types: + * + *

      + *
    • {@link Bitmap.Config.ARGB_8888} + *
    + * + * @param image a MediaPipe {@link MPImage} object for processing. + * @throws MediaPipeException if there is an internal error. + */ + public ObjectDetectionResult detect(MPImage image) { + return detect(image, ImageProcessingOptions.builder().build()); } /** @@ -217,11 +238,41 @@ public final class ObjectDetector extends BaseVisionTaskApi { *
  • {@link Bitmap.Config.ARGB_8888} * * - * @param inputImage a MediaPipe {@link MPImage} object for processing. + * @param image a MediaPipe {@link MPImage} object for processing. + * @param imageProcessingOptions the {@link ImageProcessingOptions} specifying how to process the + * input image before running inference. Note that region-of-interest is not supported + * by this task: specifying {@link ImageProcessingOptions#regionOfInterest()} will result in + * this method throwing an IllegalArgumentException. + * @throws IllegalArgumentException if the {@link ImageProcessingOptions} specify a + * region-of-interest. * @throws MediaPipeException if there is an internal error. */ - public ObjectDetectionResult detect(MPImage inputImage) { - return (ObjectDetectionResult) processImageData(inputImage); + public ObjectDetectionResult detect( + MPImage image, ImageProcessingOptions imageProcessingOptions) { + validateImageProcessingOptions(imageProcessingOptions); + return (ObjectDetectionResult) processImageData(image, imageProcessingOptions); + } + + /** + * Performs object detection on the provided video frame with default image processing options, + * i.e. without any rotation applied. Only use this method when the {@link ObjectDetector} is + * created with {@link RunningMode.VIDEO}. + * + *

    It's required to provide the video frame's timestamp (in milliseconds). The input timestamps + * must be monotonically increasing. + * + *

    {@link ObjectDetector} supports the following color space types: + * + *

      + *
    • {@link Bitmap.Config.ARGB_8888} + *
    + * + * @param image a MediaPipe {@link MPImage} object for processing. + * @param timestampMs the input timestamp (in milliseconds). + * @throws MediaPipeException if there is an internal error. + */ + public ObjectDetectionResult detectForVideo(MPImage image, long timestampMs) { + return detectForVideo(image, ImageProcessingOptions.builder().build(), timestampMs); } /** @@ -237,12 +288,43 @@ public final class ObjectDetector extends BaseVisionTaskApi { *
  • {@link Bitmap.Config.ARGB_8888} * * - * @param inputImage a MediaPipe {@link MPImage} object for processing. - * @param inputTimestampMs the input timestamp (in milliseconds). + * @param image a MediaPipe {@link MPImage} object for processing. + * @param imageProcessingOptions the {@link ImageProcessingOptions} specifying how to process the + * input image before running inference. Note that region-of-interest is not supported + * by this task: specifying {@link ImageProcessingOptions#regionOfInterest()} will result in + * this method throwing an IllegalArgumentException. + * @param timestampMs the input timestamp (in milliseconds). + * @throws IllegalArgumentException if the {@link ImageProcessingOptions} specify a + * region-of-interest. * @throws MediaPipeException if there is an internal error. */ - public ObjectDetectionResult detectForVideo(MPImage inputImage, long inputTimestampMs) { - return (ObjectDetectionResult) processVideoData(inputImage, inputTimestampMs); + public ObjectDetectionResult detectForVideo( + MPImage image, ImageProcessingOptions imageProcessingOptions, long timestampMs) { + validateImageProcessingOptions(imageProcessingOptions); + return (ObjectDetectionResult) processVideoData(image, imageProcessingOptions, timestampMs); + } + + /** + * Sends live image data to perform object detection with default image processing options, i.e. + * without any rotation applied, and the results will be available via the {@link ResultListener} + * provided in the {@link ObjectDetectorOptions}. Only use this method when the {@link + * ObjectDetector} is created with {@link RunningMode.LIVE_STREAM}. + * + *

    It's required to provide a timestamp (in milliseconds) to indicate when the input image is + * sent to the object detector. The input timestamps must be monotonically increasing. + * + *

    {@link ObjectDetector} supports the following color space types: + * + *

      + *
    • {@link Bitmap.Config.ARGB_8888} + *
    + * + * @param image a MediaPipe {@link MPImage} object for processing. + * @param timestampMs the input timestamp (in milliseconds). + * @throws MediaPipeException if there is an internal error. + */ + public void detectAsync(MPImage image, long timestampMs) { + detectAsync(image, ImageProcessingOptions.builder().build(), timestampMs); } /** @@ -259,12 +341,20 @@ public final class ObjectDetector extends BaseVisionTaskApi { *
  • {@link Bitmap.Config.ARGB_8888} * * - * @param inputImage a MediaPipe {@link MPImage} object for processing. - * @param inputTimestampMs the input timestamp (in milliseconds). + * @param image a MediaPipe {@link MPImage} object for processing. + * @param imageProcessingOptions the {@link ImageProcessingOptions} specifying how to process the + * input image before running inference. Note that region-of-interest is not supported + * by this task: specifying {@link ImageProcessingOptions#regionOfInterest()} will result in + * this method throwing an IllegalArgumentException. + * @param timestampMs the input timestamp (in milliseconds). + * @throws IllegalArgumentException if the {@link ImageProcessingOptions} specify a + * region-of-interest. * @throws MediaPipeException if there is an internal error. */ - public void detectAsync(MPImage inputImage, long inputTimestampMs) { - sendLiveStreamData(inputImage, inputTimestampMs); + public void detectAsync( + MPImage image, ImageProcessingOptions imageProcessingOptions, long timestampMs) { + validateImageProcessingOptions(imageProcessingOptions); + sendLiveStreamData(image, imageProcessingOptions, timestampMs); } /** Options for setting up an {@link ObjectDetector}. */ @@ -415,4 +505,15 @@ public final class ObjectDetector extends BaseVisionTaskApi { .build(); } } + + /** + * Validates that the provided {@link ImageProcessingOptions} doesn't contain a + * region-of-interest. + */ + private static void validateImageProcessingOptions( + ImageProcessingOptions imageProcessingOptions) { + if (imageProcessingOptions.regionOfInterest().isPresent()) { + throw new IllegalArgumentException("ObjectDetector doesn't support region-of-interest."); + } + } } diff --git a/mediapipe/tasks/javatests/com/google/mediapipe/tasks/vision/core/AndroidManifest.xml b/mediapipe/tasks/javatests/com/google/mediapipe/tasks/vision/core/AndroidManifest.xml new file mode 100644 index 000000000..aa2df6baf --- /dev/null +++ b/mediapipe/tasks/javatests/com/google/mediapipe/tasks/vision/core/AndroidManifest.xml @@ -0,0 +1,24 @@ + + + + + + + + + + + + + + + diff --git a/mediapipe/tasks/javatests/com/google/mediapipe/tasks/vision/core/BUILD b/mediapipe/tasks/javatests/com/google/mediapipe/tasks/vision/core/BUILD new file mode 100644 index 000000000..a7f804c64 --- /dev/null +++ b/mediapipe/tasks/javatests/com/google/mediapipe/tasks/vision/core/BUILD @@ -0,0 +1,19 @@ +# Copyright 2022 The MediaPipe Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +package(default_visibility = ["//mediapipe/tasks:internal"]) + +licenses(["notice"]) + +# TODO: Enable this in OSS diff --git a/mediapipe/tasks/javatests/com/google/mediapipe/tasks/vision/core/ImageProcessingOptionsTest.java b/mediapipe/tasks/javatests/com/google/mediapipe/tasks/vision/core/ImageProcessingOptionsTest.java new file mode 100644 index 000000000..078b62af1 --- /dev/null +++ b/mediapipe/tasks/javatests/com/google/mediapipe/tasks/vision/core/ImageProcessingOptionsTest.java @@ -0,0 +1,70 @@ +// Copyright 2022 The MediaPipe Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package com.google.mediapipe.tasks.vision.core; + +import static com.google.common.truth.Truth.assertThat; +import static org.junit.Assert.assertThrows; + +import android.graphics.RectF; +import androidx.test.ext.junit.runners.AndroidJUnit4; +import org.junit.Test; +import org.junit.runner.RunWith; + +/** Test for {@link ImageProcessingOptions}/ */ +@RunWith(AndroidJUnit4.class) +public final class ImageProcessingOptionsTest { + + @Test + public void succeedsWithValidInputs() throws Exception { + ImageProcessingOptions options = + ImageProcessingOptions.builder() + .setRegionOfInterest(new RectF(0.0f, 0.1f, 1.0f, 0.9f)) + .setRotationDegrees(270) + .build(); + } + + @Test + public void failsWithLeftHigherThanRight() { + IllegalArgumentException exception = + assertThrows( + IllegalArgumentException.class, + () -> + ImageProcessingOptions.builder() + .setRegionOfInterest(new RectF(0.9f, 0.0f, 0.1f, 1.0f)) + .build()); + assertThat(exception).hasMessageThat().contains("Expected left < right and top < bottom"); + } + + @Test + public void failsWithBottomHigherThanTop() { + IllegalArgumentException exception = + assertThrows( + IllegalArgumentException.class, + () -> + ImageProcessingOptions.builder() + .setRegionOfInterest(new RectF(0.0f, 0.9f, 1.0f, 0.1f)) + .build()); + assertThat(exception).hasMessageThat().contains("Expected left < right and top < bottom"); + } + + @Test + public void failsWithInvalidRotation() { + IllegalArgumentException exception = + assertThrows( + IllegalArgumentException.class, + () -> ImageProcessingOptions.builder().setRotationDegrees(1).build()); + assertThat(exception).hasMessageThat().contains("Expected rotation to be a multiple of 90°"); + } +} diff --git a/mediapipe/tasks/javatests/com/google/mediapipe/tasks/vision/gesturerecognizer/GestureRecognizerTest.java b/mediapipe/tasks/javatests/com/google/mediapipe/tasks/vision/gesturerecognizer/GestureRecognizerTest.java index 31e59a259..eca5d35c2 100644 --- a/mediapipe/tasks/javatests/com/google/mediapipe/tasks/vision/gesturerecognizer/GestureRecognizerTest.java +++ b/mediapipe/tasks/javatests/com/google/mediapipe/tasks/vision/gesturerecognizer/GestureRecognizerTest.java @@ -19,6 +19,7 @@ import static org.junit.Assert.assertThrows; import android.content.res.AssetManager; import android.graphics.BitmapFactory; +import android.graphics.RectF; import androidx.test.core.app.ApplicationProvider; import androidx.test.ext.junit.runners.AndroidJUnit4; import com.google.common.truth.Correspondence; @@ -30,6 +31,7 @@ import com.google.mediapipe.tasks.components.containers.Category; import com.google.mediapipe.tasks.components.containers.Landmark; import com.google.mediapipe.tasks.components.containers.proto.LandmarksDetectionResultProto.LandmarksDetectionResult; import com.google.mediapipe.tasks.core.BaseOptions; +import com.google.mediapipe.tasks.vision.core.ImageProcessingOptions; import com.google.mediapipe.tasks.vision.core.RunningMode; import com.google.mediapipe.tasks.vision.gesturerecognizer.GestureRecognizer.GestureRecognizerOptions; import java.io.InputStream; @@ -46,11 +48,14 @@ public class GestureRecognizerTest { private static final String GESTURE_RECOGNIZER_BUNDLE_ASSET_FILE = "gesture_recognizer.task"; private static final String TWO_HANDS_IMAGE = "right_hands.jpg"; private static final String THUMB_UP_IMAGE = "thumb_up.jpg"; + private static final String POINTING_UP_ROTATED_IMAGE = "pointing_up_rotated.jpg"; private static final String NO_HANDS_IMAGE = "cats_and_dogs.jpg"; private static final String THUMB_UP_LANDMARKS = "thumb_up_landmarks.pb"; private static final String TAG = "Gesture Recognizer Test"; private static final String THUMB_UP_LABEL = "Thumb_Up"; private static final int THUMB_UP_INDEX = 5; + private static final String POINTING_UP_LABEL = "Pointing_Up"; + private static final int POINTING_UP_INDEX = 3; private static final float LANDMARKS_ERROR_TOLERANCE = 0.03f; private static final int IMAGE_WIDTH = 382; private static final int IMAGE_HEIGHT = 406; @@ -135,6 +140,53 @@ public class GestureRecognizerTest { gestureRecognizer.recognize(getImageFromAsset(TWO_HANDS_IMAGE)); assertThat(actualResult.handednesses()).hasSize(2); } + + @Test + public void recognize_successWithRotation() throws Exception { + GestureRecognizerOptions options = + GestureRecognizerOptions.builder() + .setBaseOptions( + BaseOptions.builder() + .setModelAssetPath(GESTURE_RECOGNIZER_BUNDLE_ASSET_FILE) + .build()) + .setNumHands(1) + .build(); + GestureRecognizer gestureRecognizer = + GestureRecognizer.createFromOptions(ApplicationProvider.getApplicationContext(), options); + ImageProcessingOptions imageProcessingOptions = + ImageProcessingOptions.builder().setRotationDegrees(-90).build(); + GestureRecognitionResult actualResult = + gestureRecognizer.recognize( + getImageFromAsset(POINTING_UP_ROTATED_IMAGE), imageProcessingOptions); + assertThat(actualResult.gestures()).hasSize(1); + assertThat(actualResult.gestures().get(0).get(0).index()).isEqualTo(POINTING_UP_INDEX); + assertThat(actualResult.gestures().get(0).get(0).categoryName()).isEqualTo(POINTING_UP_LABEL); + } + + @Test + public void recognize_failsWithRegionOfInterest() throws Exception { + GestureRecognizerOptions options = + GestureRecognizerOptions.builder() + .setBaseOptions( + BaseOptions.builder() + .setModelAssetPath(GESTURE_RECOGNIZER_BUNDLE_ASSET_FILE) + .build()) + .setNumHands(1) + .build(); + GestureRecognizer gestureRecognizer = + GestureRecognizer.createFromOptions(ApplicationProvider.getApplicationContext(), options); + ImageProcessingOptions imageProcessingOptions = + ImageProcessingOptions.builder().setRegionOfInterest(new RectF(0, 0, 1, 1)).build(); + IllegalArgumentException exception = + assertThrows( + IllegalArgumentException.class, + () -> + gestureRecognizer.recognize( + getImageFromAsset(THUMB_UP_IMAGE), imageProcessingOptions)); + assertThat(exception) + .hasMessageThat() + .contains("GestureRecognizer doesn't support region-of-interest"); + } } @RunWith(AndroidJUnit4.class) @@ -195,12 +247,16 @@ public class GestureRecognizerTest { MediaPipeException exception = assertThrows( MediaPipeException.class, - () -> gestureRecognizer.recognizeForVideo(getImageFromAsset(THUMB_UP_IMAGE), 0)); + () -> + gestureRecognizer.recognizeForVideo( + getImageFromAsset(THUMB_UP_IMAGE), /*timestampsMs=*/ 0)); assertThat(exception).hasMessageThat().contains("not initialized with the video mode"); exception = assertThrows( MediaPipeException.class, - () -> gestureRecognizer.recognizeAsync(getImageFromAsset(THUMB_UP_IMAGE), 0)); + () -> + gestureRecognizer.recognizeAsync( + getImageFromAsset(THUMB_UP_IMAGE), /*timestampsMs=*/ 0)); assertThat(exception).hasMessageThat().contains("not initialized with the live stream mode"); } @@ -225,7 +281,9 @@ public class GestureRecognizerTest { exception = assertThrows( MediaPipeException.class, - () -> gestureRecognizer.recognizeAsync(getImageFromAsset(THUMB_UP_IMAGE), 0)); + () -> + gestureRecognizer.recognizeAsync( + getImageFromAsset(THUMB_UP_IMAGE), /*timestampsMs=*/ 0)); assertThat(exception).hasMessageThat().contains("not initialized with the live stream mode"); } @@ -251,7 +309,9 @@ public class GestureRecognizerTest { exception = assertThrows( MediaPipeException.class, - () -> gestureRecognizer.recognizeForVideo(getImageFromAsset(THUMB_UP_IMAGE), 0)); + () -> + gestureRecognizer.recognizeForVideo( + getImageFromAsset(THUMB_UP_IMAGE), /*timestampsMs=*/ 0)); assertThat(exception).hasMessageThat().contains("not initialized with the video mode"); } @@ -291,7 +351,8 @@ public class GestureRecognizerTest { getExpectedGestureRecognitionResult(THUMB_UP_LANDMARKS, THUMB_UP_LABEL, THUMB_UP_INDEX); for (int i = 0; i < 3; i++) { GestureRecognitionResult actualResult = - gestureRecognizer.recognizeForVideo(getImageFromAsset(THUMB_UP_IMAGE), i); + gestureRecognizer.recognizeForVideo( + getImageFromAsset(THUMB_UP_IMAGE), /*timestampsMs=*/ i); assertActualResultApproximatelyEqualsToExpectedResult(actualResult, expectedResult); } } @@ -317,9 +378,11 @@ public class GestureRecognizerTest { .build(); try (GestureRecognizer gestureRecognizer = GestureRecognizer.createFromOptions(ApplicationProvider.getApplicationContext(), options)) { - gestureRecognizer.recognizeAsync(image, 1); + gestureRecognizer.recognizeAsync(image, /*timestampsMs=*/ 1); MediaPipeException exception = - assertThrows(MediaPipeException.class, () -> gestureRecognizer.recognizeAsync(image, 0)); + assertThrows( + MediaPipeException.class, + () -> gestureRecognizer.recognizeAsync(image, /*timestampsMs=*/ 0)); assertThat(exception) .hasMessageThat() .contains("having a smaller timestamp than the processed timestamp"); @@ -348,7 +411,7 @@ public class GestureRecognizerTest { try (GestureRecognizer gestureRecognizer = GestureRecognizer.createFromOptions(ApplicationProvider.getApplicationContext(), options)) { for (int i = 0; i < 3; i++) { - gestureRecognizer.recognizeAsync(image, i); + gestureRecognizer.recognizeAsync(image, /*timestampsMs=*/ i); } } } diff --git a/mediapipe/tasks/javatests/com/google/mediapipe/tasks/vision/imageclassifier/ImageClassifierTest.java b/mediapipe/tasks/javatests/com/google/mediapipe/tasks/vision/imageclassifier/ImageClassifierTest.java index 966e4ff4a..99ebd9777 100644 --- a/mediapipe/tasks/javatests/com/google/mediapipe/tasks/vision/imageclassifier/ImageClassifierTest.java +++ b/mediapipe/tasks/javatests/com/google/mediapipe/tasks/vision/imageclassifier/ImageClassifierTest.java @@ -29,6 +29,7 @@ import com.google.mediapipe.tasks.components.containers.Category; import com.google.mediapipe.tasks.components.processors.ClassifierOptions; import com.google.mediapipe.tasks.core.BaseOptions; import com.google.mediapipe.tasks.core.TestUtils; +import com.google.mediapipe.tasks.vision.core.ImageProcessingOptions; import com.google.mediapipe.tasks.vision.core.RunningMode; import com.google.mediapipe.tasks.vision.imageclassifier.ImageClassifier.ImageClassifierOptions; import java.io.InputStream; @@ -47,7 +48,9 @@ public class ImageClassifierTest { private static final String FLOAT_MODEL_FILE = "mobilenet_v2_1.0_224.tflite"; private static final String QUANTIZED_MODEL_FILE = "mobilenet_v1_0.25_224_quant.tflite"; private static final String BURGER_IMAGE = "burger.jpg"; + private static final String BURGER_ROTATED_IMAGE = "burger_rotated.jpg"; private static final String MULTI_OBJECTS_IMAGE = "multi_objects.jpg"; + private static final String MULTI_OBJECTS_ROTATED_IMAGE = "multi_objects_rotated.jpg"; @RunWith(AndroidJUnit4.class) public static final class General extends ImageClassifierTest { @@ -209,13 +212,60 @@ public class ImageClassifierTest { ImageClassifier.createFromOptions(ApplicationProvider.getApplicationContext(), options); // RectF around the soccer ball. RectF roi = new RectF(0.450f, 0.308f, 0.614f, 0.734f); + ImageProcessingOptions imageProcessingOptions = + ImageProcessingOptions.builder().setRegionOfInterest(roi).build(); ImageClassificationResult results = - imageClassifier.classify(getImageFromAsset(MULTI_OBJECTS_IMAGE), roi); + imageClassifier.classify(getImageFromAsset(MULTI_OBJECTS_IMAGE), imageProcessingOptions); assertHasOneHeadAndOneTimestamp(results, 0); assertCategoriesAre( results, Arrays.asList(Category.create(0.9969325f, 806, "soccer ball", ""))); } + + @Test + public void classify_succeedsWithRotation() throws Exception { + ImageClassifierOptions options = + ImageClassifierOptions.builder() + .setBaseOptions(BaseOptions.builder().setModelAssetPath(FLOAT_MODEL_FILE).build()) + .setClassifierOptions(ClassifierOptions.builder().setMaxResults(3).build()) + .build(); + ImageClassifier imageClassifier = + ImageClassifier.createFromOptions(ApplicationProvider.getApplicationContext(), options); + ImageProcessingOptions imageProcessingOptions = + ImageProcessingOptions.builder().setRotationDegrees(-90).build(); + ImageClassificationResult results = + imageClassifier.classify(getImageFromAsset(BURGER_ROTATED_IMAGE), imageProcessingOptions); + + assertHasOneHeadAndOneTimestamp(results, 0); + assertCategoriesAre( + results, + Arrays.asList( + Category.create(0.6390683f, 934, "cheeseburger", ""), + Category.create(0.0495407f, 963, "meat loaf", ""), + Category.create(0.0469720f, 925, "guacamole", ""))); + } + + @Test + public void classify_succeedsWithRegionOfInterestAndRotation() throws Exception { + ImageClassifierOptions options = + ImageClassifierOptions.builder() + .setBaseOptions(BaseOptions.builder().setModelAssetPath(FLOAT_MODEL_FILE).build()) + .setClassifierOptions(ClassifierOptions.builder().setMaxResults(1).build()) + .build(); + ImageClassifier imageClassifier = + ImageClassifier.createFromOptions(ApplicationProvider.getApplicationContext(), options); + // RectF around the chair. + RectF roi = new RectF(0.0f, 0.1763f, 0.5642f, 0.3049f); + ImageProcessingOptions imageProcessingOptions = + ImageProcessingOptions.builder().setRegionOfInterest(roi).setRotationDegrees(-90).build(); + ImageClassificationResult results = + imageClassifier.classify( + getImageFromAsset(MULTI_OBJECTS_ROTATED_IMAGE), imageProcessingOptions); + + assertHasOneHeadAndOneTimestamp(results, 0); + assertCategoriesAre( + results, Arrays.asList(Category.create(0.686824f, 560, "folding chair", ""))); + } } @RunWith(AndroidJUnit4.class) @@ -269,12 +319,16 @@ public class ImageClassifierTest { MediaPipeException exception = assertThrows( MediaPipeException.class, - () -> imageClassifier.classifyForVideo(getImageFromAsset(BURGER_IMAGE), 0)); + () -> + imageClassifier.classifyForVideo( + getImageFromAsset(BURGER_IMAGE), /*timestampMs=*/ 0)); assertThat(exception).hasMessageThat().contains("not initialized with the video mode"); exception = assertThrows( MediaPipeException.class, - () -> imageClassifier.classifyAsync(getImageFromAsset(BURGER_IMAGE), 0)); + () -> + imageClassifier.classifyAsync( + getImageFromAsset(BURGER_IMAGE), /*timestampMs=*/ 0)); assertThat(exception).hasMessageThat().contains("not initialized with the live stream mode"); } @@ -296,7 +350,9 @@ public class ImageClassifierTest { exception = assertThrows( MediaPipeException.class, - () -> imageClassifier.classifyAsync(getImageFromAsset(BURGER_IMAGE), 0)); + () -> + imageClassifier.classifyAsync( + getImageFromAsset(BURGER_IMAGE), /*timestampMs=*/ 0)); assertThat(exception).hasMessageThat().contains("not initialized with the live stream mode"); } @@ -320,7 +376,9 @@ public class ImageClassifierTest { exception = assertThrows( MediaPipeException.class, - () -> imageClassifier.classifyForVideo(getImageFromAsset(BURGER_IMAGE), 0)); + () -> + imageClassifier.classifyForVideo( + getImageFromAsset(BURGER_IMAGE), /*timestampMs=*/ 0)); assertThat(exception).hasMessageThat().contains("not initialized with the video mode"); } @@ -352,7 +410,8 @@ public class ImageClassifierTest { ImageClassifier imageClassifier = ImageClassifier.createFromOptions(ApplicationProvider.getApplicationContext(), options); for (int i = 0; i < 3; i++) { - ImageClassificationResult results = imageClassifier.classifyForVideo(image, i); + ImageClassificationResult results = + imageClassifier.classifyForVideo(image, /*timestampMs=*/ i); assertHasOneHeadAndOneTimestamp(results, i); assertCategoriesAre( results, Arrays.asList(Category.create(0.7952058f, 934, "cheeseburger", ""))); @@ -377,9 +436,11 @@ public class ImageClassifierTest { .build(); try (ImageClassifier imageClassifier = ImageClassifier.createFromOptions(ApplicationProvider.getApplicationContext(), options)) { - imageClassifier.classifyAsync(getImageFromAsset(BURGER_IMAGE), 1); + imageClassifier.classifyAsync(getImageFromAsset(BURGER_IMAGE), /*timestampMs=*/ 1); MediaPipeException exception = - assertThrows(MediaPipeException.class, () -> imageClassifier.classifyAsync(image, 0)); + assertThrows( + MediaPipeException.class, + () -> imageClassifier.classifyAsync(image, /*timestampMs=*/ 0)); assertThat(exception) .hasMessageThat() .contains("having a smaller timestamp than the processed timestamp"); @@ -405,7 +466,7 @@ public class ImageClassifierTest { try (ImageClassifier imageClassifier = ImageClassifier.createFromOptions(ApplicationProvider.getApplicationContext(), options)) { for (int i = 0; i < 3; ++i) { - imageClassifier.classifyAsync(image, i); + imageClassifier.classifyAsync(image, /*timestampMs=*/ i); } } } diff --git a/mediapipe/tasks/javatests/com/google/mediapipe/tasks/vision/objectdetector/ObjectDetectorTest.java b/mediapipe/tasks/javatests/com/google/mediapipe/tasks/vision/objectdetector/ObjectDetectorTest.java index 91ffa9273..2878c380d 100644 --- a/mediapipe/tasks/javatests/com/google/mediapipe/tasks/vision/objectdetector/ObjectDetectorTest.java +++ b/mediapipe/tasks/javatests/com/google/mediapipe/tasks/vision/objectdetector/ObjectDetectorTest.java @@ -29,6 +29,7 @@ import com.google.mediapipe.tasks.components.containers.Category; import com.google.mediapipe.tasks.components.containers.Detection; import com.google.mediapipe.tasks.core.BaseOptions; import com.google.mediapipe.tasks.core.TestUtils; +import com.google.mediapipe.tasks.vision.core.ImageProcessingOptions; import com.google.mediapipe.tasks.vision.core.RunningMode; import com.google.mediapipe.tasks.vision.objectdetector.ObjectDetector.ObjectDetectorOptions; import java.io.InputStream; @@ -45,10 +46,11 @@ import org.junit.runners.Suite.SuiteClasses; public class ObjectDetectorTest { private static final String MODEL_FILE = "coco_ssd_mobilenet_v1_1.0_quant_2018_06_29.tflite"; private static final String CAT_AND_DOG_IMAGE = "cats_and_dogs.jpg"; + private static final String CAT_AND_DOG_ROTATED_IMAGE = "cats_and_dogs_rotated.jpg"; private static final int IMAGE_WIDTH = 1200; private static final int IMAGE_HEIGHT = 600; private static final float CAT_SCORE = 0.69f; - private static final RectF catBoundingBox = new RectF(611, 164, 986, 596); + private static final RectF CAT_BOUNDING_BOX = new RectF(611, 164, 986, 596); // TODO: Figure out why android_x86 and android_arm tests have slightly different // scores (0.6875 vs 0.69921875). private static final float SCORE_DIFF_TOLERANCE = 0.01f; @@ -67,7 +69,7 @@ public class ObjectDetectorTest { ObjectDetector objectDetector = ObjectDetector.createFromOptions(ApplicationProvider.getApplicationContext(), options); ObjectDetectionResult results = objectDetector.detect(getImageFromAsset(CAT_AND_DOG_IMAGE)); - assertContainsOnlyCat(results, catBoundingBox, CAT_SCORE); + assertContainsOnlyCat(results, CAT_BOUNDING_BOX, CAT_SCORE); } @Test @@ -104,7 +106,7 @@ public class ObjectDetectorTest { ObjectDetector.createFromOptions(ApplicationProvider.getApplicationContext(), options); ObjectDetectionResult results = objectDetector.detect(getImageFromAsset(CAT_AND_DOG_IMAGE)); // The score threshold should block all other other objects, except cat. - assertContainsOnlyCat(results, catBoundingBox, CAT_SCORE); + assertContainsOnlyCat(results, CAT_BOUNDING_BOX, CAT_SCORE); } @Test @@ -175,7 +177,7 @@ public class ObjectDetectorTest { ObjectDetector objectDetector = ObjectDetector.createFromOptions(ApplicationProvider.getApplicationContext(), options); ObjectDetectionResult results = objectDetector.detect(getImageFromAsset(CAT_AND_DOG_IMAGE)); - assertContainsOnlyCat(results, catBoundingBox, CAT_SCORE); + assertContainsOnlyCat(results, CAT_BOUNDING_BOX, CAT_SCORE); } @Test @@ -228,6 +230,46 @@ public class ObjectDetectorTest { .contains("`category_allowlist` and `category_denylist` are mutually exclusive options."); } + @Test + public void detect_succeedsWithRotation() throws Exception { + ObjectDetectorOptions options = + ObjectDetectorOptions.builder() + .setBaseOptions(BaseOptions.builder().setModelAssetPath(MODEL_FILE).build()) + .setMaxResults(1) + .setCategoryAllowlist(Arrays.asList("cat")) + .build(); + ObjectDetector objectDetector = + ObjectDetector.createFromOptions(ApplicationProvider.getApplicationContext(), options); + ImageProcessingOptions imageProcessingOptions = + ImageProcessingOptions.builder().setRotationDegrees(-90).build(); + ObjectDetectionResult results = + objectDetector.detect( + getImageFromAsset(CAT_AND_DOG_ROTATED_IMAGE), imageProcessingOptions); + + assertContainsOnlyCat(results, new RectF(22.0f, 611.0f, 452.0f, 890.0f), 0.7109375f); + } + + @Test + public void detect_failsWithRegionOfInterest() throws Exception { + ObjectDetectorOptions options = + ObjectDetectorOptions.builder() + .setBaseOptions(BaseOptions.builder().setModelAssetPath(MODEL_FILE).build()) + .build(); + ObjectDetector objectDetector = + ObjectDetector.createFromOptions(ApplicationProvider.getApplicationContext(), options); + ImageProcessingOptions imageProcessingOptions = + ImageProcessingOptions.builder().setRegionOfInterest(new RectF(0, 0, 1, 1)).build(); + IllegalArgumentException exception = + assertThrows( + IllegalArgumentException.class, + () -> + objectDetector.detect( + getImageFromAsset(CAT_AND_DOG_IMAGE), imageProcessingOptions)); + assertThat(exception) + .hasMessageThat() + .contains("ObjectDetector doesn't support region-of-interest"); + } + // TODO: Implement detect_succeedsWithFloatImages, detect_succeedsWithOrientation, // detect_succeedsWithNumThreads, detect_successWithNumThreadsFromBaseOptions, // detect_failsWithInvalidNegativeNumThreads, detect_failsWithInvalidNumThreadsAsZero. @@ -282,12 +324,16 @@ public class ObjectDetectorTest { MediaPipeException exception = assertThrows( MediaPipeException.class, - () -> objectDetector.detectForVideo(getImageFromAsset(CAT_AND_DOG_IMAGE), 0)); + () -> + objectDetector.detectForVideo( + getImageFromAsset(CAT_AND_DOG_IMAGE), /*timestampsMs=*/ 0)); assertThat(exception).hasMessageThat().contains("not initialized with the video mode"); exception = assertThrows( MediaPipeException.class, - () -> objectDetector.detectAsync(getImageFromAsset(CAT_AND_DOG_IMAGE), 0)); + () -> + objectDetector.detectAsync( + getImageFromAsset(CAT_AND_DOG_IMAGE), /*timestampsMs=*/ 0)); assertThat(exception).hasMessageThat().contains("not initialized with the live stream mode"); } @@ -309,7 +355,9 @@ public class ObjectDetectorTest { exception = assertThrows( MediaPipeException.class, - () -> objectDetector.detectAsync(getImageFromAsset(CAT_AND_DOG_IMAGE), 0)); + () -> + objectDetector.detectAsync( + getImageFromAsset(CAT_AND_DOG_IMAGE), /*timestampsMs=*/ 0)); assertThat(exception).hasMessageThat().contains("not initialized with the live stream mode"); } @@ -333,7 +381,9 @@ public class ObjectDetectorTest { exception = assertThrows( MediaPipeException.class, - () -> objectDetector.detectForVideo(getImageFromAsset(CAT_AND_DOG_IMAGE), 0)); + () -> + objectDetector.detectForVideo( + getImageFromAsset(CAT_AND_DOG_IMAGE), /*timestampsMs=*/ 0)); assertThat(exception).hasMessageThat().contains("not initialized with the video mode"); } @@ -348,7 +398,7 @@ public class ObjectDetectorTest { ObjectDetector objectDetector = ObjectDetector.createFromOptions(ApplicationProvider.getApplicationContext(), options); ObjectDetectionResult results = objectDetector.detect(getImageFromAsset(CAT_AND_DOG_IMAGE)); - assertContainsOnlyCat(results, catBoundingBox, CAT_SCORE); + assertContainsOnlyCat(results, CAT_BOUNDING_BOX, CAT_SCORE); } @Test @@ -363,8 +413,9 @@ public class ObjectDetectorTest { ObjectDetector.createFromOptions(ApplicationProvider.getApplicationContext(), options); for (int i = 0; i < 3; i++) { ObjectDetectionResult results = - objectDetector.detectForVideo(getImageFromAsset(CAT_AND_DOG_IMAGE), i); - assertContainsOnlyCat(results, catBoundingBox, CAT_SCORE); + objectDetector.detectForVideo( + getImageFromAsset(CAT_AND_DOG_IMAGE), /*timestampsMs=*/ i); + assertContainsOnlyCat(results, CAT_BOUNDING_BOX, CAT_SCORE); } } @@ -377,16 +428,18 @@ public class ObjectDetectorTest { .setRunningMode(RunningMode.LIVE_STREAM) .setResultListener( (objectDetectionResult, inputImage) -> { - assertContainsOnlyCat(objectDetectionResult, catBoundingBox, CAT_SCORE); + assertContainsOnlyCat(objectDetectionResult, CAT_BOUNDING_BOX, CAT_SCORE); assertImageSizeIsExpected(inputImage); }) .setMaxResults(1) .build(); try (ObjectDetector objectDetector = ObjectDetector.createFromOptions(ApplicationProvider.getApplicationContext(), options)) { - objectDetector.detectAsync(image, 1); + objectDetector.detectAsync(image, /*timestampsMs=*/ 1); MediaPipeException exception = - assertThrows(MediaPipeException.class, () -> objectDetector.detectAsync(image, 0)); + assertThrows( + MediaPipeException.class, + () -> objectDetector.detectAsync(image, /*timestampsMs=*/ 0)); assertThat(exception) .hasMessageThat() .contains("having a smaller timestamp than the processed timestamp"); @@ -402,7 +455,7 @@ public class ObjectDetectorTest { .setRunningMode(RunningMode.LIVE_STREAM) .setResultListener( (objectDetectionResult, inputImage) -> { - assertContainsOnlyCat(objectDetectionResult, catBoundingBox, CAT_SCORE); + assertContainsOnlyCat(objectDetectionResult, CAT_BOUNDING_BOX, CAT_SCORE); assertImageSizeIsExpected(inputImage); }) .setMaxResults(1) @@ -410,7 +463,7 @@ public class ObjectDetectorTest { try (ObjectDetector objectDetector = ObjectDetector.createFromOptions(ApplicationProvider.getApplicationContext(), options)) { for (int i = 0; i < 3; i++) { - objectDetector.detectAsync(image, i); + objectDetector.detectAsync(image, /*timestampsMs=*/ i); } } } From 6b0a7fb281657960ac4078ae5c72617eb00fe156 Mon Sep 17 00:00:00 2001 From: MediaPipe Team Date: Mon, 24 Oct 2022 15:44:08 -0700 Subject: [PATCH 43/55] Reverting back to special handling for Egl Thread Exit on Android PiperOrigin-RevId: 483505151 --- mediapipe/gpu/gl_context_egl.cc | 7 +++++++ 1 file changed, 7 insertions(+) diff --git a/mediapipe/gpu/gl_context_egl.cc b/mediapipe/gpu/gl_context_egl.cc index 13710a688..78b196b08 100644 --- a/mediapipe/gpu/gl_context_egl.cc +++ b/mediapipe/gpu/gl_context_egl.cc @@ -38,13 +38,20 @@ static pthread_key_t egl_release_thread_key; static pthread_once_t egl_release_key_once = PTHREAD_ONCE_INIT; static void EglThreadExitCallback(void* key_value) { +#if defined(__ANDROID__) + eglMakeCurrent(EGL_NO_DISPLAY, EGL_NO_SURFACE, EGL_NO_SURFACE, + EGL_NO_CONTEXT); +#else // Some implementations have chosen to allow EGL_NO_DISPLAY as a valid display // parameter for eglMakeCurrent. This behavior is not portable to all EGL // implementations, and should be considered as an undocumented vendor // extension. // https://www.khronos.org/registry/EGL/sdk/docs/man/html/eglMakeCurrent.xhtml + // + // NOTE: crashes on some Android devices (occurs with libGLES_meow.so). eglMakeCurrent(eglGetDisplay(EGL_DEFAULT_DISPLAY), EGL_NO_SURFACE, EGL_NO_SURFACE, EGL_NO_CONTEXT); +#endif eglReleaseThread(); } From d240c009e2afa2849836498d129eca6e0d78f637 Mon Sep 17 00:00:00 2001 From: MediaPipe Team Date: Tue, 25 Oct 2022 08:20:42 -0700 Subject: [PATCH 44/55] Remove unnecessary location_data_proto dependency on rect_proto. PiperOrigin-RevId: 483679555 --- mediapipe/framework/formats/BUILD | 1 - 1 file changed, 1 deletion(-) diff --git a/mediapipe/framework/formats/BUILD b/mediapipe/framework/formats/BUILD index b967b27fb..c3241d911 100644 --- a/mediapipe/framework/formats/BUILD +++ b/mediapipe/framework/formats/BUILD @@ -293,7 +293,6 @@ mediapipe_proto_library( name = "rect_proto", srcs = ["rect.proto"], visibility = ["//visibility:public"], - deps = ["//mediapipe/framework/formats:location_data_proto"], ) mediapipe_register_type( From 21abfc9125cb69f4c029420732cd4d4958904ca7 Mon Sep 17 00:00:00 2001 From: MediaPipe Team Date: Tue, 25 Oct 2022 12:12:18 -0700 Subject: [PATCH 45/55] Update gpu origin. PiperOrigin-RevId: 483742652 --- mediapipe/tasks/cc/components/BUILD | 1 + mediapipe/tasks/cc/components/image_preprocessing.cc | 4 ++++ 2 files changed, 5 insertions(+) diff --git a/mediapipe/tasks/cc/components/BUILD b/mediapipe/tasks/cc/components/BUILD index e4905546a..f563fbf64 100644 --- a/mediapipe/tasks/cc/components/BUILD +++ b/mediapipe/tasks/cc/components/BUILD @@ -46,6 +46,7 @@ cc_library( "//mediapipe/framework/formats:image", "//mediapipe/framework/formats:rect_cc_proto", "//mediapipe/framework/formats:tensor", + "//mediapipe/gpu:gpu_origin_cc_proto", "//mediapipe/tasks/cc:common", "//mediapipe/tasks/cc/core:model_resources", "//mediapipe/tasks/cc/vision/utils:image_tensor_specs", diff --git a/mediapipe/tasks/cc/components/image_preprocessing.cc b/mediapipe/tasks/cc/components/image_preprocessing.cc index 046a97e4d..f3f3b6863 100644 --- a/mediapipe/tasks/cc/components/image_preprocessing.cc +++ b/mediapipe/tasks/cc/components/image_preprocessing.cc @@ -30,6 +30,7 @@ limitations under the License. #include "mediapipe/framework/formats/image.h" #include "mediapipe/framework/formats/rect.pb.h" #include "mediapipe/framework/formats/tensor.h" +#include "mediapipe/gpu/gpu_origin.pb.h" #include "mediapipe/tasks/cc/common.h" #include "mediapipe/tasks/cc/components/image_preprocessing_options.pb.h" #include "mediapipe/tasks/cc/core/model_resources.h" @@ -128,6 +129,9 @@ absl::Status ConfigureImageToTensorCalculator( options->mutable_output_tensor_float_range()->set_max((255.0f - mean) / std); } + // TODO: need to.support different GPU origin on differnt + // platforms or applications. + options->set_gpu_origin(mediapipe::GpuOrigin::TOP_LEFT); return absl::OkStatus(); } From 36bd9abb8f3f7c0dd4e5d54ebf573c2484cae666 Mon Sep 17 00:00:00 2001 From: MediaPipe Team Date: Tue, 25 Oct 2022 12:51:22 -0700 Subject: [PATCH 46/55] Internal change PiperOrigin-RevId: 483751427 --- .../tensors_to_classification_calculator.cc | 7 ----- ...tensors_to_classification_calculator.proto | 5 ---- ...nsors_to_classification_calculator_test.cc | 30 ------------------- .../framework/formats/classification.proto | 4 --- 4 files changed, 46 deletions(-) diff --git a/mediapipe/calculators/tensor/tensors_to_classification_calculator.cc b/mediapipe/calculators/tensor/tensors_to_classification_calculator.cc index 76d2869e8..5bfc00ed7 100644 --- a/mediapipe/calculators/tensor/tensors_to_classification_calculator.cc +++ b/mediapipe/calculators/tensor/tensors_to_classification_calculator.cc @@ -163,7 +163,6 @@ absl::Status TensorsToClassificationCalculator::Open(CalculatorContext* cc) { } absl::Status TensorsToClassificationCalculator::Process(CalculatorContext* cc) { - const auto& options = cc->Options(); const auto& input_tensors = *kInTensors(cc); RET_CHECK_EQ(input_tensors.size(), 1); RET_CHECK(input_tensors[0].element_type() == Tensor::ElementType::kFloat32); @@ -182,12 +181,6 @@ absl::Status TensorsToClassificationCalculator::Process(CalculatorContext* cc) { auto raw_scores = view.buffer(); auto classification_list = absl::make_unique(); - if (options.has_tensor_index()) { - classification_list->set_tensor_index(options.tensor_index()); - } - if (options.has_tensor_name()) { - classification_list->set_tensor_name(options.tensor_name()); - } if (is_binary_classification_) { Classification* class_first = classification_list->add_classification(); Classification* class_second = classification_list->add_classification(); diff --git a/mediapipe/calculators/tensor/tensors_to_classification_calculator.proto b/mediapipe/calculators/tensor/tensors_to_classification_calculator.proto index f0f7727ba..32bc4b63a 100644 --- a/mediapipe/calculators/tensor/tensors_to_classification_calculator.proto +++ b/mediapipe/calculators/tensor/tensors_to_classification_calculator.proto @@ -72,9 +72,4 @@ message TensorsToClassificationCalculatorOptions { // that are not in the `allow_classes` field will be completely ignored. // `ignore_classes` and `allow_classes` are mutually exclusive. repeated int32 allow_classes = 8 [packed = true]; - - // The optional index of the tensor these classifications originate from. - optional int32 tensor_index = 10; - // The optional name of the tensor these classifications originate from. - optional string tensor_name = 11; } diff --git a/mediapipe/calculators/tensor/tensors_to_classification_calculator_test.cc b/mediapipe/calculators/tensor/tensors_to_classification_calculator_test.cc index b20f2768c..9634635f0 100644 --- a/mediapipe/calculators/tensor/tensors_to_classification_calculator_test.cc +++ b/mediapipe/calculators/tensor/tensors_to_classification_calculator_test.cc @@ -240,36 +240,6 @@ TEST_F(TensorsToClassificationCalculatorTest, } } -TEST_F(TensorsToClassificationCalculatorTest, - CorrectOutputWithTensorNameAndIndex) { - mediapipe::CalculatorRunner runner(ParseTextProtoOrDie(R"pb( - calculator: "TensorsToClassificationCalculator" - input_stream: "TENSORS:tensors" - output_stream: "CLASSIFICATIONS:classifications" - options { - [mediapipe.TensorsToClassificationCalculatorOptions.ext] { - tensor_index: 1 - tensor_name: "foo" - } - } - )pb")); - - BuildGraph(&runner, {0, 0.5, 1}); - MP_ASSERT_OK(runner.Run()); - - const auto& output_packets_ = runner.Outputs().Tag("CLASSIFICATIONS").packets; - - EXPECT_EQ(1, output_packets_.size()); - - const auto& classification_list = - output_packets_[0].Get(); - EXPECT_EQ(3, classification_list.classification_size()); - - // Verify that the tensor_index and tensor_name fields are correctly set. - EXPECT_EQ(classification_list.tensor_index(), 1); - EXPECT_EQ(classification_list.tensor_name(), "foo"); -} - TEST_F(TensorsToClassificationCalculatorTest, ClassNameAllowlistWithLabelItems) { mediapipe::CalculatorRunner runner(ParseTextProtoOrDie(R"pb( diff --git a/mediapipe/framework/formats/classification.proto b/mediapipe/framework/formats/classification.proto index c3eea07ff..7efd9074d 100644 --- a/mediapipe/framework/formats/classification.proto +++ b/mediapipe/framework/formats/classification.proto @@ -37,10 +37,6 @@ message Classification { // Group of Classification protos. message ClassificationList { repeated Classification classification = 1; - // Optional index of the tensor that produced these classifications. - optional int32 tensor_index = 2; - // Optional name of the tensor that produced these classifications. - optional string tensor_name = 3; } // Group of ClassificationList protos. From a28c9d2c2697a387795d0cd721c99024923dbd18 Mon Sep 17 00:00:00 2001 From: MediaPipe Team Date: Tue, 25 Oct 2022 13:41:41 -0700 Subject: [PATCH 47/55] Set `steps_per_epoch` to None when calling model.fit() method for image classifier. PiperOrigin-RevId: 483764377 --- .../python/vision/image_classifier/train_image_classifier_lib.py | 1 - 1 file changed, 1 deletion(-) diff --git a/mediapipe/model_maker/python/vision/image_classifier/train_image_classifier_lib.py b/mediapipe/model_maker/python/vision/image_classifier/train_image_classifier_lib.py index 704d71a5a..265c36a6e 100644 --- a/mediapipe/model_maker/python/vision/image_classifier/train_image_classifier_lib.py +++ b/mediapipe/model_maker/python/vision/image_classifier/train_image_classifier_lib.py @@ -98,6 +98,5 @@ def train_model(model: tf.keras.Model, hparams: hp.HParams, return model.fit( x=train_ds, epochs=hparams.train_epochs, - steps_per_epoch=hparams.steps_per_epoch, validation_data=validation_ds, callbacks=callbacks) From 254f7866249cee0f30194046019d1d758a92be5e Mon Sep 17 00:00:00 2001 From: MediaPipe Team Date: Wed, 26 Oct 2022 00:44:18 -0700 Subject: [PATCH 48/55] Add an option to set image preprocessing backend as gpu. PiperOrigin-RevId: 483888202 --- mediapipe/tasks/cc/components/BUILD | 1 + .../tasks/cc/components/image_preprocessing.cc | 13 +++++++++++-- .../tasks/cc/components/image_preprocessing.h | 15 +++++++++++++-- .../vision/hand_detector/hand_detector_graph.cc | 4 +++- .../hand_landmarks_detector_graph.cc | 4 +++- .../image_classifier/image_classifier_graph.cc | 4 +++- .../vision/image_embedder/image_embedder_graph.cc | 4 +++- .../image_segmenter/image_segmenter_graph.cc | 4 +++- .../object_detector/object_detector_graph.cc | 4 +++- 9 files changed, 43 insertions(+), 10 deletions(-) diff --git a/mediapipe/tasks/cc/components/BUILD b/mediapipe/tasks/cc/components/BUILD index f563fbf64..344fafb4e 100644 --- a/mediapipe/tasks/cc/components/BUILD +++ b/mediapipe/tasks/cc/components/BUILD @@ -49,6 +49,7 @@ cc_library( "//mediapipe/gpu:gpu_origin_cc_proto", "//mediapipe/tasks/cc:common", "//mediapipe/tasks/cc/core:model_resources", + "//mediapipe/tasks/cc/core/proto:acceleration_cc_proto", "//mediapipe/tasks/cc/vision/utils:image_tensor_specs", "@com_google_absl//absl/status", "@com_google_absl//absl/status:statusor", diff --git a/mediapipe/tasks/cc/components/image_preprocessing.cc b/mediapipe/tasks/cc/components/image_preprocessing.cc index f3f3b6863..7940080e1 100644 --- a/mediapipe/tasks/cc/components/image_preprocessing.cc +++ b/mediapipe/tasks/cc/components/image_preprocessing.cc @@ -34,6 +34,7 @@ limitations under the License. #include "mediapipe/tasks/cc/common.h" #include "mediapipe/tasks/cc/components/image_preprocessing_options.pb.h" #include "mediapipe/tasks/cc/core/model_resources.h" +#include "mediapipe/tasks/cc/core/proto/acceleration.pb.h" #include "mediapipe/tasks/cc/vision/utils/image_tensor_specs.h" #include "tensorflow/lite/schema/schema_generated.h" @@ -129,7 +130,7 @@ absl::Status ConfigureImageToTensorCalculator( options->mutable_output_tensor_float_range()->set_max((255.0f - mean) / std); } - // TODO: need to.support different GPU origin on differnt + // TODO: need to support different GPU origin on differnt // platforms or applications. options->set_gpu_origin(mediapipe::GpuOrigin::TOP_LEFT); return absl::OkStatus(); @@ -137,7 +138,13 @@ absl::Status ConfigureImageToTensorCalculator( } // namespace +bool DetermineImagePreprocessingGpuBackend( + const core::proto::Acceleration& acceleration) { + return acceleration.has_gpu(); +} + absl::Status ConfigureImagePreprocessing(const ModelResources& model_resources, + bool use_gpu, ImagePreprocessingOptions* options) { ASSIGN_OR_RETURN(auto image_tensor_specs, BuildImageTensorSpecs(model_resources)); @@ -145,7 +152,9 @@ absl::Status ConfigureImagePreprocessing(const ModelResources& model_resources, image_tensor_specs, options->mutable_image_to_tensor_options())); // The GPU backend isn't able to process int data. If the input tensor is // quantized, forces the image preprocessing graph to use CPU backend. - if (image_tensor_specs.tensor_type == tflite::TensorType_UINT8) { + if (use_gpu && image_tensor_specs.tensor_type != tflite::TensorType_UINT8) { + options->set_backend(ImagePreprocessingOptions::GPU_BACKEND); + } else { options->set_backend(ImagePreprocessingOptions::CPU_BACKEND); } return absl::OkStatus(); diff --git a/mediapipe/tasks/cc/components/image_preprocessing.h b/mediapipe/tasks/cc/components/image_preprocessing.h index a5b767f3a..6963b6556 100644 --- a/mediapipe/tasks/cc/components/image_preprocessing.h +++ b/mediapipe/tasks/cc/components/image_preprocessing.h @@ -19,20 +19,26 @@ limitations under the License. #include "absl/status/status.h" #include "mediapipe/tasks/cc/components/image_preprocessing_options.pb.h" #include "mediapipe/tasks/cc/core/model_resources.h" +#include "mediapipe/tasks/cc/core/proto/acceleration.pb.h" namespace mediapipe { namespace tasks { namespace components { -// Configures an ImagePreprocessing subgraph using the provided model resources. +// Configures an ImagePreprocessing subgraph using the provided model resources +// When use_gpu is true, use GPU as backend to convert image to tensor. // - Accepts CPU input images and outputs CPU tensors. // // Example usage: // // auto& preprocessing = // graph.AddNode("mediapipe.tasks.components.ImagePreprocessingSubgraph"); +// core::proto::Acceleration acceleration; +// acceleration.mutable_xnnpack(); +// bool use_gpu = DetermineImagePreprocessingGpuBackend(acceleration); // MP_RETURN_IF_ERROR(ConfigureImagePreprocessing( // model_resources, +// use_gpu, // &preprocessing.GetOptions())); // // The resulting ImagePreprocessing subgraph has the following I/O: @@ -56,9 +62,14 @@ namespace components { // The image that has the pixel data stored on the target storage (CPU vs // GPU). absl::Status ConfigureImagePreprocessing( - const core::ModelResources& model_resources, + const core::ModelResources& model_resources, bool use_gpu, ImagePreprocessingOptions* options); +// Determine if the image preprocessing subgraph should use GPU as the backend +// according to the given acceleration setting. +bool DetermineImagePreprocessingGpuBackend( + const core::proto::Acceleration& acceleration); + } // namespace components } // namespace tasks } // namespace mediapipe diff --git a/mediapipe/tasks/cc/vision/hand_detector/hand_detector_graph.cc b/mediapipe/tasks/cc/vision/hand_detector/hand_detector_graph.cc index e876d7d09..06bb2e549 100644 --- a/mediapipe/tasks/cc/vision/hand_detector/hand_detector_graph.cc +++ b/mediapipe/tasks/cc/vision/hand_detector/hand_detector_graph.cc @@ -235,8 +235,10 @@ class HandDetectorGraph : public core::ModelTaskGraph { image_to_tensor_options.set_keep_aspect_ratio(true); image_to_tensor_options.set_border_mode( mediapipe::ImageToTensorCalculatorOptions::BORDER_ZERO); + bool use_gpu = components::DetermineImagePreprocessingGpuBackend( + subgraph_options.base_options().acceleration()); MP_RETURN_IF_ERROR(ConfigureImagePreprocessing( - model_resources, + model_resources, use_gpu, &preprocessing .GetOptions())); image_in >> preprocessing.In("IMAGE"); diff --git a/mediapipe/tasks/cc/vision/hand_landmarker/hand_landmarks_detector_graph.cc b/mediapipe/tasks/cc/vision/hand_landmarker/hand_landmarks_detector_graph.cc index 23521790d..1f127deb8 100644 --- a/mediapipe/tasks/cc/vision/hand_landmarker/hand_landmarks_detector_graph.cc +++ b/mediapipe/tasks/cc/vision/hand_landmarker/hand_landmarks_detector_graph.cc @@ -283,8 +283,10 @@ class SingleHandLandmarksDetectorGraph : public core::ModelTaskGraph { auto& preprocessing = graph.AddNode("mediapipe.tasks.components.ImagePreprocessingSubgraph"); + bool use_gpu = components::DetermineImagePreprocessingGpuBackend( + subgraph_options.base_options().acceleration()); MP_RETURN_IF_ERROR(ConfigureImagePreprocessing( - model_resources, + model_resources, use_gpu, &preprocessing .GetOptions())); image_in >> preprocessing.In("IMAGE"); diff --git a/mediapipe/tasks/cc/vision/image_classifier/image_classifier_graph.cc b/mediapipe/tasks/cc/vision/image_classifier/image_classifier_graph.cc index 9a0078c5c..8a1b17ce9 100644 --- a/mediapipe/tasks/cc/vision/image_classifier/image_classifier_graph.cc +++ b/mediapipe/tasks/cc/vision/image_classifier/image_classifier_graph.cc @@ -138,8 +138,10 @@ class ImageClassifierGraph : public core::ModelTaskGraph { // stream. auto& preprocessing = graph.AddNode("mediapipe.tasks.components.ImagePreprocessingSubgraph"); + bool use_gpu = components::DetermineImagePreprocessingGpuBackend( + task_options.base_options().acceleration()); MP_RETURN_IF_ERROR(ConfigureImagePreprocessing( - model_resources, + model_resources, use_gpu, &preprocessing .GetOptions())); image_in >> preprocessing.In(kImageTag); diff --git a/mediapipe/tasks/cc/vision/image_embedder/image_embedder_graph.cc b/mediapipe/tasks/cc/vision/image_embedder/image_embedder_graph.cc index fff0f4366..f0f440986 100644 --- a/mediapipe/tasks/cc/vision/image_embedder/image_embedder_graph.cc +++ b/mediapipe/tasks/cc/vision/image_embedder/image_embedder_graph.cc @@ -134,8 +134,10 @@ class ImageEmbedderGraph : public core::ModelTaskGraph { // stream. auto& preprocessing = graph.AddNode("mediapipe.tasks.components.ImagePreprocessingSubgraph"); + bool use_gpu = components::DetermineImagePreprocessingGpuBackend( + task_options.base_options().acceleration()); MP_RETURN_IF_ERROR(ConfigureImagePreprocessing( - model_resources, + model_resources, use_gpu, &preprocessing .GetOptions())); image_in >> preprocessing.In(kImageTag); diff --git a/mediapipe/tasks/cc/vision/image_segmenter/image_segmenter_graph.cc b/mediapipe/tasks/cc/vision/image_segmenter/image_segmenter_graph.cc index 629b940aa..d3e522d92 100644 --- a/mediapipe/tasks/cc/vision/image_segmenter/image_segmenter_graph.cc +++ b/mediapipe/tasks/cc/vision/image_segmenter/image_segmenter_graph.cc @@ -243,8 +243,10 @@ class ImageSegmenterGraph : public core::ModelTaskGraph { // stream. auto& preprocessing = graph.AddNode("mediapipe.tasks.components.ImagePreprocessingSubgraph"); + bool use_gpu = components::DetermineImagePreprocessingGpuBackend( + task_options.base_options().acceleration()); MP_RETURN_IF_ERROR(ConfigureImagePreprocessing( - model_resources, + model_resources, use_gpu, &preprocessing .GetOptions())); image_in >> preprocessing.In(kImageTag); diff --git a/mediapipe/tasks/cc/vision/object_detector/object_detector_graph.cc b/mediapipe/tasks/cc/vision/object_detector/object_detector_graph.cc index 07e912cfc..b149cea0f 100644 --- a/mediapipe/tasks/cc/vision/object_detector/object_detector_graph.cc +++ b/mediapipe/tasks/cc/vision/object_detector/object_detector_graph.cc @@ -563,8 +563,10 @@ class ObjectDetectorGraph : public core::ModelTaskGraph { // stream. auto& preprocessing = graph.AddNode("mediapipe.tasks.components.ImagePreprocessingSubgraph"); + bool use_gpu = components::DetermineImagePreprocessingGpuBackend( + task_options.base_options().acceleration()); MP_RETURN_IF_ERROR(ConfigureImagePreprocessing( - model_resources, + model_resources, use_gpu, &preprocessing .GetOptions())); image_in >> preprocessing.In(kImageTag); From a941c5cdd11be51fcd8e0b1a7bde17b6130c37bf Mon Sep 17 00:00:00 2001 From: Jiuqiang Tang Date: Wed, 26 Oct 2022 10:16:20 -0700 Subject: [PATCH 49/55] Create MediaPipe "tasks-text" AAR. PiperOrigin-RevId: 484004494 --- .../mediapipe/tasks/mediapipe_tasks_aar.bzl | 97 ++++++++++++++----- .../com/google/mediapipe/tasks/text/BUILD | 8 ++ 2 files changed, 79 insertions(+), 26 deletions(-) diff --git a/mediapipe/tasks/java/com/google/mediapipe/tasks/mediapipe_tasks_aar.bzl b/mediapipe/tasks/java/com/google/mediapipe/tasks/mediapipe_tasks_aar.bzl index e0b9c79ed..0260e3fab 100644 --- a/mediapipe/tasks/java/com/google/mediapipe/tasks/mediapipe_tasks_aar.bzl +++ b/mediapipe/tasks/java/com/google/mediapipe/tasks/mediapipe_tasks_aar.bzl @@ -40,6 +40,10 @@ _VISION_TASKS_JAVA_PROTO_LITE_TARGETS = [ "//mediapipe/tasks/cc/vision/hand_landmarker/proto:hand_landmarks_detector_graph_options_java_proto_lite", ] +_TEXT_TASKS_JAVA_PROTO_LITE_TARGETS = [ + "//mediapipe/tasks/cc/text/text_classifier/proto:text_classifier_graph_options_java_proto_lite", +] + def mediapipe_tasks_core_aar(name, srcs, manifest): """Builds medaipipe tasks core AAR. @@ -60,6 +64,11 @@ def mediapipe_tasks_core_aar(name, srcs, manifest): _mediapipe_tasks_java_proto_src_extractor(target = target), ) + for target in _TEXT_TASKS_JAVA_PROTO_LITE_TARGETS: + mediapipe_tasks_java_proto_srcs.append( + _mediapipe_tasks_java_proto_src_extractor(target = target), + ) + mediapipe_tasks_java_proto_srcs.append(mediapipe_java_proto_src_extractor( target = "//mediapipe/calculators/core:flow_limiter_calculator_java_proto_lite", src_out = "com/google/mediapipe/calculator/proto/FlowLimiterCalculatorProto.java", @@ -81,32 +90,35 @@ def mediapipe_tasks_core_aar(name, srcs, manifest): ], manifest = manifest, deps = [ - "//mediapipe/calculators/core:flow_limiter_calculator_java_proto_lite", - "//mediapipe/calculators/tensor:inference_calculator_java_proto_lite", - "//mediapipe/framework:calculator_java_proto_lite", - "//mediapipe/framework:calculator_profile_java_proto_lite", - "//mediapipe/framework:calculator_options_java_proto_lite", - "//mediapipe/framework:mediapipe_options_java_proto_lite", - "//mediapipe/framework:packet_factory_java_proto_lite", - "//mediapipe/framework:packet_generator_java_proto_lite", - "//mediapipe/framework:status_handler_java_proto_lite", - "//mediapipe/framework:stream_handler_java_proto_lite", - "//mediapipe/framework/formats:classification_java_proto_lite", - "//mediapipe/framework/formats:detection_java_proto_lite", - "//mediapipe/framework/formats:landmark_java_proto_lite", - "//mediapipe/framework/formats:location_data_java_proto_lite", - "//mediapipe/framework/formats:rect_java_proto_lite", - "//mediapipe/java/com/google/mediapipe/framework:android_framework", - "//mediapipe/java/com/google/mediapipe/framework/image", - "//mediapipe/tasks/java/com/google/mediapipe/tasks/core/jni:model_resources_cache_jni", - "//third_party:androidx_annotation", - "//third_party:autovalue", - "@com_google_protobuf//:protobuf_javalite", - "@maven//:com_google_guava_guava", - "@maven//:com_google_flogger_flogger", - "@maven//:com_google_flogger_flogger_system_backend", - "@maven//:com_google_code_findbugs_jsr305", - ] + _CORE_TASKS_JAVA_PROTO_LITE_TARGETS + _VISION_TASKS_JAVA_PROTO_LITE_TARGETS, + "//mediapipe/calculators/core:flow_limiter_calculator_java_proto_lite", + "//mediapipe/calculators/tensor:inference_calculator_java_proto_lite", + "//mediapipe/framework:calculator_java_proto_lite", + "//mediapipe/framework:calculator_profile_java_proto_lite", + "//mediapipe/framework:calculator_options_java_proto_lite", + "//mediapipe/framework:mediapipe_options_java_proto_lite", + "//mediapipe/framework:packet_factory_java_proto_lite", + "//mediapipe/framework:packet_generator_java_proto_lite", + "//mediapipe/framework:status_handler_java_proto_lite", + "//mediapipe/framework:stream_handler_java_proto_lite", + "//mediapipe/framework/formats:classification_java_proto_lite", + "//mediapipe/framework/formats:detection_java_proto_lite", + "//mediapipe/framework/formats:landmark_java_proto_lite", + "//mediapipe/framework/formats:location_data_java_proto_lite", + "//mediapipe/framework/formats:rect_java_proto_lite", + "//mediapipe/java/com/google/mediapipe/framework:android_framework", + "//mediapipe/java/com/google/mediapipe/framework/image", + "//mediapipe/tasks/java/com/google/mediapipe/tasks/core/jni:model_resources_cache_jni", + "//third_party:androidx_annotation", + "//third_party:autovalue", + "@com_google_protobuf//:protobuf_javalite", + "@maven//:com_google_guava_guava", + "@maven//:com_google_flogger_flogger", + "@maven//:com_google_flogger_flogger_system_backend", + "@maven//:com_google_code_findbugs_jsr305", + ] + + _CORE_TASKS_JAVA_PROTO_LITE_TARGETS + + _VISION_TASKS_JAVA_PROTO_LITE_TARGETS + + _TEXT_TASKS_JAVA_PROTO_LITE_TARGETS, ) def mediapipe_tasks_vision_aar(name, srcs, native_library): @@ -142,6 +154,39 @@ EOF native_library = native_library, ) +def mediapipe_tasks_text_aar(name, srcs, native_library): + """Builds medaipipe tasks text AAR. + + Args: + name: The bazel target name. + srcs: MediaPipe Text Tasks' source files. + native_library: The native library that contains text tasks' graph and calculators. + """ + + native.genrule( + name = name + "tasks_manifest_generator", + outs = ["AndroidManifest.xml"], + cmd = """ +cat > $(OUTS) < + + + +EOF +""", + ) + + _mediapipe_tasks_aar( + name = name, + srcs = srcs, + manifest = "AndroidManifest.xml", + java_proto_lite_targets = _CORE_TASKS_JAVA_PROTO_LITE_TARGETS + _TEXT_TASKS_JAVA_PROTO_LITE_TARGETS, + native_library = native_library, + ) + def _mediapipe_tasks_aar(name, srcs, manifest, java_proto_lite_targets, native_library): """Builds medaipipe tasks AAR.""" diff --git a/mediapipe/tasks/java/com/google/mediapipe/tasks/text/BUILD b/mediapipe/tasks/java/com/google/mediapipe/tasks/text/BUILD index 1719707d8..fa2a547c2 100644 --- a/mediapipe/tasks/java/com/google/mediapipe/tasks/text/BUILD +++ b/mediapipe/tasks/java/com/google/mediapipe/tasks/text/BUILD @@ -61,3 +61,11 @@ android_library( "@maven//:com_google_guava_guava", ], ) + +load("//mediapipe/tasks/java/com/google/mediapipe/tasks:mediapipe_tasks_aar.bzl", "mediapipe_tasks_text_aar") + +mediapipe_tasks_text_aar( + name = "tasks_text", + srcs = glob(["**/*.java"]), + native_library = ":libmediapipe_tasks_text_jni_lib", +) From f315e6dc5824dcb3fe8cd39655b4441806ecc1eb Mon Sep 17 00:00:00 2001 From: Jiuqiang Tang Date: Thu, 27 Oct 2022 10:14:34 -0700 Subject: [PATCH 50/55] Workaround to solve "-lphtread" linker error on Android. PiperOrigin-RevId: 484285361 --- WORKSPACE | 4 +++ ...oogle_sentencepiece_no_gflag_no_gtest.diff | 34 +++++++++++++++++++ 2 files changed, 38 insertions(+) create mode 100644 third_party/com_google_sentencepiece_no_gflag_no_gtest.diff diff --git a/WORKSPACE b/WORKSPACE index 146916c5c..5a47cf6b7 100644 --- a/WORKSPACE +++ b/WORKSPACE @@ -172,6 +172,10 @@ http_archive( urls = [ "https://github.com/google/sentencepiece/archive/1.0.0.zip", ], + patches = [ + "//third_party:com_google_sentencepiece_no_gflag_no_gtest.diff", + ], + patch_args = ["-p1"], repo_mapping = {"@com_google_glog" : "@com_github_glog_glog"}, ) diff --git a/third_party/com_google_sentencepiece_no_gflag_no_gtest.diff b/third_party/com_google_sentencepiece_no_gflag_no_gtest.diff new file mode 100644 index 000000000..a084d9262 --- /dev/null +++ b/third_party/com_google_sentencepiece_no_gflag_no_gtest.diff @@ -0,0 +1,34 @@ +diff --git a/src/BUILD b/src/BUILD +index b4298d2..f3877a3 100644 +--- a/src/BUILD ++++ b/src/BUILD +@@ -71,9 +71,7 @@ cc_library( + ":common", + ":sentencepiece_cc_proto", + ":sentencepiece_model_cc_proto", +- "@com_github_gflags_gflags//:gflags", + "@com_google_glog//:glog", +- "@com_google_googletest//:gtest", + "@com_google_absl//absl/memory", + "@com_google_absl//absl/strings", + "@com_google_absl//absl/container:flat_hash_map", +diff --git a/src/normalizer.h b/src/normalizer.h +index c16ac16..2af58be 100644 +--- a/src/normalizer.h ++++ b/src/normalizer.h +@@ -21,7 +21,6 @@ + #include + #include + +-#include "gtest/gtest_prod.h" + #include "absl/strings/string_view.h" + #include "third_party/darts_clone/include/darts.h" + #include "src/common.h" +@@ -97,7 +96,6 @@ class Normalizer { + friend class Builder; + + private: +- FRIEND_TEST(NormalizerTest, EncodeDecodePrecompiledCharsMapTest); + + void Init(); + From ee84e447b281342776b0aaf21826794f5b9ecdf7 Mon Sep 17 00:00:00 2001 From: MediaPipe Team Date: Thu, 27 Oct 2022 11:06:44 -0700 Subject: [PATCH 51/55] Internal change PiperOrigin-RevId: 484299808 --- mediapipe/framework/api2/builder.h | 18 ++++++++++++++++-- 1 file changed, 16 insertions(+), 2 deletions(-) diff --git a/mediapipe/framework/api2/builder.h b/mediapipe/framework/api2/builder.h index 82905d2f5..7dce211c8 100644 --- a/mediapipe/framework/api2/builder.h +++ b/mediapipe/framework/api2/builder.h @@ -289,8 +289,15 @@ class NodeBase { template T& GetOptions() { + return GetOptions(T::ext); + } + + // Use this API when the proto extension does not follow the "ext" naming + // convention. + template + auto& GetOptions(const E& extension) { options_used_ = true; - return *options_.MutableExtension(T::ext); + return *options_.MutableExtension(extension); } protected: @@ -386,8 +393,15 @@ class PacketGenerator { template T& GetOptions() { + return GetOptions(T::ext); + } + + // Use this API when the proto extension does not follow the "ext" naming + // convention. + template + auto& GetOptions(const E& extension) { options_used_ = true; - return *options_.MutableExtension(T::ext); + return *options_.MutableExtension(extension); } template From fc1d75cc99dc4bfc619225e2dd476a64fd357cc7 Mon Sep 17 00:00:00 2001 From: MediaPipe Team Date: Thu, 27 Oct 2022 11:14:42 -0700 Subject: [PATCH 52/55] Add CombinedPredictionCalculator. PiperOrigin-RevId: 484301880 --- .../gesture_recognizer/calculators/BUILD | 43 +++ .../combined_prediction_calculator.cc | 187 +++++++++++ .../combined_prediction_calculator.proto | 41 +++ .../combined_prediction_calculator_test.cc | 315 ++++++++++++++++++ 4 files changed, 586 insertions(+) create mode 100644 mediapipe/tasks/cc/vision/gesture_recognizer/calculators/combined_prediction_calculator.cc create mode 100644 mediapipe/tasks/cc/vision/gesture_recognizer/calculators/combined_prediction_calculator.proto create mode 100644 mediapipe/tasks/cc/vision/gesture_recognizer/calculators/combined_prediction_calculator_test.cc diff --git a/mediapipe/tasks/cc/vision/gesture_recognizer/calculators/BUILD b/mediapipe/tasks/cc/vision/gesture_recognizer/calculators/BUILD index 08f7f45d0..8c2c2e593 100644 --- a/mediapipe/tasks/cc/vision/gesture_recognizer/calculators/BUILD +++ b/mediapipe/tasks/cc/vision/gesture_recognizer/calculators/BUILD @@ -93,3 +93,46 @@ cc_test( "@com_google_absl//absl/strings", ], ) + +mediapipe_proto_library( + name = "combined_prediction_calculator_proto", + srcs = ["combined_prediction_calculator.proto"], + visibility = ["//visibility:public"], + deps = [ + "//mediapipe/framework:calculator_options_proto", + "//mediapipe/framework:calculator_proto", + ], +) + +cc_library( + name = "combined_prediction_calculator", + srcs = ["combined_prediction_calculator.cc"], + deps = [ + ":combined_prediction_calculator_cc_proto", + "//mediapipe/framework:calculator_framework", + "//mediapipe/framework:collection", + "//mediapipe/framework/api2:node", + "//mediapipe/framework/api2:packet", + "//mediapipe/framework/api2:port", + "//mediapipe/framework/formats:classification_cc_proto", + "@com_google_absl//absl/container:btree", + "@com_google_absl//absl/memory", + "@com_google_absl//absl/status", + "@com_google_absl//absl/strings:str_format", + ], + alwayslink = 1, +) + +cc_test( + name = "combined_prediction_calculator_test", + srcs = ["combined_prediction_calculator_test.cc"], + deps = [ + ":combined_prediction_calculator", + "//mediapipe/framework:calculator_framework", + "//mediapipe/framework:calculator_runner", + "//mediapipe/framework/formats:classification_cc_proto", + "//mediapipe/framework/port:gtest", + "//mediapipe/framework/port:gtest_main", + "@com_google_absl//absl/strings", + ], +) diff --git a/mediapipe/tasks/cc/vision/gesture_recognizer/calculators/combined_prediction_calculator.cc b/mediapipe/tasks/cc/vision/gesture_recognizer/calculators/combined_prediction_calculator.cc new file mode 100644 index 000000000..c7147ea6e --- /dev/null +++ b/mediapipe/tasks/cc/vision/gesture_recognizer/calculators/combined_prediction_calculator.cc @@ -0,0 +1,187 @@ +/* Copyright 2022 The MediaPipe Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include +#include +#include +#include + +#include "absl/container/btree_map.h" +#include "absl/memory/memory.h" +#include "absl/status/status.h" +#include "absl/strings/str_format.h" +#include "mediapipe/framework/api2/node.h" +#include "mediapipe/framework/api2/packet.h" +#include "mediapipe/framework/api2/port.h" +#include "mediapipe/framework/calculator_framework.h" +#include "mediapipe/framework/collection.h" +#include "mediapipe/framework/formats/classification.pb.h" +#include "mediapipe/tasks/cc/vision/gesture_recognizer/calculators/combined_prediction_calculator.pb.h" + +namespace mediapipe { +namespace api2 { +namespace { + +constexpr char kPredictionTag[] = "PREDICTION"; + +Classification GetMaxScoringClassification( + const ClassificationList& classifications) { + Classification max_classification; + max_classification.set_score(0); + for (const auto& input : classifications.classification()) { + if (max_classification.score() < input.score()) { + max_classification = input; + } + } + return max_classification; +} + +float GetScoreThreshold( + const std::string& input_label, + const absl::btree_map& classwise_thresholds, + const std::string& background_label, const float default_threshold) { + float threshold = default_threshold; + auto it = classwise_thresholds.find(input_label); + if (it != classwise_thresholds.end()) { + threshold = it->second; + } + return threshold; +} + +std::unique_ptr GetWinningPrediction( + const ClassificationList& classification_list, + const absl::btree_map& classwise_thresholds, + const std::string& background_label, const float default_threshold) { + auto prediction_list = std::make_unique(); + if (classification_list.classification().empty()) { + return prediction_list; + } + Classification& prediction = *prediction_list->add_classification(); + auto argmax_prediction = GetMaxScoringClassification(classification_list); + float argmax_prediction_thresh = + GetScoreThreshold(argmax_prediction.label(), classwise_thresholds, + background_label, default_threshold); + if (argmax_prediction.score() >= argmax_prediction_thresh) { + prediction.set_label(argmax_prediction.label()); + prediction.set_score(argmax_prediction.score()); + } else { + for (const auto& input : classification_list.classification()) { + if (input.label() == background_label) { + prediction.set_label(input.label()); + prediction.set_score(input.score()); + break; + } + } + } + return prediction_list; +} + +} // namespace + +// This calculator accepts multiple ClassificationList input streams. Each +// ClassificationList should contain classifications with labels and +// corresponding softmax scores. The calculator computes the best prediction for +// each ClassificationList input stream via argmax and thresholding. Thresholds +// for all classes can be specified in the +// `CombinedPredictionCalculatorOptions`, along with a default global +// threshold. +// Please note that for this calculator to work as designed, the class names +// other than the background class in the ClassificationList objects must be +// different, but the background class name has to be the same. This background +// label name can be set via `background_label` in +// `CombinedPredictionCalculatorOptions`. +// The ClassificationList in the PREDICTION output stream contains the label of +// the winning class and corresponding softmax score. If none of the +// ClassificationList objects has a non-background winning class, the output +// contains the background class and score of the background class in the first +// ClassificationList. If multiple ClassificationList objects have a +// non-background winning class, the output contains the winning prediction from +// the ClassificationList with the highest priority. Priority is in decreasing +// order of input streams to the graph node using this calculator. +// Input: +// At least one stream with ClassificationList. +// Output: +// PREDICTION - A ClassificationList with the winning label as the only item. +// +// Usage example: +// node { +// calculator: "CombinedPredictionCalculator" +// input_stream: "classification_list_0" +// input_stream: "classification_list_1" +// output_stream: "PREDICTION:prediction" +// options { +// [mediapipe.CombinedPredictionCalculatorOptions.ext] { +// class { +// label: "A" +// score_threshold: 0.7 +// } +// default_global_threshold: 0.1 +// background_label: "B" +// } +// } +// } + +class CombinedPredictionCalculator : public Node { + public: + static constexpr Input::Multiple kClassificationListIn{ + ""}; + static constexpr Output kPredictionOut{"PREDICTION"}; + MEDIAPIPE_NODE_CONTRACT(kClassificationListIn, kPredictionOut); + + absl::Status Open(CalculatorContext* cc) override { + options_ = cc->Options(); + for (const auto& input : options_.class_()) { + classwise_thresholds_[input.label()] = input.score_threshold(); + } + classwise_thresholds_[options_.background_label()] = 0; + return absl::OkStatus(); + } + + absl::Status Process(CalculatorContext* cc) override { + // After loop, if have winning prediction return. Otherwise empty packet. + std::unique_ptr first_winning_prediction = nullptr; + auto collection = kClassificationListIn(cc); + for (int idx = 0; idx < collection.Count(); ++idx) { + const auto& packet = collection[idx]; + if (packet.IsEmpty()) { + continue; + } + auto prediction = GetWinningPrediction( + packet.Get(), classwise_thresholds_, options_.background_label(), + options_.default_global_threshold()); + if (prediction->classification(0).label() != + options_.background_label()) { + kPredictionOut(cc).Send(std::move(prediction)); + return absl::OkStatus(); + } + if (first_winning_prediction == nullptr) { + first_winning_prediction = std::move(prediction); + } + } + if (first_winning_prediction != nullptr) { + kPredictionOut(cc).Send(std::move(first_winning_prediction)); + } + return absl::OkStatus(); + } + + private: + CombinedPredictionCalculatorOptions options_; + absl::btree_map classwise_thresholds_; +}; + +MEDIAPIPE_REGISTER_NODE(CombinedPredictionCalculator); + +} // namespace api2 +} // namespace mediapipe diff --git a/mediapipe/tasks/cc/vision/gesture_recognizer/calculators/combined_prediction_calculator.proto b/mediapipe/tasks/cc/vision/gesture_recognizer/calculators/combined_prediction_calculator.proto new file mode 100644 index 000000000..730e7dd78 --- /dev/null +++ b/mediapipe/tasks/cc/vision/gesture_recognizer/calculators/combined_prediction_calculator.proto @@ -0,0 +1,41 @@ +/* Copyright 2022 The MediaPipe Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +syntax = "proto2"; + +package mediapipe; + +import "mediapipe/framework/calculator.proto"; + +message CombinedPredictionCalculatorOptions { + extend mediapipe.CalculatorOptions { + optional CombinedPredictionCalculatorOptions ext = 483738635; + } + + message Class { + optional string label = 1; + optional float score_threshold = 2; + } + + // List of classes with score thresholds. + repeated Class class = 1; + + // Default score threshold applied to a label. + optional float default_global_threshold = 2 [default = 0]; + + // Name of the background class whose input scores will be ignored while + // thresholding. + optional string background_label = 3; +} diff --git a/mediapipe/tasks/cc/vision/gesture_recognizer/calculators/combined_prediction_calculator_test.cc b/mediapipe/tasks/cc/vision/gesture_recognizer/calculators/combined_prediction_calculator_test.cc new file mode 100644 index 000000000..ecf49795b --- /dev/null +++ b/mediapipe/tasks/cc/vision/gesture_recognizer/calculators/combined_prediction_calculator_test.cc @@ -0,0 +1,315 @@ +/* Copyright 2022 The MediaPipe Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include +#include +#include +#include +#include +#include + +#include "absl/strings/string_view.h" +#include "absl/strings/substitute.h" +#include "mediapipe/framework/calculator_framework.h" +#include "mediapipe/framework/calculator_runner.h" +#include "mediapipe/framework/formats/classification.pb.h" +#include "mediapipe/framework/port/status_matchers.h" + +namespace mediapipe { + +namespace { + +constexpr char kPredictionTag[] = "PREDICTION"; + +std::unique_ptr BuildNodeRunnerWithOptions( + float drama_thresh, float llama_thresh, float bazinga_thresh, + float joy_thresh, float peace_thresh) { + constexpr absl::string_view kCalculatorProto = R"pb( + calculator: "CombinedPredictionCalculator" + input_stream: "custom_softmax_scores" + input_stream: "canned_softmax_scores" + output_stream: "PREDICTION:prediction" + options { + [mediapipe.CombinedPredictionCalculatorOptions.ext] { + class { label: "CustomDrama" score_threshold: $0 } + class { label: "CustomLlama" score_threshold: $1 } + class { label: "CannedBazinga" score_threshold: $2 } + class { label: "CannedJoy" score_threshold: $3 } + class { label: "CannedPeace" score_threshold: $4 } + background_label: "Negative" + } + } + )pb"; + auto runner = std::make_unique( + absl::Substitute(kCalculatorProto, drama_thresh, llama_thresh, + bazinga_thresh, joy_thresh, peace_thresh)); + return runner; +} + +std::unique_ptr BuildCustomScoreInput( + const float negative_score, const float drama_score, + const float llama_score) { + auto custom_scores = std::make_unique(); + auto custom_negative = custom_scores->add_classification(); + custom_negative->set_label("Negative"); + custom_negative->set_score(negative_score); + auto drama = custom_scores->add_classification(); + drama->set_label("CustomDrama"); + drama->set_score(drama_score); + auto llama = custom_scores->add_classification(); + llama->set_label("CustomLlama"); + llama->set_score(llama_score); + return custom_scores; +} + +std::unique_ptr BuildCannedScoreInput( + const float negative_score, const float bazinga_score, + const float joy_score, const float peace_score) { + auto canned_scores = std::make_unique(); + auto canned_negative = canned_scores->add_classification(); + canned_negative->set_label("Negative"); + canned_negative->set_score(negative_score); + auto bazinga = canned_scores->add_classification(); + bazinga->set_label("CannedBazinga"); + bazinga->set_score(bazinga_score); + auto joy = canned_scores->add_classification(); + joy->set_label("CannedJoy"); + joy->set_score(joy_score); + auto peace = canned_scores->add_classification(); + peace->set_label("CannedPeace"); + peace->set_score(peace_score); + return canned_scores; +} + +TEST(CombinedPredictionCalculatorPacketTest, + CustomEmpty_CannedEmpty_ResultIsEmpty) { + auto runner = BuildNodeRunnerWithOptions( + /*drama_thresh=*/0.0, /*llama_thresh=*/0.0, /*bazinga_thresh=*/0.0, + /*joy_thresh=*/0.0, /*peace_thresh=*/0.0); + MP_ASSERT_OK(runner->Run()) << "Calculator execution failed."; + EXPECT_THAT(runner->Outputs().Tag("PREDICTION").packets, testing::IsEmpty()); +} + +TEST(CombinedPredictionCalculatorPacketTest, + CustomEmpty_CannedNotEmpty_ResultIsCanned) { + auto runner = BuildNodeRunnerWithOptions( + /*drama_thresh=*/0.0, /*llama_thresh=*/0.0, /*bazinga_thresh=*/0.9, + /*joy_thresh=*/0.5, /*peace_thresh=*/0.8); + auto canned_scores = BuildCannedScoreInput( + /*negative_score=*/0.1, + /*bazinga_score=*/0.1, /*joy_score=*/0.6, /*peace_score=*/0.2); + runner->MutableInputs()->Index(1).packets.push_back( + Adopt(canned_scores.release()).At(Timestamp(1))); + MP_ASSERT_OK(runner->Run()) << "Calculator execution failed."; + + auto output_prediction_packets = + runner->Outputs().Tag(kPredictionTag).packets; + ASSERT_EQ(output_prediction_packets.size(), 1); + Classification output_prediction = + output_prediction_packets[0].Get().classification(0); + + EXPECT_EQ(output_prediction.label(), "CannedJoy"); + EXPECT_NEAR(output_prediction.score(), 0.6, 1e-4); +} + +TEST(CombinedPredictionCalculatorPacketTest, + CustomNotEmpty_CannedEmpty_ResultIsCustom) { + auto runner = BuildNodeRunnerWithOptions( + /*drama_thresh=*/0.3, /*llama_thresh=*/0.5, /*bazinga_thresh=*/0.0, + /*joy_thresh=*/0.0, /*peace_thresh=*/0.0); + auto custom_scores = + BuildCustomScoreInput(/*negative_score=*/0.1, + /*drama_score=*/0.2, /*llama_score=*/0.7); + runner->MutableInputs()->Index(0).packets.push_back( + Adopt(custom_scores.release()).At(Timestamp(1))); + MP_ASSERT_OK(runner->Run()) << "Calculator execution failed."; + + auto output_prediction_packets = + runner->Outputs().Tag(kPredictionTag).packets; + ASSERT_EQ(output_prediction_packets.size(), 1); + Classification output_prediction = + output_prediction_packets[0].Get().classification(0); + + EXPECT_EQ(output_prediction.label(), "CustomLlama"); + EXPECT_NEAR(output_prediction.score(), 0.7, 1e-4); +} + +struct CombinedPredictionCalculatorTestCase { + std::string test_name; + float custom_negative_score; + float drama_score; + float llama_score; + float drama_thresh; + float llama_thresh; + float canned_negative_score; + float bazinga_score; + float joy_score; + float peace_score; + float bazinga_thresh; + float joy_thresh; + float peace_thresh; + std::string max_scoring_label; + float max_score; +}; + +using CombinedPredictionCalculatorTest = + testing::TestWithParam; + +TEST_P(CombinedPredictionCalculatorTest, OutputsCorrectResult) { + const CombinedPredictionCalculatorTestCase& test_case = GetParam(); + + auto runner = BuildNodeRunnerWithOptions( + test_case.drama_thresh, test_case.llama_thresh, test_case.bazinga_thresh, + test_case.joy_thresh, test_case.peace_thresh); + + auto custom_scores = + BuildCustomScoreInput(test_case.custom_negative_score, + test_case.drama_score, test_case.llama_score); + + runner->MutableInputs()->Index(0).packets.push_back( + Adopt(custom_scores.release()).At(Timestamp(1))); + + auto canned_scores = BuildCannedScoreInput( + test_case.canned_negative_score, test_case.bazinga_score, + test_case.joy_score, test_case.peace_score); + runner->MutableInputs()->Index(1).packets.push_back( + Adopt(canned_scores.release()).At(Timestamp(1))); + + MP_ASSERT_OK(runner->Run()) << "Calculator execution failed."; + + auto output_prediction_packets = + runner->Outputs().Tag(kPredictionTag).packets; + ASSERT_EQ(output_prediction_packets.size(), 1); + Classification output_prediction = + output_prediction_packets[0].Get().classification(0); + + EXPECT_EQ(output_prediction.label(), test_case.max_scoring_label); + EXPECT_NEAR(output_prediction.score(), test_case.max_score, 1e-4); +} + +INSTANTIATE_TEST_CASE_P( + CombinedPredictionCalculatorTests, CombinedPredictionCalculatorTest, + testing::ValuesIn({ + { + .test_name = "TestCustomDramaWinnnerWith_HighCanned_Thresh", + .custom_negative_score = 0.1, + .drama_score = 0.5, + .llama_score = 0.3, + .drama_thresh = 0.25, + .llama_thresh = 0.7, + .canned_negative_score = 0.1, + .bazinga_score = 0.3, + .joy_score = 0.3, + .peace_score = 0.3, + .bazinga_thresh = 0.7, + .joy_thresh = 0.7, + .peace_thresh = 0.7, + .max_scoring_label = "CustomDrama", + .max_score = 0.5, + }, + { + .test_name = "TestCannedWinnerWith_HighCustom_ZeroCanned_Thresh", + .custom_negative_score = 0.1, + .drama_score = 0.3, + .llama_score = 0.6, + .drama_thresh = 0.4, + .llama_thresh = 0.8, + .canned_negative_score = 0.1, + .bazinga_score = 0.4, + .joy_score = 0.3, + .peace_score = 0.2, + .bazinga_thresh = 0.0, + .joy_thresh = 0.0, + .peace_thresh = 0.0, + .max_scoring_label = "CannedBazinga", + .max_score = 0.4, + }, + { + .test_name = "TestNegativeWinnerWith_LowCustom_HighCanned_Thresh", + .custom_negative_score = 0.5, + .drama_score = 0.1, + .llama_score = 0.4, + .drama_thresh = 0.1, + .llama_thresh = 0.05, + .canned_negative_score = 0.1, + .bazinga_score = 0.3, + .joy_score = 0.3, + .peace_score = 0.3, + .bazinga_thresh = 0.7, + .joy_thresh = 0.7, + .peace_thresh = 0.7, + .max_scoring_label = "Negative", + .max_score = 0.5, + }, + { + .test_name = "TestNegativeWinnerWith_HighCustom_HighCanned_Thresh", + .custom_negative_score = 0.8, + .drama_score = 0.1, + .llama_score = 0.1, + .drama_thresh = 0.25, + .llama_thresh = 0.7, + .canned_negative_score = 0.1, + .bazinga_score = 0.3, + .joy_score = 0.3, + .peace_score = 0.3, + .bazinga_thresh = 0.7, + .joy_thresh = 0.7, + .peace_thresh = 0.7, + .max_scoring_label = "Negative", + .max_score = 0.8, + }, + { + .test_name = "TestNegativeWinnerWith_HighCustom_HighCannedThresh2", + .custom_negative_score = 0.1, + .drama_score = 0.2, + .llama_score = 0.7, + .drama_thresh = 1.1, + .llama_thresh = 1.1, + .canned_negative_score = 0.1, + .bazinga_score = 0.3, + .joy_score = 0.3, + .peace_score = 0.3, + .bazinga_thresh = 0.7, + .joy_thresh = 0.7, + .peace_thresh = 0.7, + .max_scoring_label = "Negative", + .max_score = 0.1, + }, + { + .test_name = "TestNegativeWinnerWith_HighCustom_HighCanned_Thresh3", + .custom_negative_score = 0.1, + .drama_score = 0.3, + .llama_score = 0.6, + .drama_thresh = 0.4, + .llama_thresh = 0.8, + .canned_negative_score = 0.3, + .bazinga_score = 0.2, + .joy_score = 0.3, + .peace_score = 0.2, + .bazinga_thresh = 0.5, + .joy_thresh = 0.5, + .peace_thresh = 0.5, + .max_scoring_label = "Negative", + .max_score = 0.1, + }, + }), + [](const testing::TestParamInfo< + CombinedPredictionCalculatorTest::ParamType>& info) { + return info.param.test_name; + }); + +} // namespace + +} // namespace mediapipe From abd1ff66c869bc4543f646c9bffcb5ed89c26d98 Mon Sep 17 00:00:00 2001 From: MediaPipe Team Date: Thu, 27 Oct 2022 15:30:56 -0700 Subject: [PATCH 53/55] Fix https://github.com/google/mediapipe/issues/3784 PiperOrigin-RevId: 484365654 --- mediapipe/modules/face_geometry/libs/effect_renderer.cc | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/mediapipe/modules/face_geometry/libs/effect_renderer.cc b/mediapipe/modules/face_geometry/libs/effect_renderer.cc index 27a54e011..73f473084 100644 --- a/mediapipe/modules/face_geometry/libs/effect_renderer.cc +++ b/mediapipe/modules/face_geometry/libs/effect_renderer.cc @@ -161,7 +161,7 @@ class Texture { ~Texture() { if (is_owned_) { - glDeleteProgram(handle_); + glDeleteTextures(1, &handle_); } } From de5fe27e05f20b52b5394e681f116210db9715d3 Mon Sep 17 00:00:00 2001 From: MediaPipe Team Date: Thu, 27 Oct 2022 19:06:04 -0700 Subject: [PATCH 54/55] Modified internal dependencies. PiperOrigin-RevId: 484407262 --- mediapipe/util/tflite/BUILD | 47 ++++++++++------------ mediapipe/util/tflite/tflite_gpu_runner.cc | 17 ++++---- mediapipe/util/tflite/tflite_gpu_runner.h | 17 ++++---- 3 files changed, 39 insertions(+), 42 deletions(-) diff --git a/mediapipe/util/tflite/BUILD b/mediapipe/util/tflite/BUILD index 9d37b60a0..e9b8bfa03 100644 --- a/mediapipe/util/tflite/BUILD +++ b/mediapipe/util/tflite/BUILD @@ -84,33 +84,28 @@ cc_library( "//conditions:default": ["tflite_gpu_runner.h"], }), deps = select({ - "//mediapipe:ios": [], - "//mediapipe:macos": [], - "//conditions:default": [ - "@com_google_absl//absl/strings", - "//mediapipe/framework/port:ret_check", - "//mediapipe/framework/port:status", - "//mediapipe/framework/port:statusor", - "@org_tensorflow//tensorflow/lite:framework", - "@org_tensorflow//tensorflow/lite/delegates/gpu:api", - "@org_tensorflow//tensorflow/lite/delegates/gpu/common:model", - "@org_tensorflow//tensorflow/lite/delegates/gpu/common:model_builder", - "@org_tensorflow//tensorflow/lite/delegates/gpu/gl:api2", - ], - "//mediapipe:android": [ - "@com_google_absl//absl/strings", - "//mediapipe/framework/port:ret_check", - "//mediapipe/framework/port:status", - "//mediapipe/framework/port:statusor", - "@org_tensorflow//tensorflow/lite:framework", - "@org_tensorflow//tensorflow/lite/delegates/gpu:api", - "@org_tensorflow//tensorflow/lite/delegates/gpu/cl:api", - "@org_tensorflow//tensorflow/lite/delegates/gpu/common:model", - "@org_tensorflow//tensorflow/lite/delegates/gpu/common:model_builder", - "@org_tensorflow//tensorflow/lite/delegates/gpu/gl:api2", - ], - }) + [ + "//mediapipe:ios": [], + "//mediapipe:macos": [], + "//conditions:default": [ + "@com_google_absl//absl/strings", + "//mediapipe/framework/port:ret_check", + "//mediapipe/framework/port:status", + "//mediapipe/framework/port:statusor", + "@org_tensorflow//tensorflow/lite:framework", + "@org_tensorflow//tensorflow/lite/delegates/gpu:api", + "@org_tensorflow//tensorflow/lite/delegates/gpu/common:model", + "@org_tensorflow//tensorflow/lite/delegates/gpu/common:model_builder", + "@org_tensorflow//tensorflow/lite/delegates/gpu/gl:api2", + ], + }) + + select({ + "//mediapipe:android": [ + "@org_tensorflow//tensorflow/lite/delegates/gpu/cl:api", + ], + "//conditions:default": [], + }) + [ "@com_google_absl//absl/status", + "//mediapipe/framework:port", "@org_tensorflow//tensorflow/lite/core/api", ], ) diff --git a/mediapipe/util/tflite/tflite_gpu_runner.cc b/mediapipe/util/tflite/tflite_gpu_runner.cc index 4c422835a..4e40975cb 100644 --- a/mediapipe/util/tflite/tflite_gpu_runner.cc +++ b/mediapipe/util/tflite/tflite_gpu_runner.cc @@ -34,7 +34,7 @@ // This code should be enabled as soon as TensorFlow version, which mediapipe // uses, will include this module. -#ifdef __ANDROID__ +#if defined(__ANDROID__) || defined(MEDIAPIPE_CHROMIUMOS) #include "tensorflow/lite/delegates/gpu/cl/api.h" #endif @@ -82,7 +82,7 @@ ObjectDef GetSSBOObjectDef(int channels) { return gpu_object_def; } -#ifdef __ANDROID__ +#if defined(__ANDROID__) || defined(MEDIAPIPE_CHROMIUMOS) cl::InferenceOptions GetClInferenceOptions(const InferenceOptions& options) { cl::InferenceOptions result{}; @@ -106,7 +106,7 @@ absl::Status VerifyShapes(const std::vector& actual, return absl::OkStatus(); } -#endif // __ANDROID__ +#endif // defined(__ANDROID__) || defined(MEDIAPIPE_CHROMIUMOS) } // namespace @@ -225,7 +225,7 @@ absl::Status TFLiteGPURunner::InitializeOpenGL( absl::Status TFLiteGPURunner::InitializeOpenCL( std::unique_ptr* builder) { -#ifdef __ANDROID__ +#if defined(__ANDROID__) || defined(MEDIAPIPE_CHROMIUMOS) cl::InferenceEnvironmentOptions env_options; if (!serialized_binary_cache_.empty()) { env_options.serialized_binary_cache = serialized_binary_cache_; @@ -254,11 +254,12 @@ absl::Status TFLiteGPURunner::InitializeOpenCL( return absl::OkStatus(); #else - return mediapipe::UnimplementedError("Currently only Android is supported"); -#endif // __ANDROID__ + return mediapipe::UnimplementedError( + "Currently only Android & ChromeOS are supported"); +#endif // defined(__ANDROID__) || defined(MEDIAPIPE_CHROMIUMOS) } -#ifdef __ANDROID__ +#if defined(__ANDROID__) || defined(MEDIAPIPE_CHROMIUMOS) absl::Status TFLiteGPURunner::InitializeOpenCLFromSerializedModel( std::unique_ptr* builder) { @@ -283,7 +284,7 @@ absl::StatusOr> TFLiteGPURunner::GetSerializedModel() { return serialized_model; } -#endif // __ANDROID__ +#endif // defined(__ANDROID__) || defined(MEDIAPIPE_CHROMIUMOS) } // namespace gpu } // namespace tflite diff --git a/mediapipe/util/tflite/tflite_gpu_runner.h b/mediapipe/util/tflite/tflite_gpu_runner.h index 88d3914f7..dfbc8d659 100644 --- a/mediapipe/util/tflite/tflite_gpu_runner.h +++ b/mediapipe/util/tflite/tflite_gpu_runner.h @@ -20,6 +20,7 @@ #include #include "absl/status/status.h" +#include "mediapipe/framework/port.h" #include "mediapipe/framework/port/status.h" #include "mediapipe/framework/port/statusor.h" #include "tensorflow/lite/core/api/op_resolver.h" @@ -28,9 +29,9 @@ #include "tensorflow/lite/delegates/gpu/gl/api2.h" #include "tensorflow/lite/model.h" -#ifdef __ANDROID__ +#if defined(__ANDROID__) || defined(MEDIAPIPE_CHROMIUMOS) #include "tensorflow/lite/delegates/gpu/cl/api.h" -#endif // __ANDROID__ +#endif // defined(__ANDROID__) || defined(MEDIAPIPE_CHROMIUMOS) namespace tflite { namespace gpu { @@ -83,7 +84,7 @@ class TFLiteGPURunner { return output_shape_from_model_; } -#ifdef __ANDROID__ +#if defined(__ANDROID__) || defined(MEDIAPIPE_CHROMIUMOS) void SetSerializedBinaryCache(std::vector&& cache) { serialized_binary_cache_ = std::move(cache); } @@ -98,26 +99,26 @@ class TFLiteGPURunner { } absl::StatusOr> GetSerializedModel(); -#endif // __ANDROID__ +#endif // defined(__ANDROID__) || defined(MEDIAPIPE_CHROMIUMOS) private: absl::Status InitializeOpenGL(std::unique_ptr* builder); absl::Status InitializeOpenCL(std::unique_ptr* builder); -#ifdef __ANDROID__ +#if defined(__ANDROID__) || defined(MEDIAPIPE_CHROMIUMOS) absl::Status InitializeOpenCLFromSerializedModel( std::unique_ptr* builder); -#endif // __ANDROID__ +#endif // defined(__ANDROID__) || defined(MEDIAPIPE_CHROMIUMOS) InferenceOptions options_; std::unique_ptr gl_environment_; -#ifdef __ANDROID__ +#if defined(__ANDROID__) || defined(MEDIAPIPE_CHROMIUMOS) std::unique_ptr cl_environment_; std::vector serialized_binary_cache_; std::vector serialized_model_; bool serialized_model_used_ = false; -#endif // __ANDROID__ +#endif // defined(__ANDROID__) || defined(MEDIAPIPE_CHROMIUMOS) // graph_gl_ is maintained temporarily and becomes invalid after runner_ is // ready From 87b201b1a38c63dc8bfb2c7dc535de095e53c7ed Mon Sep 17 00:00:00 2001 From: Mark McDonald Date: Thu, 27 Oct 2022 21:11:48 -0700 Subject: [PATCH 55/55] Adds a basic script for generating API docs from the `mediapipe` Python package This is pretty bare-bones for now, but I need something to start wiring up the rest of the automation. Plus this should be easy enough for anyone to riff on once it's in place. I couldn't find a great location within the existing directory structure for this, so LMK if it should be re-homed. PiperOrigin-RevId: 484426543 --- docs/BUILD | 14 +++++++ docs/build_py_api_docs.py | 85 +++++++++++++++++++++++++++++++++++++++ 2 files changed, 99 insertions(+) create mode 100644 docs/BUILD create mode 100644 docs/build_py_api_docs.py diff --git a/docs/BUILD b/docs/BUILD new file mode 100644 index 000000000..cb8794dab --- /dev/null +++ b/docs/BUILD @@ -0,0 +1,14 @@ +# Placeholder for internal Python strict binary compatibility macro. + +py_binary( + name = "build_py_api_docs", + srcs = ["build_py_api_docs.py"], + deps = [ + "//mediapipe", + "//third_party/py/absl:app", + "//third_party/py/absl/flags", + "//third_party/py/tensorflow_docs", + "//third_party/py/tensorflow_docs/api_generator:generate_lib", + "//third_party/py/tensorflow_docs/api_generator:public_api", + ], +) diff --git a/docs/build_py_api_docs.py b/docs/build_py_api_docs.py new file mode 100644 index 000000000..9911d0736 --- /dev/null +++ b/docs/build_py_api_docs.py @@ -0,0 +1,85 @@ +# Copyright 2022 The TensorFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +r"""MediaPipe reference docs generation script. + +This script generates API reference docs for the `mediapipe` PIP package. + +$> pip install -U git+https://github.com/tensorflow/docs mediapipe +$> python build_py_api_docs.py +""" + +import os + +from absl import app +from absl import flags + +from tensorflow_docs.api_generator import generate_lib +from tensorflow_docs.api_generator import public_api + +try: + # mediapipe has not been set up to work with bazel yet, so catch & report. + import mediapipe # pytype: disable=import-error +except ImportError as e: + raise ImportError('Please `pip install mediapipe`.') from e + + +PROJECT_SHORT_NAME = 'mp' +PROJECT_FULL_NAME = 'MediaPipe' + +_OUTPUT_DIR = flags.DEFINE_string( + 'output_dir', + default='/tmp/generated_docs', + help='Where to write the resulting docs.') + +_URL_PREFIX = flags.DEFINE_string( + 'code_url_prefix', + 'https://github.com/google/mediapipe/tree/master/mediapipe', + 'The url prefix for links to code.') + +_SEARCH_HINTS = flags.DEFINE_bool( + 'search_hints', True, + 'Include metadata search hints in the generated files') + +_SITE_PATH = flags.DEFINE_string('site_path', '/mediapipe/api_docs/python', + 'Path prefix in the _toc.yaml') + + +def gen_api_docs(): + """Generates API docs for the mediapipe package.""" + + doc_generator = generate_lib.DocGenerator( + root_title=PROJECT_FULL_NAME, + py_modules=[(PROJECT_SHORT_NAME, mediapipe)], + base_dir=os.path.dirname(mediapipe.__file__), + code_url_prefix=_URL_PREFIX.value, + search_hints=_SEARCH_HINTS.value, + site_path=_SITE_PATH.value, + # This callback ensures that docs are only generated for objects that + # are explicitly imported in your __init__.py files. There are other + # options but this is a good starting point. + callbacks=[public_api.explicit_package_contents_filter], + ) + + doc_generator.build(_OUTPUT_DIR.value) + + print('Docs output to:', _OUTPUT_DIR.value) + + +def main(_): + gen_api_docs() + + +if __name__ == '__main__': + app.run(main)