From 2a2a55d1b84c908120fe237b9e892c708d05c532 Mon Sep 17 00:00:00 2001 From: kinaryml Date: Fri, 21 Apr 2023 11:46:21 -0700 Subject: [PATCH 01/28] Added Language Detector Python API and fixed a typo in Interactive Segmenter Options' docstring --- mediapipe/tasks/python/test/text/BUILD | 15 ++ .../test/text/language_detector_test.py | 225 ++++++++++++++++++ mediapipe/tasks/python/text/BUILD | 20 ++ .../tasks/python/text/language_detector.py | 205 ++++++++++++++++ .../python/vision/interactive_segmenter.py | 2 +- mediapipe/tasks/testdata/text/BUILD | 6 + 6 files changed, 472 insertions(+), 1 deletion(-) create mode 100644 mediapipe/tasks/python/test/text/language_detector_test.py create mode 100644 mediapipe/tasks/python/text/language_detector.py diff --git a/mediapipe/tasks/python/test/text/BUILD b/mediapipe/tasks/python/test/text/BUILD index 5f2d18bc5..5f8551636 100644 --- a/mediapipe/tasks/python/test/text/BUILD +++ b/mediapipe/tasks/python/test/text/BUILD @@ -49,3 +49,18 @@ py_test( "//mediapipe/tasks/python/text:text_embedder", ], ) + +py_test( + name = "language_detector_test", + srcs = ["language_detector_test.py"], + data = [ + "//mediapipe/tasks/testdata/text:language_detector_model", + ], + deps = [ + "//mediapipe/tasks/python/components/containers:category", + "//mediapipe/tasks/python/components/containers:classification_result", + "//mediapipe/tasks/python/core:base_options", + "//mediapipe/tasks/python/test:test_utils", + "//mediapipe/tasks/python/text:language_detector", + ], +) diff --git a/mediapipe/tasks/python/test/text/language_detector_test.py b/mediapipe/tasks/python/test/text/language_detector_test.py new file mode 100644 index 000000000..2443d4312 --- /dev/null +++ b/mediapipe/tasks/python/test/text/language_detector_test.py @@ -0,0 +1,225 @@ +# Copyright 2022 The MediaPipe Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""Tests for language detector.""" + +import enum +import os + +from absl.testing import absltest +from absl.testing import parameterized + +from mediapipe.tasks.python.components.containers import category +from mediapipe.tasks.python.components.containers import classification_result as classification_result_module +from mediapipe.tasks.python.core import base_options as base_options_module +from mediapipe.tasks.python.test import test_utils +from mediapipe.tasks.python.text import language_detector + +LanguageDetectorResult = language_detector.LanguageDetectorResult +LanguageDetectorPrediction = language_detector.LanguageDetectorResult.LanguageDetectorPrediction +_BaseOptions = base_options_module.BaseOptions +_Category = category.Category +_Classifications = classification_result_module.Classifications +_LanguageDetector = language_detector.LanguageDetector +_LanguageDetectorOptions = language_detector.LanguageDetectorOptions + +_LANGUAGE_DETECTOR_MODEL = 'language_detector.tflite' +_TEST_DATA_DIR = 'mediapipe/tasks/testdata/text' + +_SCORE_THRESHOLD = 0.3 +_EN_TEXT = "To be, or not to be, that is the question" +_EN_EXPECTED_RESULT = LanguageDetectorResult( + [ + LanguageDetectorPrediction("en", 0.999856) + ] +) +_FR_TEXT = ( + "Il y a beaucoup de bouches qui parlent et fort peu de têtes qui pensent." +) +_FR_EXPECTED_RESULT = LanguageDetectorResult( + [ + LanguageDetectorPrediction("fr", 0.999781) + ] +) +_RU_TEXT = "это какой-то английский язык" +_RU_EXPECTED_RESULT = LanguageDetectorResult( + [ + LanguageDetectorPrediction("ru", 0.993362) + ] +) +_MIXED_TEXT = "分久必合合久必分" +_MIXED_EXPECTED_RESULT = LanguageDetectorResult( + [ + LanguageDetectorPrediction("zh", 0.505424), + LanguageDetectorPrediction("ja", 0.481617) + ] +) +_TOLERANCE = 1e-6 + + +class ModelFileType(enum.Enum): + FILE_CONTENT = 1 + FILE_NAME = 2 + + +class LanguageDetectorTest(parameterized.TestCase): + + def setUp(self): + super().setUp() + self.model_path = test_utils.get_test_data_path( + os.path.join(_TEST_DATA_DIR, _LANGUAGE_DETECTOR_MODEL)) + + def _expect_language_detector_result_correct( + self, + actual_result: LanguageDetectorResult, + expect_result: LanguageDetectorResult + ): + for i, prediction in enumerate(actual_result.languages_and_scores): + expected_prediction = expect_result.languages_and_scores[i] + self.assertEqual( + prediction.language_code, expected_prediction.language_code, + ) + self.assertAlmostEqual( + prediction.probability, expected_prediction.probability, + delta=_TOLERANCE + ) + + def test_create_from_file_succeeds_with_valid_model_path(self): + # Creates with default option and valid model file successfully. + with _LanguageDetector.create_from_model_path(self.model_path) as detector: + self.assertIsInstance(detector, _LanguageDetector) + + def test_create_from_options_succeeds_with_valid_model_path(self): + # Creates with options containing model file successfully. + base_options = _BaseOptions(model_asset_path=self.model_path) + options = _LanguageDetectorOptions(base_options=base_options) + with _LanguageDetector.create_from_options(options) as detector: + self.assertIsInstance(detector, _LanguageDetector) + + def test_create_from_options_fails_with_invalid_model_path(self): + with self.assertRaisesRegex( + RuntimeError, 'Unable to open file at /path/to/invalid/model.tflite'): + base_options = _BaseOptions( + model_asset_path='/path/to/invalid/model.tflite') + options = _LanguageDetectorOptions(base_options=base_options) + _LanguageDetector.create_from_options(options) + + def test_create_from_options_succeeds_with_valid_model_content(self): + # Creates with options containing model content successfully. + with open(self.model_path, 'rb') as f: + base_options = _BaseOptions(model_asset_buffer=f.read()) + options = _LanguageDetectorOptions(base_options=base_options) + detector = _LanguageDetector.create_from_options(options) + self.assertIsInstance(detector, _LanguageDetector) + + @parameterized.parameters( + (ModelFileType.FILE_NAME, _EN_TEXT, _EN_EXPECTED_RESULT), + (ModelFileType.FILE_CONTENT, _EN_TEXT, _EN_EXPECTED_RESULT), + (ModelFileType.FILE_NAME, _FR_TEXT, _FR_EXPECTED_RESULT), + (ModelFileType.FILE_CONTENT, _FR_TEXT, _FR_EXPECTED_RESULT), + (ModelFileType.FILE_NAME, _RU_TEXT, _RU_EXPECTED_RESULT), + (ModelFileType.FILE_CONTENT, _RU_TEXT, _RU_EXPECTED_RESULT), + (ModelFileType.FILE_NAME, _MIXED_TEXT, _MIXED_EXPECTED_RESULT), + (ModelFileType.FILE_CONTENT, _MIXED_TEXT, _MIXED_EXPECTED_RESULT) + ) + def test_detect(self, model_file_type, text, expected_result): + # Creates detector. + if model_file_type is ModelFileType.FILE_NAME: + base_options = _BaseOptions(model_asset_path=self.model_path) + elif model_file_type is ModelFileType.FILE_CONTENT: + with open(self.model_path, 'rb') as f: + model_content = f.read() + base_options = _BaseOptions(model_asset_buffer=model_content) + else: + # Should never happen + raise ValueError('model_file_type is invalid.') + + options = _LanguageDetectorOptions( + base_options=base_options, score_threshold=_SCORE_THRESHOLD + ) + detector = _LanguageDetector.create_from_options(options) + + # Performs language detection on the input. + text_result = detector.detect(text) + # Comparing results. + self._expect_language_detector_result_correct(text_result, expected_result) + # Closes the detector explicitly when the detector is not used in + # a context. + detector.close() + + @parameterized.parameters( + (ModelFileType.FILE_NAME, _EN_TEXT, _EN_EXPECTED_RESULT), + (ModelFileType.FILE_NAME, _FR_TEXT, _FR_EXPECTED_RESULT), + (ModelFileType.FILE_NAME, _RU_TEXT, _RU_EXPECTED_RESULT), + (ModelFileType.FILE_CONTENT, _MIXED_TEXT, _MIXED_EXPECTED_RESULT) + ) + def test_detect_in_context(self, model_file_type, text, expected_result): + # Creates detector. + if model_file_type is ModelFileType.FILE_NAME: + base_options = _BaseOptions(model_asset_path=self.model_path) + elif model_file_type is ModelFileType.FILE_CONTENT: + with open(self.model_path, 'rb') as f: + model_content = f.read() + base_options = _BaseOptions(model_asset_buffer=model_content) + else: + # Should never happen + raise ValueError('model_file_type is invalid.') + + options = _LanguageDetectorOptions( + base_options=base_options, score_threshold=_SCORE_THRESHOLD + ) + with _LanguageDetector.create_from_options(options) as detector: + # Performs language detection on the input. + text_result = detector.detect(text) + # Comparing results. + self._expect_language_detector_result_correct(text_result, expected_result) + + def test_allowlist_option(self): + # Creates detector. + base_options = _BaseOptions(model_asset_path=self.model_path) + options = _LanguageDetectorOptions( + base_options=base_options, score_threshold=_SCORE_THRESHOLD, + category_allowlist=["ja"] + ) + with _LanguageDetector.create_from_options(options) as detector: + # Performs language detection on the input. + text_result = detector.detect(_MIXED_TEXT) + # Comparing results. + expected_result = LanguageDetectorResult( + [ + LanguageDetectorPrediction("ja", 0.481617) + ] + ) + self._expect_language_detector_result_correct(text_result, expected_result) + + def test_denylist_option(self): + # Creates detector. + base_options = _BaseOptions(model_asset_path=self.model_path) + options = _LanguageDetectorOptions( + base_options=base_options, score_threshold=_SCORE_THRESHOLD, + category_denylist=["ja"] + ) + with _LanguageDetector.create_from_options(options) as detector: + # Performs language detection on the input. + text_result = detector.detect(_MIXED_TEXT) + # Comparing results. + expected_result = LanguageDetectorResult( + [ + LanguageDetectorPrediction("zh", 0.505424) + ] + ) + self._expect_language_detector_result_correct(text_result, expected_result) + + +if __name__ == '__main__': + absltest.main() diff --git a/mediapipe/tasks/python/text/BUILD b/mediapipe/tasks/python/text/BUILD index 9d5d23261..b1dd3feb9 100644 --- a/mediapipe/tasks/python/text/BUILD +++ b/mediapipe/tasks/python/text/BUILD @@ -57,3 +57,23 @@ py_library( "//mediapipe/tasks/python/text/core:base_text_task_api", ], ) + +py_library( + name = "language_detector", + srcs = [ + "language_detector.py", + ], + visibility = ["//mediapipe/tasks:users"], + deps = [ + "//mediapipe/python:packet_creator", + "//mediapipe/python:packet_getter", + "//mediapipe/tasks/cc/components/containers/proto:classifications_py_pb2", + "//mediapipe/tasks/cc/components/processors/proto:classifier_options_py_pb2", + "//mediapipe/tasks/cc/text/text_classifier/proto:text_classifier_graph_options_py_pb2", + "//mediapipe/tasks/python/components/containers:classification_result", + "//mediapipe/tasks/python/core:base_options", + "//mediapipe/tasks/python/core:optional_dependencies", + "//mediapipe/tasks/python/core:task_info", + "//mediapipe/tasks/python/text/core:base_text_task_api", + ], +) diff --git a/mediapipe/tasks/python/text/language_detector.py b/mediapipe/tasks/python/text/language_detector.py new file mode 100644 index 000000000..8d933dd59 --- /dev/null +++ b/mediapipe/tasks/python/text/language_detector.py @@ -0,0 +1,205 @@ +# Copyright 2022 The MediaPipe Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""MediaPipe language detector task.""" + +import dataclasses +from typing import Optional, List + +from mediapipe.python import packet_creator +from mediapipe.python import packet_getter +from mediapipe.tasks.cc.components.containers.proto import classifications_pb2 +from mediapipe.tasks.cc.components.processors.proto import classifier_options_pb2 +from mediapipe.tasks.cc.text.text_classifier.proto import text_classifier_graph_options_pb2 +from mediapipe.tasks.python.components.containers import classification_result as classification_result_module +from mediapipe.tasks.python.core import base_options as base_options_module +from mediapipe.tasks.python.core import task_info as task_info_module +from mediapipe.tasks.python.core.optional_dependencies import doc_controls +from mediapipe.tasks.python.text.core import base_text_task_api + +_ClassificationResult = classification_result_module.ClassificationResult +_BaseOptions = base_options_module.BaseOptions +_TextClassifierGraphOptionsProto = text_classifier_graph_options_pb2.TextClassifierGraphOptions +_ClassifierOptionsProto = classifier_options_pb2.ClassifierOptions +_TaskInfo = task_info_module.TaskInfo + +_CLASSIFICATIONS_STREAM_NAME = 'classifications_out' +_CLASSIFICATIONS_TAG = 'CLASSIFICATIONS' +_TEXT_IN_STREAM_NAME = 'text_in' +_TEXT_TAG = 'TEXT' +_TASK_GRAPH_NAME = 'mediapipe.tasks.text.text_classifier.TextClassifierGraph' + + +@dataclasses.dataclass +class LanguageDetectorResult: + + @dataclasses.dataclass + class LanguageDetectorPrediction: + """A language code and its probability.""" + language_code: str + probability: float + + languages_and_scores: List[LanguageDetectorPrediction] + + +def _extract_language_detector_result( + classification_result: classification_result_module.ClassificationResult +) -> LanguageDetectorResult: + if len(classification_result.classifications) != 1: + raise ValueError( + "The LanguageDetector TextClassifierGraph should have exactly one " + "classification head." + ) + languages_and_scores = classification_result.classifications[0] + language_detector_result = LanguageDetectorResult([]) + for category in languages_and_scores.categories: + if category.category_name is None: + raise ValueError( + "LanguageDetector ClassificationResult has a missing language code.") + prediction = LanguageDetectorResult.LanguageDetectorPrediction( + category.category_name, category.score + ) + language_detector_result.languages_and_scores.append(prediction) + return language_detector_result + + +@dataclasses.dataclass +class LanguageDetectorOptions: + """Options for the language detector task. + + Attributes: + base_options: Base options for the language detector task. + display_names_locale: The locale to use for display names specified through + the TFLite Model Metadata. + max_results: The maximum number of top-scored classification results to + return. + score_threshold: Overrides the ones provided in the model metadata. Results + below this value are rejected. + category_allowlist: Allowlist of category names. If non-empty, + classification results whose category name is not in this set will be + filtered out. Duplicate or unknown category names are ignored. Mutually + exclusive with `category_denylist`. + category_denylist: Denylist of category names. If non-empty, classification + results whose category name is in this set will be filtered out. Duplicate + or unknown category names are ignored. Mutually exclusive with + `category_allowlist`. + """ + base_options: _BaseOptions + display_names_locale: Optional[str] = None + max_results: Optional[int] = None + score_threshold: Optional[float] = None + category_allowlist: Optional[List[str]] = None + category_denylist: Optional[List[str]] = None + + @doc_controls.do_not_generate_docs + def to_pb2(self) -> _TextClassifierGraphOptionsProto: + """Generates an TextClassifierOptions protobuf object.""" + base_options_proto = self.base_options.to_pb2() + classifier_options_proto = _ClassifierOptionsProto( + score_threshold=self.score_threshold, + category_allowlist=self.category_allowlist, + category_denylist=self.category_denylist, + display_names_locale=self.display_names_locale, + max_results=self.max_results) + + return _TextClassifierGraphOptionsProto( + base_options=base_options_proto, + classifier_options=classifier_options_proto) + + +class LanguageDetector(base_text_task_api.BaseTextTaskApi): + """Class that predicts the language of an input text. + + This API expects a TFLite model with TFLite Model Metadata that contains the + mandatory (described below) input tensors, output tensor, and the language + codes in an AssociatedFile. + + Input tensors: + (kTfLiteString) + - 1 input tensor that is scalar or has shape [1] containing the input + string. + Output tensor: + (kTfLiteFloat32) + - 1 output tensor of shape`[1 x N]` where `N` is the number of languages. + """ + + @classmethod + def create_from_model_path(cls, model_path: str) -> 'LanguageDetector': + """Creates an `LanguageDetector` object from a TensorFlow Lite model and the default `LanguageDetectorOptions`. + + Args: + model_path: Path to the model. + + Returns: + `LanguageDetector` object that's created from the model file and the + default `LanguageDetectorOptions`. + + Raises: + ValueError: If failed to create `LanguageDetector` object from the provided + file such as invalid file path. + RuntimeError: If other types of error occurred. + """ + base_options = _BaseOptions(model_asset_path=model_path) + options = LanguageDetectorOptions(base_options=base_options) + return cls.create_from_options(options) + + @classmethod + def create_from_options(cls, + options: LanguageDetectorOptions) -> 'LanguageDetector': + """Creates the `LanguageDetector` object from language detector options. + + Args: + options: Options for the language detector task. + + Returns: + `LanguageDetector` object that's created from `options`. + + Raises: + ValueError: If failed to create `LanguageDetector` object from + `LanguageDetectorOptions` such as missing the model. + RuntimeError: If other types of error occurred. + """ + task_info = _TaskInfo( + task_graph=_TASK_GRAPH_NAME, + input_streams=[':'.join([_TEXT_TAG, _TEXT_IN_STREAM_NAME])], + output_streams=[ + ':'.join([_CLASSIFICATIONS_TAG, _CLASSIFICATIONS_STREAM_NAME]) + ], + task_options=options) + return cls(task_info.generate_graph_config()) + + def detect(self, text: str) -> LanguageDetectorResult: + """Predicts the language of the input `text`. + + Args: + text: The input text. + + Returns: + A `LanguageDetectorResult` object that contains a list of languages and + scores. + + Raises: + ValueError: If any of the input arguments is invalid. + RuntimeError: If language detection failed to run. + """ + output_packets = self._runner.process( + {_TEXT_IN_STREAM_NAME: packet_creator.create_string(text)}) + + classification_result_proto = classifications_pb2.ClassificationResult() + classification_result_proto.CopyFrom( + packet_getter.get_proto(output_packets[_CLASSIFICATIONS_STREAM_NAME])) + + classification_result = _ClassificationResult.create_from_pb2( + classification_result_proto + ) + return _extract_language_detector_result(classification_result) diff --git a/mediapipe/tasks/python/vision/interactive_segmenter.py b/mediapipe/tasks/python/vision/interactive_segmenter.py index ad93c798c..1d9f5cf1a 100644 --- a/mediapipe/tasks/python/vision/interactive_segmenter.py +++ b/mediapipe/tasks/python/vision/interactive_segmenter.py @@ -88,7 +88,7 @@ class InteractiveSegmenterOptions: @doc_controls.do_not_generate_docs def to_pb2(self) -> _ImageSegmenterGraphOptionsProto: - """Generates an InteractiveSegmenterOptions protobuf object.""" + """Generates an ImageSegmenterGraphOptions protobuf object.""" base_options_proto = self.base_options.to_pb2() base_options_proto.use_stream_mode = False segmenter_options_proto = _SegmenterOptionsProto() diff --git a/mediapipe/tasks/testdata/text/BUILD b/mediapipe/tasks/testdata/text/BUILD index 9813b6543..62251ed8b 100644 --- a/mediapipe/tasks/testdata/text/BUILD +++ b/mediapipe/tasks/testdata/text/BUILD @@ -31,6 +31,7 @@ mediapipe_files(srcs = [ "bert_text_classifier.tflite", "mobilebert_embedding_with_metadata.tflite", "mobilebert_with_metadata.tflite", + "language_detector.tflite", "regex_one_embedding_with_metadata.tflite", "test_model_text_classifier_bool_output.tflite", "test_model_text_classifier_with_regex_tokenizer.tflite", @@ -78,6 +79,11 @@ filegroup( ], ) +filegroup( + name = "language_detector_model", + srcs = ["language_detector.tflite"], +) + filegroup( name = "text_classifier_models", srcs = [ From 0b1eb39870ea6b61e6805f15262ebe7fd2241c8e Mon Sep 17 00:00:00 2001 From: kinaryml Date: Fri, 21 Apr 2023 11:48:06 -0700 Subject: [PATCH 02/28] Updated copyright --- mediapipe/tasks/python/test/text/language_detector_test.py | 2 +- mediapipe/tasks/python/text/language_detector.py | 2 +- 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/mediapipe/tasks/python/test/text/language_detector_test.py b/mediapipe/tasks/python/test/text/language_detector_test.py index 2443d4312..69a9c092b 100644 --- a/mediapipe/tasks/python/test/text/language_detector_test.py +++ b/mediapipe/tasks/python/test/text/language_detector_test.py @@ -1,4 +1,4 @@ -# Copyright 2022 The MediaPipe Authors. All Rights Reserved. +# Copyright 2023 The MediaPipe Authors. All Rights Reserved. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. diff --git a/mediapipe/tasks/python/text/language_detector.py b/mediapipe/tasks/python/text/language_detector.py index 8d933dd59..1cc363d98 100644 --- a/mediapipe/tasks/python/text/language_detector.py +++ b/mediapipe/tasks/python/text/language_detector.py @@ -1,4 +1,4 @@ -# Copyright 2022 The MediaPipe Authors. All Rights Reserved. +# Copyright 2023 The MediaPipe Authors. All Rights Reserved. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. From 305866ccaeaa614e700f9b9a698af9514bdf17b7 Mon Sep 17 00:00:00 2001 From: kinaryml Date: Thu, 27 Apr 2023 21:11:53 -0700 Subject: [PATCH 03/28] Updated BUILD files to use the open sourced Language Detector model --- mediapipe/tasks/python/test/text/BUILD | 2 +- mediapipe/tasks/testdata/text/BUILD | 5 ----- 2 files changed, 1 insertion(+), 6 deletions(-) diff --git a/mediapipe/tasks/python/test/text/BUILD b/mediapipe/tasks/python/test/text/BUILD index 8d25f2781..be352c84d 100644 --- a/mediapipe/tasks/python/test/text/BUILD +++ b/mediapipe/tasks/python/test/text/BUILD @@ -54,7 +54,7 @@ py_test( name = "language_detector_test", srcs = ["language_detector_test.py"], data = [ - "//mediapipe/tasks/testdata/text:language_detector_model", + "//mediapipe/tasks/testdata/text:language_detector", ], deps = [ "//mediapipe/tasks/python/components/containers:category", diff --git a/mediapipe/tasks/testdata/text/BUILD b/mediapipe/tasks/testdata/text/BUILD index 0a50a229a..701ed0dfa 100644 --- a/mediapipe/tasks/testdata/text/BUILD +++ b/mediapipe/tasks/testdata/text/BUILD @@ -80,11 +80,6 @@ filegroup( ], ) -filegroup( - name = "language_detector_model", - srcs = ["language_detector.tflite"], -) - filegroup( name = "text_classifier_models", srcs = [ From 3b06772d9a2616c86d34c8605d023e7862f6a8e9 Mon Sep 17 00:00:00 2001 From: kinaryml Date: Thu, 27 Apr 2023 21:13:31 -0700 Subject: [PATCH 04/28] Fixed BUILD --- mediapipe/tasks/testdata/text/BUILD | 1 - 1 file changed, 1 deletion(-) diff --git a/mediapipe/tasks/testdata/text/BUILD b/mediapipe/tasks/testdata/text/BUILD index 701ed0dfa..f1f0dc814 100644 --- a/mediapipe/tasks/testdata/text/BUILD +++ b/mediapipe/tasks/testdata/text/BUILD @@ -32,7 +32,6 @@ mediapipe_files(srcs = [ "language_detector.tflite", "mobilebert_embedding_with_metadata.tflite", "mobilebert_with_metadata.tflite", - "language_detector.tflite", "regex_one_embedding_with_metadata.tflite", "test_model_text_classifier_bool_output.tflite", "test_model_text_classifier_with_regex_tokenizer.tflite", From 209d78f36cac23e0de035d2f01976a620c04d01f Mon Sep 17 00:00:00 2001 From: kinaryml Date: Mon, 1 May 2023 05:55:46 -0700 Subject: [PATCH 05/28] Added the Face Aligner Python API --- .../python/test/vision/face_aligner_test.py | 208 ++++++++++++++++++ mediapipe/tasks/python/vision/face_aligner.py | 201 +++++++++++++++++ 2 files changed, 409 insertions(+) create mode 100644 mediapipe/tasks/python/test/vision/face_aligner_test.py create mode 100644 mediapipe/tasks/python/vision/face_aligner.py diff --git a/mediapipe/tasks/python/test/vision/face_aligner_test.py b/mediapipe/tasks/python/test/vision/face_aligner_test.py new file mode 100644 index 000000000..7f2f8dd7e --- /dev/null +++ b/mediapipe/tasks/python/test/vision/face_aligner_test.py @@ -0,0 +1,208 @@ +# Copyright 2023 The MediaPipe Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""Tests for face aligner.""" + +import enum +import os +from unittest import mock + +from absl.testing import absltest +from absl.testing import parameterized + +from mediapipe.python._framework_bindings import image as image_module +from mediapipe.tasks.python.core import base_options as base_options_module +from mediapipe.tasks.python.components.containers import rect +from mediapipe.tasks.python.test import test_utils +from mediapipe.tasks.python.vision import face_aligner +from mediapipe.tasks.python.vision.core import vision_task_running_mode as running_mode_module +from mediapipe.tasks.python.vision.core import image_processing_options as image_processing_options_module + + +_BaseOptions = base_options_module.BaseOptions +_Rect = rect.Rect +_Image = image_module.Image +_FaceAligner = face_aligner.FaceAligner +_FaceAlignerOptions = face_aligner.FaceAlignerOptions +_RUNNING_MODE = running_mode_module.VisionTaskRunningMode +_ImageProcessingOptions = image_processing_options_module.ImageProcessingOptions + +_MODEL = 'face_stylizer.task' +_LARGE_FACE_IMAGE = "portrait.jpg" +_MODEL_IMAGE_SIZE = 256 +_TEST_DATA_DIR = 'mediapipe/tasks/testdata/vision' + + +class ModelFileType(enum.Enum): + FILE_CONTENT = 1 + FILE_NAME = 2 + + +class FaceAlignerTest(parameterized.TestCase): + + def setUp(self): + super().setUp() + self.test_image = _Image.create_from_file( + test_utils.get_test_data_path( + os.path.join(_TEST_DATA_DIR, _LARGE_FACE_IMAGE))) + self.model_path = test_utils.get_test_data_path( + os.path.join(_TEST_DATA_DIR, _MODEL)) + + def test_create_from_file_succeeds_with_valid_model_path(self): + # Creates with default option and valid model file successfully. + with _FaceAligner.create_from_model_path(self.model_path) as aligner: + self.assertIsInstance(aligner, _FaceAligner) + + def test_create_from_options_succeeds_with_valid_model_path(self): + # Creates with options containing model file successfully. + base_options = _BaseOptions(model_asset_path=self.model_path) + options = _FaceAlignerOptions(base_options=base_options) + with _FaceAligner.create_from_options(options) as aligner: + self.assertIsInstance(aligner, _FaceAligner) + + def test_create_from_options_fails_with_invalid_model_path(self): + with self.assertRaisesRegex( + RuntimeError, 'Unable to open file at /path/to/invalid/model.tflite'): + base_options = _BaseOptions( + model_asset_path='/path/to/invalid/model.tflite') + options = _FaceAlignerOptions(base_options=base_options) + _FaceAligner.create_from_options(options) + + def test_create_from_options_succeeds_with_valid_model_content(self): + # Creates with options containing model content successfully. + with open(self.model_path, 'rb') as f: + base_options = _BaseOptions(model_asset_buffer=f.read()) + options = _FaceAlignerOptions(base_options=base_options) + aligner = _FaceAligner.create_from_options(options) + self.assertIsInstance(aligner, _FaceAligner) + + @parameterized.parameters( + (ModelFileType.FILE_NAME, _LARGE_FACE_IMAGE), + (ModelFileType.FILE_CONTENT, _LARGE_FACE_IMAGE) + ) + def test_align(self, model_file_type, image_file_name): + # Load the test image. + self.test_image = _Image.create_from_file( + test_utils.get_test_data_path( + os.path.join(_TEST_DATA_DIR, image_file_name))) + # Creates aligner. + if model_file_type is ModelFileType.FILE_NAME: + base_options = _BaseOptions(model_asset_path=self.model_path) + elif model_file_type is ModelFileType.FILE_CONTENT: + with open(self.model_path, 'rb') as f: + model_content = f.read() + base_options = _BaseOptions(model_asset_buffer=model_content) + else: + # Should never happen + raise ValueError('model_file_type is invalid.') + + options = _FaceAlignerOptions(base_options=base_options) + aligner = _FaceAligner.create_from_options(options) + + # Performs face alignment on the input. + alignd_image = aligner.align(self.test_image) + self.assertIsInstance(alignd_image, _Image) + # Closes the aligner explicitly when the aligner is not used in + # a context. + aligner.close() + + @parameterized.parameters( + (ModelFileType.FILE_NAME, _LARGE_FACE_IMAGE), + (ModelFileType.FILE_CONTENT, _LARGE_FACE_IMAGE) + ) + def test_align_in_context(self, model_file_type, image_file_name): + # Load the test image. + self.test_image = _Image.create_from_file( + test_utils.get_test_data_path( + os.path.join(_TEST_DATA_DIR, image_file_name))) + # Creates aligner. + if model_file_type is ModelFileType.FILE_NAME: + base_options = _BaseOptions(model_asset_path=self.model_path) + elif model_file_type is ModelFileType.FILE_CONTENT: + with open(self.model_path, 'rb') as f: + model_content = f.read() + base_options = _BaseOptions(model_asset_buffer=model_content) + else: + # Should never happen + raise ValueError('model_file_type is invalid.') + + options = _FaceAlignerOptions(base_options=base_options) + with _FaceAligner.create_from_options(options) as aligner: + # Performs face alignment on the input. + alignd_image = aligner.align(self.test_image) + self.assertIsInstance(alignd_image, _Image) + self.assertEqual(alignd_image.width, _MODEL_IMAGE_SIZE) + self.assertEqual(alignd_image.height, _MODEL_IMAGE_SIZE) + + def test_align_succeeds_with_region_of_interest(self): + base_options = _BaseOptions(model_asset_path=self.model_path) + options = _FaceAlignerOptions(base_options=base_options) + with _FaceAligner.create_from_options(options) as aligner: + # Load the test image. + test_image = _Image.create_from_file( + test_utils.get_test_data_path( + os.path.join(_TEST_DATA_DIR, _LARGE_FACE_IMAGE) + ) + ) + # Region-of-interest around the face. + roi = _Rect(left=0.32, top=0.02, right=0.67, bottom=0.32) + image_processing_options = _ImageProcessingOptions(roi) + # Performs face alignment on the input. + alignd_image = aligner.align(test_image, image_processing_options) + self.assertIsInstance(alignd_image, _Image) + self.assertEqual(alignd_image.width, _MODEL_IMAGE_SIZE) + self.assertEqual(alignd_image.height, _MODEL_IMAGE_SIZE) + + def test_align_succeeds_with_no_face_detected(self): + base_options = _BaseOptions(model_asset_path=self.model_path) + options = _FaceAlignerOptions(base_options=base_options) + with _FaceAligner.create_from_options(options) as aligner: + # Load the test image. + test_image = _Image.create_from_file( + test_utils.get_test_data_path( + os.path.join(_TEST_DATA_DIR, _LARGE_FACE_IMAGE) + ) + ) + # Region-of-interest that doesn't contain a human face. + roi = _Rect(left=0.1, top=0.1, right=0.2, bottom=0.2) + image_processing_options = _ImageProcessingOptions(roi) + # Performs face alignment on the input. + alignd_image = aligner.align(test_image, image_processing_options) + self.assertIsNone(alignd_image) + + def test_missing_result_callback(self): + options = _FaceAlignerOptions( + base_options=_BaseOptions(model_asset_path=self.model_path), + running_mode=_RUNNING_MODE.LIVE_STREAM, + ) + with self.assertRaisesRegex( + ValueError, r'result callback must be provided' + ): + with _FaceAligner.create_from_options(options) as unused_aligner: + pass + + def test_illegal_result_callback(self): + options = _FaceAlignerOptions( + base_options=_BaseOptions(model_asset_path=self.model_path), + running_mode=_RUNNING_MODE.IMAGE, + result_callback=mock.MagicMock(), + ) + with self.assertRaisesRegex( + ValueError, r'result callback should not be provided' + ): + with _FaceAligner.create_from_options(options) as unused_aligner: + pass + + +if __name__ == '__main__': + absltest.main() diff --git a/mediapipe/tasks/python/vision/face_aligner.py b/mediapipe/tasks/python/vision/face_aligner.py new file mode 100644 index 000000000..615232f05 --- /dev/null +++ b/mediapipe/tasks/python/vision/face_aligner.py @@ -0,0 +1,201 @@ +# Copyright 2023 The MediaPipe Authors. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""MediaPipe face aligner task.""" + +import dataclasses +from typing import Callable, Mapping, Optional + +from mediapipe.python import packet_creator +from mediapipe.python import packet_getter +from mediapipe.python._framework_bindings import image as image_module +from mediapipe.python._framework_bindings import packet as packet_module +from mediapipe.tasks.cc.vision.face_stylizer.proto import face_stylizer_graph_options_pb2 +from mediapipe.tasks.python.core import base_options as base_options_module +from mediapipe.tasks.python.core import task_info as task_info_module +from mediapipe.tasks.python.core.optional_dependencies import doc_controls +from mediapipe.tasks.python.vision.core import base_vision_task_api +from mediapipe.tasks.python.vision.core import image_processing_options as image_processing_options_module +from mediapipe.tasks.python.vision.core import vision_task_running_mode as running_mode_module + +_BaseOptions = base_options_module.BaseOptions +_FaceStylizerGraphOptionsProto = ( + face_stylizer_graph_options_pb2.FaceStylizerGraphOptions +) +_RunningMode = running_mode_module.VisionTaskRunningMode +_ImageProcessingOptions = image_processing_options_module.ImageProcessingOptions +_TaskInfo = task_info_module.TaskInfo + +_FACE_ALIGNMENT_IMAGE_NAME = 'stylized_image' +_FACE_ALIGNMENT_IMAGE_TAG = 'FACE_ALIGNMENT' +_NORM_RECT_STREAM_NAME = 'norm_rect_in' +_NORM_RECT_TAG = 'NORM_RECT' +_IMAGE_IN_STREAM_NAME = 'image_in' +_IMAGE_OUT_STREAM_NAME = 'image_out' +_IMAGE_TAG = 'IMAGE' +_TASK_GRAPH_NAME = 'mediapipe.tasks.vision.face_stylizer.FaceStylizerGraph' +_MICRO_SECONDS_PER_MILLISECOND = 1000 + + +@dataclasses.dataclass +class FaceAlignerOptions: + """Options for the face aligner task. + + Attributes: + base_options: Base options for the face aligner task. + running_mode: The running mode of the task. Default to the image mode. Face + aligner task has three running modes: 1) The image mode for aligning one + face on a single image input. 2) The video mode for aligning one face per + frame on the decoded frames of a video. 3) The live stream mode for + aligning one face on a live stream of input data, such as from camera. + result_callback: The user-defined result callback for processing live stream + data. The result callback should only be specified when the running mode + is set to the live stream mode. + """ + + base_options: _BaseOptions + running_mode: _RunningMode = _RunningMode.IMAGE + result_callback: Optional[ + Callable[[image_module.Image, image_module.Image, int], None] + ] = None + + @doc_controls.do_not_generate_docs + def to_pb2(self) -> _FaceStylizerGraphOptionsProto: + """Generates an FaceStylizerOptions protobuf object.""" + base_options_proto = self.base_options.to_pb2() + base_options_proto.use_stream_mode = ( + False if self.running_mode == _RunningMode.IMAGE else True + ) + return _FaceStylizerGraphOptionsProto(base_options=base_options_proto) + + +class FaceAligner(base_vision_task_api.BaseVisionTaskApi): + """Class that performs face alignment on images.""" + + @classmethod + def create_from_model_path(cls, model_path: str) -> 'FaceAligner': + """Creates an `FaceAligner` object from a TensorFlow Lite model and the default `FaceAlignerOptions`. + + Note that the created `FaceAligner` instance is in image mode, for + aligning one face on a single image input. + + Args: + model_path: Path to the model. + + Returns: + `FaceAligner` object that's created from the model file and the default + `FaceAlignerOptions`. + + Raises: + ValueError: If failed to create `FaceAligner` object from the provided + file such as invalid file path. + RuntimeError: If other types of error occurred. + """ + base_options = _BaseOptions(model_asset_path=model_path) + options = FaceAlignerOptions( + base_options=base_options, running_mode=_RunningMode.IMAGE + ) + return cls.create_from_options(options) + + @classmethod + def create_from_options(cls, options: FaceAlignerOptions) -> 'FaceAligner': + """Creates the `FaceAligner` object from face aligner options. + + Args: + options: Options for the face aligner task. + + Returns: + `FaceAligner` object that's created from `options`. + + Raises: + ValueError: If failed to create `FaceAligner` object from + `FaceAlignerOptions` such as missing the model. + RuntimeError: If other types of error occurred. + """ + + def packets_callback(output_packets: Mapping[str, packet_module.Packet]): + if output_packets[_IMAGE_OUT_STREAM_NAME].is_empty(): + return + image = packet_getter.get_image(output_packets[_IMAGE_OUT_STREAM_NAME]) + aligned_image_packet = output_packets[_FACE_ALIGNMENT_IMAGE_NAME] + if aligned_image_packet.is_empty(): + options.result_callback( + None, + image, + aligned_image_packet.timestamp.value + // _MICRO_SECONDS_PER_MILLISECOND, + ) + + aligned_image = packet_getter.get_image(aligned_image_packet) + + options.result_callback( + aligned_image, + image, + aligned_image_packet.timestamp.value + // _MICRO_SECONDS_PER_MILLISECOND, + ) + + task_info = _TaskInfo( + task_graph=_TASK_GRAPH_NAME, + input_streams=[ + ':'.join([_IMAGE_TAG, _IMAGE_IN_STREAM_NAME]), + ':'.join([_NORM_RECT_TAG, _NORM_RECT_STREAM_NAME]), + ], + output_streams=[ + ':'.join([_FACE_ALIGNMENT_IMAGE_TAG, _FACE_ALIGNMENT_IMAGE_NAME]), + ':'.join([_IMAGE_TAG, _IMAGE_OUT_STREAM_NAME]), + ], + task_options=options, + ) + return cls( + task_info.generate_graph_config( + enable_flow_limiting=options.running_mode + == _RunningMode.LIVE_STREAM + ), + options.running_mode, + packets_callback if options.result_callback else None, + ) + + def align( + self, + image: image_module.Image, + image_processing_options: Optional[_ImageProcessingOptions] = None, + ) -> image_module.Image: + """Performs face alignment on the provided MediaPipe Image. + + Only use this method when the FaceAligner is created with the image + running mode. + + Args: + image: MediaPipe Image. + image_processing_options: Options for image processing. + + Returns: + The aligned face image. The aligned output image size is the same as the + model output size. None if no face is detected on the input image. + + Raises: + ValueError: If any of the input arguments is invalid. + RuntimeError: If face alignment failed to run. + """ + normalized_rect = self.convert_to_normalized_rect( + image_processing_options, image) + output_packets = self._process_image_data({ + _IMAGE_IN_STREAM_NAME: packet_creator.create_image(image), + _NORM_RECT_STREAM_NAME: packet_creator.create_proto( + normalized_rect.to_pb2() + ), + }) + if output_packets[_FACE_ALIGNMENT_IMAGE_NAME].is_empty(): + return None + return packet_getter.get_image(output_packets[_FACE_ALIGNMENT_IMAGE_NAME]) From bd039f8b65d2b6bf623c909832c8b5ada0e81a6e Mon Sep 17 00:00:00 2001 From: kinaryml Date: Mon, 1 May 2023 05:56:52 -0700 Subject: [PATCH 06/28] Updated necessary BUILD files --- mediapipe/tasks/python/test/vision/BUILD | 18 ++++++++++++++++++ mediapipe/tasks/python/vision/BUILD | 19 +++++++++++++++++++ mediapipe/tasks/testdata/vision/BUILD | 2 ++ 3 files changed, 39 insertions(+) diff --git a/mediapipe/tasks/python/test/vision/BUILD b/mediapipe/tasks/python/test/vision/BUILD index e55e1b572..8488a3b1f 100644 --- a/mediapipe/tasks/python/test/vision/BUILD +++ b/mediapipe/tasks/python/test/vision/BUILD @@ -140,6 +140,24 @@ py_test( ], ) +py_test( + name = "face_aligner_test", + srcs = ["face_aligner_test.py"], + data = [ + "//mediapipe/tasks/testdata/vision:test_images", + "//mediapipe/tasks/testdata/vision:test_models", + ], + deps = [ + "//mediapipe/python:_framework_bindings", + "//mediapipe/tasks/python/components/containers:rect", + "//mediapipe/tasks/python/core:base_options", + "//mediapipe/tasks/python/test:test_utils", + "//mediapipe/tasks/python/vision:face_aligner", + "//mediapipe/tasks/python/vision/core:image_processing_options", + "//mediapipe/tasks/python/vision/core:vision_task_running_mode", + ], +) + py_test( name = "hand_landmarker_test", srcs = ["hand_landmarker_test.py"], diff --git a/mediapipe/tasks/python/vision/BUILD b/mediapipe/tasks/python/vision/BUILD index d0c97434f..dcd28dcf5 100644 --- a/mediapipe/tasks/python/vision/BUILD +++ b/mediapipe/tasks/python/vision/BUILD @@ -264,3 +264,22 @@ py_library( "//mediapipe/tasks/python/vision/core:vision_task_running_mode", ], ) + +py_library( + name = "face_aligner", + srcs = [ + "face_aligner.py", + ], + deps = [ + "//mediapipe/python:_framework_bindings", + "//mediapipe/python:packet_creator", + "//mediapipe/python:packet_getter", + "//mediapipe/tasks/cc/vision/face_stylizer/proto:face_stylizer_graph_options_py_pb2", + "//mediapipe/tasks/python/core:base_options", + "//mediapipe/tasks/python/core:optional_dependencies", + "//mediapipe/tasks/python/core:task_info", + "//mediapipe/tasks/python/vision/core:base_vision_task_api", + "//mediapipe/tasks/python/vision/core:image_processing_options", + "//mediapipe/tasks/python/vision/core:vision_task_running_mode", + ], +) diff --git a/mediapipe/tasks/testdata/vision/BUILD b/mediapipe/tasks/testdata/vision/BUILD index 632e8aa4e..a8e06cad5 100644 --- a/mediapipe/tasks/testdata/vision/BUILD +++ b/mediapipe/tasks/testdata/vision/BUILD @@ -48,6 +48,7 @@ mediapipe_files(srcs = [ "face_landmark.tflite", "face_landmarker.task", "face_landmarker_v2.task", + "face_stylizer.task", "fist.jpg", "fist.png", "hair_segmentation.tflite", @@ -176,6 +177,7 @@ filegroup( "face_detection_short_range.tflite", "face_landmarker.task", "face_landmarker_v2.task", + "face_stylizer.task", "hair_segmentation.tflite", "hand_landmark_full.tflite", "hand_landmark_lite.tflite", From 3719aaef7e8aab9a62985d5e1d947a8fe09b2540 Mon Sep 17 00:00:00 2001 From: Chuo-Ling Chang Date: Mon, 1 May 2023 23:50:27 -0700 Subject: [PATCH 07/28] Fix typo. PiperOrigin-RevId: 528693117 --- mediapipe/tasks/python/vision/face_landmarker.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/mediapipe/tasks/python/vision/face_landmarker.py b/mediapipe/tasks/python/vision/face_landmarker.py index 870e7e43e..44ddba87e 100644 --- a/mediapipe/tasks/python/vision/face_landmarker.py +++ b/mediapipe/tasks/python/vision/face_landmarker.py @@ -2939,7 +2939,7 @@ class FaceLandmarkerOptions: Attributes: base_options: Base options for the face landmarker task. running_mode: The running mode of the task. Default to the image mode. - HandLandmarker has three running modes: 1) The image mode for detecting + FaceLandmarker has three running modes: 1) The image mode for detecting face landmarks on single image inputs. 2) The video mode for detecting face landmarks on the decoded frames of a video. 3) The live stream mode for detecting face landmarks on the live stream of input data, such as From 5b93477589412c3aa0e4a75d350380a797c1496e Mon Sep 17 00:00:00 2001 From: MediaPipe Team Date: Tue, 2 May 2023 02:15:03 -0700 Subject: [PATCH 08/28] internal change PiperOrigin-RevId: 528719459 --- mediapipe/calculators/util/BUILD | 2 + .../util/flat_color_image_calculator.cc | 51 +++++++++--- .../util/flat_color_image_calculator_test.cc | 80 +++++++++++++++++++ 3 files changed, 123 insertions(+), 10 deletions(-) diff --git a/mediapipe/calculators/util/BUILD b/mediapipe/calculators/util/BUILD index 22e6b0738..b6f50b840 100644 --- a/mediapipe/calculators/util/BUILD +++ b/mediapipe/calculators/util/BUILD @@ -1285,12 +1285,14 @@ cc_library( srcs = ["flat_color_image_calculator.cc"], deps = [ ":flat_color_image_calculator_cc_proto", + "//mediapipe/framework:calculator_contract", "//mediapipe/framework:calculator_framework", "//mediapipe/framework/api2:node", "//mediapipe/framework/formats:image", "//mediapipe/framework/formats:image_frame", "//mediapipe/framework/formats:image_frame_opencv", "//mediapipe/framework/port:opencv_core", + "//mediapipe/framework/port:ret_check", "//mediapipe/util:color_cc_proto", "@com_google_absl//absl/status", "@com_google_absl//absl/strings", diff --git a/mediapipe/calculators/util/flat_color_image_calculator.cc b/mediapipe/calculators/util/flat_color_image_calculator.cc index 71d3582c5..f3b9c184c 100644 --- a/mediapipe/calculators/util/flat_color_image_calculator.cc +++ b/mediapipe/calculators/util/flat_color_image_calculator.cc @@ -15,14 +15,13 @@ #include #include "absl/status/status.h" -#include "absl/strings/str_cat.h" #include "mediapipe/calculators/util/flat_color_image_calculator.pb.h" #include "mediapipe/framework/api2/node.h" +#include "mediapipe/framework/calculator_contract.h" #include "mediapipe/framework/calculator_framework.h" #include "mediapipe/framework/formats/image.h" #include "mediapipe/framework/formats/image_frame.h" #include "mediapipe/framework/formats/image_frame_opencv.h" -#include "mediapipe/framework/port/opencv_core_inc.h" #include "mediapipe/util/color.pb.h" namespace mediapipe { @@ -32,6 +31,7 @@ namespace { using ::mediapipe::api2::Input; using ::mediapipe::api2::Node; using ::mediapipe::api2::Output; +using ::mediapipe::api2::SideOutput; } // namespace // A calculator for generating an image filled with a single color. @@ -45,7 +45,8 @@ using ::mediapipe::api2::Output; // // Outputs: // IMAGE (Image) -// Image filled with the requested color. +// Image filled with the requested color. Can be either an output_stream +// or an output_side_packet. // // Example useage: // node { @@ -68,9 +69,10 @@ class FlatColorImageCalculator : public Node { public: static constexpr Input::Optional kInImage{"IMAGE"}; static constexpr Input::Optional kInColor{"COLOR"}; - static constexpr Output kOutImage{"IMAGE"}; + static constexpr Output::Optional kOutImage{"IMAGE"}; + static constexpr SideOutput::Optional kOutSideImage{"IMAGE"}; - MEDIAPIPE_NODE_CONTRACT(kInImage, kInColor, kOutImage); + MEDIAPIPE_NODE_CONTRACT(kInImage, kInColor, kOutImage, kOutSideImage); static absl::Status UpdateContract(CalculatorContract* cc) { const auto& options = cc->Options(); @@ -81,6 +83,13 @@ class FlatColorImageCalculator : public Node { RET_CHECK(kInColor(cc).IsConnected() ^ options.has_color()) << "Either set COLOR input stream, or set through options"; + RET_CHECK(kOutImage(cc).IsConnected() ^ kOutSideImage(cc).IsConnected()) + << "Set IMAGE either as output stream, or as output side packet"; + + RET_CHECK(!kOutSideImage(cc).IsConnected() || + (options.has_output_height() && options.has_output_width())) + << "Set size through options, when setting IMAGE as output side packet"; + return absl::OkStatus(); } @@ -88,6 +97,9 @@ class FlatColorImageCalculator : public Node { absl::Status Process(CalculatorContext* cc) override; private: + std::optional> CreateOutputFrame( + CalculatorContext* cc); + bool use_dimension_from_option_ = false; bool use_color_from_option_ = false; }; @@ -96,10 +108,31 @@ MEDIAPIPE_REGISTER_NODE(FlatColorImageCalculator); absl::Status FlatColorImageCalculator::Open(CalculatorContext* cc) { use_dimension_from_option_ = !kInImage(cc).IsConnected(); use_color_from_option_ = !kInColor(cc).IsConnected(); + + if (!kOutImage(cc).IsConnected()) { + std::optional> output_frame = + CreateOutputFrame(cc); + if (output_frame.has_value()) { + kOutSideImage(cc).Set(Image(output_frame.value())); + } + } return absl::OkStatus(); } absl::Status FlatColorImageCalculator::Process(CalculatorContext* cc) { + if (kOutImage(cc).IsConnected()) { + std::optional> output_frame = + CreateOutputFrame(cc); + if (output_frame.has_value()) { + kOutImage(cc).Send(Image(output_frame.value())); + } + } + + return absl::OkStatus(); +} + +std::optional> +FlatColorImageCalculator::CreateOutputFrame(CalculatorContext* cc) { const auto& options = cc->Options(); int output_height = -1; @@ -112,7 +145,7 @@ absl::Status FlatColorImageCalculator::Process(CalculatorContext* cc) { output_height = input_image.height(); output_width = input_image.width(); } else { - return absl::OkStatus(); + return std::nullopt; } Color color; @@ -121,7 +154,7 @@ absl::Status FlatColorImageCalculator::Process(CalculatorContext* cc) { } else if (!kInColor(cc).IsEmpty()) { color = kInColor(cc).Get(); } else { - return absl::OkStatus(); + return std::nullopt; } auto output_frame = std::make_shared(ImageFormat::SRGB, @@ -130,9 +163,7 @@ absl::Status FlatColorImageCalculator::Process(CalculatorContext* cc) { output_mat.setTo(cv::Scalar(color.r(), color.g(), color.b())); - kOutImage(cc).Send(Image(output_frame)); - - return absl::OkStatus(); + return output_frame; } } // namespace mediapipe diff --git a/mediapipe/calculators/util/flat_color_image_calculator_test.cc b/mediapipe/calculators/util/flat_color_image_calculator_test.cc index 53c6de1b1..c09064bf2 100644 --- a/mediapipe/calculators/util/flat_color_image_calculator_test.cc +++ b/mediapipe/calculators/util/flat_color_image_calculator_test.cc @@ -113,6 +113,35 @@ TEST(FlatColorImageCalculatorTest, SpecifyDimensionThroughOptions) { } } +TEST(FlatColorImageCalculatorTest, ProducesOutputSidePacket) { + CalculatorRunner runner(R"pb( + calculator: "FlatColorImageCalculator" + output_side_packet: "IMAGE:out_packet" + options { + [mediapipe.FlatColorImageCalculatorOptions.ext] { + output_width: 1 + output_height: 1 + color: { + r: 100, + g: 200, + b: 255, + } + } + } + )pb"); + + MP_ASSERT_OK(runner.Run()); + + const auto& image = runner.OutputSidePackets().Tag(kImageTag).Get(); + EXPECT_EQ(image.width(), 1); + EXPECT_EQ(image.height(), 1); + auto image_frame = image.GetImageFrameSharedPtr(); + const uint8_t* pixel_data = image_frame->PixelData(); + EXPECT_EQ(pixel_data[0], 100); + EXPECT_EQ(pixel_data[1], 200); + EXPECT_EQ(pixel_data[2], 255); +} + TEST(FlatColorImageCalculatorTest, FailureMissingDimension) { CalculatorRunner runner(R"pb( calculator: "FlatColorImageCalculator" @@ -206,5 +235,56 @@ TEST(FlatColorImageCalculatorTest, FailureDuplicateColor) { HasSubstr("Either set COLOR input stream")); } +TEST(FlatColorImageCalculatorTest, FailureDuplicateOutputs) { + CalculatorRunner runner(R"pb( + calculator: "FlatColorImageCalculator" + output_stream: "IMAGE:out_image" + output_side_packet: "IMAGE:out_packet" + options { + [mediapipe.FlatColorImageCalculatorOptions.ext] { + output_width: 1 + output_height: 1 + color: { + r: 100, + g: 200, + b: 255, + } + } + } + )pb"); + + ASSERT_THAT( + runner.Run().message(), + HasSubstr("Set IMAGE either as output stream, or as output side packet")); +} + +TEST(FlatColorImageCalculatorTest, FailureSettingInputImageOnOutputSidePacket) { + CalculatorRunner runner(R"pb( + calculator: "FlatColorImageCalculator" + input_stream: "IMAGE:image" + output_side_packet: "IMAGE:out_packet" + options { + [mediapipe.FlatColorImageCalculatorOptions.ext] { + color: { + r: 100, + g: 200, + b: 255, + } + } + } + )pb"); + + auto image_frame = std::make_shared(ImageFormat::SRGB, + kImageWidth, kImageHeight); + + for (int ts = 0; ts < 3; ++ts) { + runner.MutableInputs()->Tag(kImageTag).packets.push_back( + MakePacket(image_frame).At(Timestamp(ts))); + } + ASSERT_THAT(runner.Run().message(), + HasSubstr("Set size through options, when setting IMAGE as " + "output side packet")); +} + } // namespace } // namespace mediapipe From 7fdbbee5be85c06758c01211ea0c81bbcf223ccc Mon Sep 17 00:00:00 2001 From: MediaPipe Team Date: Tue, 2 May 2023 09:04:56 -0700 Subject: [PATCH 09/28] Internal change PiperOrigin-RevId: 528799585 --- mediapipe/framework/tool/BUILD | 1 + mediapipe/framework/tool/test_util.cc | 8 ++++++++ 2 files changed, 9 insertions(+) diff --git a/mediapipe/framework/tool/BUILD b/mediapipe/framework/tool/BUILD index fbdcf8c9e..4ae0bb607 100644 --- a/mediapipe/framework/tool/BUILD +++ b/mediapipe/framework/tool/BUILD @@ -791,6 +791,7 @@ cc_library( "@com_google_absl//absl/status", "@com_google_absl//absl/status:statusor", "@com_google_absl//absl/strings", + "@com_google_absl//absl/strings:str_format", "@stblib//:stb_image", "@stblib//:stb_image_write", ], diff --git a/mediapipe/framework/tool/test_util.cc b/mediapipe/framework/tool/test_util.cc index 5642941e9..64b5072c5 100644 --- a/mediapipe/framework/tool/test_util.cc +++ b/mediapipe/framework/tool/test_util.cc @@ -26,6 +26,7 @@ #include "absl/status/status.h" #include "absl/strings/match.h" #include "absl/strings/str_cat.h" +#include "absl/strings/str_format.h" #include "absl/strings/str_join.h" #include "absl/strings/substitute.h" #include "mediapipe/framework/calculator.pb.h" @@ -311,6 +312,13 @@ std::unique_ptr LoadTestPng(absl::string_view path, // Returns the path to the output if successful. absl::StatusOr SavePngTestOutput( const mediapipe::ImageFrame& image, absl::string_view prefix) { + absl::flat_hash_set supported_formats = { + ImageFormat::GRAY8, ImageFormat::SRGB, ImageFormat::SRGBA, + ImageFormat::LAB8, ImageFormat::SBGRA}; + if (!supported_formats.contains(image.Format())) { + return absl::CancelledError( + absl::StrFormat("Format %d can not be saved to PNG.", image.Format())); + } std::string now_string = absl::FormatTime(absl::Now()); std::string output_relative_path = absl::StrCat(prefix, "_", now_string, ".png"); From 60055f6feecdf9e7117622aabb68aedc77943629 Mon Sep 17 00:00:00 2001 From: Jiuqiang Tang Date: Tue, 2 May 2023 10:28:28 -0700 Subject: [PATCH 10/28] Add more comments and usage example of the face stylizer graph. PiperOrigin-RevId: 528823127 --- .../tasks/cc/vision/face_stylizer/face_stylizer_graph.cc | 9 ++++++--- 1 file changed, 6 insertions(+), 3 deletions(-) diff --git a/mediapipe/tasks/cc/vision/face_stylizer/face_stylizer_graph.cc b/mediapipe/tasks/cc/vision/face_stylizer/face_stylizer_graph.cc index 7c4e6f138..cb49ef59d 100644 --- a/mediapipe/tasks/cc/vision/face_stylizer/face_stylizer_graph.cc +++ b/mediapipe/tasks/cc/vision/face_stylizer/face_stylizer_graph.cc @@ -199,7 +199,9 @@ void ConfigureTensorsToImageCalculator( // STYLIZED_IMAGE - mediapipe::Image // The face stylization output image. // FACE_ALIGNMENT - mediapipe::Image -// The face alignment output image. +// The aligned face image that is fed to the face stylization model to +// perform stylization. Also useful for preparing face stylization training +// data. // IMAGE - mediapipe::Image // The input image that the face landmarker runs on and has the pixel data // stored on the target storage (CPU vs GPU). @@ -211,6 +213,7 @@ void ConfigureTensorsToImageCalculator( // input_stream: "NORM_RECT:norm_rect" // output_stream: "IMAGE:image_out" // output_stream: "STYLIZED_IMAGE:stylized_image" +// output_stream: "FACE_ALIGNMENT:face_alignment_image" // options { // [mediapipe.tasks.vision.face_stylizer.proto.FaceStylizerGraphOptions.ext] // { @@ -248,7 +251,7 @@ class FaceStylizerGraph : public core::ModelTaskGraph { ->mutable_face_landmarker_graph_options(), graph[Input(kImageTag)], graph[Input::Optional(kNormRectTag)], graph)); - const ModelResources* face_stylizer_model_resources; + const ModelResources* face_stylizer_model_resources = nullptr; if (output_stylized) { ASSIGN_OR_RETURN( const auto* model_resources, @@ -332,7 +335,7 @@ class FaceStylizerGraph : public core::ModelTaskGraph { auto face_rect = face_to_rect.Out(kNormRectTag); std::optional> face_alignment; - // Output face alignment only. + // Output aligned face only. // In this case, the face stylization model inference is not required. // However, to keep consistent with the inference preprocessing steps, the // ImageToTensorCalculator is still used to perform image rotation, From 4d112c132fed14fa61059b2fe20399a51639c98d Mon Sep 17 00:00:00 2001 From: MediaPipe Team Date: Tue, 2 May 2023 10:35:16 -0700 Subject: [PATCH 11/28] Fix msan errors. PiperOrigin-RevId: 528825081 --- .../autoflip/quality/padding_effect_generator_test.cc | 10 ++++++---- 1 file changed, 6 insertions(+), 4 deletions(-) diff --git a/mediapipe/examples/desktop/autoflip/quality/padding_effect_generator_test.cc b/mediapipe/examples/desktop/autoflip/quality/padding_effect_generator_test.cc index 84b229d80..4c9e96b88 100644 --- a/mediapipe/examples/desktop/autoflip/quality/padding_effect_generator_test.cc +++ b/mediapipe/examples/desktop/autoflip/quality/padding_effect_generator_test.cc @@ -190,14 +190,16 @@ TEST(PaddingEffectGeneratorTest, ScaleToMultipleOfTwo) { double target_aspect_ratio = 0.5; int expect_width = 14; int expect_height = input_height; - auto test_frame = absl::make_unique(/*format=*/ImageFormat::SRGB, - input_width, input_height); + ImageFrame test_frame(/*format=*/ImageFormat::SRGB, input_width, + input_height); + cv::Mat mat = formats::MatView(&test_frame); + mat = cv::Scalar(0, 0, 0); - PaddingEffectGenerator generator(test_frame->Width(), test_frame->Height(), + PaddingEffectGenerator generator(test_frame.Width(), test_frame.Height(), target_aspect_ratio, /*scale_to_multiple_of_two=*/true); ImageFrame result_frame; - MP_ASSERT_OK(generator.Process(*test_frame, 0.3, 40, 0.0, &result_frame)); + MP_ASSERT_OK(generator.Process(test_frame, 0.3, 40, 0.0, &result_frame)); EXPECT_EQ(result_frame.Width(), expect_width); EXPECT_EQ(result_frame.Height(), expect_height); } From 421c9e8e97d8b33332b7564504c4f5090128474d Mon Sep 17 00:00:00 2001 From: MediaPipe Team Date: Tue, 2 May 2023 10:49:40 -0700 Subject: [PATCH 12/28] Fix typo PiperOrigin-RevId: 528829423 --- mediapipe/tasks/cc/vision/face_detector/face_detector_graph.cc | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/mediapipe/tasks/cc/vision/face_detector/face_detector_graph.cc b/mediapipe/tasks/cc/vision/face_detector/face_detector_graph.cc index 8d83ac2c8..2e5f7e416 100644 --- a/mediapipe/tasks/cc/vision/face_detector/face_detector_graph.cc +++ b/mediapipe/tasks/cc/vision/face_detector/face_detector_graph.cc @@ -242,7 +242,7 @@ class FaceDetectorGraph : public core::ModelTaskGraph { auto matrix = preprocessing.Out(kMatrixTag); auto image_size = preprocessing.Out(kImageSizeTag); - // Face detection model inferece. + // Face detection model inference. auto& inference = AddInference( model_resources, subgraph_options.base_options().acceleration(), graph); preprocessed_tensors >> inference.In(kTensorsTag); From 9ce16fddebcf0c7152ef9ecee670a6d352f4cfd6 Mon Sep 17 00:00:00 2001 From: Yuqi Li Date: Tue, 2 May 2023 11:54:23 -0700 Subject: [PATCH 13/28] nit: format the documentation of LandmarksDetectionResult. PiperOrigin-RevId: 528848566 --- .../components/containers/landmark_detection_result.py | 8 +++++--- 1 file changed, 5 insertions(+), 3 deletions(-) diff --git a/mediapipe/tasks/python/components/containers/landmark_detection_result.py b/mediapipe/tasks/python/components/containers/landmark_detection_result.py index c60ad850c..fdb719b92 100644 --- a/mediapipe/tasks/python/components/containers/landmark_detection_result.py +++ b/mediapipe/tasks/python/components/containers/landmark_detection_result.py @@ -39,9 +39,11 @@ _Landmark = landmark_module.Landmark class LandmarksDetectionResult: """Represents the landmarks detection result. - Attributes: landmarks : A list of `NormalizedLandmark` objects. categories : A - list of `Category` objects. world_landmarks : A list of `Landmark` objects. - rect : A `NormalizedRect` object. + Attributes: + landmarks: A list of `NormalizedLandmark` objects. + categories: A list of `Category` objects. + world_landmarks: A list of `Landmark` objects. + rect: A `NormalizedRect` object. """ landmarks: Optional[List[_NormalizedLandmark]] From 4d9812af4396bc8fe72fe7618919bdd189a84afb Mon Sep 17 00:00:00 2001 From: Jiuqiang Tang Date: Tue, 2 May 2023 13:49:24 -0700 Subject: [PATCH 14/28] Pose detector uses `advanced_gpu_api` for gpu inference to resolve unsupported gpu op issue. PiperOrigin-RevId: 528879218 --- .../pose_landmarker/pose_landmarker_graph.cc | 15 ++++++++++++--- 1 file changed, 12 insertions(+), 3 deletions(-) diff --git a/mediapipe/tasks/cc/vision/pose_landmarker/pose_landmarker_graph.cc b/mediapipe/tasks/cc/vision/pose_landmarker/pose_landmarker_graph.cc index 826de5ec4..7889212e8 100644 --- a/mediapipe/tasks/cc/vision/pose_landmarker/pose_landmarker_graph.cc +++ b/mediapipe/tasks/cc/vision/pose_landmarker/pose_landmarker_graph.cc @@ -108,9 +108,18 @@ absl::Status SetSubTaskBaseOptions(const ModelAssetBundleResources& resources, ->mutable_model_asset(), is_copy); } - pose_detector_graph_options->mutable_base_options() - ->mutable_acceleration() - ->CopyFrom(options->base_options().acceleration()); + if (options->base_options().acceleration().has_gpu()) { + core::proto::Acceleration gpu_accel; + gpu_accel.mutable_gpu()->set_use_advanced_gpu_api(true); + pose_detector_graph_options->mutable_base_options() + ->mutable_acceleration() + ->CopyFrom(gpu_accel); + + } else { + pose_detector_graph_options->mutable_base_options() + ->mutable_acceleration() + ->CopyFrom(options->base_options().acceleration()); + } pose_detector_graph_options->mutable_base_options()->set_use_stream_mode( options->base_options().use_stream_mode()); auto* pose_landmarks_detector_graph_options = From bf11fb313e34766a6f4e856f502178ebdaf0808d Mon Sep 17 00:00:00 2001 From: Jiuqiang Tang Date: Tue, 2 May 2023 14:01:21 -0700 Subject: [PATCH 15/28] Expose PoseLandmarker as a public MediaPipe Tasks Python API. PiperOrigin-RevId: 528882303 --- mediapipe/tasks/python/vision/__init__.py | 6 ++++++ 1 file changed, 6 insertions(+) diff --git a/mediapipe/tasks/python/vision/__init__.py b/mediapipe/tasks/python/vision/__init__.py index 75a8bd323..9dee86401 100644 --- a/mediapipe/tasks/python/vision/__init__.py +++ b/mediapipe/tasks/python/vision/__init__.py @@ -25,6 +25,7 @@ import mediapipe.tasks.python.vision.image_embedder import mediapipe.tasks.python.vision.image_segmenter import mediapipe.tasks.python.vision.interactive_segmenter import mediapipe.tasks.python.vision.object_detector +import mediapipe.tasks.python.vision.pose_landmarker FaceDetector = face_detector.FaceDetector FaceDetectorOptions = face_detector.FaceDetectorOptions @@ -54,6 +55,10 @@ InteractiveSegmenterOptions = interactive_segmenter.InteractiveSegmenterOptions InteractiveSegmenterRegionOfInterest = interactive_segmenter.RegionOfInterest ObjectDetector = object_detector.ObjectDetector ObjectDetectorOptions = object_detector.ObjectDetectorOptions +ObjectDetectorResult = object_detector.ObjectDetectorResult +PoseLandmarker = pose_landmarker.PoseLandmarker +PoseLandmarkerOptions = pose_landmarker.PoseLandmarkerOptions +PoseLandmarkerResult = pose_landmarker.PoseLandmarkerResult RunningMode = core.vision_task_running_mode.VisionTaskRunningMode # Remove unnecessary modules to avoid duplication in API docs. @@ -68,4 +73,5 @@ del image_embedder del image_segmenter del interactive_segmenter del object_detector +del pose_landmarker del mediapipe From c698381e48ad70b30fb752b71c2e702ef2ceed8f Mon Sep 17 00:00:00 2001 From: MediaPipe Team Date: Tue, 2 May 2023 18:01:12 -0700 Subject: [PATCH 16/28] Internal change PiperOrigin-RevId: 528939095 --- .../com/google/mediapipe/tasks/vision/BUILD | 1 + .../PoseLandmarksConnections.java | 80 +++++++++++++++++++ 2 files changed, 81 insertions(+) create mode 100644 mediapipe/tasks/java/com/google/mediapipe/tasks/vision/poselandmarker/PoseLandmarksConnections.java diff --git a/mediapipe/tasks/java/com/google/mediapipe/tasks/vision/BUILD b/mediapipe/tasks/java/com/google/mediapipe/tasks/vision/BUILD index 5be0e233f..c27da79c7 100644 --- a/mediapipe/tasks/java/com/google/mediapipe/tasks/vision/BUILD +++ b/mediapipe/tasks/java/com/google/mediapipe/tasks/vision/BUILD @@ -180,6 +180,7 @@ android_library( srcs = [ "poselandmarker/PoseLandmarker.java", "poselandmarker/PoseLandmarkerResult.java", + "poselandmarker/PoseLandmarksConnections.java", ], javacopts = [ "-Xep:AndroidJdkLibsChecker:OFF", diff --git a/mediapipe/tasks/java/com/google/mediapipe/tasks/vision/poselandmarker/PoseLandmarksConnections.java b/mediapipe/tasks/java/com/google/mediapipe/tasks/vision/poselandmarker/PoseLandmarksConnections.java new file mode 100644 index 000000000..9be6a9aeb --- /dev/null +++ b/mediapipe/tasks/java/com/google/mediapipe/tasks/vision/poselandmarker/PoseLandmarksConnections.java @@ -0,0 +1,80 @@ +// Copyright 2023 The MediaPipe Authors. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package com.google.mediapipe.tasks.vision.poselandmarker; + +import com.google.auto.value.AutoValue; +import java.util.Arrays; +import java.util.Collections; +import java.util.HashSet; +import java.util.Set; + +/** Pose landmarks connection constants. */ +public final class PoseLandmarksConnections { + + /** Value class representing pose landmarks connection. */ + @AutoValue + public abstract static class Connection { + static Connection create(int start, int end) { + return new AutoValue_PoseLandmarksConnections_Connection(start, end); + } + + public abstract int start(); + + public abstract int end(); + } + + @SuppressWarnings("ConstantCaseForConstants") + public static final Set POSE_LANDMARKS = + Collections.unmodifiableSet( + new HashSet<>( + Arrays.asList( + Connection.create(0, 1), + Connection.create(1, 2), + Connection.create(2, 3), + Connection.create(3, 7), + Connection.create(0, 4), + Connection.create(4, 5), + Connection.create(5, 6), + Connection.create(6, 8), + Connection.create(9, 10), + Connection.create(11, 12), + Connection.create(11, 13), + Connection.create(13, 15), + Connection.create(15, 17), + Connection.create(15, 19), + Connection.create(15, 21), + Connection.create(17, 19), + Connection.create(12, 14), + Connection.create(14, 16), + Connection.create(16, 18), + Connection.create(16, 20), + Connection.create(16, 22), + Connection.create(18, 20), + Connection.create(11, 23), + Connection.create(12, 24), + Connection.create(23, 24), + Connection.create(23, 25), + Connection.create(24, 26), + Connection.create(25, 27), + Connection.create(26, 28), + Connection.create(27, 29), + Connection.create(28, 30), + Connection.create(29, 31), + Connection.create(30, 32), + Connection.create(27, 31), + Connection.create(28, 32)))); + + private PoseLandmarksConnections() {} +} From 1dea01aecc3ae30acf7698ac9eb5acf120712eb2 Mon Sep 17 00:00:00 2001 From: MediaPipe Team Date: Tue, 2 May 2023 22:55:12 -0700 Subject: [PATCH 17/28] Internal change PiperOrigin-RevId: 528996603 --- .../com/google/mediapipe/tasks/vision/BUILD | 1 + .../HandLandmarksConnections.java | 105 ++++++++++++++++++ 2 files changed, 106 insertions(+) create mode 100644 mediapipe/tasks/java/com/google/mediapipe/tasks/vision/handlandmarker/HandLandmarksConnections.java diff --git a/mediapipe/tasks/java/com/google/mediapipe/tasks/vision/BUILD b/mediapipe/tasks/java/com/google/mediapipe/tasks/vision/BUILD index c27da79c7..a2dbe351a 100644 --- a/mediapipe/tasks/java/com/google/mediapipe/tasks/vision/BUILD +++ b/mediapipe/tasks/java/com/google/mediapipe/tasks/vision/BUILD @@ -213,6 +213,7 @@ android_library( "handlandmarker/HandLandmark.java", "handlandmarker/HandLandmarker.java", "handlandmarker/HandLandmarkerResult.java", + "handlandmarker/HandLandmarksConnections.java", ], javacopts = [ "-Xep:AndroidJdkLibsChecker:OFF", diff --git a/mediapipe/tasks/java/com/google/mediapipe/tasks/vision/handlandmarker/HandLandmarksConnections.java b/mediapipe/tasks/java/com/google/mediapipe/tasks/vision/handlandmarker/HandLandmarksConnections.java new file mode 100644 index 000000000..c60923840 --- /dev/null +++ b/mediapipe/tasks/java/com/google/mediapipe/tasks/vision/handlandmarker/HandLandmarksConnections.java @@ -0,0 +1,105 @@ +// Copyright 2023 The MediaPipe Authors. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package com.google.mediapipe.tasks.vision.handlandmarker; + +import com.google.auto.value.AutoValue; +import java.util.Arrays; +import java.util.Collections; +import java.util.HashSet; +import java.util.Set; +import java.util.stream.Collectors; +import java.util.stream.Stream; + +/** Hand landmarks connection constants. */ +public final class HandLandmarksConnections { + + /** Value class representing hand landmarks connection. */ + @AutoValue + public abstract static class Connection { + static Connection create(int start, int end) { + return new AutoValue_HandLandmarksConnections_Connection(start, end); + } + + public abstract int start(); + + public abstract int end(); + } + + @SuppressWarnings("ConstantCaseForConstants") + public static final Set HAND_PALM_CONNECTIONS = + Collections.unmodifiableSet( + new HashSet<>( + Arrays.asList( + Connection.create(0, 1), + Connection.create(0, 5), + Connection.create(9, 13), + Connection.create(13, 17), + Connection.create(5, 9), + Connection.create(0, 17)))); + + @SuppressWarnings("ConstantCaseForConstants") + public static final Set HAND_THUMB_CONNECTIONS = + Collections.unmodifiableSet( + new HashSet<>( + Arrays.asList( + Connection.create(1, 2), Connection.create(2, 3), Connection.create(3, 4)))); + + @SuppressWarnings("ConstantCaseForConstants") + public static final Set HAND_INDEX_FINGER_CONNECTIONS = + Collections.unmodifiableSet( + new HashSet<>( + Arrays.asList( + Connection.create(5, 6), Connection.create(6, 7), Connection.create(7, 8)))); + + @SuppressWarnings("ConstantCaseForConstants") + public static final Set HAND_MIDDLE_FINGER_CONNECTIONS = + Collections.unmodifiableSet( + new HashSet<>( + Arrays.asList( + Connection.create(9, 10), Connection.create(10, 11), Connection.create(11, 12)))); + + @SuppressWarnings("ConstantCaseForConstants") + public static final Set HAND_RING_FINGER_CONNECTIONS = + Collections.unmodifiableSet( + new HashSet<>( + Arrays.asList( + Connection.create(13, 14), + Connection.create(14, 15), + Connection.create(15, 16)))); + + @SuppressWarnings("ConstantCaseForConstants") + public static final Set HAND_PINKY_FINGER_CONNECTIONS = + Collections.unmodifiableSet( + new HashSet<>( + Arrays.asList( + Connection.create(17, 18), + Connection.create(18, 19), + Connection.create(19, 20)))); + + @SuppressWarnings("ConstantCaseForConstants") + public static final Set HAND_CONNECTIONS = + Collections.unmodifiableSet( + Stream.of( + HAND_PALM_CONNECTIONS.stream(), + HAND_THUMB_CONNECTIONS.stream(), + HAND_INDEX_FINGER_CONNECTIONS.stream(), + HAND_MIDDLE_FINGER_CONNECTIONS.stream(), + HAND_RING_FINGER_CONNECTIONS.stream(), + HAND_PINKY_FINGER_CONNECTIONS.stream()) + .flatMap(i -> i) + .collect(Collectors.toSet())); + + private HandLandmarksConnections() {} +} From 3789156a41cc4952a4a89f333d092d23b3eaa18d Mon Sep 17 00:00:00 2001 From: MediaPipe Team Date: Wed, 3 May 2023 00:26:40 -0700 Subject: [PATCH 18/28] Internal change PiperOrigin-RevId: 529011480 --- .../port/drishti_proto_alias_rules.bzl | 31 +++++++++++++++++++ 1 file changed, 31 insertions(+) create mode 100644 mediapipe/framework/port/drishti_proto_alias_rules.bzl diff --git a/mediapipe/framework/port/drishti_proto_alias_rules.bzl b/mediapipe/framework/port/drishti_proto_alias_rules.bzl new file mode 100644 index 000000000..7df141cbe --- /dev/null +++ b/mediapipe/framework/port/drishti_proto_alias_rules.bzl @@ -0,0 +1,31 @@ +"""Rules implementation for mediapipe_proto_alias.bzl, do not load directly.""" + +def _copy_header_impl(ctx): + source = ctx.attr.source.replace("//", "").replace(":", "/") + files = [] + for dep in ctx.attr.deps: + for header in dep[CcInfo].compilation_context.direct_headers: + if (header.short_path == source): + files.append(header) + if len(files) != 1: + fail("Expected exactly 1 source, got ", str(files)) + dest_file = ctx.actions.declare_file(ctx.attr.filename) + + # Use expand_template() with no substitutions as a simple copier. + ctx.actions.expand_template( + template = files[0], + output = dest_file, + substitutions = {}, + ) + return [DefaultInfo(files = depset([dest_file]))] + +copy_header = rule( + implementation = _copy_header_impl, + attrs = { + "filename": attr.string(), + "source": attr.string(), + "deps": attr.label_list(providers = [CcInfo]), + }, + output_to_genfiles = True, + outputs = {"out": "%{filename}"}, +) From baa8fc68a1b3b3280968fa526413b486ffd5229b Mon Sep 17 00:00:00 2001 From: MediaPipe Team Date: Wed, 3 May 2023 05:55:25 -0700 Subject: [PATCH 19/28] Make uploading to GPU optional in Image.GetGpuBuffer(). PiperOrigin-RevId: 529066617 --- .../tensor/image_to_tensor_converter_frame_buffer.cc | 3 ++- mediapipe/framework/formats/image.h | 10 +++++----- 2 files changed, 7 insertions(+), 6 deletions(-) diff --git a/mediapipe/calculators/tensor/image_to_tensor_converter_frame_buffer.cc b/mediapipe/calculators/tensor/image_to_tensor_converter_frame_buffer.cc index 093f50d76..6f6f6f11c 100644 --- a/mediapipe/calculators/tensor/image_to_tensor_converter_frame_buffer.cc +++ b/mediapipe/calculators/tensor/image_to_tensor_converter_frame_buffer.cc @@ -95,7 +95,8 @@ absl::Status FrameBufferProcessor::Convert(const mediapipe::Image& input, static_cast(range_max) == 255); } - auto input_frame = input.GetGpuBuffer().GetReadView(); + auto input_frame = + input.GetGpuBuffer(/*upload_to_gpu=*/false).GetReadView(); const auto& output_shape = output_tensor.shape(); MP_RETURN_IF_ERROR(ValidateTensorShape(output_shape)); FrameBuffer::Dimension output_dimension{/*width=*/output_shape.dims[2], diff --git a/mediapipe/framework/formats/image.h b/mediapipe/framework/formats/image.h index ffb6362f3..936a3554e 100644 --- a/mediapipe/framework/formats/image.h +++ b/mediapipe/framework/formats/image.h @@ -113,11 +113,11 @@ class Image { #endif // MEDIAPIPE_GPU_BUFFER_USE_CV_PIXEL_BUFFER #endif // !MEDIAPIPE_DISABLE_GPU - // Get a GPU view. Automatically uploads from CPU if needed. - const mediapipe::GpuBuffer GetGpuBuffer() const { -#if !MEDIAPIPE_DISABLE_GPU - if (use_gpu_ == false) ConvertToGpu(); -#endif // !MEDIAPIPE_DISABLE_GPU + // Provides access to the underlying GpuBuffer storage. + // Automatically uploads from CPU to GPU if needed and requested through the + // `upload_to_gpu` argument. + const mediapipe::GpuBuffer GetGpuBuffer(bool upload_to_gpu = true) const { + if (!use_gpu_ && upload_to_gpu) ConvertToGpu(); return gpu_buffer_; } From 09662749ea433f1d5b8b4b9a9b86c341a1574658 Mon Sep 17 00:00:00 2001 From: MediaPipe Team Date: Wed, 3 May 2023 11:58:26 -0700 Subject: [PATCH 20/28] Support scribble input for Interactive Segmenter PiperOrigin-RevId: 529156049 --- .../interactive_segmenter.cc | 18 +++++- .../interactive_segmenter.h | 8 ++- .../interactive_segmenter_test.cc | 63 +++++++++++++++---- mediapipe/util/annotation_renderer.cc | 15 ++++- mediapipe/util/annotation_renderer.h | 5 ++ mediapipe/util/render_data.proto | 5 ++ 6 files changed, 99 insertions(+), 15 deletions(-) diff --git a/mediapipe/tasks/cc/vision/interactive_segmenter/interactive_segmenter.cc b/mediapipe/tasks/cc/vision/interactive_segmenter/interactive_segmenter.cc index af2a3f50c..c0d89c87d 100644 --- a/mediapipe/tasks/cc/vision/interactive_segmenter/interactive_segmenter.cc +++ b/mediapipe/tasks/cc/vision/interactive_segmenter/interactive_segmenter.cc @@ -28,6 +28,7 @@ limitations under the License. #include "mediapipe/framework/formats/image.h" #include "mediapipe/framework/formats/rect.pb.h" #include "mediapipe/tasks/cc/common.h" +#include "mediapipe/tasks/cc/components/containers/keypoint.h" #include "mediapipe/tasks/cc/core/base_options.h" #include "mediapipe/tasks/cc/vision/core/image_processing_options.h" #include "mediapipe/tasks/cc/vision/core/running_mode.h" @@ -60,6 +61,8 @@ constexpr absl::string_view kNormRectTag{"NORM_RECT"}; constexpr absl::string_view kSubgraphTypeName{ "mediapipe.tasks.vision.interactive_segmenter.InteractiveSegmenterGraph"}; +using components::containers::NormalizedKeypoint; + using ::mediapipe::CalculatorGraphConfig; using ::mediapipe::Image; using ::mediapipe::NormalizedRect; @@ -115,7 +118,7 @@ absl::StatusOr ConvertRoiToRenderData(const RegionOfInterest& roi) { case RegionOfInterest::Format::kUnspecified: return absl::InvalidArgumentError( "RegionOfInterest format not specified"); - case RegionOfInterest::Format::kKeyPoint: + case RegionOfInterest::Format::kKeyPoint: { RET_CHECK(roi.keypoint.has_value()); auto* annotation = result.add_render_annotations(); annotation->mutable_color()->set_r(255); @@ -124,6 +127,19 @@ absl::StatusOr ConvertRoiToRenderData(const RegionOfInterest& roi) { point->set_x(roi.keypoint->x); point->set_y(roi.keypoint->y); return result; + } + case RegionOfInterest::Format::kScribble: { + RET_CHECK(roi.scribble.has_value()); + auto* annotation = result.add_render_annotations(); + annotation->mutable_color()->set_r(255); + for (const NormalizedKeypoint& keypoint : *(roi.scribble)) { + auto* point = annotation->mutable_scribble()->add_point(); + point->set_normalized(true); + point->set_x(keypoint.x); + point->set_y(keypoint.y); + } + return result; + } } return absl::UnimplementedError("Unrecognized format"); } diff --git a/mediapipe/tasks/cc/vision/interactive_segmenter/interactive_segmenter.h b/mediapipe/tasks/cc/vision/interactive_segmenter/interactive_segmenter.h index ad4a238df..ad8a558df 100644 --- a/mediapipe/tasks/cc/vision/interactive_segmenter/interactive_segmenter.h +++ b/mediapipe/tasks/cc/vision/interactive_segmenter/interactive_segmenter.h @@ -53,6 +53,7 @@ struct RegionOfInterest { enum class Format { kUnspecified = 0, // Format not specified. kKeyPoint = 1, // Using keypoint to represent ROI. + kScribble = 2, // Using scribble to represent ROI. }; // Specifies the format used to specify the region-of-interest. Note that @@ -61,8 +62,13 @@ struct RegionOfInterest { Format format = Format::kUnspecified; // Represents the ROI in keypoint format, this should be non-nullopt if - // `format` is `KEYPOINT`. + // `format` is `kKeyPoint`. std::optional keypoint; + + // Represents the ROI in scribble format, this should be non-nullopt if + // `format` is `kScribble`. + std::optional> + scribble; }; // Performs interactive segmentation on images. diff --git a/mediapipe/tasks/cc/vision/interactive_segmenter/interactive_segmenter_test.cc b/mediapipe/tasks/cc/vision/interactive_segmenter/interactive_segmenter_test.cc index 443247aea..16d065f61 100644 --- a/mediapipe/tasks/cc/vision/interactive_segmenter/interactive_segmenter_test.cc +++ b/mediapipe/tasks/cc/vision/interactive_segmenter/interactive_segmenter_test.cc @@ -18,9 +18,12 @@ limitations under the License. #include #include #include +#include +#include #include "absl/flags/flag.h" #include "absl/status/status.h" +#include "absl/status/statusor.h" #include "absl/strings/string_view.h" #include "mediapipe/framework/deps/file_path.h" #include "mediapipe/framework/formats/image.h" @@ -179,22 +182,46 @@ TEST_F(CreateFromOptionsTest, FailsWithNeitherOutputSet) { struct InteractiveSegmenterTestParams { std::string test_name; RegionOfInterest::Format format; - NormalizedKeypoint roi; + std::variant> roi; absl::string_view golden_mask_file; float similarity_threshold; }; -using SucceedSegmentationWithRoi = - ::testing::TestWithParam; +class SucceedSegmentationWithRoi + : public ::testing::TestWithParam { + public: + absl::StatusOr TestParamsToTaskOptions() { + const InteractiveSegmenterTestParams& params = GetParam(); + + RegionOfInterest interaction_roi; + interaction_roi.format = params.format; + switch (params.format) { + case (RegionOfInterest::Format::kKeyPoint): { + interaction_roi.keypoint = std::get(params.roi); + break; + } + case (RegionOfInterest::Format::kScribble): { + interaction_roi.scribble = + std::get>(params.roi); + break; + } + default: { + return absl::InvalidArgumentError("Unknown ROI format"); + } + } + + return interaction_roi; + } +}; TEST_P(SucceedSegmentationWithRoi, SucceedsWithCategoryMask) { + MP_ASSERT_OK_AND_ASSIGN(RegionOfInterest interaction_roi, + TestParamsToTaskOptions()); const InteractiveSegmenterTestParams& params = GetParam(); + MP_ASSERT_OK_AND_ASSIGN( Image image, DecodeImageFromFile(JoinPath("./", kTestDataDirectory, kCatsAndDogsJpg))); - RegionOfInterest interaction_roi; - interaction_roi.format = params.format; - interaction_roi.keypoint = params.roi; auto options = std::make_unique(); options->base_options.model_asset_path = JoinPath("./", kTestDataDirectory, kPtmModel); @@ -220,13 +247,13 @@ TEST_P(SucceedSegmentationWithRoi, SucceedsWithCategoryMask) { } TEST_P(SucceedSegmentationWithRoi, SucceedsWithConfidenceMask) { - const auto& params = GetParam(); + MP_ASSERT_OK_AND_ASSIGN(RegionOfInterest interaction_roi, + TestParamsToTaskOptions()); + const InteractiveSegmenterTestParams& params = GetParam(); + MP_ASSERT_OK_AND_ASSIGN( Image image, DecodeImageFromFile(JoinPath("./", kTestDataDirectory, kCatsAndDogsJpg))); - RegionOfInterest interaction_roi; - interaction_roi.format = params.format; - interaction_roi.keypoint = params.roi; auto options = std::make_unique(); options->base_options.model_asset_path = JoinPath("./", kTestDataDirectory, kPtmModel); @@ -253,11 +280,23 @@ TEST_P(SucceedSegmentationWithRoi, SucceedsWithConfidenceMask) { INSTANTIATE_TEST_SUITE_P( SucceedSegmentationWithRoiTest, SucceedSegmentationWithRoi, ::testing::ValuesIn( - {{"PointToDog1", RegionOfInterest::Format::kKeyPoint, + {// Keypoint input. + {"PointToDog1", RegionOfInterest::Format::kKeyPoint, NormalizedKeypoint{0.44, 0.70}, kCatsAndDogsMaskDog1, 0.84f}, {"PointToDog2", RegionOfInterest::Format::kKeyPoint, NormalizedKeypoint{0.66, 0.66}, kCatsAndDogsMaskDog2, - kGoldenMaskSimilarity}}), + kGoldenMaskSimilarity}, + // Scribble input. + {"ScribbleToDog1", RegionOfInterest::Format::kScribble, + std::vector{NormalizedKeypoint{0.44, 0.70}, + NormalizedKeypoint{0.44, 0.71}, + NormalizedKeypoint{0.44, 0.72}}, + kCatsAndDogsMaskDog1, 0.84f}, + {"ScribbleToDog2", RegionOfInterest::Format::kScribble, + std::vector{NormalizedKeypoint{0.66, 0.66}, + NormalizedKeypoint{0.66, 0.67}, + NormalizedKeypoint{0.66, 0.68}}, + kCatsAndDogsMaskDog2, kGoldenMaskSimilarity}}), [](const ::testing::TestParamInfo& info) { return info.param.test_name; }); diff --git a/mediapipe/util/annotation_renderer.cc b/mediapipe/util/annotation_renderer.cc index 5188da896..d8516f9bc 100644 --- a/mediapipe/util/annotation_renderer.cc +++ b/mediapipe/util/annotation_renderer.cc @@ -22,6 +22,7 @@ #include "mediapipe/framework/port/logging.h" #include "mediapipe/framework/port/vector.h" #include "mediapipe/util/color.pb.h" +#include "mediapipe/util/render_data.pb.h" namespace mediapipe { namespace { @@ -112,6 +113,8 @@ void AnnotationRenderer::RenderDataOnImage(const RenderData& render_data) { DrawGradientLine(annotation); } else if (annotation.data_case() == RenderAnnotation::kArrow) { DrawArrow(annotation); + } else if (annotation.data_case() == RenderAnnotation::kScribble) { + DrawScribble(annotation); } else { LOG(FATAL) << "Unknown annotation type: " << annotation.data_case(); } @@ -442,7 +445,11 @@ void AnnotationRenderer::DrawArrow(const RenderAnnotation& annotation) { } void AnnotationRenderer::DrawPoint(const RenderAnnotation& annotation) { - const auto& point = annotation.point(); + DrawPoint(annotation.point(), annotation); +} + +void AnnotationRenderer::DrawPoint(const RenderAnnotation::Point& point, + const RenderAnnotation& annotation) { int x = -1; int y = -1; if (point.normalized()) { @@ -460,6 +467,12 @@ void AnnotationRenderer::DrawPoint(const RenderAnnotation& annotation) { cv::circle(mat_image_, point_to_draw, thickness, color, -1); } +void AnnotationRenderer::DrawScribble(const RenderAnnotation& annotation) { + for (const RenderAnnotation::Point& point : annotation.scribble().point()) { + DrawPoint(point, annotation); + } +} + void AnnotationRenderer::DrawLine(const RenderAnnotation& annotation) { int x_start = -1; int y_start = -1; diff --git a/mediapipe/util/annotation_renderer.h b/mediapipe/util/annotation_renderer.h index 380bc3614..ae0cf976e 100644 --- a/mediapipe/util/annotation_renderer.h +++ b/mediapipe/util/annotation_renderer.h @@ -96,6 +96,11 @@ class AnnotationRenderer { // Draws a point on the image as described in the annotation. void DrawPoint(const RenderAnnotation& annotation); + void DrawPoint(const RenderAnnotation::Point& point, + const RenderAnnotation& annotation); + + // Draws scribbles on the image as described in the annotation. + void DrawScribble(const RenderAnnotation& annotation); // Draws a line segment on the image as described in the annotation. void DrawLine(const RenderAnnotation& annotation); diff --git a/mediapipe/util/render_data.proto b/mediapipe/util/render_data.proto index fee02fff3..897d5fa37 100644 --- a/mediapipe/util/render_data.proto +++ b/mediapipe/util/render_data.proto @@ -131,6 +131,10 @@ message RenderAnnotation { optional Color color2 = 7; } + message Scribble { + repeated Point point = 1; + } + message Arrow { // The arrow head will be drawn at (x_end, y_end). optional double x_start = 1; @@ -192,6 +196,7 @@ message RenderAnnotation { RoundedRectangle rounded_rectangle = 9; FilledRoundedRectangle filled_rounded_rectangle = 10; GradientLine gradient_line = 14; + Scribble scribble = 15; } // Thickness for drawing the annotation. From c78055921492d567d89e36b3d75d069f6e376927 Mon Sep 17 00:00:00 2001 From: MediaPipe Team Date: Wed, 3 May 2023 12:18:16 -0700 Subject: [PATCH 21/28] Internal change PiperOrigin-RevId: 529161249 --- mediapipe/tasks/python/vision/__init__.py | 1 + .../tasks/python/vision/pose_landmarker.py | 49 +++++++++++++++++++ 2 files changed, 50 insertions(+) diff --git a/mediapipe/tasks/python/vision/__init__.py b/mediapipe/tasks/python/vision/__init__.py index 9dee86401..cea950ea7 100644 --- a/mediapipe/tasks/python/vision/__init__.py +++ b/mediapipe/tasks/python/vision/__init__.py @@ -59,6 +59,7 @@ ObjectDetectorResult = object_detector.ObjectDetectorResult PoseLandmarker = pose_landmarker.PoseLandmarker PoseLandmarkerOptions = pose_landmarker.PoseLandmarkerOptions PoseLandmarkerResult = pose_landmarker.PoseLandmarkerResult +PoseLandmarksConnections = pose_landmarker.PoseLandmarksConnections RunningMode = core.vision_task_running_mode.VisionTaskRunningMode # Remove unnecessary modules to avoid duplication in API docs. diff --git a/mediapipe/tasks/python/vision/pose_landmarker.py b/mediapipe/tasks/python/vision/pose_landmarker.py index b91eb0326..3ff7edb0a 100644 --- a/mediapipe/tasks/python/vision/pose_landmarker.py +++ b/mediapipe/tasks/python/vision/pose_landmarker.py @@ -132,6 +132,55 @@ def _build_landmarker_result( return pose_landmarker_result +class PoseLandmarksConnections: + """The connections between pose landmarks.""" + + @dataclasses.dataclass + class Connection: + """The connection class for pose landmarks.""" + + start: int + end: int + + POSE_LANDMARKS: List[Connection] = [ + Connection(0, 1), + Connection(1, 2), + Connection(2, 3), + Connection(3, 7), + Connection(0, 4), + Connection(4, 5), + Connection(5, 6), + Connection(6, 8), + Connection(9, 10), + Connection(11, 12), + Connection(11, 13), + Connection(13, 15), + Connection(15, 17), + Connection(15, 19), + Connection(15, 21), + Connection(17, 19), + Connection(12, 14), + Connection(14, 16), + Connection(16, 18), + Connection(16, 20), + Connection(16, 22), + Connection(18, 20), + Connection(11, 23), + Connection(12, 24), + Connection(23, 24), + Connection(23, 25), + Connection(24, 26), + Connection(25, 27), + Connection(26, 28), + Connection(27, 29), + Connection(28, 30), + Connection(29, 31), + Connection(30, 32), + Connection(27, 31), + Connection(28, 32) + ] + + @dataclasses.dataclass class PoseLandmarkerOptions: """Options for the pose landmarker task. From 7c955246aaf8f9a4d0200317c0500591d877f8a7 Mon Sep 17 00:00:00 2001 From: MediaPipe Team Date: Wed, 3 May 2023 13:19:51 -0700 Subject: [PATCH 22/28] Support scribble input for Interactive Segmenter Java API PiperOrigin-RevId: 529177660 --- .../InteractiveSegmenter.java | 23 ++++++++ .../InteractiveSegmenterTest.java | 57 ++++++++++++++++++- 2 files changed, 78 insertions(+), 2 deletions(-) diff --git a/mediapipe/tasks/java/com/google/mediapipe/tasks/vision/interactivesegmenter/InteractiveSegmenter.java b/mediapipe/tasks/java/com/google/mediapipe/tasks/vision/interactivesegmenter/InteractiveSegmenter.java index 52a5f2a67..e9ff1f2b5 100644 --- a/mediapipe/tasks/java/com/google/mediapipe/tasks/vision/interactivesegmenter/InteractiveSegmenter.java +++ b/mediapipe/tasks/java/com/google/mediapipe/tasks/vision/interactivesegmenter/InteractiveSegmenter.java @@ -502,6 +502,7 @@ public final class InteractiveSegmenter extends BaseVisionTaskApi { /** The Region-Of-Interest (ROI) to interact with. */ public static class RegionOfInterest { private NormalizedKeypoint keypoint; + private List scribble; private RegionOfInterest() {} @@ -514,6 +515,16 @@ public final class InteractiveSegmenter extends BaseVisionTaskApi { roi.keypoint = keypoint; return roi; } + + /** + * Creates a {@link RegionOfInterest} instance representing scribbles over the object that the + * user wants to segment. + */ + public static RegionOfInterest create(List scribble) { + RegionOfInterest roi = new RegionOfInterest(); + roi.scribble = scribble; + return roi; + } } /** @@ -535,6 +546,18 @@ public final class InteractiveSegmenter extends BaseVisionTaskApi { .setX(roi.keypoint.x()) .setY(roi.keypoint.y()))) .build(); + } else if (roi.scribble != null) { + RenderAnnotation.Scribble.Builder scribbleBuilder = RenderAnnotation.Scribble.newBuilder(); + for (NormalizedKeypoint p : roi.scribble) { + scribbleBuilder.addPoint(RenderAnnotation.Point.newBuilder().setX(p.x()).setY(p.y())); + } + + return builder + .addRenderAnnotations( + RenderAnnotation.newBuilder() + .setColor(Color.newBuilder().setR(255)) + .setScribble(scribbleBuilder)) + .build(); } throw new IllegalArgumentException( diff --git a/mediapipe/tasks/javatests/com/google/mediapipe/tasks/vision/interactivesegmenter/InteractiveSegmenterTest.java b/mediapipe/tasks/javatests/com/google/mediapipe/tasks/vision/interactivesegmenter/InteractiveSegmenterTest.java index 506036ba2..a534970f7 100644 --- a/mediapipe/tasks/javatests/com/google/mediapipe/tasks/vision/interactivesegmenter/InteractiveSegmenterTest.java +++ b/mediapipe/tasks/javatests/com/google/mediapipe/tasks/vision/interactivesegmenter/InteractiveSegmenterTest.java @@ -27,6 +27,7 @@ import com.google.mediapipe.tasks.core.BaseOptions; import com.google.mediapipe.tasks.vision.imagesegmenter.ImageSegmenterResult; import com.google.mediapipe.tasks.vision.interactivesegmenter.InteractiveSegmenter.InteractiveSegmenterOptions; import java.io.InputStream; +import java.util.ArrayList; import java.util.List; import org.junit.Test; import org.junit.runner.RunWith; @@ -36,7 +37,8 @@ import org.junit.runners.Suite.SuiteClasses; /** Test for {@link InteractiveSegmenter}. */ @RunWith(Suite.class) @SuiteClasses({ - InteractiveSegmenterTest.General.class, + InteractiveSegmenterTest.KeypointRoi.class, + InteractiveSegmenterTest.ScribbleRoi.class, }) public class InteractiveSegmenterTest { private static final String DEEPLAB_MODEL_FILE = "ptm_512_hdt_ptm_woid.tflite"; @@ -44,7 +46,7 @@ public class InteractiveSegmenterTest { private static final int MAGNIFICATION_FACTOR = 10; @RunWith(AndroidJUnit4.class) - public static final class General extends InteractiveSegmenterTest { + public static final class KeypointRoi extends InteractiveSegmenterTest { @Test public void segment_successWithCategoryMask() throws Exception { final String inputImageName = CATS_AND_DOGS_IMAGE; @@ -86,6 +88,57 @@ public class InteractiveSegmenterTest { } } + @RunWith(AndroidJUnit4.class) + public static final class ScribbleRoi extends InteractiveSegmenterTest { + @Test + public void segment_successWithCategoryMask() throws Exception { + final String inputImageName = CATS_AND_DOGS_IMAGE; + ArrayList scribble = new ArrayList<>(); + scribble.add(NormalizedKeypoint.create(0.25f, 0.9f)); + scribble.add(NormalizedKeypoint.create(0.25f, 0.91f)); + scribble.add(NormalizedKeypoint.create(0.25f, 0.92f)); + final InteractiveSegmenter.RegionOfInterest roi = + InteractiveSegmenter.RegionOfInterest.create(scribble); + InteractiveSegmenterOptions options = + InteractiveSegmenterOptions.builder() + .setBaseOptions(BaseOptions.builder().setModelAssetPath(DEEPLAB_MODEL_FILE).build()) + .setOutputConfidenceMasks(false) + .setOutputCategoryMask(true) + .build(); + InteractiveSegmenter imageSegmenter = + InteractiveSegmenter.createFromOptions( + ApplicationProvider.getApplicationContext(), options); + MPImage image = getImageFromAsset(inputImageName); + ImageSegmenterResult actualResult = imageSegmenter.segment(image, roi); + assertThat(actualResult.categoryMask().isPresent()).isTrue(); + } + + @Test + public void segment_successWithConfidenceMask() throws Exception { + final String inputImageName = CATS_AND_DOGS_IMAGE; + ArrayList scribble = new ArrayList<>(); + scribble.add(NormalizedKeypoint.create(0.25f, 0.9f)); + scribble.add(NormalizedKeypoint.create(0.25f, 0.91f)); + scribble.add(NormalizedKeypoint.create(0.25f, 0.92f)); + final InteractiveSegmenter.RegionOfInterest roi = + InteractiveSegmenter.RegionOfInterest.create(scribble); + InteractiveSegmenterOptions options = + InteractiveSegmenterOptions.builder() + .setBaseOptions(BaseOptions.builder().setModelAssetPath(DEEPLAB_MODEL_FILE).build()) + .setOutputConfidenceMasks(true) + .setOutputCategoryMask(false) + .build(); + InteractiveSegmenter imageSegmenter = + InteractiveSegmenter.createFromOptions( + ApplicationProvider.getApplicationContext(), options); + ImageSegmenterResult actualResult = + imageSegmenter.segment(getImageFromAsset(inputImageName), roi); + assertThat(actualResult.confidenceMasks().isPresent()).isTrue(); + List confidenceMasks = actualResult.confidenceMasks().get(); + assertThat(confidenceMasks.size()).isEqualTo(2); + } + } + private static MPImage getImageFromAsset(String filePath) throws Exception { AssetManager assetManager = ApplicationProvider.getApplicationContext().getAssets(); InputStream istr = assetManager.open(filePath); From 606b83ac65911c384537044831f15ad7a90d9573 Mon Sep 17 00:00:00 2001 From: MediaPipe Team Date: Wed, 3 May 2023 13:30:08 -0700 Subject: [PATCH 23/28] Internal change PiperOrigin-RevId: 529180655 --- mediapipe/tasks/python/vision/__init__.py | 1 + .../tasks/python/vision/hand_landmarker.py | 59 +++++++++++++++++++ 2 files changed, 60 insertions(+) diff --git a/mediapipe/tasks/python/vision/__init__.py b/mediapipe/tasks/python/vision/__init__.py index cea950ea7..3c3f34db0 100644 --- a/mediapipe/tasks/python/vision/__init__.py +++ b/mediapipe/tasks/python/vision/__init__.py @@ -42,6 +42,7 @@ GestureRecognizerResult = gesture_recognizer.GestureRecognizerResult HandLandmarker = hand_landmarker.HandLandmarker HandLandmarkerOptions = hand_landmarker.HandLandmarkerOptions HandLandmarkerResult = hand_landmarker.HandLandmarkerResult +HandLandmarksConnections = hand_landmarker.HandLandmarksConnections ImageClassifier = image_classifier.ImageClassifier ImageClassifierOptions = image_classifier.ImageClassifierOptions ImageClassifierResult = image_classifier.ImageClassifierResult diff --git a/mediapipe/tasks/python/vision/hand_landmarker.py b/mediapipe/tasks/python/vision/hand_landmarker.py index 1f2c629d2..e781c8882 100644 --- a/mediapipe/tasks/python/vision/hand_landmarker.py +++ b/mediapipe/tasks/python/vision/hand_landmarker.py @@ -82,6 +82,65 @@ class HandLandmark(enum.IntEnum): PINKY_TIP = 20 +class HandLandmarksConnections: + """The connections between hand landmarks.""" + + @dataclasses.dataclass + class Connection: + """The connection class for hand landmarks.""" + + start: int + end: int + + HAND_PALM_CONNECTIONS: List[Connection] = [ + Connection(0, 1), + Connection(1, 5), + Connection(9, 13), + Connection(13, 17), + Connection(5, 9), + Connection(0, 17), + ] + + HAND_THUMB_CONNECTIONS: List[Connection] = [ + Connection(1, 2), + Connection(2, 3), + Connection(3, 4), + ] + + HAND_INDEX_FINGER_CONNECTIONS: List[Connection] = [ + Connection(5, 6), + Connection(6, 7), + Connection(7, 8), + ] + + HAND_MIDDLE_FINGER_CONNECTIONS: List[Connection] = [ + Connection(9, 10), + Connection(10, 11), + Connection(11, 12), + ] + + HAND_RING_FINGER_CONNECTIONS: List[Connection] = [ + Connection(13, 14), + Connection(14, 15), + Connection(15, 16), + ] + + HAND_PINKY_FINGER_CONNECTIONS: List[Connection] = [ + Connection(17, 18), + Connection(18, 19), + Connection(19, 20), + ] + + HAND_CONNECTIONS: List[Connection] = ( + HAND_PALM_CONNECTIONS + + HAND_THUMB_CONNECTIONS + + HAND_INDEX_FINGER_CONNECTIONS + + HAND_MIDDLE_FINGER_CONNECTIONS + + HAND_RING_FINGER_CONNECTIONS + + HAND_PINKY_FINGER_CONNECTIONS + ) + + @dataclasses.dataclass class HandLandmarkerResult: """The hand landmarks result from HandLandmarker, where each vector element represents a single hand detected in the image. From e428bdb7e89aa101e442d05cc75157915f0ee5c3 Mon Sep 17 00:00:00 2001 From: Yuqi Li Date: Wed, 3 May 2023 13:32:20 -0700 Subject: [PATCH 24/28] internal change. PiperOrigin-RevId: 529181374 --- .../tasks/vision/core/BaseVisionTaskApi.java | 36 +++++++++++-------- 1 file changed, 21 insertions(+), 15 deletions(-) diff --git a/mediapipe/tasks/java/com/google/mediapipe/tasks/vision/core/BaseVisionTaskApi.java b/mediapipe/tasks/java/com/google/mediapipe/tasks/vision/core/BaseVisionTaskApi.java index 070806522..5964cef2c 100644 --- a/mediapipe/tasks/java/com/google/mediapipe/tasks/vision/core/BaseVisionTaskApi.java +++ b/mediapipe/tasks/java/com/google/mediapipe/tasks/vision/core/BaseVisionTaskApi.java @@ -77,11 +77,13 @@ public class BaseVisionTaskApi implements AutoCloseable { } Map inputPackets = new HashMap<>(); inputPackets.put(imageStreamName, runner.getPacketCreator().createImage(image)); - inputPackets.put( - normRectStreamName, - runner - .getPacketCreator() - .createProto(convertToNormalizedRect(imageProcessingOptions, image))); + if (!normRectStreamName.isEmpty()) { + inputPackets.put( + normRectStreamName, + runner + .getPacketCreator() + .createProto(convertToNormalizedRect(imageProcessingOptions, image))); + } return runner.process(inputPackets); } @@ -105,11 +107,13 @@ public class BaseVisionTaskApi implements AutoCloseable { } Map inputPackets = new HashMap<>(); inputPackets.put(imageStreamName, runner.getPacketCreator().createImage(image)); - inputPackets.put( - normRectStreamName, - runner - .getPacketCreator() - .createProto(convertToNormalizedRect(imageProcessingOptions, image))); + if (!normRectStreamName.isEmpty()) { + inputPackets.put( + normRectStreamName, + runner + .getPacketCreator() + .createProto(convertToNormalizedRect(imageProcessingOptions, image))); + } return runner.process(inputPackets, timestampMs * MICROSECONDS_PER_MILLISECOND); } @@ -133,11 +137,13 @@ public class BaseVisionTaskApi implements AutoCloseable { } Map inputPackets = new HashMap<>(); inputPackets.put(imageStreamName, runner.getPacketCreator().createImage(image)); - inputPackets.put( - normRectStreamName, - runner - .getPacketCreator() - .createProto(convertToNormalizedRect(imageProcessingOptions, image))); + if (!normRectStreamName.isEmpty()) { + inputPackets.put( + normRectStreamName, + runner + .getPacketCreator() + .createProto(convertToNormalizedRect(imageProcessingOptions, image))); + } runner.send(inputPackets, timestampMs * MICROSECONDS_PER_MILLISECOND); } From a09e39d431ca216f35bc45c52985449c4b23a189 Mon Sep 17 00:00:00 2001 From: MediaPipe Team Date: Wed, 3 May 2023 15:44:42 -0700 Subject: [PATCH 25/28] Add TransformerParameters proto PiperOrigin-RevId: 529213840 --- .../cc/components/processors/proto/BUILD | 5 ++ .../processors/proto/transformer_params.proto | 46 +++++++++++++++++++ 2 files changed, 51 insertions(+) create mode 100644 mediapipe/tasks/cc/components/processors/proto/transformer_params.proto diff --git a/mediapipe/tasks/cc/components/processors/proto/BUILD b/mediapipe/tasks/cc/components/processors/proto/BUILD index 82d4ea21b..a45c91633 100644 --- a/mediapipe/tasks/cc/components/processors/proto/BUILD +++ b/mediapipe/tasks/cc/components/processors/proto/BUILD @@ -93,3 +93,8 @@ mediapipe_proto_library( "//mediapipe/framework:calculator_proto", ], ) + +mediapipe_proto_library( + name = "transformer_params_proto", + srcs = ["transformer_params.proto"], +) diff --git a/mediapipe/tasks/cc/components/processors/proto/transformer_params.proto b/mediapipe/tasks/cc/components/processors/proto/transformer_params.proto new file mode 100644 index 000000000..8c1daf277 --- /dev/null +++ b/mediapipe/tasks/cc/components/processors/proto/transformer_params.proto @@ -0,0 +1,46 @@ +/* Copyright 2023 The MediaPipe Authors. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +syntax = "proto3"; + +package mediapipe.tasks.components.processors.proto; + +option java_package = "com.google.mediapipe.tasks.components.processors.proto"; +option java_outer_classname = "TransformerParametersProto"; + +// The parameters of transformer (https://arxiv.org/pdf/1706.03762.pdf) +message TransformerParameters { + // Batch size of tensors. + int32 batch_size = 1; + + // Maximum sequence length of the input/output tensor. + int32 max_seq_length = 2; + + // Embedding dimension (or model dimension), `d_model` in the paper. + // `d_k` == `d_v` == `d_model`/`h`. + int32 embedding_dim = 3; + + // Hidden dimension used in the feedforward layer, `d_ff` in the paper. + int32 hidden_dimension = 4; + + // Head dimension, `d_k` or `d_v` in the paper. + int32 head_dimension = 5; + + // Number of heads, `h` in the paper. + int32 num_heads = 6; + + // Number of stacked transformers, `N` in the paper. + int32 num_stacks = 7; +} From b350f7239482b542db325f5fa52d5a3e9193778a Mon Sep 17 00:00:00 2001 From: MediaPipe Team Date: Wed, 3 May 2023 16:10:51 -0700 Subject: [PATCH 26/28] Support MultiHW AVG Architecture for object detector PiperOrigin-RevId: 529221127 --- .../python/vision/object_detector/model.py | 9 +++++++-- .../vision/object_detector/model_spec.py | 19 +++++++++++++++++++ 2 files changed, 26 insertions(+), 2 deletions(-) diff --git a/mediapipe/model_maker/python/vision/object_detector/model.py b/mediapipe/model_maker/python/vision/object_detector/model.py index eac669786..e3eb3a651 100644 --- a/mediapipe/model_maker/python/vision/object_detector/model.py +++ b/mediapipe/model_maker/python/vision/object_detector/model.py @@ -59,7 +59,9 @@ class ObjectDetectorModel(tf.keras.Model): self._num_classes = num_classes self._model = self._build_model() checkpoint_folder = self._model_spec.downloaded_files.get_path() - checkpoint_file = os.path.join(checkpoint_folder, 'ckpt-277200') + checkpoint_file = os.path.join( + checkpoint_folder, self._model_spec.checkpoint_name + ) self.load_checkpoint(checkpoint_file) self._model.summary() self.loss_trackers = [ @@ -80,7 +82,10 @@ class ObjectDetectorModel(tf.keras.Model): num_scales=3, aspect_ratios=[0.5, 1.0, 2.0], anchor_size=3 ), backbone=configs.backbones.Backbone( - type='mobilenet', mobilenet=configs.backbones.MobileNet() + type='mobilenet', + mobilenet=configs.backbones.MobileNet( + model_id=self._model_spec.model_id + ), ), decoder=configs.decoders.Decoder( type='fpn', diff --git a/mediapipe/model_maker/python/vision/object_detector/model_spec.py b/mediapipe/model_maker/python/vision/object_detector/model_spec.py index 2ce838c71..9c89c4ed0 100644 --- a/mediapipe/model_maker/python/vision/object_detector/model_spec.py +++ b/mediapipe/model_maker/python/vision/object_detector/model_spec.py @@ -26,6 +26,12 @@ MOBILENET_V2_FILES = file_util.DownloadedFiles( is_folder=True, ) +MOBILENET_MULTI_AVG_FILES = file_util.DownloadedFiles( + 'object_detector/mobilenetmultiavg', + 'https://storage.googleapis.com/tf_model_garden/vision/qat/mobilenetv3.5_ssd_coco/mobilenetv3.5_ssd_i256_ckpt.tar.gz', + is_folder=True, +) + @dataclasses.dataclass class ModelSpec(object): @@ -38,13 +44,25 @@ class ModelSpec(object): stddev_rgb = (127.5,) downloaded_files: file_util.DownloadedFiles + checkpoint_name: str input_image_shape: List[int] + model_id: str mobilenet_v2_spec = functools.partial( ModelSpec, downloaded_files=MOBILENET_V2_FILES, + checkpoint_name='ckpt-277200', input_image_shape=[256, 256, 3], + model_id='MobileNetV2', +) + +mobilenet_multi_avg_spec = functools.partial( + ModelSpec, + downloaded_files=MOBILENET_MULTI_AVG_FILES, + checkpoint_name='ckpt-277200', + input_image_shape=[256, 256, 3], + model_id='MobileNetMultiAVG', ) @@ -53,6 +71,7 @@ class SupportedModels(enum.Enum): """Predefined object detector model specs supported by Model Maker.""" MOBILENET_V2 = mobilenet_v2_spec + MOBILENET_MULTI_AVG = mobilenet_multi_avg_spec @classmethod def get(cls, spec: 'SupportedModels') -> 'ModelSpec': From c6e3f0828248c50ff62fe51113b18bb1b7850698 Mon Sep 17 00:00:00 2001 From: Jiuqiang Tang Date: Wed, 3 May 2023 16:36:45 -0700 Subject: [PATCH 27/28] Expose FaceAligner and LanguageDetector to be public MediaPipe Tasks Python API. PiperOrigin-RevId: 529227382 --- mediapipe/tasks/python/text/__init__.py | 5 +++++ mediapipe/tasks/python/vision/__init__.py | 4 ++++ 2 files changed, 9 insertions(+) diff --git a/mediapipe/tasks/python/text/__init__.py b/mediapipe/tasks/python/text/__init__.py index 5aa221c33..66c62cafc 100644 --- a/mediapipe/tasks/python/text/__init__.py +++ b/mediapipe/tasks/python/text/__init__.py @@ -14,9 +14,13 @@ """MediaPipe Tasks Text API.""" +import mediapipe.tasks.python.text.language_detector import mediapipe.tasks.python.text.text_classifier import mediapipe.tasks.python.text.text_embedder +LanguageDetector = language_detector.LanguageDetector +LanguageDetectorOptions = language_detector.LanguageDetectorOptions +LanguageDetectorResult = language_detector.LanguageDetectorResult TextClassifier = text_classifier.TextClassifier TextClassifierOptions = text_classifier.TextClassifierOptions TextClassifierResult = text_classifier.TextClassifierResult @@ -26,5 +30,6 @@ TextEmbedderResult = text_embedder.TextEmbedderResult # Remove unnecessary modules to avoid duplication in API docs. del mediapipe +del language_detector del text_classifier del text_embedder diff --git a/mediapipe/tasks/python/vision/__init__.py b/mediapipe/tasks/python/vision/__init__.py index 3c3f34db0..c88dbb9ad 100644 --- a/mediapipe/tasks/python/vision/__init__.py +++ b/mediapipe/tasks/python/vision/__init__.py @@ -15,6 +15,7 @@ """MediaPipe Tasks Vision API.""" import mediapipe.tasks.python.vision.core +import mediapipe.tasks.python.vision.face_aligner import mediapipe.tasks.python.vision.face_detector import mediapipe.tasks.python.vision.face_landmarker import mediapipe.tasks.python.vision.face_stylizer @@ -27,6 +28,8 @@ import mediapipe.tasks.python.vision.interactive_segmenter import mediapipe.tasks.python.vision.object_detector import mediapipe.tasks.python.vision.pose_landmarker +FaceAligner = face_aligner.FaceAligner +FaceAlignerOptions = face_aligner.FaceAlignerOptions FaceDetector = face_detector.FaceDetector FaceDetectorOptions = face_detector.FaceDetectorOptions FaceDetectorResult = face_detector.FaceDetectorResult @@ -65,6 +68,7 @@ RunningMode = core.vision_task_running_mode.VisionTaskRunningMode # Remove unnecessary modules to avoid duplication in API docs. del core +del face_aligner del face_detector del face_landmarker del face_stylizer From 381ffcb474c6df8cb7f17f0beced787727a7cd1c Mon Sep 17 00:00:00 2001 From: Prianka Liz Kariat Date: Thu, 4 May 2023 17:10:07 +0530 Subject: [PATCH 28/28] Added hash implementation for iOS normalized keypoint --- .../tasks/ios/components/containers/sources/MPPDetection.m | 7 ++++++- 1 file changed, 6 insertions(+), 1 deletion(-) diff --git a/mediapipe/tasks/ios/components/containers/sources/MPPDetection.m b/mediapipe/tasks/ios/components/containers/sources/MPPDetection.m index c245478db..c61cf0b39 100644 --- a/mediapipe/tasks/ios/components/containers/sources/MPPDetection.m +++ b/mediapipe/tasks/ios/components/containers/sources/MPPDetection.m @@ -28,7 +28,12 @@ return self; } -// TODO: Implement hash +- (NSUInteger)hash { + NSUInteger nonNullPropertiesHash = + @(self.location.x).hash ^ @(self.location.y).hash ^ @(self.score).hash; + + return self.label ? nonNullPropertiesHash ^ self.label.hash : nonNullPropertiesHash; +} - (BOOL)isEqual:(nullable id)object { if (!object) {