diff --git a/.github/ISSUE_TEMPLATE/11-tasks-issue.md b/.github/ISSUE_TEMPLATE/11-tasks-issue.md new file mode 100644 index 000000000..4e9ae721d --- /dev/null +++ b/.github/ISSUE_TEMPLATE/11-tasks-issue.md @@ -0,0 +1,25 @@ +--- +name: "Tasks Issue" +about: Use this template for assistance with using MediaPipe Tasks (developers.google.com/mediapipe/solutions) to deploy on-device ML solutions (e.g. gesture recognition etc.) on supported platforms. +labels: type:support + +--- +Please make sure that this is a [Tasks](https://developers.google.com/mediapipe/solutions) issue. + +**System information** (Please provide as much relevant information as possible) +- Have I written custom code (as opposed to using a stock example script provided in MediaPipe): +- OS Platform and Distribution (e.g., Linux Ubuntu 16.04, Android 11, iOS 14.4): +- MediaPipe Tasks SDK version: +- Task name (e.g. Object detection, Gesture recognition etc.): +- Programming Language and version ( e.g. C++, Python, Java): + +**Describe the expected behavior:** + +**Standalone code you may have used to try to get what you need :** + +If there is a problem, provide a reproducible test case that is the bare minimum necessary to generate the problem. If possible, please share a link to Colab, GitHub repo link or anything that we can use to reproduce the problem: + +**Other info / Complete Logs :** +Include any logs or source code that would be helpful to +diagnose the problem. If including tracebacks, please include the full +traceback. Large logs and files should be attached: diff --git a/.github/ISSUE_TEMPLATE/12-model-maker-issue.md b/.github/ISSUE_TEMPLATE/12-model-maker-issue.md new file mode 100644 index 000000000..31e8d7f1b --- /dev/null +++ b/.github/ISSUE_TEMPLATE/12-model-maker-issue.md @@ -0,0 +1,25 @@ +--- +name: "Model Maker Issue" +about: Use this template for assistance with using MediaPipe Model Maker (developers.google.com/mediapipe/solutions) to create custom on-device ML solutions. +labels: type:support + +--- +Please make sure that this is a [Model Maker](https://developers.google.com/mediapipe/solutions) issue. + +**System information** (Please provide as much relevant information as possible) +- Have I written custom code (as opposed to using a stock example script provided in MediaPipe): +- OS Platform and Distribution (e.g., Linux Ubuntu 16.04): +- Python version (e.g. 3.8): +- [MediaPipe Model Maker version](https://pypi.org/project/mediapipe-model-maker/): +- Task name (e.g. Image classification, Gesture recognition etc.): + +**Describe the expected behavior:** + +**Standalone code you may have used to try to get what you need :** + +If there is a problem, provide a reproducible test case that is the bare minimum necessary to generate the problem. If possible, please share a link to Colab, GitHub repo link or anything that we can use to reproduce the problem: + +**Other info / Complete Logs :** +Include any logs or source code that would be helpful to +diagnose the problem. If including tracebacks, please include the full +traceback. Large logs and files should be attached: diff --git a/.github/ISSUE_TEMPLATE/10-solution-issue.md b/.github/ISSUE_TEMPLATE/13-solution-issue.md similarity index 82% rename from .github/ISSUE_TEMPLATE/10-solution-issue.md rename to .github/ISSUE_TEMPLATE/13-solution-issue.md index a5332cb36..bf0d613c9 100644 --- a/.github/ISSUE_TEMPLATE/10-solution-issue.md +++ b/.github/ISSUE_TEMPLATE/13-solution-issue.md @@ -1,6 +1,6 @@ --- -name: "Solution Issue" -about: Use this template for assistance with a specific mediapipe solution, such as "Pose" or "Iris", including inference model usage/training, solution-specific calculators, etc. +name: "Solution (legacy) Issue" +about: Use this template for assistance with a specific Mediapipe solution (google.github.io/mediapipe/solutions) such as "Pose", including inference model usage/training, solution-specific calculators etc. labels: type:support --- diff --git a/.github/ISSUE_TEMPLATE/14-studio-issue.md b/.github/ISSUE_TEMPLATE/14-studio-issue.md new file mode 100644 index 000000000..5942b1eb1 --- /dev/null +++ b/.github/ISSUE_TEMPLATE/14-studio-issue.md @@ -0,0 +1,19 @@ +--- +name: "Studio Issue" +about: Use this template for assistance with the MediaPipe Studio application. +labels: type:support + +--- +Please make sure that this is a MediaPipe Studio issue. + +**System information** (Please provide as much relevant information as possible) +- OS Platform and Distribution (e.g., Linux Ubuntu 16.04, Android 11, iOS 14.4): +- Browser and Version +- Any microphone or camera hardware +- URL that shows the problem + +**Describe the expected behavior:** + +**Other info / Complete Logs :** +Include any js console logs that would be helpful to diagnose the problem. +Large logs and files should be attached: diff --git a/.github/bot_config.yml b/.github/bot_config.yml index 8ad724168..74a60e4b9 100644 --- a/.github/bot_config.yml +++ b/.github/bot_config.yml @@ -15,4 +15,5 @@ # A list of assignees assignees: - - sureshdagooglecom + - kuaashish + - ayushgdev diff --git a/WORKSPACE b/WORKSPACE index 702d1899e..e14473e50 100644 --- a/WORKSPACE +++ b/WORKSPACE @@ -22,21 +22,20 @@ bazel_skylib_workspace() load("@bazel_skylib//lib:versions.bzl", "versions") versions.check(minimum_bazel_version = "3.7.2") -# ABSL cpp library lts_2021_03_24, patch 2. +# ABSL cpp library lts_2023_01_25. http_archive( name = "com_google_absl", urls = [ - "https://github.com/abseil/abseil-cpp/archive/refs/tags/20210324.2.tar.gz", + "https://github.com/abseil/abseil-cpp/archive/refs/tags/20230125.0.tar.gz", ], - # Remove after https://github.com/abseil/abseil-cpp/issues/326 is solved. patches = [ - "@//third_party:com_google_absl_f863b622fe13612433fdf43f76547d5edda0c93001.diff" + "@//third_party:com_google_absl_windows_patch.diff" ], patch_args = [ "-p1", ], - strip_prefix = "abseil-cpp-20210324.2", - sha256 = "59b862f50e710277f8ede96f083a5bb8d7c9595376146838b9580be90374ee1f" + strip_prefix = "abseil-cpp-20230125.0", + sha256 = "3ea49a7d97421b88a8c48a0de16c16048e17725c7ec0f1d3ea2683a2a75adc21" ) http_archive( @@ -212,14 +211,14 @@ http_archive( sha256 = "75922da3a1bdb417d820398eb03d4e9bd067c4905a4246d35a44c01d62154d91", ) -# Point to the commit that deprecates the usage of Eigen::MappedSparseMatrix. +# 2022-10-20 http_archive( name = "pybind11", urls = [ - "https://github.com/pybind/pybind11/archive/70a58c577eaf067748c2ec31bfd0b0a614cffba6.zip", + "https://github.com/pybind/pybind11/archive/v2.10.1.zip", ], - sha256 = "b971842fab1b5b8f3815a2302331782b7d137fef0e06502422bc4bc360f4956c", - strip_prefix = "pybind11-70a58c577eaf067748c2ec31bfd0b0a614cffba6", + sha256 = "fcf94065efcfd0a7a828bacf118fa11c43f6390d0c805e3e6342ac119f2e9976", + strip_prefix = "pybind11-2.10.1", build_file = "@pybind11_bazel//:pybind11.BUILD", ) @@ -320,12 +319,30 @@ http_archive( ], ) -# iOS basic build deps. +# Load Zlib before initializing TensorFlow and the iOS build rules to guarantee +# that the target @zlib//:mini_zlib is available +http_archive( + name = "zlib", + build_file = "//third_party:zlib.BUILD", + sha256 = "c3e5e9fdd5004dcb542feda5ee4f0ff0744628baf8ed2dd5d66f8ca1197cb1a1", + strip_prefix = "zlib-1.2.11", + urls = [ + "http://mirror.bazel.build/zlib.net/fossils/zlib-1.2.11.tar.gz", + "http://zlib.net/fossils/zlib-1.2.11.tar.gz", # 2017-01-15 + ], + patches = [ + "@//third_party:zlib.diff", + ], + patch_args = [ + "-p1", + ], +) +# iOS basic build deps. http_archive( name = "build_bazel_rules_apple", - sha256 = "77e8bf6fda706f420a55874ae6ee4df0c9d95da6c7838228b26910fc82eea5a2", - url = "https://github.com/bazelbuild/rules_apple/releases/download/0.32.0/rules_apple.0.32.0.tar.gz", + sha256 = "f94e6dddf74739ef5cb30f000e13a2a613f6ebfa5e63588305a71fce8a8a9911", + url = "https://github.com/bazelbuild/rules_apple/releases/download/1.1.3/rules_apple.1.1.3.tar.gz", patches = [ # Bypass checking ios unit test runner when building MP ios applications. "@//third_party:build_bazel_rules_apple_bypass_test_runner_check.diff" @@ -339,29 +356,24 @@ load( "@build_bazel_rules_apple//apple:repositories.bzl", "apple_rules_dependencies", ) - apple_rules_dependencies() load( "@build_bazel_rules_swift//swift:repositories.bzl", "swift_rules_dependencies", ) - swift_rules_dependencies() -http_archive( - name = "build_bazel_apple_support", - sha256 = "741366f79d900c11e11d8efd6cc6c66a31bfb2451178b58e0b5edc6f1db17b35", - urls = [ - "https://github.com/bazelbuild/apple_support/releases/download/0.10.0/apple_support.0.10.0.tar.gz" - ], +load( + "@build_bazel_rules_swift//swift:extras.bzl", + "swift_rules_extra_dependencies", ) +swift_rules_extra_dependencies() load( "@build_bazel_apple_support//lib:repositories.bzl", "apple_support_dependencies", ) - apple_support_dependencies() # More iOS deps. @@ -442,25 +454,6 @@ http_archive( ], ) -# Load Zlib before initializing TensorFlow to guarantee that the target -# @zlib//:mini_zlib is available -http_archive( - name = "zlib", - build_file = "//third_party:zlib.BUILD", - sha256 = "c3e5e9fdd5004dcb542feda5ee4f0ff0744628baf8ed2dd5d66f8ca1197cb1a1", - strip_prefix = "zlib-1.2.11", - urls = [ - "http://mirror.bazel.build/zlib.net/fossils/zlib-1.2.11.tar.gz", - "http://zlib.net/fossils/zlib-1.2.11.tar.gz", # 2017-01-15 - ], - patches = [ - "@//third_party:zlib.diff", - ], - patch_args = [ - "-p1", - ], -) - # TensorFlow repo should always go after the other external dependencies. # TF on 2022-08-10. _TENSORFLOW_GIT_COMMIT = "af1d5bc4fbb66d9e6cc1cf89503014a99233583b" diff --git a/docs/BUILD b/docs/BUILD index ad08df66a..80d3ab550 100644 --- a/docs/BUILD +++ b/docs/BUILD @@ -4,12 +4,10 @@ py_binary( name = "build_py_api_docs", srcs = ["build_py_api_docs.py"], deps = [ - "//mediapipe", "//third_party/py/absl:app", "//third_party/py/absl/flags", - "//third_party/py/tensorflow_docs", + "//third_party/py/mediapipe", "//third_party/py/tensorflow_docs/api_generator:generate_lib", - "//third_party/py/tensorflow_docs/api_generator:public_api", ], ) @@ -17,6 +15,7 @@ py_binary( name = "build_java_api_docs", srcs = ["build_java_api_docs.py"], data = [ + "//third_party/android/sdk:api/26.txt", "//third_party/java/doclava/current:doclava.jar", "//third_party/java/jsilver:jsilver_jar", ], diff --git a/docs/build_java_api_docs.py b/docs/build_java_api_docs.py index e96e1fd83..b13e8d1df 100644 --- a/docs/build_java_api_docs.py +++ b/docs/build_java_api_docs.py @@ -20,10 +20,6 @@ from absl import flags from tensorflow_docs.api_generator import gen_java -_JAVA_ROOT = flags.DEFINE_string('java_src', None, - 'Override the Java source path.', - required=False) - _OUT_DIR = flags.DEFINE_string('output_dir', '/tmp/mp_java/', 'Write docs here.') @@ -37,27 +33,30 @@ _ = flags.DEFINE_bool( 'search_hints', True, '[UNUSED] Include metadata search hints in the generated files') +_ANDROID_SDK = pathlib.Path('android/sdk/api/26.txt') + def main(_) -> None: - if not (java_root := _JAVA_ROOT.value): - # Default to using a relative path to find the Java source. - mp_root = pathlib.Path(__file__) - while (mp_root := mp_root.parent).name != 'mediapipe': - # Find the nearest `mediapipe` dir. - pass + # Default to using a relative path to find the Java source. + mp_root = pathlib.Path(__file__) + while (mp_root := mp_root.parent).name != 'mediapipe': + # Find the nearest `mediapipe` dir. + pass - # Externally, parts of the repo are nested inside a mediapipe/ directory - # that does not exist internally. Support both. - if (mp_root / 'mediapipe').exists(): - mp_root = mp_root / 'mediapipe' + # Find the root from which all packages are relative. + root = mp_root.parent - java_root = mp_root / 'tasks/java' + # Externally, parts of the repo are nested inside a mediapipe/ directory + # that does not exist internally. Support both. + if (mp_root / 'mediapipe').exists(): + mp_root = mp_root / 'mediapipe' gen_java.gen_java_docs( package='com.google.mediapipe', - source_path=pathlib.Path(java_root), + source_path=mp_root / 'tasks/java', output_dir=pathlib.Path(_OUT_DIR.value), - site_path=pathlib.Path(_SITE_PATH.value)) + site_path=pathlib.Path(_SITE_PATH.value), + federated_docs={'https://developer.android.com': root / _ANDROID_SDK}) if __name__ == '__main__': diff --git a/docs/build_model_maker_api_docs.py b/docs/build_model_maker_api_docs.py new file mode 100644 index 000000000..7732b7d56 --- /dev/null +++ b/docs/build_model_maker_api_docs.py @@ -0,0 +1,81 @@ +# Copyright 2022 The MediaPipe Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +r"""MediaPipe Model Maker reference docs generation script. + +This script generates API reference docs for the `mediapipe` PIP package. + +$> pip install -U git+https://github.com/tensorflow/docs mediapipe-model-maker +$> python build_model_maker_api_docs.py +""" + +import os + +from absl import app +from absl import flags + +from tensorflow_docs.api_generator import generate_lib + +try: + # mediapipe has not been set up to work with bazel yet, so catch & report. + import mediapipe_model_maker # pytype: disable=import-error +except ImportError as e: + raise ImportError('Please `pip install mediapipe-model-maker`.') from e + + +PROJECT_SHORT_NAME = 'mediapipe_model_maker' +PROJECT_FULL_NAME = 'MediaPipe Model Maker' + +_OUTPUT_DIR = flags.DEFINE_string( + 'output_dir', + default='/tmp/generated_docs', + help='Where to write the resulting docs.') + +_URL_PREFIX = flags.DEFINE_string( + 'code_url_prefix', + 'https://github.com/google/mediapipe/tree/master/mediapipe/model_maker', + 'The url prefix for links to code.') + +_SEARCH_HINTS = flags.DEFINE_bool( + 'search_hints', True, + 'Include metadata search hints in the generated files') + +_SITE_PATH = flags.DEFINE_string('site_path', '/mediapipe/api_docs/python', + 'Path prefix in the _toc.yaml') + + +def gen_api_docs(): + """Generates API docs for the mediapipe-model-maker package.""" + + doc_generator = generate_lib.DocGenerator( + root_title=PROJECT_FULL_NAME, + py_modules=[(PROJECT_SHORT_NAME, mediapipe_model_maker)], + base_dir=os.path.dirname(mediapipe_model_maker.__file__), + code_url_prefix=_URL_PREFIX.value, + search_hints=_SEARCH_HINTS.value, + site_path=_SITE_PATH.value, + callbacks=[], + ) + + doc_generator.build(_OUTPUT_DIR.value) + + print('Docs output to:', _OUTPUT_DIR.value) + + +def main(_): + gen_api_docs() + + +if __name__ == '__main__': + app.run(main) diff --git a/docs/build_py_api_docs.py b/docs/build_py_api_docs.py index fa1e4314f..02eb04074 100644 --- a/docs/build_py_api_docs.py +++ b/docs/build_py_api_docs.py @@ -26,11 +26,10 @@ from absl import app from absl import flags from tensorflow_docs.api_generator import generate_lib -from tensorflow_docs.api_generator import public_api try: # mediapipe has not been set up to work with bazel yet, so catch & report. - import mediapipe # pytype: disable=import-error + import mediapipe as mp # pytype: disable=import-error except ImportError as e: raise ImportError('Please `pip install mediapipe`.') from e @@ -45,31 +44,30 @@ _OUTPUT_DIR = flags.DEFINE_string( _URL_PREFIX = flags.DEFINE_string( 'code_url_prefix', - 'https://github.com/google/mediapipe/tree/master/mediapipe', + 'https://github.com/google/mediapipe/blob/master/mediapipe', 'The url prefix for links to code.') _SEARCH_HINTS = flags.DEFINE_bool( 'search_hints', True, 'Include metadata search hints in the generated files') -_SITE_PATH = flags.DEFINE_string('site_path', '/mediapipe/api_docs/python', +_SITE_PATH = flags.DEFINE_string('site_path', '/mediapipe/api/solutions/python', 'Path prefix in the _toc.yaml') def gen_api_docs(): """Generates API docs for the mediapipe package.""" + if hasattr(mp, 'solutions'): + del mp.solutions doc_generator = generate_lib.DocGenerator( root_title=PROJECT_FULL_NAME, - py_modules=[(PROJECT_SHORT_NAME, mediapipe)], - base_dir=os.path.dirname(mediapipe.__file__), + py_modules=[(PROJECT_SHORT_NAME, mp)], + base_dir=os.path.dirname(mp.__file__), code_url_prefix=_URL_PREFIX.value, search_hints=_SEARCH_HINTS.value, site_path=_SITE_PATH.value, - # This callback ensures that docs are only generated for objects that - # are explicitly imported in your __init__.py files. There are other - # options but this is a good starting point. - callbacks=[public_api.explicit_package_contents_filter], + callbacks=[], ) doc_generator.build(_OUTPUT_DIR.value) diff --git a/docs/getting_started/install.md b/docs/getting_started/install.md index e630b073a..d7a028ec3 100644 --- a/docs/getting_started/install.md +++ b/docs/getting_started/install.md @@ -35,7 +35,7 @@ install --user six`. ```bash $ cd $HOME - $ git clone https://github.com/google/mediapipe.git + $ git clone -–depth 1 https://github.com/google/mediapipe.git # Change directory into MediaPipe root directory $ cd mediapipe @@ -287,7 +287,7 @@ build issues. 2. Checkout MediaPipe repository. ```bash - $ git clone https://github.com/google/mediapipe.git + $ git clone -–depth 1 https://github.com/google/mediapipe.git # Change directory into MediaPipe root directory $ cd mediapipe @@ -416,7 +416,7 @@ build issues. 3. Checkout MediaPipe repository. ```bash - $ git clone https://github.com/google/mediapipe.git + $ git clone -–depth 1 https://github.com/google/mediapipe.git $ cd mediapipe ``` @@ -590,7 +590,7 @@ next section. 7. Checkout MediaPipe repository. ``` - C:\Users\Username\mediapipe_repo> git clone https://github.com/google/mediapipe.git + C:\Users\Username\mediapipe_repo> git clone -–depth 1 https://github.com/google/mediapipe.git # Change directory into MediaPipe root directory C:\Users\Username\mediapipe_repo> cd mediapipe @@ -680,7 +680,7 @@ cameras. Alternatively, you use a video file as input. 6. Checkout MediaPipe repository. ```bash - username@DESKTOP-TMVLBJ1:~$ git clone https://github.com/google/mediapipe.git + username@DESKTOP-TMVLBJ1:~$ git clone -–depth 1 https://github.com/google/mediapipe.git username@DESKTOP-TMVLBJ1:~$ cd mediapipe ``` @@ -771,7 +771,7 @@ This will use a Docker image that will isolate mediapipe's installation from the 2. Build a docker image with tag "mediapipe". ```bash - $ git clone https://github.com/google/mediapipe.git + $ git clone -–depth 1 https://github.com/google/mediapipe.git $ cd mediapipe $ docker build --tag=mediapipe . diff --git a/docs/solutions/holistic.md b/docs/solutions/holistic.md index 8c552834e..11589425d 100644 --- a/docs/solutions/holistic.md +++ b/docs/solutions/holistic.md @@ -259,6 +259,7 @@ mp_holistic = mp.solutions.holistic # For static images: IMAGE_FILES = [] +BG_COLOR = (192, 192, 192) # gray with mp_holistic.Holistic( static_image_mode=True, model_complexity=2, diff --git a/docs/solutions/models.md b/docs/solutions/models.md index 18bcf0c8b..325c41f1b 100644 --- a/docs/solutions/models.md +++ b/docs/solutions/models.md @@ -94,8 +94,6 @@ one over the other. * [TFLite model](https://storage.googleapis.com/mediapipe-assets/ssdlite_object_detection.tflite) * [TFLite model quantized for EdgeTPU/Coral](https://github.com/google/mediapipe/tree/master/mediapipe/examples/coral/models/object-detector-quantized_edgetpu.tflite) -* [TensorFlow model](https://github.com/google/mediapipe/tree/master/mediapipe/models/object_detection_saved_model) -* [Model information](https://github.com/google/mediapipe/tree/master/mediapipe/models/object_detection_saved_model/README.md) ### [Objectron](https://google.github.io/mediapipe/solutions/objectron) diff --git a/mediapipe/calculators/audio/BUILD b/mediapipe/calculators/audio/BUILD index ba461e4a7..4a8f0f598 100644 --- a/mediapipe/calculators/audio/BUILD +++ b/mediapipe/calculators/audio/BUILD @@ -12,12 +12,12 @@ # See the License for the specific language governing permissions and # limitations under the License. +load("//mediapipe/framework/port:build_config.bzl", "mediapipe_cc_proto_library") + licenses(["notice"]) package(default_visibility = ["//visibility:private"]) -load("//mediapipe/framework/port:build_config.bzl", "mediapipe_cc_proto_library") - proto_library( name = "mfcc_mel_calculators_proto", srcs = ["mfcc_mel_calculators.proto"], @@ -197,7 +197,6 @@ cc_library( ":spectrogram_calculator_cc_proto", "//mediapipe/framework:calculator_framework", "//mediapipe/framework/formats:matrix", - "//mediapipe/framework/formats:time_series_header_cc_proto", "//mediapipe/framework/port:core_proto", "//mediapipe/framework/port:integral_types", "//mediapipe/framework/port:logging", diff --git a/mediapipe/calculators/audio/spectrogram_calculator.cc b/mediapipe/calculators/audio/spectrogram_calculator.cc index c038c0cd7..bd4d8f3bf 100644 --- a/mediapipe/calculators/audio/spectrogram_calculator.cc +++ b/mediapipe/calculators/audio/spectrogram_calculator.cc @@ -280,6 +280,13 @@ absl::Status SpectrogramCalculator::Open(CalculatorContext* cc) { audio_dsp::HammingWindow().GetPeriodicSamples(frame_duration_samples_, &window); break; + case SpectrogramCalculatorOptions::SQRT_HANN: { + audio_dsp::HannWindow().GetPeriodicSamples(frame_duration_samples_, + &window); + absl::c_transform(window, window.begin(), + [](double x) { return std::sqrt(x); }); + break; + } } // Propagate settings down to the actual Spectrogram object. diff --git a/mediapipe/calculators/audio/spectrogram_calculator.proto b/mediapipe/calculators/audio/spectrogram_calculator.proto index 8e1e18051..ddfca1d1c 100644 --- a/mediapipe/calculators/audio/spectrogram_calculator.proto +++ b/mediapipe/calculators/audio/spectrogram_calculator.proto @@ -68,6 +68,7 @@ message SpectrogramCalculatorOptions { HANN = 0; HAMMING = 1; COSINE = 2; + SQRT_HANN = 4; } optional WindowType window_type = 6 [default = HANN]; diff --git a/mediapipe/calculators/core/BUILD b/mediapipe/calculators/core/BUILD index ecd878115..ecfdd5d0b 100644 --- a/mediapipe/calculators/core/BUILD +++ b/mediapipe/calculators/core/BUILD @@ -13,16 +13,24 @@ # limitations under the License. # +load("@bazel_skylib//lib:selects.bzl", "selects") load("//mediapipe/framework/port:build_config.bzl", "mediapipe_proto_library") licenses(["notice"]) -package(default_visibility = ["//visibility:private"]) +package(default_visibility = ["//visibility:public"]) + +selects.config_setting_group( + name = "ios_or_disable_gpu", + match_any = [ + "//mediapipe/gpu:disable_gpu", + "//mediapipe:ios", + ], +) mediapipe_proto_library( name = "concatenate_vector_calculator_proto", srcs = ["concatenate_vector_calculator.proto"], - visibility = ["//visibility:public"], deps = [ "//mediapipe/framework:calculator_options_proto", "//mediapipe/framework:calculator_proto", @@ -32,7 +40,6 @@ mediapipe_proto_library( mediapipe_proto_library( name = "dequantize_byte_array_calculator_proto", srcs = ["dequantize_byte_array_calculator.proto"], - visibility = ["//visibility:public"], deps = [ "//mediapipe/framework:calculator_options_proto", "//mediapipe/framework:calculator_proto", @@ -42,7 +49,6 @@ mediapipe_proto_library( mediapipe_proto_library( name = "packet_cloner_calculator_proto", srcs = ["packet_cloner_calculator.proto"], - visibility = ["//visibility:public"], deps = [ "//mediapipe/framework:calculator_options_proto", "//mediapipe/framework:calculator_proto", @@ -52,7 +58,6 @@ mediapipe_proto_library( mediapipe_proto_library( name = "packet_resampler_calculator_proto", srcs = ["packet_resampler_calculator.proto"], - visibility = ["//visibility:public"], deps = [ "//mediapipe/framework:calculator_options_proto", "//mediapipe/framework:calculator_proto", @@ -62,7 +67,6 @@ mediapipe_proto_library( mediapipe_proto_library( name = "packet_thinner_calculator_proto", srcs = ["packet_thinner_calculator.proto"], - visibility = ["//visibility:public"], deps = [ "//mediapipe/framework:calculator_options_proto", "//mediapipe/framework:calculator_proto", @@ -72,7 +76,6 @@ mediapipe_proto_library( mediapipe_proto_library( name = "split_vector_calculator_proto", srcs = ["split_vector_calculator.proto"], - visibility = ["//visibility:public"], deps = [ "//mediapipe/framework:calculator_options_proto", "//mediapipe/framework:calculator_proto", @@ -82,7 +85,6 @@ mediapipe_proto_library( mediapipe_proto_library( name = "quantize_float_vector_calculator_proto", srcs = ["quantize_float_vector_calculator.proto"], - visibility = ["//visibility:public"], deps = [ "//mediapipe/framework:calculator_options_proto", "//mediapipe/framework:calculator_proto", @@ -92,7 +94,6 @@ mediapipe_proto_library( mediapipe_proto_library( name = "sequence_shift_calculator_proto", srcs = ["sequence_shift_calculator.proto"], - visibility = ["//visibility:public"], deps = [ "//mediapipe/framework:calculator_options_proto", "//mediapipe/framework:calculator_proto", @@ -102,7 +103,6 @@ mediapipe_proto_library( mediapipe_proto_library( name = "gate_calculator_proto", srcs = ["gate_calculator.proto"], - visibility = ["//visibility:public"], deps = [ "//mediapipe/framework:calculator_options_proto", "//mediapipe/framework:calculator_proto", @@ -112,7 +112,6 @@ mediapipe_proto_library( mediapipe_proto_library( name = "constant_side_packet_calculator_proto", srcs = ["constant_side_packet_calculator.proto"], - visibility = ["//visibility:public"], deps = [ "//mediapipe/framework:calculator_options_proto", "//mediapipe/framework:calculator_proto", @@ -124,7 +123,6 @@ mediapipe_proto_library( mediapipe_proto_library( name = "clip_vector_size_calculator_proto", srcs = ["clip_vector_size_calculator.proto"], - visibility = ["//visibility:public"], deps = [ "//mediapipe/framework:calculator_options_proto", "//mediapipe/framework:calculator_proto", @@ -134,7 +132,6 @@ mediapipe_proto_library( mediapipe_proto_library( name = "flow_limiter_calculator_proto", srcs = ["flow_limiter_calculator.proto"], - visibility = ["//visibility:public"], deps = [ "//mediapipe/framework:calculator_options_proto", "//mediapipe/framework:calculator_proto", @@ -144,7 +141,6 @@ mediapipe_proto_library( mediapipe_proto_library( name = "graph_profile_calculator_proto", srcs = ["graph_profile_calculator.proto"], - visibility = ["//visibility:public"], deps = [ "//mediapipe/framework:calculator_options_proto", "//mediapipe/framework:calculator_proto", @@ -154,7 +150,6 @@ mediapipe_proto_library( mediapipe_proto_library( name = "get_vector_item_calculator_proto", srcs = ["get_vector_item_calculator.proto"], - visibility = ["//visibility:public"], deps = [ "//mediapipe/framework:calculator_options_proto", "//mediapipe/framework:calculator_proto", @@ -164,7 +159,6 @@ mediapipe_proto_library( cc_library( name = "add_header_calculator", srcs = ["add_header_calculator.cc"], - visibility = ["//visibility:public"], deps = [ "//mediapipe/framework:calculator_framework", "//mediapipe/framework/api2:node", @@ -193,7 +187,6 @@ cc_library( name = "begin_loop_calculator", srcs = ["begin_loop_calculator.cc"], hdrs = ["begin_loop_calculator.h"], - visibility = ["//visibility:public"], deps = [ "//mediapipe/framework:calculator_context", "//mediapipe/framework:calculator_contract", @@ -216,7 +209,6 @@ cc_library( name = "end_loop_calculator", srcs = ["end_loop_calculator.cc"], hdrs = ["end_loop_calculator.h"], - visibility = ["//visibility:public"], deps = [ "//mediapipe/framework:calculator_context", "//mediapipe/framework:calculator_contract", @@ -258,7 +250,6 @@ cc_test( cc_library( name = "concatenate_vector_calculator_hdr", hdrs = ["concatenate_vector_calculator.h"], - visibility = ["//visibility:public"], deps = [ ":concatenate_vector_calculator_cc_proto", "//mediapipe/framework:calculator_framework", @@ -284,7 +275,6 @@ cc_library( ], "//conditions:default": [], }), - visibility = ["//visibility:public"], deps = [ ":concatenate_vector_calculator_cc_proto", "//mediapipe/framework/api2:node", @@ -311,7 +301,6 @@ cc_library( cc_library( name = "concatenate_detection_vector_calculator", srcs = ["concatenate_detection_vector_calculator.cc"], - visibility = ["//visibility:public"], deps = [ ":concatenate_vector_calculator", "//mediapipe/framework:calculator_framework", @@ -323,7 +312,6 @@ cc_library( cc_library( name = "concatenate_proto_list_calculator", srcs = ["concatenate_proto_list_calculator.cc"], - visibility = ["//visibility:public"], deps = [ ":concatenate_vector_calculator_cc_proto", "//mediapipe/framework:calculator_framework", @@ -341,7 +329,6 @@ cc_test( srcs = ["concatenate_proto_list_calculator_test.cc"], deps = [ ":concatenate_proto_list_calculator", - ":concatenate_vector_calculator_cc_proto", "//mediapipe/framework:calculator_framework", "//mediapipe/framework:calculator_runner", "//mediapipe/framework:timestamp", @@ -373,7 +360,6 @@ cc_library( name = "clip_vector_size_calculator", srcs = ["clip_vector_size_calculator.cc"], hdrs = ["clip_vector_size_calculator.h"], - visibility = ["//visibility:public"], deps = [ ":clip_vector_size_calculator_cc_proto", "//mediapipe/framework:calculator_framework", @@ -389,7 +375,6 @@ cc_library( cc_library( name = "clip_detection_vector_size_calculator", srcs = ["clip_detection_vector_size_calculator.cc"], - visibility = ["//visibility:public"], deps = [ ":clip_vector_size_calculator", "//mediapipe/framework:calculator_framework", @@ -403,7 +388,6 @@ cc_test( srcs = ["clip_vector_size_calculator_test.cc"], deps = [ ":clip_vector_size_calculator", - "//mediapipe/calculators/core:packet_resampler_calculator_cc_proto", "//mediapipe/framework:calculator_framework", "//mediapipe/framework:calculator_runner", "//mediapipe/framework:timestamp", @@ -417,9 +401,6 @@ cc_test( cc_library( name = "counting_source_calculator", srcs = ["counting_source_calculator.cc"], - visibility = [ - "//visibility:public", - ], deps = [ "//mediapipe/framework:calculator_framework", "//mediapipe/framework/port:ret_check", @@ -432,9 +413,6 @@ cc_library( cc_library( name = "make_pair_calculator", srcs = ["make_pair_calculator.cc"], - visibility = [ - "//visibility:public", - ], deps = [ "//mediapipe/framework:calculator_framework", "//mediapipe/framework/api2:node", @@ -463,9 +441,6 @@ cc_test( cc_library( name = "matrix_multiply_calculator", srcs = ["matrix_multiply_calculator.cc"], - visibility = [ - "//visibility:public", - ], deps = [ "//mediapipe/framework:calculator_framework", "//mediapipe/framework/api2:node", @@ -479,9 +454,6 @@ cc_library( cc_library( name = "matrix_subtract_calculator", srcs = ["matrix_subtract_calculator.cc"], - visibility = [ - "//visibility:public", - ], deps = [ "//mediapipe/framework:calculator_framework", "//mediapipe/framework/api2:node", @@ -495,9 +467,6 @@ cc_library( cc_library( name = "mux_calculator", srcs = ["mux_calculator.cc"], - visibility = [ - "//visibility:public", - ], deps = [ "//mediapipe/framework:calculator_framework", "//mediapipe/framework/api2:node", @@ -510,9 +479,6 @@ cc_library( cc_library( name = "non_zero_calculator", srcs = ["non_zero_calculator.cc"], - visibility = [ - "//visibility:public", - ], deps = [ "//mediapipe/framework:calculator_framework", "//mediapipe/framework/api2:node", @@ -558,9 +524,6 @@ cc_test( cc_library( name = "packet_cloner_calculator", srcs = ["packet_cloner_calculator.cc"], - visibility = [ - "//visibility:public", - ], deps = [ ":packet_cloner_calculator_cc_proto", "//mediapipe/framework:calculator_framework", @@ -589,7 +552,6 @@ cc_test( cc_library( name = "packet_inner_join_calculator", srcs = ["packet_inner_join_calculator.cc"], - visibility = ["//visibility:public"], deps = [ "//mediapipe/framework:calculator_framework", "//mediapipe/framework/port:ret_check", @@ -613,9 +575,8 @@ cc_test( cc_library( name = "packet_thinner_calculator", srcs = ["packet_thinner_calculator.cc"], - visibility = ["//visibility:public"], deps = [ - "//mediapipe/calculators/core:packet_thinner_calculator_cc_proto", + ":packet_thinner_calculator_cc_proto", "//mediapipe/framework:calculator_context", "//mediapipe/framework:calculator_framework", "//mediapipe/framework/formats:video_stream_header", @@ -632,7 +593,7 @@ cc_test( srcs = ["packet_thinner_calculator_test.cc"], deps = [ ":packet_thinner_calculator", - "//mediapipe/calculators/core:packet_thinner_calculator_cc_proto", + ":packet_thinner_calculator_cc_proto", "//mediapipe/framework:calculator_framework", "//mediapipe/framework:calculator_runner", "//mediapipe/framework/formats:video_stream_header", @@ -645,9 +606,6 @@ cc_test( cc_library( name = "pass_through_calculator", srcs = ["pass_through_calculator.cc"], - visibility = [ - "//visibility:public", - ], deps = [ "//mediapipe/framework:calculator_framework", "//mediapipe/framework/port:status", @@ -658,9 +616,6 @@ cc_library( cc_library( name = "round_robin_demux_calculator", srcs = ["round_robin_demux_calculator.cc"], - visibility = [ - "//visibility:public", - ], deps = [ "//mediapipe/framework:calculator_framework", "//mediapipe/framework/api2:node", @@ -672,9 +627,6 @@ cc_library( cc_library( name = "immediate_mux_calculator", srcs = ["immediate_mux_calculator.cc"], - visibility = [ - "//visibility:public", - ], deps = [ "//mediapipe/framework:calculator_framework", "//mediapipe/framework/port:ret_check", @@ -686,7 +638,6 @@ cc_library( cc_library( name = "packet_presence_calculator", srcs = ["packet_presence_calculator.cc"], - visibility = ["//visibility:public"], deps = [ "//mediapipe/framework:calculator_framework", "//mediapipe/framework:packet", @@ -715,7 +666,6 @@ cc_test( cc_library( name = "previous_loopback_calculator", srcs = ["previous_loopback_calculator.cc"], - visibility = ["//visibility:public"], deps = [ "//mediapipe/framework:calculator_framework", "//mediapipe/framework:packet", @@ -731,7 +681,6 @@ cc_library( cc_library( name = "flow_limiter_calculator", srcs = ["flow_limiter_calculator.cc"], - visibility = ["//visibility:public"], deps = [ ":flow_limiter_calculator_cc_proto", "//mediapipe/framework:calculator_framework", @@ -748,7 +697,6 @@ cc_library( cc_library( name = "string_to_int_calculator", srcs = ["string_to_int_calculator.cc"], - visibility = ["//visibility:public"], deps = [ "//mediapipe/framework:calculator_framework", "//mediapipe/framework/port:integral_types", @@ -761,7 +709,6 @@ cc_library( cc_library( name = "default_side_packet_calculator", srcs = ["default_side_packet_calculator.cc"], - visibility = ["//visibility:public"], deps = [ "//mediapipe/framework:calculator_framework", "//mediapipe/framework/port:ret_check", @@ -773,7 +720,6 @@ cc_library( cc_library( name = "side_packet_to_stream_calculator", srcs = ["side_packet_to_stream_calculator.cc"], - visibility = ["//visibility:public"], deps = [ "//mediapipe/framework:calculator_framework", "//mediapipe/framework/port:logging", @@ -824,11 +770,8 @@ cc_library( name = "packet_resampler_calculator", srcs = ["packet_resampler_calculator.cc"], hdrs = ["packet_resampler_calculator.h"], - visibility = [ - "//visibility:public", - ], deps = [ - "//mediapipe/calculators/core:packet_resampler_calculator_cc_proto", + ":packet_resampler_calculator_cc_proto", "//mediapipe/framework:calculator_framework", "//mediapipe/framework:collection_item_id", "//mediapipe/framework/deps:mathutil", @@ -852,7 +795,7 @@ cc_test( ], deps = [ ":packet_resampler_calculator", - "//mediapipe/calculators/core:packet_resampler_calculator_cc_proto", + ":packet_resampler_calculator_cc_proto", "//mediapipe/framework:calculator_framework", "//mediapipe/framework:calculator_runner", "//mediapipe/framework/formats:video_stream_header", @@ -886,7 +829,6 @@ cc_test( cc_test( name = "matrix_multiply_calculator_test", srcs = ["matrix_multiply_calculator_test.cc"], - visibility = ["//visibility:private"], deps = [ ":matrix_multiply_calculator", "//mediapipe/framework:calculator_framework", @@ -902,7 +844,6 @@ cc_test( cc_test( name = "matrix_subtract_calculator_test", srcs = ["matrix_subtract_calculator_test.cc"], - visibility = ["//visibility:private"], deps = [ ":matrix_subtract_calculator", "//mediapipe/framework:calculator_framework", @@ -920,10 +861,10 @@ cc_test( name = "flow_limiter_calculator_test", srcs = ["flow_limiter_calculator_test.cc"], deps = [ + ":counting_source_calculator", ":flow_limiter_calculator", ":flow_limiter_calculator_cc_proto", - "//mediapipe/calculators/core:counting_source_calculator", - "//mediapipe/calculators/core:pass_through_calculator", + ":pass_through_calculator", "//mediapipe/framework:calculator_framework", "//mediapipe/framework:calculator_runner", "//mediapipe/framework:test_calculators", @@ -952,14 +893,13 @@ cc_library( ], "//conditions:default": [], }), - visibility = ["//visibility:public"], deps = [ ":split_vector_calculator_cc_proto", "//mediapipe/framework/formats:detection_cc_proto", - "//mediapipe/framework:calculator_framework", - "//mediapipe/framework/formats:landmark_cc_proto", "//mediapipe/framework/formats:classification_cc_proto", + "//mediapipe/framework/formats:landmark_cc_proto", "//mediapipe/framework/formats:rect_cc_proto", + "//mediapipe/framework:calculator_framework", "//mediapipe/framework/formats:matrix", "//mediapipe/framework/formats:tensor", "//mediapipe/framework/port:ret_check", @@ -968,8 +908,7 @@ cc_library( "@org_tensorflow//tensorflow/lite:framework", "@org_tensorflow//tensorflow/lite/kernels:builtin_ops", ] + select({ - "//mediapipe/gpu:disable_gpu": [], - "//mediapipe:ios": [], + ":ios_or_disable_gpu": [], "//conditions:default": [ "@org_tensorflow//tensorflow/lite/delegates/gpu/gl:gl_buffer", ], @@ -998,7 +937,6 @@ cc_test( cc_library( name = "split_proto_list_calculator", srcs = ["split_proto_list_calculator.cc"], - visibility = ["//visibility:public"], deps = [ ":split_vector_calculator_cc_proto", "//mediapipe/framework:calculator_framework", @@ -1030,7 +968,6 @@ cc_test( cc_library( name = "dequantize_byte_array_calculator", srcs = ["dequantize_byte_array_calculator.cc"], - visibility = ["//visibility:public"], deps = [ ":dequantize_byte_array_calculator_cc_proto", "//mediapipe/framework:calculator_context", @@ -1056,7 +993,6 @@ cc_test( cc_library( name = "quantize_float_vector_calculator", srcs = ["quantize_float_vector_calculator.cc"], - visibility = ["//visibility:public"], deps = [ ":quantize_float_vector_calculator_cc_proto", "//mediapipe/framework:calculator_context", @@ -1082,7 +1018,6 @@ cc_test( cc_library( name = "sequence_shift_calculator", srcs = ["sequence_shift_calculator.cc"], - visibility = ["//visibility:public"], deps = [ ":sequence_shift_calculator_cc_proto", "//mediapipe/framework:calculator_framework", @@ -1107,7 +1042,6 @@ cc_test( cc_library( name = "gate_calculator", srcs = ["gate_calculator.cc"], - visibility = ["//visibility:public"], deps = [ ":gate_calculator_cc_proto", "//mediapipe/framework:calculator_framework", @@ -1133,7 +1067,6 @@ cc_test( cc_library( name = "matrix_to_vector_calculator", srcs = ["matrix_to_vector_calculator.cc"], - visibility = ["//visibility:public"], deps = [ "//mediapipe/framework:calculator_framework", "//mediapipe/framework/api2:node", @@ -1169,7 +1102,6 @@ cc_test( cc_library( name = "merge_calculator", srcs = ["merge_calculator.cc"], - visibility = ["//visibility:public"], deps = [ "//mediapipe/framework:calculator_framework", "//mediapipe/framework/api2:node", @@ -1195,7 +1127,6 @@ cc_test( cc_library( name = "stream_to_side_packet_calculator", srcs = ["stream_to_side_packet_calculator.cc"], - visibility = ["//visibility:public"], deps = [ "//mediapipe/framework:calculator_framework", "//mediapipe/framework:timestamp", @@ -1221,7 +1152,6 @@ cc_test( cc_library( name = "constant_side_packet_calculator", srcs = ["constant_side_packet_calculator.cc"], - visibility = ["//visibility:public"], deps = [ ":constant_side_packet_calculator_cc_proto", "//mediapipe/framework:calculator_framework", @@ -1251,7 +1181,6 @@ cc_test( cc_library( name = "graph_profile_calculator", srcs = ["graph_profile_calculator.cc"], - visibility = ["//visibility:public"], deps = [ ":graph_profile_calculator_cc_proto", "//mediapipe/framework:calculator_framework", @@ -1293,7 +1222,6 @@ cc_library( name = "get_vector_item_calculator", srcs = ["get_vector_item_calculator.cc"], hdrs = ["get_vector_item_calculator.h"], - visibility = ["//visibility:public"], deps = [ ":get_vector_item_calculator_cc_proto", "//mediapipe/framework:calculator_framework", @@ -1301,6 +1229,7 @@ cc_library( "//mediapipe/framework/api2:packet", "//mediapipe/framework/api2:port", "//mediapipe/framework/formats:classification_cc_proto", + "//mediapipe/framework/formats:detection_cc_proto", "//mediapipe/framework/formats:landmark_cc_proto", "//mediapipe/framework/port:ret_check", "//mediapipe/framework/port:status", @@ -1326,7 +1255,6 @@ cc_library( name = "vector_indices_calculator", srcs = ["vector_indices_calculator.cc"], hdrs = ["vector_indices_calculator.h"], - visibility = ["//visibility:public"], deps = [ "//mediapipe/framework:calculator_framework", "//mediapipe/framework/api2:node", @@ -1352,7 +1280,6 @@ cc_library( name = "vector_size_calculator", srcs = ["vector_size_calculator.cc"], hdrs = ["vector_size_calculator.h"], - visibility = ["//visibility:public"], deps = [ "//mediapipe/framework:calculator_framework", "//mediapipe/framework/api2:node", @@ -1366,9 +1293,6 @@ cc_library( cc_library( name = "packet_sequencer_calculator", srcs = ["packet_sequencer_calculator.cc"], - visibility = [ - "//visibility:public", - ], deps = [ "//mediapipe/framework:calculator_framework", "//mediapipe/framework/api2:contract", @@ -1386,7 +1310,7 @@ cc_test( srcs = ["packet_sequencer_calculator_test.cc"], deps = [ ":packet_sequencer_calculator", - "//mediapipe/calculators/core:pass_through_calculator", + ":pass_through_calculator", "//mediapipe/framework:calculator_cc_proto", "//mediapipe/framework:calculator_framework", "//mediapipe/framework:subgraph", @@ -1403,11 +1327,12 @@ cc_library( name = "merge_to_vector_calculator", srcs = ["merge_to_vector_calculator.cc"], hdrs = ["merge_to_vector_calculator.h"], - visibility = ["//visibility:public"], deps = [ "//mediapipe/framework:calculator_framework", + "//mediapipe/framework:packet", "//mediapipe/framework/api2:node", "//mediapipe/framework/api2:port", + "//mediapipe/framework/formats:detection_cc_proto", "//mediapipe/framework/formats:image", "@com_google_absl//absl/status", ], @@ -1417,7 +1342,6 @@ cc_library( mediapipe_proto_library( name = "bypass_calculator_proto", srcs = ["bypass_calculator.proto"], - visibility = ["//visibility:public"], deps = [ "//mediapipe/framework:calculator_options_proto", "//mediapipe/framework:calculator_proto", @@ -1427,7 +1351,6 @@ mediapipe_proto_library( cc_library( name = "bypass_calculator", srcs = ["bypass_calculator.cc"], - visibility = ["//visibility:public"], deps = [ ":bypass_calculator_cc_proto", "//mediapipe/framework:calculator_framework", diff --git a/mediapipe/calculators/core/begin_loop_calculator.h b/mediapipe/calculators/core/begin_loop_calculator.h index a9d29e687..6d17f9953 100644 --- a/mediapipe/calculators/core/begin_loop_calculator.h +++ b/mediapipe/calculators/core/begin_loop_calculator.h @@ -49,7 +49,7 @@ namespace mediapipe { // calculator: "EndLoopWithOutputCalculator" // input_stream: "ITEM:output_of_loop_body" # ItemU @loop_internal_ts // input_stream: "BATCH_END:ext_ts" # Timestamp @loop_internal_ts -// output_stream: "OUTPUT:aggregated_result" # IterableU @ext_ts +// output_stream: "ITERABLE:aggregated_result" # IterableU @ext_ts // } // // Input streams tagged with "CLONE" are cloned to the corresponding output diff --git a/mediapipe/calculators/core/bypass_calculator.cc b/mediapipe/calculators/core/bypass_calculator.cc index efc0612ec..4e007329b 100644 --- a/mediapipe/calculators/core/bypass_calculator.cc +++ b/mediapipe/calculators/core/bypass_calculator.cc @@ -111,6 +111,10 @@ class BypassCalculator : public Node { cc->Outputs().Get(id).SetAny(); } } + for (auto id = cc->InputSidePackets().BeginId(); + id != cc->InputSidePackets().EndId(); ++id) { + cc->InputSidePackets().Get(id).SetAny(); + } return absl::OkStatus(); } diff --git a/mediapipe/calculators/core/flow_limiter_calculator_test.cc b/mediapipe/calculators/core/flow_limiter_calculator_test.cc index 45bace271..5d0594de9 100644 --- a/mediapipe/calculators/core/flow_limiter_calculator_test.cc +++ b/mediapipe/calculators/core/flow_limiter_calculator_test.cc @@ -85,75 +85,6 @@ std::string SourceString(Timestamp t) { : absl::StrCat("Timestamp(", t.DebugString(), ")"); } -template -std::string SourceString(Packet packet) { - std::ostringstream oss; - if (packet.IsEmpty()) { - oss << "Packet()"; - } else { - oss << "MakePacket<" << MediaPipeTypeStringOrDemangled() << ">(" - << packet.Get() << ")"; - } - oss << ".At(" << SourceString(packet.Timestamp()) << ")"; - return oss.str(); -} - -template -class PacketsEqMatcher - : public ::testing::MatcherInterface { - public: - PacketsEqMatcher(PacketContainer packets) : packets_(packets) {} - void DescribeTo(::std::ostream* os) const override { - *os << "The expected packet contents: \n"; - Print(packets_, os); - } - bool MatchAndExplain( - const PacketContainer& value, - ::testing::MatchResultListener* listener) const override { - if (!Equals(packets_, value)) { - if (listener->IsInterested()) { - *listener << "The actual packet contents: \n"; - Print(value, listener->stream()); - } - return false; - } - return true; - } - - private: - bool Equals(const PacketContainer& c1, const PacketContainer& c2) const { - if (c1.size() != c2.size()) { - return false; - } - for (auto i1 = c1.begin(), i2 = c2.begin(); i1 != c1.end(); ++i1, ++i2) { - Packet p1 = *i1, p2 = *i2; - if (p1.Timestamp() != p2.Timestamp() || p1.IsEmpty() != p2.IsEmpty() || - (!p1.IsEmpty() && - p1.Get() != p2.Get())) { - return false; - } - } - return true; - } - void Print(const PacketContainer& packets, ::std::ostream* os) const { - for (auto it = packets.begin(); it != packets.end(); ++it) { - const Packet& packet = *it; - *os << (it == packets.begin() ? "{" : ""); - *os << SourceString(packet); - *os << (std::next(it) == packets.end() ? "}" : ", "); - } - } - - const PacketContainer packets_; -}; - -template -::testing::Matcher PacketsEq( - const PacketContainer& packets) { - return MakeMatcher( - new PacketsEqMatcher(packets)); -} - // A Calculator::Process callback function. typedef std::function @@ -743,9 +674,6 @@ TEST_F(FlowLimiterCalculatorTest, TwoInputStreams) { // The processing time "sleep_time" is reduced from 22ms to 12ms to create // the same frame rate as FlowLimiterCalculatorTest::TwoInputStreams. TEST_F(FlowLimiterCalculatorTest, ZeroQueue) { - auto BoolPacketsEq = PacketsEq, bool>; - auto IntPacketsEq = PacketsEq, int>; - // Configure the test. SetUpInputData(); SetUpSimulationClock(); @@ -839,13 +767,16 @@ TEST_F(FlowLimiterCalculatorTest, ZeroQueue) { input_packets_[0], input_packets_[2], input_packets_[15], input_packets_[17], input_packets_[19], }; - EXPECT_THAT(out_1_packets_, IntPacketsEq(expected_output)); + EXPECT_THAT(out_1_packets_, + ElementsAreArray(PacketMatchers(expected_output))); + // Exactly the timestamps released by FlowLimiterCalculator for in_1_sampled. std::vector expected_output_2 = { input_packets_[0], input_packets_[2], input_packets_[4], input_packets_[15], input_packets_[17], input_packets_[19], }; - EXPECT_THAT(out_2_packets, IntPacketsEq(expected_output_2)); + EXPECT_THAT(out_2_packets, + ElementsAreArray(PacketMatchers(expected_output_2))); // Validate the ALLOW stream output. std::vector expected_allow = { @@ -871,7 +802,8 @@ TEST_F(FlowLimiterCalculatorTest, ZeroQueue) { MakePacket(true).At(Timestamp(190000)), MakePacket(false).At(Timestamp(200000)), }; - EXPECT_THAT(allow_packets_, BoolPacketsEq(expected_allow)); + EXPECT_THAT(allow_packets_, + ElementsAreArray(PacketMatchers(expected_allow))); } std::vector StripBoundsUpdates(const std::vector& packets, @@ -891,9 +823,6 @@ std::vector StripBoundsUpdates(const std::vector& packets, // Shows how FlowLimiterCalculator releases auxiliary input packets. // In this test, auxiliary input packets arrive at twice the primary rate. TEST_F(FlowLimiterCalculatorTest, AuxiliaryInputs) { - auto BoolPacketsEq = PacketsEq, bool>; - auto IntPacketsEq = PacketsEq, int>; - // Configure the test. SetUpInputData(); SetUpSimulationClock(); @@ -1011,7 +940,8 @@ TEST_F(FlowLimiterCalculatorTest, AuxiliaryInputs) { MakePacket(6).At(Timestamp(60000)), Packet().At(Timestamp(80000)), }; - EXPECT_THAT(out_1_packets_, IntPacketsEq(expected_output)); + EXPECT_THAT(out_1_packets_, + ElementsAreArray(PacketMatchers(expected_output))); // Packets following input packets 2 and 6, and not input packets 4 and 8. std::vector expected_auxiliary_output = { @@ -1031,12 +961,13 @@ TEST_F(FlowLimiterCalculatorTest, AuxiliaryInputs) { }; std::vector actual_2 = StripBoundsUpdates(out_2_packets, Timestamp(90000)); - EXPECT_THAT(actual_2, IntPacketsEq(expected_auxiliary_output)); + EXPECT_THAT(actual_2, + ElementsAreArray(PacketMatchers(expected_auxiliary_output))); std::vector expected_3 = StripBoundsUpdates(expected_auxiliary_output, Timestamp(39999)); std::vector actual_3 = StripBoundsUpdates(out_3_packets, Timestamp(39999)); - EXPECT_THAT(actual_3, IntPacketsEq(expected_3)); + EXPECT_THAT(actual_3, ElementsAreArray(PacketMatchers(expected_3))); // Validate the ALLOW stream output. std::vector expected_allow = { @@ -1045,7 +976,8 @@ TEST_F(FlowLimiterCalculatorTest, AuxiliaryInputs) { MakePacket(true).At(Timestamp(60000)), MakePacket(false).At(Timestamp(80000)), }; - EXPECT_THAT(allow_packets_, BoolPacketsEq(expected_allow)); + EXPECT_THAT(allow_packets_, + ElementsAreArray(PacketMatchers(expected_allow))); } } // anonymous namespace diff --git a/mediapipe/calculators/core/get_vector_item_calculator.cc b/mediapipe/calculators/core/get_vector_item_calculator.cc index 51fb46b98..3306e4ff3 100644 --- a/mediapipe/calculators/core/get_vector_item_calculator.cc +++ b/mediapipe/calculators/core/get_vector_item_calculator.cc @@ -15,6 +15,7 @@ #include "mediapipe/calculators/core/get_vector_item_calculator.h" #include "mediapipe/framework/formats/classification.pb.h" +#include "mediapipe/framework/formats/detection.pb.h" #include "mediapipe/framework/formats/landmark.pb.h" namespace mediapipe { @@ -32,5 +33,9 @@ using GetClassificationListVectorItemCalculator = GetVectorItemCalculator; REGISTER_CALCULATOR(GetClassificationListVectorItemCalculator); +using GetDetectionVectorItemCalculator = + GetVectorItemCalculator; +REGISTER_CALCULATOR(GetDetectionVectorItemCalculator); + } // namespace api2 } // namespace mediapipe diff --git a/mediapipe/calculators/core/get_vector_item_calculator.h b/mediapipe/calculators/core/get_vector_item_calculator.h index dc98ccfe7..ee886b381 100644 --- a/mediapipe/calculators/core/get_vector_item_calculator.h +++ b/mediapipe/calculators/core/get_vector_item_calculator.h @@ -47,7 +47,7 @@ namespace api2 { // calculator: "Get{SpecificType}VectorItemCalculator" // input_stream: "VECTOR:vector" // input_stream: "INDEX:index" -// input_stream: "ITEM:item" +// output_stream: "ITEM:item" // options { // [mediapipe.GetVectorItemCalculatorOptions.ext] { // item_index: 5 @@ -65,6 +65,7 @@ class GetVectorItemCalculator : public Node { MEDIAPIPE_NODE_CONTRACT(kIn, kIdx, kOut); absl::Status Open(CalculatorContext* cc) final { + cc->SetOffset(mediapipe::TimestampDiff(0)); auto& options = cc->Options(); RET_CHECK(kIdx(cc).IsConnected() || options.has_item_index()); return absl::OkStatus(); @@ -90,8 +91,12 @@ class GetVectorItemCalculator : public Node { return absl::OkStatus(); } - RET_CHECK(idx >= 0 && idx < items.size()); - kOut(cc).Send(items[idx]); + RET_CHECK(idx >= 0); + RET_CHECK(options.output_empty_on_oob() || idx < items.size()); + + if (idx < items.size()) { + kOut(cc).Send(items[idx]); + } return absl::OkStatus(); } diff --git a/mediapipe/calculators/core/get_vector_item_calculator.proto b/mediapipe/calculators/core/get_vector_item_calculator.proto index c406283e4..9cfb579e4 100644 --- a/mediapipe/calculators/core/get_vector_item_calculator.proto +++ b/mediapipe/calculators/core/get_vector_item_calculator.proto @@ -26,4 +26,7 @@ message GetVectorItemCalculatorOptions { // Index of vector item to get. INDEX input stream can be used instead, or to // override. optional int32 item_index = 1; + + // Set to true to output an empty packet when the index is out of bounds. + optional bool output_empty_on_oob = 2; } diff --git a/mediapipe/calculators/core/get_vector_item_calculator_test.cc b/mediapipe/calculators/core/get_vector_item_calculator_test.cc index c148aa9d1..c2974e20a 100644 --- a/mediapipe/calculators/core/get_vector_item_calculator_test.cc +++ b/mediapipe/calculators/core/get_vector_item_calculator_test.cc @@ -32,18 +32,21 @@ CalculatorRunner MakeRunnerWithStream() { )"); } -CalculatorRunner MakeRunnerWithOptions(int set_index) { - return CalculatorRunner(absl::StrFormat(R"( +CalculatorRunner MakeRunnerWithOptions(int set_index, + bool output_empty_on_oob = false) { + return CalculatorRunner( + absl::StrFormat(R"( calculator: "TestGetIntVectorItemCalculator" input_stream: "VECTOR:vector_stream" output_stream: "ITEM:item_stream" options { [mediapipe.GetVectorItemCalculatorOptions.ext] { item_index: %d + output_empty_on_oob: %s } } )", - set_index)); + set_index, output_empty_on_oob ? "true" : "false")); } void AddInputVector(CalculatorRunner& runner, const std::vector& inputs, @@ -140,8 +143,7 @@ TEST(TestGetIntVectorItemCalculatorTest, StreamIndexBoundsCheckFail1) { absl::Status status = runner.Run(); ASSERT_FALSE(status.ok()); - EXPECT_THAT(status.message(), - testing::HasSubstr("idx >= 0 && idx < items.size()")); + EXPECT_THAT(status.message(), testing::HasSubstr("idx >= 0")); } TEST(TestGetIntVectorItemCalculatorTest, StreamIndexBoundsCheckFail2) { @@ -155,7 +157,8 @@ TEST(TestGetIntVectorItemCalculatorTest, StreamIndexBoundsCheckFail2) { absl::Status status = runner.Run(); ASSERT_FALSE(status.ok()); EXPECT_THAT(status.message(), - testing::HasSubstr("idx >= 0 && idx < items.size()")); + testing::HasSubstr( + "options.output_empty_on_oob() || idx < items.size()")); } TEST(TestGetIntVectorItemCalculatorTest, OptionsIndexBoundsCheckFail1) { @@ -167,8 +170,7 @@ TEST(TestGetIntVectorItemCalculatorTest, OptionsIndexBoundsCheckFail1) { absl::Status status = runner.Run(); ASSERT_FALSE(status.ok()); - EXPECT_THAT(status.message(), - testing::HasSubstr("idx >= 0 && idx < items.size()")); + EXPECT_THAT(status.message(), testing::HasSubstr("idx >= 0")); } TEST(TestGetIntVectorItemCalculatorTest, OptionsIndexBoundsCheckFail2) { @@ -181,7 +183,21 @@ TEST(TestGetIntVectorItemCalculatorTest, OptionsIndexBoundsCheckFail2) { absl::Status status = runner.Run(); ASSERT_FALSE(status.ok()); EXPECT_THAT(status.message(), - testing::HasSubstr("idx >= 0 && idx < items.size()")); + testing::HasSubstr( + "options.output_empty_on_oob() || idx < items.size()")); +} + +TEST(TestGetIntVectorItemCalculatorTest, OptionsIndexBoundsCheckFail3) { + const int try_index = 3; + CalculatorRunner runner = MakeRunnerWithOptions(try_index, true); + const std::vector inputs = {1, 2, 3}; + + AddInputVector(runner, inputs, 1); + + MP_ASSERT_OK(runner.Run()); + + const std::vector& outputs = runner.Outputs().Tag("ITEM").packets; + EXPECT_THAT(outputs, testing::ElementsAre()); } TEST(TestGetIntVectorItemCalculatorTest, IndexStreamTwoTimestamps) { diff --git a/mediapipe/calculators/core/merge_to_vector_calculator.cc b/mediapipe/calculators/core/merge_to_vector_calculator.cc index cca64bc9a..fd053ed2b 100644 --- a/mediapipe/calculators/core/merge_to_vector_calculator.cc +++ b/mediapipe/calculators/core/merge_to_vector_calculator.cc @@ -15,6 +15,7 @@ limitations under the License. #include "mediapipe/calculators/core/merge_to_vector_calculator.h" +#include "mediapipe/framework/formats/detection.pb.h" #include "mediapipe/framework/formats/image.h" namespace mediapipe { @@ -23,5 +24,13 @@ namespace api2 { typedef MergeToVectorCalculator MergeImagesToVectorCalculator; MEDIAPIPE_REGISTER_NODE(MergeImagesToVectorCalculator); +typedef MergeToVectorCalculator + MergeGpuBuffersToVectorCalculator; +MEDIAPIPE_REGISTER_NODE(MergeGpuBuffersToVectorCalculator); + +typedef MergeToVectorCalculator + MergeDetectionsToVectorCalculator; +MEDIAPIPE_REGISTER_NODE(MergeDetectionsToVectorCalculator); + } // namespace api2 } // namespace mediapipe diff --git a/mediapipe/calculators/core/merge_to_vector_calculator.h b/mediapipe/calculators/core/merge_to_vector_calculator.h index bed616695..f63d86ee4 100644 --- a/mediapipe/calculators/core/merge_to_vector_calculator.h +++ b/mediapipe/calculators/core/merge_to_vector_calculator.h @@ -42,11 +42,20 @@ class MergeToVectorCalculator : public Node { return absl::OkStatus(); } + absl::Status Open(::mediapipe::CalculatorContext* cc) { + cc->SetOffset(::mediapipe::TimestampDiff(0)); + return absl::OkStatus(); + } + absl::Status Process(CalculatorContext* cc) { const int input_num = kIn(cc).Count(); - std::vector output_vector(input_num); - std::transform(kIn(cc).begin(), kIn(cc).end(), output_vector.begin(), - [](const auto& elem) -> T { return elem.Get(); }); + std::vector output_vector; + for (auto it = kIn(cc).begin(); it != kIn(cc).end(); it++) { + const auto& elem = *it; + if (!elem.IsEmpty()) { + output_vector.push_back(elem.Get()); + } + } kOut(cc).Send(output_vector); return absl::OkStatus(); } diff --git a/mediapipe/calculators/core/mux_calculator.cc b/mediapipe/calculators/core/mux_calculator.cc index a0ce2ae34..88b04a32b 100644 --- a/mediapipe/calculators/core/mux_calculator.cc +++ b/mediapipe/calculators/core/mux_calculator.cc @@ -41,6 +41,10 @@ class MuxCalculator : public Node { StreamHandler("MuxInputStreamHandler")); absl::Status Process(CalculatorContext* cc) final { + if (kSelect(cc).IsStream() && kSelect(cc).IsEmpty()) { + return absl::OkStatus(); + } + int select = *kSelect(cc); RET_CHECK(0 <= select && select < kIn(cc).Count()); if (!kIn(cc)[select].IsEmpty()) { diff --git a/mediapipe/calculators/core/mux_calculator_test.cc b/mediapipe/calculators/core/mux_calculator_test.cc index 86d2fab42..6b9434be9 100644 --- a/mediapipe/calculators/core/mux_calculator_test.cc +++ b/mediapipe/calculators/core/mux_calculator_test.cc @@ -398,6 +398,95 @@ TEST(MuxCalculatorTest, HandleTimestampBoundUpdates) { MP_ASSERT_OK(graph.WaitUntilDone()); } +TEST(MuxCalculatorTest, HandlesCloseGracefully) { + CalculatorGraphConfig config = + mediapipe::ParseTextProtoOrDie( + R"pb( + input_stream: "select" + input_stream: "value_0" + input_stream: "value_1" + node { + calculator: "MuxCalculator" + input_stream: "SELECT:select" + input_stream: "INPUT:0:value_0" + input_stream: "INPUT:1:value_1" + output_stream: "OUTPUT:output" + } + )pb"); + CalculatorGraph graph; + MP_ASSERT_OK(graph.Initialize(config)); + + // Observe packets. + std::vector output_packets; + MP_ASSERT_OK(graph.ObserveOutputStream( + "output", + [&output_packets](const Packet& p) -> absl::Status { + output_packets.push_back(p); + return absl::OkStatus(); + }, + /*observe_timestamp_bounds=*/true)); + + // Start graph. + MP_ASSERT_OK(graph.StartRun({})); + + // Add single packet wait for completion and close. + MP_ASSERT_OK(graph.AddPacketToInputStream( + "value_0", MakePacket(0).At(Timestamp(1000)))); + MP_ASSERT_OK(graph.WaitUntilIdle()); + MP_ASSERT_OK(graph.CloseAllInputStreams()); + MP_ASSERT_OK(graph.WaitUntilDone()); + + EXPECT_TRUE(output_packets.empty()); +} + +TEST(MuxCalculatorTest, HandlesCloseGracefullyWithDeafultInputStreamHandler) { + CalculatorGraphConfig config = + mediapipe::ParseTextProtoOrDie( + R"pb( + # This is required in order for EXPECT_DEATH to work everywhere + executor { name: "" type: "ApplicationThreadExecutor" } + + input_stream: "select" + input_stream: "value_0" + input_stream: "value_1" + node { + calculator: "MuxCalculator" + input_stream: "SELECT:select" + input_stream: "INPUT:0:value_0" + input_stream: "INPUT:1:value_1" + output_stream: "OUTPUT:output" + input_stream_handler { + input_stream_handler: "DefaultInputStreamHandler" + } + } + )pb"); + CalculatorGraph graph; + MP_ASSERT_OK(graph.Initialize(config)); + + // Observe packets. + std::vector output_packets; + MP_ASSERT_OK(graph.ObserveOutputStream( + "output", + [&output_packets](const Packet& p) -> absl::Status { + output_packets.push_back(p); + return absl::OkStatus(); + }, + /*observe_timestamp_bounds=*/true)); + + // Start graph. + MP_ASSERT_OK(graph.StartRun({})); + + // Add single packet wait for completion and close. + MP_ASSERT_OK(graph.AddPacketToInputStream( + "value_0", MakePacket(0).At(Timestamp(1000)))); + MP_ASSERT_OK(graph.WaitUntilIdle()); + MP_ASSERT_OK(graph.CloseAllInputStreams()); + MP_ASSERT_OK(graph.WaitUntilDone()); + + ASSERT_EQ(output_packets.size(), 1); + EXPECT_TRUE(output_packets[0].IsEmpty()); +} + } // namespace } // namespace mediapipe diff --git a/mediapipe/calculators/core/real_time_flow_limiter_calculator.cc b/mediapipe/calculators/core/real_time_flow_limiter_calculator.cc index ef3cb9896..e3c92ba52 100644 --- a/mediapipe/calculators/core/real_time_flow_limiter_calculator.cc +++ b/mediapipe/calculators/core/real_time_flow_limiter_calculator.cc @@ -76,7 +76,11 @@ constexpr char kMaxInFlightTag[] = "MAX_IN_FLIGHT"; // } // output_stream: "gated_frames" // } -class RealTimeFlowLimiterCalculator : public CalculatorBase { +// +// Please use FlowLimiterCalculator, which replaces this calculator and +// defines a few additional configuration options. +class ABSL_DEPRECATED("Use FlowLimiterCalculator instead.") + RealTimeFlowLimiterCalculator : public CalculatorBase { public: static absl::Status GetContract(CalculatorContract* cc) { int num_data_streams = cc->Inputs().NumEntries(""); diff --git a/mediapipe/calculators/core/sequence_shift_calculator.cc b/mediapipe/calculators/core/sequence_shift_calculator.cc index 66dbdef2e..026048b79 100644 --- a/mediapipe/calculators/core/sequence_shift_calculator.cc +++ b/mediapipe/calculators/core/sequence_shift_calculator.cc @@ -66,12 +66,16 @@ class SequenceShiftCalculator : public Node { // The number of packets or timestamps we need to store to output packet[i] at // the timestamp of packet[i + packet_offset]; equal to abs(packet_offset). int cache_size_; + bool emit_empty_packets_before_first_packet_ = false; }; MEDIAPIPE_REGISTER_NODE(SequenceShiftCalculator); absl::Status SequenceShiftCalculator::Open(CalculatorContext* cc) { packet_offset_ = kOffset(cc).GetOr( cc->Options().packet_offset()); + emit_empty_packets_before_first_packet_ = + cc->Options() + .emit_empty_packets_before_first_packet(); cache_size_ = abs(packet_offset_); // An offset of zero is a no-op, but someone might still request it. if (packet_offset_ == 0) { @@ -96,6 +100,8 @@ void SequenceShiftCalculator::ProcessPositiveOffset(CalculatorContext* cc) { // Ready to output oldest packet with current timestamp. kOut(cc).Send(packet_cache_.front().At(cc->InputTimestamp())); packet_cache_.pop_front(); + } else if (emit_empty_packets_before_first_packet_) { + LOG(FATAL) << "Not supported yet"; } // Store current packet for later output. packet_cache_.push_back(kIn(cc).packet()); diff --git a/mediapipe/calculators/core/sequence_shift_calculator.proto b/mediapipe/calculators/core/sequence_shift_calculator.proto index 15b111d71..36b0bb959 100644 --- a/mediapipe/calculators/core/sequence_shift_calculator.proto +++ b/mediapipe/calculators/core/sequence_shift_calculator.proto @@ -23,4 +23,8 @@ message SequenceShiftCalculatorOptions { optional SequenceShiftCalculatorOptions ext = 107633927; } optional int32 packet_offset = 1 [default = -1]; + + // Emits empty packets before the first delayed packet is emitted. Takes + // effect only when packet offset is set to positive. + optional bool emit_empty_packets_before_first_packet = 2 [default = false]; } diff --git a/mediapipe/calculators/image/BUILD b/mediapipe/calculators/image/BUILD index 89e2d371c..9aae8cfbc 100644 --- a/mediapipe/calculators/image/BUILD +++ b/mediapipe/calculators/image/BUILD @@ -16,12 +16,11 @@ load("//mediapipe/framework/port:build_config.bzl", "mediapipe_proto_library") licenses(["notice"]) -package(default_visibility = ["//visibility:private"]) +package(default_visibility = ["//visibility:public"]) mediapipe_proto_library( name = "opencv_image_encoder_calculator_proto", srcs = ["opencv_image_encoder_calculator.proto"], - visibility = ["//visibility:public"], deps = [ "//mediapipe/framework:calculator_options_proto", "//mediapipe/framework:calculator_proto", @@ -31,7 +30,6 @@ mediapipe_proto_library( mediapipe_proto_library( name = "scale_image_calculator_proto", srcs = ["scale_image_calculator.proto"], - visibility = ["//visibility:public"], deps = [ "//mediapipe/framework:calculator_options_proto", "//mediapipe/framework:calculator_proto", @@ -42,7 +40,6 @@ mediapipe_proto_library( mediapipe_proto_library( name = "set_alpha_calculator_proto", srcs = ["set_alpha_calculator.proto"], - visibility = ["//visibility:public"], deps = [ "//mediapipe/framework:calculator_options_proto", "//mediapipe/framework:calculator_proto", @@ -52,7 +49,6 @@ mediapipe_proto_library( mediapipe_proto_library( name = "image_cropping_calculator_proto", srcs = ["image_cropping_calculator.proto"], - visibility = ["//visibility:public"], deps = [ "//mediapipe/framework:calculator_options_proto", "//mediapipe/framework:calculator_proto", @@ -62,7 +58,6 @@ mediapipe_proto_library( mediapipe_proto_library( name = "bilateral_filter_calculator_proto", srcs = ["bilateral_filter_calculator.proto"], - visibility = ["//visibility:public"], deps = [ "//mediapipe/framework:calculator_options_proto", "//mediapipe/framework:calculator_proto", @@ -72,7 +67,6 @@ mediapipe_proto_library( mediapipe_proto_library( name = "recolor_calculator_proto", srcs = ["recolor_calculator.proto"], - visibility = ["//visibility:public"], deps = [ "//mediapipe/framework:calculator_options_proto", "//mediapipe/framework:calculator_proto", @@ -83,7 +77,6 @@ mediapipe_proto_library( mediapipe_proto_library( name = "segmentation_smoothing_calculator_proto", srcs = ["segmentation_smoothing_calculator.proto"], - visibility = ["//visibility:public"], deps = [ "//mediapipe/framework:calculator_options_proto", "//mediapipe/framework:calculator_proto", @@ -93,7 +86,6 @@ mediapipe_proto_library( cc_library( name = "color_convert_calculator", srcs = ["color_convert_calculator.cc"], - visibility = ["//visibility:public"], deps = [ "//mediapipe/framework:calculator_framework", "//mediapipe/framework:timestamp", @@ -112,7 +104,6 @@ cc_library( cc_library( name = "opencv_encoded_image_to_image_frame_calculator", srcs = ["opencv_encoded_image_to_image_frame_calculator.cc"], - visibility = ["//visibility:public"], deps = [ ":opencv_encoded_image_to_image_frame_calculator_cc_proto", "//mediapipe/framework:calculator_framework", @@ -127,7 +118,6 @@ cc_library( cc_library( name = "opencv_image_encoder_calculator", srcs = ["opencv_image_encoder_calculator.cc"], - visibility = ["//visibility:public"], deps = [ ":opencv_image_encoder_calculator_cc_proto", "//mediapipe/framework:calculator_framework", @@ -142,7 +132,6 @@ cc_library( cc_library( name = "opencv_put_text_calculator", srcs = ["opencv_put_text_calculator.cc"], - visibility = ["//visibility:public"], deps = [ "//mediapipe/framework:calculator_framework", "//mediapipe/framework/formats:image_frame_opencv", @@ -156,11 +145,10 @@ cc_library( cc_library( name = "set_alpha_calculator", srcs = ["set_alpha_calculator.cc"], - visibility = ["//visibility:public"], deps = [ ":set_alpha_calculator_cc_proto", - "//mediapipe/framework:calculator_options_cc_proto", "//mediapipe/framework/formats:image_format_cc_proto", + "//mediapipe/framework:calculator_options_cc_proto", "//mediapipe/framework:calculator_framework", "//mediapipe/framework/formats:image_frame", "//mediapipe/framework/formats:image_frame_opencv", @@ -183,11 +171,10 @@ cc_library( cc_library( name = "bilateral_filter_calculator", srcs = ["bilateral_filter_calculator.cc"], - visibility = ["//visibility:public"], deps = [ ":bilateral_filter_calculator_cc_proto", - "//mediapipe/framework:calculator_options_cc_proto", "//mediapipe/framework/formats:image_format_cc_proto", + "//mediapipe/framework:calculator_options_cc_proto", "@com_google_absl//absl/strings", "//mediapipe/framework:calculator_framework", "//mediapipe/framework/formats:image_frame", @@ -212,13 +199,11 @@ cc_library( mediapipe_proto_library( name = "rotation_mode_proto", srcs = ["rotation_mode.proto"], - visibility = ["//visibility:public"], ) mediapipe_proto_library( name = "image_transformation_calculator_proto", srcs = ["image_transformation_calculator.proto"], - visibility = ["//visibility:public"], deps = [ ":rotation_mode_proto", "//mediapipe/framework:calculator_options_proto", @@ -243,7 +228,6 @@ cc_library( ], "//conditions:default": [], }), - visibility = ["//visibility:public"], deps = [ ":rotation_mode_cc_proto", ":image_transformation_calculator_cc_proto", @@ -287,13 +271,12 @@ cc_library( ], "//conditions:default": [], }), - visibility = ["//visibility:public"], deps = [ ":image_cropping_calculator_cc_proto", + "//mediapipe/framework/formats:rect_cc_proto", "//mediapipe/framework:calculator_framework", "//mediapipe/framework/formats:image_frame", "//mediapipe/framework/formats:image_frame_opencv", - "//mediapipe/framework/formats:rect_cc_proto", "//mediapipe/framework/port:opencv_core", "//mediapipe/framework/port:opencv_imgproc", "//mediapipe/framework/port:ret_check", @@ -330,7 +313,6 @@ cc_test( cc_library( name = "luminance_calculator", srcs = ["luminance_calculator.cc"], - visibility = ["//visibility:public"], deps = [ "//mediapipe/framework/port:ret_check", "//mediapipe/framework/port:status", @@ -344,7 +326,6 @@ cc_library( cc_library( name = "sobel_edges_calculator", srcs = ["sobel_edges_calculator.cc"], - visibility = ["//visibility:public"], deps = [ "//mediapipe/framework/port:ret_check", "//mediapipe/framework/port:status", @@ -358,15 +339,14 @@ cc_library( cc_library( name = "recolor_calculator", srcs = ["recolor_calculator.cc"], - visibility = ["//visibility:public"], deps = [ ":recolor_calculator_cc_proto", + "//mediapipe/util:color_cc_proto", "//mediapipe/framework:calculator_framework", "//mediapipe/framework/formats:image_frame", "//mediapipe/framework/formats:image_frame_opencv", "//mediapipe/framework/port:status", "//mediapipe/framework/port:ret_check", - "//mediapipe/util:color_cc_proto", "//mediapipe/framework/port:opencv_core", "//mediapipe/framework/port:opencv_imgproc", ] + select({ @@ -385,9 +365,6 @@ cc_library( name = "scale_image_utils", srcs = ["scale_image_utils.cc"], hdrs = ["scale_image_utils.h"], - visibility = [ - "//mediapipe:__subpackages__", - ], deps = [ "//mediapipe/framework:calculator_framework", "//mediapipe/framework/port:logging", @@ -400,12 +377,9 @@ cc_library( cc_library( name = "scale_image_calculator", srcs = ["scale_image_calculator.cc"], - visibility = [ - "//visibility:public", - ], deps = [ + ":scale_image_calculator_cc_proto", ":scale_image_utils", - "//mediapipe/calculators/image:scale_image_calculator_cc_proto", "//mediapipe/framework:calculator_framework", "//mediapipe/framework/formats:image_format_cc_proto", "//mediapipe/framework/formats:image_frame", @@ -429,7 +403,6 @@ cc_library( mediapipe_proto_library( name = "image_clone_calculator_proto", srcs = ["image_clone_calculator.proto"], - visibility = ["//visibility:public"], deps = [ "//mediapipe/framework:calculator_options_proto", "//mediapipe/framework:calculator_proto", @@ -439,7 +412,6 @@ mediapipe_proto_library( cc_library( name = "image_clone_calculator", srcs = ["image_clone_calculator.cc"], - visibility = ["//visibility:public"], deps = [ ":image_clone_calculator_cc_proto", "//mediapipe/framework/api2:node", @@ -459,7 +431,6 @@ cc_library( cc_library( name = "image_properties_calculator", srcs = ["image_properties_calculator.cc"], - visibility = ["//visibility:public"], deps = [ "//mediapipe/framework/api2:node", "//mediapipe/framework:calculator_framework", @@ -524,7 +495,6 @@ cc_test( mediapipe_proto_library( name = "mask_overlay_calculator_proto", srcs = ["mask_overlay_calculator.proto"], - visibility = ["//visibility:public"], deps = [ "//mediapipe/framework:calculator_options_proto", "//mediapipe/framework:calculator_proto", @@ -534,7 +504,6 @@ mediapipe_proto_library( mediapipe_proto_library( name = "opencv_encoded_image_to_image_frame_calculator_proto", srcs = ["opencv_encoded_image_to_image_frame_calculator.proto"], - visibility = ["//visibility:public"], deps = [ "//mediapipe/framework:calculator_options_proto", "//mediapipe/framework:calculator_proto", @@ -544,7 +513,6 @@ mediapipe_proto_library( mediapipe_proto_library( name = "feature_detector_calculator_proto", srcs = ["feature_detector_calculator.proto"], - visibility = ["//visibility:public"], deps = [ "//mediapipe/framework:calculator_options_proto", "//mediapipe/framework:calculator_proto", @@ -554,7 +522,6 @@ mediapipe_proto_library( cc_library( name = "mask_overlay_calculator", srcs = ["mask_overlay_calculator.cc"], - visibility = ["//visibility:public"], deps = [ ":mask_overlay_calculator_cc_proto", "//mediapipe/framework:calculator_framework", @@ -570,7 +537,6 @@ cc_library( cc_library( name = "feature_detector_calculator", srcs = ["feature_detector_calculator.cc"], - visibility = ["//visibility:public"], deps = [ ":feature_detector_calculator_cc_proto", "//mediapipe/framework:calculator_framework", @@ -597,7 +563,6 @@ cc_library( cc_library( name = "image_file_properties_calculator", srcs = ["image_file_properties_calculator.cc"], - visibility = ["//visibility:public"], deps = [ "//mediapipe/framework:calculator_framework", "//mediapipe/framework/formats:image_file_properties_cc_proto", @@ -627,11 +592,10 @@ cc_test( cc_library( name = "segmentation_smoothing_calculator", srcs = ["segmentation_smoothing_calculator.cc"], - visibility = ["//visibility:public"], deps = [ ":segmentation_smoothing_calculator_cc_proto", - "//mediapipe/framework:calculator_options_cc_proto", "//mediapipe/framework/formats:image_format_cc_proto", + "//mediapipe/framework:calculator_options_cc_proto", "//mediapipe/framework:calculator_framework", "//mediapipe/framework/formats:image_frame", "//mediapipe/framework/formats:image", @@ -724,7 +688,6 @@ cc_library( mediapipe_proto_library( name = "warp_affine_calculator_proto", srcs = ["warp_affine_calculator.proto"], - visibility = ["//visibility:public"], deps = [ "//mediapipe/framework:calculator_options_proto", "//mediapipe/framework:calculator_proto", @@ -736,7 +699,6 @@ cc_library( name = "warp_affine_calculator", srcs = ["warp_affine_calculator.cc"], hdrs = ["warp_affine_calculator.h"], - visibility = ["//visibility:public"], deps = [ ":affine_transformation", ":warp_affine_calculator_cc_proto", @@ -785,8 +747,8 @@ cc_test( tags = ["desktop_only_test"], deps = [ ":affine_transformation", + ":image_transformation_calculator", ":warp_affine_calculator", - "//mediapipe/calculators/image:image_transformation_calculator", "//mediapipe/calculators/tensor:image_to_tensor_converter", "//mediapipe/calculators/tensor:image_to_tensor_utils", "//mediapipe/calculators/util:from_image_calculator", @@ -817,7 +779,6 @@ cc_test( cc_library( name = "yuv_to_image_calculator", srcs = ["yuv_to_image_calculator.cc"], - visibility = ["//visibility:public"], deps = [ "//mediapipe/framework:calculator_context", "//mediapipe/framework:calculator_framework", diff --git a/mediapipe/calculators/image/affine_transformation_runner_gl.cc b/mediapipe/calculators/image/affine_transformation_runner_gl.cc index c38fc8e07..361dfc902 100644 --- a/mediapipe/calculators/image/affine_transformation_runner_gl.cc +++ b/mediapipe/calculators/image/affine_transformation_runner_gl.cc @@ -92,8 +92,8 @@ class GlTextureWarpAffineRunner constexpr GLchar kVertShader[] = R"( in vec4 position; - in mediump vec4 texture_coordinate; - out mediump vec2 sample_coordinate; + in highp vec4 texture_coordinate; + out highp vec2 sample_coordinate; uniform mat4 transform_matrix; void main() { @@ -104,7 +104,7 @@ class GlTextureWarpAffineRunner )"; constexpr GLchar kFragShader[] = R"( - DEFAULT_PRECISION(mediump, float) + DEFAULT_PRECISION(highp, float) in vec2 sample_coordinate; uniform sampler2D input_texture; diff --git a/mediapipe/calculators/image/color_convert_calculator.cc b/mediapipe/calculators/image/color_convert_calculator.cc index bdac932bb..4781f1ea1 100644 --- a/mediapipe/calculators/image/color_convert_calculator.cc +++ b/mediapipe/calculators/image/color_convert_calculator.cc @@ -38,6 +38,7 @@ void SetColorChannel(int channel, uint8 value, cv::Mat* mat) { constexpr char kRgbaInTag[] = "RGBA_IN"; constexpr char kRgbInTag[] = "RGB_IN"; +constexpr char kBgrInTag[] = "BGR_IN"; constexpr char kBgraInTag[] = "BGRA_IN"; constexpr char kGrayInTag[] = "GRAY_IN"; constexpr char kRgbaOutTag[] = "RGBA_OUT"; @@ -57,6 +58,7 @@ constexpr char kGrayOutTag[] = "GRAY_OUT"; // RGB -> RGBA // RGBA -> BGRA // BGRA -> RGBA +// BGR -> RGB // // This calculator only supports a single input stream and output stream at a // time. If more than one input stream or output stream is present, the @@ -69,6 +71,7 @@ constexpr char kGrayOutTag[] = "GRAY_OUT"; // RGB_IN: The input video stream (ImageFrame, SRGB). // BGRA_IN: The input video stream (ImageFrame, SBGRA). // GRAY_IN: The input video stream (ImageFrame, GRAY8). +// BGR_IN: The input video stream (ImageFrame, SBGR). // // Output streams: // RGBA_OUT: The output video stream (ImageFrame, SRGBA). @@ -122,6 +125,10 @@ absl::Status ColorConvertCalculator::GetContract(CalculatorContract* cc) { cc->Inputs().Tag(kBgraInTag).Set(); } + if (cc->Inputs().HasTag(kBgrInTag)) { + cc->Inputs().Tag(kBgrInTag).Set(); + } + if (cc->Outputs().HasTag(kRgbOutTag)) { cc->Outputs().Tag(kRgbOutTag).Set(); } @@ -194,6 +201,11 @@ absl::Status ColorConvertCalculator::Process(CalculatorContext* cc) { return ConvertAndOutput(kRgbaInTag, kBgraOutTag, ImageFormat::SBGRA, cv::COLOR_RGBA2BGRA, cc); } + // BGR -> RGB + if (cc->Inputs().HasTag(kBgrInTag) && cc->Outputs().HasTag(kRgbOutTag)) { + return ConvertAndOutput(kBgrInTag, kRgbOutTag, ImageFormat::SRGB, + cv::COLOR_BGR2RGB, cc); + } return mediapipe::InvalidArgumentErrorBuilder(MEDIAPIPE_LOC) << "Unsupported image format conversion."; diff --git a/mediapipe/calculators/image/image_cropping_calculator.cc b/mediapipe/calculators/image/image_cropping_calculator.cc index 8c9305ffb..1a2b2e5b0 100644 --- a/mediapipe/calculators/image/image_cropping_calculator.cc +++ b/mediapipe/calculators/image/image_cropping_calculator.cc @@ -37,7 +37,8 @@ enum { ATTRIB_VERTEX, ATTRIB_TEXTURE_POSITION, NUM_ATTRIBUTES }; namespace mediapipe { namespace { - +using ::mediapipe::NormalizedRect; +using ::mediapipe::Rect; #if !MEDIAPIPE_DISABLE_GPU #endif // !MEDIAPIPE_DISABLE_GPU diff --git a/mediapipe/calculators/image/image_cropping_calculator_test.cc b/mediapipe/calculators/image/image_cropping_calculator_test.cc index b3f692889..3c565282b 100644 --- a/mediapipe/calculators/image/image_cropping_calculator_test.cc +++ b/mediapipe/calculators/image/image_cropping_calculator_test.cc @@ -195,11 +195,11 @@ TEST(ImageCroppingCalculatorTest, RedundantSpecWithInputStream) { auto cc = absl::make_unique( calculator_state.get(), inputTags, tool::CreateTagMap({}).value()); auto& inputs = cc->Inputs(); - mediapipe::Rect rect = ParseTextProtoOrDie( + Rect rect = ParseTextProtoOrDie( R"pb( width: 1 height: 1 x_center: 40 y_center: 40 rotation: 0.5 )pb"); - inputs.Tag(kRectTag).Value() = MakePacket(rect); + inputs.Tag(kRectTag).Value() = MakePacket(rect); RectSpec expectRect = { .width = 1, .height = 1, diff --git a/mediapipe/calculators/image/scale_image_utils.cc b/mediapipe/calculators/image/scale_image_utils.cc index 490d0336a..86a53ffc5 100644 --- a/mediapipe/calculators/image/scale_image_utils.cc +++ b/mediapipe/calculators/image/scale_image_utils.cc @@ -142,6 +142,9 @@ absl::Status FindOutputDimensions(int input_width, // static_cast(input_height)); try_width = (try_width / 2) * 2; try_height = (try_height / 2) * 2; + // The output width/height should be greater than 0. + try_width = std::max(try_width, 1); + try_height = std::max(try_height, 1); if (target_height <= 0 || try_height <= target_height) { // The resulting height based on the target width and aspect ratio @@ -160,6 +163,9 @@ absl::Status FindOutputDimensions(int input_width, // static_cast(input_width)); try_width = (try_width / 2) * 2; try_height = (try_height / 2) * 2; + // The output width/height should be greater than 0. + try_width = std::max(try_width, 1); + try_height = std::max(try_height, 1); if (target_width <= 0 || try_width <= target_width) { // The resulting width based on the target width and aspect ratio diff --git a/mediapipe/calculators/image/scale_image_utils_test.cc b/mediapipe/calculators/image/scale_image_utils_test.cc index bda1fa4d6..b4810071c 100644 --- a/mediapipe/calculators/image/scale_image_utils_test.cc +++ b/mediapipe/calculators/image/scale_image_utils_test.cc @@ -124,6 +124,16 @@ TEST(ScaleImageUtilsTest, FindOutputDimensionsPreserveRatio) { &output_width, &output_height)); EXPECT_EQ(151, output_width); EXPECT_EQ(101, output_height); + // Scale to height 1. + MP_ASSERT_OK(FindOutputDimensions(10000, 10, 100, 0, 0, true, 2, + &output_width, &output_height)); + EXPECT_EQ(100, output_width); + EXPECT_EQ(1, output_height); + // Scale to width 1. + MP_ASSERT_OK(FindOutputDimensions(10, 10000, 0, 100, 0, true, 2, + &output_width, &output_height)); + EXPECT_EQ(1, output_width); + EXPECT_EQ(100, output_height); } // Tests scaling without keeping the aspect ratio fixed. diff --git a/mediapipe/calculators/internal/BUILD b/mediapipe/calculators/internal/BUILD index 54b6c20f1..8647e3f3f 100644 --- a/mediapipe/calculators/internal/BUILD +++ b/mediapipe/calculators/internal/BUILD @@ -12,16 +12,16 @@ # See the License for the specific language governing permissions and # limitations under the License. -licenses(["notice"]) - load("//mediapipe/framework/port:build_config.bzl", "mediapipe_cc_proto_library") +licenses(["notice"]) + package(default_visibility = ["//visibility:private"]) proto_library( name = "callback_packet_calculator_proto", srcs = ["callback_packet_calculator.proto"], - visibility = ["//visibility:public"], + visibility = ["//mediapipe/framework:__subpackages__"], deps = ["//mediapipe/framework:calculator_proto"], ) @@ -29,14 +29,14 @@ mediapipe_cc_proto_library( name = "callback_packet_calculator_cc_proto", srcs = ["callback_packet_calculator.proto"], cc_deps = ["//mediapipe/framework:calculator_cc_proto"], - visibility = ["//visibility:public"], + visibility = ["//mediapipe/framework:__subpackages__"], deps = [":callback_packet_calculator_proto"], ) cc_library( name = "callback_packet_calculator", srcs = ["callback_packet_calculator.cc"], - visibility = ["//visibility:public"], + visibility = ["//mediapipe/framework:__subpackages__"], deps = [ ":callback_packet_calculator_cc_proto", "//mediapipe/framework:calculator_base", diff --git a/mediapipe/calculators/tensor/BUILD b/mediapipe/calculators/tensor/BUILD index 3f1278397..69d666092 100644 --- a/mediapipe/calculators/tensor/BUILD +++ b/mediapipe/calculators/tensor/BUILD @@ -24,12 +24,13 @@ load("//mediapipe/framework:encode_binary_proto.bzl", "encode_binary_proto") licenses(["notice"]) -package(default_visibility = ["//visibility:private"]) +package(default_visibility = ["//visibility:public"]) exports_files( glob(["testdata/image_to_tensor/*"]), visibility = [ "//mediapipe/calculators/image:__subpackages__", + "//mediapipe/util:__subpackages__", ], ) @@ -43,9 +44,6 @@ selects.config_setting_group( mediapipe_proto_library( name = "audio_to_tensor_calculator_proto", srcs = ["audio_to_tensor_calculator.proto"], - visibility = [ - "//mediapipe/framework:mediapipe_internal", - ], deps = [ "//mediapipe/framework:calculator_options_proto", "//mediapipe/framework:calculator_proto", @@ -55,17 +53,6 @@ mediapipe_proto_library( cc_library( name = "audio_to_tensor_calculator", srcs = ["audio_to_tensor_calculator.cc"], - copts = select({ - # b/215212850 - "//mediapipe:apple": [ - "-x objective-c++", - "-fobjc-arc", - ], - "//conditions:default": [], - }), - visibility = [ - "//mediapipe/framework:mediapipe_internal", - ], deps = [ ":audio_to_tensor_calculator_cc_proto", "//mediapipe/framework:calculator_framework", @@ -112,9 +99,6 @@ cc_test( mediapipe_proto_library( name = "tensors_to_audio_calculator_proto", srcs = ["tensors_to_audio_calculator.proto"], - visibility = [ - "//mediapipe/framework:mediapipe_internal", - ], deps = [ "//mediapipe/framework:calculator_options_proto", "//mediapipe/framework:calculator_proto", @@ -124,9 +108,6 @@ mediapipe_proto_library( cc_library( name = "tensors_to_audio_calculator", srcs = ["tensors_to_audio_calculator.cc"], - visibility = [ - "//mediapipe/framework:mediapipe_internal", - ], deps = [ ":tensors_to_audio_calculator_cc_proto", "//mediapipe/framework:calculator_framework", @@ -163,9 +144,6 @@ cc_test( mediapipe_proto_library( name = "feedback_tensors_calculator_proto", srcs = ["feedback_tensors_calculator.proto"], - visibility = [ - "//mediapipe/framework:mediapipe_internal", - ], deps = [ "//mediapipe/framework:calculator_options_proto", "//mediapipe/framework:calculator_proto", @@ -175,17 +153,6 @@ mediapipe_proto_library( cc_library( name = "feedback_tensors_calculator", srcs = ["feedback_tensors_calculator.cc"], - copts = select({ - # b/215212850 - "//mediapipe:apple": [ - "-x objective-c++", - "-fobjc-arc", - ], - "//conditions:default": [], - }), - visibility = [ - "//mediapipe/framework:mediapipe_internal", - ], deps = [ ":feedback_tensors_calculator_cc_proto", "//mediapipe/framework:calculator_framework", @@ -215,9 +182,6 @@ cc_test( mediapipe_proto_library( name = "bert_preprocessor_calculator_proto", srcs = ["bert_preprocessor_calculator.proto"], - visibility = [ - "//mediapipe/framework:mediapipe_internal", - ], deps = [ "//mediapipe/framework:calculator_options_proto", "//mediapipe/framework:calculator_proto", @@ -227,9 +191,6 @@ mediapipe_proto_library( cc_library( name = "bert_preprocessor_calculator", srcs = ["bert_preprocessor_calculator.cc"], - visibility = [ - "//mediapipe/framework:mediapipe_internal", - ], deps = [ ":bert_preprocessor_calculator_cc_proto", "//mediapipe/framework:calculator_framework", @@ -273,9 +234,6 @@ cc_test( mediapipe_proto_library( name = "regex_preprocessor_calculator_proto", srcs = ["regex_preprocessor_calculator.proto"], - visibility = [ - "//mediapipe/framework:mediapipe_internal", - ], deps = [ "//mediapipe/framework:calculator_options_proto", "//mediapipe/framework:calculator_proto", @@ -285,9 +243,6 @@ mediapipe_proto_library( cc_library( name = "regex_preprocessor_calculator", srcs = ["regex_preprocessor_calculator.cc"], - visibility = [ - "//mediapipe/framework:mediapipe_internal", - ], deps = [ ":regex_preprocessor_calculator_cc_proto", "//mediapipe/framework:calculator_framework", @@ -329,9 +284,6 @@ cc_test( cc_library( name = "text_to_tensor_calculator", srcs = ["text_to_tensor_calculator.cc"], - visibility = [ - "//mediapipe/framework:mediapipe_internal", - ], deps = [ "//mediapipe/framework:calculator_context", "//mediapipe/framework:calculator_framework", @@ -404,7 +356,6 @@ cc_test( mediapipe_proto_library( name = "inference_calculator_proto", srcs = ["inference_calculator.proto"], - visibility = ["//visibility:public"], deps = [ "//mediapipe/framework:calculator_options_proto", "//mediapipe/framework:calculator_proto", @@ -423,16 +374,8 @@ cc_library( name = "inference_calculator_interface", srcs = ["inference_calculator.cc"], hdrs = ["inference_calculator.h"], - copts = select({ - # TODO: fix tensor.h not to require this, if possible - "//mediapipe:apple": [ - "-x objective-c++", - "-fobjc-arc", # enable reference-counting - ], - "//conditions:default": [], - }), - visibility = ["//visibility:public"], deps = [ + ":inference_calculator_cc_proto", ":inference_calculator_options_lib", "//mediapipe/framework:calculator_framework", "//mediapipe/framework/api2:node", @@ -455,7 +398,6 @@ cc_library( name = "inference_calculator_gl", srcs = ["inference_calculator_gl.cc"], tags = ["nomac"], # config problem with cpuinfo via TF - visibility = ["//visibility:public"], deps = [ ":inference_calculator_cc_proto", ":inference_calculator_interface", @@ -463,6 +405,7 @@ cc_library( "//mediapipe/gpu:gl_calculator_helper", "@com_google_absl//absl/memory", "@com_google_absl//absl/status", + "@com_google_absl//absl/strings:str_format", "@org_tensorflow//tensorflow/lite/delegates/gpu:gl_delegate", ], alwayslink = 1, @@ -472,7 +415,6 @@ cc_library( name = "inference_calculator_gl_advanced", srcs = ["inference_calculator_gl_advanced.cc"], tags = ["nomac"], - visibility = ["//visibility:public"], deps = [ ":inference_calculator_interface", "@com_google_absl//absl/memory", @@ -503,15 +445,16 @@ cc_library( "-framework MetalKit", ], tags = ["ios"], - visibility = ["//visibility:public"], deps = [ "inference_calculator_interface", + "//mediapipe/framework/formats:tensor", "//mediapipe/gpu:MPPMetalHelper", "//mediapipe/gpu:MPPMetalUtil", "//mediapipe/gpu:gpu_buffer", "//mediapipe/objc:mediapipe_framework_ios", "//mediapipe/util/tflite:config", "@com_google_absl//absl/memory", + "@com_google_absl//absl/strings:str_format", "@org_tensorflow//tensorflow/lite/delegates/gpu:metal_delegate", "@org_tensorflow//tensorflow/lite/delegates/gpu:metal_delegate_internal", "@org_tensorflow//tensorflow/lite/delegates/gpu/common:shape", @@ -523,15 +466,6 @@ cc_library( cc_library( name = "inference_runner", hdrs = ["inference_runner.h"], - copts = select({ - # TODO: fix tensor.h not to require this, if possible - "//mediapipe:apple": [ - "-x objective-c++", - "-fobjc-arc", # enable reference-counting - ], - "//conditions:default": [], - }), - visibility = ["//visibility:public"], deps = [ "//mediapipe/framework:calculator_context", "//mediapipe/framework/formats:tensor", @@ -543,15 +477,6 @@ cc_library( name = "inference_interpreter_delegate_runner", srcs = ["inference_interpreter_delegate_runner.cc"], hdrs = ["inference_interpreter_delegate_runner.h"], - copts = select({ - # TODO: fix tensor.h not to require this, if possible - "//mediapipe:apple": [ - "-x objective-c++", - "-fobjc-arc", # enable reference-counting - ], - "//conditions:default": [], - }), - visibility = ["//visibility:public"], deps = [ ":inference_runner", "//mediapipe/framework:mediapipe_profiling", @@ -573,15 +498,6 @@ cc_library( srcs = [ "inference_calculator_cpu.cc", ], - copts = select({ - # TODO: fix tensor.h not to require this, if possible - "//mediapipe:apple": [ - "-x objective-c++", - "-fobjc-arc", # enable reference-counting - ], - "//conditions:default": [], - }), - visibility = ["//visibility:public"], deps = [ ":inference_calculator_interface", ":inference_calculator_utils", @@ -620,15 +536,6 @@ cc_library( srcs = [ "inference_calculator_xnnpack.cc", ], - copts = select({ - # TODO: fix tensor.h not to require this, if possible - "//mediapipe:apple": [ - "-x objective-c++", - "-fobjc-arc", # enable reference-counting - ], - "//conditions:default": [], - }), - visibility = ["//visibility:public"], deps = [ ":inference_calculator_interface", ":inference_calculator_utils", @@ -644,7 +551,6 @@ cc_library( cc_library( name = "inference_calculator_gl_if_compute_shader_available", - visibility = ["//visibility:public"], deps = selects.with_or({ ":compute_shader_unavailable": [], "//conditions:default": [ @@ -660,7 +566,6 @@ cc_library( # inference_calculator_interface. cc_library( name = "inference_calculator", - visibility = ["//visibility:public"], deps = [ ":inference_calculator_interface", ":inference_calculator_cpu", @@ -674,7 +579,6 @@ cc_library( mediapipe_proto_library( name = "tensor_converter_calculator_proto", srcs = ["tensor_converter_calculator.proto"], - visibility = ["//visibility:public"], deps = [ "//mediapipe/framework:calculator_options_proto", "//mediapipe/framework:calculator_proto", @@ -699,7 +603,6 @@ cc_library( ], "//conditions:default": [], }), - visibility = ["//visibility:public"], deps = [ ":tensor_converter_calculator_cc_proto", "//mediapipe/framework:calculator_framework", @@ -718,6 +621,7 @@ cc_library( cc_library( name = "tensor_converter_calculator_gpu_deps", + visibility = ["//visibility:private"], deps = select({ "//mediapipe:android": [ "//mediapipe/gpu:gl_calculator_helper", @@ -762,7 +666,6 @@ cc_test( mediapipe_proto_library( name = "tensors_to_detections_calculator_proto", srcs = ["tensors_to_detections_calculator.proto"], - visibility = ["//visibility:public"], deps = [ "//mediapipe/framework:calculator_options_proto", "//mediapipe/framework:calculator_proto", @@ -787,19 +690,18 @@ cc_library( ], "//conditions:default": [], }), - visibility = ["//visibility:public"], deps = [ ":tensors_to_detections_calculator_cc_proto", "//mediapipe/framework/formats:detection_cc_proto", "@com_google_absl//absl/strings:str_format", "@com_google_absl//absl/types:span", "//mediapipe/framework/api2:node", + "//mediapipe/framework/formats/object_detection:anchor_cc_proto", "//mediapipe/framework:calculator_framework", "//mediapipe/framework:port", "//mediapipe/framework/deps:file_path", "//mediapipe/framework/formats:location", "//mediapipe/framework/formats:tensor", - "//mediapipe/framework/formats/object_detection:anchor_cc_proto", "//mediapipe/framework/port:ret_check", ] + selects.with_or({ ":compute_shader_unavailable": [], @@ -810,6 +712,7 @@ cc_library( cc_library( name = "tensors_to_detections_calculator_gpu_deps", + visibility = ["//visibility:private"], deps = select({ "//mediapipe:ios": [ "//mediapipe/gpu:MPPMetalUtil", @@ -825,7 +728,6 @@ cc_library( mediapipe_proto_library( name = "tensors_to_landmarks_calculator_proto", srcs = ["tensors_to_landmarks_calculator.proto"], - visibility = ["//visibility:public"], deps = [ "//mediapipe/framework:calculator_options_proto", "//mediapipe/framework:calculator_proto", @@ -842,7 +744,6 @@ cc_library( ], "//conditions:default": [], }), - visibility = ["//visibility:public"], deps = [ ":tensors_to_landmarks_calculator_cc_proto", "//mediapipe/framework:calculator_framework", @@ -857,7 +758,6 @@ cc_library( mediapipe_proto_library( name = "landmarks_to_tensor_calculator_proto", srcs = ["landmarks_to_tensor_calculator.proto"], - visibility = ["//visibility:public"], deps = [ "//mediapipe/framework:calculator_options_proto", "//mediapipe/framework:calculator_proto", @@ -875,7 +775,6 @@ cc_library( ], "//conditions:default": [], }), - visibility = ["//visibility:public"], deps = [ ":landmarks_to_tensor_calculator_cc_proto", "//mediapipe/framework:calculator_framework", @@ -908,7 +807,6 @@ cc_test( mediapipe_proto_library( name = "tensors_to_floats_calculator_proto", srcs = ["tensors_to_floats_calculator.proto"], - visibility = ["//visibility:public"], deps = [ "//mediapipe/framework:calculator_options_proto", "//mediapipe/framework:calculator_proto", @@ -925,7 +823,6 @@ cc_library( ], "//conditions:default": [], }), - visibility = ["//visibility:public"], deps = [ ":tensors_to_floats_calculator_cc_proto", "//mediapipe/framework:calculator_framework", @@ -963,7 +860,6 @@ cc_library( ], "//conditions:default": [], }), - visibility = ["//visibility:public"], deps = [ ":tensors_to_classification_calculator_cc_proto", "@com_google_absl//absl/container:node_hash_map", @@ -994,7 +890,6 @@ cc_library( mediapipe_proto_library( name = "tensors_to_classification_calculator_proto", srcs = ["tensors_to_classification_calculator.proto"], - visibility = ["//visibility:public"], deps = [ "//mediapipe/framework:calculator_options_proto", "//mediapipe/framework:calculator_proto", @@ -1032,7 +927,6 @@ cc_library( "//conditions:default": [], }), features = ["-layering_check"], # allow depending on image_to_tensor_calculator_gpu_deps - visibility = ["//visibility:public"], deps = [ ":image_to_tensor_calculator_cc_proto", ":image_to_tensor_converter", @@ -1061,6 +955,7 @@ cc_library( cc_library( name = "image_to_tensor_calculator_gpu_deps", + visibility = ["//visibility:private"], deps = selects.with_or({ "//mediapipe:android": [ ":image_to_tensor_converter_gl_buffer", @@ -1084,7 +979,6 @@ cc_library( mediapipe_proto_library( name = "image_to_tensor_calculator_proto", srcs = ["image_to_tensor_calculator.proto"], - visibility = ["//visibility:public"], deps = [ "//mediapipe/framework:calculator_options_proto", "//mediapipe/framework:calculator_proto", @@ -1130,6 +1024,7 @@ cc_test( "//mediapipe/framework/port:opencv_imgcodecs", "//mediapipe/framework/port:opencv_imgproc", "//mediapipe/framework/port:parse_text_proto", + "//mediapipe/util:image_test_utils", "@com_google_absl//absl/flags:flag", "@com_google_absl//absl/memory", "@com_google_absl//absl/strings", @@ -1146,7 +1041,6 @@ cc_library( ], "//conditions:default": [], }), - visibility = ["//visibility:public"], deps = [ ":image_to_tensor_utils", "//mediapipe/framework/formats:image", @@ -1166,7 +1060,6 @@ cc_library( ], "//conditions:default": [], }), - visibility = ["//visibility:public"], deps = [ ":image_to_tensor_converter", ":image_to_tensor_utils", @@ -1186,6 +1079,7 @@ cc_library( name = "image_to_tensor_converter_gl_buffer", srcs = ["image_to_tensor_converter_gl_buffer.cc"], hdrs = ["image_to_tensor_converter_gl_buffer.h"], + visibility = ["//visibility:private"], deps = ["//mediapipe/framework:port"] + selects.with_or({ "//mediapipe:apple": [], "//conditions:default": [ @@ -1219,6 +1113,7 @@ cc_library( name = "image_to_tensor_converter_gl_texture", srcs = ["image_to_tensor_converter_gl_texture.cc"], hdrs = ["image_to_tensor_converter_gl_texture.h"], + visibility = ["//visibility:private"], deps = ["//mediapipe/framework:port"] + select({ "//mediapipe/gpu:disable_gpu": [], "//conditions:default": [ @@ -1243,6 +1138,7 @@ cc_library( name = "image_to_tensor_converter_gl_utils", srcs = ["image_to_tensor_converter_gl_utils.cc"], hdrs = ["image_to_tensor_converter_gl_utils.h"], + visibility = ["//visibility:private"], deps = ["//mediapipe/framework:port"] + select({ "//mediapipe/gpu:disable_gpu": [], "//conditions:default": [ @@ -1272,6 +1168,7 @@ cc_library( ], "//conditions:default": [], }), + visibility = ["//visibility:private"], deps = ["//mediapipe/framework:port"] + select({ "//mediapipe:apple": [ ":image_to_tensor_converter", @@ -1279,7 +1176,6 @@ cc_library( "//mediapipe/gpu:MPPMetalHelper", "@com_google_absl//absl/strings", "//mediapipe/framework:calculator_framework", - "//mediapipe/framework/formats:rect_cc_proto", "//mediapipe/framework/formats:tensor", "//mediapipe/framework/port:ret_check", "//mediapipe/framework/port:status", @@ -1304,7 +1200,6 @@ cc_library( ], "//conditions:default": [], }), - visibility = ["//visibility:public"], deps = [ ":image_to_tensor_calculator_cc_proto", "@com_google_absl//absl/status", @@ -1347,7 +1242,6 @@ selects.config_setting_group( mediapipe_proto_library( name = "tensors_to_segmentation_calculator_proto", srcs = ["tensors_to_segmentation_calculator.proto"], - visibility = ["//visibility:public"], deps = [ "//mediapipe/framework:calculator_options_proto", "//mediapipe/framework:calculator_proto", @@ -1365,7 +1259,6 @@ cc_library( ], "//conditions:default": [], }), - visibility = ["//visibility:public"], deps = [ ":tensors_to_segmentation_calculator_cc_proto", "@com_google_absl//absl/strings:str_format", @@ -1378,9 +1271,9 @@ cc_library( "//mediapipe/framework:calculator_context", "//mediapipe/framework:calculator_framework", "//mediapipe/framework:port", + "//mediapipe/gpu:gpu_origin_cc_proto", "//mediapipe/util:resource_util", "@org_tensorflow//tensorflow/lite:framework", - "//mediapipe/gpu:gpu_origin_cc_proto", "//mediapipe/framework/port:statusor", ] + selects.with_or({ "//mediapipe/gpu:disable_gpu": [], @@ -1423,7 +1316,6 @@ cc_library( ], "//conditions:default": [], }), - visibility = ["//visibility:public"], deps = [ "//mediapipe/framework:calculator_context", "//mediapipe/framework:calculator_framework", diff --git a/mediapipe/calculators/tensor/audio_to_tensor_calculator.cc b/mediapipe/calculators/tensor/audio_to_tensor_calculator.cc index d0513518a..9cb23a393 100644 --- a/mediapipe/calculators/tensor/audio_to_tensor_calculator.cc +++ b/mediapipe/calculators/tensor/audio_to_tensor_calculator.cc @@ -43,6 +43,7 @@ namespace api2 { namespace { using Options = ::mediapipe::AudioToTensorCalculatorOptions; +using DftTensorFormat = Options::DftTensorFormat; using FlushMode = Options::FlushMode; std::vector HannWindow(int window_size, bool sqrt_hann) { @@ -188,6 +189,8 @@ class AudioToTensorCalculator : public Node { int padding_samples_before_; int padding_samples_after_; FlushMode flush_mode_; + DftTensorFormat dft_tensor_format_; + Timestamp initial_timestamp_ = Timestamp::Unstarted(); int64 cumulative_input_samples_ = 0; Timestamp next_output_timestamp_ = Timestamp::Unstarted(); @@ -273,6 +276,7 @@ absl::Status AudioToTensorCalculator::Open(CalculatorContext* cc) { } padding_samples_before_ = options.padding_samples_before(); padding_samples_after_ = options.padding_samples_after(); + dft_tensor_format_ = options.dft_tensor_format(); flush_mode_ = options.flush_mode(); RET_CHECK(kAudioSampleRateIn(cc).IsConnected() ^ @@ -492,14 +496,43 @@ absl::Status AudioToTensorCalculator::OutputTensor(const Matrix& block, kDcAndNyquistOut(cc).Send(std::make_pair(fft_output_[0], fft_output_[1]), timestamp); } - Matrix fft_output_matrix = - Eigen::Map(fft_output_.data() + 2, 1, fft_size_ - 2); - fft_output_matrix.conservativeResize(Eigen::NoChange, fft_size_); - // The last two elements are the DFT Nyquist values. - fft_output_matrix(fft_size_ - 2) = fft_output_[1]; // Nyquist real part - fft_output_matrix(fft_size_ - 1) = 0.0f; // Nyquist imagery part - ASSIGN_OR_RETURN(output_tensor, - ConvertToTensor(fft_output_matrix, {2, fft_size_ / 2})); + switch (dft_tensor_format_) { + case Options::WITH_NYQUIST: { + Matrix fft_output_matrix = + Eigen::Map(fft_output_.data() + 2, 1, fft_size_ - 2); + fft_output_matrix.conservativeResize(Eigen::NoChange, fft_size_); + // The last two elements are Nyquist component. + fft_output_matrix(fft_size_ - 2) = fft_output_[1]; // Nyquist real part + fft_output_matrix(fft_size_ - 1) = 0.0f; // Nyquist imagery part + ASSIGN_OR_RETURN(output_tensor, ConvertToTensor(fft_output_matrix, + {2, fft_size_ / 2})); + break; + } + case Options::WITH_DC_AND_NYQUIST: { + Matrix fft_output_matrix = + Eigen::Map(fft_output_.data(), 1, fft_size_); + fft_output_matrix.conservativeResize(Eigen::NoChange, fft_size_ + 2); + fft_output_matrix(1) = 0.0f; // DC imagery part. + // The last two elements are Nyquist component. + fft_output_matrix(fft_size_) = fft_output_[1]; // Nyquist real part + fft_output_matrix(fft_size_ + 1) = 0.0f; // Nyquist imagery part + ASSIGN_OR_RETURN( + output_tensor, + ConvertToTensor(fft_output_matrix, {2, (fft_size_ + 2) / 2})); + break; + } + case Options::WITHOUT_DC_AND_NYQUIST: { + Matrix fft_output_matrix = + Eigen::Map(fft_output_.data() + 2, 1, fft_size_ - 2); + ASSIGN_OR_RETURN( + output_tensor, + ConvertToTensor(fft_output_matrix, {2, (fft_size_ - 2) / 2})); + break; + } + default: + return absl::InvalidArgumentError("Unsupported dft tensor format."); + } + } else { ASSIGN_OR_RETURN(output_tensor, ConvertToTensor(block, {num_channels_, num_samples_})); diff --git a/mediapipe/calculators/tensor/audio_to_tensor_calculator.proto b/mediapipe/calculators/tensor/audio_to_tensor_calculator.proto index cff6b2878..aa3c1229c 100644 --- a/mediapipe/calculators/tensor/audio_to_tensor_calculator.proto +++ b/mediapipe/calculators/tensor/audio_to_tensor_calculator.proto @@ -68,4 +68,17 @@ message AudioToTensorCalculatorOptions { } optional FlushMode flush_mode = 10 [default = ENTIRE_TAIL_AT_TIMESTAMP_MAX]; + + enum DftTensorFormat { + DFT_TENSOR_FORMAT_UNKNOWN = 0; + // The output dft tensor without dc and nyquist components. + WITHOUT_DC_AND_NYQUIST = 1; + // The output dft tensor contains the nyquist component as the last + // two values. + WITH_NYQUIST = 2; + // The output dft tensor contains the dc component as the first two values + // and the nyquist component as the last two values. + WITH_DC_AND_NYQUIST = 3; + } + optional DftTensorFormat dft_tensor_format = 11 [default = WITH_NYQUIST]; } diff --git a/mediapipe/calculators/tensor/image_to_tensor_calculator_test.cc b/mediapipe/calculators/tensor/image_to_tensor_calculator_test.cc index 07a5f9fe1..ceb1fc502 100644 --- a/mediapipe/calculators/tensor/image_to_tensor_calculator_test.cc +++ b/mediapipe/calculators/tensor/image_to_tensor_calculator_test.cc @@ -36,22 +36,17 @@ #include "mediapipe/framework/port/opencv_imgproc_inc.h" #include "mediapipe/framework/port/parse_text_proto.h" #include "mediapipe/framework/port/status_matchers.h" +#include "mediapipe/util/image_test_utils.h" namespace mediapipe { namespace { -cv::Mat GetRgb(absl::string_view path) { - cv::Mat bgr = cv::imread(file::JoinPath("./", path)); - cv::Mat rgb; - cv::cvtColor(bgr, rgb, cv::COLOR_BGR2RGB); - return rgb; -} +constexpr char kTestDataDir[] = + "/mediapipe/calculators/tensor/testdata/" + "image_to_tensor/"; -cv::Mat GetRgba(absl::string_view path) { - cv::Mat bgr = cv::imread(file::JoinPath("./", path)); - cv::Mat rgb; - cv::cvtColor(bgr, rgb, cv::COLOR_BGR2RGBA); - return rgb; +std::string GetFilePath(absl::string_view filename) { + return file::JoinPath("./", kTestDataDir, filename); } // Image to tensor test template. @@ -147,29 +142,34 @@ void RunTestWithInputImagePacket(const Packet& input_image_packet, ASSERT_THAT(tensor_vec, testing::SizeIs(1)); const Tensor& tensor = tensor_vec[0]; + const int channels = tensor.shape().dims[3]; + ASSERT_TRUE(channels == 1 || channels == 3); auto view = tensor.GetCpuReadView(); cv::Mat tensor_mat; if (output_int_tensor) { if (range_min < 0) { EXPECT_EQ(tensor.element_type(), Tensor::ElementType::kInt8); - tensor_mat = cv::Mat(tensor_height, tensor_width, CV_8SC3, + tensor_mat = cv::Mat(tensor_height, tensor_width, + channels == 1 ? CV_8SC1 : CV_8SC3, const_cast(view.buffer())); } else { EXPECT_EQ(tensor.element_type(), Tensor::ElementType::kUInt8); - tensor_mat = cv::Mat(tensor_height, tensor_width, CV_8UC3, + tensor_mat = cv::Mat(tensor_height, tensor_width, + channels == 1 ? CV_8UC1 : CV_8UC3, const_cast(view.buffer())); } } else { EXPECT_EQ(tensor.element_type(), Tensor::ElementType::kFloat32); - tensor_mat = cv::Mat(tensor_height, tensor_width, CV_32FC3, + tensor_mat = cv::Mat(tensor_height, tensor_width, + channels == 1 ? CV_32FC1 : CV_32FC3, const_cast(view.buffer())); } cv::Mat result_rgb; auto transformation = GetValueRangeTransformation(range_min, range_max, 0.0f, 255.0f).value(); - tensor_mat.convertTo(result_rgb, CV_8UC3, transformation.scale, - transformation.offset); + tensor_mat.convertTo(result_rgb, channels == 1 ? CV_8UC1 : CV_8UC3, + transformation.scale, transformation.offset); cv::Mat diff; cv::absdiff(result_rgb, expected_result, diff); @@ -185,17 +185,27 @@ void RunTestWithInputImagePacket(const Packet& input_image_packet, MP_ASSERT_OK(graph.WaitUntilDone()); } +mediapipe::ImageFormat::Format GetImageFormat(int image_channels) { + if (image_channels == 4) { + return ImageFormat::SRGBA; + } else if (image_channels == 3) { + return ImageFormat::SRGB; + } else if (image_channels == 1) { + return ImageFormat::GRAY8; + } + CHECK(false) << "Unsupported input image channles: " << image_channels; +} + Packet MakeImageFramePacket(cv::Mat input) { - ImageFrame input_image( - input.channels() == 4 ? ImageFormat::SRGBA : ImageFormat::SRGB, - input.cols, input.rows, input.step, input.data, [](uint8*) {}); + ImageFrame input_image(GetImageFormat(input.channels()), input.cols, + input.rows, input.step, input.data, [](uint8*) {}); return MakePacket(std::move(input_image)).At(Timestamp(0)); } Packet MakeImagePacket(cv::Mat input) { mediapipe::Image input_image(std::make_shared( - input.channels() == 4 ? ImageFormat::SRGBA : ImageFormat::SRGB, - input.cols, input.rows, input.step, input.data, [](uint8*) {})); + GetImageFormat(input.channels()), input.cols, input.rows, input.step, + input.data, [](uint8*) {})); return MakePacket(std::move(input_image)).At(Timestamp(0)); } @@ -237,15 +247,12 @@ TEST(ImageToTensorCalculatorTest, MediumSubRectKeepAspect) { roi.set_width(0.5f); roi.set_height(0.5f); roi.set_rotation(0); - RunTest( - GetRgb("/mediapipe/calculators/" - "tensor/testdata/image_to_tensor/input.jpg"), - GetRgb("/mediapipe/calculators/" - "tensor/testdata/image_to_tensor/medium_sub_rect_keep_aspect.png"), - /*float_ranges=*/{{0.0f, 1.0f}}, - /*int_ranges=*/{{0, 255}, {-128, 127}}, - /*tensor_width=*/256, /*tensor_height=*/256, /*keep_aspect=*/true, - /*border mode*/ {}, roi); + RunTest(GetRgb(GetFilePath("input.jpg")), + GetRgb(GetFilePath("medium_sub_rect_keep_aspect.png")), + /*float_ranges=*/{{0.0f, 1.0f}}, + /*int_ranges=*/{{0, 255}, {-128, 127}}, + /*tensor_width=*/256, /*tensor_height=*/256, /*keep_aspect=*/true, + /*border mode*/ {}, roi); } TEST(ImageToTensorCalculatorTest, MediumSubRectKeepAspectBorderZero) { @@ -255,11 +262,8 @@ TEST(ImageToTensorCalculatorTest, MediumSubRectKeepAspectBorderZero) { roi.set_width(0.5f); roi.set_height(0.5f); roi.set_rotation(0); - RunTest(GetRgb("/mediapipe/calculators/" - "tensor/testdata/image_to_tensor/input.jpg"), - GetRgb("/mediapipe/calculators/" - "tensor/testdata/image_to_tensor/" - "medium_sub_rect_keep_aspect_border_zero.png"), + RunTest(GetRgb(GetFilePath("input.jpg")), + GetRgb(GetFilePath("medium_sub_rect_keep_aspect_border_zero.png")), /*float_ranges=*/{{0.0f, 1.0f}}, /*int_ranges=*/{{0, 255}, {-128, 127}}, /*tensor_width=*/256, /*tensor_height=*/256, /*keep_aspect=*/true, @@ -273,11 +277,8 @@ TEST(ImageToTensorCalculatorTest, MediumSubRectKeepAspectWithRotation) { roi.set_width(0.5f); roi.set_height(0.5f); roi.set_rotation(M_PI * 90.0f / 180.0f); - RunTest(GetRgb("/mediapipe/calculators/" - "tensor/testdata/image_to_tensor/input.jpg"), - GetRgb("/mediapipe/calculators/" - "tensor/testdata/image_to_tensor/" - "medium_sub_rect_keep_aspect_with_rotation.png"), + RunTest(GetRgb(GetFilePath("input.jpg")), + GetRgb(GetFilePath("medium_sub_rect_keep_aspect_with_rotation.png")), /*float_ranges=*/{{0.0f, 1.0f}}, /*int_ranges=*/{{0, 255}}, /*tensor_width=*/256, /*tensor_height=*/256, /*keep_aspect=*/true, @@ -292,11 +293,9 @@ TEST(ImageToTensorCalculatorTest, roi.set_width(0.5f); roi.set_height(0.5f); roi.set_rotation(M_PI * 90.0f / 180.0f); - RunTest(GetRgb("/mediapipe/calculators/" - "tensor/testdata/image_to_tensor/input.jpg"), - GetRgb("/mediapipe/calculators/" - "tensor/testdata/image_to_tensor/" - "medium_sub_rect_keep_aspect_with_rotation_border_zero.png"), + RunTest(GetRgb(GetFilePath("input.jpg")), + GetRgb(GetFilePath( + "medium_sub_rect_keep_aspect_with_rotation_border_zero.png")), /*float_ranges=*/{{0.0f, 1.0f}}, /*int_ranges=*/{{0, 255}, {-128, 127}}, /*tensor_width=*/256, /*tensor_height=*/256, /*keep_aspect=*/true, @@ -310,16 +309,12 @@ TEST(ImageToTensorCalculatorTest, MediumSubRectWithRotation) { roi.set_width(0.5f); roi.set_height(0.5f); roi.set_rotation(M_PI * -45.0f / 180.0f); - RunTest( - GetRgb("/mediapipe/calculators/" - "tensor/testdata/image_to_tensor/input.jpg"), - GetRgb( - "/mediapipe/calculators/" - "tensor/testdata/image_to_tensor/medium_sub_rect_with_rotation.png"), - /*float_ranges=*/{{-1.0f, 1.0f}}, - /*int_ranges=*/{{0, 255}, {-128, 127}}, - /*tensor_width=*/256, /*tensor_height=*/256, /*keep_aspect=*/false, - BorderMode::kReplicate, roi); + RunTest(GetRgb(GetFilePath("input.jpg")), + GetRgb(GetFilePath("medium_sub_rect_with_rotation.png")), + /*float_ranges=*/{{-1.0f, 1.0f}}, + /*int_ranges=*/{{0, 255}, {-128, 127}}, + /*tensor_width=*/256, /*tensor_height=*/256, /*keep_aspect=*/false, + BorderMode::kReplicate, roi); } TEST(ImageToTensorCalculatorTest, MediumSubRectWithRotationBorderZero) { @@ -329,11 +324,8 @@ TEST(ImageToTensorCalculatorTest, MediumSubRectWithRotationBorderZero) { roi.set_width(0.5f); roi.set_height(0.5f); roi.set_rotation(M_PI * -45.0f / 180.0f); - RunTest(GetRgb("/mediapipe/calculators/" - "tensor/testdata/image_to_tensor/input.jpg"), - GetRgb("/mediapipe/calculators/" - "tensor/testdata/image_to_tensor/" - "medium_sub_rect_with_rotation_border_zero.png"), + RunTest(GetRgb(GetFilePath("input.jpg")), + GetRgb(GetFilePath("medium_sub_rect_with_rotation_border_zero.png")), /*float_ranges=*/{{-1.0f, 1.0f}}, /*int_ranges=*/{{0, 255}, {-128, 127}}, /*tensor_width=*/256, /*tensor_height=*/256, /*keep_aspect=*/false, @@ -347,10 +339,8 @@ TEST(ImageToTensorCalculatorTest, LargeSubRect) { roi.set_width(1.5f); roi.set_height(1.1f); roi.set_rotation(0); - RunTest(GetRgb("/mediapipe/calculators/" - "tensor/testdata/image_to_tensor/input.jpg"), - GetRgb("/mediapipe/calculators/" - "tensor/testdata/image_to_tensor/large_sub_rect.png"), + RunTest(GetRgb(GetFilePath("input.jpg")), + GetRgb(GetFilePath("large_sub_rect.png")), /*float_ranges=*/{{0.0f, 1.0f}}, /*int_ranges=*/{{0, 255}}, /*tensor_width=*/128, /*tensor_height=*/128, /*keep_aspect=*/false, @@ -364,15 +354,12 @@ TEST(ImageToTensorCalculatorTest, LargeSubRectBorderZero) { roi.set_width(1.5f); roi.set_height(1.1f); roi.set_rotation(0); - RunTest( - GetRgb("/mediapipe/calculators/" - "tensor/testdata/image_to_tensor/input.jpg"), - GetRgb("/mediapipe/calculators/" - "tensor/testdata/image_to_tensor/large_sub_rect_border_zero.png"), - /*float_ranges=*/{{0.0f, 1.0f}}, - /*int_ranges=*/{{0, 255}, {-128, 127}}, - /*tensor_width=*/128, /*tensor_height=*/128, /*keep_aspect=*/false, - BorderMode::kZero, roi); + RunTest(GetRgb(GetFilePath("input.jpg")), + GetRgb(GetFilePath("large_sub_rect_border_zero.png")), + /*float_ranges=*/{{0.0f, 1.0f}}, + /*int_ranges=*/{{0, 255}, {-128, 127}}, + /*tensor_width=*/128, /*tensor_height=*/128, /*keep_aspect=*/false, + BorderMode::kZero, roi); } TEST(ImageToTensorCalculatorTest, LargeSubRectKeepAspect) { @@ -382,15 +369,12 @@ TEST(ImageToTensorCalculatorTest, LargeSubRectKeepAspect) { roi.set_width(1.5f); roi.set_height(1.1f); roi.set_rotation(0); - RunTest( - GetRgb("/mediapipe/calculators/" - "tensor/testdata/image_to_tensor/input.jpg"), - GetRgb("/mediapipe/calculators/" - "tensor/testdata/image_to_tensor/large_sub_rect_keep_aspect.png"), - /*float_ranges=*/{{0.0f, 1.0f}}, - /*int_ranges=*/{{0, 255}, {-128, 127}}, - /*tensor_width=*/128, /*tensor_height=*/128, /*keep_aspect=*/true, - BorderMode::kReplicate, roi); + RunTest(GetRgb(GetFilePath("input.jpg")), + GetRgb(GetFilePath("large_sub_rect_keep_aspect.png")), + /*float_ranges=*/{{0.0f, 1.0f}}, + /*int_ranges=*/{{0, 255}, {-128, 127}}, + /*tensor_width=*/128, /*tensor_height=*/128, /*keep_aspect=*/true, + BorderMode::kReplicate, roi); } TEST(ImageToTensorCalculatorTest, LargeSubRectKeepAspectBorderZero) { @@ -400,11 +384,8 @@ TEST(ImageToTensorCalculatorTest, LargeSubRectKeepAspectBorderZero) { roi.set_width(1.5f); roi.set_height(1.1f); roi.set_rotation(0); - RunTest(GetRgb("/mediapipe/calculators/" - "tensor/testdata/image_to_tensor/input.jpg"), - GetRgb("/mediapipe/calculators/" - "tensor/testdata/image_to_tensor/" - "large_sub_rect_keep_aspect_border_zero.png"), + RunTest(GetRgb(GetFilePath("input.jpg")), + GetRgb(GetFilePath("large_sub_rect_keep_aspect_border_zero.png")), /*float_ranges=*/{{0.0f, 1.0f}}, /*int_ranges=*/{{0, 255}, {-128, 127}}, /*tensor_width=*/128, /*tensor_height=*/128, /*keep_aspect=*/true, @@ -418,11 +399,23 @@ TEST(ImageToTensorCalculatorTest, LargeSubRectKeepAspectWithRotation) { roi.set_width(1.5f); roi.set_height(1.1f); roi.set_rotation(M_PI * -15.0f / 180.0f); - RunTest(GetRgba("/mediapipe/calculators/" - "tensor/testdata/image_to_tensor/input.jpg"), - GetRgb("/mediapipe/calculators/" - "tensor/testdata/image_to_tensor/" - "large_sub_rect_keep_aspect_with_rotation.png"), + RunTest(GetRgba(GetFilePath("input.jpg")), + GetRgb(GetFilePath("large_sub_rect_keep_aspect_with_rotation.png")), + /*float_ranges=*/{{0.0f, 1.0f}}, + /*int_ranges=*/{{0, 255}, {-128, 127}}, + /*tensor_width=*/128, /*tensor_height=*/128, /*keep_aspect=*/true, + /*border_mode=*/{}, roi); +} + +TEST(ImageToTensorCalculatorTest, LargeSubRectKeepAspectWithRotationGray) { + mediapipe::NormalizedRect roi; + roi.set_x_center(0.5f); + roi.set_y_center(0.5f); + roi.set_width(1.5f); + roi.set_height(1.1f); + roi.set_rotation(M_PI * -15.0f / 180.0f); + RunTest(GetGray(GetFilePath("input.jpg")), + GetGray(GetFilePath("large_sub_rect_keep_aspect_with_rotation.png")), /*float_ranges=*/{{0.0f, 1.0f}}, /*int_ranges=*/{{0, 255}, {-128, 127}}, /*tensor_width=*/128, /*tensor_height=*/128, /*keep_aspect=*/true, @@ -437,11 +430,26 @@ TEST(ImageToTensorCalculatorTest, roi.set_width(1.5f); roi.set_height(1.1f); roi.set_rotation(M_PI * -15.0f / 180.0f); - RunTest(GetRgba("/mediapipe/calculators/" - "tensor/testdata/image_to_tensor/input.jpg"), - GetRgb("/mediapipe/calculators/" - "tensor/testdata/image_to_tensor/" - "large_sub_rect_keep_aspect_with_rotation_border_zero.png"), + RunTest(GetRgba(GetFilePath("input.jpg")), + GetRgb(GetFilePath( + "large_sub_rect_keep_aspect_with_rotation_border_zero.png")), + /*float_ranges=*/{{0.0f, 1.0f}}, + /*int_ranges=*/{{0, 255}}, + /*tensor_width=*/128, /*tensor_height=*/128, /*keep_aspect=*/true, + /*border_mode=*/BorderMode::kZero, roi); +} + +TEST(ImageToTensorCalculatorTest, + LargeSubRectKeepAspectWithRotationBorderZeroGray) { + mediapipe::NormalizedRect roi; + roi.set_x_center(0.5f); + roi.set_y_center(0.5f); + roi.set_width(1.5f); + roi.set_height(1.1f); + roi.set_rotation(M_PI * -15.0f / 180.0f); + RunTest(GetGray(GetFilePath("input.jpg")), + GetGray(GetFilePath( + "large_sub_rect_keep_aspect_with_rotation_border_zero.png")), /*float_ranges=*/{{0.0f, 1.0f}}, /*int_ranges=*/{{0, 255}}, /*tensor_width=*/128, /*tensor_height=*/128, /*keep_aspect=*/true, @@ -455,10 +463,8 @@ TEST(ImageToTensorCalculatorTest, NoOpExceptRange) { roi.set_width(1.0f); roi.set_height(1.0f); roi.set_rotation(0); - RunTest(GetRgba("/mediapipe/calculators/" - "tensor/testdata/image_to_tensor/input.jpg"), - GetRgb("/mediapipe/calculators/" - "tensor/testdata/image_to_tensor/noop_except_range.png"), + RunTest(GetRgba(GetFilePath("input.jpg")), + GetRgb(GetFilePath("noop_except_range.png")), /*float_ranges=*/{{0.0f, 1.0f}}, /*int_ranges=*/{{0, 255}, {-128, 127}}, /*tensor_width=*/64, /*tensor_height=*/128, /*keep_aspect=*/true, @@ -472,10 +478,8 @@ TEST(ImageToTensorCalculatorTest, NoOpExceptRangeBorderZero) { roi.set_width(1.0f); roi.set_height(1.0f); roi.set_rotation(0); - RunTest(GetRgba("/mediapipe/calculators/" - "tensor/testdata/image_to_tensor/input.jpg"), - GetRgb("/mediapipe/calculators/" - "tensor/testdata/image_to_tensor/noop_except_range.png"), + RunTest(GetRgba(GetFilePath("input.jpg")), + GetRgb(GetFilePath("noop_except_range.png")), /*float_ranges=*/{{0.0f, 1.0f}}, /*int_ranges=*/{{0, 255}, {-128, 127}}, /*tensor_width=*/64, /*tensor_height=*/128, /*keep_aspect=*/true, diff --git a/mediapipe/calculators/tensor/image_to_tensor_converter_gl_buffer.cc b/mediapipe/calculators/tensor/image_to_tensor_converter_gl_buffer.cc index a551e7f8d..eb1726aac 100644 --- a/mediapipe/calculators/tensor/image_to_tensor_converter_gl_buffer.cc +++ b/mediapipe/calculators/tensor/image_to_tensor_converter_gl_buffer.cc @@ -285,7 +285,7 @@ class GlProcessor : public ImageToTensorConverter { auto source_texture = gl_helper_.CreateSourceTexture(input); tflite::gpu::gl::GlTexture input_texture( GL_TEXTURE_2D, source_texture.name(), - input_num_channels == 4 ? GL_RGB : GL_RGBA, + input_num_channels == 4 ? GL_RGBA : GL_RGB, source_texture.width() * source_texture.height() * input_num_channels * sizeof(uint8_t), /*layer=*/0, diff --git a/mediapipe/calculators/tensor/image_to_tensor_converter_gl_texture.cc b/mediapipe/calculators/tensor/image_to_tensor_converter_gl_texture.cc index 5efd34041..165df8970 100644 --- a/mediapipe/calculators/tensor/image_to_tensor_converter_gl_texture.cc +++ b/mediapipe/calculators/tensor/image_to_tensor_converter_gl_texture.cc @@ -68,8 +68,8 @@ class GlProcessor : public ImageToTensorConverter { constexpr GLchar kExtractSubRectVertexShader[] = R"( in vec4 position; - in mediump vec4 texture_coordinate; - out mediump vec2 sample_coordinate; + in highp vec4 texture_coordinate; + out highp vec2 sample_coordinate; uniform mat4 transform_matrix; void main() { @@ -86,7 +86,7 @@ class GlProcessor : public ImageToTensorConverter { )"; constexpr GLchar kExtractSubRectFragBody[] = R"( - DEFAULT_PRECISION(mediump, float) + DEFAULT_PRECISION(highp, float) // Provided by kExtractSubRectVertexShader. in vec2 sample_coordinate; diff --git a/mediapipe/calculators/tensor/image_to_tensor_converter_metal.cc b/mediapipe/calculators/tensor/image_to_tensor_converter_metal.cc index a8211d39b..354547042 100644 --- a/mediapipe/calculators/tensor/image_to_tensor_converter_metal.cc +++ b/mediapipe/calculators/tensor/image_to_tensor_converter_metal.cc @@ -36,6 +36,10 @@ #include "tensorflow/lite/delegates/gpu/common/shape.h" #include "tensorflow/lite/delegates/gpu/common/types.h" +#if MEDIAPIPE_METAL_ENABLED +#include "mediapipe/framework/formats/tensor_mtl_buffer_view.h" +#endif // MEDIAPIPE_METAL_ENABLED + namespace mediapipe { namespace { @@ -376,7 +380,7 @@ class MetalProcessor : public ImageToTensorConverter { id command_buffer = [metal_helper_ commandBuffer]; const auto& buffer_view = - output_tensor.GetMtlBufferWriteView(command_buffer); + MtlBufferView::GetWriteView(output_tensor, command_buffer); MP_RETURN_IF_ERROR(extractor_->Execute( texture, roi, /*flip_horizontaly=*/false, transform.scale, transform.offset, diff --git a/mediapipe/calculators/tensor/image_to_tensor_converter_opencv.cc b/mediapipe/calculators/tensor/image_to_tensor_converter_opencv.cc index f910b59f3..95e38f89c 100644 --- a/mediapipe/calculators/tensor/image_to_tensor_converter_opencv.cc +++ b/mediapipe/calculators/tensor/image_to_tensor_converter_opencv.cc @@ -48,15 +48,19 @@ class OpenCvProcessor : public ImageToTensorConverter { switch (tensor_type_) { case Tensor::ElementType::kInt8: mat_type_ = CV_8SC3; + mat_gray_type_ = CV_8SC1; break; case Tensor::ElementType::kFloat32: mat_type_ = CV_32FC3; + mat_gray_type_ = CV_32FC1; break; case Tensor::ElementType::kUInt8: mat_type_ = CV_8UC3; + mat_gray_type_ = CV_8UC1; break; default: mat_type_ = -1; + mat_gray_type_ = -1; } } @@ -64,36 +68,57 @@ class OpenCvProcessor : public ImageToTensorConverter { float range_min, float range_max, int tensor_buffer_offset, Tensor& output_tensor) override { - if (input.image_format() != mediapipe::ImageFormat::SRGB && - input.image_format() != mediapipe::ImageFormat::SRGBA) { - return InvalidArgumentError( - absl::StrCat("Only RGBA/RGB formats are supported, passed format: ", - static_cast(input.image_format()))); + const bool is_supported_format = + input.image_format() == mediapipe::ImageFormat::SRGB || + input.image_format() == mediapipe::ImageFormat::SRGBA || + input.image_format() == mediapipe::ImageFormat::GRAY8; + if (!is_supported_format) { + return InvalidArgumentError(absl::StrCat( + "Unsupported format: ", static_cast(input.image_format()))); } - // TODO: Remove the check once tensor_buffer_offset > 0 is - // supported. - RET_CHECK_EQ(tensor_buffer_offset, 0) - << "The non-zero tensor_buffer_offset input is not supported yet."; + + RET_CHECK_GE(tensor_buffer_offset, 0) + << "The input tensor_buffer_offset needs to be non-negative."; const auto& output_shape = output_tensor.shape(); MP_RETURN_IF_ERROR(ValidateTensorShape(output_shape)); const int output_height = output_shape.dims[1]; const int output_width = output_shape.dims[2]; const int output_channels = output_shape.dims[3]; + const int num_elements_per_img = + output_height * output_width * output_channels; auto buffer_view = output_tensor.GetCpuWriteView(); cv::Mat dst; + const int dst_data_type = output_channels == 1 ? mat_gray_type_ : mat_type_; switch (tensor_type_) { case Tensor::ElementType::kInt8: - dst = cv::Mat(output_height, output_width, mat_type_, - buffer_view.buffer()); + RET_CHECK_GE(output_shape.num_elements(), + tensor_buffer_offset / sizeof(int8) + num_elements_per_img) + << "The buffer offset + the input image size is larger than the " + "allocated tensor buffer."; + dst = cv::Mat( + output_height, output_width, dst_data_type, + buffer_view.buffer() + tensor_buffer_offset / sizeof(int8)); break; case Tensor::ElementType::kFloat32: - dst = cv::Mat(output_height, output_width, mat_type_, - buffer_view.buffer()); + RET_CHECK_GE( + output_shape.num_elements(), + tensor_buffer_offset / sizeof(float) + num_elements_per_img) + << "The buffer offset + the input image size is larger than the " + "allocated tensor buffer."; + dst = cv::Mat( + output_height, output_width, dst_data_type, + buffer_view.buffer() + tensor_buffer_offset / sizeof(float)); break; case Tensor::ElementType::kUInt8: - dst = cv::Mat(output_height, output_width, mat_type_, - buffer_view.buffer()); + RET_CHECK_GE( + output_shape.num_elements(), + tensor_buffer_offset / sizeof(uint8) + num_elements_per_img) + << "The buffer offset + the input image size is larger than the " + "allocated tensor buffer."; + dst = cv::Mat( + output_height, output_width, dst_data_type, + buffer_view.buffer() + tensor_buffer_offset / sizeof(uint8)); break; default: return InvalidArgumentError( @@ -137,7 +162,8 @@ class OpenCvProcessor : public ImageToTensorConverter { auto transform, GetValueRangeTransformation(kInputImageRangeMin, kInputImageRangeMax, range_min, range_max)); - transformed.convertTo(dst, mat_type_, transform.scale, transform.offset); + transformed.convertTo(dst, dst_data_type, transform.scale, + transform.offset); return absl::OkStatus(); } @@ -145,10 +171,9 @@ class OpenCvProcessor : public ImageToTensorConverter { absl::Status ValidateTensorShape(const Tensor::Shape& output_shape) { RET_CHECK_EQ(output_shape.dims.size(), 4) << "Wrong output dims size: " << output_shape.dims.size(); - RET_CHECK_EQ(output_shape.dims[0], 1) - << "Handling batch dimension not equal to 1 is not implemented in this " - "converter."; - RET_CHECK_EQ(output_shape.dims[3], 3) + RET_CHECK_GE(output_shape.dims[0], 1) + << "The batch dimension needs to be equal or larger than 1."; + RET_CHECK(output_shape.dims[3] == 3 || output_shape.dims[3] == 1) << "Wrong output channel: " << output_shape.dims[3]; return absl::OkStatus(); } @@ -156,6 +181,7 @@ class OpenCvProcessor : public ImageToTensorConverter { enum cv::BorderTypes border_mode_; Tensor::ElementType tensor_type_; int mat_type_; + int mat_gray_type_; }; } // namespace diff --git a/mediapipe/calculators/tensor/image_to_tensor_utils.cc b/mediapipe/calculators/tensor/image_to_tensor_utils.cc index 3f4c05d4e..3f91f3dc2 100644 --- a/mediapipe/calculators/tensor/image_to_tensor_utils.cc +++ b/mediapipe/calculators/tensor/image_to_tensor_utils.cc @@ -253,7 +253,14 @@ int GetNumOutputChannels(const mediapipe::Image& image) { } #endif // MEDIAPIPE_METAL_ENABLED #endif // !MEDIAPIPE_DISABLE_GPU - // All of the processors except for Metal expect 3 channels. + // TODO: Add a unittest here to test the behavior on GPU, i.e. + // failure. + // Only output channel == 1 when running on CPU and the input image channel + // is 1. Ideally, we want to also support GPU for output channel == 1. But + // setting this on the safer side to prevent unintentional failure. + if (!image.UsesGpu() && image.channels() == 1) { + return 1; + } return 3; } diff --git a/mediapipe/calculators/tensor/inference_calculator.cc b/mediapipe/calculators/tensor/inference_calculator.cc index 4ccdc07e1..2a6936eba 100644 --- a/mediapipe/calculators/tensor/inference_calculator.cc +++ b/mediapipe/calculators/tensor/inference_calculator.cc @@ -63,6 +63,10 @@ class InferenceCalculatorSelectorImpl for (const auto& suffix : impls) { const auto impl = absl::StrCat("InferenceCalculator", suffix); if (!mediapipe::CalculatorBaseRegistry::IsRegistered(impl)) continue; + VLOG(1) << "Using " << suffix << " for InferenceCalculator with " + << (options.has_model_path() + ? "model " + options.model_path() + : "output_stream " + subgraph_node.output_stream(0)); CalculatorGraphConfig::Node impl_node = subgraph_node; impl_node.set_calculator(impl); return tool::MakeSingleNodeGraph(std::move(impl_node)); diff --git a/mediapipe/calculators/tensor/inference_calculator.proto b/mediapipe/calculators/tensor/inference_calculator.proto index 46552803b..78a0039bc 100644 --- a/mediapipe/calculators/tensor/inference_calculator.proto +++ b/mediapipe/calculators/tensor/inference_calculator.proto @@ -17,6 +17,7 @@ syntax = "proto2"; package mediapipe; import "mediapipe/framework/calculator.proto"; +import "mediapipe/framework/calculator_options.proto"; option java_package = "com.google.mediapipe.calculator.proto"; option java_outer_classname = "InferenceCalculatorProto"; diff --git a/mediapipe/calculators/tensor/inference_calculator_gl.cc b/mediapipe/calculators/tensor/inference_calculator_gl.cc index bd8eb3eed..27b8bc23a 100644 --- a/mediapipe/calculators/tensor/inference_calculator_gl.cc +++ b/mediapipe/calculators/tensor/inference_calculator_gl.cc @@ -20,6 +20,7 @@ #include "absl/memory/memory.h" #include "absl/status/status.h" +#include "absl/strings/str_format.h" #include "mediapipe/calculators/tensor/inference_calculator.h" #include "mediapipe/calculators/tensor/inference_calculator.pb.h" #include "mediapipe/framework/calculator_context.h" @@ -154,6 +155,10 @@ absl::Status InferenceCalculatorGlImpl::GpuInferenceRunner::LoadDelegate( const auto& input_indices = interpreter_->inputs(); for (int i = 0; i < input_indices.size(); ++i) { const TfLiteTensor* tensor = interpreter_->tensor(input_indices[i]); + RET_CHECK(tensor->dims->size > 0) << absl::StrFormat( + "Input tensor at index [%d] doesn't specify dimensions.", + input_indices[i]); + gpu_buffers_in_.emplace_back(absl::make_unique( Tensor::ElementType::kFloat32, Tensor::Shape{std::vector{ @@ -171,6 +176,9 @@ absl::Status InferenceCalculatorGlImpl::GpuInferenceRunner::LoadDelegate( // Create and bind output buffers. for (int i = 0; i < output_size_; ++i) { const TfLiteTensor* tensor = interpreter_->tensor(output_indices[i]); + RET_CHECK(tensor->dims->size > 0) << absl::StrFormat( + "Output tensor at index [%d] doesn't specify dimensions.", + output_indices[i]); gpu_buffers_out_.emplace_back(absl::make_unique( Tensor::ElementType::kFloat32, Tensor::Shape{std::vector{ diff --git a/mediapipe/calculators/tensor/inference_calculator_gl_advanced.cc b/mediapipe/calculators/tensor/inference_calculator_gl_advanced.cc index ad5df849f..8fd55efa7 100644 --- a/mediapipe/calculators/tensor/inference_calculator_gl_advanced.cc +++ b/mediapipe/calculators/tensor/inference_calculator_gl_advanced.cc @@ -236,14 +236,21 @@ absl::Status InferenceCalculatorGlAdvancedImpl::OnDiskCacheHelper::Init( const mediapipe::InferenceCalculatorOptions& options, const mediapipe::InferenceCalculatorOptions::Delegate::Gpu& gpu_delegate_options) { - use_kernel_caching_ = gpu_delegate_options.has_cached_kernel_path(); + // The kernel cache needs a unique filename based on either model_path or the + // model token, to prevent the cache from being overwritten if the graph has + // more than one model. + use_kernel_caching_ = + gpu_delegate_options.has_cached_kernel_path() && + (options.has_model_path() || gpu_delegate_options.has_model_token()); use_serialized_model_ = gpu_delegate_options.has_serialized_model_dir() && gpu_delegate_options.has_model_token(); if (use_kernel_caching_) { - cached_kernel_filename_ = gpu_delegate_options.cached_kernel_path() + - mediapipe::File::Basename(options.model_path()) + - ".ker"; + std::string basename = options.has_model_path() + ? mediapipe::File::Basename(options.model_path()) + : gpu_delegate_options.model_token(); + cached_kernel_filename_ = mediapipe::file::JoinPath( + gpu_delegate_options.cached_kernel_path(), basename + ".ker"); } if (use_serialized_model_) { serialized_model_path_ = @@ -258,9 +265,9 @@ InferenceCalculatorGlAdvancedImpl::OnDiskCacheHelper::SaveGpuCaches( tflite::gpu::TFLiteGPURunner* gpu_runner) const { if (use_kernel_caching_) { // Save kernel file. - auto kernel_cache = absl::make_unique>( - gpu_runner->GetSerializedBinaryCache()); - std::string cache_str(kernel_cache->begin(), kernel_cache->end()); + ASSIGN_OR_RETURN(std::vector kernel_cache, + gpu_runner->GetSerializedBinaryCache()); + std::string cache_str(kernel_cache.begin(), kernel_cache.end()); MP_RETURN_IF_ERROR( mediapipe::file::SetContents(cached_kernel_filename_, cache_str)); } diff --git a/mediapipe/calculators/tensor/inference_calculator_metal.cc b/mediapipe/calculators/tensor/inference_calculator_metal.cc index a85071f3e..fba18a81c 100644 --- a/mediapipe/calculators/tensor/inference_calculator_metal.cc +++ b/mediapipe/calculators/tensor/inference_calculator_metal.cc @@ -22,7 +22,10 @@ #include #include "absl/memory/memory.h" +#include "absl/strings/str_format.h" #include "mediapipe/calculators/tensor/inference_calculator.h" +#include "mediapipe/framework/formats/tensor.h" +#include "mediapipe/framework/formats/tensor_mtl_buffer_view.h" #import "mediapipe/gpu/MPPMetalHelper.h" #include "mediapipe/gpu/MPPMetalUtil.h" #include "mediapipe/gpu/gpu_buffer.h" @@ -149,11 +152,12 @@ absl::Status InferenceCalculatorMetalImpl::Process(CalculatorContext* cc) { command_buffer.label = @"InferenceCalculator"; // Explicit copy input with conversion float 32 bits to 16 bits. for (int i = 0; i < input_tensors.size(); ++i) { - auto input_view = input_tensors[i].GetMtlBufferReadView(command_buffer); + auto input_view = + MtlBufferView::GetReadView(input_tensors[i], command_buffer); // Reshape tensor. tflite::gpu::BHWC shape = BhwcFromTensorShape(input_tensors[i].shape()); auto gpu_buffer_view = - gpu_buffers_in_[i]->GetMtlBufferWriteView(command_buffer); + MtlBufferView::GetWriteView(*gpu_buffers_in_[i], command_buffer); id input_encoder = [command_buffer computeCommandEncoder]; [converter_to_BPHWC4_ convertWithEncoder:input_encoder @@ -173,9 +177,10 @@ absl::Status InferenceCalculatorMetalImpl::Process(CalculatorContext* cc) { output_shapes_[i]); // Reshape tensor. tflite::gpu::BHWC shape = BhwcFromTensorShape(output_shapes_[i]); - auto read_view = gpu_buffers_out_[i]->GetMtlBufferReadView(command_buffer); + auto read_view = + MtlBufferView::GetReadView(*gpu_buffers_out_[i], command_buffer); auto write_view = - output_tensors->at(i).GetMtlBufferWriteView(command_buffer); + MtlBufferView::GetWriteView(output_tensors->at(i), command_buffer); id output_encoder = [command_buffer computeCommandEncoder]; [converter_from_BPHWC4_ convertWithEncoder:output_encoder @@ -245,6 +250,9 @@ absl::Status InferenceCalculatorMetalImpl::CreateConverters( const auto& input_indices = interpreter_->inputs(); for (int i = 0; i < input_indices.size(); ++i) { const TfLiteTensor* tensor = interpreter_->tensor(input_indices[i]); + RET_CHECK(tensor->dims->size > 0) << absl::StrFormat( + "Input tensor at index [%d] doesn't specify dimensions.", + input_indices[i]); // Create and bind input buffer. std::vector dims{tensor->dims->data, tensor->dims->data + tensor->dims->size}; @@ -254,7 +262,7 @@ absl::Status InferenceCalculatorMetalImpl::CreateConverters( : Tensor::ElementType::kFloat32, Tensor::Shape{dims})); auto buffer_view = - gpu_buffers_in_[i]->GetMtlBufferWriteView(gpu_helper_.mtlDevice); + MtlBufferView::GetWriteView(*gpu_buffers_in_[i], gpu_helper_.mtlDevice); RET_CHECK_EQ(TFLGpuDelegateBindMetalBufferToTensor( delegate_.get(), input_indices[i], buffer_view.buffer()), true); @@ -266,6 +274,9 @@ absl::Status InferenceCalculatorMetalImpl::CreateConverters( output_shapes_.resize(output_indices.size()); for (int i = 0; i < output_shapes_.size(); ++i) { const TfLiteTensor* tensor = interpreter_->tensor(output_indices[i]); + RET_CHECK(tensor->dims->size > 0) << absl::StrFormat( + "Output tensor at index [%d] doesn't specify dimensions.", + output_indices[i]); RET_CHECK(tensor->dims->size <= 4); // Create and bind output buffers. // Channels are always padded to multiple of 4. @@ -279,8 +290,8 @@ absl::Status InferenceCalculatorMetalImpl::CreateConverters( Tensor::Shape{dims})); RET_CHECK_EQ(TFLGpuDelegateBindMetalBufferToTensor( delegate_.get(), output_indices[i], - gpu_buffers_out_[i] - ->GetMtlBufferWriteView(gpu_helper_.mtlDevice) + MtlBufferView::GetWriteView(*gpu_buffers_out_[i], + gpu_helper_.mtlDevice) .buffer()), true); } diff --git a/mediapipe/calculators/tensor/tensor_converter_calculator.cc b/mediapipe/calculators/tensor/tensor_converter_calculator.cc index 0b750b859..4b05488fd 100644 --- a/mediapipe/calculators/tensor/tensor_converter_calculator.cc +++ b/mediapipe/calculators/tensor/tensor_converter_calculator.cc @@ -31,6 +31,7 @@ #import #import +#include "mediapipe/framework/formats/tensor_mtl_buffer_view.h" #import "mediapipe/gpu/MPPMetalHelper.h" #elif MEDIAPIPE_OPENGL_ES_VERSION >= MEDIAPIPE_OPENGL_ES_30 #include "mediapipe/gpu/gl_calculator_helper.h" @@ -304,7 +305,7 @@ absl::Status TensorConverterCalculator::ProcessGPU(CalculatorContext* cc) { id src_texture = [gpu_helper_ metalTextureWithGpuBuffer:input]; [compute_encoder setTexture:src_texture atIndex:0]; auto output_view = - output_tensors->at(0).GetMtlBufferWriteView(command_buffer); + MtlBufferView::GetWriteView(output_tensors->at(0), command_buffer); [compute_encoder setBuffer:output_view.buffer() offset:0 atIndex:1]; MTLSize threads_per_group = MTLSizeMake(kWorkgroupSize, kWorkgroupSize, 1); MTLSize threadgroups = diff --git a/mediapipe/calculators/tensor/tensors_to_detections_calculator.cc b/mediapipe/calculators/tensor/tensors_to_detections_calculator.cc index 97ef01b4c..4bb3f0f57 100644 --- a/mediapipe/calculators/tensor/tensors_to_detections_calculator.cc +++ b/mediapipe/calculators/tensor/tensors_to_detections_calculator.cc @@ -41,6 +41,7 @@ #import #import +#include "mediapipe/framework/formats/tensor_mtl_buffer_view.h" #import "mediapipe/gpu/MPPMetalHelper.h" #include "mediapipe/gpu/MPPMetalUtil.h" #endif // MEDIAPIPE_METAL_ENABLED @@ -536,10 +537,11 @@ absl::Status TensorsToDetectionsCalculator::ProcessGPU( if (input_tensors.size() == kNumInputTensorsWithAnchors) { RET_CHECK_EQ(input_tensors.size(), kNumInputTensorsWithAnchors); auto command_buffer = [gpu_helper_ commandBuffer]; - auto src_buffer = input_tensors[tensor_mapping_.anchors_tensor_index()] - .GetMtlBufferReadView(command_buffer); + auto src_buffer = MtlBufferView::GetReadView( + input_tensors[tensor_mapping_.anchors_tensor_index()], + command_buffer); auto dest_buffer = - raw_anchors_buffer_->GetMtlBufferWriteView(command_buffer); + MtlBufferView::GetWriteView(*raw_anchors_buffer_, command_buffer); id blit_command = [command_buffer blitCommandEncoder]; [blit_command copyFromBuffer:src_buffer.buffer() @@ -571,15 +573,16 @@ absl::Status TensorsToDetectionsCalculator::ProcessGPU( [command_encoder setComputePipelineState:decode_program_]; { auto scored_boxes_view = - scored_boxes_buffer_->GetMtlBufferWriteView(command_buffer); + MtlBufferView::GetWriteView(*scored_boxes_buffer_, command_buffer); auto decoded_boxes_view = - decoded_boxes_buffer_->GetMtlBufferWriteView(command_buffer); + MtlBufferView::GetWriteView(*decoded_boxes_buffer_, command_buffer); [command_encoder setBuffer:decoded_boxes_view.buffer() offset:0 atIndex:0]; - auto input0_view = input_tensors[tensor_mapping_.detections_tensor_index()] - .GetMtlBufferReadView(command_buffer); + auto input0_view = MtlBufferView::GetReadView( + input_tensors[tensor_mapping_.detections_tensor_index()], + command_buffer); [command_encoder setBuffer:input0_view.buffer() offset:0 atIndex:1]; auto raw_anchors_view = - raw_anchors_buffer_->GetMtlBufferReadView(command_buffer); + MtlBufferView::GetReadView(*raw_anchors_buffer_, command_buffer); [command_encoder setBuffer:raw_anchors_view.buffer() offset:0 atIndex:2]; MTLSize decode_threads_per_group = MTLSizeMake(1, 1, 1); MTLSize decode_threadgroups = MTLSizeMake(num_boxes_, 1, 1); @@ -588,8 +591,8 @@ absl::Status TensorsToDetectionsCalculator::ProcessGPU( [command_encoder setComputePipelineState:score_program_]; [command_encoder setBuffer:scored_boxes_view.buffer() offset:0 atIndex:0]; - auto input1_view = input_tensors[tensor_mapping_.scores_tensor_index()] - .GetMtlBufferReadView(command_buffer); + auto input1_view = MtlBufferView::GetReadView( + input_tensors[tensor_mapping_.scores_tensor_index()], command_buffer); [command_encoder setBuffer:input1_view.buffer() offset:0 atIndex:1]; MTLSize score_threads_per_group = MTLSizeMake(1, num_classes_, 1); MTLSize score_threadgroups = MTLSizeMake(num_boxes_, 1, 1); diff --git a/mediapipe/calculators/tensor/tensors_to_segmentation_calculator.cc b/mediapipe/calculators/tensor/tensors_to_segmentation_calculator.cc index 172f70880..839451ab7 100644 --- a/mediapipe/calculators/tensor/tensors_to_segmentation_calculator.cc +++ b/mediapipe/calculators/tensor/tensors_to_segmentation_calculator.cc @@ -53,6 +53,7 @@ #import #import +#include "mediapipe/framework/formats/tensor_mtl_buffer_view.h" #import "mediapipe/gpu/MPPMetalHelper.h" #include "mediapipe/gpu/MPPMetalUtil.h" #endif // MEDIAPIPE_METAL_ENABLED @@ -485,7 +486,8 @@ absl::Status TensorsToSegmentationCalculator::ProcessGpu( [command_buffer computeCommandEncoder]; [command_encoder setComputePipelineState:mask_program_]; - auto read_view = input_tensors[0].GetMtlBufferReadView(command_buffer); + auto read_view = + MtlBufferView::GetReadView(input_tensors[0], command_buffer); [command_encoder setBuffer:read_view.buffer() offset:0 atIndex:0]; mediapipe::GpuBuffer small_mask_buffer = [metal_helper_ diff --git a/mediapipe/calculators/tensorflow/BUILD b/mediapipe/calculators/tensorflow/BUILD index d0dfc12ab..4aec15dcb 100644 --- a/mediapipe/calculators/tensorflow/BUILD +++ b/mediapipe/calculators/tensorflow/BUILD @@ -17,12 +17,11 @@ load("//mediapipe/framework/port:build_config.bzl", "mediapipe_cc_proto_library" licenses(["notice"]) -package(default_visibility = ["//visibility:private"]) +package(default_visibility = ["//visibility:public"]) proto_library( name = "graph_tensors_packet_generator_proto", srcs = ["graph_tensors_packet_generator.proto"], - visibility = ["//visibility:public"], deps = [ "//mediapipe/framework:calculator_proto", "//mediapipe/framework:packet_generator_proto", @@ -32,49 +31,42 @@ proto_library( proto_library( name = "matrix_to_tensor_calculator_options_proto", srcs = ["matrix_to_tensor_calculator_options.proto"], - visibility = ["//visibility:public"], deps = ["//mediapipe/framework:calculator_proto"], ) proto_library( name = "lapped_tensor_buffer_calculator_proto", srcs = ["lapped_tensor_buffer_calculator.proto"], - visibility = ["//visibility:public"], deps = ["//mediapipe/framework:calculator_proto"], ) proto_library( name = "object_detection_tensors_to_detections_calculator_proto", srcs = ["object_detection_tensors_to_detections_calculator.proto"], - visibility = ["//visibility:public"], deps = ["//mediapipe/framework:calculator_proto"], ) proto_library( name = "tensorflow_inference_calculator_proto", srcs = ["tensorflow_inference_calculator.proto"], - visibility = ["//visibility:public"], deps = ["//mediapipe/framework:calculator_proto"], ) proto_library( name = "tensor_squeeze_dimensions_calculator_proto", srcs = ["tensor_squeeze_dimensions_calculator.proto"], - visibility = ["//visibility:public"], deps = ["//mediapipe/framework:calculator_proto"], ) proto_library( name = "tensor_to_image_frame_calculator_proto", srcs = ["tensor_to_image_frame_calculator.proto"], - visibility = ["//visibility:public"], deps = ["//mediapipe/framework:calculator_proto"], ) proto_library( name = "tensor_to_matrix_calculator_proto", srcs = ["tensor_to_matrix_calculator.proto"], - visibility = ["//visibility:public"], deps = [ "//mediapipe/framework:calculator_proto", "//mediapipe/framework/formats:time_series_header_proto", @@ -84,30 +76,24 @@ proto_library( proto_library( name = "tensor_to_vector_float_calculator_options_proto", srcs = ["tensor_to_vector_float_calculator_options.proto"], - visibility = ["//visibility:public"], deps = ["//mediapipe/framework:calculator_proto"], ) proto_library( name = "tensor_to_vector_int_calculator_options_proto", srcs = ["tensor_to_vector_int_calculator_options.proto"], - visibility = ["//visibility:public"], deps = ["//mediapipe/framework:calculator_proto"], ) proto_library( name = "tensor_to_vector_string_calculator_options_proto", srcs = ["tensor_to_vector_string_calculator_options.proto"], - visibility = ["//visibility:public"], deps = ["//mediapipe/framework:calculator_proto"], ) mediapipe_proto_library( name = "unpack_media_sequence_calculator_proto", srcs = ["unpack_media_sequence_calculator.proto"], - visibility = [ - "//visibility:public", - ], deps = [ "//mediapipe/calculators/core:packet_resampler_calculator_proto", "//mediapipe/framework:calculator_proto", @@ -118,14 +104,12 @@ mediapipe_proto_library( proto_library( name = "vector_float_to_tensor_calculator_options_proto", srcs = ["vector_float_to_tensor_calculator_options.proto"], - visibility = ["//visibility:public"], deps = ["//mediapipe/framework:calculator_proto"], ) proto_library( name = "vector_string_to_tensor_calculator_options_proto", srcs = ["vector_string_to_tensor_calculator_options.proto"], - visibility = ["//visibility:public"], deps = ["//mediapipe/framework:calculator_proto"], ) @@ -136,7 +120,6 @@ mediapipe_cc_proto_library( "//mediapipe/framework:calculator_cc_proto", "//mediapipe/framework:packet_generator_cc_proto", ], - visibility = ["//visibility:public"], deps = [":graph_tensors_packet_generator_proto"], ) @@ -147,7 +130,6 @@ mediapipe_cc_proto_library( "//mediapipe/framework:calculator_cc_proto", "@org_tensorflow//tensorflow/core:protos_all_cc", ], - visibility = ["//visibility:public"], deps = [":image_frame_to_tensor_calculator_proto"], ) @@ -155,7 +137,6 @@ mediapipe_cc_proto_library( name = "matrix_to_tensor_calculator_options_cc_proto", srcs = ["matrix_to_tensor_calculator_options.proto"], cc_deps = ["//mediapipe/framework:calculator_cc_proto"], - visibility = ["//visibility:public"], deps = [":matrix_to_tensor_calculator_options_proto"], ) @@ -163,7 +144,6 @@ mediapipe_cc_proto_library( name = "lapped_tensor_buffer_calculator_cc_proto", srcs = ["lapped_tensor_buffer_calculator.proto"], cc_deps = ["//mediapipe/framework:calculator_cc_proto"], - visibility = ["//visibility:public"], deps = [":lapped_tensor_buffer_calculator_proto"], ) @@ -171,7 +151,6 @@ mediapipe_cc_proto_library( name = "object_detection_tensors_to_detections_calculator_cc_proto", srcs = ["object_detection_tensors_to_detections_calculator.proto"], cc_deps = ["//mediapipe/framework:calculator_cc_proto"], - visibility = ["//visibility:public"], deps = [":object_detection_tensors_to_detections_calculator_proto"], ) @@ -179,7 +158,6 @@ mediapipe_cc_proto_library( name = "tensorflow_inference_calculator_cc_proto", srcs = ["tensorflow_inference_calculator.proto"], cc_deps = ["//mediapipe/framework:calculator_cc_proto"], - visibility = ["//visibility:public"], deps = [":tensorflow_inference_calculator_proto"], ) @@ -190,7 +168,6 @@ mediapipe_cc_proto_library( "//mediapipe/framework:packet_generator_cc_proto", "@org_tensorflow//tensorflow/core:protos_all_cc", ], - visibility = ["//visibility:public"], deps = [":tensorflow_session_from_frozen_graph_generator_proto"], ) @@ -201,7 +178,6 @@ mediapipe_cc_proto_library( "//mediapipe/framework:calculator_cc_proto", "@org_tensorflow//tensorflow/core:protos_all_cc", ], - visibility = ["//visibility:public"], deps = [":tensorflow_session_from_frozen_graph_calculator_proto"], ) @@ -212,7 +188,6 @@ mediapipe_cc_proto_library( "//mediapipe/framework:packet_generator_cc_proto", "@org_tensorflow//tensorflow/core:protos_all_cc", ], - visibility = ["//visibility:public"], deps = [":tensorflow_session_from_saved_model_generator_proto"], ) @@ -223,7 +198,6 @@ mediapipe_cc_proto_library( "//mediapipe/framework:calculator_cc_proto", "@org_tensorflow//tensorflow/core:protos_all_cc", ], - visibility = ["//visibility:public"], deps = [":tensorflow_session_from_saved_model_calculator_proto"], ) @@ -231,7 +205,6 @@ mediapipe_cc_proto_library( name = "tensor_squeeze_dimensions_calculator_cc_proto", srcs = ["tensor_squeeze_dimensions_calculator.proto"], cc_deps = ["//mediapipe/framework:calculator_cc_proto"], - visibility = ["//visibility:public"], deps = [":tensor_squeeze_dimensions_calculator_proto"], ) @@ -239,7 +212,6 @@ mediapipe_cc_proto_library( name = "tensor_to_image_frame_calculator_cc_proto", srcs = ["tensor_to_image_frame_calculator.proto"], cc_deps = ["//mediapipe/framework:calculator_cc_proto"], - visibility = ["//visibility:public"], deps = [":tensor_to_image_frame_calculator_proto"], ) @@ -250,7 +222,6 @@ mediapipe_cc_proto_library( "//mediapipe/framework:calculator_cc_proto", "//mediapipe/framework/formats:time_series_header_cc_proto", ], - visibility = ["//visibility:public"], deps = [":tensor_to_matrix_calculator_proto"], ) @@ -258,7 +229,6 @@ mediapipe_cc_proto_library( name = "tensor_to_vector_float_calculator_options_cc_proto", srcs = ["tensor_to_vector_float_calculator_options.proto"], cc_deps = ["//mediapipe/framework:calculator_cc_proto"], - visibility = ["//visibility:public"], deps = [":tensor_to_vector_float_calculator_options_proto"], ) @@ -266,7 +236,6 @@ mediapipe_cc_proto_library( name = "tensor_to_vector_int_calculator_options_cc_proto", srcs = ["tensor_to_vector_int_calculator_options.proto"], cc_deps = ["//mediapipe/framework:calculator_cc_proto"], - visibility = ["//visibility:public"], deps = [":tensor_to_vector_int_calculator_options_proto"], ) @@ -274,7 +243,6 @@ mediapipe_cc_proto_library( name = "tensor_to_vector_string_calculator_options_cc_proto", srcs = ["tensor_to_vector_string_calculator_options.proto"], cc_deps = ["//mediapipe/framework:calculator_cc_proto"], - visibility = ["//visibility:public"], deps = [":tensor_to_vector_string_calculator_options_proto"], ) @@ -285,7 +253,6 @@ mediapipe_cc_proto_library( "//mediapipe/framework:calculator_cc_proto", "@org_tensorflow//tensorflow/core:protos_all_cc", ], - visibility = ["//visibility:public"], deps = [":vector_int_to_tensor_calculator_options_proto"], ) @@ -293,7 +260,6 @@ mediapipe_cc_proto_library( name = "vector_float_to_tensor_calculator_options_cc_proto", srcs = ["vector_float_to_tensor_calculator_options.proto"], cc_deps = ["//mediapipe/framework:calculator_cc_proto"], - visibility = ["//visibility:public"], deps = [":vector_float_to_tensor_calculator_options_proto"], ) @@ -301,14 +267,12 @@ mediapipe_cc_proto_library( name = "vector_string_to_tensor_calculator_options_cc_proto", srcs = ["vector_string_to_tensor_calculator_options.proto"], cc_deps = ["//mediapipe/framework:calculator_cc_proto"], - visibility = ["//visibility:public"], deps = [":vector_string_to_tensor_calculator_options_proto"], ) cc_library( name = "graph_tensors_packet_generator", srcs = ["graph_tensors_packet_generator.cc"], - visibility = ["//visibility:public"], deps = [ ":graph_tensors_packet_generator_cc_proto", "//mediapipe/framework:calculator_framework", @@ -323,7 +287,6 @@ cc_library( cc_library( name = "image_frame_to_tensor_calculator", srcs = ["image_frame_to_tensor_calculator.cc"], - visibility = ["//visibility:public"], deps = [ ":image_frame_to_tensor_calculator_cc_proto", "//mediapipe/framework:calculator_framework", @@ -344,10 +307,9 @@ cc_library( cc_library( name = "matrix_to_tensor_calculator", srcs = ["matrix_to_tensor_calculator.cc"], - visibility = ["//visibility:public"], deps = [ - "//mediapipe/framework/formats:time_series_header_cc_proto", ":matrix_to_tensor_calculator_options_cc_proto", + "//mediapipe/framework/formats:time_series_header_cc_proto", "//mediapipe/framework:calculator_framework", "//mediapipe/framework/formats:matrix", "//mediapipe/framework/port:status", @@ -366,7 +328,6 @@ cc_library( cc_library( name = "lapped_tensor_buffer_calculator", srcs = ["lapped_tensor_buffer_calculator.cc"], - visibility = ["//visibility:public"], deps = [ ":lapped_tensor_buffer_calculator_cc_proto", "//mediapipe/framework:calculator_framework", @@ -388,9 +349,6 @@ cc_library( # Layering check doesn't play nicely with portable proto wrappers. "no_layering_check", ], - visibility = [ - "//visibility:public", - ], deps = [ ":object_detection_tensors_to_detections_calculator_cc_proto", "//mediapipe/framework:calculator_framework", @@ -407,14 +365,11 @@ cc_library( cc_library( name = "pack_media_sequence_calculator", srcs = ["pack_media_sequence_calculator.cc"], - visibility = [ - "//visibility:public", - ], deps = [ "//mediapipe/calculators/image:opencv_image_encoder_calculator_cc_proto", "//mediapipe/calculators/tensorflow:pack_media_sequence_calculator_cc_proto", "//mediapipe/framework:calculator_framework", - "//mediapipe/framework/formats:detection_cc_proto", # build_cleaner: keep + "//mediapipe/framework/formats:detection_cc_proto", "//mediapipe/framework/formats:location", "//mediapipe/framework/formats:location_opencv", "//mediapipe/framework/port:opencv_imgcodecs", @@ -432,9 +387,6 @@ cc_library( cc_library( name = "string_to_sequence_example_calculator", srcs = ["string_to_sequence_example_calculator.cc"], - visibility = [ - "//visibility:public", - ], deps = [ "//mediapipe/framework:calculator_framework", "//mediapipe/framework/port:ret_check", @@ -449,10 +401,9 @@ cc_library( cc_library( name = "tensorflow_inference_calculator", srcs = ["tensorflow_inference_calculator.cc"], - visibility = ["//visibility:public"], deps = [ - ":tensorflow_session", ":tensorflow_inference_calculator_cc_proto", + ":tensorflow_session", "@com_google_absl//absl/log:check", "//mediapipe/framework:timestamp", "@com_google_absl//absl/base:core_headers", @@ -487,7 +438,6 @@ cc_library( "tensorflow_session.h", ], features = ["no_layering_check"], - visibility = ["//visibility:public"], deps = select({ "//conditions:default": [ "@org_tensorflow//tensorflow/core:core", @@ -505,7 +455,6 @@ cc_library( name = "tensorflow_session_from_frozen_graph_calculator", srcs = ["tensorflow_session_from_frozen_graph_calculator.cc"], features = ["no_layering_check"], - visibility = ["//visibility:public"], deps = [ ":tensorflow_session", "//mediapipe/calculators/tensorflow:tensorflow_session_from_frozen_graph_calculator_cc_proto", @@ -515,6 +464,7 @@ cc_library( "//mediapipe/framework/port:logging", "//mediapipe/framework/port:status", "//mediapipe/framework/port:ret_check", + "@org_tensorflow//tensorflow/core:protos_all_cc", ] + select({ "//conditions:default": [ "//mediapipe/framework/port:file_helpers", @@ -536,7 +486,6 @@ cc_library( name = "tensorflow_session_from_frozen_graph_generator", srcs = ["tensorflow_session_from_frozen_graph_generator.cc"], features = ["no_layering_check"], - visibility = ["//visibility:public"], deps = [ ":tensorflow_session", ":tensorflow_session_from_frozen_graph_generator_cc_proto", @@ -546,6 +495,7 @@ cc_library( "//mediapipe/framework/deps:clock", "//mediapipe/framework/port:logging", "//mediapipe/framework/port:ret_check", + "@org_tensorflow//tensorflow/core:protos_all_cc", ] + select({ "//conditions:default": [ "//mediapipe/framework/port:file_helpers", @@ -570,7 +520,6 @@ cc_library( "//mediapipe:android": ["__ANDROID__"], "//conditions:default": [], }), - visibility = ["//visibility:public"], deps = [ ":tensorflow_session", ":tensorflow_session_from_saved_model_calculator_cc_proto", @@ -609,7 +558,6 @@ cc_library( "//mediapipe:android": ["__ANDROID__"], "//conditions:default": [], }), - visibility = ["//visibility:public"], deps = [ ":tensorflow_session", ":tensorflow_session_from_saved_model_generator_cc_proto", @@ -635,7 +583,6 @@ cc_library( cc_library( name = "tensor_squeeze_dimensions_calculator", srcs = ["tensor_squeeze_dimensions_calculator.cc"], - visibility = ["//visibility:public"], deps = [ ":tensor_squeeze_dimensions_calculator_cc_proto", "//mediapipe/framework:calculator_framework", @@ -649,7 +596,6 @@ cc_library( cc_library( name = "tensor_to_image_frame_calculator", srcs = ["tensor_to_image_frame_calculator.cc"], - visibility = ["//visibility:public"], deps = [ ":tensor_to_image_frame_calculator_cc_proto", "//mediapipe/framework:calculator_framework", @@ -664,10 +610,9 @@ cc_library( cc_library( name = "tensor_to_matrix_calculator", srcs = ["tensor_to_matrix_calculator.cc"], - visibility = ["//visibility:public"], deps = [ - "//mediapipe/framework/formats:time_series_header_cc_proto", ":tensor_to_matrix_calculator_cc_proto", + "//mediapipe/framework/formats:time_series_header_cc_proto", "//mediapipe/framework:calculator_framework", "//mediapipe/framework/formats:matrix", "//mediapipe/framework/port:status", @@ -686,7 +631,6 @@ cc_library( cc_library( name = "tfrecord_reader_calculator", srcs = ["tfrecord_reader_calculator.cc"], - visibility = ["//visibility:public"], deps = [ "//mediapipe/framework:calculator_framework", "//mediapipe/framework/port:integral_types", @@ -702,12 +646,11 @@ cc_library( cc_library( name = "tensor_to_vector_float_calculator", srcs = ["tensor_to_vector_float_calculator.cc"], - visibility = ["//visibility:public"], deps = [ + ":tensor_to_vector_float_calculator_options_cc_proto", "//mediapipe/framework:calculator_framework", "//mediapipe/framework/port:status", "//mediapipe/framework/port:ret_check", - ":tensor_to_vector_float_calculator_options_cc_proto", ] + select({ "//conditions:default": [ "@org_tensorflow//tensorflow/core:framework", @@ -722,7 +665,6 @@ cc_library( cc_library( name = "tensor_to_vector_int_calculator", srcs = ["tensor_to_vector_int_calculator.cc"], - visibility = ["//visibility:public"], deps = [ ":tensor_to_vector_int_calculator_options_cc_proto", "@com_google_absl//absl/base:core_headers", @@ -744,7 +686,6 @@ cc_library( cc_library( name = "tensor_to_vector_string_calculator", srcs = ["tensor_to_vector_string_calculator.cc"], - visibility = ["//visibility:public"], deps = [ "//mediapipe/framework:calculator_framework", "//mediapipe/framework/port:status", @@ -764,9 +705,6 @@ cc_library( cc_library( name = "unpack_media_sequence_calculator", srcs = ["unpack_media_sequence_calculator.cc"], - visibility = [ - "//visibility:public", - ], deps = [ "//mediapipe/calculators/core:packet_resampler_calculator_cc_proto", "//mediapipe/calculators/tensorflow:unpack_media_sequence_calculator_cc_proto", @@ -784,7 +722,6 @@ cc_library( cc_library( name = "vector_int_to_tensor_calculator", srcs = ["vector_int_to_tensor_calculator.cc"], - visibility = ["//visibility:public"], deps = [ ":vector_int_to_tensor_calculator_options_cc_proto", "//mediapipe/framework:calculator_framework", @@ -798,7 +735,6 @@ cc_library( cc_library( name = "vector_float_to_tensor_calculator", srcs = ["vector_float_to_tensor_calculator.cc"], - visibility = ["//visibility:public"], deps = [ ":vector_float_to_tensor_calculator_options_cc_proto", "//mediapipe/framework:calculator_framework", @@ -812,7 +748,6 @@ cc_library( cc_library( name = "vector_string_to_tensor_calculator", srcs = ["vector_string_to_tensor_calculator.cc"], - visibility = ["//visibility:public"], deps = [ ":vector_string_to_tensor_calculator_options_cc_proto", "//mediapipe/framework:calculator_framework", @@ -826,7 +761,6 @@ cc_library( cc_library( name = "unpack_yt8m_sequence_example_calculator", srcs = ["unpack_yt8m_sequence_example_calculator.cc"], - visibility = ["//visibility:public"], deps = [ ":lapped_tensor_buffer_calculator_cc_proto", "//mediapipe/framework:calculator_framework", @@ -1083,7 +1017,6 @@ cc_test( linkstatic = 1, deps = [ ":tensor_to_image_frame_calculator", - ":tensor_to_image_frame_calculator_cc_proto", "//mediapipe/framework:calculator_framework", "//mediapipe/framework:calculator_runner", "//mediapipe/framework/formats:image_frame", @@ -1121,6 +1054,7 @@ cc_test( "//mediapipe/framework:calculator_framework", "//mediapipe/framework:calculator_runner", "//mediapipe/framework/port:gtest_main", + "//mediapipe/util:packet_test_util", "@org_tensorflow//tensorflow/core:framework", "@org_tensorflow//tensorflow/core:protos_all_cc", ], @@ -1236,6 +1170,7 @@ cc_test( data = [":test_frozen_graph"], linkstatic = 1, deps = [ + ":tensorflow_inference_calculator_cc_proto", ":tensorflow_session", ":tensorflow_inference_calculator", ":tensorflow_session_from_frozen_graph_generator", diff --git a/mediapipe/calculators/tensorflow/tensor_to_vector_float_calculator.cc b/mediapipe/calculators/tensorflow/tensor_to_vector_float_calculator.cc index cd807b87b..ec7cd70fa 100644 --- a/mediapipe/calculators/tensorflow/tensor_to_vector_float_calculator.cc +++ b/mediapipe/calculators/tensorflow/tensor_to_vector_float_calculator.cc @@ -102,7 +102,7 @@ absl::Status TensorToVectorFloatCalculator::Process(CalculatorContext* cc) { } auto output = absl::make_unique>(input_tensor.NumElements()); - const auto& tensor_values = input_tensor.flat(); + const auto& tensor_values = input_tensor.unaligned_flat(); for (int i = 0; i < input_tensor.NumElements(); ++i) { output->at(i) = tensor_values(i); } diff --git a/mediapipe/calculators/tensorflow/tensor_to_vector_float_calculator_test.cc b/mediapipe/calculators/tensorflow/tensor_to_vector_float_calculator_test.cc index 69d3af60a..98ba4f020 100644 --- a/mediapipe/calculators/tensorflow/tensor_to_vector_float_calculator_test.cc +++ b/mediapipe/calculators/tensorflow/tensor_to_vector_float_calculator_test.cc @@ -16,6 +16,7 @@ #include "mediapipe/framework/calculator_framework.h" #include "mediapipe/framework/calculator_runner.h" #include "mediapipe/framework/port/gtest.h" +#include "mediapipe/util/packet_test_util.h" #include "tensorflow/core/framework/tensor.h" #include "tensorflow/core/framework/types.pb.h" @@ -129,5 +130,28 @@ TEST_F(TensorToVectorFloatCalculatorTest, FlattenShouldTakeAllDimensions) { } } +TEST_F(TensorToVectorFloatCalculatorTest, AcceptsUnalignedTensors) { + SetUpRunner(/*tensor_is_2d=*/false, /*flatten_nd=*/false); + + const tf::TensorShape tensor_shape(std::vector{2, 5}); + tf::Tensor tensor(tf::DT_FLOAT, tensor_shape); + auto slice = tensor.Slice(1, 1).flat(); + for (int i = 0; i < 5; ++i) { + slice(i) = i; + } + + auto input_tensor = tensor.SubSlice(1); + // Ensure that the input tensor is unaligned. + ASSERT_FALSE(input_tensor.IsAligned()); + runner_->MutableInputs()->Index(0).packets.push_back( + MakePacket(input_tensor).At(Timestamp(5))); + + ASSERT_TRUE(runner_->Run().ok()); + + EXPECT_THAT(runner_->Outputs().Index(0).packets, + ElementsAre(PacketContainsTimestampAndPayload>( + Timestamp(5), std::vector({0, 1, 2, 3, 4})))); +} + } // namespace } // namespace mediapipe diff --git a/mediapipe/calculators/tensorflow/tensor_to_vector_int_calculator.cc b/mediapipe/calculators/tensorflow/tensor_to_vector_int_calculator.cc index 2f4ff28cf..f92ddf08d 100644 --- a/mediapipe/calculators/tensorflow/tensor_to_vector_int_calculator.cc +++ b/mediapipe/calculators/tensorflow/tensor_to_vector_int_calculator.cc @@ -37,8 +37,10 @@ class TensorToVectorIntCalculator : public CalculatorBase { private: void TokenizeVector(std::vector* vector) const; + void RemoveOverlapVector(std::vector* vector); TensorToVectorIntCalculatorOptions options_; + int32_t overlapping_values_; }; REGISTER_CALCULATOR(TensorToVectorIntCalculator); @@ -66,6 +68,7 @@ absl::Status TensorToVectorIntCalculator::GetContract(CalculatorContract* cc) { absl::Status TensorToVectorIntCalculator::Open(CalculatorContext* cc) { options_ = cc->Options(); + overlapping_values_ = 0; // Inform mediapipe that this calculator produces an output at time t for // each input received at time t (i.e. this calculator does not buffer @@ -106,6 +109,7 @@ absl::Status TensorToVectorIntCalculator::Process(CalculatorContext* cc) { } } TokenizeVector(&instance_output); + RemoveOverlapVector(&instance_output); } cc->Outputs().Index(0).Add(output.release(), cc->InputTimestamp()); } else { @@ -128,12 +132,28 @@ absl::Status TensorToVectorIntCalculator::Process(CalculatorContext* cc) { } } TokenizeVector(output.get()); + RemoveOverlapVector(output.get()); cc->Outputs().Index(0).Add(output.release(), cc->InputTimestamp()); } return absl::OkStatus(); } +void TensorToVectorIntCalculator::RemoveOverlapVector( + std::vector* vector) { + if (options_.overlap() <= 0) { + return; + } + if (overlapping_values_ > 0) { + if (vector->size() < overlapping_values_) { + vector->clear(); + } else { + vector->erase(vector->begin(), vector->begin() + overlapping_values_); + } + } + overlapping_values_ = options_.overlap(); +} + void TensorToVectorIntCalculator::TokenizeVector( std::vector* vector) const { if (!options_.tensor_is_token()) { diff --git a/mediapipe/calculators/tensorflow/tensor_to_vector_int_calculator_options.proto b/mediapipe/calculators/tensorflow/tensor_to_vector_int_calculator_options.proto index 9da3298b9..76b9be952 100644 --- a/mediapipe/calculators/tensorflow/tensor_to_vector_int_calculator_options.proto +++ b/mediapipe/calculators/tensorflow/tensor_to_vector_int_calculator_options.proto @@ -36,4 +36,8 @@ message TensorToVectorIntCalculatorOptions { optional bool tensor_is_token = 3 [default = false]; // Threshold for the token generation. optional float token_threshold = 4 [default = 0.5]; + + // Values which overlap between timely following vectors. They are removed + // from the output to reduce redundancy. + optional int32 overlap = 5 [default = 0]; } diff --git a/mediapipe/calculators/tensorflow/tensor_to_vector_int_calculator_test.cc b/mediapipe/calculators/tensorflow/tensor_to_vector_int_calculator_test.cc index 60c0d47ec..406c2c1a7 100644 --- a/mediapipe/calculators/tensorflow/tensor_to_vector_int_calculator_test.cc +++ b/mediapipe/calculators/tensorflow/tensor_to_vector_int_calculator_test.cc @@ -28,7 +28,8 @@ namespace tf = ::tensorflow; class TensorToVectorIntCalculatorTest : public ::testing::Test { protected: void SetUpRunner(const bool tensor_is_2d, const bool flatten_nd, - const bool tensor_is_token = false) { + const bool tensor_is_token = false, + const int32_t overlap = 0) { CalculatorGraphConfig::Node config; config.set_calculator("TensorToVectorIntCalculator"); config.add_input_stream("input_tensor"); @@ -38,6 +39,7 @@ class TensorToVectorIntCalculatorTest : public ::testing::Test { options->set_tensor_is_2d(tensor_is_2d); options->set_flatten_nd(flatten_nd); options->set_tensor_is_token(tensor_is_token); + options->set_overlap(overlap); runner_ = absl::make_unique(config); } @@ -188,5 +190,54 @@ TEST_F(TensorToVectorIntCalculatorTest, FlattenShouldTakeAllDimensions) { } } +TEST_F(TensorToVectorIntCalculatorTest, Overlap) { + SetUpRunner(false, false, false, 2); + for (int time = 0; time < 3; ++time) { + const tf::TensorShape tensor_shape(std::vector{5}); + auto tensor = absl::make_unique(tf::DT_INT64, tensor_shape); + auto tensor_vec = tensor->vec(); + for (int i = 0; i < 5; ++i) { + // 2^i can be represented exactly in floating point numbers if 'i' is + // small. + tensor_vec(i) = static_cast(time + (1 << i)); + } + + runner_->MutableInputs()->Index(0).packets.push_back( + Adopt(tensor.release()).At(Timestamp(time))); + } + + ASSERT_TRUE(runner_->Run().ok()); + const std::vector& output_packets = + runner_->Outputs().Index(0).packets; + EXPECT_EQ(3, output_packets.size()); + + { + // First vector in full. + int time = 0; + EXPECT_EQ(time, output_packets[time].Timestamp().Value()); + const std::vector& output_vector = + output_packets[time].Get>(); + + EXPECT_EQ(5, output_vector.size()); + for (int i = 0; i < 5; ++i) { + const int64 expected = static_cast(time + (1 << i)); + EXPECT_EQ(expected, output_vector[i]); + } + } + + // All following vectors the overlap removed + for (int time = 1; time < 3; ++time) { + EXPECT_EQ(time, output_packets[time].Timestamp().Value()); + const std::vector& output_vector = + output_packets[time].Get>(); + + EXPECT_EQ(3, output_vector.size()); + for (int i = 0; i < 3; ++i) { + const int64 expected = static_cast(time + (1 << (i + 2))); + EXPECT_EQ(expected, output_vector[i]); + } + } +} + } // namespace } // namespace mediapipe diff --git a/mediapipe/calculators/tensorflow/tensorflow_session_from_saved_model_calculator.cc b/mediapipe/calculators/tensorflow/tensorflow_session_from_saved_model_calculator.cc index 922eb9d50..18bddbbe3 100644 --- a/mediapipe/calculators/tensorflow/tensorflow_session_from_saved_model_calculator.cc +++ b/mediapipe/calculators/tensorflow/tensorflow_session_from_saved_model_calculator.cc @@ -55,7 +55,7 @@ absl::Status GetLatestDirectory(std::string* path) { } // If options.convert_signature_to_tags() is set, will convert letters to -// uppercase and replace /'s and -'s with _'s. This enables the standard +// uppercase and replace /, -, . and :'s with _'s. This enables the standard // SavedModel classification, regression, and prediction signatures to be used // as uppercase INPUTS and OUTPUTS tags for streams and supports other common // patterns. @@ -67,9 +67,8 @@ const std::string MaybeConvertSignatureToTag( output.resize(name.length()); std::transform(name.begin(), name.end(), output.begin(), [](unsigned char c) { return std::toupper(c); }); - output = absl::StrReplaceAll(output, {{"/", "_"}}); - output = absl::StrReplaceAll(output, {{"-", "_"}}); - output = absl::StrReplaceAll(output, {{".", "_"}}); + output = absl::StrReplaceAll( + output, {{"/", "_"}, {"-", "_"}, {".", "_"}, {":", "_"}}); LOG(INFO) << "Renamed TAG from: " << name << " to " << output; return output; } else { diff --git a/mediapipe/calculators/tensorflow/tensorflow_session_from_saved_model_calculator.proto b/mediapipe/calculators/tensorflow/tensorflow_session_from_saved_model_calculator.proto index 927d3b51f..515b46fa9 100644 --- a/mediapipe/calculators/tensorflow/tensorflow_session_from_saved_model_calculator.proto +++ b/mediapipe/calculators/tensorflow/tensorflow_session_from_saved_model_calculator.proto @@ -33,8 +33,8 @@ message TensorFlowSessionFromSavedModelCalculatorOptions { // The name of the generic signature to load into the mapping from tags to // tensor names. optional string signature_name = 2 [default = "serving_default"]; - // Whether to convert the signature keys to uppercase as well as switch /'s - // and -'s to _'s, which enables common signatures to be used as Tags. + // Whether to convert the signature keys to uppercase as well as switch + // /, -, .and :'s to _'s, which enables common signatures to be used as Tags. optional bool convert_signature_to_tags = 3 [default = true]; // If true, saved_model_path can have multiple exported models in // subdirectories saved_model_path/%08d and the alphabetically last (i.e., diff --git a/mediapipe/calculators/tensorflow/tensorflow_session_from_saved_model_generator.cc b/mediapipe/calculators/tensorflow/tensorflow_session_from_saved_model_generator.cc index d5236f1cc..ee69ec56a 100644 --- a/mediapipe/calculators/tensorflow/tensorflow_session_from_saved_model_generator.cc +++ b/mediapipe/calculators/tensorflow/tensorflow_session_from_saved_model_generator.cc @@ -61,7 +61,7 @@ absl::Status GetLatestDirectory(std::string* path) { } // If options.convert_signature_to_tags() is set, will convert letters to -// uppercase and replace /'s and -'s with _'s. This enables the standard +// uppercase and replace /, -, and .'s with _'s. This enables the standard // SavedModel classification, regression, and prediction signatures to be used // as uppercase INPUTS and OUTPUTS tags for streams and supports other common // patterns. @@ -73,9 +73,8 @@ const std::string MaybeConvertSignatureToTag( output.resize(name.length()); std::transform(name.begin(), name.end(), output.begin(), [](unsigned char c) { return std::toupper(c); }); - output = absl::StrReplaceAll(output, {{"/", "_"}}); - output = absl::StrReplaceAll(output, {{"-", "_"}}); - output = absl::StrReplaceAll(output, {{".", "_"}}); + output = absl::StrReplaceAll( + output, {{"/", "_"}, {"-", "_"}, {".", "_"}, {":", "_"}}); LOG(INFO) << "Renamed TAG from: " << name << " to " << output; return output; } else { diff --git a/mediapipe/calculators/tensorflow/tensorflow_session_from_saved_model_generator.proto b/mediapipe/calculators/tensorflow/tensorflow_session_from_saved_model_generator.proto index d24a1cd73..d45fcb662 100644 --- a/mediapipe/calculators/tensorflow/tensorflow_session_from_saved_model_generator.proto +++ b/mediapipe/calculators/tensorflow/tensorflow_session_from_saved_model_generator.proto @@ -33,8 +33,8 @@ message TensorFlowSessionFromSavedModelGeneratorOptions { // The name of the generic signature to load into the mapping from tags to // tensor names. optional string signature_name = 2 [default = "serving_default"]; - // Whether to convert the signature keys to uppercase as well as switch /'s - // and -'s to _'s, which enables common signatures to be used as Tags. + // Whether to convert the signature keys to uppercase, as well as switch /'s + // -'s, .'s, and :'s to _'s, enabling common signatures to be used as Tags. optional bool convert_signature_to_tags = 3 [default = true]; // If true, saved_model_path can have multiple exported models in // subdirectories saved_model_path/%08d and the alphabetically last (i.e., diff --git a/mediapipe/calculators/tensorflow/unpack_media_sequence_calculator_test.cc b/mediapipe/calculators/tensorflow/unpack_media_sequence_calculator_test.cc index d8562ffc4..fbf775403 100644 --- a/mediapipe/calculators/tensorflow/unpack_media_sequence_calculator_test.cc +++ b/mediapipe/calculators/tensorflow/unpack_media_sequence_calculator_test.cc @@ -647,7 +647,7 @@ TEST_F(UnpackMediaSequenceCalculatorTest, GetAudioDecoderOptionsOverride) { TEST_F(UnpackMediaSequenceCalculatorTest, GetPacketResamplingOptions) { // TODO: Suport proto3 proto.Any in CalculatorOptions. - // TODO: Avoid proto2 extensions in "RESAMPLER_OPTIONS". + // TODO: Avoid google::protobuf extensions in "RESAMPLER_OPTIONS". CalculatorOptions options; options.MutableExtension(UnpackMediaSequenceCalculatorOptions::ext) ->set_padding_before_label(1); diff --git a/mediapipe/calculators/tflite/BUILD b/mediapipe/calculators/tflite/BUILD index 2007a4fe1..db2a27630 100644 --- a/mediapipe/calculators/tflite/BUILD +++ b/mediapipe/calculators/tflite/BUILD @@ -18,12 +18,11 @@ load("@bazel_skylib//lib:selects.bzl", "selects") licenses(["notice"]) -package(default_visibility = ["//visibility:private"]) +package(default_visibility = ["//visibility:public"]) mediapipe_proto_library( name = "ssd_anchors_calculator_proto", srcs = ["ssd_anchors_calculator.proto"], - visibility = ["//visibility:public"], deps = [ "//mediapipe/framework:calculator_options_proto", "//mediapipe/framework:calculator_proto", @@ -33,7 +32,6 @@ mediapipe_proto_library( mediapipe_proto_library( name = "tflite_custom_op_resolver_calculator_proto", srcs = ["tflite_custom_op_resolver_calculator.proto"], - visibility = ["//visibility:public"], deps = [ "//mediapipe/framework:calculator_options_proto", "//mediapipe/framework:calculator_proto", @@ -43,7 +41,6 @@ mediapipe_proto_library( mediapipe_proto_library( name = "tflite_inference_calculator_proto", srcs = ["tflite_inference_calculator.proto"], - visibility = ["//visibility:public"], deps = [ "//mediapipe/framework:calculator_options_proto", "//mediapipe/framework:calculator_proto", @@ -53,7 +50,6 @@ mediapipe_proto_library( mediapipe_proto_library( name = "tflite_converter_calculator_proto", srcs = ["tflite_converter_calculator.proto"], - visibility = ["//visibility:public"], deps = [ "//mediapipe/framework:calculator_options_proto", "//mediapipe/framework:calculator_proto", @@ -63,7 +59,6 @@ mediapipe_proto_library( mediapipe_proto_library( name = "tflite_tensors_to_segmentation_calculator_proto", srcs = ["tflite_tensors_to_segmentation_calculator.proto"], - visibility = ["//visibility:public"], deps = [ "//mediapipe/framework:calculator_options_proto", "//mediapipe/framework:calculator_proto", @@ -73,7 +68,6 @@ mediapipe_proto_library( mediapipe_proto_library( name = "tflite_tensors_to_detections_calculator_proto", srcs = ["tflite_tensors_to_detections_calculator.proto"], - visibility = ["//visibility:public"], deps = [ "//mediapipe/framework:calculator_options_proto", "//mediapipe/framework:calculator_proto", @@ -83,7 +77,6 @@ mediapipe_proto_library( mediapipe_proto_library( name = "tflite_tensors_to_classification_calculator_proto", srcs = ["tflite_tensors_to_classification_calculator.proto"], - visibility = ["//visibility:public"], deps = [ "//mediapipe/framework:calculator_options_proto", "//mediapipe/framework:calculator_proto", @@ -93,7 +86,6 @@ mediapipe_proto_library( mediapipe_proto_library( name = "tflite_tensors_to_landmarks_calculator_proto", srcs = ["tflite_tensors_to_landmarks_calculator.proto"], - visibility = ["//visibility:public"], deps = [ "//mediapipe/framework:calculator_options_proto", "//mediapipe/framework:calculator_proto", @@ -103,7 +95,6 @@ mediapipe_proto_library( cc_library( name = "ssd_anchors_calculator", srcs = ["ssd_anchors_calculator.cc"], - visibility = ["//visibility:public"], deps = [ ":ssd_anchors_calculator_cc_proto", "//mediapipe/framework:calculator_framework", @@ -117,7 +108,6 @@ cc_library( cc_library( name = "tflite_custom_op_resolver_calculator", srcs = ["tflite_custom_op_resolver_calculator.cc"], - visibility = ["//visibility:public"], deps = [ ":tflite_custom_op_resolver_calculator_cc_proto", "//mediapipe/framework:calculator_framework", @@ -208,7 +198,6 @@ cc_library( ], "//conditions:default": [], }), - visibility = ["//visibility:public"], deps = [ ":tflite_inference_calculator_cc_proto", "@com_google_absl//absl/memory", @@ -287,10 +276,9 @@ cc_library( ], "//conditions:default": [], }), - visibility = ["//visibility:public"], deps = [ - "//mediapipe/util/tflite:config", ":tflite_converter_calculator_cc_proto", + "//mediapipe/util/tflite:config", "//mediapipe/util:resource_util", "//mediapipe/framework:calculator_framework", "//mediapipe/framework/formats:image_frame", @@ -326,7 +314,6 @@ cc_library( cc_library( name = "tflite_model_calculator", srcs = ["tflite_model_calculator.cc"], - visibility = ["//visibility:public"], deps = [ "//mediapipe/framework:calculator_framework", "//mediapipe/framework:packet", @@ -340,7 +327,6 @@ cc_library( cc_library( name = "tflite_tensors_to_segmentation_calculator", srcs = ["tflite_tensors_to_segmentation_calculator.cc"], - visibility = ["//visibility:public"], deps = [ ":tflite_tensors_to_segmentation_calculator_cc_proto", "@com_google_absl//absl/strings:str_format", @@ -408,17 +394,16 @@ cc_library( ], "//conditions:default": [], }), - visibility = ["//visibility:public"], deps = [ - "//mediapipe/util/tflite:config", ":tflite_tensors_to_detections_calculator_cc_proto", "//mediapipe/framework/formats:detection_cc_proto", + "//mediapipe/framework/formats/object_detection:anchor_cc_proto", + "//mediapipe/util/tflite:config", "@com_google_absl//absl/strings:str_format", "@com_google_absl//absl/types:span", "//mediapipe/framework/deps:file_path", "//mediapipe/framework:calculator_framework", "//mediapipe/framework/formats:location", - "//mediapipe/framework/formats/object_detection:anchor_cc_proto", "//mediapipe/framework/port:ret_check", "@org_tensorflow//tensorflow/lite:framework", ] + selects.with_or({ @@ -444,7 +429,6 @@ cc_library( cc_library( name = "tflite_tensors_to_classification_calculator", srcs = ["tflite_tensors_to_classification_calculator.cc"], - visibility = ["//visibility:public"], deps = [ ":tflite_tensors_to_classification_calculator_cc_proto", "@com_google_absl//absl/container:node_hash_map", @@ -476,7 +460,6 @@ cc_library( cc_library( name = "tflite_tensors_to_landmarks_calculator", srcs = ["tflite_tensors_to_landmarks_calculator.cc"], - visibility = ["//visibility:public"], deps = [ ":tflite_tensors_to_landmarks_calculator_cc_proto", "//mediapipe/framework:calculator_framework", @@ -490,7 +473,6 @@ cc_library( cc_library( name = "tflite_tensors_to_floats_calculator", srcs = ["tflite_tensors_to_floats_calculator.cc"], - visibility = ["//visibility:public"], deps = [ "//mediapipe/framework:calculator_framework", "//mediapipe/framework/port:ret_check", diff --git a/mediapipe/calculators/tflite/tflite_inference_calculator.cc b/mediapipe/calculators/tflite/tflite_inference_calculator.cc index afdc9ed6f..0f7fa933e 100644 --- a/mediapipe/calculators/tflite/tflite_inference_calculator.cc +++ b/mediapipe/calculators/tflite/tflite_inference_calculator.cc @@ -485,9 +485,9 @@ absl::Status TfLiteInferenceCalculator::WriteKernelsToFile() { #if MEDIAPIPE_TFLITE_GL_INFERENCE && defined(MEDIAPIPE_ANDROID) if (use_kernel_caching_) { // Save kernel file. - auto kernel_cache = absl::make_unique>( - tflite_gpu_runner_->GetSerializedBinaryCache()); - std::string cache_str(kernel_cache->begin(), kernel_cache->end()); + ASSIGN_OR_RETURN(std::vector kernel_cache, + tflite_gpu_runner_->GetSerializedBinaryCache()); + std::string cache_str(kernel_cache.begin(), kernel_cache.end()); MP_RETURN_IF_ERROR( mediapipe::file::SetContents(cached_kernel_filename_, cache_str)); } diff --git a/mediapipe/calculators/util/BUILD b/mediapipe/calculators/util/BUILD index 3a9ddc36f..a679a80fd 100644 --- a/mediapipe/calculators/util/BUILD +++ b/mediapipe/calculators/util/BUILD @@ -21,10 +21,9 @@ package(default_visibility = ["//visibility:public"]) cc_library( name = "alignment_points_to_rects_calculator", srcs = ["alignment_points_to_rects_calculator.cc"], - visibility = ["//visibility:public"], deps = [ - "//mediapipe/calculators/util:detections_to_rects_calculator", - "//mediapipe/calculators/util:detections_to_rects_calculator_cc_proto", + ":detections_to_rects_calculator", + ":detections_to_rects_calculator_cc_proto", "//mediapipe/framework:calculator_framework", "//mediapipe/framework:calculator_options_cc_proto", "//mediapipe/framework/formats:detection_cc_proto", @@ -39,7 +38,6 @@ cc_library( mediapipe_proto_library( name = "annotation_overlay_calculator_proto", srcs = ["annotation_overlay_calculator.proto"], - visibility = ["//visibility:public"], deps = [ "//mediapipe/framework:calculator_options_proto", "//mediapipe/framework:calculator_proto", @@ -50,7 +48,6 @@ mediapipe_proto_library( mediapipe_proto_library( name = "detection_label_id_to_text_calculator_proto", srcs = ["detection_label_id_to_text_calculator.proto"], - visibility = ["//visibility:public"], deps = [ "//mediapipe/framework:calculator_options_proto", "//mediapipe/framework:calculator_proto", @@ -61,7 +58,6 @@ mediapipe_proto_library( mediapipe_proto_library( name = "filter_detections_calculator_proto", srcs = ["filter_detections_calculator.proto"], - visibility = ["//visibility:public"], deps = [ "//mediapipe/framework:calculator_options_proto", "//mediapipe/framework:calculator_proto", @@ -71,7 +67,6 @@ mediapipe_proto_library( mediapipe_proto_library( name = "timed_box_list_id_to_label_calculator_proto", srcs = ["timed_box_list_id_to_label_calculator.proto"], - visibility = ["//visibility:public"], deps = [ "//mediapipe/framework:calculator_options_proto", "//mediapipe/framework:calculator_proto", @@ -81,13 +76,11 @@ mediapipe_proto_library( mediapipe_proto_library( name = "latency_proto", srcs = ["latency.proto"], - visibility = ["//visibility:public"], ) mediapipe_proto_library( name = "non_max_suppression_calculator_proto", srcs = ["non_max_suppression_calculator.proto"], - visibility = ["//visibility:public"], deps = [ "//mediapipe/framework:calculator_options_proto", "//mediapipe/framework:calculator_proto", @@ -97,13 +90,11 @@ mediapipe_proto_library( mediapipe_proto_library( name = "packet_frequency_proto", srcs = ["packet_frequency.proto"], - visibility = ["//visibility:public"], ) mediapipe_proto_library( name = "packet_frequency_calculator_proto", srcs = ["packet_frequency_calculator.proto"], - visibility = ["//visibility:public"], deps = [ "//mediapipe/framework:calculator_options_proto", "//mediapipe/framework:calculator_proto", @@ -113,7 +104,6 @@ mediapipe_proto_library( mediapipe_proto_library( name = "packet_latency_calculator_proto", srcs = ["packet_latency_calculator.proto"], - visibility = ["//visibility:public"], deps = [ "//mediapipe/framework:calculator_options_proto", "//mediapipe/framework:calculator_proto", @@ -123,7 +113,6 @@ mediapipe_proto_library( mediapipe_proto_library( name = "collection_has_min_size_calculator_proto", srcs = ["collection_has_min_size_calculator.proto"], - visibility = ["//visibility:public"], deps = [ "//mediapipe/framework:calculator_options_proto", "//mediapipe/framework:calculator_proto", @@ -133,7 +122,6 @@ mediapipe_proto_library( mediapipe_proto_library( name = "association_calculator_proto", srcs = ["association_calculator.proto"], - visibility = ["//visibility:public"], deps = [ "//mediapipe/framework:calculator_options_proto", "//mediapipe/framework:calculator_proto", @@ -143,7 +131,6 @@ mediapipe_proto_library( cc_library( name = "packet_frequency_calculator", srcs = ["packet_frequency_calculator.cc"], - visibility = ["//visibility:public"], deps = [ "//mediapipe/calculators/util:packet_frequency_calculator_cc_proto", "//mediapipe/calculators/util:packet_frequency_cc_proto", @@ -188,7 +175,6 @@ cc_test( cc_library( name = "packet_latency_calculator", srcs = ["packet_latency_calculator.cc"], - visibility = ["//visibility:public"], deps = [ "//mediapipe/calculators/util:latency_cc_proto", "//mediapipe/calculators/util:packet_latency_calculator_cc_proto", @@ -228,9 +214,6 @@ cc_test( cc_library( name = "clock_timestamp_calculator", srcs = ["clock_timestamp_calculator.cc"], - visibility = [ - "//visibility:public", - ], deps = [ "//mediapipe/framework:calculator_framework", "//mediapipe/framework:timestamp", @@ -246,9 +229,6 @@ cc_library( cc_library( name = "clock_latency_calculator", srcs = ["clock_latency_calculator.cc"], - visibility = [ - "//visibility:public", - ], deps = [ "//mediapipe/framework:calculator_framework", "//mediapipe/framework:timestamp", @@ -263,11 +243,10 @@ cc_library( cc_library( name = "annotation_overlay_calculator", srcs = ["annotation_overlay_calculator.cc"], - visibility = ["//visibility:public"], deps = [ ":annotation_overlay_calculator_cc_proto", - "//mediapipe/framework:calculator_options_cc_proto", "//mediapipe/framework/formats:image_format_cc_proto", + "//mediapipe/framework:calculator_options_cc_proto", "//mediapipe/util:color_cc_proto", "@com_google_absl//absl/strings", "//mediapipe/framework:calculator_framework", @@ -296,7 +275,6 @@ cc_library( cc_library( name = "detection_label_id_to_text_calculator", srcs = ["detection_label_id_to_text_calculator.cc"], - visibility = ["//visibility:public"], deps = [ ":detection_label_id_to_text_calculator_cc_proto", "//mediapipe/framework/formats:detection_cc_proto", @@ -328,7 +306,6 @@ cc_library( cc_library( name = "timed_box_list_id_to_label_calculator", srcs = ["timed_box_list_id_to_label_calculator.cc"], - visibility = ["//visibility:public"], deps = [ ":timed_box_list_id_to_label_calculator_cc_proto", "@com_google_absl//absl/container:node_hash_map", @@ -357,7 +334,6 @@ cc_library( cc_library( name = "detection_transformation_calculator", srcs = ["detection_transformation_calculator.cc"], - visibility = ["//visibility:public"], deps = [ "//mediapipe/framework:calculator_framework", "//mediapipe/framework/api2:node", @@ -391,7 +367,6 @@ cc_test( cc_library( name = "non_max_suppression_calculator", srcs = ["non_max_suppression_calculator.cc"], - visibility = ["//visibility:public"], deps = [ ":non_max_suppression_calculator_cc_proto", "//mediapipe/framework:calculator_framework", @@ -408,7 +383,6 @@ cc_library( cc_library( name = "thresholding_calculator", srcs = ["thresholding_calculator.cc"], - visibility = ["//visibility:public"], deps = [ ":thresholding_calculator_cc_proto", "//mediapipe/framework:calculator_framework", @@ -421,7 +395,6 @@ cc_library( cc_library( name = "detection_to_landmarks_calculator", srcs = ["detection_to_landmarks_calculator.cc"], - visibility = ["//visibility:public"], deps = [ "//mediapipe/framework:calculator_framework", "//mediapipe/framework/formats:detection_cc_proto", @@ -436,7 +409,6 @@ cc_library( cc_library( name = "filter_detections_calculator", srcs = ["filter_detections_calculator.cc"], - visibility = ["//visibility:public"], deps = [ ":filter_detections_calculator_cc_proto", "//mediapipe/framework:calculator_framework", @@ -450,7 +422,6 @@ cc_library( cc_library( name = "landmarks_to_detection_calculator", srcs = ["landmarks_to_detection_calculator.cc"], - visibility = ["//visibility:public"], deps = [ ":landmarks_to_detection_calculator_cc_proto", "//mediapipe/framework:calculator_framework", @@ -471,7 +442,6 @@ cc_library( hdrs = [ "detections_to_rects_calculator.h", ], - visibility = ["//visibility:public"], deps = [ ":detections_to_rects_calculator_cc_proto", "//mediapipe/framework:calculator_framework", @@ -486,10 +456,26 @@ cc_library( alwayslink = 1, ) +cc_library( + name = "detections_deduplicate_calculator", + srcs = [ + "detections_deduplicate_calculator.cc", + ], + deps = [ + "//mediapipe/framework:calculator_framework", + "//mediapipe/framework/api2:node", + "//mediapipe/framework/api2:port", + "//mediapipe/framework/formats:detection_cc_proto", + "//mediapipe/framework/formats:location_data_cc_proto", + "@com_google_absl//absl/container:flat_hash_map", + "@com_google_absl//absl/status", + ], + alwayslink = 1, +) + cc_library( name = "rect_transformation_calculator", srcs = ["rect_transformation_calculator.cc"], - visibility = ["//visibility:public"], deps = [ ":rect_transformation_calculator_cc_proto", "//mediapipe/framework:calculator_framework", @@ -504,7 +490,6 @@ cc_library( cc_library( name = "rect_projection_calculator", srcs = ["rect_projection_calculator.cc"], - visibility = ["//visibility:public"], deps = [ "//mediapipe/framework:calculator_framework", "//mediapipe/framework/formats:rect_cc_proto", @@ -535,7 +520,6 @@ cc_test( mediapipe_proto_library( name = "rect_to_render_data_calculator_proto", srcs = ["rect_to_render_data_calculator.proto"], - visibility = ["//visibility:public"], deps = [ "//mediapipe/framework:calculator_options_proto", "//mediapipe/framework:calculator_proto", @@ -547,7 +531,6 @@ mediapipe_proto_library( mediapipe_proto_library( name = "rect_to_render_scale_calculator_proto", srcs = ["rect_to_render_scale_calculator.proto"], - visibility = ["//visibility:public"], deps = [ "//mediapipe/framework:calculator_options_proto", "//mediapipe/framework:calculator_proto", @@ -557,7 +540,6 @@ mediapipe_proto_library( mediapipe_proto_library( name = "detections_to_render_data_calculator_proto", srcs = ["detections_to_render_data_calculator.proto"], - visibility = ["//visibility:public"], deps = [ "//mediapipe/framework:calculator_options_proto", "//mediapipe/framework:calculator_proto", @@ -569,7 +551,6 @@ mediapipe_proto_library( mediapipe_proto_library( name = "landmarks_to_render_data_calculator_proto", srcs = ["landmarks_to_render_data_calculator.proto"], - visibility = ["//visibility:public"], deps = [ "//mediapipe/framework:calculator_options_proto", "//mediapipe/framework:calculator_proto", @@ -581,7 +562,6 @@ mediapipe_proto_library( mediapipe_proto_library( name = "timed_box_list_to_render_data_calculator_proto", srcs = ["timed_box_list_to_render_data_calculator.proto"], - visibility = ["//visibility:public"], deps = [ "//mediapipe/framework:calculator_options_proto", "//mediapipe/framework:calculator_proto", @@ -593,7 +573,6 @@ mediapipe_proto_library( mediapipe_proto_library( name = "labels_to_render_data_calculator_proto", srcs = ["labels_to_render_data_calculator.proto"], - visibility = ["//visibility:public"], deps = [ "//mediapipe/framework:calculator_options_proto", "//mediapipe/framework:calculator_proto", @@ -605,7 +584,6 @@ mediapipe_proto_library( mediapipe_proto_library( name = "thresholding_calculator_proto", srcs = ["thresholding_calculator.proto"], - visibility = ["//visibility:public"], deps = [ "//mediapipe/framework:calculator_options_proto", "//mediapipe/framework:calculator_proto", @@ -617,7 +595,6 @@ mediapipe_proto_library( mediapipe_proto_library( name = "detections_to_rects_calculator_proto", srcs = ["detections_to_rects_calculator.proto"], - visibility = ["//visibility:public"], deps = [ "//mediapipe/framework:calculator_options_proto", "//mediapipe/framework:calculator_proto", @@ -627,7 +604,6 @@ mediapipe_proto_library( mediapipe_proto_library( name = "landmark_projection_calculator_proto", srcs = ["landmark_projection_calculator.proto"], - visibility = ["//visibility:public"], deps = [ "//mediapipe/framework:calculator_options_proto", "//mediapipe/framework:calculator_proto", @@ -637,7 +613,6 @@ mediapipe_proto_library( cc_library( name = "landmark_visibility_calculator", srcs = ["landmark_visibility_calculator.cc"], - visibility = ["//visibility:public"], deps = [ "//mediapipe/framework:calculator_framework", "//mediapipe/framework/formats:landmark_cc_proto", @@ -649,7 +624,6 @@ cc_library( cc_library( name = "set_landmark_visibility_calculator", srcs = ["set_landmark_visibility_calculator.cc"], - visibility = ["//visibility:public"], deps = [ "//mediapipe/framework:calculator_framework", "//mediapipe/framework/formats:landmark_cc_proto", @@ -661,7 +635,6 @@ cc_library( mediapipe_proto_library( name = "landmarks_to_floats_calculator_proto", srcs = ["landmarks_to_floats_calculator.proto"], - visibility = ["//visibility:public"], deps = [ "//mediapipe/framework:calculator_options_proto", "//mediapipe/framework:calculator_proto", @@ -671,7 +644,6 @@ mediapipe_proto_library( mediapipe_proto_library( name = "rect_transformation_calculator_proto", srcs = ["rect_transformation_calculator.proto"], - visibility = ["//visibility:public"], deps = [ "//mediapipe/framework:calculator_options_proto", "//mediapipe/framework:calculator_proto", @@ -681,7 +653,6 @@ mediapipe_proto_library( mediapipe_proto_library( name = "landmarks_to_detection_calculator_proto", srcs = ["landmarks_to_detection_calculator.proto"], - visibility = ["//visibility:public"], deps = [ "//mediapipe/framework:calculator_options_proto", "//mediapipe/framework:calculator_proto", @@ -693,7 +664,6 @@ mediapipe_proto_library( cc_library( name = "detections_to_render_data_calculator", srcs = ["detections_to_render_data_calculator.cc"], - visibility = ["//visibility:public"], deps = [ ":detections_to_render_data_calculator_cc_proto", "//mediapipe/framework:calculator_framework", @@ -713,7 +683,6 @@ cc_library( name = "landmarks_to_render_data_calculator", srcs = ["landmarks_to_render_data_calculator.cc"], hdrs = ["landmarks_to_render_data_calculator.h"], - visibility = ["//visibility:public"], deps = [ ":landmarks_to_render_data_calculator_cc_proto", "//mediapipe/framework:calculator_framework", @@ -732,7 +701,6 @@ cc_library( cc_library( name = "timed_box_list_to_render_data_calculator", srcs = ["timed_box_list_to_render_data_calculator.cc"], - visibility = ["//visibility:public"], deps = [ ":timed_box_list_to_render_data_calculator_cc_proto", "//mediapipe/framework:calculator_framework", @@ -751,11 +719,9 @@ cc_library( cc_library( name = "labels_to_render_data_calculator", srcs = ["labels_to_render_data_calculator.cc"], - visibility = ["//visibility:public"], deps = [ ":labels_to_render_data_calculator_cc_proto", "//mediapipe/framework:calculator_framework", - "//mediapipe/framework:calculator_options_cc_proto", "//mediapipe/framework/formats:classification_cc_proto", "//mediapipe/framework/formats:video_stream_header", "//mediapipe/framework/port:ret_check", @@ -771,7 +737,6 @@ cc_library( cc_library( name = "rect_to_render_data_calculator", srcs = ["rect_to_render_data_calculator.cc"], - visibility = ["//visibility:public"], deps = [ ":rect_to_render_data_calculator_cc_proto", "//mediapipe/framework:calculator_framework", @@ -786,7 +751,6 @@ cc_library( cc_library( name = "rect_to_render_scale_calculator", srcs = ["rect_to_render_scale_calculator.cc"], - visibility = ["//visibility:public"], deps = [ ":rect_to_render_scale_calculator_cc_proto", "//mediapipe/framework:calculator_framework", @@ -821,7 +785,6 @@ cc_test( cc_library( name = "detection_letterbox_removal_calculator", srcs = ["detection_letterbox_removal_calculator.cc"], - visibility = ["//visibility:public"], deps = [ "//mediapipe/framework:calculator_framework", "//mediapipe/framework/formats:detection_cc_proto", @@ -835,7 +798,6 @@ cc_library( cc_library( name = "detection_projection_calculator", srcs = ["detection_projection_calculator.cc"], - visibility = ["//visibility:public"], deps = [ "//mediapipe/framework:calculator_framework", "//mediapipe/framework/formats:detection_cc_proto", @@ -868,7 +830,6 @@ cc_test( cc_library( name = "landmark_letterbox_removal_calculator", srcs = ["landmark_letterbox_removal_calculator.cc"], - visibility = ["//visibility:public"], deps = [ "//mediapipe/framework:calculator_framework", "//mediapipe/framework/formats:landmark_cc_proto", @@ -882,7 +843,6 @@ cc_library( cc_library( name = "landmark_projection_calculator", srcs = ["landmark_projection_calculator.cc"], - visibility = ["//visibility:public"], deps = [ ":landmark_projection_calculator_cc_proto", "//mediapipe/framework:calculator_framework", @@ -915,7 +875,6 @@ cc_test( cc_library( name = "world_landmark_projection_calculator", srcs = ["world_landmark_projection_calculator.cc"], - visibility = ["//visibility:public"], deps = [ "//mediapipe/framework:calculator_framework", "//mediapipe/framework/formats:landmark_cc_proto", @@ -929,7 +888,6 @@ cc_library( mediapipe_proto_library( name = "landmarks_smoothing_calculator_proto", srcs = ["landmarks_smoothing_calculator.proto"], - visibility = ["//visibility:public"], deps = [ "//mediapipe/framework:calculator_options_proto", "//mediapipe/framework:calculator_proto", @@ -939,7 +897,6 @@ mediapipe_proto_library( cc_library( name = "landmarks_smoothing_calculator", srcs = ["landmarks_smoothing_calculator.cc"], - visibility = ["//visibility:public"], deps = [ ":landmarks_smoothing_calculator_cc_proto", "//mediapipe/framework:calculator_framework", @@ -957,7 +914,6 @@ cc_library( mediapipe_proto_library( name = "visibility_smoothing_calculator_proto", srcs = ["visibility_smoothing_calculator.proto"], - visibility = ["//visibility:public"], deps = [ "//mediapipe/framework:calculator_options_proto", "//mediapipe/framework:calculator_proto", @@ -967,7 +923,6 @@ mediapipe_proto_library( cc_library( name = "visibility_smoothing_calculator", srcs = ["visibility_smoothing_calculator.cc"], - visibility = ["//visibility:public"], deps = [ ":visibility_smoothing_calculator_cc_proto", "//mediapipe/framework:calculator_framework", @@ -983,7 +938,6 @@ cc_library( mediapipe_proto_library( name = "visibility_copy_calculator_proto", srcs = ["visibility_copy_calculator.proto"], - visibility = ["//visibility:public"], deps = [ "//mediapipe/framework:calculator_options_proto", "//mediapipe/framework:calculator_proto", @@ -993,7 +947,6 @@ mediapipe_proto_library( cc_library( name = "visibility_copy_calculator", srcs = ["visibility_copy_calculator.cc"], - visibility = ["//visibility:public"], deps = [ ":visibility_copy_calculator_cc_proto", "//mediapipe/framework:calculator_framework", @@ -1008,7 +961,6 @@ cc_library( cc_library( name = "landmarks_to_floats_calculator", srcs = ["landmarks_to_floats_calculator.cc"], - visibility = ["//visibility:public"], deps = [ ":landmarks_to_floats_calculator_cc_proto", "//mediapipe/framework:calculator_framework", @@ -1055,7 +1007,6 @@ cc_test( mediapipe_proto_library( name = "top_k_scores_calculator_proto", srcs = ["top_k_scores_calculator.proto"], - visibility = ["//visibility:public"], deps = [ "//mediapipe/framework:calculator_options_proto", "//mediapipe/framework:calculator_proto", @@ -1065,7 +1016,6 @@ mediapipe_proto_library( cc_library( name = "top_k_scores_calculator", srcs = ["top_k_scores_calculator.cc"], - visibility = ["//visibility:public"], deps = [ ":top_k_scores_calculator_cc_proto", "@com_google_absl//absl/container:node_hash_map", @@ -1109,7 +1059,6 @@ cc_test( mediapipe_proto_library( name = "local_file_contents_calculator_proto", srcs = ["local_file_contents_calculator.proto"], - visibility = ["//visibility:public"], deps = [ "//mediapipe/framework:calculator_options_proto", "//mediapipe/framework:calculator_proto", @@ -1119,7 +1068,6 @@ mediapipe_proto_library( cc_library( name = "local_file_contents_calculator", srcs = ["local_file_contents_calculator.cc"], - visibility = ["//visibility:public"], deps = [ ":local_file_contents_calculator_cc_proto", "//mediapipe/framework:calculator_framework", @@ -1133,7 +1081,6 @@ cc_library( cc_library( name = "local_file_pattern_contents_calculator", srcs = ["local_file_pattern_contents_calculator.cc"], - visibility = ["//visibility:public"], deps = [ "//mediapipe/framework:calculator_framework", "//mediapipe/framework/port:file_helpers", @@ -1147,7 +1094,6 @@ cc_library( name = "filter_collection_calculator", srcs = ["filter_collection_calculator.cc"], hdrs = ["filter_collection_calculator.h"], - visibility = ["//visibility:public"], deps = [ "//mediapipe/framework:calculator_framework", "//mediapipe/framework/formats:classification_cc_proto", @@ -1165,7 +1111,6 @@ cc_library( name = "collection_has_min_size_calculator", srcs = ["collection_has_min_size_calculator.cc"], hdrs = ["collection_has_min_size_calculator.h"], - visibility = ["//visibility:public"], deps = [ ":collection_has_min_size_calculator_cc_proto", "//mediapipe/framework:calculator_framework", @@ -1193,7 +1138,6 @@ cc_test( cc_library( name = "association_calculator", hdrs = ["association_calculator.h"], - visibility = ["//visibility:public"], deps = [ ":association_calculator_cc_proto", "//mediapipe/framework:calculator_context", @@ -1210,7 +1154,6 @@ cc_library( cc_library( name = "association_norm_rect_calculator", srcs = ["association_norm_rect_calculator.cc"], - visibility = ["//visibility:public"], deps = [ ":association_calculator", "//mediapipe/framework:calculator_context", @@ -1225,7 +1168,6 @@ cc_library( cc_library( name = "association_detection_calculator", srcs = ["association_detection_calculator.cc"], - visibility = ["//visibility:public"], deps = [ ":association_calculator", "//mediapipe/framework:calculator_context", @@ -1260,7 +1202,6 @@ cc_test( cc_library( name = "detections_to_timed_box_list_calculator", srcs = ["detections_to_timed_box_list_calculator.cc"], - visibility = ["//visibility:public"], deps = [ "//mediapipe/framework:calculator_framework", "//mediapipe/framework/formats:detection_cc_proto", @@ -1275,7 +1216,6 @@ cc_library( cc_library( name = "detection_unique_id_calculator", srcs = ["detection_unique_id_calculator.cc"], - visibility = ["//visibility:public"], deps = [ "//mediapipe/framework:calculator_framework", "//mediapipe/framework/formats:detection_cc_proto", @@ -1288,7 +1228,6 @@ cc_library( mediapipe_proto_library( name = "logic_calculator_proto", srcs = ["logic_calculator.proto"], - visibility = ["//visibility:public"], deps = [ "//mediapipe/framework:calculator_options_proto", "//mediapipe/framework:calculator_proto", @@ -1298,7 +1237,6 @@ mediapipe_proto_library( cc_library( name = "logic_calculator", srcs = ["logic_calculator.cc"], - visibility = ["//visibility:public"], deps = [ ":logic_calculator_cc_proto", "//mediapipe/framework:calculator_framework", @@ -1311,10 +1249,9 @@ cc_library( cc_library( name = "to_image_calculator", srcs = ["to_image_calculator.cc"], - visibility = ["//visibility:public"], deps = [ - "//mediapipe/framework:calculator_options_cc_proto", "//mediapipe/framework/formats:image_format_cc_proto", + "//mediapipe/framework:calculator_options_cc_proto", "//mediapipe/framework:calculator_framework", "//mediapipe/framework/api2:node", "//mediapipe/framework/formats:image_frame", @@ -1334,10 +1271,9 @@ cc_library( cc_library( name = "from_image_calculator", srcs = ["from_image_calculator.cc"], - visibility = ["//visibility:public"], deps = [ - "//mediapipe/framework:calculator_options_cc_proto", "//mediapipe/framework/formats:image_format_cc_proto", + "//mediapipe/framework:calculator_options_cc_proto", "//mediapipe/framework:calculator_framework", "//mediapipe/framework/formats:image_frame", "//mediapipe/framework/formats:image", @@ -1386,7 +1322,6 @@ cc_test( mediapipe_proto_library( name = "refine_landmarks_from_heatmap_calculator_proto", srcs = ["refine_landmarks_from_heatmap_calculator.proto"], - visibility = ["//visibility:public"], deps = [ "//mediapipe/framework:calculator_options_proto", "//mediapipe/framework:calculator_proto", @@ -1404,7 +1339,6 @@ cc_library( ], "//conditions:default": [], }), - visibility = ["//visibility:public"], deps = [ ":refine_landmarks_from_heatmap_calculator_cc_proto", "//mediapipe/framework:calculator_framework", @@ -1455,7 +1389,6 @@ cc_library( name = "inverse_matrix_calculator", srcs = ["inverse_matrix_calculator.cc"], hdrs = ["inverse_matrix_calculator.h"], - visibility = ["//visibility:public"], deps = [ "//mediapipe/framework:calculator_framework", "//mediapipe/framework/api2:node", diff --git a/mediapipe/calculators/util/detections_deduplicate_calculator.cc b/mediapipe/calculators/util/detections_deduplicate_calculator.cc new file mode 100644 index 000000000..2dfa09028 --- /dev/null +++ b/mediapipe/calculators/util/detections_deduplicate_calculator.cc @@ -0,0 +1,114 @@ +/* Copyright 2022 The MediaPipe Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include +#include +#include +#include + +#include "absl/container/flat_hash_map.h" +#include "absl/status/status.h" +#include "mediapipe/framework/api2/node.h" +#include "mediapipe/framework/api2/port.h" +#include "mediapipe/framework/calculator_framework.h" +#include "mediapipe/framework/formats/detection.pb.h" +#include "mediapipe/framework/formats/location_data.pb.h" + +namespace mediapipe { +namespace api2 { +namespace { + +struct BoundingBoxHash { + size_t operator()(const LocationData::BoundingBox& bbox) const { + return std::hash{}(bbox.xmin()) ^ std::hash{}(bbox.ymin()) ^ + std::hash{}(bbox.width()) ^ std::hash{}(bbox.height()); + } +}; + +struct BoundingBoxEq { + bool operator()(const LocationData::BoundingBox& lhs, + const LocationData::BoundingBox& rhs) const { + return lhs.xmin() == rhs.xmin() && lhs.ymin() == rhs.ymin() && + lhs.width() == rhs.width() && lhs.height() == rhs.height(); + } +}; + +} // namespace + +// This Calculator deduplicates the bunding boxes with exactly the same +// coordinates, and folds the labels into a single Detection proto. Note +// non-maximum-suppression remove the overlapping bounding boxes within a class, +// while the deduplication operation merges bounding boxes from different +// classes. + +// Example config: +// node { +// calculator: "DetectionsDeduplicateCalculator" +// input_stream: "detections" +// output_stream: "deduplicated_detections" +// } +class DetectionsDeduplicateCalculator : public Node { + public: + static constexpr Input> kIn{""}; + static constexpr Output> kOut{""}; + + MEDIAPIPE_NODE_CONTRACT(kIn, kOut); + + absl::Status Open(mediapipe::CalculatorContext* cc) { + cc->SetOffset(::mediapipe::TimestampDiff(0)); + return absl::OkStatus(); + } + + absl::Status Process(mediapipe::CalculatorContext* cc) { + const std::vector& raw_detections = kIn(cc).Get(); + absl::flat_hash_map + bbox_to_detections; + std::vector deduplicated_detections; + for (const auto& detection : raw_detections) { + if (!detection.has_location_data() || + !detection.location_data().has_bounding_box()) { + return absl::InvalidArgumentError( + "The location data of Detections must be BoundingBox."); + } + if (bbox_to_detections.contains( + detection.location_data().bounding_box())) { + // The bbox location already exists. Merge the detection labels into + // the existing detection proto. + Detection& deduplicated_detection = + *bbox_to_detections[detection.location_data().bounding_box()]; + deduplicated_detection.mutable_score()->MergeFrom(detection.score()); + deduplicated_detection.mutable_label()->MergeFrom(detection.label()); + deduplicated_detection.mutable_label_id()->MergeFrom( + detection.label_id()); + deduplicated_detection.mutable_display_name()->MergeFrom( + detection.display_name()); + } else { + // The bbox location appears first time. Add the detection to output + // detection vector. + deduplicated_detections.push_back(detection); + bbox_to_detections[detection.location_data().bounding_box()] = + &deduplicated_detections.back(); + } + } + kOut(cc).Send(std::move(deduplicated_detections)); + return absl::OkStatus(); + } +}; + +MEDIAPIPE_REGISTER_NODE(DetectionsDeduplicateCalculator); + +} // namespace api2 +} // namespace mediapipe diff --git a/mediapipe/calculators/util/detections_to_rects_calculator.cc b/mediapipe/calculators/util/detections_to_rects_calculator.cc index 73a67d322..3e566836c 100644 --- a/mediapipe/calculators/util/detections_to_rects_calculator.cc +++ b/mediapipe/calculators/util/detections_to_rects_calculator.cc @@ -37,6 +37,9 @@ constexpr char kNormRectTag[] = "NORM_RECT"; constexpr char kRectsTag[] = "RECTS"; constexpr char kNormRectsTag[] = "NORM_RECTS"; +using ::mediapipe::NormalizedRect; +using ::mediapipe::Rect; + constexpr float kMinFloat = std::numeric_limits::lowest(); constexpr float kMaxFloat = std::numeric_limits::max(); diff --git a/mediapipe/calculators/util/detections_to_rects_calculator_test.cc b/mediapipe/calculators/util/detections_to_rects_calculator_test.cc index 6caf792a7..63de60a60 100644 --- a/mediapipe/calculators/util/detections_to_rects_calculator_test.cc +++ b/mediapipe/calculators/util/detections_to_rects_calculator_test.cc @@ -39,6 +39,9 @@ constexpr char kImageSizeTag[] = "IMAGE_SIZE"; constexpr char kRectTag[] = "RECT"; constexpr char kDetectionTag[] = "DETECTION"; +using ::mediapipe::NormalizedRect; +using ::mediapipe::Rect; + MATCHER_P4(RectEq, x_center, y_center, width, height, "") { return testing::Value(arg.x_center(), testing::Eq(x_center)) && testing::Value(arg.y_center(), testing::Eq(y_center)) && diff --git a/mediapipe/calculators/util/landmark_projection_calculator.cc b/mediapipe/calculators/util/landmark_projection_calculator.cc index e27edea66..9f276da56 100644 --- a/mediapipe/calculators/util/landmark_projection_calculator.cc +++ b/mediapipe/calculators/util/landmark_projection_calculator.cc @@ -24,6 +24,8 @@ namespace mediapipe { +using ::mediapipe::NormalizedRect; + namespace { constexpr char kLandmarksTag[] = "NORM_LANDMARKS"; diff --git a/mediapipe/calculators/util/landmarks_smoothing_calculator.cc b/mediapipe/calculators/util/landmarks_smoothing_calculator.cc index 6673816e7..7a92cfb7e 100644 --- a/mediapipe/calculators/util/landmarks_smoothing_calculator.cc +++ b/mediapipe/calculators/util/landmarks_smoothing_calculator.cc @@ -35,7 +35,9 @@ constexpr char kObjectScaleRoiTag[] = "OBJECT_SCALE_ROI"; constexpr char kNormalizedFilteredLandmarksTag[] = "NORM_FILTERED_LANDMARKS"; constexpr char kFilteredLandmarksTag[] = "FILTERED_LANDMARKS"; +using ::mediapipe::NormalizedRect; using mediapipe::OneEuroFilter; +using ::mediapipe::Rect; using mediapipe::RelativeVelocityFilter; void NormalizedLandmarksToLandmarks( diff --git a/mediapipe/calculators/util/rect_projection_calculator.cc b/mediapipe/calculators/util/rect_projection_calculator.cc index dcc6e7391..69b28af87 100644 --- a/mediapipe/calculators/util/rect_projection_calculator.cc +++ b/mediapipe/calculators/util/rect_projection_calculator.cc @@ -23,6 +23,8 @@ namespace { constexpr char kNormRectTag[] = "NORM_RECT"; constexpr char kNormReferenceRectTag[] = "NORM_REFERENCE_RECT"; +using ::mediapipe::NormalizedRect; + } // namespace // Projects rectangle from reference coordinate system (defined by reference diff --git a/mediapipe/calculators/util/rect_to_render_data_calculator.cc b/mediapipe/calculators/util/rect_to_render_data_calculator.cc index 400be277d..bbc08255e 100644 --- a/mediapipe/calculators/util/rect_to_render_data_calculator.cc +++ b/mediapipe/calculators/util/rect_to_render_data_calculator.cc @@ -29,6 +29,9 @@ constexpr char kNormRectsTag[] = "NORM_RECTS"; constexpr char kRectsTag[] = "RECTS"; constexpr char kRenderDataTag[] = "RENDER_DATA"; +using ::mediapipe::NormalizedRect; +using ::mediapipe::Rect; + RenderAnnotation::Rectangle* NewRect( const RectToRenderDataCalculatorOptions& options, RenderData* render_data) { auto* annotation = render_data->add_render_annotations(); diff --git a/mediapipe/calculators/util/rect_to_render_scale_calculator.cc b/mediapipe/calculators/util/rect_to_render_scale_calculator.cc index d94615228..85ed1db72 100644 --- a/mediapipe/calculators/util/rect_to_render_scale_calculator.cc +++ b/mediapipe/calculators/util/rect_to_render_scale_calculator.cc @@ -24,6 +24,8 @@ constexpr char kNormRectTag[] = "NORM_RECT"; constexpr char kImageSizeTag[] = "IMAGE_SIZE"; constexpr char kRenderScaleTag[] = "RENDER_SCALE"; +using ::mediapipe::NormalizedRect; + } // namespace // A calculator to get scale for RenderData primitives. @@ -78,7 +80,9 @@ absl::Status RectToRenderScaleCalculator::GetContract(CalculatorContract* cc) { cc->Inputs().Tag(kNormRectTag).Set(); cc->Inputs().Tag(kImageSizeTag).Set>(); cc->Outputs().Tag(kRenderScaleTag).Set(); - + cc->SetProcessTimestampBounds( + cc->Options() + .process_timestamp_bounds()); return absl::OkStatus(); } diff --git a/mediapipe/calculators/util/rect_to_render_scale_calculator.proto b/mediapipe/calculators/util/rect_to_render_scale_calculator.proto index dda6e2c9c..377b12412 100644 --- a/mediapipe/calculators/util/rect_to_render_scale_calculator.proto +++ b/mediapipe/calculators/util/rect_to_render_scale_calculator.proto @@ -29,4 +29,8 @@ message RectToRenderScaleCalculatorOptions { // when actual object size on the image will be `B`, than all RenderData // primitives will be scaled with factor `B/A`. optional float multiplier = 1 [default = 0.01]; + + // When true, Process is called for every new timestamp bound, with or without + // new packets. + optional bool process_timestamp_bounds = 2 [default = false]; } diff --git a/mediapipe/calculators/util/rect_transformation_calculator.cc b/mediapipe/calculators/util/rect_transformation_calculator.cc index 15bb26826..4783cb919 100644 --- a/mediapipe/calculators/util/rect_transformation_calculator.cc +++ b/mediapipe/calculators/util/rect_transformation_calculator.cc @@ -28,6 +28,9 @@ constexpr char kRectTag[] = "RECT"; constexpr char kRectsTag[] = "RECTS"; constexpr char kImageSizeTag[] = "IMAGE_SIZE"; +using ::mediapipe::NormalizedRect; +using ::mediapipe::Rect; + // Wraps around an angle in radians to within -M_PI and M_PI. inline float NormalizeRadians(float angle) { return angle - 2 * M_PI * std::floor((angle - (-M_PI)) / (2 * M_PI)); diff --git a/mediapipe/calculators/util/world_landmark_projection_calculator.cc b/mediapipe/calculators/util/world_landmark_projection_calculator.cc index bcd7352a2..e843d63bf 100644 --- a/mediapipe/calculators/util/world_landmark_projection_calculator.cc +++ b/mediapipe/calculators/util/world_landmark_projection_calculator.cc @@ -22,6 +22,8 @@ namespace mediapipe { +using ::mediapipe::NormalizedRect; + namespace { constexpr char kLandmarksTag[] = "LANDMARKS"; diff --git a/mediapipe/calculators/video/BUILD b/mediapipe/calculators/video/BUILD index 53d968151..f2b8135f2 100644 --- a/mediapipe/calculators/video/BUILD +++ b/mediapipe/calculators/video/BUILD @@ -21,19 +21,17 @@ load( licenses(["notice"]) -package(default_visibility = ["//visibility:private"]) +package(default_visibility = ["//visibility:public"]) proto_library( name = "flow_to_image_calculator_proto", srcs = ["flow_to_image_calculator.proto"], - visibility = ["//visibility:public"], deps = ["//mediapipe/framework:calculator_proto"], ) proto_library( name = "opencv_video_encoder_calculator_proto", srcs = ["opencv_video_encoder_calculator.proto"], - visibility = ["//visibility:public"], deps = ["//mediapipe/framework:calculator_proto"], ) @@ -58,7 +56,6 @@ proto_library( proto_library( name = "box_tracker_calculator_proto", srcs = ["box_tracker_calculator.proto"], - visibility = ["//visibility:public"], deps = [ "//mediapipe/framework:calculator_proto", "//mediapipe/util/tracking:box_tracker_proto", @@ -68,7 +65,6 @@ proto_library( proto_library( name = "tracked_detection_manager_calculator_proto", srcs = ["tracked_detection_manager_calculator.proto"], - visibility = ["//visibility:public"], deps = [ "//mediapipe/framework:calculator_proto", "//mediapipe/util/tracking:tracked_detection_manager_config_proto", @@ -78,7 +74,6 @@ proto_library( proto_library( name = "box_detector_calculator_proto", srcs = ["box_detector_calculator.proto"], - visibility = ["//visibility:public"], deps = [ "//mediapipe/framework:calculator_proto", "//mediapipe/util/tracking:box_detector_proto", @@ -88,7 +83,6 @@ proto_library( proto_library( name = "video_pre_stream_calculator_proto", srcs = ["video_pre_stream_calculator.proto"], - visibility = ["//visibility:public"], deps = [ "//mediapipe/framework:calculator_proto", ], @@ -101,7 +95,6 @@ mediapipe_cc_proto_library( "//mediapipe/framework:calculator_cc_proto", "//mediapipe/util/tracking:motion_analysis_cc_proto", ], - visibility = ["//visibility:public"], deps = [":motion_analysis_calculator_proto"], ) @@ -112,7 +105,6 @@ mediapipe_cc_proto_library( "//mediapipe/framework:calculator_cc_proto", "//mediapipe/util/tracking:flow_packager_cc_proto", ], - visibility = ["//visibility:public"], deps = [":flow_packager_calculator_proto"], ) @@ -123,7 +115,6 @@ mediapipe_cc_proto_library( "//mediapipe/framework:calculator_cc_proto", "//mediapipe/util/tracking:box_tracker_cc_proto", ], - visibility = ["//visibility:public"], deps = [":box_tracker_calculator_proto"], ) @@ -134,7 +125,6 @@ mediapipe_cc_proto_library( "//mediapipe/framework:calculator_cc_proto", "//mediapipe/util/tracking:tracked_detection_manager_config_cc_proto", ], - visibility = ["//visibility:public"], deps = [":tracked_detection_manager_calculator_proto"], ) @@ -145,7 +135,6 @@ mediapipe_cc_proto_library( "//mediapipe/framework:calculator_cc_proto", "//mediapipe/util/tracking:box_detector_cc_proto", ], - visibility = ["//visibility:public"], deps = [":box_detector_calculator_proto"], ) @@ -155,7 +144,6 @@ mediapipe_cc_proto_library( cc_deps = [ "//mediapipe/framework:calculator_cc_proto", ], - visibility = ["//visibility:public"], deps = [":video_pre_stream_calculator_proto"], ) @@ -163,7 +151,6 @@ mediapipe_cc_proto_library( name = "flow_to_image_calculator_cc_proto", srcs = ["flow_to_image_calculator.proto"], cc_deps = ["//mediapipe/framework:calculator_cc_proto"], - visibility = ["//visibility:public"], deps = [":flow_to_image_calculator_proto"], ) @@ -171,14 +158,12 @@ mediapipe_cc_proto_library( name = "opencv_video_encoder_calculator_cc_proto", srcs = ["opencv_video_encoder_calculator.proto"], cc_deps = ["//mediapipe/framework:calculator_cc_proto"], - visibility = ["//visibility:public"], deps = [":opencv_video_encoder_calculator_proto"], ) cc_library( name = "flow_to_image_calculator", srcs = ["flow_to_image_calculator.cc"], - visibility = ["//visibility:public"], deps = [ ":flow_to_image_calculator_cc_proto", "//mediapipe/calculators/video/tool:flow_quantizer_model", @@ -198,7 +183,6 @@ cc_library( cc_library( name = "opencv_video_decoder_calculator", srcs = ["opencv_video_decoder_calculator.cc"], - visibility = ["//visibility:public"], deps = [ "//mediapipe/framework:calculator_framework", "//mediapipe/framework/formats:image_format_cc_proto", @@ -217,7 +201,6 @@ cc_library( cc_library( name = "opencv_video_encoder_calculator", srcs = ["opencv_video_encoder_calculator.cc"], - visibility = ["//visibility:public"], deps = [ ":opencv_video_encoder_calculator_cc_proto", "//mediapipe/framework:calculator_framework", @@ -240,7 +223,6 @@ cc_library( cc_library( name = "tvl1_optical_flow_calculator", srcs = ["tvl1_optical_flow_calculator.cc"], - visibility = ["//visibility:public"], deps = [ "//mediapipe/framework:calculator_framework", "//mediapipe/framework/formats:image_frame", @@ -256,7 +238,6 @@ cc_library( cc_library( name = "motion_analysis_calculator", srcs = ["motion_analysis_calculator.cc"], - visibility = ["//visibility:public"], deps = [ ":motion_analysis_calculator_cc_proto", "//mediapipe/framework:calculator_framework", @@ -282,7 +263,6 @@ cc_library( cc_library( name = "flow_packager_calculator", srcs = ["flow_packager_calculator.cc"], - visibility = ["//visibility:public"], deps = [ ":flow_packager_calculator_cc_proto", "//mediapipe/framework:calculator_framework", @@ -300,7 +280,6 @@ cc_library( cc_library( name = "box_tracker_calculator", srcs = ["box_tracker_calculator.cc"], - visibility = ["//visibility:public"], deps = [ ":box_tracker_calculator_cc_proto", "//mediapipe/framework:calculator_framework", @@ -327,7 +306,6 @@ cc_library( cc_library( name = "box_detector_calculator", srcs = ["box_detector_calculator.cc"], - visibility = ["//visibility:public"], deps = [ ":box_detector_calculator_cc_proto", "@com_google_absl//absl/memory", @@ -342,12 +320,12 @@ cc_library( "//mediapipe/framework/port:opencv_features2d", "//mediapipe/framework/port:ret_check", "//mediapipe/framework/port:status", + "//mediapipe/util/tracking:box_tracker_cc_proto", + "//mediapipe/util/tracking:flow_packager_cc_proto", "//mediapipe/util:resource_util", "//mediapipe/util/tracking", "//mediapipe/util/tracking:box_detector", "//mediapipe/util/tracking:box_tracker", - "//mediapipe/util/tracking:box_tracker_cc_proto", - "//mediapipe/util/tracking:flow_packager_cc_proto", "//mediapipe/util/tracking:tracking_visualization_utilities", ] + select({ "//mediapipe:android": [ @@ -369,7 +347,6 @@ cc_library( cc_library( name = "tracked_detection_manager_calculator", srcs = ["tracked_detection_manager_calculator.cc"], - visibility = ["//visibility:public"], deps = [ ":tracked_detection_manager_calculator_cc_proto", "//mediapipe/framework:calculator_framework", @@ -390,7 +367,6 @@ cc_library( cc_library( name = "video_pre_stream_calculator", srcs = ["video_pre_stream_calculator.cc"], - visibility = ["//visibility:public"], deps = [ ":video_pre_stream_calculator_cc_proto", "//mediapipe/framework:calculator_framework", @@ -407,7 +383,6 @@ filegroup( "testdata/format_MKV_VP8_VORBIS.video", "testdata/format_MP4_AVC720P_AAC.video", ], - visibility = ["//visibility:public"], ) cc_test( @@ -480,7 +455,6 @@ mediapipe_binary_graph( name = "parallel_tracker_binarypb", graph = "testdata/parallel_tracker_graph.pbtxt", output_name = "testdata/parallel_tracker.binarypb", - visibility = ["//visibility:public"], deps = [ ":box_tracker_calculator", ":flow_packager_calculator", @@ -494,7 +468,6 @@ mediapipe_binary_graph( name = "tracker_binarypb", graph = "testdata/tracker_graph.pbtxt", output_name = "testdata/tracker.binarypb", - visibility = ["//visibility:public"], deps = [ ":box_tracker_calculator", ":flow_packager_calculator", diff --git a/mediapipe/calculators/video/tracked_detection_manager_calculator.cc b/mediapipe/calculators/video/tracked_detection_manager_calculator.cc index c416fa9b0..48664fead 100644 --- a/mediapipe/calculators/video/tracked_detection_manager_calculator.cc +++ b/mediapipe/calculators/video/tracked_detection_manager_calculator.cc @@ -32,6 +32,8 @@ namespace mediapipe { namespace { +using ::mediapipe::NormalizedRect; + constexpr int kDetectionUpdateTimeOutMS = 5000; constexpr char kDetectionsTag[] = "DETECTIONS"; constexpr char kDetectionBoxesTag[] = "DETECTION_BOXES"; diff --git a/mediapipe/examples/android/solutions/gradle/wrapper/gradle-wrapper.properties b/mediapipe/examples/android/solutions/gradle/wrapper/gradle-wrapper.properties index 41dfb8790..070cb702f 100644 --- a/mediapipe/examples/android/solutions/gradle/wrapper/gradle-wrapper.properties +++ b/mediapipe/examples/android/solutions/gradle/wrapper/gradle-wrapper.properties @@ -1,5 +1,5 @@ distributionBase=GRADLE_USER_HOME distributionPath=wrapper/dists -distributionUrl=https\://services.gradle.org/distributions/gradle-7.4-bin.zip +distributionUrl=https\://services.gradle.org/distributions/gradle-7.6-bin.zip zipStoreBase=GRADLE_USER_HOME zipStorePath=wrapper/dists diff --git a/mediapipe/examples/android/src/java/com/google/mediapipe/apps/instantmotiontracking/GIFEditText.java b/mediapipe/examples/android/src/java/com/google/mediapipe/apps/instantmotiontracking/GIFEditText.java index 10e6422ba..1b733ed82 100644 --- a/mediapipe/examples/android/src/java/com/google/mediapipe/apps/instantmotiontracking/GIFEditText.java +++ b/mediapipe/examples/android/src/java/com/google/mediapipe/apps/instantmotiontracking/GIFEditText.java @@ -18,7 +18,7 @@ import android.content.ClipDescription; import android.content.Context; import android.net.Uri; import android.os.Bundle; -import android.support.v7.widget.AppCompatEditText; +import androidx.appcompat.widget.AppCompatEditText; import android.util.AttributeSet; import android.util.Log; import android.view.inputmethod.EditorInfo; diff --git a/mediapipe/examples/desktop/autoflip/BUILD b/mediapipe/examples/desktop/autoflip/BUILD index 562f11c49..340205caa 100644 --- a/mediapipe/examples/desktop/autoflip/BUILD +++ b/mediapipe/examples/desktop/autoflip/BUILD @@ -18,6 +18,8 @@ licenses(["notice"]) package(default_visibility = [ "//mediapipe/examples:__subpackages__", + "//photos/editing/mobile/mediapipe/calculators:__subpackages__", + "//photos/editing/mobile/mediapipe/proto:__subpackages__", ]) proto_library( @@ -30,6 +32,10 @@ proto_library( java_lite_proto_library( name = "autoflip_messages_java_proto_lite", + visibility = [ + "//java/com/google/android/apps/photos:__subpackages__", + "//javatests/com/google/android/apps/photos:__subpackages__", + ], deps = [ ":autoflip_messages_proto", ], @@ -41,6 +47,8 @@ mediapipe_cc_proto_library( cc_deps = ["//mediapipe/framework:calculator_cc_proto"], visibility = [ "//mediapipe/examples:__subpackages__", + "//photos/editing/mobile/mediapipe/calculators:__pkg__", + "//photos/editing/mobile/mediapipe/calculators:__subpackages__", ], deps = [":autoflip_messages_proto"], ) diff --git a/mediapipe/examples/desktop/autoflip/autoflip_messages.proto b/mediapipe/examples/desktop/autoflip/autoflip_messages.proto index 8507c9ad7..c89a6aea6 100644 --- a/mediapipe/examples/desktop/autoflip/autoflip_messages.proto +++ b/mediapipe/examples/desktop/autoflip/autoflip_messages.proto @@ -185,6 +185,10 @@ message ExternalRenderFrame { // original dimensions of the input video. The first step to render this // frame is to crop this rect from the input frame. optional Rect crop_from_location = 1; + // Rect that must be cropped out of the input frame. It is defined in the + // ratio of the frame of the input video. The first step to render this frame + // is to crop this rect from the input frame. + optional Rect normalized_crop_from_location = 7; // The placement location where the above rect is placed on the output frame. // This will always have the same aspect ratio as the above rect but scaling // may be required. diff --git a/mediapipe/examples/desktop/autoflip/calculators/scene_cropping_calculator.cc b/mediapipe/examples/desktop/autoflip/calculators/scene_cropping_calculator.cc index 89170dc6a..7e286b743 100644 --- a/mediapipe/examples/desktop/autoflip/calculators/scene_cropping_calculator.cc +++ b/mediapipe/examples/desktop/autoflip/calculators/scene_cropping_calculator.cc @@ -201,13 +201,26 @@ absl::Status ParseAspectRatioString(const std::string& aspect_ratio_string, void ConstructExternalRenderMessage( const cv::Rect& crop_from_location, const cv::Rect& render_to_location, const cv::Scalar& padding_color, const uint64 timestamp_us, - ExternalRenderFrame* external_render_message) { + ExternalRenderFrame* external_render_message, int frame_width, + int frame_height) { auto crop_from_message = external_render_message->mutable_crop_from_location(); crop_from_message->set_x(crop_from_location.x); crop_from_message->set_y(crop_from_location.y); crop_from_message->set_width(crop_from_location.width); crop_from_message->set_height(crop_from_location.height); + + auto normalized_crop_from_message = + external_render_message->mutable_normalized_crop_from_location(); + normalized_crop_from_message->set_x(crop_from_location.x / + static_cast(frame_width)); + normalized_crop_from_message->set_y(crop_from_location.y / + static_cast(frame_height)); + normalized_crop_from_message->set_width(crop_from_location.width / + static_cast(frame_width)); + normalized_crop_from_message->set_height(crop_from_location.height / + static_cast(frame_height)); + auto render_to_message = external_render_message->mutable_render_to_location(); render_to_message->set_x(render_to_location.x); @@ -627,7 +640,8 @@ absl::Status SceneCroppingCalculator::ProcessScene(const bool is_end_of_scene, auto external_render_message = absl::make_unique(); ConstructExternalRenderMessage( crop_from_locations[i], render_to_locations[i], padding_colors[i], - scene_frame_timestamps_[i], external_render_message.get()); + scene_frame_timestamps_[i], external_render_message.get(), + frame_width_, frame_height_); cc->Outputs() .Tag(kExternalRenderingPerFrame) .Add(external_render_message.release(), @@ -640,7 +654,8 @@ absl::Status SceneCroppingCalculator::ProcessScene(const bool is_end_of_scene, ExternalRenderFrame render_frame; ConstructExternalRenderMessage(crop_from_locations[i], render_to_locations[i], padding_colors[i], - scene_frame_timestamps_[i], &render_frame); + scene_frame_timestamps_[i], &render_frame, + frame_width_, frame_height_); external_render_list_->push_back(render_frame); } } diff --git a/mediapipe/examples/desktop/autoflip/calculators/scene_cropping_calculator_test.cc b/mediapipe/examples/desktop/autoflip/calculators/scene_cropping_calculator_test.cc index 88728860a..c3285ea58 100644 --- a/mediapipe/examples/desktop/autoflip/calculators/scene_cropping_calculator_test.cc +++ b/mediapipe/examples/desktop/autoflip/calculators/scene_cropping_calculator_test.cc @@ -920,6 +920,41 @@ TEST(SceneCroppingCalculatorTest, OutputsCropMessageKinematicPathNoVideo) { EXPECT_EQ(ext_render_message.render_to_location().height(), 1124); } } + +// Checks external render message with default poly path solver using +// normalized crops. +TEST(SceneCroppingCalculatorTest, OutputsCropMessagePolyPathNormalized) { + const CalculatorGraphConfig::Node config = + ParseTextProtoOrDie( + absl::Substitute(kExternalRenderConfig, kTargetWidth, kTargetHeight)); + auto runner = absl::make_unique(config); + const int num_frames = kSceneSize; + AddScene(0, num_frames, kInputFrameWidth, kInputFrameHeight, kKeyFrameWidth, + kKeyFrameHeight, 1, runner->MutableInputs()); + + MP_EXPECT_OK(runner->Run()); + const auto& outputs = runner->Outputs(); + const auto& ext_render_per_frame = + outputs.Tag(kExternalRenderingPerFrameTag).packets; + EXPECT_EQ(ext_render_per_frame.size(), num_frames); + + for (int i = 0; i < num_frames - 1; ++i) { + const auto& ext_render_message = + ext_render_per_frame[i].Get(); + EXPECT_EQ(ext_render_message.timestamp_us(), i * 20000); + EXPECT_EQ(ext_render_message.normalized_crop_from_location().x(), + 725 / static_cast(kInputFrameWidth)); + EXPECT_EQ(ext_render_message.normalized_crop_from_location().y(), 0); + EXPECT_EQ(ext_render_message.normalized_crop_from_location().width(), + 461 / static_cast(kInputFrameWidth)); + EXPECT_EQ(ext_render_message.normalized_crop_from_location().height(), + 720 / static_cast(kInputFrameHeight)); + EXPECT_EQ(ext_render_message.render_to_location().x(), 0); + EXPECT_EQ(ext_render_message.render_to_location().y(), 0); + EXPECT_EQ(ext_render_message.render_to_location().width(), 720); + EXPECT_EQ(ext_render_message.render_to_location().height(), 1124); + } +} } // namespace } // namespace autoflip } // namespace mediapipe diff --git a/mediapipe/examples/desktop/hello_world/BUILD b/mediapipe/examples/desktop/hello_world/BUILD index edf98bf13..27aa088e7 100644 --- a/mediapipe/examples/desktop/hello_world/BUILD +++ b/mediapipe/examples/desktop/hello_world/BUILD @@ -14,12 +14,11 @@ licenses(["notice"]) -package(default_visibility = ["//mediapipe/examples:__subpackages__"]) +package(default_visibility = ["//visibility:public"]) cc_binary( name = "hello_world", srcs = ["hello_world.cc"], - visibility = ["//visibility:public"], deps = [ "//mediapipe/calculators/core:pass_through_calculator", "//mediapipe/framework:calculator_graph", diff --git a/mediapipe/examples/ios/common/BUILD b/mediapipe/examples/ios/common/BUILD index 9b8f8a968..bfa770cec 100644 --- a/mediapipe/examples/ios/common/BUILD +++ b/mediapipe/examples/ios/common/BUILD @@ -29,12 +29,6 @@ objc_library( "Base.lproj/LaunchScreen.storyboard", "Base.lproj/Main.storyboard", ], - sdk_frameworks = [ - "AVFoundation", - "CoreGraphics", - "CoreMedia", - "UIKit", - ], visibility = [ "//mediapipe:__subpackages__", ], @@ -42,6 +36,10 @@ objc_library( "//mediapipe/objc:mediapipe_framework_ios", "//mediapipe/objc:mediapipe_input_sources_ios", "//mediapipe/objc:mediapipe_layer_renderer", + "//third_party/apple_frameworks:AVFoundation", + "//third_party/apple_frameworks:CoreGraphics", + "//third_party/apple_frameworks:CoreMedia", + "//third_party/apple_frameworks:UIKit", ], ) diff --git a/mediapipe/examples/ios/faceeffect/BUILD b/mediapipe/examples/ios/faceeffect/BUILD index 50a6f68bd..7d3a75cc6 100644 --- a/mediapipe/examples/ios/faceeffect/BUILD +++ b/mediapipe/examples/ios/faceeffect/BUILD @@ -73,13 +73,13 @@ objc_library( "//mediapipe/modules/face_landmark:face_landmark.tflite", ], features = ["-layering_check"], - sdk_frameworks = [ - "AVFoundation", - "CoreGraphics", - "CoreMedia", - "UIKit", - ], deps = [ + "//mediapipe/framework/formats:matrix_data_cc_proto", + "//third_party/apple_frameworks:AVFoundation", + "//third_party/apple_frameworks:CoreGraphics", + "//third_party/apple_frameworks:CoreMedia", + "//third_party/apple_frameworks:UIKit", + "//mediapipe/modules/face_geometry/protos:face_geometry_cc_proto", "//mediapipe/objc:mediapipe_framework_ios", "//mediapipe/objc:mediapipe_input_sources_ios", "//mediapipe/objc:mediapipe_layer_renderer", @@ -87,9 +87,7 @@ objc_library( "//mediapipe:ios_i386": [], "//mediapipe:ios_x86_64": [], "//conditions:default": [ - "//mediapipe/framework/formats:matrix_data_cc_proto", "//mediapipe/graphs/face_effect:face_effect_gpu_deps", - "//mediapipe/modules/face_geometry/protos:face_geometry_cc_proto", ], }), ) diff --git a/mediapipe/examples/ios/facemeshgpu/BUILD b/mediapipe/examples/ios/facemeshgpu/BUILD index 02103ce2f..6caf8c09c 100644 --- a/mediapipe/examples/ios/facemeshgpu/BUILD +++ b/mediapipe/examples/ios/facemeshgpu/BUILD @@ -67,12 +67,12 @@ objc_library( ], deps = [ "//mediapipe/examples/ios/common:CommonMediaPipeAppLibrary", + "//mediapipe/framework/formats:landmark_cc_proto", ] + select({ "//mediapipe:ios_i386": [], "//mediapipe:ios_x86_64": [], "//conditions:default": [ "//mediapipe/graphs/face_mesh:mobile_calculators", - "//mediapipe/framework/formats:landmark_cc_proto", ], }), ) diff --git a/mediapipe/examples/ios/handtrackinggpu/BUILD b/mediapipe/examples/ios/handtrackinggpu/BUILD index 647b7670a..c5b8e7b58 100644 --- a/mediapipe/examples/ios/handtrackinggpu/BUILD +++ b/mediapipe/examples/ios/handtrackinggpu/BUILD @@ -68,12 +68,12 @@ objc_library( ], deps = [ "//mediapipe/examples/ios/common:CommonMediaPipeAppLibrary", + "//mediapipe/framework/formats:landmark_cc_proto", ] + select({ "//mediapipe:ios_i386": [], "//mediapipe:ios_x86_64": [], "//conditions:default": [ "//mediapipe/graphs/hand_tracking:mobile_calculators", - "//mediapipe/framework/formats:landmark_cc_proto", ], }), ) diff --git a/mediapipe/examples/ios/iristrackinggpu/BUILD b/mediapipe/examples/ios/iristrackinggpu/BUILD index 056447d63..646d2e5a2 100644 --- a/mediapipe/examples/ios/iristrackinggpu/BUILD +++ b/mediapipe/examples/ios/iristrackinggpu/BUILD @@ -68,12 +68,12 @@ objc_library( ], deps = [ "//mediapipe/examples/ios/common:CommonMediaPipeAppLibrary", + "//mediapipe/framework/formats:landmark_cc_proto", ] + select({ "//mediapipe:ios_i386": [], "//mediapipe:ios_x86_64": [], "//conditions:default": [ "//mediapipe/graphs/iris_tracking:iris_tracking_gpu_deps", - "//mediapipe/framework/formats:landmark_cc_proto", ], }), ) diff --git a/mediapipe/examples/ios/posetrackinggpu/BUILD b/mediapipe/examples/ios/posetrackinggpu/BUILD index 86b41ed36..4fbc2280c 100644 --- a/mediapipe/examples/ios/posetrackinggpu/BUILD +++ b/mediapipe/examples/ios/posetrackinggpu/BUILD @@ -67,12 +67,12 @@ objc_library( ], deps = [ "//mediapipe/examples/ios/common:CommonMediaPipeAppLibrary", + "//mediapipe/framework/formats:landmark_cc_proto", ] + select({ "//mediapipe:ios_i386": [], "//mediapipe:ios_x86_64": [], "//conditions:default": [ "//mediapipe/graphs/pose_tracking:pose_tracking_gpu_deps", - "//mediapipe/framework/formats:landmark_cc_proto", ], }), ) diff --git a/mediapipe/framework/BUILD b/mediapipe/framework/BUILD index 19c51853c..e082ef2e6 100644 --- a/mediapipe/framework/BUILD +++ b/mediapipe/framework/BUILD @@ -1,4 +1,3 @@ -# # Copyright 2019 The MediaPipe Authors. # # Licensed under the Apache License, Version 2.0 (the "License"); @@ -21,6 +20,7 @@ licenses(["notice"]) package(default_visibility = ["//visibility:private"]) +# The MediaPipe internal package group. No mediapipe users should be added to this group. package_group( name = "mediapipe_internal", packages = [ @@ -56,12 +56,12 @@ mediapipe_proto_library( srcs = ["calculator.proto"], visibility = ["//visibility:public"], deps = [ - "//mediapipe/framework:calculator_options_proto", - "//mediapipe/framework:mediapipe_options_proto", - "//mediapipe/framework:packet_factory_proto", - "//mediapipe/framework:packet_generator_proto", - "//mediapipe/framework:status_handler_proto", - "//mediapipe/framework:stream_handler_proto", + ":calculator_options_proto", + ":mediapipe_options_proto", + ":packet_factory_proto", + ":packet_generator_proto", + ":status_handler_proto", + ":stream_handler_proto", "@com_google_protobuf//:any_proto", ], ) @@ -78,8 +78,8 @@ mediapipe_proto_library( srcs = ["calculator_contract_test.proto"], visibility = ["//mediapipe/framework:__subpackages__"], deps = [ - "//mediapipe/framework:calculator_options_proto", - "//mediapipe/framework:calculator_proto", + ":calculator_options_proto", + ":calculator_proto", ], ) @@ -88,15 +88,17 @@ mediapipe_proto_library( srcs = ["calculator_profile.proto"], visibility = ["//visibility:public"], deps = [ - "//mediapipe/framework:calculator_options_proto", - "//mediapipe/framework:calculator_proto", + ":calculator_options_proto", + ":calculator_proto", ], ) mediapipe_proto_library( name = "mediapipe_options_proto", srcs = ["mediapipe_options.proto"], - visibility = [":mediapipe_internal"], + visibility = [ + ":mediapipe_internal", + ], ) mediapipe_proto_library( @@ -125,24 +127,24 @@ mediapipe_proto_library( name = "status_handler_proto", srcs = ["status_handler.proto"], visibility = [":mediapipe_internal"], - deps = ["//mediapipe/framework:mediapipe_options_proto"], + deps = [":mediapipe_options_proto"], ) mediapipe_proto_library( name = "stream_handler_proto", srcs = ["stream_handler.proto"], visibility = [":mediapipe_internal"], - deps = ["//mediapipe/framework:mediapipe_options_proto"], + deps = [":mediapipe_options_proto"], ) mediapipe_proto_library( name = "test_calculators_proto", testonly = 1, srcs = ["test_calculators.proto"], - visibility = ["//visibility:public"], + visibility = [":mediapipe_internal"], deps = [ - "//mediapipe/framework:calculator_options_proto", - "//mediapipe/framework:calculator_proto", + ":calculator_options_proto", + ":calculator_proto", ], ) @@ -150,7 +152,7 @@ mediapipe_proto_library( name = "thread_pool_executor_proto", srcs = ["thread_pool_executor.proto"], visibility = [":mediapipe_internal"], - deps = ["//mediapipe/framework:mediapipe_options_proto"], + deps = [":mediapipe_options_proto"], ) # It is for pure-native Android builds where the library can't have any dependency on libandroid.so @@ -226,13 +228,13 @@ cc_library( ":mediapipe_internal", ], deps = [ + ":calculator_cc_proto", ":graph_service", + ":mediapipe_options_cc_proto", + ":packet_generator_cc_proto", ":packet_type", ":port", - "//mediapipe/framework:calculator_cc_proto", - "//mediapipe/framework:mediapipe_options_cc_proto", - "//mediapipe/framework:packet_generator_cc_proto", - "//mediapipe/framework:status_handler_cc_proto", + ":status_handler_cc_proto", "//mediapipe/framework/port:any_proto", "//mediapipe/framework/port:status", "//mediapipe/framework/tool:options_map", @@ -328,10 +330,10 @@ cc_library( ":thread_pool_executor", ":timestamp", ":validated_graph_config", - "//mediapipe/framework:calculator_cc_proto", - "//mediapipe/framework:packet_generator_cc_proto", - "//mediapipe/framework:status_handler_cc_proto", - "//mediapipe/framework:thread_pool_executor_cc_proto", + ":calculator_cc_proto", + ":packet_generator_cc_proto", + ":status_handler_cc_proto", + ":thread_pool_executor_cc_proto", "@com_google_absl//absl/base:core_headers", "@com_google_absl//absl/container:fixed_array", "@com_google_absl//absl/container:flat_hash_map", @@ -369,7 +371,7 @@ cc_library( visibility = [":mediapipe_internal"], deps = [ ":graph_service", - "//mediapipe/framework:packet", + ":packet", "@com_google_absl//absl/status", ], ) @@ -379,7 +381,7 @@ cc_test( srcs = ["graph_service_manager_test.cc"], deps = [ ":graph_service_manager", - "//mediapipe/framework:packet", + ":packet", "//mediapipe/framework/port:gtest_main", ], ) @@ -391,6 +393,7 @@ cc_library( visibility = [":mediapipe_internal"], deps = [ ":calculator_base", + ":calculator_cc_proto", ":calculator_context", ":calculator_context_manager", ":calculator_state", @@ -407,10 +410,9 @@ cc_library( ":packet_set", ":packet_type", ":port", + ":stream_handler_cc_proto", ":timestamp", ":validated_graph_config", - "//mediapipe/framework:calculator_cc_proto", - "//mediapipe/framework:stream_handler_cc_proto", "//mediapipe/framework/port:core_proto", "//mediapipe/framework/port:integral_types", "//mediapipe/framework/port:logging", @@ -466,6 +468,7 @@ cc_library( hdrs = ["calculator_state.h"], visibility = [":mediapipe_internal"], deps = [ + ":calculator_cc_proto", ":counter", ":counter_factory", ":graph_service", @@ -475,7 +478,6 @@ cc_library( ":packet", ":packet_set", ":port", - "//mediapipe/framework:calculator_cc_proto", "//mediapipe/framework/port:any_proto", "//mediapipe/framework/port:logging", "//mediapipe/framework/tool:options_map", @@ -583,7 +585,7 @@ cc_library( hdrs = ["executor.h"], visibility = ["//visibility:public"], deps = [ - "//mediapipe/framework:mediapipe_options_cc_proto", + ":mediapipe_options_cc_proto", "//mediapipe/framework/deps:registration", "//mediapipe/framework/port:status", "//mediapipe/framework/port:statusor", @@ -670,11 +672,11 @@ cc_library( ":collection_item_id", ":input_stream_manager", ":input_stream_shard", + ":mediapipe_options_cc_proto", ":mediapipe_profiling", ":packet", ":packet_set", ":packet_type", - "//mediapipe/framework:mediapipe_options_cc_proto", "//mediapipe/framework/deps:registration", "//mediapipe/framework/port:ret_check", "//mediapipe/framework/port:status", @@ -784,12 +786,12 @@ cc_library( ":calculator_context_manager", ":collection", ":collection_item_id", + ":mediapipe_options_cc_proto", ":output_stream_manager", ":output_stream_shard", ":packet_set", ":packet_type", ":timestamp", - "//mediapipe/framework:mediapipe_options_cc_proto", "//mediapipe/framework/deps:registration", "//mediapipe/framework/port:logging", "//mediapipe/framework/port:status", @@ -875,10 +877,10 @@ cc_library( visibility = ["//visibility:public"], deps = [ ":packet", + ":packet_generator_cc_proto", ":packet_set", ":packet_type", ":port", - "//mediapipe/framework:packet_generator_cc_proto", "//mediapipe/framework/deps:registration", "//mediapipe/framework/port:core_proto", "//mediapipe/framework/port:status", @@ -896,13 +898,13 @@ cc_library( ":delegating_executor", ":executor", ":packet", + ":packet_factory_cc_proto", ":packet_generator", + ":packet_generator_cc_proto", ":packet_type", ":port", ":thread_pool_executor", ":validated_graph_config", - "//mediapipe/framework:packet_factory_cc_proto", - "//mediapipe/framework:packet_generator_cc_proto", "//mediapipe/framework/port:core_proto", "//mediapipe/framework/port:logging", "//mediapipe/framework/port:ret_check", @@ -1019,10 +1021,10 @@ cc_library( hdrs = ["status_handler.h"], visibility = ["//visibility:public"], deps = [ + ":mediapipe_options_cc_proto", ":packet_set", ":packet_type", ":port", - "//mediapipe/framework:mediapipe_options_cc_proto", "//mediapipe/framework/deps:registration", "//mediapipe/framework/port:status", "@com_google_absl//absl/memory", @@ -1035,11 +1037,10 @@ cc_library( hdrs = ["subgraph.h"], visibility = ["//visibility:public"], deps = [ + ":calculator_cc_proto", ":graph_service", ":graph_service_manager", ":port", - "//mediapipe/framework:calculator_cc_proto", - "//mediapipe/framework:mediapipe_options_cc_proto", "//mediapipe/framework/deps:registration", "//mediapipe/framework/port:ret_check", "//mediapipe/framework/port:status", @@ -1050,6 +1051,7 @@ cc_library( "@com_google_absl//absl/base:core_headers", "@com_google_absl//absl/memory", "@com_google_absl//absl/status", + "@com_google_absl//absl/strings", "@com_google_absl//absl/types:optional", ], ) @@ -1061,7 +1063,7 @@ cc_library( visibility = ["//visibility:public"], deps = [ ":calculator_framework", - "//mediapipe/framework:test_calculators_cc_proto", + ":test_calculators_cc_proto", "//mediapipe/framework/deps:mathutil", "//mediapipe/framework/formats:matrix", "//mediapipe/framework/port:integral_types", @@ -1098,7 +1100,7 @@ cc_library( visibility = ["//visibility:public"], deps = [ ":executor", - "//mediapipe/framework:thread_pool_executor_cc_proto", + ":thread_pool_executor_cc_proto", "//mediapipe/framework/deps:thread_options", "//mediapipe/framework/port:logging", "//mediapipe/framework/port:status", @@ -1163,22 +1165,22 @@ cc_library( visibility = ["//visibility:public"], deps = [ ":calculator_base", + ":calculator_cc_proto", ":calculator_contract", ":graph_service_manager", ":legacy_calculator_support", ":packet", ":packet_generator", + ":packet_generator_cc_proto", ":packet_set", ":packet_type", ":port", ":status_handler", + ":status_handler_cc_proto", + ":stream_handler_cc_proto", ":subgraph", + ":thread_pool_executor_cc_proto", ":timestamp", - "//mediapipe/framework:calculator_cc_proto", - "//mediapipe/framework:packet_generator_cc_proto", - "//mediapipe/framework:status_handler_cc_proto", - "//mediapipe/framework:stream_handler_cc_proto", - "//mediapipe/framework:thread_pool_executor_cc_proto", "//mediapipe/framework/port:core_proto", "//mediapipe/framework/port:integral_types", "//mediapipe/framework/port:logging", @@ -1203,11 +1205,11 @@ cc_test( name = "validated_graph_config_test", srcs = ["validated_graph_config_test.cc"], deps = [ + ":calculator_cc_proto", ":calculator_framework", ":graph_service", ":graph_service_manager", ":validated_graph_config", - "//mediapipe/framework:calculator_cc_proto", "//mediapipe/framework/api2:node", "//mediapipe/framework/api2:port", "//mediapipe/framework/port:gtest_main", @@ -1234,6 +1236,7 @@ cc_test( linkstatic = 1, deps = [ ":calculator_base", + ":calculator_cc_proto", ":calculator_context", ":calculator_context_manager", ":calculator_registry", @@ -1243,7 +1246,6 @@ cc_test( ":output_stream_shard", ":packet_set", ":packet_type", - "//mediapipe/framework:calculator_cc_proto", "//mediapipe/framework/port:gtest_main", "//mediapipe/framework/port:status", "//mediapipe/framework/tool:status_util", @@ -1257,11 +1259,11 @@ cc_test( srcs = ["calculator_contract_test.cc"], linkstatic = 1, deps = [ + ":calculator_cc_proto", ":calculator_contract", ":calculator_contract_test_cc_proto", - "//mediapipe/framework:calculator_cc_proto", - "//mediapipe/framework:packet_generator_cc_proto", - "//mediapipe/framework:status_handler_cc_proto", + ":packet_generator_cc_proto", + ":status_handler_cc_proto", "//mediapipe/framework/port:gtest_main", "//mediapipe/framework/port:parse_text_proto", ], @@ -1369,6 +1371,7 @@ cc_test( srcs = ["calculator_context_test.cc"], linkstatic = 1, deps = [ + ":calculator_cc_proto", ":calculator_context", ":calculator_context_manager", ":calculator_state", @@ -1377,7 +1380,6 @@ cc_test( ":output_stream_shard", ":packet_set", ":packet_type", - "//mediapipe/framework:calculator_cc_proto", "//mediapipe/framework/port:gtest_main", "//mediapipe/framework/port:parse_text_proto", "//mediapipe/framework/port:status", @@ -1404,6 +1406,7 @@ cc_test( ":executor", ":input_stream_handler", ":lifetime_tracker", + ":mediapipe_options_cc_proto", ":output_stream_poller", ":packet_set", ":packet_type", @@ -1411,13 +1414,12 @@ cc_test( ":subgraph", ":test_calculators", ":thread_pool_executor", + ":thread_pool_executor_cc_proto", ":timestamp", ":type_map", "//mediapipe/calculators/core:counting_source_calculator", "//mediapipe/calculators/core:mux_calculator", "//mediapipe/calculators/core:pass_through_calculator", - "//mediapipe/framework:mediapipe_options_cc_proto", - "//mediapipe/framework:thread_pool_executor_cc_proto", "//mediapipe/framework/port:gtest_main", "//mediapipe/framework/port:logging", "//mediapipe/framework/port:parse_text_proto", @@ -1469,6 +1471,7 @@ cc_test( "//mediapipe/framework/stream_handler:mux_input_stream_handler", "//mediapipe/framework/stream_handler:sync_set_input_stream_handler", "//mediapipe/framework/tool:sink", + "//mediapipe/util:packet_test_util", "@com_google_absl//absl/strings", ], ) @@ -1481,12 +1484,12 @@ cc_test( ], visibility = ["//visibility:public"], deps = [ + ":calculator_cc_proto", ":calculator_framework", ":test_calculators", "//mediapipe/calculators/core:counting_source_calculator", "//mediapipe/calculators/core:mux_calculator", "//mediapipe/calculators/core:pass_through_calculator", - "//mediapipe/framework:calculator_cc_proto", "//mediapipe/framework/port:gtest_main", "//mediapipe/framework/port:logging", "//mediapipe/framework/port:parse_text_proto", @@ -1630,8 +1633,8 @@ cc_test( srcs = ["packet_generator_test.cc"], deps = [ ":packet_generator", + ":packet_generator_cc_proto", ":packet_type", - "//mediapipe/framework:packet_generator_cc_proto", "//mediapipe/framework/port:gtest_main", "//mediapipe/framework/tool:validate_type", "@com_google_absl//absl/strings", @@ -1659,9 +1662,6 @@ cc_test( "//mediapipe/calculators/core:constant_side_packet_calculator", "//mediapipe/calculators/core:default_side_packet_calculator", "//mediapipe/calculators/core:pass_through_calculator", - "//mediapipe/framework:calculator_cc_proto", - "//mediapipe/framework:packet_generator_cc_proto", - "//mediapipe/framework:status_handler_cc_proto", "//mediapipe/framework/port:gtest_main", "//mediapipe/framework/port:parse_text_proto", "//mediapipe/framework/tool:template_parser", diff --git a/mediapipe/framework/api2/builder.h b/mediapipe/framework/api2/builder.h index 5af9ee5e0..da09acc83 100644 --- a/mediapipe/framework/api2/builder.h +++ b/mediapipe/framework/api2/builder.h @@ -176,22 +176,50 @@ class SourceImpl { : SourceImpl(&GetWithAutoGrow(vec, 0)) {} explicit SourceImpl(SourceBase* base) : base_(base) {} + // Connects MediaPipe stream or side packet to a destination: + // - node input (input stream) / side input (input side packet) + // - graph output (output stream) / side output (output side packet). + // + // MediaPipe streams and side packets can be connected to multiple + // destinations. Side packets and packets added to streams are sent to all + // connected destinations. template {}, int>::type = 0> - Src& AddTarget(const Dst& dest) { + Src& ConnectTo(const Dst& dest) { CHECK(dest.base_.source == nullptr); dest.base_.source = base_; base_->dests_.emplace_back(&dest.base_); return *this; } + + // Shortcut for `ConnectTo`. + // + // Connects MediaPipe stream or side packet to a destination: + // - node input (input stream) / side input (input side packet) + // - graph output (output stream) / side output (output side packet). + // + // MediaPipe streams and side packets can be connected to multiple + // destinations. Side packets and packets added to streams are sent to all + // connected destinations. + template + Src& operator>>(const Dst& dest) { + return ConnectTo(dest); + } + + template + bool operator==(const SourceImpl& other) { + return base_ == other.base_; + } + + template + bool operator!=(const SourceImpl& other) { + return !(*this == other); + } + Src& SetName(std::string name) { base_->name_ = std::move(name); return *this; } - template - Src& operator>>(const Dst& dest) { - return AddTarget(dest); - } template {}, int> = 0> @@ -200,6 +228,9 @@ class SourceImpl { } private: + template + friend class SourceImpl; + // Never null. SourceBase* base_; }; @@ -380,7 +411,7 @@ template class Node; #if __cplusplus >= 201703L // Deduction guide to silence -Wctad-maybe-unsupported. -explicit Node()->Node; +explicit Node() -> Node; #endif // C++17 template <> @@ -394,11 +425,11 @@ using GenericNode = Node; template class Node : public NodeBase { public: - Node() : NodeBase(Calc::kCalculatorName) {} + Node() : NodeBase(std::string(Calc::kCalculatorName)) {} // Overrides the built-in calculator type string with the provided argument. // Can be used to create nodes from pure interfaces. // TODO: only use this for pure interfaces - Node(const std::string& type_override) : NodeBase(type_override) {} + Node(std::string type_override) : NodeBase(std::move(type_override)) {} // These methods only allow access to ports declared in the contract. // The argument must be a tag object created with the MPP_TAG macro. diff --git a/mediapipe/framework/api2/builder_test.cc b/mediapipe/framework/api2/builder_test.cc index 3bf3ec198..363971689 100644 --- a/mediapipe/framework/api2/builder_test.cc +++ b/mediapipe/framework/api2/builder_test.cc @@ -15,18 +15,32 @@ #include "mediapipe/framework/port/parse_text_proto.h" #include "mediapipe/framework/port/status_matchers.h" -namespace mediapipe { -namespace api2 { -namespace test { +namespace mediapipe::api2::builder { +namespace { + +using ::mediapipe::api2::test::Bar; +using ::mediapipe::api2::test::FloatAdder; +using ::mediapipe::api2::test::Foo; +using ::mediapipe::api2::test::Foo2; +using ::mediapipe::api2::test::FooBar1; TEST(BuilderTest, BuildGraph) { - builder::Graph graph; + Graph graph; + // Graph inputs. + Stream base = graph.In("IN").SetName("base"); + SidePacket side = graph.SideIn("SIDE").SetName("side"); + auto& foo = graph.AddNode("Foo"); + base >> foo.In("BASE"); + side >> foo.SideIn("SIDE"); + Stream foo_out = foo.Out("OUT"); + auto& bar = graph.AddNode("Bar"); - graph.In("IN").SetName("base") >> foo.In("BASE"); - graph.SideIn("SIDE").SetName("side") >> foo.SideIn("SIDE"); - foo.Out("OUT") >> bar.In("IN"); - bar.Out("OUT").SetName("out") >> graph.Out("OUT"); + foo_out >> bar.In("IN"); + Stream bar_out = bar.Out("OUT"); + + // Graph outputs. + bar_out.SetName("out") >> graph.Out("OUT"); CalculatorGraphConfig expected = mediapipe::ParseTextProtoOrDie(R"pb( @@ -48,23 +62,20 @@ TEST(BuilderTest, BuildGraph) { EXPECT_THAT(graph.GetConfig(), EqualsProto(expected)); } -TEST(BuilderTest, CopyableSource) { - builder::Graph graph; - builder::Source a = graph[Input("A")]; - a.SetName("a"); - builder::Source b = graph[Input("B")]; - b.SetName("b"); - builder::SideSource side_a = graph[SideInput("SIDE_A")]; - side_a.SetName("side_a"); - builder::SideSource side_b = graph[SideInput("SIDE_B")]; - side_b.SetName("side_b"); - builder::Destination out = graph[Output("OUT")]; - builder::SideDestination side_out = - graph[SideOutput("SIDE_OUT")]; +TEST(BuilderTest, CopyableStream) { + Graph graph; + Stream a = graph.In("A").SetName("a").Cast(); + Stream b = graph.In("B").SetName("b").Cast(); + SidePacket side_a = + graph.SideIn("SIDE_A").SetName("side_a").Cast(); + SidePacket side_b = + graph.SideIn("SIDE_B").SetName("side_b").Cast(); + Destination out = graph.Out("OUT").Cast(); + SideDestination side_out = graph.SideOut("SIDE_OUT").Cast(); - builder::Source input = a; + Stream input = a; input = b; - builder::SideSource side_input = side_b; + SidePacket side_input = side_b; side_input = side_a; input >> out; @@ -83,31 +94,29 @@ TEST(BuilderTest, CopyableSource) { } TEST(BuilderTest, BuildGraphWithFunctions) { - builder::Graph graph; + Graph graph; - builder::Source base = graph[Input("IN")]; - base.SetName("base"); - builder::SideSource side = graph[SideInput("SIDE")]; - side.SetName("side"); + // Graph inputs. + Stream base = graph.In("IN").SetName("base").Cast(); + SidePacket side = graph.SideIn("SIDE").SetName("side").Cast(); - auto foo_fn = [](builder::Source base, builder::SideSource side, - builder::Graph& graph) { + auto foo_fn = [](Stream base, SidePacket side, Graph& graph) { auto& foo = graph.AddNode("Foo"); - base >> foo[Input("BASE")]; - side >> foo[SideInput("SIDE")]; - return foo[Output("OUT")]; + base >> foo.In("BASE"); + side >> foo.SideIn("SIDE"); + return foo.Out("OUT")[0].Cast(); }; - builder::Source foo_out = foo_fn(base, side, graph); + Stream foo_out = foo_fn(base, side, graph); - auto bar_fn = [](builder::Source in, builder::Graph& graph) { + auto bar_fn = [](Stream in, Graph& graph) { auto& bar = graph.AddNode("Bar"); - in >> bar[Input("IN")]; - return bar[Output("OUT")]; + in >> bar.In("IN"); + return bar.Out("OUT")[0].Cast(); }; - builder::Source bar_out = bar_fn(foo_out, graph); - bar_out.SetName("out"); + Stream bar_out = bar_fn(foo_out, graph); - bar_out >> graph[Output("OUT")]; + // Graph outputs. + bar_out.SetName("out") >> graph.Out("OUT"); CalculatorGraphConfig expected = mediapipe::ParseTextProtoOrDie(R"pb( @@ -131,13 +140,22 @@ TEST(BuilderTest, BuildGraphWithFunctions) { template void BuildGraphTypedTest() { - builder::Graph graph; + Graph graph; + // Graph inputs. + Stream base = graph.In("IN").SetName("base"); + SidePacket side = graph.SideIn("SIDE").SetName("side"); + auto& foo = graph.AddNode(); + base >> foo.In(MPP_TAG("BASE")); + side >> foo.SideIn(MPP_TAG("BIAS")); + Stream foo_out = foo.Out(MPP_TAG("OUT")); + auto& bar = graph.AddNode(); - graph.In("IN").SetName("base") >> foo.In(MPP_TAG("BASE")); - graph.SideIn("SIDE").SetName("side") >> foo.SideIn(MPP_TAG("BIAS")); - foo.Out(MPP_TAG("OUT")) >> bar.In(MPP_TAG("IN")); - bar.Out(MPP_TAG("OUT")).SetName("out") >> graph.Out("OUT"); + foo_out >> bar.In(MPP_TAG("IN")); + Stream bar_out = bar.Out(MPP_TAG("OUT")); + + // Graph outputs. + bar_out.SetName("out") >> graph.Out("OUT"); CalculatorGraphConfig expected = mediapipe::ParseTextProtoOrDie( @@ -161,18 +179,26 @@ void BuildGraphTypedTest() { EXPECT_THAT(graph.GetConfig(), EqualsProto(expected)); } -TEST(BuilderTest, BuildGraphTyped) { BuildGraphTypedTest(); } +TEST(BuilderTest, BuildGraphTyped) { BuildGraphTypedTest(); } -TEST(BuilderTest, BuildGraphTyped2) { BuildGraphTypedTest(); } +TEST(BuilderTest, BuildGraphTyped2) { BuildGraphTypedTest(); } TEST(BuilderTest, FanOut) { - builder::Graph graph; + Graph graph; + // Graph inputs. + Stream base = graph.In("IN").SetName("base"); + auto& foo = graph.AddNode("Foo"); + base >> foo.In("BASE"); + Stream foo_out = foo.Out("OUT"); + auto& adder = graph.AddNode("FloatAdder"); - graph.In("IN").SetName("base") >> foo.In("BASE"); - foo.Out("OUT") >> adder.In("IN")[0]; - foo.Out("OUT") >> adder.In("IN")[1]; - adder.Out("OUT").SetName("out") >> graph.Out("OUT"); + foo_out >> adder.In("IN")[0]; + foo_out >> adder.In("IN")[1]; + Stream out = adder.Out("OUT"); + + // Graph outputs. + out.SetName("out") >> graph.Out("OUT"); CalculatorGraphConfig expected = mediapipe::ParseTextProtoOrDie(R"pb( @@ -194,13 +220,21 @@ TEST(BuilderTest, FanOut) { } TEST(BuilderTest, TypedMultiple) { - builder::Graph graph; + Graph graph; + // Graph inputs. + Stream base = graph.In("IN").SetName("base"); + auto& foo = graph.AddNode(); + base >> foo.In(MPP_TAG("BASE")); + Stream foo_out = foo.Out(MPP_TAG("OUT")); + auto& adder = graph.AddNode(); - graph.In("IN").SetName("base") >> foo.In(MPP_TAG("BASE")); - foo.Out(MPP_TAG("OUT")) >> adder.In(MPP_TAG("IN"))[0]; - foo.Out(MPP_TAG("OUT")) >> adder.In(MPP_TAG("IN"))[1]; - adder.Out(MPP_TAG("OUT")).SetName("out") >> graph.Out("OUT"); + foo_out >> adder.In(MPP_TAG("IN"))[0]; + foo_out >> adder.In(MPP_TAG("IN"))[1]; + Stream out = adder.Out(MPP_TAG("OUT")); + + // Graph outputs. + out.SetName("out") >> graph.Out("OUT"); CalculatorGraphConfig expected = mediapipe::ParseTextProtoOrDie(R"pb( @@ -222,14 +256,21 @@ TEST(BuilderTest, TypedMultiple) { } TEST(BuilderTest, TypedByPorts) { - builder::Graph graph; - auto& foo = graph.AddNode(); - auto& adder = graph.AddNode(); + Graph graph; + // Graph inputs. + Stream base = graph.In(FooBar1::kIn).SetName("base"); - graph[FooBar1::kIn].SetName("base") >> foo[Foo::kBase]; - foo[Foo::kOut] >> adder[FloatAdder::kIn][0]; - foo[Foo::kOut] >> adder[FloatAdder::kIn][1]; - adder[FloatAdder::kOut].SetName("out") >> graph[FooBar1::kOut]; + auto& foo = graph.AddNode(); + base >> foo[Foo::kBase]; + Stream foo_out = foo[Foo::kOut]; + + auto& adder = graph.AddNode(); + foo_out >> adder[FloatAdder::kIn][0]; + foo_out >> adder[FloatAdder::kIn][1]; + Stream out = adder[FloatAdder::kOut]; + + // Graph outputs. + out.SetName("out") >> graph.Out(FooBar1::kOut); CalculatorGraphConfig expected = mediapipe::ParseTextProtoOrDie(R"pb( @@ -251,10 +292,16 @@ TEST(BuilderTest, TypedByPorts) { } TEST(BuilderTest, PacketGenerator) { - builder::Graph graph; + Graph graph; + // Graph inputs. + SidePacket side_in = graph.SideIn("IN"); + auto& generator = graph.AddPacketGenerator("FloatGenerator"); - graph.SideIn("IN") >> generator.SideIn("IN"); - generator.SideOut("OUT") >> graph.SideOut("OUT"); + side_in >> generator.SideIn("IN"); + SidePacket side_out = generator.SideOut("OUT"); + + // Graph outputs. + side_out >> graph.SideOut("OUT"); CalculatorGraphConfig expected = mediapipe::ParseTextProtoOrDie(R"pb( @@ -270,13 +317,22 @@ TEST(BuilderTest, PacketGenerator) { } TEST(BuilderTest, EmptyTag) { - builder::Graph graph; + Graph graph; + // Graph inputs. + Stream a = graph.In("A").SetName("a"); + Stream c = graph.In("C").SetName("c"); + Stream b = graph.In("B").SetName("b"); + auto& foo = graph.AddNode("Foo"); - graph.In("A").SetName("a") >> foo.In("")[0]; - graph.In("C").SetName("c") >> foo.In("")[2]; - graph.In("B").SetName("b") >> foo.In("")[1]; - foo.Out("")[0].SetName("x") >> graph.Out("ONE"); - foo.Out("")[1].SetName("y") >> graph.Out("TWO"); + a >> foo.In("")[0]; + c >> foo.In("")[2]; + b >> foo.In("")[1]; + Stream x = foo.Out("")[0]; + Stream y = foo.Out("")[1]; + + // Graph outputs. + x.SetName("x") >> graph.Out("ONE"); + y.SetName("y") >> graph.Out("TWO"); CalculatorGraphConfig expected = mediapipe::ParseTextProtoOrDie(R"pb( @@ -302,11 +358,18 @@ TEST(BuilderTest, StringLikeTags) { const std::string kB = "B"; constexpr absl::string_view kC = "C"; - builder::Graph graph; + Graph graph; + // Graph inputs. + Stream a = graph.In(kA).SetName("a"); + Stream b = graph.In(kB).SetName("b"); + auto& foo = graph.AddNode("Foo"); - graph.In(kA).SetName("a") >> foo.In(kA); - graph.In(kB).SetName("b") >> foo.In(kB); - foo.Out(kC).SetName("c") >> graph.Out(kC); + a >> foo.In(kA); + b >> foo.In(kB); + Stream c = foo.Out(kC); + + // Graph outputs. + c.SetName("c") >> graph.Out(kC); CalculatorGraphConfig expected = mediapipe::ParseTextProtoOrDie(R"pb( @@ -324,13 +387,22 @@ TEST(BuilderTest, StringLikeTags) { } TEST(BuilderTest, GraphIndexes) { - builder::Graph graph; + Graph graph; + // Graph inputs. + Stream a = graph.In(0).SetName("a"); + Stream c = graph.In(1).SetName("c"); + Stream b = graph.In(2).SetName("b"); + auto& foo = graph.AddNode("Foo"); - graph.In(0).SetName("a") >> foo.In("")[0]; - graph.In(1).SetName("c") >> foo.In("")[2]; - graph.In(2).SetName("b") >> foo.In("")[1]; - foo.Out("")[0].SetName("x") >> graph.Out(1); - foo.Out("")[1].SetName("y") >> graph.Out(0); + a >> foo.In("")[0]; + c >> foo.In("")[2]; + b >> foo.In("")[1]; + Stream x = foo.Out("")[0]; + Stream y = foo.Out("")[1]; + + // Graph outputs. + x.SetName("x") >> graph.Out(1); + y.SetName("y") >> graph.Out(0); CalculatorGraphConfig expected = mediapipe::ParseTextProtoOrDie(R"pb( @@ -376,29 +448,27 @@ class AnyAndSameTypeCalculator : public NodeIntf { }; TEST(BuilderTest, AnyAndSameTypeHandledProperly) { - builder::Graph graph; - builder::Source any_input = graph[Input{"GRAPH_ANY_INPUT"}]; - builder::Source int_input = graph[Input{"GRAPH_INT_INPUT"}]; + Graph graph; + Stream any_input = graph.In("GRAPH_ANY_INPUT"); + Stream int_input = graph.In("GRAPH_INT_INPUT").Cast(); auto& node = graph.AddNode("AnyAndSameTypeCalculator"); any_input >> node[AnyAndSameTypeCalculator::kAnyTypeInput]; int_input >> node[AnyAndSameTypeCalculator::kIntInput]; - - builder::Source any_type_output = + Stream any_type_output = node[AnyAndSameTypeCalculator::kAnyTypeOutput]; - any_type_output.SetName("any_type_output"); - - builder::Source same_type_output = + Stream same_type_output = node[AnyAndSameTypeCalculator::kSameTypeOutput]; - same_type_output.SetName("same_type_output"); - builder::Source recursive_same_type_output = + Stream recursive_same_type_output = node[AnyAndSameTypeCalculator::kRecursiveSameTypeOutput]; - recursive_same_type_output.SetName("recursive_same_type_output"); - builder::Source same_int_output = - node[AnyAndSameTypeCalculator::kSameIntOutput]; - same_int_output.SetName("same_int_output"); - builder::Source recursive_same_int_type_output = + Stream same_int_output = node[AnyAndSameTypeCalculator::kSameIntOutput]; + Stream recursive_same_int_type_output = node[AnyAndSameTypeCalculator::kRecursiveSameIntOutput]; + + any_type_output.SetName("any_type_output"); + same_type_output.SetName("same_type_output"); + recursive_same_type_output.SetName("recursive_same_type_output"); + same_int_output.SetName("same_int_output"); recursive_same_int_type_output.SetName("recursive_same_int_type_output"); CalculatorGraphConfig expected = mediapipe::ParseTextProtoOrDie< @@ -420,17 +490,17 @@ TEST(BuilderTest, AnyAndSameTypeHandledProperly) { } TEST(BuilderTest, AnyTypeCanBeCast) { - builder::Graph graph; - builder::Source any_input = + Graph graph; + Stream any_input = graph.In("GRAPH_ANY_INPUT").Cast(); auto& node = graph.AddNode("AnyAndSameTypeCalculator"); any_input >> node[AnyAndSameTypeCalculator::kAnyTypeInput]; - builder::Source any_type_output = + Stream any_type_output = node[AnyAndSameTypeCalculator::kAnyTypeOutput].Cast(); - any_type_output.SetName("any_type_output"); - any_type_output >> graph.Out("GRAPH_ANY_OUTPUT").Cast(); + any_type_output.SetName("any_type_output") >> + graph.Out("GRAPH_ANY_OUTPUT").Cast(); CalculatorGraphConfig expected = mediapipe::ParseTextProtoOrDie(R"pb( @@ -446,11 +516,11 @@ TEST(BuilderTest, AnyTypeCanBeCast) { } TEST(BuilderTest, MultiPortIsCastToMultiPort) { - builder::Graph graph; - builder::MultiSource any_input = graph.In("ANY_INPUT"); - builder::MultiSource int_input = any_input.Cast(); - builder::MultiDestination any_output = graph.Out("ANY_OUTPUT"); - builder::MultiDestination int_output = any_output.Cast(); + Graph graph; + MultiSource any_input = graph.In("ANY_INPUT"); + MultiSource int_input = any_input.Cast(); + MultiDestination any_output = graph.Out("ANY_OUTPUT"); + MultiDestination int_output = any_output.Cast(); int_input >> int_output; CalculatorGraphConfig expected = @@ -462,11 +532,11 @@ TEST(BuilderTest, MultiPortIsCastToMultiPort) { } TEST(BuilderTest, MultiPortCanBeSlicedToSinglePort) { - builder::Graph graph; - builder::MultiSource any_multi_input = graph.In("ANY_INPUT"); - builder::Source any_input = any_multi_input; - builder::MultiDestination any_multi_output = graph.Out("ANY_OUTPUT"); - builder::Destination any_output = any_multi_output; + Graph graph; + MultiSource any_multi_input = graph.In("ANY_INPUT"); + Stream any_input = any_multi_input; + MultiDestination any_multi_output = graph.Out("ANY_OUTPUT"); + Destination any_output = any_multi_output; any_input >> any_output; CalculatorGraphConfig expected = @@ -478,11 +548,11 @@ TEST(BuilderTest, MultiPortCanBeSlicedToSinglePort) { } TEST(BuilderTest, SinglePortAccessWorksThroughSlicing) { - builder::Graph graph; - builder::Source int_input = graph.In("INT_INPUT").Cast(); - builder::Source any_input = graph.In("ANY_OUTPUT"); - builder::Destination int_output = graph.Out("INT_OUTPUT").Cast(); - builder::Destination any_output = graph.Out("ANY_OUTPUT"); + Graph graph; + Stream int_input = graph.In("INT_INPUT").Cast(); + Stream any_input = graph.In("ANY_OUTPUT"); + Destination int_output = graph.Out("INT_OUTPUT").Cast(); + Destination any_output = graph.Out("ANY_OUTPUT"); int_input >> int_output; any_input >> any_output; @@ -496,6 +566,51 @@ TEST(BuilderTest, SinglePortAccessWorksThroughSlicing) { EXPECT_THAT(graph.GetConfig(), EqualsProto(expected)); } -} // namespace test -} // namespace api2 -} // namespace mediapipe +TEST(BuilderTest, TestStreamEqualsNotEqualsOperators) { + Graph graph; + Stream input0 = graph.In(0); + EXPECT_TRUE(input0 == input0); + EXPECT_FALSE(input0 != input0); + + EXPECT_TRUE(input0 == input0.Cast()); + EXPECT_FALSE(input0.Cast() != input0); + + EXPECT_TRUE(input0.Cast() == input0.Cast()); + EXPECT_FALSE(input0.Cast() != input0.Cast()); + + Stream input1 = graph.In(1); + EXPECT_FALSE(input0 == input1); + EXPECT_TRUE(input0 != input1); + + input1 = input0; + EXPECT_TRUE(input0 == input1); + EXPECT_FALSE(input0 != input1); + EXPECT_TRUE(input0.Cast() == input1.Cast()); + EXPECT_FALSE(input0.Cast() != input1.Cast()); +} + +TEST(BuilderTest, TestSidePacketEqualsNotEqualsOperators) { + Graph graph; + SidePacket side_input0 = graph.SideIn(0); + EXPECT_TRUE(side_input0 == side_input0); + EXPECT_FALSE(side_input0 != side_input0); + + EXPECT_TRUE(side_input0 == side_input0.Cast()); + EXPECT_FALSE(side_input0.Cast() != side_input0); + + EXPECT_TRUE(side_input0.Cast() == side_input0.Cast()); + EXPECT_FALSE(side_input0.Cast() != side_input0.Cast()); + + SidePacket side_input1 = graph.SideIn(1); + EXPECT_FALSE(side_input0 == side_input1); + EXPECT_TRUE(side_input0 != side_input1); + + side_input1 = side_input0; + EXPECT_TRUE(side_input0 == side_input1); + EXPECT_FALSE(side_input0 != side_input1); + EXPECT_TRUE(side_input0.Cast() == side_input1.Cast()); + EXPECT_FALSE(side_input0.Cast() != side_input1.Cast()); +} + +} // namespace +} // namespace mediapipe::api2::builder diff --git a/mediapipe/framework/api2/packet.h b/mediapipe/framework/api2/packet.h index 7933575d3..b1ebb0410 100644 --- a/mediapipe/framework/api2/packet.h +++ b/mediapipe/framework/api2/packet.h @@ -181,7 +181,7 @@ template class Packet; #if __cplusplus >= 201703L // Deduction guide to silence -Wctad-maybe-unsupported. -explicit Packet()->Packet; +explicit Packet() -> Packet; #endif // C++17 template <> diff --git a/mediapipe/framework/api2/port.h b/mediapipe/framework/api2/port.h index e63d3651e..eee542640 100644 --- a/mediapipe/framework/api2/port.h +++ b/mediapipe/framework/api2/port.h @@ -557,8 +557,8 @@ class OutputSidePacketAccess { if (output_) output_->Set(ToOldPacket(std::move(packet))); } - void Set(const T& payload) { Set(MakePacket(payload)); } - void Set(T&& payload) { Set(MakePacket(std::move(payload))); } + void Set(const T& payload) { Set(api2::MakePacket(payload)); } + void Set(T&& payload) { Set(api2::MakePacket(std::move(payload))); } private: OutputSidePacketAccess(OutputSidePacket* output) : output_(output) {} diff --git a/mediapipe/framework/calculator.proto b/mediapipe/framework/calculator.proto index 7c5e8b144..eecd033c9 100644 --- a/mediapipe/framework/calculator.proto +++ b/mediapipe/framework/calculator.proto @@ -382,7 +382,7 @@ message CalculatorGraphConfig { // is empty and no other nodes are running (to prevent possible deadlocks due // to a incorrectly specified value). This global parameter is set to 100 // packets by default to enable pipelining. If any node indicates that it - // buffers packets before emitting them, then the max(node_buffer_size, + // buffers packets before emitting them, then the max(buffer_size_hint, // max_queue_size) is used. Set this parameter to -1 to disable throttling // (i.e. the graph will use as much memory as it requires). If not specified, // the limit is 100 packets. diff --git a/mediapipe/framework/calculator_context_test.cc b/mediapipe/framework/calculator_context_test.cc index e7612501a..be9103b4d 100644 --- a/mediapipe/framework/calculator_context_test.cc +++ b/mediapipe/framework/calculator_context_test.cc @@ -131,10 +131,10 @@ TEST(CalculatorTest, GetOptions) { auto calculator_state_3 = MakeCalculatorState(config.node(3), 3); auto cc_3 = MakeCalculatorContext(&*calculator_state_3); - // Get a proto2 options extension from Node::options. + // Get a google::protobuf options extension from Node::options. EXPECT_EQ(cc_0->Options().jitter(), 0.123); - // Get a proto2 options extension from Node::node_options. + // Get a google::protobuf options extension from Node::node_options. EXPECT_EQ(cc_1->Options().jitter(), 0.123); // Get a proto3 options protobuf::Any from Node::node_options. diff --git a/mediapipe/framework/calculator_graph.cc b/mediapipe/framework/calculator_graph.cc index c17a2e1e2..526a74835 100644 --- a/mediapipe/framework/calculator_graph.cc +++ b/mediapipe/framework/calculator_graph.cc @@ -98,14 +98,13 @@ void CalculatorGraph::GraphInputStream::SetHeader(const Packet& header) { manager_->LockIntroData(); } +void CalculatorGraph::GraphInputStream::SetNextTimestampBound( + Timestamp timestamp) { + shard_.SetNextTimestampBound(timestamp); +} + void CalculatorGraph::GraphInputStream::PropagateUpdatesToMirrors() { - // Since GraphInputStream doesn't allow SetOffset() and - // SetNextTimestampBound(), the timestamp bound to propagate is only - // determined by the timestamp of the output packets. - CHECK(!shard_.IsEmpty()) << "Shard with name \"" << manager_->Name() - << "\" failed"; - manager_->PropagateUpdatesToMirrors( - shard_.LastAddedPacketTimestamp().NextAllowedInStream(), &shard_); + manager_->PropagateUpdatesToMirrors(shard_.NextTimestampBound(), &shard_); } void CalculatorGraph::GraphInputStream::Close() { @@ -868,6 +867,19 @@ absl::Status CalculatorGraph::AddPacketToInputStream( return AddPacketToInputStreamInternal(stream_name, std::move(packet)); } +absl::Status CalculatorGraph::SetInputStreamTimestampBound( + const std::string& stream_name, Timestamp timestamp) { + std::unique_ptr* stream = + mediapipe::FindOrNull(graph_input_streams_, stream_name); + RET_CHECK(stream).SetNoLogging() << absl::Substitute( + "SetInputStreamTimestampBound called on input stream \"$0\" which is not " + "a graph input stream.", + stream_name); + (*stream)->SetNextTimestampBound(timestamp); + (*stream)->PropagateUpdatesToMirrors(); + return absl::OkStatus(); +} + // We avoid having two copies of this code for AddPacketToInputStream( // const Packet&) and AddPacketToInputStream(Packet &&) by having this // internal-only templated version. T&& is a forwarding reference here, so diff --git a/mediapipe/framework/calculator_graph.h b/mediapipe/framework/calculator_graph.h index c51476102..04f9de45f 100644 --- a/mediapipe/framework/calculator_graph.h +++ b/mediapipe/framework/calculator_graph.h @@ -257,6 +257,10 @@ class CalculatorGraph { absl::Status AddPacketToInputStream(const std::string& stream_name, Packet&& packet); + // Indicates that input will arrive no earlier than a certain timestamp. + absl::Status SetInputStreamTimestampBound(const std::string& stream_name, + Timestamp timestamp); + // Sets the queue size of a graph input stream, overriding the graph default. absl::Status SetInputStreamMaxQueueSize(const std::string& stream_name, int max_queue_size); @@ -425,6 +429,8 @@ class CalculatorGraph { void AddPacket(Packet&& packet) { shard_.AddPacket(std::move(packet)); } + void SetNextTimestampBound(Timestamp timestamp); + void PropagateUpdatesToMirrors(); void Close(); diff --git a/mediapipe/framework/calculator_graph_bounds_test.cc b/mediapipe/framework/calculator_graph_bounds_test.cc index b55f9459d..d149337cc 100644 --- a/mediapipe/framework/calculator_graph_bounds_test.cc +++ b/mediapipe/framework/calculator_graph_bounds_test.cc @@ -12,6 +12,8 @@ // See the License for the specific language governing permissions and // limitations under the License. +#include + #include "absl/strings/str_replace.h" #include "mediapipe/framework/calculator_context.h" #include "mediapipe/framework/calculator_framework.h" @@ -24,6 +26,7 @@ #include "mediapipe/framework/port/status_matchers.h" #include "mediapipe/framework/thread_pool_executor.h" #include "mediapipe/framework/timestamp.h" +#include "mediapipe/util/packet_test_util.h" namespace mediapipe { namespace { @@ -1536,7 +1539,7 @@ class EmptyPacketCalculator : public CalculatorBase { }; REGISTER_CALCULATOR(EmptyPacketCalculator); -// This test shows that an output timestamp bound can be specified by outputing +// This test shows that an output timestamp bound can be specified by outputting // an empty packet with a settled timestamp. TEST(CalculatorGraphBoundsTest, EmptyPacketOutput) { // OffsetAndBoundCalculator runs on parallel threads and sends ts @@ -1580,6 +1583,195 @@ TEST(CalculatorGraphBoundsTest, EmptyPacketOutput) { EXPECT_EQ(output_0_packets[i].Timestamp(), Timestamp(10 + i * 10)); } + // Shut down the graph. + MP_ASSERT_OK(graph.CloseAllPacketSources()); + MP_ASSERT_OK(graph.WaitUntilDone()); +} + +// This test shows that input timestamp bounds can be specified using +// CalculatorGraph::SetInputStreamTimestampBound. +TEST(CalculatorGraphBoundsTest, SetInputStreamTimestampBound) { + std::string config_str = R"( + input_stream: "input_0" + node { + calculator: "ProcessBoundToPacketCalculator" + input_stream: "input_0" + output_stream: "output_0" + } + )"; + CalculatorGraphConfig config = + mediapipe::ParseTextProtoOrDie(config_str); + CalculatorGraph graph; + std::vector output_0_packets; + MP_ASSERT_OK(graph.Initialize(config)); + MP_ASSERT_OK(graph.ObserveOutputStream("output_0", [&](const Packet& p) { + output_0_packets.push_back(p); + return absl::OkStatus(); + })); + MP_ASSERT_OK(graph.StartRun({})); + MP_ASSERT_OK(graph.WaitUntilIdle()); + + // Send in timestamp bounds. + for (int i = 0; i < 9; ++i) { + const int ts = 10 + i * 10; + MP_ASSERT_OK(graph.SetInputStreamTimestampBound( + "input_0", Timestamp(ts).NextAllowedInStream())); + MP_ASSERT_OK(graph.WaitUntilIdle()); + } + + // 9 timestamp bounds are converted to packets. + EXPECT_EQ(output_0_packets.size(), 9); + for (int i = 0; i < 9; ++i) { + EXPECT_EQ(output_0_packets[i].Timestamp(), Timestamp(10 + i * 10)); + } + + // Shutdown the graph. + MP_ASSERT_OK(graph.CloseAllPacketSources()); + MP_ASSERT_OK(graph.WaitUntilDone()); +} + +// This test shows how an input stream with infrequent packets, such as +// configuration protobufs, can be consumed while processing more frequent +// packets, such as video frames. +TEST(CalculatorGraphBoundsTest, TimestampBoundsForInfrequentInput) { + // PassThroughCalculator consuming two input streams, with default ISH. + std::string config_str = R"pb( + input_stream: "INFREQUENT:config" + input_stream: "FREQUENT:frame" + node { + calculator: "PassThroughCalculator" + input_stream: "CONFIG:config" + input_stream: "VIDEO:frame" + output_stream: "VIDEO:output_frame" + output_stream: "CONFIG:output_config" + } + )pb"; + + CalculatorGraphConfig config = + mediapipe::ParseTextProtoOrDie(config_str); + CalculatorGraph graph; + std::vector frame_packets; + MP_ASSERT_OK(graph.Initialize(config)); + MP_ASSERT_OK(graph.ObserveOutputStream( + "output_frame", + [&](const Packet& p) { + frame_packets.push_back(p); + return absl::OkStatus(); + }, + /*observe_bound_updates=*/true)); + std::vector config_packets; + MP_ASSERT_OK(graph.ObserveOutputStream( + "output_config", + [&](const Packet& p) { + config_packets.push_back(p); + return absl::OkStatus(); + }, + /*observe_bound_updates=*/true)); + MP_ASSERT_OK(graph.StartRun({})); + MP_ASSERT_OK(graph.WaitUntilIdle()); + + // Utility functions to send packets or timestamp bounds. + auto send_fn = [&](std::string stream, std::string value, int ts) { + MP_ASSERT_OK(graph.AddPacketToInputStream( + stream, + MakePacket(absl::StrCat(value)).At(Timestamp(ts)))); + MP_ASSERT_OK(graph.WaitUntilIdle()); + }; + auto bound_fn = [&](std::string stream, int ts) { + MP_ASSERT_OK(graph.SetInputStreamTimestampBound(stream, Timestamp(ts))); + MP_ASSERT_OK(graph.WaitUntilIdle()); + }; + + // Send in a frame packet. + send_fn("frame", "frame_0", 0); + // The frame is not processed yet. + EXPECT_THAT(frame_packets, ElementsAreArray(PacketMatchers({}))); + bound_fn("config", 10000); + // The frame is processed after a fresh config timestamp bound arrives. + EXPECT_THAT(frame_packets, + ElementsAreArray(PacketMatchers({ + MakePacket("frame_0").At(Timestamp(0)), + }))); + + // Send in a frame packet. + send_fn("frame", "frame_1", 20000); + // The frame is not processed yet. + // The PassThroughCalculator with TimestampOffset 0 now propagates + // Timestamp bound 10000 to both "output_frame" and "output_config", + // which appears here as Packet().At(Timestamp(9999). The timestamp + // bounds at 29999 and 50000 are propagated similarly. + EXPECT_THAT(frame_packets, + ElementsAreArray(PacketMatchers({ + MakePacket("frame_0").At(Timestamp(0)), + Packet().At(Timestamp(9999)), + }))); + bound_fn("config", 30000); + // The frame is processed after a fresh config timestamp bound arrives. + EXPECT_THAT(frame_packets, + ElementsAreArray(PacketMatchers({ + MakePacket("frame_0").At(Timestamp(0)), + Packet().At(Timestamp(9999)), + MakePacket("frame_1").At(Timestamp(20000)), + }))); + + // Send in a frame packet. + send_fn("frame", "frame_2", 40000); + // The frame is not processed yet. + EXPECT_THAT(frame_packets, + ElementsAreArray(PacketMatchers({ + MakePacket("frame_0").At(Timestamp(0)), + Packet().At(Timestamp(9999)), + MakePacket("frame_1").At(Timestamp(20000)), + Packet().At(Timestamp(29999)), + }))); + send_fn("config", "config_1", 50000); + // The frame is processed after a fresh config arrives. + EXPECT_THAT(frame_packets, + ElementsAreArray(PacketMatchers({ + MakePacket("frame_0").At(Timestamp(0)), + Packet().At(Timestamp(9999)), + MakePacket("frame_1").At(Timestamp(20000)), + Packet().At(Timestamp(29999)), + MakePacket("frame_2").At(Timestamp(40000)), + }))); + + // Send in a frame packet. + send_fn("frame", "frame_3", 60000); + // The frame is not processed yet. + EXPECT_THAT(frame_packets, + ElementsAreArray(PacketMatchers({ + MakePacket("frame_0").At(Timestamp(0)), + Packet().At(Timestamp(9999)), + MakePacket("frame_1").At(Timestamp(20000)), + Packet().At(Timestamp(29999)), + MakePacket("frame_2").At(Timestamp(40000)), + Packet().At(Timestamp(50000)), + }))); + bound_fn("config", 70000); + // The frame is processed after a fresh config timestamp bound arrives. + EXPECT_THAT(frame_packets, + ElementsAreArray(PacketMatchers({ + MakePacket("frame_0").At(Timestamp(0)), + Packet().At(Timestamp(9999)), + MakePacket("frame_1").At(Timestamp(20000)), + Packet().At(Timestamp(29999)), + MakePacket("frame_2").At(Timestamp(40000)), + Packet().At(Timestamp(50000)), + MakePacket("frame_3").At(Timestamp(60000)), + }))); + + // One config packet is deleivered. + EXPECT_THAT(config_packets, + ElementsAreArray(PacketMatchers({ + Packet().At(Timestamp(0)), + Packet().At(Timestamp(9999)), + Packet().At(Timestamp(20000)), + Packet().At(Timestamp(29999)), + Packet().At(Timestamp(40000)), + MakePacket("config_1").At(Timestamp(50000)), + Packet().At(Timestamp(60000)), + }))); + // Shutdown the graph. MP_ASSERT_OK(graph.CloseAllPacketSources()); MP_ASSERT_OK(graph.WaitUntilDone()); diff --git a/mediapipe/framework/deps/BUILD b/mediapipe/framework/deps/BUILD index a39d7476e..7994aae75 100644 --- a/mediapipe/framework/deps/BUILD +++ b/mediapipe/framework/deps/BUILD @@ -20,7 +20,14 @@ load("@bazel_skylib//:bzl_library.bzl", "bzl_library") licenses(["notice"]) -package(default_visibility = ["//visibility:private"]) +package_group( + name = "mediapipe_internal", + packages = [ + "//mediapipe/...", + ], +) + +package(default_visibility = ["mediapipe_internal"]) bzl_library( name = "expand_template_bzl", @@ -50,13 +57,11 @@ mediapipe_proto_library( cc_library( name = "aligned_malloc_and_free", hdrs = ["aligned_malloc_and_free.h"], - visibility = ["//visibility:public"], ) cc_library( name = "cleanup", hdrs = ["cleanup.h"], - visibility = ["//visibility:public"], deps = ["@com_google_absl//absl/base:core_headers"], ) @@ -83,10 +88,10 @@ cc_library( name = "message_matchers", testonly = True, hdrs = ["message_matchers.h"], + # Use this library through "mediapipe/framework/port:gtest_main". visibility = [ "//mediapipe/framework/port:__pkg__", - "//third_party/visionai/algorithms/tracking:__pkg__", ], deps = [ "//mediapipe/framework/port:core_proto", @@ -108,7 +113,6 @@ cc_library( name = "file_helpers", srcs = ["file_helpers.cc"], hdrs = ["file_helpers.h"], - visibility = ["//visibility:public"], deps = [ ":file_path", "//mediapipe/framework/port:status", @@ -134,7 +138,6 @@ cc_library( cc_library( name = "image_resizer", hdrs = ["image_resizer.h"], - visibility = ["//visibility:public"], deps = [ "//mediapipe/framework/port:opencv_imgproc", ], @@ -143,6 +146,7 @@ cc_library( cc_library( name = "map_util", hdrs = ["map_util.h"], + # Use this library through "mediapipe/framework/port:map_util". visibility = ["//mediapipe/framework/port:__pkg__"], deps = ["//mediapipe/framework/port:logging"], @@ -151,7 +155,9 @@ cc_library( cc_library( name = "mathutil", hdrs = ["mathutil.h"], - visibility = ["//visibility:public"], + visibility = [ + "//mediapipe:__subpackages__", + ], deps = [ "//mediapipe/framework/port:integral_types", "//mediapipe/framework/port:logging", @@ -171,12 +177,12 @@ cc_library( cc_library( name = "no_destructor", hdrs = ["no_destructor.h"], - visibility = ["//visibility:public"], ) cc_library( name = "point", hdrs = ["point2.h"], + # Use this library through "mediapipe/framework/port:point". visibility = ["//mediapipe/framework/port:__pkg__"], deps = [ @@ -190,13 +196,13 @@ cc_library( cc_library( name = "random", hdrs = ["random_base.h"], - visibility = ["//visibility:public"], deps = ["//mediapipe/framework/port:integral_types"], ) cc_library( name = "rectangle", hdrs = ["rectangle.h"], + # Use this library through "mediapipe/framework/port:rectangle". visibility = ["//mediapipe/framework/port:__pkg__"], deps = [ @@ -211,20 +217,22 @@ cc_library( name = "registration_token", srcs = ["registration_token.cc"], hdrs = ["registration_token.h"], - visibility = ["//visibility:public"], ) cc_library( name = "registration", srcs = ["registration.cc"], hdrs = ["registration.h"], - visibility = ["//visibility:public"], + visibility = [ + "mediapipe_internal", + ], deps = [ ":registration_token", "//mediapipe/framework/port:logging", "//mediapipe/framework/port:status", "//mediapipe/framework/port:statusor", "@com_google_absl//absl/base:core_headers", + "@com_google_absl//absl/container:flat_hash_map", "@com_google_absl//absl/container:flat_hash_set", "@com_google_absl//absl/meta:type_traits", "@com_google_absl//absl/strings", @@ -235,6 +243,7 @@ cc_library( cc_library( name = "singleton", hdrs = ["singleton.h"], + # Use this library through "mediapipe/framework/port:singleton". visibility = ["//mediapipe/framework/port:__pkg__"], deps = [ @@ -245,6 +254,7 @@ cc_library( cc_library( name = "source_location", hdrs = ["source_location.h"], + # Use this library through "mediapipe/framework/port:source_location". visibility = ["//mediapipe/framework/port:__pkg__"], ) @@ -261,6 +271,7 @@ cc_library( "status_builder.h", "status_macros.h", ], + # Use this library through "mediapipe/framework/port:status". visibility = ["//mediapipe/framework/port:__pkg__"], deps = [ @@ -278,13 +289,13 @@ cc_library( hdrs = [ "re2.h", ], - visibility = ["//visibility:public"], ) cc_library( name = "status_matchers", testonly = 1, hdrs = ["status_matchers.h"], + # Use this library through "mediapipe/framework/port:gtest_main". visibility = ["//mediapipe/framework/port:__pkg__"], deps = [ @@ -298,6 +309,7 @@ cc_library( name = "ret_check", srcs = ["ret_check.cc"], hdrs = ["ret_check.h"], + # Use this library through "mediapipe/framework/port:ret_check". visibility = ["//mediapipe/framework/port:__pkg__"], deps = [ @@ -309,7 +321,6 @@ cc_library( cc_library( name = "thread_options", hdrs = ["thread_options.h"], - visibility = ["//visibility:public"], ) cc_library( @@ -319,6 +330,7 @@ cc_library( "//conditions:default": ["threadpool_pthread_impl.cc"], }), hdrs = ["threadpool.h"], + # Use this library through "mediapipe/framework/port:threadpool". visibility = ["//mediapipe/framework/port:__pkg__"], deps = [ @@ -333,6 +345,7 @@ cc_library( name = "topologicalsorter", srcs = ["topologicalsorter.cc"], hdrs = ["topologicalsorter.h"], + # Use this library through "mediapipe/framework/port:topologicalsorter". visibility = ["//mediapipe/framework/port:__pkg__"], deps = [ @@ -343,6 +356,7 @@ cc_library( cc_library( name = "vector", hdrs = ["vector.h"], + # Use this library through "mediapipe/framework/port:vector". visibility = ["//mediapipe/framework/port:__pkg__"], deps = [ @@ -355,7 +369,6 @@ cc_library( cc_test( name = "mathutil_unittest", srcs = ["mathutil_unittest.cc"], - visibility = ["//visibility:public"], deps = [ ":mathutil", "//mediapipe/framework/port:benchmark", @@ -367,7 +380,6 @@ cc_test( name = "registration_token_test", srcs = ["registration_token_test.cc"], linkstatic = 1, - visibility = ["//visibility:public"], deps = [ ":registration_token", "//mediapipe/framework/port:gtest_main", @@ -380,7 +392,6 @@ cc_test( timeout = "long", srcs = ["safe_int_test.cc"], linkstatic = 1, - visibility = ["//visibility:public"], deps = [ ":intops", "//mediapipe/framework/port:gtest_main", @@ -392,7 +403,6 @@ cc_test( name = "monotonic_clock_test", srcs = ["monotonic_clock_test.cc"], linkstatic = 1, - visibility = ["//visibility:public"], deps = [ ":clock", "//mediapipe/framework/port:gtest_main", diff --git a/mediapipe/framework/deps/registration.h b/mediapipe/framework/deps/registration.h index b39a1e293..9d80aafea 100644 --- a/mediapipe/framework/deps/registration.h +++ b/mediapipe/framework/deps/registration.h @@ -26,10 +26,12 @@ #include "absl/base/macros.h" #include "absl/base/thread_annotations.h" +#include "absl/container/flat_hash_map.h" #include "absl/container/flat_hash_set.h" #include "absl/meta/type_traits.h" #include "absl/strings/str_join.h" #include "absl/strings/str_split.h" +#include "absl/strings/string_view.h" #include "absl/synchronization/mutex.h" #include "mediapipe/framework/deps/registration_token.h" #include "mediapipe/framework/port/canonical_errors.h" @@ -159,7 +161,7 @@ class FunctionRegistry { FunctionRegistry(const FunctionRegistry&) = delete; FunctionRegistry& operator=(const FunctionRegistry&) = delete; - RegistrationToken Register(const std::string& name, Function func) + RegistrationToken Register(absl::string_view name, Function func) ABSL_LOCKS_EXCLUDED(lock_) { std::string normalized_name = GetNormalizedName(name); absl::WriterMutexLock lock(&lock_); @@ -189,14 +191,15 @@ class FunctionRegistry { absl::enable_if_t, std::tuple>::value, int> = 0> - ReturnType Invoke(const std::string& name, Args2&&... args) + ReturnType Invoke(absl::string_view name, Args2&&... args) ABSL_LOCKS_EXCLUDED(lock_) { Function function; { absl::ReaderMutexLock lock(&lock_); auto it = functions_.find(name); if (it == functions_.end()) { - return absl::NotFoundError("No registered object with name: " + name); + return absl::NotFoundError( + absl::StrCat("No registered object with name: ", name)); } function = it->second; } @@ -206,7 +209,7 @@ class FunctionRegistry { // Invokes the specified factory function and returns the result. // Namespaces in |name| and |ns| are separated by kNameSep. template - ReturnType Invoke(const std::string& ns, const std::string& name, + ReturnType Invoke(absl::string_view ns, absl::string_view name, Args2&&... args) ABSL_LOCKS_EXCLUDED(lock_) { return Invoke(GetQualifiedName(ns, name), args...); } @@ -214,14 +217,14 @@ class FunctionRegistry { // Note that it's possible for registered implementations to be subsequently // unregistered, though this will never happen with registrations made via // MEDIAPIPE_REGISTER_FACTORY_FUNCTION. - bool IsRegistered(const std::string& name) const ABSL_LOCKS_EXCLUDED(lock_) { + bool IsRegistered(absl::string_view name) const ABSL_LOCKS_EXCLUDED(lock_) { absl::ReaderMutexLock lock(&lock_); return functions_.count(name) != 0; } // Returns true if the specified factory function is available. // Namespaces in |name| and |ns| are separated by kNameSep. - bool IsRegistered(const std::string& ns, const std::string& name) const + bool IsRegistered(absl::string_view ns, absl::string_view name) const ABSL_LOCKS_EXCLUDED(lock_) { return IsRegistered(GetQualifiedName(ns, name)); } @@ -244,13 +247,13 @@ class FunctionRegistry { // Normalizes a C++ qualified name. Validates the name qualification. // The name must be either unqualified or fully qualified with a leading "::". // The leading "::" in a fully qualified name is stripped. - std::string GetNormalizedName(const std::string& name) { + std::string GetNormalizedName(absl::string_view name) { using ::mediapipe::registration_internal::kCxxSep; std::vector names = absl::StrSplit(name, kCxxSep); if (names[0].empty()) { names.erase(names.begin()); } else { - CHECK_EQ(1, names.size()) + CHECK_EQ(1u, names.size()) << "A registered class name must be either fully qualified " << "with a leading :: or unqualified, got: " << name << "."; } @@ -259,8 +262,8 @@ class FunctionRegistry { // Returns the registry key for a name specified within a namespace. // Namespaces are separated by kNameSep. - std::string GetQualifiedName(const std::string& ns, - const std::string& name) const { + std::string GetQualifiedName(absl::string_view ns, + absl::string_view name) const { using ::mediapipe::registration_internal::kCxxSep; using ::mediapipe::registration_internal::kNameSep; std::vector names = absl::StrSplit(name, kNameSep); @@ -287,10 +290,10 @@ class FunctionRegistry { private: mutable absl::Mutex lock_; - std::unordered_map functions_ ABSL_GUARDED_BY(lock_); + absl::flat_hash_map functions_ ABSL_GUARDED_BY(lock_); // For names included in NamespaceAllowlist, strips the namespace. - std::string GetAdjustedName(const std::string& name) { + std::string GetAdjustedName(absl::string_view name) { using ::mediapipe::registration_internal::kCxxSep; std::vector names = absl::StrSplit(name, kCxxSep); std::string base_name = names.back(); @@ -299,10 +302,10 @@ class FunctionRegistry { if (NamespaceAllowlist::TopNamespaces().count(ns)) { return base_name; } - return name; + return std::string(name); } - void Unregister(const std::string& name) { + void Unregister(absl::string_view name) { absl::WriterMutexLock lock(&lock_); std::string adjusted_name = GetAdjustedName(name); if (adjusted_name != name) { @@ -317,7 +320,7 @@ class GlobalFactoryRegistry { using Functions = FunctionRegistry; public: - static RegistrationToken Register(const std::string& name, + static RegistrationToken Register(absl::string_view name, typename Functions::Function func) { return functions()->Register(name, std::move(func)); } @@ -326,7 +329,7 @@ class GlobalFactoryRegistry { // If using namespaces with this registry, the variant with a namespace // argument should be used. template - static typename Functions::ReturnType CreateByName(const std::string& name, + static typename Functions::ReturnType CreateByName(absl::string_view name, Args2&&... args) { return functions()->Invoke(name, std::forward(args)...); } @@ -334,7 +337,7 @@ class GlobalFactoryRegistry { // Returns true if the specified factory function is available. // If using namespaces with this registry, the variant with a namespace // argument should be used. - static bool IsRegistered(const std::string& name) { + static bool IsRegistered(absl::string_view name) { return functions()->IsRegistered(name); } @@ -350,13 +353,13 @@ class GlobalFactoryRegistry { std::tuple>::value, int> = 0> static typename Functions::ReturnType CreateByNameInNamespace( - const std::string& ns, const std::string& name, Args2&&... args) { + absl::string_view ns, absl::string_view name, Args2&&... args) { return functions()->Invoke(ns, name, std::forward(args)...); } // Returns true if the specified factory function is available. // Namespaces in |name| and |ns| are separated by kNameSep. - static bool IsRegistered(const std::string& ns, const std::string& name) { + static bool IsRegistered(absl::string_view ns, absl::string_view name) { return functions()->IsRegistered(ns, name); } diff --git a/mediapipe/framework/deps/status_builder.cc b/mediapipe/framework/deps/status_builder.cc index 70775949d..0202b8689 100644 --- a/mediapipe/framework/deps/status_builder.cc +++ b/mediapipe/framework/deps/status_builder.cc @@ -97,39 +97,24 @@ absl::Status StatusBuilder::Impl::JoinMessageToStatus() { }()); } -StatusBuilder::Impl::Impl(const absl::Status& status, const char* file, - int line) - : status(status), line(line), file(file), stream() {} - -StatusBuilder::Impl::Impl(absl::Status&& status, const char* file, int line) - : status(std::move(status)), line(line), file(file), stream() {} - StatusBuilder::Impl::Impl(const absl::Status& status, mediapipe::source_location location) - : status(status), - line(location.line()), - file(location.file_name()), - stream() {} + : status(status), location(location), stream() {} StatusBuilder::Impl::Impl(absl::Status&& status, mediapipe::source_location location) - : status(std::move(status)), - line(location.line()), - file(location.file_name()), - stream() {} + : status(std::move(status)), location(location), stream() {} StatusBuilder::Impl::Impl(const Impl& other) : status(other.status), - line(other.line), - file(other.file), + location(other.location), no_logging(other.no_logging), stream(other.stream.str()), join_style(other.join_style) {} StatusBuilder::Impl& StatusBuilder::Impl::operator=(const Impl& other) { status = other.status; - line = other.line; - file = other.file; + location = other.location; no_logging = other.no_logging; stream = std::ostringstream(other.stream.str()); join_style = other.join_style; diff --git a/mediapipe/framework/deps/status_builder.h b/mediapipe/framework/deps/status_builder.h index d2e40d575..ae11699d2 100644 --- a/mediapipe/framework/deps/status_builder.h +++ b/mediapipe/framework/deps/status_builder.h @@ -60,17 +60,6 @@ class ABSL_MUST_USE_RESULT StatusBuilder { ? nullptr : std::make_unique(absl::Status(code, ""), location)) {} - StatusBuilder(const absl::Status& original_status, const char* file, int line) - : impl_(original_status.ok() - ? nullptr - : std::make_unique(original_status, file, line)) {} - - StatusBuilder(absl::Status&& original_status, const char* file, int line) - : impl_(original_status.ok() - ? nullptr - : std::make_unique(std::move(original_status), file, - line)) {} - bool ok() const { return !impl_; } StatusBuilder& SetAppend() &; @@ -109,8 +98,6 @@ class ABSL_MUST_USE_RESULT StatusBuilder { kPrepend, }; - Impl(const absl::Status& status, const char* file, int line); - Impl(absl::Status&& status, const char* file, int line); Impl(const absl::Status& status, mediapipe::source_location location); Impl(absl::Status&& status, mediapipe::source_location location); Impl(const Impl&); @@ -120,10 +107,8 @@ class ABSL_MUST_USE_RESULT StatusBuilder { // The status that the result will be based on. absl::Status status; - // The line to record if this file is logged. - int line; - // Not-owned: The file to record if this status is logged. - const char* file; + // The source location to record if this file is logged. + mediapipe::source_location location; // Logging disabled if true. bool no_logging = false; // The additional messages added with `<<`. This is nullptr when status_ is diff --git a/mediapipe/framework/deps/status_builder_test.cc b/mediapipe/framework/deps/status_builder_test.cc index 560acd3c6..f517bb909 100644 --- a/mediapipe/framework/deps/status_builder_test.cc +++ b/mediapipe/framework/deps/status_builder_test.cc @@ -33,21 +33,6 @@ TEST(StatusBuilder, OkStatusRvalue) { ASSERT_EQ(status, absl::OkStatus()); } -TEST(StatusBuilder, OkStatusFileAndLineRvalueStatus) { - absl::Status status = StatusBuilder(absl::OkStatus(), "hello.cc", 1234) - << "annotated message1 " - << "annotated message2"; - ASSERT_EQ(status, absl::OkStatus()); -} - -TEST(StatusBuilder, OkStatusFileAndLineLvalueStatus) { - const auto original_status = absl::OkStatus(); - absl::Status status = StatusBuilder(original_status, "hello.cc", 1234) - << "annotated message1 " - << "annotated message2"; - ASSERT_EQ(status, absl::OkStatus()); -} - TEST(StatusBuilder, AnnotateMode) { absl::Status status = StatusBuilder(absl::Status(absl::StatusCode::kNotFound, "original message"), @@ -60,30 +45,6 @@ TEST(StatusBuilder, AnnotateMode) { "original message; annotated message1 annotated message2"); } -TEST(StatusBuilder, AnnotateModeFileAndLineRvalueStatus) { - absl::Status status = StatusBuilder(absl::Status(absl::StatusCode::kNotFound, - "original message"), - "hello.cc", 1234) - << "annotated message1 " - << "annotated message2"; - ASSERT_FALSE(status.ok()); - EXPECT_EQ(status.code(), absl::StatusCode::kNotFound); - EXPECT_EQ(status.message(), - "original message; annotated message1 annotated message2"); -} - -TEST(StatusBuilder, AnnotateModeFileAndLineLvalueStatus) { - const auto original_status = - absl::Status(absl::StatusCode::kNotFound, "original message"); - absl::Status status = StatusBuilder(original_status, "hello.cc", 1234) - << "annotated message1 " - << "annotated message2"; - ASSERT_FALSE(status.ok()); - EXPECT_EQ(status.code(), absl::StatusCode::kNotFound); - EXPECT_EQ(status.message(), - "original message; annotated message1 annotated message2"); -} - TEST(StatusBuilder, PrependModeLvalue) { StatusBuilder builder( absl::Status(absl::StatusCode::kInvalidArgument, "original message"), diff --git a/mediapipe/framework/deps/status_macros.h b/mediapipe/framework/deps/status_macros.h index 757d99392..92bbf0b84 100644 --- a/mediapipe/framework/deps/status_macros.h +++ b/mediapipe/framework/deps/status_macros.h @@ -81,11 +81,11 @@ // MP_RETURN_IF_ERROR(foo.Method(args...)); // return absl::OkStatus(); // } -#define MP_RETURN_IF_ERROR(expr) \ - STATUS_MACROS_IMPL_ELSE_BLOCKER_ \ - if (mediapipe::status_macro_internal::StatusAdaptorForMacros \ - status_macro_internal_adaptor = {(expr), __FILE__, __LINE__}) { \ - } else /* NOLINT */ \ +#define MP_RETURN_IF_ERROR(expr) \ + STATUS_MACROS_IMPL_ELSE_BLOCKER_ \ + if (mediapipe::status_macro_internal::StatusAdaptorForMacros \ + status_macro_internal_adaptor = {(expr), MEDIAPIPE_LOC}) { \ + } else /* NOLINT */ \ return status_macro_internal_adaptor.Consume() // Executes an expression `rexpr` that returns a `absl::StatusOr`. On @@ -156,14 +156,14 @@ return mediapipe::StatusBuilder( \ std::move(STATUS_MACROS_IMPL_CONCAT_(_status_or_value, __LINE__)) \ .status(), \ - __FILE__, __LINE__)) + MEDIAPIPE_LOC)) #define STATUS_MACROS_IMPL_ASSIGN_OR_RETURN_3_(lhs, rexpr, error_expression) \ STATUS_MACROS_IMPL_ASSIGN_OR_RETURN_( \ STATUS_MACROS_IMPL_CONCAT_(_status_or_value, __LINE__), lhs, rexpr, \ mediapipe::StatusBuilder _( \ std::move(STATUS_MACROS_IMPL_CONCAT_(_status_or_value, __LINE__)) \ .status(), \ - __FILE__, __LINE__); \ + MEDIAPIPE_LOC); \ (void)_; /* error_expression is allowed to not use this variable */ \ return (error_expression)) #define STATUS_MACROS_IMPL_ASSIGN_OR_RETURN_(statusor, lhs, rexpr, \ @@ -201,18 +201,17 @@ namespace status_macro_internal { // that declares a variable. class StatusAdaptorForMacros { public: - StatusAdaptorForMacros(const absl::Status& status, const char* file, int line) - : builder_(status, file, line) {} + StatusAdaptorForMacros(const absl::Status& status, source_location location) + : builder_(status, location) {} - StatusAdaptorForMacros(absl::Status&& status, const char* file, int line) - : builder_(std::move(status), file, line) {} + StatusAdaptorForMacros(absl::Status&& status, source_location location) + : builder_(std::move(status), location) {} - StatusAdaptorForMacros(const StatusBuilder& builder, const char* /* file */, - int /* line */) + StatusAdaptorForMacros(const StatusBuilder& builder, + source_location /*location*/) : builder_(builder) {} - StatusAdaptorForMacros(StatusBuilder&& builder, const char* /* file */, - int /* line */) + StatusAdaptorForMacros(StatusBuilder&& builder, source_location /*location*/) : builder_(std::move(builder)) {} StatusAdaptorForMacros(const StatusAdaptorForMacros&) = delete; diff --git a/mediapipe/framework/formats/BUILD b/mediapipe/framework/formats/BUILD index c3241d911..989ee18f0 100644 --- a/mediapipe/framework/formats/BUILD +++ b/mediapipe/framework/formats/BUILD @@ -17,7 +17,7 @@ load("//mediapipe/framework/port:build_config.bzl", "mediapipe_proto_library") load("//mediapipe/framework:mediapipe_register_type.bzl", "mediapipe_register_type") package( - default_visibility = ["//visibility:private"], + default_visibility = ["//visibility:public"], features = ["-layering_check"], ) @@ -26,8 +26,7 @@ licenses(["notice"]) mediapipe_proto_library( name = "detection_proto", srcs = ["detection.proto"], - visibility = ["//visibility:public"], - deps = ["//mediapipe/framework/formats:location_data_proto"], + deps = [":location_data_proto"], ) mediapipe_register_type( @@ -39,13 +38,12 @@ mediapipe_register_type( "::std::vector<::mediapipe::Detection>", "::std::vector<::mediapipe::DetectionList>", ], - deps = ["//mediapipe/framework/formats:detection_cc_proto"], + deps = [":detection_cc_proto"], ) mediapipe_proto_library( name = "classification_proto", srcs = ["classification.proto"], - visibility = ["//visibility:public"], ) mediapipe_register_type( @@ -64,46 +62,39 @@ mediapipe_register_type( mediapipe_proto_library( name = "image_format_proto", srcs = ["image_format.proto"], - visibility = ["//visibility:public"], ) mediapipe_proto_library( name = "matrix_data_proto", srcs = ["matrix_data.proto"], - visibility = ["//visibility:public"], ) mediapipe_proto_library( name = "location_data_proto", srcs = ["location_data.proto"], portable_deps = ["//mediapipe/framework/formats/annotation:rasterization_cc_proto"], - visibility = ["//visibility:public"], deps = ["//mediapipe/framework/formats/annotation:rasterization_proto"], ) mediapipe_proto_library( name = "affine_transform_data_proto", srcs = ["affine_transform_data.proto"], - visibility = ["//visibility:public"], ) mediapipe_proto_library( name = "time_series_header_proto", srcs = ["time_series_header.proto"], - visibility = ["//visibility:public"], ) mediapipe_proto_library( name = "image_file_properties_proto", srcs = ["image_file_properties.proto"], - visibility = ["//visibility:public"], ) cc_library( name = "deleting_file", srcs = ["deleting_file.cc"], hdrs = ["deleting_file.h"], - visibility = ["//visibility:public"], deps = [ "//mediapipe/framework/port:logging", ], @@ -113,10 +104,9 @@ cc_library( name = "matrix", srcs = ["matrix.cc"], hdrs = ["matrix.h"], - visibility = ["//visibility:public"], deps = [ + ":matrix_data_cc_proto", "//mediapipe/framework:port", - "//mediapipe/framework/formats:matrix_data_cc_proto", "//mediapipe/framework/port:core_proto", "//mediapipe/framework/port:logging", "//mediapipe/framework/port:ret_check", @@ -129,13 +119,10 @@ cc_library( name = "affine_transform", srcs = ["affine_transform.cc"], hdrs = ["affine_transform.h"], - visibility = [ - "//visibility:public", - ], deps = [ + ":affine_transform_data_cc_proto", "//mediapipe/framework:port", "//mediapipe/framework:type_map", - "//mediapipe/framework/formats:affine_transform_data_cc_proto", "//mediapipe/framework/port:integral_types", "//mediapipe/framework/port:logging", "//mediapipe/framework/port:point", @@ -154,9 +141,8 @@ cc_library( name = "image_frame", srcs = ["image_frame.cc"], hdrs = ["image_frame.h"], - visibility = ["//visibility:public"], deps = [ - "//mediapipe/framework/formats:image_format_cc_proto", + ":image_format_cc_proto", "@com_google_absl//absl/base", "@com_google_absl//absl/base:core_headers", "@com_google_absl//absl/memory", @@ -179,10 +165,9 @@ cc_library( name = "image_frame_opencv", srcs = ["image_frame_opencv.cc"], hdrs = ["image_frame_opencv.h"], - visibility = ["//visibility:public"], deps = [ + ":image_format_cc_proto", ":image_frame", - "//mediapipe/framework/formats:image_format_cc_proto", "//mediapipe/framework/port:opencv_core", ], ) @@ -206,11 +191,10 @@ cc_library( name = "location", srcs = ["location.cc"], hdrs = ["location.h"], - visibility = ["//visibility:public"], deps = [ "@com_google_protobuf//:protobuf", - "//mediapipe/framework/formats:location_data_cc_proto", "//mediapipe/framework/formats/annotation:locus_cc_proto", + ":location_data_cc_proto", "@com_google_absl//absl/base:core_headers", "@com_google_absl//absl/memory", "@com_google_absl//absl/strings", @@ -238,9 +222,9 @@ cc_library( name = "location_opencv", srcs = ["location_opencv.cc"], hdrs = ["location_opencv.h"], - visibility = ["//visibility:public"], deps = [ ":location", + "//mediapipe/framework/formats/annotation:rasterization_cc_proto", "//mediapipe/framework/port:opencv_imgproc", ], alwayslink = 1, @@ -251,6 +235,7 @@ cc_test( srcs = ["location_opencv_test.cc"], deps = [ ":location_opencv", + "//mediapipe/framework/formats/annotation:rasterization_cc_proto", "//mediapipe/framework/port:gtest_main", "//mediapipe/framework/port:rectangle", ], @@ -259,16 +244,14 @@ cc_test( cc_library( name = "video_stream_header", hdrs = ["video_stream_header.h"], - visibility = ["//visibility:public"], deps = [ - "//mediapipe/framework/formats:image_format_cc_proto", + ":image_format_cc_proto", ], ) cc_library( name = "yuv_image", hdrs = ["yuv_image.h"], - visibility = ["//visibility:public"], deps = [ "//mediapipe/framework/port:integral_types", "@libyuv", @@ -280,9 +263,9 @@ cc_test( size = "small", srcs = ["image_frame_opencv_test.cc"], deps = [ + ":image_format_cc_proto", ":image_frame", ":image_frame_opencv", - "//mediapipe/framework/formats:image_format_cc_proto", "//mediapipe/framework/port:gtest_main", "//mediapipe/framework/port:integral_types", "//mediapipe/framework/port:logging", @@ -292,7 +275,6 @@ cc_test( mediapipe_proto_library( name = "rect_proto", srcs = ["rect.proto"], - visibility = ["//visibility:public"], ) mediapipe_register_type( @@ -310,9 +292,6 @@ mediapipe_register_type( mediapipe_proto_library( name = "landmark_proto", srcs = ["landmark.proto"], - visibility = [ - "//visibility:public", - ], ) mediapipe_register_type( @@ -344,10 +323,9 @@ cc_library( ], "//conditions:default": [], }), - visibility = ["//visibility:public"], deps = [ - "//mediapipe/framework/formats:image_frame", - "//mediapipe/framework/formats:image_format_cc_proto", + ":image_format_cc_proto", + ":image_frame", "@com_google_absl//absl/synchronization", "//mediapipe/framework:port", "//mediapipe/framework:type_map", @@ -374,10 +352,9 @@ cc_library( name = "image_multi_pool", srcs = ["image_multi_pool.cc"], hdrs = ["image_multi_pool.h"], - visibility = ["//visibility:public"], deps = [ ":image", - "//mediapipe/framework/formats:image_frame_pool", + ":image_frame_pool", "//mediapipe/framework:port", "//mediapipe/framework/port:logging", "@com_google_absl//absl/memory", @@ -411,10 +388,9 @@ cc_library( hdrs = [ "image_opencv.h", ], - visibility = ["//visibility:public"], deps = [ ":image", - "//mediapipe/framework/formats:image_format_cc_proto", + ":image_format_cc_proto", "//mediapipe/framework/port:logging", "//mediapipe/framework/port:opencv_core", "//mediapipe/framework/port:statusor", @@ -425,7 +401,6 @@ cc_library( name = "image_frame_pool", srcs = ["image_frame_pool.cc"], hdrs = ["image_frame_pool.h"], - visibility = ["//visibility:public"], deps = [ ":image_frame", "@com_google_absl//absl/memory", @@ -453,7 +428,13 @@ cc_library( "tensor.cc", "tensor_ahwb.cc", ], - hdrs = ["tensor.h"], + hdrs = [ + "tensor.h", + "//mediapipe/framework/formats/tensor:internal.h", + ] + select({ + "//mediapipe:ios": ["tensor_mtl_buffer_view.h"], + "//conditions:default": [], + }), copts = select({ "//mediapipe:apple": [ "-x objective-c++", @@ -476,8 +457,8 @@ cc_library( "-landroid", ], }), - visibility = ["//visibility:public"], deps = [ + "@com_google_absl//absl/container:flat_hash_map", "@com_google_absl//absl/memory", "@com_google_absl//absl/synchronization", "//mediapipe/framework:port", @@ -505,3 +486,12 @@ cc_test( "//mediapipe/gpu:disable_gpu": [], }), ) + +cc_library( + name = "frame_buffer", + hdrs = ["frame_buffer.h"], + deps = [ + "//mediapipe/framework/port:integral_types", + "@com_google_absl//absl/log:check", + ], +) diff --git a/mediapipe/framework/formats/annotation/BUILD b/mediapipe/framework/formats/annotation/BUILD index 328001e85..9bcb7bccd 100644 --- a/mediapipe/framework/formats/annotation/BUILD +++ b/mediapipe/framework/formats/annotation/BUILD @@ -16,7 +16,7 @@ load("//mediapipe/framework/port:build_config.bzl", "mediapipe_proto_library") -package(default_visibility = ["//visibility:private"]) +package(default_visibility = ["//visibility:public"]) licenses(["notice"]) @@ -24,12 +24,10 @@ mediapipe_proto_library( name = "locus_proto", srcs = ["locus.proto"], portable_deps = ["//mediapipe/framework/formats/annotation:rasterization_cc_proto"], - visibility = ["//visibility:public"], deps = ["//mediapipe/framework/formats/annotation:rasterization_proto"], ) mediapipe_proto_library( name = "rasterization_proto", srcs = ["rasterization.proto"], - visibility = ["//visibility:public"], ) diff --git a/mediapipe/framework/formats/frame_buffer.h b/mediapipe/framework/formats/frame_buffer.h new file mode 100644 index 000000000..ccc699724 --- /dev/null +++ b/mediapipe/framework/formats/frame_buffer.h @@ -0,0 +1,161 @@ +/* Copyright 2022 The MediaPipe Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef MEDIAPIPE_FRAMEWORK_FORMATS_FRAME_BUFFER_H_ +#define MEDIAPIPE_FRAMEWORK_FORMATS_FRAME_BUFFER_H_ + +#include + +#include "absl/log/check.h" +#include "mediapipe/framework/port/integral_types.h" + +namespace mediapipe { + +// A `FrameBuffer` provides a view into the provided backing buffer (e.g. camera +// frame or still image) with buffer format information. FrameBuffer doesn't +// take ownership of the provided backing buffer. The caller is responsible to +// manage the backing buffer lifecycle for the lifetime of the FrameBuffer. +// +// Examples: +// +// // Create an metadata instance with no backing buffer. +// FrameBuffer buffer{/*planes=*/{}, dimension, kRGBA}; +// +// // Create an RGBA instance with backing buffer on single plane. +// FrameBuffer::Plane plane{rgba_buffer, /*stride=*/{dimension.width * 4, 4}}; +// FrameBuffer buffer{{plane}, dimension, kRGBA, kTopLeft)}; +// +// // Create an YUV instance with planar backing buffer. +// FrameBuffer::Plane y_plane{y_buffer, /*stride=*/{dimension.width , 1}}; +// FrameBuffer::Plane uv_plane{u_buffer, /*stride=*/{dimension.width, 2}}; +// FrameBuffer buffer{{y_plane, uv_plane}, dimension, kNV21}; +class FrameBuffer { + public: + // Colorspace formats. + enum class Format { + kRGBA, + kRGB, + kNV12, + kNV21, + kYV12, + kYV21, + kGRAY, + kUNKNOWN + }; + + // Stride information. + struct Stride { + // The row stride in bytes. This is the distance between the start pixels of + // two consecutive rows in the image. + int row_stride_bytes; + // This is the distance between two consecutive pixel values in a row of + // pixels in bytes. It may be larger than the size of a single pixel to + // account for interleaved image data or padded formats. + int pixel_stride_bytes; + + bool operator==(const Stride& other) const { + return row_stride_bytes == other.row_stride_bytes && + pixel_stride_bytes == other.pixel_stride_bytes; + } + + bool operator!=(const Stride& other) const { return !operator==(other); } + }; + + // Plane encapsulates buffer and stride information. + struct Plane { + Plane(uint8* buffer, Stride stride) : buffer_(buffer), stride_(stride) {} + const uint8* buffer() const { return buffer_; } + uint8* mutable_buffer() { return buffer_; } + Stride stride() const { return stride_; } + + private: + uint8* buffer_; + Stride stride_; + }; + + // Dimension information for the whole frame or a cropped portion of it. + struct Dimension { + // The width dimension in pixel unit. + int width; + // The height dimension in pixel unit. + int height; + + bool operator==(const Dimension& other) const { + return width == other.width && height == other.height; + } + + bool operator!=(const Dimension& other) const { + return width != other.width || height != other.height; + } + + bool operator>=(const Dimension& other) const { + return width >= other.width && height >= other.height; + } + + bool operator<=(const Dimension& other) const { + return width <= other.width && height <= other.height; + } + + // Swaps width and height. + void Swap() { + using std::swap; + swap(width, height); + } + + // Returns area represented by width * height. + int Size() const { return width * height; } + }; + + // Builds a FrameBuffer object from a row-major backing buffer. + // + // The FrameBuffer does not take ownership of the backing buffer. The caller + // is responsible for maintaining the backing buffer lifecycle for the + // lifetime of FrameBuffer. + FrameBuffer(const std::vector& planes, Dimension dimension, + Format format) + : planes_(planes), dimension_(dimension), format_(format) {} + + // Returns number of planes. + int plane_count() const { return planes_.size(); } + + // Returns plane indexed by the input `index`. + const Plane& plane(int index) const { + CHECK_GE(index, 0); + CHECK_LT(static_cast(index), planes_.size()); + return planes_[index]; + } + + // Returns mutable plane indexed by the input `index`. + Plane mutable_plane(int index) { + CHECK_GE(index, 0); + CHECK_LT(static_cast(index), planes_.size()); + return planes_[index]; + } + + // Returns FrameBuffer dimension. + Dimension dimension() const { return dimension_; } + + // Returns FrameBuffer format. + Format format() const { return format_; } + + private: + std::vector planes_; + Dimension dimension_; + Format format_; +}; + +} // namespace mediapipe + +#endif // MEDIAPIPE_FRAMEWORK_FORMATS_FRAME_BUFFER_H_ diff --git a/mediapipe/framework/formats/motion/BUILD b/mediapipe/framework/formats/motion/BUILD index 28e0bfc6a..c9bb8b4ff 100644 --- a/mediapipe/framework/formats/motion/BUILD +++ b/mediapipe/framework/formats/motion/BUILD @@ -16,22 +16,20 @@ # Description: # Working with dense optical flow in mediapipe. -licenses(["notice"]) - load("//mediapipe/framework/port:build_config.bzl", "mediapipe_cc_proto_library") -package(default_visibility = ["//visibility:private"]) +licenses(["notice"]) + +package(default_visibility = ["//visibility:public"]) proto_library( name = "optical_flow_field_data_proto", srcs = ["optical_flow_field_data.proto"], - visibility = ["//visibility:public"], ) mediapipe_cc_proto_library( name = "optical_flow_field_data_cc_proto", srcs = ["optical_flow_field_data.proto"], - visibility = ["//visibility:public"], deps = [":optical_flow_field_data_proto"], ) @@ -39,15 +37,12 @@ cc_library( name = "optical_flow_field", srcs = ["optical_flow_field.cc"], hdrs = ["optical_flow_field.h"], - visibility = [ - "//visibility:public", - ], deps = [ + ":optical_flow_field_data_cc_proto", "//mediapipe/framework:type_map", "//mediapipe/framework/deps:mathutil", "//mediapipe/framework/formats:location", "//mediapipe/framework/formats:location_opencv", - "//mediapipe/framework/formats/motion:optical_flow_field_data_cc_proto", "//mediapipe/framework/port:file_helpers", "//mediapipe/framework/port:integral_types", "//mediapipe/framework/port:logging", diff --git a/mediapipe/framework/formats/object_detection/BUILD b/mediapipe/framework/formats/object_detection/BUILD index 39940acdc..35292e1cc 100644 --- a/mediapipe/framework/formats/object_detection/BUILD +++ b/mediapipe/framework/formats/object_detection/BUILD @@ -19,17 +19,15 @@ load("//mediapipe/framework/port:build_config.bzl", "mediapipe_cc_proto_library" licenses(["notice"]) -package(default_visibility = ["//visibility:private"]) +package(default_visibility = ["//visibility:public"]) proto_library( name = "anchor_proto", srcs = ["anchor.proto"], - visibility = ["//visibility:public"], ) mediapipe_cc_proto_library( name = "anchor_cc_proto", srcs = ["anchor.proto"], - visibility = ["//visibility:public"], deps = [":anchor_proto"], ) diff --git a/mediapipe/framework/formats/tensor.cc b/mediapipe/framework/formats/tensor.cc index ef0cddea4..1dbd8f8ac 100644 --- a/mediapipe/framework/formats/tensor.cc +++ b/mediapipe/framework/formats/tensor.cc @@ -25,8 +25,11 @@ #endif // MEDIAPIPE_OPENGL_ES_VERSION >= MEDIAPIPE_OPENGL_ES_30 #if MEDIAPIPE_METAL_ENABLED +#import #include #include + +#include "mediapipe/framework/formats/tensor_mtl_buffer_view.h" #else #include #endif // MEDIAPIPE_METAL_ENABLED @@ -61,6 +64,12 @@ int BhwcDepthFromShape(const Tensor::Shape& shape) { // 3) pad/"unpad" the bitmap after transfer CPU <-> GPU #if MEDIAPIPE_METAL_ENABLED +// No ODR violation here because this file compiled just once per project. +struct MtlResources { + id command_buffer = nil; + id device = nil; + id metal_buffer = nil; +}; namespace { // MTLBuffer can use existing properly aligned and allocated CPU memory. size_t AlignToPageSize(size_t size) { @@ -83,52 +92,56 @@ void DeallocateVirtualMemory(void* pointer, size_t size) { } } // namespace -Tensor::MtlBufferView Tensor::GetMtlBufferReadView( - id command_buffer) const { - LOG_IF(FATAL, valid_ == kValidNone) +void MtlBufferView::AllocateMtlBuffer(const Tensor& tensor, + id device) { + tensor.mtl_resources_->device = device; + if (!tensor.cpu_buffer_) { + // It also means that the metal buffer is not allocated yet. + tensor.cpu_buffer_ = AllocateVirtualMemory(tensor.bytes()); + } + if (!tensor.mtl_resources_->metal_buffer) { + tensor.mtl_resources_->metal_buffer = [tensor.mtl_resources_->device + newBufferWithBytesNoCopy:tensor.cpu_buffer_ + length:AlignToPageSize(tensor.bytes()) + options:MTLResourceStorageModeShared | + MTLResourceCPUCacheModeDefaultCache + deallocator:^(void* pointer, NSUInteger length) { + DeallocateVirtualMemory(pointer, length); + }]; + } +} + +MtlBufferView MtlBufferView::GetReadView(const Tensor& tensor, + id command_buffer) { + LOG_IF(FATAL, tensor.valid_ == Tensor::kValidNone) << "Tensor must be written prior to read from."; - LOG_IF(FATAL, !(valid_ & (kValidCpu | kValidMetalBuffer))) + LOG_IF(FATAL, + !(tensor.valid_ & (Tensor::kValidCpu | Tensor::kValidMetalBuffer))) << "Tensor conversion between different GPU resources is not supported " "yet."; - auto lock(absl::make_unique(&view_mutex_)); - valid_ |= kValidMetalBuffer; - AllocateMtlBuffer([command_buffer device]); - return {metal_buffer_, std::move(lock)}; + auto lock(absl::make_unique(&tensor.view_mutex_)); + tensor.valid_ |= Tensor::kValidMetalBuffer; + AllocateMtlBuffer(tensor, [command_buffer device]); + return {tensor.mtl_resources_->metal_buffer, std::move(lock)}; } -Tensor::MtlBufferView Tensor::GetMtlBufferWriteView( - id command_buffer) const { +MtlBufferView MtlBufferView::GetWriteView(const Tensor& tensor, + id command_buffer) { // Don't overwrite command buffer at which the metal buffer has been written // so we can wait until completed. - command_buffer_ = command_buffer; - return GetMtlBufferWriteView([command_buffer device]); + tensor.mtl_resources_->command_buffer = command_buffer; + return GetWriteView(tensor, [command_buffer device]); } -Tensor::MtlBufferView Tensor::GetMtlBufferWriteView( - id device) const { - auto lock(absl::make_unique(&view_mutex_)); - valid_ = kValidMetalBuffer; - AllocateMtlBuffer(device); - return {metal_buffer_, std::move(lock)}; -} - -void Tensor::AllocateMtlBuffer(id device) const { - device_ = device; - if (!cpu_buffer_) { - // It also means that the metal buffer is not allocated yet. - cpu_buffer_ = AllocateVirtualMemory(bytes()); - } - if (!metal_buffer_) { - metal_buffer_ = - [device_ newBufferWithBytesNoCopy:cpu_buffer_ - length:AlignToPageSize(bytes()) - options:MTLResourceStorageModeShared | - MTLResourceCPUCacheModeDefaultCache - deallocator:^(void* pointer, NSUInteger length) { - DeallocateVirtualMemory(pointer, length); - }]; - } +MtlBufferView MtlBufferView::GetWriteView(const Tensor& tensor, + id device) { + auto lock(absl::make_unique(&tensor.view_mutex_)); + tensor.valid_ = Tensor::kValidMetalBuffer; + AllocateMtlBuffer(tensor, device); + return {tensor.mtl_resources_->metal_buffer, std::move(lock)}; } +#else +struct MtlResources {}; #endif // MEDIAPIPE_METAL_ENABLED #if MEDIAPIPE_OPENGL_ES_VERSION >= MEDIAPIPE_OPENGL_ES_30 @@ -246,10 +259,10 @@ Tensor::OpenGlTexture2dView::GetLayoutDimensions(const Tensor::Shape& shape, return Tensor::OpenGlTexture2dView::Layout::kAligned; } } - // The best performance of a compute shader can be achived with textures' + // The best performance of a compute shader can be achieved with textures' // width multiple of 256. Making minimum fixed width of 256 waste memory for // small tensors. The optimal balance memory-vs-performance is power of 2. - // The texture width and height are choosen to be closer to square. + // The texture width and height are chosen to be closer to square. float power = std::log2(std::sqrt(static_cast(num_pixels))); w = 1 << static_cast(power); int h = (num_pixels + w - 1) / w; @@ -326,7 +339,7 @@ Tensor::OpenGlBufferView Tensor::GetOpenGlBufferReadView() const { auto lock(absl::make_unique(&view_mutex_)); AllocateOpenGlBuffer(); if (!(valid_ & kValidOpenGlBuffer)) { - // If the call succeds then AHWB -> SSBO are synchronized so any usage of + // If the call succeeds then AHWB -> SSBO are synchronized so any usage of // the SSBO is correct after this call. if (!InsertAhwbToSsboFence()) { glBindBuffer(GL_SHADER_STORAGE_BUFFER, opengl_buffer_); @@ -348,8 +361,10 @@ Tensor::OpenGlBufferView Tensor::GetOpenGlBufferReadView() const { }; } -Tensor::OpenGlBufferView Tensor::GetOpenGlBufferWriteView() const { +Tensor::OpenGlBufferView Tensor::GetOpenGlBufferWriteView( + uint64_t source_location_hash) const { auto lock(absl::make_unique(&view_mutex_)); + TrackAhwbUsage(source_location_hash); AllocateOpenGlBuffer(); valid_ = kValidOpenGlBuffer; return {opengl_buffer_, std::move(lock), nullptr}; @@ -361,7 +376,7 @@ void Tensor::AllocateOpenGlBuffer() const { LOG_IF(FATAL, !gl_context_) << "GlContext is not bound to the thread."; glGenBuffers(1, &opengl_buffer_); glBindBuffer(GL_SHADER_STORAGE_BUFFER, opengl_buffer_); - if (!AllocateAhwbMapToSsbo()) { + if (!use_ahwb_ || !AllocateAhwbMapToSsbo()) { glBufferData(GL_SHADER_STORAGE_BUFFER, bytes(), NULL, GL_STREAM_COPY); } glBindBuffer(GL_SHADER_STORAGE_BUFFER, 0); @@ -377,6 +392,9 @@ Tensor& Tensor::operator=(Tensor&& src) { return *this; } +Tensor::Tensor(Tensor&& src) { Move(&src); } +Tensor::~Tensor() { Invalidate(); } + void Tensor::Move(Tensor* src) { valid_ = src->valid_; src->valid_ = kValidNone; @@ -385,13 +403,8 @@ void Tensor::Move(Tensor* src) { src->element_type_ = ElementType::kNone; // Mark as invalidated. cpu_buffer_ = src->cpu_buffer_; src->cpu_buffer_ = nullptr; -#if MEDIAPIPE_METAL_ENABLED - device_ = src->device_; - command_buffer_ = src->command_buffer_; - metal_buffer_ = src->metal_buffer_; - src->metal_buffer_ = nil; -#endif // MEDIAPIPE_METAL_ENABLED - + ahwb_tracking_key_ = src->ahwb_tracking_key_; + mtl_resources_ = std::move(src->mtl_resources_); MoveAhwbStuff(src); #if MEDIAPIPE_OPENGL_ES_VERSION >= MEDIAPIPE_OPENGL_ES_30 @@ -410,12 +423,15 @@ void Tensor::Move(Tensor* src) { } Tensor::Tensor(ElementType element_type, const Shape& shape) - : element_type_(element_type), shape_(shape) {} + : element_type_(element_type), + shape_(shape), + mtl_resources_(std::make_unique()) {} Tensor::Tensor(ElementType element_type, const Shape& shape, const QuantizationParameters& quantization_parameters) : element_type_(element_type), shape_(shape), - quantization_parameters_(quantization_parameters) {} + quantization_parameters_(quantization_parameters), + mtl_resources_(std::make_unique()) {} #if MEDIAPIPE_METAL_ENABLED void Tensor::Invalidate() { @@ -427,11 +443,16 @@ void Tensor::Invalidate() { absl::MutexLock lock(&view_mutex_); // If memory is allocated and not owned by the metal buffer. // TODO: Re-design cpu buffer memory management. - if (cpu_buffer_ && !metal_buffer_) { + if (cpu_buffer_ && !mtl_resources_->metal_buffer) { DeallocateVirtualMemory(cpu_buffer_, AlignToPageSize(bytes())); } - metal_buffer_ = nil; cpu_buffer_ = nullptr; + // This becomes NULL if the tensor is moved. + if (mtl_resources_) { + mtl_resources_->metal_buffer = nil; + mtl_resources_->command_buffer = nil; + mtl_resources_->device = nil; + } #if MEDIAPIPE_OPENGL_ES_VERSION >= MEDIAPIPE_OPENGL_ES_30 // Don't need to wait for the resource to be deleted bacause if will be // released on last reference deletion inside the OpenGL driver. @@ -525,10 +546,11 @@ Tensor::CpuReadView Tensor::GetCpuReadView() const { // GPU-to-CPU synchronization and read-back. #if MEDIAPIPE_METAL_ENABLED if (valid_ & kValidMetalBuffer) { - LOG_IF(FATAL, !command_buffer_) << "Metal -> CPU synchronization " - "requires MTLCommandBuffer to be set."; - if (command_buffer_) { - [command_buffer_ waitUntilCompleted]; + LOG_IF(FATAL, !mtl_resources_->command_buffer) + << "Metal -> CPU synchronization " + "requires MTLCommandBuffer to be set."; + if (mtl_resources_->command_buffer) { + [mtl_resources_->command_buffer waitUntilCompleted]; } } #endif // MEDIAPIPE_METAL_ENABLED @@ -547,7 +569,7 @@ Tensor::CpuReadView Tensor::GetCpuReadView() const { }); } else #endif // MEDIAPIPE_OPENGL_ES_VERSION >= MEDIAPIPE_OPENGL_ES_31 - + { // Transfer data from texture if not transferred from SSBO/MTLBuffer // yet. if (valid_ & kValidOpenGlTexture2d) { @@ -578,14 +600,17 @@ Tensor::CpuReadView Tensor::GetCpuReadView() const { } }); } + } #endif // MEDIAPIPE_OPENGL_ES_VERSION >= MEDIAPIPE_OPENGL_ES_30 valid_ |= kValidCpu; } return {cpu_buffer_, std::move(lock)}; } -Tensor::CpuWriteView Tensor::GetCpuWriteView() const { +Tensor::CpuWriteView Tensor::GetCpuWriteView( + uint64_t source_location_hash) const { auto lock = absl::make_unique(&view_mutex_); + TrackAhwbUsage(source_location_hash); AllocateCpuBuffer(); valid_ = kValidCpu; #ifdef MEDIAPIPE_TENSOR_USE_AHWB @@ -605,7 +630,7 @@ Tensor::CpuWriteView Tensor::GetCpuWriteView() const { void Tensor::AllocateCpuBuffer() const { if (!cpu_buffer_) { #ifdef MEDIAPIPE_TENSOR_USE_AHWB - if (AllocateAHardwareBuffer()) return; + if (use_ahwb_ && AllocateAHardwareBuffer()) return; #endif // MEDIAPIPE_TENSOR_USE_AHWB #if MEDIAPIPE_METAL_ENABLED cpu_buffer_ = AllocateVirtualMemory(bytes()); @@ -615,24 +640,4 @@ void Tensor::AllocateCpuBuffer() const { } } -void Tensor::SetPreferredStorageType(StorageType type) { -#ifdef MEDIAPIPE_TENSOR_USE_AHWB - if (__builtin_available(android 26, *)) { - use_ahwb_ = type == StorageType::kAhwb; - VLOG(4) << "Tensor: use of AHardwareBuffer is " - << (use_ahwb_ ? "allowed" : "not allowed"); - } -#else - VLOG(4) << "Tensor: use of AHardwareBuffer is not allowed"; -#endif // MEDIAPIPE_TENSOR_USE_AHWB -} - -Tensor::StorageType Tensor::GetPreferredStorageType() { -#ifdef MEDIAPIPE_TENSOR_USE_AHWB - return use_ahwb_ ? StorageType::kAhwb : StorageType::kDefault; -#else - return StorageType::kDefault; -#endif // MEDIAPIPE_TENSOR_USE_AHWB -} - } // namespace mediapipe diff --git a/mediapipe/framework/formats/tensor.h b/mediapipe/framework/formats/tensor.h index ff9da3ec6..1d670d805 100644 --- a/mediapipe/framework/formats/tensor.h +++ b/mediapipe/framework/formats/tensor.h @@ -24,13 +24,11 @@ #include #include -#include "absl/memory/memory.h" +#include "absl/container/flat_hash_map.h" #include "absl/synchronization/mutex.h" +#include "mediapipe/framework/formats/tensor/internal.h" #include "mediapipe/framework/port.h" -#if MEDIAPIPE_METAL_ENABLED -#import -#endif // MEDIAPIPE_METAL_ENABLED #ifndef MEDIAPIPE_NO_JNI #if __ANDROID_API__ >= 26 || defined(__ANDROID_UNAVAILABLE_SYMBOLS_ARE_WEAK__) #define MEDIAPIPE_TENSOR_USE_AHWB 1 @@ -39,18 +37,32 @@ #endif // MEDIAPIPE_NO_JNI #ifdef MEDIAPIPE_TENSOR_USE_AHWB +#include +#include #include - -#include "third_party/GL/gl/include/EGL/egl.h" -#include "third_party/GL/gl/include/EGL/eglext.h" #endif // MEDIAPIPE_TENSOR_USE_AHWB #if MEDIAPIPE_OPENGL_ES_VERSION >= MEDIAPIPE_OPENGL_ES_30 #include "mediapipe/gpu/gl_base.h" #include "mediapipe/gpu/gl_context.h" #endif // MEDIAPIPE_OPENGL_ES_VERSION >= MEDIAPIPE_OPENGL_ES_30 -namespace mediapipe { +#if defined __has_builtin +#if __has_builtin(__builtin_LINE) +#define builtin_LINE __builtin_LINE +#endif +#if __has_builtin(__builtin_FILE) +#define builtin_FILE __builtin_FILE +#endif +#endif +#ifndef builtin_LINE +#define builtin_LINE() 0 +#endif +#ifndef builtin_FILE +#define builtin_FILE() "" +#endif + +namespace mediapipe { // Tensor is a container of multi-dimensional data that supports sharing the // content across different backends and APIs, currently: CPU / Metal / OpenGL. // Texture2DView is limited to 4 dimensions. @@ -66,7 +78,7 @@ namespace mediapipe { // GLuint buffer = view.buffer(); // Then the buffer can be bound to the GPU command buffer. // ...binding the buffer to the command buffer... -// ...commiting command buffer and releasing the view... +// ...committing command buffer and releasing the view... // // The following request for the CPU view will be blocked until the GPU view is // released and the GPU task is finished. @@ -75,6 +87,7 @@ namespace mediapipe { // float* pointer = view.buffer(); // ...reading the cpu memory... +struct MtlResources; class Tensor { class View { public: @@ -128,9 +141,9 @@ class Tensor { Tensor(const Tensor&) = delete; Tensor& operator=(const Tensor&) = delete; // Move-only. - Tensor(Tensor&& src) { Move(&src); } + Tensor(Tensor&& src); Tensor& operator=(Tensor&&); - ~Tensor() { Invalidate(); } + ~Tensor(); template class CpuView : public View { @@ -162,34 +175,9 @@ class Tensor { using CpuReadView = CpuView; CpuReadView GetCpuReadView() const; using CpuWriteView = CpuView; - CpuWriteView GetCpuWriteView() const; - -#if MEDIAPIPE_METAL_ENABLED - // TODO: id vs. MtlBufferView. - class MtlBufferView : public View { - public: - id buffer() const { return buffer_; } - MtlBufferView(MtlBufferView&& src) - : View(std::move(src)), buffer_(src.buffer_) { - src.buffer_ = nil; - } - - protected: - friend class Tensor; - MtlBufferView(id buffer, std::unique_ptr&& lock) - : View(std::move(lock)), buffer_(buffer) {} - id buffer_; - }; - // The command buffer status is checked for completeness if GPU-to-CPU - // synchronization is required. - // TODO: Design const and non-const view acquiring. - MtlBufferView GetMtlBufferReadView(id command_buffer) const; - MtlBufferView GetMtlBufferWriteView( - id command_buffer) const; - // Allocate new buffer. - // TODO: GPU-to-CPU design considerations. - MtlBufferView GetMtlBufferWriteView(id device) const; -#endif // MEDIAPIPE_METAL_ENABLED + CpuWriteView GetCpuWriteView( + uint64_t source_location_hash = + tensor_internal::FnvHash64(builtin_FILE(), builtin_LINE())) const; #ifdef MEDIAPIPE_TENSOR_USE_AHWB using FinishingFunc = std::function; @@ -306,7 +294,9 @@ class Tensor { // A valid OpenGL context must be bound to the calling thread due to possible // GPU resource allocation. OpenGlBufferView GetOpenGlBufferReadView() const; - OpenGlBufferView GetOpenGlBufferWriteView() const; + OpenGlBufferView GetOpenGlBufferWriteView( + uint64_t source_location_hash = + tensor_internal::FnvHash64(builtin_FILE(), builtin_LINE())) const; #endif // MEDIAPIPE_OPENGL_ES_VERSION >= MEDIAPIPE_OPENGL_ES_31 const Shape& shape() const { return shape_; } @@ -350,15 +340,9 @@ class Tensor { bool ready_as_opengl_texture_2d() const { return valid_ & kValidOpenGlTexture2d; } - // Sets the type of underlying resource that is going to be allocated. - enum class StorageType { - kDefault, - kAhwb, - }; - static void SetPreferredStorageType(StorageType type); - static StorageType GetPreferredStorageType(); private: + friend class MtlBufferView; void Move(Tensor*); void Invalidate(); @@ -383,12 +367,9 @@ class Tensor { mutable void* cpu_buffer_ = nullptr; void AllocateCpuBuffer() const; -#if MEDIAPIPE_METAL_ENABLED - mutable id command_buffer_; - mutable id device_; - mutable id metal_buffer_; - void AllocateMtlBuffer(id device) const; -#endif // MEDIAPIPE_METAL_ENABLED + // Forward declaration of the MtlResources provides compile-time verification + // of ODR if this header includes any actual code that uses MtlResources. + mutable std::unique_ptr mtl_resources_; #ifdef MEDIAPIPE_TENSOR_USE_AHWB mutable AHardwareBuffer* ahwb_ = nullptr; @@ -409,9 +390,14 @@ class Tensor { mutable std::function release_callback_; bool AllocateAHardwareBuffer(int size_alignment = 0) const; void CreateEglSyncAndFd() const; - // Use Ahwb for other views: OpenGL / CPU buffer. - static inline bool use_ahwb_ = false; #endif // MEDIAPIPE_TENSOR_USE_AHWB + // Use Ahwb for other views: OpenGL / CPU buffer. + mutable bool use_ahwb_ = false; + mutable uint64_t ahwb_tracking_key_ = 0; + // TODO: Tracks all unique tensors. Can grow to a large number. LRU + // (Least Recently Used) can be more predicted. + // The value contains the size alignment parameter. + static inline absl::flat_hash_map ahwb_usage_track_; // Expects the target SSBO to be already bound. bool AllocateAhwbMapToSsbo() const; bool InsertAhwbToSsboFence() const; @@ -419,6 +405,9 @@ class Tensor { void ReleaseAhwbStuff(); void* MapAhwbToCpuRead() const; void* MapAhwbToCpuWrite() const; + void MoveCpuOrSsboToAhwb() const; + // Set current tracking key, set "use ahwb" if the key is already marked. + void TrackAhwbUsage(uint64_t key) const; #if MEDIAPIPE_OPENGL_ES_VERSION >= MEDIAPIPE_OPENGL_ES_30 mutable std::shared_ptr gl_context_; diff --git a/mediapipe/framework/formats/tensor/BUILD b/mediapipe/framework/formats/tensor/BUILD new file mode 100644 index 000000000..3895fc82e --- /dev/null +++ b/mediapipe/framework/formats/tensor/BUILD @@ -0,0 +1,24 @@ +# Copyright 2019 The MediaPipe Authors. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +package( + default_visibility = ["//visibility:private"], + features = ["-layering_check"], +) + +licenses(["notice"]) + +exports_files([ + "internal.h", +]) diff --git a/mediapipe/framework/formats/tensor_internal.h b/mediapipe/framework/formats/tensor/internal.h similarity index 92% rename from mediapipe/framework/formats/tensor_internal.h rename to mediapipe/framework/formats/tensor/internal.h index 1231a991c..c223c5b1d 100644 --- a/mediapipe/framework/formats/tensor_internal.h +++ b/mediapipe/framework/formats/tensor/internal.h @@ -18,8 +18,6 @@ #include #include -#include "mediapipe/framework/tool/type_util.h" - namespace mediapipe { // Generates unique view id at compile-time using FILE and LINE. @@ -41,10 +39,12 @@ namespace tensor_internal { // https://en.wikipedia.org/wiki/Fowler%E2%80%93Noll%E2%80%93Vo_hash_function constexpr uint64_t kFnvPrime = 0x00000100000001B3; constexpr uint64_t kFnvOffsetBias = 0xcbf29ce484222325; -constexpr uint64_t FnvHash64(const char* str, uint64_t hash = kFnvOffsetBias) { - return (str[0] == 0) ? hash : FnvHash64(str + 1, (hash ^ str[0]) * kFnvPrime); +constexpr uint64_t FnvHash64(uint64_t value1, uint64_t value2) { + return (value2 ^ value1) * kFnvPrime; +} +constexpr uint64_t FnvHash64(const char* str, uint64_t hash = kFnvOffsetBias) { + return (str[0] == 0) ? hash : FnvHash64(str + 1, FnvHash64(hash, str[0])); } - template struct TypeList { static constexpr std::size_t size{sizeof...(Ts)}; diff --git a/mediapipe/framework/formats/tensor_ahwb.cc b/mediapipe/framework/formats/tensor_ahwb.cc index b11f6b55b..525f05f31 100644 --- a/mediapipe/framework/formats/tensor_ahwb.cc +++ b/mediapipe/framework/formats/tensor_ahwb.cc @@ -4,12 +4,13 @@ #include "mediapipe/framework/formats/tensor.h" #ifdef MEDIAPIPE_TENSOR_USE_AHWB +#include +#include + #include "absl/synchronization/mutex.h" #include "mediapipe/framework/port.h" #include "mediapipe/framework/port/logging.h" #include "mediapipe/gpu/gl_base.h" -#include "third_party/GL/gl/include/EGL/egl.h" -#include "third_party/GL/gl/include/EGL/eglext.h" #endif // MEDIAPIPE_TENSOR_USE_AHWB namespace mediapipe { @@ -211,13 +212,15 @@ Tensor::AHardwareBufferView Tensor::GetAHardwareBufferReadView() const { CHECK(!(valid_ & kValidOpenGlTexture2d)) << "Tensor conversion between OpenGL texture and AHardwareBuffer is not " "supported."; - CHECK(ahwb_ || !(valid_ & kValidOpenGlBuffer)) - << "Interoperability bettween OpenGL buffer and AHardwareBuffer is not " - "supported on targe system."; + bool transfer = !ahwb_; CHECK(AllocateAHardwareBuffer()) << "AHardwareBuffer is not supported on the target system."; valid_ |= kValidAHardwareBuffer; - if (valid_ & kValidOpenGlBuffer) CreateEglSyncAndFd(); + if (transfer) { + MoveCpuOrSsboToAhwb(); + } else { + if (valid_ & kValidOpenGlBuffer) CreateEglSyncAndFd(); + } return {ahwb_, ssbo_written_, &fence_fd_, // The FD is created for SSBO -> AHWB synchronization. @@ -262,7 +265,15 @@ Tensor::AHardwareBufferView Tensor::GetAHardwareBufferWriteView( } bool Tensor::AllocateAHardwareBuffer(int size_alignment) const { - if (!use_ahwb_) return false; + // Mark current tracking key as Ahwb-use. + if (auto it = ahwb_usage_track_.find(ahwb_tracking_key_); + it != ahwb_usage_track_.end()) { + size_alignment = it->second; + } else if (ahwb_tracking_key_ != 0) { + ahwb_usage_track_.insert({ahwb_tracking_key_, size_alignment}); + } + use_ahwb_ = true; + if (__builtin_available(android 26, *)) { if (ahwb_ == nullptr) { AHardwareBuffer_Desc desc = {}; @@ -302,6 +313,43 @@ bool Tensor::AllocateAhwbMapToSsbo() const { return false; } +// Moves Cpu/Ssbo resource under the Ahwb backed memory. +void Tensor::MoveCpuOrSsboToAhwb() const { + void* dest = nullptr; + if (__builtin_available(android 26, *)) { + auto error = AHardwareBuffer_lock( + ahwb_, AHARDWAREBUFFER_USAGE_CPU_WRITE_RARELY, -1, nullptr, &dest); + CHECK(error == 0) << "AHardwareBuffer_lock " << error; + } + if (valid_ & kValidCpu) { + std::memcpy(dest, cpu_buffer_, bytes()); + // Free CPU memory because next time AHWB is mapped instead. + free(cpu_buffer_); + cpu_buffer_ = nullptr; + valid_ &= ~kValidCpu; + } else if (valid_ & kValidOpenGlBuffer) { + gl_context_->Run([this, dest]() { + glBindBuffer(GL_SHADER_STORAGE_BUFFER, opengl_buffer_); + const void* src = glMapBufferRange(GL_SHADER_STORAGE_BUFFER, 0, bytes(), + GL_MAP_READ_BIT); + std::memcpy(dest, src, bytes()); + glUnmapBuffer(GL_SHADER_STORAGE_BUFFER); + glDeleteBuffers(1, &opengl_buffer_); + }); + opengl_buffer_ = GL_INVALID_INDEX; + gl_context_ = nullptr; + // Reset OpenGL Buffer validness. The OpenGL buffer will be allocated on top + // of the Ahwb at the next request to the OpenGlBufferView. + valid_ &= ~kValidOpenGlBuffer; + } else { + LOG(FATAL) << "Can't convert tensor with mask " << valid_ << " into AHWB."; + } + if (__builtin_available(android 26, *)) { + auto error = AHardwareBuffer_unlock(ahwb_, nullptr); + CHECK(error == 0) << "AHardwareBuffer_unlock " << error; + } +} + // SSBO is created on top of AHWB. A fence is inserted into the GPU queue before // the GPU task that is going to read from the SSBO. When the writing into AHWB // is finished then the GPU reads from the SSBO. @@ -408,6 +456,17 @@ void* Tensor::MapAhwbToCpuWrite() const { return nullptr; } +void Tensor::TrackAhwbUsage(uint64_t source_location_hash) const { + if (ahwb_tracking_key_ == 0) { + ahwb_tracking_key_ = source_location_hash; + for (int dim : shape_.dims) { + ahwb_tracking_key_ = tensor_internal::FnvHash64(ahwb_tracking_key_, dim); + } + } + // Keep flag value if it was set previously. + use_ahwb_ = use_ahwb_ || ahwb_usage_track_.contains(ahwb_tracking_key_); +} + #else // MEDIAPIPE_TENSOR_USE_AHWB bool Tensor::AllocateAhwbMapToSsbo() const { return false; } @@ -416,6 +475,7 @@ void Tensor::MoveAhwbStuff(Tensor* src) {} void Tensor::ReleaseAhwbStuff() {} void* Tensor::MapAhwbToCpuRead() const { return nullptr; } void* Tensor::MapAhwbToCpuWrite() const { return nullptr; } +void Tensor::TrackAhwbUsage(uint64_t key) const {} #endif // MEDIAPIPE_TENSOR_USE_AHWB diff --git a/mediapipe/framework/formats/tensor_ahwb_gpu_test.cc b/mediapipe/framework/formats/tensor_ahwb_gpu_test.cc new file mode 100644 index 000000000..b06bd3ef2 --- /dev/null +++ b/mediapipe/framework/formats/tensor_ahwb_gpu_test.cc @@ -0,0 +1,210 @@ + +#if !defined(MEDIAPIPE_NO_JNI) && \ + (__ANDROID_API__ >= 26 || \ + defined(__ANDROID_UNAVAILABLE_SYMBOLS_ARE_WEAK__)) +#include + +#include + +#include "mediapipe/framework/formats/tensor.h" +#include "mediapipe/framework/formats/tensor/views/data_types.h" +#include "mediapipe/gpu/gpu_test_base.h" +#include "mediapipe/gpu/shader_util.h" +#include "tensorflow/lite/delegates/gpu/gl/gl_call.h" +#include "testing/base/public/gunit.h" + +// The test creates OpenGL ES buffer, fills the buffer with incrementing values +// 0.0, 0.1, 0.2 etc. with the compute shader on GPU. +// Then the test requests the CPU view and compares the values. +// Float32 and Float16 tests are there. + +namespace { + +using mediapipe::Float16; +using mediapipe::Tensor; + +MATCHER_P(NearWithPrecision, precision, "") { + return std::abs(std::get<0>(arg) - std::get<1>(arg)) < precision; +} + +#if MEDIAPIPE_OPENGL_ES_VERSION >= MEDIAPIPE_OPENGL_ES_31 + +// Utility function to fill the GPU buffer. +void FillGpuBuffer(GLuint name, std::size_t size, + const Tensor::ElementType fmt) { + std::string shader_source; + if (fmt == Tensor::ElementType::kFloat32) { + shader_source = R"( #version 310 es + precision highp float; + layout(local_size_x = 1, local_size_y = 1) in; + layout(std430, binding = 0) buffer Output {float elements[];} output_data; + void main() { + uint v = gl_GlobalInvocationID.x * 2u; + output_data.elements[v] = float(v) / 10.0; + output_data.elements[v + 1u] = float(v + 1u) / 10.0; + })"; + } else { + shader_source = R"( #version 310 es + precision highp float; + layout(local_size_x = 1, local_size_y = 1) in; + layout(std430, binding = 0) buffer Output {float elements[];} output_data; + void main() { + uint v = gl_GlobalInvocationID.x; + uint tmp = packHalf2x16(vec2((float(v)* 2.0 + 0.0) / 10.0, + (float(v) * 2.0 + 1.0) / 10.0)); + output_data.elements[v] = uintBitsToFloat(tmp); + })"; + } + GLuint shader; + MP_ASSERT_OK(TFLITE_GPU_CALL_GL(glCreateShader, &shader, GL_COMPUTE_SHADER)); + const GLchar* sources[] = {shader_source.c_str()}; + MP_ASSERT_OK(TFLITE_GPU_CALL_GL(glShaderSource, shader, 1, sources, nullptr)); + MP_ASSERT_OK(TFLITE_GPU_CALL_GL(glCompileShader, shader)); + GLint is_compiled = 0; + MP_ASSERT_OK(TFLITE_GPU_CALL_GL(glGetShaderiv, shader, GL_COMPILE_STATUS, + &is_compiled)); + if (is_compiled == GL_FALSE) { + GLint max_length = 0; + MP_ASSERT_OK(TFLITE_GPU_CALL_GL(glGetShaderiv, shader, GL_INFO_LOG_LENGTH, + &max_length)); + std::vector error_log(max_length); + MP_ASSERT_OK(TFLITE_GPU_CALL_GL(glGetShaderInfoLog, shader, max_length, + &max_length, error_log.data())); + MP_EXPECT_OK(TFLITE_GPU_CALL_GL(glDeleteShader, shader)); + FAIL() << error_log.data(); + return; + } + GLuint to_buffer_program; + MP_ASSERT_OK(TFLITE_GPU_CALL_GL(glCreateProgram, &to_buffer_program)); + MP_ASSERT_OK(TFLITE_GPU_CALL_GL(glAttachShader, to_buffer_program, shader)); + MP_ASSERT_OK(TFLITE_GPU_CALL_GL(glDeleteShader, shader)); + MP_ASSERT_OK(TFLITE_GPU_CALL_GL(glLinkProgram, to_buffer_program)); + + MP_ASSERT_OK( + TFLITE_GPU_CALL_GL(glBindBufferBase, GL_SHADER_STORAGE_BUFFER, 0, name)); + MP_ASSERT_OK(TFLITE_GPU_CALL_GL(glUseProgram, to_buffer_program)); + MP_ASSERT_OK(TFLITE_GPU_CALL_GL(glDispatchCompute, size / 2, 1, 1)); + MP_EXPECT_OK(TFLITE_GPU_CALL_GL(glBindBuffer, GL_SHADER_STORAGE_BUFFER, 0)); + MP_EXPECT_OK(TFLITE_GPU_CALL_GL(glDeleteProgram, to_buffer_program)); +} + +class TensorAhwbGpuTest : public mediapipe::GpuTestBase { + public: +}; + +TEST_F(TensorAhwbGpuTest, TestGpuToCpuFloat32) { + constexpr size_t num_elements = 20; + Tensor tensor{Tensor::ElementType::kFloat32, Tensor::Shape({num_elements})}; + { + // Request Ahwb first to get Ahwb storage allocated internally. + auto view = tensor.GetAHardwareBufferWriteView(); + ASSERT_NE(view.handle(), nullptr); + view.SetWritingFinishedFD(-1, [](bool) { return true; }); + } + RunInGlContext([&tensor] { + auto ssbo_view = tensor.GetOpenGlBufferWriteView(); + auto ssbo_name = ssbo_view.name(); + ASSERT_GT(ssbo_name, 0); + FillGpuBuffer(ssbo_name, tensor.shape().num_elements(), + tensor.element_type()); + }); + auto ptr = tensor.GetCpuReadView().buffer(); + ASSERT_NE(ptr, nullptr); + std::vector reference; + reference.resize(num_elements); + for (int i = 0; i < num_elements; i++) { + reference[i] = static_cast(i) / 10.0f; + } + EXPECT_THAT(absl::Span(ptr, num_elements), + testing::Pointwise(testing::FloatEq(), reference)); +} + +TEST_F(TensorAhwbGpuTest, TestGpuToCpuFloat16) { + constexpr size_t num_elements = 20; + Tensor tensor{Tensor::ElementType::kFloat16, Tensor::Shape({num_elements})}; + { + // Request Ahwb first to get Ahwb storage allocated internally. + auto view = tensor.GetAHardwareBufferWriteView(); + ASSERT_NE(view.handle(), nullptr); + view.SetReadingFinishedFunc([](bool) { return true; }); + } + RunInGlContext([&tensor] { + auto ssbo_view = tensor.GetOpenGlBufferWriteView(); + auto ssbo_name = ssbo_view.name(); + ASSERT_GT(ssbo_name, 0); + FillGpuBuffer(ssbo_name, tensor.shape().num_elements(), + tensor.element_type()); + }); + auto ptr = tensor.GetCpuReadView().buffer(); + ASSERT_NE(ptr, nullptr); + std::vector reference; + reference.resize(num_elements); + for (int i = 0; i < num_elements; i++) { + reference[i] = static_cast(i) / 10.0f; + } + // Precision is set to a reasonable value for Float16. + EXPECT_THAT(absl::Span(ptr, num_elements), + testing::Pointwise(NearWithPrecision(0.001), reference)); +} + +TEST_F(TensorAhwbGpuTest, TestReplacingCpuByAhwb) { + // Request the CPU view to get the memory to be allocated. + // Request Ahwb view then to transform the storage into Ahwb. + constexpr size_t num_elements = 20; + Tensor tensor{Tensor::ElementType::kFloat32, Tensor::Shape({num_elements})}; + { + auto ptr = tensor.GetCpuWriteView().buffer(); + ASSERT_NE(ptr, nullptr); + for (int i = 0; i < num_elements; i++) { + ptr[i] = static_cast(i) / 10.0f; + } + } + { + auto view = tensor.GetAHardwareBufferReadView(); + ASSERT_NE(view.handle(), nullptr); + view.SetReadingFinishedFunc([](bool) { return true; }); + } + auto ptr = tensor.GetCpuReadView().buffer(); + ASSERT_NE(ptr, nullptr); + std::vector reference; + reference.resize(num_elements); + for (int i = 0; i < num_elements; i++) { + reference[i] = static_cast(i) / 10.0f; + } + EXPECT_THAT(absl::Span(ptr, num_elements), + testing::Pointwise(testing::FloatEq(), reference)); +} + +TEST_F(TensorAhwbGpuTest, TestReplacingGpuByAhwb) { + // Request the GPU view to get the ssbo allocated internally. + // Request Ahwb view then to transform the storage into Ahwb. + constexpr size_t num_elements = 20; + Tensor tensor{Tensor::ElementType::kFloat32, Tensor::Shape({num_elements})}; + RunInGlContext([&tensor] { + auto ssbo_view = tensor.GetOpenGlBufferWriteView(); + auto ssbo_name = ssbo_view.name(); + ASSERT_GT(ssbo_name, 0); + FillGpuBuffer(ssbo_name, tensor.shape().num_elements(), + tensor.element_type()); + }); + { + auto view = tensor.GetAHardwareBufferReadView(); + ASSERT_NE(view.handle(), nullptr); + view.SetReadingFinishedFunc([](bool) { return true; }); + } + auto ptr = tensor.GetCpuReadView().buffer(); + ASSERT_NE(ptr, nullptr); + std::vector reference; + reference.resize(num_elements); + for (int i = 0; i < num_elements; i++) { + reference[i] = static_cast(i) / 10.0f; + } + EXPECT_THAT(absl::Span(ptr, num_elements), + testing::Pointwise(testing::FloatEq(), reference)); +} + +#endif // MEDIAPIPE_OPENGL_ES_VERSION >= MEDIAPIPE_OPENGL_ES_31 +} // namespace + +#endif // !defined(MEDIAPIPE_NO_JNI) && (__ANDROID_API__ >= 26 || + // defined(__ANDROID_UNAVAILABLE_SYMBOLS_ARE_WEAK__)) diff --git a/mediapipe/framework/formats/tensor_ahwb_test.cc b/mediapipe/framework/formats/tensor_ahwb_test.cc index 7ab5a4925..69e49dd58 100644 --- a/mediapipe/framework/formats/tensor_ahwb_test.cc +++ b/mediapipe/framework/formats/tensor_ahwb_test.cc @@ -1,34 +1,28 @@ #include "mediapipe/framework/formats/tensor.h" -#include "mediapipe/gpu/gpu_test_base.h" #include "testing/base/public/gmock.h" #include "testing/base/public/gunit.h" -#ifdef MEDIAPIPE_TENSOR_USE_AHWB -#if !MEDIAPIPE_DISABLE_GPU - namespace mediapipe { -class TensorAhwbTest : public mediapipe::GpuTestBase { - public: -}; - -TEST_F(TensorAhwbTest, TestCpuThenAHWB) { +TEST(TensorAhwbTest, TestCpuThenAHWB) { Tensor tensor(Tensor::ElementType::kFloat32, Tensor::Shape{1}); { auto ptr = tensor.GetCpuWriteView().buffer(); EXPECT_NE(ptr, nullptr); } { - auto ahwb = tensor.GetAHardwareBufferReadView().handle(); - EXPECT_NE(ahwb, nullptr); + auto view = tensor.GetAHardwareBufferReadView(); + EXPECT_NE(view.handle(), nullptr); + view.SetReadingFinishedFunc([](bool) { return true; }); } } -TEST_F(TensorAhwbTest, TestAHWBThenCpu) { +TEST(TensorAhwbTest, TestAHWBThenCpu) { Tensor tensor(Tensor::ElementType::kFloat32, Tensor::Shape{1}); { - auto ahwb = tensor.GetAHardwareBufferWriteView().handle(); - EXPECT_NE(ahwb, nullptr); + auto view = tensor.GetAHardwareBufferWriteView(); + EXPECT_NE(view.handle(), nullptr); + view.SetWritingFinishedFD(-1, [](bool) { return true; }); } { auto ptr = tensor.GetCpuReadView().buffer(); @@ -36,21 +30,71 @@ TEST_F(TensorAhwbTest, TestAHWBThenCpu) { } } -TEST_F(TensorAhwbTest, TestCpuThenGl) { - RunInGlContext([] { - Tensor tensor(Tensor::ElementType::kFloat32, Tensor::Shape{1}); +TEST(TensorAhwbTest, TestAhwbAlignment) { + Tensor tensor(Tensor::ElementType::kFloat32, Tensor::Shape{5}); + { + auto view = tensor.GetAHardwareBufferWriteView(16); + ASSERT_NE(view.handle(), nullptr); + if (__builtin_available(android 26, *)) { + AHardwareBuffer_Desc desc; + AHardwareBuffer_describe(view.handle(), &desc); + // sizeof(float) * 5 = 20, the closest aligned to 16 size is 32. + EXPECT_EQ(desc.width, 32); + } + view.SetWritingFinishedFD(-1, [](bool) { return true; }); + } +} + +// Tensor::GetCpuView uses source location mechanism that gives source file name +// and line from where the method is called. The function is intended just to +// have two calls providing the same source file name and line. +auto GetCpuView(const Tensor &tensor) { return tensor.GetCpuWriteView(); } + +// The test checks the tracking mechanism: when a tensor's Cpu view is retrieved +// for the first time then the source location is attached to the tensor. If the +// Ahwb view is requested then from the tensor then the previously recorded Cpu +// view request source location is marked for using Ahwb storage. +// When a Cpu view with the same source location (but for the newly allocated +// tensor) is requested and the location is marked to use Ahwb storage then the +// Ahwb storage is allocated for the CpuView. +TEST(TensorAhwbTest, TestTrackingAhwb) { + // Create first tensor and request Cpu and then Ahwb view to mark the source + // location for Ahwb storage. + { + Tensor tensor(Tensor::ElementType::kFloat32, Tensor::Shape{9}); { - auto ptr = tensor.GetCpuWriteView().buffer(); - EXPECT_NE(ptr, nullptr); + auto view = GetCpuView(tensor); + EXPECT_NE(view.buffer(), nullptr); } { - auto ssbo = tensor.GetOpenGlBufferReadView().name(); - EXPECT_GT(ssbo, 0); + // Align size of the Ahwb by multiple of 16. + auto view = tensor.GetAHardwareBufferWriteView(16); + EXPECT_NE(view.handle(), nullptr); + view.SetReadingFinishedFunc([](bool) { return true; }); } - }); + } + { + Tensor tensor(Tensor::ElementType::kFloat32, Tensor::Shape{9}); + { + // The second tensor uses the same Cpu view source location so Ahwb + // storage is allocated internally. + auto view = GetCpuView(tensor); + EXPECT_NE(view.buffer(), nullptr); + } + { + // Check the Ahwb size to be aligned to multiple of 16. The alignment is + // stored by previous requesting of the Ahwb view. + auto view = tensor.GetAHardwareBufferReadView(); + EXPECT_NE(view.handle(), nullptr); + if (__builtin_available(android 26, *)) { + AHardwareBuffer_Desc desc; + AHardwareBuffer_describe(view.handle(), &desc); + // sizeof(float) * 9 = 36. The closest aligned size is 48. + EXPECT_EQ(desc.width, 48); + } + view.SetReadingFinishedFunc([](bool) { return true; }); + } + } } } // namespace mediapipe - -#endif // !MEDIAPIPE_DISABLE_GPU -#endif // MEDIAPIPE_TENSOR_USE_AHWB diff --git a/mediapipe/framework/formats/tensor_mtl_buffer_view.h b/mediapipe/framework/formats/tensor_mtl_buffer_view.h new file mode 100644 index 000000000..a61659d3d --- /dev/null +++ b/mediapipe/framework/formats/tensor_mtl_buffer_view.h @@ -0,0 +1,61 @@ +// Copyright 2020 The MediaPipe Authors. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#ifndef MEDIAPIPE_FRAMEWORK_FORMATS_TENSOR_MTL_BUFFER_VIEW_H_ +#define MEDIAPIPE_FRAMEWORK_FORMATS_TENSOR_MTL_BUFFER_VIEW_H_ + +#import + +#include +#include +#include +#include +#include +#include +#include +#include + +#include "absl/container/flat_hash_map.h" +#include "absl/synchronization/mutex.h" +#include "mediapipe/framework/formats/tensor.h" +#include "mediapipe/framework/port.h" + +namespace mediapipe { +class MtlBufferView : public Tensor::View { + public: + // The command buffer status is checked for completeness if GPU-to-CPU + // synchronization is required. + static MtlBufferView GetReadView(const Tensor& tensor, + id command_buffer); + static MtlBufferView GetWriteView(const Tensor& tensor, + id command_buffer); + static MtlBufferView GetWriteView(const Tensor& tensor, id device); + + id buffer() const { return buffer_; } + MtlBufferView(MtlBufferView&& src) + : Tensor::View(std::move(src)), buffer_(src.buffer_) { + src.buffer_ = nil; + } + + protected: + friend class Tensor; + static void AllocateMtlBuffer(const Tensor& tensor, id device); + MtlBufferView(id buffer, std::unique_ptr&& lock) + : Tensor::View(std::move(lock)), buffer_(buffer) {} + id buffer_; +}; + +} // namespace mediapipe + +#endif // MEDIAPIPE_FRAMEWORK_FORMATS_TENSOR_MTL_BUFFER_VIEW_H_ diff --git a/mediapipe/framework/input_stream_handler.cc b/mediapipe/framework/input_stream_handler.cc index d1dffa414..a7bd9ef43 100644 --- a/mediapipe/framework/input_stream_handler.cc +++ b/mediapipe/framework/input_stream_handler.cc @@ -354,7 +354,9 @@ NodeReadiness SyncSet::GetReadiness(Timestamp* min_stream_timestamp) { } } *min_stream_timestamp = std::min(min_packet, min_bound); - if (*min_stream_timestamp == Timestamp::Done()) { + if (*min_stream_timestamp >= Timestamp::OneOverPostStream()) { + // Either OneOverPostStream or Done indicates no more packets. + *min_stream_timestamp = Timestamp::Done(); last_processed_ts_ = Timestamp::Done().PreviousAllowedInStream(); return NodeReadiness::kReadyForClose; } diff --git a/mediapipe/framework/output_stream_shard.h b/mediapipe/framework/output_stream_shard.h index fdc5fe077..718174c45 100644 --- a/mediapipe/framework/output_stream_shard.h +++ b/mediapipe/framework/output_stream_shard.h @@ -127,6 +127,8 @@ class OutputStreamShard : public OutputStream { friend class GraphProfiler; // Accesses OutputStreamShard for profiling. friend class GraphTracer; + // Accesses OutputStreamShard for profiling. + friend class PerfettoTraceScope; // Accesses OutputStreamShard for post processing. friend class OutputStreamManager; }; diff --git a/mediapipe/framework/port/BUILD b/mediapipe/framework/port/BUILD index 87944d80f..1039dc1c6 100644 --- a/mediapipe/framework/port/BUILD +++ b/mediapipe/framework/port/BUILD @@ -18,7 +18,7 @@ licenses(["notice"]) package( - default_visibility = ["//visibility:private"], + default_visibility = ["//visibility:public"], features = ["-parse_headers"], ) @@ -28,7 +28,6 @@ config_setting( define_values = { "USE_MEDIAPIPE_THREADPOOL": "1", }, - visibility = ["//visibility:public"], ) #TODO : remove from OSS. @@ -37,13 +36,11 @@ config_setting( define_values = { "USE_MEDIAPIPE_THREADPOOL": "0", }, - visibility = ["//visibility:public"], ) cc_library( name = "aligned_malloc_and_free", hdrs = ["aligned_malloc_and_free.h"], - visibility = ["//visibility:public"], deps = [ "//mediapipe/framework/deps:aligned_malloc_and_free", "@com_google_absl//absl/base:core_headers", @@ -57,7 +54,6 @@ cc_library( "advanced_proto_inc.h", "proto_ns.h", ], - visibility = ["//visibility:public"], deps = [ ":advanced_proto_lite", ":core_proto", @@ -72,7 +68,6 @@ cc_library( "advanced_proto_lite_inc.h", "proto_ns.h", ], - visibility = ["//visibility:public"], deps = [ ":core_proto", "//mediapipe/framework:port", @@ -83,7 +78,6 @@ cc_library( cc_library( name = "any_proto", hdrs = ["any_proto.h"], - visibility = ["//visibility:public"], deps = [ ":core_proto", ], @@ -94,7 +88,6 @@ cc_library( hdrs = [ "commandlineflags.h", ], - visibility = ["//visibility:public"], deps = [ "//third_party:glog", "@com_google_absl//absl/flags:flag", @@ -107,7 +100,6 @@ cc_library( "core_proto_inc.h", "proto_ns.h", ], - visibility = ["//visibility:public"], deps = [ "//mediapipe/framework:port", "@com_google_protobuf//:protobuf", @@ -117,7 +109,6 @@ cc_library( cc_library( name = "file_helpers", hdrs = ["file_helpers.h"], - visibility = ["//visibility:public"], deps = [ ":status", "//mediapipe/framework/deps:file_helpers", @@ -128,7 +119,6 @@ cc_library( cc_library( name = "image_resizer", hdrs = ["image_resizer.h"], - visibility = ["//visibility:public"], deps = select({ "//conditions:default": [ "//mediapipe/framework/deps:image_resizer", @@ -140,14 +130,12 @@ cc_library( cc_library( name = "integral_types", hdrs = ["integral_types.h"], - visibility = ["//visibility:public"], ) cc_library( name = "benchmark", testonly = 1, hdrs = ["benchmark.h"], - visibility = ["//visibility:public"], deps = [ "@com_google_benchmark//:benchmark", ], @@ -158,7 +146,6 @@ cc_library( hdrs = [ "re2.h", ], - visibility = ["//visibility:public"], deps = [ "//mediapipe/framework/deps:re2", ], @@ -173,7 +160,6 @@ cc_library( "gtest-spi.h", "status_matchers.h", ], - visibility = ["//visibility:public"], deps = [ ":status_matchers", "//mediapipe/framework/deps:message_matchers", @@ -190,7 +176,6 @@ cc_library( "gtest-spi.h", "status_matchers.h", ], - visibility = ["//visibility:public"], deps = [ ":status_matchers", "//mediapipe/framework/deps:message_matchers", @@ -204,7 +189,6 @@ cc_library( hdrs = [ "logging.h", ], - visibility = ["//visibility:public"], deps = [ "//mediapipe/framework:port", "//third_party:glog", @@ -217,7 +201,6 @@ cc_library( hdrs = [ "map_util.h", ], - visibility = ["//visibility:public"], deps = [ "//mediapipe/framework:port", "//mediapipe/framework/deps:map_util", @@ -227,7 +210,6 @@ cc_library( cc_library( name = "numbers", hdrs = ["numbers.h"], - visibility = ["//visibility:public"], deps = ["//mediapipe/framework/deps:numbers"], ) @@ -238,13 +220,11 @@ config_setting( define_values = { "MEDIAPIPE_DISABLE_OPENCV": "1", }, - visibility = ["//visibility:public"], ) cc_library( name = "opencv_core", hdrs = ["opencv_core_inc.h"], - visibility = ["//visibility:public"], deps = [ "//third_party:opencv", ], @@ -253,7 +233,6 @@ cc_library( cc_library( name = "opencv_imgproc", hdrs = ["opencv_imgproc_inc.h"], - visibility = ["//visibility:public"], deps = [ ":opencv_core", "//third_party:opencv", @@ -263,7 +242,6 @@ cc_library( cc_library( name = "opencv_imgcodecs", hdrs = ["opencv_imgcodecs_inc.h"], - visibility = ["//visibility:public"], deps = [ ":opencv_core", "//third_party:opencv", @@ -273,7 +251,6 @@ cc_library( cc_library( name = "opencv_highgui", hdrs = ["opencv_highgui_inc.h"], - visibility = ["//visibility:public"], deps = [ ":opencv_core", "//third_party:opencv", @@ -283,7 +260,6 @@ cc_library( cc_library( name = "opencv_video", hdrs = ["opencv_video_inc.h"], - visibility = ["//visibility:public"], deps = [ ":opencv_core", "//mediapipe/framework:port", @@ -294,7 +270,6 @@ cc_library( cc_library( name = "opencv_features2d", hdrs = ["opencv_features2d_inc.h"], - visibility = ["//visibility:public"], deps = [ ":opencv_core", "//third_party:opencv", @@ -304,20 +279,28 @@ cc_library( cc_library( name = "opencv_calib3d", hdrs = ["opencv_calib3d_inc.h"], - visibility = ["//visibility:public"], deps = [ ":opencv_core", "//third_party:opencv", ], ) +cc_library( + name = "opencv_videoio", + hdrs = ["opencv_videoio_inc.h"], + deps = [ + ":opencv_core", + "//mediapipe/framework:port", + "//third_party:opencv", + ], +) + cc_library( name = "parse_text_proto", hdrs = [ "parse_text_proto.h", "proto_ns.h", ], - visibility = ["//visibility:public"], deps = [ ":core_proto", ":logging", @@ -328,14 +311,12 @@ cc_library( cc_library( name = "point", hdrs = ["point2.h"], - visibility = ["//visibility:public"], deps = ["//mediapipe/framework/deps:point"], ) cc_library( name = "port", hdrs = ["port.h"], - visibility = ["//visibility:public"], deps = [ "//mediapipe/framework:port", "@com_google_absl//absl/base:core_headers", @@ -345,14 +326,12 @@ cc_library( cc_library( name = "rectangle", hdrs = ["rectangle.h"], - visibility = ["//visibility:public"], deps = ["//mediapipe/framework/deps:rectangle"], ) cc_library( name = "ret_check", hdrs = ["ret_check.h"], - visibility = ["//visibility:public"], deps = [ "//mediapipe/framework:port", "//mediapipe/framework/deps:ret_check", @@ -362,7 +341,6 @@ cc_library( cc_library( name = "singleton", hdrs = ["singleton.h"], - visibility = ["//visibility:public"], deps = ["//mediapipe/framework/deps:singleton"], ) @@ -371,7 +349,6 @@ cc_library( hdrs = [ "source_location.h", ], - visibility = ["//visibility:public"], deps = [ "//mediapipe/framework:port", "//mediapipe/framework/deps:source_location", @@ -386,7 +363,6 @@ cc_library( "status_builder.h", "status_macros.h", ], - visibility = ["//visibility:public"], deps = [ ":source_location", "//mediapipe/framework:port", @@ -401,7 +377,6 @@ cc_library( hdrs = [ "statusor.h", ], - visibility = ["//visibility:public"], deps = [ "//mediapipe/framework:port", "@com_google_absl//absl/status:statusor", @@ -412,7 +387,6 @@ cc_library( name = "status_matchers", testonly = 1, hdrs = ["status_matchers.h"], - visibility = ["//visibility:private"], deps = [ ":status", "@com_google_googletest//:gtest", @@ -422,7 +396,6 @@ cc_library( cc_library( name = "threadpool", hdrs = ["threadpool.h"], - visibility = ["//visibility:public"], deps = select({ "//conditions:default": [":threadpool_impl_default_to_google"], "//mediapipe:android": [":threadpool_impl_default_to_mediapipe"], @@ -449,7 +422,6 @@ alias( cc_library( name = "topologicalsorter", hdrs = ["topologicalsorter.h"], - visibility = ["//visibility:public"], deps = [ "//mediapipe/framework:port", "//mediapipe/framework/deps:topologicalsorter", @@ -459,6 +431,5 @@ cc_library( cc_library( name = "vector", hdrs = ["vector.h"], - visibility = ["//visibility:public"], deps = ["//mediapipe/framework/deps:vector"], ) diff --git a/mediapipe/framework/port/build_config.bzl b/mediapipe/framework/port/build_config.bzl index 80e9bfc4d..94a4a5646 100644 --- a/mediapipe/framework/port/build_config.bzl +++ b/mediapipe/framework/port/build_config.bzl @@ -214,10 +214,10 @@ def mediapipe_ts_library( """Generate ts_project for MediaPipe open source version. Args: - name: the name of the cc_proto_library. - srcs: the .proto files of the cc_proto_library for Bazel use. + name: the name of the mediapipe_ts_library. + srcs: the .ts files of the mediapipe_ts_library for Bazel use. visibility: visibility of this target. - deps: a list of dependency labels for Bazel use; must be cc_proto_library. + deps: a list of dependency labels for Bazel use. testonly: test only or not. allow_unoptimized_namespaces: ignored, used only internally """ @@ -228,6 +228,8 @@ def mediapipe_ts_library( srcs = srcs, visibility = visibility, deps = deps + [ + "@npm//@types/jasmine", + "@npm//@types/node", "@npm//@types/offscreencanvas", "@npm//@types/google-protobuf", ], @@ -235,3 +237,36 @@ def mediapipe_ts_library( declaration = True, tsconfig = "//:tsconfig.json", )) + +def mediapipe_ts_declaration( + name, + srcs, + visibility = None, + deps = []): + """Generate ts_declaration for MediaPipe open source version. + + Args: + name: the name of the mediapipe_ts_declaration. + srcs: the .d.ts files of the mediapipe_ts_declaration for Bazel use. + visibility: visibility of this target. + deps: a list of dependency labels for Bazel use + """ + + # Bazel does not create JS files for .d.ts files, which leads to import + # failures in our open source build. We simply re-name the .d.ts files + # to .ts to work around this problem. + for src in srcs: + native.genrule( + name = replace_suffix(src, ".d.ts", "_d_ts"), + srcs = [src], + outs = [replace_suffix(src, ".d.ts", ".ts")], + visibility = visibility, + cmd = "cp -n $< $@;", + ) + + mediapipe_ts_library( + name = name, + srcs = [replace_suffix(src, ".d.ts", "_d_ts") for src in srcs], + visibility = visibility, + deps = deps, + ) diff --git a/mediapipe/framework/port/opencv_videoio_inc.h b/mediapipe/framework/port/opencv_videoio_inc.h new file mode 100644 index 000000000..63029b69f --- /dev/null +++ b/mediapipe/framework/port/opencv_videoio_inc.h @@ -0,0 +1,21 @@ +// Copyright 2022 The MediaPipe Authors. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#ifndef MEDIAPIPE_PORT_OPENCV_VIDEOIO_INC_H_ +#define MEDIAPIPE_PORT_OPENCV_VIDEOIO_INC_H_ + +#include "mediapipe/framework/port/opencv_core_inc.h" +#include "third_party/OpenCV/videoio.hpp" + +#endif // MEDIAPIPE_PORT_OPENCV_VIDEOIO_INC_H_ diff --git a/mediapipe/framework/port/proto_ns.h b/mediapipe/framework/port/proto_ns.h index 83aecdf49..53b854ff7 100644 --- a/mediapipe/framework/port/proto_ns.h +++ b/mediapipe/framework/port/proto_ns.h @@ -17,8 +17,9 @@ #include -// Temporary forward declarations for proto2 support on portable targets. -// Use proto_ns inside namespace mediapipe instead of proto2 namespace. +// Temporary forward declarations for google::protobuf support on portable +// targets. Use proto_ns inside namespace mediapipe instead of google::protobuf +// namespace. #include "google/protobuf/message.h" #include "google/protobuf/message_lite.h" #include "google/protobuf/repeated_field.h" diff --git a/mediapipe/framework/profiler/BUILD b/mediapipe/framework/profiler/BUILD index 237aa825f..6184ed45b 100644 --- a/mediapipe/framework/profiler/BUILD +++ b/mediapipe/framework/profiler/BUILD @@ -127,6 +127,7 @@ cc_library( "//mediapipe/framework/port:status", "//mediapipe/framework/port:advanced_proto_lite", "//mediapipe/framework/tool:name_util", + ":web_performance_profiling", ] + select({ "//conditions:default": [], }) + select({ @@ -140,7 +141,7 @@ cc_library( name = "circular_buffer", hdrs = ["circular_buffer.h"], visibility = [ - "//visibility:public", + "//mediapipe:__subpackages__", ], deps = [ "//mediapipe/framework/port:integral_types", @@ -151,7 +152,6 @@ cc_test( name = "circular_buffer_test", size = "small", srcs = ["circular_buffer_test.cc"], - visibility = ["//visibility:public"], deps = [ ":circular_buffer", "//mediapipe/framework/port:gtest_main", @@ -164,7 +164,7 @@ cc_library( name = "trace_buffer", srcs = ["trace_buffer.h"], hdrs = ["trace_buffer.h"], - visibility = ["//visibility:public"], + visibility = ["//mediapipe/framework/profiler:__subpackages__"], deps = [ ":circular_buffer", "//mediapipe/framework:calculator_profile_cc_proto", @@ -276,6 +276,25 @@ cc_test( ], ) +config_setting( + name = "mediapipe_web_profiling_enabled", + values = { + "define": "MEDIAPIPE_WEB_PROFILING=1", + }, + visibility = ["//visibility:private"], +) + +cc_library( + name = "web_performance_profiling", + hdrs = ["web_performance_profiling.h"], + defines = select({ + ":mediapipe_web_profiling_enabled": ["MEDIAPIPE_WEB_PROFILING_ENABLED"], + "//conditions:default": [], + }), + visibility = ["//mediapipe:__subpackages__"], + deps = ["@com_google_absl//absl/strings"], +) + cc_library( name = "profiler_resource_util", srcs = ["profiler_resource_util_common.cc"] + select({ @@ -285,6 +304,7 @@ cc_library( "//mediapipe:ios": ["profiler_resource_util_ios.cc"], }), hdrs = ["profiler_resource_util.h"], + # We use Objective-C++ on iOS. copts = select({ "//conditions:default": [], @@ -292,9 +312,7 @@ cc_library( "-ObjC++", ], }), - visibility = [ - "//mediapipe/framework:mediapipe_internal", - ], + visibility = ["//visibility:private"], deps = [ "@com_google_absl//absl/flags:flag", "//mediapipe/framework/port:logging", @@ -334,6 +352,10 @@ cc_library( "graph_profiler_stub.h", ], visibility = ["//mediapipe/framework:__pkg__"], + deps = [ + "//mediapipe/framework:calculator_cc_proto", + "//mediapipe/framework:calculator_profile_cc_proto", + ], ) cc_test( diff --git a/mediapipe/framework/profiler/graph_profiler.cc b/mediapipe/framework/profiler/graph_profiler.cc index f14acfc78..6aead5250 100644 --- a/mediapipe/framework/profiler/graph_profiler.cc +++ b/mediapipe/framework/profiler/graph_profiler.cc @@ -194,6 +194,7 @@ void GraphProfiler::Initialize( "Calculator \"$0\" has already been added.", node_name); } profile_builder_ = std::make_unique(this); + graph_id_ = ++next_instance_id_; is_initialized_ = true; } diff --git a/mediapipe/framework/profiler/graph_profiler.h b/mediapipe/framework/profiler/graph_profiler.h index 29969af2e..6358cb057 100644 --- a/mediapipe/framework/profiler/graph_profiler.h +++ b/mediapipe/framework/profiler/graph_profiler.h @@ -232,6 +232,14 @@ class GraphProfiler : public std::enable_shared_from_this { const ProfilerConfig& profiler_config() { return profiler_config_; } + // Helper method to expose the config to other profilers. + const ValidatedGraphConfig* GetValidatedGraphConfig() { + return validated_graph_; + } + + // Gets a numerical identifier for this GraphProfiler object. + uint64_t GetGraphId() { return graph_id_; } + private: // This can be used to add packet info for the input streams to the graph. // It treats the stream defined by |stream_name| as a stream produced by a @@ -352,6 +360,12 @@ class GraphProfiler : public std::enable_shared_from_this { class GraphProfileBuilder; std::unique_ptr profile_builder_; + // The globally incrementing identifier for all graphs in a process. + static inline std::atomic_int next_instance_id_ = 0; + + // A unique identifier for this object. Only unique within a process. + uint64_t graph_id_; + // For testing. friend GraphProfilerTestPeer; }; diff --git a/mediapipe/framework/profiler/graph_profiler_stub.h b/mediapipe/framework/profiler/graph_profiler_stub.h index 12a024fe8..72d5d7275 100644 --- a/mediapipe/framework/profiler/graph_profiler_stub.h +++ b/mediapipe/framework/profiler/graph_profiler_stub.h @@ -93,6 +93,7 @@ class GraphProfilerStub { PopulateGraphConfig populate_config = PopulateGraphConfig::kNo) { return absl::OkStatus(); } + inline absl::Status WriteProfile() { return absl::OkStatus(); } inline void Pause() {} inline void Resume() {} inline void Reset() {} diff --git a/mediapipe/framework/profiler/graph_profiler_test.cc b/mediapipe/framework/profiler/graph_profiler_test.cc index 81ba90cda..e9badaa25 100644 --- a/mediapipe/framework/profiler/graph_profiler_test.cc +++ b/mediapipe/framework/profiler/graph_profiler_test.cc @@ -39,13 +39,15 @@ constexpr char kDummyTestCalculatorName[] = "DummyTestCalculator"; CalculatorGraphConfig::Node CreateNodeConfig( const std::string& raw_node_config) { CalculatorGraphConfig::Node node_config; - QCHECK(proto2::TextFormat::ParseFromString(raw_node_config, &node_config)); + QCHECK(google::protobuf::TextFormat::ParseFromString(raw_node_config, + &node_config)); return node_config; } CalculatorGraphConfig CreateGraphConfig(const std::string& raw_graph_config) { CalculatorGraphConfig graph_config; - QCHECK(proto2::TextFormat::ParseFromString(raw_graph_config, &graph_config)); + QCHECK(google::protobuf::TextFormat::ParseFromString(raw_graph_config, + &graph_config)); return graph_config; } @@ -442,6 +444,32 @@ TEST_F(GraphProfilerTestPeer, InitializeMultipleTimes) { "Cannot initialize .* multiple times."); } +// Tests that graph identifiers are not reused, even after destruction. +TEST_F(GraphProfilerTestPeer, InitializeMultipleProfilers) { + auto raw_graph_config = R"( + profiler_config { + enable_profiler: true + } + input_stream: "input_stream" + node { + calculator: "DummyTestCalculator" + input_stream: "input_stream" + })"; + const int n_iterations = 100; + absl::flat_hash_set seen_ids; + for (int i = 0; i < n_iterations; ++i) { + std::shared_ptr profiler = + std::make_shared(); + auto graph_config = CreateGraphConfig(raw_graph_config); + mediapipe::ValidatedGraphConfig validated_graph; + QCHECK_OK(validated_graph.Initialize(graph_config)); + profiler->Initialize(validated_graph); + + int id = profiler->GetGraphId(); + ASSERT_THAT(seen_ids, testing::Not(testing::Contains(id))); + seen_ids.insert(id); + } +} // Tests that Pause(), Resume(), and Reset() works. TEST_F(GraphProfilerTestPeer, PauseResumeReset) { InitializeProfilerWithGraphConfig(R"( @@ -1141,7 +1169,7 @@ TEST_F(GraphProfilerTestPeer, AddProcessSampleWithStreamLatency) { TEST(GraphProfilerTest, ParallelReads) { // A graph that processes a certain number of packets before finishing. CalculatorGraphConfig config; - QCHECK(proto2::TextFormat::ParseFromString(R"( + QCHECK(google::protobuf::TextFormat::ParseFromString(R"( profiler_config { enable_profiler: true } @@ -1163,7 +1191,7 @@ TEST(GraphProfilerTest, ParallelReads) { } output_stream: "OUT:0:the_integers" )", - &config)); + &config)); // Start running the graph on its own threads. absl::Mutex out_1_mutex; @@ -1220,7 +1248,7 @@ std::set GetCalculatorNames(const CalculatorGraphConfig& config) { TEST(GraphProfilerTest, CalculatorProfileFilter) { CalculatorGraphConfig config; - QCHECK(proto2::TextFormat::ParseFromString(R"( + QCHECK(google::protobuf::TextFormat::ParseFromString(R"( profiler_config { enable_profiler: true } @@ -1242,7 +1270,7 @@ TEST(GraphProfilerTest, CalculatorProfileFilter) { } output_stream: "OUT:0:the_integers" )", - &config)); + &config)); std::set expected_names; expected_names = {"RangeCalculator", "PassThroughCalculator"}; @@ -1269,7 +1297,7 @@ TEST(GraphProfilerTest, CalculatorProfileFilter) { TEST(GraphProfilerTest, CaptureProfilePopulateConfig) { CalculatorGraphConfig config; - QCHECK(proto2::TextFormat::ParseFromString(R"( + QCHECK(google::protobuf::TextFormat::ParseFromString(R"( profiler_config { enable_profiler: true trace_enabled: true @@ -1284,7 +1312,7 @@ TEST(GraphProfilerTest, CaptureProfilePopulateConfig) { input_stream: "input_stream" } )", - &config)); + &config)); CalculatorGraph graph; MP_ASSERT_OK(graph.Initialize(config)); GraphProfile profile; diff --git a/mediapipe/framework/profiler/web_performance_profiling.h b/mediapipe/framework/profiler/web_performance_profiling.h new file mode 100644 index 000000000..47b76fe88 --- /dev/null +++ b/mediapipe/framework/profiler/web_performance_profiling.h @@ -0,0 +1,68 @@ +#ifndef MEDIAPIPE_FRAMEWORK_PROFILER_WEB_PERFORMANCE_PROFILING_H_ +#define MEDIAPIPE_FRAMEWORK_PROFILER_WEB_PERFORMANCE_PROFILING_H_ + +#if MEDIAPIPE_WEB_PROFILING_ENABLED && __EMSCRIPTEN__ +#include + +#include "absl/strings/str_cat.h" + +// This records MediaPipe profiling events in the browser's performance trace. +// To use, build with: +// --define MEDIAPIPE_PROFILING=1 --define MEDIAPIPE_WEB_PROFILING=1 + +namespace mediapipe { + +class WepPerformanceTraceScope { + public: + explicit WepPerformanceTraceScope(TraceEvent::EventType event_type, + const char* event_type_str, + CalculatorContext* cc) + : event_type_str_(event_type_str), cc_(cc) { + const auto& calculator_name = cc->NodeName(); + std::string start_name = + absl::StrCat(calculator_name, "::", event_type_str_, "_start"); + std::string timestamp_str = cc->InputTimestamp().DebugString(); + EM_ASM( + { + const startName = UTF8ToString($0); + const timestamp = UTF8ToString($1); + performance.mark(startName, {mp_timestamp : timestamp}); + }, + start_name.c_str(), timestamp_str.c_str()); + } + + ~WepPerformanceTraceScope() { + const auto& calculator_name = cc_->NodeName(); + std::string start_name = + absl::StrCat(calculator_name, "::", event_type_str_, "_start"); + std::string end_name = + absl::StrCat(calculator_name, "::", event_type_str_, "_end"); + std::string measure_name = + absl::StrCat(calculator_name, "::", event_type_str_); + EM_ASM( + { + const startName = UTF8ToString($0); + const endName = UTF8ToString($1); + const measureName = UTF8ToString($2); + performance.mark(endName); + performance.measure(measureName, startName, endName); + }, + start_name.c_str(), end_name.c_str(), measure_name.c_str()); + } + + private: + const char* event_type_str_; + CalculatorContext* cc_; +}; + +} // namespace mediapipe + +#define MEDIAPIPE_WEB_PERFORMANCE_SCOPE(event_type, calculator_context) \ + mediapipe::WepPerformanceTraceScope web_trace_scope( \ + mediapipe::TraceEvent::event_type, #event_type, calculator_context) + +#else +#define MEDIAPIPE_WEB_PERFORMANCE_SCOPE(event_type, calculator_context) +#endif // MEDIAPIPE_WEB_PROFILING_ENABLED && __EMSCRIPTEN__ + +#endif // MEDIAPIPE_FRAMEWORK_PROFILER_WEB_PERFORMANCE_PROFILING_H_ diff --git a/mediapipe/framework/scheduler.cc b/mediapipe/framework/scheduler.cc index afef4f383..854c10fd5 100644 --- a/mediapipe/framework/scheduler.cc +++ b/mediapipe/framework/scheduler.cc @@ -117,7 +117,7 @@ void Scheduler::SubmitWaitingTasksOnQueues() { // Note: state_mutex_ is held when this function is entered or // exited. void Scheduler::HandleIdle() { - if (handling_idle_) { + if (++handling_idle_ > 1) { // Someone is already inside this method. // Note: This can happen in the sections below where we unlock the mutex // and make more nodes runnable: the nodes can run and become idle again @@ -127,7 +127,6 @@ void Scheduler::HandleIdle() { VLOG(2) << "HandleIdle: already in progress"; return; } - handling_idle_ = true; while (IsIdle() && (state_ == STATE_RUNNING || state_ == STATE_CANCELLING)) { // Remove active sources that are closed. @@ -165,11 +164,17 @@ void Scheduler::HandleIdle() { } } + // If HandleIdle has been called again, then continue scheduling. + if (handling_idle_ > 1) { + handling_idle_ = 1; + continue; + } + // Nothing left to do. break; } - handling_idle_ = false; + handling_idle_ = 0; } // Note: state_mutex_ is held when this function is entered or exited. diff --git a/mediapipe/framework/scheduler.h b/mediapipe/framework/scheduler.h index dd1572d99..b59467b9f 100644 --- a/mediapipe/framework/scheduler.h +++ b/mediapipe/framework/scheduler.h @@ -302,7 +302,7 @@ class Scheduler { // - We need it to be reentrant, which Mutex does not support. // - We want simultaneous calls to return immediately instead of waiting, // and Mutex's TryLock is not guaranteed to work. - bool handling_idle_ ABSL_GUARDED_BY(state_mutex_) = false; + int handling_idle_ ABSL_GUARDED_BY(state_mutex_) = 0; // Mutex for the scheduler state and related things. // Note: state_ is declared as atomic so that its getter methods don't need diff --git a/mediapipe/framework/stream_handler/BUILD b/mediapipe/framework/stream_handler/BUILD index 8771a8773..68a9af52d 100644 --- a/mediapipe/framework/stream_handler/BUILD +++ b/mediapipe/framework/stream_handler/BUILD @@ -13,40 +13,36 @@ # limitations under the License. # +load("//mediapipe/framework/port:build_config.bzl", "mediapipe_cc_proto_library") + licenses(["notice"]) package( - default_visibility = ["//visibility:private"], + default_visibility = ["//visibility:public"], features = ["-layering_check"], ) -load("//mediapipe/framework/port:build_config.bzl", "mediapipe_cc_proto_library") - proto_library( name = "default_input_stream_handler_proto", srcs = ["default_input_stream_handler.proto"], - visibility = ["//visibility:public"], deps = ["//mediapipe/framework:mediapipe_options_proto"], ) proto_library( name = "fixed_size_input_stream_handler_proto", srcs = ["fixed_size_input_stream_handler.proto"], - visibility = ["//visibility:public"], deps = ["//mediapipe/framework:mediapipe_options_proto"], ) proto_library( name = "sync_set_input_stream_handler_proto", srcs = ["sync_set_input_stream_handler.proto"], - visibility = ["//visibility:public"], deps = ["//mediapipe/framework:mediapipe_options_proto"], ) proto_library( name = "timestamp_align_input_stream_handler_proto", srcs = ["timestamp_align_input_stream_handler.proto"], - visibility = ["//visibility:public"], deps = ["//mediapipe/framework:mediapipe_options_proto"], ) @@ -54,7 +50,6 @@ mediapipe_cc_proto_library( name = "default_input_stream_handler_cc_proto", srcs = ["default_input_stream_handler.proto"], cc_deps = ["//mediapipe/framework:mediapipe_options_cc_proto"], - visibility = ["//visibility:public"], deps = [":default_input_stream_handler_proto"], ) @@ -62,7 +57,6 @@ mediapipe_cc_proto_library( name = "fixed_size_input_stream_handler_cc_proto", srcs = ["fixed_size_input_stream_handler.proto"], cc_deps = ["//mediapipe/framework:mediapipe_options_cc_proto"], - visibility = ["//visibility:public"], deps = [":fixed_size_input_stream_handler_proto"], ) @@ -70,7 +64,6 @@ mediapipe_cc_proto_library( name = "sync_set_input_stream_handler_cc_proto", srcs = ["sync_set_input_stream_handler.proto"], cc_deps = ["//mediapipe/framework:mediapipe_options_cc_proto"], - visibility = ["//visibility:public"], deps = [":sync_set_input_stream_handler_proto"], ) @@ -78,14 +71,12 @@ mediapipe_cc_proto_library( name = "timestamp_align_input_stream_handler_cc_proto", srcs = ["timestamp_align_input_stream_handler.proto"], cc_deps = ["//mediapipe/framework:mediapipe_options_cc_proto"], - visibility = ["//visibility:public"], deps = [":timestamp_align_input_stream_handler_proto"], ) cc_library( name = "barrier_input_stream_handler", srcs = ["barrier_input_stream_handler.cc"], - visibility = ["//visibility:public"], deps = [ "//mediapipe/framework:input_stream_handler", ], @@ -96,10 +87,9 @@ cc_library( name = "default_input_stream_handler", srcs = ["default_input_stream_handler.cc"], hdrs = ["default_input_stream_handler.h"], - visibility = ["//visibility:public"], deps = [ + ":default_input_stream_handler_cc_proto", "//mediapipe/framework:input_stream_handler", - "//mediapipe/framework/stream_handler:default_input_stream_handler_cc_proto", "@com_google_absl//absl/strings", ], alwayslink = 1, @@ -108,7 +98,6 @@ cc_library( cc_library( name = "early_close_input_stream_handler", srcs = ["early_close_input_stream_handler.cc"], - visibility = ["//visibility:public"], deps = [ "//mediapipe/framework:input_stream_handler", "@com_google_absl//absl/strings", @@ -119,11 +108,10 @@ cc_library( cc_library( name = "fixed_size_input_stream_handler", srcs = ["fixed_size_input_stream_handler.cc"], - visibility = ["//visibility:public"], deps = [ ":default_input_stream_handler", + ":fixed_size_input_stream_handler_cc_proto", "//mediapipe/framework:input_stream_handler", - "//mediapipe/framework/stream_handler:fixed_size_input_stream_handler_cc_proto", ], alwayslink = 1, ) @@ -131,7 +119,6 @@ cc_library( cc_library( name = "immediate_input_stream_handler", srcs = ["immediate_input_stream_handler.cc"], - visibility = ["//visibility:public"], deps = [ "//mediapipe/framework:input_stream_handler", ], @@ -142,7 +129,6 @@ cc_library( name = "in_order_output_stream_handler", srcs = ["in_order_output_stream_handler.cc"], hdrs = ["in_order_output_stream_handler.h"], - visibility = ["//visibility:public"], deps = [ "//mediapipe/framework:collection", "//mediapipe/framework:collection_item_id", @@ -160,7 +146,6 @@ cc_library( cc_library( name = "mux_input_stream_handler", srcs = ["mux_input_stream_handler.cc"], - visibility = ["//visibility:public"], deps = [ "//mediapipe/framework:input_stream_handler", "//mediapipe/framework/port:logging", @@ -173,15 +158,14 @@ cc_library( cc_library( name = "sync_set_input_stream_handler", srcs = ["sync_set_input_stream_handler.cc"], - visibility = ["//visibility:public"], deps = [ + ":sync_set_input_stream_handler_cc_proto", "//mediapipe/framework:collection", "//mediapipe/framework:collection_item_id", "//mediapipe/framework:input_stream_handler", "//mediapipe/framework:mediapipe_options_cc_proto", "//mediapipe/framework:packet_set", "//mediapipe/framework:timestamp", - "//mediapipe/framework/stream_handler:sync_set_input_stream_handler_cc_proto", "//mediapipe/framework/tool:tag_map", "@com_google_absl//absl/strings", "@com_google_absl//absl/synchronization", @@ -192,12 +176,11 @@ cc_library( cc_library( name = "timestamp_align_input_stream_handler", srcs = ["timestamp_align_input_stream_handler.cc"], - visibility = ["//visibility:public"], deps = [ + ":timestamp_align_input_stream_handler_cc_proto", "//mediapipe/framework:collection_item_id", "//mediapipe/framework:input_stream_handler", "//mediapipe/framework:timestamp", - "//mediapipe/framework/stream_handler:timestamp_align_input_stream_handler_cc_proto", "//mediapipe/framework/tool:validate_name", "@com_google_absl//absl/strings", "@com_google_absl//absl/synchronization", @@ -260,6 +243,7 @@ cc_test( srcs = ["set_input_stream_handler_test.cc"], deps = [ ":fixed_size_input_stream_handler", + ":fixed_size_input_stream_handler_cc_proto", ":mux_input_stream_handler", "//mediapipe/calculators/core:mux_calculator", "//mediapipe/calculators/core:pass_through_calculator", @@ -268,7 +252,6 @@ cc_test( "//mediapipe/framework/port:gtest_main", "//mediapipe/framework/port:parse_text_proto", "//mediapipe/framework/port:ret_check", - "//mediapipe/framework/stream_handler:fixed_size_input_stream_handler_cc_proto", ], ) @@ -289,13 +272,13 @@ cc_test( srcs = ["fixed_size_input_stream_handler_test.cc"], deps = [ ":fixed_size_input_stream_handler", + ":fixed_size_input_stream_handler_cc_proto", "//mediapipe/calculators/core:counting_source_calculator", "//mediapipe/calculators/core:pass_through_calculator", "//mediapipe/framework:calculator_framework", "//mediapipe/framework/port:gtest_main", "//mediapipe/framework/port:logging", "//mediapipe/framework/port:parse_text_proto", - "//mediapipe/framework/stream_handler:fixed_size_input_stream_handler_cc_proto", "@com_google_absl//absl/base:core_headers", "@com_google_absl//absl/synchronization", ], @@ -306,11 +289,11 @@ cc_test( srcs = ["sync_set_input_stream_handler_test.cc"], deps = [ ":sync_set_input_stream_handler", + ":sync_set_input_stream_handler_cc_proto", "//mediapipe/framework:calculator_framework", "//mediapipe/framework:test_calculators", "//mediapipe/framework/port:gtest_main", "//mediapipe/framework/port:parse_text_proto", - "//mediapipe/framework/stream_handler:sync_set_input_stream_handler_cc_proto", "@com_google_absl//absl/base:core_headers", "@com_google_absl//absl/memory", "@com_google_absl//absl/strings", diff --git a/mediapipe/framework/stream_handler/immediate_input_stream_handler_test.cc b/mediapipe/framework/stream_handler/immediate_input_stream_handler_test.cc index e721afb02..e5de7f0c9 100644 --- a/mediapipe/framework/stream_handler/immediate_input_stream_handler_test.cc +++ b/mediapipe/framework/stream_handler/immediate_input_stream_handler_test.cc @@ -230,6 +230,43 @@ TEST_F(ImmediateInputStreamHandlerTest, StreamDoneReady) { input_stream_handler_->ClearCurrentInputs(cc_); } +// This test checks that the state is ReadyForClose after all streams reach +// Timestamp::Max. +TEST_F(ImmediateInputStreamHandlerTest, ReadyForCloseAfterTimestampMax) { + Timestamp min_stream_timestamp; + std::list packets; + + // One packet arrives, ready for process. + packets.push_back(Adopt(new std::string("packet 1")).At(Timestamp(10))); + input_stream_handler_->AddPackets(name_to_id_["input_a"], packets); + EXPECT_TRUE(input_stream_handler_->ScheduleInvocations( + /*max_allowance=*/1, &min_stream_timestamp)); + EXPECT_EQ(Timestamp(10), cc_->InputTimestamp()); + input_stream_handler_->FinalizeInputSet(cc_->InputTimestamp(), + &cc_->Inputs()); + input_stream_handler_->ClearCurrentInputs(cc_); + + // No packets arrive, not ready. + EXPECT_FALSE(input_stream_handler_->ScheduleInvocations( + /*max_allowance=*/1, &min_stream_timestamp)); + EXPECT_EQ(Timestamp::Unset(), cc_->InputTimestamp()); + + // Timestamp::Max arrives, ready for close. + input_stream_handler_->SetNextTimestampBound( + name_to_id_["input_a"], Timestamp::Max().NextAllowedInStream()); + input_stream_handler_->SetNextTimestampBound( + name_to_id_["input_b"], Timestamp::Max().NextAllowedInStream()); + input_stream_handler_->SetNextTimestampBound( + name_to_id_["input_c"], Timestamp::Max().NextAllowedInStream()); + + EXPECT_TRUE(input_stream_handler_->ScheduleInvocations( + /*max_allowance=*/1, &min_stream_timestamp)); + EXPECT_EQ(Timestamp::Done(), cc_->InputTimestamp()); + input_stream_handler_->FinalizeInputSet(cc_->InputTimestamp(), + &cc_->Inputs()); + input_stream_handler_->ClearCurrentInputs(cc_); +} + // This test checks that when any stream is done, the state is ready to close. TEST_F(ImmediateInputStreamHandlerTest, ReadyForClose) { Timestamp min_stream_timestamp; diff --git a/mediapipe/framework/subgraph.cc b/mediapipe/framework/subgraph.cc index d0f018e1a..7cbde28bf 100644 --- a/mediapipe/framework/subgraph.cc +++ b/mediapipe/framework/subgraph.cc @@ -92,7 +92,7 @@ bool GraphRegistry::IsRegistered(const std::string& ns, } absl::StatusOr GraphRegistry::CreateByName( - const std::string& ns, const std::string& type_name, + absl::string_view ns, absl::string_view type_name, SubgraphContext* context) const { absl::StatusOr> maker = local_factories_.IsRegistered(ns, type_name) diff --git a/mediapipe/framework/subgraph.h b/mediapipe/framework/subgraph.h index b3e7d958b..5b1d9646a 100644 --- a/mediapipe/framework/subgraph.h +++ b/mediapipe/framework/subgraph.h @@ -20,6 +20,7 @@ #include "absl/base/macros.h" #include "absl/memory/memory.h" #include "absl/status/status.h" +#include "absl/strings/string_view.h" #include "absl/types/optional.h" #include "mediapipe/framework/calculator.pb.h" #include "mediapipe/framework/deps/registration.h" @@ -187,7 +188,7 @@ class GraphRegistry { // Returns the specified graph config. absl::StatusOr CreateByName( - const std::string& ns, const std::string& type_name, + absl::string_view ns, absl::string_view type_name, SubgraphContext* context = nullptr) const; static GraphRegistry global_graph_registry; diff --git a/mediapipe/framework/tool/BUILD b/mediapipe/framework/tool/BUILD index e54fb2177..193343a90 100644 --- a/mediapipe/framework/tool/BUILD +++ b/mediapipe/framework/tool/BUILD @@ -90,7 +90,7 @@ mediapipe_proto_library( name = "packet_generator_wrapper_calculator_proto", srcs = ["packet_generator_wrapper_calculator.proto"], def_py_proto = False, - visibility = ["//mediapipe/framework:mediapipe_internal"], + visibility = ["//mediapipe/framework:__subpackages__"], deps = [ "//mediapipe/framework:calculator_options_proto", "//mediapipe/framework:packet_generator_proto", @@ -120,13 +120,13 @@ cc_library( name = "fill_packet_set", srcs = ["fill_packet_set.cc"], hdrs = ["fill_packet_set.h"], - visibility = ["//mediapipe/framework:mediapipe_internal"], + visibility = ["//mediapipe/framework:__subpackages__"], deps = [ + ":status_util", "//mediapipe/framework:packet_set", "//mediapipe/framework:packet_type", "//mediapipe/framework/port:status", "//mediapipe/framework/port:statusor", - "//mediapipe/framework/tool:status_util", "@com_google_absl//absl/memory", ], ) @@ -162,7 +162,6 @@ cc_library( cc_test( name = "executor_util_test", srcs = ["executor_util_test.cc"], - visibility = ["//mediapipe/framework:mediapipe_internal"], deps = [ ":executor_util", "//mediapipe/framework/port:gtest_main", @@ -173,7 +172,7 @@ cc_test( cc_library( name = "options_map", hdrs = ["options_map.h"], - visibility = ["//mediapipe/framework:mediapipe_internal"], + visibility = ["//mediapipe:__subpackages__"], deps = [ ":type_util", "//mediapipe/framework:calculator_cc_proto", @@ -193,7 +192,7 @@ cc_library( name = "options_field_util", srcs = ["options_field_util.cc"], hdrs = ["options_field_util.h"], - visibility = ["//mediapipe/framework:mediapipe_internal"], + visibility = ["//visibility:private"], deps = [ ":field_data_cc_proto", ":name_util", @@ -216,7 +215,7 @@ cc_library( name = "options_syntax_util", srcs = ["options_syntax_util.cc"], hdrs = ["options_syntax_util.h"], - visibility = ["//mediapipe/framework:mediapipe_internal"], + visibility = ["//visibility:private"], deps = [ ":name_util", ":options_field_util", @@ -235,8 +234,9 @@ cc_library( name = "options_util", srcs = ["options_util.cc"], hdrs = ["options_util.h"], - visibility = ["//mediapipe/framework:mediapipe_internal"], + visibility = ["//visibility:public"], deps = [ + ":name_util", ":options_field_util", ":options_map", ":options_registry", @@ -254,7 +254,6 @@ cc_library( "//mediapipe/framework/port:advanced_proto", "//mediapipe/framework/port:any_proto", "//mediapipe/framework/port:status", - "//mediapipe/framework/tool:name_util", "@com_google_absl//absl/strings", ], ) @@ -299,12 +298,13 @@ mediapipe_cc_test( data = [":node_chain_subgraph.proto"], requires_full_emulation = False, deps = [ + ":node_chain_subgraph_cc_proto", + ":node_chain_subgraph_options_lib", ":options_field_util", ":options_registry", ":options_syntax_util", ":options_util", "//mediapipe/calculators/core:flow_limiter_calculator", - "//mediapipe/calculators/core:flow_limiter_calculator_cc_proto", "//mediapipe/framework:basic_types_registration", "//mediapipe/framework:calculator_framework", "//mediapipe/framework:calculator_runner", @@ -312,8 +312,8 @@ mediapipe_cc_test( "//mediapipe/framework/port:gtest_main", "//mediapipe/framework/port:parse_text_proto", "//mediapipe/framework/port:status", + "//mediapipe/framework/testdata:night_light_calculator_cc_proto", "//mediapipe/framework/testdata:night_light_calculator_options_lib", - "//mediapipe/framework/tool:node_chain_subgraph_options_lib", "//mediapipe/util:header_util", "@com_google_absl//absl/strings", ], @@ -322,7 +322,7 @@ mediapipe_cc_test( cc_library( name = "packet_generator_wrapper_calculator", srcs = ["packet_generator_wrapper_calculator.cc"], - visibility = ["//mediapipe/framework:mediapipe_internal"], + visibility = ["//mediapipe/framework:__subpackages__"], deps = [ ":packet_generator_wrapper_calculator_cc_proto", "//mediapipe/framework:calculator_base", @@ -346,6 +346,7 @@ cc_library( "//mediapipe/framework/port:logging", "//mediapipe/framework/port:ret_check", "//mediapipe/framework/port:status", + "//mediapipe/framework/port:statusor", "@com_google_absl//absl/strings", ], ) @@ -421,9 +422,9 @@ cc_library( srcs = ["source.cc"], visibility = ["//visibility:public"], deps = [ + ":source_cc_proto", "//mediapipe/framework:calculator_framework", "//mediapipe/framework/port:status", - "//mediapipe/framework/tool:source_cc_proto", "@com_google_absl//absl/base:core_headers", "@com_google_absl//absl/strings", ], @@ -484,14 +485,13 @@ cc_library( hdrs = ["template_expander.h"], visibility = ["//visibility:public"], deps = [ + ":calculator_graph_template_cc_proto", ":proto_util_lite", "//mediapipe/framework:calculator_cc_proto", - "//mediapipe/framework/deps:proto_descriptor_cc_proto", "//mediapipe/framework/port:logging", "//mediapipe/framework/port:numbers", "//mediapipe/framework/port:ret_check", "//mediapipe/framework/port:status", - "//mediapipe/framework/tool:calculator_graph_template_cc_proto", "@com_google_absl//absl/strings", ], ) @@ -506,7 +506,9 @@ cc_library( ], visibility = ["//visibility:public"], deps = [ + ":calculator_graph_template_cc_proto", ":proto_util_lite", + "//mediapipe/framework:calculator_cc_proto", "//mediapipe/framework/deps:proto_descriptor_cc_proto", "//mediapipe/framework/port:advanced_proto", "//mediapipe/framework/port:integral_types", @@ -514,7 +516,6 @@ cc_library( "//mediapipe/framework/port:map_util", "//mediapipe/framework/port:ret_check", "//mediapipe/framework/port:status", - "//mediapipe/framework/tool:calculator_graph_template_cc_proto", "@com_google_absl//absl/base:core_headers", "@com_google_absl//absl/memory", "@com_google_absl//absl/strings", @@ -660,8 +661,8 @@ cc_library( hdrs = ["simulation_clock_executor.h"], visibility = ["//visibility:public"], deps = [ + ":simulation_clock", "//mediapipe/framework:thread_pool_executor", - "//mediapipe/framework/tool:simulation_clock", ], ) @@ -738,9 +739,7 @@ cc_test( "//mediapipe/framework:calculator_cc_proto", "//mediapipe/framework:calculator_framework", "//mediapipe/framework:graph_service_manager", - "//mediapipe/framework:mediapipe_options_cc_proto", "//mediapipe/framework:packet", - "//mediapipe/framework:packet_generator_cc_proto", "//mediapipe/framework:packet_set", "//mediapipe/framework:packet_type", "//mediapipe/framework:status_handler", @@ -790,10 +789,10 @@ cc_library( visibility = ["//visibility:public"], deps = [ ":name_util", + ":switch_container_cc_proto", "//mediapipe/framework:calculator_framework", "//mediapipe/framework/port:ret_check", "//mediapipe/framework/port:status", - "//mediapipe/framework/tool:switch_container_cc_proto", ], ) @@ -806,6 +805,7 @@ cc_library( deps = [ ":container_util", ":options_util", + ":switch_container_cc_proto", "//mediapipe/framework:calculator_framework", "//mediapipe/framework:collection_item_id", "//mediapipe/framework/deps:mathutil", @@ -815,7 +815,6 @@ cc_library( "//mediapipe/framework/port:ret_check", "//mediapipe/framework/port:status", "//mediapipe/framework/stream_handler:immediate_input_stream_handler", - "//mediapipe/framework/tool:switch_container_cc_proto", "@com_google_absl//absl/strings", ], alwayslink = 1, @@ -842,6 +841,7 @@ cc_library( ], deps = [ ":container_util", + ":switch_container_cc_proto", "//mediapipe/framework:calculator_framework", "//mediapipe/framework:collection_item_id", "//mediapipe/framework:input_stream_shard", @@ -851,7 +851,6 @@ cc_library( "//mediapipe/framework/port:ret_check", "//mediapipe/framework/port:status", "//mediapipe/framework/stream_handler:immediate_input_stream_handler", - "//mediapipe/framework/tool:switch_container_cc_proto", ], alwayslink = 1, ) @@ -894,6 +893,7 @@ cc_library( ":container_util", ":name_util", ":subgraph_expansion", + ":switch_container_cc_proto", ":switch_demux_calculator", ":switch_mux_calculator", "//mediapipe/calculators/core:packet_sequencer_calculator", @@ -905,7 +905,6 @@ cc_library( "//mediapipe/framework/port:core_proto", "//mediapipe/framework/port:ret_check", "//mediapipe/framework/port:status", - "//mediapipe/framework/tool:switch_container_cc_proto", "@com_google_absl//absl/strings", ], alwayslink = 1, @@ -923,7 +922,6 @@ cc_test( "//mediapipe/calculators/core:pass_through_calculator", "//mediapipe/framework:calculator_cc_proto", "//mediapipe/framework:calculator_framework", - "//mediapipe/framework:stream_handler_cc_proto", "//mediapipe/framework:subgraph", "//mediapipe/framework:test_calculators", "//mediapipe/framework/port:gtest_main", diff --git a/mediapipe/framework/tool/calculator_graph_template.proto b/mediapipe/framework/tool/calculator_graph_template.proto index 27153f3f7..31c233812 100644 --- a/mediapipe/framework/tool/calculator_graph_template.proto +++ b/mediapipe/framework/tool/calculator_graph_template.proto @@ -27,6 +27,9 @@ message TemplateExpression { // The FieldDescriptor::Type of the modified field. optional mediapipe.FieldDescriptorProto.Type field_type = 5; + // The FieldDescriptor::Type of each map key in the path. + repeated mediapipe.FieldDescriptorProto.Type key_type = 6; + // Alternative value for the modified field, in protobuf binary format. optional string field_value = 7; } diff --git a/mediapipe/framework/tool/mediapipe_graph.bzl b/mediapipe/framework/tool/mediapipe_graph.bzl index 45d98b1eb..ef5182a53 100644 --- a/mediapipe/framework/tool/mediapipe_graph.bzl +++ b/mediapipe/framework/tool/mediapipe_graph.bzl @@ -67,7 +67,8 @@ def data_as_c_string( name, srcs, outs = None, - testonly = None): + testonly = None, + compatible_with = None): """Encodes the data from a file as a C string literal. This produces a text file containing the quoted C string literal. It can be @@ -79,6 +80,7 @@ def data_as_c_string( outs: A list containing a single item, the name of the output text file. Defaults to the rule name. testonly: pass 1 if the graph is to be used only for tests. + compatible_with: a list of environments the rule is compatible with. """ if len(srcs) != 1: fail("srcs must be a single-element list") @@ -92,6 +94,7 @@ def data_as_c_string( cmd = "$(location %s) \"$<\" > \"$@\"" % encode_as_c_string, tools = [encode_as_c_string], testonly = testonly, + compatible_with = compatible_with, ) def mediapipe_simple_subgraph( @@ -208,6 +211,7 @@ def mediapipe_options_library( deps = [], visibility = None, testonly = None, + compatible_with = None, **kwargs): """Registers options protobuf metadata for defining options packets. @@ -217,6 +221,7 @@ def mediapipe_options_library( deps: any additional protobuf dependencies. visibility: The list of packages the subgraph should be visible to. testonly: pass 1 if the graph is to be used only for tests. + compatible_with: a list of environments the rule is compatible with. **kwargs: Remaining keyword args, forwarded to cc_library. """ @@ -224,16 +229,19 @@ def mediapipe_options_library( name = proto_lib + "_transitive", deps = [proto_lib], testonly = testonly, + compatible_with = compatible_with, ) direct_descriptor_set( name = proto_lib + "_direct", deps = [proto_lib], testonly = testonly, + compatible_with = compatible_with, ) data_as_c_string( name = name + "_inc", srcs = [proto_lib + "_transitive-transitive-descriptor-set.proto.bin"], outs = [proto_lib + "_descriptors.inc"], + compatible_with = compatible_with, ) native.genrule( name = name + "_type_name", @@ -245,6 +253,7 @@ def mediapipe_options_library( tools = ["//mediapipe/framework/tool:message_type_util"], visibility = visibility, testonly = testonly, + compatible_with = compatible_with, ) expand_template( name = name + "_cc", @@ -256,6 +265,7 @@ def mediapipe_options_library( "{{DESCRIPTOR_INC_FILE_PATH}}": native.package_name() + "/" + proto_lib + "_descriptors.inc", }, testonly = testonly, + compatible_with = compatible_with, ) native.cc_library( name = proto_lib.replace("_proto", "_options_registry"), @@ -274,6 +284,7 @@ def mediapipe_options_library( visibility = visibility, testonly = testonly, features = ["-no_undefined"], + compatible_with = compatible_with, **kwargs ) mediapipe_reexport_library( diff --git a/mediapipe/framework/tool/options_lib_template.cc b/mediapipe/framework/tool/options_lib_template.cc index 21a5db10f..4861132a2 100644 --- a/mediapipe/framework/tool/options_lib_template.cc +++ b/mediapipe/framework/tool/options_lib_template.cc @@ -28,7 +28,7 @@ constexpr char kDescriptorContents[] = mediapipe::FieldData ReadFileDescriptorSet(const std::string& pb) { mediapipe::FieldData result; *result.mutable_message_value()->mutable_type_url() = - "proto2.FileDescriptorSet"; + "google::protobuf.FileDescriptorSet"; *result.mutable_message_value()->mutable_value() = pb; // Force linking of the generated options protobuf. diff --git a/mediapipe/framework/tool/options_registry.cc b/mediapipe/framework/tool/options_registry.cc index f6858be0a..07cc65a95 100644 --- a/mediapipe/framework/tool/options_registry.cc +++ b/mediapipe/framework/tool/options_registry.cc @@ -66,26 +66,28 @@ std::string GetFieldString(const FieldData& message_data, void RegisterDescriptorProtos( absl::flat_hash_map& result) { std::vector descriptors = { - {"proto2.FileDescriptorSet", + {"google::protobuf.FileDescriptorSet", { - {"file", 1, FieldType::TYPE_MESSAGE, "proto2.FileDescriptorProto"}, + {"file", 1, FieldType::TYPE_MESSAGE, + "google::protobuf.FileDescriptorProto"}, }}, - {"proto2.FileDescriptorProto", + {"google::protobuf.FileDescriptorProto", { {"package", 2, FieldType::TYPE_STRING, ""}, {"message_type", 4, FieldType::TYPE_MESSAGE, - "proto2.DescriptorProto"}, + "google::protobuf.DescriptorProto"}, }}, - {"proto2.DescriptorProto", + {"google::protobuf.DescriptorProto", { {"name", 1, FieldType::TYPE_STRING, ""}, - {"field", 2, FieldType::TYPE_MESSAGE, "proto2.FieldDescriptorProto"}, + {"field", 2, FieldType::TYPE_MESSAGE, + "google::protobuf.FieldDescriptorProto"}, {"extension", 6, FieldType::TYPE_MESSAGE, - "proto2.FieldDescriptorProto"}, + "google::protobuf.FieldDescriptorProto"}, {"nested_type", 3, FieldType::TYPE_MESSAGE, - "proto2.DescriptorProto"}, + "google::protobuf.DescriptorProto"}, }}, - {"proto2.FieldDescriptorProto", + {"google::protobuf.FieldDescriptorProto", { {"name", 1, FieldType::TYPE_STRING, ""}, {"number", 3, FieldType::TYPE_INT32, ""}, @@ -140,7 +142,7 @@ void OptionsRegistry::Register(const FieldData& message_type, const Descriptor* OptionsRegistry::GetProtobufDescriptor( const std::string& type_name) { - if (descriptors().count("proto2.DescriptorProto") == 0) { + if (descriptors().count("google::protobuf.DescriptorProto") == 0) { RegisterDescriptorProtos(descriptors()); } absl::ReaderMutexLock lock(&mutex()); diff --git a/mediapipe/framework/tool/options_registry.h b/mediapipe/framework/tool/options_registry.h index b843b113a..3b2d2be89 100644 --- a/mediapipe/framework/tool/options_registry.h +++ b/mediapipe/framework/tool/options_registry.h @@ -28,7 +28,7 @@ class OptionsRegistry { // Finds the descriptor for a protobuf. static const Descriptor* GetProtobufDescriptor(const std::string& type_name); - // Returns all known proto2 extensions to a type. + // Returns all known google::protobuf extensions to a type. static void FindAllExtensions(absl::string_view extendee, std::vector* result); diff --git a/mediapipe/framework/tool/proto_util_lite.cc b/mediapipe/framework/tool/proto_util_lite.cc index 4628815ea..a810ce129 100644 --- a/mediapipe/framework/tool/proto_util_lite.cc +++ b/mediapipe/framework/tool/proto_util_lite.cc @@ -22,6 +22,7 @@ #include "mediapipe/framework/port/canonical_errors.h" #include "mediapipe/framework/port/logging.h" #include "mediapipe/framework/port/ret_check.h" +#include "mediapipe/framework/port/statusor.h" #include "mediapipe/framework/tool/field_data.pb.h" #include "mediapipe/framework/type_map.h" @@ -87,12 +88,13 @@ absl::Status ReadPackedValues(WireFormatLite::WireType wire_type, // Extracts the data value(s) for one field from a serialized message. // The message with these field values removed is written to |out|. -absl::Status GetFieldValues(uint32 field_id, WireFormatLite::WireType wire_type, - CodedInputStream* in, CodedOutputStream* out, +absl::Status GetFieldValues(uint32 field_id, CodedInputStream* in, + CodedOutputStream* out, std::vector* field_values) { uint32 tag; while ((tag = in->ReadTag()) != 0) { int field_number = WireFormatLite::GetTagFieldNumber(tag); + WireFormatLite::WireType wire_type = WireFormatLite::GetTagWireType(tag); if (field_number == field_id) { if (!IsLengthDelimited(wire_type) && IsLengthDelimited(WireFormatLite::GetTagWireType(tag))) { @@ -131,9 +133,7 @@ absl::Status FieldAccess::SetMessage(const std::string& message) { CodedInputStream in(&ais); StringOutputStream sos(&message_); CodedOutputStream out(&sos); - WireFormatLite::WireType wire_type = - WireFormatLite::WireTypeForFieldType(field_type_); - return GetFieldValues(field_id_, wire_type, &in, &out, &field_values_); + return GetFieldValues(field_id_, &in, &out, &field_values_); } void FieldAccess::GetMessage(std::string* result) { @@ -149,18 +149,56 @@ std::vector* FieldAccess::mutable_field_values() { return &field_values_; } +namespace { +using ProtoPathEntry = ProtoUtilLite::ProtoPathEntry; + +// Returns the FieldAccess and index for a field-id or a map-id. +// Returns access to the field-id if the field index is found, +// to the map-id if the map entry is found, and to the field-id otherwise. +absl::StatusOr> AccessField( + const ProtoPathEntry& entry, FieldType field_type, + const FieldValue& message) { + FieldAccess result(entry.field_id, field_type); + if (entry.field_id >= 0) { + MP_RETURN_IF_ERROR(result.SetMessage(message)); + if (entry.index < result.mutable_field_values()->size()) { + return std::pair(result, entry.index); + } + } + if (entry.map_id >= 0) { + FieldAccess access(entry.map_id, field_type); + MP_RETURN_IF_ERROR(access.SetMessage(message)); + auto& field_values = *access.mutable_field_values(); + for (int index = 0; index < field_values.size(); ++index) { + FieldAccess key(entry.key_id, entry.key_type); + MP_RETURN_IF_ERROR(key.SetMessage(field_values[index])); + if (key.mutable_field_values()->at(0) == entry.key_value) { + return std::pair(std::move(access), index); + } + } + } + if (entry.field_id >= 0) { + return std::pair(result, entry.index); + } + return absl::InvalidArgumentError(absl::StrCat( + "ProtoPath field missing, field-id: ", entry.field_id, ", map-id: ", + entry.map_id, ", key: ", entry.key_value, " key_type: ", entry.key_type)); +} + +} // namespace + // Replaces a range of field values for one field nested within a protobuf. absl::Status ProtoUtilLite::ReplaceFieldRange( FieldValue* message, ProtoPath proto_path, int length, FieldType field_type, const std::vector& field_values) { - int field_id, index; - std::tie(field_id, index) = proto_path.front(); + ProtoPathEntry entry = proto_path.front(); proto_path.erase(proto_path.begin()); - FieldAccess access(field_id, !proto_path.empty() - ? WireFormatLite::TYPE_MESSAGE - : field_type); - MP_RETURN_IF_ERROR(access.SetMessage(*message)); - std::vector& v = *access.mutable_field_values(); + FieldType type = + !proto_path.empty() ? WireFormatLite::TYPE_MESSAGE : field_type; + ASSIGN_OR_RETURN(auto r, AccessField(entry, type, *message)); + FieldAccess& access = r.first; + int index = r.second; + std::vector& v = *access.mutable_field_values(); if (!proto_path.empty()) { RET_CHECK_NO_LOG(index >= 0 && index < v.size()); MP_RETURN_IF_ERROR(ReplaceFieldRange(&v[index], proto_path, length, @@ -180,19 +218,22 @@ absl::Status ProtoUtilLite::ReplaceFieldRange( absl::Status ProtoUtilLite::GetFieldRange( const FieldValue& message, ProtoPath proto_path, int length, FieldType field_type, std::vector* field_values) { - int field_id, index; - std::tie(field_id, index) = proto_path.front(); + ProtoPathEntry entry = proto_path.front(); proto_path.erase(proto_path.begin()); - FieldAccess access(field_id, !proto_path.empty() - ? WireFormatLite::TYPE_MESSAGE - : field_type); - MP_RETURN_IF_ERROR(access.SetMessage(message)); - std::vector& v = *access.mutable_field_values(); + FieldType type = + !proto_path.empty() ? WireFormatLite::TYPE_MESSAGE : field_type; + ASSIGN_OR_RETURN(auto r, AccessField(entry, type, message)); + FieldAccess& access = r.first; + int index = r.second; + std::vector& v = *access.mutable_field_values(); if (!proto_path.empty()) { RET_CHECK_NO_LOG(index >= 0 && index < v.size()); MP_RETURN_IF_ERROR( GetFieldRange(v[index], proto_path, length, field_type, field_values)); } else { + if (length == -1) { + length = v.size() - index; + } RET_CHECK_NO_LOG(index >= 0 && index <= v.size()); RET_CHECK_NO_LOG(index + length >= 0 && index + length <= v.size()); field_values->insert(field_values->begin(), v.begin() + index, @@ -206,19 +247,21 @@ absl::Status ProtoUtilLite::GetFieldCount(const FieldValue& message, ProtoPath proto_path, FieldType field_type, int* field_count) { - int field_id, index; - std::tie(field_id, index) = proto_path.back(); - proto_path.pop_back(); - std::vector parent; - if (proto_path.empty()) { - parent.push_back(std::string(message)); + ProtoPathEntry entry = proto_path.front(); + proto_path.erase(proto_path.begin()); + FieldType type = + !proto_path.empty() ? WireFormatLite::TYPE_MESSAGE : field_type; + ASSIGN_OR_RETURN(auto r, AccessField(entry, type, message)); + FieldAccess& access = r.first; + int index = r.second; + std::vector& v = *access.mutable_field_values(); + if (!proto_path.empty()) { + RET_CHECK_NO_LOG(index >= 0 && index < v.size()); + MP_RETURN_IF_ERROR( + GetFieldCount(v[index], proto_path, field_type, field_count)); } else { - MP_RETURN_IF_ERROR(ProtoUtilLite::GetFieldRange( - message, proto_path, 1, WireFormatLite::TYPE_MESSAGE, &parent)); + *field_count = v.size(); } - FieldAccess access(field_id, field_type); - MP_RETURN_IF_ERROR(access.SetMessage(parent[0])); - *field_count = access.mutable_field_values()->size(); return absl::OkStatus(); } diff --git a/mediapipe/framework/tool/proto_util_lite.h b/mediapipe/framework/tool/proto_util_lite.h index 7d3a263f3..15e321eeb 100644 --- a/mediapipe/framework/tool/proto_util_lite.h +++ b/mediapipe/framework/tool/proto_util_lite.h @@ -34,15 +34,36 @@ class ProtoUtilLite { // Defines field types and tag formats. using WireFormatLite = proto_ns::internal::WireFormatLite; - // Defines a sequence of nested field-number field-index pairs. - using ProtoPath = std::vector>; - // The serialized value for a protobuf field. using FieldValue = std::string; // The serialized data type for a protobuf field. using FieldType = WireFormatLite::FieldType; + // A field-id and index, or a map-id and key, or both. + struct ProtoPathEntry { + ProtoPathEntry(int id, int index) : field_id(id), index(index) {} + ProtoPathEntry(int id, int key_id, FieldType key_type, FieldValue key_value) + : map_id(id), + key_id(key_id), + key_type(key_type), + key_value(std::move(key_value)) {} + bool operator==(const ProtoPathEntry& o) const { + return field_id == o.field_id && index == o.index && map_id == o.map_id && + key_id == o.key_id && key_type == o.key_type && + key_value == o.key_value; + } + int field_id = -1; + int index = -1; + int map_id = -1; + int key_id = -1; + FieldType key_type = FieldType::MAX_FIELD_TYPE; + FieldValue key_value; + }; + + // Defines a sequence of nested field-number field-index pairs. + using ProtoPath = std::vector; + class FieldAccess { public: // Provides access to a certain protobuf field. @@ -57,9 +78,11 @@ class ProtoUtilLite { // Returns the serialized values of the protobuf field. std::vector* mutable_field_values(); + uint32 field_id() const { return field_id_; } + private: - const uint32 field_id_; - const FieldType field_type_; + uint32 field_id_; + FieldType field_type_; std::string message_; std::vector field_values_; }; diff --git a/mediapipe/framework/tool/sink.cc b/mediapipe/framework/tool/sink.cc index 4a181b43f..f8abf4925 100644 --- a/mediapipe/framework/tool/sink.cc +++ b/mediapipe/framework/tool/sink.cc @@ -87,7 +87,8 @@ void AddVectorSink(const std::string& stream_name, // node->mutable_options()->MutableExtension( CallbackPacketCalculatorOptions::ext); options->set_type(CallbackPacketCalculatorOptions::VECTOR_PACKET); - char address[17]; + // Up to 64-bit pointer in hex (16 characters) and an optional "0x" prepended. + char address[19]; int written = snprintf(address, sizeof(address), "%p", dumped_data); CHECK(written > 0 && written < sizeof(address)); options->set_pointer(address); @@ -112,7 +113,8 @@ void AddPostStreamPacketSink(const std::string& stream_name, node->mutable_options()->MutableExtension( CallbackPacketCalculatorOptions::ext); options->set_type(CallbackPacketCalculatorOptions::POST_STREAM_PACKET); - char address[17]; + // Up to 64-bit pointer in hex (16 characters) and an optional "0x" prepended. + char address[19]; int written = snprintf(address, sizeof(address), "%p", post_stream_packet); CHECK(written > 0 && written < sizeof(address)); options->set_pointer(address); diff --git a/mediapipe/framework/tool/sink_test.cc b/mediapipe/framework/tool/sink_test.cc index 2b5f94f9f..c5316af4d 100644 --- a/mediapipe/framework/tool/sink_test.cc +++ b/mediapipe/framework/tool/sink_test.cc @@ -171,6 +171,7 @@ class TimestampBoundTestCalculator : public CalculatorBase { }; REGISTER_CALCULATOR(TimestampBoundTestCalculator); +#if 0 // test is flaky, try it with --runs_per_test=200 TEST(CallbackTest, TestAddMultiStreamCallbackWithTimestampNotification) { std::string config_str = R"( node { @@ -203,6 +204,7 @@ TEST(CallbackTest, TestAddMultiStreamCallbackWithTimestampNotification) { EXPECT_THAT(sums, testing::ElementsAre(10, 20)); } +#endif } // namespace } // namespace mediapipe diff --git a/mediapipe/framework/tool/switch/BUILD b/mediapipe/framework/tool/switch/BUILD new file mode 100644 index 000000000..e7a3ba741 --- /dev/null +++ b/mediapipe/framework/tool/switch/BUILD @@ -0,0 +1,60 @@ +# Copyright 2023 The MediaPipe Authors. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# + +licenses(["notice"]) + +package(default_visibility = ["//visibility:private"]) + +cc_library( + name = "packet_processor", + hdrs = ["packet_processor.h"], + visibility = [ + "//visibility:public", + ], + deps = [ + "//mediapipe/framework:calculator_contract", + "//mediapipe/framework:collection_item_id", + "//mediapipe/framework:packet", + "//mediapipe/framework/port:logging", + "//mediapipe/framework/port:ret_check", + "//mediapipe/framework/port:status", + ], +) + +cc_library( + name = "graph_processor", + srcs = ["graph_processor.cc"], + hdrs = ["graph_processor.h"], + visibility = [ + "//visibility:public", + ], + deps = [ + ":packet_processor", + "//mediapipe/framework:calculator_framework", + "//mediapipe/framework:collection_item_id", + "//mediapipe/framework:input_stream_shard", + "//mediapipe/framework:output_stream_shard", + "//mediapipe/framework:validated_graph_config", + "//mediapipe/framework/api2:node", + "//mediapipe/framework/api2:packet", + "//mediapipe/framework/port:integral_types", + "//mediapipe/framework/port:logging", + "//mediapipe/framework/port:ret_check", + "//mediapipe/framework/port:status", + "//mediapipe/framework/stream_handler:immediate_input_stream_handler", + "@com_google_absl//absl/synchronization", + ], + alwayslink = 1, +) diff --git a/mediapipe/framework/tool/switch/graph_processor.cc b/mediapipe/framework/tool/switch/graph_processor.cc new file mode 100644 index 000000000..f35730761 --- /dev/null +++ b/mediapipe/framework/tool/switch/graph_processor.cc @@ -0,0 +1,110 @@ +#include "mediapipe/framework/tool/switch/graph_processor.h" + +#include "absl/synchronization/mutex.h" + +namespace mediapipe { + +// TODO: add support for input and output side packets. +absl::Status GraphProcessor::Initialize(CalculatorGraphConfig graph_config) { + graph_config_ = graph_config; + + ASSIGN_OR_RETURN(graph_input_map_, + tool::TagMap::Create(graph_config_.input_stream())); + ASSIGN_OR_RETURN(graph_output_map_, + tool::TagMap::Create(graph_config_.output_stream())); + return absl::OkStatus(); +} + +absl::Status GraphProcessor::AddPacket(CollectionItemId id, Packet packet) { + absl::MutexLock lock(&graph_mutex_); + const std::string& stream_name = graph_input_map_->Names().at(id.value()); + return graph_->AddPacketToInputStream(stream_name, packet); +} + +std::shared_ptr GraphProcessor::InputTags() { + return graph_input_map_; +} + +absl::Status GraphProcessor::SendPacket(CollectionItemId id, Packet packet) { + MP_RETURN_IF_ERROR(WaitUntilInitialized()); + auto it = consumer_ids_.find(id); + if (it == consumer_ids_.end()) { + return absl::NotFoundError( + absl::StrCat("Consumer stream not found: ", id.value())); + } + return consumer_->AddPacket(it->second, packet); +} + +void GraphProcessor::SetConsumer(PacketConsumer* consumer) { + absl::MutexLock lock(&graph_mutex_); + consumer_ = consumer; + auto input_map = consumer_->InputTags(); + for (auto id = input_map->BeginId(); id != input_map->EndId(); ++id) { + auto tag_index = input_map->TagAndIndexFromId(id); + auto stream_id = graph_input_map_->GetId(tag_index.first, tag_index.second); + consumer_ids_[stream_id] = id; + } +} + +absl::Status GraphProcessor::ObserveGraph() { + for (auto id = graph_output_map_->BeginId(); id != graph_output_map_->EndId(); + ++id) { + std::string stream_name = graph_output_map_->Names().at(id.value()); + MP_RETURN_IF_ERROR(graph_->ObserveOutputStream( + stream_name, + [this, id](const Packet& packet) { return SendPacket(id, packet); }, + true)); + } + return absl::OkStatus(); +} + +absl::Status GraphProcessor::WaitUntilInitialized() { + absl::MutexLock lock(&graph_mutex_); + auto is_initialized = [this]() ABSL_SHARED_LOCKS_REQUIRED(graph_mutex_) { + return graph_ != nullptr && consumer_ != nullptr; + }; + graph_mutex_.AwaitWithTimeout(absl::Condition(&is_initialized), + absl::Seconds(4)); + RET_CHECK(is_initialized()) << "GraphProcessor initialization timed out."; + return absl::OkStatus(); +} + +absl::Status GraphProcessor::Start() { + absl::MutexLock lock(&graph_mutex_); + graph_ = std::make_unique(); + + // The graph is validated here with its specified inputs and output. + MP_RETURN_IF_ERROR(graph_->Initialize(graph_config_, side_packets_)); + MP_RETURN_IF_ERROR(ObserveGraph()); + MP_RETURN_IF_ERROR(graph_->StartRun({})); + return absl::OkStatus(); +} + +absl::Status GraphProcessor::Shutdown() { + absl::MutexLock lock(&graph_mutex_); + if (!graph_) { + return absl::OkStatus(); + } + MP_RETURN_IF_ERROR(graph_->CloseAllPacketSources()); + MP_RETURN_IF_ERROR(graph_->WaitUntilDone()); + graph_ = nullptr; + return absl::OkStatus(); +} + +absl::Status GraphProcessor::WaitUntilIdle() { + absl::MutexLock lock(&graph_mutex_); + return graph_->WaitUntilIdle(); +} + +// TODO +absl::Status GraphProcessor::SetSidePacket(CollectionItemId id, Packet packet) { + return absl::OkStatus(); +} +// TODO +std::shared_ptr GraphProcessor::SideInputTags() { + return nullptr; +} +// TODO +void GraphProcessor::SetSideConsumer(SidePacketConsumer* consumer) {} + +} // namespace mediapipe diff --git a/mediapipe/framework/tool/switch/graph_processor.h b/mediapipe/framework/tool/switch/graph_processor.h new file mode 100644 index 000000000..e2220b5dc --- /dev/null +++ b/mediapipe/framework/tool/switch/graph_processor.h @@ -0,0 +1,59 @@ +#ifndef MEDIAPIPE_FRAMEWORK_TOOL_GRAPH_PROCESSOR_H_ +#define MEDIAPIPE_FRAMEWORK_TOOL_GRAPH_PROCESSOR_H_ + +#include + +#include "absl/synchronization/mutex.h" +#include "mediapipe/framework/calculator_framework.h" +#include "mediapipe/framework/collection_item_id.h" +#include "mediapipe/framework/port/status.h" +#include "mediapipe/framework/tool/switch/packet_processor.h" + +namespace mediapipe { + +// Processes MediaPipe Packets using a MediaPipe CalculatorGraph. +class GraphProcessor : public PacketProcessor { + public: + GraphProcessor() = default; + + // Configures this GraphProcessor to create a run a CalculatorGraph. + absl::Status Initialize(CalculatorGraphConfig graph_config); + + public: + // The PacketProcessor interface. + absl::Status AddPacket(CollectionItemId id, Packet packet) override; + std::shared_ptr InputTags() override; + absl::Status SetSidePacket(CollectionItemId id, Packet packet) override; + std::shared_ptr SideInputTags() override; + void SetConsumer(PacketConsumer* consumer) override; + void SetSideConsumer(SidePacketConsumer* consumer) override; + absl::Status Start() override; + absl::Status Shutdown() override; + absl::Status WaitUntilIdle() override; + + private: + // Sends a tagged output packet. + absl::Status SendPacket(CollectionItemId id, Packet packet); + + // Observes output packets from the calculator graph. + absl::Status ObserveGraph() ABSL_SHARED_LOCKS_REQUIRED(graph_mutex_); + + // Blocks until this GraphProcessor is initialized. + absl::Status WaitUntilInitialized(); + + private: + CalculatorGraphConfig graph_config_; + std::shared_ptr graph_input_map_; + std::shared_ptr graph_output_map_; + std::map consumer_ids_; + + PacketConsumer* consumer_ = nullptr; + std::map side_packets_; + std::unique_ptr graph_ ABSL_GUARDED_BY(graph_mutex_) = + nullptr; + absl::Mutex graph_mutex_; +}; + +} // namespace mediapipe + +#endif // MEDIAPIPE_FRAMEWORK_TOOL_GRAPH_PROCESSOR_H_ diff --git a/mediapipe/framework/tool/switch/packet_processor.h b/mediapipe/framework/tool/switch/packet_processor.h new file mode 100644 index 000000000..d97883c53 --- /dev/null +++ b/mediapipe/framework/tool/switch/packet_processor.h @@ -0,0 +1,88 @@ +// Copyright 2022 The MediaPipe Authors. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#ifndef MEDIAPIPE_FRAMEWORK_TOOL_PACKET_PROCESSOR_H_ +#define MEDIAPIPE_FRAMEWORK_TOOL_PACKET_PROCESSOR_H_ + +#include + +#include "mediapipe/framework/collection_item_id.h" +#include "mediapipe/framework/packet.h" +#include "mediapipe/framework/port/status.h" + +namespace mediapipe { + +// PacketConsumer accepts several tagged streams of packets. +class PacketConsumer { + public: + virtual ~PacketConsumer() = default; + + // Accepts a tagged input packet. + virtual absl::Status AddPacket(CollectionItemId id, Packet packet) = 0; + + // Returns the id for each input tag. + virtual std::shared_ptr InputTags() = 0; +}; + +// PacketConsumer delivers several tagged streams of packets. +class PacketProducer { + public: + virtual ~PacketProducer() = default; + + // Connects a consumer to recieve packets from this producer. + virtual void SetConsumer(PacketConsumer* consumer) = 0; +}; + +// SidePacketConsumer accepts several tagged constant packets. +class SidePacketConsumer { + public: + virtual ~SidePacketConsumer() = default; + + // Accepts a tagged input side-packet. + virtual absl::Status SetSidePacket(CollectionItemId id, Packet packet) = 0; + + // Returns the id for each input side-packet tag. + virtual std::shared_ptr SideInputTags() = 0; +}; + +// SidePacketProducer delivers several tagged constant packets. +class SidePacketProducer { + public: + virtual ~SidePacketProducer() = default; + + // Connects a consumer to recieve packets from this producer. + virtual void SetSideConsumer(SidePacketConsumer* consumer) = 0; +}; + +// PacketProcessor consumes and produces packet streams and constant packets. +class PacketProcessor : public PacketConsumer, + public PacketProducer, + public SidePacketConsumer, + public SidePacketProducer { + public: + virtual ~PacketProcessor() = default; + + // Activate this PacketProcessor. + virtual absl::Status Start() = 0; + + // Block until this PacketProcessor has no remaining work to do. + virtual absl::Status WaitUntilIdle() = 0; + + // Deactivate this PacketProcessor. + virtual absl::Status Shutdown() = 0; +}; + +} // namespace mediapipe + +#endif // MEDIAPIPE_FRAMEWORK_TOOL_PACKET_PROCESSOR_H_ diff --git a/mediapipe/framework/tool/template_expander.cc b/mediapipe/framework/tool/template_expander.cc index 034e1a026..a91ea5adc 100644 --- a/mediapipe/framework/tool/template_expander.cc +++ b/mediapipe/framework/tool/template_expander.cc @@ -22,6 +22,7 @@ #include #include "absl/strings/ascii.h" +#include "absl/strings/match.h" #include "absl/strings/numbers.h" #include "absl/strings/str_join.h" #include "absl/strings/str_split.h" @@ -44,6 +45,7 @@ using WireFormatLite = ProtoUtilLite::WireFormatLite; using FieldValue = ProtoUtilLite::FieldValue; using FieldType = ProtoUtilLite::FieldType; using ProtoPath = ProtoUtilLite::ProtoPath; +using ProtoPathEntry = ProtoUtilLite::ProtoPathEntry; namespace { @@ -84,26 +86,87 @@ std::unique_ptr CloneMessage(const MessageLite& message) { return result; } -// Returns the (tag, index) pairs in a field path. -// For example, returns {{1, 1}, {2, 1}, {3, 1}} for path "/1[1]/2[1]/3[1]". -absl::Status ProtoPathSplit(const std::string& path, ProtoPath* result) { - absl::Status status; - std::vector ids = absl::StrSplit(path, '/'); - for (const std::string& id : ids) { - if (id.length() > 0) { - std::pair id_pair = - absl::StrSplit(id, absl::ByAnyChar("[]")); - int tag = 0; - int index = 0; - bool ok = absl::SimpleAtoi(id_pair.first, &tag) && - absl::SimpleAtoi(id_pair.second, &index); - if (!ok) { - status.Update(absl::InvalidArgumentError(path)); - } - result->push_back(std::make_pair(tag, index)); +// Parses one ProtoPathEntry. +// The parsed entry is appended to `result` and removed from `path`. +// ProtoPathEntry::key_value stores map key text. Use SetMapKeyTypes +// to serialize the key text to protobuf wire format. +absl::Status ParseEntry(absl::string_view& path, ProtoPath* result) { + bool ok = true; + int sb = path.find('['); + int eb = path.find(']'); + int field_id = -1; + ok &= absl::SimpleAtoi(path.substr(0, sb), &field_id); + auto selector = path.substr(sb + 1, eb - 1 - sb); + if (absl::StartsWith(selector, "@")) { + int eq = selector.find('='); + int key_id = -1; + ok &= absl::SimpleAtoi(selector.substr(1, eq - 1), &key_id); + auto key_text = selector.substr(eq + 1); + FieldType key_type = FieldType::TYPE_STRING; + result->push_back({field_id, key_id, key_type, std::string(key_text)}); + } else { + int index = 0; + ok &= absl::SimpleAtoi(selector, &index); + result->push_back({field_id, index}); + } + int end = path.find('/', eb); + if (end == std::string::npos) { + path = ""; + } else { + path = path.substr(end + 1); + } + return ok ? absl::OkStatus() + : absl::InvalidArgumentError( + absl::StrCat("Failed to parse ProtoPath entry: ", path)); +} + +// Specifies the FieldTypes for protobuf map keys in a ProtoPath. +// Each ProtoPathEntry::key_value is converted from text to the protobuf +// wire format for its key type. +absl::Status SetMapKeyTypes(const std::vector& key_types, + ProtoPath* result) { + int i = 0; + for (ProtoPathEntry& entry : *result) { + if (entry.map_id >= 0) { + FieldType key_type = key_types[i++]; + std::vector key_value; + MP_RETURN_IF_ERROR( + ProtoUtilLite::Serialize({entry.key_value}, key_type, &key_value)); + entry.key_type = key_type; + entry.key_value = key_value.front(); } } - return status; + return absl::OkStatus(); +} + +// Returns the (tag, index) pairs in a field path. +// For example, returns {{1, 1}, {2, 1}, {3, 1}} for "/1[1]/2[1]/3[1]", +// returns {{1, 1}, {2, 1, "INPUT_FRAMES"}} for "/1[1]/2[@1=INPUT_FRAMES]". +absl::Status ProtoPathSplit(const std::string& path, ProtoPath* result) { + result->clear(); + absl::string_view rest = path; + if (absl::StartsWith(rest, "/")) { + rest = rest.substr(1); + } + while (!rest.empty()) { + MP_RETURN_IF_ERROR(ParseEntry(rest, result)); + } + return absl::OkStatus(); +} + +// Parse the TemplateExpression.path field into a ProtoPath struct. +absl::Status ParseProtoPath(const TemplateExpression& rule, + std::string base_path, ProtoPath* result) { + ProtoPath base_entries; + MP_RETURN_IF_ERROR(ProtoPathSplit(base_path, &base_entries)); + MP_RETURN_IF_ERROR(ProtoPathSplit(rule.path(), result)); + std::vector key_types; + for (int type : rule.key_type()) { + key_types.push_back(static_cast(type)); + } + MP_RETURN_IF_ERROR(SetMapKeyTypes(key_types, result)); + result->erase(result->begin(), result->begin() + base_entries.size()); + return absl::OkStatus(); } // Returns true if one proto path is prefix by another. @@ -111,13 +174,6 @@ bool ProtoPathStartsWith(const std::string& path, const std::string& prefix) { return absl::StartsWith(path, prefix); } -// Returns the part of one proto path after a prefix proto path. -std::string ProtoPathRelative(const std::string& field_path, - const std::string& base_path) { - CHECK(ProtoPathStartsWith(field_path, base_path)); - return field_path.substr(base_path.length()); -} - // Returns the target ProtoUtilLite::FieldType of a rule. FieldType GetFieldType(const TemplateExpression& rule) { return static_cast(rule.field_type()); @@ -126,19 +182,10 @@ FieldType GetFieldType(const TemplateExpression& rule) { // Returns the count of field values at a ProtoPath. int FieldCount(const FieldValue& base, ProtoPath field_path, FieldType field_type) { - int field_id, index; - std::tie(field_id, index) = field_path.back(); - field_path.pop_back(); - std::vector parent; - if (field_path.empty()) { - parent.push_back(base); - } else { - MEDIAPIPE_CHECK_OK(ProtoUtilLite::GetFieldRange( - base, field_path, 1, WireFormatLite::TYPE_MESSAGE, &parent)); - } - ProtoUtilLite::FieldAccess access(field_id, field_type); - MEDIAPIPE_CHECK_OK(access.SetMessage(parent[0])); - return access.mutable_field_values()->size(); + int result = 0; + CHECK( + ProtoUtilLite::GetFieldCount(base, field_path, field_type, &result).ok()); + return result; } } // namespace @@ -229,9 +276,7 @@ class TemplateExpanderImpl { return absl::OkStatus(); } ProtoPath field_path; - absl::Status status = - ProtoPathSplit(ProtoPathRelative(rule.path(), base_path), &field_path); - if (!status.ok()) return status; + MP_RETURN_IF_ERROR(ParseProtoPath(rule, base_path, &field_path)); return ProtoUtilLite::GetFieldRange(output, field_path, 1, GetFieldType(rule), base); } @@ -242,12 +287,13 @@ class TemplateExpanderImpl { const std::vector& field_values, FieldValue* output) { if (!rule.has_path()) { - *output = field_values[0]; + if (!field_values.empty()) { + *output = field_values[0]; + } return absl::OkStatus(); } ProtoPath field_path; - RET_CHECK_OK( - ProtoPathSplit(ProtoPathRelative(rule.path(), base_path), &field_path)); + MP_RETURN_IF_ERROR(ParseProtoPath(rule, base_path, &field_path)); int field_count = 1; if (rule.has_field_value()) { // For a non-repeated field, only one value can be specified. @@ -257,7 +303,7 @@ class TemplateExpanderImpl { "Multiple values specified for non-repeated field: ", rule.path())); } // For a non-repeated field, the field value is stored only in the rule. - field_path[field_path.size() - 1].second = 0; + field_path[field_path.size() - 1].index = 0; field_count = 0; } return ProtoUtilLite::ReplaceFieldRange(output, field_path, field_count, diff --git a/mediapipe/framework/tool/template_parser.cc b/mediapipe/framework/tool/template_parser.cc index 1d81e7a78..6c7237f8e 100644 --- a/mediapipe/framework/tool/template_parser.cc +++ b/mediapipe/framework/tool/template_parser.cc @@ -26,6 +26,7 @@ #include "absl/strings/str_cat.h" #include "absl/strings/str_join.h" #include "absl/strings/str_split.h" +#include "mediapipe/framework/calculator.pb.h" #include "mediapipe/framework/deps/proto_descriptor.pb.h" #include "mediapipe/framework/port/canonical_errors.h" #include "mediapipe/framework/port/integral_types.h" @@ -45,6 +46,9 @@ using mediapipe::proto_ns::Message; using mediapipe::proto_ns::OneofDescriptor; using mediapipe::proto_ns::Reflection; using mediapipe::proto_ns::TextFormat; +using ProtoPath = mediapipe::tool::ProtoUtilLite::ProtoPath; +using FieldType = mediapipe::tool::ProtoUtilLite::FieldType; +using FieldValue = mediapipe::tool::ProtoUtilLite::FieldValue; namespace mediapipe { @@ -326,7 +330,7 @@ class TemplateParser::Parser::ParserImpl { return suc && LookingAtType(io::Tokenizer::TYPE_END); } - void ReportError(int line, int col, const std::string& message) { + void ReportError(int line, int col, absl::string_view message) { had_errors_ = true; if (error_collector_ == NULL) { if (line >= 0) { @@ -338,11 +342,11 @@ class TemplateParser::Parser::ParserImpl { << root_message_type_->full_name() << ": " << message; } } else { - error_collector_->AddError(line, col, message); + error_collector_->AddError(line, col, std::string(message)); } } - void ReportWarning(int line, int col, const std::string& message) { + void ReportWarning(int line, int col, absl::string_view message) { if (error_collector_ == NULL) { if (line >= 0) { LOG(WARNING) << "Warning parsing text-format " @@ -353,21 +357,21 @@ class TemplateParser::Parser::ParserImpl { << root_message_type_->full_name() << ": " << message; } } else { - error_collector_->AddWarning(line, col, message); + error_collector_->AddWarning(line, col, std::string(message)); } } protected: // Reports an error with the given message with information indicating // the position (as derived from the current token). - void ReportError(const std::string& message) { + void ReportError(absl::string_view message) { ReportError(tokenizer_.current().line, tokenizer_.current().column, message); } // Reports a warning with the given message with information indicating // the position (as derived from the current token). - void ReportWarning(const std::string& message) { + void ReportWarning(absl::string_view message) { ReportWarning(tokenizer_.current().line, tokenizer_.current().column, message); } @@ -375,7 +379,7 @@ class TemplateParser::Parser::ParserImpl { // Consumes the specified message with the given starting delimiter. // This method checks to see that the end delimiter at the conclusion of // the consumption matches the starting delimiter passed in here. - bool ConsumeMessage(Message* message, const std::string delimiter) { + bool ConsumeMessage(Message* message, absl::string_view delimiter) { while (!LookingAt(">") && !LookingAt("}")) { if (LookingAt("%")) { DO(ConsumeFieldTemplate(message)); @@ -403,7 +407,7 @@ class TemplateParser::Parser::ParserImpl { #ifndef PROTO2_OPENSOURCE // Consumes a string value and parses it as a packed repeated field into // the given field of the given message. - bool ConsumePackedFieldAsString(const std::string& field_name, + bool ConsumePackedFieldAsString(absl::string_view field_name, const FieldDescriptor* field, Message* message) { std::string packed; @@ -427,8 +431,8 @@ class TemplateParser::Parser::ParserImpl { io::ArrayInputStream array_input(tagged.data(), tagged.size()); io::CodedInputStream coded_input(&array_input); if (!message->MergePartialFromCodedStream(&coded_input)) { - ReportError("Could not parse packed field \"" + field_name + - "\" as wire-encoded string."); + ReportError(absl::StrCat("Could not parse packed field \"", field_name, + "\" as wire-encoded string.")); return false; } @@ -1215,12 +1219,12 @@ class TemplateParser::Parser::ParserImpl { // Consumes a token and confirms that it matches that specified in the // value parameter. Returns false if the token found does not match that // which was specified. - bool Consume(const std::string& value) { + bool Consume(absl::string_view value) { const std::string& current_value = tokenizer_.current().text; if (current_value != value) { - ReportError("Expected \"" + value + "\", found \"" + current_value + - "\"."); + ReportError(absl::StrCat("Expected \"", value, "\", found \"", + current_value, "\".")); return false; } @@ -1357,32 +1361,138 @@ absl::Status ProtoPathSplit(const std::string& path, if (!ok) { status.Update(absl::InvalidArgumentError(path)); } - result->push_back(std::make_pair(tag, index)); + result->push_back({tag, index}); } } return status; } +// Returns a message serialized deterministically. +bool DeterministicallySerialize(const Message& proto, std::string* result) { + proto_ns::io::StringOutputStream stream(result); + proto_ns::io::CodedOutputStream output(&stream); + output.SetSerializationDeterministic(true); + return proto.SerializeToCodedStream(&output); +} + // Serialize one field of a message. void SerializeField(const Message* message, const FieldDescriptor* field, std::vector* result) { ProtoUtilLite::FieldValue message_bytes; - CHECK(message->SerializePartialToString(&message_bytes)); + CHECK(DeterministicallySerialize(*message, &message_bytes)); ProtoUtilLite::FieldAccess access( field->number(), static_cast(field->type())); MEDIAPIPE_CHECK_OK(access.SetMessage(message_bytes)); *result = *access.mutable_field_values(); } +// Serialize a ProtoPath as a readable string. +// For example, {{1, 1}, {2, 1}, {3, 1}} returns "/1[1]/2[1]/3[1]", +// and {{1, 1}, {2, 1, "INPUT_FRAMES"}} returns "/1[1]/2[@1=INPUT_FRAMES]". +std::string ProtoPathJoin(ProtoPath path) { + std::string result; + for (ProtoUtilLite::ProtoPathEntry& e : path) { + if (e.field_id >= 0) { + absl::StrAppend(&result, "/", e.field_id, "[", e.index, "]"); + } else if (e.map_id >= 0) { + absl::StrAppend(&result, "/", e.map_id, "[@", e.key_id, "=", e.key_value, + "]"); + } + } + return result; +} + +// Returns the message value from a field at an index. +const Message* GetFieldMessage(const Message& message, + const FieldDescriptor* field, int index) { + if (field->type() != FieldDescriptor::TYPE_MESSAGE) { + return nullptr; + } + if (!field->is_repeated()) { + return &message.GetReflection()->GetMessage(message, field); + } + if (index < message.GetReflection()->FieldSize(message, field)) { + return &message.GetReflection()->GetRepeatedMessage(message, field, index); + } + return nullptr; +} + +// Returns all FieldDescriptors including extensions. +std::vector GetFields(const Message* src) { + std::vector result; + src->GetDescriptor()->file()->pool()->FindAllExtensions(src->GetDescriptor(), + &result); + for (int i = 0; i < src->GetDescriptor()->field_count(); ++i) { + result.push_back(src->GetDescriptor()->field(i)); + } + return result; +} + +// Orders map entries in dst to match src. +void OrderMapEntries(const Message* src, Message* dst, + std::set* seen = nullptr) { + std::unique_ptr> seen_owner; + if (!seen) { + seen_owner = std::make_unique>(); + seen = seen_owner.get(); + } + if (seen->count(src) > 0) { + return; + } else { + seen->insert(src); + } + for (auto field : GetFields(src)) { + if (field->is_map()) { + dst->GetReflection()->ClearField(dst, field); + for (int j = 0; j < src->GetReflection()->FieldSize(*src, field); ++j) { + const Message& entry = + src->GetReflection()->GetRepeatedMessage(*src, field, j); + dst->GetReflection()->AddMessage(dst, field)->CopyFrom(entry); + } + } + if (field->type() == FieldDescriptor::TYPE_MESSAGE) { + if (field->is_repeated()) { + for (int j = 0; j < src->GetReflection()->FieldSize(*src, field); ++j) { + OrderMapEntries( + &src->GetReflection()->GetRepeatedMessage(*src, field, j), + dst->GetReflection()->MutableRepeatedMessage(dst, field, j), + seen); + } + } else { + OrderMapEntries(&src->GetReflection()->GetMessage(*src, field), + dst->GetReflection()->MutableMessage(dst, field), seen); + } + } + } +} + +// Copies a Message, keeping map entries in order. +std::unique_ptr CloneMessage(const Message* message) { + std::unique_ptr result(message->New()); + result->CopyFrom(*message); + OrderMapEntries(message, result.get()); + return result; +} + +using MessageMap = std::map>; + // For a non-repeated field, move the most recently parsed field value // into the most recently parsed template expression. -void StowFieldValue(Message* message, TemplateExpression* expression) { +void StowFieldValue(Message* message, TemplateExpression* expression, + MessageMap* stowed_messages) { const Reflection* reflection = message->GetReflection(); const Descriptor* descriptor = message->GetDescriptor(); ProtoUtilLite::ProtoPath path; MEDIAPIPE_CHECK_OK(ProtoPathSplit(expression->path(), &path)); - int field_number = path[path.size() - 1].first; + int field_number = path[path.size() - 1].field_id; const FieldDescriptor* field = descriptor->FindFieldByNumber(field_number); + + // Save each stowed message unserialized preserving map entry order. + if (!field->is_repeated() && field->type() == FieldDescriptor::TYPE_MESSAGE) { + (*stowed_messages)[ProtoPathJoin(path)] = + CloneMessage(GetFieldMessage(*message, field, 0)); + } + if (!field->is_repeated()) { std::vector field_values; SerializeField(message, field, &field_values); @@ -1402,6 +1512,112 @@ static void StripQuotes(std::string* str) { } } +// Returns the field or extension for field number. +const FieldDescriptor* FindFieldByNumber(const Message* message, + int field_num) { + const FieldDescriptor* result = + message->GetDescriptor()->FindFieldByNumber(field_num); + if (result == nullptr) { + result = message->GetReflection()->FindKnownExtensionByNumber(field_num); + } + return result; +} + +// Returns the protobuf map key types from a ProtoPath. +std::vector ProtoPathKeyTypes(ProtoPath path) { + std::vector result; + for (auto& entry : path) { + if (entry.map_id >= 0) { + result.push_back(entry.key_type); + } + } + return result; +} + +// Returns the text value for a string or numeric protobuf map key. +std::string GetMapKey(const Message& map_entry) { + auto key_field = map_entry.GetDescriptor()->FindFieldByName("key"); + auto reflection = map_entry.GetReflection(); + if (key_field->type() == FieldDescriptor::TYPE_STRING) { + return reflection->GetString(map_entry, key_field); + } else if (key_field->type() == FieldDescriptor::TYPE_INT32) { + return absl::StrCat(reflection->GetInt32(map_entry, key_field)); + } else if (key_field->type() == FieldDescriptor::TYPE_INT64) { + return absl::StrCat(reflection->GetInt64(map_entry, key_field)); + } + return ""; +} + +// Returns a Message store in CalculatorGraphTemplate::field_value. +Message* FindStowedMessage(MessageMap* stowed_messages, ProtoPath proto_path) { + auto it = stowed_messages->find(ProtoPathJoin(proto_path)); + return (it != stowed_messages->end()) ? it->second.get() : nullptr; +} + +const Message* GetNestedMessage(const Message& message, + const FieldDescriptor* field, + ProtoPath proto_path, + MessageMap* stowed_messages) { + if (field->type() != FieldDescriptor::TYPE_MESSAGE) { + return nullptr; + } + const Message* result = FindStowedMessage(stowed_messages, proto_path); + if (!result) { + result = GetFieldMessage(message, field, proto_path.back().index); + } + return result; +} + +// Adjusts map-entries from indexes to keys. +// Protobuf map-entry order is intentionally not preserved. +absl::Status KeyProtoMapEntries(Message* source, MessageMap* stowed_messages) { + // Copy the rules from the source CalculatorGraphTemplate. + mediapipe::CalculatorGraphTemplate rules; + rules.ParsePartialFromString(source->SerializePartialAsString()); + // Only the "source" Message knows all extension types. + Message* config_0 = source->GetReflection()->MutableMessage( + source, source->GetDescriptor()->FindFieldByName("config"), nullptr); + for (int i = 0; i < rules.rule().size(); ++i) { + TemplateExpression* rule = rules.mutable_rule()->Mutable(i); + const Message* message = config_0; + ProtoPath path; + MP_RETURN_IF_ERROR(ProtoPathSplit(rule->path(), &path)); + for (int j = 0; j < path.size(); ++j) { + int field_id = path[j].field_id; + const FieldDescriptor* field = FindFieldByNumber(message, field_id); + ProtoPath prefix = {path.begin(), path.begin() + j + 1}; + message = GetNestedMessage(*message, field, prefix, stowed_messages); + if (!message) { + break; + } + if (field->is_map()) { + const Message* map_entry = message; + int key_id = + map_entry->GetDescriptor()->FindFieldByName("key")->number(); + FieldType key_type = static_cast( + map_entry->GetDescriptor()->FindFieldByName("key")->type()); + std::string key_value = GetMapKey(*map_entry); + path[j] = {field_id, key_id, key_type, key_value}; + } + } + if (!rule->path().empty()) { + *rule->mutable_path() = ProtoPathJoin(path); + for (FieldType key_type : ProtoPathKeyTypes(path)) { + *rule->mutable_key_type()->Add() = key_type; + } + } + } + // Copy the rules back into the source CalculatorGraphTemplate. + auto source_rules = + source->GetReflection()->GetMutableRepeatedFieldRef( + source, source->GetDescriptor()->FindFieldByName("rule")); + source_rules.Clear(); + for (auto& rule : rules.rule()) { + source_rules.Add(rule); + } + return absl::OkStatus(); +} + } // namespace class TemplateParser::Parser::MediaPipeParserImpl @@ -1416,6 +1632,8 @@ class TemplateParser::Parser::MediaPipeParserImpl // Copy the template rules into the output template "rule" field. success &= MergeFields(template_rules_, output).ok(); + // Replace map-entry indexes with map keys. + success &= KeyProtoMapEntries(output, &stowed_messages_).ok(); return success; } @@ -1441,7 +1659,7 @@ class TemplateParser::Parser::MediaPipeParserImpl DO(ConsumeFieldTemplate(message)); } else { DO(ConsumeField(message)); - StowFieldValue(message, expression); + StowFieldValue(message, expression, &stowed_messages_); } DO(ConsumeEndTemplate()); return true; @@ -1652,6 +1870,7 @@ class TemplateParser::Parser::MediaPipeParserImpl } mediapipe::CalculatorGraphTemplate template_rules_; + std::map> stowed_messages_; }; #undef DO diff --git a/mediapipe/framework/tool/test_util.cc b/mediapipe/framework/tool/test_util.cc index 6433c93d2..c7ed063e0 100644 --- a/mediapipe/framework/tool/test_util.cc +++ b/mediapipe/framework/tool/test_util.cc @@ -258,11 +258,8 @@ std::string GetTestFilePath(absl::string_view relative_path) { return file::JoinPath(GetTestRootDir(), relative_path); } -absl::StatusOr> LoadTestImage( - absl::string_view path, ImageFormat::Format format) { - std::string encoded; - MP_RETURN_IF_ERROR(mediapipe::file::GetContents(path, &encoded)); - +absl::StatusOr> DecodeTestImage( + absl::string_view encoded, ImageFormat::Format format) { // stbi_load determines the output pixel format based on the desired channels. // 0 means "use whatever's in the file". int desired_channels = format == ImageFormat::UNKNOWN ? 0 @@ -274,10 +271,10 @@ absl::StatusOr> LoadTestImage( << "unsupported output format requested: " << format; int width, height, channels_in_file; - auto data = stbi_load_from_memory(reinterpret_cast(encoded.data()), - encoded.size(), &width, &height, - &channels_in_file, desired_channels); - RET_CHECK(data) << "failed to decode image data from: " << path; + auto data = stbi_load_from_memory( + reinterpret_cast(encoded.data()), encoded.size(), &width, + &height, &channels_in_file, desired_channels); + RET_CHECK(data) << "failed to decode image data"; // If we didn't specify a desired format, it will be determined by what the // file contains. @@ -295,6 +292,13 @@ absl::StatusOr> LoadTestImage( format, width, height, width * output_channels, data, stbi_image_free); } +absl::StatusOr> LoadTestImage( + absl::string_view path, ImageFormat::Format format) { + std::string encoded; + MP_RETURN_IF_ERROR(mediapipe::file::GetContents(path, &encoded)); + return DecodeTestImage(encoded, format); +} + std::unique_ptr LoadTestPng(absl::string_view path, ImageFormat::Format format) { return nullptr; diff --git a/mediapipe/framework/tool/test_util.h b/mediapipe/framework/tool/test_util.h index 71c096db7..80b768e3d 100644 --- a/mediapipe/framework/tool/test_util.h +++ b/mediapipe/framework/tool/test_util.h @@ -81,6 +81,10 @@ std::string GetTestDataDir(absl::string_view package_base_path); // Loads a binary graph from path. Returns true iff successful. bool LoadTestGraph(CalculatorGraphConfig* proto, const std::string& path); +// Loads an image from memory. +absl::StatusOr> DecodeTestImage( + absl::string_view encoded, ImageFormat::Format format = ImageFormat::SRGBA); + // Loads an image from path. absl::StatusOr> LoadTestImage( absl::string_view path, ImageFormat::Format format = ImageFormat::SRGBA); diff --git a/mediapipe/framework/tool/testdata/BUILD b/mediapipe/framework/tool/testdata/BUILD index 906688520..8300181b5 100644 --- a/mediapipe/framework/tool/testdata/BUILD +++ b/mediapipe/framework/tool/testdata/BUILD @@ -17,10 +17,13 @@ load( "//mediapipe/framework/tool:mediapipe_graph.bzl", "mediapipe_simple_subgraph", ) +load("//mediapipe/framework/port:build_config.bzl", "mediapipe_proto_library") licenses(["notice"]) -package(default_visibility = ["//mediapipe:__subpackages__"]) +package(default_visibility = [ + "//mediapipe:__subpackages__", +]) filegroup( name = "test_graph", @@ -40,7 +43,6 @@ mediapipe_simple_subgraph( testonly = 1, graph = "dub_quad_test_subgraph.pbtxt", register_as = "DubQuadTestSubgraph", - visibility = ["//visibility:public"], deps = [ "//mediapipe/framework:test_calculators", ], @@ -51,9 +53,18 @@ mediapipe_simple_subgraph( testonly = 1, graph = "nested_test_subgraph.pbtxt", register_as = "NestedTestSubgraph", - visibility = ["//visibility:public"], + visibility = ["//mediapipe/framework:__subpackages__"], deps = [ ":dub_quad_test_subgraph", "//mediapipe/framework:test_calculators", ], ) + +mediapipe_proto_library( + name = "frozen_generator_proto", + srcs = ["frozen_generator.proto"], + visibility = ["//mediapipe/framework:__subpackages__"], + deps = [ + "//mediapipe/framework:packet_generator_proto", + ], +) diff --git a/mediapipe/framework/tool/testdata/frozen_generator.proto b/mediapipe/framework/tool/testdata/frozen_generator.proto new file mode 100644 index 000000000..5f133f461 --- /dev/null +++ b/mediapipe/framework/tool/testdata/frozen_generator.proto @@ -0,0 +1,20 @@ +syntax = "proto2"; + +package mediapipe; + +import "mediapipe/framework/packet_generator.proto"; + +message FrozenGeneratorOptions { + extend mediapipe.PacketGeneratorOptions { + optional FrozenGeneratorOptions ext = 225748738; + } + + // Path to file containing serialized proto of type tensorflow::GraphDef. + optional string graph_proto_path = 1; + + // This map defines the which streams are fed to which tensors in the model. + map tag_to_tensor_names = 2; + + // Graph nodes to run to initialize the model. + repeated string initialization_op_names = 4; +} diff --git a/mediapipe/framework/validated_graph_config.cc b/mediapipe/framework/validated_graph_config.cc index 16aad6e9b..15eac3209 100644 --- a/mediapipe/framework/validated_graph_config.cc +++ b/mediapipe/framework/validated_graph_config.cc @@ -369,6 +369,7 @@ absl::Status ValidatedGraphConfig::Initialize( input_side_packets_.clear(); output_side_packets_.clear(); stream_to_producer_.clear(); + output_streams_to_consumer_nodes_.clear(); input_streams_.clear(); output_streams_.clear(); owned_packet_types_.clear(); @@ -719,6 +720,15 @@ absl::Status ValidatedGraphConfig::AddInputStreamsForNode( << " does not have a corresponding output stream."; } } + // Add this node as a consumer of this edge's output stream. + if (edge_info.upstream > -1) { + auto parent_node = output_streams_[edge_info.upstream].parent_node; + if (parent_node.type == NodeTypeInfo::NodeType::CALCULATOR) { + int this_idx = node_type_info->Node().index; + output_streams_to_consumer_nodes_[edge_info.upstream].push_back( + this_idx); + } + } edge_info.parent_node = node_type_info->Node(); edge_info.name = name; @@ -1048,6 +1058,14 @@ absl::Status ValidatedGraphConfig::ValidateRequiredSidePacketTypes( for (const auto& required_item : required_side_packets_) { auto iter = side_packet_types.find(required_item.first); if (iter == side_packet_types.end()) { + bool is_optional = true; + for (int index : required_item.second) { + is_optional &= input_side_packets_[index].packet_type->IsOptional(); + } + if (is_optional) { + // Side packets that are optional and not provided are ignored. + continue; + } statuses.push_back(mediapipe::InvalidArgumentErrorBuilder(MEDIAPIPE_LOC) << "Side packet \"" << required_item.first << "\" is required but was not provided."); diff --git a/mediapipe/framework/validated_graph_config.h b/mediapipe/framework/validated_graph_config.h index 11f9553cd..95ecccbb4 100644 --- a/mediapipe/framework/validated_graph_config.h +++ b/mediapipe/framework/validated_graph_config.h @@ -282,6 +282,14 @@ class ValidatedGraphConfig { return output_streams_[iter->second].parent_node.index; } + std::vector OutputStreamToConsumers(int idx) const { + auto iter = output_streams_to_consumer_nodes_.find(idx); + if (iter == output_streams_to_consumer_nodes_.end()) { + return {}; + } + return iter->second; + } + // Returns the registered type name of the specified side packet if // it can be determined, otherwise an appropriate error is returned. absl::StatusOr RegisteredSidePacketTypeName( @@ -418,6 +426,10 @@ class ValidatedGraphConfig { // Mapping from stream name to the output_streams_ index which produces it. std::map stream_to_producer_; + + // Mapping from output streams to consumer node ids. Used for profiling. + std::map> output_streams_to_consumer_nodes_; + // Mapping from side packet name to the output_side_packets_ index // which produces it. std::map side_packet_to_producer_; diff --git a/mediapipe/gpu/BUILD b/mediapipe/gpu/BUILD index 9c2f47469..702812718 100644 --- a/mediapipe/gpu/BUILD +++ b/mediapipe/gpu/BUILD @@ -176,6 +176,16 @@ cc_library( "-fobjc-arc", # enable reference-counting ], }), + linkopts = select({ + "//conditions:default": [], + "//mediapipe:ios": [ + "-framework OpenGLES", + ], + "//mediapipe:macos": [ + "-framework OpenGL", + "-framework AppKit", + ], + }), visibility = ["//visibility:public"], deps = [ ":attachments", @@ -204,8 +214,10 @@ cc_library( }) + select({ "//conditions:default": [ ], - "//mediapipe:ios": [], - "//mediapipe:macos": [], + "//mediapipe:ios": [ + ], + "//mediapipe:macos": [ + ], }), ) @@ -221,12 +233,18 @@ cc_library( ":gpu_buffer_format", ":gpu_buffer_storage", ":gpu_buffer_storage_image_frame", + "@com_google_absl//absl/memory", # TODO: remove this dependency. Some other teams' tests # depend on having an indirect image_frame dependency, need to be # fixed first. "//mediapipe/framework/formats:image_frame", - "@com_google_absl//absl/memory", - ], + ] + select({ + "//conditions:default": [], + ":platform_ios_with_gpu": [ + ":gl_texture_util", + ":gpu_buffer_storage_cv_pixel_buffer", + ], + }), ) cc_library( @@ -271,7 +289,9 @@ cc_library( deps = [ ":gpu_buffer_format", ":gpu_buffer_storage", + "@com_google_absl//absl/functional:bind_front", "@com_google_absl//absl/strings", + "@com_google_absl//absl/synchronization", "//mediapipe/framework/formats:image_frame", "//mediapipe/framework/port:logging", ":gpu_buffer_storage_image_frame", @@ -344,6 +364,60 @@ cc_library( ], ) +mediapipe_cc_test( + name = "gpu_buffer_storage_cv_pixel_buffer_test", + size = "small", + timeout = "moderate", + srcs = ["gpu_buffer_storage_cv_pixel_buffer_test.cc"], + platforms = ["ios"], + deps = [ + ":gl_texture_buffer", + ":gl_texture_util", + ":gpu_buffer", + ":gpu_buffer_storage_cv_pixel_buffer", + ":gpu_test_base", + "//mediapipe/framework/port:gtest_main", + "//mediapipe/framework/tool:test_util", + "//mediapipe/objc:util", + "@com_google_absl//absl/strings", + ], +) + +cc_library( + name = "cv_texture_cache_manager", + srcs = ["cv_texture_cache_manager.cc"], + hdrs = ["cv_texture_cache_manager.h"], + deps = [ + ":pixel_buffer_pool_util", + "//mediapipe/framework/port:logging", + "//mediapipe/objc:CFHolder", + "@com_google_absl//absl/synchronization", + ], +) + +cc_library( + name = "cv_pixel_buffer_pool_wrapper", + srcs = ["cv_pixel_buffer_pool_wrapper.cc"], + hdrs = ["cv_pixel_buffer_pool_wrapper.h"], + copts = select({ + "//conditions:default": [], + "//mediapipe:apple": [ + "-x objective-c++", + "-fobjc-arc", + ], + }), + deps = [ + ":cv_texture_cache_manager", + ":gpu_buffer_format", + ":multi_pool", + ":pixel_buffer_pool_util", + "//mediapipe/framework/port:logging", + "//mediapipe/objc:CFHolder", + "//mediapipe/objc:util", + "@com_google_absl//absl/synchronization", + ], +) + cc_library( name = "gpu_buffer_storage_image_frame", hdrs = ["gpu_buffer_storage_image_frame.h"], @@ -367,6 +441,21 @@ cc_library( ], ) +cc_library( + name = "gpu_buffer_storage_yuv_image", + srcs = ["gpu_buffer_storage_yuv_image.cc"], + hdrs = ["gpu_buffer_storage_yuv_image.h"], + visibility = ["//visibility:public"], + deps = [ + ":gpu_buffer_format", + ":gpu_buffer_storage", + "//mediapipe/framework/formats:yuv_image", + "//third_party/libyuv", + "@com_google_absl//absl/log", + "@com_google_absl//absl/log:check", + ], +) + cc_library( name = "gpu_buffer_storage_ahwb", srcs = ["gpu_buffer_storage_ahwb.cc"], @@ -400,22 +489,19 @@ objc_library( copts = [ "-Wno-shorten-64-to-32", ], - sdk_frameworks = [ - "Accelerate", - "CoreGraphics", - "CoreVideo", - ], visibility = ["//visibility:public"], - deps = ["//mediapipe/objc:util"], + deps = [ + "//mediapipe/objc:util", + "//third_party/apple_frameworks:Accelerate", + "//third_party/apple_frameworks:CoreGraphics", + "//third_party/apple_frameworks:CoreVideo", + ], ) objc_library( - name = "MPPGraphGPUData", - srcs = [ - "MPPGraphGPUData.mm", - "gpu_shared_data_internal.cc", - ], - hdrs = ["MPPGraphGPUData.h"], + name = "metal_shared_resources", + srcs = ["metal_shared_resources.mm"], + hdrs = ["metal_shared_resources.h"], copts = [ "-x objective-c++", "-Wno-shorten-64-to-32", @@ -424,24 +510,9 @@ objc_library( sdk_frameworks = [ "CoreVideo", "Metal", - ] + select({ - "//conditions:default": [ - "OpenGLES", - ], - "//mediapipe:macos": [ - "OpenGL", - "AppKit", - ], - }), + ], visibility = ["//visibility:public"], deps = [ - ":gl_base", - ":gl_context", - ":gpu_buffer_multi_pool", - ":gpu_shared_data_header", - ":graph_support", - "//mediapipe/gpu:gl_context_options_cc_proto", - "//mediapipe/framework:calculator_context", "//mediapipe/framework/port:ret_check", "@google_toolbox_for_mac//:GTM_Defines", ] + [ @@ -456,13 +527,11 @@ objc_library( "-x objective-c++", "-Wno-shorten-64-to-32", ], - sdk_frameworks = [ - "CoreVideo", - "Metal", - ], visibility = ["//visibility:public"], deps = [ "//mediapipe/objc:mediapipe_framework_ios", + "//third_party/apple_frameworks:CoreVideo", + "//third_party/apple_frameworks:Metal", "@com_google_absl//absl/time", "@google_toolbox_for_mac//:GTM_Defines", ], @@ -489,12 +558,7 @@ cc_library( name = "gpu_shared_data_header", textual_hdrs = [ "gpu_shared_data_internal.h", - ] + select({ - "//conditions:default": [], - "//mediapipe:apple": [ - "MPPGraphGPUData.h", - ], - }), + ], visibility = ["//visibility:private"], deps = [ ":gl_base", @@ -515,6 +579,7 @@ cc_library( name = "gpu_shared_data_internal_stub", visibility = ["//visibility:private"], deps = [ + ":gl_context_options_cc_proto", ":graph_support", "//mediapipe/framework:calculator_context", "//mediapipe/framework:calculator_node", @@ -522,25 +587,27 @@ cc_library( "//mediapipe/framework:port", "//mediapipe/framework/deps:no_destructor", "//mediapipe/framework/port:ret_check", - "//mediapipe/gpu:gl_context_options_cc_proto", ], ) cc_library( name = "gpu_shared_data_internal_actual", - srcs = select({ - "//conditions:default": [ - "gpu_shared_data_internal.cc", - ], - # iOS uses an Objective-C++ version of this, built in MPPGraphGPUData. - "//mediapipe:apple": [], - }), + srcs = [ + "gpu_shared_data_internal.cc", + ], hdrs = [ "gpu_shared_data_internal.h", ], + copts = select({ + "//conditions:default": [], + "//mediapipe:apple": [ + "-x objective-c++", + "-fobjc-arc", # enable reference-counting + ], + }), visibility = ["//visibility:private"], deps = [ - "//mediapipe/gpu:gl_context_options_cc_proto", + ":gl_context_options_cc_proto", ":graph_support", "//mediapipe/framework:calculator_context", "//mediapipe/framework:executor", @@ -554,7 +621,8 @@ cc_library( ] + select({ "//conditions:default": [], "//mediapipe:apple": [ - ":MPPGraphGPUData", + ":metal_shared_resources", + ":cv_texture_cache_manager", ], }), ) @@ -569,6 +637,8 @@ cc_library( ":gl_texture_buffer", ":gpu_buffer", ":gpu_shared_data_header", + ":multi_pool", + ":reusable_pool", "//mediapipe/framework:calculator_context", "//mediapipe/framework:calculator_node", "//mediapipe/framework/port:logging", @@ -577,6 +647,22 @@ cc_library( ], ) +cc_library( + name = "reusable_pool", + hdrs = ["reusable_pool.h"], + deps = [ + ":multi_pool", + "@com_google_absl//absl/functional:any_invocable", + "@com_google_absl//absl/synchronization", + ], +) + +cc_library( + name = "multi_pool", + hdrs = ["multi_pool.h"], + deps = ["//mediapipe/util:resource_cache"], +) + cc_library( name = "gpu_buffer_multi_pool", srcs = ["gpu_buffer_multi_pool.cc"], @@ -604,6 +690,7 @@ cc_library( ":gl_base", ":gpu_buffer", ":gpu_shared_data_header", + ":multi_pool", "//mediapipe/framework:calculator_context", "//mediapipe/framework:calculator_node", "//mediapipe/framework/port:logging", @@ -617,11 +704,15 @@ cc_library( ":gl_texture_buffer_pool", ], "//mediapipe:ios": [ + ":cv_pixel_buffer_pool_wrapper", + ":cv_texture_cache_manager", ":pixel_buffer_pool_util", "//mediapipe/objc:CFHolder", "//mediapipe/objc:util", ], "//mediapipe:macos": [ + ":cv_pixel_buffer_pool_wrapper", + ":cv_texture_cache_manager", ":pixel_buffer_pool_util", ":gl_texture_buffer", ":gl_texture_buffer_pool", @@ -629,6 +720,17 @@ cc_library( }), ) +cc_library( + name = "gl_texture_util", + srcs = ["gl_texture_util.cc"], + hdrs = ["gl_texture_util.h"], + visibility = ["//visibility:public"], + deps = [ + ":gl_base", + ":gl_texture_view", + ], +) + cc_library( name = "shader_util", srcs = ["shader_util.cc"], @@ -653,11 +755,9 @@ cc_library( name = "gl_calculator_helper", srcs = [ "gl_calculator_helper.cc", - "gl_calculator_helper_impl_common.cc", ], hdrs = [ "gl_calculator_helper.h", - "gl_calculator_helper_impl.h", ], linkopts = select({ "//conditions:default": [], @@ -689,7 +789,7 @@ cc_library( ":image_frame_view", ":shader_util", "//mediapipe/framework:calculator_framework", - "//mediapipe/framework:calculator_cc_proto", + "@com_google_absl//absl/base:core_headers", "//mediapipe/framework:calculator_context", "//mediapipe/framework:calculator_node", "//mediapipe/framework:calculator_contract", @@ -715,20 +815,6 @@ cc_library( }), ) -# TODO: remove -objc_library( - name = "gl_calculator_helper_ios", - copts = [ - "-Wno-shorten-64-to-32", - ], - visibility = ["//visibility:public"], - deps = [ - ":gl_calculator_helper", - "//mediapipe/objc:mediapipe_framework_ios", - "//mediapipe/objc:util", - ], -) - objc_library( name = "MPPMetalHelper", srcs = ["MPPMetalHelper.mm"], @@ -737,15 +823,13 @@ objc_library( "-Wno-shorten-64-to-32", ], features = ["-layering_check"], - sdk_frameworks = [ - "CoreVideo", - "Metal", - ], visibility = ["//visibility:public"], deps = [ ":gpu_shared_data_internal", ":graph_support", "//mediapipe/objc:mediapipe_framework_ios", + "//third_party/apple_frameworks:CoreVideo", + "//third_party/apple_frameworks:Metal", "@google_toolbox_for_mac//:GTM_Defines", ], ) @@ -764,10 +848,10 @@ cc_library( deps = [ ":gl_base", ":gl_simple_shaders", + ":scale_mode_cc_proto", ":shader_util", "//mediapipe/framework/port:ret_check", "//mediapipe/framework/port:status", - "//mediapipe/gpu:scale_mode_cc_proto", ], ) @@ -821,6 +905,8 @@ cc_library( visibility = ["//visibility:public"], deps = [ ":gl_calculator_helper", + ":gpu_buffer_storage_image_frame", + "//mediapipe/framework/api2:node", "//mediapipe/framework:calculator_framework", "//mediapipe/framework/formats:image_frame", "//mediapipe/framework/port:status", @@ -831,13 +917,24 @@ cc_library( alwayslink = 1, ) +### Simple calculators + +mediapipe_proto_library( + name = "gl_animation_overlay_calculator_proto", + srcs = ["gl_animation_overlay_calculator.proto"], + visibility = ["//visibility:public"], + deps = [ + "//mediapipe/framework:calculator_proto", + ], +) + proto_library( name = "gl_scaler_calculator_proto", srcs = ["gl_scaler_calculator.proto"], visibility = ["//visibility:public"], deps = [ + ":scale_mode_proto", "//mediapipe/framework:calculator_proto", - "//mediapipe/gpu:scale_mode_proto", ], ) @@ -859,6 +956,7 @@ cc_library( deps = [ ":gl_calculator_helper", ":gl_quad_renderer", + ":gl_scaler_calculator_cc_proto", ":gl_simple_shaders", ":shader_util", "//mediapipe/framework:calculator_framework", @@ -866,7 +964,6 @@ cc_library( "//mediapipe/framework/port:ret_check", "//mediapipe/framework/port:status", "//mediapipe/framework/tool:options_util", - "//mediapipe/gpu:gl_scaler_calculator_cc_proto", ], alwayslink = 1, ) @@ -879,13 +976,13 @@ cc_library( ":egl_surface_holder", ":gl_calculator_helper", ":gl_quad_renderer", + ":gl_surface_sink_calculator_cc_proto", ":gpu_buffer", ":shader_util", "//mediapipe/framework:calculator_framework", "//mediapipe/framework/api2:node", "//mediapipe/framework/port:ret_check", "//mediapipe/framework/port:status", - "//mediapipe/gpu:gl_surface_sink_calculator_cc_proto", "@com_google_absl//absl/synchronization", ], alwayslink = 1, @@ -895,8 +992,8 @@ proto_library( name = "gl_surface_sink_calculator_proto", srcs = ["gl_surface_sink_calculator.proto"], deps = [ + ":scale_mode_proto", "//mediapipe/framework:calculator_proto", - "//mediapipe/gpu:scale_mode_proto", ], ) @@ -947,16 +1044,14 @@ objc_library( name = "metal_copy_calculator", srcs = ["MetalCopyCalculator.mm"], features = ["-layering_check"], - sdk_frameworks = [ - "CoreVideo", - "Metal", - ], visibility = ["//visibility:public"], deps = [ ":MPPMetalHelper", + ":copy_calculator_cc_proto", ":simple_shaders_mtl", - "//mediapipe/gpu:copy_calculator_cc_proto", "//mediapipe/objc:mediapipe_framework_ios", + "//third_party/apple_frameworks:CoreVideo", + "//third_party/apple_frameworks:Metal", ], alwayslink = 1, ) @@ -965,15 +1060,13 @@ objc_library( name = "metal_rgb_weight_calculator", srcs = ["MetalRgbWeightCalculator.mm"], features = ["-layering_check"], - sdk_frameworks = [ - "CoreVideo", - "Metal", - ], visibility = ["//visibility:public"], deps = [ ":MPPMetalHelper", ":simple_shaders_mtl", "//mediapipe/objc:mediapipe_framework_ios", + "//third_party/apple_frameworks:CoreVideo", + "//third_party/apple_frameworks:Metal", ], alwayslink = 1, ) @@ -982,15 +1075,13 @@ objc_library( name = "metal_sobel_calculator", srcs = ["MetalSobelCalculator.mm"], features = ["-layering_check"], - sdk_frameworks = [ - "CoreVideo", - "Metal", - ], visibility = ["//visibility:public"], deps = [ ":MPPMetalHelper", ":simple_shaders_mtl", "//mediapipe/objc:mediapipe_framework_ios", + "//third_party/apple_frameworks:CoreVideo", + "//third_party/apple_frameworks:Metal", ], alwayslink = 1, ) @@ -999,15 +1090,13 @@ objc_library( name = "metal_sobel_compute_calculator", srcs = ["MetalSobelComputeCalculator.mm"], features = ["-layering_check"], - sdk_frameworks = [ - "CoreVideo", - "Metal", - ], visibility = ["//visibility:public"], deps = [ ":MPPMetalHelper", ":simple_shaders_mtl", "//mediapipe/objc:mediapipe_framework_ios", + "//third_party/apple_frameworks:CoreVideo", + "//third_party/apple_frameworks:Metal", ], alwayslink = 1, ) @@ -1017,15 +1106,13 @@ objc_library( srcs = ["MPSSobelCalculator.mm"], copts = ["-std=c++17"], features = ["-layering_check"], - sdk_frameworks = [ - "CoreVideo", - "Metal", - "MetalPerformanceShaders", - ], visibility = ["//visibility:public"], deps = [ ":MPPMetalHelper", "//mediapipe/objc:mediapipe_framework_ios", + "//third_party/apple_frameworks:CoreVideo", + "//third_party/apple_frameworks:Metal", + "//third_party/apple_frameworks:MetalPerformanceShaders", ], alwayslink = 1, ) @@ -1033,15 +1120,13 @@ objc_library( objc_library( name = "mps_threshold_calculator", srcs = ["MPSThresholdCalculator.mm"], - sdk_frameworks = [ - "CoreVideo", - "Metal", - "MetalPerformanceShaders", - ], visibility = ["//visibility:public"], deps = [ ":MPPMetalHelper", "//mediapipe/objc:mediapipe_framework_ios", + "//third_party/apple_frameworks:CoreVideo", + "//third_party/apple_frameworks:Metal", + "//third_party/apple_frameworks:MetalPerformanceShaders", ], alwayslink = 1, ) @@ -1062,8 +1147,8 @@ objc_library( name = "gl_ios_test_lib", testonly = 1, srcs = [ - "MPPGraphGPUDataTests.mm", "gl_ios_test.mm", + "metal_shared_resources_test.mm", ], copts = [ "-Wno-shorten-64-to-32", @@ -1073,7 +1158,7 @@ objc_library( ], features = ["-layering_check"], deps = [ - ":MPPGraphGPUData", + ":metal_shared_resources", ":gl_scaler_calculator", ":gpu_buffer_to_image_frame_calculator", ":gpu_shared_data_internal", @@ -1117,3 +1202,17 @@ mediapipe_cc_test( "//mediapipe/framework/port:gtest_main", ], ) + +mediapipe_cc_test( + name = "gpu_buffer_storage_yuv_image_test", + size = "small", + srcs = ["gpu_buffer_storage_yuv_image_test.cc"], + exclude_platforms = [ + "ios", + ], + deps = [ + ":gpu_buffer_format", + ":gpu_buffer_storage_yuv_image", + "//mediapipe/framework/port:gtest_main", + ], +) diff --git a/mediapipe/gpu/MPPGraphGPUData.h b/mediapipe/gpu/MPPGraphGPUData.h deleted file mode 100644 index 3d8fc0c94..000000000 --- a/mediapipe/gpu/MPPGraphGPUData.h +++ /dev/null @@ -1,71 +0,0 @@ -// Copyright 2019 The MediaPipe Authors. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -#ifndef MEDIAPIPE_GPU_MPPGRAPHGPUDATA_H_ -#define MEDIAPIPE_GPU_MPPGRAPHGPUDATA_H_ - -#import -#import -#import - -#import "mediapipe/gpu/gl_base.h" -#import "mediapipe/gpu/gl_context.h" - -namespace mediapipe { -class GlContext; -class GpuBufferMultiPool; -} // namespace mediapipe - -@interface MPPGraphGPUData : NSObject { - // Shared buffer pool for GPU calculators. - mediapipe::GpuBufferMultiPool* _gpuBufferPool; - mediapipe::GlContext* _glContext; -} - -- (instancetype)init NS_UNAVAILABLE; - -/// Initialize. The provided multipool pointer must remain valid throughout -/// this object's lifetime. -- (instancetype)initWithContext:(mediapipe::GlContext*)context - multiPool:(mediapipe::GpuBufferMultiPool*)pool NS_DESIGNATED_INITIALIZER; - -/// Shared texture pool for GPU calculators. -/// For internal use by GlCalculatorHelper. -@property(readonly) mediapipe::GpuBufferMultiPool* gpuBufferPool; - -/// Shared OpenGL context. -#if TARGET_OS_OSX -@property(readonly) NSOpenGLContext* glContext; -@property(readonly) NSOpenGLPixelFormat* glPixelFormat; -#else -@property(readonly) EAGLContext* glContext; -#endif // TARGET_OS_OSX - -/// Shared texture cache. -#if TARGET_OS_OSX -@property(readonly) CVOpenGLTextureCacheRef textureCache; -#else -@property(readonly) CVOpenGLESTextureCacheRef textureCache; -#endif // TARGET_OS_OSX - -/// Shared Metal resources. -@property(readonly) id mtlDevice; -@property(readonly) id mtlCommandQueue; -#if COREVIDEO_SUPPORTS_METAL -@property(readonly) CVMetalTextureCacheRef mtlTextureCache; -#endif - -@end - -#endif // MEDIAPIPE_GPU_MPPGRAPHGPUDATA_H_ diff --git a/mediapipe/gpu/MPPGraphGPUData.mm b/mediapipe/gpu/MPPGraphGPUData.mm deleted file mode 100644 index 8ac1eefa5..000000000 --- a/mediapipe/gpu/MPPGraphGPUData.mm +++ /dev/null @@ -1,124 +0,0 @@ -// Copyright 2019 The MediaPipe Authors. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -#import "mediapipe/gpu/MPPGraphGPUData.h" - -#import "GTMDefines.h" - -#include "mediapipe/gpu/gl_context.h" -#include "mediapipe/gpu/gpu_buffer_multi_pool.h" - -#if TARGET_OS_OSX -#import -#else -#import -#endif // TARGET_OS_OSX - -@implementation MPPGraphGPUData - -@synthesize textureCache = _textureCache; -@synthesize mtlDevice = _mtlDevice; -@synthesize mtlCommandQueue = _mtlCommandQueue; -#if COREVIDEO_SUPPORTS_METAL -@synthesize mtlTextureCache = _mtlTextureCache; -#endif - -#if TARGET_OS_OSX -typedef CVOpenGLTextureCacheRef CVTextureCacheType; -#else -typedef CVOpenGLESTextureCacheRef CVTextureCacheType; -#endif // TARGET_OS_OSX - -- (instancetype)initWithContext:(mediapipe::GlContext *)context - multiPool:(mediapipe::GpuBufferMultiPool *)pool { - self = [super init]; - if (self) { - _gpuBufferPool = pool; - _glContext = context; - } - return self; -} - -- (void)dealloc { - if (_textureCache) { - _textureCache = NULL; - } -#if COREVIDEO_SUPPORTS_METAL - if (_mtlTextureCache) { - CFRelease(_mtlTextureCache); - _mtlTextureCache = NULL; - } -#endif -} - -#if TARGET_OS_OSX -- (NSOpenGLContext *)glContext { - return _glContext->nsgl_context(); -} - -- (NSOpenGLPixelFormat *) glPixelFormat { - return _glContext->nsgl_pixel_format(); -} -#else -- (EAGLContext *)glContext { - return _glContext->eagl_context(); -} -#endif // TARGET_OS_OSX - -- (CVTextureCacheType)textureCache { - @synchronized(self) { - if (!_textureCache) { - _textureCache = _glContext->cv_texture_cache(); - } - } - return _textureCache; -} - -- (mediapipe::GpuBufferMultiPool *)gpuBufferPool { - return _gpuBufferPool; -} - -- (id)mtlDevice { - @synchronized(self) { - if (!_mtlDevice) { - _mtlDevice = MTLCreateSystemDefaultDevice(); - } - } - return _mtlDevice; -} - -- (id)mtlCommandQueue { - @synchronized(self) { - if (!_mtlCommandQueue) { - _mtlCommandQueue = [self.mtlDevice newCommandQueue]; - } - } - return _mtlCommandQueue; -} - -#if COREVIDEO_SUPPORTS_METAL -- (CVMetalTextureCacheRef)mtlTextureCache { - @synchronized(self) { - if (!_mtlTextureCache) { - CVReturn __unused err = - CVMetalTextureCacheCreate(NULL, NULL, self.mtlDevice, NULL, &_mtlTextureCache); - NSAssert(err == kCVReturnSuccess, @"Error at CVMetalTextureCacheCreate %d", err); - // TODO: register and flush metal caches too. - } - } - return _mtlTextureCache; -} -#endif - -@end diff --git a/mediapipe/gpu/MPPGraphGPUDataTests.mm b/mediapipe/gpu/MPPGraphGPUDataTests.mm deleted file mode 100644 index e8b50845b..000000000 --- a/mediapipe/gpu/MPPGraphGPUDataTests.mm +++ /dev/null @@ -1,86 +0,0 @@ -// Copyright 2019 The MediaPipe Authors. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -#import -#import - -#include - -#include "absl/memory/memory.h" -#include "mediapipe/framework/port/threadpool.h" - -#import "mediapipe/gpu/MPPGraphGPUData.h" -#import "mediapipe/gpu/gpu_shared_data_internal.h" - -@interface MPPGraphGPUDataTests : XCTestCase { -} -@end - -@implementation MPPGraphGPUDataTests - -// This test verifies that the internal Objective-C object is correctly -// released when the C++ wrapper is released. -- (void)testCorrectlyReleased { - __weak id gpuData = nil; - std::weak_ptr gpuRes; - @autoreleasepool { - mediapipe::GpuSharedData gpu_shared; - gpuRes = gpu_shared.gpu_resources; - gpuData = gpu_shared.gpu_resources->ios_gpu_data(); - XCTAssertNotEqual(gpuRes.lock(), nullptr); - XCTAssertNotNil(gpuData); - } - XCTAssertEqual(gpuRes.lock(), nullptr); - XCTAssertNil(gpuData); -} - -// This test verifies that the lazy initialization of the glContext instance -// variable is thread-safe. All threads should read the same value. -- (void)testGlContextThreadSafeLazyInitialization { - mediapipe::GpuSharedData gpu_shared; - constexpr int kNumThreads = 10; - EAGLContext* ogl_context[kNumThreads]; - auto pool = absl::make_unique(kNumThreads); - pool->StartWorkers(); - for (int i = 0; i < kNumThreads; ++i) { - pool->Schedule([&gpu_shared, &ogl_context, i] { - ogl_context[i] = gpu_shared.gpu_resources->ios_gpu_data().glContext; - }); - } - pool.reset(); - for (int i = 0; i < kNumThreads - 1; ++i) { - XCTAssertEqual(ogl_context[i], ogl_context[i + 1]); - } -} - -// This test verifies that the lazy initialization of the textureCache instance -// variable is thread-safe. All threads should read the same value. -- (void)testTextureCacheThreadSafeLazyInitialization { - mediapipe::GpuSharedData gpu_shared; - constexpr int kNumThreads = 10; - CFHolder texture_cache[kNumThreads]; - auto pool = absl::make_unique(kNumThreads); - pool->StartWorkers(); - for (int i = 0; i < kNumThreads; ++i) { - pool->Schedule([&gpu_shared, &texture_cache, i] { - texture_cache[i].reset(gpu_shared.gpu_resources->ios_gpu_data().textureCache); - }); - } - pool.reset(); - for (int i = 0; i < kNumThreads - 1; ++i) { - XCTAssertEqual(*texture_cache[i], *texture_cache[i + 1]); - } -} - -@end diff --git a/mediapipe/gpu/MPPMetalHelper.h b/mediapipe/gpu/MPPMetalHelper.h index f3662422e..6ae0f3cf9 100644 --- a/mediapipe/gpu/MPPMetalHelper.h +++ b/mediapipe/gpu/MPPMetalHelper.h @@ -21,37 +21,35 @@ #include "mediapipe/framework/packet.h" #include "mediapipe/framework/packet_type.h" -#include "mediapipe/gpu/MPPGraphGPUData.h" #include "mediapipe/gpu/gpu_shared_data_internal.h" NS_ASSUME_NONNULL_BEGIN @interface MPPMetalHelper : NSObject { - MPPGraphGPUData* _gpuShared; } - (instancetype)init NS_UNAVAILABLE; /// Initialize. This initializer is recommended for calculators. -- (instancetype)initWithCalculatorContext:(mediapipe::CalculatorContext*)cc; +- (instancetype)initWithCalculatorContext:(mediapipe::CalculatorContext *)cc; /// Initialize. -- (instancetype)initWithGpuResources:(mediapipe::GpuResources*)gpuResources +- (instancetype)initWithGpuResources:(mediapipe::GpuResources *)gpuResources NS_DESIGNATED_INITIALIZER; /// Configures a calculator's contract for accessing GPU resources. /// Calculators should use this in GetContract. -+ (absl::Status)updateContract:(mediapipe::CalculatorContract*)cc; ++ (absl::Status)updateContract:(mediapipe::CalculatorContract *)cc; /// Deprecated initializer. -- (instancetype)initWithSidePackets:(const mediapipe::PacketSet&)inputSidePackets; +- (instancetype)initWithSidePackets:(const mediapipe::PacketSet &)inputSidePackets; /// Deprecated initializer. -- (instancetype)initWithGpuSharedData:(mediapipe::GpuSharedData*)gpuShared; +- (instancetype)initWithGpuSharedData:(mediapipe::GpuSharedData *)gpuShared; /// Configures a calculator's side packets for accessing GPU resources. /// Calculators should use this in FillExpectations. -+ (absl::Status)setupInputSidePackets:(mediapipe::PacketTypeSet*)inputSidePackets; ++ (absl::Status)setupInputSidePackets:(mediapipe::PacketTypeSet *)inputSidePackets; /// Get a metal command buffer. /// Calculators should use this method instead of getting a buffer from the @@ -63,23 +61,23 @@ NS_ASSUME_NONNULL_BEGIN /// Creates a CVMetalTextureRef linked to the provided GpuBuffer. /// Ownership follows the copy rule, so the caller is responsible for /// releasing the CVMetalTextureRef. -- (CVMetalTextureRef)copyCVMetalTextureWithGpuBuffer:(const mediapipe::GpuBuffer&)gpuBuffer; +- (CVMetalTextureRef)copyCVMetalTextureWithGpuBuffer:(const mediapipe::GpuBuffer &)gpuBuffer; /// Creates a CVMetalTextureRef linked to the provided GpuBuffer given a specific plane. /// Ownership follows the copy rule, so the caller is responsible for /// releasing the CVMetalTextureRef. -- (CVMetalTextureRef)copyCVMetalTextureWithGpuBuffer:(const mediapipe::GpuBuffer&)gpuBuffer +- (CVMetalTextureRef)copyCVMetalTextureWithGpuBuffer:(const mediapipe::GpuBuffer &)gpuBuffer plane:(size_t)plane; /// Returns a MTLTexture linked to the provided GpuBuffer. /// A calculator can freely use it as a rendering source, but it should not /// use it as a rendering target if the GpuBuffer was provided as an input. -- (id)metalTextureWithGpuBuffer:(const mediapipe::GpuBuffer&)gpuBuffer; +- (id)metalTextureWithGpuBuffer:(const mediapipe::GpuBuffer &)gpuBuffer; /// Returns a MTLTexture linked to the provided GpuBuffer given a specific plane. /// A calculator can freely use it as a rendering source, but it should not /// use it as a rendering target if the GpuBuffer was provided as an input. -- (id)metalTextureWithGpuBuffer:(const mediapipe::GpuBuffer&)gpuBuffer +- (id)metalTextureWithGpuBuffer:(const mediapipe::GpuBuffer &)gpuBuffer plane:(size_t)plane; /// Obtains a new GpuBuffer to be used as an output destination. @@ -91,7 +89,7 @@ NS_ASSUME_NONNULL_BEGIN format:(mediapipe::GpuBufferFormat)format; /// Convenience method to load a Metal library stored as a bundle resource. -- (id)newLibraryWithResourceName:(NSString*)name error:(NSError* _Nullable*)error; +- (id)newLibraryWithResourceName:(NSString *)name error:(NSError *_Nullable *)error; /// Shared Metal resources. @property(readonly) id mtlDevice; diff --git a/mediapipe/gpu/MPPMetalHelper.mm b/mediapipe/gpu/MPPMetalHelper.mm index ce6620972..1acf7cbfb 100644 --- a/mediapipe/gpu/MPPMetalHelper.mm +++ b/mediapipe/gpu/MPPMetalHelper.mm @@ -14,11 +14,18 @@ #import "mediapipe/gpu/MPPMetalHelper.h" +#import "mediapipe/gpu/gpu_buffer.h" #import "mediapipe/gpu/graph_support.h" +#import "mediapipe/gpu/metal_shared_resources.h" #import "GTMDefines.h" #include "mediapipe/framework/port/ret_check.h" +@interface MPPMetalHelper () { + mediapipe::GpuResources* _gpuResources; +} +@end + namespace mediapipe { // Using a C++ class so it can be declared as a friend of LegacyCalculatorSupport. @@ -40,7 +47,7 @@ class MetalHelperLegacySupport { - (instancetype)initWithGpuResources:(mediapipe::GpuResources*)gpuResources { self = [super init]; if (self) { - _gpuShared = gpuResources->ios_gpu_data(); + _gpuResources = gpuResources; } return self; } @@ -105,19 +112,19 @@ class MetalHelperLegacySupport { } - (id)mtlDevice { - return _gpuShared.mtlDevice; + return _gpuResources->metal_shared().resources().mtlDevice; } - (id)mtlCommandQueue { - return _gpuShared.mtlCommandQueue; + return _gpuResources->metal_shared().resources().mtlCommandQueue; } - (CVMetalTextureCacheRef)mtlTextureCache { - return _gpuShared.mtlTextureCache; + return _gpuResources->metal_shared().resources().mtlTextureCache; } - (id)commandBuffer { - return [_gpuShared.mtlCommandQueue commandBuffer]; + return [_gpuResources->metal_shared().resources().mtlCommandQueue commandBuffer]; } - (CVMetalTextureRef)copyCVMetalTextureWithGpuBuffer:(const mediapipe::GpuBuffer&)gpuBuffer @@ -169,8 +176,9 @@ class MetalHelperLegacySupport { CVMetalTextureRef texture; CVReturn err = CVMetalTextureCacheCreateTextureFromImage( - NULL, _gpuShared.mtlTextureCache, mediapipe::GetCVPixelBufferRef(gpuBuffer), NULL, - metalPixelFormat, width, height, plane, &texture); + NULL, _gpuResources->metal_shared().resources().mtlTextureCache, + mediapipe::GetCVPixelBufferRef(gpuBuffer), NULL, metalPixelFormat, width, height, plane, + &texture); CHECK_EQ(err, kCVReturnSuccess); return texture; } @@ -191,19 +199,20 @@ class MetalHelperLegacySupport { } - (mediapipe::GpuBuffer)mediapipeGpuBufferWithWidth:(int)width height:(int)height { - return _gpuShared.gpuBufferPool->GetBuffer(width, height); + return _gpuResources->gpu_buffer_pool().GetBuffer(width, height); } - (mediapipe::GpuBuffer)mediapipeGpuBufferWithWidth:(int)width height:(int)height format:(mediapipe::GpuBufferFormat)format { - return _gpuShared.gpuBufferPool->GetBuffer(width, height, format); + return _gpuResources->gpu_buffer_pool().GetBuffer(width, height, format); } - (id)newLibraryWithResourceName:(NSString*)name error:(NSError * _Nullable *)error { - return [_gpuShared.mtlDevice newLibraryWithFile:[[NSBundle bundleForClass:[self class]] - pathForResource:name ofType:@"metallib"] - error:error]; + return [_gpuResources->metal_shared().resources().mtlDevice + newLibraryWithFile:[[NSBundle bundleForClass:[self class]] pathForResource:name + ofType:@"metallib"] + error:error]; } @end diff --git a/mediapipe/gpu/attachments.h b/mediapipe/gpu/attachments.h index ca9f074c4..3a73e4676 100644 --- a/mediapipe/gpu/attachments.h +++ b/mediapipe/gpu/attachments.h @@ -31,8 +31,8 @@ class AttachmentBase {}; template class Attachment : public AttachmentBase { public: - using FactoryT = std::function(Context&)>; - Attachment(FactoryT factory) : factory_(factory) {} + using FactoryT = AttachmentPtr (*)(Context&); + explicit constexpr Attachment(FactoryT factory) : factory_(factory) {} Attachment(const Attachment&) = delete; Attachment(Attachment&&) = delete; diff --git a/mediapipe/gpu/cv_pixel_buffer_pool_wrapper.cc b/mediapipe/gpu/cv_pixel_buffer_pool_wrapper.cc new file mode 100644 index 000000000..6e077ae6e --- /dev/null +++ b/mediapipe/gpu/cv_pixel_buffer_pool_wrapper.cc @@ -0,0 +1,84 @@ +// Copyright 2019 The MediaPipe Authors. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "mediapipe/gpu/cv_pixel_buffer_pool_wrapper.h" + +#include + +#include "CoreFoundation/CFBase.h" +#include "mediapipe/framework/port/logging.h" +#include "mediapipe/objc/CFHolder.h" +#include "mediapipe/objc/util.h" + +namespace mediapipe { + +CvPixelBufferPoolWrapper::CvPixelBufferPoolWrapper( + int width, int height, GpuBufferFormat format, CFTimeInterval maxAge, + CvTextureCacheManager* texture_caches) { + OSType cv_format = CVPixelFormatForGpuBufferFormat(format); + CHECK_NE(cv_format, -1) << "unsupported pixel format"; + pool_ = MakeCFHolderAdopting( + /* keep count is 0 because the age param keeps buffers around anyway */ + CreateCVPixelBufferPool(width, height, cv_format, 0, maxAge)); + texture_caches_ = texture_caches; +} + +CFHolder CvPixelBufferPoolWrapper::GetBuffer() { + CVPixelBufferRef buffer; + int threshold = 1; + NSMutableDictionary* auxAttributes = + [NSMutableDictionary dictionaryWithCapacity:1]; + CVReturn err; + bool tried_flushing = false; + while (1) { + auxAttributes[(id)kCVPixelBufferPoolAllocationThresholdKey] = @(threshold); + err = CVPixelBufferPoolCreatePixelBufferWithAuxAttributes( + kCFAllocatorDefault, *pool_, (__bridge CFDictionaryRef)auxAttributes, + &buffer); + if (err != kCVReturnWouldExceedAllocationThreshold) break; + if (texture_caches_ && !tried_flushing) { + // Call the flush function to potentially release old holds on buffers + // and try again to create a pixel buffer. + // This is used to flush CV texture caches, which may retain buffers until + // flushed. + texture_caches_->FlushTextureCaches(); + tried_flushing = true; + } else { + ++threshold; + } + } + CHECK(!err) << "Error creating pixel buffer: " << err; + count_ = threshold; + return MakeCFHolderAdopting(buffer); +} + +std::string CvPixelBufferPoolWrapper::GetDebugString() const { + auto description = MakeCFHolderAdopting(CFCopyDescription(*pool_)); + return [(__bridge NSString*)*description UTF8String]; +} + +void CvPixelBufferPoolWrapper::Flush() { CVPixelBufferPoolFlush(*pool_, 0); } + +CFHolder CvPixelBufferPoolWrapper::CreateBufferWithoutPool( + const internal::GpuBufferSpec& spec) { + OSType cv_format = CVPixelFormatForGpuBufferFormat(spec.format); + CHECK_NE(cv_format, -1) << "unsupported pixel format"; + CVPixelBufferRef buffer; + CVReturn err = CreateCVPixelBufferWithoutPool(spec.width, spec.height, + cv_format, &buffer); + CHECK(!err) << "Error creating pixel buffer: " << err; + return MakeCFHolderAdopting(buffer); +} + +} // namespace mediapipe diff --git a/mediapipe/gpu/cv_pixel_buffer_pool_wrapper.h b/mediapipe/gpu/cv_pixel_buffer_pool_wrapper.h new file mode 100644 index 000000000..4d71adbf2 --- /dev/null +++ b/mediapipe/gpu/cv_pixel_buffer_pool_wrapper.h @@ -0,0 +1,66 @@ +// Copyright 2019 The MediaPipe Authors. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +// This class lets calculators allocate GpuBuffers of various sizes, caching +// and reusing them as needed. It does so by automatically creating and using +// platform-specific buffer pools for the requested sizes. +// +// This class is not meant to be used directly by calculators, but is instead +// used by GlCalculatorHelper to allocate buffers. + +#ifndef MEDIAPIPE_GPU_CV_PIXEL_BUFFER_POOL_WRAPPER_H_ +#define MEDIAPIPE_GPU_CV_PIXEL_BUFFER_POOL_WRAPPER_H_ + +#include "CoreFoundation/CFBase.h" +#include "mediapipe/gpu/cv_texture_cache_manager.h" +#include "mediapipe/gpu/gpu_buffer_format.h" +#include "mediapipe/gpu/multi_pool.h" +#include "mediapipe/gpu/pixel_buffer_pool_util.h" +#include "mediapipe/objc/CFHolder.h" + +namespace mediapipe { + +class CvPixelBufferPoolWrapper { + public: + CvPixelBufferPoolWrapper(int width, int height, GpuBufferFormat format, + CFTimeInterval maxAge, + CvTextureCacheManager* texture_caches); + + static std::shared_ptr Create( + const internal::GpuBufferSpec& spec, const MultiPoolOptions& options, + CvTextureCacheManager* texture_caches = nullptr) { + return std::make_shared( + spec.width, spec.height, spec.format, options.max_inactive_buffer_age, + texture_caches); + } + + CFHolder GetBuffer(); + + int GetBufferCount() const { return count_; } + std::string GetDebugString() const; + + void Flush(); + + static CFHolder CreateBufferWithoutPool( + const internal::GpuBufferSpec& spec); + + private: + CFHolder pool_; + int count_ = 0; + CvTextureCacheManager* texture_caches_; +}; + +} // namespace mediapipe + +#endif // MEDIAPIPE_GPU_CV_PIXEL_BUFFER_POOL_WRAPPER_H_ diff --git a/mediapipe/gpu/cv_texture_cache_manager.cc b/mediapipe/gpu/cv_texture_cache_manager.cc new file mode 100644 index 000000000..b977a8993 --- /dev/null +++ b/mediapipe/gpu/cv_texture_cache_manager.cc @@ -0,0 +1,55 @@ +// Copyright 2019 The MediaPipe Authors. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "mediapipe/gpu/cv_texture_cache_manager.h" + +#include "mediapipe/framework/port/logging.h" + +namespace mediapipe { + +void CvTextureCacheManager::FlushTextureCaches() { + absl::MutexLock lock(&mutex_); + for (const auto& cache : texture_caches_) { +#if TARGET_OS_OSX + CVOpenGLTextureCacheFlush(*cache, 0); +#else + CVOpenGLESTextureCacheFlush(*cache, 0); +#endif // TARGET_OS_OSX + } +} + +void CvTextureCacheManager::RegisterTextureCache(CVTextureCacheType cache) { + absl::MutexLock lock(&mutex_); + + CHECK(std::find(texture_caches_.begin(), texture_caches_.end(), cache) == + texture_caches_.end()) + << "Attempting to register a texture cache twice"; + texture_caches_.emplace_back(cache); +} + +void CvTextureCacheManager::UnregisterTextureCache(CVTextureCacheType cache) { + absl::MutexLock lock(&mutex_); + + auto it = std::find(texture_caches_.begin(), texture_caches_.end(), cache); + CHECK(it != texture_caches_.end()) + << "Attempting to unregister an unknown texture cache"; + texture_caches_.erase(it); +} + +CvTextureCacheManager::~CvTextureCacheManager() { + CHECK_EQ(texture_caches_.size(), 0) + << "Failed to unregister texture caches before deleting manager"; +} + +} // namespace mediapipe diff --git a/mediapipe/gpu/cv_texture_cache_manager.h b/mediapipe/gpu/cv_texture_cache_manager.h new file mode 100644 index 000000000..17e44fc6e --- /dev/null +++ b/mediapipe/gpu/cv_texture_cache_manager.h @@ -0,0 +1,49 @@ +// Copyright 2019 The MediaPipe Authors. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#ifndef MEDIAPIPE_GPU_CV_TEXTURE_CACHE_MANAGER_H_ +#define MEDIAPIPE_GPU_CV_TEXTURE_CACHE_MANAGER_H_ + +#include + +#include "absl/synchronization/mutex.h" +#include "mediapipe/gpu/pixel_buffer_pool_util.h" +#include "mediapipe/objc/CFHolder.h" + +namespace mediapipe { + +class CvTextureCacheManager { + public: + ~CvTextureCacheManager(); + + // TODO: add tests for the texture cache registration. + + // Inform the pool of a cache that should be flushed when it is low on + // reusable buffers. + void RegisterTextureCache(CVTextureCacheType cache); + + // Remove a texture cache from the list of caches to be flushed. + void UnregisterTextureCache(CVTextureCacheType cache); + + void FlushTextureCaches(); + + private: + absl::Mutex mutex_; + std::vector> texture_caches_ + ABSL_GUARDED_BY(mutex_); +}; + +} // namespace mediapipe + +#endif // MEDIAPIPE_GPU_CV_TEXTURE_CACHE_MANAGER_H_ diff --git a/mediapipe/gpu/gl_calculator_helper.cc b/mediapipe/gpu/gl_calculator_helper.cc index ba1423977..9b217ddfd 100644 --- a/mediapipe/gpu/gl_calculator_helper.cc +++ b/mediapipe/gpu/gl_calculator_helper.cc @@ -20,38 +20,37 @@ #include "mediapipe/framework/port/canonical_errors.h" #include "mediapipe/framework/port/ret_check.h" #include "mediapipe/framework/port/status.h" -#include "mediapipe/gpu/gl_calculator_helper_impl.h" #include "mediapipe/gpu/gpu_buffer.h" #include "mediapipe/gpu/gpu_service.h" namespace mediapipe { -// The constructor and destructor need to be defined here so that -// std::unique_ptr can see the full definition of GlCalculatorHelperImpl. -// In the header, it is an incomplete type. GlCalculatorHelper::GlCalculatorHelper() {} GlCalculatorHelper::~GlCalculatorHelper() {} +void GlCalculatorHelper::InitializeInternal(CalculatorContext* cc, + GpuResources* gpu_resources) { + gpu_resources_ = gpu_resources; + gl_context_ = gpu_resources_->gl_context(cc); +} + absl::Status GlCalculatorHelper::Open(CalculatorContext* cc) { CHECK(cc); auto gpu_service = cc->Service(kGpuService); RET_CHECK(gpu_service.IsAvailable()) << "GPU service not available. Did you forget to call " "GlCalculatorHelper::UpdateContract?"; - // TODO return error from impl_ (needs two-stage init) - impl_ = - absl::make_unique(cc, &gpu_service.GetObject()); + InitializeInternal(cc, &gpu_service.GetObject()); return absl::OkStatus(); } void GlCalculatorHelper::InitializeForTest(GpuSharedData* gpu_shared) { - impl_ = absl::make_unique( - nullptr, gpu_shared->gpu_resources.get()); + InitializeInternal(nullptr, gpu_shared->gpu_resources.get()); } void GlCalculatorHelper::InitializeForTest(GpuResources* gpu_resources) { - impl_ = absl::make_unique(nullptr, gpu_resources); + InitializeInternal(nullptr, gpu_resources); } // static @@ -88,44 +87,109 @@ absl::Status GlCalculatorHelper::SetupInputSidePackets( return absl::OkStatus(); } +absl::Status GlCalculatorHelper::RunInGlContext( + std::function gl_func, + CalculatorContext* calculator_context) { + if (calculator_context) { + return gl_context_->Run(std::move(gl_func), calculator_context->NodeId(), + calculator_context->InputTimestamp()); + } else { + return gl_context_->Run(std::move(gl_func)); + } +} + absl::Status GlCalculatorHelper::RunInGlContext( std::function gl_func) { - if (!impl_) return absl::InternalError("helper not initialized"); + if (!Initialized()) return absl::InternalError("helper not initialized"); // TODO: Remove LegacyCalculatorSupport from MediaPipe OSS. auto calculator_context = LegacyCalculatorSupport::Scoped::current(); - return impl_->RunInGlContext(gl_func, calculator_context); + return RunInGlContext(gl_func, calculator_context); } -GLuint GlCalculatorHelper::framebuffer() const { return impl_->framebuffer(); } +GLuint GlCalculatorHelper::framebuffer() const { return framebuffer_; } + +void GlCalculatorHelper::CreateFramebuffer() { + // Our framebuffer will have a color attachment but no depth attachment, + // so it's important that the depth test be off. It is disabled by default, + // but we wanted to be explicit. + // TODO: move this to glBindFramebuffer? Or just remove. + glDisable(GL_DEPTH_TEST); + framebuffer_ = kUtilityFramebuffer.Get(*gl_context_); +} void GlCalculatorHelper::BindFramebuffer(const GlTexture& dst) { - return impl_->BindFramebuffer(dst); +#ifdef __ANDROID__ + // On (some?) Android devices, attaching a new texture to the frame buffer + // does not seem to detach the old one. As a result, using that texture + // for texturing can produce incorrect output. See b/32091368 for details. + // To fix this, we have to call either glBindFramebuffer with a FBO id of 0 + // or glFramebufferTexture2D with a texture ID of 0. + glBindFramebuffer(GL_FRAMEBUFFER, 0); +#endif + if (!framebuffer_) { + CreateFramebuffer(); + } + glBindFramebuffer(GL_FRAMEBUFFER, framebuffer_); + glViewport(0, 0, dst.width(), dst.height()); + glFramebufferTexture2D(GL_FRAMEBUFFER, GL_COLOR_ATTACHMENT0, dst.target(), + dst.name(), 0); + +#ifndef NDEBUG + GLenum status = glCheckFramebufferStatus(GL_FRAMEBUFFER); + if (status != GL_FRAMEBUFFER_COMPLETE) { + VLOG(2) << "incomplete framebuffer: " << status; + } +#endif } -GlTexture GlCalculatorHelper::CreateSourceTexture( - const GpuBuffer& pixel_buffer) { - return impl_->CreateSourceTexture(pixel_buffer); +GlTexture GlCalculatorHelper::MapGpuBuffer(const GpuBuffer& gpu_buffer, + GlTextureView view) { + if (gpu_buffer.format() != GpuBufferFormat::kUnknown) { + // TODO: do the params need to be reset here?? + glBindTexture(view.target(), view.name()); + GlTextureInfo info = GlTextureInfoForGpuBufferFormat( + gpu_buffer.format(), view.plane(), GetGlVersion()); + gl_context_->SetStandardTextureParams(view.target(), + info.gl_internal_format); + glBindTexture(view.target(), 0); + } + + return GlTexture(std::move(view), gpu_buffer); +} + +GlTexture GlCalculatorHelper::CreateSourceTexture(const GpuBuffer& gpu_buffer) { + return CreateSourceTexture(gpu_buffer, 0); +} + +GlTexture GlCalculatorHelper::CreateSourceTexture(const GpuBuffer& gpu_buffer, + int plane) { + return MapGpuBuffer(gpu_buffer, gpu_buffer.GetReadView(plane)); } GlTexture GlCalculatorHelper::CreateSourceTexture( const ImageFrame& image_frame) { - return impl_->CreateSourceTexture(image_frame); -} - -GlTexture GlCalculatorHelper::CreateSourceTexture(const GpuBuffer& pixel_buffer, - int plane) { - return impl_->CreateSourceTexture(pixel_buffer, plane); + auto gpu_buffer = GpuBufferCopyingImageFrame(image_frame); + return MapGpuBuffer(gpu_buffer, gpu_buffer.GetReadView(0)); } GpuBuffer GlCalculatorHelper::GpuBufferWithImageFrame( std::shared_ptr image_frame) { - return impl_->GpuBufferWithImageFrame(std::move(image_frame)); + return GpuBuffer( + std::make_shared(std::move(image_frame))); } GpuBuffer GlCalculatorHelper::GpuBufferCopyingImageFrame( const ImageFrame& image_frame) { - return impl_->GpuBufferCopyingImageFrame(image_frame); +#if MEDIAPIPE_GPU_BUFFER_USE_CV_PIXEL_BUFFER + auto maybe_buffer = CreateCVPixelBufferCopyingImageFrame(image_frame); + // Converts absl::StatusOr to absl::Status since CHECK_OK() currently only + // deals with absl::Status in MediaPipe OSS. + CHECK_OK(maybe_buffer.status()); + return GpuBuffer(std::move(maybe_buffer).value()); +#else + return GpuBuffer(GlTextureBuffer::Create(image_frame)); +#endif // !MEDIAPIPE_GPU_BUFFER_USE_CV_PIXEL_BUFFER } void GlCalculatorHelper::GetGpuBufferDimensions(const GpuBuffer& pixel_buffer, @@ -136,23 +200,36 @@ void GlCalculatorHelper::GetGpuBufferDimensions(const GpuBuffer& pixel_buffer, *height = pixel_buffer.height(); } -GlTexture GlCalculatorHelper::CreateDestinationTexture(int output_width, - int output_height, +GlTexture GlCalculatorHelper::CreateDestinationTexture(int width, int height, GpuBufferFormat format) { - return impl_->CreateDestinationTexture(output_width, output_height, format); -} + if (!framebuffer_) { + CreateFramebuffer(); + } -GlContext& GlCalculatorHelper::GetGlContext() const { - return impl_->GetGlContext(); -} - -GlVersion GlCalculatorHelper::GetGlVersion() const { - return impl_->GetGlVersion(); + GpuBuffer gpu_buffer = + gpu_resources_->gpu_buffer_pool().GetBuffer(width, height, format); + return MapGpuBuffer(gpu_buffer, gpu_buffer.GetWriteView(0)); } GlTexture GlCalculatorHelper::CreateSourceTexture( const mediapipe::Image& image) { - return impl_->CreateSourceTexture(image.GetGpuBuffer()); + return CreateSourceTexture(image.GetGpuBuffer()); +} + +template <> +std::unique_ptr GlTexture::GetFrame() const { + view_->DoneWriting(); + std::shared_ptr view = + gpu_buffer_.GetReadView(); + auto copy = absl::make_unique(); + copy->CopyFrom(*view, ImageFrame::kDefaultAlignmentBoundary); + return copy; +} + +template <> +std::unique_ptr GlTexture::GetFrame() const { + view_->DoneWriting(); + return absl::make_unique(gpu_buffer_); } template <> diff --git a/mediapipe/gpu/gl_calculator_helper.h b/mediapipe/gpu/gl_calculator_helper.h index e44523202..af897bbe9 100644 --- a/mediapipe/gpu/gl_calculator_helper.h +++ b/mediapipe/gpu/gl_calculator_helper.h @@ -17,6 +17,7 @@ #include +#include "absl/base/attributes.h" #include "absl/memory/memory.h" #include "mediapipe/framework/calculator_context.h" #include "mediapipe/framework/calculator_contract.h" @@ -33,7 +34,6 @@ namespace mediapipe { -class GlCalculatorHelperImpl; class GlTexture; class GpuResources; struct GpuSharedData; @@ -62,6 +62,7 @@ class GlCalculatorHelper { // Can be used to initialize the helper outside of a calculator. Useful for // testing. void InitializeForTest(GpuResources* gpu_resources); + ABSL_DEPRECATED("Use InitializeForTest(GpuResources)") void InitializeForTest(GpuSharedData* gpu_shared); // This method can be called from GetContract to set up the needed GPU @@ -70,6 +71,7 @@ class GlCalculatorHelper { // This method can be called from FillExpectations to set the correct types // for the shared GL input side packet(s). + ABSL_DEPRECATED("Use UpdateContract") static absl::Status SetupInputSidePackets(PacketTypeSet* input_side_packets); // Execute the provided function within the helper's GL context. On some @@ -161,15 +163,30 @@ class GlCalculatorHelper { // TODO: do we need an unbind method too? void BindFramebuffer(const GlTexture& dst); - GlContext& GetGlContext() const; + GlContext& GetGlContext() const { return *gl_context_; } - GlVersion GetGlVersion() const; + GlVersion GetGlVersion() const { return gl_context_->GetGlVersion(); } // Check if the calculator helper has been previously initialized. - bool Initialized() { return impl_ != nullptr; } + bool Initialized() { return gpu_resources_ != nullptr; } private: - std::unique_ptr impl_; + void InitializeInternal(CalculatorContext* cc, GpuResources* gpu_resources); + + absl::Status RunInGlContext(std::function gl_func, + CalculatorContext* calculator_context); + + // Makes a GpuBuffer accessible as a texture in the GL context. + GlTexture MapGpuBuffer(const GpuBuffer& gpu_buffer, GlTextureView view); + + // Create the framebuffer for rendering. + void CreateFramebuffer(); + + std::shared_ptr gl_context_; + + GLuint framebuffer_ = 0; + + GpuResources* gpu_resources_ = nullptr; }; // Represents an OpenGL texture, and is a 'view' into the memory pool. @@ -201,9 +218,13 @@ class GlTexture { void Release() { view_ = std::make_shared(); } private: - explicit GlTexture(GlTextureView view) - : view_(std::make_shared(std::move(view))) {} - friend class GlCalculatorHelperImpl; + explicit GlTexture(GlTextureView view, GpuBuffer gpu_buffer) + : gpu_buffer_(std::move(gpu_buffer)), + view_(std::make_shared(std::move(view))) {} + friend class GlCalculatorHelper; + // We store the GpuBuffer to support GetFrame, and to ensure that the storage + // outlives the view. + GpuBuffer gpu_buffer_; std::shared_ptr view_; }; @@ -217,12 +238,14 @@ class GlTexture { // it is better to keep const-safety and accept having two versions of the // same thing. template +ABSL_DEPRECATED("Only for legacy calculators") auto TagOrIndex(const T& collection, const std::string& tag, int index) -> decltype(collection.Tag(tag)) { return collection.UsesTags() ? collection.Tag(tag) : collection.Index(index); } template +ABSL_DEPRECATED("Only for legacy calculators") auto TagOrIndex(T* collection, const std::string& tag, int index) -> decltype(collection->Tag(tag)) { return collection->UsesTags() ? collection->Tag(tag) @@ -230,12 +253,14 @@ auto TagOrIndex(T* collection, const std::string& tag, int index) } template +ABSL_DEPRECATED("Only for legacy calculators") bool HasTagOrIndex(const T& collection, const std::string& tag, int index) { return collection.UsesTags() ? collection.HasTag(tag) : index < collection.NumEntries(); } template +ABSL_DEPRECATED("Only for legacy calculators") bool HasTagOrIndex(T* collection, const std::string& tag, int index) { return collection->UsesTags() ? collection->HasTag(tag) : index < collection->NumEntries(); diff --git a/mediapipe/gpu/gl_calculator_helper_impl.h b/mediapipe/gpu/gl_calculator_helper_impl.h deleted file mode 100644 index 72b3265fe..000000000 --- a/mediapipe/gpu/gl_calculator_helper_impl.h +++ /dev/null @@ -1,82 +0,0 @@ -// Copyright 2019 The MediaPipe Authors. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -#ifndef MEDIAPIPE_GPU_GL_CALCULATOR_HELPER_IMPL_H_ -#define MEDIAPIPE_GPU_GL_CALCULATOR_HELPER_IMPL_H_ - -#include "mediapipe/gpu/gl_calculator_helper.h" -#include "mediapipe/gpu/gpu_shared_data_internal.h" - -#ifdef __OBJC__ -#import -#import -#endif // __OBJC__ - -#ifdef __ANDROID__ -#include "mediapipe/gpu/gl_texture_buffer_pool.h" -#endif - -namespace mediapipe { - -// This class implements the GlCalculatorHelper for iOS and Android. -// See GlCalculatorHelper for details on these methods. -class GlCalculatorHelperImpl { - public: - explicit GlCalculatorHelperImpl(CalculatorContext* cc, - GpuResources* gpu_resources); - ~GlCalculatorHelperImpl(); - - absl::Status RunInGlContext(std::function gl_func, - CalculatorContext* calculator_context); - - GlTexture CreateSourceTexture(const ImageFrame& image_frame); - GlTexture CreateSourceTexture(const GpuBuffer& gpu_buffer); - - // Note: multi-plane support is currently only available on iOS. - GlTexture CreateSourceTexture(const GpuBuffer& gpu_buffer, int plane); - - // Creates a framebuffer and returns the texture that it is bound to. - GlTexture CreateDestinationTexture(int output_width, int output_height, - GpuBufferFormat format); - - GpuBuffer GpuBufferWithImageFrame(std::shared_ptr image_frame); - GpuBuffer GpuBufferCopyingImageFrame(const ImageFrame& image_frame); - - GLuint framebuffer() const { return framebuffer_; } - void BindFramebuffer(const GlTexture& dst); - - GlVersion GetGlVersion() const { return gl_context_->GetGlVersion(); } - - GlContext& GetGlContext() const; - - // For internal use. - static void ReadTexture(const GlTextureView& view, void* output, size_t size); - - private: - // Makes a GpuBuffer accessible as a texture in the GL context. - GlTexture MapGpuBuffer(const GpuBuffer& gpu_buffer, GlTextureView view); - - // Create the framebuffer for rendering. - void CreateFramebuffer(); - - std::shared_ptr gl_context_; - - GLuint framebuffer_ = 0; - - GpuResources& gpu_resources_; -}; - -} // namespace mediapipe - -#endif // MEDIAPIPE_GPU_GL_CALCULATOR_HELPER_IMPL_H_ diff --git a/mediapipe/gpu/gl_calculator_helper_impl_common.cc b/mediapipe/gpu/gl_calculator_helper_impl_common.cc deleted file mode 100644 index c5c028d4f..000000000 --- a/mediapipe/gpu/gl_calculator_helper_impl_common.cc +++ /dev/null @@ -1,178 +0,0 @@ -// Copyright 2019 The MediaPipe Authors. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -#include - -#include "absl/memory/memory.h" -#include "mediapipe/framework/formats/image_frame.h" -#include "mediapipe/gpu/gl_calculator_helper_impl.h" -#include "mediapipe/gpu/gpu_buffer_format.h" -#include "mediapipe/gpu/gpu_shared_data_internal.h" -#include "mediapipe/gpu/image_frame_view.h" - -namespace mediapipe { - -GlCalculatorHelperImpl::GlCalculatorHelperImpl(CalculatorContext* cc, - GpuResources* gpu_resources) - : gpu_resources_(*gpu_resources) { - gl_context_ = gpu_resources_.gl_context(cc); -} - -GlCalculatorHelperImpl::~GlCalculatorHelperImpl() { - RunInGlContext( - [this] { - if (framebuffer_) { - glDeleteFramebuffers(1, &framebuffer_); - framebuffer_ = 0; - } - return absl::OkStatus(); - }, - /*calculator_context=*/nullptr) - .IgnoreError(); -} - -GlContext& GlCalculatorHelperImpl::GetGlContext() const { return *gl_context_; } - -absl::Status GlCalculatorHelperImpl::RunInGlContext( - std::function gl_func, - CalculatorContext* calculator_context) { - if (calculator_context) { - return gl_context_->Run(std::move(gl_func), calculator_context->NodeId(), - calculator_context->InputTimestamp()); - } else { - return gl_context_->Run(std::move(gl_func)); - } -} - -void GlCalculatorHelperImpl::CreateFramebuffer() { - // Our framebuffer will have a color attachment but no depth attachment, - // so it's important that the depth test be off. It is disabled by default, - // but we wanted to be explicit. - // TODO: move this to glBindFramebuffer? - glDisable(GL_DEPTH_TEST); - glGenFramebuffers(1, &framebuffer_); -} - -void GlCalculatorHelperImpl::BindFramebuffer(const GlTexture& dst) { -#ifdef __ANDROID__ - // On (some?) Android devices, attaching a new texture to the frame buffer - // does not seem to detach the old one. As a result, using that texture - // for texturing can produce incorrect output. See b/32091368 for details. - // To fix this, we have to call either glBindFramebuffer with a FBO id of 0 - // or glFramebufferTexture2D with a texture ID of 0. - glBindFramebuffer(GL_FRAMEBUFFER, 0); -#endif - if (!framebuffer_) { - CreateFramebuffer(); - } - glBindFramebuffer(GL_FRAMEBUFFER, framebuffer_); - glViewport(0, 0, dst.width(), dst.height()); - glFramebufferTexture2D(GL_FRAMEBUFFER, GL_COLOR_ATTACHMENT0, dst.target(), - dst.name(), 0); - -#ifndef NDEBUG - GLenum status = glCheckFramebufferStatus(GL_FRAMEBUFFER); - if (status != GL_FRAMEBUFFER_COMPLETE) { - VLOG(2) << "incomplete framebuffer: " << status; - } -#endif -} - -GlTexture GlCalculatorHelperImpl::MapGpuBuffer(const GpuBuffer& gpu_buffer, - GlTextureView view) { - if (gpu_buffer.format() != GpuBufferFormat::kUnknown) { - // TODO: do the params need to be reset here?? - glBindTexture(view.target(), view.name()); - GlTextureInfo info = GlTextureInfoForGpuBufferFormat( - gpu_buffer.format(), view.plane(), GetGlVersion()); - gl_context_->SetStandardTextureParams(view.target(), - info.gl_internal_format); - glBindTexture(view.target(), 0); - } - - return GlTexture(std::move(view)); -} - -GlTexture GlCalculatorHelperImpl::CreateSourceTexture( - const GpuBuffer& gpu_buffer) { - return CreateSourceTexture(gpu_buffer, 0); -} - -GlTexture GlCalculatorHelperImpl::CreateSourceTexture( - const GpuBuffer& gpu_buffer, int plane) { - return MapGpuBuffer(gpu_buffer, gpu_buffer.GetReadView(plane)); -} - -GlTexture GlCalculatorHelperImpl::CreateSourceTexture( - const ImageFrame& image_frame) { - auto gpu_buffer = GpuBufferCopyingImageFrame(image_frame); - return MapGpuBuffer(gpu_buffer, gpu_buffer.GetReadView(0)); -} - -GpuBuffer GlCalculatorHelperImpl::GpuBufferWithImageFrame( - std::shared_ptr image_frame) { - return GpuBuffer( - std::make_shared(std::move(image_frame))); -} - -GpuBuffer GlCalculatorHelperImpl::GpuBufferCopyingImageFrame( - const ImageFrame& image_frame) { -#if MEDIAPIPE_GPU_BUFFER_USE_CV_PIXEL_BUFFER - auto maybe_buffer = CreateCVPixelBufferCopyingImageFrame(image_frame); - // Converts absl::StatusOr to absl::Status since CHECK_OK() currently only - // deals with absl::Status in MediaPipe OSS. - CHECK_OK(maybe_buffer.status()); - return GpuBuffer(std::move(maybe_buffer).value()); -#else - return GpuBuffer(GlTextureBuffer::Create(image_frame)); -#endif // !MEDIAPIPE_GPU_BUFFER_USE_CV_PIXEL_BUFFER -} - -template <> -std::unique_ptr GlTexture::GetFrame() const { - view_->DoneWriting(); - std::shared_ptr view = - view_->gpu_buffer().GetReadView(); - auto copy = absl::make_unique(); - copy->CopyFrom(*view, ImageFrame::kDefaultAlignmentBoundary); - return copy; -} - -template <> -std::unique_ptr GlTexture::GetFrame() const { - auto gpu_buffer = view_->gpu_buffer(); -#ifdef __EMSCRIPTEN__ - // When WebGL is used, the GL context may be spontaneously lost which can - // cause GpuBuffer allocations to fail. In that case, return a dummy buffer - // to allow processing of the current frame complete. - if (!gpu_buffer) { - return std::make_unique(); - } -#endif // __EMSCRIPTEN__ - view_->DoneWriting(); - return absl::make_unique(gpu_buffer); -} - -GlTexture GlCalculatorHelperImpl::CreateDestinationTexture( - int width, int height, GpuBufferFormat format) { - if (!framebuffer_) { - CreateFramebuffer(); - } - - GpuBuffer gpu_buffer = - gpu_resources_.gpu_buffer_pool().GetBuffer(width, height, format); - return MapGpuBuffer(gpu_buffer, gpu_buffer.GetWriteView(0)); -} - -} // namespace mediapipe diff --git a/mediapipe/gpu/gl_context.cc b/mediapipe/gpu/gl_context.cc index 7f7ba0e23..99b995dda 100644 --- a/mediapipe/gpu/gl_context.cc +++ b/mediapipe/gpu/gl_context.cc @@ -290,8 +290,15 @@ absl::Status GlContext::FinishInitialization(bool create_thread) { // some Emscripten cases), there might be some existing tripped error. ForceClearExistingGlErrors(); - absl::string_view version_string( - reinterpret_cast(glGetString(GL_VERSION))); + absl::string_view version_string; + const GLubyte* version_string_ptr = glGetString(GL_VERSION); + if (version_string_ptr != nullptr) { + version_string = reinterpret_cast(version_string_ptr); + } else { + // This may happen when using SwiftShader, but the numeric versions are + // available and will be used instead. + LOG(WARNING) << "failed to get GL_VERSION string"; + } // We will decide later whether we want to use the version numbers we query // for, or instead derive that information from the context creation result, @@ -333,7 +340,7 @@ absl::Status GlContext::FinishInitialization(bool create_thread) { } LOG(INFO) << "GL version: " << gl_major_version_ << "." << gl_minor_version_ - << " (" << glGetString(GL_VERSION) << ")"; + << " (" << version_string << ")"; { auto status = GetGlExtensions(); if (!status.ok()) { @@ -826,10 +833,14 @@ std::shared_ptr GlContext::CreateSyncToken() { return token; } -bool GlContext::IsAnyContextCurrent() { +PlatformGlContext GlContext::GetCurrentNativeContext() { ContextBinding ctx; GetCurrentContextBinding(&ctx); - return ctx.context != kPlatformGlContextNone; + return ctx.context; +} + +bool GlContext::IsAnyContextCurrent() { + return GetCurrentNativeContext() != kPlatformGlContextNone; } std::shared_ptr @@ -1043,4 +1054,16 @@ void GlContext::SetStandardTextureParams(GLenum target, GLint internal_format) { glTexParameteri(target, GL_TEXTURE_WRAP_T, GL_CLAMP_TO_EDGE); } +const GlContext::Attachment kUtilityFramebuffer( + [](GlContext&) -> GlContext::Attachment::Ptr { + GLuint framebuffer; + glGenFramebuffers(1, &framebuffer); + if (!framebuffer) return nullptr; + return {new GLuint(framebuffer), [](void* ptr) { + GLuint* fb = static_cast(ptr); + glDeleteFramebuffers(1, fb); + delete fb; + }}; + }); + } // namespace mediapipe diff --git a/mediapipe/gpu/gl_context.h b/mediapipe/gpu/gl_context.h index 957cb510f..4f2390404 100644 --- a/mediapipe/gpu/gl_context.h +++ b/mediapipe/gpu/gl_context.h @@ -307,6 +307,10 @@ class GlContext : public std::enable_shared_from_this { // the GlContext class, is current. static bool IsAnyContextCurrent(); + // Returns the current native context, whether managed by this class or not. + // Useful as a cross-platform way to get the current PlatformGlContext. + static PlatformGlContext GetCurrentNativeContext(); + // Creates a synchronization token for the current, non-GlContext-owned // context. This can be passed to MediaPipe so it can synchronize with the // commands issued in the external context up to this point. @@ -470,6 +474,12 @@ class GlContext : public std::enable_shared_from_this { bool destructing_ = false; }; +// A framebuffer that the framework can use to attach textures for rendering +// etc. +// This could just be a member of GlContext, but it serves as a basic example +// of an attachment. +ABSL_CONST_INIT extern const GlContext::Attachment kUtilityFramebuffer; + // For backward compatibility. TODO: migrate remaining callers. ABSL_DEPRECATED( "Prefer passing an explicit GlVersion argument (use " diff --git a/mediapipe/gpu/gl_surface_sink_calculator.cc b/mediapipe/gpu/gl_surface_sink_calculator.cc index 31500ed9a..ad867c2be 100644 --- a/mediapipe/gpu/gl_surface_sink_calculator.cc +++ b/mediapipe/gpu/gl_surface_sink_calculator.cc @@ -37,7 +37,6 @@ enum { kAttribVertex, kAttribTexturePosition, kNumberOfAttributes }; // VIDEO or index 0: GpuBuffers to be rendered. // Side inputs: // SURFACE: unique_ptr to an EglSurfaceHolder to draw to. -// GPU_SHARED: shared GPU resources. // // See GlSurfaceSinkCalculatorOptions for options. class GlSurfaceSinkCalculator : public Node { diff --git a/mediapipe/gpu/gl_texture_buffer.cc b/mediapipe/gpu/gl_texture_buffer.cc index fbb91a8f5..69b9889c7 100644 --- a/mediapipe/gpu/gl_texture_buffer.cc +++ b/mediapipe/gpu/gl_texture_buffer.cc @@ -15,9 +15,15 @@ #include "mediapipe/gpu/gl_texture_buffer.h" #include "mediapipe/framework/formats/image_frame.h" +#include "mediapipe/gpu/gl_context.h" #include "mediapipe/gpu/gl_texture_view.h" #include "mediapipe/gpu/gpu_buffer_storage_image_frame.h" +#if MEDIAPIPE_GPU_BUFFER_USE_CV_PIXEL_BUFFER +#include "mediapipe/gpu/gl_texture_util.h" +#include "mediapipe/gpu/gpu_buffer_storage_cv_pixel_buffer.h" +#endif // MEDIAPIPE_GPU_BUFFER_USE_CV_PIXEL_BUFFER + namespace mediapipe { std::unique_ptr GlTextureBuffer::Wrap( @@ -250,39 +256,46 @@ void GlTextureBuffer::WaitForConsumersOnGpu() { // precisely, on only one GL context. } -GlTextureView GlTextureBuffer::GetReadView( - internal::types, std::shared_ptr gpu_buffer, - int plane) const { +GlTextureView GlTextureBuffer::GetReadView(internal::types, + int plane) const { auto gl_context = GlContext::GetCurrent(); CHECK(gl_context); CHECK_EQ(plane, 0); + // Note that this method is only supposed to be called by GpuBuffer, which + // ensures this condition is satisfied. + DCHECK(!weak_from_this().expired()) + << "GlTextureBuffer must be held in shared_ptr to get a GlTextureView"; // Insert wait call to sync with the producer. WaitOnGpu(); - GlTextureView::DetachFn detach = [this](GlTextureView& texture) { - // Inform the GlTextureBuffer that we have finished accessing its - // contents, and create a consumer sync point. - DidRead(texture.gl_context()->CreateSyncToken()); - }; + GlTextureView::DetachFn detach = + [texbuf = shared_from_this()](GlTextureView& texture) { + // Inform the GlTextureBuffer that we have finished accessing its + // contents, and create a consumer sync point. + texbuf->DidRead(texture.gl_context()->CreateSyncToken()); + }; return GlTextureView(gl_context.get(), target(), name(), width(), height(), - std::move(gpu_buffer), plane, std::move(detach), - nullptr); + plane, std::move(detach), nullptr); } -GlTextureView GlTextureBuffer::GetWriteView( - internal::types, std::shared_ptr gpu_buffer, - int plane) { +GlTextureView GlTextureBuffer::GetWriteView(internal::types, + int plane) { auto gl_context = GlContext::GetCurrent(); CHECK(gl_context); CHECK_EQ(plane, 0); + // Note that this method is only supposed to be called by GpuBuffer, which + // ensures this condition is satisfied. + DCHECK(!weak_from_this().expired()) + << "GlTextureBuffer must be held in shared_ptr to get a GlTextureView"; // Insert wait call to sync with the producer. WaitOnGpu(); Reuse(); // TODO: the producer wait should probably be part of Reuse in the // case when there are no consumers. GlTextureView::DoneWritingFn done_writing = - [this](const GlTextureView& texture) { ViewDoneWriting(texture); }; + [texbuf = shared_from_this()](const GlTextureView& texture) { + texbuf->ViewDoneWriting(texture); + }; return GlTextureView(gl_context.get(), target(), name(), width(), height(), - std::move(gpu_buffer), plane, nullptr, - std::move(done_writing)); + plane, nullptr, std::move(done_writing)); } void GlTextureBuffer::ViewDoneWriting(const GlTextureView& view) { @@ -321,8 +334,8 @@ void GlTextureBuffer::ViewDoneWriting(const GlTextureView& view) { #endif // __ANDROID__ } -static void ReadTexture(const GlTextureView& view, GpuBufferFormat format, - void* output, size_t size) { +static void ReadTexture(GlContext& ctx, const GlTextureView& view, + GpuBufferFormat format, void* output, size_t size) { // TODO: check buffer size? We could use glReadnPixels where available // (OpenGL ES 3.2, i.e. nowhere). Note that, to fully check that the read // won't overflow the buffer with glReadPixels, we'd also need to check or @@ -332,13 +345,7 @@ static void ReadTexture(const GlTextureView& view, GpuBufferFormat format, GlTextureInfo info = GlTextureInfoForGpuBufferFormat( format, view.plane(), view.gl_context()->GetGlVersion()); - GLint previous_fbo; - glGetIntegerv(GL_FRAMEBUFFER_BINDING, &previous_fbo); - - // We use a temp fbo to avoid depending on the app having an existing one. - // TODO: keep a utility fbo around in the context? - GLuint fbo = 0; - glGenFramebuffers(1, &fbo); + GLuint fbo = kUtilityFramebuffer.Get(ctx); glBindFramebuffer(GL_FRAMEBUFFER, fbo); glFramebufferTexture2D(GL_FRAMEBUFFER, GL_COLOR_ATTACHMENT0, view.target(), view.name(), 0); @@ -346,9 +353,7 @@ static void ReadTexture(const GlTextureView& view, GpuBufferFormat format, output); glFramebufferTexture2D(GL_FRAMEBUFFER, GL_COLOR_ATTACHMENT0, GL_TEXTURE_2D, 0, 0); - // TODO: just set the binding to 0 to avoid the get call? - glBindFramebuffer(GL_FRAMEBUFFER, previous_fbo); - glDeleteFramebuffers(1, &fbo); + glBindFramebuffer(GL_FRAMEBUFFER, 0); } static std::shared_ptr ConvertToImageFrame( @@ -358,9 +363,11 @@ static std::shared_ptr ConvertToImageFrame( auto output = absl::make_unique(image_format, buf->width(), buf->height(), ImageFrame::kGlDefaultAlignmentBoundary); - buf->GetProducerContext()->Run([buf, &output] { - auto view = buf->GetReadView(internal::types{}, nullptr, 0); - ReadTexture(view, buf->format(), output->MutablePixelData(), + auto ctx = GlContext::GetCurrent(); + if (!ctx) ctx = buf->GetProducerContext(); + ctx->Run([buf, &output, &ctx] { + auto view = buf->GetReadView(internal::types{}, /*plane=*/0); + ReadTexture(*ctx, view, buf->format(), output->MutablePixelData(), output->PixelDataSize()); }); return std::make_shared(std::move(output)); @@ -380,4 +387,30 @@ static auto kConverterRegistration2 = .RegisterConverter( ConvertFromImageFrame); +#if MEDIAPIPE_GPU_BUFFER_USE_CV_PIXEL_BUFFER + +static std::shared_ptr ConvertToCvPixelBuffer( + std::shared_ptr buf) { + auto output = absl::make_unique( + buf->width(), buf->height(), buf->format()); + auto ctx = GlContext::GetCurrent(); + if (!ctx) ctx = buf->GetProducerContext(); + ctx->Run([buf, &output] { + TempGlFramebuffer framebuffer; + auto src = buf->GetReadView(internal::types{}, /*plane=*/0); + auto dst = + output->GetWriteView(internal::types{}, /*plane=*/0); + CopyGlTexture(src, dst); + glFlush(); + }); + return output; +} + +static auto kConverterRegistrationCvpb = + internal::GpuBufferStorageRegistry::Get() + .RegisterConverter( + ConvertToCvPixelBuffer); + +#endif // MEDIAPIPE_GPU_BUFFER_USE_CV_PIXEL_BUFFER + } // namespace mediapipe diff --git a/mediapipe/gpu/gl_texture_buffer.h b/mediapipe/gpu/gl_texture_buffer.h index 124a0ec2f..f785571a1 100644 --- a/mediapipe/gpu/gl_texture_buffer.h +++ b/mediapipe/gpu/gl_texture_buffer.h @@ -35,7 +35,8 @@ class GlCalculatorHelperImpl; // Implements a GPU memory buffer as an OpenGL texture. For internal use. class GlTextureBuffer : public internal::GpuBufferStorageImpl< - GlTextureBuffer, internal::ViewProvider> { + GlTextureBuffer, internal::ViewProvider>, + public std::enable_shared_from_this { public: // This is called when the texture buffer is deleted. It is passed a sync // token created at that time on the GlContext. If the GlTextureBuffer has @@ -71,6 +72,11 @@ class GlTextureBuffer // Create a texture with a copy of the data in image_frame. static std::unique_ptr Create(const ImageFrame& image_frame); + static std::unique_ptr Create( + const internal::GpuBufferSpec& spec) { + return Create(spec.width, spec.height, spec.format); + } + // Wraps an existing texture, but does not take ownership of it. // deletion_callback is invoked when the GlTextureBuffer is released, so // the caller knows that the texture is no longer in use. @@ -90,10 +96,8 @@ class GlTextureBuffer GpuBufferFormat format() const { return format_; } GlTextureView GetReadView(internal::types, - std::shared_ptr gpu_buffer, int plane) const override; GlTextureView GetWriteView(internal::types, - std::shared_ptr gpu_buffer, int plane) override; // If this texture is going to be used outside of the context that produced @@ -138,6 +142,10 @@ class GlTextureBuffer return producer_context_; } +#if MEDIAPIPE_GPU_BUFFER_USE_CV_PIXEL_BUFFER + static constexpr bool kDisableGpuBufferRegistration = true; +#endif // MEDIAPIPE_GPU_BUFFER_USE_CV_PIXEL_BUFFER + private: // Creates a texture of dimensions width x height and allocates space for it. // If data is provided, it is uploaded to the texture; otherwise, it can be diff --git a/mediapipe/gpu/gl_texture_buffer_pool.cc b/mediapipe/gpu/gl_texture_buffer_pool.cc index 3d5a8cdaa..599381a34 100644 --- a/mediapipe/gpu/gl_texture_buffer_pool.cc +++ b/mediapipe/gpu/gl_texture_buffer_pool.cc @@ -16,79 +16,4 @@ #include "absl/synchronization/mutex.h" -namespace mediapipe { - -GlTextureBufferPool::GlTextureBufferPool(int width, int height, - GpuBufferFormat format, int keep_count) - : width_(width), - height_(height), - format_(format), - keep_count_(keep_count) {} - -GlTextureBufferSharedPtr GlTextureBufferPool::GetBuffer() { - std::unique_ptr buffer; - bool reuse = false; - - { - absl::MutexLock lock(&mutex_); - if (available_.empty()) { - buffer = GlTextureBuffer::Create(width_, height_, format_); - if (!buffer) return nullptr; - } else { - buffer = std::move(available_.back()); - available_.pop_back(); - reuse = true; - } - - ++in_use_count_; - } - - // This needs to wait on consumer sync points, therefore it should not be - // done while holding the mutex. - if (reuse) { - buffer->Reuse(); - } - - // Return a shared_ptr with a custom deleter that adds the buffer back - // to our available list. - std::weak_ptr weak_pool(shared_from_this()); - return std::shared_ptr( - buffer.release(), [weak_pool](GlTextureBuffer* buf) { - auto pool = weak_pool.lock(); - if (pool) { - pool->Return(absl::WrapUnique(buf)); - } else { - delete buf; - } - }); -} - -std::pair GlTextureBufferPool::GetInUseAndAvailableCounts() { - absl::MutexLock lock(&mutex_); - return {in_use_count_, available_.size()}; -} - -void GlTextureBufferPool::Return(std::unique_ptr buf) { - std::vector> trimmed; - { - absl::MutexLock lock(&mutex_); - --in_use_count_; - available_.emplace_back(std::move(buf)); - TrimAvailable(&trimmed); - } - // The trimmed buffers will be released without holding the lock. -} - -void GlTextureBufferPool::TrimAvailable( - std::vector>* trimmed) { - int keep = std::max(keep_count_ - in_use_count_, 0); - if (available_.size() > keep) { - auto trim_it = std::next(available_.begin(), keep); - if (trimmed) { - std::move(trim_it, available_.end(), std::back_inserter(*trimmed)); - } - available_.erase(trim_it, available_.end()); - } -} - -} // namespace mediapipe +namespace mediapipe {} // namespace mediapipe diff --git a/mediapipe/gpu/gl_texture_buffer_pool.h b/mediapipe/gpu/gl_texture_buffer_pool.h index 4dcad305e..726d0528d 100644 --- a/mediapipe/gpu/gl_texture_buffer_pool.h +++ b/mediapipe/gpu/gl_texture_buffer_pool.h @@ -23,11 +23,12 @@ #include "absl/synchronization/mutex.h" #include "mediapipe/gpu/gl_texture_buffer.h" +#include "mediapipe/gpu/multi_pool.h" +#include "mediapipe/gpu/reusable_pool.h" namespace mediapipe { -class GlTextureBufferPool - : public std::enable_shared_from_this { +class GlTextureBufferPool : public ReusablePool { public: // Creates a pool. This pool will manage buffers of the specified dimensions, // and will keep keep_count buffers around for reuse. @@ -36,42 +37,32 @@ class GlTextureBufferPool static std::shared_ptr Create(int width, int height, GpuBufferFormat format, int keep_count) { - return std::shared_ptr( - new GlTextureBufferPool(width, height, format, keep_count)); + return Create({width, height, format}, {.keep_count = keep_count}); } - // Obtains a buffers. May either be reused or created anew. - // A GlContext must be current when this is called. - GlTextureBufferSharedPtr GetBuffer(); + static std::shared_ptr Create( + const internal::GpuBufferSpec& spec, const MultiPoolOptions& options) { + return std::shared_ptr( + new GlTextureBufferPool(spec, options)); + } - int width() const { return width_; } - int height() const { return height_; } - GpuBufferFormat format() const { return format_; } + int width() const { return spec_.width; } + int height() const { return spec_.height; } + GpuBufferFormat format() const { return spec_.format; } - // This method is meant for testing. - std::pair GetInUseAndAvailableCounts(); + static GlTextureBufferSharedPtr CreateBufferWithoutPool( + const internal::GpuBufferSpec& spec) { + return GlTextureBuffer::Create(spec); + } - private: - GlTextureBufferPool(int width, int height, GpuBufferFormat format, - int keep_count); + protected: + GlTextureBufferPool(const internal::GpuBufferSpec& spec, + const MultiPoolOptions& options) + : ReusablePool( + [this] { return GlTextureBuffer::Create(spec_); }, options), + spec_(spec) {} - // Return a buffer to the pool. - void Return(std::unique_ptr buf); - - // If the total number of buffers is greater than keep_count, destroys any - // surplus buffers that are no longer in use. - void TrimAvailable(std::vector>* trimmed) - ABSL_EXCLUSIVE_LOCKS_REQUIRED(mutex_); - - const int width_; - const int height_; - const GpuBufferFormat format_; - const int keep_count_; - - absl::Mutex mutex_; - int in_use_count_ ABSL_GUARDED_BY(mutex_) = 0; - std::vector> available_ - ABSL_GUARDED_BY(mutex_); + const internal::GpuBufferSpec spec_; }; } // namespace mediapipe diff --git a/mediapipe/gpu/gl_texture_util.cc b/mediapipe/gpu/gl_texture_util.cc new file mode 100644 index 000000000..603e82a46 --- /dev/null +++ b/mediapipe/gpu/gl_texture_util.cc @@ -0,0 +1,30 @@ +#include "mediapipe/gpu/gl_texture_util.h" + +namespace mediapipe { + +void CopyGlTexture(const GlTextureView& src, GlTextureView& dst) { + glViewport(0, 0, src.width(), src.height()); + glFramebufferTexture2D(GL_FRAMEBUFFER, GL_COLOR_ATTACHMENT0, src.target(), + src.name(), 0); + + glActiveTexture(GL_TEXTURE0); + glBindTexture(dst.target(), dst.name()); + glCopyTexSubImage2D(dst.target(), 0, 0, 0, 0, 0, dst.width(), dst.height()); + + glBindTexture(dst.target(), 0); + glFramebufferTexture2D(GL_FRAMEBUFFER, GL_COLOR_ATTACHMENT0, src.target(), 0, + 0); +} + +void FillGlTextureRgba(GlTextureView& view, float r, float g, float b, + float a) { + glViewport(0, 0, view.width(), view.height()); + glFramebufferTexture2D(GL_FRAMEBUFFER, GL_COLOR_ATTACHMENT0, view.target(), + view.name(), 0); + glClearColor(r, g, b, a); + glClear(GL_COLOR_BUFFER_BIT); + glFramebufferTexture2D(GL_FRAMEBUFFER, GL_COLOR_ATTACHMENT0, view.target(), 0, + 0); +} + +} // namespace mediapipe diff --git a/mediapipe/gpu/gl_texture_util.h b/mediapipe/gpu/gl_texture_util.h new file mode 100644 index 000000000..73ac37ade --- /dev/null +++ b/mediapipe/gpu/gl_texture_util.h @@ -0,0 +1,34 @@ +#ifndef MEDIAPIPE_GPU_GL_TEXTURE_UTIL_H_ +#define MEDIAPIPE_GPU_GL_TEXTURE_UTIL_H_ + +#include "mediapipe/gpu/gl_base.h" +#include "mediapipe/gpu/gl_texture_view.h" + +namespace mediapipe { + +// Copies a texture to another. +// Assumes a framebuffer is already set up +void CopyGlTexture(const GlTextureView& src, GlTextureView& dst); + +// Fills a texture with a color. +void FillGlTextureRgba(GlTextureView& view, float r, float g, float b, float a); + +// RAII class to set up a temporary framebuffer. Mainly for test use. +class TempGlFramebuffer { + public: + TempGlFramebuffer() { + glGenFramebuffers(1, &framebuffer_); + glBindFramebuffer(GL_FRAMEBUFFER, framebuffer_); + } + ~TempGlFramebuffer() { + glBindFramebuffer(GL_FRAMEBUFFER, 0); + glDeleteFramebuffers(1, &framebuffer_); + } + + private: + GLuint framebuffer_; +}; + +} // namespace mediapipe + +#endif // MEDIAPIPE_GPU_GL_TEXTURE_UTIL_H_ diff --git a/mediapipe/gpu/gl_texture_view.cc b/mediapipe/gpu/gl_texture_view.cc index 5d1862ddc..cae4039a4 100644 --- a/mediapipe/gpu/gl_texture_view.cc +++ b/mediapipe/gpu/gl_texture_view.cc @@ -7,7 +7,6 @@ void GlTextureView::Release() { if (detach_) detach_(*this); detach_ = nullptr; gl_context_ = nullptr; - gpu_buffer_ = nullptr; plane_ = 0; name_ = 0; width_ = 0; diff --git a/mediapipe/gpu/gl_texture_view.h b/mediapipe/gpu/gl_texture_view.h index 8b47d620b..8a257cf53 100644 --- a/mediapipe/gpu/gl_texture_view.h +++ b/mediapipe/gpu/gl_texture_view.h @@ -25,8 +25,6 @@ namespace mediapipe { class GlContext; -class GlTextureViewManager; -class GpuBuffer; class GlTextureView { public: @@ -43,7 +41,6 @@ class GlTextureView { name_ = other.name_; width_ = other.width_; height_ = other.height_; - gpu_buffer_ = std::move(other.gpu_buffer_); plane_ = other.plane_; detach_ = std::exchange(other.detach_, nullptr); done_writing_ = std::exchange(other.done_writing_, nullptr); @@ -55,26 +52,23 @@ class GlTextureView { int height() const { return height_; } GLenum target() const { return target_; } GLuint name() const { return name_; } - const GpuBuffer& gpu_buffer() const { return *gpu_buffer_; } int plane() const { return plane_; } using DetachFn = std::function; using DoneWritingFn = std::function; private: - friend class GpuBuffer; friend class GlTextureBuffer; friend class GpuBufferStorageCvPixelBuffer; friend class GpuBufferStorageAhwb; GlTextureView(GlContext* context, GLenum target, GLuint name, int width, - int height, std::shared_ptr gpu_buffer, int plane, - DetachFn detach, DoneWritingFn done_writing) + int height, int plane, DetachFn detach, + DoneWritingFn done_writing) : gl_context_(context), target_(target), name_(name), width_(width), height_(height), - gpu_buffer_(std::move(gpu_buffer)), plane_(plane), detach_(std::move(detach)), done_writing_(std::move(done_writing)) {} @@ -93,7 +87,6 @@ class GlTextureView { // Note: when scale is not 1, we still give the nominal size of the image. int width_ = 0; int height_ = 0; - std::shared_ptr gpu_buffer_; // using shared_ptr temporarily int plane_ = 0; DetachFn detach_; mutable DoneWritingFn done_writing_; @@ -112,12 +105,8 @@ class ViewProvider { // the same view implement the same signature. // Note that we allow different views to have custom signatures, providing // additional view-specific arguments that may be needed. - virtual GlTextureView GetReadView(types, - std::shared_ptr gpu_buffer, - int plane) const = 0; - virtual GlTextureView GetWriteView(types, - std::shared_ptr gpu_buffer, - int plane) = 0; + virtual GlTextureView GetReadView(types, int plane) const = 0; + virtual GlTextureView GetWriteView(types, int plane) = 0; }; } // namespace internal diff --git a/mediapipe/gpu/gpu_buffer.cc b/mediapipe/gpu/gpu_buffer.cc index e570ce8ba..628e86099 100644 --- a/mediapipe/gpu/gpu_buffer.cc +++ b/mediapipe/gpu/gpu_buffer.cc @@ -1,7 +1,9 @@ #include "mediapipe/gpu/gpu_buffer.h" #include +#include +#include "absl/functional/bind_front.h" #include "absl/strings/str_cat.h" #include "absl/strings/str_join.h" #include "mediapipe/framework/port/logging.h" @@ -24,66 +26,122 @@ struct StorageTypeFormatter { } // namespace std::string GpuBuffer::DebugString() const { - return absl::StrCat("GpuBuffer[", - absl::StrJoin(storages_, ", ", StorageTypeFormatter()), - "]"); + return holder_ ? absl::StrCat("GpuBuffer[", width(), "x", height(), " ", + format(), " as ", holder_->DebugString(), "]") + : "GpuBuffer[invalid]"; } -internal::GpuBufferStorage& GpuBuffer::GetStorageForView( - TypeId view_provider_type, bool for_writing) const { - const std::shared_ptr* chosen_storage = nullptr; +std::string GpuBuffer::StorageHolder::DebugString() const { + absl::MutexLock lock(&mutex_); + return absl::StrJoin(storages_, ", ", StorageTypeFormatter()); +} - // First see if any current storage supports the view. - for (const auto& s : storages_) { - if (s->can_down_cast_to(view_provider_type)) { - chosen_storage = &s; - break; +internal::GpuBufferStorage* GpuBuffer::StorageHolder::GetStorageForView( + TypeId view_provider_type, bool for_writing) const { + std::shared_ptr chosen_storage; + std::function()> conversion; + + { + absl::MutexLock lock(&mutex_); + // First see if any current storage supports the view. + for (const auto& s : storages_) { + if (s->can_down_cast_to(view_provider_type)) { + chosen_storage = s; + break; + } + } + + // Then try to convert existing storages to one that does. + // TODO: choose best conversion. + if (!chosen_storage) { + for (const auto& s : storages_) { + if (auto converter = internal::GpuBufferStorageRegistry::Get() + .StorageConverterForViewProvider( + view_provider_type, s->storage_type())) { + conversion = absl::bind_front(converter, s); + break; + } + } } } - // Then try to convert existing storages to one that does. - // TODO: choose best conversion. - if (!chosen_storage) { + // Avoid invoking a converter or factory while holding the mutex. + // Two reasons: + // 1. Readers that don't need a conversion will not be blocked. + // 2. We use mutexes to make sure GL contexts are not used simultaneously on + // different threads, and we also rely on Mutex's deadlock detection + // heuristic, which enforces a consistent mutex acquisition order. + // This function is likely to be called within a GL context, and the + // conversion function may in turn use a GL context, and this may cause a + // false positive in the deadlock detector. + // TODO: we could use Mutex::ForgetDeadlockInfo instead. + if (conversion) { + auto new_storage = conversion(); + absl::MutexLock lock(&mutex_); + // Another reader might have already completed and inserted the same + // conversion. TODO: prevent this? for (const auto& s : storages_) { - auto converter = internal::GpuBufferStorageRegistry::Get() - .StorageConverterForViewProvider(view_provider_type, - s->storage_type()); - if (converter) { - storages_.push_back(converter(s)); - chosen_storage = &storages_.back(); + if (s->can_down_cast_to(view_provider_type)) { + chosen_storage = s; + break; } } + if (!chosen_storage) { + storages_.push_back(std::move(new_storage)); + chosen_storage = storages_.back(); + } } if (for_writing) { - if (!chosen_storage) { - // Allocate a new storage supporting the requested view. - auto factory = internal::GpuBufferStorageRegistry::Get() - .StorageFactoryForViewProvider(view_provider_type); - if (factory) { - storages_ = {factory(width(), height(), format())}; - chosen_storage = &storages_.back(); - } - } else { + // This will temporarily hold storages to be released, and do so while the + // lock is not held (see above). + decltype(storages_) old_storages; + using std::swap; + if (chosen_storage) { // Discard all other storages. - storages_ = {*chosen_storage}; - chosen_storage = &storages_.back(); + absl::MutexLock lock(&mutex_); + swap(old_storages, storages_); + storages_ = {chosen_storage}; + } else { + // Allocate a new storage supporting the requested view. + if (auto factory = + internal::GpuBufferStorageRegistry::Get() + .StorageFactoryForViewProvider(view_provider_type)) { + if (auto new_storage = factory(width_, height_, format_)) { + absl::MutexLock lock(&mutex_); + swap(old_storages, storages_); + storages_ = {std::move(new_storage)}; + chosen_storage = storages_.back(); + } + } } } + // It is ok to return a non-owning storage pointer here because this object + // ensures the storage's lifetime. Overwriting a GpuBuffer while readers are + // active would violate this, but it's not allowed in MediaPipe. + return chosen_storage ? chosen_storage.get() : nullptr; +} + +internal::GpuBufferStorage& GpuBuffer::GetStorageForViewOrDie( + TypeId view_provider_type, bool for_writing) const { + auto* chosen_storage = + GpuBuffer::GetStorageForView(view_provider_type, for_writing); CHECK(chosen_storage) << "no view provider found for requested view " << view_provider_type.name() << "; storages available: " - << absl::StrJoin(storages_, ", ", - StorageTypeFormatter()); - DCHECK((*chosen_storage)->can_down_cast_to(view_provider_type)); - return **chosen_storage; + << (holder_ ? holder_->DebugString() : "invalid"); + DCHECK(chosen_storage->can_down_cast_to(view_provider_type)); + return *chosen_storage; } #if !MEDIAPIPE_DISABLE_GPU #if MEDIAPIPE_GPU_BUFFER_USE_CV_PIXEL_BUFFER CVPixelBufferRef GetCVPixelBufferRef(const GpuBuffer& buffer) { - auto p = buffer.internal_storage(); - if (p) return **p; + if (buffer.GetStorageForView( + kTypeId>, + /*for_writing=*/false) != nullptr) { + return *buffer.GetReadView(); + } return nullptr; } #endif // MEDIAPIPE_GPU_BUFFER_USE_CV_PIXEL_BUFFER diff --git a/mediapipe/gpu/gpu_buffer.h b/mediapipe/gpu/gpu_buffer.h index 57e077151..b9a88aa53 100644 --- a/mediapipe/gpu/gpu_buffer.h +++ b/mediapipe/gpu/gpu_buffer.h @@ -15,9 +15,12 @@ #ifndef MEDIAPIPE_GPU_GPU_BUFFER_H_ #define MEDIAPIPE_GPU_GPU_BUFFER_H_ +#include +#include #include #include +#include "absl/synchronization/mutex.h" #include "mediapipe/framework/formats/image_frame.h" #include "mediapipe/gpu/gpu_buffer_format.h" #include "mediapipe/gpu/gpu_buffer_storage.h" @@ -56,8 +59,7 @@ class GpuBuffer { // Creates an empty buffer of a given size and format. It will be allocated // when a view is requested. GpuBuffer(int width, int height, Format format) - : GpuBuffer(std::make_shared(width, height, - format)) {} + : holder_(std::make_shared(width, height, format)) {} // Copy and move constructors and assignment operators are supported. GpuBuffer(const GpuBuffer& other) = default; @@ -70,9 +72,8 @@ class GpuBuffer { // are not portable. Applications and calculators should normally obtain // GpuBuffers in a portable way from the framework, e.g. using // GpuBufferMultiPool. - explicit GpuBuffer(std::shared_ptr storage) { - storages_.push_back(std::move(storage)); - } + explicit GpuBuffer(std::shared_ptr storage) + : holder_(std::make_shared(std::move(storage))) {} #if !MEDIAPIPE_DISABLE_GPU && MEDIAPIPE_GPU_BUFFER_USE_CV_PIXEL_BUFFER // This is used to support backward-compatible construction of GpuBuffer from @@ -84,9 +85,11 @@ class GpuBuffer { : GpuBuffer(internal::AsGpuBufferStorage(storage_convertible)) {} #endif // !MEDIAPIPE_DISABLE_GPU && MEDIAPIPE_GPU_BUFFER_USE_CV_PIXEL_BUFFER - int width() const { return current_storage().width(); } - int height() const { return current_storage().height(); } - GpuBufferFormat format() const { return current_storage().format(); } + int width() const { return holder_ ? holder_->width() : 0; } + int height() const { return holder_ ? holder_->height() : 0; } + GpuBufferFormat format() const { + return holder_ ? holder_->format() : GpuBufferFormat::kUnknown; + } // Converts to true iff valid. explicit operator bool() const { return operator!=(nullptr); } @@ -105,18 +108,16 @@ class GpuBuffer { // specific view type; see the corresponding ViewProvider. template decltype(auto) GetReadView(Args... args) const { - return GetViewProvider(false)->GetReadView( - internal::types{}, std::make_shared(*this), - std::forward(args)...); + return GetViewProviderOrDie(false).GetReadView( + internal::types{}, std::forward(args)...); } // Gets a write view of the specified type. The arguments depend on the // specific view type; see the corresponding ViewProvider. template decltype(auto) GetWriteView(Args... args) { - return GetViewProvider(true)->GetWriteView( - internal::types{}, std::make_shared(*this), - std::forward(args)...); + return GetViewProviderOrDie(true).GetWriteView( + internal::types{}, std::forward(args)...); } // Attempts to access an underlying storage object of the specified type. @@ -124,69 +125,87 @@ class GpuBuffer { // using views. template std::shared_ptr internal_storage() const { - for (const auto& s : storages_) - if (s->down_cast()) return std::static_pointer_cast(s); - return nullptr; + return holder_ ? holder_->internal_storage() : nullptr; } std::string DebugString() const; private: - class PlaceholderGpuBufferStorage - : public internal::GpuBufferStorageImpl { + internal::GpuBufferStorage* GetStorageForView(TypeId view_provider_type, + bool for_writing) const { + return holder_ ? holder_->GetStorageForView(view_provider_type, for_writing) + : nullptr; + } + + internal::GpuBufferStorage& GetStorageForViewOrDie(TypeId view_provider_type, + bool for_writing) const; + + template + internal::ViewProvider& GetViewProviderOrDie(bool for_writing) const { + using VP = internal::ViewProvider; + return *GetStorageForViewOrDie(kTypeId, for_writing) + .template down_cast(); + } + + // This class manages a set of alternative storages for the contents of a + // GpuBuffer. GpuBuffer was originally designed as a reference-type object, + // where a copy represents another reference to the same contents, so multiple + // GpuBuffer instances can share the same StorageHolder. + class StorageHolder { public: - PlaceholderGpuBufferStorage(int width, int height, Format format) + explicit StorageHolder(std::shared_ptr storage) + : StorageHolder(storage->width(), storage->height(), + storage->format()) { + storages_.push_back(std::move(storage)); + } + explicit StorageHolder(int width, int height, Format format) : width_(width), height_(height), format_(format) {} - int width() const override { return width_; } - int height() const override { return height_; } - GpuBufferFormat format() const override { return format_; } + + int width() const { return width_; } + int height() const { return height_; } + GpuBufferFormat format() const { return format_; } + + internal::GpuBufferStorage* GetStorageForView(TypeId view_provider_type, + bool for_writing) const; + + template + std::shared_ptr internal_storage() const { + absl::MutexLock lock(&mutex_); + for (const auto& s : storages_) + if (s->down_cast()) return std::static_pointer_cast(s); + return nullptr; + } + + std::string DebugString() const; private: int width_ = 0; int height_ = 0; GpuBufferFormat format_ = GpuBufferFormat::kUnknown; + // This is mutable because view methods that do not change the contents may + // still need to allocate new storages. + mutable absl::Mutex mutex_; + mutable std::vector> storages_ + ABSL_GUARDED_BY(mutex_); }; - internal::GpuBufferStorage& GetStorageForView(TypeId view_provider_type, - bool for_writing) const; + std::shared_ptr holder_; - template - internal::ViewProvider* GetViewProvider(bool for_writing) const { - using VP = internal::ViewProvider; - return GetStorageForView(kTypeId, for_writing).template down_cast(); - } - - std::shared_ptr& no_storage() const { - static auto placeholder = - std::static_pointer_cast( - std::make_shared( - 0, 0, GpuBufferFormat::kUnknown)); - return placeholder; - } - - const internal::GpuBufferStorage& current_storage() const { - return storages_.empty() ? *no_storage() : *storages_[0]; - } - - internal::GpuBufferStorage& current_storage() { - return storages_.empty() ? *no_storage() : *storages_[0]; - } - - // This is mutable because view methods that do not change the contents may - // still need to allocate new storages. - mutable std::vector> storages_; +#if MEDIAPIPE_GPU_BUFFER_USE_CV_PIXEL_BUFFER + friend CVPixelBufferRef GetCVPixelBufferRef(const GpuBuffer& buffer); +#endif // MEDIAPIPE_GPU_BUFFER_USE_CV_PIXEL_BUFFER }; inline bool GpuBuffer::operator==(std::nullptr_t other) const { - return storages_.empty(); + return holder_ == other; } inline bool GpuBuffer::operator==(const GpuBuffer& other) const { - return storages_ == other.storages_; + return holder_ == other.holder_; } inline GpuBuffer& GpuBuffer::operator=(std::nullptr_t other) { - storages_.clear(); + holder_ = other; return *this; } diff --git a/mediapipe/gpu/gpu_buffer_format.cc b/mediapipe/gpu/gpu_buffer_format.cc index 1dcd58e63..8e2e3858e 100644 --- a/mediapipe/gpu/gpu_buffer_format.cc +++ b/mediapipe/gpu/gpu_buffer_format.cc @@ -212,6 +212,10 @@ ImageFormat::Format ImageFormatForGpuBufferFormat(GpuBufferFormat format) { case GpuBufferFormat::kTwoComponentHalf16: case GpuBufferFormat::kRGBAHalf64: case GpuBufferFormat::kRGBAFloat128: + case GpuBufferFormat::kNV12: + case GpuBufferFormat::kNV21: + case GpuBufferFormat::kI420: + case GpuBufferFormat::kYV12: case GpuBufferFormat::kUnknown: return ImageFormat::UNKNOWN; } diff --git a/mediapipe/gpu/gpu_buffer_format.h b/mediapipe/gpu/gpu_buffer_format.h index 45f054d31..5d77afeb6 100644 --- a/mediapipe/gpu/gpu_buffer_format.h +++ b/mediapipe/gpu/gpu_buffer_format.h @@ -52,6 +52,14 @@ enum class GpuBufferFormat : uint32_t { kRGB24 = 0x00000018, // Note: prefer BGRA32 whenever possible. kRGBAHalf64 = MEDIAPIPE_FOURCC('R', 'G', 'h', 'A'), kRGBAFloat128 = MEDIAPIPE_FOURCC('R', 'G', 'f', 'A'), + // 8-bit Y plane + interleaved 8-bit U/V plane with 2x2 subsampling. + kNV12 = MEDIAPIPE_FOURCC('N', 'V', '1', '2'), + // 8-bit Y plane + interleaved 8-bit V/U plane with 2x2 subsampling. + kNV21 = MEDIAPIPE_FOURCC('N', 'V', '2', '1'), + // 8-bit Y plane + non-interleaved 8-bit U/V planes with 2x2 subsampling. + kI420 = MEDIAPIPE_FOURCC('I', '4', '2', '0'), + // 8-bit Y plane + non-interleaved 8-bit V/U planes with 2x2 subsampling. + kYV12 = MEDIAPIPE_FOURCC('Y', 'V', '1', '2'), }; #if !MEDIAPIPE_DISABLE_GPU @@ -111,6 +119,10 @@ inline OSType CVPixelFormatForGpuBufferFormat(GpuBufferFormat format) { return kCVPixelFormatType_64RGBAHalf; case GpuBufferFormat::kRGBAFloat128: return kCVPixelFormatType_128RGBAFloat; + case GpuBufferFormat::kNV12: + case GpuBufferFormat::kNV21: + case GpuBufferFormat::kI420: + case GpuBufferFormat::kYV12: case GpuBufferFormat::kUnknown: return -1; } @@ -153,6 +165,34 @@ inline GpuBufferFormat GpuBufferFormatForCVPixelFormat(OSType format) { #endif // __APPLE__ +namespace internal { + +struct GpuBufferSpec { + GpuBufferSpec(int w, int h, GpuBufferFormat f) + : width(w), height(h), format(f) {} + + template + friend H AbslHashValue(H h, const GpuBufferSpec& spec) { + return H::combine(std::move(h), spec.width, spec.height, + static_cast(spec.format)); + } + + int width; + int height; + GpuBufferFormat format; +}; + +// BufferSpec equality operators +inline bool operator==(const GpuBufferSpec& lhs, const GpuBufferSpec& rhs) { + return lhs.width == rhs.width && lhs.height == rhs.height && + lhs.format == rhs.format; +} +inline bool operator!=(const GpuBufferSpec& lhs, const GpuBufferSpec& rhs) { + return !operator==(lhs, rhs); +} + +} // namespace internal + } // namespace mediapipe #endif // MEDIAPIPE_GPU_GPU_BUFFER_FORMAT_H_ diff --git a/mediapipe/gpu/gpu_buffer_multi_pool.cc b/mediapipe/gpu/gpu_buffer_multi_pool.cc index 6e4fd38ea..e2ed523e4 100644 --- a/mediapipe/gpu/gpu_buffer_multi_pool.cc +++ b/mediapipe/gpu/gpu_buffer_multi_pool.cc @@ -16,204 +16,7 @@ #include -#include "absl/memory/memory.h" #include "absl/synchronization/mutex.h" #include "mediapipe/framework/port/logging.h" -#include "mediapipe/gpu/gpu_shared_data_internal.h" -#if MEDIAPIPE_GPU_BUFFER_USE_CV_PIXEL_BUFFER -#include "CoreFoundation/CFBase.h" -#include "mediapipe/objc/CFHolder.h" -#include "mediapipe/objc/util.h" -#endif // MEDIAPIPE_GPU_BUFFER_USE_CV_PIXEL_BUFFER - -namespace mediapipe { - -// Keep this many buffers allocated for a given frame size. -static constexpr int kKeepCount = 2; -// The maximum size of the GpuBufferMultiPool. When the limit is reached, the -// oldest BufferSpec will be dropped. -static constexpr int kMaxPoolCount = 10; -// Time in seconds after which an inactive buffer can be dropped from the pool. -// Currently only used with CVPixelBufferPool. -static constexpr float kMaxInactiveBufferAge = 0.25; -// Skip allocating a buffer pool until at least this many requests have been -// made for a given BufferSpec. -static constexpr int kMinRequestsBeforePool = 2; -// Do a deeper flush every this many requests. -static constexpr int kRequestCountScrubInterval = 50; - -#if MEDIAPIPE_GPU_BUFFER_USE_CV_PIXEL_BUFFER - -CvPixelBufferPoolWrapper::CvPixelBufferPoolWrapper( - const GpuBufferMultiPool::BufferSpec& spec, CFTimeInterval maxAge) { - OSType cv_format = CVPixelFormatForGpuBufferFormat(spec.format); - CHECK_NE(cv_format, -1) << "unsupported pixel format"; - pool_ = MakeCFHolderAdopting( - /* keep count is 0 because the age param keeps buffers around anyway */ - CreateCVPixelBufferPool(spec.width, spec.height, cv_format, 0, maxAge)); -} - -GpuBuffer CvPixelBufferPoolWrapper::GetBuffer(std::function flush) { - CVPixelBufferRef buffer; - int threshold = 1; - NSMutableDictionary* auxAttributes = - [NSMutableDictionary dictionaryWithCapacity:1]; - CVReturn err; - bool tried_flushing = false; - while (1) { - auxAttributes[(id)kCVPixelBufferPoolAllocationThresholdKey] = @(threshold); - err = CVPixelBufferPoolCreatePixelBufferWithAuxAttributes( - kCFAllocatorDefault, *pool_, (__bridge CFDictionaryRef)auxAttributes, - &buffer); - if (err != kCVReturnWouldExceedAllocationThreshold) break; - if (flush && !tried_flushing) { - // Call the flush function to potentially release old holds on buffers - // and try again to create a pixel buffer. - // This is used to flush CV texture caches, which may retain buffers until - // flushed. - flush(); - tried_flushing = true; - } else { - ++threshold; - } - } - CHECK(!err) << "Error creating pixel buffer: " << err; - count_ = threshold; - return GpuBuffer(MakeCFHolderAdopting(buffer)); -} - -std::string CvPixelBufferPoolWrapper::GetDebugString() const { - auto description = MakeCFHolderAdopting(CFCopyDescription(*pool_)); - return [(__bridge NSString*)*description UTF8String]; -} - -void CvPixelBufferPoolWrapper::Flush() { CVPixelBufferPoolFlush(*pool_, 0); } - -GpuBufferMultiPool::SimplePool GpuBufferMultiPool::MakeSimplePool( - const GpuBufferMultiPool::BufferSpec& spec) { - return std::make_shared(spec, - kMaxInactiveBufferAge); -} - -GpuBuffer GpuBufferMultiPool::GetBufferWithoutPool(const BufferSpec& spec) { - OSType cv_format = CVPixelFormatForGpuBufferFormat(spec.format); - CHECK_NE(cv_format, -1) << "unsupported pixel format"; - CVPixelBufferRef buffer; - CVReturn err = CreateCVPixelBufferWithoutPool(spec.width, spec.height, - cv_format, &buffer); - CHECK(!err) << "Error creating pixel buffer: " << err; - return GpuBuffer(MakeCFHolderAdopting(buffer)); -} - -void GpuBufferMultiPool::FlushTextureCaches() { - absl::MutexLock lock(&mutex_); - for (const auto& cache : texture_caches_) { -#if TARGET_OS_OSX - CVOpenGLTextureCacheFlush(*cache, 0); -#else - CVOpenGLESTextureCacheFlush(*cache, 0); -#endif // TARGET_OS_OSX - } -} - -// Turning this on disables the pixel buffer pools when using the simulator. -// It is no longer necessary, since the helper code now supports non-contiguous -// buffers. We leave the code in for now for the sake of documentation. -#define FORCE_CONTIGUOUS_PIXEL_BUFFER_ON_IPHONE_SIMULATOR 0 - -GpuBuffer GpuBufferMultiPool::GetBufferFromSimplePool( - BufferSpec spec, const GpuBufferMultiPool::SimplePool& pool) { -#if TARGET_IPHONE_SIMULATOR && FORCE_CONTIGUOUS_PIXEL_BUFFER_ON_IPHONE_SIMULATOR - // On the simulator, syncing the texture with the pixelbuffer does not work, - // and we have to use glReadPixels. Since GL_UNPACK_ROW_LENGTH is not - // available in OpenGL ES 2, we should create the buffer so the pixels are - // contiguous. - // - // TODO: verify if we can use kIOSurfaceBytesPerRow to force the - // pool to give us contiguous data. - return GetBufferWithoutPool(spec); -#else - return pool->GetBuffer([this]() { FlushTextureCaches(); }); -#endif // TARGET_IPHONE_SIMULATOR -} - -#else - -GpuBufferMultiPool::SimplePool GpuBufferMultiPool::MakeSimplePool( - const BufferSpec& spec) { - return GlTextureBufferPool::Create(spec.width, spec.height, spec.format, - kKeepCount); -} - -GpuBuffer GpuBufferMultiPool::GetBufferWithoutPool(const BufferSpec& spec) { - return GpuBuffer( - GlTextureBuffer::Create(spec.width, spec.height, spec.format)); -} - -GpuBuffer GpuBufferMultiPool::GetBufferFromSimplePool( - BufferSpec spec, const GpuBufferMultiPool::SimplePool& pool) { - return GpuBuffer(pool->GetBuffer()); -} - -#endif // MEDIAPIPE_GPU_BUFFER_USE_CV_PIXEL_BUFFER - -GpuBufferMultiPool::SimplePool GpuBufferMultiPool::RequestPool( - const BufferSpec& spec) { - SimplePool pool; - std::vector evicted; - { - absl::MutexLock lock(&mutex_); - pool = - cache_.Lookup(spec, [this](const BufferSpec& spec, int request_count) { - return (request_count >= kMinRequestsBeforePool) - ? MakeSimplePool(spec) - : nullptr; - }); - evicted = cache_.Evict(kMaxPoolCount, kRequestCountScrubInterval); - } - // Evicted pools, and their buffers, will be released without holding the - // lock. - return pool; -} - -GpuBuffer GpuBufferMultiPool::GetBuffer(int width, int height, - GpuBufferFormat format) { - BufferSpec key(width, height, format); - SimplePool pool = RequestPool(key); - if (pool) { - // Note: we release our multipool lock before accessing the simple pool. - return GetBufferFromSimplePool(key, pool); - } else { - return GetBufferWithoutPool(key); - } -} - -GpuBufferMultiPool::~GpuBufferMultiPool() { -#ifdef __APPLE__ - CHECK_EQ(texture_caches_.size(), 0) - << "Failed to unregister texture caches before deleting pool"; -#endif // defined(__APPLE__) -} - -#ifdef __APPLE__ -void GpuBufferMultiPool::RegisterTextureCache(CVTextureCacheType cache) { - absl::MutexLock lock(&mutex_); - - CHECK(std::find(texture_caches_.begin(), texture_caches_.end(), cache) == - texture_caches_.end()) - << "Attempting to register a texture cache twice"; - texture_caches_.emplace_back(cache); -} - -void GpuBufferMultiPool::UnregisterTextureCache(CVTextureCacheType cache) { - absl::MutexLock lock(&mutex_); - - auto it = std::find(texture_caches_.begin(), texture_caches_.end(), cache); - CHECK(it != texture_caches_.end()) - << "Attempting to unregister an unknown texture cache"; - texture_caches_.erase(it); -} -#endif // defined(__APPLE__) - -} // namespace mediapipe +namespace mediapipe {} // namespace mediapipe diff --git a/mediapipe/gpu/gpu_buffer_multi_pool.h b/mediapipe/gpu/gpu_buffer_multi_pool.h index 5ea6e314f..827cf514a 100644 --- a/mediapipe/gpu/gpu_buffer_multi_pool.h +++ b/mediapipe/gpu/gpu_buffer_multi_pool.h @@ -22,120 +22,35 @@ #ifndef MEDIAPIPE_GPU_GPU_BUFFER_MULTI_POOL_H_ #define MEDIAPIPE_GPU_GPU_BUFFER_MULTI_POOL_H_ -#include "absl/hash/hash.h" #include "absl/synchronization/mutex.h" #include "mediapipe/gpu/gpu_buffer.h" -#include "mediapipe/util/resource_cache.h" +#include "mediapipe/gpu/multi_pool.h" -#ifdef __APPLE__ -#include "mediapipe/gpu/pixel_buffer_pool_util.h" -#endif // __APPLE__ - -#if !MEDIAPIPE_GPU_BUFFER_USE_CV_PIXEL_BUFFER +#if MEDIAPIPE_GPU_BUFFER_USE_CV_PIXEL_BUFFER +#include "mediapipe/gpu/cv_pixel_buffer_pool_wrapper.h" +#else #include "mediapipe/gpu/gl_texture_buffer_pool.h" -#endif // !MEDIAPIPE_GPU_BUFFER_USE_CV_PIXEL_BUFFER +#endif // MEDIAPIPE_GPU_BUFFER_USE_CV_PIXEL_BUFFER namespace mediapipe { -struct GpuSharedData; class CvPixelBufferPoolWrapper; -class GpuBufferMultiPool { - public: - GpuBufferMultiPool() {} - explicit GpuBufferMultiPool(void* ignored) {} - ~GpuBufferMultiPool(); - - // Obtains a buffer. May either be reused or created anew. - GpuBuffer GetBuffer(int width, int height, - GpuBufferFormat format = GpuBufferFormat::kBGRA32); - -#ifdef __APPLE__ - // TODO: add tests for the texture cache registration. - - // Inform the pool of a cache that should be flushed when it is low on - // reusable buffers. - void RegisterTextureCache(CVTextureCacheType cache); - - // Remove a texture cache from the list of caches to be flushed. - void UnregisterTextureCache(CVTextureCacheType cache); - - void FlushTextureCaches(); -#endif // defined(__APPLE__) - - // This class is not intended as part of the public api of this class. It is - // public only because it is used as a map key type, and the map - // implementation needs access to, e.g., the equality operator. - struct BufferSpec { - BufferSpec(int w, int h, mediapipe::GpuBufferFormat f) - : width(w), height(h), format(f) {} - - template - friend H AbslHashValue(H h, const BufferSpec& spec) { - return H::combine(std::move(h), spec.width, spec.height, - static_cast(spec.format)); - } - - int width; - int height; - mediapipe::GpuBufferFormat format; - }; - - private: +class GpuBufferMultiPool : public MultiPool< #if MEDIAPIPE_GPU_BUFFER_USE_CV_PIXEL_BUFFER - using SimplePool = std::shared_ptr; + CvPixelBufferPoolWrapper, #else - using SimplePool = std::shared_ptr; + GlTextureBufferPool, #endif // MEDIAPIPE_GPU_BUFFER_USE_CV_PIXEL_BUFFER - - SimplePool MakeSimplePool(const BufferSpec& spec); - // Requests a simple buffer pool for the given spec. This may return nullptr - // if we have not yet reached a sufficient number of requests to allocate a - // pool, in which case the caller should invoke GetBufferWithoutPool instead - // of GetBufferFromSimplePool. - SimplePool RequestPool(const BufferSpec& spec); - GpuBuffer GetBufferFromSimplePool(BufferSpec spec, const SimplePool& pool); - GpuBuffer GetBufferWithoutPool(const BufferSpec& spec); - - absl::Mutex mutex_; - mediapipe::ResourceCache> - cache_ ABSL_GUARDED_BY(mutex_); - -#ifdef __APPLE__ - // Texture caches used with this pool. - std::vector> texture_caches_ - ABSL_GUARDED_BY(mutex_); -#endif // defined(__APPLE__) -}; - -#if MEDIAPIPE_GPU_BUFFER_USE_CV_PIXEL_BUFFER -class CvPixelBufferPoolWrapper { + internal::GpuBufferSpec, GpuBuffer> { public: - CvPixelBufferPoolWrapper(const GpuBufferMultiPool::BufferSpec& spec, - CFTimeInterval maxAge); - GpuBuffer GetBuffer(std::function flush); + using MultiPool::MultiPool; - int GetBufferCount() const { return count_; } - std::string GetDebugString() const; - - void Flush(); - - private: - CFHolder pool_; - int count_ = 0; + GpuBuffer GetBuffer(int width, int height, + GpuBufferFormat format = GpuBufferFormat::kBGRA32) { + return Get(internal::GpuBufferSpec(width, height, format)); + } }; -#endif // MEDIAPIPE_GPU_BUFFER_USE_CV_PIXEL_BUFFER - -// BufferSpec equality operators -inline bool operator==(const GpuBufferMultiPool::BufferSpec& lhs, - const GpuBufferMultiPool::BufferSpec& rhs) { - return lhs.width == rhs.width && lhs.height == rhs.height && - lhs.format == rhs.format; -} -inline bool operator!=(const GpuBufferMultiPool::BufferSpec& lhs, - const GpuBufferMultiPool::BufferSpec& rhs) { - return !operator==(lhs, rhs); -} } // namespace mediapipe diff --git a/mediapipe/gpu/gpu_buffer_storage.h b/mediapipe/gpu/gpu_buffer_storage.h index 3d872eb66..19661d930 100644 --- a/mediapipe/gpu/gpu_buffer_storage.h +++ b/mediapipe/gpu/gpu_buffer_storage.h @@ -13,22 +13,57 @@ #include "mediapipe/gpu/gpu_buffer_format.h" namespace mediapipe { -class GpuBuffer; namespace internal { template struct types {}; +// This template must be specialized for each view type V. Each specialization +// should define a pair of virtual methods called GetReadView and GetWriteView, +// whose first argument is a types tag object. The result type and optional +// further arguments will depend on the view type. +// +// Example: +// template <> +// class ViewProvider { +// public: +// virtual ~ViewProvider() = default; +// virtual MyView GetReadView(types) const = 0; +// virtual MyView GetWriteView(types) = 0; +// }; +// +// The additional arguments and result type are reflected in GpuBuffer's +// GetReadView and GetWriteView methods. +// +// Using a type tag for the first argument allows the methods to be overloaded, +// so that a single storage can implement provider methods for multiple views. +// Since these methods are not template methods, they can (and should) be +// virtual, which allows storage classes to override them, enforcing that all +// storages providing a given view type implement the same interface. template class ViewProvider; -// Interface for a backing storage for GpuBuffer. +// Generic interface for a backing storage for GpuBuffer. +// +// GpuBuffer is an opaque handle to an image. Its contents are handled by +// Storage classes. Application code does not interact with the storages +// directly; to access the data, it asks the GpuBuffer for a View, and in turn +// GpuBuffer looks for a storage that can provide that view. +// This architecture decouples application code from the underlying storage, +// making it possible to use platform-specific optimized storage systems, e.g. +// for zero-copy data sharing between CPU and GPU. +// +// Storage implementations should inherit from GpuBufferStorageImpl. See that +// class for details. class GpuBufferStorage { public: virtual ~GpuBufferStorage() = default; + + // Concrete storage types should override the following three accessors. virtual int width() const = 0; virtual int height() const = 0; virtual GpuBufferFormat format() const = 0; + // We can't use dynamic_cast since we want to support building without RTTI. // The public methods delegate to the type-erased private virtual method. template @@ -72,19 +107,33 @@ class GpuBufferStorageRegistry { return *registry; } + // Registers a storage type by automatically creating a factory for it. + // This is normally called by GpuBufferImpl. template RegistryToken Register() { - return Register( + return RegisterFactory( [](int width, int height, GpuBufferFormat format) -> std::shared_ptr { return CreateStorage(overload_priority<10>{}, width, height, format); - }, - Storage::GetProviderTypes()); + }); } + // Registers a new factory for a storage type. + template + RegistryToken RegisterFactory(F&& factory) { + if constexpr (kDisableRegistration) { + return {}; + } + return Register(factory, Storage::GetProviderTypes()); + } + + // Registers a new converter from storage type StorageFrom to StorageTo. template RegistryToken RegisterConverter(F&& converter) { + if constexpr (kDisableRegistration) { + return {}; + } return Register( [converter](std::shared_ptr source) -> std::shared_ptr { @@ -115,6 +164,13 @@ class GpuBufferStorageRegistry { return std::make_shared(args...); } + // Temporary workaround: a Storage class can define a static constexpr + // kDisableGpuBufferRegistration member to true to prevent registering any + // factory of converter that would produce it. + // TODO: better solution for storage priorities. + template + static constexpr bool kDisableRegistration = false; + RegistryToken Register(StorageFactory factory, std::vector provider_hashes); RegistryToken Register(StorageConverter converter, @@ -126,6 +182,13 @@ class GpuBufferStorageRegistry { converter_for_view_provider_and_existing_storage_; }; +// Putting this outside the class body to work around a GCC bug. +// https://gcc.gnu.org/bugzilla/show_bug.cgi?id=71954 +template +constexpr bool GpuBufferStorageRegistry::kDisableRegistration< + Storage, std::void_t> = + Storage::kDisableGpuBufferRegistration; + // Defining a member of this type causes P to be ODR-used, which forces its // instantiation if it's a static member of a template. template @@ -138,21 +201,41 @@ struct ForceStaticInstantiation { #endif // _MSC_VER }; -// T: storage type -// U...: ViewProvider +// Inherit from this class to define a new storage type. The storage type itself +// should be passed as the first template argument (CRTP), followed by one or +// more specializations of ViewProvider. +// +// Concrete storage types should implement the basic accessors from +// GpuBufferStorage, plus the view read/write getters for each ViewProvider they +// implement. This class handles the rest. +// +// Arguments: +// T: storage type +// U...: ViewProvider +// Example: +// class MyStorage : public GpuBufferStorageImpl< +// MyStorage, ViewProvider> template class GpuBufferStorageImpl : public GpuBufferStorage, public U... { public: static const std::vector& GetProviderTypes() { - static std::vector kHashes{kTypeId...}; - return kHashes; + static std::vector kProviderIds{kTypeId...}; + return kProviderIds; + } + + // Exposing this as a function allows dependent initializers to call this to + // ensure proper ordering. + static GpuBufferStorageRegistry::RegistryToken RegisterOnce() { + static auto registration = GpuBufferStorageRegistry::Get().Register(); + return registration; } private: - virtual const void* down_cast(TypeId to) const override { + // Allows a down_cast to any of the view provider types in U. + const void* down_cast(TypeId to) const final { return down_cast_impl(to, types{}); } - TypeId storage_type() const override { return kTypeId; } + TypeId storage_type() const final { return kTypeId; } const void* down_cast_impl(TypeId to, types<>) const { return nullptr; } template @@ -161,8 +244,7 @@ class GpuBufferStorageImpl : public GpuBufferStorage, public U... { return down_cast_impl(to, types{}); } - inline static auto registration = - GpuBufferStorageRegistry::Get().Register(); + inline static auto registration = RegisterOnce(); using RequireStatics = ForceStaticInstantiation<®istration>; }; diff --git a/mediapipe/gpu/gpu_buffer_storage_cv_pixel_buffer.cc b/mediapipe/gpu/gpu_buffer_storage_cv_pixel_buffer.cc index d68ac0db0..7cac32b7f 100644 --- a/mediapipe/gpu/gpu_buffer_storage_cv_pixel_buffer.cc +++ b/mediapipe/gpu/gpu_buffer_storage_cv_pixel_buffer.cc @@ -26,8 +26,7 @@ GpuBufferStorageCvPixelBuffer::GpuBufferStorageCvPixelBuffer( } GlTextureView GpuBufferStorageCvPixelBuffer::GetTexture( - std::shared_ptr gpu_buffer, int plane, - GlTextureView::DoneWritingFn done_writing) const { + int plane, GlTextureView::DoneWritingFn done_writing) const { CVReturn err; auto gl_context = GlContext::GetCurrent(); CHECK(gl_context); @@ -60,79 +59,92 @@ GlTextureView GpuBufferStorageCvPixelBuffer::GetTexture( cv_texture.adopt(cv_texture_temp); return GlTextureView( gl_context.get(), CVOpenGLESTextureGetTarget(*cv_texture), - CVOpenGLESTextureGetName(*cv_texture), width(), height(), - std::move(gpu_buffer), plane, + CVOpenGLESTextureGetName(*cv_texture), width(), height(), plane, [cv_texture](mediapipe::GlTextureView&) { /* only retains cv_texture */ }, done_writing); #endif // TARGET_OS_OSX } GlTextureView GpuBufferStorageCvPixelBuffer::GetReadView( - internal::types, std::shared_ptr gpu_buffer, - int plane) const { - return GetTexture(std::move(gpu_buffer), plane, nullptr); + internal::types, int plane) const { + return GetTexture(plane, nullptr); } +#if TARGET_IPHONE_SIMULATOR +static void ViewDoneWritingSimulatorWorkaround(CVPixelBufferRef pixel_buffer, + const GlTextureView& view) { + CHECK(pixel_buffer); + auto ctx = GlContext::GetCurrent().get(); + if (!ctx) ctx = view.gl_context(); + ctx->Run([pixel_buffer, &view, ctx] { + CVReturn err = CVPixelBufferLockBaseAddress(pixel_buffer, 0); + CHECK(err == kCVReturnSuccess) + << "CVPixelBufferLockBaseAddress failed: " << err; + OSType pixel_format = CVPixelBufferGetPixelFormatType(pixel_buffer); + size_t bytes_per_row = CVPixelBufferGetBytesPerRow(pixel_buffer); + uint8_t* pixel_ptr = + static_cast(CVPixelBufferGetBaseAddress(pixel_buffer)); + if (pixel_format == kCVPixelFormatType_32BGRA) { + glBindFramebuffer(GL_FRAMEBUFFER, kUtilityFramebuffer.Get(*ctx)); + glViewport(0, 0, view.width(), view.height()); + glFramebufferTexture2D(GL_FRAMEBUFFER, GL_COLOR_ATTACHMENT0, + view.target(), view.name(), 0); + + size_t contiguous_bytes_per_row = view.width() * 4; + if (bytes_per_row == contiguous_bytes_per_row) { + glReadPixels(0, 0, view.width(), view.height(), GL_BGRA, + GL_UNSIGNED_BYTE, pixel_ptr); + } else { + // TODO: use GL_PACK settings for row length. We can expect + // GLES 3.0 on iOS now. + std::vector contiguous_buffer(contiguous_bytes_per_row * + view.height()); + uint8_t* temp_ptr = contiguous_buffer.data(); + glReadPixels(0, 0, view.width(), view.height(), GL_BGRA, + GL_UNSIGNED_BYTE, temp_ptr); + for (int i = 0; i < view.height(); ++i) { + memcpy(pixel_ptr, temp_ptr, contiguous_bytes_per_row); + temp_ptr += contiguous_bytes_per_row; + pixel_ptr += bytes_per_row; + } + } + // TODO: restore previous framebuffer? + glFramebufferTexture2D(GL_FRAMEBUFFER, GL_COLOR_ATTACHMENT0, + view.target(), 0, 0); + glBindFramebuffer(GL_FRAMEBUFFER, 0); + } else { + LOG(ERROR) << "unsupported pixel format: " << pixel_format; + } + err = CVPixelBufferUnlockBaseAddress(pixel_buffer, 0); + CHECK(err == kCVReturnSuccess) + << "CVPixelBufferUnlockBaseAddress failed: " << err; + }); +} +#endif // TARGET_IPHONE_SIMULATOR + GlTextureView GpuBufferStorageCvPixelBuffer::GetWriteView( - internal::types, std::shared_ptr gpu_buffer, - int plane) { - return GetTexture( - std::move(gpu_buffer), plane, - [this](const mediapipe::GlTextureView& view) { ViewDoneWriting(view); }); + internal::types, int plane) { + return GetTexture(plane, +#if TARGET_IPHONE_SIMULATOR + [pixel_buffer = CFHolder(*this)]( + const mediapipe::GlTextureView& view) { + ViewDoneWritingSimulatorWorkaround(*pixel_buffer, view); + } +#else + nullptr +#endif // TARGET_IPHONE_SIMULATOR + ); } std::shared_ptr GpuBufferStorageCvPixelBuffer::GetReadView( - internal::types, std::shared_ptr gpu_buffer) const { + internal::types) const { return CreateImageFrameForCVPixelBuffer(**this); } std::shared_ptr GpuBufferStorageCvPixelBuffer::GetWriteView( - internal::types, std::shared_ptr gpu_buffer) { + internal::types) { return CreateImageFrameForCVPixelBuffer(**this); } -void GpuBufferStorageCvPixelBuffer::ViewDoneWriting(const GlTextureView& view) { -#if TARGET_IPHONE_SIMULATOR - CVPixelBufferRef pixel_buffer = **this; - CHECK(pixel_buffer); - CVReturn err = CVPixelBufferLockBaseAddress(pixel_buffer, 0); - CHECK(err == kCVReturnSuccess) - << "CVPixelBufferLockBaseAddress failed: " << err; - OSType pixel_format = CVPixelBufferGetPixelFormatType(pixel_buffer); - size_t bytes_per_row = CVPixelBufferGetBytesPerRow(pixel_buffer); - uint8_t* pixel_ptr = - static_cast(CVPixelBufferGetBaseAddress(pixel_buffer)); - if (pixel_format == kCVPixelFormatType_32BGRA) { - // TODO: restore previous framebuffer? Move this to helper so we - // can use BindFramebuffer? - glViewport(0, 0, view.width(), view.height()); - glFramebufferTexture2D(GL_FRAMEBUFFER, GL_COLOR_ATTACHMENT0, view.target(), - view.name(), 0); - - size_t contiguous_bytes_per_row = view.width() * 4; - if (bytes_per_row == contiguous_bytes_per_row) { - glReadPixels(0, 0, view.width(), view.height(), GL_BGRA, GL_UNSIGNED_BYTE, - pixel_ptr); - } else { - std::vector contiguous_buffer(contiguous_bytes_per_row * - view.height()); - uint8_t* temp_ptr = contiguous_buffer.data(); - glReadPixels(0, 0, view.width(), view.height(), GL_BGRA, GL_UNSIGNED_BYTE, - temp_ptr); - for (int i = 0; i < view.height(); ++i) { - memcpy(pixel_ptr, temp_ptr, contiguous_bytes_per_row); - temp_ptr += contiguous_bytes_per_row; - pixel_ptr += bytes_per_row; - } - } - } else { - LOG(ERROR) << "unsupported pixel format: " << pixel_format; - } - err = CVPixelBufferUnlockBaseAddress(pixel_buffer, 0); - CHECK(err == kCVReturnSuccess) - << "CVPixelBufferUnlockBaseAddress failed: " << err; -#endif -} - static std::shared_ptr ConvertFromImageFrame( std::shared_ptr frame) { auto status_or_buffer = diff --git a/mediapipe/gpu/gpu_buffer_storage_cv_pixel_buffer.h b/mediapipe/gpu/gpu_buffer_storage_cv_pixel_buffer.h index 017771dc7..8723a1087 100644 --- a/mediapipe/gpu/gpu_buffer_storage_cv_pixel_buffer.h +++ b/mediapipe/gpu/gpu_buffer_storage_cv_pixel_buffer.h @@ -12,10 +12,25 @@ namespace mediapipe { class GlContext; +namespace internal { + +template <> +class ViewProvider { + public: + virtual ~ViewProvider() = default; + virtual CFHolder GetReadView( + internal::types) const = 0; + virtual CFHolder GetWriteView( + internal::types) = 0; +}; + +} // namespace internal + class GpuBufferStorageCvPixelBuffer : public internal::GpuBufferStorageImpl< GpuBufferStorageCvPixelBuffer, internal::ViewProvider, - internal::ViewProvider>, + internal::ViewProvider, + internal::ViewProvider>, public CFHolder { public: using CFHolder::CFHolder; @@ -33,24 +48,32 @@ class GpuBufferStorageCvPixelBuffer CVPixelBufferGetPixelFormatType(**this)); } GlTextureView GetReadView(internal::types, - std::shared_ptr gpu_buffer, int plane) const override; GlTextureView GetWriteView(internal::types, - std::shared_ptr gpu_buffer, int plane) override; std::shared_ptr GetReadView( - internal::types, - std::shared_ptr gpu_buffer) const override; + internal::types) const override; std::shared_ptr GetWriteView( - internal::types, - std::shared_ptr gpu_buffer) override; + internal::types) override; + CFHolder GetReadView( + internal::types) const override; + CFHolder GetWriteView( + internal::types) override; private: - GlTextureView GetTexture(std::shared_ptr gpu_buffer, int plane, + GlTextureView GetTexture(int plane, GlTextureView::DoneWritingFn done_writing) const; - void ViewDoneWriting(const GlTextureView& view); }; +inline CFHolder GpuBufferStorageCvPixelBuffer::GetReadView( + internal::types) const { + return *this; +} +inline CFHolder GpuBufferStorageCvPixelBuffer::GetWriteView( + internal::types) { + return *this; +} + namespace internal { // These functions enable backward-compatible construction of a GpuBuffer from // CVPixelBufferRef without having to expose that type in the main GpuBuffer diff --git a/mediapipe/gpu/gpu_buffer_storage_image_frame.h b/mediapipe/gpu/gpu_buffer_storage_image_frame.h index 2cea3445e..ab547b9ea 100644 --- a/mediapipe/gpu/gpu_buffer_storage_image_frame.h +++ b/mediapipe/gpu/gpu_buffer_storage_image_frame.h @@ -29,13 +29,11 @@ class GpuBufferStorageImageFrame std::shared_ptr image_frame() const { return image_frame_; } std::shared_ptr image_frame() { return image_frame_; } std::shared_ptr GetReadView( - internal::types, - std::shared_ptr gpu_buffer) const override { + internal::types) const override { return image_frame_; } std::shared_ptr GetWriteView( - internal::types, - std::shared_ptr gpu_buffer) override { + internal::types) override { return image_frame_; } diff --git a/mediapipe/gpu/gpu_buffer_test.cc b/mediapipe/gpu/gpu_buffer_test.cc index 3fd519b21..e4be617db 100644 --- a/mediapipe/gpu/gpu_buffer_test.cc +++ b/mediapipe/gpu/gpu_buffer_test.cc @@ -14,10 +14,14 @@ #include "mediapipe/gpu/gpu_buffer.h" +#include + #include "mediapipe/framework/formats/image_format.pb.h" #include "mediapipe/framework/port/gmock.h" #include "mediapipe/framework/port/gtest.h" #include "mediapipe/framework/tool/test_util.h" +#include "mediapipe/gpu/gl_texture_buffer.h" +#include "mediapipe/gpu/gl_texture_util.h" #include "mediapipe/gpu/gpu_buffer_storage_ahwb.h" #include "mediapipe/gpu/gpu_buffer_storage_image_frame.h" #include "mediapipe/gpu/gpu_test_base.h" @@ -41,47 +45,6 @@ void FillImageFrameRGBA(ImageFrame& image, uint8 r, uint8 g, uint8 b, uint8 a) { } } -// Assumes a framebuffer is already set up -void CopyGlTexture(const GlTextureView& src, GlTextureView& dst) { - glViewport(0, 0, src.width(), src.height()); - glFramebufferTexture2D(GL_FRAMEBUFFER, GL_COLOR_ATTACHMENT0, src.target(), - src.name(), 0); - - glActiveTexture(GL_TEXTURE0); - glBindTexture(dst.target(), dst.name()); - glCopyTexSubImage2D(dst.target(), 0, 0, 0, 0, 0, dst.width(), dst.height()); - - glBindTexture(dst.target(), 0); - glFramebufferTexture2D(GL_FRAMEBUFFER, GL_COLOR_ATTACHMENT0, src.target(), 0, - 0); -} - -void FillGlTextureRgba(GlTextureView& view, float r, float g, float b, - float a) { - glViewport(0, 0, view.width(), view.height()); - glFramebufferTexture2D(GL_FRAMEBUFFER, GL_COLOR_ATTACHMENT0, view.target(), - view.name(), 0); - glClearColor(r, g, b, a); - glClear(GL_COLOR_BUFFER_BIT); - glFramebufferTexture2D(GL_FRAMEBUFFER, GL_COLOR_ATTACHMENT0, view.target(), 0, - 0); -} - -class TempGlFramebuffer { - public: - TempGlFramebuffer() { - glGenFramebuffers(1, &framebuffer_); - glBindFramebuffer(GL_FRAMEBUFFER, framebuffer_); - } - ~TempGlFramebuffer() { - glBindFramebuffer(GL_FRAMEBUFFER, 0); - glDeleteFramebuffers(1, &framebuffer_); - } - - private: - GLuint framebuffer_; -}; - class GpuBufferTest : public GpuTestBase {}; TEST_F(GpuBufferTest, BasicTest) { @@ -127,7 +90,7 @@ TEST_F(GpuBufferTest, GlTextureView) { ImageFrame red(ImageFormat::SRGBA, 300, 200); FillImageFrameRGBA(red, 255, 0, 0, 255); - EXPECT_TRUE(mediapipe::CompareImageFrames(*view, red, 0.0, 0.0)); + EXPECT_TRUE(CompareImageFrames(*view, red, 0.0, 0.0)); MP_EXPECT_OK(SavePngTestOutput(red, "gltv_red_gold")); MP_EXPECT_OK(SavePngTestOutput(*view, "gltv_red_view")); } @@ -162,7 +125,7 @@ TEST_F(GpuBufferTest, ImageFrame) { ImageFrame red(ImageFormat::SRGBA, 300, 200); FillImageFrameRGBA(red, 255, 0, 0, 255); - EXPECT_TRUE(mediapipe::CompareImageFrames(*view, red, 0.0, 0.0)); + EXPECT_TRUE(CompareImageFrames(*view, red, 0.0, 0.0)); MP_EXPECT_OK(SavePngTestOutput(red, "if_red_gold")); MP_EXPECT_OK(SavePngTestOutput(*view, "if_red_view")); } @@ -196,7 +159,7 @@ TEST_F(GpuBufferTest, Overwrite) { ImageFrame red(ImageFormat::SRGBA, 300, 200); FillImageFrameRGBA(red, 255, 0, 0, 255); - EXPECT_TRUE(mediapipe::CompareImageFrames(*view, red, 0.0, 0.0)); + EXPECT_TRUE(CompareImageFrames(*view, red, 0.0, 0.0)); MP_EXPECT_OK(SavePngTestOutput(red, "ow_red_gold")); MP_EXPECT_OK(SavePngTestOutput(*view, "ow_red_view")); } @@ -230,7 +193,7 @@ TEST_F(GpuBufferTest, Overwrite) { ImageFrame green(ImageFormat::SRGBA, 300, 200); FillImageFrameRGBA(green, 0, 255, 0, 255); - EXPECT_TRUE(mediapipe::CompareImageFrames(*view, green, 0.0, 0.0)); + EXPECT_TRUE(CompareImageFrames(*view, green, 0.0, 0.0)); MP_EXPECT_OK(SavePngTestOutput(green, "ow_green_gold")); MP_EXPECT_OK(SavePngTestOutput(*view, "ow_green_view")); } @@ -240,11 +203,52 @@ TEST_F(GpuBufferTest, Overwrite) { ImageFrame blue(ImageFormat::SRGBA, 300, 200); FillImageFrameRGBA(blue, 0, 0, 255, 255); - EXPECT_TRUE(mediapipe::CompareImageFrames(*view, blue, 0.0, 0.0)); + EXPECT_TRUE(CompareImageFrames(*view, blue, 0.0, 0.0)); MP_EXPECT_OK(SavePngTestOutput(blue, "ow_blue_gold")); MP_EXPECT_OK(SavePngTestOutput(*view, "ow_blue_view")); } } +TEST_F(GpuBufferTest, GlTextureViewRetainsWhatItNeeds) { + GpuBuffer buffer(300, 200, GpuBufferFormat::kBGRA32); + { + std::shared_ptr view = buffer.GetWriteView(); + EXPECT_EQ(view->Width(), 300); + EXPECT_EQ(view->Height(), 200); + FillImageFrameRGBA(*view, 255, 0, 0, 255); + } + + RunInGlContext([buffer = std::move(buffer)]() mutable { + // This is not a recommended pattern, but let's make sure that we don't + // crash if the buffer is released before the view. The view can hold + // callbacks into its underlying storage. + auto view = buffer.GetReadView(0); + buffer = nullptr; + }); + // We're really checking that we haven't crashed. + EXPECT_TRUE(true); +} + +TEST_F(GpuBufferTest, CopiesShareConversions) { + GpuBuffer buffer(300, 200, GpuBufferFormat::kBGRA32); + { + std::shared_ptr view = buffer.GetWriteView(); + FillImageFrameRGBA(*view, 255, 0, 0, 255); + } + + GpuBuffer other_handle = buffer; + RunInGlContext([&buffer] { + TempGlFramebuffer fb; + auto view = buffer.GetReadView(0); + }); + + // Check that other_handle also sees the same GlTextureBuffer as buffer. + // Note that this is deliberately written so that it still passes on platforms + // where we use another storage for GL textures (they will both be null). + // TODO: expose more accessors for testing? + EXPECT_EQ(other_handle.internal_storage(), + buffer.internal_storage()); +} + } // anonymous namespace } // namespace mediapipe diff --git a/mediapipe/gpu/gpu_shared_data_internal.cc b/mediapipe/gpu/gpu_shared_data_internal.cc index a8bf0c3a3..203a8dfd1 100644 --- a/mediapipe/gpu/gpu_shared_data_internal.cc +++ b/mediapipe/gpu/gpu_shared_data_internal.cc @@ -21,7 +21,7 @@ #include "mediapipe/gpu/graph_support.h" #if __APPLE__ -#import "mediapipe/gpu/MPPGraphGPUData.h" +#include "mediapipe/gpu/metal_shared_resources.h" #endif // __APPLE__ namespace mediapipe { @@ -80,28 +80,40 @@ GpuResources::StatusOrGpuResources GpuResources::Create( return gpu_resources; } -GpuResources::GpuResources(std::shared_ptr gl_context) { +GpuResources::GpuResources(std::shared_ptr gl_context) +#if MEDIAPIPE_GPU_BUFFER_USE_CV_PIXEL_BUFFER + : texture_caches_(std::make_shared()), + gpu_buffer_pool_( + [tc = texture_caches_](const internal::GpuBufferSpec& spec, + const MultiPoolOptions& options) { + return CvPixelBufferPoolWrapper::Create(spec, options, tc.get()); + }) +#endif // MEDIAPIPE_GPU_BUFFER_USE_CV_PIXEL_BUFFER +{ gl_key_context_[SharedContextKey()] = gl_context; named_executors_[kGpuExecutorName] = std::make_shared(gl_context.get()); #if __APPLE__ - gpu_buffer_pool().RegisterTextureCache(gl_context->cv_texture_cache()); - ios_gpu_data_ = [[MPPGraphGPUData alloc] initWithContext:gl_context.get() - multiPool:&gpu_buffer_pool_]; +#if MEDIAPIPE_GPU_BUFFER_USE_CV_PIXEL_BUFFER + texture_caches_->RegisterTextureCache(gl_context->cv_texture_cache()); +#endif // MEDIAPIPE_GPU_BUFFER_USE_CV_PIXEL_BUFFER + metal_shared_ = std::make_unique(); #endif // __APPLE__ } GpuResources::~GpuResources() { #if __APPLE__ - // Note: on Apple platforms, this object contains Objective-C objects. The - // destructor will release them, but ARC must be on. + // Note: on Apple platforms, this object contains Objective-C objects. + // The destructor will release them, but ARC must be on. #if !__has_feature(objc_arc) #error This file must be built with ARC. #endif +#if MEDIAPIPE_GPU_BUFFER_USE_CV_PIXEL_BUFFER for (auto& kv : gl_key_context_) { - gpu_buffer_pool().UnregisterTextureCache(kv.second->cv_texture_cache()); + texture_caches_->UnregisterTextureCache(kv.second->cv_texture_cache()); } -#endif +#endif // MEDIAPIPE_GPU_BUFFER_USE_CV_PIXEL_BUFFER +#endif // __APPLE__ } absl::Status GpuResources::PrepareGpuNode(CalculatorNode* node) { @@ -174,17 +186,43 @@ GlContext::StatusOrGlContext GpuResources::GetOrCreateGlContext( GlContext::Create(*gl_key_context_[SharedContextKey()], kGlContextUseDedicatedThread)); it = gl_key_context_.emplace(key, new_context).first; -#if __APPLE__ - gpu_buffer_pool_.RegisterTextureCache(it->second->cv_texture_cache()); -#endif +#if MEDIAPIPE_GPU_BUFFER_USE_CV_PIXEL_BUFFER + texture_caches_->RegisterTextureCache(it->second->cv_texture_cache()); +#endif // MEDIAPIPE_GPU_BUFFER_USE_CV_PIXEL_BUFFER } return it->second; } GpuSharedData::GpuSharedData() : GpuSharedData(kPlatformGlContextNone) {} -#if __APPLE__ -MPPGraphGPUData* GpuResources::ios_gpu_data() { return ios_gpu_data_; } -#endif // __APPLE__ +extern const GraphService kGpuService; + +#if !MEDIAPIPE_GPU_BUFFER_USE_CV_PIXEL_BUFFER +static std::shared_ptr GetGlTextureBufferFromPool( + int width, int height, GpuBufferFormat format) { + std::shared_ptr texture_buffer; + const auto cc = LegacyCalculatorSupport::Scoped::current(); + + if (cc && cc->Service(kGpuService).IsAvailable()) { + GpuBufferMultiPool* pool = + &cc->Service(kGpuService).GetObject().gpu_buffer_pool(); + // Note that the "gpu_buffer_pool" serves GlTextureBuffers on non-Apple + // platforms. TODO: refactor into storage pools. + texture_buffer = pool->GetBuffer(width, height, format) + .internal_storage(); + } else { + texture_buffer = GlTextureBuffer::Create(width, height, format); + } + return texture_buffer; +} + +static auto kGlTextureBufferPoolRegistration = [] { + // Ensure that the GlTextureBuffer's own factory is already registered, so we + // can override it. + GlTextureBuffer::RegisterOnce(); + return internal::GpuBufferStorageRegistry::Get() + .RegisterFactory(GetGlTextureBufferFromPool); +}(); +#endif // !MEDIAPIPE_GPU_BUFFER_USE_CV_PIXEL_BUFFER } // namespace mediapipe diff --git a/mediapipe/gpu/gpu_shared_data_internal.h b/mediapipe/gpu/gpu_shared_data_internal.h index 62d6bb27e..3f7c67e2e 100644 --- a/mediapipe/gpu/gpu_shared_data_internal.h +++ b/mediapipe/gpu/gpu_shared_data_internal.h @@ -30,15 +30,15 @@ #include "mediapipe/gpu/gpu_buffer_multi_pool.h" #ifdef __APPLE__ -#ifdef __OBJC__ -@class MPPGraphGPUData; -#else -struct MPPGraphGPUData; -#endif // __OBJC__ +#include "mediapipe/gpu/cv_texture_cache_manager.h" #endif // defined(__APPLE__) namespace mediapipe { +#ifdef __APPLE__ +class MetalSharedResources; +#endif // defined(__APPLE__) + // TODO: rename to GpuService or GpuManager or something. class GpuResources { public: @@ -55,9 +55,7 @@ class GpuResources { // Shared GL context for calculators. // TODO: require passing a context or node identifier. - const std::shared_ptr& gl_context() { - return gl_context(nullptr); - }; + const std::shared_ptr& gl_context() { return gl_context(nullptr); } const std::shared_ptr& gl_context(CalculatorContext* cc); @@ -65,7 +63,7 @@ class GpuResources { GpuBufferMultiPool& gpu_buffer_pool() { return gpu_buffer_pool_; } #ifdef __APPLE__ - MPPGraphGPUData* ios_gpu_data(); + MetalSharedResources& metal_shared() { return *metal_shared_; } #endif // defined(__APPLE__)§ absl::Status PrepareGpuNode(CalculatorNode* node); @@ -86,13 +84,16 @@ class GpuResources { std::map node_key_; std::map> gl_key_context_; +#ifdef MEDIAPIPE_GPU_BUFFER_USE_CV_PIXEL_BUFFER + std::shared_ptr texture_caches_; +#endif // MEDIAPIPE_GPU_BUFFER_USE_CV_PIXEL_BUFFER + // The pool must be destructed before the gl_context, but after the // ios_gpu_data, so the declaration order is important. GpuBufferMultiPool gpu_buffer_pool_; #ifdef __APPLE__ - // Note that this is an Objective-C object. - MPPGraphGPUData* ios_gpu_data_; + std::unique_ptr metal_shared_; #endif // defined(__APPLE__) std::map> named_executors_; diff --git a/mediapipe/gpu/gpu_test_base.h b/mediapipe/gpu/gpu_test_base.h index e9fd64725..6ec53603b 100644 --- a/mediapipe/gpu/gpu_test_base.h +++ b/mediapipe/gpu/gpu_test_base.h @@ -24,13 +24,14 @@ namespace mediapipe { class GpuTestBase : public ::testing::Test { protected: - GpuTestBase() { helper_.InitializeForTest(&gpu_shared_); } + GpuTestBase() { helper_.InitializeForTest(gpu_resources_.get()); } void RunInGlContext(std::function gl_func) { helper_.RunInGlContext(std::move(gl_func)); } GpuSharedData gpu_shared_; + std::shared_ptr gpu_resources_ = gpu_shared_.gpu_resources; GlCalculatorHelper helper_; }; diff --git a/mediapipe/gpu/image_frame_to_gpu_buffer_calculator.cc b/mediapipe/gpu/image_frame_to_gpu_buffer_calculator.cc index 2a8331db8..c67fb0c62 100644 --- a/mediapipe/gpu/image_frame_to_gpu_buffer_calculator.cc +++ b/mediapipe/gpu/image_frame_to_gpu_buffer_calculator.cc @@ -12,73 +12,63 @@ // See the License for the specific language governing permissions and // limitations under the License. +#include "mediapipe/framework/api2/node.h" #include "mediapipe/framework/calculator_framework.h" #include "mediapipe/framework/formats/image_frame.h" #include "mediapipe/framework/port/status.h" #include "mediapipe/gpu/gl_calculator_helper.h" -#ifdef __APPLE__ -#include "mediapipe/objc/util.h" -#endif - namespace mediapipe { +namespace api2 { -// Convert ImageFrame to GpuBuffer. -class ImageFrameToGpuBufferCalculator : public CalculatorBase { +class ImageFrameToGpuBufferCalculator + : public RegisteredNode { public: - ImageFrameToGpuBufferCalculator() {} + static constexpr Input kIn{""}; + static constexpr Output kOut{""}; - static absl::Status GetContract(CalculatorContract* cc); + MEDIAPIPE_NODE_INTERFACE(ImageFrameToGpuBufferCalculator, kIn, kOut); + + static absl::Status UpdateContract(CalculatorContract* cc); absl::Status Open(CalculatorContext* cc) override; absl::Status Process(CalculatorContext* cc) override; private: -#if !MEDIAPIPE_GPU_BUFFER_USE_CV_PIXEL_BUFFER GlCalculatorHelper helper_; -#endif // !MEDIAPIPE_GPU_BUFFER_USE_CV_PIXEL_BUFFER }; -REGISTER_CALCULATOR(ImageFrameToGpuBufferCalculator); // static -absl::Status ImageFrameToGpuBufferCalculator::GetContract( +absl::Status ImageFrameToGpuBufferCalculator::UpdateContract( CalculatorContract* cc) { - cc->Inputs().Index(0).Set(); - cc->Outputs().Index(0).Set(); // Note: we call this method even on platforms where we don't use the helper, // to ensure the calculator's contract is the same. In particular, the helper // enables support for the legacy side packet, which several graphs still use. - MP_RETURN_IF_ERROR(GlCalculatorHelper::UpdateContract(cc)); - return absl::OkStatus(); + return GlCalculatorHelper::UpdateContract(cc); } absl::Status ImageFrameToGpuBufferCalculator::Open(CalculatorContext* cc) { - // Inform the framework that we always output at the same timestamp - // as we receive a packet at. - cc->SetOffset(TimestampDiff(0)); -#if !MEDIAPIPE_GPU_BUFFER_USE_CV_PIXEL_BUFFER MP_RETURN_IF_ERROR(helper_.Open(cc)); -#endif // !MEDIAPIPE_GPU_BUFFER_USE_CV_PIXEL_BUFFER return absl::OkStatus(); } absl::Status ImageFrameToGpuBufferCalculator::Process(CalculatorContext* cc) { -#if MEDIAPIPE_GPU_BUFFER_USE_CV_PIXEL_BUFFER - CFHolder buffer; - MP_RETURN_IF_ERROR(CreateCVPixelBufferForImageFramePacket( - cc->Inputs().Index(0).Value(), &buffer)); - cc->Outputs().Index(0).Add(new GpuBuffer(buffer), cc->InputTimestamp()); -#else - const auto& input = cc->Inputs().Index(0).Get(); - helper_.RunInGlContext([this, &input, &cc]() { - auto src = helper_.CreateSourceTexture(input); - auto output = src.GetFrame(); - glFlush(); - cc->Outputs().Index(0).Add(output.release(), cc->InputTimestamp()); - src.Release(); - }); -#endif // MEDIAPIPE_GPU_BUFFER_USE_CV_PIXEL_BUFFER + auto image_frame = std::const_pointer_cast( + mediapipe::SharedPtrWithPacket(kIn(cc).packet())); + auto gpu_buffer = api2::MakePacket( + std::make_shared( + std::move(image_frame))) + .At(cc->InputTimestamp()); + // This calculator's behavior has been to do the texture upload eagerly, and + // some graphs may rely on running this on a separate GL context to avoid + // blocking another context with the read operation. So let's request GPU + // access here to ensure that the behavior stays the same. + // TODO: have a better way to do this, or defer until later. + helper_.RunInGlContext( + [&gpu_buffer] { auto view = gpu_buffer->GetReadView(0); }); + kOut(cc).Send(std::move(gpu_buffer)); return absl::OkStatus(); } +} // namespace api2 } // namespace mediapipe diff --git a/mediapipe/gpu/image_frame_view.h b/mediapipe/gpu/image_frame_view.h index 2fc6f2495..b7e58a824 100644 --- a/mediapipe/gpu/image_frame_view.h +++ b/mediapipe/gpu/image_frame_view.h @@ -12,9 +12,8 @@ class ViewProvider { public: virtual ~ViewProvider() = default; virtual std::shared_ptr GetReadView( - types, std::shared_ptr gpu_buffer) const = 0; - virtual std::shared_ptr GetWriteView( - types, std::shared_ptr gpu_buffer) = 0; + types) const = 0; + virtual std::shared_ptr GetWriteView(types) = 0; }; } // namespace internal diff --git a/mediapipe/gpu/metal_shared_resources.h b/mediapipe/gpu/metal_shared_resources.h new file mode 100644 index 000000000..341860a2d --- /dev/null +++ b/mediapipe/gpu/metal_shared_resources.h @@ -0,0 +1,40 @@ +#ifndef MEDIAPIPE_GPU_METAL_SHARED_RESOURCES_H_ +#define MEDIAPIPE_GPU_METAL_SHARED_RESOURCES_H_ + +#import +#import +#import +#import + +#ifndef __OBJC__ +#error This class must be built as Objective-C++. +#endif // !__OBJC__ + +@interface MPPMetalSharedResources : NSObject { +} + +- (instancetype)init NS_DESIGNATED_INITIALIZER; + +@property(readonly) id mtlDevice; +@property(readonly) id mtlCommandQueue; +#if COREVIDEO_SUPPORTS_METAL +@property(readonly) CVMetalTextureCacheRef mtlTextureCache; +#endif + +@end + +namespace mediapipe { + +class MetalSharedResources { + public: + MetalSharedResources(); + ~MetalSharedResources(); + MPPMetalSharedResources* resources() { return resources_; } + + private: + MPPMetalSharedResources* resources_; +}; + +} // namespace mediapipe + +#endif // MEDIAPIPE_GPU_METAL_SHARED_RESOURCES_H_ diff --git a/mediapipe/gpu/metal_shared_resources.mm b/mediapipe/gpu/metal_shared_resources.mm new file mode 100644 index 000000000..80d755a01 --- /dev/null +++ b/mediapipe/gpu/metal_shared_resources.mm @@ -0,0 +1,73 @@ +#import "mediapipe/gpu/metal_shared_resources.h" + +@interface MPPMetalSharedResources () +@end + +@implementation MPPMetalSharedResources { +} + +@synthesize mtlDevice = _mtlDevice; +@synthesize mtlCommandQueue = _mtlCommandQueue; +#if COREVIDEO_SUPPORTS_METAL +@synthesize mtlTextureCache = _mtlTextureCache; +#endif + +- (instancetype)init { + self = [super init]; + if (self) { + } + return self; +} + +- (void)dealloc { +#if COREVIDEO_SUPPORTS_METAL + if (_mtlTextureCache) { + CFRelease(_mtlTextureCache); + _mtlTextureCache = NULL; + } +#endif +} + +- (id)mtlDevice { + @synchronized(self) { + if (!_mtlDevice) { + _mtlDevice = MTLCreateSystemDefaultDevice(); + } + } + return _mtlDevice; +} + +- (id)mtlCommandQueue { + @synchronized(self) { + if (!_mtlCommandQueue) { + _mtlCommandQueue = [self.mtlDevice newCommandQueue]; + } + } + return _mtlCommandQueue; +} + +#if COREVIDEO_SUPPORTS_METAL +- (CVMetalTextureCacheRef)mtlTextureCache { + @synchronized(self) { + if (!_mtlTextureCache) { + CVReturn __unused err = + CVMetalTextureCacheCreate(NULL, NULL, self.mtlDevice, NULL, &_mtlTextureCache); + NSAssert(err == kCVReturnSuccess, @"Error at CVMetalTextureCacheCreate %d ; device %@", err, + self.mtlDevice); + // TODO: register and flush metal caches too. + } + } + return _mtlTextureCache; +} +#endif + +@end + +namespace mediapipe { + +MetalSharedResources::MetalSharedResources() { + resources_ = [[MPPMetalSharedResources alloc] init]; +} +MetalSharedResources::~MetalSharedResources() {} + +} // namespace mediapipe diff --git a/mediapipe/gpu/metal_shared_resources_test.mm b/mediapipe/gpu/metal_shared_resources_test.mm new file mode 100644 index 000000000..9eb53a9b7 --- /dev/null +++ b/mediapipe/gpu/metal_shared_resources_test.mm @@ -0,0 +1,49 @@ +// Copyright 2019 The MediaPipe Authors. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#import +#import + +#include + +#include "absl/memory/memory.h" +#include "mediapipe/framework/port/threadpool.h" + +#import "mediapipe/gpu/gpu_shared_data_internal.h" +#import "mediapipe/gpu/metal_shared_resources.h" + +@interface MPPMetalSharedResourcesTests : XCTestCase { +} +@end + +@implementation MPPMetalSharedResourcesTests + +// This test verifies that the internal Objective-C object is correctly +// released when the C++ wrapper is released. +- (void)testCorrectlyReleased { + __weak id metalRes = nil; + std::weak_ptr weakGpuRes; + @autoreleasepool { + auto maybeGpuRes = mediapipe::GpuResources::Create(); + XCTAssertTrue(maybeGpuRes.ok()); + weakGpuRes = *maybeGpuRes; + metalRes = (**maybeGpuRes).metal_shared().resources(); + XCTAssertNotEqual(weakGpuRes.lock(), nullptr); + XCTAssertNotNil(metalRes); + } + XCTAssertEqual(weakGpuRes.lock(), nullptr); + XCTAssertNil(metalRes); +} + +@end diff --git a/mediapipe/gpu/multi_pool.h b/mediapipe/gpu/multi_pool.h new file mode 100644 index 000000000..e677c3bbf --- /dev/null +++ b/mediapipe/gpu/multi_pool.h @@ -0,0 +1,119 @@ +// Copyright 2019 The MediaPipe Authors. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#ifndef MEDIAPIPE_GPU_MULTI_POOL_H_ +#define MEDIAPIPE_GPU_MULTI_POOL_H_ + +#include "mediapipe/util/resource_cache.h" + +namespace mediapipe { + +struct MultiPoolOptions { + // Keep this many buffers allocated for a given frame size. + int keep_count = 2; + // The maximum size of the GpuBufferMultiPool. When the limit is reached, the + // oldest BufferSpec will be dropped. + int max_pool_count = 10; + // Time in seconds after which an inactive buffer can be dropped from the + // pool. Currently only used with CVPixelBufferPool. + float max_inactive_buffer_age = 0.25; + // Skip allocating a buffer pool until at least this many requests have been + // made for a given BufferSpec. + int min_requests_before_pool = 2; + // Do a deeper flush every this many requests. + int request_count_scrub_interval = 50; +}; + +static constexpr MultiPoolOptions kDefaultMultiPoolOptions; + +// MultiPool is a generic class for vending reusable resources of type Item, +// which are assumed to be relatively expensive to create, so that reusing them +// is beneficial. +// Items are classified by Spec; when an item with a given Spec is requested, +// an old Item with the same Spec can be reused, if available; otherwise a new +// Item will be created. When user code is done with an Item, it is returned +// to the pool for reuse. +// In order to manage this, a MultiPool contains a map of Specs to SimplePool; +// each SimplePool manages Items with the same Spec, which are thus considered +// interchangeable. +// Item retention and eviction policies are controlled by options. +// A concrete example would be a pool of GlTextureBuffer, grouped by dimensions +// and format. +template +class MultiPool { + public: + using SimplePoolFactory = std::function( + const Spec& spec, const MultiPoolOptions& options)>; + + MultiPool(SimplePoolFactory factory = DefaultMakeSimplePool, + MultiPoolOptions options = kDefaultMultiPoolOptions) + : create_simple_pool_(factory), options_(options) {} + explicit MultiPool(MultiPoolOptions options) + : MultiPool(DefaultMakeSimplePool, options) {} + + // Obtains an item. May either be reused or created anew. + Item Get(const Spec& spec); + + private: + static std::shared_ptr DefaultMakeSimplePool( + const Spec& spec, const MultiPoolOptions& options) { + return SimplePool::Create(spec, options); + } + + // Requests a simple buffer pool for the given spec. This may return nullptr + // if we have not yet reached a sufficient number of requests to allocate a + // pool, in which case the caller should invoke CreateBufferWithoutPool. + std::shared_ptr RequestPool(const Spec& spec); + + absl::Mutex mutex_; + mediapipe::ResourceCache> cache_ + ABSL_GUARDED_BY(mutex_); + SimplePoolFactory create_simple_pool_ = DefaultMakeSimplePool; + MultiPoolOptions options_; +}; + +template +std::shared_ptr MultiPool::RequestPool( + const Spec& spec) { + std::shared_ptr pool; + std::vector> evicted; + { + absl::MutexLock lock(&mutex_); + pool = cache_.Lookup(spec, [this](const Spec& spec, int request_count) { + return (request_count >= options_.min_requests_before_pool) + ? create_simple_pool_(spec, options_) + : nullptr; + }); + evicted = cache_.Evict(options_.max_pool_count, + options_.request_count_scrub_interval); + } + // Evicted pools, and their buffers, will be released without holding the + // lock. + return pool; +} + +template +Item MultiPool::Get(const Spec& spec) { + std::shared_ptr pool = RequestPool(spec); + if (pool) { + // Note: we release our multipool lock before accessing the simple pool. + return Item(pool->GetBuffer()); + } else { + return Item(SimplePool::CreateBufferWithoutPool(spec)); + } +} + +} // namespace mediapipe + +#endif // MEDIAPIPE_GPU_MULTI_POOL_H_ diff --git a/mediapipe/gpu/reusable_pool.h b/mediapipe/gpu/reusable_pool.h new file mode 100644 index 000000000..ddeaa5ba7 --- /dev/null +++ b/mediapipe/gpu/reusable_pool.h @@ -0,0 +1,145 @@ +// Copyright 2019 The MediaPipe Authors. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +// Consider this file an implementation detail. None of this is part of the +// public API. + +#ifndef MEDIAPIPE_GPU_REUSABLE_POOL_H_ +#define MEDIAPIPE_GPU_REUSABLE_POOL_H_ + +#include +#include + +#include "absl/functional/any_invocable.h" +#include "absl/synchronization/mutex.h" +#include "mediapipe/gpu/multi_pool.h" + +namespace mediapipe { + +template +class ReusablePool : public std::enable_shared_from_this> { + public: + using ItemFactory = absl::AnyInvocable() const>; + + // Creates a pool. This pool will manage buffers of the specified dimensions, + // and will keep keep_count buffers around for reuse. + // We enforce creation as a shared_ptr so that we can use a weak reference in + // the buffers' deleters. + static std::shared_ptr> Create( + ItemFactory item_factory, const MultiPoolOptions& options) { + return std::shared_ptr>( + new ReusablePool(std::move(item_factory), options)); + } + + // Obtains a buffer. May either be reused or created anew. + // A GlContext must be current when this is called. + std::shared_ptr GetBuffer(); + + // This method is meant for testing. + std::pair GetInUseAndAvailableCounts(); + + protected: + ReusablePool(ItemFactory item_factory, const MultiPoolOptions& options) + : item_factory_(std::move(item_factory)), + keep_count_(options.keep_count) {} + + private: + // Return a buffer to the pool. + void Return(std::unique_ptr buf); + + // If the total number of buffers is greater than keep_count, destroys any + // surplus buffers that are no longer in use. + void TrimAvailable(std::vector>* trimmed) + ABSL_EXCLUSIVE_LOCKS_REQUIRED(mutex_); + + const ItemFactory item_factory_; + const int keep_count_; + + absl::Mutex mutex_; + int in_use_count_ ABSL_GUARDED_BY(mutex_) = 0; + std::vector> available_ ABSL_GUARDED_BY(mutex_); +}; + +template +inline std::shared_ptr ReusablePool::GetBuffer() { + std::unique_ptr buffer; + bool reuse = false; + + { + absl::MutexLock lock(&mutex_); + if (available_.empty()) { + buffer = item_factory_(); + if (!buffer) return nullptr; + } else { + buffer = std::move(available_.back()); + available_.pop_back(); + reuse = true; + } + + ++in_use_count_; + } + + // This needs to wait on consumer sync points, therefore it should not be + // done while holding the mutex. + if (reuse) { + buffer->Reuse(); + } + + // Return a shared_ptr with a custom deleter that adds the buffer back + // to our available list. + std::weak_ptr> weak_pool(this->shared_from_this()); + return std::shared_ptr(buffer.release(), [weak_pool](Item* buf) { + auto pool = weak_pool.lock(); + if (pool) { + pool->Return(absl::WrapUnique(buf)); + } else { + delete buf; + } + }); +} + +template +inline std::pair ReusablePool::GetInUseAndAvailableCounts() { + absl::MutexLock lock(&mutex_); + return {in_use_count_, available_.size()}; +} + +template +void ReusablePool::Return(std::unique_ptr buf) { + std::vector> trimmed; + { + absl::MutexLock lock(&mutex_); + --in_use_count_; + available_.emplace_back(std::move(buf)); + TrimAvailable(&trimmed); + } + // The trimmed buffers will be released without holding the lock. +} + +template +void ReusablePool::TrimAvailable( + std::vector>* trimmed) { + int keep = std::max(keep_count_ - in_use_count_, 0); + if (available_.size() > keep) { + auto trim_it = std::next(available_.begin(), keep); + if (trimmed) { + std::move(trim_it, available_.end(), std::back_inserter(*trimmed)); + } + available_.erase(trim_it, available_.end()); + } +} + +} // namespace mediapipe + +#endif // MEDIAPIPE_GPU_REUSABLE_POOL_H_ diff --git a/mediapipe/gpu/shader_util.cc b/mediapipe/gpu/shader_util.cc index 2132cbda9..5de7e24f5 100644 --- a/mediapipe/gpu/shader_util.cc +++ b/mediapipe/gpu/shader_util.cc @@ -140,7 +140,7 @@ GLint GlhCreateProgram(const GLchar* vert_src, const GLchar* frag_src, glBindAttribLocation(*program, attr_locations[i], attr_names[i]); } - ok = GlhLinkProgram(*program); + ok = GlhLinkProgram(*program, force_log_errors); } if (vert_shader) glDeleteShader(vert_shader); diff --git a/mediapipe/graphs/hair_segmentation/BUILD b/mediapipe/graphs/hair_segmentation/BUILD index b177726bf..945f02c62 100644 --- a/mediapipe/graphs/hair_segmentation/BUILD +++ b/mediapipe/graphs/hair_segmentation/BUILD @@ -43,6 +43,7 @@ cc_library( deps = [ "//mediapipe/calculators/core:flow_limiter_calculator", "//mediapipe/calculators/core:previous_loopback_calculator", + "//mediapipe/calculators/image:color_convert_calculator", "//mediapipe/calculators/image:image_transformation_calculator", "//mediapipe/calculators/image:recolor_calculator", "//mediapipe/calculators/image:set_alpha_calculator", diff --git a/mediapipe/graphs/hair_segmentation/hair_segmentation_desktop_live.pbtxt b/mediapipe/graphs/hair_segmentation/hair_segmentation_desktop_live.pbtxt index 36c6970e1..f48b26be0 100644 --- a/mediapipe/graphs/hair_segmentation/hair_segmentation_desktop_live.pbtxt +++ b/mediapipe/graphs/hair_segmentation/hair_segmentation_desktop_live.pbtxt @@ -60,7 +60,14 @@ node { tag_index: "LOOP" back_edge: true } - output_stream: "PREV_LOOP:previous_hair_mask" + output_stream: "PREV_LOOP:previous_hair_mask_rgb" +} + +# Converts the 4 channel hair mask to a single channel mask +node { + calculator: "ColorConvertCalculator" + input_stream: "RGB_IN:previous_hair_mask_rgb" + output_stream: "GRAY_OUT:previous_hair_mask" } # Embeds the hair mask generated from the previous round of hair segmentation diff --git a/mediapipe/graphs/iris_tracking/calculators/BUILD b/mediapipe/graphs/iris_tracking/calculators/BUILD index 3a3d57a0f..f5124b464 100644 --- a/mediapipe/graphs/iris_tracking/calculators/BUILD +++ b/mediapipe/graphs/iris_tracking/calculators/BUILD @@ -97,7 +97,6 @@ cc_library( visibility = ["//visibility:public"], deps = [ "//mediapipe/framework:calculator_framework", - "//mediapipe/framework/formats:image_file_properties_cc_proto", "//mediapipe/framework/formats:landmark_cc_proto", "//mediapipe/framework/port:ret_check", "//mediapipe/framework/port:status", diff --git a/mediapipe/graphs/object_detection_3d/calculators/BUILD b/mediapipe/graphs/object_detection_3d/calculators/BUILD index 783fff187..d4c5c496b 100644 --- a/mediapipe/graphs/object_detection_3d/calculators/BUILD +++ b/mediapipe/graphs/object_detection_3d/calculators/BUILD @@ -22,6 +22,7 @@ package(default_visibility = ["//visibility:public"]) mediapipe_proto_library( name = "gl_animation_overlay_calculator_proto", srcs = ["gl_animation_overlay_calculator.proto"], + def_options_lib = False, visibility = ["//visibility:public"], exports = [ "//mediapipe/gpu:gl_animation_overlay_calculator_proto", diff --git a/mediapipe/java/com/google/mediapipe/framework/AndroidPacketCreator.java b/mediapipe/java/com/google/mediapipe/framework/AndroidPacketCreator.java index 05700ba17..fc1e5484e 100644 --- a/mediapipe/java/com/google/mediapipe/framework/AndroidPacketCreator.java +++ b/mediapipe/java/com/google/mediapipe/framework/AndroidPacketCreator.java @@ -15,10 +15,13 @@ package com.google.mediapipe.framework; import android.graphics.Bitmap; +import android.graphics.PixelFormat; +import android.media.Image; import com.google.mediapipe.framework.image.BitmapExtractor; import com.google.mediapipe.framework.image.ByteBufferExtractor; import com.google.mediapipe.framework.image.MPImage; import com.google.mediapipe.framework.image.MPImageProperties; +import com.google.mediapipe.framework.image.MediaImageExtractor; import java.nio.ByteBuffer; // TODO: use Preconditions in this file. @@ -97,7 +100,17 @@ public class AndroidPacketCreator extends PacketCreator { } return Packet.create(nativeCreateRgbaImage(mediapipeGraph.getNativeHandle(), bitmap)); } - + if (properties.getStorageType() == MPImage.STORAGE_TYPE_MEDIA_IMAGE) { + Image mediaImage = MediaImageExtractor.extract(image); + if (mediaImage.getFormat() != PixelFormat.RGBA_8888) { + throw new UnsupportedOperationException("Android media image must use RGBA_8888 config."); + } + return createImage( + mediaImage.getPlanes()[0].getBuffer(), + mediaImage.getWidth(), + mediaImage.getHeight(), + /* numChannels= */ 4); + } // Unsupported type. throw new UnsupportedOperationException( "Unsupported Image container type: " + properties.getStorageType()); diff --git a/mediapipe/java/com/google/mediapipe/framework/GraphTextureFrame.java b/mediapipe/java/com/google/mediapipe/framework/GraphTextureFrame.java index efaec34a7..6a2c97b94 100644 --- a/mediapipe/java/com/google/mediapipe/framework/GraphTextureFrame.java +++ b/mediapipe/java/com/google/mediapipe/framework/GraphTextureFrame.java @@ -14,6 +14,10 @@ package com.google.mediapipe.framework; +import com.google.common.flogger.FluentLogger; +import java.util.HashSet; +import java.util.Set; + /** * A {@link TextureFrame} that represents a texture produced by MediaPipe. * @@ -21,6 +25,7 @@ package com.google.mediapipe.framework; * method. */ public class GraphTextureFrame implements TextureFrame { + private static final FluentLogger logger = FluentLogger.forEnclosingClass(); private long nativeBufferHandle; // We cache these to be able to get them without a JNI call. private int textureName; @@ -30,6 +35,8 @@ public class GraphTextureFrame implements TextureFrame { // True when created with PacketGetter.getTextureFrameDeferredSync(). This will result in gpuWait // when calling getTextureName(). private final boolean deferredSync; + private final Set activeConsumerContextHandleSet = new HashSet<>(); + private int refCount = 1; GraphTextureFrame(long nativeHandle, long timestamp) { this(nativeHandle, timestamp, false); @@ -54,17 +61,19 @@ public class GraphTextureFrame implements TextureFrame { * condition if release() is called after the if-check for nativeBufferHandle is already passed. */ @Override - public int getTextureName() { + public synchronized int getTextureName() { // Return special texture id 0 if handle is 0 i.e. frame is already released. if (nativeBufferHandle == 0) { return 0; } - // Gpu wait only if deferredSync is true, such as when this GraphTextureFrame is created using - // PacketGetter.getTextureFrameDeferredSync(). - if (deferredSync) { - // Note that, if a CPU wait has already been done, the sync point will have been - // cleared and this will turn into a no-op. See GlFenceSyncPoint::Wait. - nativeGpuWait(nativeBufferHandle); + if (activeConsumerContextHandleSet.add(nativeGetCurrentExternalContextHandle())) { + // Gpu wait only if deferredSync is true, such as when this GraphTextureFrame is created using + // PacketGetter.getTextureFrameDeferredSync(). + if (deferredSync) { + // Note that, if a CPU wait has already been done, the sync point will have been + // cleared and this will turn into a no-op. See GlFenceSyncPoint::Wait. + nativeGpuWait(nativeBufferHandle); + } } return textureName; } @@ -86,15 +95,31 @@ public class GraphTextureFrame implements TextureFrame { return timestamp; } + @Override + public boolean supportsRetain() { + return true; + } + + @Override + public synchronized void retain() { + // TODO: check that refCount is > 0 and handle is not 0. + refCount++; + } + /** * Releases a reference to the underlying buffer. * *

The consumer calls this when it is done using the texture. */ @Override - public void release() { - GlSyncToken consumerToken = - new GraphGlSyncToken(nativeCreateSyncTokenForCurrentExternalContext(nativeBufferHandle)); + public synchronized void release() { + GlSyncToken consumerToken = null; + // Note that this remove should be moved to the other overload of release when b/68808951 is + // addressed. + if (activeConsumerContextHandleSet.remove(nativeGetCurrentExternalContextHandle())) { + consumerToken = + new GraphGlSyncToken(nativeCreateSyncTokenForCurrentExternalContext(nativeBufferHandle)); + } release(consumerToken); } @@ -108,18 +133,40 @@ public class GraphTextureFrame implements TextureFrame { * currently cannot create a GlSyncToken, so they cannot call this method. */ @Override - public void release(GlSyncToken consumerSyncToken) { - if (nativeBufferHandle != 0) { - long token = consumerSyncToken == null ? 0 : consumerSyncToken.nativeToken(); - nativeReleaseBuffer(nativeBufferHandle, token); - nativeBufferHandle = 0; + public synchronized void release(GlSyncToken consumerSyncToken) { + if (nativeBufferHandle == 0) { + if (consumerSyncToken != null) { + logger.atWarning().log("release with sync token, but handle is 0"); + } + return; } + if (consumerSyncToken != null) { + long token = consumerSyncToken.nativeToken(); + nativeDidRead(nativeBufferHandle, token); + // We should remove the token's context from activeConsumerContextHandleSet here, but for now + // we do it in the release(void) overload. consumerSyncToken.release(); } + + refCount--; + if (refCount <= 0) { + nativeReleaseBuffer(nativeBufferHandle); + nativeBufferHandle = 0; + } } - private native void nativeReleaseBuffer(long nativeHandle, long consumerSyncToken); + @Override + protected void finalize() throws Throwable { + if (refCount > 0 || nativeBufferHandle != 0) { + logger.atWarning().log("release was not called before finalize"); + } + if (!activeConsumerContextHandleSet.isEmpty()) { + logger.atWarning().log("active consumers did not release with sync before finalize"); + } + } + + private native void nativeReleaseBuffer(long nativeHandle); private native int nativeGetTextureName(long nativeHandle); private native int nativeGetWidth(long nativeHandle); @@ -128,4 +175,8 @@ public class GraphTextureFrame implements TextureFrame { private native void nativeGpuWait(long nativeHandle); private native long nativeCreateSyncTokenForCurrentExternalContext(long nativeHandle); + + private native long nativeGetCurrentExternalContextHandle(); + + private native void nativeDidRead(long nativeHandle, long consumerSyncToken); } diff --git a/mediapipe/java/com/google/mediapipe/framework/PacketCreator.java b/mediapipe/java/com/google/mediapipe/framework/PacketCreator.java index d93eea7b5..04265cab5 100644 --- a/mediapipe/java/com/google/mediapipe/framework/PacketCreator.java +++ b/mediapipe/java/com/google/mediapipe/framework/PacketCreator.java @@ -55,7 +55,11 @@ public class PacketCreator { public Packet createRgbImage(ByteBuffer buffer, int width, int height) { int widthStep = (((width * 3) + 3) / 4) * 4; if (widthStep * height != buffer.capacity()) { - throw new RuntimeException("The size of the buffer should be: " + widthStep * height); + throw new IllegalArgumentException( + "The size of the buffer should be: " + + widthStep * height + + " but is " + + buffer.capacity()); } return Packet.create( nativeCreateRgbImage(mediapipeGraph.getNativeHandle(), buffer, width, height)); @@ -123,7 +127,11 @@ public class PacketCreator { */ public Packet createRgbImageFromRgba(ByteBuffer buffer, int width, int height) { if (width * height * 4 != buffer.capacity()) { - throw new RuntimeException("The size of the buffer should be: " + width * height * 4); + throw new IllegalArgumentException( + "The size of the buffer should be: " + + width * height * 4 + + " but is " + + buffer.capacity()); } return Packet.create( nativeCreateRgbImageFromRgba(mediapipeGraph.getNativeHandle(), buffer, width, height)); @@ -136,7 +144,7 @@ public class PacketCreator { */ public Packet createGrayscaleImage(ByteBuffer buffer, int width, int height) { if (width * height != buffer.capacity()) { - throw new RuntimeException( + throw new IllegalArgumentException( "The size of the buffer should be: " + width * height + " but is " + buffer.capacity()); } return Packet.create( @@ -150,7 +158,11 @@ public class PacketCreator { */ public Packet createRgbaImageFrame(ByteBuffer buffer, int width, int height) { if (buffer.capacity() != width * height * 4) { - throw new RuntimeException("buffer doesn't have the correct size."); + throw new IllegalArgumentException( + "The size of the buffer should be: " + + width * height * 4 + + " but is " + + buffer.capacity()); } return Packet.create( nativeCreateRgbaImageFrame(mediapipeGraph.getNativeHandle(), buffer, width, height)); @@ -163,7 +175,11 @@ public class PacketCreator { */ public Packet createFloatImageFrame(FloatBuffer buffer, int width, int height) { if (buffer.capacity() != width * height * 4) { - throw new RuntimeException("buffer doesn't have the correct size."); + throw new IllegalArgumentException( + "The size of the buffer should be: " + + width * height * 4 + + " but is " + + buffer.capacity()); } return Packet.create( nativeCreateFloatImageFrame(mediapipeGraph.getNativeHandle(), buffer, width, height)); @@ -354,25 +370,24 @@ public class PacketCreator { *

For 3 and 4 channel images, the pixel rows should have 4-byte alignment. */ public Packet createImage(ByteBuffer buffer, int width, int height, int numChannels) { + int widthStep; if (numChannels == 4) { - if (buffer.capacity() != width * height * 4) { - throw new RuntimeException("buffer doesn't have the correct size."); - } + widthStep = width * 4; } else if (numChannels == 3) { - int widthStep = (((width * 3) + 3) / 4) * 4; - if (widthStep * height != buffer.capacity()) { - throw new RuntimeException("The size of the buffer should be: " + widthStep * height); - } + widthStep = (((width * 3) + 3) / 4) * 4; } else if (numChannels == 1) { - if (width * height != buffer.capacity()) { - throw new RuntimeException( - "The size of the buffer should be: " + width * height + " but is " + buffer.capacity()); - } + widthStep = width; } else { - throw new RuntimeException("Channels should be: 1, 3, or 4, but is " + numChannels); + throw new IllegalArgumentException("Channels should be: 1, 3, or 4, but is " + numChannels); + } + int expectedSize = widthStep * height; + if (buffer.capacity() != expectedSize) { + throw new IllegalArgumentException( + "The size of the buffer should be: " + expectedSize + " but is " + buffer.capacity()); } return Packet.create( - nativeCreateCpuImage(mediapipeGraph.getNativeHandle(), buffer, width, height, numChannels)); + nativeCreateCpuImage( + mediapipeGraph.getNativeHandle(), buffer, width, height, widthStep, numChannels)); } /** Helper callback adaptor to create the Java {@link GlSyncToken}. This is called by JNI code. */ @@ -430,7 +445,7 @@ public class PacketCreator { long context, int name, int width, int height, TextureReleaseCallback releaseCallback); private native long nativeCreateCpuImage( - long context, ByteBuffer buffer, int width, int height, int numChannels); + long context, ByteBuffer buffer, int width, int height, int rowBytes, int numChannels); private native long nativeCreateInt32Array(long context, int[] data); diff --git a/mediapipe/java/com/google/mediapipe/framework/PacketGetter.java b/mediapipe/java/com/google/mediapipe/framework/PacketGetter.java index 7e66e0b75..92cf723e6 100644 --- a/mediapipe/java/com/google/mediapipe/framework/PacketGetter.java +++ b/mediapipe/java/com/google/mediapipe/framework/PacketGetter.java @@ -199,6 +199,28 @@ public final class PacketGetter { return nativeGetImageData(packet.getNativeHandle(), buffer); } + /** Returns the size of Image list. This helps to determine size of allocated ByteBuffer array. */ + public static int getImageListSize(final Packet packet) { + return nativeGetImageListSize(packet.getNativeHandle()); + } + + /** + * Assign the native image buffer array in given ByteBuffer array. It assumes given ByteBuffer + * array has the the same size of image list packet, and assumes the output buffer stores pixels + * contiguously. It returns false if this assumption does not hold. + * + *

If deepCopy is true, it assumes the given buffersArray has allocated the required size of + * ByteBuffer to copy image data to. If false, the ByteBuffer will wrap the memory address of + * MediaPipe ImageFrame of graph output, and the ByteBuffer data is available only when MediaPipe + * graph is alive. + * + *

Note: this function does not assume the pixel format. + */ + public static boolean getImageList( + final Packet packet, ByteBuffer[] buffersArray, boolean deepCopy) { + return nativeGetImageList(packet.getNativeHandle(), buffersArray, deepCopy); + } + /** * Converts an RGB mediapipe image frame packet to an RGBA Byte buffer. * @@ -316,7 +338,8 @@ public final class PacketGetter { public static GraphTextureFrame getTextureFrameDeferredSync(final Packet packet) { return new GraphTextureFrame( nativeGetGpuBuffer(packet.getNativeHandle(), /* waitOnCpu= */ false), - packet.getTimestamp(), /* deferredSync= */true); + packet.getTimestamp(), + /* deferredSync= */ true); } private static native long nativeGetPacketFromReference(long nativePacketHandle); @@ -363,6 +386,11 @@ public final class PacketGetter { private static native boolean nativeGetImageData(long nativePacketHandle, ByteBuffer buffer); + private static native int nativeGetImageListSize(long nativePacketHandle); + + private static native boolean nativeGetImageList( + long nativePacketHandle, ByteBuffer[] bufferArray, boolean deepCopy); + private static native boolean nativeGetRgbaFromRgb(long nativePacketHandle, ByteBuffer buffer); // Retrieves the values that are in the VideoHeader. private static native int nativeGetVideoHeaderWidth(long nativepackethandle); diff --git a/mediapipe/java/com/google/mediapipe/framework/TextureFrame.java b/mediapipe/java/com/google/mediapipe/framework/TextureFrame.java index babfd2958..76eaf39df 100644 --- a/mediapipe/java/com/google/mediapipe/framework/TextureFrame.java +++ b/mediapipe/java/com/google/mediapipe/framework/TextureFrame.java @@ -59,4 +59,18 @@ public interface TextureFrame extends TextureReleaseCallback { */ @Override void release(GlSyncToken syncToken); + + /** + * If this method returns true, this object supports the retain method, and can be used with + * multiple consumers. Call retain for each additional consumer beyond the first; each consumer + * should call release. + */ + default boolean supportsRetain() { + return false; + } + + /** Increments the reference count. Only available with some implementations of TextureFrame. */ + default void retain() { + throw new UnsupportedOperationException(); + } } diff --git a/mediapipe/java/com/google/mediapipe/framework/image/BUILD b/mediapipe/java/com/google/mediapipe/framework/image/BUILD index bb3be318d..d9508c1f7 100644 --- a/mediapipe/java/com/google/mediapipe/framework/image/BUILD +++ b/mediapipe/java/com/google/mediapipe/framework/image/BUILD @@ -20,9 +20,7 @@ android_library( name = "image", srcs = glob(["*.java"]), manifest = "AndroidManifest.xml", - visibility = [ - "//mediapipe:__subpackages__", - ], + visibility = ["//visibility:public"], deps = [ "//third_party:androidx_legacy_support_v4", "//third_party:autovalue", diff --git a/mediapipe/java/com/google/mediapipe/framework/image/ByteBufferExtractor.java b/mediapipe/java/com/google/mediapipe/framework/image/ByteBufferExtractor.java index 748a10667..68c53b0c4 100644 --- a/mediapipe/java/com/google/mediapipe/framework/image/ByteBufferExtractor.java +++ b/mediapipe/java/com/google/mediapipe/framework/image/ByteBufferExtractor.java @@ -50,7 +50,10 @@ public class ByteBufferExtractor { switch (container.getImageProperties().getStorageType()) { case MPImage.STORAGE_TYPE_BYTEBUFFER: ByteBufferImageContainer byteBufferImageContainer = (ByteBufferImageContainer) container; - return byteBufferImageContainer.getByteBuffer().asReadOnlyBuffer(); + return byteBufferImageContainer + .getByteBuffer() + .asReadOnlyBuffer() + .order(ByteOrder.nativeOrder()); default: throw new IllegalArgumentException( "Extract ByteBuffer from a MPImage created by objects other than Bytebuffer is not" @@ -74,7 +77,7 @@ public class ByteBufferExtractor { * @throws IllegalArgumentException when the extraction requires unsupported format or data type * conversions. */ - static ByteBuffer extract(MPImage image, @MPImageFormat int targetFormat) { + public static ByteBuffer extract(MPImage image, @MPImageFormat int targetFormat) { MPImageContainer container; MPImageProperties byteBufferProperties = MPImageProperties.builder() @@ -83,12 +86,16 @@ public class ByteBufferExtractor { .build(); if ((container = image.getContainer(byteBufferProperties)) != null) { ByteBufferImageContainer byteBufferImageContainer = (ByteBufferImageContainer) container; - return byteBufferImageContainer.getByteBuffer().asReadOnlyBuffer(); + return byteBufferImageContainer + .getByteBuffer() + .asReadOnlyBuffer() + .order(ByteOrder.nativeOrder()); } else if ((container = image.getContainer(MPImage.STORAGE_TYPE_BYTEBUFFER)) != null) { ByteBufferImageContainer byteBufferImageContainer = (ByteBufferImageContainer) container; @MPImageFormat int sourceFormat = byteBufferImageContainer.getImageFormat(); return convertByteBuffer(byteBufferImageContainer.getByteBuffer(), sourceFormat, targetFormat) - .asReadOnlyBuffer(); + .asReadOnlyBuffer() + .order(ByteOrder.nativeOrder()); } else if ((container = image.getContainer(MPImage.STORAGE_TYPE_BITMAP)) != null) { BitmapImageContainer bitmapImageContainer = (BitmapImageContainer) container; ByteBuffer byteBuffer = diff --git a/mediapipe/java/com/google/mediapipe/framework/image/MPImage.java b/mediapipe/java/com/google/mediapipe/framework/image/MPImage.java index e17cc4d30..946beae37 100644 --- a/mediapipe/java/com/google/mediapipe/framework/image/MPImage.java +++ b/mediapipe/java/com/google/mediapipe/framework/image/MPImage.java @@ -67,6 +67,8 @@ public class MPImage implements Closeable { IMAGE_FORMAT_YUV_420_888, IMAGE_FORMAT_ALPHA, IMAGE_FORMAT_JPEG, + IMAGE_FORMAT_VEC32F1, + IMAGE_FORMAT_VEC32F2, }) @Retention(RetentionPolicy.SOURCE) public @interface MPImageFormat {} @@ -81,6 +83,8 @@ public class MPImage implements Closeable { public static final int IMAGE_FORMAT_YUV_420_888 = 7; public static final int IMAGE_FORMAT_ALPHA = 8; public static final int IMAGE_FORMAT_JPEG = 9; + public static final int IMAGE_FORMAT_VEC32F1 = 10; + public static final int IMAGE_FORMAT_VEC32F2 = 11; /** Specifies the image container type. Would be useful for choosing extractors. */ @IntDef({ diff --git a/mediapipe/java/com/google/mediapipe/framework/jni/BUILD b/mediapipe/java/com/google/mediapipe/framework/jni/BUILD index 4926e2f3c..4540f63a6 100644 --- a/mediapipe/java/com/google/mediapipe/framework/jni/BUILD +++ b/mediapipe/java/com/google/mediapipe/framework/jni/BUILD @@ -84,12 +84,11 @@ cc_library( deps = [ ":class_registry", ":jni_util", - "//mediapipe/framework:calculator_framework", - "//mediapipe/framework:calculator_profile_cc_proto", - "//mediapipe/framework/tool:calculator_graph_template_cc_proto", "//mediapipe/framework/formats:image_format_cc_proto", - "//mediapipe/framework/formats:matrix_data_cc_proto", "//mediapipe/framework/formats:time_series_header_cc_proto", + "//mediapipe/framework:calculator_cc_proto", + "//mediapipe/framework:calculator_profile_cc_proto", + "//mediapipe/framework:calculator_framework", "@com_google_absl//absl/strings", "@com_google_absl//absl/strings:str_format", "@com_google_absl//absl/synchronization", diff --git a/mediapipe/java/com/google/mediapipe/framework/jni/graph.cc b/mediapipe/java/com/google/mediapipe/framework/jni/graph.cc index 6a67c01cb..23bd553af 100644 --- a/mediapipe/java/com/google/mediapipe/framework/jni/graph.cc +++ b/mediapipe/java/com/google/mediapipe/framework/jni/graph.cc @@ -231,8 +231,6 @@ int64_t Graph::AddSurfaceOutput(const std::string& output_stream_name) { *graph_config(), absl::StrCat("egl_surface_sink_", output_stream_name))); sink_node->set_calculator("GlSurfaceSinkCalculator"); sink_node->add_input_stream(output_stream_name); - sink_node->add_input_side_packet( - absl::StrCat(kGpuSharedTagName, ":", kGpuSharedSidePacketName)); const std::string input_side_packet_name = mediapipe::tool::GetUnusedSidePacketName( diff --git a/mediapipe/java/com/google/mediapipe/framework/jni/graph_texture_frame_jni.cc b/mediapipe/java/com/google/mediapipe/framework/jni/graph_texture_frame_jni.cc index 84df89260..dd99cccd4 100644 --- a/mediapipe/java/com/google/mediapipe/framework/jni/graph_texture_frame_jni.cc +++ b/mediapipe/java/com/google/mediapipe/framework/jni/graph_texture_frame_jni.cc @@ -15,20 +15,16 @@ #include "mediapipe/java/com/google/mediapipe/framework/jni/graph_texture_frame_jni.h" #include "mediapipe/gpu/gl_calculator_helper.h" +#include "mediapipe/gpu/gl_context.h" #include "mediapipe/gpu/gl_texture_buffer.h" #include "mediapipe/java/com/google/mediapipe/framework/jni/jni_util.h" using mediapipe::GlTextureBufferSharedPtr; JNIEXPORT void JNICALL GRAPH_TEXTURE_FRAME_METHOD(nativeReleaseBuffer)( - JNIEnv* env, jobject thiz, jlong nativeHandle, jlong consumerSyncToken) { + JNIEnv* env, jobject thiz, jlong nativeHandle) { GlTextureBufferSharedPtr* buffer = reinterpret_cast(nativeHandle); - if (consumerSyncToken) { - mediapipe::GlSyncToken& token = - *reinterpret_cast(consumerSyncToken); - (*buffer)->DidRead(token); - } delete buffer; } @@ -84,3 +80,18 @@ JNIEXPORT jlong JNICALL GRAPH_TEXTURE_FRAME_METHOD( } return reinterpret_cast(token); } + +JNIEXPORT jlong JNICALL GRAPH_TEXTURE_FRAME_METHOD( + nativeGetCurrentExternalContextHandle)(JNIEnv* env, jobject thiz) { + return reinterpret_cast( + mediapipe::GlContext::GetCurrentNativeContext()); +} + +JNIEXPORT void JNICALL GRAPH_TEXTURE_FRAME_METHOD(nativeDidRead)( + JNIEnv* env, jobject thiz, jlong nativeHandle, jlong consumerSyncToken) { + GlTextureBufferSharedPtr* buffer = + reinterpret_cast(nativeHandle); + mediapipe::GlSyncToken& token = + *reinterpret_cast(consumerSyncToken); + (*buffer)->DidRead(token); +} diff --git a/mediapipe/java/com/google/mediapipe/framework/jni/graph_texture_frame_jni.h b/mediapipe/java/com/google/mediapipe/framework/jni/graph_texture_frame_jni.h index 45637bb31..41c531fff 100644 --- a/mediapipe/java/com/google/mediapipe/framework/jni/graph_texture_frame_jni.h +++ b/mediapipe/java/com/google/mediapipe/framework/jni/graph_texture_frame_jni.h @@ -26,7 +26,7 @@ extern "C" { // Releases a native mediapipe::GpuBuffer. JNIEXPORT void JNICALL GRAPH_TEXTURE_FRAME_METHOD(nativeReleaseBuffer)( - JNIEnv* env, jobject thiz, jlong nativeHandle, jlong consumerSyncToken); + JNIEnv* env, jobject thiz, jlong nativeHandle); JNIEXPORT jint JNICALL GRAPH_TEXTURE_FRAME_METHOD(nativeGetTextureName)( JNIEnv* env, jobject thiz, jlong nativeHandle); @@ -44,6 +44,12 @@ JNIEXPORT jlong JNICALL GRAPH_TEXTURE_FRAME_METHOD( nativeCreateSyncTokenForCurrentExternalContext)(JNIEnv* env, jobject thiz, jlong nativeHandle); +JNIEXPORT void JNICALL GRAPH_TEXTURE_FRAME_METHOD(nativeDidRead)( + JNIEnv* env, jobject thiz, jlong nativeHandle, jlong consumerSyncToken); + +JNIEXPORT jlong JNICALL GRAPH_TEXTURE_FRAME_METHOD( + nativeGetCurrentExternalContextHandle)(JNIEnv* env, jobject thiz); + #ifdef __cplusplus } // extern "C" #endif // __cplusplus diff --git a/mediapipe/java/com/google/mediapipe/framework/jni/packet_creator_jni.cc b/mediapipe/java/com/google/mediapipe/framework/jni/packet_creator_jni.cc index 250d7c938..46ea1ce41 100644 --- a/mediapipe/java/com/google/mediapipe/framework/jni/packet_creator_jni.cc +++ b/mediapipe/java/com/google/mediapipe/framework/jni/packet_creator_jni.cc @@ -17,6 +17,8 @@ #include #include +#include "absl/status/status.h" +#include "absl/strings/str_cat.h" #include "mediapipe/framework/calculator.pb.h" #include "mediapipe/framework/camera_intrinsics.h" #include "mediapipe/framework/formats/image.h" @@ -107,55 +109,31 @@ absl::StatusOr CreateGpuBuffer( // Create a 1, 3, or 4 channel 8-bit ImageFrame shared pointer from a Java // ByteBuffer. -std::unique_ptr CreateImageFrameFromByteBuffer( - JNIEnv* env, jobject byte_buffer, jint width, jint height, - mediapipe::ImageFormat::Format format) { - switch (format) { - case mediapipe::ImageFormat::SRGBA: - case mediapipe::ImageFormat::SRGB: - case mediapipe::ImageFormat::GRAY8: - break; - default: - LOG(ERROR) << "Format must be either SRGBA, SRGB, or GRAY8."; - return nullptr; - } - - auto image_frame = std::make_unique( - format, width, height, - mediapipe::ImageFrame::kGlDefaultAlignmentBoundary); - +absl::StatusOr> +CreateImageFrameFromByteBuffer(JNIEnv* env, jobject byte_buffer, jint width, + jint height, jint width_step, + mediapipe::ImageFormat::Format format) { const int64_t buffer_size = env->GetDirectBufferCapacity(byte_buffer); - const int num_channels = image_frame->NumberOfChannels(); - const int expected_buffer_size = - num_channels == 1 ? width * height : image_frame->PixelDataSize(); - - if (buffer_size != expected_buffer_size) { - if (num_channels != 1) - LOG(ERROR) << "The input image buffer should have 4 bytes alignment."; - LOG(ERROR) << "Please check the input buffer size."; - LOG(ERROR) << "Buffer size: " << buffer_size - << ", Buffer size needed: " << expected_buffer_size - << ", Image width: " << width; - return nullptr; + const void* buffer_data = env->GetDirectBufferAddress(byte_buffer); + if (buffer_data == nullptr || buffer_size < 0) { + return absl::InvalidArgumentError( + "Cannot get direct access to the input buffer. It should be created " + "using allocateDirect."); } - // Copy buffer data to image frame's pixel_data_. - if (num_channels == 1) { - const int width_step = image_frame->WidthStep(); - const char* src_row = - reinterpret_cast(env->GetDirectBufferAddress(byte_buffer)); - char* dst_row = reinterpret_cast(image_frame->MutablePixelData()); - for (int i = height; i > 0; --i) { - std::memcpy(dst_row, src_row, width); - src_row += width; - dst_row += width_step; - } - } else { - // 3 and 4 channels. - const void* buffer_data = env->GetDirectBufferAddress(byte_buffer); - std::memcpy(image_frame->MutablePixelData(), buffer_data, - image_frame->PixelDataSize()); - } + const int expected_buffer_size = height * width_step; + RET_CHECK_EQ(buffer_size, expected_buffer_size) + << "Input buffer size should be " << expected_buffer_size + << " but is: " << buffer_size; + + auto image_frame = std::make_unique(); + // TODO: we could retain the buffer with a special deleter and use + // the data directly without a copy. May need a new Java API since existing + // code might expect to be able to overwrite the buffer after creating an + // ImageFrame from it. + image_frame->CopyPixelData( + format, width, height, width_step, static_cast(buffer_data), + mediapipe::ImageFrame::kGlDefaultAlignmentBoundary); return image_frame; } @@ -176,77 +154,83 @@ JNIEXPORT jlong JNICALL PACKET_CREATOR_METHOD(nativeCreateReferencePacket)( JNIEXPORT jlong JNICALL PACKET_CREATOR_METHOD(nativeCreateRgbImage)( JNIEnv* env, jobject thiz, jlong context, jobject byte_buffer, jint width, jint height) { - auto image_frame = CreateImageFrameFromByteBuffer( - env, byte_buffer, width, height, mediapipe::ImageFormat::SRGB); - if (nullptr == image_frame) return 0L; + // We require 4-byte alignment. See Java method. + constexpr int kAlignment = 4; + int width_step = ((width * 3 - 1) | (kAlignment - 1)) + 1; + auto image_frame_or = + CreateImageFrameFromByteBuffer(env, byte_buffer, width, height, + width_step, mediapipe::ImageFormat::SRGB); + if (ThrowIfError(env, image_frame_or.status())) return 0L; - mediapipe::Packet packet = mediapipe::Adopt(image_frame.release()); + mediapipe::Packet packet = mediapipe::Adopt(image_frame_or->release()); return CreatePacketWithContext(context, packet); } +absl::StatusOr> CreateRgbImageFromRgba( + JNIEnv* env, jobject byte_buffer, jint width, jint height) { + const uint8_t* rgba_data = + static_cast(env->GetDirectBufferAddress(byte_buffer)); + int64_t buffer_size = env->GetDirectBufferCapacity(byte_buffer); + if (rgba_data == nullptr || buffer_size < 0) { + return absl::InvalidArgumentError( + "Cannot get direct access to the input buffer. It should be created " + "using allocateDirect."); + } + + const int expected_buffer_size = width * height * 4; + RET_CHECK_EQ(buffer_size, expected_buffer_size) + << "Input buffer size should be " << expected_buffer_size + << " but is: " << buffer_size; + + auto image_frame = absl::make_unique( + mediapipe::ImageFormat::SRGB, width, height, + mediapipe::ImageFrame::kGlDefaultAlignmentBoundary); + mediapipe::android::RgbaToRgb(rgba_data, width * 4, width, height, + image_frame->MutablePixelData(), + image_frame->WidthStep()); + return image_frame; +} + JNIEXPORT jlong JNICALL PACKET_CREATOR_METHOD(nativeCreateRgbImageFromRgba)( JNIEnv* env, jobject thiz, jlong context, jobject byte_buffer, jint width, jint height) { - const uint8_t* rgba_data = - static_cast(env->GetDirectBufferAddress(byte_buffer)); - auto image_frame = absl::make_unique( - mediapipe::ImageFormat::SRGB, width, height, - mediapipe::ImageFrame::kGlDefaultAlignmentBoundary); - int64_t buffer_size = env->GetDirectBufferCapacity(byte_buffer); - if (buffer_size != width * height * 4) { - LOG(ERROR) << "Please check the input buffer size."; - LOG(ERROR) << "Buffer size: " << buffer_size - << ", Buffer size needed: " << width * height * 4 - << ", Image width: " << width; - return 0L; - } - mediapipe::android::RgbaToRgb(rgba_data, width * 4, width, height, - image_frame->MutablePixelData(), - image_frame->WidthStep()); - mediapipe::Packet packet = mediapipe::Adopt(image_frame.release()); + auto image_frame_or = CreateRgbImageFromRgba(env, byte_buffer, width, height); + if (ThrowIfError(env, image_frame_or.status())) return 0L; + + mediapipe::Packet packet = mediapipe::Adopt(image_frame_or->release()); return CreatePacketWithContext(context, packet); } JNIEXPORT jlong JNICALL PACKET_CREATOR_METHOD(nativeCreateGrayscaleImage)( JNIEnv* env, jobject thiz, jlong context, jobject byte_buffer, jint width, jint height) { - auto image_frame = CreateImageFrameFromByteBuffer( - env, byte_buffer, width, height, mediapipe::ImageFormat::GRAY8); - if (nullptr == image_frame) return 0L; + auto image_frame_or = CreateImageFrameFromByteBuffer( + env, byte_buffer, width, height, width, mediapipe::ImageFormat::GRAY8); + if (ThrowIfError(env, image_frame_or.status())) return 0L; - mediapipe::Packet packet = mediapipe::Adopt(image_frame.release()); + mediapipe::Packet packet = mediapipe::Adopt(image_frame_or->release()); return CreatePacketWithContext(context, packet); } JNIEXPORT jlong JNICALL PACKET_CREATOR_METHOD(nativeCreateFloatImageFrame)( JNIEnv* env, jobject thiz, jlong context, jobject byte_buffer, jint width, jint height) { - const void* data = env->GetDirectBufferAddress(byte_buffer); - auto image_frame = absl::make_unique( - mediapipe::ImageFormat::VEC32F1, width, height, - mediapipe::ImageFrame::kGlDefaultAlignmentBoundary); - int64_t buffer_size = env->GetDirectBufferCapacity(byte_buffer); - if (buffer_size != image_frame->PixelDataSize()) { - LOG(ERROR) << "Please check the input buffer size."; - LOG(ERROR) << "Buffer size: " << buffer_size - << ", Buffer size needed: " << image_frame->PixelDataSize() - << ", Image width: " << width; - return 0L; - } - std::memcpy(image_frame->MutablePixelData(), data, - image_frame->PixelDataSize()); - mediapipe::Packet packet = mediapipe::Adopt(image_frame.release()); + auto image_frame_or = + CreateImageFrameFromByteBuffer(env, byte_buffer, width, height, width * 4, + mediapipe::ImageFormat::VEC32F1); + if (ThrowIfError(env, image_frame_or.status())) return 0L; + mediapipe::Packet packet = mediapipe::Adopt(image_frame_or->release()); return CreatePacketWithContext(context, packet); } JNIEXPORT jlong JNICALL PACKET_CREATOR_METHOD(nativeCreateRgbaImageFrame)( JNIEnv* env, jobject thiz, jlong context, jobject byte_buffer, jint width, jint height) { - auto image_frame = CreateImageFrameFromByteBuffer( - env, byte_buffer, width, height, mediapipe::ImageFormat::SRGBA); - if (nullptr == image_frame) return 0L; - - mediapipe::Packet packet = mediapipe::Adopt(image_frame.release()); + auto image_frame_or = + CreateImageFrameFromByteBuffer(env, byte_buffer, width, height, width * 4, + mediapipe::ImageFormat::SRGBA); + if (ThrowIfError(env, image_frame_or.status())) return 0L; + mediapipe::Packet packet = mediapipe::Adopt(image_frame_or->release()); return CreatePacketWithContext(context, packet); } @@ -291,6 +275,12 @@ JNIEXPORT jlong JNICALL PACKET_CREATOR_METHOD(nativeCreateAudioPacketDirect)( jint num_samples) { const uint8_t* audio_sample = reinterpret_cast(env->GetDirectBufferAddress(data)); + if (!audio_sample) { + ThrowIfError(env, absl::InvalidArgumentError( + "Cannot get direct access to the input buffer. It " + "should be created using allocateDirect.")); + return 0L; + } mediapipe::Packet packet = createAudioPacket(audio_sample, num_samples, num_channels); return CreatePacketWithContext(context, packet); @@ -360,8 +350,10 @@ JNIEXPORT jlong JNICALL PACKET_CREATOR_METHOD(nativeCreateMatrix)( JNIEnv* env, jobject thiz, jlong context, jint rows, jint cols, jfloatArray data) { if (env->GetArrayLength(data) != rows * cols) { - LOG(ERROR) << "Please check the matrix data size, has to be rows * cols = " - << rows * cols; + ThrowIfError( + env, absl::InvalidArgumentError(absl::StrCat( + "Please check the matrix data size, has to be rows * cols = ", + rows * cols))); return 0L; } std::unique_ptr matrix(new mediapipe::Matrix(rows, cols)); @@ -379,7 +371,7 @@ JNIEXPORT jlong JNICALL PACKET_CREATOR_METHOD(nativeCreateMatrix)( JNIEXPORT jlong JNICALL PACKET_CREATOR_METHOD(nativeCreateCpuImage)( JNIEnv* env, jobject thiz, jlong context, jobject byte_buffer, jint width, - jint height, jint num_channels) { + jint height, jint width_step, jint num_channels) { mediapipe::ImageFormat::Format format; switch (num_channels) { case 4: @@ -392,16 +384,18 @@ JNIEXPORT jlong JNICALL PACKET_CREATOR_METHOD(nativeCreateCpuImage)( format = mediapipe::ImageFormat::GRAY8; break; default: - LOG(ERROR) << "Channels must be either 1, 3, or 4."; + ThrowIfError(env, absl::InvalidArgumentError(absl::StrCat( + "Channels must be either 1, 3, or 4, but are ", + num_channels))); return 0L; } - auto image_frame = - CreateImageFrameFromByteBuffer(env, byte_buffer, width, height, format); - if (nullptr == image_frame) return 0L; + auto image_frame_or = CreateImageFrameFromByteBuffer( + env, byte_buffer, width, height, width_step, format); + if (ThrowIfError(env, image_frame_or.status())) return 0L; mediapipe::Packet packet = - mediapipe::MakePacket(std::move(image_frame)); + mediapipe::MakePacket(*std::move(image_frame_or)); return CreatePacketWithContext(context, packet); } @@ -502,7 +496,8 @@ JNIEXPORT jlong JNICALL PACKET_CREATOR_METHOD(nativeCreateCalculatorOptions)( jbyte* data_ref = env->GetByteArrayElements(data, nullptr); auto options = absl::make_unique(); if (!options->ParseFromArray(data_ref, count)) { - LOG(ERROR) << "Parsing binary-encoded CalculatorOptions failed."; + ThrowIfError(env, absl::InvalidArgumentError(absl::StrCat( + "Parsing binary-encoded CalculatorOptions failed."))); return 0L; } mediapipe::Packet packet = mediapipe::Adopt(options.release()); diff --git a/mediapipe/java/com/google/mediapipe/framework/jni/packet_creator_jni.h b/mediapipe/java/com/google/mediapipe/framework/jni/packet_creator_jni.h index d6f44b0a3..b3b1043fb 100644 --- a/mediapipe/java/com/google/mediapipe/framework/jni/packet_creator_jni.h +++ b/mediapipe/java/com/google/mediapipe/framework/jni/packet_creator_jni.h @@ -99,7 +99,7 @@ JNIEXPORT jlong JNICALL PACKET_CREATOR_METHOD(nativeCreateMatrix)( JNIEXPORT jlong JNICALL PACKET_CREATOR_METHOD(nativeCreateCpuImage)( JNIEnv* env, jobject thiz, jlong context, jobject byte_buffer, jint width, - jint height, jint num_channels); + jint height, jint width_step, jint num_channels); JNIEXPORT jlong JNICALL PACKET_CREATOR_METHOD(nativeCreateGpuImage)( JNIEnv* env, jobject thiz, jlong context, jint name, jint width, diff --git a/mediapipe/java/com/google/mediapipe/framework/jni/packet_getter_jni.cc b/mediapipe/java/com/google/mediapipe/framework/jni/packet_getter_jni.cc index c215dd929..234209b8c 100644 --- a/mediapipe/java/com/google/mediapipe/framework/jni/packet_getter_jni.cc +++ b/mediapipe/java/com/google/mediapipe/framework/jni/packet_getter_jni.cc @@ -14,6 +14,8 @@ #include "mediapipe/java/com/google/mediapipe/framework/jni/packet_getter_jni.h" +#include "absl/status/status.h" +#include "absl/strings/str_cat.h" #include "mediapipe/framework/calculator.pb.h" #include "mediapipe/framework/formats/image.h" #include "mediapipe/framework/formats/image_frame.h" @@ -38,6 +40,52 @@ template const T& GetFromNativeHandle(int64_t packet_handle) { return mediapipe::android::Graph::GetPacketFromHandle(packet_handle).Get(); } + +bool CopyImageDataToByteBuffer(JNIEnv* env, const mediapipe::ImageFrame& image, + jobject byte_buffer) { + int64_t buffer_size = env->GetDirectBufferCapacity(byte_buffer); + void* buffer_data = env->GetDirectBufferAddress(byte_buffer); + if (buffer_data == nullptr || buffer_size < 0) { + ThrowIfError(env, absl::InvalidArgumentError( + "input buffer does not support direct access")); + return false; + } + + // Assume byte buffer stores pixel data contiguously. + const int expected_buffer_size = image.Width() * image.Height() * + image.ByteDepth() * image.NumberOfChannels(); + if (buffer_size != expected_buffer_size) { + ThrowIfError( + env, absl::InvalidArgumentError(absl::StrCat( + "Expected buffer size ", expected_buffer_size, + " got: ", buffer_size, ", width ", image.Width(), ", height ", + image.Height(), ", channels ", image.NumberOfChannels()))); + return false; + } + + switch (image.ByteDepth()) { + case 1: { + uint8* data = static_cast(buffer_data); + image.CopyToBuffer(data, expected_buffer_size); + break; + } + case 2: { + uint16* data = static_cast(buffer_data); + image.CopyToBuffer(data, expected_buffer_size); + break; + } + case 4: { + float* data = static_cast(buffer_data); + image.CopyToBuffer(data, expected_buffer_size); + break; + } + default: { + return false; + } + } + return true; +} + } // namespace JNIEXPORT jlong JNICALL PACKET_GETTER_METHOD(nativeGetPacketFromReference)( @@ -297,42 +345,51 @@ JNIEXPORT jboolean JNICALL PACKET_GETTER_METHOD(nativeGetImageData)( .GetImageFrameSharedPtr() .get() : GetFromNativeHandle(packet); + return CopyImageDataToByteBuffer(env, image, byte_buffer); +} - int64_t buffer_size = env->GetDirectBufferCapacity(byte_buffer); +JNIEXPORT jint JNICALL PACKET_GETTER_METHOD(nativeGetImageListSize)( + JNIEnv* env, jobject thiz, jlong packet) { + const auto& image_list = + GetFromNativeHandle>(packet); + return image_list.size(); +} - // Assume byte buffer stores pixel data contiguously. - const int expected_buffer_size = image.Width() * image.Height() * - image.ByteDepth() * image.NumberOfChannels(); - if (buffer_size != expected_buffer_size) { - LOG(ERROR) << "Expected buffer size " << expected_buffer_size - << " got: " << buffer_size << ", width " << image.Width() - << ", height " << image.Height() << ", channels " - << image.NumberOfChannels(); +JNIEXPORT jboolean JNICALL PACKET_GETTER_METHOD(nativeGetImageList)( + JNIEnv* env, jobject thiz, jlong packet, jobjectArray byte_buffer_array, + jboolean deep_copy) { + const auto& image_list = + GetFromNativeHandle>(packet); + if (env->GetArrayLength(byte_buffer_array) != image_list.size()) { + ThrowIfError(env, absl::InvalidArgumentError(absl::StrCat( + "Expected ByteBuffer array size: ", image_list.size(), + " but get ByteBuffer array size: ", + env->GetArrayLength(byte_buffer_array)))); return false; } - - switch (image.ByteDepth()) { - case 1: { - uint8* data = - static_cast(env->GetDirectBufferAddress(byte_buffer)); - image.CopyToBuffer(data, expected_buffer_size); - break; - } - case 2: { - uint16* data = - static_cast(env->GetDirectBufferAddress(byte_buffer)); - image.CopyToBuffer(data, expected_buffer_size); - break; - } - case 4: { - float* data = - static_cast(env->GetDirectBufferAddress(byte_buffer)); - image.CopyToBuffer(data, expected_buffer_size); - break; - } - default: { + for (int i = 0; i < image_list.size(); ++i) { + auto& image = *image_list[i].GetImageFrameSharedPtr().get(); + if (!image.IsContiguous()) { + ThrowIfError( + env, absl::InternalError("ImageFrame must store data contiguously to " + "be allocated as ByteBuffer.")); return false; } + if (deep_copy) { + jobject byte_buffer = reinterpret_cast( + env->GetObjectArrayElement(byte_buffer_array, i)); + if (!CopyImageDataToByteBuffer(env, image, byte_buffer)) { + return false; + } + } else { + // Assume byte buffer stores pixel data contiguously. + const int expected_buffer_size = image.Width() * image.Height() * + image.ByteDepth() * + image.NumberOfChannels(); + jobject image_data_byte_buffer = env->NewDirectByteBuffer( + image.MutablePixelData(), expected_buffer_size); + env->SetObjectArrayElement(byte_buffer_array, i, image_data_byte_buffer); + } } return true; } @@ -351,12 +408,19 @@ JNIEXPORT jboolean JNICALL PACKET_GETTER_METHOD(nativeGetRgbaFromRgb)( uint8_t* rgba_data = static_cast(env->GetDirectBufferAddress(byte_buffer)); int64_t buffer_size = env->GetDirectBufferCapacity(byte_buffer); + if (rgba_data == nullptr || buffer_size < 0) { + ThrowIfError(env, absl::InvalidArgumentError( + "input buffer does not support direct access")); + return false; + } if (buffer_size != image.Width() * image.Height() * 4) { - LOG(ERROR) << "Buffer size has to be width*height*4\n" - << "Image width: " << image.Width() - << ", Image height: " << image.Height() - << ", Buffer size: " << buffer_size << ", Buffer size needed: " - << image.Width() * image.Height() * 4; + ThrowIfError(env, + absl::InvalidArgumentError(absl::StrCat( + "Buffer size has to be width*height*4\n" + "Image width: ", + image.Width(), ", Image height: ", image.Height(), + ", Buffer size: ", buffer_size, ", Buffer size needed: ", + image.Width() * image.Height() * 4))); return false; } mediapipe::android::RgbToRgba(image.PixelData(), image.WidthStep(), @@ -403,7 +467,8 @@ JNIEXPORT jbyteArray JNICALL PACKET_GETTER_METHOD(nativeGetAudioData)( int16 value = static_cast(audio_mat(channel, sample) * kMultiplier); // The java and native has the same byte order, by default is little - // Endian, we can safely copy data directly, we have tests to cover this. + // Endian, we can safely copy data directly, we have tests to cover + // this. env->SetByteArrayRegion(byte_data, offset, 2, reinterpret_cast(&value)); offset += 2; diff --git a/mediapipe/java/com/google/mediapipe/framework/jni/packet_getter_jni.h b/mediapipe/java/com/google/mediapipe/framework/jni/packet_getter_jni.h index 6a20d3daf..4602ebd59 100644 --- a/mediapipe/java/com/google/mediapipe/framework/jni/packet_getter_jni.h +++ b/mediapipe/java/com/google/mediapipe/framework/jni/packet_getter_jni.h @@ -106,6 +106,17 @@ JNIEXPORT jint JNICALL PACKET_GETTER_METHOD(nativeGetImageHeight)(JNIEnv* env, JNIEXPORT jboolean JNICALL PACKET_GETTER_METHOD(nativeGetImageData)( JNIEnv* env, jobject thiz, jlong packet, jobject byte_buffer); +// Return the vector size of std::vector. +JNIEXPORT jint JNICALL PACKET_GETTER_METHOD(nativeGetImageListSize)( + JNIEnv* env, jobject thiz, jlong packet); + +// Fill ByteBuffer[] from the Packet of std::vector. +// Before calling this, the byte_buffer_array needs to have the correct +// allocated size. +JNIEXPORT jboolean JNICALL PACKET_GETTER_METHOD(nativeGetImageList)( + JNIEnv* env, jobject thiz, jlong packet, jobjectArray byte_buffer_array, + jboolean deep_copy); + // Before calling this, the byte_buffer needs to have the correct allocated // size. JNIEXPORT jboolean JNICALL PACKET_GETTER_METHOD(nativeGetRgbaFromRgb)( diff --git a/mediapipe/model_maker/MANIFEST.in b/mediapipe/model_maker/MANIFEST.in new file mode 100644 index 000000000..54ce01aff --- /dev/null +++ b/mediapipe/model_maker/MANIFEST.in @@ -0,0 +1 @@ +recursive-include pip_src/mediapipe_model_maker/models * diff --git a/mediapipe/model_maker/__init__.py b/mediapipe/model_maker/__init__.py index 7ca2f9216..b37088764 100644 --- a/mediapipe/model_maker/__init__.py +++ b/mediapipe/model_maker/__init__.py @@ -11,3 +11,12 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. + + +from mediapipe.model_maker.python.core.utils import quantization +from mediapipe.model_maker.python.vision import image_classifier +from mediapipe.model_maker.python.vision import gesture_recognizer +from mediapipe.model_maker.python.text import text_classifier + +# Remove duplicated and non-public API +del python diff --git a/mediapipe/model_maker/models/text_classifier/BUILD b/mediapipe/model_maker/models/text_classifier/BUILD new file mode 100644 index 000000000..4c54bbccc --- /dev/null +++ b/mediapipe/model_maker/models/text_classifier/BUILD @@ -0,0 +1,45 @@ +# Copyright 2022 The MediaPipe Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +load( + "//mediapipe/framework/tool:mediapipe_files.bzl", + "mediapipe_files", +) + +licenses(["notice"]) + +package( + default_visibility = ["//mediapipe/model_maker/python/text/text_classifier:__subpackages__"], +) + +mediapipe_files( + srcs = [ + "mobilebert_tiny/assets/vocab.txt", + "mobilebert_tiny/keras_metadata.pb", + "mobilebert_tiny/saved_model.pb", + "mobilebert_tiny/variables/variables.data-00000-of-00001", + "mobilebert_tiny/variables/variables.index", + ], +) + +filegroup( + name = "mobilebert_tiny", + srcs = [ + "mobilebert_tiny/assets/vocab.txt", + "mobilebert_tiny/keras_metadata.pb", + "mobilebert_tiny/saved_model.pb", + "mobilebert_tiny/variables/variables.data-00000-of-00001", + "mobilebert_tiny/variables/variables.index", + ], +) diff --git a/mediapipe/model_maker/python/core/tasks/classifier.py b/mediapipe/model_maker/python/core/tasks/classifier.py index 200726864..abcfff835 100644 --- a/mediapipe/model_maker/python/core/tasks/classifier.py +++ b/mediapipe/model_maker/python/core/tasks/classifier.py @@ -37,7 +37,7 @@ class Classifier(custom_model.CustomModel): label_names: A list of label names for the classes. shuffle: Whether the dataset should be shuffled. """ - super(Classifier, self).__init__(model_spec, shuffle) + super().__init__(model_spec, shuffle) self._label_names = label_names self._num_classes = len(label_names) self._model: tf.keras.Model = None @@ -48,11 +48,11 @@ class Classifier(custom_model.CustomModel): self._hparams: hp.BaseHParams = None self._history: tf.keras.callbacks.History = None - # TODO: Integrate this into all Model Maker tasks. def _train_model(self, train_data: classification_ds.ClassificationDataset, validation_data: classification_ds.ClassificationDataset, - preprocessor: Optional[Callable[..., bool]] = None): + preprocessor: Optional[Callable[..., bool]] = None, + checkpoint_path: Optional[str] = None): """Trains the classifier model. Compiles and fits the tf.keras `_model` and records the `_history`. @@ -62,6 +62,9 @@ class Classifier(custom_model.CustomModel): validation_data: Validation data. preprocessor: An optional data preprocessor that can be used when generating a tf.data.Dataset. + checkpoint_path: An optional directory for the checkpoint file to support + continual training. If provided, loads model weights from the latest + checkpoint in the directory. """ tf.compat.v1.logging.info('Training the models...') if len(train_data) < self._hparams.batch_size: @@ -88,9 +91,21 @@ class Classifier(custom_model.CustomModel): optimizer=self._optimizer, loss=self._loss_function, metrics=[self._metric_function]) + + latest_checkpoint = ( + tf.train.latest_checkpoint(checkpoint_path) + if checkpoint_path else None) + if latest_checkpoint: + print(f'Resuming from {latest_checkpoint}') + self._model.load_weights(latest_checkpoint) + self._history = self._model.fit( x=train_dataset, epochs=self._hparams.epochs, + # `steps_per_epoch` is intentionally set to None in case the dataset + # is not repeated. Otherwise, the training process will stop when the + # dataset is exhausted even if there are epochs remaining. + steps_per_epoch=None, validation_data=validation_dataset, callbacks=self._callbacks) diff --git a/mediapipe/model_maker/python/core/utils/BUILD b/mediapipe/model_maker/python/core/utils/BUILD index 12fef631f..3c9107dba 100644 --- a/mediapipe/model_maker/python/core/utils/BUILD +++ b/mediapipe/model_maker/python/core/utils/BUILD @@ -45,6 +45,7 @@ py_test( name = "model_util_test", srcs = ["model_util_test.py"], deps = [ + ":file_util", ":model_util", ":quantization", ":test_util", @@ -60,6 +61,7 @@ py_test( name = "file_util_test", srcs = ["file_util_test.py"], data = ["//mediapipe/model_maker/python/core/utils/testdata"], + tags = ["requires-net:external"], deps = [":file_util"], ) diff --git a/mediapipe/model_maker/python/core/utils/file_util.py b/mediapipe/model_maker/python/core/utils/file_util.py index bccf928e2..29d11ebbe 100644 --- a/mediapipe/model_maker/python/core/utils/file_util.py +++ b/mediapipe/model_maker/python/core/utils/file_util.py @@ -13,13 +13,95 @@ # limitations under the License. """Utilities for files.""" +import dataclasses import os +import pathlib +import shutil +import tarfile +import tempfile +import requests # resources dependency +_TEMPDIR_FOLDER = 'model_maker' + + +@dataclasses.dataclass +class DownloadedFiles: + """File(s) that are downloaded from a url into a local directory. + + If `is_folder` is True: + 1. `path` should be a folder + 2. `url` should point to a .tar.gz file which contains a single folder at + the root level. + + Attributes: + path: Relative path in local directory. + url: GCS url to download the file(s). + is_folder: Whether the path and url represents a folder. + """ + + path: str + url: str + is_folder: bool = False + + def get_path(self) -> str: + """Gets the path of files saved in a local directory. + + If the path doesn't exist, this method will download the file(s) from the + provided url. The path is not cleaned up so it can be reused for subsequent + calls to the same path. + Folders are expected to be zipped in a .tar.gz file which will be extracted + into self.path in the local directory. + + Raises: + RuntimeError: If the extracted folder does not have a singular root + directory. + + Returns: + The absolute path to the downloaded file(s) + """ + tmpdir = tempfile.gettempdir() + absolute_path = pathlib.Path( + os.path.join(tmpdir, _TEMPDIR_FOLDER, self.path) + ) + if not absolute_path.exists(): + print(f'Downloading {self.url} to {absolute_path}') + r = requests.get(self.url, allow_redirects=True) + if self.is_folder: + # Use tempf to store the downloaded .tar.gz file + tempf = tempfile.NamedTemporaryFile(suffix='.tar.gz', mode='wb') + tempf.write(r.content) + tarf = tarfile.open(tempf.name) + # Use tmpdir to store the extracted contents of the .tar.gz file + with tempfile.TemporaryDirectory() as tmpdir: + tarf.extractall(tmpdir) + tarf.close() + tempf.close() + subdirs = os.listdir(tmpdir) + # Make sure tmpdir only has one subdirectory + if len(subdirs) > 1 or not os.path.isdir( + os.path.join(tmpdir, subdirs[0]) + ): + raise RuntimeError( + f"Extracted folder from {self.url} doesn't contain a " + f'single root directory: {subdirs}' + ) + # Create the parent dir of absolute_path and copy the contents of the + # top level dir in the .tar.gz file into absolute_path. + pathlib.Path.mkdir(absolute_path.parent, parents=True, exist_ok=True) + shutil.copytree(os.path.join(tmpdir, subdirs[0]), absolute_path) + else: + pathlib.Path.mkdir(absolute_path.parent, parents=True, exist_ok=True) + with open(absolute_path, 'wb') as f: + f.write(r.content) + return str(absolute_path) + + +# TODO Remove after text_classifier supports downloading on demand. def get_absolute_path(file_path: str) -> str: - """Gets the absolute path of a file. + """Gets the absolute path of a file in the model_maker directory. Args: file_path: The path to a file relative to the `mediapipe` dir @@ -27,10 +109,17 @@ def get_absolute_path(file_path: str) -> str: Returns: The full path of the file """ - # Extract the file path before mediapipe/ as the `base_dir`. By joining it - # with the `path` which defines the relative path under mediapipe/, it - # yields to the absolute path of the model files directory. + # Extract the file path before and including 'model_maker' as the + # `mm_base_dir`. By joining it with the `path` after 'model_maker/', it + # yields to the absolute path of the model files directory. We must join + # on 'model_maker' because in the pypi package, the 'model_maker' directory + # is renamed to 'mediapipe_model_maker'. So we have to join on model_maker + # to ensure that the `mm_base_dir` path includes the renamed + # 'mediapipe_model_maker' directory. cwd = os.path.dirname(__file__) - base_dir = cwd[:cwd.rfind('mediapipe')] - absolute_path = os.path.join(base_dir, file_path) + cwd_stop_idx = cwd.rfind('model_maker') + len('model_maker') + mm_base_dir = cwd[:cwd_stop_idx] + file_path_start_idx = file_path.find('model_maker') + len('model_maker') + 1 + mm_relative_path = file_path[file_path_start_idx:] + absolute_path = os.path.join(mm_base_dir, mm_relative_path) return absolute_path diff --git a/mediapipe/model_maker/python/core/utils/file_util_test.py b/mediapipe/model_maker/python/core/utils/file_util_test.py index 4a2d6dcfb..f9f4a5954 100644 --- a/mediapipe/model_maker/python/core/utils/file_util_test.py +++ b/mediapipe/model_maker/python/core/utils/file_util_test.py @@ -12,13 +12,68 @@ # See the License for the specific language governing permissions and # limitations under the License. import os +import tempfile +from unittest import mock as unittest_mock from absl.testing import absltest +import requests + from mediapipe.model_maker.python.core.utils import file_util class FileUtilTest(absltest.TestCase): + def setUp(self): + super().setUp() + mock_gettempdir = unittest_mock.patch.object( + tempfile, + 'gettempdir', + return_value=self.create_tempdir(), + autospec=True, + ) + self.mock_gettempdir = mock_gettempdir.start() + self.addCleanup(mock_gettempdir.stop) + + def test_get_path(self): + path = 'gesture_recognizer/hand_landmark_full.tflite' + url = 'https://storage.googleapis.com/mediapipe-assets/hand_landmark_full.tflite' + downloaded_files = file_util.DownloadedFiles(path, url, is_folder=False) + model_path = downloaded_files.get_path() + self.assertTrue(os.path.exists(model_path)) + self.assertGreater(os.path.getsize(model_path), 0) + + def test_get_path_folder(self): + folder_contents = [ + 'keras_metadata.pb', + 'saved_model.pb', + 'assets/vocab.txt', + 'variables/variables.data-00000-of-00001', + 'variables/variables.index', + ] + path = 'text_classifier/mobilebert_tiny' + url = ( + 'https://storage.googleapis.com/mediapipe-assets/mobilebert_tiny.tar.gz' + ) + downloaded_files = file_util.DownloadedFiles(path, url, is_folder=True) + model_path = downloaded_files.get_path() + self.assertTrue(os.path.exists(model_path)) + for file_name in folder_contents: + file_path = os.path.join(model_path, file_name) + self.assertTrue(os.path.exists(file_path)) + self.assertGreater(os.path.getsize(file_path), 0) + + @unittest_mock.patch.object(requests, 'get', wraps=requests.get) + def test_get_path_multiple_calls(self, mock_get): + path = 'gesture_recognizer/hand_landmark_full.tflite' + url = 'https://storage.googleapis.com/mediapipe-assets/hand_landmark_full.tflite' + downloaded_files = file_util.DownloadedFiles(path, url, is_folder=False) + model_path = downloaded_files.get_path() + self.assertTrue(os.path.exists(model_path)) + self.assertGreater(os.path.getsize(model_path), 0) + model_path_2 = downloaded_files.get_path() + self.assertEqual(model_path, model_path_2) + self.assertEqual(mock_get.call_count, 1) + def test_get_absolute_path(self): test_file = 'mediapipe/model_maker/python/core/utils/testdata/test.txt' absolute_path = file_util.get_absolute_path(test_file) diff --git a/mediapipe/model_maker/python/core/utils/model_util.py b/mediapipe/model_maker/python/core/utils/model_util.py index f10d9390c..db02444df 100644 --- a/mediapipe/model_maker/python/core/utils/model_util.py +++ b/mediapipe/model_maker/python/core/utils/model_util.py @@ -42,7 +42,9 @@ def get_default_callbacks( checkpoint_path = os.path.join(export_dir, 'checkpoint') checkpoint_callback = tf.keras.callbacks.ModelCheckpoint( - checkpoint_path, save_weights_only=True) + os.path.join(checkpoint_path, 'model-{epoch:04d}'), + save_weights_only=True, + period=5) return [summary_callback, checkpoint_callback] diff --git a/mediapipe/model_maker/python/core/utils/model_util_test.py b/mediapipe/model_maker/python/core/utils/model_util_test.py index bef9c8a97..f0020db25 100644 --- a/mediapipe/model_maker/python/core/utils/model_util_test.py +++ b/mediapipe/model_maker/python/core/utils/model_util_test.py @@ -13,10 +13,13 @@ # limitations under the License. import os +from typing import Optional +from unittest import mock as unittest_mock from absl.testing import parameterized import tensorflow as tf +from mediapipe.model_maker.python.core.utils import file_util from mediapipe.model_maker.python.core.utils import model_util from mediapipe.model_maker.python.core.utils import quantization from mediapipe.model_maker.python.core.utils import test_util @@ -24,11 +27,15 @@ from mediapipe.model_maker.python.core.utils import test_util class ModelUtilTest(tf.test.TestCase, parameterized.TestCase): - def test_load_keras_model(self): + @unittest_mock.patch.object(file_util, 'get_absolute_path', autospec=True) + def test_load_keras_model(self, mock_get_absolute_path): input_dim = 4 model = test_util.build_model(input_shape=[input_dim], num_classes=2) saved_model_path = os.path.join(self.get_temp_dir(), 'saved_model') model.save(saved_model_path) + # model_util.load_keras_model takes in a relative path to files within the + # model_maker dir, so we patch the function for testing + mock_get_absolute_path.return_value = saved_model_path loaded_model = model_util.load_keras_model(saved_model_path) input_tensors = test_util.create_random_sample(size=[1, input_dim]) @@ -36,13 +43,16 @@ class ModelUtilTest(tf.test.TestCase, parameterized.TestCase): loaded_model_output = loaded_model.predict_on_batch(input_tensors) self.assertTrue((model_output == loaded_model_output).all()) - def test_load_tflite_model_buffer(self): + @unittest_mock.patch.object(file_util, 'get_absolute_path', autospec=True) + def test_load_tflite_model_buffer(self, mock_get_absolute_path): input_dim = 4 model = test_util.build_model(input_shape=[input_dim], num_classes=2) tflite_model = model_util.convert_to_tflite(model) tflite_file = os.path.join(self.get_temp_dir(), 'model.tflite') model_util.save_tflite(tflite_model=tflite_model, tflite_file=tflite_file) - + # model_util.load_tflite_model_buffer takes in a relative path to files + # within the model_maker dir, so we patch the function for testing + mock_get_absolute_path.return_value = tflite_file tflite_model_buffer = model_util.load_tflite_model_buffer(tflite_file) test_util.test_tflite( keras_model=model, @@ -76,8 +86,10 @@ class ModelUtilTest(tf.test.TestCase, parameterized.TestCase): train_data=tf.data.Dataset.from_tensor_slices([[0, 1], [1, 1], [0, 0], [1, 0]]), expected_steps_per_epoch=2)) - def test_get_steps_per_epoch(self, steps_per_epoch, batch_size, train_data, - expected_steps_per_epoch): + def test_get_steps_per_epoch(self, steps_per_epoch: Optional[int], + batch_size: Optional[int], + train_data: Optional[tf.data.Dataset], + expected_steps_per_epoch: int): estimated_steps_per_epoch = model_util.get_steps_per_epoch( steps_per_epoch=steps_per_epoch, batch_size=batch_size, @@ -130,7 +142,9 @@ class ModelUtilTest(tf.test.TestCase, parameterized.TestCase): testcase_name='float16_quantize', config=quantization.QuantizationConfig.for_float16(), model_size=1468)) - def test_convert_to_tflite_quantized(self, config, model_size): + def test_convert_to_tflite_quantized(self, + config: quantization.QuantizationConfig, + model_size: int): input_dim = 16 num_classes = 2 max_input_value = 5 @@ -157,5 +171,6 @@ class ModelUtilTest(tf.test.TestCase, parameterized.TestCase): test_util.test_tflite_file( keras_model=model, tflite_file=tflite_file, size=[1, input_dim]) + if __name__ == '__main__': tf.test.main() diff --git a/mediapipe/model_maker/python/text/__init__.py b/mediapipe/model_maker/python/text/__init__.py new file mode 100644 index 000000000..7ca2f9216 --- /dev/null +++ b/mediapipe/model_maker/python/text/__init__.py @@ -0,0 +1,13 @@ +# Copyright 2022 The MediaPipe Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. diff --git a/mediapipe/model_maker/python/text/text_classifier/BUILD b/mediapipe/model_maker/python/text/text_classifier/BUILD index 0c35e7966..ac5b04f20 100644 --- a/mediapipe/model_maker/python/text/text_classifier/BUILD +++ b/mediapipe/model_maker/python/text/text_classifier/BUILD @@ -21,9 +21,16 @@ package( licenses(["notice"]) +###################################################################### +# Public target of the MediaPipe Model Maker TextCassifier APIs. + +# Please see https://developers.google.com/mediapipe/solutions/text/text_classifier/customize for +# more information about the MediaPipe Model Maker TextCassifier APIs. +###################################################################### py_library( name = "text_classifier_import", srcs = ["__init__.py"], + visibility = ["//visibility:public"], deps = [ ":dataset", ":model_options", @@ -46,6 +53,7 @@ py_library( deps = [ ":model_options", "//mediapipe/model_maker/python/core:hyperparameters", + "//mediapipe/model_maker/python/core/utils:file_util", "//mediapipe/model_maker/python/text/core:bert_model_spec", ], ) @@ -81,6 +89,9 @@ py_library( py_test( name = "preprocessor_test", srcs = ["preprocessor_test.py"], + data = [ + "//mediapipe/model_maker/models/text_classifier:mobilebert_tiny", + ], tags = ["requires-net:external"], deps = [ ":dataset", @@ -102,6 +113,9 @@ py_library( py_library( name = "text_classifier", srcs = ["text_classifier.py"], + data = [ + "//mediapipe/model_maker/models/text_classifier:mobilebert_tiny", + ], deps = [ ":dataset", ":model_options", @@ -123,9 +137,14 @@ py_test( size = "large", srcs = ["text_classifier_test.py"], data = [ + "//mediapipe/model_maker/models/text_classifier:mobilebert_tiny", "//mediapipe/model_maker/python/text/text_classifier/testdata", ], - tags = ["requires-net:external"], + tags = [ + "notsan", + "requires-mem:16g", + "requires-net:external", + ], deps = [ ":text_classifier_import", "//mediapipe/tasks/python/test:test_utils", @@ -144,6 +163,9 @@ py_library( py_binary( name = "text_classifier_demo", srcs = ["text_classifier_demo.py"], + data = [ + "//mediapipe/model_maker/models/text_classifier:mobilebert_tiny", + ], deps = [ ":text_classifier_demo_lib", ], diff --git a/mediapipe/model_maker/python/text/text_classifier/__init__.py b/mediapipe/model_maker/python/text/text_classifier/__init__.py index 618e51645..697461969 100644 --- a/mediapipe/model_maker/python/text/text_classifier/__init__.py +++ b/mediapipe/model_maker/python/text/text_classifier/__init__.py @@ -29,3 +29,12 @@ BertModelOptions = model_options.BertModelOptions SupportedModels = model_spec.SupportedModels TextClassifier = text_classifier.TextClassifier TextClassifierOptions = text_classifier_options.TextClassifierOptions + +# Remove duplicated and non-public API +del hyperparameters +del dataset +del model_options +del model_spec +del preprocessor # pylint: disable=undefined-variable +del text_classifier +del text_classifier_options diff --git a/mediapipe/model_maker/python/text/text_classifier/model_spec.py b/mediapipe/model_maker/python/text/text_classifier/model_spec.py index 9df7e1039..a6bdd9522 100644 --- a/mediapipe/model_maker/python/text/text_classifier/model_spec.py +++ b/mediapipe/model_maker/python/text/text_classifier/model_spec.py @@ -18,12 +18,15 @@ import enum import functools from mediapipe.model_maker.python.core import hyperparameters as hp +from mediapipe.model_maker.python.core.utils import file_util from mediapipe.model_maker.python.text.core import bert_model_spec from mediapipe.model_maker.python.text.text_classifier import model_options as mo # BERT-based text classifier spec inherited from BertModelSpec BertClassifierSpec = bert_model_spec.BertModelSpec +MOBILEBERT_TINY_PATH = 'mediapipe/model_maker/models/text_classifier/mobilebert_tiny/' + @dataclasses.dataclass class AverageWordEmbeddingClassifierSpec: @@ -49,16 +52,14 @@ average_word_embedding_classifier_spec = functools.partial( mobilebert_classifier_spec = functools.partial( BertClassifierSpec, hparams=hp.BaseHParams( - epochs=3, - batch_size=48, - learning_rate=3e-5, - distribution_strategy='off'), + epochs=3, batch_size=48, learning_rate=3e-5, distribution_strategy='off' + ), name='MobileBert', - uri='https://tfhub.dev/tensorflow/mobilebert_en_uncased_L-24_H-128_B-512_A-4_F-4_OPT/1', + uri=file_util.get_absolute_path(MOBILEBERT_TINY_PATH), tflite_input_name={ 'ids': 'serving_default_input_1:0', 'mask': 'serving_default_input_3:0', - 'segment_ids': 'serving_default_input_2:0' + 'segment_ids': 'serving_default_input_2:0', }, ) diff --git a/mediapipe/model_maker/python/text/text_classifier/model_spec_test.py b/mediapipe/model_maker/python/text/text_classifier/model_spec_test.py index dd7f880f3..3ea019b44 100644 --- a/mediapipe/model_maker/python/text/text_classifier/model_spec_test.py +++ b/mediapipe/model_maker/python/text/text_classifier/model_spec_test.py @@ -28,9 +28,10 @@ class ModelSpecTest(tf.test.TestCase): model_spec_obj = ms.SupportedModels.MOBILEBERT_CLASSIFIER.value() self.assertIsInstance(model_spec_obj, ms.BertClassifierSpec) self.assertEqual(model_spec_obj.name, 'MobileBert') - self.assertEqual( - model_spec_obj.uri, 'https://tfhub.dev/tensorflow/' - 'mobilebert_en_uncased_L-24_H-128_B-512_A-4_F-4_OPT/1') + self.assertIn( + 'mediapipe/model_maker/models/text_classifier/mobilebert_tiny', + model_spec_obj.uri, + ) self.assertTrue(model_spec_obj.do_lower_case) self.assertEqual( model_spec_obj.tflite_input_name, { diff --git a/mediapipe/model_maker/python/text/text_classifier/testdata/BUILD b/mediapipe/model_maker/python/text/text_classifier/testdata/BUILD index 663c72082..a581462cf 100644 --- a/mediapipe/model_maker/python/text/text_classifier/testdata/BUILD +++ b/mediapipe/model_maker/python/text/text_classifier/testdata/BUILD @@ -19,5 +19,8 @@ package( filegroup( name = "testdata", - srcs = ["average_word_embedding_metadata.json"], + srcs = [ + "average_word_embedding_metadata.json", + "bert_metadata.json", + ], ) diff --git a/mediapipe/model_maker/python/text/text_classifier/testdata/bert_metadata.json b/mediapipe/model_maker/python/text/text_classifier/testdata/bert_metadata.json new file mode 100644 index 000000000..24214a80d --- /dev/null +++ b/mediapipe/model_maker/python/text/text_classifier/testdata/bert_metadata.json @@ -0,0 +1,84 @@ +{ + "name": "TextClassifier", + "description": "Classify the input text into a set of known categories.", + "subgraph_metadata": [ + { + "input_tensor_metadata": [ + { + "name": "ids", + "description": "Tokenized ids of the input text.", + "content": { + "content_properties_type": "FeatureProperties", + "content_properties": { + } + }, + "stats": { + } + }, + { + "name": "mask", + "description": "Mask with 1 for real tokens and 0 for padding tokens.", + "content": { + "content_properties_type": "FeatureProperties", + "content_properties": { + } + }, + "stats": { + } + }, + { + "name": "segment_ids", + "description": "0 for the first sequence, 1 for the second sequence if exists.", + "content": { + "content_properties_type": "FeatureProperties", + "content_properties": { + } + }, + "stats": { + } + } + ], + "output_tensor_metadata": [ + { + "name": "score", + "description": "Score of the labels respectively.", + "content": { + "content_properties_type": "FeatureProperties", + "content_properties": { + } + }, + "stats": { + "max": [ + 1.0 + ], + "min": [ + 0.0 + ] + }, + "associated_files": [ + { + "name": "labels.txt", + "description": "Labels for categories that the model can recognize.", + "type": "TENSOR_AXIS_LABELS" + } + ] + } + ], + "input_process_units": [ + { + "options_type": "BertTokenizerOptions", + "options": { + "vocab_file": [ + { + "name": "vocab.txt", + "description": "Vocabulary file to convert natural language words to embedding vectors.", + "type": "VOCABULARY" + } + ] + } + } + ] + } + ], + "min_parser_version": "1.1.0" +} diff --git a/mediapipe/model_maker/python/text/text_classifier/text_classifier.py b/mediapipe/model_maker/python/text/text_classifier/text_classifier.py index c285702d2..f6abc8bf0 100644 --- a/mediapipe/model_maker/python/text/text_classifier/text_classifier.py +++ b/mediapipe/model_maker/python/text/text_classifier/text_classifier.py @@ -33,7 +33,6 @@ from mediapipe.model_maker.python.text.text_classifier import preprocessor from mediapipe.model_maker.python.text.text_classifier import text_classifier_options from mediapipe.tasks.python.metadata.metadata_writers import metadata_writer from mediapipe.tasks.python.metadata.metadata_writers import text_classifier as text_classifier_writer -from official.nlp import optimization def _validate(options: text_classifier_options.TextClassifierOptions): @@ -270,16 +269,21 @@ class _AverageWordEmbeddingClassifier(TextClassifier): """Creates an Average Word Embedding model.""" self._model = tf.keras.Sequential([ tf.keras.layers.InputLayer( - input_shape=[self._model_options.seq_len], dtype=tf.int32), + input_shape=[self._model_options.seq_len], + dtype=tf.int32, + name="input_ids", + ), tf.keras.layers.Embedding( len(self._text_preprocessor.get_vocab()), self._model_options.wordvec_dim, - input_length=self._model_options.seq_len), + input_length=self._model_options.seq_len, + ), tf.keras.layers.GlobalAveragePooling1D(), tf.keras.layers.Dense( - self._model_options.wordvec_dim, activation=tf.nn.relu), + self._model_options.wordvec_dim, activation=tf.nn.relu + ), tf.keras.layers.Dropout(self._model_options.dropout_rate), - tf.keras.layers.Dense(self._num_classes, activation="softmax") + tf.keras.layers.Dense(self._num_classes, activation="softmax"), ]) def _save_vocab(self, vocab_filepath: str): @@ -417,8 +421,22 @@ class _BertClassifier(TextClassifier): total_steps = self._hparams.steps_per_epoch * self._hparams.epochs warmup_steps = int(total_steps * 0.1) initial_lr = self._hparams.learning_rate - self._optimizer = optimization.create_optimizer(initial_lr, total_steps, - warmup_steps) + # Implements linear decay of the learning rate. + lr_schedule = tf.keras.optimizers.schedules.PolynomialDecay( + initial_learning_rate=initial_lr, + decay_steps=total_steps, + end_learning_rate=0.0, + power=1.0) + if warmup_steps: + lr_schedule = model_util.WarmUp( + initial_learning_rate=initial_lr, + decay_schedule_fn=lr_schedule, + warmup_steps=warmup_steps) + + self._optimizer = tf.keras.optimizers.experimental.AdamW( + lr_schedule, weight_decay=0.01, epsilon=1e-6, global_clipnorm=1.0) + self._optimizer.exclude_from_weight_decay( + var_names=["LayerNorm", "layer_norm", "bias"]) def _save_vocab(self, vocab_filepath: str): tf.io.gfile.copy( diff --git a/mediapipe/model_maker/python/text/text_classifier/text_classifier_test.py b/mediapipe/model_maker/python/text/text_classifier/text_classifier_test.py index 7a30d19fd..1ae2bc553 100644 --- a/mediapipe/model_maker/python/text/text_classifier/text_classifier_test.py +++ b/mediapipe/model_maker/python/text/text_classifier/text_classifier_test.py @@ -26,6 +26,9 @@ class TextClassifierTest(tf.test.TestCase): _AVERAGE_WORD_EMBEDDING_JSON_FILE = ( test_utils.get_test_data_path('average_word_embedding_metadata.json')) + _BERT_CLASSIFIER_JSON_FILE = test_utils.get_test_data_path( + 'bert_metadata.json' + ) def _get_data(self): labels_and_text = (('pos', 'super good'), (('neg', 'really bad'))) @@ -71,9 +74,12 @@ class TextClassifierTest(tf.test.TestCase): self.assertTrue(os.path.exists(output_metadata_file)) self.assertGreater(os.path.getsize(output_metadata_file), 0) + filecmp.clear_cache() self.assertTrue( - filecmp.cmp(output_metadata_file, - self._AVERAGE_WORD_EMBEDDING_JSON_FILE)) + filecmp.cmp( + output_metadata_file, + self._AVERAGE_WORD_EMBEDDING_JSON_FILE, + shallow=False)) def test_create_and_train_bert(self): train_data, validation_data = self._get_data() @@ -91,7 +97,27 @@ class TextClassifierTest(tf.test.TestCase): _, accuracy = bert_classifier.evaluate(validation_data) self.assertGreaterEqual(accuracy, 0.0) - # TODO: Add a unit test that does not run OOM. + + # Test export_model + bert_classifier.export_model() + output_metadata_file = os.path.join( + options.hparams.export_dir, 'metadata.json' + ) + output_tflite_file = os.path.join( + options.hparams.export_dir, 'model.tflite' + ) + + self.assertTrue(os.path.exists(output_tflite_file)) + self.assertGreater(os.path.getsize(output_tflite_file), 0) + + self.assertTrue(os.path.exists(output_metadata_file)) + self.assertGreater(os.path.getsize(output_metadata_file), 0) + filecmp.clear_cache() + self.assertTrue( + filecmp.cmp( + output_metadata_file, self._BERT_CLASSIFIER_JSON_FILE, shallow=False + ) + ) def test_label_mismatch(self): options = ( diff --git a/mediapipe/model_maker/python/vision/gesture_recognizer/BUILD b/mediapipe/model_maker/python/vision/gesture_recognizer/BUILD index b7d334d9c..cbdff7cf3 100644 --- a/mediapipe/model_maker/python/vision/gesture_recognizer/BUILD +++ b/mediapipe/model_maker/python/vision/gesture_recognizer/BUILD @@ -24,9 +24,9 @@ package( # TODO: Remove the unncessary test data once the demo data are moved to an open-sourced # directory. filegroup( - name = "test_data", + name = "testdata", srcs = glob([ - "test_data/**", + "testdata/**", ]), ) @@ -35,31 +35,33 @@ py_library( srcs = ["constants.py"], ) -# TODO: Change to py_library after migrating the MediaPipe hand solution -# library to MediaPipe hand task library. py_library( name = "dataset", srcs = ["dataset.py"], deps = [ ":constants", + ":metadata_writer", "//mediapipe/model_maker/python/core/data:classification_dataset", - "//mediapipe/model_maker/python/core/data:data_util", "//mediapipe/model_maker/python/core/utils:model_util", - "//mediapipe/python/solutions:hands", + "//mediapipe/python:_framework_bindings", + "//mediapipe/tasks/python/core:base_options", + "//mediapipe/tasks/python/vision:hand_landmarker", ], ) +# TODO: Remove notsan tag once tasks no longer has race condition issue py_test( name = "dataset_test", srcs = ["dataset_test.py"], data = [ - ":test_data", + ":testdata", "//mediapipe/model_maker/models/gesture_recognizer:models", ], + tags = ["notsan"], deps = [ ":dataset", - "//mediapipe/python/solutions:hands", "//mediapipe/tasks/python/test:test_utils", + "//mediapipe/tasks/python/vision:hand_landmarker", ], ) @@ -103,9 +105,16 @@ py_library( ], ) +###################################################################### +# Public target of the MediaPipe Model Maker GestureRecognizer APIs. + +# Please see https://developers.google.com/mediapipe/solutions/vision/gesture_recognizer/customize +# for more information about the MediaPipe Model Maker GestureRecognizer APIs. +###################################################################### py_library( name = "gesture_recognizer_import", srcs = ["__init__.py"], + visibility = ["//visibility:public"], deps = [ ":dataset", ":gesture_recognizer", @@ -124,17 +133,21 @@ py_library( ], ) +# TODO: Remove notsan tag once tasks no longer has race condition issue py_test( name = "gesture_recognizer_test", size = "large", srcs = ["gesture_recognizer_test.py"], data = [ - ":test_data", + ":testdata", "//mediapipe/model_maker/models/gesture_recognizer:models", ], shard_count = 2, + tags = ["notsan"], deps = [ ":gesture_recognizer_import", + ":hyperparameters", + ":model_options", "//mediapipe/model_maker/python/core/utils:test_util", "//mediapipe/tasks/python/test:test_utils", ], @@ -144,7 +157,7 @@ py_test( name = "metadata_writer_test", srcs = ["metadata_writer_test.py"], data = [ - ":test_data", + ":testdata", ], deps = [ ":metadata_writer", @@ -157,7 +170,7 @@ py_binary( name = "gesture_recognizer_demo", srcs = ["gesture_recognizer_demo.py"], data = [ - ":test_data", + ":testdata", "//mediapipe/model_maker/models/gesture_recognizer:models", ], python_version = "PY3", diff --git a/mediapipe/model_maker/python/vision/gesture_recognizer/__init__.py b/mediapipe/model_maker/python/vision/gesture_recognizer/__init__.py index dc6923fac..a302e8d79 100644 --- a/mediapipe/model_maker/python/vision/gesture_recognizer/__init__.py +++ b/mediapipe/model_maker/python/vision/gesture_recognizer/__init__.py @@ -25,3 +25,12 @@ HParams = hyperparameters.HParams Dataset = dataset.Dataset HandDataPreprocessingParams = dataset.HandDataPreprocessingParams GestureRecognizerOptions = gesture_recognizer_options.GestureRecognizerOptions + +# Remove duplicated and non-public API +del constants # pylint: disable=undefined-variable +del dataset +del gesture_recognizer +del gesture_recognizer_options +del hyperparameters +del metadata_writer # pylint: disable=undefined-variable +del model_options diff --git a/mediapipe/model_maker/python/vision/gesture_recognizer/dataset.py b/mediapipe/model_maker/python/vision/gesture_recognizer/dataset.py index 256f26fd6..6a2c878c0 100644 --- a/mediapipe/model_maker/python/vision/gesture_recognizer/dataset.py +++ b/mediapipe/model_maker/python/vision/gesture_recognizer/dataset.py @@ -16,16 +16,22 @@ import dataclasses import os import random -from typing import List, NamedTuple, Optional +from typing import List, Optional -import cv2 import tensorflow as tf from mediapipe.model_maker.python.core.data import classification_dataset -from mediapipe.model_maker.python.core.data import data_util from mediapipe.model_maker.python.core.utils import model_util from mediapipe.model_maker.python.vision.gesture_recognizer import constants -from mediapipe.python.solutions import hands as mp_hands +from mediapipe.model_maker.python.vision.gesture_recognizer import metadata_writer +from mediapipe.python._framework_bindings import image as image_module +from mediapipe.tasks.python.core import base_options as base_options_module +from mediapipe.tasks.python.vision import hand_landmarker as hand_landmarker_module + +_Image = image_module.Image +_HandLandmarker = hand_landmarker_module.HandLandmarker +_HandLandmarkerOptions = hand_landmarker_module.HandLandmarkerOptions +_HandLandmarkerResult = hand_landmarker_module.HandLandmarkerResult @dataclasses.dataclass @@ -59,7 +65,7 @@ class HandData: handedness: List[float] -def _validate_data_sample(data: NamedTuple) -> bool: +def _validate_data_sample(data: _HandLandmarkerResult) -> bool: """Validates the input hand data sample. Args: @@ -70,19 +76,17 @@ def _validate_data_sample(data: NamedTuple) -> bool: 'multi_hand_landmarks' or 'multi_hand_world_landmarks' or 'multi_handedness' or any of these attributes' values are none. Otherwise, True. """ - if (not hasattr(data, 'multi_hand_landmarks') or - data.multi_hand_landmarks is None): + if data.hand_landmarks is None or not data.hand_landmarks: return False - if (not hasattr(data, 'multi_hand_world_landmarks') or - data.multi_hand_world_landmarks is None): + if data.hand_world_landmarks is None or not data.hand_world_landmarks: return False - if not hasattr(data, 'multi_handedness') or data.multi_handedness is None: + if data.handedness is None or not data.handedness: return False return True def _get_hand_data(all_image_paths: List[str], - min_detection_confidence: float) -> Optional[HandData]: + min_detection_confidence: float) -> List[Optional[HandData]]: """Computes hand data (landmarks and handedness) in the input image. Args: @@ -93,28 +97,36 @@ def _get_hand_data(all_image_paths: List[str], A HandData object. Returns None if no hand is detected. """ hand_data_result = [] - with mp_hands.Hands( - static_image_mode=True, - max_num_hands=1, - min_detection_confidence=min_detection_confidence) as hands: + hand_detector_model_buffer = model_util.load_tflite_model_buffer( + constants.HAND_DETECTOR_TFLITE_MODEL_FILE) + hand_landmarks_detector_model_buffer = model_util.load_tflite_model_buffer( + constants.HAND_LANDMARKS_DETECTOR_TFLITE_MODEL_FILE) + hand_landmarker_writer = metadata_writer.HandLandmarkerMetadataWriter( + hand_detector_model_buffer, hand_landmarks_detector_model_buffer) + hand_landmarker_options = _HandLandmarkerOptions( + base_options=base_options_module.BaseOptions( + model_asset_buffer=hand_landmarker_writer.populate()), + num_hands=1, + min_hand_detection_confidence=min_detection_confidence, + min_hand_presence_confidence=0.5, + min_tracking_confidence=1, + ) + with _HandLandmarker.create_from_options( + hand_landmarker_options) as hand_landmarker: for path in all_image_paths: tf.compat.v1.logging.info('Loading image %s', path) - image = data_util.load_image(path) - # Flip image around y-axis for correct handedness output - image = cv2.flip(image, 1) - data = hands.process(image) + image = _Image.create_from_file(path) + data = hand_landmarker.detect(image) if not _validate_data_sample(data): hand_data_result.append(None) continue - hand_landmarks = [[ - hand_landmark.x, hand_landmark.y, hand_landmark.z - ] for hand_landmark in data.multi_hand_landmarks[0].landmark] + hand_landmarks = [[hand_landmark.x, hand_landmark.y, hand_landmark.z] + for hand_landmark in data.hand_landmarks[0]] hand_world_landmarks = [[ hand_landmark.x, hand_landmark.y, hand_landmark.z - ] for hand_landmark in data.multi_hand_world_landmarks[0].landmark] + ] for hand_landmark in data.hand_world_landmarks[0]] handedness_scores = [ - handedness.score - for handedness in data.multi_handedness[0].classification + handedness.score for handedness in data.handedness[0] ] hand_data_result.append( HandData( diff --git a/mediapipe/model_maker/python/vision/gesture_recognizer/dataset_test.py b/mediapipe/model_maker/python/vision/gesture_recognizer/dataset_test.py index 76e70a58d..528d02edd 100644 --- a/mediapipe/model_maker/python/vision/gesture_recognizer/dataset_test.py +++ b/mediapipe/model_maker/python/vision/gesture_recognizer/dataset_test.py @@ -12,21 +12,17 @@ # See the License for the specific language governing permissions and # limitations under the License. -import collections import os import shutil from typing import NamedTuple import unittest -from absl import flags from absl.testing import parameterized import tensorflow as tf from mediapipe.model_maker.python.vision.gesture_recognizer import dataset -from mediapipe.python.solutions import hands as mp_hands from mediapipe.tasks.python.test import test_utils - -FLAGS = flags.FLAGS +from mediapipe.tasks.python.vision import hand_landmarker _TEST_DATA_DIRNAME = 'raw_data' @@ -39,14 +35,14 @@ class DatasetTest(tf.test.TestCase, parameterized.TestCase): dirname=input_data_dir, hparams=dataset.HandDataPreprocessingParams()) train_data, test_data = data.split(0.5) - self.assertLen(train_data, 17) + self.assertLen(train_data, 16) for _, elem in enumerate(train_data.gen_tf_dataset(is_training=True)): self.assertEqual(elem[0].shape, (1, 128)) self.assertEqual(elem[1].shape, ([1, 4])) self.assertEqual(train_data.num_classes, 4) self.assertEqual(train_data.label_names, ['none', 'call', 'four', 'rock']) - self.assertLen(test_data, 18) + self.assertLen(test_data, 16) for _, elem in enumerate(test_data.gen_tf_dataset(is_training=True)): self.assertEqual(elem[0].shape, (1, 128)) self.assertEqual(elem[1].shape, ([1, 4])) @@ -60,7 +56,7 @@ class DatasetTest(tf.test.TestCase, parameterized.TestCase): for _, elem in enumerate(data.gen_tf_dataset(is_training=True)): self.assertEqual(elem[0].shape, (1, 128)) self.assertEqual(elem[1].shape, ([1, 4])) - self.assertLen(data, 35) + self.assertLen(data, 32) self.assertEqual(data.num_classes, 4) self.assertEqual(data.label_names, ['none', 'call', 'four', 'rock']) @@ -105,51 +101,42 @@ class DatasetTest(tf.test.TestCase, parameterized.TestCase): for _, elem in enumerate(data.gen_tf_dataset(is_training=True)): self.assertEqual(elem[0].shape, (1, 128)) self.assertEqual(elem[1].shape, ([1, 4])) - self.assertLen(data, 35) + self.assertLen(data, 32) self.assertEqual(data.num_classes, 4) self.assertEqual(data.label_names, ['NONE', 'CALL', 'FOUR', 'ROCK']) @parameterized.named_parameters( dict( - testcase_name='invalid_field_name_multi_hand_landmark', - hand=collections.namedtuple('Hand', [ - 'multi_hand_landmark', 'multi_hand_world_landmarks', - 'multi_handedness' - ])(1, 2, 3)), + testcase_name='none_handedness', + hand=hand_landmarker.HandLandmarkerResult( + handedness=None, hand_landmarks=[[2]], + hand_world_landmarks=[[3]])), dict( - testcase_name='invalid_field_name_multi_hand_world_landmarks', - hand=collections.namedtuple('Hand', [ - 'multi_hand_landmarks', 'multi_hand_world_landmark', - 'multi_handedness' - ])(1, 2, 3)), + testcase_name='none_hand_landmarks', + hand=hand_landmarker.HandLandmarkerResult( + handedness=[[1]], hand_landmarks=None, + hand_world_landmarks=[[3]])), dict( - testcase_name='invalid_field_name_multi_handed', - hand=collections.namedtuple('Hand', [ - 'multi_hand_landmarks', 'multi_hand_world_landmarks', - 'multi_handed' - ])(1, 2, 3)), + testcase_name='none_hand_world_landmarks', + hand=hand_landmarker.HandLandmarkerResult( + handedness=[[1]], hand_landmarks=[[2]], + hand_world_landmarks=None)), dict( - testcase_name='multi_hand_landmarks_is_none', - hand=collections.namedtuple('Hand', [ - 'multi_hand_landmarks', 'multi_hand_world_landmarks', - 'multi_handedness' - ])(None, 2, 3)), + testcase_name='empty_handedness', + hand=hand_landmarker.HandLandmarkerResult( + handedness=[], hand_landmarks=[[2]], hand_world_landmarks=[[3]])), dict( - testcase_name='multi_hand_world_landmarks_is_none', - hand=collections.namedtuple('Hand', [ - 'multi_hand_landmarks', 'multi_hand_world_landmarks', - 'multi_handedness' - ])(1, None, 3)), + testcase_name='empty_hand_landmarks', + hand=hand_landmarker.HandLandmarkerResult( + handedness=[[1]], hand_landmarks=[], hand_world_landmarks=[[3]])), dict( - testcase_name='multi_handedness_is_none', - hand=collections.namedtuple('Hand', [ - 'multi_hand_landmarks', 'multi_hand_world_landmarks', - 'multi_handedness' - ])(1, 2, None)), + testcase_name='empty_hand_world_landmarks', + hand=hand_landmarker.HandLandmarkerResult( + handedness=[[1]], hand_landmarks=[[2]], hand_world_landmarks=[])), ) def test_create_dataset_from_invalid_hand_data(self, hand: NamedTuple): with unittest.mock.patch.object( - mp_hands.Hands, 'process', return_value=hand): + hand_landmarker.HandLandmarker, 'detect', return_value=hand): input_data_dir = test_utils.get_test_data_path(_TEST_DATA_DIRNAME) with self.assertRaisesRegex(ValueError, 'No valid hand is detected'): dataset.Dataset.from_folder( diff --git a/mediapipe/model_maker/python/vision/gesture_recognizer/gesture_recognizer.py b/mediapipe/model_maker/python/vision/gesture_recognizer/gesture_recognizer.py index f297d8640..b27f7161f 100644 --- a/mediapipe/model_maker/python/vision/gesture_recognizer/gesture_recognizer.py +++ b/mediapipe/model_maker/python/vision/gesture_recognizer/gesture_recognizer.py @@ -53,6 +53,10 @@ class GestureRecognizer(classifier.Classifier): model_spec=None, label_names=label_names, shuffle=hparams.shuffle) self._model_options = model_options self._hparams = hparams + self._loss_function = loss_functions.FocalLoss(gamma=self._hparams.gamma) + self._metric_function = 'categorical_accuracy' + self._optimizer = 'adam' + self._callbacks = self._get_callbacks() self._history = None self.embedding_size = _EMBEDDING_SIZE @@ -71,7 +75,7 @@ class GestureRecognizer(classifier.Classifier): Args: train_data: Training data. - validation_data: Validation data. If None, skips validation process. + validation_data: Validation data. options: options for creating and training gesture recognizer model. Returns: @@ -87,49 +91,39 @@ class GestureRecognizer(classifier.Classifier): label_names=train_data.label_names, model_options=options.model_options, hparams=options.hparams) - - gesture_recognizer._create_model() - - train_dataset = train_data.gen_tf_dataset( - batch_size=options.hparams.batch_size, - is_training=True, - shuffle=options.hparams.shuffle) - options.hparams.steps_per_epoch = model_util.get_steps_per_epoch( - steps_per_epoch=options.hparams.steps_per_epoch, - batch_size=options.hparams.batch_size, - train_data=train_data) - train_dataset = train_dataset.take(count=options.hparams.steps_per_epoch) - - validation_dataset = validation_data.gen_tf_dataset( - batch_size=options.hparams.batch_size, is_training=False) - - tf.compat.v1.logging.info('Training the gesture recognizer model...') - gesture_recognizer._train( - train_data=train_dataset, validation_data=validation_dataset) - + gesture_recognizer._create_and_train_model(train_data, validation_data) return gesture_recognizer - def _train(self, train_data: tf.data.Dataset, - validation_data: tf.data.Dataset): - """Trains the model with input train_data. - - The training results are recorded by a self.History object returned by - tf.keras.Model.fit(). + def _create_and_train_model( + self, + train_data: classification_ds.ClassificationDataset, + validation_data: classification_ds.ClassificationDataset, + ): + """Creates and trains the model. Args: train_data: Training data. validation_data: Validation data. """ + self._create_model() + self._train_model( + train_data=train_data, + validation_data=validation_data, + checkpoint_path=self._get_checkpoint_path(), + ) + + def _get_callbacks(self) -> List[tf.keras.callbacks.Callback]: + """Gets the list of callbacks to use in model training.""" hparams = self._hparams scheduler = lambda epoch: hparams.learning_rate * (hparams.lr_decay**epoch) scheduler_callback = tf.keras.callbacks.LearningRateScheduler(scheduler) job_dir = hparams.export_dir - checkpoint_path = os.path.join(job_dir, 'epoch_models') checkpoint_callback = tf.keras.callbacks.ModelCheckpoint( - os.path.join(checkpoint_path, 'model-{epoch:04d}'), - save_weights_only=True) + os.path.join(self._get_checkpoint_path(), 'model-{epoch:04d}'), + save_weights_only=True, + ) best_model_path = os.path.join(job_dir, 'best_model_weights') best_model_callback = tf.keras.callbacks.ModelCheckpoint( @@ -141,27 +135,15 @@ class GestureRecognizer(classifier.Classifier): tensorboard_callback = tf.keras.callbacks.TensorBoard( log_dir=os.path.join(job_dir, 'logs')) + return [ + checkpoint_callback, + best_model_callback, + scheduler_callback, + tensorboard_callback, + ] - self._model.compile( - optimizer='adam', - loss=loss_functions.FocalLoss(gamma=self._hparams.gamma), - metrics=['categorical_accuracy']) - - latest_checkpoint = tf.train.latest_checkpoint(checkpoint_path) - if latest_checkpoint: - print(f'Resuming from {latest_checkpoint}') - self._model.load_weights(latest_checkpoint) - - self._history = self._model.fit( - x=train_data, - epochs=hparams.epochs, - validation_data=validation_data, - validation_freq=1, - callbacks=[ - checkpoint_callback, best_model_callback, scheduler_callback, - tensorboard_callback - ], - ) + def _get_checkpoint_path(self) -> str: + return os.path.join(self._hparams.export_dir, 'epoch_models') def _create_model(self): """Creates the hand gesture recognizer model. @@ -172,16 +154,22 @@ class GestureRecognizer(classifier.Classifier): shape=[self.embedding_size], batch_size=None, dtype=tf.float32, - name='hand_embedding') - - x = tf.keras.layers.BatchNormalization()(inputs) - x = tf.keras.layers.ReLU()(x) + name='hand_embedding', + ) + x = inputs dropout_rate = self._model_options.dropout_rate - x = tf.keras.layers.Dropout(rate=dropout_rate, name='dropout')(x) + for i, width in enumerate(self._model_options.layer_widths): + x = tf.keras.layers.BatchNormalization()(x) + x = tf.keras.layers.ReLU()(x) + x = tf.keras.layers.Dropout(rate=dropout_rate)(x) + x = tf.keras.layers.Dense(width, name=f'custom_gesture_recognizer_{i}')(x) + x = tf.keras.layers.BatchNormalization()(x) + x = tf.keras.layers.ReLU()(x) + x = tf.keras.layers.Dropout(rate=dropout_rate)(x) outputs = tf.keras.layers.Dense( self._num_classes, activation='softmax', - name='custom_gesture_recognizer')( + name='custom_gesture_recognizer_out')( x) self._model = tf.keras.Model(inputs=inputs, outputs=outputs) diff --git a/mediapipe/model_maker/python/vision/gesture_recognizer/gesture_recognizer_demo.py b/mediapipe/model_maker/python/vision/gesture_recognizer/gesture_recognizer_demo.py index 06075fbc6..1cf9f0619 100644 --- a/mediapipe/model_maker/python/vision/gesture_recognizer/gesture_recognizer_demo.py +++ b/mediapipe/model_maker/python/vision/gesture_recognizer/gesture_recognizer_demo.py @@ -31,7 +31,7 @@ FLAGS = flags.FLAGS # TODO: Move hand gesture recognizer demo dataset to an # open-sourced directory. -TEST_DATA_DIR = 'mediapipe/model_maker/python/vision/gesture_recognizer/test_data/raw_data' +TEST_DATA_DIR = 'mediapipe/model_maker/python/vision/gesture_recognizer/testdata/raw_data' def define_flags(): diff --git a/mediapipe/model_maker/python/vision/gesture_recognizer/gesture_recognizer_test.py b/mediapipe/model_maker/python/vision/gesture_recognizer/gesture_recognizer_test.py index eb2b1d171..4fdb74225 100644 --- a/mediapipe/model_maker/python/vision/gesture_recognizer/gesture_recognizer_test.py +++ b/mediapipe/model_maker/python/vision/gesture_recognizer/gesture_recognizer_test.py @@ -14,6 +14,7 @@ import io import os +import tempfile from unittest import mock as unittest_mock import zipfile @@ -22,9 +23,12 @@ import tensorflow as tf from mediapipe.model_maker.python.core.utils import test_util from mediapipe.model_maker.python.vision import gesture_recognizer +from mediapipe.model_maker.python.vision.gesture_recognizer import hyperparameters +from mediapipe.model_maker.python.vision.gesture_recognizer import model_options from mediapipe.tasks.python.test import test_utils -_TEST_DATA_DIR = 'mediapipe/model_maker/python/vision/gesture_recognizer/test_data' +_TEST_DATA_DIR = 'mediapipe/model_maker/python/vision/gesture_recognizer/testdata' +tf.keras.backend.experimental.enable_tf_random_generator() class GestureRecognizerTest(tf.test.TestCase): @@ -40,30 +44,62 @@ class GestureRecognizerTest(tf.test.TestCase): def setUp(self): super().setUp() - self._model_options = gesture_recognizer.ModelOptions() - self._hparams = gesture_recognizer.HParams(epochs=2) - self._gesture_recognizer_options = ( - gesture_recognizer.GestureRecognizerOptions( - model_options=self._model_options, hparams=self._hparams)) + tf.keras.utils.set_random_seed(87654321) all_data = self._load_data() - # Splits data, 90% data for training, 10% for testing - self._train_data, self._test_data = all_data.split(0.9) + # Splits data, 90% data for training, 10% for validation + self._train_data, self._validation_data = all_data.split(0.9) def test_gesture_recognizer_model(self): + mo = gesture_recognizer.ModelOptions() + hparams = gesture_recognizer.HParams( + export_dir=tempfile.mkdtemp(), epochs=2) + gesture_recognizer_options = gesture_recognizer.GestureRecognizerOptions( + model_options=mo, hparams=hparams) model = gesture_recognizer.GestureRecognizer.create( train_data=self._train_data, - validation_data=self._test_data, - options=self._gesture_recognizer_options) + validation_data=self._validation_data, + options=gesture_recognizer_options) self._test_accuracy(model) - def test_export_gesture_recognizer_model(self): + @unittest_mock.patch.object( + tf.keras.layers, 'Dense', wraps=tf.keras.layers.Dense) + def test_gesture_recognizer_model_layer_widths(self, mock_dense): + layer_widths = [64, 32] + mo = gesture_recognizer.ModelOptions(layer_widths=layer_widths) + hparams = gesture_recognizer.HParams( + export_dir=tempfile.mkdtemp(), epochs=2) + gesture_recognizer_options = gesture_recognizer.GestureRecognizerOptions( + model_options=mo, hparams=hparams) model = gesture_recognizer.GestureRecognizer.create( train_data=self._train_data, - validation_data=self._test_data, - options=self._gesture_recognizer_options) + validation_data=self._validation_data, + options=gesture_recognizer_options) + expected_calls = [ + unittest_mock.call(w, name=f'custom_gesture_recognizer_{i}') + for i, w in enumerate(layer_widths) + ] + expected_calls.append( + unittest_mock.call( + len(self._train_data.label_names), + activation='softmax', + name='custom_gesture_recognizer_out')) + self.assertLen(mock_dense.call_args_list, len(expected_calls)) + mock_dense.assert_has_calls(expected_calls) + self._test_accuracy(model) + + def test_export_gesture_recognizer_model(self): + mo = gesture_recognizer.ModelOptions() + hparams = gesture_recognizer.HParams( + export_dir=tempfile.mkdtemp(), epochs=2) + gesture_recognizer_options = gesture_recognizer.GestureRecognizerOptions( + model_options=mo, hparams=hparams) + model = gesture_recognizer.GestureRecognizer.create( + train_data=self._train_data, + validation_data=self._validation_data, + options=gesture_recognizer_options) model.export_model() - model_bundle_file = os.path.join(self._hparams.export_dir, + model_bundle_file = os.path.join(hparams.export_dir, 'gesture_recognizer.task') with zipfile.ZipFile(model_bundle_file) as zf: self.assertEqual( @@ -87,42 +123,48 @@ class GestureRecognizerTest(tf.test.TestCase): tflite_file=gesture_classifier_tflite_file, size=[1, model.embedding_size]) - def _test_accuracy(self, model, threshold=0.5): - _, accuracy = model.evaluate(self._test_data) - tf.compat.v1.logging.info(f'accuracy: {accuracy}') - self.assertGreaterEqual(accuracy, threshold) + def _test_accuracy(self, model, threshold=0.0): + # Test on _train_data because of our limited dataset size + _, accuracy = model.evaluate(self._train_data) + tf.compat.v1.logging.info(f'train accuracy: {accuracy}') + self.assertGreater(accuracy, threshold) @unittest_mock.patch.object( - gesture_recognizer.hyperparameters, + hyperparameters, 'HParams', autospec=True, return_value=gesture_recognizer.HParams(epochs=1)) @unittest_mock.patch.object( - gesture_recognizer.model_options, + model_options, 'GestureRecognizerModelOptions', autospec=True, return_value=gesture_recognizer.ModelOptions()) - def test_create_hparams_and_model_options_if_none_in_image_classifier_options( + def test_create_hparams_and_model_options_if_none_in_gesture_recognizer_options( self, mock_hparams, mock_model_options): options = gesture_recognizer.GestureRecognizerOptions() gesture_recognizer.GestureRecognizer.create( train_data=self._train_data, - validation_data=self._test_data, + validation_data=self._validation_data, options=options) mock_hparams.assert_called_once() mock_model_options.assert_called_once() def test_continual_training_by_loading_checkpoint(self): + mo = gesture_recognizer.ModelOptions() + hparams = gesture_recognizer.HParams( + export_dir=tempfile.mkdtemp(), epochs=2) + gesture_recognizer_options = gesture_recognizer.GestureRecognizerOptions( + model_options=mo, hparams=hparams) mock_stdout = io.StringIO() with mock.patch('sys.stdout', mock_stdout): model = gesture_recognizer.GestureRecognizer.create( train_data=self._train_data, - validation_data=self._test_data, - options=self._gesture_recognizer_options) + validation_data=self._validation_data, + options=gesture_recognizer_options) model = gesture_recognizer.GestureRecognizer.create( train_data=self._train_data, - validation_data=self._test_data, - options=self._gesture_recognizer_options) + validation_data=self._validation_data, + options=gesture_recognizer_options) self._test_accuracy(model) self.assertRegex(mock_stdout.getvalue(), 'Resuming from') diff --git a/mediapipe/model_maker/python/vision/gesture_recognizer/metadata_writer.py b/mediapipe/model_maker/python/vision/gesture_recognizer/metadata_writer.py index 58b67e072..b2e851afe 100644 --- a/mediapipe/model_maker/python/vision/gesture_recognizer/metadata_writer.py +++ b/mediapipe/model_maker/python/vision/gesture_recognizer/metadata_writer.py @@ -62,6 +62,50 @@ def read_file(file_path: str, mode: str = "rb") -> Union[str, bytes]: return f.read() +class HandLandmarkerMetadataWriter: + """MetadataWriter to write the model asset bundle for HandLandmarker.""" + + def __init__( + self, + hand_detector_model_buffer: bytearray, + hand_landmarks_detector_model_buffer: bytearray, + ) -> None: + """Initializes HandLandmarkerMetadataWriter to write model asset bundle. + + Args: + hand_detector_model_buffer: A valid flatbuffer *with* metadata loaded from + the TFLite hand detector model file. + hand_landmarks_detector_model_buffer: A valid flatbuffer *with* metadata + loaded from the TFLite hand landmarks detector model file. + """ + self._hand_detector_model_buffer = hand_detector_model_buffer + self._hand_landmarks_detector_model_buffer = hand_landmarks_detector_model_buffer + self._temp_folder = tempfile.TemporaryDirectory() + + def __del__(self): + if os.path.exists(self._temp_folder.name): + self._temp_folder.cleanup() + + def populate(self): + """Creates the model asset bundle for hand landmarker task. + + Returns: + Model asset bundle in bytes + """ + landmark_models = { + _HAND_DETECTOR_TFLITE_NAME: + self._hand_detector_model_buffer, + _HAND_LANDMARKS_DETECTOR_TFLITE_NAME: + self._hand_landmarks_detector_model_buffer + } + output_hand_landmarker_path = os.path.join(self._temp_folder.name, + _HAND_LANDMARKER_BUNDLE_NAME) + writer_utils.create_model_asset_bundle(landmark_models, + output_hand_landmarker_path) + hand_landmarker_model_buffer = read_file(output_hand_landmarker_path) + return hand_landmarker_model_buffer + + class MetadataWriter: """MetadataWriter to write the metadata and the model asset bundle.""" @@ -86,8 +130,8 @@ class MetadataWriter: custom_gesture_classifier_metadata_writer: Metadata writer to write custom gesture classifier metadata into the TFLite file. """ - self._hand_detector_model_buffer = hand_detector_model_buffer - self._hand_landmarks_detector_model_buffer = hand_landmarks_detector_model_buffer + self._hand_landmarker_metadata_writer = HandLandmarkerMetadataWriter( + hand_detector_model_buffer, hand_landmarks_detector_model_buffer) self._gesture_embedder_model_buffer = gesture_embedder_model_buffer self._canned_gesture_classifier_model_buffer = canned_gesture_classifier_model_buffer self._custom_gesture_classifier_metadata_writer = custom_gesture_classifier_metadata_writer @@ -147,16 +191,8 @@ class MetadataWriter: A tuple of (model_asset_bundle_in_bytes, metadata_json_content) """ # Creates the model asset bundle for hand landmarker task. - landmark_models = { - _HAND_DETECTOR_TFLITE_NAME: - self._hand_detector_model_buffer, - _HAND_LANDMARKS_DETECTOR_TFLITE_NAME: - self._hand_landmarks_detector_model_buffer - } - output_hand_landmarker_path = os.path.join(self._temp_folder.name, - _HAND_LANDMARKER_BUNDLE_NAME) - writer_utils.create_model_asset_bundle(landmark_models, - output_hand_landmarker_path) + hand_landmarker_model_buffer = self._hand_landmarker_metadata_writer.populate( + ) # Write metadata into custom gesture classifier model. self._custom_gesture_classifier_model_buffer, custom_gesture_classifier_metadata_json = self._custom_gesture_classifier_metadata_writer.populate( @@ -179,7 +215,7 @@ class MetadataWriter: # graph. gesture_recognizer_models = { _HAND_LANDMARKER_BUNDLE_NAME: - read_file(output_hand_landmarker_path), + hand_landmarker_model_buffer, _HAND_GESTURE_RECOGNIZER_BUNDLE_NAME: read_file(output_hand_gesture_recognizer_path), } diff --git a/mediapipe/model_maker/python/vision/gesture_recognizer/metadata_writer_test.py b/mediapipe/model_maker/python/vision/gesture_recognizer/metadata_writer_test.py index e1101e066..fd26b274d 100644 --- a/mediapipe/model_maker/python/vision/gesture_recognizer/metadata_writer_test.py +++ b/mediapipe/model_maker/python/vision/gesture_recognizer/metadata_writer_test.py @@ -23,7 +23,7 @@ from mediapipe.model_maker.python.vision.gesture_recognizer import metadata_writ from mediapipe.tasks.python.metadata.metadata_writers import metadata_writer as base_metadata_writer from mediapipe.tasks.python.test import test_utils -_TEST_DATA_DIR = "mediapipe/model_maker/python/vision/gesture_recognizer/test_data/metadata" +_TEST_DATA_DIR = "mediapipe/model_maker/python/vision/gesture_recognizer/testdata/metadata" _EXPECTED_JSON = test_utils.get_test_data_path( os.path.join(_TEST_DATA_DIR, "custom_gesture_classifier_meta.json")) @@ -33,6 +33,23 @@ _CUSTOM_GESTURE_CLASSIFIER_PATH = test_utils.get_test_data_path( class MetadataWriterTest(tf.test.TestCase): + def test_hand_landmarker_metadata_writer(self): + # Use dummy model buffer for unit test only. + hand_detector_model_buffer = b"\x11\x12" + hand_landmarks_detector_model_buffer = b"\x22" + writer = metadata_writer.HandLandmarkerMetadataWriter( + hand_detector_model_buffer, hand_landmarks_detector_model_buffer) + model_bundle_content = writer.populate() + model_bundle_filepath = os.path.join(self.get_temp_dir(), + "hand_landmarker.task") + with open(model_bundle_filepath, "wb") as f: + f.write(model_bundle_content) + + with zipfile.ZipFile(model_bundle_filepath) as zf: + self.assertEqual( + set(zf.namelist()), + set(["hand_landmarks_detector.tflite", "hand_detector.tflite"])) + def test_write_metadata_and_create_model_asset_bundle_successful(self): # Use dummy model buffer for unit test only. hand_detector_model_buffer = b"\x11\x12" diff --git a/mediapipe/model_maker/python/vision/gesture_recognizer/model_options.py b/mediapipe/model_maker/python/vision/gesture_recognizer/model_options.py index 79a84c792..1870437d4 100644 --- a/mediapipe/model_maker/python/vision/gesture_recognizer/model_options.py +++ b/mediapipe/model_maker/python/vision/gesture_recognizer/model_options.py @@ -14,6 +14,7 @@ """Configurable model options for gesture recognizer models.""" import dataclasses +from typing import List @dataclasses.dataclass @@ -23,5 +24,10 @@ class GestureRecognizerModelOptions: Attributes: dropout_rate: The fraction of the input units to drop, used in dropout layer. + layer_widths: A list of hidden layer widths for the gesture model. Each + element in the list will create a new hidden layer with the specified + width. The hidden layers are separated with BatchNorm, Dropout, and ReLU. + Defaults to an empty list(no hidden layers). """ dropout_rate: float = 0.05 + layer_widths: List[int] = dataclasses.field(default_factory=list) diff --git a/mediapipe/model_maker/python/vision/gesture_recognizer/test_data/metadata/custom_gesture_classifier.tflite b/mediapipe/model_maker/python/vision/gesture_recognizer/testdata/metadata/custom_gesture_classifier.tflite similarity index 100% rename from mediapipe/model_maker/python/vision/gesture_recognizer/test_data/metadata/custom_gesture_classifier.tflite rename to mediapipe/model_maker/python/vision/gesture_recognizer/testdata/metadata/custom_gesture_classifier.tflite diff --git a/mediapipe/model_maker/python/vision/gesture_recognizer/test_data/metadata/custom_gesture_classifier_meta.json b/mediapipe/model_maker/python/vision/gesture_recognizer/testdata/metadata/custom_gesture_classifier_meta.json similarity index 100% rename from mediapipe/model_maker/python/vision/gesture_recognizer/test_data/metadata/custom_gesture_classifier_meta.json rename to mediapipe/model_maker/python/vision/gesture_recognizer/testdata/metadata/custom_gesture_classifier_meta.json diff --git a/mediapipe/model_maker/python/vision/gesture_recognizer/test_data/raw_data/call/0413d5c5-f5ba-476f-a921-ea5e967692a9.jpg b/mediapipe/model_maker/python/vision/gesture_recognizer/testdata/raw_data/call/0413d5c5-f5ba-476f-a921-ea5e967692a9.jpg similarity index 100% rename from mediapipe/model_maker/python/vision/gesture_recognizer/test_data/raw_data/call/0413d5c5-f5ba-476f-a921-ea5e967692a9.jpg rename to mediapipe/model_maker/python/vision/gesture_recognizer/testdata/raw_data/call/0413d5c5-f5ba-476f-a921-ea5e967692a9.jpg diff --git a/mediapipe/model_maker/python/vision/gesture_recognizer/test_data/raw_data/call/044768ad-1709-44ba-b041-c2f8cbe4c166.jpg b/mediapipe/model_maker/python/vision/gesture_recognizer/testdata/raw_data/call/044768ad-1709-44ba-b041-c2f8cbe4c166.jpg similarity index 100% rename from mediapipe/model_maker/python/vision/gesture_recognizer/test_data/raw_data/call/044768ad-1709-44ba-b041-c2f8cbe4c166.jpg rename to mediapipe/model_maker/python/vision/gesture_recognizer/testdata/raw_data/call/044768ad-1709-44ba-b041-c2f8cbe4c166.jpg diff --git a/mediapipe/model_maker/python/vision/gesture_recognizer/test_data/raw_data/call/0e022ee9-74fd-44fe-adad-60c11835e44f.jpg b/mediapipe/model_maker/python/vision/gesture_recognizer/testdata/raw_data/call/0e022ee9-74fd-44fe-adad-60c11835e44f.jpg similarity index 100% rename from mediapipe/model_maker/python/vision/gesture_recognizer/test_data/raw_data/call/0e022ee9-74fd-44fe-adad-60c11835e44f.jpg rename to mediapipe/model_maker/python/vision/gesture_recognizer/testdata/raw_data/call/0e022ee9-74fd-44fe-adad-60c11835e44f.jpg diff --git a/mediapipe/model_maker/python/vision/gesture_recognizer/test_data/raw_data/call/143f8b21-1dc3-4383-bf36-0a54244dfbc0.jpg b/mediapipe/model_maker/python/vision/gesture_recognizer/testdata/raw_data/call/143f8b21-1dc3-4383-bf36-0a54244dfbc0.jpg similarity index 100% rename from mediapipe/model_maker/python/vision/gesture_recognizer/test_data/raw_data/call/143f8b21-1dc3-4383-bf36-0a54244dfbc0.jpg rename to mediapipe/model_maker/python/vision/gesture_recognizer/testdata/raw_data/call/143f8b21-1dc3-4383-bf36-0a54244dfbc0.jpg diff --git a/mediapipe/model_maker/python/vision/gesture_recognizer/test_data/raw_data/call/172ba7f6-c6ba-4398-89a2-25375dccfefa.jpg b/mediapipe/model_maker/python/vision/gesture_recognizer/testdata/raw_data/call/172ba7f6-c6ba-4398-89a2-25375dccfefa.jpg similarity index 100% rename from mediapipe/model_maker/python/vision/gesture_recognizer/test_data/raw_data/call/172ba7f6-c6ba-4398-89a2-25375dccfefa.jpg rename to mediapipe/model_maker/python/vision/gesture_recognizer/testdata/raw_data/call/172ba7f6-c6ba-4398-89a2-25375dccfefa.jpg diff --git a/mediapipe/model_maker/python/vision/gesture_recognizer/test_data/raw_data/call/17b3aa02-dc4d-448d-8601-e2b67193d436.jpg b/mediapipe/model_maker/python/vision/gesture_recognizer/testdata/raw_data/call/17b3aa02-dc4d-448d-8601-e2b67193d436.jpg similarity index 100% rename from mediapipe/model_maker/python/vision/gesture_recognizer/test_data/raw_data/call/17b3aa02-dc4d-448d-8601-e2b67193d436.jpg rename to mediapipe/model_maker/python/vision/gesture_recognizer/testdata/raw_data/call/17b3aa02-dc4d-448d-8601-e2b67193d436.jpg diff --git a/mediapipe/model_maker/python/vision/gesture_recognizer/test_data/raw_data/call/17d804b5-7118-462d-8191-58d764f591b8.jpg b/mediapipe/model_maker/python/vision/gesture_recognizer/testdata/raw_data/call/17d804b5-7118-462d-8191-58d764f591b8.jpg similarity index 100% rename from mediapipe/model_maker/python/vision/gesture_recognizer/test_data/raw_data/call/17d804b5-7118-462d-8191-58d764f591b8.jpg rename to mediapipe/model_maker/python/vision/gesture_recognizer/testdata/raw_data/call/17d804b5-7118-462d-8191-58d764f591b8.jpg diff --git a/mediapipe/model_maker/python/vision/gesture_recognizer/test_data/raw_data/call/1d65a858-623a-4984-9420-958c7e870c3e.jpg b/mediapipe/model_maker/python/vision/gesture_recognizer/testdata/raw_data/call/1d65a858-623a-4984-9420-958c7e870c3e.jpg similarity index 100% rename from mediapipe/model_maker/python/vision/gesture_recognizer/test_data/raw_data/call/1d65a858-623a-4984-9420-958c7e870c3e.jpg rename to mediapipe/model_maker/python/vision/gesture_recognizer/testdata/raw_data/call/1d65a858-623a-4984-9420-958c7e870c3e.jpg diff --git a/mediapipe/model_maker/python/vision/gesture_recognizer/test_data/raw_data/call/1f5fb137-c7a9-435b-85dd-6d7b63ea233a.jpg b/mediapipe/model_maker/python/vision/gesture_recognizer/testdata/raw_data/call/1f5fb137-c7a9-435b-85dd-6d7b63ea233a.jpg similarity index 100% rename from mediapipe/model_maker/python/vision/gesture_recognizer/test_data/raw_data/call/1f5fb137-c7a9-435b-85dd-6d7b63ea233a.jpg rename to mediapipe/model_maker/python/vision/gesture_recognizer/testdata/raw_data/call/1f5fb137-c7a9-435b-85dd-6d7b63ea233a.jpg diff --git a/mediapipe/model_maker/python/vision/gesture_recognizer/test_data/raw_data/call/21de0cfe-af9f-42c2-95d4-aa3d852e7dad.jpg b/mediapipe/model_maker/python/vision/gesture_recognizer/testdata/raw_data/call/21de0cfe-af9f-42c2-95d4-aa3d852e7dad.jpg similarity index 100% rename from mediapipe/model_maker/python/vision/gesture_recognizer/test_data/raw_data/call/21de0cfe-af9f-42c2-95d4-aa3d852e7dad.jpg rename to mediapipe/model_maker/python/vision/gesture_recognizer/testdata/raw_data/call/21de0cfe-af9f-42c2-95d4-aa3d852e7dad.jpg diff --git a/mediapipe/model_maker/python/vision/gesture_recognizer/test_data/raw_data/four/06aa70cc-a12a-4b1e-85cf-e54d44c19a3a.jpg b/mediapipe/model_maker/python/vision/gesture_recognizer/testdata/raw_data/four/06aa70cc-a12a-4b1e-85cf-e54d44c19a3a.jpg similarity index 100% rename from mediapipe/model_maker/python/vision/gesture_recognizer/test_data/raw_data/four/06aa70cc-a12a-4b1e-85cf-e54d44c19a3a.jpg rename to mediapipe/model_maker/python/vision/gesture_recognizer/testdata/raw_data/four/06aa70cc-a12a-4b1e-85cf-e54d44c19a3a.jpg diff --git a/mediapipe/model_maker/python/vision/gesture_recognizer/test_data/raw_data/four/077fa4bf-a99e-496b-b895-709afc614eec.jpg b/mediapipe/model_maker/python/vision/gesture_recognizer/testdata/raw_data/four/077fa4bf-a99e-496b-b895-709afc614eec.jpg similarity index 100% rename from mediapipe/model_maker/python/vision/gesture_recognizer/test_data/raw_data/four/077fa4bf-a99e-496b-b895-709afc614eec.jpg rename to mediapipe/model_maker/python/vision/gesture_recognizer/testdata/raw_data/four/077fa4bf-a99e-496b-b895-709afc614eec.jpg diff --git a/mediapipe/model_maker/python/vision/gesture_recognizer/test_data/raw_data/four/07a5a144-c635-4441-aedb-5c8e9da79aac.jpg b/mediapipe/model_maker/python/vision/gesture_recognizer/testdata/raw_data/four/07a5a144-c635-4441-aedb-5c8e9da79aac.jpg similarity index 100% rename from mediapipe/model_maker/python/vision/gesture_recognizer/test_data/raw_data/four/07a5a144-c635-4441-aedb-5c8e9da79aac.jpg rename to mediapipe/model_maker/python/vision/gesture_recognizer/testdata/raw_data/four/07a5a144-c635-4441-aedb-5c8e9da79aac.jpg diff --git a/mediapipe/model_maker/python/vision/gesture_recognizer/test_data/raw_data/four/07fdea90-1102-4419-a3af-b394cb29531b.jpg b/mediapipe/model_maker/python/vision/gesture_recognizer/testdata/raw_data/four/07fdea90-1102-4419-a3af-b394cb29531b.jpg similarity index 100% rename from mediapipe/model_maker/python/vision/gesture_recognizer/test_data/raw_data/four/07fdea90-1102-4419-a3af-b394cb29531b.jpg rename to mediapipe/model_maker/python/vision/gesture_recognizer/testdata/raw_data/four/07fdea90-1102-4419-a3af-b394cb29531b.jpg diff --git a/mediapipe/model_maker/python/vision/gesture_recognizer/test_data/raw_data/four/0c960166-75b0-4c1b-a3cc-2ddbd5a21703.jpg b/mediapipe/model_maker/python/vision/gesture_recognizer/testdata/raw_data/four/0c960166-75b0-4c1b-a3cc-2ddbd5a21703.jpg similarity index 100% rename from mediapipe/model_maker/python/vision/gesture_recognizer/test_data/raw_data/four/0c960166-75b0-4c1b-a3cc-2ddbd5a21703.jpg rename to mediapipe/model_maker/python/vision/gesture_recognizer/testdata/raw_data/four/0c960166-75b0-4c1b-a3cc-2ddbd5a21703.jpg diff --git a/mediapipe/model_maker/python/vision/gesture_recognizer/test_data/raw_data/four/105f8f8e-ccd6-45a0-b22a-e314930bc13e.jpg b/mediapipe/model_maker/python/vision/gesture_recognizer/testdata/raw_data/four/105f8f8e-ccd6-45a0-b22a-e314930bc13e.jpg similarity index 100% rename from mediapipe/model_maker/python/vision/gesture_recognizer/test_data/raw_data/four/105f8f8e-ccd6-45a0-b22a-e314930bc13e.jpg rename to mediapipe/model_maker/python/vision/gesture_recognizer/testdata/raw_data/four/105f8f8e-ccd6-45a0-b22a-e314930bc13e.jpg diff --git a/mediapipe/model_maker/python/vision/gesture_recognizer/test_data/raw_data/four/116292ef-5947-4d6c-a479-630ebb8a1050.jpg b/mediapipe/model_maker/python/vision/gesture_recognizer/testdata/raw_data/four/116292ef-5947-4d6c-a479-630ebb8a1050.jpg similarity index 100% rename from mediapipe/model_maker/python/vision/gesture_recognizer/test_data/raw_data/four/116292ef-5947-4d6c-a479-630ebb8a1050.jpg rename to mediapipe/model_maker/python/vision/gesture_recognizer/testdata/raw_data/four/116292ef-5947-4d6c-a479-630ebb8a1050.jpg diff --git a/mediapipe/model_maker/python/vision/gesture_recognizer/test_data/raw_data/four/15a73593-b13e-4a1b-99bb-51775cfdfc42.jpg b/mediapipe/model_maker/python/vision/gesture_recognizer/testdata/raw_data/four/15a73593-b13e-4a1b-99bb-51775cfdfc42.jpg similarity index 100% rename from mediapipe/model_maker/python/vision/gesture_recognizer/test_data/raw_data/four/15a73593-b13e-4a1b-99bb-51775cfdfc42.jpg rename to mediapipe/model_maker/python/vision/gesture_recognizer/testdata/raw_data/four/15a73593-b13e-4a1b-99bb-51775cfdfc42.jpg diff --git a/mediapipe/model_maker/python/vision/gesture_recognizer/test_data/raw_data/four/249c5023-6106-447a-84ac-17eb4713731b.jpg b/mediapipe/model_maker/python/vision/gesture_recognizer/testdata/raw_data/four/249c5023-6106-447a-84ac-17eb4713731b.jpg similarity index 100% rename from mediapipe/model_maker/python/vision/gesture_recognizer/test_data/raw_data/four/249c5023-6106-447a-84ac-17eb4713731b.jpg rename to mediapipe/model_maker/python/vision/gesture_recognizer/testdata/raw_data/four/249c5023-6106-447a-84ac-17eb4713731b.jpg diff --git a/mediapipe/model_maker/python/vision/gesture_recognizer/test_data/raw_data/four/25bb4c45-e40b-482c-b588-04db60b7e450.jpg b/mediapipe/model_maker/python/vision/gesture_recognizer/testdata/raw_data/four/25bb4c45-e40b-482c-b588-04db60b7e450.jpg similarity index 100% rename from mediapipe/model_maker/python/vision/gesture_recognizer/test_data/raw_data/four/25bb4c45-e40b-482c-b588-04db60b7e450.jpg rename to mediapipe/model_maker/python/vision/gesture_recognizer/testdata/raw_data/four/25bb4c45-e40b-482c-b588-04db60b7e450.jpg diff --git a/mediapipe/model_maker/python/vision/gesture_recognizer/test_data/raw_data/none/00af1db1-7c86-4e9b-9383-1fbd06c3492d.jpg b/mediapipe/model_maker/python/vision/gesture_recognizer/testdata/raw_data/none/00af1db1-7c86-4e9b-9383-1fbd06c3492d.jpg similarity index 100% rename from mediapipe/model_maker/python/vision/gesture_recognizer/test_data/raw_data/none/00af1db1-7c86-4e9b-9383-1fbd06c3492d.jpg rename to mediapipe/model_maker/python/vision/gesture_recognizer/testdata/raw_data/none/00af1db1-7c86-4e9b-9383-1fbd06c3492d.jpg diff --git a/mediapipe/model_maker/python/vision/gesture_recognizer/test_data/raw_data/none/00b85ea4-8c5d-4302-b847-0a5de1d7dab2.jpg b/mediapipe/model_maker/python/vision/gesture_recognizer/testdata/raw_data/none/00b85ea4-8c5d-4302-b847-0a5de1d7dab2.jpg similarity index 100% rename from mediapipe/model_maker/python/vision/gesture_recognizer/test_data/raw_data/none/00b85ea4-8c5d-4302-b847-0a5de1d7dab2.jpg rename to mediapipe/model_maker/python/vision/gesture_recognizer/testdata/raw_data/none/00b85ea4-8c5d-4302-b847-0a5de1d7dab2.jpg diff --git a/mediapipe/model_maker/python/vision/gesture_recognizer/test_data/raw_data/none/00c84257-800d-4032-9e64-e47eb97005f5.jpg b/mediapipe/model_maker/python/vision/gesture_recognizer/testdata/raw_data/none/00c84257-800d-4032-9e64-e47eb97005f5.jpg similarity index 100% rename from mediapipe/model_maker/python/vision/gesture_recognizer/test_data/raw_data/none/00c84257-800d-4032-9e64-e47eb97005f5.jpg rename to mediapipe/model_maker/python/vision/gesture_recognizer/testdata/raw_data/none/00c84257-800d-4032-9e64-e47eb97005f5.jpg diff --git a/mediapipe/model_maker/python/vision/gesture_recognizer/test_data/raw_data/none/0a038096-c14f-46ac-9155-980161ebc440.jpg b/mediapipe/model_maker/python/vision/gesture_recognizer/testdata/raw_data/none/0a038096-c14f-46ac-9155-980161ebc440.jpg similarity index 100% rename from mediapipe/model_maker/python/vision/gesture_recognizer/test_data/raw_data/none/0a038096-c14f-46ac-9155-980161ebc440.jpg rename to mediapipe/model_maker/python/vision/gesture_recognizer/testdata/raw_data/none/0a038096-c14f-46ac-9155-980161ebc440.jpg diff --git a/mediapipe/model_maker/python/vision/gesture_recognizer/test_data/raw_data/none/0a0ef3d2-2560-4a93-904d-437189fffbf2.jpg b/mediapipe/model_maker/python/vision/gesture_recognizer/testdata/raw_data/none/0a0ef3d2-2560-4a93-904d-437189fffbf2.jpg similarity index 100% rename from mediapipe/model_maker/python/vision/gesture_recognizer/test_data/raw_data/none/0a0ef3d2-2560-4a93-904d-437189fffbf2.jpg rename to mediapipe/model_maker/python/vision/gesture_recognizer/testdata/raw_data/none/0a0ef3d2-2560-4a93-904d-437189fffbf2.jpg diff --git a/mediapipe/model_maker/python/vision/gesture_recognizer/test_data/raw_data/none/0a272153-56c7-42d5-a17d-cd307a1cd6d4.jpg b/mediapipe/model_maker/python/vision/gesture_recognizer/testdata/raw_data/none/0a272153-56c7-42d5-a17d-cd307a1cd6d4.jpg similarity index 100% rename from mediapipe/model_maker/python/vision/gesture_recognizer/test_data/raw_data/none/0a272153-56c7-42d5-a17d-cd307a1cd6d4.jpg rename to mediapipe/model_maker/python/vision/gesture_recognizer/testdata/raw_data/none/0a272153-56c7-42d5-a17d-cd307a1cd6d4.jpg diff --git a/mediapipe/model_maker/python/vision/gesture_recognizer/test_data/raw_data/none/0a4a8907-1950-4e43-9a03-1740e78224ef.jpg b/mediapipe/model_maker/python/vision/gesture_recognizer/testdata/raw_data/none/0a4a8907-1950-4e43-9a03-1740e78224ef.jpg similarity index 100% rename from mediapipe/model_maker/python/vision/gesture_recognizer/test_data/raw_data/none/0a4a8907-1950-4e43-9a03-1740e78224ef.jpg rename to mediapipe/model_maker/python/vision/gesture_recognizer/testdata/raw_data/none/0a4a8907-1950-4e43-9a03-1740e78224ef.jpg diff --git a/mediapipe/model_maker/python/vision/gesture_recognizer/test_data/raw_data/none/0a4bc2da-f5b3-48cd-8f0d-c61dbd08ba53.jpg b/mediapipe/model_maker/python/vision/gesture_recognizer/testdata/raw_data/none/0a4bc2da-f5b3-48cd-8f0d-c61dbd08ba53.jpg similarity index 100% rename from mediapipe/model_maker/python/vision/gesture_recognizer/test_data/raw_data/none/0a4bc2da-f5b3-48cd-8f0d-c61dbd08ba53.jpg rename to mediapipe/model_maker/python/vision/gesture_recognizer/testdata/raw_data/none/0a4bc2da-f5b3-48cd-8f0d-c61dbd08ba53.jpg diff --git a/mediapipe/model_maker/python/vision/gesture_recognizer/test_data/raw_data/none/0a71a6e8-bb06-4ed0-a60b-c2a602fce261.jpg b/mediapipe/model_maker/python/vision/gesture_recognizer/testdata/raw_data/none/0a71a6e8-bb06-4ed0-a60b-c2a602fce261.jpg similarity index 100% rename from mediapipe/model_maker/python/vision/gesture_recognizer/test_data/raw_data/none/0a71a6e8-bb06-4ed0-a60b-c2a602fce261.jpg rename to mediapipe/model_maker/python/vision/gesture_recognizer/testdata/raw_data/none/0a71a6e8-bb06-4ed0-a60b-c2a602fce261.jpg diff --git a/mediapipe/model_maker/python/vision/gesture_recognizer/test_data/raw_data/none/0a787971-9377-4888-803f-aef21863ef7d.jpg b/mediapipe/model_maker/python/vision/gesture_recognizer/testdata/raw_data/none/0a787971-9377-4888-803f-aef21863ef7d.jpg similarity index 100% rename from mediapipe/model_maker/python/vision/gesture_recognizer/test_data/raw_data/none/0a787971-9377-4888-803f-aef21863ef7d.jpg rename to mediapipe/model_maker/python/vision/gesture_recognizer/testdata/raw_data/none/0a787971-9377-4888-803f-aef21863ef7d.jpg diff --git a/mediapipe/model_maker/python/vision/gesture_recognizer/test_data/raw_data/rock/026fd791-8f64-4fae-8cb0-0e01dc4362ce.jpg b/mediapipe/model_maker/python/vision/gesture_recognizer/testdata/raw_data/rock/026fd791-8f64-4fae-8cb0-0e01dc4362ce.jpg similarity index 100% rename from mediapipe/model_maker/python/vision/gesture_recognizer/test_data/raw_data/rock/026fd791-8f64-4fae-8cb0-0e01dc4362ce.jpg rename to mediapipe/model_maker/python/vision/gesture_recognizer/testdata/raw_data/rock/026fd791-8f64-4fae-8cb0-0e01dc4362ce.jpg diff --git a/mediapipe/model_maker/python/vision/gesture_recognizer/test_data/raw_data/rock/055f8be9-f7fd-4c7f-ad3f-7b404b6489c3.jpg b/mediapipe/model_maker/python/vision/gesture_recognizer/testdata/raw_data/rock/055f8be9-f7fd-4c7f-ad3f-7b404b6489c3.jpg similarity index 100% rename from mediapipe/model_maker/python/vision/gesture_recognizer/test_data/raw_data/rock/055f8be9-f7fd-4c7f-ad3f-7b404b6489c3.jpg rename to mediapipe/model_maker/python/vision/gesture_recognizer/testdata/raw_data/rock/055f8be9-f7fd-4c7f-ad3f-7b404b6489c3.jpg diff --git a/mediapipe/model_maker/python/vision/gesture_recognizer/test_data/raw_data/rock/09a619ab-cdf7-4a66-911f-347113f050f1.jpg b/mediapipe/model_maker/python/vision/gesture_recognizer/testdata/raw_data/rock/09a619ab-cdf7-4a66-911f-347113f050f1.jpg similarity index 100% rename from mediapipe/model_maker/python/vision/gesture_recognizer/test_data/raw_data/rock/09a619ab-cdf7-4a66-911f-347113f050f1.jpg rename to mediapipe/model_maker/python/vision/gesture_recognizer/testdata/raw_data/rock/09a619ab-cdf7-4a66-911f-347113f050f1.jpg diff --git a/mediapipe/model_maker/python/vision/gesture_recognizer/test_data/raw_data/rock/0c6628ea-4a8c-49c9-b7cf-c30aef18dc3d.jpg b/mediapipe/model_maker/python/vision/gesture_recognizer/testdata/raw_data/rock/0c6628ea-4a8c-49c9-b7cf-c30aef18dc3d.jpg similarity index 100% rename from mediapipe/model_maker/python/vision/gesture_recognizer/test_data/raw_data/rock/0c6628ea-4a8c-49c9-b7cf-c30aef18dc3d.jpg rename to mediapipe/model_maker/python/vision/gesture_recognizer/testdata/raw_data/rock/0c6628ea-4a8c-49c9-b7cf-c30aef18dc3d.jpg diff --git a/mediapipe/model_maker/python/vision/gesture_recognizer/test_data/raw_data/rock/0cc7ad09-ae5f-45a8-b264-4216176369b6.jpg b/mediapipe/model_maker/python/vision/gesture_recognizer/testdata/raw_data/rock/0cc7ad09-ae5f-45a8-b264-4216176369b6.jpg similarity index 100% rename from mediapipe/model_maker/python/vision/gesture_recognizer/test_data/raw_data/rock/0cc7ad09-ae5f-45a8-b264-4216176369b6.jpg rename to mediapipe/model_maker/python/vision/gesture_recognizer/testdata/raw_data/rock/0cc7ad09-ae5f-45a8-b264-4216176369b6.jpg diff --git a/mediapipe/model_maker/python/vision/gesture_recognizer/test_data/raw_data/rock/10eacf4b-8aaf-46d9-be21-7fb8d8353005.jpg b/mediapipe/model_maker/python/vision/gesture_recognizer/testdata/raw_data/rock/10eacf4b-8aaf-46d9-be21-7fb8d8353005.jpg similarity index 100% rename from mediapipe/model_maker/python/vision/gesture_recognizer/test_data/raw_data/rock/10eacf4b-8aaf-46d9-be21-7fb8d8353005.jpg rename to mediapipe/model_maker/python/vision/gesture_recognizer/testdata/raw_data/rock/10eacf4b-8aaf-46d9-be21-7fb8d8353005.jpg diff --git a/mediapipe/model_maker/python/vision/gesture_recognizer/test_data/raw_data/rock/15cb4e8b-ba1d-46f1-8456-247016a599a4.jpg b/mediapipe/model_maker/python/vision/gesture_recognizer/testdata/raw_data/rock/15cb4e8b-ba1d-46f1-8456-247016a599a4.jpg similarity index 100% rename from mediapipe/model_maker/python/vision/gesture_recognizer/test_data/raw_data/rock/15cb4e8b-ba1d-46f1-8456-247016a599a4.jpg rename to mediapipe/model_maker/python/vision/gesture_recognizer/testdata/raw_data/rock/15cb4e8b-ba1d-46f1-8456-247016a599a4.jpg diff --git a/mediapipe/model_maker/python/vision/gesture_recognizer/test_data/raw_data/rock/18e20af8-8fe1-48d4-bd0e-83fa9e2db88e.jpg b/mediapipe/model_maker/python/vision/gesture_recognizer/testdata/raw_data/rock/18e20af8-8fe1-48d4-bd0e-83fa9e2db88e.jpg similarity index 100% rename from mediapipe/model_maker/python/vision/gesture_recognizer/test_data/raw_data/rock/18e20af8-8fe1-48d4-bd0e-83fa9e2db88e.jpg rename to mediapipe/model_maker/python/vision/gesture_recognizer/testdata/raw_data/rock/18e20af8-8fe1-48d4-bd0e-83fa9e2db88e.jpg diff --git a/mediapipe/model_maker/python/vision/gesture_recognizer/test_data/raw_data/rock/1bed937b-7ae4-4070-891c-daf69415da41.jpg b/mediapipe/model_maker/python/vision/gesture_recognizer/testdata/raw_data/rock/1bed937b-7ae4-4070-891c-daf69415da41.jpg similarity index 100% rename from mediapipe/model_maker/python/vision/gesture_recognizer/test_data/raw_data/rock/1bed937b-7ae4-4070-891c-daf69415da41.jpg rename to mediapipe/model_maker/python/vision/gesture_recognizer/testdata/raw_data/rock/1bed937b-7ae4-4070-891c-daf69415da41.jpg diff --git a/mediapipe/model_maker/python/vision/gesture_recognizer/test_data/raw_data/rock/20e2164d-3473-4d42-8755-22cdbd4417ba.jpg b/mediapipe/model_maker/python/vision/gesture_recognizer/testdata/raw_data/rock/20e2164d-3473-4d42-8755-22cdbd4417ba.jpg similarity index 100% rename from mediapipe/model_maker/python/vision/gesture_recognizer/test_data/raw_data/rock/20e2164d-3473-4d42-8755-22cdbd4417ba.jpg rename to mediapipe/model_maker/python/vision/gesture_recognizer/testdata/raw_data/rock/20e2164d-3473-4d42-8755-22cdbd4417ba.jpg diff --git a/mediapipe/model_maker/python/vision/image_classifier/BUILD b/mediapipe/model_maker/python/vision/image_classifier/BUILD index c581d9fbc..bd916a92b 100644 --- a/mediapipe/model_maker/python/vision/image_classifier/BUILD +++ b/mediapipe/model_maker/python/vision/image_classifier/BUILD @@ -21,9 +21,16 @@ package( default_visibility = ["//mediapipe:__subpackages__"], ) +###################################################################### +# Public target of the MediaPipe Model Maker ImageClassifier APIs. + +# Please see https://developers.google.com/mediapipe/solutions/vision/image_classifier/customize for +# more information about the MediaPipe Model Maker ImageClassifier APIs. +###################################################################### py_library( name = "image_classifier_import", srcs = ["__init__.py"], + visibility = ["//visibility:public"], deps = [ ":dataset", ":hyperparameters", @@ -80,15 +87,6 @@ py_library( ], ) -py_library( - name = "train_image_classifier_lib", - srcs = ["train_image_classifier_lib.py"], - deps = [ - ":hyperparameters", - "//mediapipe/model_maker/python/core/utils:model_util", - ], -) - py_library( name = "image_classifier", srcs = ["image_classifier.py"], @@ -97,7 +95,6 @@ py_library( ":image_classifier_options", ":model_options", ":model_spec", - ":train_image_classifier_lib", "//mediapipe/model_maker/python/core/data:classification_dataset", "//mediapipe/model_maker/python/core/tasks:classifier", "//mediapipe/model_maker/python/core/utils:model_util", @@ -114,7 +111,9 @@ py_library( srcs = ["image_classifier_test.py"], data = ["//mediapipe/model_maker/python/vision/image_classifier/testdata"], deps = [ + ":hyperparameters", ":image_classifier_import", + ":model_options", "//mediapipe/tasks/python/test:test_utils", ], ) diff --git a/mediapipe/model_maker/python/vision/image_classifier/__init__.py b/mediapipe/model_maker/python/vision/image_classifier/__init__.py index 3d0543cd2..4cde9e7e3 100644 --- a/mediapipe/model_maker/python/vision/image_classifier/__init__.py +++ b/mediapipe/model_maker/python/vision/image_classifier/__init__.py @@ -27,3 +27,11 @@ ModelOptions = model_options.ImageClassifierModelOptions ModelSpec = model_spec.ModelSpec SupportedModels = model_spec.SupportedModels ImageClassifierOptions = image_classifier_options.ImageClassifierOptions + +# Remove duplicated and non-public API +del dataset +del hyperparameters +del image_classifier +del image_classifier_options +del model_options +del model_spec diff --git a/mediapipe/model_maker/python/vision/image_classifier/image_classifier.py b/mediapipe/model_maker/python/vision/image_classifier/image_classifier.py index 1ff6132b4..c2181121c 100644 --- a/mediapipe/model_maker/python/vision/image_classifier/image_classifier.py +++ b/mediapipe/model_maker/python/vision/image_classifier/image_classifier.py @@ -28,7 +28,6 @@ from mediapipe.model_maker.python.vision.image_classifier import hyperparameters from mediapipe.model_maker.python.vision.image_classifier import image_classifier_options from mediapipe.model_maker.python.vision.image_classifier import model_options as model_opt from mediapipe.model_maker.python.vision.image_classifier import model_spec as ms -from mediapipe.model_maker.python.vision.image_classifier import train_image_classifier_lib from mediapipe.tasks.python.metadata.metadata_writers import image_classifier as image_classifier_writer from mediapipe.tasks.python.metadata.metadata_writers import metadata_writer @@ -57,6 +56,10 @@ class ImageClassifier(classifier.Classifier): mean_rgb=self._model_spec.mean_rgb, stddev_rgb=self._model_spec.stddev_rgb, use_augmentation=hparams.do_data_augmentation) + self._callbacks = model_util.get_default_callbacks(self._hparams.export_dir) + self._loss_function = tf.keras.losses.CategoricalCrossentropy( + label_smoothing=self._hparams.label_smoothing) + self._metric_function = 'accuracy' self._history = None # Training history returned from `keras_model.fit`. @classmethod @@ -66,7 +69,7 @@ class ImageClassifier(classifier.Classifier): validation_data: classification_ds.ClassificationDataset, options: image_classifier_options.ImageClassifierOptions, ) -> 'ImageClassifier': - """Creates and trains an image classifier. + """Creates and trains an ImageClassifier. Loads data and trains the model based on data for image classification. If a checkpoint file exists in the {options.hparams.export_dir}/checkpoint/ @@ -93,58 +96,29 @@ class ImageClassifier(classifier.Classifier): label_names=train_data.label_names, hparams=options.hparams, model_options=options.model_options) - - image_classifier._create_model() - - tf.compat.v1.logging.info('Training the models...') - image_classifier._train( - train_data=train_data, validation_data=validation_data) - + image_classifier._create_and_train_model(train_data, validation_data) return image_classifier - # TODO: Migrate to the shared training library of Model Maker. - def _train(self, train_data: classification_ds.ClassificationDataset, - validation_data: classification_ds.ClassificationDataset): - """Trains the model with input train_data. - - The training results are recorded by a self._history object returned by - tf.keras.Model.fit(). + def _create_and_train_model( + self, train_data: classification_ds.ClassificationDataset, + validation_data: classification_ds.ClassificationDataset): + """Creates and trains the model and optimizer. Args: train_data: Training data. validation_data: Validation data. """ - - tf.compat.v1.logging.info('Training the models...') - hparams = self._hparams - if len(train_data) < hparams.batch_size: - raise ValueError('The size of the train_data (%d) couldn\'t be smaller ' - 'than batch_size (%d). To solve this problem, set ' - 'the batch_size smaller or increase the size of the ' - 'train_data.' % (len(train_data), hparams.batch_size)) - - train_dataset = train_data.gen_tf_dataset( - batch_size=hparams.batch_size, - is_training=True, - shuffle=self._shuffle, - preprocess=self._preprocess) - hparams.steps_per_epoch = model_util.get_steps_per_epoch( - steps_per_epoch=hparams.steps_per_epoch, - batch_size=hparams.batch_size, + self._create_model() + self._hparams.steps_per_epoch = model_util.get_steps_per_epoch( + steps_per_epoch=self._hparams.steps_per_epoch, + batch_size=self._hparams.batch_size, train_data=train_data) - train_dataset = train_dataset.take(count=hparams.steps_per_epoch) - - validation_dataset = validation_data.gen_tf_dataset( - batch_size=hparams.batch_size, - is_training=False, - preprocess=self._preprocess) - - # Train the model. - self._history = train_image_classifier_lib.train_model( - model=self._model, - hparams=hparams, - train_ds=train_dataset, - validation_ds=validation_dataset) + self._optimizer = self._create_optimizer() + self._train_model( + train_data=train_data, + validation_data=validation_data, + preprocessor=self._preprocess, + checkpoint_path=os.path.join(self._hparams.export_dir, 'checkpoint')) def _create_model(self): """Creates the classifier model from TFHub pretrained models.""" @@ -177,7 +151,7 @@ class ImageClassifier(classifier.Classifier): Args: model_name: File name to save TFLite model with metadata. The full export - path is {self._hparams.model_dir}/{model_name}. + path is {self._hparams.export_dir}/{model_name}. quantization_config: The configuration for model quantization. """ if not tf.io.gfile.exists(self._hparams.export_dir): @@ -198,3 +172,33 @@ class ImageClassifier(classifier.Classifier): model_util.save_tflite(tflite_model_with_metadata, tflite_file) with open(metadata_file, 'w') as f: f.write(metadata_json) + + def _create_optimizer(self) -> tf.keras.optimizers.Optimizer: + """Creates an optimizer with learning rate schedule. + + Uses Keras CosineDecay schedule for the learning rate by default. + + Returns: + A tf.keras.optimizers.Optimizer for model training. + """ + # Learning rate is linear to batch size. + init_lr = self._hparams.learning_rate * self._hparams.batch_size / 256 + + # Get decay steps. + total_training_steps = self._hparams.steps_per_epoch * self._hparams.epochs + default_decay_steps = ( + self._hparams.decay_samples // self._hparams.batch_size) + decay_steps = max(total_training_steps, default_decay_steps) + + learning_rate_fn = tf.keras.experimental.CosineDecay( + initial_learning_rate=init_lr, decay_steps=decay_steps, alpha=0.0) + warmup_steps = self._hparams.warmup_epochs * self._hparams.steps_per_epoch + if warmup_steps: + learning_rate_fn = model_util.WarmUp( + initial_learning_rate=init_lr, + decay_schedule_fn=learning_rate_fn, + warmup_steps=warmup_steps) + optimizer = tf.keras.optimizers.RMSprop( + learning_rate=learning_rate_fn, rho=0.9, momentum=0.9, epsilon=0.001) + + return optimizer diff --git a/mediapipe/model_maker/python/vision/image_classifier/image_classifier_demo.py b/mediapipe/model_maker/python/vision/image_classifier/image_classifier_demo.py index 5832ea53a..f382e28aa 100644 --- a/mediapipe/model_maker/python/vision/image_classifier/image_classifier_demo.py +++ b/mediapipe/model_maker/python/vision/image_classifier/image_classifier_demo.py @@ -61,12 +61,14 @@ def run(data_dir: str, export_dir: str, data = image_classifier.Dataset.from_folder(data_dir) train_data, rest_data = data.split(0.8) validation_data, test_data = rest_data.split(0.5) - + model_options = image_classifier.ImageClassifierOptions( + supported_model=model_spec, + hparams=image_classifier.HParams(export_dir=export_dir), + ) model = image_classifier.ImageClassifier.create( - model_spec=model_spec, train_data=train_data, validation_data=validation_data, - hparams=image_classifier.HParams(model_dir=export_dir)) + options=model_options) _, acc = model.evaluate(test_data) print('Test accuracy: %f' % acc) @@ -83,7 +85,6 @@ def run(data_dir: str, export_dir: str, raise ValueError(f'Quantization: {quantization} is not recognized') model.export_model(quantization_config=quantization_config) - model.export_labels(export_dir) def main(_) -> None: diff --git a/mediapipe/model_maker/python/vision/image_classifier/image_classifier_test.py b/mediapipe/model_maker/python/vision/image_classifier/image_classifier_test.py index 252659edc..afda8643b 100644 --- a/mediapipe/model_maker/python/vision/image_classifier/image_classifier_test.py +++ b/mediapipe/model_maker/python/vision/image_classifier/image_classifier_test.py @@ -24,6 +24,8 @@ import numpy as np import tensorflow as tf from mediapipe.model_maker.python.vision import image_classifier +from mediapipe.model_maker.python.vision.image_classifier import hyperparameters +from mediapipe.model_maker.python.vision.image_classifier import model_options from mediapipe.tasks.python.test import test_utils @@ -133,7 +135,10 @@ class ImageClassifierTest(tf.test.TestCase, parameterized.TestCase): self.assertTrue(os.path.exists(output_metadata_file)) self.assertGreater(os.path.getsize(output_metadata_file), 0) - self.assertTrue(filecmp.cmp(output_metadata_file, expected_metadata_file)) + filecmp.clear_cache() + self.assertTrue( + filecmp.cmp( + output_metadata_file, expected_metadata_file, shallow=False)) def test_continual_training_by_loading_checkpoint(self): mock_stdout = io.StringIO() @@ -159,15 +164,15 @@ class ImageClassifierTest(tf.test.TestCase, parameterized.TestCase): self.assertGreaterEqual(accuracy, threshold) @unittest_mock.patch.object( - image_classifier.hyperparameters, + hyperparameters, 'HParams', autospec=True, - return_value=image_classifier.HParams(epochs=1)) + return_value=hyperparameters.HParams(epochs=1)) @unittest_mock.patch.object( - image_classifier.model_options, + model_options, 'ImageClassifierModelOptions', autospec=True, - return_value=image_classifier.ModelOptions()) + return_value=model_options.ImageClassifierModelOptions()) def test_create_hparams_and_model_options_if_none_in_image_classifier_options( self, mock_hparams, mock_model_options): options = image_classifier.ImageClassifierOptions( diff --git a/mediapipe/model_maker/python/vision/image_classifier/train_image_classifier_lib.py b/mediapipe/model_maker/python/vision/image_classifier/train_image_classifier_lib.py deleted file mode 100644 index c5b28cff5..000000000 --- a/mediapipe/model_maker/python/vision/image_classifier/train_image_classifier_lib.py +++ /dev/null @@ -1,102 +0,0 @@ -# Copyright 2022 The MediaPipe Authors. All Rights Reserved. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -"""Library to train model.""" - -import os -import tensorflow as tf - -from mediapipe.model_maker.python.core.utils import model_util -from mediapipe.model_maker.python.vision.image_classifier import hyperparameters as hp - - -def _create_optimizer(init_lr: float, decay_steps: int, - warmup_steps: int) -> tf.keras.optimizers.Optimizer: - """Creates an optimizer with learning rate schedule. - - Uses Keras CosineDecay schedule for the learning rate by default. - - Args: - init_lr: Initial learning rate. - decay_steps: Number of steps to decay over. - warmup_steps: Number of steps to do warmup for. - - Returns: - A tf.keras.optimizers.Optimizer for model training. - """ - learning_rate_fn = tf.keras.experimental.CosineDecay( - initial_learning_rate=init_lr, decay_steps=decay_steps, alpha=0.0) - if warmup_steps: - learning_rate_fn = model_util.WarmUp( - initial_learning_rate=init_lr, - decay_schedule_fn=learning_rate_fn, - warmup_steps=warmup_steps) - optimizer = tf.keras.optimizers.RMSprop( - learning_rate=learning_rate_fn, rho=0.9, momentum=0.9, epsilon=0.001) - - return optimizer - - -def train_model(model: tf.keras.Model, hparams: hp.HParams, - train_ds: tf.data.Dataset, - validation_ds: tf.data.Dataset) -> tf.keras.callbacks.History: - """Trains model with the input data and hyperparameters. - - Args: - model: Input tf.keras.Model. - hparams: Hyperparameters for training image classifier. - train_ds: tf.data.Dataset, training data to be fed in tf.keras.Model.fit(). - validation_ds: tf.data.Dataset, validation data to be fed in - tf.keras.Model.fit(). - - Returns: - The tf.keras.callbacks.History object returned by tf.keras.Model.fit(). - """ - - # Learning rate is linear to batch size. - learning_rate = hparams.learning_rate * hparams.batch_size / 256 - - # Get decay steps. - # NOMUTANTS--(b/256493858):Plan to test it in the unified training library. - total_training_steps = hparams.steps_per_epoch * hparams.epochs - default_decay_steps = hparams.decay_samples // hparams.batch_size - decay_steps = max(total_training_steps, default_decay_steps) - - warmup_steps = hparams.warmup_epochs * hparams.steps_per_epoch - optimizer = _create_optimizer( - init_lr=learning_rate, decay_steps=decay_steps, warmup_steps=warmup_steps) - - loss = tf.keras.losses.CategoricalCrossentropy( - label_smoothing=hparams.label_smoothing) - model.compile(optimizer=optimizer, loss=loss, metrics=['accuracy']) - - summary_dir = os.path.join(hparams.export_dir, 'summaries') - summary_callback = tf.keras.callbacks.TensorBoard(summary_dir) - # Save checkpoint every 5 epochs. - checkpoint_path = os.path.join(hparams.export_dir, 'checkpoint') - checkpoint_callback = tf.keras.callbacks.ModelCheckpoint( - os.path.join(checkpoint_path, 'model-{epoch:04d}'), - save_weights_only=True, - period=5) - - latest_checkpoint = tf.train.latest_checkpoint(checkpoint_path) - if latest_checkpoint: - print(f'Resuming from {latest_checkpoint}') - model.load_weights(latest_checkpoint) - - # Train the model. - return model.fit( - x=train_ds, - epochs=hparams.epochs, - validation_data=validation_ds, - callbacks=[summary_callback, checkpoint_callback]) diff --git a/mediapipe/model_maker/requirements.txt b/mediapipe/model_maker/requirements.txt index 389ee484a..d7e4a950f 100644 --- a/mediapipe/model_maker/requirements.txt +++ b/mediapipe/model_maker/requirements.txt @@ -1,6 +1,8 @@ absl-py +mediapipe==0.9.0.1 numpy -opencv-contrib-python -tensorflow +opencv-python +tensorflow>=2.10 tensorflow-datasets tensorflow-hub +tf-models-official>=2.10.1 diff --git a/mediapipe/model_maker/setup.py b/mediapipe/model_maker/setup.py new file mode 100644 index 000000000..1dac6301a --- /dev/null +++ b/mediapipe/model_maker/setup.py @@ -0,0 +1,155 @@ +"""Copyright 2020-2022 The MediaPipe Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. + +Setup for Mediapipe-Model-Maker package with setuptools. +""" + +import glob +import os +import shutil +import subprocess +import sys +import setuptools + + +__version__ = 'dev' +MM_ROOT_PATH = os.path.dirname(os.path.abspath(__file__)) +# Build dir to copy all necessary files and build package +SRC_NAME = 'pip_src' +BUILD_DIR = os.path.join(MM_ROOT_PATH, SRC_NAME) +BUILD_MM_DIR = os.path.join(BUILD_DIR, 'mediapipe_model_maker') + + +def _parse_requirements(path): + with open(os.path.join(MM_ROOT_PATH, path)) as f: + return [ + line.rstrip() + for line in f + if not (line.isspace() or line.startswith('#')) + ] + + +def _copy_to_pip_src_dir(file): + """Copy a file from bazel-bin to the pip_src dir.""" + dst = file + dst_dir = os.path.dirname(dst) + if not os.path.exists(dst_dir): + os.makedirs(dst_dir) + src_file = os.path.join('../../bazel-bin/mediapipe/model_maker', file) + shutil.copyfile(src_file, file) + + +def _setup_build_dir(): + """Setup the BUILD_DIR directory to build the mediapipe_model_maker package. + + We need to create a new BUILD_DIR directory because any references to the path + `mediapipe/model_maker` needs to be renamed to `mediapipe_model_maker` to + avoid conflicting with the mediapipe package name. + This setup function performs the following actions: + 1. Copy python source code into BUILD_DIR and rename imports to + mediapipe_model_maker + 2. Download models from GCS into BUILD_DIR + """ + # Copy python source code into BUILD_DIR + if os.path.exists(BUILD_DIR): + shutil.rmtree(BUILD_DIR) + python_files = glob.glob('python/**/*.py', recursive=True) + python_files.append('__init__.py') + for python_file in python_files: + # Exclude test files from pip package + if '_test.py' in python_file: + continue + build_target_file = os.path.join(BUILD_MM_DIR, python_file) + with open(python_file, 'r') as file: + filedata = file.read() + # Rename all mediapipe.model_maker imports to mediapipe_model_maker + filedata = filedata.replace('from mediapipe.model_maker', + 'from mediapipe_model_maker') + os.makedirs(os.path.dirname(build_target_file), exist_ok=True) + with open(build_target_file, 'w') as file: + file.write(filedata) + + # Use bazel to download GCS model files + model_build_files = [ + 'models/gesture_recognizer/BUILD', + 'models/text_classifier/BUILD', + ] + for model_build_file in model_build_files: + build_target_file = os.path.join(BUILD_MM_DIR, model_build_file) + os.makedirs(os.path.dirname(build_target_file), exist_ok=True) + shutil.copy(model_build_file, build_target_file) + external_files = [ + 'models/gesture_recognizer/canned_gesture_classifier.tflite', + 'models/gesture_recognizer/gesture_embedder.tflite', + 'models/gesture_recognizer/hand_landmark_full.tflite', + 'models/gesture_recognizer/palm_detection_full.tflite', + 'models/gesture_recognizer/gesture_embedder/keras_metadata.pb', + 'models/gesture_recognizer/gesture_embedder/saved_model.pb', + 'models/gesture_recognizer/gesture_embedder/variables/variables.data-00000-of-00001', + 'models/gesture_recognizer/gesture_embedder/variables/variables.index', + 'models/text_classifier/mobilebert_tiny/keras_metadata.pb', + 'models/text_classifier/mobilebert_tiny/saved_model.pb', + 'models/text_classifier/mobilebert_tiny/assets/vocab.txt', + 'models/text_classifier/mobilebert_tiny/variables/variables.data-00000-of-00001', + 'models/text_classifier/mobilebert_tiny/variables/variables.index', + ] + for elem in external_files: + external_file = os.path.join(f'{SRC_NAME}/mediapipe_model_maker', elem) + sys.stderr.write('downloading file: %s\n' % external_file) + fetch_model_command = [ + 'bazel', + 'build', + external_file, + ] + if subprocess.call(fetch_model_command) != 0: + sys.exit(-1) + _copy_to_pip_src_dir(external_file) + +_setup_build_dir() + +setuptools.setup( + name='mediapipe-model-maker', + version=__version__, + url='https://github.com/google/mediapipe/tree/master/mediapipe/model_maker', + description='MediaPipe Model Maker is a simple, low-code solution for customizing on-device ML models', + author='The MediaPipe Authors', + author_email='mediapipe@google.com', + long_description='', + long_description_content_type='text/markdown', + packages=setuptools.find_packages(where=SRC_NAME), + package_dir={'': SRC_NAME}, + install_requires=_parse_requirements('requirements.txt'), + include_package_data=True, + classifiers=[ + 'Development Status :: 3 - Alpha', + 'Intended Audience :: Developers', + 'Intended Audience :: Education', + 'Intended Audience :: Science/Research', + 'License :: OSI Approved :: Apache Software License', + 'Operating System :: MacOS :: MacOS X', + 'Operating System :: Microsoft :: Windows', + 'Operating System :: POSIX :: Linux', + 'Programming Language :: Python :: 3.8', + 'Programming Language :: Python :: 3.9', + 'Programming Language :: Python :: 3.10', + 'Programming Language :: Python :: 3 :: Only', + 'Topic :: Scientific/Engineering', + 'Topic :: Scientific/Engineering :: Artificial Intelligence', + 'Topic :: Software Development', + 'Topic :: Software Development :: Libraries', + 'Topic :: Software Development :: Libraries :: Python Modules', + ], + license='Apache 2.0', + keywords=['mediapipe', 'model', 'maker'], +) diff --git a/mediapipe/modules/hand_landmark/calculators/BUILD b/mediapipe/modules/hand_landmark/calculators/BUILD index b2a8efe37..b42ec94de 100644 --- a/mediapipe/modules/hand_landmark/calculators/BUILD +++ b/mediapipe/modules/hand_landmark/calculators/BUILD @@ -24,7 +24,6 @@ cc_library( "//mediapipe/framework:calculator_framework", "//mediapipe/framework:calculator_options_cc_proto", "//mediapipe/framework/formats:landmark_cc_proto", - "//mediapipe/framework/formats:location_data_cc_proto", "//mediapipe/framework/formats:rect_cc_proto", "//mediapipe/framework/port:ret_check", "//mediapipe/framework/port:status", diff --git a/mediapipe/modules/hand_landmark/calculators/hand_landmarks_to_rect_calculator.cc b/mediapipe/modules/hand_landmark/calculators/hand_landmarks_to_rect_calculator.cc index 6f2c49d64..638678ff5 100644 --- a/mediapipe/modules/hand_landmark/calculators/hand_landmarks_to_rect_calculator.cc +++ b/mediapipe/modules/hand_landmark/calculators/hand_landmarks_to_rect_calculator.cc @@ -22,6 +22,8 @@ namespace mediapipe { +using ::mediapipe::NormalizedRect; + namespace { // NORM_LANDMARKS is either the full set of landmarks for the hand, or diff --git a/mediapipe/modules/holistic_landmark/calculators/BUILD b/mediapipe/modules/holistic_landmark/calculators/BUILD index c3c091924..bc00b697c 100644 --- a/mediapipe/modules/holistic_landmark/calculators/BUILD +++ b/mediapipe/modules/holistic_landmark/calculators/BUILD @@ -21,7 +21,6 @@ package(default_visibility = ["//visibility:public"]) cc_library( name = "hand_detections_from_pose_to_rects_calculator", srcs = ["hand_detections_from_pose_to_rects_calculator.cc"], - visibility = ["//visibility:public"], deps = [ "//mediapipe/calculators/util:detections_to_rects_calculator", "//mediapipe/calculators/util:detections_to_rects_calculator_cc_proto", @@ -39,7 +38,6 @@ cc_library( mediapipe_proto_library( name = "roi_tracking_calculator_proto", srcs = ["roi_tracking_calculator.proto"], - visibility = ["//visibility:public"], deps = [ "//mediapipe/framework:calculator_options_proto", "//mediapipe/framework:calculator_proto", @@ -49,7 +47,6 @@ mediapipe_proto_library( cc_library( name = "roi_tracking_calculator", srcs = ["roi_tracking_calculator.cc"], - visibility = ["//visibility:public"], deps = [ ":roi_tracking_calculator_cc_proto", "//mediapipe/framework:calculator_framework", diff --git a/mediapipe/modules/holistic_landmark/calculators/roi_tracking_calculator.cc b/mediapipe/modules/holistic_landmark/calculators/roi_tracking_calculator.cc index 0da6cd7f7..49c7b93fb 100644 --- a/mediapipe/modules/holistic_landmark/calculators/roi_tracking_calculator.cc +++ b/mediapipe/modules/holistic_landmark/calculators/roi_tracking_calculator.cc @@ -34,6 +34,8 @@ constexpr char kRecropRectTag[] = "RECROP_RECT"; constexpr char kImageSizeTag[] = "IMAGE_SIZE"; constexpr char kTrackingRectTag[] = "TRACKING_RECT"; +using ::mediapipe::NormalizedRect; + // TODO: Use rect rotation. // Verifies that Intersection over Union of previous frame rect and current // frame re-crop rect is less than threshold. diff --git a/mediapipe/modules/objectron/calculators/BUILD b/mediapipe/modules/objectron/calculators/BUILD index eeeaee5f4..14cea526f 100644 --- a/mediapipe/modules/objectron/calculators/BUILD +++ b/mediapipe/modules/objectron/calculators/BUILD @@ -275,7 +275,6 @@ cc_library( ":tflite_tensors_to_objects_calculator_cc_proto", "//mediapipe/framework:calculator_framework", "//mediapipe/framework/deps:file_path", - "//mediapipe/framework/formats:detection_cc_proto", "//mediapipe/framework/port:opencv_core", "//mediapipe/framework/port:ret_check", "@com_google_absl//absl/memory", @@ -299,7 +298,6 @@ cc_library( ":tensors_to_objects_calculator_cc_proto", "//mediapipe/framework:calculator_framework", "//mediapipe/framework/deps:file_path", - "//mediapipe/framework/formats:detection_cc_proto", "//mediapipe/framework/port:opencv_core", "//mediapipe/framework/port:ret_check", "@com_google_absl//absl/memory", @@ -316,13 +314,11 @@ cc_library( visibility = ["//visibility:public"], deps = [ ":annotation_cc_proto", - ":belief_decoder_config_cc_proto", ":decoder", ":lift_2d_frame_annotation_to_3d_calculator_cc_proto", ":tensor_util", "//mediapipe/framework:calculator_framework", "//mediapipe/framework/deps:file_path", - "//mediapipe/framework/formats:detection_cc_proto", "//mediapipe/framework/port:opencv_core", "//mediapipe/framework/port:ret_check", "@com_google_absl//absl/memory", diff --git a/mediapipe/modules/objectron/calculators/frame_annotation_to_rect_calculator.cc b/mediapipe/modules/objectron/calculators/frame_annotation_to_rect_calculator.cc index 476f8cb54..1fe919c54 100644 --- a/mediapipe/modules/objectron/calculators/frame_annotation_to_rect_calculator.cc +++ b/mediapipe/modules/objectron/calculators/frame_annotation_to_rect_calculator.cc @@ -34,6 +34,8 @@ namespace { constexpr char kInputFrameAnnotationTag[] = "FRAME_ANNOTATION"; constexpr char kOutputNormRectsTag[] = "NORM_RECTS"; +using ::mediapipe::NormalizedRect; + } // namespace // A calculator that converts FrameAnnotation proto to NormalizedRect. diff --git a/mediapipe/objc/BUILD b/mediapipe/objc/BUILD index 48c9b181a..c71c02b6d 100644 --- a/mediapipe/objc/BUILD +++ b/mediapipe/objc/BUILD @@ -68,7 +68,6 @@ objc_library( copts = [ "-Wno-shorten-64-to-32", ], - sdk_frameworks = ["Accelerate"], # This build rule is public to allow external customers to build their own iOS apps. visibility = ["//visibility:public"], deps = [ @@ -83,13 +82,14 @@ objc_library( "//mediapipe/framework/port:status", "//mediapipe/framework/port:statusor", "//mediapipe/framework/port:threadpool", - "//mediapipe/gpu:MPPGraphGPUData", "//mediapipe/gpu:gl_base", "//mediapipe/gpu:gpu_buffer", "//mediapipe/gpu:gpu_shared_data_internal", "//mediapipe/gpu:graph_support", + "//mediapipe/gpu:metal_shared_resources", "//mediapipe/gpu:pixel_buffer_pool_util", "//mediapipe/util:cpu_util", + "//third_party/apple_frameworks:Accelerate", "@com_google_absl//absl/base:core_headers", "@com_google_absl//absl/memory", "@com_google_absl//absl/strings", @@ -120,13 +120,13 @@ objc_library( ], "//conditions:default": [], }), - sdk_frameworks = [ - "AVFoundation", - "CoreVideo", - "Foundation", - ], # This build rule is public to allow external customers to build their own iOS apps. visibility = ["//visibility:public"], + deps = [ + "//third_party/apple_frameworks:AVFoundation", + "//third_party/apple_frameworks:CoreVideo", + "//third_party/apple_frameworks:Foundation", + ], ) objc_library( @@ -140,16 +140,14 @@ objc_library( copts = [ "-Wno-shorten-64-to-32", ], - sdk_frameworks = [ - "Foundation", - "GLKit", - ], visibility = ["//mediapipe/framework:mediapipe_internal"], deps = [ ":mediapipe_framework_ios", - "//mediapipe/gpu:gl_calculator_helper_ios", + "//mediapipe/gpu:gl_calculator_helper", "//mediapipe/gpu:gl_quad_renderer", "//mediapipe/gpu:gl_simple_shaders", + "//third_party/apple_frameworks:Foundation", + "//third_party/apple_frameworks:GLKit", ], ) @@ -164,16 +162,14 @@ objc_library( copts = [ "-Wno-shorten-64-to-32", ], - sdk_frameworks = [ - "Foundation", - "GLKit", - ], # This build rule is public to allow external customers to build their own iOS apps. visibility = ["//visibility:public"], deps = [ ":mediapipe_framework_ios", ":mediapipe_gl_view_renderer", - "//mediapipe/gpu:gl_calculator_helper_ios", + "//mediapipe/gpu:gl_calculator_helper", + "//third_party/apple_frameworks:Foundation", + "//third_party/apple_frameworks:GLKit", ], ) @@ -188,13 +184,11 @@ objc_library( copts = [ "-Wno-shorten-64-to-32", ], - sdk_frameworks = [ - "CoreVideo", - "Foundation", - ], visibility = ["//mediapipe/framework:mediapipe_internal"], deps = [ ":mediapipe_framework_ios", + "//third_party/apple_frameworks:CoreVideo", + "//third_party/apple_frameworks:Foundation", "@com_google_absl//absl/strings", ], ) @@ -211,23 +205,21 @@ objc_library( copts = [ "-Wno-shorten-64-to-32", ], - sdk_frameworks = [ - "AVFoundation", - "Accelerate", - "CoreGraphics", - "CoreMedia", - "CoreVideo", - "GLKit", - "OpenGLES", - "QuartzCore", - "UIKit", - ], visibility = ["//mediapipe/framework:mediapipe_internal"], deps = [ ":CGImageRefUtils", ":Weakify", ":mediapipe_framework_ios", "//mediapipe/framework:calculator_framework", + "//third_party/apple_frameworks:AVFoundation", + "//third_party/apple_frameworks:Accelerate", + "//third_party/apple_frameworks:CoreGraphics", + "//third_party/apple_frameworks:CoreMedia", + "//third_party/apple_frameworks:CoreVideo", + "//third_party/apple_frameworks:GLKit", + "//third_party/apple_frameworks:OpenGLES", + "//third_party/apple_frameworks:QuartzCore", + "//third_party/apple_frameworks:UIKit", ], ) @@ -245,16 +237,6 @@ objc_library( data = [ "testdata/googlelogo_color_272x92dp.png", ], - sdk_frameworks = [ - "AVFoundation", - "Accelerate", - "CoreGraphics", - "CoreMedia", - "CoreVideo", - "GLKit", - "QuartzCore", - "UIKit", - ], visibility = ["//mediapipe/framework:mediapipe_internal"], deps = [ ":CGImageRefUtils", @@ -263,6 +245,14 @@ objc_library( ":mediapipe_framework_ios", ":mediapipe_input_sources_ios", "//mediapipe/calculators/core:pass_through_calculator", + "//third_party/apple_frameworks:AVFoundation", + "//third_party/apple_frameworks:Accelerate", + "//third_party/apple_frameworks:CoreGraphics", + "//third_party/apple_frameworks:CoreMedia", + "//third_party/apple_frameworks:CoreVideo", + "//third_party/apple_frameworks:GLKit", + "//third_party/apple_frameworks:QuartzCore", + "//third_party/apple_frameworks:UIKit", ], ) diff --git a/mediapipe/objc/MPPGraph.mm b/mediapipe/objc/MPPGraph.mm index 080cca20f..3123eb863 100644 --- a/mediapipe/objc/MPPGraph.mm +++ b/mediapipe/objc/MPPGraph.mm @@ -24,7 +24,6 @@ #include "mediapipe/framework/formats/image.h" #include "mediapipe/framework/formats/image_frame.h" #include "mediapipe/framework/graph_service.h" -#include "mediapipe/gpu/MPPGraphGPUData.h" #include "mediapipe/gpu/gl_base.h" #include "mediapipe/gpu/gpu_shared_data_internal.h" #include "mediapipe/objc/util.h" @@ -231,16 +230,17 @@ if ([wrapper.delegate } - (absl::Status)performStart { - absl::Status status = _graph->Initialize(_config); - if (!status.ok()) { - return status; - } + absl::Status status; for (const auto& service_packet : _servicePackets) { status = _graph->SetServicePacket(*service_packet.first, service_packet.second); if (!status.ok()) { return status; } } + status = _graph->Initialize(_config); + if (!status.ok()) { + return status; + } status = _graph->StartRun(_inputSidePackets, _streamHeaders); if (!status.ok()) { return status; diff --git a/mediapipe/objc/MPPLayerRenderer.m b/mediapipe/objc/MPPLayerRenderer.m index 7c3027fb6..edd2216ee 100644 --- a/mediapipe/objc/MPPLayerRenderer.m +++ b/mediapipe/objc/MPPLayerRenderer.m @@ -54,10 +54,11 @@ glGenRenderbuffers(1, &renderbuffer_); glBindRenderbuffer(GL_RENDERBUFFER, renderbuffer_); glFramebufferRenderbuffer(GL_FRAMEBUFFER, GL_COLOR_ATTACHMENT0, GL_RENDERBUFFER, renderbuffer_); - BOOL success = [_glRenderer.glContext renderbufferStorage:GL_RENDERBUFFER fromDrawable:_layer]; + BOOL success __unused = [_glRenderer.glContext renderbufferStorage:GL_RENDERBUFFER + fromDrawable:_layer]; NSAssert(success, @"could not create renderbuffer storage for layer with bounds %@", NSStringFromCGRect(_layer.bounds)); - GLenum status = glCheckFramebufferStatus(GL_FRAMEBUFFER); + GLenum status __unused = glCheckFramebufferStatus(GL_FRAMEBUFFER); NSAssert(status == GL_FRAMEBUFFER_COMPLETE, @"failed to make complete framebuffer object %x", status); } diff --git a/mediapipe/python/BUILD b/mediapipe/python/BUILD index 21067e828..6e36ca1c2 100644 --- a/mediapipe/python/BUILD +++ b/mediapipe/python/BUILD @@ -94,10 +94,11 @@ cc_library( "//mediapipe/tasks/cc/vision/object_detector:object_detector_graph", ] + select({ # TODO: Build text_classifier_graph and text_embedder_graph on Windows. - # TODO: Build audio_classifier_graph on Windows. + # TODO: Build audio_classifier_graph and audio_embedder_graph on Windows. "//mediapipe:windows": [], "//conditions:default": [ "//mediapipe/tasks/cc/audio/audio_classifier:audio_classifier_graph", + "//mediapipe/tasks/cc/audio/audio_embedder:audio_embedder_graph", "//mediapipe/tasks/cc/text/text_classifier:text_classifier_graph", "//mediapipe/tasks/cc/text/text_embedder:text_embedder_graph", ], diff --git a/mediapipe/python/image_test.py b/mediapipe/python/image_test.py index 117d20974..cd9124948 100644 --- a/mediapipe/python/image_test.py +++ b/mediapipe/python/image_test.py @@ -28,6 +28,8 @@ import PIL.Image from mediapipe.python._framework_bindings import image from mediapipe.python._framework_bindings import image_frame +TEST_IMAGE_PATH = 'mediapipe/python/solutions/testdata' + Image = image.Image ImageFormat = image_frame.ImageFormat @@ -187,5 +189,26 @@ class ImageTest(absltest.TestCase): gc.collect() self.assertEqual(sys.getrefcount(rgb_image), initial_ref_count) + def test_image_create_from_cvmat(self): + image_path = os.path.join(os.path.dirname(__file__), + 'solutions/testdata/hands.jpg') + mat = cv2.imread(image_path).astype(np.uint8) + mat = cv2.cvtColor(mat, cv2.COLOR_BGR2RGB) + rgb_image = Image(image_format=ImageFormat.SRGB, data=mat) + self.assertEqual(rgb_image.width, 720) + self.assertEqual(rgb_image.height, 382) + self.assertEqual(rgb_image.channels, 3) + self.assertEqual(rgb_image.image_format, ImageFormat.SRGB) + self.assertTrue(np.array_equal(mat, rgb_image.numpy_view())) + + def test_image_create_from_file(self): + image_path = os.path.join(os.path.dirname(__file__), + 'solutions/testdata/hands.jpg') + loaded_image = Image.create_from_file(image_path) + self.assertEqual(loaded_image.width, 720) + self.assertEqual(loaded_image.height, 382) + self.assertEqual(loaded_image.channels, 3) + self.assertEqual(loaded_image.image_format, ImageFormat.SRGB) + if __name__ == '__main__': absltest.main() diff --git a/mediapipe/python/packet_test.py b/mediapipe/python/packet_test.py index e1a4c12af..16fc37c87 100644 --- a/mediapipe/python/packet_test.py +++ b/mediapipe/python/packet_test.py @@ -157,7 +157,7 @@ class PacketTest(absltest.TestCase): p.timestamp = 0 self.assertAlmostEqual(packet_getter.get_float(p), 0.42) self.assertEqual(p.timestamp, 0) - p2 = packet_creator.create_float(np.float(0.42)) + p2 = packet_creator.create_float(float(0.42)) p2.timestamp = 0 self.assertAlmostEqual(packet_getter.get_float(p2), 0.42) self.assertEqual(p2.timestamp, 0) diff --git a/mediapipe/python/pybind/image.cc b/mediapipe/python/pybind/image.cc index 5d8663143..1bcca12ff 100644 --- a/mediapipe/python/pybind/image.cc +++ b/mediapipe/python/pybind/image.cc @@ -48,16 +48,20 @@ void ImageSubmodule(pybind11::module* module) { become immutable after creation. Creation examples: - import cv2 - cv_mat = cv2.imread(input_file)[:, :, ::-1] - rgb_frame = mp.Image(format=ImageFormat.SRGB, data=cv_mat) - gray_frame = mp.Image( - format=ImageFormat.GRAY, data=cv2.cvtColor(cv_mat, cv2.COLOR_RGB2GRAY)) - from PIL import Image - pil_img = Image.new('RGB', (60, 30), color = 'red') - image = mp.Image( - format=mp.ImageFormat.SRGB, data=np.asarray(pil_img)) + ```python + import cv2 + cv_mat = cv2.imread(input_file)[:, :, ::-1] + rgb_frame = mp.Image(image_format=ImageFormat.SRGB, data=cv_mat) + gray_frame = mp.Image( + image_format=ImageFormat.GRAY, + data=cv2.cvtColor(cv_mat, cv2.COLOR_RGB2GRAY)) + + from PIL import Image + pil_img = Image.new('RGB', (60, 30), color = 'red') + image = mp.Image( + image_format=mp.ImageFormat.SRGB, data=np.asarray(pil_img)) + ``` The pixel data in an Image can be retrieved as a numpy ndarray by calling `Image.numpy_view()`. The returned numpy ndarray is a reference to the @@ -65,15 +69,18 @@ void ImageSubmodule(pybind11::module* module) { numpy ndarray, it's required to obtain a copy of it. Pixel data retrieval examples: - for channel in range(num_channel): - for col in range(width): - for row in range(height): - print(image[row, col, channel]) - output_ndarray = image.numpy_view() - print(output_ndarray[0, 0, 0]) - copied_ndarray = np.copy(output_ndarray) - copied_ndarray[0,0,0] = 0 + ```python + for channel in range(num_channel): + for col in range(width): + for row in range(height): + print(image[row, col, channel]) + + output_ndarray = image.numpy_view() + print(output_ndarray[0, 0, 0]) + copied_ndarray = np.copy(output_ndarray) + copied_ndarray[0,0,0] = 0 + ``` )doc", py::dynamic_attr()); @@ -156,9 +163,11 @@ void ImageSubmodule(pybind11::module* module) { An unwritable numpy ndarray. Examples: + ``` output_ndarray = image.numpy_view() copied_ndarray = np.copy(output_ndarray) copied_ndarray[0,0,0] = 0 + ``` )doc"); image.def( @@ -191,10 +200,12 @@ void ImageSubmodule(pybind11::module* module) { IndexError: If the index is invalid or out of bounds. Examples: + ``` for channel in range(num_channel): for col in range(width): for row in range(height): print(image[row, col, channel]) + ``` )doc"); image @@ -224,7 +235,9 @@ void ImageSubmodule(pybind11::module* module) { A boolean. Examples: + ``` image.is_aligned(16) + ``` )doc"); image.def_static( diff --git a/mediapipe/python/pybind/image_frame.cc b/mediapipe/python/pybind/image_frame.cc index a7fc6bfe4..bc7a9753d 100644 --- a/mediapipe/python/pybind/image_frame.cc +++ b/mediapipe/python/pybind/image_frame.cc @@ -83,14 +83,15 @@ void ImageFrameSubmodule(pybind11::module* module) { Creation examples: import cv2 cv_mat = cv2.imread(input_file)[:, :, ::-1] - rgb_frame = mp.ImageFrame(format=ImageFormat.SRGB, data=cv_mat) + rgb_frame = mp.ImageFrame(image_format=ImageFormat.SRGB, data=cv_mat) gray_frame = mp.ImageFrame( - format=ImageFormat.GRAY, data=cv2.cvtColor(cv_mat, cv2.COLOR_RGB2GRAY)) + image_format=ImageFormat.GRAY, + data=cv2.cvtColor(cv_mat, cv2.COLOR_RGB2GRAY)) from PIL import Image pil_img = Image.new('RGB', (60, 30), color = 'red') image_frame = mp.ImageFrame( - format=mp.ImageFormat.SRGB, data=np.asarray(pil_img)) + image_format=mp.ImageFormat.SRGB, data=np.asarray(pil_img)) The pixel data in an ImageFrame can be retrieved as a numpy ndarray by calling `ImageFrame.numpy_view()`. The returned numpy ndarray is a reference to the diff --git a/mediapipe/python/solutions/drawing_styles.py b/mediapipe/python/solutions/drawing_styles.py index b43bca8d3..5d75d5b30 100644 --- a/mediapipe/python/solutions/drawing_styles.py +++ b/mediapipe/python/solutions/drawing_styles.py @@ -37,9 +37,10 @@ _THICKNESS_FINGER = 2 _THICKNESS_DOT = -1 # Hand landmarks -_PALM_LANMARKS = (HandLandmark.WRIST, HandLandmark.THUMB_CMC, - HandLandmark.INDEX_FINGER_MCP, HandLandmark.MIDDLE_FINGER_MCP, - HandLandmark.RING_FINGER_MCP, HandLandmark.PINKY_MCP) +_PALM_LANDMARKS = (HandLandmark.WRIST, HandLandmark.THUMB_CMC, + HandLandmark.INDEX_FINGER_MCP, + HandLandmark.MIDDLE_FINGER_MCP, HandLandmark.RING_FINGER_MCP, + HandLandmark.PINKY_MCP) _THUMP_LANDMARKS = (HandLandmark.THUMB_MCP, HandLandmark.THUMB_IP, HandLandmark.THUMB_TIP) _INDEX_FINGER_LANDMARKS = (HandLandmark.INDEX_FINGER_PIP, @@ -54,7 +55,7 @@ _RING_FINGER_LANDMARKS = (HandLandmark.RING_FINGER_PIP, _PINKY_FINGER_LANDMARKS = (HandLandmark.PINKY_PIP, HandLandmark.PINKY_DIP, HandLandmark.PINKY_TIP) _HAND_LANDMARK_STYLE = { - _PALM_LANMARKS: + _PALM_LANDMARKS: DrawingSpec( color=_RED, thickness=_THICKNESS_DOT, circle_radius=_RADIUS), _THUMP_LANDMARKS: diff --git a/mediapipe/python/solutions/drawing_utils.py b/mediapipe/python/solutions/drawing_utils.py index bebcbe97c..1b8b173f7 100644 --- a/mediapipe/python/solutions/drawing_utils.py +++ b/mediapipe/python/solutions/drawing_utils.py @@ -11,7 +11,6 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. - """MediaPipe solution drawing utils.""" import math @@ -135,15 +134,14 @@ def draw_landmarks( the image. connections: A list of landmark index tuples that specifies how landmarks to be connected in the drawing. - landmark_drawing_spec: Either a DrawingSpec object or a mapping from - hand landmarks to the DrawingSpecs that specifies the landmarks' drawing - settings such as color, line thickness, and circle radius. - If this argument is explicitly set to None, no landmarks will be drawn. - connection_drawing_spec: Either a DrawingSpec object or a mapping from - hand connections to the DrawingSpecs that specifies the - connections' drawing settings such as color and line thickness. - If this argument is explicitly set to None, no landmark connections will - be drawn. + landmark_drawing_spec: Either a DrawingSpec object or a mapping from hand + landmarks to the DrawingSpecs that specifies the landmarks' drawing + settings such as color, line thickness, and circle radius. If this + argument is explicitly set to None, no landmarks will be drawn. + connection_drawing_spec: Either a DrawingSpec object or a mapping from hand + connections to the DrawingSpecs that specifies the connections' drawing + settings such as color and line thickness. If this argument is explicitly + set to None, no landmark connections will be drawn. Raises: ValueError: If one of the followings: @@ -197,14 +195,13 @@ def draw_landmarks( drawing_spec.color, drawing_spec.thickness) -def draw_axis( - image: np.ndarray, - rotation: np.ndarray, - translation: np.ndarray, - focal_length: Tuple[float, float] = (1.0, 1.0), - principal_point: Tuple[float, float] = (0.0, 0.0), - axis_length: float = 0.1, - axis_drawing_spec: DrawingSpec = DrawingSpec()): +def draw_axis(image: np.ndarray, + rotation: np.ndarray, + translation: np.ndarray, + focal_length: Tuple[float, float] = (1.0, 1.0), + principal_point: Tuple[float, float] = (0.0, 0.0), + axis_length: float = 0.1, + axis_drawing_spec: DrawingSpec = DrawingSpec()): """Draws the 3D axis on the image. Args: @@ -214,8 +211,8 @@ def draw_axis( focal_length: camera focal length along x and y directions. principal_point: camera principal point in x and y. axis_length: length of the axis in the drawing. - axis_drawing_spec: A DrawingSpec object that specifies the xyz axis - drawing settings such as line thickness. + axis_drawing_spec: A DrawingSpec object that specifies the xyz axis drawing + settings such as line thickness. Raises: ValueError: If one of the followings: @@ -226,7 +223,7 @@ def draw_axis( image_rows, image_cols, _ = image.shape # Create axis points in camera coordinate frame. axis_world = np.float32([[0, 0, 0], [1, 0, 0], [0, 1, 0], [0, 0, 1]]) - axis_cam = np.matmul(rotation, axis_length*axis_world.T).T + translation + axis_cam = np.matmul(rotation, axis_length * axis_world.T).T + translation x = axis_cam[..., 0] y = axis_cam[..., 1] z = axis_cam[..., 2] @@ -274,8 +271,9 @@ def plot_landmarks(landmark_list: landmark_pb2.NormalizedLandmarkList, connections' drawing settings such as color and line thickness. elevation: The elevation from which to view the plot. azimuth: the azimuth angle to rotate the plot. + Raises: - ValueError: If any connetions contain invalid landmark index. + ValueError: If any connection contains an invalid landmark index. """ if not landmark_list: return diff --git a/mediapipe/tasks/BUILD b/mediapipe/tasks/BUILD index 242a88cfc..98ddd5777 100644 --- a/mediapipe/tasks/BUILD +++ b/mediapipe/tasks/BUILD @@ -21,3 +21,10 @@ package_group( "//mediapipe/tasks/...", ], ) + +package_group( + name = "users", + includes = [ + ":internal", + ], +) diff --git a/mediapipe/tasks/cc/audio/audio_classifier/BUILD b/mediapipe/tasks/cc/audio/audio_classifier/BUILD index 1955adfe7..c575caabe 100644 --- a/mediapipe/tasks/cc/audio/audio_classifier/BUILD +++ b/mediapipe/tasks/cc/audio/audio_classifier/BUILD @@ -16,6 +16,33 @@ package(default_visibility = ["//mediapipe/tasks:internal"]) licenses(["notice"]) +# Docs for Mediapipe Tasks Audio Classifier +# https://developers.google.com/mediapipe/solutions/audio/audio_classifier +cc_library( + name = "audio_classifier", + srcs = ["audio_classifier.cc"], + hdrs = ["audio_classifier.h"], + visibility = ["//visibility:public"], + deps = [ + ":audio_classifier_graph", + "//mediapipe/framework/api2:builder", + "//mediapipe/framework/formats:matrix", + "//mediapipe/tasks/cc/audio/audio_classifier/proto:audio_classifier_graph_options_cc_proto", + "//mediapipe/tasks/cc/audio/core:audio_task_api_factory", + "//mediapipe/tasks/cc/audio/core:base_audio_task_api", + "//mediapipe/tasks/cc/audio/core:running_mode", + "//mediapipe/tasks/cc/components/containers:classification_result", + "//mediapipe/tasks/cc/components/containers/proto:classifications_cc_proto", + "//mediapipe/tasks/cc/components/processors:classifier_options", + "//mediapipe/tasks/cc/components/processors/proto:classifier_options_cc_proto", + "//mediapipe/tasks/cc/core:base_options", + "//mediapipe/tasks/cc/core:task_runner", + "//mediapipe/tasks/cc/core/proto:inference_subgraph_cc_proto", + "@com_google_absl//absl/status:statusor", + "@org_tensorflow//tensorflow/lite/core/api:op_resolver", + ], +) + cc_library( name = "audio_classifier_graph", srcs = ["audio_classifier_graph.cc"], @@ -26,7 +53,7 @@ cc_library( "//mediapipe/calculators/core:side_packet_to_stream_calculator", "//mediapipe/calculators/tensor:audio_to_tensor_calculator", "//mediapipe/calculators/tensor:audio_to_tensor_calculator_cc_proto", - "//mediapipe/calculators/tensor:inference_calculator", + "//mediapipe/calculators/tensor:inference_calculator_cpu", "//mediapipe/framework:calculator_cc_proto", "//mediapipe/framework:calculator_framework", "//mediapipe/framework/api2:builder", @@ -52,28 +79,4 @@ cc_library( alwayslink = 1, ) -cc_library( - name = "audio_classifier", - srcs = ["audio_classifier.cc"], - hdrs = ["audio_classifier.h"], - deps = [ - ":audio_classifier_graph", - "//mediapipe/framework/api2:builder", - "//mediapipe/framework/formats:matrix", - "//mediapipe/tasks/cc/audio/audio_classifier/proto:audio_classifier_graph_options_cc_proto", - "//mediapipe/tasks/cc/audio/core:audio_task_api_factory", - "//mediapipe/tasks/cc/audio/core:base_audio_task_api", - "//mediapipe/tasks/cc/audio/core:running_mode", - "//mediapipe/tasks/cc/components/containers:classification_result", - "//mediapipe/tasks/cc/components/containers/proto:classifications_cc_proto", - "//mediapipe/tasks/cc/components/processors:classifier_options", - "//mediapipe/tasks/cc/components/processors/proto:classifier_options_cc_proto", - "//mediapipe/tasks/cc/core:base_options", - "//mediapipe/tasks/cc/core:task_runner", - "//mediapipe/tasks/cc/core/proto:inference_subgraph_cc_proto", - "@com_google_absl//absl/status:statusor", - "@org_tensorflow//tensorflow/lite/core/api:op_resolver", - ], -) - # TODO: mediapipe/tasks/cc/audio/utils:test_utils does not compile in the OSS build diff --git a/mediapipe/tasks/cc/audio/audio_classifier/audio_classifier_test.cc b/mediapipe/tasks/cc/audio/audio_classifier/audio_classifier_test.cc index 596b910f8..2d5b221a9 100644 --- a/mediapipe/tasks/cc/audio/audio_classifier/audio_classifier_test.cc +++ b/mediapipe/tasks/cc/audio/audio_classifier/audio_classifier_test.cc @@ -143,8 +143,9 @@ void CheckStreamingModeResults(std::vector outputs) { EXPECT_EQ(outputs.size(), 5); // Ignore last result, which operates on a too small chunk to return relevant // results. + std::vector timestamps_ms = {0, 975, 1950, 2925}; for (int i = 0; i < outputs.size() - 1; i++) { - EXPECT_FALSE(outputs[i].timestamp_ms.has_value()); + EXPECT_EQ(outputs[i].timestamp_ms.value(), timestamps_ms[i]); EXPECT_EQ(outputs[i].classifications.size(), 1); EXPECT_EQ(outputs[i].classifications[0].head_index, 0); EXPECT_EQ(outputs[i].classifications[0].head_name, "scores"); diff --git a/mediapipe/tasks/cc/audio/audio_classifier/proto/audio_classifier_graph_options.proto b/mediapipe/tasks/cc/audio/audio_classifier/proto/audio_classifier_graph_options.proto index 5d4ba3296..cc26b3070 100644 --- a/mediapipe/tasks/cc/audio/audio_classifier/proto/audio_classifier_graph_options.proto +++ b/mediapipe/tasks/cc/audio/audio_classifier/proto/audio_classifier_graph_options.proto @@ -18,6 +18,7 @@ syntax = "proto2"; package mediapipe.tasks.audio.audio_classifier.proto; import "mediapipe/framework/calculator.proto"; +import "mediapipe/framework/calculator_options.proto"; import "mediapipe/tasks/cc/components/processors/proto/classifier_options.proto"; import "mediapipe/tasks/cc/core/proto/base_options.proto"; diff --git a/mediapipe/tasks/cc/audio/audio_embedder/BUILD b/mediapipe/tasks/cc/audio/audio_embedder/BUILD index b982ef39a..1dfdd6f1b 100644 --- a/mediapipe/tasks/cc/audio/audio_embedder/BUILD +++ b/mediapipe/tasks/cc/audio/audio_embedder/BUILD @@ -16,6 +16,34 @@ package(default_visibility = ["//mediapipe/tasks:internal"]) licenses(["notice"]) +# Docs for Mediapipe Tasks Audio Embedder +# https://developers.google.com/mediapipe/solutions/audio/audio_embedder +cc_library( + name = "audio_embedder", + srcs = ["audio_embedder.cc"], + hdrs = ["audio_embedder.h"], + visibility = ["//visibility:public"], + deps = [ + ":audio_embedder_graph", + "//mediapipe/framework/api2:builder", + "//mediapipe/framework/formats:matrix", + "//mediapipe/tasks/cc/audio/audio_embedder/proto:audio_embedder_graph_options_cc_proto", + "//mediapipe/tasks/cc/audio/core:audio_task_api_factory", + "//mediapipe/tasks/cc/audio/core:base_audio_task_api", + "//mediapipe/tasks/cc/audio/core:running_mode", + "//mediapipe/tasks/cc/components/containers:embedding_result", + "//mediapipe/tasks/cc/components/containers/proto:embeddings_cc_proto", + "//mediapipe/tasks/cc/components/processors:embedder_options", + "//mediapipe/tasks/cc/components/processors/proto:embedder_options_cc_proto", + "//mediapipe/tasks/cc/components/utils:cosine_similarity", + "//mediapipe/tasks/cc/core:base_options", + "//mediapipe/tasks/cc/core:task_runner", + "//mediapipe/tasks/cc/core/proto:inference_subgraph_cc_proto", + "@com_google_absl//absl/status:statusor", + "@org_tensorflow//tensorflow/lite/core/api:op_resolver", + ], +) + cc_library( name = "audio_embedder_graph", srcs = ["audio_embedder_graph.cc"], @@ -26,7 +54,7 @@ cc_library( "//mediapipe/calculators/core:side_packet_to_stream_calculator", "//mediapipe/calculators/tensor:audio_to_tensor_calculator", "//mediapipe/calculators/tensor:audio_to_tensor_calculator_cc_proto", - "//mediapipe/calculators/tensor:inference_calculator", + "//mediapipe/calculators/tensor:inference_calculator_cpu", "//mediapipe/framework:calculator_cc_proto", "//mediapipe/framework:calculator_framework", "//mediapipe/framework/api2:builder", @@ -51,29 +79,4 @@ cc_library( alwayslink = 1, ) -cc_library( - name = "audio_embedder", - srcs = ["audio_embedder.cc"], - hdrs = ["audio_embedder.h"], - deps = [ - ":audio_embedder_graph", - "//mediapipe/framework/api2:builder", - "//mediapipe/framework/formats:matrix", - "//mediapipe/tasks/cc/audio/audio_embedder/proto:audio_embedder_graph_options_cc_proto", - "//mediapipe/tasks/cc/audio/core:audio_task_api_factory", - "//mediapipe/tasks/cc/audio/core:base_audio_task_api", - "//mediapipe/tasks/cc/audio/core:running_mode", - "//mediapipe/tasks/cc/components/containers:embedding_result", - "//mediapipe/tasks/cc/components/containers/proto:embeddings_cc_proto", - "//mediapipe/tasks/cc/components/processors:embedder_options", - "//mediapipe/tasks/cc/components/processors/proto:embedder_options_cc_proto", - "//mediapipe/tasks/cc/components/utils:cosine_similarity", - "//mediapipe/tasks/cc/core:base_options", - "//mediapipe/tasks/cc/core:task_runner", - "//mediapipe/tasks/cc/core/proto:inference_subgraph_cc_proto", - "@com_google_absl//absl/status:statusor", - "@org_tensorflow//tensorflow/lite/core/api:op_resolver", - ], -) - # TODO: mediapipe/tasks/cc/audio/utils:test_utils does not compile in the OSS build diff --git a/mediapipe/tasks/cc/audio/audio_embedder/audio_embedder.h b/mediapipe/tasks/cc/audio/audio_embedder/audio_embedder.h index 4e7e20530..31cb61422 100644 --- a/mediapipe/tasks/cc/audio/audio_embedder/audio_embedder.h +++ b/mediapipe/tasks/cc/audio/audio_embedder/audio_embedder.h @@ -58,9 +58,12 @@ struct AudioEmbedderOptions { nullptr; }; -// Performs embedding extraction on audio clips or audio stream. +// Performs audio embedding extraction on audio clips or audio stream. // -// The API expects a TFLite model with TFLite Model Metadata. +// This API expects a TFLite model with mandatory TFLite Model Metadata that +// contains the mandatory AudioProperties of the solo input audio tensor and the +// optional (but recommended) label items as AssociatedFiles with type +// TENSOR_AXIS_LABELS per output embedding tensor. // // Input tensor: // (kTfLiteFloat32) diff --git a/mediapipe/tasks/cc/audio/audio_embedder/audio_embedder_graph.cc b/mediapipe/tasks/cc/audio/audio_embedder/audio_embedder_graph.cc index 7667feaa3..187f11f7f 100644 --- a/mediapipe/tasks/cc/audio/audio_embedder/audio_embedder_graph.cc +++ b/mediapipe/tasks/cc/audio/audio_embedder/audio_embedder_graph.cc @@ -100,6 +100,46 @@ void ConfigureAudioToTensorCalculator( } } // namespace +// An "AudioEmebdderGraph" performs embedding extractions. +// - Accepts CPU audio buffer and outputs embedding results on CPU. +// +// Inputs: +// AUDIO - Matrix +// Audio buffer to perform classification on. +// SAMPLE_RATE - double @Optional +// The sample rate of the corresponding audio data in the "AUDIO" stream. +// If sample rate is not provided, the "AUDIO" stream must carry a time +// series stream header with sample rate info. +// +// Outputs: +// EMBEDDINGS - EmbeddingResult @Optional +// The embedding results aggregated by head. Only produces results if +// the graph if the 'use_stream_mode' option is true. +// TIMESTAMPED_EMBEDDINGS - std::vector @Optional +// The embedding result aggregated by timestamp, then by head. Only +// produces results if the graph if the 'use_stream_mode' option is false. +// +// Example: +// node { +// calculator: "mediapipe.tasks.audio.audio_embedder.AudioEmbedderGraph" +// input_stream: "AUDIO:audio_in" +// input_stream: "SAMPLE_RATE:sample_rate_in" +// output_stream: "EMBEDDINGS:embeddings_out" +// output_stream: "TIMESTAMPED_EMBEDDINGS:timestamped_embeddings_out" +// options { +// [mediapipe.tasks.audio.audio_embedder.proto.AudioEmbedderGraphOptions.ext] +// { +// base_options { +// model_asset { +// file_name: "/path/to/model.tflite" +// } +// } +// embedder_options { +// l2_normalize: true +// } +// } +// } +// } class AudioEmbedderGraph : public core::ModelTaskGraph { public: absl::StatusOr GetConfig( @@ -158,10 +198,12 @@ class AudioEmbedderGraph : public core::ModelTaskGraph { // inference results. auto& postprocessing = graph.AddNode( "mediapipe.tasks.components.processors.EmbeddingPostprocessingGraph"); - MP_RETURN_IF_ERROR(components::processors::ConfigureEmbeddingPostprocessing( - model_resources, task_options.embedder_options(), - &postprocessing.GetOptions())); + MP_RETURN_IF_ERROR( + components::processors::ConfigureEmbeddingPostprocessingGraph( + model_resources, task_options.embedder_options(), + &postprocessing + .GetOptions())); inference.Out(kTensorsTag) >> postprocessing.In(kTensorsTag); // Time aggregation is only needed for performing audio embedding on // audio files. Disables timestamp aggregation by not connecting the diff --git a/mediapipe/tasks/cc/audio/audio_embedder/proto/audio_embedder_graph_options.proto b/mediapipe/tasks/cc/audio/audio_embedder/proto/audio_embedder_graph_options.proto index 25c5d5474..367a1bf26 100644 --- a/mediapipe/tasks/cc/audio/audio_embedder/proto/audio_embedder_graph_options.proto +++ b/mediapipe/tasks/cc/audio/audio_embedder/proto/audio_embedder_graph_options.proto @@ -18,6 +18,7 @@ syntax = "proto2"; package mediapipe.tasks.audio.audio_embedder.proto; import "mediapipe/framework/calculator.proto"; +import "mediapipe/framework/calculator_options.proto"; import "mediapipe/tasks/cc/components/processors/proto/embedder_options.proto"; import "mediapipe/tasks/cc/core/proto/base_options.proto"; diff --git a/mediapipe/tasks/cc/audio/core/BUILD b/mediapipe/tasks/cc/audio/core/BUILD index 93362fd3d..016faa10f 100644 --- a/mediapipe/tasks/cc/audio/core/BUILD +++ b/mediapipe/tasks/cc/audio/core/BUILD @@ -19,6 +19,7 @@ package(default_visibility = ["//mediapipe/tasks:internal"]) cc_library( name = "running_mode", hdrs = ["running_mode.h"], + visibility = ["//visibility:public"], ) cc_library( diff --git a/mediapipe/tasks/cc/components/BUILD b/mediapipe/tasks/cc/components/BUILD deleted file mode 100644 index c90349ab2..000000000 --- a/mediapipe/tasks/cc/components/BUILD +++ /dev/null @@ -1,88 +0,0 @@ -# Copyright 2022 The MediaPipe Authors. All Rights Reserved. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -load("//mediapipe/framework/port:build_config.bzl", "mediapipe_proto_library") - -package(default_visibility = ["//mediapipe/tasks:internal"]) - -licenses(["notice"]) - -mediapipe_proto_library( - name = "image_preprocessing_options_proto", - srcs = ["image_preprocessing_options.proto"], - deps = [ - "//mediapipe/calculators/tensor:image_to_tensor_calculator_proto", - "//mediapipe/framework:calculator_options_proto", - "//mediapipe/framework:calculator_proto", - ], -) - -cc_library( - name = "image_preprocessing", - srcs = ["image_preprocessing.cc"], - hdrs = ["image_preprocessing.h"], - deps = [ - ":image_preprocessing_options_cc_proto", - "//mediapipe/calculators/core:pass_through_calculator", - "//mediapipe/calculators/image:image_clone_calculator", - "//mediapipe/calculators/image:image_clone_calculator_cc_proto", - "//mediapipe/calculators/image:image_properties_calculator", - "//mediapipe/calculators/tensor:image_to_tensor_calculator", - "//mediapipe/calculators/tensor:image_to_tensor_calculator_cc_proto", - "//mediapipe/calculators/tensor:inference_calculator_cc_proto", - "//mediapipe/framework:calculator_framework", - "//mediapipe/framework/api2:builder", - "//mediapipe/framework/api2:port", - "//mediapipe/framework/formats:image", - "//mediapipe/framework/formats:rect_cc_proto", - "//mediapipe/framework/formats:tensor", - "//mediapipe/gpu:gpu_origin_cc_proto", - "//mediapipe/tasks/cc:common", - "//mediapipe/tasks/cc/core:model_resources", - "//mediapipe/tasks/cc/core/proto:acceleration_cc_proto", - "//mediapipe/tasks/cc/vision/utils:image_tensor_specs", - "@com_google_absl//absl/status", - "@com_google_absl//absl/status:statusor", - "@org_tensorflow//tensorflow/lite/schema:schema_fbs", - ], - alwayslink = 1, -) - -# TODO: Enable this test - -# TODO: Investigate rewriting the build rule to only link -# the Bert Preprocessor if it's needed. -cc_library( - name = "text_preprocessing_graph", - srcs = ["text_preprocessing_graph.cc"], - hdrs = ["text_preprocessing_graph.h"], - deps = [ - "//mediapipe/calculators/tensor:bert_preprocessor_calculator", - "//mediapipe/calculators/tensor:bert_preprocessor_calculator_cc_proto", - "//mediapipe/calculators/tensor:regex_preprocessor_calculator", - "//mediapipe/calculators/tensor:regex_preprocessor_calculator_cc_proto", - "//mediapipe/calculators/tensor:text_to_tensor_calculator", - "//mediapipe/framework:subgraph", - "//mediapipe/framework/api2:builder", - "//mediapipe/framework/api2:port", - "//mediapipe/framework/formats:tensor", - "//mediapipe/tasks/cc/components/proto:text_preprocessing_graph_options_cc_proto", - "//mediapipe/tasks/cc/core:model_resources", - "//mediapipe/tasks/cc/metadata:metadata_extractor", - "@com_google_absl//absl/status", - "@com_google_absl//absl/status:statusor", - "@com_google_absl//absl/strings", - ], - alwayslink = 1, -) diff --git a/mediapipe/tasks/cc/components/calculators/BUILD b/mediapipe/tasks/cc/components/calculators/BUILD index 1f726a018..16931811c 100644 --- a/mediapipe/tasks/cc/components/calculators/BUILD +++ b/mediapipe/tasks/cc/components/calculators/BUILD @@ -37,7 +37,6 @@ cc_library( "//mediapipe/framework/api2:packet", "//mediapipe/framework/api2:port", "//mediapipe/framework/formats:classification_cc_proto", - "//mediapipe/tasks/cc/components/containers/proto:category_cc_proto", "//mediapipe/tasks/cc/components/containers/proto:classifications_cc_proto", "@com_google_absl//absl/status", ], diff --git a/mediapipe/tasks/cc/components/calculators/classification_aggregation_calculator.cc b/mediapipe/tasks/cc/components/calculators/classification_aggregation_calculator.cc index 1a83fdad2..145076cd3 100644 --- a/mediapipe/tasks/cc/components/calculators/classification_aggregation_calculator.cc +++ b/mediapipe/tasks/cc/components/calculators/classification_aggregation_calculator.cc @@ -25,14 +25,12 @@ #include "mediapipe/framework/calculator_framework.h" #include "mediapipe/framework/formats/classification.pb.h" #include "mediapipe/tasks/cc/components/calculators/classification_aggregation_calculator.pb.h" -#include "mediapipe/tasks/cc/components/containers/proto/category.pb.h" #include "mediapipe/tasks/cc/components/containers/proto/classifications.pb.h" namespace mediapipe { namespace api2 { using ::mediapipe::tasks::components::containers::proto::ClassificationResult; -using ::mediapipe::tasks::components::containers::proto::Classifications; // Aggregates ClassificationLists into either a ClassificationResult object // representing the classification results aggregated by classifier head, or @@ -57,9 +55,6 @@ using ::mediapipe::tasks::components::containers::proto::Classifications; // The classification result aggregated by timestamp, then by head. Must be // connected if the TIMESTAMPS input is connected, as it signals that // timestamp aggregation is required. -// // TODO: remove output once migration is over. -// CLASSIFICATION_RESULT - (DEPRECATED) ClassificationResult @Optional -// The aggregated classification result. // // Example without timestamp aggregation: // node { @@ -122,9 +117,6 @@ class ClassificationAggregationCalculator : public Node { ClassificationResult ConvertToClassificationResult(CalculatorContext* cc); std::vector ConvertToTimestampedClassificationResults( CalculatorContext* cc); - // TODO: deprecate this function once migration is over. - ClassificationResult LegacyConvertToClassificationResult( - CalculatorContext* cc); }; absl::Status ClassificationAggregationCalculator::UpdateContract( @@ -137,10 +129,11 @@ absl::Status ClassificationAggregationCalculator::UpdateContract( << "The size of classifications input streams should match the " "size of head names specified in the calculator options"; } - // TODO: enforce connecting TIMESTAMPED_CLASSIFICATIONS if - // TIMESTAMPS is connected, and connecting CLASSIFICATIONS if TIMESTAMPS is - // not connected. All dependent tasks must be updated to use these outputs - // first. + if (kTimestampsIn(cc).IsConnected()) { + RET_CHECK(kTimestampedClassificationsOut(cc).IsConnected()); + } else { + RET_CHECK(kClassificationsOut(cc).IsConnected()); + } return absl::OkStatus(); } @@ -170,11 +163,9 @@ absl::Status ClassificationAggregationCalculator::Process( if (kTimestampsIn(cc).IsEmpty()) { return absl::OkStatus(); } - classification_result = LegacyConvertToClassificationResult(cc); kTimestampedClassificationsOut(cc).Send( ConvertToTimestampedClassificationResults(cc)); } else { - classification_result = LegacyConvertToClassificationResult(cc); kClassificationsOut(cc).Send(ConvertToClassificationResult(cc)); } kClassificationResultOut(cc).Send(classification_result); @@ -197,6 +188,7 @@ ClassificationAggregationCalculator::ConvertToClassificationResult( *classifications->mutable_classification_list() = std::move(classification_lists[i]); } + result.set_timestamp_ms(cc->InputTimestamp().Value() / 1000); cached_classifications_.erase(cc->InputTimestamp().Value()); return result; } @@ -226,55 +218,6 @@ ClassificationAggregationCalculator::ConvertToTimestampedClassificationResults( return results; } -ClassificationResult -ClassificationAggregationCalculator::LegacyConvertToClassificationResult( - CalculatorContext* cc) { - ClassificationResult result; - Timestamp first_timestamp(0); - std::vector timestamps; - if (time_aggregation_enabled_) { - timestamps = kTimestampsIn(cc).Get(); - first_timestamp = timestamps[0]; - } else { - timestamps = {cc->InputTimestamp()}; - } - for (Timestamp timestamp : timestamps) { - int count = cached_classifications_[timestamp.Value()].size(); - for (int i = 0; i < count; ++i) { - Classifications* c; - if (result.classifications_size() <= i) { - c = result.add_classifications(); - if (!head_names_.empty()) { - c->set_head_index(i); - c->set_head_name(head_names_[i]); - } - } else { - c = result.mutable_classifications(i); - } - auto* entry = c->add_entries(); - for (const auto& elem : - cached_classifications_[timestamp.Value()][i].classification()) { - auto* category = entry->add_categories(); - if (elem.has_index()) { - category->set_index(elem.index()); - } - if (elem.has_score()) { - category->set_score(elem.score()); - } - if (elem.has_label()) { - category->set_category_name(elem.label()); - } - if (elem.has_display_name()) { - category->set_display_name(elem.display_name()); - } - } - entry->set_timestamp_ms((timestamp.Value() - first_timestamp.Value()) / - 1000); - } - } - return result; -} - MEDIAPIPE_REGISTER_NODE(ClassificationAggregationCalculator); } // namespace api2 diff --git a/mediapipe/tasks/cc/components/calculators/classification_aggregation_calculator_test.cc b/mediapipe/tasks/cc/components/calculators/classification_aggregation_calculator_test.cc index 1bc8cafd6..811d70544 100644 --- a/mediapipe/tasks/cc/components/calculators/classification_aggregation_calculator_test.cc +++ b/mediapipe/tasks/cc/components/calculators/classification_aggregation_calculator_test.cc @@ -150,14 +150,15 @@ class ClassificationAggregationCalculatorTest CalculatorGraph calculator_graph_; }; -TEST_F(ClassificationAggregationCalculatorTest, SucceedsWithoutTimestamps) { +TEST_F(ClassificationAggregationCalculatorTest, SucceedsWithoutAggregation) { MP_ASSERT_OK_AND_ASSIGN(auto poller, BuildGraph()); MP_ASSERT_OK(Send({MakeClassificationList(0), MakeClassificationList(1)})); MP_ASSERT_OK_AND_ASSIGN(auto result, GetResult(poller)); EXPECT_THAT(result, EqualsProto(ParseTextProtoOrDie( - R"pb(classifications { + R"pb(timestamp_ms: 0, + classifications { head_index: 0 head_name: "foo" classification_list { classification { index: 0 } } @@ -169,7 +170,7 @@ TEST_F(ClassificationAggregationCalculatorTest, SucceedsWithoutTimestamps) { })pb"))); } -TEST_F(ClassificationAggregationCalculatorTest, SucceedsWithTimestamps) { +TEST_F(ClassificationAggregationCalculatorTest, SucceedsWithAggregation) { MP_ASSERT_OK_AND_ASSIGN(auto poller, BuildGraph(/*connect_timestamps=*/true)); MP_ASSERT_OK(Send({MakeClassificationList(0), MakeClassificationList(1)})); MP_ASSERT_OK(Send( diff --git a/mediapipe/tasks/cc/components/calculators/embedding_aggregation_calculator.cc b/mediapipe/tasks/cc/components/calculators/embedding_aggregation_calculator.cc index bae926b76..6e06c4e32 100644 --- a/mediapipe/tasks/cc/components/calculators/embedding_aggregation_calculator.cc +++ b/mediapipe/tasks/cc/components/calculators/embedding_aggregation_calculator.cc @@ -120,7 +120,9 @@ absl::Status EmbeddingAggregationCalculator::Process(CalculatorContext* cc) { } kTimestampedEmbeddingsOut(cc).Send(std::move(results)); } else { - kEmbeddingsOut(cc).Send(kEmbeddingsIn(cc)); + auto result = kEmbeddingsIn(cc).Get(); + result.set_timestamp_ms(cc->InputTimestamp().Value() / 1000); + kEmbeddingsOut(cc).Send(result); } RET_CHECK(cached_embeddings_.empty()); return absl::OkStatus(); diff --git a/mediapipe/tasks/cc/components/calculators/embedding_aggregation_calculator_test.cc b/mediapipe/tasks/cc/components/calculators/embedding_aggregation_calculator_test.cc index ebb4d8880..f2b2fa1d5 100644 --- a/mediapipe/tasks/cc/components/calculators/embedding_aggregation_calculator_test.cc +++ b/mediapipe/tasks/cc/components/calculators/embedding_aggregation_calculator_test.cc @@ -120,7 +120,7 @@ class EmbeddingAggregationCalculatorTest : public tflite_shims::testing::Test { CalculatorGraph calculator_graph_; }; -TEST_F(EmbeddingAggregationCalculatorTest, SucceedsWithoutTimestamps) { +TEST_F(EmbeddingAggregationCalculatorTest, SucceedsWithoutAggregation) { EmbeddingResult embedding = ParseTextProtoOrDie( R"pb(embeddings { head_index: 0 })pb"); @@ -129,10 +129,12 @@ TEST_F(EmbeddingAggregationCalculatorTest, SucceedsWithoutTimestamps) { MP_ASSERT_OK(Send(embedding)); MP_ASSERT_OK_AND_ASSIGN(auto result, GetResult(poller)); - EXPECT_THAT(result, EqualsProto(embedding)); + EXPECT_THAT(result, EqualsProto(ParseTextProtoOrDie( + R"pb(timestamp_ms: 0 + embeddings { head_index: 0 })pb"))); } -TEST_F(EmbeddingAggregationCalculatorTest, SucceedsWithTimestamps) { +TEST_F(EmbeddingAggregationCalculatorTest, SucceedsWithAggregation) { MP_ASSERT_OK_AND_ASSIGN(auto poller, BuildGraph(/*connect_timestamps=*/true)); MP_ASSERT_OK(Send(ParseTextProtoOrDie(R"pb(embeddings { head_index: 0 diff --git a/mediapipe/tasks/cc/components/containers/BUILD b/mediapipe/tasks/cc/components/containers/BUILD index bd66a0f28..a7307b2ce 100644 --- a/mediapipe/tasks/cc/components/containers/BUILD +++ b/mediapipe/tasks/cc/components/containers/BUILD @@ -12,12 +12,13 @@ # See the License for the specific language governing permissions and # limitations under the License. -package(default_visibility = ["//mediapipe/tasks:internal"]) +package(default_visibility = ["//visibility:public"]) licenses(["notice"]) cc_library( name = "rect", + srcs = ["rect.cc"], hdrs = ["rect.h"], ) @@ -41,6 +42,18 @@ cc_library( ], ) +cc_library( + name = "detection_result", + srcs = ["detection_result.cc"], + hdrs = ["detection_result.h"], + deps = [ + ":category", + ":rect", + "//mediapipe/framework/formats:detection_cc_proto", + "//mediapipe/framework/formats:location_data_cc_proto", + ], +) + cc_library( name = "embedding_result", srcs = ["embedding_result.cc"], @@ -49,3 +62,12 @@ cc_library( "//mediapipe/tasks/cc/components/containers/proto:embeddings_cc_proto", ], ) + +cc_library( + name = "landmark", + srcs = ["landmark.cc"], + hdrs = ["landmark.h"], + deps = [ + "//mediapipe/framework/formats:landmark_cc_proto", + ], +) diff --git a/mediapipe/tasks/cc/components/containers/classification_result.cc b/mediapipe/tasks/cc/components/containers/classification_result.cc index 98583ff15..f2d88406d 100644 --- a/mediapipe/tasks/cc/components/containers/classification_result.cc +++ b/mediapipe/tasks/cc/components/containers/classification_result.cc @@ -40,6 +40,19 @@ Classifications ConvertToClassifications(const proto::Classifications& proto) { return classifications; } +Classifications ConvertToClassifications( + const mediapipe::ClassificationList& proto, int head_index, + std::optional head_name) { + Classifications classifications; + classifications.categories.reserve(proto.classification_size()); + for (const auto& classification : proto.classification()) { + classifications.categories.push_back(ConvertToCategory(classification)); + } + classifications.head_index = head_index; + classifications.head_name = head_name; + return classifications; +} + ClassificationResult ConvertToClassificationResult( const proto::ClassificationResult& proto) { ClassificationResult classification_result; diff --git a/mediapipe/tasks/cc/components/containers/classification_result.h b/mediapipe/tasks/cc/components/containers/classification_result.h index 88273fd00..e359fb33e 100644 --- a/mediapipe/tasks/cc/components/containers/classification_result.h +++ b/mediapipe/tasks/cc/components/containers/classification_result.h @@ -20,6 +20,7 @@ limitations under the License. #include #include +#include "mediapipe/framework/formats/classification.pb.h" #include "mediapipe/tasks/cc/components/containers/category.h" #include "mediapipe/tasks/cc/components/containers/proto/classifications.pb.h" @@ -58,6 +59,12 @@ struct ClassificationResult { // Classifications struct. Classifications ConvertToClassifications(const proto::Classifications& proto); +// Utility function to convert from ClassificationList proto to +// Classifications struct. +Classifications ConvertToClassifications( + const mediapipe::ClassificationList& proto, int head_index = 0, + std::optional head_name = std::nullopt); + // Utility function to convert from ClassificationResult proto to // ClassificationResult struct. ClassificationResult ConvertToClassificationResult( diff --git a/mediapipe/tasks/cc/components/containers/detection_result.cc b/mediapipe/tasks/cc/components/containers/detection_result.cc new file mode 100644 index 000000000..38126f917 --- /dev/null +++ b/mediapipe/tasks/cc/components/containers/detection_result.cc @@ -0,0 +1,71 @@ +/* Copyright 2022 The MediaPipe Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "mediapipe/tasks/cc/components/containers/detection_result.h" + +#include +#include +#include + +#include "mediapipe/framework/formats/detection.pb.h" +#include "mediapipe/framework/formats/location_data.pb.h" +#include "mediapipe/tasks/cc/components/containers/category.h" +#include "mediapipe/tasks/cc/components/containers/rect.h" + +namespace mediapipe::tasks::components::containers { + +constexpr int kDefaultCategoryIndex = -1; + +Detection ConvertToDetectionResult( + const mediapipe::Detection& detection_proto) { + Detection detection; + for (int idx = 0; idx < detection_proto.score_size(); ++idx) { + detection.categories.push_back( + {/* index= */ detection_proto.label_id_size() > idx + ? detection_proto.label_id(idx) + : kDefaultCategoryIndex, + /* score= */ detection_proto.score(idx), + /* category_name */ detection_proto.label_size() > idx + ? detection_proto.label(idx) + : "", + /* display_name */ detection_proto.display_name_size() > idx + ? detection_proto.display_name(idx) + : ""}); + } + Rect bounding_box; + if (detection_proto.location_data().has_bounding_box()) { + mediapipe::LocationData::BoundingBox bounding_box_proto = + detection_proto.location_data().bounding_box(); + bounding_box.left = bounding_box_proto.xmin(); + bounding_box.top = bounding_box_proto.ymin(); + bounding_box.right = bounding_box_proto.xmin() + bounding_box_proto.width(); + bounding_box.bottom = + bounding_box_proto.ymin() + bounding_box_proto.height(); + } + detection.bounding_box = bounding_box; + return detection; +} + +DetectionResult ConvertToDetectionResult( + std::vector detections_proto) { + DetectionResult detection_result; + detection_result.detections.reserve(detections_proto.size()); + for (const auto& detection_proto : detections_proto) { + detection_result.detections.push_back( + ConvertToDetectionResult(detection_proto)); + } + return detection_result; +} +} // namespace mediapipe::tasks::components::containers diff --git a/mediapipe/tasks/cc/components/containers/detection_result.h b/mediapipe/tasks/cc/components/containers/detection_result.h new file mode 100644 index 000000000..546f324d6 --- /dev/null +++ b/mediapipe/tasks/cc/components/containers/detection_result.h @@ -0,0 +1,52 @@ +/* Copyright 2022 The MediaPipe Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef MEDIAPIPE_TASKS_CC_COMPONENTS_CONTAINERS_DETECTION_RESULT_H_ +#define MEDIAPIPE_TASKS_CC_COMPONENTS_CONTAINERS_DETECTION_RESULT_H_ + +#include +#include +#include + +#include "mediapipe/framework/formats/detection.pb.h" +#include "mediapipe/tasks/cc/components/containers/category.h" +#include "mediapipe/tasks/cc/components/containers/rect.h" + +namespace mediapipe::tasks::components::containers { + +// Detection for a single bounding box. +struct Detection { + // A vector of detected categories. + std::vector categories; + // The bounding box location. + Rect bounding_box; +}; + +// Detection results of a model. +struct DetectionResult { + // A vector of Detections. + std::vector detections; +}; + +// Utility function to convert from Detection proto to Detection struct. +Detection ConvertToDetection(const mediapipe::Detection& detection_proto); + +// Utility function to convert from list of Detection proto to DetectionResult +// struct. +DetectionResult ConvertToDetectionResult( + std::vector detections_proto); + +} // namespace mediapipe::tasks::components::containers +#endif // MEDIAPIPE_TASKS_CC_COMPONENTS_CONTAINERS_DETECTION_RESULT_H_ diff --git a/mediapipe/tasks/cc/components/containers/landmark.cc b/mediapipe/tasks/cc/components/containers/landmark.cc new file mode 100644 index 000000000..6d80cb835 --- /dev/null +++ b/mediapipe/tasks/cc/components/containers/landmark.cc @@ -0,0 +1,65 @@ +/* Copyright 2023 The MediaPipe Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "mediapipe/tasks/cc/components/containers/landmark.h" + +#include +#include + +#include "mediapipe/framework/formats/landmark.pb.h" + +namespace mediapipe::tasks::components::containers { + +Landmark ConvertToLandmark(const mediapipe::Landmark& proto) { + return {/*x=*/proto.x(), /*y=*/proto.y(), /*z=*/proto.z(), + /*visibility=*/proto.has_visibility() + ? std::optional(proto.visibility()) + : std::nullopt, + /*presence=*/proto.has_presence() + ? std::optional(proto.presence()) + : std::nullopt}; +} + +NormalizedLandmark ConvertToNormalizedLandmark( + const mediapipe::NormalizedLandmark& proto) { + return {/*x=*/proto.x(), /*y=*/proto.y(), /*z=*/proto.z(), + /*visibility=*/proto.has_visibility() + ? std::optional(proto.visibility()) + : std::nullopt, + /*presence=*/proto.has_presence() + ? std::optional(proto.presence()) + : std::nullopt}; +} + +Landmarks ConvertToLandmarks(const mediapipe::LandmarkList& proto) { + Landmarks landmarks; + landmarks.landmarks.reserve(proto.landmark_size()); + for (const auto& landmark : proto.landmark()) { + landmarks.landmarks.push_back(ConvertToLandmark(landmark)); + } + return landmarks; +} + +NormalizedLandmarks ConvertToNormalizedLandmarks( + const mediapipe::NormalizedLandmarkList& proto) { + NormalizedLandmarks landmarks; + landmarks.landmarks.reserve(proto.landmark_size()); + for (const auto& landmark : proto.landmark()) { + landmarks.landmarks.push_back(ConvertToNormalizedLandmark(landmark)); + } + return landmarks; +} + +} // namespace mediapipe::tasks::components::containers diff --git a/mediapipe/tasks/cc/components/containers/landmark.h b/mediapipe/tasks/cc/components/containers/landmark.h new file mode 100644 index 000000000..15b730001 --- /dev/null +++ b/mediapipe/tasks/cc/components/containers/landmark.h @@ -0,0 +1,103 @@ +/* Copyright 2023 The MediaPipe Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef MEDIAPIPE_TASKS_CC_COMPONENTS_CONTAINERS_LANDMARK_H_ +#define MEDIAPIPE_TASKS_CC_COMPONENTS_CONTAINERS_LANDMARK_H_ + +#include +#include +#include + +#include "mediapipe/framework/formats/landmark.pb.h" + +namespace mediapipe::tasks::components::containers { +constexpr float kLandmarkTolerance = 1e-6; + +// Landmark represents a point in 3D space with x, y, z coordinates. The +// landmark coordinates are in meters. z represents the landmark depth, and the +// smaller the value the closer the world landmark is to the camera. +struct Landmark { + float x; + float y; + float z; + // Landmark visibility. Should stay unset if not supported. + // Float score of whether landmark is visible or occluded by other objects. + // Landmark considered as invisible also if it is not present on the screen + // (out of scene bounds). Depending on the model, visibility value is either a + // sigmoid or an argument of sigmoid. + std::optional visibility = std::nullopt; + // Landmark presence. Should stay unset if not supported. + // Float score of whether landmark is present on the scene (located within + // scene bounds). Depending on the model, presence value is either a result of + // sigmoid or an argument of sigmoid function to get landmark presence + // probability. + std::optional presence = std::nullopt; + // Landmark name. Should stay unset if not supported. + std::optional name = std::nullopt; +}; + +inline bool operator==(const Landmark& lhs, const Landmark& rhs) { + return abs(lhs.x - rhs.x) < kLandmarkTolerance && + abs(lhs.y - rhs.y) < kLandmarkTolerance && + abs(lhs.z - rhs.z) < kLandmarkTolerance; +} + +// A normalized version of above Landmark struct. All coordinates should be +// within [0, 1]. +struct NormalizedLandmark { + float x; + float y; + float z; + std::optional visibility = std::nullopt; + std::optional presence = std::nullopt; + std::optional name = std::nullopt; +}; + +inline bool operator==(const NormalizedLandmark& lhs, + const NormalizedLandmark& rhs) { + return abs(lhs.x - rhs.x) < kLandmarkTolerance && + abs(lhs.y - rhs.y) < kLandmarkTolerance && + abs(lhs.z - rhs.z) < kLandmarkTolerance; +} + +// A list of Landmarks. +struct Landmarks { + std::vector landmarks; +}; + +// A list of NormalizedLandmarks. +struct NormalizedLandmarks { + std::vector landmarks; +}; + +// Utility function to convert from Landmark proto to Landmark struct. +Landmark ConvertToLandmark(const mediapipe::Landmark& proto); + +// Utility function to convert from NormalizedLandmark proto to +// NormalizedLandmark struct. +NormalizedLandmark ConvertToNormalizedLandmark( + const mediapipe::NormalizedLandmark& proto); + +// Utility function to convert from LandmarkList proto to Landmarks struct. +Landmarks ConvertToLandmarks(const mediapipe::LandmarkList& proto); + +// Utility function to convert from NormalizedLandmarkList proto to +// NormalizedLandmarks struct. +NormalizedLandmarks ConvertToNormalizedLandmarks( + const mediapipe::NormalizedLandmarkList& proto); + +} // namespace mediapipe::tasks::components::containers + +#endif // MEDIAPIPE_TASKS_CC_COMPONENTS_CONTAINERS_LANDMARK_H_ diff --git a/mediapipe/tasks/cc/components/containers/proto/BUILD b/mediapipe/tasks/cc/components/containers/proto/BUILD index 7b455c0c4..27d2357b5 100644 --- a/mediapipe/tasks/cc/components/containers/proto/BUILD +++ b/mediapipe/tasks/cc/components/containers/proto/BUILD @@ -18,16 +18,10 @@ package(default_visibility = ["//mediapipe/tasks:internal"]) licenses(["notice"]) -mediapipe_proto_library( - name = "category_proto", - srcs = ["category.proto"], -) - mediapipe_proto_library( name = "classifications_proto", srcs = ["classifications.proto"], deps = [ - ":category_proto", "//mediapipe/framework/formats:classification_proto", ], ) diff --git a/mediapipe/tasks/cc/components/containers/proto/category.proto b/mediapipe/tasks/cc/components/containers/proto/category.proto deleted file mode 100644 index 412e71428..000000000 --- a/mediapipe/tasks/cc/components/containers/proto/category.proto +++ /dev/null @@ -1,41 +0,0 @@ -/* Copyright 2022 The MediaPipe Authors. All Rights Reserved. - -Licensed under the Apache License, Version 2.0 (the "License"); -you may not use this file except in compliance with the License. -You may obtain a copy of the License at - - http://www.apache.org/licenses/LICENSE-2.0 - -Unless required by applicable law or agreed to in writing, software -distributed under the License is distributed on an "AS IS" BASIS, -WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -See the License for the specific language governing permissions and -limitations under the License. -==============================================================================*/ - -syntax = "proto2"; - -package mediapipe.tasks.components.containers.proto; - -option java_package = "com.google.mediapipe.tasks.components.containers.proto"; -option java_outer_classname = "CategoryProto"; - -// TODO: deprecate this message once migration is over. -// A single classification result. -message Category { - // The index of the category in the corresponding label map, usually packed in - // the TFLite Model Metadata [1]. - // - // [1]: https://www.tensorflow.org/lite/convert/metadata - optional int32 index = 1; - // The score for this category, e.g. (but not necessarily) a probability in - // [0,1]. - optional float score = 2; - // A human readable name of the category filled from the label map. - optional string display_name = 3; - // An ID for the category, not necessarily human-readable, e.g. a Google - // Knowledge Graph ID [1], filled from the label map. - // - // [1]: https://developers.google.com/knowledge-graph - optional string category_name = 4; -} diff --git a/mediapipe/tasks/cc/components/containers/proto/classifications.proto b/mediapipe/tasks/cc/components/containers/proto/classifications.proto index f098ed0e4..2b2306829 100644 --- a/mediapipe/tasks/cc/components/containers/proto/classifications.proto +++ b/mediapipe/tasks/cc/components/containers/proto/classifications.proto @@ -18,27 +18,12 @@ syntax = "proto2"; package mediapipe.tasks.components.containers.proto; import "mediapipe/framework/formats/classification.proto"; -import "mediapipe/tasks/cc/components/containers/proto/category.proto"; option java_package = "com.google.mediapipe.tasks.components.containers.proto"; option java_outer_classname = "ClassificationsProto"; -// TODO: deprecate this message once migration is over. -// List of predicted categories with an optional timestamp. -message ClassificationEntry { - // The array of predicted categories, usually sorted by descending scores, - // e.g., from high to low probability. - repeated Category categories = 1; - // The optional timestamp (in milliseconds) associated to the classifcation - // entry. This is useful for time series use cases, e.g., audio - // classification. - optional int64 timestamp_ms = 2; -} - // Classifications for a given classifier head, i.e. for a given output tensor. message Classifications { - // TODO: deprecate this field once migration is over. - repeated ClassificationEntry entries = 1; // The classification results for this head. optional mediapipe.ClassificationList classification_list = 4; // The index of the classifier head these categories refer to. This is useful @@ -48,6 +33,8 @@ message Classifications { // name. // TODO: Add github link to metadata_schema.fbs. optional string head_name = 3; + // Reserved fields. + reserved 1; } // Classifications for a given classifier model. diff --git a/mediapipe/tasks/cc/components/containers/rect.cc b/mediapipe/tasks/cc/components/containers/rect.cc new file mode 100644 index 000000000..4a94832a6 --- /dev/null +++ b/mediapipe/tasks/cc/components/containers/rect.cc @@ -0,0 +1,34 @@ +/* Copyright 2022 The MediaPipe Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "mediapipe/tasks/cc/components/containers/rect.h" + +namespace mediapipe::tasks::components::containers { + +RectF ToRectF(const Rect& rect, int image_height, int image_width) { + return RectF{static_cast(rect.left) / image_width, + static_cast(rect.top) / image_height, + static_cast(rect.right) / image_width, + static_cast(rect.bottom) / image_height}; +} + +Rect ToRect(const RectF& rect, int image_height, int image_width) { + return Rect{static_cast(rect.left * image_width), + static_cast(rect.top * image_height), + static_cast(rect.right * image_width), + static_cast(rect.bottom * image_height)}; +} + +} // namespace mediapipe::tasks::components::containers diff --git a/mediapipe/tasks/cc/components/containers/rect.h b/mediapipe/tasks/cc/components/containers/rect.h index 3f5432cf2..72c7a8acb 100644 --- a/mediapipe/tasks/cc/components/containers/rect.h +++ b/mediapipe/tasks/cc/components/containers/rect.h @@ -16,20 +16,48 @@ limitations under the License. #ifndef MEDIAPIPE_TASKS_CC_COMPONENTS_CONTAINERS_RECT_H_ #define MEDIAPIPE_TASKS_CC_COMPONENTS_CONTAINERS_RECT_H_ +#include +#include + namespace mediapipe::tasks::components::containers { +constexpr float kRectFTolerance = 1e-4; + // Defines a rectangle, used e.g. as part of detection results or as input // region-of-interest. // +struct Rect { + int left; + int top; + int right; + int bottom; +}; + +inline bool operator==(const Rect& lhs, const Rect& rhs) { + return lhs.left == rhs.left && lhs.top == rhs.top && lhs.right == rhs.right && + lhs.bottom == rhs.bottom; +} + // The coordinates are normalized wrt the image dimensions, i.e. generally in // [0,1] but they may exceed these bounds if describing a region overlapping the // image. The origin is on the top-left corner of the image. -struct Rect { +struct RectF { float left; float top; float right; float bottom; }; +inline bool operator==(const RectF& lhs, const RectF& rhs) { + return std::fabs(lhs.left - rhs.left) < kRectFTolerance && + std::fabs(lhs.top - rhs.top) < kRectFTolerance && + std::fabs(lhs.right - rhs.right) < kRectFTolerance && + std::fabs(lhs.bottom - rhs.bottom) < kRectFTolerance; +} + +RectF ToRectF(const Rect& rect, int image_height, int image_width); + +Rect ToRect(const RectF& rect, int image_height, int image_width); + } // namespace mediapipe::tasks::components::containers #endif // MEDIAPIPE_TASKS_CC_COMPONENTS_CONTAINERS_RECT_H_ diff --git a/mediapipe/tasks/cc/components/processors/BUILD b/mediapipe/tasks/cc/components/processors/BUILD index 7845a3dae..10bc0726a 100644 --- a/mediapipe/tasks/cc/components/processors/BUILD +++ b/mediapipe/tasks/cc/components/processors/BUILD @@ -20,6 +20,7 @@ cc_library( name = "classifier_options", srcs = ["classifier_options.cc"], hdrs = ["classifier_options.h"], + visibility = ["//visibility:public"], deps = ["//mediapipe/tasks/cc/components/processors/proto:classifier_options_cc_proto"], ) @@ -47,7 +48,6 @@ cc_library( "//mediapipe/tasks/cc/components/containers/proto:classifications_cc_proto", "//mediapipe/tasks/cc/components/processors/proto:classification_postprocessing_graph_options_cc_proto", "//mediapipe/tasks/cc/components/processors/proto:classifier_options_cc_proto", - "//mediapipe/tasks/cc/components/utils:source_or_node_output", "//mediapipe/tasks/cc/core:model_resources", "//mediapipe/tasks/cc/metadata:metadata_extractor", "//mediapipe/tasks/metadata:metadata_schema_cc", @@ -67,6 +67,7 @@ cc_library( name = "embedder_options", srcs = ["embedder_options.cc"], hdrs = ["embedder_options.h"], + visibility = ["//visibility:public"], deps = ["//mediapipe/tasks/cc/components/processors/proto:embedder_options_cc_proto"], ) @@ -88,7 +89,6 @@ cc_library( "//mediapipe/tasks/cc/components/containers/proto:embeddings_cc_proto", "//mediapipe/tasks/cc/components/processors/proto:embedder_options_cc_proto", "//mediapipe/tasks/cc/components/processors/proto:embedding_postprocessing_graph_options_cc_proto", - "//mediapipe/tasks/cc/components/utils:source_or_node_output", "//mediapipe/tasks/cc/core:model_resources", "//mediapipe/tasks/cc/metadata:metadata_extractor", "@com_google_absl//absl/status", @@ -98,3 +98,65 @@ cc_library( ], alwayslink = 1, ) + +cc_library( + name = "image_preprocessing_graph", + srcs = ["image_preprocessing_graph.cc"], + hdrs = ["image_preprocessing_graph.h"], + deps = [ + "//mediapipe/calculators/core:pass_through_calculator", + "//mediapipe/calculators/image:image_clone_calculator", + "//mediapipe/calculators/image:image_clone_calculator_cc_proto", + "//mediapipe/calculators/image:image_properties_calculator", + "//mediapipe/calculators/tensor:image_to_tensor_calculator", + "//mediapipe/calculators/tensor:image_to_tensor_calculator_cc_proto", + "//mediapipe/calculators/tensor:inference_calculator_cc_proto", + "//mediapipe/framework:calculator_framework", + "//mediapipe/framework/api2:builder", + "//mediapipe/framework/api2:port", + "//mediapipe/framework/formats:image", + "//mediapipe/framework/formats:rect_cc_proto", + "//mediapipe/framework/formats:tensor", + "//mediapipe/gpu:gpu_origin_cc_proto", + "//mediapipe/tasks/cc:common", + "//mediapipe/tasks/cc/components/processors/proto:image_preprocessing_graph_options_cc_proto", + "//mediapipe/tasks/cc/core:model_resources", + "//mediapipe/tasks/cc/core/proto:acceleration_cc_proto", + "//mediapipe/tasks/cc/vision/utils:image_tensor_specs", + "@com_google_absl//absl/status", + "@com_google_absl//absl/status:statusor", + "@org_tensorflow//tensorflow/lite/schema:schema_fbs", + ], + alwayslink = 1, +) + +# TODO: Enable this test + +# TODO: Investigate rewriting the build rule to only link +# the Bert Preprocessor if it's needed. +cc_library( + name = "text_preprocessing_graph", + srcs = ["text_preprocessing_graph.cc"], + hdrs = ["text_preprocessing_graph.h"], + deps = [ + "//mediapipe/calculators/tensor:bert_preprocessor_calculator", + "//mediapipe/calculators/tensor:bert_preprocessor_calculator_cc_proto", + "//mediapipe/calculators/tensor:regex_preprocessor_calculator", + "//mediapipe/calculators/tensor:regex_preprocessor_calculator_cc_proto", + "//mediapipe/calculators/tensor:text_to_tensor_calculator", + "//mediapipe/framework:subgraph", + "//mediapipe/framework/api2:builder", + "//mediapipe/framework/api2:port", + "//mediapipe/framework/formats:tensor", + "//mediapipe/tasks/cc:common", + "//mediapipe/tasks/cc/components/processors/proto:text_model_type_cc_proto", + "//mediapipe/tasks/cc/components/processors/proto:text_preprocessing_graph_options_cc_proto", + "//mediapipe/tasks/cc/core:model_resources", + "//mediapipe/tasks/cc/metadata:metadata_extractor", + "//mediapipe/tasks/cc/text/utils:text_model_utils", + "@com_google_absl//absl/status", + "@com_google_absl//absl/status:statusor", + "@com_google_absl//absl/strings", + ], + alwayslink = 1, +) diff --git a/mediapipe/tasks/cc/components/processors/classification_postprocessing_graph.cc b/mediapipe/tasks/cc/components/processors/classification_postprocessing_graph.cc index 0fb62afaf..cfb3b02cf 100644 --- a/mediapipe/tasks/cc/components/processors/classification_postprocessing_graph.cc +++ b/mediapipe/tasks/cc/components/processors/classification_postprocessing_graph.cc @@ -40,7 +40,6 @@ limitations under the License. #include "mediapipe/tasks/cc/components/containers/proto/classifications.pb.h" #include "mediapipe/tasks/cc/components/processors/proto/classification_postprocessing_graph_options.pb.h" #include "mediapipe/tasks/cc/components/processors/proto/classifier_options.pb.h" -#include "mediapipe/tasks/cc/components/utils/source_or_node_output.h" #include "mediapipe/tasks/cc/core/model_resources.h" #include "mediapipe/tasks/cc/metadata/metadata_extractor.h" #include "mediapipe/tasks/metadata/metadata_schema_generated.h" @@ -68,12 +67,11 @@ using ::mediapipe::tasks::metadata::ModelMetadataExtractor; using ::tflite::ProcessUnit; using ::tflite::TensorMetadata; using LabelItems = mediapipe::proto_ns::Map; -using TensorsSource = mediapipe::tasks::SourceOrNodeOutput>; +using TensorsSource = mediapipe::api2::builder::Source>; constexpr float kDefaultScoreThreshold = std::numeric_limits::lowest(); constexpr char kCalibratedScoresTag[] = "CALIBRATED_SCORES"; -constexpr char kClassificationResultTag[] = "CLASSIFICATION_RESULT"; constexpr char kClassificationsTag[] = "CLASSIFICATIONS"; constexpr char kScoresTag[] = "SCORES"; constexpr char kTensorsTag[] = "TENSORS"; @@ -82,7 +80,6 @@ constexpr char kTimestampedClassificationsTag[] = "TIMESTAMPED_CLASSIFICATIONS"; // Struct holding the different output streams produced by the graph. struct ClassificationPostprocessingOutputStreams { - Source classification_result; Source classifications; Source> timestamped_classifications; }; @@ -400,9 +397,6 @@ absl::Status ConfigureClassificationPostprocessingGraph( // The classification result aggregated by timestamp, then by head. Must be // connected if the TIMESTAMPS input is connected, as it signals that // timestamp aggregation is required. -// // TODO: remove output once migration is over. -// CLASSIFICATION_RESULT - (DEPRECATED) ClassificationResult @Optional -// The aggregated classification result. // // The recommended way of using this graph is through the GraphBuilder API // using the 'ConfigureClassificationPostprocessingGraph()' function. See header @@ -418,8 +412,6 @@ class ClassificationPostprocessingGraph : public mediapipe::Subgraph { sc->Options(), graph[Input>(kTensorsTag)], graph[Input>(kTimestampsTag)], graph)); - output_streams.classification_result >> - graph[Output(kClassificationResultTag)]; output_streams.classifications >> graph[Output(kClassificationsTag)]; output_streams.timestamped_classifications >> @@ -462,12 +454,13 @@ class ClassificationPostprocessingGraph : public mediapipe::Subgraph { } // If output tensors are quantized, they must be dequantized first. - TensorsSource dequantized_tensors(&tensors_in); + TensorsSource dequantized_tensors = tensors_in; if (options.has_quantized_outputs()) { GenericNode* tensors_dequantization_node = &graph.AddNode("TensorsDequantizationCalculator"); tensors_in >> tensors_dequantization_node->In(kTensorsTag); - dequantized_tensors = {tensors_dequantization_node, kTensorsTag}; + dequantized_tensors = tensors_dequantization_node->Out(kTensorsTag) + .Cast>(); } // If there are multiple classification heads, the output tensors need to be @@ -484,7 +477,8 @@ class ClassificationPostprocessingGraph : public mediapipe::Subgraph { auto* range = split_tensor_vector_options.add_ranges(); range->set_begin(i); range->set_end(i + 1); - split_tensors.emplace_back(split_tensor_vector_node, i); + split_tensors.push_back( + split_tensor_vector_node->Out(i).Cast>()); } dequantized_tensors >> split_tensor_vector_node->In(0); } else { @@ -501,8 +495,9 @@ class ClassificationPostprocessingGraph : public mediapipe::Subgraph { score_calibration_node->GetOptions() .CopyFrom(options.score_calibration_options().at(i)); split_tensors[i] >> score_calibration_node->In(kScoresTag); - calibrated_tensors.emplace_back(score_calibration_node, - kCalibratedScoresTag); + calibrated_tensors.push_back( + score_calibration_node->Out(kCalibratedScoresTag) + .Cast>()); } else { calibrated_tensors.emplace_back(split_tensors[i]); } @@ -536,8 +531,6 @@ class ClassificationPostprocessingGraph : public mediapipe::Subgraph { // Connects output. ClassificationPostprocessingOutputStreams output_streams{ - /*classification_result=*/result_aggregation - [Output(kClassificationResultTag)], /*classifications=*/ result_aggregation[Output(kClassificationsTag)], /*timestamped_classifications=*/ diff --git a/mediapipe/tasks/cc/components/processors/classification_postprocessing_graph.h b/mediapipe/tasks/cc/components/processors/classification_postprocessing_graph.h index 48575ceb0..03ae91130 100644 --- a/mediapipe/tasks/cc/components/processors/classification_postprocessing_graph.h +++ b/mediapipe/tasks/cc/components/processors/classification_postprocessing_graph.h @@ -58,9 +58,6 @@ namespace processors { // The classification result aggregated by timestamp, then by head. Must be // connected if the TIMESTAMPS input is connected, as it signals that // timestamp aggregation is required. -// // TODO: remove output once migration is over. -// CLASSIFICATION_RESULT - (DEPRECATED) ClassificationResult @Optional -// The aggregated classification result. absl::Status ConfigureClassificationPostprocessingGraph( const tasks::core::ModelResources& model_resources, const proto::ClassifierOptions& classifier_options, diff --git a/mediapipe/tasks/cc/components/processors/classification_postprocessing_graph_test.cc b/mediapipe/tasks/cc/components/processors/classification_postprocessing_graph_test.cc index d4728e725..a11bad71a 100644 --- a/mediapipe/tasks/cc/components/processors/classification_postprocessing_graph_test.cc +++ b/mediapipe/tasks/cc/components/processors/classification_postprocessing_graph_test.cc @@ -86,8 +86,6 @@ constexpr char kTensorsTag[] = "TENSORS"; constexpr char kTensorsName[] = "tensors"; constexpr char kTimestampsTag[] = "TIMESTAMPS"; constexpr char kTimestampsName[] = "timestamps"; -constexpr char kClassificationResultTag[] = "CLASSIFICATION_RESULT"; -constexpr char kClassificationResultName[] = "classification_result"; constexpr char kClassificationsTag[] = "CLASSIFICATIONS"; constexpr char kClassificationsName[] = "classifications"; constexpr char kTimestampedClassificationsTag[] = "TIMESTAMPED_CLASSIFICATIONS"; @@ -536,6 +534,7 @@ TEST_F(PostprocessingTest, SucceedsWithoutMetadata) { // Validate results. EXPECT_THAT(results, EqualsProto(ParseTextProtoOrDie(R"pb( + timestamp_ms: 0, classifications { head_index: 0 classification_list { @@ -569,6 +568,7 @@ TEST_F(PostprocessingTest, SucceedsWithMetadata) { // Validate results. EXPECT_THAT( results, EqualsProto(ParseTextProtoOrDie(R"pb( + timestamp_ms: 0, classifications { head_index: 0 head_name: "probability" @@ -605,6 +605,7 @@ TEST_F(PostprocessingTest, SucceedsWithScoreCalibration) { // Validate results. EXPECT_THAT( results, EqualsProto(ParseTextProtoOrDie(R"pb( + timestamp_ms: 0, classifications { head_index: 0 head_name: "probability" @@ -648,6 +649,7 @@ TEST_F(PostprocessingTest, SucceedsWithMultipleHeads) { // Validate results. EXPECT_THAT( results, EqualsProto(ParseTextProtoOrDie(R"pb( + timestamp_ms: 0, classifications { head_index: 0 head_name: "yamnet_classification" @@ -728,326 +730,6 @@ TEST_F(PostprocessingTest, SucceedsWithTimestamps) { })pb")})); } -// TODO: remove these tests once migration is over. -class LegacyPostprocessingTest : public tflite_shims::testing::Test { - protected: - absl::StatusOr BuildGraph( - absl::string_view model_name, const proto::ClassifierOptions& options, - bool connect_timestamps = false) { - ASSIGN_OR_RETURN(auto model_resources, - CreateModelResourcesForModel(model_name)); - - Graph graph; - auto& postprocessing = graph.AddNode( - "mediapipe.tasks.components.processors." - "ClassificationPostprocessingGraph"); - MP_RETURN_IF_ERROR(ConfigureClassificationPostprocessingGraph( - *model_resources, options, - &postprocessing - .GetOptions())); - graph[Input>(kTensorsTag)].SetName(kTensorsName) >> - postprocessing.In(kTensorsTag); - if (connect_timestamps) { - graph[Input>(kTimestampsTag)].SetName( - kTimestampsName) >> - postprocessing.In(kTimestampsTag); - } - postprocessing.Out(kClassificationResultTag) - .SetName(kClassificationResultName) >> - graph[Output(kClassificationResultTag)]; - - MP_RETURN_IF_ERROR(calculator_graph_.Initialize(graph.GetConfig())); - ASSIGN_OR_RETURN(auto poller, calculator_graph_.AddOutputStreamPoller( - kClassificationResultName)); - MP_RETURN_IF_ERROR(calculator_graph_.StartRun(/*extra_side_packets=*/{})); - return poller; - } - - template - void AddTensor( - const std::vector& tensor, const Tensor::ElementType& element_type, - const Tensor::QuantizationParameters& quantization_parameters = {}) { - tensors_->emplace_back(element_type, - Tensor::Shape{1, static_cast(tensor.size())}, - quantization_parameters); - auto view = tensors_->back().GetCpuWriteView(); - T* buffer = view.buffer(); - std::copy(tensor.begin(), tensor.end(), buffer); - } - - absl::Status Run( - std::optional> aggregation_timestamps = std::nullopt, - int timestamp = 0) { - MP_RETURN_IF_ERROR(calculator_graph_.AddPacketToInputStream( - kTensorsName, Adopt(tensors_.release()).At(Timestamp(timestamp)))); - // Reset tensors for future calls. - tensors_ = absl::make_unique>(); - if (aggregation_timestamps.has_value()) { - auto packet = absl::make_unique>(); - for (const auto& timestamp : *aggregation_timestamps) { - packet->emplace_back(Timestamp(timestamp)); - } - MP_RETURN_IF_ERROR(calculator_graph_.AddPacketToInputStream( - kTimestampsName, Adopt(packet.release()).At(Timestamp(timestamp)))); - } - return absl::OkStatus(); - } - - absl::StatusOr GetClassificationResult( - OutputStreamPoller& poller) { - MP_RETURN_IF_ERROR(calculator_graph_.WaitUntilIdle()); - MP_RETURN_IF_ERROR(calculator_graph_.CloseAllInputStreams()); - - Packet packet; - if (!poller.Next(&packet)) { - return absl::InternalError("Unable to get output packet"); - } - auto result = packet.Get(); - MP_RETURN_IF_ERROR(calculator_graph_.WaitUntilDone()); - return result; - } - - private: - CalculatorGraph calculator_graph_; - std::unique_ptr> tensors_ = - absl::make_unique>(); -}; - -TEST_F(LegacyPostprocessingTest, SucceedsWithoutMetadata) { - // Build graph. - proto::ClassifierOptions options; - options.set_max_results(3); - options.set_score_threshold(0.5); - MP_ASSERT_OK_AND_ASSIGN( - auto poller, - BuildGraph(kQuantizedImageClassifierWithoutMetadata, options)); - // Build input tensors. - std::vector tensor(kMobileNetNumClasses, 0); - tensor[1] = 18; - tensor[2] = 16; - - // Send tensors and get results. - AddTensor(tensor, Tensor::ElementType::kUInt8, - /*quantization_parameters=*/{0.1, 10}); - MP_ASSERT_OK(Run()); - MP_ASSERT_OK_AND_ASSIGN(auto results, GetClassificationResult(poller)); - - // Validate results. - EXPECT_THAT(results, EqualsProto(R"pb(classifications { - entries { - categories { index: 1 score: 0.8 } - categories { index: 2 score: 0.6 } - timestamp_ms: 0 - } - })pb")); -} - -TEST_F(LegacyPostprocessingTest, SucceedsWithMetadata) { - // Build graph. - proto::ClassifierOptions options; - options.set_max_results(3); - MP_ASSERT_OK_AND_ASSIGN( - auto poller, BuildGraph(kQuantizedImageClassifierWithMetadata, options)); - // Build input tensors. - std::vector tensor(kMobileNetNumClasses, 0); - tensor[1] = 12; - tensor[2] = 14; - tensor[3] = 16; - tensor[4] = 18; - - // Send tensors and get results. - AddTensor(tensor, Tensor::ElementType::kUInt8, - /*quantization_parameters=*/{0.1, 10}); - MP_ASSERT_OK(Run()); - MP_ASSERT_OK_AND_ASSIGN(auto results, GetClassificationResult(poller)); - - // Validate results. - EXPECT_THAT( - results, - EqualsProto( - R"pb(classifications { - entries { - categories { - index: 4 - score: 0.8 - category_name: "tiger shark" - } - categories { - index: 3 - score: 0.6 - category_name: "great white shark" - } - categories { index: 2 score: 0.4 category_name: "goldfish" } - timestamp_ms: 0 - } - head_index: 0 - head_name: "probability" - })pb")); -} - -TEST_F(LegacyPostprocessingTest, SucceedsWithScoreCalibration) { - // Build graph. - proto::ClassifierOptions options; - options.set_max_results(3); - MP_ASSERT_OK_AND_ASSIGN( - auto poller, - BuildGraph(kQuantizedImageClassifierWithDummyScoreCalibration, options)); - // Build input tensors. - std::vector tensor(kMobileNetNumClasses, 0); - tensor[1] = 12; - tensor[2] = 14; - tensor[3] = 16; - tensor[4] = 18; - - // Send tensors and get results. - AddTensor(tensor, Tensor::ElementType::kUInt8, - /*quantization_parameters=*/{0.1, 10}); - MP_ASSERT_OK(Run()); - MP_ASSERT_OK_AND_ASSIGN(auto results, GetClassificationResult(poller)); - - // Validate results. - EXPECT_THAT(results, EqualsProto( - R"pb(classifications { - entries { - categories { - index: 4 - score: 0.6899744811 - category_name: "tiger shark" - } - categories { - index: 3 - score: 0.6456563062 - category_name: "great white shark" - } - categories { - index: 2 - score: 0.5986876601 - category_name: "goldfish" - } - timestamp_ms: 0 - } - head_index: 0 - head_name: "probability" - })pb")); -} - -TEST_F(LegacyPostprocessingTest, SucceedsWithMultipleHeads) { - // Build graph. - proto::ClassifierOptions options; - options.set_max_results(2); - MP_ASSERT_OK_AND_ASSIGN( - auto poller, - BuildGraph(kFloatTwoHeadsAudioClassifierWithMetadata, options)); - // Build input tensors. - std::vector tensor_0(kTwoHeadsNumClasses[0], 0); - tensor_0[1] = 0.2; - tensor_0[2] = 0.4; - tensor_0[3] = 0.6; - std::vector tensor_1(kTwoHeadsNumClasses[1], 0); - tensor_1[1] = 0.2; - tensor_1[2] = 0.4; - tensor_1[3] = 0.6; - - // Send tensors and get results. - AddTensor(tensor_0, Tensor::ElementType::kFloat32); - AddTensor(tensor_1, Tensor::ElementType::kFloat32); - MP_ASSERT_OK(Run()); - MP_ASSERT_OK_AND_ASSIGN(auto results, GetClassificationResult(poller)); - - EXPECT_THAT(results, EqualsProto( - R"pb(classifications { - entries { - categories { - index: 3 - score: 0.6 - category_name: "Narration, monologue" - } - categories { - index: 2 - score: 0.4 - category_name: "Conversation" - } - timestamp_ms: 0 - } - head_index: 0 - head_name: "yamnet_classification" - } - classifications { - entries { - categories { - index: 3 - score: 0.6 - category_name: "Azara\'s Spinetail" - } - categories { - index: 2 - score: 0.4 - category_name: "House Sparrow" - } - timestamp_ms: 0 - } - head_index: 1 - head_name: "bird_classification" - })pb")); -} - -TEST_F(LegacyPostprocessingTest, SucceedsWithTimestamps) { - // Build graph. - proto::ClassifierOptions options; - options.set_max_results(2); - MP_ASSERT_OK_AND_ASSIGN( - auto poller, BuildGraph(kQuantizedImageClassifierWithMetadata, options, - /*connect_timestamps=*/true)); - // Build input tensors. - std::vector tensor_0(kMobileNetNumClasses, 0); - tensor_0[1] = 12; - tensor_0[2] = 14; - tensor_0[3] = 16; - std::vector tensor_1(kMobileNetNumClasses, 0); - tensor_1[5] = 12; - tensor_1[6] = 14; - tensor_1[7] = 16; - - // Send tensors and get results. - AddTensor(tensor_0, Tensor::ElementType::kUInt8, - /*quantization_parameters=*/{0.1, 10}); - MP_ASSERT_OK(Run()); - AddTensor(tensor_1, Tensor::ElementType::kUInt8, - /*quantization_parameters=*/{0.1, 10}); - MP_ASSERT_OK(Run( - /*aggregation_timestamps=*/std::optional>({0, 1000}), - /*timestamp=*/1000)); - - MP_ASSERT_OK_AND_ASSIGN(auto results, GetClassificationResult(poller)); - - // Validate results. - EXPECT_THAT( - results, - EqualsProto( - R"pb(classifications { - entries { - categories { - index: 3 - score: 0.6 - category_name: "great white shark" - } - categories { index: 2 score: 0.4 category_name: "goldfish" } - timestamp_ms: 0 - } - entries { - categories { index: 7 score: 0.6 category_name: "stingray" } - categories { - index: 6 - score: 0.4 - category_name: "electric ray" - } - timestamp_ms: 1 - } - head_index: 0 - head_name: "probability" - })pb")); -} - } // namespace } // namespace processors } // namespace components diff --git a/mediapipe/tasks/cc/components/processors/embedding_postprocessing_graph.cc b/mediapipe/tasks/cc/components/processors/embedding_postprocessing_graph.cc index 880aec5d7..7b023ba41 100644 --- a/mediapipe/tasks/cc/components/processors/embedding_postprocessing_graph.cc +++ b/mediapipe/tasks/cc/components/processors/embedding_postprocessing_graph.cc @@ -31,7 +31,6 @@ limitations under the License. #include "mediapipe/tasks/cc/components/containers/proto/embeddings.pb.h" #include "mediapipe/tasks/cc/components/processors/proto/embedder_options.pb.h" #include "mediapipe/tasks/cc/components/processors/proto/embedding_postprocessing_graph_options.pb.h" -#include "mediapipe/tasks/cc/components/utils/source_or_node_output.h" #include "mediapipe/tasks/cc/core/model_resources.h" #include "mediapipe/tasks/cc/metadata/metadata_extractor.h" #include "tensorflow/lite/schema/schema_generated.h" @@ -51,8 +50,6 @@ using ::mediapipe::api2::builder::Graph; using ::mediapipe::api2::builder::Source; using ::mediapipe::tasks::components::containers::proto::EmbeddingResult; using ::mediapipe::tasks::core::ModelResources; -using TensorsSource = - ::mediapipe::tasks::SourceOrNodeOutput>; constexpr char kTensorsTag[] = "TENSORS"; constexpr char kEmbeddingsTag[] = "EMBEDDINGS"; @@ -150,7 +147,7 @@ absl::StatusOr> GetHeadNames( } // namespace -absl::Status ConfigureEmbeddingPostprocessing( +absl::Status ConfigureEmbeddingPostprocessingGraph( const ModelResources& model_resources, const proto::EmbedderOptions& embedder_options, proto::EmbeddingPostprocessingGraphOptions* options) { @@ -193,8 +190,8 @@ absl::Status ConfigureEmbeddingPostprocessing( // timestamp aggregation is required. // // The recommended way of using this graph is through the GraphBuilder API using -// the 'ConfigureEmbeddingPostprocessing()' function. See header file for more -// details. +// the 'ConfigureEmbeddingPostprocessingGraph()' function. See header file for +// more details. class EmbeddingPostprocessingGraph : public mediapipe::Subgraph { public: absl::StatusOr GetConfig( @@ -229,12 +226,13 @@ class EmbeddingPostprocessingGraph : public mediapipe::Subgraph { Source> tensors_in, Source> timestamps_in, Graph& graph) { // If output tensors are quantized, they must be dequantized first. - TensorsSource dequantized_tensors(&tensors_in); + Source> dequantized_tensors = tensors_in; if (options.has_quantized_outputs()) { GenericNode& tensors_dequantization_node = graph.AddNode("TensorsDequantizationCalculator"); tensors_in >> tensors_dequantization_node.In(kTensorsTag); - dequantized_tensors = {&tensors_dequantization_node, kTensorsTag}; + dequantized_tensors = tensors_dequantization_node.Out(kTensorsTag) + .Cast>(); } // Adds TensorsToEmbeddingsCalculator. diff --git a/mediapipe/tasks/cc/components/processors/embedding_postprocessing_graph.h b/mediapipe/tasks/cc/components/processors/embedding_postprocessing_graph.h index 58606ed80..889992463 100644 --- a/mediapipe/tasks/cc/components/processors/embedding_postprocessing_graph.h +++ b/mediapipe/tasks/cc/components/processors/embedding_postprocessing_graph.h @@ -58,7 +58,7 @@ namespace processors { // The embedding result aggregated by timestamp, then by head. Must be // connected if the TIMESTAMPS input is connected, as it signals that // timestamp aggregation is required. -absl::Status ConfigureEmbeddingPostprocessing( +absl::Status ConfigureEmbeddingPostprocessingGraph( const tasks::core::ModelResources& model_resources, const proto::EmbedderOptions& embedder_options, proto::EmbeddingPostprocessingGraphOptions* options); diff --git a/mediapipe/tasks/cc/components/processors/embedding_postprocessing_graph_test.cc b/mediapipe/tasks/cc/components/processors/embedding_postprocessing_graph_test.cc index 84d84d648..809268a63 100644 --- a/mediapipe/tasks/cc/components/processors/embedding_postprocessing_graph_test.cc +++ b/mediapipe/tasks/cc/components/processors/embedding_postprocessing_graph_test.cc @@ -95,8 +95,8 @@ TEST_F(ConfigureTest, SucceedsWithQuantizedModelWithMetadata) { options_in.set_l2_normalize(true); proto::EmbeddingPostprocessingGraphOptions options_out; - MP_ASSERT_OK(ConfigureEmbeddingPostprocessing(*model_resources, options_in, - &options_out)); + MP_ASSERT_OK(ConfigureEmbeddingPostprocessingGraph(*model_resources, + options_in, &options_out)); EXPECT_THAT( options_out, @@ -117,8 +117,8 @@ TEST_F(ConfigureTest, SucceedsWithQuantizedModelWithoutMetadata) { options_in.set_quantize(true); proto::EmbeddingPostprocessingGraphOptions options_out; - MP_ASSERT_OK(ConfigureEmbeddingPostprocessing(*model_resources, options_in, - &options_out)); + MP_ASSERT_OK(ConfigureEmbeddingPostprocessingGraph(*model_resources, + options_in, &options_out)); EXPECT_THAT( options_out, @@ -138,8 +138,8 @@ TEST_F(ConfigureTest, SucceedsWithFloatModelWithMetadata) { options_in.set_l2_normalize(true); proto::EmbeddingPostprocessingGraphOptions options_out; - MP_ASSERT_OK(ConfigureEmbeddingPostprocessing(*model_resources, options_in, - &options_out)); + MP_ASSERT_OK(ConfigureEmbeddingPostprocessingGraph(*model_resources, + options_in, &options_out)); EXPECT_THAT( options_out, @@ -164,7 +164,7 @@ class PostprocessingTest : public tflite_shims::testing::Test { auto& postprocessing = graph.AddNode( "mediapipe.tasks.components.processors." "EmbeddingPostprocessingGraph"); - MP_RETURN_IF_ERROR(ConfigureEmbeddingPostprocessing( + MP_RETURN_IF_ERROR(ConfigureEmbeddingPostprocessingGraph( *model_resources, options, &postprocessing .GetOptions())); @@ -246,7 +246,7 @@ class PostprocessingTest : public tflite_shims::testing::Test { absl::make_unique>(); }; -TEST_F(PostprocessingTest, SucceedsWithoutTimestamps) { +TEST_F(PostprocessingTest, SucceedsWithoutAggregation) { // Build graph. proto::EmbedderOptions options; MP_ASSERT_OK_AND_ASSIGN(auto poller, @@ -261,7 +261,8 @@ TEST_F(PostprocessingTest, SucceedsWithoutTimestamps) { MP_ASSERT_OK_AND_ASSIGN(auto results, GetResult(poller)); // Validate results. - EXPECT_FALSE(results.has_timestamp_ms()); + EXPECT_TRUE(results.has_timestamp_ms()); + EXPECT_EQ(results.timestamp_ms(), 0); EXPECT_EQ(results.embeddings_size(), 1); EXPECT_EQ(results.embeddings(0).head_index(), 0); EXPECT_EQ(results.embeddings(0).head_name(), "feature"); @@ -273,7 +274,7 @@ TEST_F(PostprocessingTest, SucceedsWithoutTimestamps) { } } -TEST_F(PostprocessingTest, SucceedsWithTimestamps) { +TEST_F(PostprocessingTest, SucceedsWithAggregation) { // Build graph. proto::EmbedderOptions options; MP_ASSERT_OK_AND_ASSIGN(auto poller, BuildGraph(kMobileNetV3Embedder, options, diff --git a/mediapipe/tasks/cc/components/image_preprocessing.cc b/mediapipe/tasks/cc/components/processors/image_preprocessing_graph.cc similarity index 90% rename from mediapipe/tasks/cc/components/image_preprocessing.cc rename to mediapipe/tasks/cc/components/processors/image_preprocessing_graph.cc index ef447df97..fefc1ec52 100644 --- a/mediapipe/tasks/cc/components/image_preprocessing.cc +++ b/mediapipe/tasks/cc/components/processors/image_preprocessing_graph.cc @@ -12,7 +12,7 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ -#include "mediapipe/tasks/cc/components/image_preprocessing.h" +#include "mediapipe/tasks/cc/components/processors/image_preprocessing_graph.h" #include #include @@ -33,7 +33,7 @@ limitations under the License. #include "mediapipe/framework/formats/tensor.h" #include "mediapipe/gpu/gpu_origin.pb.h" #include "mediapipe/tasks/cc/common.h" -#include "mediapipe/tasks/cc/components/image_preprocessing_options.pb.h" +#include "mediapipe/tasks/cc/components/processors/proto/image_preprocessing_graph_options.pb.h" #include "mediapipe/tasks/cc/core/model_resources.h" #include "mediapipe/tasks/cc/core/proto/acceleration.pb.h" #include "mediapipe/tasks/cc/vision/utils/image_tensor_specs.h" @@ -42,8 +42,10 @@ limitations under the License. namespace mediapipe { namespace tasks { namespace components { +namespace processors { namespace { +using ::mediapipe::NormalizedRect; using ::mediapipe::Tensor; using ::mediapipe::api2::Input; using ::mediapipe::api2::Output; @@ -144,9 +146,9 @@ bool DetermineImagePreprocessingGpuBackend( return acceleration.has_gpu(); } -absl::Status ConfigureImagePreprocessing(const ModelResources& model_resources, - bool use_gpu, - ImagePreprocessingOptions* options) { +absl::Status ConfigureImagePreprocessingGraph( + const ModelResources& model_resources, bool use_gpu, + proto::ImagePreprocessingGraphOptions* options) { ASSIGN_OR_RETURN(auto image_tensor_specs, BuildImageTensorSpecs(model_resources)); MP_RETURN_IF_ERROR(ConfigureImageToTensorCalculator( @@ -154,9 +156,9 @@ absl::Status ConfigureImagePreprocessing(const ModelResources& model_resources, // The GPU backend isn't able to process int data. If the input tensor is // quantized, forces the image preprocessing graph to use CPU backend. if (use_gpu && image_tensor_specs.tensor_type != tflite::TensorType_UINT8) { - options->set_backend(ImagePreprocessingOptions::GPU_BACKEND); + options->set_backend(proto::ImagePreprocessingGraphOptions::GPU_BACKEND); } else { - options->set_backend(ImagePreprocessingOptions::CPU_BACKEND); + options->set_backend(proto::ImagePreprocessingGraphOptions::CPU_BACKEND); } return absl::OkStatus(); } @@ -170,8 +172,7 @@ Source AddDataConverter(Source image_in, Graph& graph, return image_converter[Output("")]; } -// A "mediapipe.tasks.components.ImagePreprocessingSubgraph" performs image -// preprocessing. +// An ImagePreprocessingGraph performs image preprocessing. // - Accepts CPU input images and outputs CPU tensors. // // Inputs: @@ -192,7 +193,7 @@ Source AddDataConverter(Source image_in, Graph& graph, // An std::array representing the letterbox padding from the 4 // sides ([left, top, right, bottom]) of the output image, normalized to // [0.f, 1.f] by the output dimensions. The padding values are non-zero only -// when the "keep_aspect_ratio" is true in ImagePreprocessingOptions. +// when the "keep_aspect_ratio" is true in ImagePreprocessingGraphOptions. // IMAGE_SIZE - std::pair @Optional // The size of the original input image as a pair. // IMAGE - Image @Optional @@ -200,15 +201,15 @@ Source AddDataConverter(Source image_in, Graph& graph, // GPU). // // The recommended way of using this subgraph is through the GraphBuilder API -// using the 'ConfigureImagePreprocessing()' function. See header file for more -// details. -class ImagePreprocessingSubgraph : public Subgraph { +// using the 'ConfigureImagePreprocessingGraph()' function. See header file for +// more details. +class ImagePreprocessingGraph : public Subgraph { public: absl::StatusOr GetConfig( SubgraphContext* sc) override { Graph graph; auto output_streams = BuildImagePreprocessing( - sc->Options(), + sc->Options(), graph[Input(kImageTag)], graph[Input::Optional(kNormRectTag)], graph); output_streams.tensors >> graph[Output>(kTensorsTag)]; @@ -233,24 +234,25 @@ class ImagePreprocessingSubgraph : public Subgraph { // - the image that has pixel data stored on the target storage // (mediapipe::Image). // - // options: the mediapipe tasks ImagePreprocessingOptions. + // options: the mediapipe tasks ImagePreprocessingGraphOptions. // image_in: (mediapipe::Image) stream to preprocess. // graph: the mediapipe builder::Graph instance to be updated. ImagePreprocessingOutputStreams BuildImagePreprocessing( - const ImagePreprocessingOptions& options, Source image_in, - Source norm_rect_in, Graph& graph) { + const proto::ImagePreprocessingGraphOptions& options, + Source image_in, Source norm_rect_in, + Graph& graph) { // Convert image to tensor. auto& image_to_tensor = graph.AddNode("ImageToTensorCalculator"); image_to_tensor.GetOptions() .CopyFrom(options.image_to_tensor_options()); switch (options.backend()) { - case ImagePreprocessingOptions::CPU_BACKEND: { + case proto::ImagePreprocessingGraphOptions::CPU_BACKEND: { auto cpu_image = AddDataConverter(image_in, graph, /*output_on_gpu=*/false); cpu_image >> image_to_tensor.In(kImageTag); break; } - case ImagePreprocessingOptions::GPU_BACKEND: { + case proto::ImagePreprocessingGraphOptions::GPU_BACKEND: { auto gpu_image = AddDataConverter(image_in, graph, /*output_on_gpu=*/true); gpu_image >> image_to_tensor.In(kImageTag); @@ -284,8 +286,9 @@ class ImagePreprocessingSubgraph : public Subgraph { } }; REGISTER_MEDIAPIPE_GRAPH( - ::mediapipe::tasks::components::ImagePreprocessingSubgraph); + ::mediapipe::tasks::components::processors::ImagePreprocessingGraph); +} // namespace processors } // namespace components } // namespace tasks } // namespace mediapipe diff --git a/mediapipe/tasks/cc/components/image_preprocessing.h b/mediapipe/tasks/cc/components/processors/image_preprocessing_graph.h similarity index 72% rename from mediapipe/tasks/cc/components/image_preprocessing.h rename to mediapipe/tasks/cc/components/processors/image_preprocessing_graph.h index 6963b6556..455a9b316 100644 --- a/mediapipe/tasks/cc/components/image_preprocessing.h +++ b/mediapipe/tasks/cc/components/processors/image_preprocessing_graph.h @@ -13,35 +13,36 @@ See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ -#ifndef MEDIAPIPE_TASKS_CC_COMPONENTS_IMAGE_PREPROCESSING_H_ -#define MEDIAPIPE_TASKS_CC_COMPONENTS_IMAGE_PREPROCESSING_H_ +#ifndef MEDIAPIPE_TASKS_CC_COMPONENTS_PROCESSORS_IMAGE_PREPROCESSING_GRAPH_H_ +#define MEDIAPIPE_TASKS_CC_COMPONENTS_PROCESSORS_IMAGE_PREPROCESSING_GRAPH_H_ #include "absl/status/status.h" -#include "mediapipe/tasks/cc/components/image_preprocessing_options.pb.h" +#include "mediapipe/tasks/cc/components/processors/proto/image_preprocessing_graph_options.pb.h" #include "mediapipe/tasks/cc/core/model_resources.h" #include "mediapipe/tasks/cc/core/proto/acceleration.pb.h" namespace mediapipe { namespace tasks { namespace components { +namespace processors { -// Configures an ImagePreprocessing subgraph using the provided model resources +// Configures an ImagePreprocessingGraph using the provided model resources // When use_gpu is true, use GPU as backend to convert image to tensor. // - Accepts CPU input images and outputs CPU tensors. // // Example usage: // // auto& preprocessing = -// graph.AddNode("mediapipe.tasks.components.ImagePreprocessingSubgraph"); +// graph.AddNode("mediapipe.tasks.components.processors.ImagePreprocessingGraph"); // core::proto::Acceleration acceleration; // acceleration.mutable_xnnpack(); // bool use_gpu = DetermineImagePreprocessingGpuBackend(acceleration); -// MP_RETURN_IF_ERROR(ConfigureImagePreprocessing( +// MP_RETURN_IF_ERROR(ConfigureImagePreprocessingGraph( // model_resources, // use_gpu, -// &preprocessing.GetOptions())); +// &preprocessing.GetOptions())); // -// The resulting ImagePreprocessing subgraph has the following I/O: +// The resulting ImagePreprocessingGraph has the following I/O: // Inputs: // IMAGE - Image // The image to preprocess. @@ -61,17 +62,18 @@ namespace components { // IMAGE - Image @Optional // The image that has the pixel data stored on the target storage (CPU vs // GPU). -absl::Status ConfigureImagePreprocessing( +absl::Status ConfigureImagePreprocessingGraph( const core::ModelResources& model_resources, bool use_gpu, - ImagePreprocessingOptions* options); + proto::ImagePreprocessingGraphOptions* options); -// Determine if the image preprocessing subgraph should use GPU as the backend +// Determine if the image preprocessing graph should use GPU as the backend // according to the given acceleration setting. bool DetermineImagePreprocessingGpuBackend( const core::proto::Acceleration& acceleration); +} // namespace processors } // namespace components } // namespace tasks } // namespace mediapipe -#endif // MEDIAPIPE_TASKS_CC_COMPONENTS_IMAGE_PREPROCESSING_H_ +#endif // MEDIAPIPE_TASKS_CC_COMPONENTS_PROCESSORS_IMAGE_PREPROCESSING_GRAPH_H_ diff --git a/mediapipe/tasks/cc/components/processors/image_preprocessing_graph_test.cc b/mediapipe/tasks/cc/components/processors/image_preprocessing_graph_test.cc new file mode 100644 index 000000000..6c094c6bc --- /dev/null +++ b/mediapipe/tasks/cc/components/processors/image_preprocessing_graph_test.cc @@ -0,0 +1,343 @@ +/* Copyright 2022 The MediaPipe Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "mediapipe/tasks/cc/components/processors/image_preprocessing_graph.h" + +#include +#include + +#include "absl/flags/flag.h" +#include "absl/status/statusor.h" +#include "mediapipe/calculators/tensor/image_to_tensor_calculator.pb.h" +#include "mediapipe/framework/api2/builder.h" +#include "mediapipe/framework/api2/port.h" +#include "mediapipe/framework/calculator_framework.h" +#include "mediapipe/framework/deps/file_path.h" +#include "mediapipe/framework/formats/image.h" +#include "mediapipe/framework/formats/tensor.h" +#include "mediapipe/framework/packet.h" +#include "mediapipe/framework/port/gmock.h" +#include "mediapipe/framework/port/gtest.h" +#include "mediapipe/framework/port/status_matchers.h" +#include "mediapipe/tasks/cc/components/processors/proto/image_preprocessing_graph_options.pb.h" +#include "mediapipe/tasks/cc/core/model_resources.h" +#include "mediapipe/tasks/cc/core/proto/acceleration.pb.h" +#include "mediapipe/tasks/cc/core/proto/external_file.pb.h" +#include "mediapipe/tasks/cc/core/task_runner.h" +#include "mediapipe/tasks/cc/vision/utils/image_utils.h" +#include "tensorflow/lite/core/shims/cc/shims_test_util.h" + +namespace mediapipe { +namespace tasks { +namespace components { +namespace processors { +namespace { + +using ::mediapipe::api2::Input; +using ::mediapipe::api2::Output; +using ::mediapipe::api2::builder::Graph; +using ::mediapipe::api2::builder::Source; +using ::mediapipe::file::JoinPath; +using ::mediapipe::tasks::core::ModelResources; +using ::mediapipe::tasks::core::TaskRunner; +using ::mediapipe::tasks::vision::DecodeImageFromFile; +using ::testing::ContainerEq; +using ::testing::HasSubstr; +using ::testing::TestParamInfo; +using ::testing::TestWithParam; +using ::testing::Values; + +constexpr char kTestDataDirectory[] = "/mediapipe/tasks/testdata/vision/"; +constexpr char kMobileNetFloatWithMetadata[] = "mobilenet_v2_1.0_224.tflite"; +constexpr char kMobileNetFloatWithoutMetadata[] = + "mobilenet_v1_0.25_224_1_default_1.tflite"; +constexpr char kMobileNetQuantizedWithMetadata[] = + "mobilenet_v1_0.25_224_quant.tflite"; +constexpr char kMobileNetQuantizedWithoutMetadata[] = + "mobilenet_v1_0.25_192_quantized_1_default_1.tflite"; + +constexpr char kTestImage[] = "burger.jpg"; +constexpr int kTestImageWidth = 480; +constexpr int kTestImageHeight = 325; + +constexpr char kTestModelResourcesTag[] = "test_model_resources"; +constexpr std::array kIdentityMatrix = {1, 0, 0, 0, 0, 1, 0, 0, + 0, 0, 1, 0, 0, 0, 0, 1}; + +constexpr char kImageTag[] = "IMAGE"; +constexpr char kImageName[] = "image_in"; +constexpr char kMatrixTag[] = "MATRIX"; +constexpr char kMatrixName[] = "matrix_out"; +constexpr char kTensorsTag[] = "TENSORS"; +constexpr char kTensorsName[] = "tensors_out"; +constexpr char kImageSizeTag[] = "IMAGE_SIZE"; +constexpr char kImageSizeName[] = "image_size_out"; +constexpr char kLetterboxPaddingTag[] = "LETTERBOX_PADDING"; +constexpr char kLetterboxPaddingName[] = "letterbox_padding_out"; + +constexpr float kLetterboxMaxAbsError = 1e-5; + +// Helper function to get ModelResources. +absl::StatusOr> CreateModelResourcesForModel( + absl::string_view model_name) { + auto external_file = std::make_unique(); + external_file->set_file_name(JoinPath("./", kTestDataDirectory, model_name)); + return ModelResources::Create(kTestModelResourcesTag, + std::move(external_file)); +} + +// Helper function to create a TaskRunner from ModelResources. +absl::StatusOr> CreateTaskRunner( + const ModelResources& model_resources, bool keep_aspect_ratio) { + Graph graph; + + auto& preprocessing = graph.AddNode( + "mediapipe.tasks.components.processors.ImagePreprocessingGraph"); + auto& options = + preprocessing.GetOptions(); + options.mutable_image_to_tensor_options()->set_keep_aspect_ratio( + keep_aspect_ratio); + MP_RETURN_IF_ERROR( + ConfigureImagePreprocessingGraph(model_resources, false, &options)); + graph[Input(kImageTag)].SetName(kImageName) >> + preprocessing.In(kImageTag); + preprocessing.Out(kTensorsTag).SetName(kTensorsName) >> + graph[Output>(kTensorsTag)]; + preprocessing.Out(kMatrixTag).SetName(kMatrixName) >> + graph[Output>(kMatrixTag)]; + preprocessing.Out(kImageSizeTag).SetName(kImageSizeName) >> + graph[Output>(kImageSizeTag)]; + preprocessing.Out(kLetterboxPaddingTag).SetName(kLetterboxPaddingName) >> + graph[Output>(kLetterboxPaddingTag)]; + + return TaskRunner::Create(graph.GetConfig()); +} + +class ConfigureTest : public tflite_shims::testing::Test {}; + +TEST_F(ConfigureTest, SucceedsWithQuantizedModelWithMetadata) { + MP_ASSERT_OK_AND_ASSIGN( + auto model_resources, + CreateModelResourcesForModel(kMobileNetQuantizedWithMetadata)); + + proto::ImagePreprocessingGraphOptions options; + MP_EXPECT_OK( + ConfigureImagePreprocessingGraph(*model_resources, false, &options)); + + EXPECT_THAT(options, EqualsProto( + R"pb(image_to_tensor_options { + output_tensor_width: 224 + output_tensor_height: 224 + output_tensor_uint_range { min: 0 max: 255 } + gpu_origin: TOP_LEFT + } + backend: CPU_BACKEND)pb")); +} + +TEST_F(ConfigureTest, SucceedsWithQuantizedModelWithoutMetadata) { + MP_ASSERT_OK_AND_ASSIGN( + auto model_resources, + CreateModelResourcesForModel(kMobileNetQuantizedWithoutMetadata)); + + proto::ImagePreprocessingGraphOptions options; + MP_EXPECT_OK( + ConfigureImagePreprocessingGraph(*model_resources, false, &options)); + + EXPECT_THAT(options, EqualsProto( + R"pb(image_to_tensor_options { + output_tensor_width: 192 + output_tensor_height: 192 + output_tensor_uint_range { min: 0 max: 255 } + gpu_origin: TOP_LEFT + } + backend: CPU_BACKEND)pb")); +} + +TEST_F(ConfigureTest, SucceedsWithFloatModelWithMetadata) { + MP_ASSERT_OK_AND_ASSIGN( + auto model_resources, + CreateModelResourcesForModel(kMobileNetFloatWithMetadata)); + + proto::ImagePreprocessingGraphOptions options; + MP_EXPECT_OK( + ConfigureImagePreprocessingGraph(*model_resources, false, &options)); + + EXPECT_THAT(options, EqualsProto( + R"pb(image_to_tensor_options { + output_tensor_width: 224 + output_tensor_height: 224 + output_tensor_float_range { min: -1 max: 1 } + gpu_origin: TOP_LEFT + } + backend: CPU_BACKEND)pb")); +} + +TEST_F(ConfigureTest, SucceedsWithQuantizedModelFallbacksCpuBackend) { + MP_ASSERT_OK_AND_ASSIGN( + auto model_resources, + CreateModelResourcesForModel(kMobileNetQuantizedWithMetadata)); + + proto::ImagePreprocessingGraphOptions options; + core::proto::Acceleration acceleration; + acceleration.mutable_gpu(); + bool use_gpu = DetermineImagePreprocessingGpuBackend(acceleration); + EXPECT_TRUE(use_gpu); + MP_EXPECT_OK( + ConfigureImagePreprocessingGraph(*model_resources, use_gpu, &options)); + + EXPECT_THAT(options, EqualsProto( + R"pb(image_to_tensor_options { + output_tensor_width: 224 + output_tensor_height: 224 + output_tensor_uint_range { min: 0 max: 255 } + gpu_origin: TOP_LEFT + } + backend: CPU_BACKEND)pb")); +} + +TEST_F(ConfigureTest, SucceedsWithFloatModelGpuBackend) { + MP_ASSERT_OK_AND_ASSIGN( + auto model_resources, + CreateModelResourcesForModel(kMobileNetFloatWithMetadata)); + + proto::ImagePreprocessingGraphOptions options; + core::proto::Acceleration acceleration; + acceleration.mutable_gpu(); + bool use_gpu = DetermineImagePreprocessingGpuBackend(acceleration); + EXPECT_TRUE(use_gpu); + MP_EXPECT_OK( + ConfigureImagePreprocessingGraph(*model_resources, use_gpu, &options)); + + EXPECT_THAT(options, EqualsProto( + R"pb(image_to_tensor_options { + output_tensor_width: 224 + output_tensor_height: 224 + output_tensor_float_range { min: -1 max: 1 } + gpu_origin: TOP_LEFT + } + backend: GPU_BACKEND)pb")); +} + +TEST_F(ConfigureTest, FailsWithFloatModelWithoutMetadata) { + MP_ASSERT_OK_AND_ASSIGN( + auto model_resources, + CreateModelResourcesForModel(kMobileNetFloatWithoutMetadata)); + + proto::ImagePreprocessingGraphOptions options; + auto status = + ConfigureImagePreprocessingGraph(*model_resources, false, &options); + + EXPECT_EQ(status.code(), absl::StatusCode::kNotFound); + EXPECT_THAT(status.message(), + HasSubstr("requires specifying NormalizationOptions metadata")); +} + +// Struct holding the parameters for parameterized PreprocessingTest class. +struct PreprocessingParams { + // The name of this test, for convenience when displaying test results. + std::string test_name; + // The filename of the model to test. + std::string input_model_name; + // If true, keep test image aspect ratio. + bool keep_aspect_ratio; + // The expected output tensor type. + Tensor::ElementType expected_type; + // The expected outoput tensor shape. + std::vector expected_shape; + // The expected output letterbox padding; + std::array expected_letterbox_padding; +}; + +class PreprocessingTest : public testing::TestWithParam {}; + +TEST_P(PreprocessingTest, Succeeds) { + MP_ASSERT_OK_AND_ASSIGN( + Image image, + DecodeImageFromFile(JoinPath("./", kTestDataDirectory, kTestImage))); + MP_ASSERT_OK_AND_ASSIGN( + auto model_resources, + CreateModelResourcesForModel(GetParam().input_model_name)); + MP_ASSERT_OK_AND_ASSIGN( + auto task_runner, + CreateTaskRunner(*model_resources, GetParam().keep_aspect_ratio)); + + auto output_packets = + task_runner->Process({{kImageName, MakePacket(std::move(image))}}); + MP_ASSERT_OK(output_packets); + + const std::vector& tensors = + (*output_packets)[kTensorsName].Get>(); + EXPECT_EQ(tensors.size(), 1); + EXPECT_EQ(tensors[0].element_type(), GetParam().expected_type); + EXPECT_THAT(tensors[0].shape().dims, ContainerEq(GetParam().expected_shape)); + auto& matrix = (*output_packets)[kMatrixName].Get>(); + if (!GetParam().keep_aspect_ratio) { + for (int i = 0; i < matrix.size(); ++i) { + EXPECT_FLOAT_EQ(matrix[i], kIdentityMatrix[i]); + } + } + auto& image_size = + (*output_packets)[kImageSizeName].Get>(); + EXPECT_EQ(image_size.first, kTestImageWidth); + EXPECT_EQ(image_size.second, kTestImageHeight); + std::array letterbox_padding = + (*output_packets)[kLetterboxPaddingName].Get>(); + for (int i = 0; i < letterbox_padding.size(); ++i) { + EXPECT_NEAR(letterbox_padding[i], GetParam().expected_letterbox_padding[i], + kLetterboxMaxAbsError); + } +} + +INSTANTIATE_TEST_SUITE_P( + PreprocessingTest, PreprocessingTest, + Values( + PreprocessingParams{.test_name = "kMobileNetQuantizedWithMetadata", + .input_model_name = kMobileNetQuantizedWithMetadata, + .keep_aspect_ratio = false, + .expected_type = Tensor::ElementType::kUInt8, + .expected_shape = {1, 224, 224, 3}, + .expected_letterbox_padding = {0, 0, 0, 0}}, + PreprocessingParams{ + .test_name = "kMobileNetQuantizedWithoutMetadata", + .input_model_name = kMobileNetQuantizedWithoutMetadata, + .keep_aspect_ratio = false, + .expected_type = Tensor::ElementType::kUInt8, + .expected_shape = {1, 192, 192, 3}, + .expected_letterbox_padding = {0, 0, 0, 0}}, + PreprocessingParams{.test_name = "kMobileNetFloatWithMetadata", + .input_model_name = kMobileNetFloatWithMetadata, + .keep_aspect_ratio = false, + .expected_type = Tensor::ElementType::kFloat32, + .expected_shape = {1, 224, 224, 3}, + .expected_letterbox_padding = {0, 0, 0, 0}}, + PreprocessingParams{ + .test_name = "kMobileNetFloatWithMetadataKeepAspectRatio", + .input_model_name = kMobileNetFloatWithMetadata, + .keep_aspect_ratio = true, + .expected_type = Tensor::ElementType::kFloat32, + .expected_shape = {1, 224, 224, 3}, + .expected_letterbox_padding = {/*left*/ 0, + /*top*/ 0.161458, + /*right*/ 0, + /*bottom*/ 0.161458}}), + [](const TestParamInfo& info) { + return info.param.test_name; + }); + +} // namespace +} // namespace processors +} // namespace components +} // namespace tasks +} // namespace mediapipe diff --git a/mediapipe/tasks/cc/components/processors/proto/BUILD b/mediapipe/tasks/cc/components/processors/proto/BUILD index 23ebbe008..816ba47e3 100644 --- a/mediapipe/tasks/cc/components/processors/proto/BUILD +++ b/mediapipe/tasks/cc/components/processors/proto/BUILD @@ -49,3 +49,28 @@ mediapipe_proto_library( "//mediapipe/tasks/cc/components/calculators:tensors_to_embeddings_calculator_proto", ], ) + +mediapipe_proto_library( + name = "image_preprocessing_graph_options_proto", + srcs = ["image_preprocessing_graph_options.proto"], + deps = [ + "//mediapipe/calculators/tensor:image_to_tensor_calculator_proto", + "//mediapipe/framework:calculator_options_proto", + "//mediapipe/framework:calculator_proto", + ], +) + +mediapipe_proto_library( + name = "text_model_type_proto", + srcs = ["text_model_type.proto"], +) + +mediapipe_proto_library( + name = "text_preprocessing_graph_options_proto", + srcs = ["text_preprocessing_graph_options.proto"], + deps = [ + ":text_model_type_proto", + "//mediapipe/framework:calculator_options_proto", + "//mediapipe/framework:calculator_proto", + ], +) diff --git a/mediapipe/tasks/cc/components/image_preprocessing_options.proto b/mediapipe/tasks/cc/components/processors/proto/image_preprocessing_graph_options.proto similarity index 89% rename from mediapipe/tasks/cc/components/image_preprocessing_options.proto rename to mediapipe/tasks/cc/components/processors/proto/image_preprocessing_graph_options.proto index d1685c319..bf4fc9067 100644 --- a/mediapipe/tasks/cc/components/image_preprocessing_options.proto +++ b/mediapipe/tasks/cc/components/processors/proto/image_preprocessing_graph_options.proto @@ -15,14 +15,14 @@ limitations under the License. syntax = "proto2"; -package mediapipe.tasks.components; +package mediapipe.tasks.components.processors.proto; import "mediapipe/calculators/tensor/image_to_tensor_calculator.proto"; import "mediapipe/framework/calculator.proto"; -message ImagePreprocessingOptions { +message ImagePreprocessingGraphOptions { extend mediapipe.CalculatorOptions { - optional ImagePreprocessingOptions ext = 456882436; + optional ImagePreprocessingGraphOptions ext = 456882436; } // Options for the ImageToTensor calculator encapsulated by the diff --git a/mediapipe/tasks/cc/components/processors/proto/text_model_type.proto b/mediapipe/tasks/cc/components/processors/proto/text_model_type.proto new file mode 100644 index 000000000..7ffc0db07 --- /dev/null +++ b/mediapipe/tasks/cc/components/processors/proto/text_model_type.proto @@ -0,0 +1,31 @@ +/* Copyright 2022 The MediaPipe Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +syntax = "proto2"; + +package mediapipe.tasks.components.processors.proto; + +message TextModelType { + // TFLite text models supported by MediaPipe tasks. + enum ModelType { + UNSPECIFIED_MODEL = 0; + // A BERT-based model. + BERT_MODEL = 1; + // A model expecting input passed through a regex-based tokenizer. + REGEX_MODEL = 2; + // A model taking a string tensor input. + STRING_MODEL = 3; + } +} diff --git a/mediapipe/tasks/cc/components/proto/text_preprocessing_graph_options.proto b/mediapipe/tasks/cc/components/processors/proto/text_preprocessing_graph_options.proto similarity index 66% rename from mediapipe/tasks/cc/components/proto/text_preprocessing_graph_options.proto rename to mediapipe/tasks/cc/components/processors/proto/text_preprocessing_graph_options.proto index 926e3d7fb..b610f7757 100644 --- a/mediapipe/tasks/cc/components/proto/text_preprocessing_graph_options.proto +++ b/mediapipe/tasks/cc/components/processors/proto/text_preprocessing_graph_options.proto @@ -15,28 +15,19 @@ limitations under the License. syntax = "proto2"; -package mediapipe.tasks.components.proto; +package mediapipe.tasks.components.processors.proto; import "mediapipe/framework/calculator.proto"; +import "mediapipe/tasks/cc/components/processors/proto/text_model_type.proto"; message TextPreprocessingGraphOptions { extend mediapipe.CalculatorOptions { optional TextPreprocessingGraphOptions ext = 476978751; } - // The type of text preprocessor required for the TFLite model. - enum PreprocessorType { - UNSPECIFIED_PREPROCESSOR = 0; - // Used for the BertPreprocessorCalculator. - BERT_PREPROCESSOR = 1; - // Used for the RegexPreprocessorCalculator. - REGEX_PREPROCESSOR = 2; - // Used for the TextToTensorCalculator. - STRING_PREPROCESSOR = 3; - } - optional PreprocessorType preprocessor_type = 1; + optional TextModelType.ModelType model_type = 1; // The maximum input sequence length for the TFLite model. Used with - // BERT_PREPROCESSOR and REGEX_PREPROCESSOR. + // BERT_MODEL and REGEX_MODEL. optional int32 max_seq_len = 2; } diff --git a/mediapipe/tasks/cc/components/text_preprocessing_graph.cc b/mediapipe/tasks/cc/components/processors/text_preprocessing_graph.cc similarity index 57% rename from mediapipe/tasks/cc/components/text_preprocessing_graph.cc rename to mediapipe/tasks/cc/components/processors/text_preprocessing_graph.cc index 6aad8fdd5..f6c15c441 100644 --- a/mediapipe/tasks/cc/components/text_preprocessing_graph.cc +++ b/mediapipe/tasks/cc/components/processors/text_preprocessing_graph.cc @@ -12,7 +12,7 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ -#include "mediapipe/tasks/cc/components/text_preprocessing_graph.h" +#include "mediapipe/tasks/cc/components/processors/text_preprocessing_graph.h" #include @@ -25,14 +25,14 @@ limitations under the License. #include "mediapipe/framework/api2/port.h" #include "mediapipe/framework/formats/tensor.h" #include "mediapipe/framework/subgraph.h" -#include "mediapipe/tasks/cc/components/proto/text_preprocessing_graph_options.pb.h" +#include "mediapipe/tasks/cc/common.h" +#include "mediapipe/tasks/cc/components/processors/proto/text_model_type.pb.h" +#include "mediapipe/tasks/cc/components/processors/proto/text_preprocessing_graph_options.pb.h" #include "mediapipe/tasks/cc/core/model_resources.h" #include "mediapipe/tasks/cc/metadata/metadata_extractor.h" +#include "mediapipe/tasks/cc/text/utils/text_model_utils.h" -namespace mediapipe { -namespace tasks { -namespace components { - +namespace mediapipe::tasks::components::processors { namespace { using ::mediapipe::api2::Input; @@ -41,90 +41,35 @@ using ::mediapipe::api2::SideInput; using ::mediapipe::api2::builder::Graph; using ::mediapipe::api2::builder::SideSource; using ::mediapipe::api2::builder::Source; -using ::mediapipe::tasks::components::proto::TextPreprocessingGraphOptions; +using ::mediapipe::tasks::components::processors::proto::TextModelType; +using ::mediapipe::tasks::components::processors::proto:: + TextPreprocessingGraphOptions; using ::mediapipe::tasks::core::ModelResources; using ::mediapipe::tasks::metadata::ModelMetadataExtractor; +using ::mediapipe::tasks::text::utils::GetModelType; constexpr char kTextTag[] = "TEXT"; constexpr char kMetadataExtractorTag[] = "METADATA_EXTRACTOR"; constexpr char kTensorsTag[] = "TENSORS"; -constexpr int kNumInputTensorsForBert = 3; -constexpr int kNumInputTensorsForRegex = 1; - -// Gets the name of the MediaPipe calculator associated with -// `preprocessor_type`. -absl::StatusOr GetCalculatorNameFromPreprocessorType( - TextPreprocessingGraphOptions::PreprocessorType preprocessor_type) { - switch (preprocessor_type) { - case TextPreprocessingGraphOptions::UNSPECIFIED_PREPROCESSOR: +// Gets the name of the MediaPipe preprocessor calculator associated with +// `model_type`. +absl::StatusOr GetCalculatorNameFromModelType( + TextModelType::ModelType model_type) { + switch (model_type) { + case TextModelType::UNSPECIFIED_MODEL: return CreateStatusWithPayload( - absl::StatusCode::kInvalidArgument, "Unspecified preprocessor type", + absl::StatusCode::kInvalidArgument, "Unspecified model type", MediaPipeTasksStatus::kInvalidArgumentError); - case TextPreprocessingGraphOptions::BERT_PREPROCESSOR: + case TextModelType::BERT_MODEL: return "BertPreprocessorCalculator"; - case TextPreprocessingGraphOptions::REGEX_PREPROCESSOR: + case TextModelType::REGEX_MODEL: return "RegexPreprocessorCalculator"; - case TextPreprocessingGraphOptions::STRING_PREPROCESSOR: + case TextModelType::STRING_MODEL: return "TextToTensorCalculator"; } } -// Determines the PreprocessorType for the model based on its metadata as well -// as its input tensors' type and count. Returns an error if there is no -// compatible preprocessor. -absl::StatusOr -GetPreprocessorType(const ModelResources& model_resources) { - const tflite::SubGraph& model_graph = - *(*model_resources.GetTfLiteModel()->subgraphs())[0]; - bool all_int32_tensors = - absl::c_all_of(*model_graph.inputs(), [&model_graph](int i) { - return (*model_graph.tensors())[i]->type() == tflite::TensorType_INT32; - }); - bool all_string_tensors = - absl::c_all_of(*model_graph.inputs(), [&model_graph](int i) { - return (*model_graph.tensors())[i]->type() == tflite::TensorType_STRING; - }); - if (!all_int32_tensors && !all_string_tensors) { - return CreateStatusWithPayload( - absl::StatusCode::kInvalidArgument, - "All input tensors should have type int32 or all should have type " - "string", - MediaPipeTasksStatus::kInvalidInputTensorTypeError); - } - if (all_string_tensors) { - return TextPreprocessingGraphOptions::STRING_PREPROCESSOR; - } - - // Otherwise, all tensors should have type int32 - const ModelMetadataExtractor* metadata_extractor = - model_resources.GetMetadataExtractor(); - if (metadata_extractor->GetModelMetadata() == nullptr || - metadata_extractor->GetModelMetadata()->subgraph_metadata() == nullptr) { - return CreateStatusWithPayload( - absl::StatusCode::kInvalidArgument, - "Text models with int32 input tensors require TFLite Model " - "Metadata but none was found", - MediaPipeTasksStatus::kMetadataNotFoundError); - } - - if (model_graph.inputs()->size() == kNumInputTensorsForBert) { - return TextPreprocessingGraphOptions::BERT_PREPROCESSOR; - } - - if (model_graph.inputs()->size() == kNumInputTensorsForRegex) { - return TextPreprocessingGraphOptions::REGEX_PREPROCESSOR; - } - - return CreateStatusWithPayload( - absl::StatusCode::kInvalidArgument, - absl::Substitute("Models with int32 input tensors should take exactly $0 " - "or $1 input tensors, but found $2", - kNumInputTensorsForBert, kNumInputTensorsForRegex, - model_graph.inputs()->size()), - MediaPipeTasksStatus::kInvalidNumInputTensorsError); -} - // Returns the maximum input sequence length accepted by the TFLite // model that owns `model graph` or returns an error if the model's input // tensors' shape is invalid for text preprocessing. This util assumes that the @@ -169,7 +114,7 @@ absl::StatusOr GetMaxSeqLen(const tflite::SubGraph& model_graph) { } } // namespace -absl::Status ConfigureTextPreprocessingSubgraph( +absl::Status ConfigureTextPreprocessingGraph( const ModelResources& model_resources, TextPreprocessingGraphOptions& options) { if (model_resources.GetTfLiteModel()->subgraphs()->size() != 1) { @@ -179,17 +124,16 @@ absl::Status ConfigureTextPreprocessingSubgraph( MediaPipeTasksStatus::kInvalidArgumentError); } - ASSIGN_OR_RETURN( - TextPreprocessingGraphOptions::PreprocessorType preprocessor_type, - GetPreprocessorType(model_resources)); - options.set_preprocessor_type(preprocessor_type); - switch (preprocessor_type) { - case TextPreprocessingGraphOptions::UNSPECIFIED_PREPROCESSOR: - case TextPreprocessingGraphOptions::STRING_PREPROCESSOR: { + ASSIGN_OR_RETURN(TextModelType::ModelType model_type, + GetModelType(model_resources)); + options.set_model_type(model_type); + switch (model_type) { + case TextModelType::UNSPECIFIED_MODEL: + case TextModelType::STRING_MODEL: { break; } - case TextPreprocessingGraphOptions::BERT_PREPROCESSOR: - case TextPreprocessingGraphOptions::REGEX_PREPROCESSOR: { + case TextModelType::BERT_MODEL: + case TextModelType::REGEX_MODEL: { ASSIGN_OR_RETURN( int max_seq_len, GetMaxSeqLen(*(*model_resources.GetTfLiteModel()->subgraphs())[0])); @@ -200,8 +144,7 @@ absl::Status ConfigureTextPreprocessingSubgraph( return absl::OkStatus(); } -// A "mediapipe.tasks.components.TextPreprocessingSubgraph" performs text -// preprocessing. +// A TextPreprocessingGraph performs text preprocessing. // - Accepts a std::string input and outputs CPU tensors. // // Inputs: @@ -216,9 +159,9 @@ absl::Status ConfigureTextPreprocessingSubgraph( // Vector containing the preprocessed input tensors for the TFLite model. // // The recommended way of using this subgraph is through the GraphBuilder API -// using the 'ConfigureTextPreprocessing()' function. See header file for more -// details. -class TextPreprocessingSubgraph : public mediapipe::Subgraph { +// using the 'ConfigureTextPreprocessingGraph()' function. See header file for +// more details. +class TextPreprocessingGraph : public mediapipe::Subgraph { public: absl::StatusOr GetConfig( mediapipe::SubgraphContext* sc) override { @@ -238,23 +181,22 @@ class TextPreprocessingSubgraph : public mediapipe::Subgraph { absl::StatusOr>> BuildTextPreprocessing( const TextPreprocessingGraphOptions& options, Source text_in, SideSource metadata_extractor_in, Graph& graph) { - ASSIGN_OR_RETURN( - std::string preprocessor_name, - GetCalculatorNameFromPreprocessorType(options.preprocessor_type())); + ASSIGN_OR_RETURN(std::string preprocessor_name, + GetCalculatorNameFromModelType(options.model_type())); auto& text_preprocessor = graph.AddNode(preprocessor_name); - switch (options.preprocessor_type()) { - case TextPreprocessingGraphOptions::UNSPECIFIED_PREPROCESSOR: - case TextPreprocessingGraphOptions::STRING_PREPROCESSOR: { + switch (options.model_type()) { + case TextModelType::UNSPECIFIED_MODEL: + case TextModelType::STRING_MODEL: { break; } - case TextPreprocessingGraphOptions::BERT_PREPROCESSOR: { + case TextModelType::BERT_MODEL: { text_preprocessor.GetOptions() .set_bert_max_seq_len(options.max_seq_len()); metadata_extractor_in >> text_preprocessor.SideIn(kMetadataExtractorTag); break; } - case TextPreprocessingGraphOptions::REGEX_PREPROCESSOR: { + case TextModelType::REGEX_MODEL: { text_preprocessor.GetOptions() .set_max_seq_len(options.max_seq_len()); metadata_extractor_in >> @@ -267,8 +209,6 @@ class TextPreprocessingSubgraph : public mediapipe::Subgraph { } }; REGISTER_MEDIAPIPE_GRAPH( - ::mediapipe::tasks::components::TextPreprocessingSubgraph); + ::mediapipe::tasks::components::processors::TextPreprocessingGraph); -} // namespace components -} // namespace tasks -} // namespace mediapipe +} // namespace mediapipe::tasks::components::processors diff --git a/mediapipe/tasks/cc/components/text_preprocessing_graph.h b/mediapipe/tasks/cc/components/processors/text_preprocessing_graph.h similarity index 67% rename from mediapipe/tasks/cc/components/text_preprocessing_graph.h rename to mediapipe/tasks/cc/components/processors/text_preprocessing_graph.h index b031a5550..43d57be29 100644 --- a/mediapipe/tasks/cc/components/text_preprocessing_graph.h +++ b/mediapipe/tasks/cc/components/processors/text_preprocessing_graph.h @@ -13,26 +13,31 @@ See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ -#ifndef MEDIAPIPE_TASKS_CC_COMPONENTS_TEXT_PREPROCESSING_H_ -#define MEDIAPIPE_TASKS_CC_COMPONENTS_TEXT_PREPROCESSING_H_ +#ifndef MEDIAPIPE_TASKS_CC_COMPONENTS_PROCESSORS_TEXT_PREPROCESSING_GRAPH_H_ +#define MEDIAPIPE_TASKS_CC_COMPONENTS_PROCESSORS_TEXT_PREPROCESSING_GRAPH_H_ #include "absl/status/status.h" -#include "mediapipe/tasks/cc/components/proto/text_preprocessing_graph_options.pb.h" +#include "mediapipe/tasks/cc/components/processors/proto/text_preprocessing_graph_options.pb.h" #include "mediapipe/tasks/cc/core/model_resources.h" -// Configures a TextPreprocessing subgraph using the provided `model_resources` +namespace mediapipe { +namespace tasks { +namespace components { +namespace processors { + +// Configures a TextPreprocessingGraph using the provided `model_resources` // and TextPreprocessingGraphOptions. // - Accepts a std::string input and outputs CPU tensors. // // Example usage: // // auto& preprocessing = -// graph.AddNode("mediapipe.tasks.components.TextPreprocessingSubgraph"); +// graph.AddNode("mediapipe.tasks.components.processors.TextPreprocessingSubgraph"); // MP_RETURN_IF_ERROR(ConfigureTextPreprocessingSubgraph( // model_resources, // &preprocessing.GetOptions())); // -// The resulting TextPreprocessing subgraph has the following I/O: +// The resulting TextPreprocessingGraph has the following I/O: // Inputs: // TEXT - std::string // The text to preprocess. @@ -43,16 +48,13 @@ limitations under the License. // Outputs: // TENSORS - std::vector // Vector containing the preprocessed input tensors for the TFLite model. -namespace mediapipe { -namespace tasks { -namespace components { - -absl::Status ConfigureTextPreprocessingSubgraph( - const tasks::core::ModelResources& model_resources, - tasks::components::proto::TextPreprocessingGraphOptions& options); +absl::Status ConfigureTextPreprocessingGraph( + const core::ModelResources& model_resources, + proto::TextPreprocessingGraphOptions& options); +} // namespace processors } // namespace components } // namespace tasks } // namespace mediapipe -#endif // MEDIAPIPE_TASKS_CC_COMPONENTS_TEXT_PREPROCESSING_H_ +#endif // MEDIAPIPE_TASKS_CC_COMPONENTS_PROCESSORS_TEXT_PREPROCESSING_GRAPH_H_ diff --git a/mediapipe/tasks/cc/components/utils/BUILD b/mediapipe/tasks/cc/components/utils/BUILD index 8bb5b8415..2e0ea3ce6 100644 --- a/mediapipe/tasks/cc/components/utils/BUILD +++ b/mediapipe/tasks/cc/components/utils/BUILD @@ -14,12 +14,6 @@ package(default_visibility = ["//mediapipe/tasks:internal"]) -cc_library( - name = "source_or_node_output", - hdrs = ["source_or_node_output.h"], - deps = ["//mediapipe/framework/api2:builder"], -) - cc_library( name = "cosine_similarity", srcs = ["cosine_similarity.cc"], diff --git a/mediapipe/tasks/cc/components/utils/source_or_node_output.h b/mediapipe/tasks/cc/components/utils/source_or_node_output.h deleted file mode 100644 index 55805d5a3..000000000 --- a/mediapipe/tasks/cc/components/utils/source_or_node_output.h +++ /dev/null @@ -1,66 +0,0 @@ -/* Copyright 2022 The MediaPipe Authors. All Rights Reserved. - -Licensed under the Apache License, Version 2.0 (the "License"); -you may not use this file except in compliance with the License. -You may obtain a copy of the License at - - http://www.apache.org/licenses/LICENSE-2.0 - -Unless required by applicable law or agreed to in writing, software -distributed under the License is distributed on an "AS IS" BASIS, -WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -See the License for the specific language governing permissions and -limitations under the License. -==============================================================================*/ -#ifndef MEDIAPIPE_TASKS_CC_COMPONENTS_UTILS_SOURCE_OR_NODE_OUTPUT_H_ -#define MEDIAPIPE_TASKS_CC_COMPONENTS_UTILS_SOURCE_OR_NODE_OUTPUT_H_ - -#include "mediapipe/framework/api2/builder.h" - -namespace mediapipe { -namespace tasks { - -// Helper class representing either a Source object or a GenericNode output. -// -// Source and MultiSource (the output of a GenericNode) are widely incompatible, -// but being able to represent either of these in temporary variables and -// connect them later on facilitates graph building. -template -class SourceOrNodeOutput { - public: - SourceOrNodeOutput() = delete; - // The caller is responsible for ensuring 'source' outlives this object. - explicit SourceOrNodeOutput(mediapipe::api2::builder::Source* source) - : source_(source) {} - // The caller is responsible for ensuring 'node' outlives this object. - SourceOrNodeOutput(mediapipe::api2::builder::GenericNode* node, - std::string tag) - : node_(node), tag_(tag) {} - // The caller is responsible for ensuring 'node' outlives this object. - SourceOrNodeOutput(mediapipe::api2::builder::GenericNode* node, int index) - : node_(node), index_(index) {} - - // Connects the source or node output to the provided destination. - template - void operator>>(const U& dest) { - if (source_ != nullptr) { - *source_ >> dest; - } else { - if (index_ < 0) { - node_->Out(tag_) >> dest; - } else { - node_->Out(index_) >> dest; - } - } - } - - private: - mediapipe::api2::builder::Source* source_ = nullptr; - mediapipe::api2::builder::GenericNode* node_ = nullptr; - std::string tag_ = ""; - int index_ = -1; -}; - -} // namespace tasks -} // namespace mediapipe -#endif // MEDIAPIPE_TASKS_CC_COMPONENTS_UTILS_SOURCE_OR_NODE_OUTPUT_H_ diff --git a/mediapipe/tasks/cc/core/BUILD b/mediapipe/tasks/cc/core/BUILD index 291dd29fe..e5bc18306 100644 --- a/mediapipe/tasks/cc/core/BUILD +++ b/mediapipe/tasks/cc/core/BUILD @@ -22,6 +22,7 @@ cc_library( name = "base_options", srcs = ["base_options.cc"], hdrs = ["base_options.h"], + visibility = ["//visibility:public"], deps = [ ":mediapipe_builtin_op_resolver", "//mediapipe/calculators/tensor:inference_calculator_cc_proto", @@ -39,16 +40,19 @@ cc_library( srcs = ["external_file_handler.cc"], hdrs = ["external_file_handler.h"], deps = [ - "//mediapipe/framework/port:integral_types", - "//mediapipe/framework/port:status", - "//mediapipe/tasks/cc:common", - "//mediapipe/tasks/cc/core/proto:external_file_cc_proto", "@com_google_absl//absl/memory", "@com_google_absl//absl/status", "@com_google_absl//absl/status:statusor", "@com_google_absl//absl/strings", "@com_google_absl//absl/strings:str_format", - ], + "//mediapipe/framework/port:integral_types", + "//mediapipe/framework/port:status", + "//mediapipe/tasks/cc:common", + "//mediapipe/tasks/cc/core/proto:external_file_cc_proto", + ] + select({ + "//mediapipe:windows": ["@bazel_tools//tools/cpp/runfiles"], + "//conditions:default": [], + }), ) cc_library( @@ -116,6 +120,7 @@ cc_library_with_tflite( "//mediapipe/tasks/cc/core/proto:external_file_cc_proto", "//mediapipe/tasks/cc/metadata:metadata_extractor", "//mediapipe/util:resource_util", + "//mediapipe/util:resource_util_custom", "//mediapipe/util/tflite:error_reporter", "@com_google_absl//absl/memory", "@com_google_absl//absl/status", @@ -307,7 +312,10 @@ cc_library( name = "utils", srcs = ["utils.cc"], hdrs = ["utils.h"], - visibility = ["//mediapipe/framework:mediapipe_internal"], + visibility = [ + "//mediapipe/calculators:__subpackages__", + "//mediapipe/tasks:internal", + ], deps = [ "//mediapipe/calculators/core:flow_limiter_calculator_cc_proto", "//mediapipe/framework:calculator_cc_proto", diff --git a/mediapipe/tasks/cc/core/base_task_api.h b/mediapipe/tasks/cc/core/base_task_api.h index 1019c4fe9..92d41cc84 100644 --- a/mediapipe/tasks/cc/core/base_task_api.h +++ b/mediapipe/tasks/cc/core/base_task_api.h @@ -26,7 +26,7 @@ namespace mediapipe { namespace tasks { namespace core { -// The base calss of the user-facing mediapipe tasks api classes. +// The base class of the user-facing mediapipe tasks api classes. class BaseTaskApi { public: // Constructor. diff --git a/mediapipe/tasks/cc/core/external_file_handler.cc b/mediapipe/tasks/cc/core/external_file_handler.cc index 33dfeca0b..a95b8e744 100644 --- a/mediapipe/tasks/cc/core/external_file_handler.cc +++ b/mediapipe/tasks/cc/core/external_file_handler.cc @@ -37,12 +37,17 @@ limitations under the License. #include "absl/memory/memory.h" #include "absl/status/status.h" #include "absl/status/statusor.h" +#include "absl/strings/match.h" #include "absl/strings/str_format.h" #include "absl/strings/string_view.h" #include "mediapipe/framework/port/status_macros.h" #include "mediapipe/tasks/cc/common.h" #include "mediapipe/tasks/cc/core/proto/external_file.pb.h" +#ifdef _WIN32 +#include "tools/cpp/runfiles/runfiles.h" +#endif // _WIN32 + namespace mediapipe { namespace tasks { namespace core { @@ -50,13 +55,21 @@ namespace { using ::absl::StatusCode; +#ifndef O_BINARY +#ifdef _O_BINARY +#define O_BINARY _O_BINARY +#else +#define O_BINARY 0 // If this isn't defined, the platform doesn't need it. +#endif // _O_BINARY +#endif // O_BINARY + // Gets the offset aligned to page size for mapping given files into memory by // file descriptor correctly, as according to mmap(2), the offset used in mmap // must be a multiple of sysconf(_SC_PAGE_SIZE). int64 GetPageSizeAlignedOffset(int64 offset) { #ifdef _WIN32 // mmap is not used on Windows - return -1; + return 0; #else int64 aligned_offset = offset; int64 page_size = sysconf(_SC_PAGE_SIZE); @@ -64,7 +77,7 @@ int64 GetPageSizeAlignedOffset(int64 offset) { aligned_offset = offset / page_size * page_size; } return aligned_offset; -#endif +#endif // _WIN32 } } // namespace @@ -83,13 +96,25 @@ ExternalFileHandler::CreateFromExternalFile( return handler; } -absl::Status ExternalFileHandler::MapExternalFile() { -// TODO: Add Windows support -#ifdef _WIN32 - return CreateStatusWithPayload(StatusCode::kFailedPrecondition, - "File loading is not yet supported on Windows", - MediaPipeTasksStatus::kFileReadError); +absl::StatusOr PathToResourceAsFile(std::string path) { +#ifndef _WIN32 + return path; #else + if (absl::StartsWith(path, "./")) { + path = "mediapipe" + path.substr(1); + } + + std::string error; + std::unique_ptr<::bazel::tools::cpp::runfiles::Runfiles> runfiles( + ::bazel::tools::cpp::runfiles::Runfiles::Create("", &error)); + if (!runfiles) { + return absl::InternalError("Unable to initialize runfiles: " + error); + } + return runfiles->Rlocation(path); +#endif // _WIN32 +} + +absl::Status ExternalFileHandler::MapExternalFile() { if (!external_file_.file_content().empty()) { return absl::OkStatus(); } else if (external_file_.has_file_pointer_meta()) { @@ -106,6 +131,7 @@ absl::Status ExternalFileHandler::MapExternalFile() { } return absl::OkStatus(); } + if (external_file_.file_name().empty() && !external_file_.has_file_descriptor_meta()) { return CreateStatusWithPayload( @@ -117,7 +143,9 @@ absl::Status ExternalFileHandler::MapExternalFile() { // Obtain file descriptor, offset and size. int fd = -1; if (!external_file_.file_name().empty()) { - owned_fd_ = open(external_file_.file_name().c_str(), O_RDONLY); + ASSIGN_OR_RETURN(std::string file_name, + PathToResourceAsFile(external_file_.file_name())); + owned_fd_ = open(file_name.c_str(), O_RDONLY | O_BINARY); if (owned_fd_ < 0) { const std::string error_message = absl::StrFormat( "Unable to open file at %s", external_file_.file_name()); @@ -148,6 +176,12 @@ absl::Status ExternalFileHandler::MapExternalFile() { } fd = owned_fd_; } else { +#ifdef _WIN32 + return CreateStatusWithPayload( + StatusCode::kFailedPrecondition, + "File descriptors are not supported on Windows.", + MediaPipeTasksStatus::kFileReadError); +#else fd = external_file_.file_descriptor_meta().fd(); if (fd < 0) { return CreateStatusWithPayload( @@ -157,6 +191,7 @@ absl::Status ExternalFileHandler::MapExternalFile() { } buffer_offset_ = external_file_.file_descriptor_meta().offset(); buffer_size_ = external_file_.file_descriptor_meta().length(); +#endif // _WIN32 } // Get actual file size. Always use 0 as offset to lseek(2) to get the actual // file size, as SEEK_END returns the size of the file *plus* offset. @@ -188,22 +223,37 @@ absl::Status ExternalFileHandler::MapExternalFile() { buffer_size_ + buffer_offset_, file_size), MediaPipeTasksStatus::kInvalidArgumentError); } + // If buffer_offset_ is not multiple of sysconf(_SC_PAGE_SIZE), align with // extra leading bytes and adjust buffer_size_ to account for the extra // leading bytes. buffer_aligned_offset_ = GetPageSizeAlignedOffset(buffer_offset_); buffer_aligned_size_ = buffer_size_ + buffer_offset_ - buffer_aligned_offset_; + +#ifdef _WIN32 + buffer_ = malloc(file_size); + // Return the file pointer back to the beginning of the file + lseek(fd, 0L, SEEK_SET); + buffer_size_ = read(fd, buffer_, file_size); + if (buffer_size_ <= 0) { + free(buffer_); + buffer_ = nullptr; + } +#else // Map into memory. buffer_ = mmap(/*addr=*/nullptr, buffer_aligned_size_, PROT_READ, MAP_SHARED, fd, buffer_aligned_offset_); if (buffer_ == MAP_FAILED) { + buffer_ = nullptr; + } +#endif // _WIN32 + if (!buffer_) { return CreateStatusWithPayload( StatusCode::kUnknown, absl::StrFormat("Unable to map file to memory buffer, errno=%d", errno), MediaPipeTasksStatus::kFileMmapError); } return absl::OkStatus(); -#endif } absl::string_view ExternalFileHandler::GetFileContent() { @@ -222,11 +272,13 @@ absl::string_view ExternalFileHandler::GetFileContent() { } ExternalFileHandler::~ExternalFileHandler() { -#ifndef _WIN32 - if (buffer_ != MAP_FAILED) { + if (buffer_) { +#ifdef _WIN32 + free(buffer_); +#else munmap(buffer_, buffer_aligned_size_); +#endif // _WIN32 } -#endif if (owned_fd_ >= 0) { close(owned_fd_); } diff --git a/mediapipe/tasks/cc/core/model_resources.cc b/mediapipe/tasks/cc/core/model_resources.cc index 618761f32..7819f6213 100644 --- a/mediapipe/tasks/cc/core/model_resources.cc +++ b/mediapipe/tasks/cc/core/model_resources.cc @@ -33,6 +33,7 @@ limitations under the License. #include "mediapipe/tasks/cc/core/proto/external_file.pb.h" #include "mediapipe/tasks/cc/metadata/metadata_extractor.h" #include "mediapipe/util/resource_util.h" +#include "mediapipe/util/resource_util_custom.h" #include "mediapipe/util/tflite/error_reporter.h" #include "tensorflow/lite/core/api/error_reporter.h" #include "tensorflow/lite/core/api/op_resolver.h" @@ -99,11 +100,20 @@ const tflite::Model* ModelResources::GetTfLiteModel() const { absl::Status ModelResources::BuildModelFromExternalFileProto() { if (model_file_->has_file_name()) { - // If the model file name is a relative path, searches the file in a - // platform-specific location and returns the absolute path on success. - ASSIGN_OR_RETURN(std::string path_to_resource, - mediapipe::PathToResourceAsFile(model_file_->file_name())); - model_file_->set_file_name(path_to_resource); + if (HasCustomGlobalResourceProvider()) { + // If the model contents are provided via a custom ResourceProviderFn, the + // open() method may not work. Thus, loads the model content from the + // model file path in advance with the help of GetResourceContents. + MP_RETURN_IF_ERROR(GetResourceContents( + model_file_->file_name(), model_file_->mutable_file_content())); + model_file_->clear_file_name(); + } else { + // If the model file name is a relative path, searches the file in a + // platform-specific location and returns the absolute path on success. + ASSIGN_OR_RETURN(std::string path_to_resource, + PathToResourceAsFile(model_file_->file_name())); + model_file_->set_file_name(path_to_resource); + } } ASSIGN_OR_RETURN( model_file_handler_, diff --git a/mediapipe/tasks/cc/core/model_task_graph.cc b/mediapipe/tasks/cc/core/model_task_graph.cc index 66434483b..0cb556ec2 100644 --- a/mediapipe/tasks/cc/core/model_task_graph.cc +++ b/mediapipe/tasks/cc/core/model_task_graph.cc @@ -186,7 +186,7 @@ absl::StatusOr ModelTaskGraph::CreateModelResources( absl::StatusOr ModelTaskGraph::CreateModelAssetBundleResources( SubgraphContext* sc, std::unique_ptr external_file, - const std::string tag_suffix) { + std::string tag_suffix) { auto model_resources_cache_service = sc->Service(kModelResourcesCacheService); bool has_file_pointer_meta = external_file->has_file_pointer_meta(); // if external file is set by file pointer, no need to add the model asset diff --git a/mediapipe/tasks/cc/core/model_task_graph.h b/mediapipe/tasks/cc/core/model_task_graph.h index 50dcc903b..3068b2c46 100644 --- a/mediapipe/tasks/cc/core/model_task_graph.h +++ b/mediapipe/tasks/cc/core/model_task_graph.h @@ -59,14 +59,16 @@ class ModelTaskGraph : public Subgraph { // creates a local model resources object that can only be used in the graph // construction stage. The returned model resources pointer will provide graph // authors with the access to the metadata extractor and the tflite model. + // If more than one model resources are created in a graph, the model + // resources graph service add the tag_suffix to support multiple resources. template absl::StatusOr CreateModelResources( - SubgraphContext* sc) { + SubgraphContext* sc, std::string tag_suffix = "") { auto external_file = std::make_unique(); external_file->Swap(sc->MutableOptions() ->mutable_base_options() ->mutable_model_asset()); - return CreateModelResources(sc, std::move(external_file)); + return CreateModelResources(sc, std::move(external_file), tag_suffix); } // If the model resources graph service is available, creates a model @@ -83,7 +85,7 @@ class ModelTaskGraph : public Subgraph { // resources. absl::StatusOr CreateModelResources( SubgraphContext* sc, std::unique_ptr external_file, - const std::string tag_suffix = ""); + std::string tag_suffix = ""); // If the model resources graph service is available, creates a model asset // bundle resources object from the subgraph context, and caches the created diff --git a/mediapipe/tasks/cc/text/text_classifier/BUILD b/mediapipe/tasks/cc/text/text_classifier/BUILD index 52b0c0e4b..3c9c3fc0e 100644 --- a/mediapipe/tasks/cc/text/text_classifier/BUILD +++ b/mediapipe/tasks/cc/text/text_classifier/BUILD @@ -16,35 +16,13 @@ package(default_visibility = ["//mediapipe/tasks:internal"]) licenses(["notice"]) -cc_library( - name = "text_classifier_graph", - srcs = ["text_classifier_graph.cc"], - deps = [ - "//mediapipe/calculators/tensor:inference_calculator", - "//mediapipe/framework:calculator_cc_proto", - "//mediapipe/framework:calculator_framework", - "//mediapipe/framework/api2:builder", - "//mediapipe/framework/api2:port", - "//mediapipe/tasks/cc/components:text_preprocessing_graph", - "//mediapipe/tasks/cc/components/containers/proto:classifications_cc_proto", - "//mediapipe/tasks/cc/components/processors:classification_postprocessing_graph", - "//mediapipe/tasks/cc/components/processors/proto:classification_postprocessing_graph_options_cc_proto", - "//mediapipe/tasks/cc/components/proto:text_preprocessing_graph_options_cc_proto", - "//mediapipe/tasks/cc/core:model_resources", - "//mediapipe/tasks/cc/core:model_resources_calculator", - "//mediapipe/tasks/cc/core:model_task_graph", - "//mediapipe/tasks/cc/core/proto:model_resources_calculator_cc_proto", - "//mediapipe/tasks/cc/text/text_classifier/proto:text_classifier_graph_options_cc_proto", - "@com_google_absl//absl/status", - "@com_google_absl//absl/status:statusor", - ], - alwayslink = 1, -) - +# Docs for Mediapipe Tasks Text Classifier +# https://developers.google.com/mediapipe/solutions/text/text_classifier cc_library( name = "text_classifier", srcs = ["text_classifier.cc"], hdrs = ["text_classifier.h"], + visibility = ["//visibility:public"], deps = [ ":text_classifier_graph", "//mediapipe/framework:packet", @@ -65,6 +43,31 @@ cc_library( ], ) +cc_library( + name = "text_classifier_graph", + srcs = ["text_classifier_graph.cc"], + deps = [ + "//mediapipe/calculators/tensor:inference_calculator_cpu", + "//mediapipe/framework:calculator_cc_proto", + "//mediapipe/framework:calculator_framework", + "//mediapipe/framework/api2:builder", + "//mediapipe/framework/api2:port", + "//mediapipe/tasks/cc/components/containers/proto:classifications_cc_proto", + "//mediapipe/tasks/cc/components/processors:classification_postprocessing_graph", + "//mediapipe/tasks/cc/components/processors:text_preprocessing_graph", + "//mediapipe/tasks/cc/components/processors/proto:classification_postprocessing_graph_options_cc_proto", + "//mediapipe/tasks/cc/components/processors/proto:text_preprocessing_graph_options_cc_proto", + "//mediapipe/tasks/cc/core:model_resources", + "//mediapipe/tasks/cc/core:model_resources_calculator", + "//mediapipe/tasks/cc/core:model_task_graph", + "//mediapipe/tasks/cc/core/proto:model_resources_calculator_cc_proto", + "//mediapipe/tasks/cc/text/text_classifier/proto:text_classifier_graph_options_cc_proto", + "@com_google_absl//absl/status", + "@com_google_absl//absl/status:statusor", + ], + alwayslink = 1, +) + cc_test( name = "text_classifier_test", srcs = ["text_classifier_test.cc"], diff --git a/mediapipe/tasks/cc/text/text_classifier/proto/text_classifier_graph_options.proto b/mediapipe/tasks/cc/text/text_classifier/proto/text_classifier_graph_options.proto index 8f4d7eea6..41f87b519 100644 --- a/mediapipe/tasks/cc/text/text_classifier/proto/text_classifier_graph_options.proto +++ b/mediapipe/tasks/cc/text/text_classifier/proto/text_classifier_graph_options.proto @@ -18,6 +18,7 @@ syntax = "proto2"; package mediapipe.tasks.text.text_classifier.proto; import "mediapipe/framework/calculator.proto"; +import "mediapipe/framework/calculator_options.proto"; import "mediapipe/tasks/cc/components/processors/proto/classifier_options.proto"; import "mediapipe/tasks/cc/core/proto/base_options.proto"; diff --git a/mediapipe/tasks/cc/text/text_classifier/text_classifier_graph.cc b/mediapipe/tasks/cc/text/text_classifier/text_classifier_graph.cc index 36ff68a07..3be92f309 100644 --- a/mediapipe/tasks/cc/text/text_classifier/text_classifier_graph.cc +++ b/mediapipe/tasks/cc/text/text_classifier/text_classifier_graph.cc @@ -25,8 +25,8 @@ limitations under the License. #include "mediapipe/tasks/cc/components/containers/proto/classifications.pb.h" #include "mediapipe/tasks/cc/components/processors/classification_postprocessing_graph.h" #include "mediapipe/tasks/cc/components/processors/proto/classification_postprocessing_graph_options.pb.h" -#include "mediapipe/tasks/cc/components/proto/text_preprocessing_graph_options.pb.h" -#include "mediapipe/tasks/cc/components/text_preprocessing_graph.h" +#include "mediapipe/tasks/cc/components/processors/proto/text_preprocessing_graph_options.pb.h" +#include "mediapipe/tasks/cc/components/processors/text_preprocessing_graph.h" #include "mediapipe/tasks/cc/core/model_resources.h" #include "mediapipe/tasks/cc/core/model_task_graph.h" #include "mediapipe/tasks/cc/core/proto/model_resources_calculator.pb.h" @@ -46,19 +46,11 @@ using ::mediapipe::api2::builder::Source; using ::mediapipe::tasks::components::containers::proto::ClassificationResult; using ::mediapipe::tasks::core::ModelResources; -constexpr char kClassificationResultTag[] = "CLASSIFICATION_RESULT"; constexpr char kClassificationsTag[] = "CLASSIFICATIONS"; constexpr char kTextTag[] = "TEXT"; constexpr char kMetadataExtractorTag[] = "METADATA_EXTRACTOR"; constexpr char kTensorsTag[] = "TENSORS"; -// TODO: remove once Java API migration is over. -// Struct holding the different output streams produced by the text classifier. -struct TextClassifierOutputStreams { - Source classification_result; - Source classifications; -}; - } // namespace // A "TextClassifierGraph" performs Natural Language classification (including @@ -72,10 +64,6 @@ struct TextClassifierOutputStreams { // Outputs: // CLASSIFICATIONS - ClassificationResult @Optional // The classification results aggregated by classifier head. -// TODO: remove once Java API migration is over. -// CLASSIFICATION_RESULT - (DEPRECATED) ClassificationResult @Optional -// The aggregated classification result object that has 3 dimensions: -// (classification head, classification timestamp, classification category). // // Example: // node { @@ -102,14 +90,11 @@ class TextClassifierGraph : public core::ModelTaskGraph { CreateModelResources(sc)); Graph graph; ASSIGN_OR_RETURN( - auto output_streams, + auto classifications, BuildTextClassifierTask( sc->Options(), *model_resources, graph[Input(kTextTag)], graph)); - output_streams.classification_result >> - graph[Output(kClassificationResultTag)]; - output_streams.classifications >> - graph[Output(kClassificationsTag)]; + classifications >> graph[Output(kClassificationsTag)]; return graph.GetConfig(); } @@ -124,18 +109,18 @@ class TextClassifierGraph : public core::ModelTaskGraph { // TextClassifier model file with model metadata. // text_in: (std::string) stream to run text classification on. // graph: the mediapipe builder::Graph instance to be updated. - absl::StatusOr BuildTextClassifierTask( + absl::StatusOr> BuildTextClassifierTask( const proto::TextClassifierGraphOptions& task_options, const ModelResources& model_resources, Source text_in, Graph& graph) { // Adds preprocessing calculators and connects them to the text input // stream. - auto& preprocessing = - graph.AddNode("mediapipe.tasks.components.TextPreprocessingSubgraph"); - MP_RETURN_IF_ERROR(components::ConfigureTextPreprocessingSubgraph( + auto& preprocessing = graph.AddNode( + "mediapipe.tasks.components.processors.TextPreprocessingGraph"); + MP_RETURN_IF_ERROR(components::processors::ConfigureTextPreprocessingGraph( model_resources, preprocessing.GetOptions< - tasks::components::proto::TextPreprocessingGraphOptions>())); + components::processors::proto::TextPreprocessingGraphOptions>())); text_in >> preprocessing.In(kTextTag); // Adds both InferenceCalculator and ModelResourcesCalculator. @@ -161,11 +146,7 @@ class TextClassifierGraph : public core::ModelTaskGraph { // Outputs the aggregated classification result as the subgraph output // stream. - return TextClassifierOutputStreams{ - /*classification_result=*/postprocessing[Output( - kClassificationResultTag)], - /*classifications=*/postprocessing[Output( - kClassificationsTag)]}; + return postprocessing[Output(kClassificationsTag)]; } }; diff --git a/mediapipe/tasks/cc/text/text_classifier/text_classifier_test.cc b/mediapipe/tasks/cc/text/text_classifier/text_classifier_test.cc index 8f73914fc..71f7b1f2d 100644 --- a/mediapipe/tasks/cc/text/text_classifier/text_classifier_test.cc +++ b/mediapipe/tasks/cc/text/text_classifier/text_classifier_test.cc @@ -38,10 +38,7 @@ limitations under the License. #include "mediapipe/tasks/cc/text/text_classifier/text_classifier_test_utils.h" #include "tensorflow/lite/core/shims/cc/shims_test_util.h" -namespace mediapipe { -namespace tasks { -namespace text { -namespace text_classifier { +namespace mediapipe::tasks::text::text_classifier { namespace { using ::mediapipe::file::JoinPath; @@ -88,6 +85,8 @@ void ExpectApproximatelyEqual(const TextClassifierResult& actual, } } +} // namespace + class TextClassifierTest : public tflite_shims::testing::Test {}; TEST_F(TextClassifierTest, CreateSucceedsWithBertModel) { @@ -136,28 +135,46 @@ TEST_F(TextClassifierTest, TextClassifierWithBert) { options->base_options.model_asset_path = GetFullPath(kTestBertModelPath); MP_ASSERT_OK_AND_ASSIGN(std::unique_ptr classifier, TextClassifier::Create(std::move(options))); - MP_ASSERT_OK_AND_ASSIGN( - TextClassifierResult negative_result, - classifier->Classify("unflinchingly bleak and desperate")); + TextClassifierResult negative_expected; + TextClassifierResult positive_expected; + +#ifdef _WIN32 + negative_expected.classifications.emplace_back(Classifications{ + /*categories=*/{ + {/*index=*/0, /*score=*/0.956124, /*category_name=*/"negative"}, + {/*index=*/1, /*score=*/0.043875, /*category_name=*/"positive"}}, + /*head_index=*/0, + /*head_name=*/"probability"}); + positive_expected.classifications.emplace_back(Classifications{ + /*categories=*/{ + {/*index=*/1, /*score=*/0.999951, /*category_name=*/"positive"}, + {/*index=*/0, /*score=*/0.000048, /*category_name=*/"negative"}}, + /*head_index=*/0, + /*head_name=*/"probability"}); +#else negative_expected.classifications.emplace_back(Classifications{ /*categories=*/{ {/*index=*/0, /*score=*/0.956317, /*category_name=*/"negative"}, {/*index=*/1, /*score=*/0.043683, /*category_name=*/"positive"}}, /*head_index=*/0, /*head_name=*/"probability"}); - ExpectApproximatelyEqual(negative_result, negative_expected); - - MP_ASSERT_OK_AND_ASSIGN( - TextClassifierResult positive_result, - classifier->Classify("it's a charming and often affecting journey")); - TextClassifierResult positive_expected; positive_expected.classifications.emplace_back(Classifications{ /*categories=*/{ {/*index=*/1, /*score=*/0.999945, /*category_name=*/"positive"}, {/*index=*/0, /*score=*/0.000056, /*category_name=*/"negative"}}, /*head_index=*/0, /*head_name=*/"probability"}); +#endif // _WIN32 + + MP_ASSERT_OK_AND_ASSIGN( + TextClassifierResult negative_result, + classifier->Classify("unflinchingly bleak and desperate")); + ExpectApproximatelyEqual(negative_result, negative_expected); + + MP_ASSERT_OK_AND_ASSIGN( + TextClassifierResult positive_result, + classifier->Classify("it's a charming and often affecting journey")); ExpectApproximatelyEqual(positive_result, positive_expected); MP_ASSERT_OK(classifier->Close()); @@ -217,8 +234,47 @@ TEST_F(TextClassifierTest, TextClassifierWithStringToBool) { MP_ASSERT_OK(classifier->Close()); } -} // namespace -} // namespace text_classifier -} // namespace text -} // namespace tasks -} // namespace mediapipe +TEST_F(TextClassifierTest, BertLongPositive) { + std::stringstream ss_for_positive_review; + ss_for_positive_review + << "it's a charming and often affecting journey and this is a long"; + for (int i = 0; i < kMaxSeqLen; ++i) { + ss_for_positive_review << " long"; + } + ss_for_positive_review << " movie review"; + auto options = std::make_unique(); + options->base_options.model_asset_path = GetFullPath(kTestBertModelPath); + MP_ASSERT_OK_AND_ASSIGN(std::unique_ptr classifier, + TextClassifier::Create(std::move(options))); + MP_ASSERT_OK_AND_ASSIGN(TextClassifierResult result, + classifier->Classify(ss_for_positive_review.str())); + TextClassifierResult expected; + std::vector categories; + +// Predicted scores are slightly different across platforms. +#ifdef __APPLE__ + categories.push_back( + {/*index=*/1, /*score=*/0.974181, /*category_name=*/"positive"}); + categories.push_back( + {/*index=*/0, /*score=*/0.025819, /*category_name=*/"negative"}); +#elif defined _WIN32 + categories.push_back( + {/*index=*/1, /*score=*/0.976686, /*category_name=*/"positive"}); + categories.push_back( + {/*index=*/0, /*score=*/0.023313, /*category_name=*/"negative"}); +#else + categories.push_back( + {/*index=*/1, /*score=*/0.985889, /*category_name=*/"positive"}); + categories.push_back( + {/*index=*/0, /*score=*/0.014112, /*category_name=*/"negative"}); +#endif // __APPLE__ + + expected.classifications.emplace_back( + Classifications{/*categories=*/categories, + /*head_index=*/0, + /*head_name=*/"probability"}); + ExpectApproximatelyEqual(result, expected); + MP_ASSERT_OK(classifier->Close()); +} + +} // namespace mediapipe::tasks::text::text_classifier diff --git a/mediapipe/tasks/cc/text/text_embedder/BUILD b/mediapipe/tasks/cc/text/text_embedder/BUILD index e2e16c9c1..4c970159e 100644 --- a/mediapipe/tasks/cc/text/text_embedder/BUILD +++ b/mediapipe/tasks/cc/text/text_embedder/BUILD @@ -16,10 +16,13 @@ package(default_visibility = ["//mediapipe/tasks:internal"]) licenses(["notice"]) +# Docs for Mediapipe Tasks Text Embedder +# https://developers.google.com/mediapipe/solutions/text/text_embedder cc_library( name = "text_embedder", srcs = ["text_embedder.cc"], hdrs = ["text_embedder.h"], + visibility = ["//visibility:public"], deps = [ ":text_embedder_graph", "//mediapipe/calculators/tensor:inference_calculator_cc_proto", @@ -45,17 +48,17 @@ cc_library( name = "text_embedder_graph", srcs = ["text_embedder_graph.cc"], deps = [ - "//mediapipe/calculators/tensor:inference_calculator", "//mediapipe/calculators/tensor:inference_calculator_cc_proto", + "//mediapipe/calculators/tensor:inference_calculator_cpu", "//mediapipe/framework:calculator_cc_proto", "//mediapipe/framework:calculator_framework", "//mediapipe/framework/api2:builder", "//mediapipe/framework/api2:port", - "//mediapipe/tasks/cc/components:text_preprocessing_graph", "//mediapipe/tasks/cc/components/containers/proto:embeddings_cc_proto", "//mediapipe/tasks/cc/components/processors:embedding_postprocessing_graph", + "//mediapipe/tasks/cc/components/processors:text_preprocessing_graph", "//mediapipe/tasks/cc/components/processors/proto:embedding_postprocessing_graph_options_cc_proto", - "//mediapipe/tasks/cc/components/proto:text_preprocessing_graph_options_cc_proto", + "//mediapipe/tasks/cc/components/processors/proto:text_preprocessing_graph_options_cc_proto", "//mediapipe/tasks/cc/core:model_resources", "//mediapipe/tasks/cc/core:model_task_graph", "//mediapipe/tasks/cc/core/proto:model_resources_calculator_cc_proto", diff --git a/mediapipe/tasks/cc/text/text_embedder/proto/text_embedder_graph_options.proto b/mediapipe/tasks/cc/text/text_embedder/proto/text_embedder_graph_options.proto index 6b8d41a57..fc8e02858 100644 --- a/mediapipe/tasks/cc/text/text_embedder/proto/text_embedder_graph_options.proto +++ b/mediapipe/tasks/cc/text/text_embedder/proto/text_embedder_graph_options.proto @@ -18,9 +18,13 @@ syntax = "proto2"; package mediapipe.tasks.text.text_embedder.proto; import "mediapipe/framework/calculator.proto"; +import "mediapipe/framework/calculator_options.proto"; import "mediapipe/tasks/cc/components/processors/proto/embedder_options.proto"; import "mediapipe/tasks/cc/core/proto/base_options.proto"; +option java_package = "com.google.mediapipe.tasks.text.textembedder.proto"; +option java_outer_classname = "TextEmbedderGraphOptionsProto"; + message TextEmbedderGraphOptions { extend mediapipe.CalculatorOptions { optional TextEmbedderGraphOptions ext = 477589892; diff --git a/mediapipe/tasks/cc/text/text_embedder/text_embedder_graph.cc b/mediapipe/tasks/cc/text/text_embedder/text_embedder_graph.cc index 79eedb6b5..225ef07bd 100644 --- a/mediapipe/tasks/cc/text/text_embedder/text_embedder_graph.cc +++ b/mediapipe/tasks/cc/text/text_embedder/text_embedder_graph.cc @@ -23,8 +23,8 @@ limitations under the License. #include "mediapipe/tasks/cc/components/containers/proto/embeddings.pb.h" #include "mediapipe/tasks/cc/components/processors/embedding_postprocessing_graph.h" #include "mediapipe/tasks/cc/components/processors/proto/embedding_postprocessing_graph_options.pb.h" -#include "mediapipe/tasks/cc/components/proto/text_preprocessing_graph_options.pb.h" -#include "mediapipe/tasks/cc/components/text_preprocessing_graph.h" +#include "mediapipe/tasks/cc/components/processors/proto/text_preprocessing_graph_options.pb.h" +#include "mediapipe/tasks/cc/components/processors/text_preprocessing_graph.h" #include "mediapipe/tasks/cc/core/model_resources.h" #include "mediapipe/tasks/cc/core/model_task_graph.h" #include "mediapipe/tasks/cc/core/proto/model_resources_calculator.pb.h" @@ -107,12 +107,12 @@ class TextEmbedderGraph : public core::ModelTaskGraph { Graph& graph) { // Adds preprocessing calculators and connects them to the text input // stream. - auto& preprocessing = - graph.AddNode("mediapipe.tasks.components.TextPreprocessingSubgraph"); - MP_RETURN_IF_ERROR(components::ConfigureTextPreprocessingSubgraph( + auto& preprocessing = graph.AddNode( + "mediapipe.tasks.components.processors.TextPreprocessingGraph"); + MP_RETURN_IF_ERROR(components::processors::ConfigureTextPreprocessingGraph( model_resources, preprocessing.GetOptions< - tasks::components::proto::TextPreprocessingGraphOptions>())); + components::processors::proto::TextPreprocessingGraphOptions>())); text_in >> preprocessing.In(kTextTag); // Adds both InferenceCalculator and ModelResourcesCalculator. @@ -128,10 +128,12 @@ class TextEmbedderGraph : public core::ModelTaskGraph { // inference results. auto& postprocessing = graph.AddNode( "mediapipe.tasks.components.processors.EmbeddingPostprocessingGraph"); - MP_RETURN_IF_ERROR(components::processors::ConfigureEmbeddingPostprocessing( - model_resources, task_options.embedder_options(), - &postprocessing.GetOptions())); + MP_RETURN_IF_ERROR( + components::processors::ConfigureEmbeddingPostprocessingGraph( + model_resources, task_options.embedder_options(), + &postprocessing + .GetOptions())); inference.Out(kTensorsTag) >> postprocessing.In(kTensorsTag); // Outputs the embedding result. diff --git a/mediapipe/tasks/cc/text/text_embedder/text_embedder_test.cc b/mediapipe/tasks/cc/text/text_embedder/text_embedder_test.cc index fa3d8af91..533d829b9 100644 --- a/mediapipe/tasks/cc/text/text_embedder/text_embedder_test.cc +++ b/mediapipe/tasks/cc/text/text_embedder/text_embedder_test.cc @@ -75,7 +75,11 @@ TEST_F(EmbedderTest, SucceedsWithMobileBert) { text_embedder->Embed("it's a charming and often affecting journey")); ASSERT_EQ(result0.embeddings.size(), 1); ASSERT_EQ(result0.embeddings[0].float_embedding.size(), 512); +#ifdef _WIN32 + ASSERT_NEAR(result0.embeddings[0].float_embedding[0], 21.2148f, kEpsilon); +#else ASSERT_NEAR(result0.embeddings[0].float_embedding[0], 19.9016f, kEpsilon); +#endif // _WIN32 MP_ASSERT_OK_AND_ASSIGN( auto result1, text_embedder->Embed("what a great and fantastic trip")); @@ -87,7 +91,11 @@ TEST_F(EmbedderTest, SucceedsWithMobileBert) { MP_ASSERT_OK_AND_ASSIGN( double similarity, TextEmbedder::CosineSimilarity(result0.embeddings[0], result1.embeddings[0])); +#ifdef _WIN32 + EXPECT_NEAR(similarity, 0.971417, kSimilarityTolerancy); +#else EXPECT_NEAR(similarity, 0.969514, kSimilarityTolerancy); +#endif // _WIN32 MP_ASSERT_OK(text_embedder->Close()); } @@ -139,5 +147,36 @@ TEST_F(EmbedderTest, SucceedsWithQuantization) { MP_ASSERT_OK(text_embedder->Close()); } +TEST_F(EmbedderTest, SucceedsWithMobileBertAndDifferentThemes) { + auto options = std::make_unique(); + options->base_options.model_asset_path = + JoinPath("./", kTestDataDirectory, kMobileBert); + MP_ASSERT_OK_AND_ASSIGN(std::unique_ptr text_embedder, + TextEmbedder::Create(std::move(options))); + + MP_ASSERT_OK_AND_ASSIGN( + TextEmbedderResult result0, + text_embedder->Embed("When you go to this restaurant, they hold the " + "pancake upside-down before they hand it " + "to you. It's a great gimmick.")); + MP_ASSERT_OK_AND_ASSIGN( + TextEmbedderResult result1, + text_embedder->Embed( + "Let's make a plan to steal the declaration of independence.")); + + // Check cosine similarity. + MP_ASSERT_OK_AND_ASSIGN( + double similarity, TextEmbedder::CosineSimilarity(result0.embeddings[0], + result1.embeddings[0])); + // TODO: These similarity should likely be lower +#ifdef _WIN32 + EXPECT_NEAR(similarity, 0.98152, kSimilarityTolerancy); +#else + EXPECT_NEAR(similarity, 0.98088, kSimilarityTolerancy); +#endif // _WIN32 + + MP_ASSERT_OK(text_embedder->Close()); +} + } // namespace } // namespace mediapipe::tasks::text::text_embedder diff --git a/mediapipe/tasks/cc/text/tokenizers/BUILD b/mediapipe/tasks/cc/text/tokenizers/BUILD index 7f1ea2848..92fac8eaa 100644 --- a/mediapipe/tasks/cc/text/tokenizers/BUILD +++ b/mediapipe/tasks/cc/text/tokenizers/BUILD @@ -12,7 +12,7 @@ # See the License for the specific language governing permissions and # limitations under the License. -package(default_visibility = ["//mediapipe/framework:mediapipe_internal"]) +package(default_visibility = ["//mediapipe/calculators/tensor:__subpackages__"]) licenses(["notice"]) diff --git a/mediapipe/tasks/cc/text/utils/BUILD b/mediapipe/tasks/cc/text/utils/BUILD index 710e8a984..092a7d450 100644 --- a/mediapipe/tasks/cc/text/utils/BUILD +++ b/mediapipe/tasks/cc/text/utils/BUILD @@ -43,3 +43,43 @@ cc_test( "@com_google_absl//absl/container:node_hash_map", ], ) + +cc_library( + name = "text_model_utils", + srcs = ["text_model_utils.cc"], + hdrs = ["text_model_utils.h"], + deps = [ + "//mediapipe/tasks/cc:common", + "//mediapipe/tasks/cc/components/processors/proto:text_model_type_cc_proto", + "//mediapipe/tasks/cc/core:model_resources", + "//mediapipe/tasks/cc/metadata:metadata_extractor", + "@com_google_absl//absl/algorithm:container", + "@com_google_absl//absl/status", + "@com_google_absl//absl/status:statusor", + "@com_google_absl//absl/strings", + "@org_tensorflow//tensorflow/lite/schema:schema_fbs", + ], +) + +cc_test( + name = "text_model_utils_test", + srcs = ["text_model_utils_test.cc"], + data = [ + "//mediapipe/tasks/testdata/text:bert_text_classifier_models", + "//mediapipe/tasks/testdata/text:mobilebert_embedding_model", + "//mediapipe/tasks/testdata/text:regex_embedding_with_metadata", + "//mediapipe/tasks/testdata/text:text_classifier_models", + ], + deps = [ + ":text_model_utils", + "//mediapipe/framework/deps:file_path", + "//mediapipe/framework/port:gtest_main", + "//mediapipe/tasks/cc/components/processors/proto:text_model_type_cc_proto", + "//mediapipe/tasks/cc/core:model_resources", + "//mediapipe/tasks/cc/core/proto:external_file_cc_proto", + "@com_google_absl//absl/flags:flag", + "@com_google_absl//absl/status:statusor", + "@com_google_absl//absl/strings", + "@org_tensorflow//tensorflow/lite/core/shims:cc_shims_test_util", + ], +) diff --git a/mediapipe/tasks/cc/text/utils/text_model_utils.cc b/mediapipe/tasks/cc/text/utils/text_model_utils.cc new file mode 100644 index 000000000..9d0005ec1 --- /dev/null +++ b/mediapipe/tasks/cc/text/utils/text_model_utils.cc @@ -0,0 +1,119 @@ +/* Copyright 2022 The MediaPipe Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "mediapipe/tasks/cc/text/utils/text_model_utils.h" + +#include "absl/algorithm/container.h" +#include "absl/status/status.h" +#include "absl/status/statusor.h" +#include "absl/strings/substitute.h" +#include "mediapipe/tasks/cc/common.h" +#include "mediapipe/tasks/cc/components/processors/proto/text_model_type.pb.h" +#include "mediapipe/tasks/cc/core/model_resources.h" +#include "mediapipe/tasks/cc/metadata/metadata_extractor.h" +#include "tensorflow/lite/schema/schema_generated.h" + +namespace mediapipe::tasks::text::utils { +namespace { + +using ::mediapipe::tasks::components::processors::proto::TextModelType; +using ::mediapipe::tasks::core::ModelResources; +using ::mediapipe::tasks::metadata::ModelMetadataExtractor; + +constexpr int kNumInputTensorsForBert = 3; +constexpr int kNumInputTensorsForRegex = 1; +constexpr int kNumInputTensorsForStringPreprocessor = 1; + +// Determines the ModelType for a model with int32 input tensors based +// on the number of input tensors. Returns an error if there is missing metadata +// or an invalid number of input tensors. +absl::StatusOr GetIntTensorModelType( + const ModelResources& model_resources, int num_input_tensors) { + const ModelMetadataExtractor* metadata_extractor = + model_resources.GetMetadataExtractor(); + if (metadata_extractor->GetModelMetadata() == nullptr || + metadata_extractor->GetModelMetadata()->subgraph_metadata() == nullptr) { + return CreateStatusWithPayload( + absl::StatusCode::kInvalidArgument, + "Text models with int32 input tensors require TFLite Model " + "Metadata but none was found", + MediaPipeTasksStatus::kMetadataNotFoundError); + } + + if (num_input_tensors == kNumInputTensorsForBert) { + return TextModelType::BERT_MODEL; + } + + if (num_input_tensors == kNumInputTensorsForRegex) { + return TextModelType::REGEX_MODEL; + } + + return CreateStatusWithPayload( + absl::StatusCode::kInvalidArgument, + absl::Substitute("Models with int32 input tensors should take exactly $0 " + "or $1 input tensors, but found $2", + kNumInputTensorsForBert, kNumInputTensorsForRegex, + num_input_tensors), + MediaPipeTasksStatus::kInvalidNumInputTensorsError); +} + +// Determines the ModelType for a model with string input tensors based +// on the number of input tensors. Returns an error if there is an invalid +// number of input tensors. +absl::StatusOr GetStringTensorModelType( + const ModelResources& model_resources, int num_input_tensors) { + if (num_input_tensors == kNumInputTensorsForStringPreprocessor) { + return TextModelType::STRING_MODEL; + } + + return CreateStatusWithPayload( + absl::StatusCode::kInvalidArgument, + absl::Substitute("Models with string input tensors should take exactly " + "$0 tensors, but found $1", + kNumInputTensorsForStringPreprocessor, + num_input_tensors), + MediaPipeTasksStatus::kInvalidNumInputTensorsError); +} +} // namespace + +absl::StatusOr GetModelType( + const ModelResources& model_resources) { + const tflite::SubGraph& model_graph = + *(*model_resources.GetTfLiteModel()->subgraphs())[0]; + bool all_int32_tensors = + absl::c_all_of(*model_graph.inputs(), [&model_graph](int i) { + return (*model_graph.tensors())[i]->type() == tflite::TensorType_INT32; + }); + bool all_string_tensors = + absl::c_all_of(*model_graph.inputs(), [&model_graph](int i) { + return (*model_graph.tensors())[i]->type() == tflite::TensorType_STRING; + }); + if (!all_int32_tensors && !all_string_tensors) { + return CreateStatusWithPayload( + absl::StatusCode::kInvalidArgument, + "All input tensors should have type int32 or all should have type " + "string", + MediaPipeTasksStatus::kInvalidInputTensorTypeError); + } + if (all_string_tensors) { + return GetStringTensorModelType(model_resources, + model_graph.inputs()->size()); + } + + // Otherwise, all tensors should have type int32 + return GetIntTensorModelType(model_resources, model_graph.inputs()->size()); +} + +} // namespace mediapipe::tasks::text::utils diff --git a/mediapipe/tasks/cc/text/utils/text_model_utils.h b/mediapipe/tasks/cc/text/utils/text_model_utils.h new file mode 100644 index 000000000..da8783d33 --- /dev/null +++ b/mediapipe/tasks/cc/text/utils/text_model_utils.h @@ -0,0 +1,33 @@ +/* Copyright 2022 The MediaPipe Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef MEDIAPIPE_TASKS_CC_TEXT_UTILS_TEXT_MODEL_UTILS_H_ +#define MEDIAPIPE_TASKS_CC_TEXT_UTILS_TEXT_MODEL_UTILS_H_ + +#include "absl/status/statusor.h" +#include "mediapipe/tasks/cc/components/processors/proto/text_model_type.pb.h" +#include "mediapipe/tasks/cc/core/model_resources.h" + +namespace mediapipe::tasks::text::utils { + +// Determines the ModelType for the model based on its metadata as well +// as its input tensors' type and count. Returns an error if there is no +// compatible model type. +absl::StatusOr +GetModelType(const core::ModelResources& model_resources); + +} // namespace mediapipe::tasks::text::utils + +#endif // MEDIAPIPE_TASKS_CC_TEXT_UTILS_TEXT_MODEL_UTILS_H_ diff --git a/mediapipe/tasks/cc/text/utils/text_model_utils_test.cc b/mediapipe/tasks/cc/text/utils/text_model_utils_test.cc new file mode 100644 index 000000000..c02f8eca5 --- /dev/null +++ b/mediapipe/tasks/cc/text/utils/text_model_utils_test.cc @@ -0,0 +1,108 @@ +/* Copyright 2022 The MediaPipe Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "mediapipe/tasks/cc/text/utils/text_model_utils.h" + +#include +#include + +#include "absl/flags/flag.h" +#include "absl/status/statusor.h" +#include "absl/strings/string_view.h" +#include "mediapipe/framework/deps/file_path.h" +#include "mediapipe/framework/port/gmock.h" +#include "mediapipe/framework/port/gtest.h" +#include "mediapipe/framework/port/status_matchers.h" +#include "mediapipe/tasks/cc/components/processors/proto/text_model_type.pb.h" +#include "mediapipe/tasks/cc/core/model_resources.h" +#include "mediapipe/tasks/cc/core/proto/external_file.pb.h" +#include "tensorflow/lite/core/shims/cc/shims_test_util.h" + +namespace mediapipe::tasks::text::utils { + +namespace { + +using ::mediapipe::file::JoinPath; +using ::mediapipe::tasks::components::processors::proto::TextModelType; +using ::mediapipe::tasks::core::ModelResources; +using ::mediapipe::tasks::core::proto::ExternalFile; + +constexpr absl::string_view kTestModelResourcesTag = "test_model_resources"; + +constexpr absl::string_view kTestDataDirectory = + "/mediapipe/tasks/testdata/text/"; +// Classification model with BERT preprocessing. +constexpr absl::string_view kBertClassifierPath = "bert_text_classifier.tflite"; +// Embedding model with BERT preprocessing. +constexpr absl::string_view kMobileBert = + "mobilebert_embedding_with_metadata.tflite"; +// Classification model with regex preprocessing. +constexpr absl::string_view kRegexClassifierPath = + "test_model_text_classifier_with_regex_tokenizer.tflite"; +// Embedding model with regex preprocessing. +constexpr absl::string_view kRegexOneEmbeddingModel = + "regex_one_embedding_with_metadata.tflite"; +// Classification model that takes a string tensor and outputs a bool tensor. +constexpr absl::string_view kStringToBoolModelPath = + "test_model_text_classifier_bool_output.tflite"; + +std::string GetFullPath(absl::string_view file_name) { + return JoinPath("./", kTestDataDirectory, file_name); +} + +absl::StatusOr GetModelTypeFromFile( + absl::string_view file_name) { + auto model_file = std::make_unique(); + model_file->set_file_name(GetFullPath(file_name)); + ASSIGN_OR_RETURN(auto model_resources, + ModelResources::Create(std::string(kTestModelResourcesTag), + std::move(model_file))); + return GetModelType(*model_resources); +} + +} // namespace + +class TextModelUtilsTest : public tflite_shims::testing::Test {}; + +TEST_F(TextModelUtilsTest, BertClassifierModelTest) { + MP_ASSERT_OK_AND_ASSIGN(auto model_type, + GetModelTypeFromFile(kBertClassifierPath)); + ASSERT_EQ(model_type, TextModelType::BERT_MODEL); +} + +TEST_F(TextModelUtilsTest, BertEmbedderModelTest) { + MP_ASSERT_OK_AND_ASSIGN(auto model_type, GetModelTypeFromFile(kMobileBert)); + ASSERT_EQ(model_type, TextModelType::BERT_MODEL); +} + +TEST_F(TextModelUtilsTest, RegexClassifierModelTest) { + MP_ASSERT_OK_AND_ASSIGN(auto model_type, + GetModelTypeFromFile(kRegexClassifierPath)); + ASSERT_EQ(model_type, TextModelType::REGEX_MODEL); +} + +TEST_F(TextModelUtilsTest, RegexEmbedderModelTest) { + MP_ASSERT_OK_AND_ASSIGN(auto model_type, + GetModelTypeFromFile(kRegexOneEmbeddingModel)); + ASSERT_EQ(model_type, TextModelType::REGEX_MODEL); +} + +TEST_F(TextModelUtilsTest, StringInputModelTest) { + MP_ASSERT_OK_AND_ASSIGN(auto model_type, + GetModelTypeFromFile(kStringToBoolModelPath)); + ASSERT_EQ(model_type, TextModelType::STRING_MODEL); +} + +} // namespace mediapipe::tasks::text::utils diff --git a/mediapipe/tasks/cc/vision/core/BUILD b/mediapipe/tasks/cc/vision/core/BUILD index e8e197a1d..1f5ab5faf 100644 --- a/mediapipe/tasks/cc/vision/core/BUILD +++ b/mediapipe/tasks/cc/vision/core/BUILD @@ -19,11 +19,13 @@ package(default_visibility = ["//mediapipe/tasks:internal"]) cc_library( name = "running_mode", hdrs = ["running_mode.h"], + visibility = ["//visibility:public"], ) cc_library( name = "image_processing_options", hdrs = ["image_processing_options.h"], + visibility = ["//visibility:public"], deps = [ "//mediapipe/tasks/cc/components/containers:rect", ], diff --git a/mediapipe/tasks/cc/vision/core/base_vision_task_api.h b/mediapipe/tasks/cc/vision/core/base_vision_task_api.h index c3c0a0261..a86b2cca8 100644 --- a/mediapipe/tasks/cc/vision/core/base_vision_task_api.h +++ b/mediapipe/tasks/cc/vision/core/base_vision_task_api.h @@ -129,13 +129,13 @@ class BaseVisionTaskApi : public tasks::core::BaseTaskApi { if (roi.left >= roi.right || roi.top >= roi.bottom) { return CreateStatusWithPayload( absl::StatusCode::kInvalidArgument, - "Expected Rect with left < right and top < bottom.", + "Expected RectF with left < right and top < bottom.", MediaPipeTasksStatus::kImageProcessingInvalidArgumentError); } if (roi.left < 0 || roi.top < 0 || roi.right > 1 || roi.bottom > 1) { return CreateStatusWithPayload( absl::StatusCode::kInvalidArgument, - "Expected Rect values to be in [0,1].", + "Expected RectF values to be in [0,1].", MediaPipeTasksStatus::kImageProcessingInvalidArgumentError); } normalized_rect.set_x_center((roi.left + roi.right) / 2.0); diff --git a/mediapipe/tasks/cc/vision/core/image_processing_options.h b/mediapipe/tasks/cc/vision/core/image_processing_options.h index 7e764c1fe..e2647be71 100644 --- a/mediapipe/tasks/cc/vision/core/image_processing_options.h +++ b/mediapipe/tasks/cc/vision/core/image_processing_options.h @@ -28,14 +28,15 @@ namespace core { // Options for image processing. // // If both region-or-interest and rotation are specified, the crop around the -// region-of-interest is extracted first, the the specified rotation is applied +// region-of-interest is extracted first, then the specified rotation is applied // to the crop. struct ImageProcessingOptions { // The optional region-of-interest to crop from the image. If not specified, // the full image is used. // // Coordinates must be in [0,1] with 'left' < 'right' and 'top' < bottom. - std::optional region_of_interest = std::nullopt; + std::optional region_of_interest = + std::nullopt; // The rotation to apply to the image (or cropped region-of-interest), in // degrees clockwise. diff --git a/mediapipe/tasks/cc/vision/face_detector/BUILD b/mediapipe/tasks/cc/vision/face_detector/BUILD new file mode 100644 index 000000000..09af34aa0 --- /dev/null +++ b/mediapipe/tasks/cc/vision/face_detector/BUILD @@ -0,0 +1,61 @@ +# Copyright 2023 The MediaPipe Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +package(default_visibility = [ + # "//mediapipe/tasks:internal", + "//visibility:public", +]) + +licenses(["notice"]) + +cc_library( + name = "face_detector_graph", + srcs = ["face_detector_graph.cc"], + deps = [ + "//mediapipe/calculators/core:clip_vector_size_calculator", + "//mediapipe/calculators/core:clip_vector_size_calculator_cc_proto", + "//mediapipe/calculators/tensor:image_to_tensor_calculator_cc_proto", + "//mediapipe/calculators/tensor:inference_calculator", + "//mediapipe/calculators/tensor:tensors_to_detections_calculator", + "//mediapipe/calculators/tensor:tensors_to_detections_calculator_cc_proto", + "//mediapipe/calculators/tflite:ssd_anchors_calculator", + "//mediapipe/calculators/tflite:ssd_anchors_calculator_cc_proto", + "//mediapipe/calculators/util:detection_label_id_to_text_calculator", + "//mediapipe/calculators/util:detection_label_id_to_text_calculator_cc_proto", + "//mediapipe/calculators/util:detection_projection_calculator", + "//mediapipe/calculators/util:detections_to_rects_calculator", + "//mediapipe/calculators/util:detections_to_rects_calculator_cc_proto", + "//mediapipe/calculators/util:non_max_suppression_calculator", + "//mediapipe/calculators/util:non_max_suppression_calculator_cc_proto", + "//mediapipe/calculators/util:rect_transformation_calculator", + "//mediapipe/calculators/util:rect_transformation_calculator_cc_proto", + "//mediapipe/framework/api2:builder", + "//mediapipe/framework/api2:port", + "//mediapipe/framework/formats:detection_cc_proto", + "//mediapipe/framework/formats:image", + "//mediapipe/framework/formats:rect_cc_proto", + "//mediapipe/framework/formats:tensor", + "//mediapipe/tasks/cc:common", + "//mediapipe/tasks/cc/components/processors:image_preprocessing_graph", + "//mediapipe/tasks/cc/core:model_resources", + "//mediapipe/tasks/cc/core:model_task_graph", + "//mediapipe/tasks/cc/core:utils", + "//mediapipe/tasks/cc/core/proto:inference_subgraph_cc_proto", + "//mediapipe/tasks/cc/vision/face_detector/proto:face_detector_graph_options_cc_proto", + "//mediapipe/tasks/cc/vision/utils:image_tensor_specs", + "@com_google_absl//absl/status", + "@com_google_absl//absl/status:statusor", + ], + alwayslink = 1, +) diff --git a/mediapipe/tasks/cc/vision/face_detector/face_detector_graph.cc b/mediapipe/tasks/cc/vision/face_detector/face_detector_graph.cc new file mode 100644 index 000000000..6b60621a6 --- /dev/null +++ b/mediapipe/tasks/cc/vision/face_detector/face_detector_graph.cc @@ -0,0 +1,208 @@ +/* Copyright 2023 The MediaPipe Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include + +#include "absl/status/status.h" +#include "absl/status/statusor.h" +#include "mediapipe/calculators/core/clip_vector_size_calculator.pb.h" +#include "mediapipe/calculators/tensor/image_to_tensor_calculator.pb.h" +#include "mediapipe/calculators/tensor/tensors_to_detections_calculator.pb.h" +#include "mediapipe/calculators/tflite/ssd_anchors_calculator.pb.h" +#include "mediapipe/calculators/util/detection_label_id_to_text_calculator.pb.h" +#include "mediapipe/calculators/util/detections_to_rects_calculator.pb.h" +#include "mediapipe/calculators/util/non_max_suppression_calculator.pb.h" +#include "mediapipe/calculators/util/rect_transformation_calculator.pb.h" +#include "mediapipe/framework/api2/builder.h" +#include "mediapipe/framework/api2/port.h" +#include "mediapipe/framework/formats/detection.pb.h" +#include "mediapipe/framework/formats/image.h" +#include "mediapipe/framework/formats/rect.pb.h" +#include "mediapipe/framework/formats/tensor.h" +#include "mediapipe/tasks/cc/common.h" +#include "mediapipe/tasks/cc/components/processors/image_preprocessing_graph.h" +#include "mediapipe/tasks/cc/core/model_resources.h" +#include "mediapipe/tasks/cc/core/model_task_graph.h" +#include "mediapipe/tasks/cc/core/proto/inference_subgraph.pb.h" +#include "mediapipe/tasks/cc/core/utils.h" +#include "mediapipe/tasks/cc/vision/face_detector/proto/face_detector_graph_options.pb.h" +#include "mediapipe/tasks/cc/vision/utils/image_tensor_specs.h" + +namespace mediapipe { +namespace tasks { +namespace vision { +namespace face_detector { + +using ::mediapipe::NormalizedRect; +using ::mediapipe::Tensor; +using ::mediapipe::api2::Input; +using ::mediapipe::api2::Output; +using ::mediapipe::api2::builder::Graph; +using ::mediapipe::api2::builder::Source; +using ::mediapipe::tasks::vision::face_detector::proto:: + FaceDetectorGraphOptions; + +namespace { +constexpr char kImageTag[] = "IMAGE"; +constexpr char kNormRectTag[] = "NORM_RECT"; +constexpr char kDetectionsTag[] = "DETECTIONS"; + +void ConfigureSsdAnchorsCalculator( + mediapipe::SsdAnchorsCalculatorOptions* options) { + // TODO config SSD anchors parameters from metadata. + options->set_num_layers(1); + options->set_min_scale(0.1484375); + options->set_max_scale(0.75); + options->set_input_size_height(192); + options->set_input_size_width(192); + options->set_anchor_offset_x(0.5); + options->set_anchor_offset_y(0.5); + options->add_strides(4); + options->add_aspect_ratios(1.0); + options->set_fixed_anchor_size(true); + options->set_interpolated_scale_aspect_ratio(0.0); +} + +void ConfigureTensorsToDetectionsCalculator( + const FaceDetectorGraphOptions& tasks_options, + mediapipe::TensorsToDetectionsCalculatorOptions* options) { + // TODO use metadata to configure these fields. + options->set_num_classes(1); + options->set_num_boxes(2304); + options->set_num_coords(16); + options->set_box_coord_offset(0); + options->set_keypoint_coord_offset(4); + options->set_num_keypoints(6); + options->set_num_values_per_keypoint(2); + options->set_sigmoid_score(true); + options->set_score_clipping_thresh(100.0); + options->set_reverse_output_order(true); + options->set_min_score_thresh(tasks_options.min_detection_confidence()); + options->set_x_scale(192.0); + options->set_y_scale(192.0); + options->set_w_scale(192.0); + options->set_h_scale(192.0); +} + +void ConfigureNonMaxSuppressionCalculator( + const FaceDetectorGraphOptions& tasks_options, + mediapipe::NonMaxSuppressionCalculatorOptions* options) { + options->set_min_suppression_threshold( + tasks_options.min_suppression_threshold()); + options->set_overlap_type( + mediapipe::NonMaxSuppressionCalculatorOptions::INTERSECTION_OVER_UNION); + options->set_algorithm( + mediapipe::NonMaxSuppressionCalculatorOptions::WEIGHTED); +} + +} // namespace + +class FaceDetectorGraph : public core::ModelTaskGraph { + public: + absl::StatusOr GetConfig( + SubgraphContext* sc) override { + ASSIGN_OR_RETURN(const auto* model_resources, + CreateModelResources(sc)); + Graph graph; + ASSIGN_OR_RETURN(auto face_detections, + BuildFaceDetectionSubgraph( + sc->Options(), + *model_resources, graph[Input(kImageTag)], + graph[Input(kNormRectTag)], graph)); + face_detections >> graph[Output>(kDetectionsTag)]; + return graph.GetConfig(); + } + + private: + absl::StatusOr>> BuildFaceDetectionSubgraph( + const FaceDetectorGraphOptions& subgraph_options, + const core::ModelResources& model_resources, Source image_in, + Source norm_rect_in, Graph& graph) { + // Image preprocessing subgraph to convert image to tensor for the tflite + // model. + auto& preprocessing = graph.AddNode( + "mediapipe.tasks.components.processors.ImagePreprocessingGraph"); + bool use_gpu = + components::processors::DetermineImagePreprocessingGpuBackend( + subgraph_options.base_options().acceleration()); + MP_RETURN_IF_ERROR(components::processors::ConfigureImagePreprocessingGraph( + model_resources, use_gpu, + &preprocessing.GetOptions< + components::processors::proto::ImagePreprocessingGraphOptions>())); + auto& image_to_tensor_options = + *preprocessing + .GetOptions() + .mutable_image_to_tensor_options(); + image_to_tensor_options.set_keep_aspect_ratio(true); + image_to_tensor_options.set_border_mode( + mediapipe::ImageToTensorCalculatorOptions::BORDER_ZERO); + image_in >> preprocessing.In("IMAGE"); + norm_rect_in >> preprocessing.In("NORM_RECT"); + auto preprocessed_tensors = preprocessing.Out("TENSORS"); + auto matrix = preprocessing.Out("MATRIX"); + + // Face detection model inferece. + auto& inference = AddInference( + model_resources, subgraph_options.base_options().acceleration(), graph); + preprocessed_tensors >> inference.In("TENSORS"); + auto model_output_tensors = + inference.Out("TENSORS").Cast>(); + + // Generates a single side packet containing a vector of SSD anchors. + auto& ssd_anchor = graph.AddNode("SsdAnchorsCalculator"); + ConfigureSsdAnchorsCalculator( + &ssd_anchor.GetOptions()); + auto anchors = ssd_anchor.SideOut(""); + + // Converts output tensors to Detections. + auto& tensors_to_detections = + graph.AddNode("TensorsToDetectionsCalculator"); + ConfigureTensorsToDetectionsCalculator( + subgraph_options, + &tensors_to_detections + .GetOptions()); + model_output_tensors >> tensors_to_detections.In("TENSORS"); + anchors >> tensors_to_detections.SideIn("ANCHORS"); + auto detections = tensors_to_detections.Out("DETECTIONS"); + + // Non maximum suppression removes redundant face detections. + auto& non_maximum_suppression = + graph.AddNode("NonMaxSuppressionCalculator"); + ConfigureNonMaxSuppressionCalculator( + subgraph_options, + &non_maximum_suppression + .GetOptions()); + detections >> non_maximum_suppression.In(""); + auto nms_detections = non_maximum_suppression.Out(""); + + // Projects detections back into the input image coordinates system. + auto& detection_projection = graph.AddNode("DetectionProjectionCalculator"); + nms_detections >> detection_projection.In("DETECTIONS"); + matrix >> detection_projection.In("PROJECTION_MATRIX"); + auto face_detections = + detection_projection[Output>("DETECTIONS")]; + + return {face_detections}; + } +}; + +REGISTER_MEDIAPIPE_GRAPH( + ::mediapipe::tasks::vision::face_detector::FaceDetectorGraph); + +} // namespace face_detector +} // namespace vision +} // namespace tasks +} // namespace mediapipe diff --git a/mediapipe/tasks/cc/vision/face_detector/face_detector_graph_test.cc b/mediapipe/tasks/cc/vision/face_detector/face_detector_graph_test.cc new file mode 100644 index 000000000..fc1f49f13 --- /dev/null +++ b/mediapipe/tasks/cc/vision/face_detector/face_detector_graph_test.cc @@ -0,0 +1,183 @@ +/* Copyright 2023 The MediaPipe Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include +#include +#include +#include +#include +#include + +#include "absl/flags/flag.h" +#include "absl/status/statusor.h" +#include "absl/strings/str_format.h" +#include "absl/strings/string_view.h" +#include "mediapipe/framework/api2/builder.h" +#include "mediapipe/framework/api2/port.h" +#include "mediapipe/framework/calculator_framework.h" +#include "mediapipe/framework/deps/file_path.h" +#include "mediapipe/framework/formats/detection.pb.h" +#include "mediapipe/framework/formats/image.h" +#include "mediapipe/framework/formats/rect.pb.h" +#include "mediapipe/framework/packet.h" +#include "mediapipe/framework/port/file_helpers.h" +#include "mediapipe/framework/port/gmock.h" +#include "mediapipe/framework/port/gtest.h" +#include "mediapipe/framework/port/parse_text_proto.h" +#include "mediapipe/tasks/cc/core/mediapipe_builtin_op_resolver.h" +#include "mediapipe/tasks/cc/core/model_resources.h" +#include "mediapipe/tasks/cc/core/proto/base_options.pb.h" +#include "mediapipe/tasks/cc/core/proto/external_file.pb.h" +#include "mediapipe/tasks/cc/core/task_runner.h" +#include "mediapipe/tasks/cc/vision/face_detector/proto/face_detector_graph_options.pb.h" +#include "mediapipe/tasks/cc/vision/utils/image_utils.h" + +namespace mediapipe { +namespace tasks { +namespace vision { +namespace face_detector { +namespace { + +using ::file::Defaults; +using ::file::GetTextProto; +using ::mediapipe::NormalizedRect; +using ::mediapipe::api2::Input; +using ::mediapipe::api2::Output; +using ::mediapipe::api2::builder::Graph; +using ::mediapipe::api2::builder::Source; +using ::mediapipe::file::JoinPath; +using ::mediapipe::tasks::core::TaskRunner; +using ::mediapipe::tasks::vision::DecodeImageFromFile; +using ::mediapipe::tasks::vision::face_detector::proto:: + FaceDetectorGraphOptions; +using ::testing::EqualsProto; +using ::testing::Pointwise; +using ::testing::TestParamInfo; +using ::testing::TestWithParam; +using ::testing::Values; +using ::testing::proto::Approximately; +using ::testing::proto::Partially; + +constexpr char kTestDataDirectory[] = "/mediapipe/tasks/testdata/vision/"; +constexpr char kFullRangeBlazeFaceModel[] = "face_detection_full_range.tflite"; +constexpr char kFullRangeSparseBlazeFaceModel[] = + "face_detection_full_range_sparse.tflite"; +constexpr char kPortraitImage[] = "portrait.jpg"; +constexpr char kPortraitExpectedDetection[] = + "portrait_expected_detection.pbtxt"; + +constexpr char kImageTag[] = "IMAGE"; +constexpr char kImageName[] = "image"; +constexpr char kNormRectTag[] = "NORM_RECT"; +constexpr char kNormRectName[] = "norm_rect"; +constexpr char kDetectionsTag[] = "DETECTIONS"; +constexpr char kDetectionsName[] = "detections"; + +constexpr float kFaceDetectionMaxDiff = 0.01; + +// Helper function to create a TaskRunner. +absl::StatusOr> CreateTaskRunner( + absl::string_view model_name) { + Graph graph; + + auto& face_detector_graph = + graph.AddNode("mediapipe.tasks.vision.face_detector.FaceDetectorGraph"); + + auto options = std::make_unique(); + options->mutable_base_options()->mutable_model_asset()->set_file_name( + JoinPath("./", kTestDataDirectory, model_name)); + options->set_min_detection_confidence(0.6); + options->set_min_suppression_threshold(0.3); + face_detector_graph.GetOptions().Swap( + options.get()); + + graph[Input(kImageTag)].SetName(kImageName) >> + face_detector_graph.In(kImageTag); + graph[Input(kNormRectTag)].SetName(kNormRectName) >> + face_detector_graph.In(kNormRectTag); + + face_detector_graph.Out(kDetectionsTag).SetName(kDetectionsName) >> + graph[Output>(kDetectionsTag)]; + + return TaskRunner::Create( + graph.GetConfig(), std::make_unique()); +} + +Detection GetExpectedFaceDetectionResult(absl::string_view file_name) { + Detection detection; + CHECK_OK(GetTextProto(file::JoinPath("./", kTestDataDirectory, file_name), + &detection, Defaults())) + << "Expected face detection result does not exist."; + return detection; +} + +struct TestParams { + // The name of this test, for convenience when displaying test results. + std::string test_name; + // The filename of face landmark detection model. + std::string face_detection_model_name; + // The filename of test image. + std::string test_image_name; + // Expected face detection results. + std::vector expected_result; +}; + +class FaceDetectorGraphTest : public testing::TestWithParam {}; + +TEST_P(FaceDetectorGraphTest, Succeed) { + MP_ASSERT_OK_AND_ASSIGN( + Image image, DecodeImageFromFile(JoinPath("./", kTestDataDirectory, + GetParam().test_image_name))); + NormalizedRect input_norm_rect; + input_norm_rect.set_x_center(0.5); + input_norm_rect.set_y_center(0.5); + input_norm_rect.set_width(1.0); + input_norm_rect.set_height(1.0); + MP_ASSERT_OK_AND_ASSIGN( + auto task_runner, CreateTaskRunner(GetParam().face_detection_model_name)); + auto output_packets = task_runner->Process( + {{kImageName, MakePacket(std::move(image))}, + {kNormRectName, + MakePacket(std::move(input_norm_rect))}}); + MP_ASSERT_OK(output_packets); + const std::vector& face_detections = + (*output_packets)[kDetectionsName].Get>(); + EXPECT_THAT(face_detections, Pointwise(Approximately(Partially(EqualsProto()), + kFaceDetectionMaxDiff), + GetParam().expected_result)); +} + +INSTANTIATE_TEST_SUITE_P( + FaceDetectorGraphTest, FaceDetectorGraphTest, + Values(TestParams{.test_name = "FullRange", + .face_detection_model_name = kFullRangeBlazeFaceModel, + .test_image_name = kPortraitImage, + .expected_result = {GetExpectedFaceDetectionResult( + kPortraitExpectedDetection)}}, + TestParams{ + .test_name = "FullRangeSparse", + .face_detection_model_name = kFullRangeSparseBlazeFaceModel, + .test_image_name = kPortraitImage, + .expected_result = {GetExpectedFaceDetectionResult( + kPortraitExpectedDetection)}}), + [](const TestParamInfo& info) { + return info.param.test_name; + }); + +} // namespace +} // namespace face_detector +} // namespace vision +} // namespace tasks +} // namespace mediapipe diff --git a/mediapipe/tasks/cc/components/proto/BUILD b/mediapipe/tasks/cc/vision/face_detector/proto/BUILD similarity index 70% rename from mediapipe/tasks/cc/components/proto/BUILD rename to mediapipe/tasks/cc/vision/face_detector/proto/BUILD index 4534a1652..ca9a6f8c4 100644 --- a/mediapipe/tasks/cc/components/proto/BUILD +++ b/mediapipe/tasks/cc/vision/face_detector/proto/BUILD @@ -1,4 +1,4 @@ -# Copyright 2022 The MediaPipe Authors. All Rights Reserved. +# Copyright 2023 The MediaPipe Authors. All Rights Reserved. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. @@ -14,20 +14,18 @@ load("//mediapipe/framework/port:build_config.bzl", "mediapipe_proto_library") -package(default_visibility = ["//mediapipe/tasks:internal"]) +package(default_visibility = [ + "//mediapipe/tasks:internal", +]) licenses(["notice"]) mediapipe_proto_library( - name = "segmenter_options_proto", - srcs = ["segmenter_options.proto"], -) - -mediapipe_proto_library( - name = "text_preprocessing_graph_options_proto", - srcs = ["text_preprocessing_graph_options.proto"], + name = "face_detector_graph_options_proto", + srcs = ["face_detector_graph_options.proto"], deps = [ "//mediapipe/framework:calculator_options_proto", "//mediapipe/framework:calculator_proto", + "//mediapipe/tasks/cc/core/proto:base_options_proto", ], ) diff --git a/mediapipe/tasks/cc/vision/face_detector/proto/face_detector_graph_options.proto b/mediapipe/tasks/cc/vision/face_detector/proto/face_detector_graph_options.proto new file mode 100644 index 000000000..a58338288 --- /dev/null +++ b/mediapipe/tasks/cc/vision/face_detector/proto/face_detector_graph_options.proto @@ -0,0 +1,42 @@ +/* Copyright 2023 The MediaPipe Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +syntax = "proto2"; + +package mediapipe.tasks.vision.face_detector.proto; + +import "mediapipe/framework/calculator.proto"; +import "mediapipe/framework/calculator_options.proto"; +import "mediapipe/tasks/cc/core/proto/base_options.proto"; + +option java_package = "com.google.mediapipe.tasks.vision.facedetector.proto"; +option java_outer_classname = "FaceDetectorGraphOptionsProto"; + +message FaceDetectorGraphOptions { + extend mediapipe.CalculatorOptions { + optional FaceDetectorGraphOptions ext = 502141897; + } + // Base options for configuring Task library, such as specifying the TfLite + // model file with metadata, accelerator options, etc. + optional core.proto.BaseOptions base_options = 1; + + // Minimum confidence value ([0.0, 1.0]) for confidence score to be considered + // successfully detecting a face in the image. + optional float min_detection_confidence = 2 [default = 0.5]; + + // IoU threshold ([0,0, 1.0]) for non-maximu-suppression to be considered + // duplicate detetions. + optional float min_suppression_threshold = 3 [default = 0.5]; +} diff --git a/mediapipe/tasks/cc/vision/gesture_recognizer/BUILD b/mediapipe/tasks/cc/vision/gesture_recognizer/BUILD index 75289b1e8..7ffae6ff2 100644 --- a/mediapipe/tasks/cc/vision/gesture_recognizer/BUILD +++ b/mediapipe/tasks/cc/vision/gesture_recognizer/BUILD @@ -18,6 +18,51 @@ package(default_visibility = [ licenses(["notice"]) +# Docs for Mediapipe Tasks Gesture Recognizer +# https://developers.google.com/mediapipe/solutions/vision/gesture_recognizer +cc_library( + name = "gesture_recognizer", + srcs = ["gesture_recognizer.cc"], + hdrs = ["gesture_recognizer.h"], + visibility = ["//visibility:public"], + deps = [ + ":gesture_recognizer_graph", + ":gesture_recognizer_result", + ":hand_gesture_recognizer_graph", + "//mediapipe/framework:packet", + "//mediapipe/framework/api2:builder", + "//mediapipe/framework/api2:port", + "//mediapipe/framework/formats:classification_cc_proto", + "//mediapipe/framework/formats:image", + "//mediapipe/framework/formats:landmark_cc_proto", + "//mediapipe/framework/formats:rect_cc_proto", + "//mediapipe/tasks/cc:common", + "//mediapipe/tasks/cc/components/processors:classifier_options", + "//mediapipe/tasks/cc/components/processors/proto:classifier_options_cc_proto", + "//mediapipe/tasks/cc/core:base_options", + "//mediapipe/tasks/cc/core:base_task_api", + "//mediapipe/tasks/cc/core:model_resources", + "//mediapipe/tasks/cc/core:task_runner", + "//mediapipe/tasks/cc/core:utils", + "//mediapipe/tasks/cc/core/proto:inference_subgraph_cc_proto", + "//mediapipe/tasks/cc/vision/core:base_vision_task_api", + "//mediapipe/tasks/cc/vision/core:image_processing_options", + "//mediapipe/tasks/cc/vision/core:running_mode", + "//mediapipe/tasks/cc/vision/core:vision_task_api_factory", + "//mediapipe/tasks/cc/vision/gesture_recognizer/proto:gesture_classifier_graph_options_cc_proto", + "//mediapipe/tasks/cc/vision/gesture_recognizer/proto:gesture_recognizer_graph_options_cc_proto", + "//mediapipe/tasks/cc/vision/gesture_recognizer/proto:hand_gesture_recognizer_graph_options_cc_proto", + "//mediapipe/tasks/cc/vision/hand_detector/proto:hand_detector_graph_options_cc_proto", + "//mediapipe/tasks/cc/vision/hand_landmarker/proto:hand_landmarker_graph_options_cc_proto", + "//mediapipe/tasks/cc/vision/hand_landmarker/proto:hand_landmarks_detector_graph_options_cc_proto", + "@com_google_absl//absl/memory", + "@com_google_absl//absl/status", + "@com_google_absl//absl/status:statusor", + "@org_tensorflow//tensorflow/lite/core/api:op_resolver", + "@org_tensorflow//tensorflow/lite/kernels:builtin_ops", + ], +) + cc_library( name = "handedness_util", srcs = ["handedness_util.cc"], @@ -59,10 +104,7 @@ cc_library( "//mediapipe/framework/formats:rect_cc_proto", "//mediapipe/framework/formats:tensor", "//mediapipe/tasks/cc:common", - "//mediapipe/tasks/cc/components:image_preprocessing", "//mediapipe/tasks/cc/components/processors:classification_postprocessing_graph", - "//mediapipe/tasks/cc/components/processors/proto:classification_postprocessing_graph_options_cc_proto", - "//mediapipe/tasks/cc/components/processors/proto:classifier_options_cc_proto", "//mediapipe/tasks/cc/core:model_asset_bundle_resources", "//mediapipe/tasks/cc/core:model_resources", "//mediapipe/tasks/cc/core:model_resources_cache", @@ -98,6 +140,7 @@ cc_library( "//mediapipe/framework/api2:builder", "//mediapipe/framework/api2:port", "//mediapipe/framework/formats:classification_cc_proto", + "//mediapipe/framework/formats:detection_cc_proto", "//mediapipe/framework/formats:image", "//mediapipe/framework/formats:landmark_cc_proto", "//mediapipe/framework/formats:rect_cc_proto", @@ -127,51 +170,9 @@ cc_library( cc_library( name = "gesture_recognizer_result", hdrs = ["gesture_recognizer_result.h"], + visibility = ["//visibility:public"], deps = [ "//mediapipe/framework/formats:classification_cc_proto", "//mediapipe/framework/formats:landmark_cc_proto", ], ) - -cc_library( - name = "gesture_recognizer", - srcs = ["gesture_recognizer.cc"], - hdrs = ["gesture_recognizer.h"], - deps = [ - ":gesture_recognizer_graph", - ":gesture_recognizer_result", - ":hand_gesture_recognizer_graph", - "//mediapipe/framework:packet", - "//mediapipe/framework/api2:builder", - "//mediapipe/framework/api2:port", - "//mediapipe/framework/formats:classification_cc_proto", - "//mediapipe/framework/formats:image", - "//mediapipe/framework/formats:landmark_cc_proto", - "//mediapipe/framework/formats:rect_cc_proto", - "//mediapipe/tasks/cc:common", - "//mediapipe/tasks/cc/components:image_preprocessing", - "//mediapipe/tasks/cc/components/processors:classifier_options", - "//mediapipe/tasks/cc/components/processors/proto:classifier_options_cc_proto", - "//mediapipe/tasks/cc/core:base_options", - "//mediapipe/tasks/cc/core:base_task_api", - "//mediapipe/tasks/cc/core:model_resources", - "//mediapipe/tasks/cc/core:task_runner", - "//mediapipe/tasks/cc/core:utils", - "//mediapipe/tasks/cc/core/proto:inference_subgraph_cc_proto", - "//mediapipe/tasks/cc/vision/core:base_vision_task_api", - "//mediapipe/tasks/cc/vision/core:image_processing_options", - "//mediapipe/tasks/cc/vision/core:running_mode", - "//mediapipe/tasks/cc/vision/core:vision_task_api_factory", - "//mediapipe/tasks/cc/vision/gesture_recognizer/proto:gesture_classifier_graph_options_cc_proto", - "//mediapipe/tasks/cc/vision/gesture_recognizer/proto:gesture_recognizer_graph_options_cc_proto", - "//mediapipe/tasks/cc/vision/gesture_recognizer/proto:hand_gesture_recognizer_graph_options_cc_proto", - "//mediapipe/tasks/cc/vision/hand_detector/proto:hand_detector_graph_options_cc_proto", - "//mediapipe/tasks/cc/vision/hand_landmarker/proto:hand_landmarker_graph_options_cc_proto", - "//mediapipe/tasks/cc/vision/hand_landmarker/proto:hand_landmarks_detector_graph_options_cc_proto", - "@com_google_absl//absl/memory", - "@com_google_absl//absl/status", - "@com_google_absl//absl/status:statusor", - "@org_tensorflow//tensorflow/lite/core/api:op_resolver", - "@org_tensorflow//tensorflow/lite/kernels:builtin_ops", - ], -) diff --git a/mediapipe/tasks/cc/vision/gesture_recognizer/calculators/landmarks_to_matrix_calculator.cc b/mediapipe/tasks/cc/vision/gesture_recognizer/calculators/landmarks_to_matrix_calculator.cc index 277bb170a..088f97c29 100644 --- a/mediapipe/tasks/cc/vision/gesture_recognizer/calculators/landmarks_to_matrix_calculator.cc +++ b/mediapipe/tasks/cc/vision/gesture_recognizer/calculators/landmarks_to_matrix_calculator.cc @@ -35,6 +35,8 @@ limitations under the License. namespace mediapipe { namespace api2 { +using ::mediapipe::NormalizedRect; + namespace { constexpr char kLandmarksTag[] = "LANDMARKS"; diff --git a/mediapipe/tasks/cc/vision/gesture_recognizer/calculators/landmarks_to_matrix_calculator_test.cc b/mediapipe/tasks/cc/vision/gesture_recognizer/calculators/landmarks_to_matrix_calculator_test.cc index fe6f1162b..a1a44c8d1 100644 --- a/mediapipe/tasks/cc/vision/gesture_recognizer/calculators/landmarks_to_matrix_calculator_test.cc +++ b/mediapipe/tasks/cc/vision/gesture_recognizer/calculators/landmarks_to_matrix_calculator_test.cc @@ -33,6 +33,8 @@ namespace mediapipe { namespace { +using ::mediapipe::NormalizedRect; + constexpr char kLandmarksTag[] = "LANDMARKS"; constexpr char kWorldLandmarksTag[] = "WORLD_LANDMARKS"; constexpr char kImageSizeTag[] = "IMAGE_SIZE"; diff --git a/mediapipe/tasks/cc/vision/gesture_recognizer/gesture_recognizer.cc b/mediapipe/tasks/cc/vision/gesture_recognizer/gesture_recognizer.cc index 8d555b12c..91a5ec213 100644 --- a/mediapipe/tasks/cc/vision/gesture_recognizer/gesture_recognizer.cc +++ b/mediapipe/tasks/cc/vision/gesture_recognizer/gesture_recognizer.cc @@ -31,7 +31,6 @@ limitations under the License. #include "mediapipe/framework/formats/rect.pb.h" #include "mediapipe/framework/packet.h" #include "mediapipe/tasks/cc/common.h" -#include "mediapipe/tasks/cc/components/image_preprocessing.h" #include "mediapipe/tasks/cc/components/processors/proto/classifier_options.pb.h" #include "mediapipe/tasks/cc/core/base_task_api.h" #include "mediapipe/tasks/cc/core/model_resources.h" @@ -58,6 +57,8 @@ namespace { using GestureRecognizerGraphOptionsProto = ::mediapipe::tasks::vision:: gesture_recognizer::proto::GestureRecognizerGraphOptions; +using ::mediapipe::NormalizedRect; + constexpr char kHandGestureSubgraphTypeName[] = "mediapipe.tasks.vision.gesture_recognizer.GestureRecognizerGraph"; @@ -150,11 +151,11 @@ ConvertGestureRecognizerGraphOptionsProto(GestureRecognizerOptions* options) { auto custom_gestures_classifier_options_proto = std::make_unique( components::processors::ConvertClassifierOptionsToProto( - &(options->canned_gestures_classifier_options))); + &(options->custom_gestures_classifier_options))); hand_gesture_recognizer_graph_options ->mutable_custom_gesture_classifier_graph_options() ->mutable_classifier_options() - ->Swap(canned_gestures_classifier_options_proto.get()); + ->Swap(custom_gestures_classifier_options_proto.get()); return options_proto; } diff --git a/mediapipe/tasks/cc/vision/gesture_recognizer/gesture_recognizer_graph.cc b/mediapipe/tasks/cc/vision/gesture_recognizer/gesture_recognizer_graph.cc index 47d95100b..b6f6c88da 100644 --- a/mediapipe/tasks/cc/vision/gesture_recognizer/gesture_recognizer_graph.cc +++ b/mediapipe/tasks/cc/vision/gesture_recognizer/gesture_recognizer_graph.cc @@ -22,6 +22,7 @@ limitations under the License. #include "mediapipe/framework/api2/builder.h" #include "mediapipe/framework/api2/port.h" #include "mediapipe/framework/formats/classification.pb.h" +#include "mediapipe/framework/formats/detection.pb.h" #include "mediapipe/framework/formats/image.h" #include "mediapipe/framework/formats/landmark.pb.h" #include "mediapipe/framework/formats/rect.pb.h" @@ -46,6 +47,7 @@ namespace gesture_recognizer { namespace { +using ::mediapipe::NormalizedRect; using ::mediapipe::api2::Input; using ::mediapipe::api2::Output; using ::mediapipe::api2::builder::Graph; @@ -67,6 +69,9 @@ constexpr char kHandednessTag[] = "HANDEDNESS"; constexpr char kImageSizeTag[] = "IMAGE_SIZE"; constexpr char kHandGesturesTag[] = "HAND_GESTURES"; constexpr char kHandTrackingIdsTag[] = "HAND_TRACKING_IDS"; +constexpr char kRectNextFrameTag[] = "HAND_RECT_NEXT_FRAME"; +constexpr char kPalmRectsTag[] = "PALM_RECTS"; +constexpr char kPalmDetectionsTag[] = "PALM_DETECTIONS"; constexpr char kHandLandmarkerBundleAssetName[] = "hand_landmarker.task"; constexpr char kHandGestureRecognizerBundleAssetName[] = "hand_gesture_recognizer.task"; @@ -76,6 +81,9 @@ struct GestureRecognizerOutputs { Source> handedness; Source> hand_landmarks; Source> hand_world_landmarks; + Source> hand_rects_next_frame; + Source> palm_rects; + Source> palm_detections; Source image; }; @@ -134,9 +142,10 @@ absl::Status SetSubTaskBaseOptions(const ModelAssetBundleResources& resources, // Inputs: // IMAGE - Image // Image to perform hand gesture recognition on. -// NORM_RECT - NormalizedRect +// NORM_RECT - NormalizedRect @Optional // Describes image rotation and region of image to perform landmarks -// detection on. +// detection on. If not provided, whole image is used for gesture +// recognition. // // Outputs: // HAND_GESTURES - std::vector @@ -207,11 +216,12 @@ class GestureRecognizerGraph : public core::ModelTaskGraph { !sc->Service(::mediapipe::tasks::core::kModelResourcesCacheService) .IsAvailable())); } - ASSIGN_OR_RETURN(auto hand_gesture_recognition_output, - BuildGestureRecognizerGraph( - *sc->MutableOptions(), - graph[Input(kImageTag)], - graph[Input(kNormRectTag)], graph)); + ASSIGN_OR_RETURN( + auto hand_gesture_recognition_output, + BuildGestureRecognizerGraph( + *sc->MutableOptions(), + graph[Input(kImageTag)], + graph[Input::Optional(kNormRectTag)], graph)); hand_gesture_recognition_output.gesture >> graph[Output>(kHandGesturesTag)]; hand_gesture_recognition_output.handedness >> @@ -221,6 +231,12 @@ class GestureRecognizerGraph : public core::ModelTaskGraph { hand_gesture_recognition_output.hand_world_landmarks >> graph[Output>(kWorldLandmarksTag)]; hand_gesture_recognition_output.image >> graph[Output(kImageTag)]; + hand_gesture_recognition_output.hand_rects_next_frame >> + graph[Output>(kRectNextFrameTag)]; + hand_gesture_recognition_output.palm_rects >> + graph[Output>(kPalmRectsTag)]; + hand_gesture_recognition_output.palm_detections >> + graph[Output>(kPalmDetectionsTag)]; return graph.GetConfig(); } @@ -278,7 +294,17 @@ class GestureRecognizerGraph : public core::ModelTaskGraph { /*handedness=*/handedness, /*hand_landmarks=*/hand_landmarks, /*hand_world_landmarks=*/hand_world_landmarks, - /*image=*/hand_landmarker_graph[Output(kImageTag)]}; + /*hand_rects_next_frame =*/ + hand_landmarker_graph[Output>( + kRectNextFrameTag)], + /*palm_rects =*/ + hand_landmarker_graph[Output>( + kPalmRectsTag)], + /*palm_detections =*/ + hand_landmarker_graph[Output>( + kPalmDetectionsTag)], + /*image=*/hand_landmarker_graph[Output(kImageTag)], + }; } }; diff --git a/mediapipe/tasks/cc/vision/gesture_recognizer/hand_gesture_recognizer_graph.cc b/mediapipe/tasks/cc/vision/gesture_recognizer/hand_gesture_recognizer_graph.cc index 7b6a8c79d..4db57e85b 100644 --- a/mediapipe/tasks/cc/vision/gesture_recognizer/hand_gesture_recognizer_graph.cc +++ b/mediapipe/tasks/cc/vision/gesture_recognizer/hand_gesture_recognizer_graph.cc @@ -29,8 +29,6 @@ limitations under the License. #include "mediapipe/framework/formats/tensor.h" #include "mediapipe/tasks/cc/common.h" #include "mediapipe/tasks/cc/components/processors/classification_postprocessing_graph.h" -#include "mediapipe/tasks/cc/components/processors/proto/classification_postprocessing_graph_options.pb.h" -#include "mediapipe/tasks/cc/components/processors/proto/classifier_options.pb.h" #include "mediapipe/tasks/cc/core/model_asset_bundle_resources.h" #include "mediapipe/tasks/cc/core/model_resources.h" #include "mediapipe/tasks/cc/core/model_resources_cache.h" @@ -54,6 +52,7 @@ namespace gesture_recognizer { namespace { +using ::mediapipe::NormalizedRect; using ::mediapipe::api2::Input; using ::mediapipe::api2::Output; using ::mediapipe::api2::builder::Graph; diff --git a/mediapipe/tasks/cc/vision/gesture_recognizer/proto/gesture_classifier_graph_options.proto b/mediapipe/tasks/cc/vision/gesture_recognizer/proto/gesture_classifier_graph_options.proto index dcefa075f..edbabc018 100644 --- a/mediapipe/tasks/cc/vision/gesture_recognizer/proto/gesture_classifier_graph_options.proto +++ b/mediapipe/tasks/cc/vision/gesture_recognizer/proto/gesture_classifier_graph_options.proto @@ -18,6 +18,7 @@ syntax = "proto2"; package mediapipe.tasks.vision.gesture_recognizer.proto; import "mediapipe/framework/calculator.proto"; +import "mediapipe/framework/calculator_options.proto"; import "mediapipe/tasks/cc/components/processors/proto/classifier_options.proto"; import "mediapipe/tasks/cc/core/proto/base_options.proto"; diff --git a/mediapipe/tasks/cc/vision/gesture_recognizer/proto/gesture_embedder_graph_options.proto b/mediapipe/tasks/cc/vision/gesture_recognizer/proto/gesture_embedder_graph_options.proto index bff4e0a9c..df909a6db 100644 --- a/mediapipe/tasks/cc/vision/gesture_recognizer/proto/gesture_embedder_graph_options.proto +++ b/mediapipe/tasks/cc/vision/gesture_recognizer/proto/gesture_embedder_graph_options.proto @@ -18,6 +18,7 @@ syntax = "proto2"; package mediapipe.tasks.vision.gesture_recognizer.proto; import "mediapipe/framework/calculator.proto"; +import "mediapipe/framework/calculator_options.proto"; import "mediapipe/tasks/cc/core/proto/base_options.proto"; option java_package = "com.google.mediapipe.tasks.vision.gesturerecognizer.proto"; diff --git a/mediapipe/tasks/cc/vision/gesture_recognizer/proto/gesture_recognizer_graph_options.proto b/mediapipe/tasks/cc/vision/gesture_recognizer/proto/gesture_recognizer_graph_options.proto index 57d8a3746..fef22c07c 100644 --- a/mediapipe/tasks/cc/vision/gesture_recognizer/proto/gesture_recognizer_graph_options.proto +++ b/mediapipe/tasks/cc/vision/gesture_recognizer/proto/gesture_recognizer_graph_options.proto @@ -18,6 +18,7 @@ syntax = "proto2"; package mediapipe.tasks.vision.gesture_recognizer.proto; import "mediapipe/framework/calculator.proto"; +import "mediapipe/framework/calculator_options.proto"; import "mediapipe/tasks/cc/core/proto/base_options.proto"; import "mediapipe/tasks/cc/vision/gesture_recognizer/proto/hand_gesture_recognizer_graph_options.proto"; import "mediapipe/tasks/cc/vision/hand_landmarker/proto/hand_landmarker_graph_options.proto"; diff --git a/mediapipe/tasks/cc/vision/gesture_recognizer/proto/hand_gesture_recognizer_graph_options.proto b/mediapipe/tasks/cc/vision/gesture_recognizer/proto/hand_gesture_recognizer_graph_options.proto index 7df2fed37..ae85509da 100644 --- a/mediapipe/tasks/cc/vision/gesture_recognizer/proto/hand_gesture_recognizer_graph_options.proto +++ b/mediapipe/tasks/cc/vision/gesture_recognizer/proto/hand_gesture_recognizer_graph_options.proto @@ -18,6 +18,7 @@ syntax = "proto2"; package mediapipe.tasks.vision.gesture_recognizer.proto; import "mediapipe/framework/calculator.proto"; +import "mediapipe/framework/calculator_options.proto"; import "mediapipe/tasks/cc/core/proto/base_options.proto"; import "mediapipe/tasks/cc/vision/gesture_recognizer/proto/gesture_classifier_graph_options.proto"; import "mediapipe/tasks/cc/vision/gesture_recognizer/proto/gesture_embedder_graph_options.proto"; diff --git a/mediapipe/tasks/cc/vision/hand_detector/BUILD b/mediapipe/tasks/cc/vision/hand_detector/BUILD index 71cef6270..55162d09b 100644 --- a/mediapipe/tasks/cc/vision/hand_detector/BUILD +++ b/mediapipe/tasks/cc/vision/hand_detector/BUILD @@ -46,7 +46,7 @@ cc_library( "//mediapipe/framework/formats:rect_cc_proto", "//mediapipe/framework/formats:tensor", "//mediapipe/tasks/cc:common", - "//mediapipe/tasks/cc/components:image_preprocessing", + "//mediapipe/tasks/cc/components/processors:image_preprocessing_graph", "//mediapipe/tasks/cc/core:model_resources", "//mediapipe/tasks/cc/core:model_task_graph", "//mediapipe/tasks/cc/core:utils", diff --git a/mediapipe/tasks/cc/vision/hand_detector/hand_detector_graph.cc b/mediapipe/tasks/cc/vision/hand_detector/hand_detector_graph.cc index 06bb2e549..d7163e331 100644 --- a/mediapipe/tasks/cc/vision/hand_detector/hand_detector_graph.cc +++ b/mediapipe/tasks/cc/vision/hand_detector/hand_detector_graph.cc @@ -35,7 +35,7 @@ limitations under the License. #include "mediapipe/framework/formats/rect.pb.h" #include "mediapipe/framework/formats/tensor.h" #include "mediapipe/tasks/cc/common.h" -#include "mediapipe/tasks/cc/components/image_preprocessing.h" +#include "mediapipe/tasks/cc/components/processors/image_preprocessing_graph.h" #include "mediapipe/tasks/cc/core/model_resources.h" #include "mediapipe/tasks/cc/core/model_task_graph.h" #include "mediapipe/tasks/cc/core/proto/inference_subgraph.pb.h" @@ -50,6 +50,7 @@ namespace hand_detector { namespace { +using ::mediapipe::NormalizedRect; using ::mediapipe::api2::Input; using ::mediapipe::api2::Output; using ::mediapipe::api2::builder::Graph; @@ -149,9 +150,9 @@ void ConfigureRectTransformationCalculator( // Inputs: // IMAGE - Image // Image to perform detection on. -// NORM_RECT - NormalizedRect -// Describes image rotation and region of image to perform detection -// on. +// NORM_RECT - NormalizedRect @Optional +// Describes image rotation and region of image to perform detection on. If +// not provided, whole image is used for hand detection. // // Outputs: // PALM_DETECTIONS - std::vector @@ -196,11 +197,12 @@ class HandDetectorGraph : public core::ModelTaskGraph { ASSIGN_OR_RETURN(const auto* model_resources, CreateModelResources(sc)); Graph graph; - ASSIGN_OR_RETURN(auto hand_detection_outs, - BuildHandDetectionSubgraph( - sc->Options(), - *model_resources, graph[Input(kImageTag)], - graph[Input(kNormRectTag)], graph)); + ASSIGN_OR_RETURN( + auto hand_detection_outs, + BuildHandDetectionSubgraph( + sc->Options(), *model_resources, + graph[Input(kImageTag)], + graph[Input::Optional(kNormRectTag)], graph)); hand_detection_outs.palm_detections >> graph[Output>(kPalmDetectionsTag)]; hand_detection_outs.hand_rects >> @@ -226,21 +228,23 @@ class HandDetectorGraph : public core::ModelTaskGraph { Source norm_rect_in, Graph& graph) { // Add image preprocessing subgraph. The model expects aspect ratio // unchanged. - auto& preprocessing = - graph.AddNode("mediapipe.tasks.components.ImagePreprocessingSubgraph"); + auto& preprocessing = graph.AddNode( + "mediapipe.tasks.components.processors.ImagePreprocessingGraph"); auto& image_to_tensor_options = *preprocessing - .GetOptions() + .GetOptions() .mutable_image_to_tensor_options(); image_to_tensor_options.set_keep_aspect_ratio(true); image_to_tensor_options.set_border_mode( mediapipe::ImageToTensorCalculatorOptions::BORDER_ZERO); - bool use_gpu = components::DetermineImagePreprocessingGpuBackend( - subgraph_options.base_options().acceleration()); - MP_RETURN_IF_ERROR(ConfigureImagePreprocessing( + bool use_gpu = + components::processors::DetermineImagePreprocessingGpuBackend( + subgraph_options.base_options().acceleration()); + MP_RETURN_IF_ERROR(components::processors::ConfigureImagePreprocessingGraph( model_resources, use_gpu, - &preprocessing - .GetOptions())); + &preprocessing.GetOptions< + components::processors::proto::ImagePreprocessingGraphOptions>())); image_in >> preprocessing.In("IMAGE"); norm_rect_in >> preprocessing.In("NORM_RECT"); auto preprocessed_tensors = preprocessing.Out("TENSORS"); diff --git a/mediapipe/tasks/cc/vision/hand_detector/hand_detector_graph_test.cc b/mediapipe/tasks/cc/vision/hand_detector/hand_detector_graph_test.cc index cbbc0e193..f4e5f8c7d 100644 --- a/mediapipe/tasks/cc/vision/hand_detector/hand_detector_graph_test.cc +++ b/mediapipe/tasks/cc/vision/hand_detector/hand_detector_graph_test.cc @@ -53,6 +53,7 @@ namespace { using ::file::Defaults; using ::file::GetTextProto; +using ::mediapipe::NormalizedRect; using ::mediapipe::api2::Input; using ::mediapipe::api2::Output; using ::mediapipe::api2::builder::Graph; diff --git a/mediapipe/tasks/cc/vision/hand_detector/proto/hand_detector_graph_options.proto b/mediapipe/tasks/cc/vision/hand_detector/proto/hand_detector_graph_options.proto index a009f2365..bede70da5 100644 --- a/mediapipe/tasks/cc/vision/hand_detector/proto/hand_detector_graph_options.proto +++ b/mediapipe/tasks/cc/vision/hand_detector/proto/hand_detector_graph_options.proto @@ -18,6 +18,7 @@ syntax = "proto2"; package mediapipe.tasks.vision.hand_detector.proto; import "mediapipe/framework/calculator.proto"; +import "mediapipe/framework/calculator_options.proto"; import "mediapipe/tasks/cc/core/proto/base_options.proto"; option java_package = "com.google.mediapipe.tasks.vision.handdetector.proto"; diff --git a/mediapipe/tasks/cc/vision/hand_landmarker/BUILD b/mediapipe/tasks/cc/vision/hand_landmarker/BUILD index 5c5073fc2..2552e7a10 100644 --- a/mediapipe/tasks/cc/vision/hand_landmarker/BUILD +++ b/mediapipe/tasks/cc/vision/hand_landmarker/BUILD @@ -18,6 +18,48 @@ package(default_visibility = [ licenses(["notice"]) +# Docs for Mediapipe Tasks Hand Landmarker +# https://developers.google.com/mediapipe/solutions/vision/hand_landmarker +cc_library( + name = "hand_landmarker", + srcs = ["hand_landmarker.cc"], + hdrs = ["hand_landmarker.h"], + visibility = ["//visibility:public"], + deps = [ + ":hand_landmarker_graph", + ":hand_landmarker_result", + "//mediapipe/framework/api2:builder", + "//mediapipe/framework/api2:port", + "//mediapipe/framework/formats:classification_cc_proto", + "//mediapipe/framework/formats:image", + "//mediapipe/framework/formats:landmark_cc_proto", + "//mediapipe/framework/formats:rect_cc_proto", + "//mediapipe/tasks/cc:common", + "//mediapipe/tasks/cc/components/processors:classifier_options", + "//mediapipe/tasks/cc/components/processors/proto:classifier_options_cc_proto", + "//mediapipe/tasks/cc/core:base_options", + "//mediapipe/tasks/cc/core:base_task_api", + "//mediapipe/tasks/cc/core:model_resources", + "//mediapipe/tasks/cc/core:task_runner", + "//mediapipe/tasks/cc/core:utils", + "//mediapipe/tasks/cc/core/proto:inference_subgraph_cc_proto", + "//mediapipe/tasks/cc/vision/core:base_vision_task_api", + "//mediapipe/tasks/cc/vision/core:image_processing_options", + "//mediapipe/tasks/cc/vision/core:running_mode", + "//mediapipe/tasks/cc/vision/core:vision_task_api_factory", + "//mediapipe/tasks/cc/vision/hand_detector/proto:hand_detector_graph_options_cc_proto", + "//mediapipe/tasks/cc/vision/hand_landmarker/proto:hand_landmarker_graph_options_cc_proto", + "//mediapipe/tasks/cc/vision/hand_landmarker/proto:hand_landmarks_detector_graph_options_cc_proto", + "@com_google_absl//absl/status:statusor", + ], +) + +cc_library( + name = "hand_landmark", + hdrs = ["hand_landmark.h"], + visibility = ["//visibility:public"], +) + cc_library( name = "hand_landmarks_detector_graph", srcs = ["hand_landmarks_detector_graph.cc"], @@ -52,7 +94,7 @@ cc_library( "//mediapipe/modules/hand_landmark/calculators:hand_landmarks_to_rect_calculator", "//mediapipe/tasks/cc:common", "//mediapipe/tasks/cc/components/utils:gate", - "//mediapipe/tasks/cc/components:image_preprocessing", + "//mediapipe/tasks/cc/components/processors:image_preprocessing_graph", "//mediapipe/tasks/cc/core:model_resources", "//mediapipe/tasks/cc/core:model_task_graph", "//mediapipe/tasks/cc/core:utils", @@ -112,44 +154,14 @@ cc_library( cc_library( name = "hand_landmarker_result", + srcs = ["hand_landmarker_result.cc"], hdrs = ["hand_landmarker_result.h"], + visibility = ["//visibility:public"], deps = [ "//mediapipe/framework/formats:classification_cc_proto", "//mediapipe/framework/formats:landmark_cc_proto", - ], -) - -cc_library( - name = "hand_landmarker", - srcs = ["hand_landmarker.cc"], - hdrs = ["hand_landmarker.h"], - deps = [ - ":hand_landmarker_graph", - ":hand_landmarker_result", - "//mediapipe/framework/api2:builder", - "//mediapipe/framework/api2:port", - "//mediapipe/framework/formats:classification_cc_proto", - "//mediapipe/framework/formats:image", - "//mediapipe/framework/formats:landmark_cc_proto", - "//mediapipe/framework/formats:rect_cc_proto", - "//mediapipe/tasks/cc:common", - "//mediapipe/tasks/cc/components:image_preprocessing", - "//mediapipe/tasks/cc/components/processors:classifier_options", - "//mediapipe/tasks/cc/components/processors/proto:classifier_options_cc_proto", - "//mediapipe/tasks/cc/core:base_options", - "//mediapipe/tasks/cc/core:base_task_api", - "//mediapipe/tasks/cc/core:model_resources", - "//mediapipe/tasks/cc/core:task_runner", - "//mediapipe/tasks/cc/core:utils", - "//mediapipe/tasks/cc/core/proto:inference_subgraph_cc_proto", - "//mediapipe/tasks/cc/vision/core:base_vision_task_api", - "//mediapipe/tasks/cc/vision/core:image_processing_options", - "//mediapipe/tasks/cc/vision/core:running_mode", - "//mediapipe/tasks/cc/vision/core:vision_task_api_factory", - "//mediapipe/tasks/cc/vision/hand_detector/proto:hand_detector_graph_options_cc_proto", - "//mediapipe/tasks/cc/vision/hand_landmarker/proto:hand_landmarker_graph_options_cc_proto", - "//mediapipe/tasks/cc/vision/hand_landmarker/proto:hand_landmarks_detector_graph_options_cc_proto", - "@com_google_absl//absl/status:statusor", + "//mediapipe/tasks/cc/components/containers:classification_result", + "//mediapipe/tasks/cc/components/containers:landmark", ], ) diff --git a/mediapipe/tasks/cc/vision/hand_landmarker/calculators/hand_association_calculator.cc b/mediapipe/tasks/cc/vision/hand_landmarker/calculators/hand_association_calculator.cc index b6df80588..dffdbdd38 100644 --- a/mediapipe/tasks/cc/vision/hand_landmarker/calculators/hand_association_calculator.cc +++ b/mediapipe/tasks/cc/vision/hand_landmarker/calculators/hand_association_calculator.cc @@ -27,6 +27,8 @@ limitations under the License. namespace mediapipe::api2 { +using ::mediapipe::NormalizedRect; + // HandAssociationCalculator accepts multiple inputs of vectors of // NormalizedRect. The output is a vector of NormalizedRect that contains // rects from the input vectors that don't overlap with each other. When two diff --git a/mediapipe/tasks/cc/vision/hand_landmarker/calculators/hand_association_calculator_test.cc b/mediapipe/tasks/cc/vision/hand_landmarker/calculators/hand_association_calculator_test.cc index cb3130854..138164209 100644 --- a/mediapipe/tasks/cc/vision/hand_landmarker/calculators/hand_association_calculator_test.cc +++ b/mediapipe/tasks/cc/vision/hand_landmarker/calculators/hand_association_calculator_test.cc @@ -26,6 +26,8 @@ limitations under the License. namespace mediapipe { namespace { +using ::mediapipe::NormalizedRect; + class HandAssociationCalculatorTest : public testing::Test { protected: HandAssociationCalculatorTest() { diff --git a/mediapipe/tasks/cc/vision/hand_landmarker/calculators/hand_landmarks_deduplication_calculator.cc b/mediapipe/tasks/cc/vision/hand_landmarker/calculators/hand_landmarks_deduplication_calculator.cc index 564184c64..d875de98f 100644 --- a/mediapipe/tasks/cc/vision/hand_landmarker/calculators/hand_landmarks_deduplication_calculator.cc +++ b/mediapipe/tasks/cc/vision/hand_landmarker/calculators/hand_landmarks_deduplication_calculator.cc @@ -41,10 +41,11 @@ limitations under the License. namespace mediapipe::api2 { namespace { +using ::mediapipe::NormalizedRect; using ::mediapipe::api2::Input; using ::mediapipe::api2::Output; using ::mediapipe::api2::builder::Source; -using ::mediapipe::tasks::components::containers::Rect; +using ::mediapipe::tasks::components::containers::RectF; using ::mediapipe::tasks::vision::utils::CalculateIOU; using ::mediapipe::tasks::vision::utils::DuplicatesFinder; @@ -126,7 +127,7 @@ absl::StatusOr HandBaselineDistance( return distance; } -Rect CalculateBound(const NormalizedLandmarkList& list) { +RectF CalculateBound(const NormalizedLandmarkList& list) { constexpr float kMinInitialValue = std::numeric_limits::max(); constexpr float kMaxInitialValue = std::numeric_limits::lowest(); @@ -144,10 +145,10 @@ Rect CalculateBound(const NormalizedLandmarkList& list) { } // Populate normalized non rotated face bounding box - return Rect{/*left=*/bounding_box_left, - /*top=*/bounding_box_top, - /*right=*/bounding_box_right, - /*bottom=*/bounding_box_bottom}; + return RectF{/*left=*/bounding_box_left, + /*top=*/bounding_box_top, + /*right=*/bounding_box_right, + /*bottom=*/bounding_box_bottom}; } // Uses IoU and distance of some corresponding hand landmarks to detect @@ -172,7 +173,7 @@ class HandDuplicatesFinder : public DuplicatesFinder { const int num = multi_landmarks.size(); std::vector baseline_distances; baseline_distances.reserve(num); - std::vector bounds; + std::vector bounds; bounds.reserve(num); for (const NormalizedLandmarkList& list : multi_landmarks) { ASSIGN_OR_RETURN(const float baseline_distance, diff --git a/mediapipe/tasks/cc/vision/hand_landmarker/hand_landmark.h b/mediapipe/tasks/cc/vision/hand_landmarker/hand_landmark.h new file mode 100644 index 000000000..c8dbc9254 --- /dev/null +++ b/mediapipe/tasks/cc/vision/hand_landmarker/hand_landmark.h @@ -0,0 +1,48 @@ +/* Copyright 2022 The MediaPipe Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef MEDIAPIPE_TASKS_CC_VISION_HAND_LANDMARKER_HAND_LANDMARK_H_ +#define MEDIAPIPE_TASKS_CC_VISION_HAND_LANDMARKER_HAND_LANDMARK_H_ + +namespace mediapipe::tasks::vision::hand_landmarker { + +// The 21 hand landmarks. +enum HandLandmark { + WRIST = 0, + THUMB_CMC = 1, + THUMB_MCP = 2, + THUMB_IP = 3, + THUMB_TIP = 4, + INDEX_FINGER_MCP = 5, + INDEX_FINGER_PIP = 6, + INDEX_FINGER_DIP = 7, + INDEX_FINGER_TIP = 8, + MIDDLE_FINGER_MCP = 9, + MIDDLE_FINGER_PIP = 10, + MIDDLE_FINGER_DIP = 11, + MIDDLE_FINGER_TIP = 12, + RING_FINGER_MCP = 13, + RING_FINGER_PIP = 14, + RING_FINGER_DIP = 15, + RING_FINGER_TIP = 16, + PINKY_MCP = 17, + PINKY_PIP = 18, + PINKY_DIP = 19, + PINKY_TIP = 20 +}; + +} // namespace mediapipe::tasks::vision::hand_landmarker + +#endif // MEDIAPIPE_TASKS_CC_VISION_HAND_LANDMARKER_HAND_LANDMARK_H_ diff --git a/mediapipe/tasks/cc/vision/hand_landmarker/hand_landmarker.cc b/mediapipe/tasks/cc/vision/hand_landmarker/hand_landmarker.cc index 3a9ed5bc2..ab66fe136 100644 --- a/mediapipe/tasks/cc/vision/hand_landmarker/hand_landmarker.cc +++ b/mediapipe/tasks/cc/vision/hand_landmarker/hand_landmarker.cc @@ -22,7 +22,6 @@ limitations under the License. #include "mediapipe/framework/formats/landmark.pb.h" #include "mediapipe/framework/formats/rect.pb.h" #include "mediapipe/tasks/cc/common.h" -#include "mediapipe/tasks/cc/components/image_preprocessing.h" #include "mediapipe/tasks/cc/components/processors/proto/classifier_options.pb.h" #include "mediapipe/tasks/cc/core/base_task_api.h" #include "mediapipe/tasks/cc/core/model_resources.h" @@ -47,6 +46,8 @@ namespace { using HandLandmarkerGraphOptionsProto = ::mediapipe::tasks::vision:: hand_landmarker::proto::HandLandmarkerGraphOptions; +using ::mediapipe::NormalizedRect; + constexpr char kHandLandmarkerGraphTypeName[] = "mediapipe.tasks.vision.hand_landmarker.HandLandmarkerGraph"; @@ -154,9 +155,13 @@ absl::StatusOr> HandLandmarker::Create( Packet hand_world_landmarks_packet = status_or_packets.value()[kHandWorldLandmarksStreamName]; result_callback( - {{handedness_packet.Get>(), - hand_landmarks_packet.Get>(), - hand_world_landmarks_packet.Get>()}}, + ConvertToHandLandmarkerResult( + /* handedness= */ handedness_packet + .Get>(), + /* hand_landmarks= */ + hand_landmarks_packet.Get>(), + /* hand_world_landmarks= */ + hand_world_landmarks_packet.Get>()), image_packet.Get(), hand_landmarks_packet.Timestamp().Value() / kMicroSecondsPerMilliSecond); @@ -192,15 +197,21 @@ absl::StatusOr HandLandmarker::Detect( if (output_packets[kHandLandmarksStreamName].IsEmpty()) { return {HandLandmarkerResult()}; } - return {{/* handedness= */ - {output_packets[kHandednessStreamName] - .Get>()}, - /* hand_landmarks= */ - {output_packets[kHandLandmarksStreamName] - .Get>()}, - /* hand_world_landmarks */ - {output_packets[kHandWorldLandmarksStreamName] - .Get>()}}}; + return ConvertToHandLandmarkerResult(/* handedness= */ + output_packets[kHandednessStreamName] + .Get>(), + /* hand_landmarks= */ + output_packets[kHandLandmarksStreamName] + .Get>(), + /* hand_world_landmarks */ + output_packets + [kHandWorldLandmarksStreamName] + .Get>()); } absl::StatusOr HandLandmarker::DetectForVideo( @@ -227,17 +238,21 @@ absl::StatusOr HandLandmarker::DetectForVideo( if (output_packets[kHandLandmarksStreamName].IsEmpty()) { return {HandLandmarkerResult()}; } - return { - {/* handedness= */ - {output_packets[kHandednessStreamName] - .Get>()}, - /* hand_landmarks= */ - {output_packets[kHandLandmarksStreamName] - .Get>()}, - /* hand_world_landmarks */ - {output_packets[kHandWorldLandmarksStreamName] - .Get>()}}, - }; + return ConvertToHandLandmarkerResult(/* handedness= */ + output_packets[kHandednessStreamName] + .Get>(), + /* hand_landmarks= */ + output_packets[kHandLandmarksStreamName] + .Get>(), + /* hand_world_landmarks */ + output_packets + [kHandWorldLandmarksStreamName] + .Get>()); } absl::Status HandLandmarker::DetectAsync( diff --git a/mediapipe/tasks/cc/vision/hand_landmarker/hand_landmarker_graph.cc b/mediapipe/tasks/cc/vision/hand_landmarker/hand_landmarker_graph.cc index 2c4133eb1..74d288ac1 100644 --- a/mediapipe/tasks/cc/vision/hand_landmarker/hand_landmarker_graph.cc +++ b/mediapipe/tasks/cc/vision/hand_landmarker/hand_landmarker_graph.cc @@ -49,6 +49,7 @@ namespace hand_landmarker { namespace { +using ::mediapipe::NormalizedRect; using ::mediapipe::api2::Input; using ::mediapipe::api2::Output; using ::mediapipe::api2::builder::Graph; @@ -135,9 +136,10 @@ absl::Status SetSubTaskBaseOptions(const ModelAssetBundleResources& resources, // Inputs: // IMAGE - Image // Image to perform hand landmarks detection on. -// NORM_RECT - NormalizedRect +// NORM_RECT - NormalizedRect @Optional // Describes image rotation and region of image to perform landmarks -// detection on. +// detection on. If not provided, whole image is used for hand landmarks +// detection. // // Outputs: // LANDMARKS: - std::vector @@ -217,11 +219,12 @@ class HandLandmarkerGraph : public core::ModelTaskGraph { !sc->Service(::mediapipe::tasks::core::kModelResourcesCacheService) .IsAvailable())); } - ASSIGN_OR_RETURN(auto hand_landmarker_outputs, - BuildHandLandmarkerGraph( - sc->Options(), - graph[Input(kImageTag)], - graph[Input(kNormRectTag)], graph)); + ASSIGN_OR_RETURN( + auto hand_landmarker_outputs, + BuildHandLandmarkerGraph( + sc->Options(), + graph[Input(kImageTag)], + graph[Input::Optional(kNormRectTag)], graph)); hand_landmarker_outputs.landmark_lists >> graph[Output>(kLandmarksTag)]; hand_landmarker_outputs.world_landmark_lists >> diff --git a/mediapipe/tasks/cc/vision/hand_landmarker/hand_landmarker_graph_test.cc b/mediapipe/tasks/cc/vision/hand_landmarker/hand_landmarker_graph_test.cc index f275486f5..c28df2c05 100644 --- a/mediapipe/tasks/cc/vision/hand_landmarker/hand_landmarker_graph_test.cc +++ b/mediapipe/tasks/cc/vision/hand_landmarker/hand_landmarker_graph_test.cc @@ -54,6 +54,7 @@ namespace { using ::file::Defaults; using ::file::GetTextProto; +using ::mediapipe::NormalizedRect; using ::mediapipe::api2::Input; using ::mediapipe::api2::Output; using ::mediapipe::api2::builder::Graph; diff --git a/mediapipe/tasks/cc/vision/hand_landmarker/hand_landmarker_result.cc b/mediapipe/tasks/cc/vision/hand_landmarker/hand_landmarker_result.cc new file mode 100644 index 000000000..9d2ae2be8 --- /dev/null +++ b/mediapipe/tasks/cc/vision/hand_landmarker/hand_landmarker_result.cc @@ -0,0 +1,56 @@ +/* Copyright 2023 The MediaPipe Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "mediapipe/tasks/cc/vision/hand_landmarker/hand_landmarker_result.h" + +#include + +#include "mediapipe/framework/formats/classification.pb.h" +#include "mediapipe/tasks/cc/components/containers/classification_result.h" +#include "mediapipe/tasks/cc/components/containers/landmark.h" + +namespace mediapipe { +namespace tasks { +namespace vision { +namespace hand_landmarker { + +HandLandmarkerResult ConvertToHandLandmarkerResult( + const std::vector& handedness_proto, + const std::vector& hand_landmarks_proto, + const std::vector& hand_world_landmarks_proto) { + HandLandmarkerResult result; + result.handedness.resize(handedness_proto.size()); + result.hand_landmarks.resize(hand_landmarks_proto.size()); + result.hand_world_landmarks.resize(hand_world_landmarks_proto.size()); + std::transform(handedness_proto.begin(), handedness_proto.end(), + result.handedness.begin(), + [](const mediapipe::ClassificationList& classification_list) { + return components::containers::ConvertToClassifications( + classification_list); + }); + std::transform(hand_landmarks_proto.begin(), hand_landmarks_proto.end(), + result.hand_landmarks.begin(), + components::containers::ConvertToNormalizedLandmarks); + std::transform(hand_world_landmarks_proto.begin(), + hand_world_landmarks_proto.end(), + result.hand_world_landmarks.begin(), + components::containers::ConvertToLandmarks); + return result; +} + +} // namespace hand_landmarker +} // namespace vision +} // namespace tasks +} // namespace mediapipe diff --git a/mediapipe/tasks/cc/vision/hand_landmarker/hand_landmarker_result.h b/mediapipe/tasks/cc/vision/hand_landmarker/hand_landmarker_result.h index 5e51c244e..1bca8e66a 100644 --- a/mediapipe/tasks/cc/vision/hand_landmarker/hand_landmarker_result.h +++ b/mediapipe/tasks/cc/vision/hand_landmarker/hand_landmarker_result.h @@ -1,4 +1,4 @@ -/* Copyright 2022 The MediaPipe Authors. All Rights Reserved. +/* Copyright 2023 The MediaPipe Authors. All Rights Reserved. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. @@ -18,6 +18,8 @@ limitations under the License. #include "mediapipe/framework/formats/classification.pb.h" #include "mediapipe/framework/formats/landmark.pb.h" +#include "mediapipe/tasks/cc/components/containers/classification_result.h" +#include "mediapipe/tasks/cc/components/containers/landmark.h" namespace mediapipe { namespace tasks { @@ -28,13 +30,18 @@ namespace hand_landmarker { // element represents a single hand detected in the image. struct HandLandmarkerResult { // Classification of handedness. - std::vector handedness; + std::vector handedness; // Detected hand landmarks in normalized image coordinates. - std::vector hand_landmarks; + std::vector hand_landmarks; // Detected hand landmarks in world coordinates. - std::vector hand_world_landmarks; + std::vector hand_world_landmarks; }; +HandLandmarkerResult ConvertToHandLandmarkerResult( + const std::vector& handedness_proto, + const std::vector& hand_landmarks_proto, + const std::vector& hand_world_landmarks_proto); + } // namespace hand_landmarker } // namespace vision } // namespace tasks diff --git a/mediapipe/tasks/cc/vision/hand_landmarker/hand_landmarker_result_test.cc b/mediapipe/tasks/cc/vision/hand_landmarker/hand_landmarker_result_test.cc new file mode 100644 index 000000000..109749b01 --- /dev/null +++ b/mediapipe/tasks/cc/vision/hand_landmarker/hand_landmarker_result_test.cc @@ -0,0 +1,88 @@ +/* Copyright 2023 The MediaPipe Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "mediapipe/tasks/cc/vision/hand_landmarker/hand_landmarker_result.h" + +#include +#include + +#include "mediapipe/framework/formats/classification.pb.h" +#include "mediapipe/framework/formats/landmark.pb.h" +#include "mediapipe/framework/port/gmock.h" +#include "mediapipe/framework/port/gtest.h" +#include "mediapipe/tasks/cc/components/containers/classification_result.h" +#include "mediapipe/tasks/cc/components/containers/landmark.h" + +namespace mediapipe { +namespace tasks { +namespace vision { +namespace hand_landmarker { + +TEST(ConvertFromProto, Succeeds) { + mediapipe::ClassificationList classification_list_proto; + mediapipe::Classification& classification_proto = + *classification_list_proto.add_classification(); + classification_proto.set_index(1); + classification_proto.set_score(0.5); + classification_proto.set_label("Left"); + classification_proto.set_display_name("Left_Hand"); + + mediapipe::NormalizedLandmarkList normalized_landmark_list_proto; + mediapipe::NormalizedLandmark& normalized_landmark_proto = + *normalized_landmark_list_proto.add_landmark(); + normalized_landmark_proto.set_x(0.1); + normalized_landmark_proto.set_y(0.2); + normalized_landmark_proto.set_z(0.3); + + mediapipe::LandmarkList landmark_list_proto; + mediapipe::Landmark& landmark_proto = *landmark_list_proto.add_landmark(); + landmark_proto.set_x(3.1); + landmark_proto.set_y(5.2); + landmark_proto.set_z(4.3); + + std::vector classification_lists = { + classification_list_proto}; + std::vector normalized_landmarks_lists = { + normalized_landmark_list_proto}; + std::vector landmarks_lists = {landmark_list_proto}; + + HandLandmarkerResult hand_landmarker_result = ConvertToHandLandmarkerResult( + classification_lists, normalized_landmarks_lists, landmarks_lists); + + EXPECT_EQ(hand_landmarker_result.handedness.size(), 1); + EXPECT_EQ(hand_landmarker_result.handedness[0].categories.size(), 1); + EXPECT_THAT( + hand_landmarker_result.handedness[0].categories[0], + testing::FieldsAre(1, testing::FloatEq(0.5), "Left", "Left_Hand")); + + EXPECT_EQ(hand_landmarker_result.hand_landmarks.size(), 1); + EXPECT_EQ(hand_landmarker_result.hand_landmarks[0].landmarks.size(), 1); + EXPECT_THAT(hand_landmarker_result.hand_landmarks[0].landmarks[0], + testing::FieldsAre(testing::FloatEq(0.1), testing::FloatEq(0.2), + testing::FloatEq(0.3), std::nullopt, + std::nullopt, std::nullopt)); + + EXPECT_EQ(hand_landmarker_result.hand_world_landmarks.size(), 1); + EXPECT_EQ(hand_landmarker_result.hand_world_landmarks[0].landmarks.size(), 1); + EXPECT_THAT(hand_landmarker_result.hand_world_landmarks[0].landmarks[0], + testing::FieldsAre(testing::FloatEq(3.1), testing::FloatEq(5.2), + testing::FloatEq(4.3), std::nullopt, + std::nullopt, std::nullopt)); +} + +} // namespace hand_landmarker +} // namespace vision +} // namespace tasks +} // namespace mediapipe diff --git a/mediapipe/tasks/cc/vision/hand_landmarker/hand_landmarker_test.cc b/mediapipe/tasks/cc/vision/hand_landmarker/hand_landmarker_test.cc index fa49a4c1f..b21f1bee9 100644 --- a/mediapipe/tasks/cc/vision/hand_landmarker/hand_landmarker_test.cc +++ b/mediapipe/tasks/cc/vision/hand_landmarker/hand_landmarker_test.cc @@ -32,6 +32,8 @@ limitations under the License. #include "mediapipe/framework/port/gmock.h" #include "mediapipe/framework/port/gtest.h" #include "mediapipe/tasks/cc/common.h" +#include "mediapipe/tasks/cc/components/containers/classification_result.h" +#include "mediapipe/tasks/cc/components/containers/landmark.h" #include "mediapipe/tasks/cc/components/containers/proto/landmarks_detection_result.pb.h" #include "mediapipe/tasks/cc/components/containers/rect.h" #include "mediapipe/tasks/cc/components/processors/proto/classifier_options.pb.h" @@ -50,18 +52,16 @@ namespace { using ::file::Defaults; using ::mediapipe::file::JoinPath; -using ::mediapipe::tasks::components::containers::Rect; +using ::mediapipe::tasks::components::containers::ConvertToClassifications; +using ::mediapipe::tasks::components::containers::ConvertToNormalizedLandmarks; +using ::mediapipe::tasks::components::containers::RectF; using ::mediapipe::tasks::containers::proto::LandmarksDetectionResult; using ::mediapipe::tasks::vision::core::ImageProcessingOptions; -using ::testing::EqualsProto; using ::testing::HasSubstr; using ::testing::Optional; -using ::testing::Pointwise; using ::testing::TestParamInfo; using ::testing::TestWithParam; using ::testing::Values; -using ::testing::proto::Approximately; -using ::testing::proto::Partially; constexpr char kTestDataDirectory[] = "/mediapipe/tasks/testdata/vision/"; constexpr char kHandLandmarkerBundleAsset[] = "hand_landmarker.task"; @@ -74,7 +74,6 @@ constexpr char kPointingUpImage[] = "pointing_up.jpg"; constexpr char kPointingUpRotatedImage[] = "pointing_up_rotated.jpg"; constexpr char kNoHandsImage[] = "cats_and_dogs.jpg"; -constexpr float kLandmarksFractionDiff = 0.03; // percentage constexpr float kLandmarksAbsMargin = 0.03; constexpr float kHandednessMargin = 0.05; @@ -101,13 +100,47 @@ HandLandmarkerResult GetExpectedHandLandmarkerResult( const auto landmarks_detection_result = GetLandmarksDetectionResult(file_name); expected_results.hand_landmarks.push_back( - landmarks_detection_result.landmarks()); + ConvertToNormalizedLandmarks(landmarks_detection_result.landmarks())); expected_results.handedness.push_back( - landmarks_detection_result.classifications()); + ConvertToClassifications(landmarks_detection_result.classifications())); } return expected_results; } +MATCHER_P2(HandednessMatches, expected_handedness, tolerance, "") { + for (int i = 0; i < arg.size(); i++) { + for (int j = 0; j < arg[i].categories.size(); j++) { + if (arg[i].categories[j].index != + expected_handedness[i].categories[j].index) { + return false; + } + if (std::abs(arg[i].categories[j].score - + expected_handedness[i].categories[j].score) > tolerance) { + return false; + } + if (arg[i].categories[j].category_name != + expected_handedness[i].categories[j].category_name) { + return false; + } + } + } + return true; +} + +MATCHER_P2(LandmarksMatches, expected_landmarks, toleration, "") { + for (int i = 0; i < arg.size(); i++) { + for (int j = 0; j < arg[i].landmarks.size(); j++) { + if (std::abs(arg[i].landmarks[j].x - + expected_landmarks[i].landmarks[j].x) > toleration || + std::abs(arg[i].landmarks[j].y - + expected_landmarks[i].landmarks[j].y) > toleration) { + return false; + } + } + } + return true; +} + void ExpectHandLandmarkerResultsCorrect( const HandLandmarkerResult& actual_results, const HandLandmarkerResult& expected_results) { @@ -119,16 +152,15 @@ void ExpectHandLandmarkerResultsCorrect( ASSERT_EQ(actual_landmarks.size(), expected_landmarks.size()); ASSERT_EQ(actual_handedness.size(), expected_handedness.size()); + if (actual_landmarks.empty()) { + return; + } + ASSERT_GE(actual_landmarks.size(), 1); - EXPECT_THAT( - actual_handedness, - Pointwise(Approximately(Partially(EqualsProto()), kHandednessMargin), - expected_handedness)); + EXPECT_THAT(actual_handedness, + HandednessMatches(expected_handedness, kHandednessMargin)); EXPECT_THAT(actual_landmarks, - Pointwise(Approximately(Partially(EqualsProto()), - /*margin=*/kLandmarksAbsMargin, - /*fraction=*/kLandmarksFractionDiff), - expected_landmarks)); + LandmarksMatches(expected_landmarks, kLandmarksAbsMargin)); } } // namespace @@ -188,7 +220,7 @@ TEST_F(ImageModeTest, FailsWithRegionOfInterest) { options->running_mode = core::RunningMode::IMAGE; MP_ASSERT_OK_AND_ASSIGN(std::unique_ptr hand_landmarker, HandLandmarker::Create(std::move(options))); - Rect roi{/*left=*/0.1, /*top=*/0, /*right=*/0.9, /*bottom=*/1}; + RectF roi{/*left=*/0.1, /*top=*/0, /*right=*/0.9, /*bottom=*/1}; ImageProcessingOptions image_processing_options{roi, /*rotation_degrees=*/0}; auto results = hand_landmarker->Detect(image, image_processing_options); diff --git a/mediapipe/tasks/cc/vision/hand_landmarker/hand_landmarks_detector_graph.cc b/mediapipe/tasks/cc/vision/hand_landmarker/hand_landmarks_detector_graph.cc index 1f127deb8..914bc30fc 100644 --- a/mediapipe/tasks/cc/vision/hand_landmarker/hand_landmarks_detector_graph.cc +++ b/mediapipe/tasks/cc/vision/hand_landmarker/hand_landmarks_detector_graph.cc @@ -33,7 +33,7 @@ limitations under the License. #include "mediapipe/framework/formats/rect.pb.h" #include "mediapipe/framework/formats/tensor.h" #include "mediapipe/tasks/cc/common.h" -#include "mediapipe/tasks/cc/components/image_preprocessing.h" +#include "mediapipe/tasks/cc/components/processors/image_preprocessing_graph.h" #include "mediapipe/tasks/cc/components/utils/gate.h" #include "mediapipe/tasks/cc/core/model_resources.h" #include "mediapipe/tasks/cc/core/model_task_graph.h" @@ -53,6 +53,7 @@ namespace hand_landmarker { namespace { +using ::mediapipe::NormalizedRect; using ::mediapipe::api2::Input; using ::mediapipe::api2::Output; using ::mediapipe::api2::builder::Graph; @@ -242,11 +243,12 @@ class SingleHandLandmarksDetectorGraph : public core::ModelTaskGraph { const auto* model_resources, CreateModelResources(sc)); Graph graph; - ASSIGN_OR_RETURN(auto hand_landmark_detection_outs, - BuildSingleHandLandmarksDetectorGraph( - sc->Options(), - *model_resources, graph[Input(kImageTag)], - graph[Input(kHandRectTag)], graph)); + ASSIGN_OR_RETURN( + auto hand_landmark_detection_outs, + BuildSingleHandLandmarksDetectorGraph( + sc->Options(), *model_resources, + graph[Input(kImageTag)], + graph[Input::Optional(kHandRectTag)], graph)); hand_landmark_detection_outs.hand_landmarks >> graph[Output(kLandmarksTag)]; hand_landmark_detection_outs.world_hand_landmarks >> @@ -281,14 +283,15 @@ class SingleHandLandmarksDetectorGraph : public core::ModelTaskGraph { Source hand_rect, Graph& graph) { MP_RETURN_IF_ERROR(SanityCheckOptions(subgraph_options)); - auto& preprocessing = - graph.AddNode("mediapipe.tasks.components.ImagePreprocessingSubgraph"); - bool use_gpu = components::DetermineImagePreprocessingGpuBackend( - subgraph_options.base_options().acceleration()); - MP_RETURN_IF_ERROR(ConfigureImagePreprocessing( + auto& preprocessing = graph.AddNode( + "mediapipe.tasks.components.processors.ImagePreprocessingGraph"); + bool use_gpu = + components::processors::DetermineImagePreprocessingGpuBackend( + subgraph_options.base_options().acceleration()); + MP_RETURN_IF_ERROR(components::processors::ConfigureImagePreprocessingGraph( model_resources, use_gpu, - &preprocessing - .GetOptions())); + &preprocessing.GetOptions())); image_in >> preprocessing.In("IMAGE"); hand_rect >> preprocessing.In("NORM_RECT"); auto image_size = preprocessing[Output>("IMAGE_SIZE")]; diff --git a/mediapipe/tasks/cc/vision/hand_landmarker/hand_landmarks_detector_graph_test.cc b/mediapipe/tasks/cc/vision/hand_landmarker/hand_landmarks_detector_graph_test.cc index d1e928ce7..f28907d2f 100644 --- a/mediapipe/tasks/cc/vision/hand_landmarker/hand_landmarks_detector_graph_test.cc +++ b/mediapipe/tasks/cc/vision/hand_landmarker/hand_landmarks_detector_graph_test.cc @@ -50,6 +50,7 @@ namespace { using ::file::Defaults; using ::file::GetTextProto; +using ::mediapipe::NormalizedRect; using ::mediapipe::api2::Input; using ::mediapipe::api2::Output; using ::mediapipe::api2::builder::Graph; diff --git a/mediapipe/tasks/cc/vision/hand_landmarker/proto/hand_landmarker_graph_options.proto b/mediapipe/tasks/cc/vision/hand_landmarker/proto/hand_landmarker_graph_options.proto index 51e4e129a..d0edf99c0 100644 --- a/mediapipe/tasks/cc/vision/hand_landmarker/proto/hand_landmarker_graph_options.proto +++ b/mediapipe/tasks/cc/vision/hand_landmarker/proto/hand_landmarker_graph_options.proto @@ -18,6 +18,7 @@ syntax = "proto2"; package mediapipe.tasks.vision.hand_landmarker.proto; import "mediapipe/framework/calculator.proto"; +import "mediapipe/framework/calculator_options.proto"; import "mediapipe/tasks/cc/core/proto/base_options.proto"; import "mediapipe/tasks/cc/vision/hand_detector/proto/hand_detector_graph_options.proto"; import "mediapipe/tasks/cc/vision/hand_landmarker/proto/hand_landmarks_detector_graph_options.proto"; diff --git a/mediapipe/tasks/cc/vision/hand_landmarker/proto/hand_landmarks_detector_graph_options.proto b/mediapipe/tasks/cc/vision/hand_landmarker/proto/hand_landmarks_detector_graph_options.proto index 195f6e5cc..a2d520963 100644 --- a/mediapipe/tasks/cc/vision/hand_landmarker/proto/hand_landmarks_detector_graph_options.proto +++ b/mediapipe/tasks/cc/vision/hand_landmarker/proto/hand_landmarks_detector_graph_options.proto @@ -18,6 +18,7 @@ syntax = "proto2"; package mediapipe.tasks.vision.hand_landmarker.proto; import "mediapipe/framework/calculator.proto"; +import "mediapipe/framework/calculator_options.proto"; import "mediapipe/tasks/cc/core/proto/base_options.proto"; option java_package = "com.google.mediapipe.tasks.vision.handlandmarker.proto"; diff --git a/mediapipe/tasks/cc/vision/image_classifier/BUILD b/mediapipe/tasks/cc/vision/image_classifier/BUILD index b59d8d682..514e601ef 100644 --- a/mediapipe/tasks/cc/vision/image_classifier/BUILD +++ b/mediapipe/tasks/cc/vision/image_classifier/BUILD @@ -16,33 +16,13 @@ package(default_visibility = ["//mediapipe/tasks:internal"]) licenses(["notice"]) -cc_library( - name = "image_classifier_graph", - srcs = ["image_classifier_graph.cc"], - deps = [ - "//mediapipe/calculators/tensor:inference_calculator", - "//mediapipe/framework:calculator_framework", - "//mediapipe/framework/api2:builder", - "//mediapipe/framework/api2:port", - "//mediapipe/framework/formats:image", - "//mediapipe/framework/formats:rect_cc_proto", - "//mediapipe/tasks/cc/components:image_preprocessing", - "//mediapipe/tasks/cc/components:image_preprocessing_options_cc_proto", - "//mediapipe/tasks/cc/components/containers/proto:classifications_cc_proto", - "//mediapipe/tasks/cc/components/processors:classification_postprocessing_graph", - "//mediapipe/tasks/cc/components/processors/proto:classification_postprocessing_graph_options_cc_proto", - "//mediapipe/tasks/cc/core:model_resources", - "//mediapipe/tasks/cc/core:model_task_graph", - "//mediapipe/tasks/cc/vision/image_classifier/proto:image_classifier_graph_options_cc_proto", - "@com_google_absl//absl/status:statusor", - ], - alwayslink = 1, -) - +# Docs for Mediapipe Tasks Image Classifier +# https://developers.google.com/mediapipe/solutions/vision/image_classifier cc_library( name = "image_classifier", srcs = ["image_classifier.cc"], hdrs = ["image_classifier.h"], + visibility = ["//visibility:public"], deps = [ ":image_classifier_graph", "//mediapipe/framework:packet", @@ -69,4 +49,27 @@ cc_library( ], ) +cc_library( + name = "image_classifier_graph", + srcs = ["image_classifier_graph.cc"], + deps = [ + "//mediapipe/calculators/tensor:inference_calculator", + "//mediapipe/framework:calculator_framework", + "//mediapipe/framework/api2:builder", + "//mediapipe/framework/api2:port", + "//mediapipe/framework/formats:image", + "//mediapipe/framework/formats:rect_cc_proto", + "//mediapipe/tasks/cc/components/containers/proto:classifications_cc_proto", + "//mediapipe/tasks/cc/components/processors:classification_postprocessing_graph", + "//mediapipe/tasks/cc/components/processors:image_preprocessing_graph", + "//mediapipe/tasks/cc/components/processors/proto:classification_postprocessing_graph_options_cc_proto", + "//mediapipe/tasks/cc/components/processors/proto:image_preprocessing_graph_options_cc_proto", + "//mediapipe/tasks/cc/core:model_resources", + "//mediapipe/tasks/cc/core:model_task_graph", + "//mediapipe/tasks/cc/vision/image_classifier/proto:image_classifier_graph_options_cc_proto", + "@com_google_absl//absl/status:statusor", + ], + alwayslink = 1, +) + # TODO: This test fails in OSS diff --git a/mediapipe/tasks/cc/vision/image_classifier/image_classifier.cc b/mediapipe/tasks/cc/vision/image_classifier/image_classifier.cc index 60f8f7ed4..763e0a320 100644 --- a/mediapipe/tasks/cc/vision/image_classifier/image_classifier.cc +++ b/mediapipe/tasks/cc/vision/image_classifier/image_classifier.cc @@ -58,6 +58,7 @@ constexpr char kSubgraphTypeName[] = "mediapipe.tasks.vision.image_classifier.ImageClassifierGraph"; constexpr int kMicroSecondsPerMilliSecond = 1000; +using ::mediapipe::NormalizedRect; using ::mediapipe::tasks::components::containers::ConvertToClassificationResult; using ::mediapipe::tasks::components::containers::proto::ClassificationResult; using ::mediapipe::tasks::core::PacketMap; diff --git a/mediapipe/tasks/cc/vision/image_classifier/image_classifier_graph.cc b/mediapipe/tasks/cc/vision/image_classifier/image_classifier_graph.cc index 8fa1a0d2a..0adcf842d 100644 --- a/mediapipe/tasks/cc/vision/image_classifier/image_classifier_graph.cc +++ b/mediapipe/tasks/cc/vision/image_classifier/image_classifier_graph.cc @@ -23,10 +23,10 @@ limitations under the License. #include "mediapipe/framework/formats/image.h" #include "mediapipe/framework/formats/rect.pb.h" #include "mediapipe/tasks/cc/components/containers/proto/classifications.pb.h" -#include "mediapipe/tasks/cc/components/image_preprocessing.h" -#include "mediapipe/tasks/cc/components/image_preprocessing_options.pb.h" #include "mediapipe/tasks/cc/components/processors/classification_postprocessing_graph.h" +#include "mediapipe/tasks/cc/components/processors/image_preprocessing_graph.h" #include "mediapipe/tasks/cc/components/processors/proto/classification_postprocessing_graph_options.pb.h" +#include "mediapipe/tasks/cc/components/processors/proto/image_preprocessing_graph_options.pb.h" #include "mediapipe/tasks/cc/core/model_resources.h" #include "mediapipe/tasks/cc/core/model_task_graph.h" #include "mediapipe/tasks/cc/vision/image_classifier/proto/image_classifier_graph_options.pb.h" @@ -38,6 +38,7 @@ namespace image_classifier { namespace { +using ::mediapipe::NormalizedRect; using ::mediapipe::api2::Input; using ::mediapipe::api2::Output; using ::mediapipe::api2::builder::GenericNode; @@ -47,7 +48,6 @@ using ::mediapipe::tasks::components::containers::proto::ClassificationResult; constexpr float kDefaultScoreThreshold = std::numeric_limits::lowest(); -constexpr char kClassificationResultTag[] = "CLASSIFICATION_RESULT"; constexpr char kClassificationsTag[] = "CLASSIFICATIONS"; constexpr char kImageTag[] = "IMAGE"; constexpr char kNormRectTag[] = "NORM_RECT"; @@ -56,7 +56,6 @@ constexpr char kTensorsTag[] = "TENSORS"; // Struct holding the different output streams produced by the image classifier // subgraph. struct ImageClassifierOutputStreams { - Source classification_result; Source classifications; Source image; }; @@ -77,9 +76,6 @@ struct ImageClassifierOutputStreams { // The classification results aggregated by classifier head. // IMAGE - Image // The image that object detection runs on. -// TODO: remove this output once Java API migration is over. -// CLASSIFICATION_RESULT - (DEPRECATED) ClassificationResult @Optional -// The aggregated classification result. // // Example: // node { @@ -117,8 +113,6 @@ class ImageClassifierGraph : public core::ModelTaskGraph { sc->Options(), *model_resources, graph[Input(kImageTag)], graph[Input::Optional(kNormRectTag)], graph)); - output_streams.classification_result >> - graph[Output(kClassificationResultTag)]; output_streams.classifications >> graph[Output(kClassificationsTag)]; output_streams.image >> graph[Output(kImageTag)]; @@ -142,14 +136,15 @@ class ImageClassifierGraph : public core::ModelTaskGraph { Source norm_rect_in, Graph& graph) { // Adds preprocessing calculators and connects them to the graph input image // stream. - auto& preprocessing = - graph.AddNode("mediapipe.tasks.components.ImagePreprocessingSubgraph"); - bool use_gpu = components::DetermineImagePreprocessingGpuBackend( - task_options.base_options().acceleration()); - MP_RETURN_IF_ERROR(ConfigureImagePreprocessing( + auto& preprocessing = graph.AddNode( + "mediapipe.tasks.components.processors.ImagePreprocessingGraph"); + bool use_gpu = + components::processors::DetermineImagePreprocessingGpuBackend( + task_options.base_options().acceleration()); + MP_RETURN_IF_ERROR(components::processors::ConfigureImagePreprocessingGraph( model_resources, use_gpu, - &preprocessing - .GetOptions())); + &preprocessing.GetOptions())); image_in >> preprocessing.In(kImageTag); norm_rect_in >> preprocessing.In(kNormRectTag); @@ -174,8 +169,6 @@ class ImageClassifierGraph : public core::ModelTaskGraph { // Outputs the aggregated classification result as the subgraph output // stream. return ImageClassifierOutputStreams{ - /*classification_result=*/postprocessing[Output( - kClassificationResultTag)], /*classifications=*/ postprocessing[Output(kClassificationsTag)], /*image=*/preprocessing[Output(kImageTag)]}; diff --git a/mediapipe/tasks/cc/vision/image_classifier/image_classifier_test.cc b/mediapipe/tasks/cc/vision/image_classifier/image_classifier_test.cc index 1144e9032..7aa2a148c 100644 --- a/mediapipe/tasks/cc/vision/image_classifier/image_classifier_test.cc +++ b/mediapipe/tasks/cc/vision/image_classifier/image_classifier_test.cc @@ -52,7 +52,7 @@ namespace { using ::mediapipe::file::JoinPath; using ::mediapipe::tasks::components::containers::Category; using ::mediapipe::tasks::components::containers::Classifications; -using ::mediapipe::tasks::components::containers::Rect; +using ::mediapipe::tasks::components::containers::RectF; using ::mediapipe::tasks::vision::core::ImageProcessingOptions; using ::testing::HasSubstr; using ::testing::Optional; @@ -472,7 +472,7 @@ TEST_F(ImageModeTest, SucceedsWithRegionOfInterest) { MP_ASSERT_OK_AND_ASSIGN(std::unique_ptr image_classifier, ImageClassifier::Create(std::move(options))); // Region-of-interest around the soccer ball. - Rect roi{/*left=*/0.45, /*top=*/0.3075, /*right=*/0.614, /*bottom=*/0.7345}; + RectF roi{/*left=*/0.45, /*top=*/0.3075, /*right=*/0.614, /*bottom=*/0.7345}; ImageProcessingOptions image_processing_options{roi, /*rotation_degrees=*/0}; MP_ASSERT_OK_AND_ASSIGN(auto results, image_classifier->Classify( @@ -526,7 +526,8 @@ TEST_F(ImageModeTest, SucceedsWithRegionOfInterestAndRotation) { MP_ASSERT_OK_AND_ASSIGN(std::unique_ptr image_classifier, ImageClassifier::Create(std::move(options))); // Region-of-interest around the chair, with 90° anti-clockwise rotation. - Rect roi{/*left=*/0.006, /*top=*/0.1763, /*right=*/0.5702, /*bottom=*/0.3049}; + RectF roi{/*left=*/0.006, /*top=*/0.1763, /*right=*/0.5702, + /*bottom=*/0.3049}; ImageProcessingOptions image_processing_options{roi, /*rotation_degrees=*/-90}; @@ -554,13 +555,13 @@ TEST_F(ImageModeTest, FailsWithInvalidImageProcessingOptions) { ImageClassifier::Create(std::move(options))); // Invalid: left > right. - Rect roi{/*left=*/0.9, /*top=*/0, /*right=*/0.1, /*bottom=*/1}; + RectF roi{/*left=*/0.9, /*top=*/0, /*right=*/0.1, /*bottom=*/1}; ImageProcessingOptions image_processing_options{roi, /*rotation_degrees=*/0}; auto results = image_classifier->Classify(image, image_processing_options); EXPECT_EQ(results.status().code(), absl::StatusCode::kInvalidArgument); EXPECT_THAT(results.status().message(), - HasSubstr("Expected Rect with left < right and top < bottom")); + HasSubstr("Expected RectF with left < right and top < bottom")); EXPECT_THAT( results.status().GetPayload(kMediaPipeTasksPayload), Optional(absl::Cord(absl::StrCat( @@ -573,7 +574,7 @@ TEST_F(ImageModeTest, FailsWithInvalidImageProcessingOptions) { results = image_classifier->Classify(image, image_processing_options); EXPECT_EQ(results.status().code(), absl::StatusCode::kInvalidArgument); EXPECT_THAT(results.status().message(), - HasSubstr("Expected Rect with left < right and top < bottom")); + HasSubstr("Expected RectF with left < right and top < bottom")); EXPECT_THAT( results.status().GetPayload(kMediaPipeTasksPayload), Optional(absl::Cord(absl::StrCat( @@ -586,7 +587,7 @@ TEST_F(ImageModeTest, FailsWithInvalidImageProcessingOptions) { results = image_classifier->Classify(image, image_processing_options); EXPECT_EQ(results.status().code(), absl::StatusCode::kInvalidArgument); EXPECT_THAT(results.status().message(), - HasSubstr("Expected Rect values to be in [0,1]")); + HasSubstr("Expected RectF values to be in [0,1]")); EXPECT_THAT( results.status().GetPayload(kMediaPipeTasksPayload), Optional(absl::Cord(absl::StrCat( @@ -695,7 +696,7 @@ TEST_F(VideoModeTest, SucceedsWithRegionOfInterest) { ImageClassifier::Create(std::move(options))); // Crop around the soccer ball. // Region-of-interest around the soccer ball. - Rect roi{/*left=*/0.45, /*top=*/0.3075, /*right=*/0.614, /*bottom=*/0.7345}; + RectF roi{/*left=*/0.45, /*top=*/0.3075, /*right=*/0.614, /*bottom=*/0.7345}; ImageProcessingOptions image_processing_options{roi, /*rotation_degrees=*/0}; for (int i = 0; i < iterations; ++i) { @@ -837,7 +838,7 @@ TEST_F(LiveStreamModeTest, SucceedsWithRegionOfInterest) { MP_ASSERT_OK_AND_ASSIGN(std::unique_ptr image_classifier, ImageClassifier::Create(std::move(options))); // Crop around the soccer ball. - Rect roi{/*left=*/0.45, /*top=*/0.3075, /*right=*/0.614, /*bottom=*/0.7345}; + RectF roi{/*left=*/0.45, /*top=*/0.3075, /*right=*/0.614, /*bottom=*/0.7345}; ImageProcessingOptions image_processing_options{roi, /*rotation_degrees=*/0}; for (int i = 0; i < iterations; ++i) { diff --git a/mediapipe/tasks/cc/vision/image_classifier/proto/image_classifier_graph_options.proto b/mediapipe/tasks/cc/vision/image_classifier/proto/image_classifier_graph_options.proto index 76315e230..24b126a35 100644 --- a/mediapipe/tasks/cc/vision/image_classifier/proto/image_classifier_graph_options.proto +++ b/mediapipe/tasks/cc/vision/image_classifier/proto/image_classifier_graph_options.proto @@ -18,6 +18,7 @@ syntax = "proto2"; package mediapipe.tasks.vision.image_classifier.proto; import "mediapipe/framework/calculator.proto"; +import "mediapipe/framework/calculator_options.proto"; import "mediapipe/tasks/cc/components/processors/proto/classifier_options.proto"; import "mediapipe/tasks/cc/core/proto/base_options.proto"; diff --git a/mediapipe/tasks/cc/vision/image_embedder/BUILD b/mediapipe/tasks/cc/vision/image_embedder/BUILD index ea7f40261..d729eaf1a 100644 --- a/mediapipe/tasks/cc/vision/image_embedder/BUILD +++ b/mediapipe/tasks/cc/vision/image_embedder/BUILD @@ -16,33 +16,13 @@ package(default_visibility = ["//mediapipe/tasks:internal"]) licenses(["notice"]) -cc_library( - name = "image_embedder_graph", - srcs = ["image_embedder_graph.cc"], - deps = [ - "//mediapipe/calculators/tensor:inference_calculator", - "//mediapipe/framework:calculator_framework", - "//mediapipe/framework/api2:builder", - "//mediapipe/framework/api2:port", - "//mediapipe/framework/formats:image", - "//mediapipe/framework/formats:rect_cc_proto", - "//mediapipe/tasks/cc/components:image_preprocessing", - "//mediapipe/tasks/cc/components:image_preprocessing_options_cc_proto", - "//mediapipe/tasks/cc/components/calculators:tensors_to_embeddings_calculator", - "//mediapipe/tasks/cc/components/containers/proto:embeddings_cc_proto", - "//mediapipe/tasks/cc/components/processors:embedding_postprocessing_graph", - "//mediapipe/tasks/cc/components/processors/proto:embedding_postprocessing_graph_options_cc_proto", - "//mediapipe/tasks/cc/core:model_task_graph", - "//mediapipe/tasks/cc/vision/image_embedder/proto:image_embedder_graph_options_cc_proto", - "@com_google_absl//absl/status:statusor", - ], - alwayslink = 1, -) - +# Docs for Mediapipe Tasks Image Embedder +# https://developers.google.com/mediapipe/solutions/vision/image_embedder cc_library( name = "image_embedder", srcs = ["image_embedder.cc"], hdrs = ["image_embedder.h"], + visibility = ["//visibility:public"], deps = [ ":image_embedder_graph", "//mediapipe/framework/api2:builder", @@ -67,4 +47,27 @@ cc_library( ], ) +cc_library( + name = "image_embedder_graph", + srcs = ["image_embedder_graph.cc"], + deps = [ + "//mediapipe/calculators/tensor:inference_calculator", + "//mediapipe/framework:calculator_framework", + "//mediapipe/framework/api2:builder", + "//mediapipe/framework/api2:port", + "//mediapipe/framework/formats:image", + "//mediapipe/framework/formats:rect_cc_proto", + "//mediapipe/tasks/cc/components/calculators:tensors_to_embeddings_calculator", + "//mediapipe/tasks/cc/components/containers/proto:embeddings_cc_proto", + "//mediapipe/tasks/cc/components/processors:embedding_postprocessing_graph", + "//mediapipe/tasks/cc/components/processors:image_preprocessing_graph", + "//mediapipe/tasks/cc/components/processors/proto:embedding_postprocessing_graph_options_cc_proto", + "//mediapipe/tasks/cc/components/processors/proto:image_preprocessing_graph_options_cc_proto", + "//mediapipe/tasks/cc/core:model_task_graph", + "//mediapipe/tasks/cc/vision/image_embedder/proto:image_embedder_graph_options_cc_proto", + "@com_google_absl//absl/status:statusor", + ], + alwayslink = 1, +) + # TODO: This test fails in OSS diff --git a/mediapipe/tasks/cc/vision/image_embedder/image_embedder.cc b/mediapipe/tasks/cc/vision/image_embedder/image_embedder.cc index e3198090f..494b075a7 100644 --- a/mediapipe/tasks/cc/vision/image_embedder/image_embedder.cc +++ b/mediapipe/tasks/cc/vision/image_embedder/image_embedder.cc @@ -54,6 +54,7 @@ constexpr char kGraphTypeName[] = "mediapipe.tasks.vision.image_embedder.ImageEmbedderGraph"; constexpr int kMicroSecondsPerMilliSecond = 1000; +using ::mediapipe::NormalizedRect; using ::mediapipe::tasks::components::containers::ConvertToEmbeddingResult; using ::mediapipe::tasks::components::containers::proto::EmbeddingResult; using ::mediapipe::tasks::core::PacketMap; diff --git a/mediapipe/tasks/cc/vision/image_embedder/image_embedder_graph.cc b/mediapipe/tasks/cc/vision/image_embedder/image_embedder_graph.cc index 11e25144c..95c4ff379 100644 --- a/mediapipe/tasks/cc/vision/image_embedder/image_embedder_graph.cc +++ b/mediapipe/tasks/cc/vision/image_embedder/image_embedder_graph.cc @@ -20,10 +20,10 @@ limitations under the License. #include "mediapipe/framework/formats/image.h" #include "mediapipe/framework/formats/rect.pb.h" #include "mediapipe/tasks/cc/components/containers/proto/embeddings.pb.h" -#include "mediapipe/tasks/cc/components/image_preprocessing.h" -#include "mediapipe/tasks/cc/components/image_preprocessing_options.pb.h" #include "mediapipe/tasks/cc/components/processors/embedding_postprocessing_graph.h" +#include "mediapipe/tasks/cc/components/processors/image_preprocessing_graph.h" #include "mediapipe/tasks/cc/components/processors/proto/embedding_postprocessing_graph_options.pb.h" +#include "mediapipe/tasks/cc/components/processors/proto/image_preprocessing_graph_options.pb.h" #include "mediapipe/tasks/cc/core/model_task_graph.h" #include "mediapipe/tasks/cc/vision/image_embedder/proto/image_embedder_graph_options.pb.h" @@ -34,6 +34,7 @@ namespace image_embedder { namespace { +using ::mediapipe::NormalizedRect; using ::mediapipe::api2::Input; using ::mediapipe::api2::Output; using ::mediapipe::api2::builder::GenericNode; @@ -130,14 +131,15 @@ class ImageEmbedderGraph : public core::ModelTaskGraph { Source norm_rect_in, Graph& graph) { // Adds preprocessing calculators and connects them to the graph input image // stream. - auto& preprocessing = - graph.AddNode("mediapipe.tasks.components.ImagePreprocessingSubgraph"); - bool use_gpu = components::DetermineImagePreprocessingGpuBackend( - task_options.base_options().acceleration()); - MP_RETURN_IF_ERROR(ConfigureImagePreprocessing( + auto& preprocessing = graph.AddNode( + "mediapipe.tasks.components.processors.ImagePreprocessingGraph"); + bool use_gpu = + components::processors::DetermineImagePreprocessingGpuBackend( + task_options.base_options().acceleration()); + MP_RETURN_IF_ERROR(components::processors::ConfigureImagePreprocessingGraph( model_resources, use_gpu, - &preprocessing - .GetOptions())); + &preprocessing.GetOptions())); image_in >> preprocessing.In(kImageTag); norm_rect_in >> preprocessing.In(kNormRectTag); @@ -151,10 +153,12 @@ class ImageEmbedderGraph : public core::ModelTaskGraph { // inference results. auto& postprocessing = graph.AddNode( "mediapipe.tasks.components.processors.EmbeddingPostprocessingGraph"); - MP_RETURN_IF_ERROR(components::processors::ConfigureEmbeddingPostprocessing( - model_resources, task_options.embedder_options(), - &postprocessing.GetOptions())); + MP_RETURN_IF_ERROR( + components::processors::ConfigureEmbeddingPostprocessingGraph( + model_resources, task_options.embedder_options(), + &postprocessing + .GetOptions())); inference.Out(kTensorsTag) >> postprocessing.In(kTensorsTag); // Outputs the embedding results. diff --git a/mediapipe/tasks/cc/vision/image_embedder/image_embedder_test.cc b/mediapipe/tasks/cc/vision/image_embedder/image_embedder_test.cc index 6098a9a70..dd602bef5 100644 --- a/mediapipe/tasks/cc/vision/image_embedder/image_embedder_test.cc +++ b/mediapipe/tasks/cc/vision/image_embedder/image_embedder_test.cc @@ -41,7 +41,7 @@ namespace image_embedder { namespace { using ::mediapipe::file::JoinPath; -using ::mediapipe::tasks::components::containers::Rect; +using ::mediapipe::tasks::components::containers::RectF; using ::mediapipe::tasks::vision::core::ImageProcessingOptions; using ::testing::HasSubstr; using ::testing::Optional; @@ -320,7 +320,7 @@ TEST_F(ImageModeTest, SucceedsWithRegionOfInterest) { Image crop, DecodeImageFromFile( JoinPath("./", kTestDataDirectory, "burger_crop.jpg"))); // Region-of-interest in "burger.jpg" corresponding to "burger_crop.jpg". - Rect roi{/*left=*/0, /*top=*/0, /*right=*/0.833333, /*bottom=*/1}; + RectF roi{/*left=*/0, /*top=*/0, /*right=*/0.833333, /*bottom=*/1}; ImageProcessingOptions image_processing_options{roi, /*rotation_degrees=*/0}; // Extract both embeddings. @@ -388,7 +388,7 @@ TEST_F(ImageModeTest, SucceedsWithRegionOfInterestAndRotation) { DecodeImageFromFile(JoinPath("./", kTestDataDirectory, "burger_rotated.jpg"))); // Region-of-interest corresponding to burger_crop.jpg. - Rect roi{/*left=*/0, /*top=*/0, /*right=*/1, /*bottom=*/0.8333333}; + RectF roi{/*left=*/0, /*top=*/0, /*right=*/1, /*bottom=*/0.8333333}; ImageProcessingOptions image_processing_options{roi, /*rotation_degrees=*/-90}; diff --git a/mediapipe/tasks/cc/vision/image_embedder/proto/image_embedder_graph_options.proto b/mediapipe/tasks/cc/vision/image_embedder/proto/image_embedder_graph_options.proto index 4adba5ab7..24ee866f2 100644 --- a/mediapipe/tasks/cc/vision/image_embedder/proto/image_embedder_graph_options.proto +++ b/mediapipe/tasks/cc/vision/image_embedder/proto/image_embedder_graph_options.proto @@ -18,9 +18,13 @@ syntax = "proto2"; package mediapipe.tasks.vision.image_embedder.proto; import "mediapipe/framework/calculator.proto"; +import "mediapipe/framework/calculator_options.proto"; import "mediapipe/tasks/cc/components/processors/proto/embedder_options.proto"; import "mediapipe/tasks/cc/core/proto/base_options.proto"; +option java_package = "com.google.mediapipe.tasks.vision.imageembedder.proto"; +option java_outer_classname = "ImageEmbedderGraphOptionsProto"; + message ImageEmbedderGraphOptions { extend mediapipe.CalculatorOptions { optional ImageEmbedderGraphOptions ext = 476348187; diff --git a/mediapipe/tasks/cc/vision/image_segmenter/BUILD b/mediapipe/tasks/cc/vision/image_segmenter/BUILD index 4c43a07f5..4c9c6e69c 100644 --- a/mediapipe/tasks/cc/vision/image_segmenter/BUILD +++ b/mediapipe/tasks/cc/vision/image_segmenter/BUILD @@ -16,16 +16,18 @@ package(default_visibility = ["//mediapipe/tasks:internal"]) licenses(["notice"]) +# Docs for Mediapipe Tasks Image Segmenter +# https://developers.google.com/mediapipe/solutions/vision/image_segmenter cc_library( name = "image_segmenter", srcs = ["image_segmenter.cc"], hdrs = ["image_segmenter.h"], + visibility = ["//visibility:public"], deps = [ ":image_segmenter_graph", "//mediapipe/framework/api2:builder", "//mediapipe/framework/formats:image", "//mediapipe/framework/formats:rect_cc_proto", - "//mediapipe/tasks/cc/components/proto:segmenter_options_cc_proto", "//mediapipe/tasks/cc/core:base_options", "//mediapipe/tasks/cc/core:utils", "//mediapipe/tasks/cc/vision/core:base_vision_task_api", @@ -33,6 +35,7 @@ cc_library( "//mediapipe/tasks/cc/vision/core:running_mode", "//mediapipe/tasks/cc/vision/core:vision_task_api_factory", "//mediapipe/tasks/cc/vision/image_segmenter/proto:image_segmenter_graph_options_cc_proto", + "//mediapipe/tasks/cc/vision/image_segmenter/proto:segmenter_options_cc_proto", "@com_google_absl//absl/memory", "@com_google_absl//absl/status:statusor", "@org_tensorflow//tensorflow/lite/kernels:builtin_ops", @@ -53,17 +56,17 @@ cc_library( "//mediapipe/framework/formats:rect_cc_proto", "//mediapipe/framework/port:status", "//mediapipe/tasks/cc:common", - "//mediapipe/tasks/cc/components:image_preprocessing", - "//mediapipe/tasks/cc/components:image_preprocessing_options_cc_proto", - "//mediapipe/tasks/cc/components/calculators/tensor:tensors_to_segmentation_calculator", - "//mediapipe/tasks/cc/components/calculators/tensor:tensors_to_segmentation_calculator_cc_proto", - "//mediapipe/tasks/cc/components/proto:segmenter_options_cc_proto", + "//mediapipe/tasks/cc/components/processors:image_preprocessing_graph", + "//mediapipe/tasks/cc/components/processors/proto:image_preprocessing_graph_options_cc_proto", "//mediapipe/tasks/cc/core:model_resources", "//mediapipe/tasks/cc/core:model_task_graph", "//mediapipe/tasks/cc/core/proto:acceleration_cc_proto", "//mediapipe/tasks/cc/core/proto:inference_subgraph_cc_proto", "//mediapipe/tasks/cc/metadata:metadata_extractor", + "//mediapipe/tasks/cc/vision/image_segmenter/calculators:tensors_to_segmentation_calculator", + "//mediapipe/tasks/cc/vision/image_segmenter/calculators:tensors_to_segmentation_calculator_cc_proto", "//mediapipe/tasks/cc/vision/image_segmenter/proto:image_segmenter_graph_options_cc_proto", + "//mediapipe/tasks/cc/vision/image_segmenter/proto:segmenter_options_cc_proto", "//mediapipe/tasks/metadata:metadata_schema_cc", "//mediapipe/util:label_map_cc_proto", "//mediapipe/util:label_map_util", diff --git a/mediapipe/tasks/cc/components/calculators/tensor/BUILD b/mediapipe/tasks/cc/vision/image_segmenter/calculators/BUILD similarity index 94% rename from mediapipe/tasks/cc/components/calculators/tensor/BUILD rename to mediapipe/tasks/cc/vision/image_segmenter/calculators/BUILD index 6e4322a8f..dcd7fb407 100644 --- a/mediapipe/tasks/cc/components/calculators/tensor/BUILD +++ b/mediapipe/tasks/cc/vision/image_segmenter/calculators/BUILD @@ -25,7 +25,7 @@ mediapipe_proto_library( "//mediapipe/framework:calculator_options_proto", "//mediapipe/framework:calculator_proto", "//mediapipe/framework/formats:image_format_proto", - "//mediapipe/tasks/cc/components/proto:segmenter_options_proto", + "//mediapipe/tasks/cc/vision/image_segmenter/proto:segmenter_options_proto", "//mediapipe/util:label_map_proto", ], ) @@ -45,7 +45,7 @@ cc_library( "//mediapipe/framework/port:opencv_core", "//mediapipe/framework/port:opencv_imgproc", "//mediapipe/framework/port:status", - "//mediapipe/tasks/cc/components/proto:segmenter_options_cc_proto", + "//mediapipe/tasks/cc/vision/image_segmenter/proto:segmenter_options_cc_proto", "//mediapipe/tasks/cc/vision/utils:image_utils", "//mediapipe/util:label_map_cc_proto", "@com_google_absl//absl/status", diff --git a/mediapipe/tasks/cc/components/calculators/tensor/tensors_to_segmentation_calculator.cc b/mediapipe/tasks/cc/vision/image_segmenter/calculators/tensors_to_segmentation_calculator.cc similarity index 95% rename from mediapipe/tasks/cc/components/calculators/tensor/tensors_to_segmentation_calculator.cc rename to mediapipe/tasks/cc/vision/image_segmenter/calculators/tensors_to_segmentation_calculator.cc index 40585848f..668de0057 100644 --- a/mediapipe/tasks/cc/components/calculators/tensor/tensors_to_segmentation_calculator.cc +++ b/mediapipe/tasks/cc/vision/image_segmenter/calculators/tensors_to_segmentation_calculator.cc @@ -13,7 +13,6 @@ See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ -// TODO consolidate TensorsToSegmentationCalculator. #include #include #include @@ -35,14 +34,14 @@ limitations under the License. #include "mediapipe/framework/port/opencv_core_inc.h" #include "mediapipe/framework/port/opencv_imgproc_inc.h" #include "mediapipe/framework/port/status_macros.h" -#include "mediapipe/tasks/cc/components/calculators/tensor/tensors_to_segmentation_calculator.pb.h" -#include "mediapipe/tasks/cc/components/proto/segmenter_options.pb.h" +#include "mediapipe/tasks/cc/vision/image_segmenter/calculators/tensors_to_segmentation_calculator.pb.h" +#include "mediapipe/tasks/cc/vision/image_segmenter/proto/segmenter_options.pb.h" #include "mediapipe/tasks/cc/vision/utils/image_utils.h" #include "mediapipe/util/label_map.pb.h" +// TODO: consolidate TensorToSegmentationCalculator. namespace mediapipe { namespace tasks { - namespace { using ::mediapipe::Image; @@ -51,9 +50,9 @@ using ::mediapipe::api2::Input; using ::mediapipe::api2::Node; using ::mediapipe::api2::Output; using ::mediapipe::tasks::TensorsToSegmentationCalculatorOptions; -using ::mediapipe::tasks::components::proto::SegmenterOptions; using ::mediapipe::tasks::vision::GetImageLikeTensorShape; using ::mediapipe::tasks::vision::Shape; +using ::mediapipe::tasks::vision::image_segmenter::proto::SegmenterOptions; void StableSoftmax(absl::Span values, absl::Span activated_values) { @@ -90,7 +89,7 @@ void Sigmoid(absl::Span values, // the size to resize masks to. // // Output: -// Segmentation: Segmenation proto. +// Segmentation: Segmentation proto. // // Options: // See tensors_to_segmentation_calculator.proto @@ -132,8 +131,7 @@ class TensorsToSegmentationCalculator : public Node { absl::Status TensorsToSegmentationCalculator::Open( mediapipe::CalculatorContext* cc) { - options_ = - cc->Options(); + options_ = cc->Options(); RET_CHECK_NE(options_.segmenter_options().output_type(), SegmenterOptions::UNSPECIFIED) << "Must specify output_type as one of [CONFIDENCE_MASK|CATEGORY_MASK]."; diff --git a/mediapipe/tasks/cc/components/calculators/tensor/tensors_to_segmentation_calculator.proto b/mediapipe/tasks/cc/vision/image_segmenter/calculators/tensors_to_segmentation_calculator.proto similarity index 82% rename from mediapipe/tasks/cc/components/calculators/tensor/tensors_to_segmentation_calculator.proto rename to mediapipe/tasks/cc/vision/image_segmenter/calculators/tensors_to_segmentation_calculator.proto index c26cf910a..b0fdfdd32 100644 --- a/mediapipe/tasks/cc/components/calculators/tensor/tensors_to_segmentation_calculator.proto +++ b/mediapipe/tasks/cc/vision/image_segmenter/calculators/tensors_to_segmentation_calculator.proto @@ -15,10 +15,11 @@ limitations under the License. syntax = "proto2"; +// TODO: consolidate TensorToSegmentationCalculator. package mediapipe.tasks; import "mediapipe/framework/calculator.proto"; -import "mediapipe/tasks/cc/components/proto/segmenter_options.proto"; +import "mediapipe/tasks/cc/vision/image_segmenter/proto/segmenter_options.proto"; import "mediapipe/util/label_map.proto"; message TensorsToSegmentationCalculatorOptions { @@ -26,7 +27,8 @@ message TensorsToSegmentationCalculatorOptions { optional TensorsToSegmentationCalculatorOptions ext = 458105876; } - optional components.proto.SegmenterOptions segmenter_options = 1; + optional mediapipe.tasks.vision.image_segmenter.proto.SegmenterOptions + segmenter_options = 1; // Identifying information for each classification label. map label_items = 2; diff --git a/mediapipe/tasks/cc/components/calculators/tensor/tensors_to_segmentation_calculator_test.cc b/mediapipe/tasks/cc/vision/image_segmenter/calculators/tensors_to_segmentation_calculator_test.cc similarity index 99% rename from mediapipe/tasks/cc/components/calculators/tensor/tensors_to_segmentation_calculator_test.cc rename to mediapipe/tasks/cc/vision/image_segmenter/calculators/tensors_to_segmentation_calculator_test.cc index 55e46d72b..54fb9b816 100644 --- a/mediapipe/tasks/cc/components/calculators/tensor/tensors_to_segmentation_calculator_test.cc +++ b/mediapipe/tasks/cc/vision/image_segmenter/calculators/tensors_to_segmentation_calculator_test.cc @@ -33,10 +33,9 @@ limitations under the License. #include "mediapipe/framework/port/gtest.h" #include "mediapipe/framework/port/parse_text_proto.h" #include "mediapipe/framework/port/status_matchers.h" -#include "mediapipe/tasks/cc/components/calculators/tensor/tensors_to_segmentation_calculator.pb.h" +#include "mediapipe/tasks/cc/vision/image_segmenter/calculators/tensors_to_segmentation_calculator.pb.h" namespace mediapipe { -namespace api2 { namespace { @@ -374,5 +373,4 @@ TEST(TensorsToSegmentationCalculatorTest, SucceedsCategoryMaskResize) { expected_index, buffer_indices))); } -} // namespace api2 } // namespace mediapipe diff --git a/mediapipe/tasks/cc/vision/image_segmenter/image_segmenter.cc b/mediapipe/tasks/cc/vision/image_segmenter/image_segmenter.cc index 6dce1b4ea..7130c72e2 100644 --- a/mediapipe/tasks/cc/vision/image_segmenter/image_segmenter.cc +++ b/mediapipe/tasks/cc/vision/image_segmenter/image_segmenter.cc @@ -18,12 +18,12 @@ limitations under the License. #include "mediapipe/framework/api2/builder.h" #include "mediapipe/framework/formats/image.h" #include "mediapipe/framework/formats/rect.pb.h" -#include "mediapipe/tasks/cc/components/proto/segmenter_options.pb.h" #include "mediapipe/tasks/cc/core/utils.h" #include "mediapipe/tasks/cc/vision/core/image_processing_options.h" #include "mediapipe/tasks/cc/vision/core/running_mode.h" #include "mediapipe/tasks/cc/vision/core/vision_task_api_factory.h" #include "mediapipe/tasks/cc/vision/image_segmenter/proto/image_segmenter_graph_options.pb.h" +#include "mediapipe/tasks/cc/vision/image_segmenter/proto/segmenter_options.pb.h" namespace mediapipe { namespace tasks { @@ -44,7 +44,8 @@ constexpr int kMicroSecondsPerMilliSecond = 1000; using ::mediapipe::CalculatorGraphConfig; using ::mediapipe::Image; -using ::mediapipe::tasks::components::proto::SegmenterOptions; +using ::mediapipe::NormalizedRect; +using ::mediapipe::tasks::vision::image_segmenter::proto::SegmenterOptions; using ImageSegmenterGraphOptionsProto = ::mediapipe::tasks::vision:: image_segmenter::proto::ImageSegmenterGraphOptions; diff --git a/mediapipe/tasks/cc/vision/image_segmenter/image_segmenter.h b/mediapipe/tasks/cc/vision/image_segmenter/image_segmenter.h index 43bf5b7e6..511d3b9c1 100644 --- a/mediapipe/tasks/cc/vision/image_segmenter/image_segmenter.h +++ b/mediapipe/tasks/cc/vision/image_segmenter/image_segmenter.h @@ -98,7 +98,7 @@ struct ImageSegmenterOptions { // - list of segmented masks. // - if `output_type` is CATEGORY_MASK, uint8 Image, Image vector of size 1. // - if `output_type` is CONFIDENCE_MASK, float32 Image list of size -// `cahnnels`. +// `channels`. // - batch is always 1 // An example of such model can be found at: // https://tfhub.dev/tensorflow/lite-model/deeplabv3/1/metadata/2 diff --git a/mediapipe/tasks/cc/vision/image_segmenter/image_segmenter_graph.cc b/mediapipe/tasks/cc/vision/image_segmenter/image_segmenter_graph.cc index 44742e043..923cf2937 100644 --- a/mediapipe/tasks/cc/vision/image_segmenter/image_segmenter_graph.cc +++ b/mediapipe/tasks/cc/vision/image_segmenter/image_segmenter_graph.cc @@ -26,16 +26,16 @@ limitations under the License. #include "mediapipe/framework/formats/rect.pb.h" #include "mediapipe/framework/port/status_macros.h" #include "mediapipe/tasks/cc/common.h" -#include "mediapipe/tasks/cc/components/calculators/tensor/tensors_to_segmentation_calculator.pb.h" -#include "mediapipe/tasks/cc/components/image_preprocessing.h" -#include "mediapipe/tasks/cc/components/image_preprocessing_options.pb.h" -#include "mediapipe/tasks/cc/components/proto/segmenter_options.pb.h" +#include "mediapipe/tasks/cc/components/processors/image_preprocessing_graph.h" +#include "mediapipe/tasks/cc/components/processors/proto/image_preprocessing_graph_options.pb.h" #include "mediapipe/tasks/cc/core/model_resources.h" #include "mediapipe/tasks/cc/core/model_task_graph.h" #include "mediapipe/tasks/cc/core/proto/acceleration.pb.h" #include "mediapipe/tasks/cc/core/proto/inference_subgraph.pb.h" #include "mediapipe/tasks/cc/metadata/metadata_extractor.h" +#include "mediapipe/tasks/cc/vision/image_segmenter/calculators/tensors_to_segmentation_calculator.pb.h" #include "mediapipe/tasks/cc/vision/image_segmenter/proto/image_segmenter_graph_options.pb.h" +#include "mediapipe/tasks/cc/vision/image_segmenter/proto/segmenter_options.pb.h" #include "mediapipe/tasks/metadata/metadata_schema_generated.h" #include "mediapipe/util/label_map.pb.h" #include "mediapipe/util/label_map_util.h" @@ -49,15 +49,16 @@ namespace image_segmenter { namespace { using ::mediapipe::Image; +using ::mediapipe::NormalizedRect; using ::mediapipe::api2::Input; using ::mediapipe::api2::Output; using ::mediapipe::api2::builder::Graph; using ::mediapipe::api2::builder::MultiSource; using ::mediapipe::api2::builder::Source; -using ::mediapipe::tasks::components::proto::SegmenterOptions; using ::mediapipe::tasks::metadata::ModelMetadataExtractor; using ::mediapipe::tasks::vision::image_segmenter::proto:: ImageSegmenterGraphOptions; +using ::mediapipe::tasks::vision::image_segmenter::proto::SegmenterOptions; using ::tflite::Tensor; using ::tflite::TensorMetadata; using LabelItems = mediapipe::proto_ns::Map; @@ -243,14 +244,15 @@ class ImageSegmenterGraph : public core::ModelTaskGraph { // Adds preprocessing calculators and connects them to the graph input image // stream. - auto& preprocessing = - graph.AddNode("mediapipe.tasks.components.ImagePreprocessingSubgraph"); - bool use_gpu = components::DetermineImagePreprocessingGpuBackend( - task_options.base_options().acceleration()); - MP_RETURN_IF_ERROR(ConfigureImagePreprocessing( + auto& preprocessing = graph.AddNode( + "mediapipe.tasks.components.processors.ImagePreprocessingGraph"); + bool use_gpu = + components::processors::DetermineImagePreprocessingGpuBackend( + task_options.base_options().acceleration()); + MP_RETURN_IF_ERROR(components::processors::ConfigureImagePreprocessingGraph( model_resources, use_gpu, - &preprocessing - .GetOptions())); + &preprocessing.GetOptions())); image_in >> preprocessing.In(kImageTag); norm_rect_in >> preprocessing.In(kNormRectTag); diff --git a/mediapipe/tasks/cc/vision/image_segmenter/image_segmenter_test.cc b/mediapipe/tasks/cc/vision/image_segmenter/image_segmenter_test.cc index 752a116dd..c8c6e9036 100644 --- a/mediapipe/tasks/cc/vision/image_segmenter/image_segmenter_test.cc +++ b/mediapipe/tasks/cc/vision/image_segmenter/image_segmenter_test.cc @@ -28,11 +28,11 @@ limitations under the License. #include "mediapipe/framework/port/opencv_core_inc.h" #include "mediapipe/framework/port/opencv_imgcodecs_inc.h" #include "mediapipe/framework/port/status_matchers.h" -#include "mediapipe/tasks/cc/components/calculators/tensor/tensors_to_segmentation_calculator.pb.h" #include "mediapipe/tasks/cc/components/containers/rect.h" #include "mediapipe/tasks/cc/core/proto/base_options.pb.h" #include "mediapipe/tasks/cc/core/proto/external_file.pb.h" #include "mediapipe/tasks/cc/vision/core/image_processing_options.h" +#include "mediapipe/tasks/cc/vision/image_segmenter/calculators/tensors_to_segmentation_calculator.pb.h" #include "mediapipe/tasks/cc/vision/image_segmenter/proto/image_segmenter_graph_options.pb.h" #include "mediapipe/tasks/cc/vision/utils/image_utils.h" #include "tensorflow/lite/core/shims/cc/shims_test_util.h" @@ -47,7 +47,7 @@ namespace { using ::mediapipe::Image; using ::mediapipe::file::JoinPath; -using ::mediapipe::tasks::components::containers::Rect; +using ::mediapipe::tasks::components::containers::RectF; using ::mediapipe::tasks::vision::core::ImageProcessingOptions; using ::testing::HasSubstr; using ::testing::Optional; @@ -257,10 +257,12 @@ TEST_F(ImageModeTest, SucceedsWithConfidenceMask) { SimilarToFloatMask(expected_mask_float, kGoldenMaskSimilarity)); } -TEST_F(ImageModeTest, SucceedsWithRotation) { +// TODO: fix this unit test after image segmenter handled post +// processing correctly with rotated image. +TEST_F(ImageModeTest, DISABLED_SucceedsWithRotation) { MP_ASSERT_OK_AND_ASSIGN( - Image image, DecodeImageFromFile( - JoinPath("./", kTestDataDirectory, "cat_rotated.jpg"))); + Image image, + DecodeImageFromFile(JoinPath("./", kTestDataDirectory, "cat.jpg"))); auto options = std::make_unique(); options->base_options.model_asset_path = JoinPath("./", kTestDataDirectory, kDeeplabV3WithMetadata); @@ -271,7 +273,8 @@ TEST_F(ImageModeTest, SucceedsWithRotation) { ImageSegmenter::Create(std::move(options))); ImageProcessingOptions image_processing_options; image_processing_options.rotation_degrees = -90; - MP_ASSERT_OK_AND_ASSIGN(auto confidence_masks, segmenter->Segment(image)); + MP_ASSERT_OK_AND_ASSIGN(auto confidence_masks, + segmenter->Segment(image, image_processing_options)); EXPECT_EQ(confidence_masks.size(), 21); cv::Mat expected_mask = @@ -299,7 +302,7 @@ TEST_F(ImageModeTest, FailsWithRegionOfInterest) { MP_ASSERT_OK_AND_ASSIGN(std::unique_ptr segmenter, ImageSegmenter::Create(std::move(options))); - Rect roi{/*left=*/0.1, /*top=*/0, /*right=*/0.9, /*bottom=*/1}; + RectF roi{/*left=*/0.1, /*top=*/0, /*right=*/0.9, /*bottom=*/1}; ImageProcessingOptions image_processing_options{roi, /*rotation_degrees=*/0}; auto results = segmenter->Segment(image, image_processing_options); diff --git a/mediapipe/tasks/cc/vision/image_segmenter/proto/BUILD b/mediapipe/tasks/cc/vision/image_segmenter/proto/BUILD index 3b14060f1..9523dd679 100644 --- a/mediapipe/tasks/cc/vision/image_segmenter/proto/BUILD +++ b/mediapipe/tasks/cc/vision/image_segmenter/proto/BUILD @@ -18,13 +18,18 @@ package(default_visibility = ["//mediapipe/tasks:internal"]) licenses(["notice"]) +mediapipe_proto_library( + name = "segmenter_options_proto", + srcs = ["segmenter_options.proto"], +) + mediapipe_proto_library( name = "image_segmenter_graph_options_proto", srcs = ["image_segmenter_graph_options.proto"], deps = [ + ":segmenter_options_proto", "//mediapipe/framework:calculator_options_proto", "//mediapipe/framework:calculator_proto", - "//mediapipe/tasks/cc/components/proto:segmenter_options_proto", "//mediapipe/tasks/cc/core/proto:base_options_proto", ], ) diff --git a/mediapipe/tasks/cc/vision/image_segmenter/proto/image_segmenter_graph_options.proto b/mediapipe/tasks/cc/vision/image_segmenter/proto/image_segmenter_graph_options.proto index 166e2e8e0..5c7d2ec71 100644 --- a/mediapipe/tasks/cc/vision/image_segmenter/proto/image_segmenter_graph_options.proto +++ b/mediapipe/tasks/cc/vision/image_segmenter/proto/image_segmenter_graph_options.proto @@ -18,8 +18,9 @@ syntax = "proto2"; package mediapipe.tasks.vision.image_segmenter.proto; import "mediapipe/framework/calculator.proto"; -import "mediapipe/tasks/cc/components/proto/segmenter_options.proto"; +import "mediapipe/framework/calculator_options.proto"; import "mediapipe/tasks/cc/core/proto/base_options.proto"; +import "mediapipe/tasks/cc/vision/image_segmenter/proto/segmenter_options.proto"; option java_package = "com.google.mediapipe.tasks.vision.imagesegmenter.proto"; option java_outer_classname = "ImageSegmenterGraphOptionsProto"; @@ -37,5 +38,5 @@ message ImageSegmenterGraphOptions { optional string display_names_locale = 2 [default = "en"]; // Segmentation output options. - optional components.proto.SegmenterOptions segmenter_options = 3; + optional SegmenterOptions segmenter_options = 3; } diff --git a/mediapipe/tasks/cc/components/proto/segmenter_options.proto b/mediapipe/tasks/cc/vision/image_segmenter/proto/segmenter_options.proto similarity index 92% rename from mediapipe/tasks/cc/components/proto/segmenter_options.proto rename to mediapipe/tasks/cc/vision/image_segmenter/proto/segmenter_options.proto index ca9986707..be2b8a589 100644 --- a/mediapipe/tasks/cc/components/proto/segmenter_options.proto +++ b/mediapipe/tasks/cc/vision/image_segmenter/proto/segmenter_options.proto @@ -15,9 +15,9 @@ limitations under the License. syntax = "proto2"; -package mediapipe.tasks.components.proto; +package mediapipe.tasks.vision.image_segmenter.proto; -option java_package = "com.google.mediapipe.tasks.components.proto"; +option java_package = "com.google.mediapipe.tasks.vision.imagesegmenter.proto"; option java_outer_classname = "SegmenterOptionsProto"; // Shared options used by image segmentation tasks. diff --git a/mediapipe/tasks/cc/vision/object_detector/BUILD b/mediapipe/tasks/cc/vision/object_detector/BUILD index 8220d8b7f..0238449c7 100644 --- a/mediapipe/tasks/cc/vision/object_detector/BUILD +++ b/mediapipe/tasks/cc/vision/object_detector/BUILD @@ -16,50 +16,13 @@ package(default_visibility = ["//mediapipe/tasks:internal"]) licenses(["notice"]) -cc_library( - name = "object_detector_graph", - srcs = ["object_detector_graph.cc"], - deps = [ - "//mediapipe/calculators/core:split_vector_calculator_cc_proto", - "//mediapipe/calculators/tensor:inference_calculator", - "//mediapipe/calculators/tensor:tensors_to_detections_calculator", - "//mediapipe/calculators/tensor:tensors_to_detections_calculator_cc_proto", - "//mediapipe/calculators/util:detection_label_id_to_text_calculator", - "//mediapipe/calculators/util:detection_label_id_to_text_calculator_cc_proto", - "//mediapipe/calculators/util:detection_projection_calculator", - "//mediapipe/calculators/util:detection_transformation_calculator", - "//mediapipe/framework:calculator_cc_proto", - "//mediapipe/framework/api2:builder", - "//mediapipe/framework/api2:port", - "//mediapipe/framework/formats:detection_cc_proto", - "//mediapipe/framework/formats:image", - "//mediapipe/framework/formats:rect_cc_proto", - "//mediapipe/framework/formats:tensor", - "//mediapipe/tasks/cc:common", - "//mediapipe/tasks/cc/components:image_preprocessing", - "//mediapipe/tasks/cc/components/calculators:score_calibration_calculator_cc_proto", - "//mediapipe/tasks/cc/components/calculators:score_calibration_utils", - "//mediapipe/tasks/cc/components/utils:source_or_node_output", - "//mediapipe/tasks/cc/core:model_resources", - "//mediapipe/tasks/cc/core:model_task_graph", - "//mediapipe/tasks/cc/core:utils", - "//mediapipe/tasks/cc/core/proto:acceleration_cc_proto", - "//mediapipe/tasks/cc/core/proto:inference_subgraph_cc_proto", - "//mediapipe/tasks/cc/metadata:metadata_extractor", - "//mediapipe/tasks/cc/vision/object_detector/proto:object_detector_options_cc_proto", - "//mediapipe/tasks/metadata:metadata_schema_cc", - "//mediapipe/util:label_map_cc_proto", - "//mediapipe/util:label_map_util", - "@com_google_absl//absl/status", - "@com_google_absl//absl/status:statusor", - ], - alwayslink = 1, -) - +# Docs for Mediapipe Tasks Object Detector +# https://developers.google.com/mediapipe/solutions/vision/object_detector cc_library( name = "object_detector", srcs = ["object_detector.cc"], hdrs = ["object_detector.h"], + visibility = ["//visibility:public"], deps = [ ":object_detector_graph", "//mediapipe/calculators/core:concatenate_vector_calculator", @@ -70,6 +33,7 @@ cc_library( "//mediapipe/framework/formats:rect_cc_proto", "//mediapipe/tasks/cc:common", "//mediapipe/tasks/cc/components/calculators:score_calibration_calculator", + "//mediapipe/tasks/cc/components/containers:detection_result", "//mediapipe/tasks/cc/core:base_options", "//mediapipe/tasks/cc/core:utils", "//mediapipe/tasks/cc/core/proto:base_options_cc_proto", @@ -86,4 +50,44 @@ cc_library( ], ) +cc_library( + name = "object_detector_graph", + srcs = ["object_detector_graph.cc"], + deps = [ + "//mediapipe/calculators/core:split_vector_calculator_cc_proto", + "//mediapipe/calculators/tensor:inference_calculator", + "//mediapipe/calculators/tensor:tensors_to_detections_calculator", + "//mediapipe/calculators/tensor:tensors_to_detections_calculator_cc_proto", + "//mediapipe/calculators/util:detection_label_id_to_text_calculator", + "//mediapipe/calculators/util:detection_label_id_to_text_calculator_cc_proto", + "//mediapipe/calculators/util:detection_projection_calculator", + "//mediapipe/calculators/util:detection_transformation_calculator", + "//mediapipe/calculators/util:detections_deduplicate_calculator", + "//mediapipe/framework:calculator_cc_proto", + "//mediapipe/framework/api2:builder", + "//mediapipe/framework/api2:port", + "//mediapipe/framework/formats:detection_cc_proto", + "//mediapipe/framework/formats:image", + "//mediapipe/framework/formats:rect_cc_proto", + "//mediapipe/framework/formats:tensor", + "//mediapipe/tasks/cc:common", + "//mediapipe/tasks/cc/components/calculators:score_calibration_calculator_cc_proto", + "//mediapipe/tasks/cc/components/calculators:score_calibration_utils", + "//mediapipe/tasks/cc/components/processors:image_preprocessing_graph", + "//mediapipe/tasks/cc/core:model_resources", + "//mediapipe/tasks/cc/core:model_task_graph", + "//mediapipe/tasks/cc/core:utils", + "//mediapipe/tasks/cc/core/proto:acceleration_cc_proto", + "//mediapipe/tasks/cc/core/proto:inference_subgraph_cc_proto", + "//mediapipe/tasks/cc/metadata:metadata_extractor", + "//mediapipe/tasks/cc/vision/object_detector/proto:object_detector_options_cc_proto", + "//mediapipe/tasks/metadata:metadata_schema_cc", + "//mediapipe/util:label_map_cc_proto", + "//mediapipe/util:label_map_util", + "@com_google_absl//absl/status", + "@com_google_absl//absl/status:statusor", + ], + alwayslink = 1, +) + # TODO: This test fails in OSS diff --git a/mediapipe/tasks/cc/vision/object_detector/object_detector.cc b/mediapipe/tasks/cc/vision/object_detector/object_detector.cc index dd19237ff..2477f8a44 100644 --- a/mediapipe/tasks/cc/vision/object_detector/object_detector.cc +++ b/mediapipe/tasks/cc/vision/object_detector/object_detector.cc @@ -30,6 +30,7 @@ limitations under the License. #include "mediapipe/framework/formats/image.h" #include "mediapipe/framework/formats/rect.pb.h" #include "mediapipe/tasks/cc/common.h" +#include "mediapipe/tasks/cc/components/containers/detection_result.h" #include "mediapipe/tasks/cc/core/base_options.h" #include "mediapipe/tasks/cc/core/proto/base_options.pb.h" #include "mediapipe/tasks/cc/core/proto/inference_subgraph.pb.h" @@ -56,6 +57,8 @@ constexpr char kSubgraphTypeName[] = "mediapipe.tasks.vision.ObjectDetectorGraph"; constexpr int kMicroSecondsPerMilliSecond = 1000; +using ::mediapipe::NormalizedRect; +using ::mediapipe::tasks::components::containers::ConvertToDetectionResult; using ObjectDetectorOptionsProto = object_detector::proto::ObjectDetectorOptions; @@ -129,7 +132,8 @@ absl::StatusOr> ObjectDetector::Create( Packet detections_packet = status_or_packets.value()[kDetectionsOutStreamName]; Packet image_packet = status_or_packets.value()[kImageOutStreamName]; - result_callback(detections_packet.Get>(), + result_callback(ConvertToDetectionResult( + detections_packet.Get>()), image_packet.Get(), detections_packet.Timestamp().Value() / kMicroSecondsPerMilliSecond); @@ -144,7 +148,7 @@ absl::StatusOr> ObjectDetector::Create( std::move(packets_callback)); } -absl::StatusOr> ObjectDetector::Detect( +absl::StatusOr ObjectDetector::Detect( mediapipe::Image image, std::optional image_processing_options) { if (image.UsesGpu()) { @@ -161,10 +165,11 @@ absl::StatusOr> ObjectDetector::Detect( ProcessImageData( {{kImageInStreamName, MakePacket(std::move(image))}, {kNormRectName, MakePacket(std::move(norm_rect))}})); - return output_packets[kDetectionsOutStreamName].Get>(); + return ConvertToDetectionResult( + output_packets[kDetectionsOutStreamName].Get>()); } -absl::StatusOr> ObjectDetector::DetectForVideo( +absl::StatusOr ObjectDetector::DetectForVideo( mediapipe::Image image, int64 timestamp_ms, std::optional image_processing_options) { if (image.UsesGpu()) { @@ -185,7 +190,8 @@ absl::StatusOr> ObjectDetector::DetectForVideo( {kNormRectName, MakePacket(std::move(norm_rect)) .At(Timestamp(timestamp_ms * kMicroSecondsPerMilliSecond))}})); - return output_packets[kDetectionsOutStreamName].Get>(); + return ConvertToDetectionResult( + output_packets[kDetectionsOutStreamName].Get>()); } absl::Status ObjectDetector::DetectAsync( diff --git a/mediapipe/tasks/cc/vision/object_detector/object_detector.h b/mediapipe/tasks/cc/vision/object_detector/object_detector.h index 44ce68ed9..249a2ebf5 100644 --- a/mediapipe/tasks/cc/vision/object_detector/object_detector.h +++ b/mediapipe/tasks/cc/vision/object_detector/object_detector.h @@ -27,6 +27,7 @@ limitations under the License. #include "absl/status/statusor.h" #include "mediapipe/framework/formats/detection.pb.h" #include "mediapipe/framework/formats/image.h" +#include "mediapipe/tasks/cc/components/containers/detection_result.h" #include "mediapipe/tasks/cc/core/base_options.h" #include "mediapipe/tasks/cc/vision/core/base_vision_task_api.h" #include "mediapipe/tasks/cc/vision/core/image_processing_options.h" @@ -36,6 +37,10 @@ namespace mediapipe { namespace tasks { namespace vision { +// Alias the shared DetectionResult struct as result typo. +using ObjectDetectorResult = + ::mediapipe::tasks::components::containers::DetectionResult; + // The options for configuring a mediapipe object detector task. struct ObjectDetectorOptions { // Base options for configuring MediaPipe Tasks, such as specifying the TfLite @@ -79,8 +84,7 @@ struct ObjectDetectorOptions { // The user-defined result callback for processing live stream data. // The result callback should only be specified when the running mode is set // to RunningMode::LIVE_STREAM. - std::function>, - const Image&, int64)> + std::function, const Image&, int64)> result_callback = nullptr; }; @@ -165,7 +169,7 @@ class ObjectDetector : tasks::vision::core::BaseVisionTaskApi { // underlying image data. // TODO: Describes the output bounding boxes for gpu input // images after enabling the gpu support in MediaPipe Tasks. - absl::StatusOr> Detect( + absl::StatusOr Detect( mediapipe::Image image, std::optional image_processing_options = std::nullopt); @@ -188,7 +192,7 @@ class ObjectDetector : tasks::vision::core::BaseVisionTaskApi { // unrotated input frame of reference coordinates system, i.e. in `[0, // image_width) x [0, image_height)`, which are the dimensions of the // underlying image data. - absl::StatusOr> DetectForVideo( + absl::StatusOr DetectForVideo( mediapipe::Image image, int64 timestamp_ms, std::optional image_processing_options = std::nullopt); diff --git a/mediapipe/tasks/cc/vision/object_detector/object_detector_graph.cc b/mediapipe/tasks/cc/vision/object_detector/object_detector_graph.cc index b149cea0f..cb85fc46f 100644 --- a/mediapipe/tasks/cc/vision/object_detector/object_detector_graph.cc +++ b/mediapipe/tasks/cc/vision/object_detector/object_detector_graph.cc @@ -33,8 +33,7 @@ limitations under the License. #include "mediapipe/tasks/cc/common.h" #include "mediapipe/tasks/cc/components/calculators/score_calibration_calculator.pb.h" #include "mediapipe/tasks/cc/components/calculators/score_calibration_utils.h" -#include "mediapipe/tasks/cc/components/image_preprocessing.h" -#include "mediapipe/tasks/cc/components/utils/source_or_node_output.h" +#include "mediapipe/tasks/cc/components/processors/image_preprocessing_graph.h" #include "mediapipe/tasks/cc/core/model_resources.h" #include "mediapipe/tasks/cc/core/model_task_graph.h" #include "mediapipe/tasks/cc/core/proto/acceleration.pb.h" @@ -52,6 +51,7 @@ namespace vision { namespace { +using ::mediapipe::NormalizedRect; using ::mediapipe::api2::Input; using ::mediapipe::api2::Output; using ::mediapipe::api2::builder::Graph; @@ -68,7 +68,7 @@ using LabelItems = mediapipe::proto_ns::Map; using ObjectDetectorOptionsProto = object_detector::proto::ObjectDetectorOptions; using TensorsSource = - mediapipe::tasks::SourceOrNodeOutput>; + mediapipe::api2::builder::Source>; constexpr int kDefaultLocationsIndex = 0; constexpr int kDefaultCategoriesIndex = 1; @@ -532,8 +532,7 @@ class ObjectDetectorGraph : public core::ModelTaskGraph { MP_RETURN_IF_ERROR(SanityCheckOptions(task_options)); // Checks that the model has 4 outputs. auto& model = *model_resources.GetTfLiteModel(); - if (model.subgraphs()->size() != 1 || - (*model.subgraphs())[0]->outputs()->size() != 4) { + if (model.subgraphs()->size() != 1) { return CreateStatusWithPayload( absl::StatusCode::kInvalidArgument, absl::StrFormat("Expected a model with a single subgraph, found %d.", @@ -561,14 +560,15 @@ class ObjectDetectorGraph : public core::ModelTaskGraph { // Adds preprocessing calculators and connects them to the graph input image // stream. - auto& preprocessing = - graph.AddNode("mediapipe.tasks.components.ImagePreprocessingSubgraph"); - bool use_gpu = components::DetermineImagePreprocessingGpuBackend( - task_options.base_options().acceleration()); - MP_RETURN_IF_ERROR(ConfigureImagePreprocessing( + auto& preprocessing = graph.AddNode( + "mediapipe.tasks.components.processors.ImagePreprocessingGraph"); + bool use_gpu = + components::processors::DetermineImagePreprocessingGpuBackend( + task_options.base_options().acceleration()); + MP_RETURN_IF_ERROR(components::processors::ConfigureImagePreprocessingGraph( model_resources, use_gpu, - &preprocessing - .GetOptions())); + &preprocessing.GetOptions())); image_in >> preprocessing.In(kImageTag); norm_rect_in >> preprocessing.In(kNormRectTag); @@ -583,7 +583,8 @@ class ObjectDetectorGraph : public core::ModelTaskGraph { auto post_processing_specs, BuildPostProcessingSpecs(task_options, metadata_extractor)); // Calculators to perform score calibration, if specified in the metadata. - TensorsSource calibrated_tensors = {&inference, kTensorTag}; + TensorsSource calibrated_tensors = + inference.Out(kTensorTag).Cast>(); if (post_processing_specs.score_calibration_options.has_value()) { // Split tensors. auto* split_tensor_vector_node = @@ -622,7 +623,8 @@ class ObjectDetectorGraph : public core::ModelTaskGraph { concatenate_tensor_vector_node->In(i); } } - calibrated_tensors = {concatenate_tensor_vector_node, 0}; + calibrated_tensors = + concatenate_tensor_vector_node->Out(0).Cast>(); } // Calculator to convert output tensors to a detection proto vector. // Connects TensorsToDetectionsCalculator's input stream to the output @@ -662,11 +664,16 @@ class ObjectDetectorGraph : public core::ModelTaskGraph { detection_transformation.Out(kPixelDetectionsTag) >> detection_label_id_to_text.In(""); + // Deduplicate Detections with same bounding box coordinates. + auto& detections_deduplicate = + graph.AddNode("DetectionsDeduplicateCalculator"); + detection_label_id_to_text.Out("") >> detections_deduplicate.In(""); + // Outputs the labeled detections and the processed image as the subgraph // output streams. return {{ /* detections= */ - detection_label_id_to_text[Output>("")], + detections_deduplicate[Output>("")], /* image= */ preprocessing[Output(kImageTag)], }}; } diff --git a/mediapipe/tasks/cc/vision/object_detector/object_detector_test.cc b/mediapipe/tasks/cc/vision/object_detector/object_detector_test.cc index 1747685dd..798e3f238 100644 --- a/mediapipe/tasks/cc/vision/object_detector/object_detector_test.cc +++ b/mediapipe/tasks/cc/vision/object_detector/object_detector_test.cc @@ -35,6 +35,7 @@ limitations under the License. #include "mediapipe/framework/port/gtest.h" #include "mediapipe/framework/port/parse_text_proto.h" #include "mediapipe/framework/port/status_matchers.h" +#include "mediapipe/tasks/cc/components/containers/detection_result.h" #include "mediapipe/tasks/cc/components/containers/rect.h" #include "mediapipe/tasks/cc/vision/core/image_processing_options.h" #include "mediapipe/tasks/cc/vision/core/running_mode.h" @@ -65,10 +66,14 @@ namespace vision { namespace { using ::mediapipe::file::JoinPath; -using ::mediapipe::tasks::components::containers::Rect; +using ::mediapipe::tasks::components::containers::ConvertToDetectionResult; +using ::mediapipe::tasks::components::containers::Detection; +using ::mediapipe::tasks::components::containers::DetectionResult; +using ::mediapipe::tasks::components::containers::RectF; using ::mediapipe::tasks::vision::core::ImageProcessingOptions; using ::testing::HasSubstr; using ::testing::Optional; +using DetectionProto = mediapipe::Detection; constexpr char kTestDataDirectory[] = "/mediapipe/tasks/testdata/vision/"; constexpr char kMobileSsdWithMetadata[] = @@ -83,47 +88,45 @@ constexpr char kEfficientDetWithMetadata[] = // Checks that the two provided `Detection` proto vectors are equal, with a // tolerancy on floating-point scores to account for numerical instabilities. // If the proto definition changes, please also change this function. -void ExpectApproximatelyEqual(const std::vector& actual, - const std::vector& expected) { +void ExpectApproximatelyEqual(const ObjectDetectorResult& actual, + const ObjectDetectorResult& expected) { const float kPrecision = 1e-6; - EXPECT_EQ(actual.size(), expected.size()); - for (int i = 0; i < actual.size(); ++i) { - const Detection& a = actual[i]; - const Detection& b = expected[i]; - EXPECT_THAT(a.location_data().bounding_box(), - EqualsProto(b.location_data().bounding_box())); - EXPECT_EQ(a.label_size(), 1); - EXPECT_EQ(b.label_size(), 1); - EXPECT_EQ(a.label(0), b.label(0)); - EXPECT_EQ(a.score_size(), 1); - EXPECT_EQ(b.score_size(), 1); - EXPECT_NEAR(a.score(0), b.score(0), kPrecision); + EXPECT_EQ(actual.detections.size(), expected.detections.size()); + for (int i = 0; i < actual.detections.size(); ++i) { + const Detection& a = actual.detections[i]; + const Detection& b = expected.detections[i]; + EXPECT_EQ(a.bounding_box, b.bounding_box); + EXPECT_EQ(a.categories.size(), 1); + EXPECT_EQ(b.categories.size(), 1); + EXPECT_EQ(a.categories[0].category_name, b.categories[0].category_name); + EXPECT_NEAR(a.categories[0].score, b.categories[0].score, kPrecision); } } -std::vector GenerateMobileSsdNoImageResizingFullExpectedResults() { - return {ParseTextProtoOrDie(R"pb( +std::vector +GenerateMobileSsdNoImageResizingFullExpectedResults() { + return {ParseTextProtoOrDie(R"pb( label: "cat" score: 0.6328125 location_data { format: BOUNDING_BOX bounding_box { xmin: 14 ymin: 197 width: 98 height: 99 } })pb"), - ParseTextProtoOrDie(R"pb( + ParseTextProtoOrDie(R"pb( label: "cat" score: 0.59765625 location_data { format: BOUNDING_BOX bounding_box { xmin: 151 ymin: 78 width: 104 height: 223 } })pb"), - ParseTextProtoOrDie(R"pb( + ParseTextProtoOrDie(R"pb( label: "cat" score: 0.5 location_data { format: BOUNDING_BOX bounding_box { xmin: 65 ymin: 199 width: 41 height: 101 } })pb"), - ParseTextProtoOrDie(R"pb( + ParseTextProtoOrDie(R"pb( label: "dog" score: 0.48828125 location_data { @@ -263,8 +266,8 @@ TEST_F(CreateFromOptionsTest, FailsWithIllegalCallbackInImageOrVideoMode) { JoinPath("./", kTestDataDirectory, kMobileSsdWithMetadata); options->running_mode = running_mode; options->result_callback = - [](absl::StatusOr> detections, - const Image& image, int64 timestamp_ms) {}; + [](absl::StatusOr detections, const Image& image, + int64 timestamp_ms) {}; absl::StatusOr> object_detector = ObjectDetector::Create(std::move(options)); EXPECT_EQ(object_detector.status().code(), @@ -340,34 +343,36 @@ TEST_F(ImageModeTest, Succeeds) { MP_ASSERT_OK_AND_ASSIGN(auto results, object_detector->Detect(image)); MP_ASSERT_OK(object_detector->Close()); ExpectApproximatelyEqual( - results, {ParseTextProtoOrDie(R"pb( - label: "cat" - score: 0.69921875 - location_data { - format: BOUNDING_BOX - bounding_box { xmin: 608 ymin: 161 width: 381 height: 439 } - })pb"), - ParseTextProtoOrDie(R"pb( - label: "cat" - score: 0.64453125 - location_data { - format: BOUNDING_BOX - bounding_box { xmin: 60 ymin: 398 width: 386 height: 196 } - })pb"), - ParseTextProtoOrDie(R"pb( - label: "cat" - score: 0.51171875 - location_data { - format: BOUNDING_BOX - bounding_box { xmin: 256 ymin: 395 width: 173 height: 202 } - })pb"), - ParseTextProtoOrDie(R"pb( - label: "cat" - score: 0.48828125 - location_data { - format: BOUNDING_BOX - bounding_box { xmin: 362 ymin: 191 width: 325 height: 419 } - })pb")}); + results, + ConvertToDetectionResult( + {ParseTextProtoOrDie(R"pb( + label: "cat" + score: 0.69921875 + location_data { + format: BOUNDING_BOX + bounding_box { xmin: 608 ymin: 161 width: 381 height: 439 } + })pb"), + ParseTextProtoOrDie(R"pb( + label: "cat" + score: 0.64453125 + location_data { + format: BOUNDING_BOX + bounding_box { xmin: 60 ymin: 398 width: 386 height: 196 } + })pb"), + ParseTextProtoOrDie(R"pb( + label: "cat" + score: 0.51171875 + location_data { + format: BOUNDING_BOX + bounding_box { xmin: 256 ymin: 395 width: 173 height: 202 } + })pb"), + ParseTextProtoOrDie(R"pb( + label: "cat" + score: 0.48828125 + location_data { + format: BOUNDING_BOX + bounding_box { xmin: 362 ymin: 191 width: 325 height: 419 } + })pb")})); } TEST_F(ImageModeTest, SucceedsEfficientDetModel) { @@ -383,34 +388,36 @@ TEST_F(ImageModeTest, SucceedsEfficientDetModel) { MP_ASSERT_OK_AND_ASSIGN(auto results, object_detector->Detect(image)); MP_ASSERT_OK(object_detector->Close()); ExpectApproximatelyEqual( - results, {ParseTextProtoOrDie(R"pb( - label: "cat" - score: 0.7578125 - location_data { - format: BOUNDING_BOX - bounding_box { xmin: 858 ymin: 408 width: 225 height: 187 } - })pb"), - ParseTextProtoOrDie(R"pb( - label: "cat" - score: 0.72265625 - location_data { - format: BOUNDING_BOX - bounding_box { xmin: 67 ymin: 401 width: 399 height: 192 } - })pb"), - ParseTextProtoOrDie(R"pb( - label: "cat" - score: 0.6289063 - location_data { - format: BOUNDING_BOX - bounding_box { xmin: 368 ymin: 210 width: 272 height: 385 } - })pb"), - ParseTextProtoOrDie(R"pb( - label: "cat" - score: 0.5859375 - location_data { - format: BOUNDING_BOX - bounding_box { xmin: 601 ymin: 166 width: 298 height: 437 } - })pb")}); + results, + ConvertToDetectionResult( + {ParseTextProtoOrDie(R"pb( + label: "cat" + score: 0.7578125 + location_data { + format: BOUNDING_BOX + bounding_box { xmin: 858 ymin: 408 width: 225 height: 187 } + })pb"), + ParseTextProtoOrDie(R"pb( + label: "cat" + score: 0.72265625 + location_data { + format: BOUNDING_BOX + bounding_box { xmin: 67 ymin: 401 width: 399 height: 192 } + })pb"), + ParseTextProtoOrDie(R"pb( + label: "cat" + score: 0.6289063 + location_data { + format: BOUNDING_BOX + bounding_box { xmin: 368 ymin: 210 width: 272 height: 385 } + })pb"), + ParseTextProtoOrDie(R"pb( + label: "cat" + score: 0.5859375 + location_data { + format: BOUNDING_BOX + bounding_box { xmin: 601 ymin: 166 width: 298 height: 437 } + })pb")})); } TEST_F(ImageModeTest, SucceedsWithoutImageResizing) { @@ -426,7 +433,8 @@ TEST_F(ImageModeTest, SucceedsWithoutImageResizing) { MP_ASSERT_OK_AND_ASSIGN(auto results, object_detector->Detect(image)); MP_ASSERT_OK(object_detector->Close()); ExpectApproximatelyEqual( - results, GenerateMobileSsdNoImageResizingFullExpectedResults()); + results, ConvertToDetectionResult( + GenerateMobileSsdNoImageResizingFullExpectedResults())); } TEST_F(ImageModeTest, SucceedsWithScoreCalibration) { @@ -442,13 +450,14 @@ TEST_F(ImageModeTest, SucceedsWithScoreCalibration) { MP_ASSERT_OK_AND_ASSIGN(auto results, object_detector->Detect(image)); MP_ASSERT_OK(object_detector->Close()); ExpectApproximatelyEqual( - results, {ParseTextProtoOrDie(R"pb( + results, + ConvertToDetectionResult({ParseTextProtoOrDie(R"pb( label: "cat" score: 0.6531269142 location_data { format: BOUNDING_BOX bounding_box { xmin: 14 ymin: 197 width: 98 height: 99 } - })pb")}); + })pb")})); } TEST_F(ImageModeTest, SucceedsWithScoreThresholdOption) { @@ -463,11 +472,13 @@ TEST_F(ImageModeTest, SucceedsWithScoreThresholdOption) { ObjectDetector::Create(std::move(options))); MP_ASSERT_OK_AND_ASSIGN(auto results, object_detector->Detect(image)); MP_ASSERT_OK(object_detector->Close()); - std::vector full_expected_results = + std::vector full_expected_results = GenerateMobileSsdNoImageResizingFullExpectedResults(); - ExpectApproximatelyEqual(results, - {full_expected_results[0], full_expected_results[1], - full_expected_results[2]}); + + ExpectApproximatelyEqual( + results, ConvertToDetectionResult({full_expected_results[0], + full_expected_results[1], + full_expected_results[2]})); } TEST_F(ImageModeTest, SucceedsWithMaxResultsOption) { @@ -482,10 +493,11 @@ TEST_F(ImageModeTest, SucceedsWithMaxResultsOption) { ObjectDetector::Create(std::move(options))); MP_ASSERT_OK_AND_ASSIGN(auto results, object_detector->Detect(image)); MP_ASSERT_OK(object_detector->Close()); - std::vector full_expected_results = + std::vector full_expected_results = GenerateMobileSsdNoImageResizingFullExpectedResults(); ExpectApproximatelyEqual( - results, {full_expected_results[0], full_expected_results[1]}); + results, ConvertToDetectionResult( + {full_expected_results[0], full_expected_results[1]})); } TEST_F(ImageModeTest, SucceedsWithAllowlistOption) { @@ -501,9 +513,10 @@ TEST_F(ImageModeTest, SucceedsWithAllowlistOption) { ObjectDetector::Create(std::move(options))); MP_ASSERT_OK_AND_ASSIGN(auto results, object_detector->Detect(image)); MP_ASSERT_OK(object_detector->Close()); - std::vector full_expected_results = + std::vector full_expected_results = GenerateMobileSsdNoImageResizingFullExpectedResults(); - ExpectApproximatelyEqual(results, {full_expected_results[3]}); + ExpectApproximatelyEqual( + results, ConvertToDetectionResult({full_expected_results[3]})); } TEST_F(ImageModeTest, SucceedsWithDenylistOption) { @@ -519,9 +532,10 @@ TEST_F(ImageModeTest, SucceedsWithDenylistOption) { ObjectDetector::Create(std::move(options))); MP_ASSERT_OK_AND_ASSIGN(auto results, object_detector->Detect(image)); MP_ASSERT_OK(object_detector->Close()); - std::vector full_expected_results = + std::vector full_expected_results = GenerateMobileSsdNoImageResizingFullExpectedResults(); - ExpectApproximatelyEqual(results, {full_expected_results[3]}); + ExpectApproximatelyEqual( + results, ConvertToDetectionResult({full_expected_results[3]})); } TEST_F(ImageModeTest, SucceedsWithRotation) { @@ -541,13 +555,14 @@ TEST_F(ImageModeTest, SucceedsWithRotation) { auto results, object_detector->Detect(image, image_processing_options)); MP_ASSERT_OK(object_detector->Close()); ExpectApproximatelyEqual( - results, {ParseTextProtoOrDie(R"pb( + results, + ConvertToDetectionResult({ParseTextProtoOrDie(R"pb( label: "cat" score: 0.7109375 location_data { format: BOUNDING_BOX bounding_box { xmin: 0 ymin: 622 width: 436 height: 276 } - })pb")}); + })pb")})); } TEST_F(ImageModeTest, FailsWithRegionOfInterest) { @@ -560,7 +575,7 @@ TEST_F(ImageModeTest, FailsWithRegionOfInterest) { JoinPath("./", kTestDataDirectory, kMobileSsdWithMetadata); MP_ASSERT_OK_AND_ASSIGN(std::unique_ptr object_detector, ObjectDetector::Create(std::move(options))); - Rect roi{/*left=*/0.1, /*top=*/0, /*right=*/0.9, /*bottom=*/1}; + RectF roi{/*left=*/0.1, /*top=*/0, /*right=*/0.9, /*bottom=*/1}; ImageProcessingOptions image_processing_options{roi, /*rotation_degrees=*/0}; auto results = object_detector->Detect(image, image_processing_options); @@ -619,10 +634,11 @@ TEST_F(VideoModeTest, Succeeds) { for (int i = 0; i < iterations; ++i) { MP_ASSERT_OK_AND_ASSIGN(auto results, object_detector->DetectForVideo(image, i)); - std::vector full_expected_results = + std::vector full_expected_results = GenerateMobileSsdNoImageResizingFullExpectedResults(); ExpectApproximatelyEqual( - results, {full_expected_results[0], full_expected_results[1]}); + results, ConvertToDetectionResult( + {full_expected_results[0], full_expected_results[1]})); } MP_ASSERT_OK(object_detector->Close()); } @@ -637,9 +653,8 @@ TEST_F(LiveStreamModeTest, FailsWithCallingWrongMethod) { options->base_options.model_asset_path = JoinPath("./", kTestDataDirectory, kMobileSsdWithMetadata); options->running_mode = core::RunningMode::LIVE_STREAM; - options->result_callback = - [](absl::StatusOr> detections, const Image& image, - int64 timestamp_ms) {}; + options->result_callback = [](absl::StatusOr detections, + const Image& image, int64 timestamp_ms) {}; MP_ASSERT_OK_AND_ASSIGN(std::unique_ptr object_detector, ObjectDetector::Create(std::move(options))); @@ -669,9 +684,8 @@ TEST_F(LiveStreamModeTest, FailsWithOutOfOrderInputTimestamps) { options->running_mode = core::RunningMode::LIVE_STREAM; options->base_options.model_asset_path = JoinPath("./", kTestDataDirectory, kMobileSsdWithMetadata); - options->result_callback = - [](absl::StatusOr> detections, const Image& image, - int64 timestamp_ms) {}; + options->result_callback = [](absl::StatusOr detections, + const Image& image, int64 timestamp_ms) {}; MP_ASSERT_OK_AND_ASSIGN(std::unique_ptr object_detector, ObjectDetector::Create(std::move(options))); MP_ASSERT_OK(object_detector->DetectAsync(image, 1)); @@ -695,14 +709,14 @@ TEST_F(LiveStreamModeTest, Succeeds) { auto options = std::make_unique(); options->max_results = 2; options->running_mode = core::RunningMode::LIVE_STREAM; - std::vector> detection_results; + std::vector detection_results; std::vector> image_sizes; std::vector timestamps; options->base_options.model_asset_path = JoinPath("./", kTestDataDirectory, kMobileSsdWithMetadata); options->result_callback = [&detection_results, &image_sizes, ×tamps]( - absl::StatusOr> detections, const Image& image, + absl::StatusOr detections, const Image& image, int64 timestamp_ms) { MP_ASSERT_OK(detections.status()); detection_results.push_back(std::move(detections).value()); @@ -719,11 +733,12 @@ TEST_F(LiveStreamModeTest, Succeeds) { // number of iterations. ASSERT_LE(detection_results.size(), iterations); ASSERT_GT(detection_results.size(), 0); - std::vector full_expected_results = + std::vector full_expected_results = GenerateMobileSsdNoImageResizingFullExpectedResults(); for (const auto& detection_result : detection_results) { ExpectApproximatelyEqual( - detection_result, {full_expected_results[0], full_expected_results[1]}); + detection_result, ConvertToDetectionResult({full_expected_results[0], + full_expected_results[1]})); } for (const auto& image_size : image_sizes) { EXPECT_EQ(image_size.first, image.width()); diff --git a/mediapipe/tasks/cc/vision/object_detector/proto/object_detector_options.proto b/mediapipe/tasks/cc/vision/object_detector/proto/object_detector_options.proto index cba58ace8..3f6932f8f 100644 --- a/mediapipe/tasks/cc/vision/object_detector/proto/object_detector_options.proto +++ b/mediapipe/tasks/cc/vision/object_detector/proto/object_detector_options.proto @@ -18,6 +18,7 @@ syntax = "proto2"; package mediapipe.tasks.vision.object_detector.proto; import "mediapipe/framework/calculator.proto"; +import "mediapipe/framework/calculator_options.proto"; import "mediapipe/tasks/cc/core/proto/base_options.proto"; option java_package = "com.google.mediapipe.tasks.vision.objectdetector.proto"; diff --git a/mediapipe/tasks/cc/vision/utils/landmarks_utils.cc b/mediapipe/tasks/cc/vision/utils/landmarks_utils.cc index 2ce9e2454..fe4e63824 100644 --- a/mediapipe/tasks/cc/vision/utils/landmarks_utils.cc +++ b/mediapipe/tasks/cc/vision/utils/landmarks_utils.cc @@ -22,13 +22,13 @@ limitations under the License. namespace mediapipe::tasks::vision::utils { -using ::mediapipe::tasks::components::containers::Rect; +using ::mediapipe::tasks::components::containers::RectF; -float CalculateArea(const Rect& rect) { +float CalculateArea(const RectF& rect) { return (rect.right - rect.left) * (rect.bottom - rect.top); } -float CalculateIntersectionArea(const Rect& a, const Rect& b) { +float CalculateIntersectionArea(const RectF& a, const RectF& b) { const float intersection_left = std::max(a.left, b.left); const float intersection_top = std::max(a.top, b.top); const float intersection_right = std::min(a.right, b.right); @@ -38,7 +38,7 @@ float CalculateIntersectionArea(const Rect& a, const Rect& b) { std::max(intersection_right - intersection_left, 0.0); } -float CalculateIOU(const Rect& a, const Rect& b) { +float CalculateIOU(const RectF& a, const RectF& b) { const float area_a = CalculateArea(a); const float area_b = CalculateArea(b); if (area_a <= 0 || area_b <= 0) return 0.0; diff --git a/mediapipe/tasks/cc/vision/utils/landmarks_utils.h b/mediapipe/tasks/cc/vision/utils/landmarks_utils.h index 73114d2ef..4d1fac62f 100644 --- a/mediapipe/tasks/cc/vision/utils/landmarks_utils.h +++ b/mediapipe/tasks/cc/vision/utils/landmarks_utils.h @@ -27,15 +27,15 @@ limitations under the License. namespace mediapipe::tasks::vision::utils { // Calculates intersection over union for two bounds. -float CalculateIOU(const components::containers::Rect& a, - const components::containers::Rect& b); +float CalculateIOU(const components::containers::RectF& a, + const components::containers::RectF& b); // Calculates area for face bound -float CalculateArea(const components::containers::Rect& rect); +float CalculateArea(const components::containers::RectF& rect); // Calucates intersection area of two face bounds -float CalculateIntersectionArea(const components::containers::Rect& a, - const components::containers::Rect& b); +float CalculateIntersectionArea(const components::containers::RectF& a, + const components::containers::RectF& b); } // namespace mediapipe::tasks::vision::utils #endif // MEDIAPIPE_TASKS_CC_VISION_UTILS_LANDMARKS_UTILS_H_ diff --git a/mediapipe/tasks/ios/common/BUILD b/mediapipe/tasks/ios/common/BUILD new file mode 100644 index 000000000..5f13f8d5c --- /dev/null +++ b/mediapipe/tasks/ios/common/BUILD @@ -0,0 +1,25 @@ +# Copyright 2022 The MediaPipe Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +package(default_visibility = ["//mediapipe/tasks:internal"]) + +licenses(["notice"]) + +objc_library( + name = "MPPCommon", + hdrs = [ + "sources/MPPCommon.h", + ], + module_name = "MPPCommon", +) diff --git a/mediapipe/tasks/ios/common/sources/MPPCommon.h b/mediapipe/tasks/ios/common/sources/MPPCommon.h new file mode 100644 index 000000000..d76123fa0 --- /dev/null +++ b/mediapipe/tasks/ios/common/sources/MPPCommon.h @@ -0,0 +1,110 @@ +// Copyright 2022 The MediaPipe Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#import + +NS_ASSUME_NONNULL_BEGIN + +/** + * @enum MPPTasksErrorCode + * This enum specifies error codes for errors thrown by iOS MediaPipe Task Library. + */ +typedef NS_ENUM(NSUInteger, MPPTasksErrorCode) { + + // Generic error codes. + + /** Indicates the operation was cancelled, typically by the caller. */ + MPPTasksErrorCodeCancelledError = 1, + + /** Indicates an unknown error occurred. */ + MPPTasksErrorCodeUnknownError = 2, + + /** Indicates the caller specified an invalid argument, such as a malformed filename. */ + MPPTasksErrorCodeInvalidArgumentError = 3, + + /** Indicates a deadline expired before the operation could complete. */ + MPPTasksErrorCodeDeadlineExceededError = 4, + + /** Indicates some requested entity (such as a file or directory) was not found. */ + MPPTasksErrorCodeNotFoundError = 5, + + /** + * Indicates that the entity a caller attempted to create (such as a file or directory) is + * already present. + */ + MPPTasksErrorCodeAlreadyExistsError = 6, + + /** Indicates that the caller does not have permission to execute the specified operation. */ + MPPTasksErrorCodePermissionDeniedError = 7, + + /** + * Indicates some resource has been exhausted, perhaps a per-user quota, or perhaps the entire + * file system is out of space. + */ + MPPTasksErrorCodeResourceExhaustedError = 8, + + /** + * Indicates that the operation was rejected because the system is not in a state required for + * the operation's execution. For example, a directory to be deleted may be non-empty, an "rmdir" + * operation is applied to a non-directory, etc. + */ + MPPTasksErrorCodeFailedPreconditionError = 9, + + /** + * Indicates the operation was aborted, typically due to a concurrency issue such as a sequencer + * check failure or a failed transaction. + */ + MPPTasksErrorCodeAbortedError = 10, + + /** + * Indicates the operation was attempted past the valid range, such as seeking or reading past an + * end-of-file. + */ + MPPTasksErrorCodeOutOfRangeError = 11, + + /** + * Indicates the operation is not implemented or supported in this service. In this case, the + * operation should not be re-attempted. + */ + MPPTasksErrorCodeUnimplementedError = 12, + + /** + * Indicates an internal error has occurred and some invariants expected by the underlying system + * have not been satisfied. This error code is reserved for serious errors. + */ + MPPTasksErrorCodeInternalError = 13, + + /** + * Indicates the service is currently unavailable and that this is most likely a transient + * condition. + */ + MPPTasksErrorCodeUnavailableError = 14, + + /** Indicates that unrecoverable data loss or corruption has occurred. */ + MPPTasksErrorCodeDataLossError = 15, + + /** + * Indicates that the request does not have valid authentication credentials for the operation. + */ + MPPTasksErrorCodeUnauthenticatedError = 16, + + /** The first error code in MPPTasksErrorCode (for internal use only). */ + MPPTasksErrorCodeFirst = MPPTasksErrorCodeCancelledError, + + /** The last error code in MPPTasksErrorCode (for internal use only). */ + MPPTasksErrorCodeLast = MPPTasksErrorCodeUnauthenticatedError, + +} NS_SWIFT_NAME(TasksErrorCode); + +NS_ASSUME_NONNULL_END diff --git a/mediapipe/tasks/ios/common/utils/BUILD b/mediapipe/tasks/ios/common/utils/BUILD new file mode 100644 index 000000000..a29c700da --- /dev/null +++ b/mediapipe/tasks/ios/common/utils/BUILD @@ -0,0 +1,40 @@ +# Copyright 2022 The MediaPipe Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +package(default_visibility = ["//mediapipe/tasks:internal"]) + +licenses(["notice"]) + +objc_library( + name = "MPPCommonUtils", + srcs = ["sources/MPPCommonUtils.mm"], + hdrs = ["sources/MPPCommonUtils.h"], + deps = [ + "//mediapipe/tasks/cc:common", + "//mediapipe/tasks/ios/common:MPPCommon", + "@com_google_absl//absl/status", + "@com_google_absl//absl/strings", + "@com_google_absl//absl/strings:cord", + ], +) + +objc_library( + name = "NSStringHelpers", + srcs = ["sources/NSString+Helpers.mm"], + hdrs = ["sources/NSString+Helpers.h"], + copts = [ + "-ObjC++", + "-std=c++17", + ], +) diff --git a/mediapipe/tasks/ios/common/utils/sources/MPPCommonUtils.h b/mediapipe/tasks/ios/common/utils/sources/MPPCommonUtils.h new file mode 100644 index 000000000..69c28b916 --- /dev/null +++ b/mediapipe/tasks/ios/common/utils/sources/MPPCommonUtils.h @@ -0,0 +1,80 @@ +// Copyright 2022 The MediaPipe Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#import + +#include "mediapipe/tasks/cc/common.h" + +NS_ASSUME_NONNULL_BEGIN + +/** Error domain of MediaPipe Task related errors. */ +extern NSString *const MPPTasksErrorDomain; + +/** Helper utility for the all tasks which encapsulates common functionality. */ +@interface MPPCommonUtils : NSObject + +/** + * Creates and saves an NSError in the MediPipe task library domain, with the given code and + * description. + * + * @param code Error code. + * @param description Error description. + * @param error Pointer to the memory location where the created error should be saved. If `nil`, + * no error will be saved. + */ ++ (void)createCustomError:(NSError **)error + withCode:(NSUInteger)code + description:(NSString *)description; + +/** + * Creates and saves an NSError with the given domain, code and description. + * + * @param error Pointer to the memory location where the created error should be saved. If `nil`, + * no error will be saved. + * @param domain Error domain. + * @param code Error code. + * @param description Error description. + */ ++ (void)createCustomError:(NSError **)error + withDomain:(NSString *)domain + code:(NSUInteger)code + description:(NSString *)description; + +/** + * Converts an absl::Status to an NSError. + * + * @param status absl::Status. + * @param error Pointer to the memory location where the created error should be saved. If `nil`, + * no error will be saved. + * @return YES when there is no error, NO otherwise. + */ ++ (BOOL)checkCppError:(const absl::Status &)status toError:(NSError **)error; + +/** + * Allocates a block of memory with the specified size and returns a pointer to it. If memory + * cannot be allocated because of an invalid `memSize`, it saves an error. In other cases, it + * terminates program execution. + * + * @param memSize size of memory to be allocated + * @param error Pointer to the memory location where errors if any should be saved. If `nil`, no + * error will be saved. + * + * @return Pointer to the allocated block of memory on successfull allocation. `nil` in case as + * error is encountered because of invalid `memSize`. If failure is due to any other reason, method + * terminates program execution. + */ ++ (void *)mallocWithSize:(size_t)memSize error:(NSError **)error; +@end + +NS_ASSUME_NONNULL_END diff --git a/mediapipe/tasks/ios/common/utils/sources/MPPCommonUtils.mm b/mediapipe/tasks/ios/common/utils/sources/MPPCommonUtils.mm new file mode 100644 index 000000000..65b551c32 --- /dev/null +++ b/mediapipe/tasks/ios/common/utils/sources/MPPCommonUtils.mm @@ -0,0 +1,141 @@ +// Copyright 2022 The TensorFlow Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#import "mediapipe/tasks/ios/common/utils/sources/MPPCommonUtils.h" + +#import "mediapipe/tasks/ios/common/sources/MPPCommon.h" + +#include + +#include "absl/status/status.h" // from @com_google_absl +#include "absl/strings/cord.h" // from @com_google_absl +#include "mediapipe/tasks/cc/common.h" + +/** Error domain of MediaPipe task library errors. */ +NSString *const MPPTasksErrorDomain = @"com.google.mediapipe.tasks"; + +namespace { +using absl::StatusCode; +} + +@implementation MPPCommonUtils + ++ (void)createCustomError:(NSError **)error + withCode:(NSUInteger)code + description:(NSString *)description { + [MPPCommonUtils createCustomError:error + withDomain:MPPTasksErrorDomain + code:code + description:description]; +} + ++ (void)createCustomError:(NSError **)error + withDomain:(NSString *)domain + code:(NSUInteger)code + description:(NSString *)description { + if (error) { + *error = [NSError errorWithDomain:domain + code:code + userInfo:@{NSLocalizedDescriptionKey : description}]; + } +} + ++ (void *)mallocWithSize:(size_t)memSize error:(NSError **)error { + if (!memSize) { + [MPPCommonUtils createCustomError:error + withCode:MPPTasksErrorCodeInvalidArgumentError + description:@"memSize cannot be zero."]; + return NULL; + } + + void *allocedMemory = malloc(memSize); + if (!allocedMemory) { + exit(-1); + } + + return allocedMemory; +} + ++ (BOOL)checkCppError:(const absl::Status &)status toError:(NSError *_Nullable *)error { + if (status.ok()) { + return YES; + } + + // Converts the absl status message to an NSString. + NSString *description = [NSString + stringWithCString:status.ToString(absl::StatusToStringMode::kWithNoExtraData).c_str() + encoding:NSUTF8StringEncoding]; + + MPPTasksErrorCode errorCode = MPPTasksErrorCodeUnknownError; + + // Maps the absl::StatusCode to the appropriate MPPTasksErrorCode. Note: MPPTasksErrorCode omits + // absl::StatusCode::kOk. + switch (status.code()) { + case StatusCode::kCancelled: + errorCode = MPPTasksErrorCodeCancelledError; + break; + case StatusCode::kUnknown: + errorCode = MPPTasksErrorCodeUnknownError; + break; + case StatusCode::kInvalidArgument: + errorCode = MPPTasksErrorCodeInvalidArgumentError; + break; + case StatusCode::kDeadlineExceeded: + errorCode = MPPTasksErrorCodeDeadlineExceededError; + break; + case StatusCode::kNotFound: + errorCode = MPPTasksErrorCodeNotFoundError; + break; + case StatusCode::kAlreadyExists: + errorCode = MPPTasksErrorCodeAlreadyExistsError; + break; + case StatusCode::kPermissionDenied: + errorCode = MPPTasksErrorCodePermissionDeniedError; + break; + case StatusCode::kResourceExhausted: + errorCode = MPPTasksErrorCodeResourceExhaustedError; + break; + case StatusCode::kFailedPrecondition: + errorCode = MPPTasksErrorCodeFailedPreconditionError; + break; + case StatusCode::kAborted: + errorCode = MPPTasksErrorCodeAbortedError; + break; + case StatusCode::kOutOfRange: + errorCode = MPPTasksErrorCodeOutOfRangeError; + break; + case StatusCode::kUnimplemented: + errorCode = MPPTasksErrorCodeUnimplementedError; + break; + case StatusCode::kInternal: + errorCode = MPPTasksErrorCodeInternalError; + break; + case StatusCode::kUnavailable: + errorCode = MPPTasksErrorCodeUnavailableError; + break; + case StatusCode::kDataLoss: + errorCode = MPPTasksErrorCodeDataLossError; + break; + case StatusCode::kUnauthenticated: + errorCode = MPPTasksErrorCodeUnauthenticatedError; + break; + default: + break; + } + + [MPPCommonUtils createCustomError:error withCode:errorCode description:description]; + return NO; +} + +@end diff --git a/mediapipe/tasks/ios/common/utils/sources/NSString+Helpers.h b/mediapipe/tasks/ios/common/utils/sources/NSString+Helpers.h new file mode 100644 index 000000000..66f9c5ccc --- /dev/null +++ b/mediapipe/tasks/ios/common/utils/sources/NSString+Helpers.h @@ -0,0 +1,29 @@ +// Copyright 2022 The MediaPipe Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#import + +#include + +NS_ASSUME_NONNULL_BEGIN + +@interface NSString (Helpers) + +@property(readonly, nonatomic) std::string cppString; + ++ (NSString *)stringWithCppString:(std::string)text; + +@end + +NS_ASSUME_NONNULL_END diff --git a/mediapipe/tasks/ios/common/utils/sources/NSString+Helpers.mm b/mediapipe/tasks/ios/common/utils/sources/NSString+Helpers.mm new file mode 100644 index 000000000..183ed4365 --- /dev/null +++ b/mediapipe/tasks/ios/common/utils/sources/NSString+Helpers.mm @@ -0,0 +1,27 @@ +// Copyright 2022 The MediaPipe Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#import "mediapipe/tasks/ios/common/utils/sources/NSString+Helpers.h" + +@implementation NSString (Helpers) + +- (std::string)cppString { + return std::string(self.UTF8String, [self lengthOfBytesUsingEncoding:NSUTF8StringEncoding]); +} + ++ (NSString *)stringWithCppString:(std::string)text { + return [NSString stringWithCString:text.c_str() encoding:[NSString defaultCStringEncoding]]; +} + +@end diff --git a/mediapipe/tasks/ios/components/containers/BUILD b/mediapipe/tasks/ios/components/containers/BUILD new file mode 100644 index 000000000..ee54bb712 --- /dev/null +++ b/mediapipe/tasks/ios/components/containers/BUILD @@ -0,0 +1,43 @@ +# Copyright 2023 The MediaPipe Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +package(default_visibility = ["//mediapipe/tasks:internal"]) + +licenses(["notice"]) + +objc_library( + name = "MPPCategory", + srcs = ["sources/MPPCategory.m"], + hdrs = ["sources/MPPCategory.h"], +) + +objc_library( + name = "MPPClassificationResult", + srcs = ["sources/MPPClassificationResult.m"], + hdrs = ["sources/MPPClassificationResult.h"], + deps = [":MPPCategory"], +) + +objc_library( + name = "MPPEmbedding", + srcs = ["sources/MPPEmbedding.m"], + hdrs = ["sources/MPPEmbedding.h"], +) + +objc_library( + name = "MPPEmbeddingResult", + srcs = ["sources/MPPEmbeddingResult.m"], + hdrs = ["sources/MPPEmbeddingResult.h"], + deps = [":MPPEmbedding"], +) diff --git a/mediapipe/tasks/ios/components/containers/sources/MPPCategory.h b/mediapipe/tasks/ios/components/containers/sources/MPPCategory.h new file mode 100644 index 000000000..f360d46da --- /dev/null +++ b/mediapipe/tasks/ios/components/containers/sources/MPPCategory.h @@ -0,0 +1,68 @@ +// Copyright 2023 The MediaPipe Authors. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#import + +NS_ASSUME_NONNULL_BEGIN + +/** + * Category is a util class that contains a label, its display name, a float value as score, and the + * index of the label in the corresponding label file. Typically it's used as the result of + * classification tasks. + */ +NS_SWIFT_NAME(ResultCategory) +@interface MPPCategory : NSObject + +/** + * The index of the label in the corresponding label file. Set to -1 if the index is + * not set. + */ +@property(nonatomic, readonly) NSInteger index; + +/** Confidence score for this class. */ +@property(nonatomic, readonly) float score; + +/** The label of this category object. */ +@property(nonatomic, readonly, nullable) NSString *categoryName; + +/** + * The display name of the label, which may be translated for different locales. For example, a + * label, "apple", may be translated into Spanish for display purpose, so that the display name is + * "manzana". + */ +@property(nonatomic, readonly, nullable) NSString *displayName; + +/** + * Initializes a new `MPPCategory` with the given index, score, category name and display name. + * + * @param index The index of the label in the corresponding label file. + * @param score The probability score of this label category. + * @param categoryName The label of this category object. + * @param displayName The display name of the label. + * + * @return An instance of `MPPCategory` initialized with the given index, score, category name and + * display name. + */ +- (instancetype)initWithIndex:(NSInteger)index + score:(float)score + categoryName:(nullable NSString *)categoryName + displayName:(nullable NSString *)displayName NS_DESIGNATED_INITIALIZER; + +- (instancetype)init NS_UNAVAILABLE; + ++ (instancetype)new NS_UNAVAILABLE; + +@end + +NS_ASSUME_NONNULL_END diff --git a/mediapipe/tasks/ios/components/containers/sources/MPPCategory.m b/mediapipe/tasks/ios/components/containers/sources/MPPCategory.m new file mode 100644 index 000000000..824fae65e --- /dev/null +++ b/mediapipe/tasks/ios/components/containers/sources/MPPCategory.m @@ -0,0 +1,33 @@ +// Copyright 2023 The MediaPipe Authors. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#import "mediapipe/tasks/ios/components/containers/sources/MPPCategory.h" + +@implementation MPPCategory + +- (instancetype)initWithIndex:(NSInteger)index + score:(float)score + categoryName:(nullable NSString *)categoryName + displayName:(nullable NSString *)displayName { + self = [super init]; + if (self) { + _index = index; + _score = score; + _categoryName = categoryName; + _displayName = displayName; + } + return self; +} + +@end diff --git a/mediapipe/tasks/ios/components/containers/sources/MPPClassificationResult.h b/mediapipe/tasks/ios/components/containers/sources/MPPClassificationResult.h new file mode 100644 index 000000000..cd464c6a1 --- /dev/null +++ b/mediapipe/tasks/ios/components/containers/sources/MPPClassificationResult.h @@ -0,0 +1,116 @@ +// Copyright 2023 The MediaPipe Authors. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#import +#import "mediapipe/tasks/ios/components/containers/sources/MPPCategory.h" + +NS_ASSUME_NONNULL_BEGIN + +/** + * Represents the list of classification for a given classifier head. Typically used as a result + * for classification tasks. + */ +NS_SWIFT_NAME(Classifications) +@interface MPPClassifications : NSObject + +/** + * The index of the classifier head these entries refer to. This is useful for multi-head models. + */ +@property(nonatomic, readonly) NSInteger headIndex; + +/** The optional name of the classifier head, which is the corresponding tensor metadata name. */ +@property(nonatomic, readonly, nullable) NSString *headName; + +/** An array of `MPPCategory` objects containing the predicted categories. */ +@property(nonatomic, readonly) NSArray *categories; + +/** + * Initializes a new `MPPClassifications` object with the given head index and array of categories. + * Head name is initialized to `nil`. + * + * @param headIndex The index of the classifier head. + * @param categories An array of `MPPCategory` objects containing the predicted categories. + * + * @return An instance of `MPPClassifications` initialized with the given head index and + * array of categories. + */ +- (instancetype)initWithHeadIndex:(NSInteger)headIndex + categories:(NSArray *)categories; + +/** + * Initializes a new `MPPClassifications` with the given head index, head name and array of + * categories. + * + * @param headIndex The index of the classifier head. + * @param headName The name of the classifier head, which is the corresponding tensor metadata + * name. + * @param categories An array of `MPPCategory` objects containing the predicted categories. + * + * @return An object of `MPPClassifications` initialized with the given head index, head name and + * array of categories. + */ +- (instancetype)initWithHeadIndex:(NSInteger)headIndex + headName:(nullable NSString *)headName + categories:(NSArray *)categories NS_DESIGNATED_INITIALIZER; + +- (instancetype)init NS_UNAVAILABLE; + ++ (instancetype)new NS_UNAVAILABLE; + +@end + +/** + * Represents the classification results of a model. Typically used as a result for classification + * tasks. + */ +NS_SWIFT_NAME(ClassificationResult) +@interface MPPClassificationResult : NSObject + +/** + * An Array of `MPPClassifications` objects containing the predicted categories for each head of + * the model. + */ +@property(nonatomic, readonly) NSArray *classifications; + +/** + * The optional timestamp (in milliseconds) of the start of the chunk of data corresponding to + * these results. If it is set to the value -1, it signifies the absence of a timestamp. This is + * only used for classification on time series (e.g. audio classification). In these use cases, the + * amount of data to process might exceed the maximum size that the model can process: to solve + * this, the input data is split into multiple chunks starting at different timestamps. + */ +@property(nonatomic, readonly) NSInteger timestampMs; + +/** + * Initializes a new `MPPClassificationResult` with the given array of classifications and time + * stamp (in milliseconds). + * + * @param classifications An Array of `MPPClassifications` objects containing the predicted + * categories for each head of the model. + * @param timestampMs The timestamp (in milliseconds) of the start of the chunk of data + * corresponding to these results. + * + * @return An instance of `MPPClassificationResult` initialized with the given array of + * classifications and timestampMs. + */ +- (instancetype)initWithClassifications:(NSArray *)classifications + timestampMs:(NSInteger)timestampMs NS_DESIGNATED_INITIALIZER; + +- (instancetype)init NS_UNAVAILABLE; + ++ (instancetype)new NS_UNAVAILABLE; + +@end + +NS_ASSUME_NONNULL_END diff --git a/mediapipe/tasks/ios/components/containers/sources/MPPClassificationResult.m b/mediapipe/tasks/ios/components/containers/sources/MPPClassificationResult.m new file mode 100644 index 000000000..6d42d22ca --- /dev/null +++ b/mediapipe/tasks/ios/components/containers/sources/MPPClassificationResult.m @@ -0,0 +1,51 @@ +// Copyright 2023 The MediaPipe Authors. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#import "mediapipe/tasks/ios/components/containers/sources/MPPClassificationResult.h" + +@implementation MPPClassifications + +- (instancetype)initWithHeadIndex:(NSInteger)headIndex + headName:(nullable NSString *)headName + categories:(NSArray *)categories { + self = [super init]; + if (self) { + _headIndex = headIndex; + _headName = headName; + _categories = categories; + } + return self; +} + +- (instancetype)initWithHeadIndex:(NSInteger)headIndex + categories:(NSArray *)categories { + return [self initWithHeadIndex:headIndex headName:nil categories:categories]; +} + +@end + +@implementation MPPClassificationResult + +- (instancetype)initWithClassifications:(NSArray *)classifications + timestampMs:(NSInteger)timestampMs { + self = [super init]; + if (self) { + _classifications = classifications; + _timestampMs = timestampMs; + } + + return self; +} + +@end diff --git a/mediapipe/tasks/ios/components/containers/sources/MPPEmbedding.h b/mediapipe/tasks/ios/components/containers/sources/MPPEmbedding.h new file mode 100644 index 000000000..931d4e0b9 --- /dev/null +++ b/mediapipe/tasks/ios/components/containers/sources/MPPEmbedding.h @@ -0,0 +1,69 @@ +// Copyright 2023 The MediaPipe Authors. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#import + +NS_ASSUME_NONNULL_BEGIN + +/** + * Represents the embedding for a given embedder head. Typically used in embedding tasks. + * + * One and only one of the two 'floatEmbedding' and 'quantizedEmbedding' will contain data, based on + * whether or not the embedder was configured to perform scala quantization. + */ +NS_SWIFT_NAME(Embedding) +@interface MPPEmbedding : NSObject + +/** + * @brief The embedding represented as an `NSArray` of `Float` values. + * Empty if the embedder was configured to perform scalar quantization. + */ +@property(nonatomic, readonly, nullable) NSArray *floatEmbedding; + +/** + * @brief The embedding represented as an `NSArray` of `UInt8` values. + * Empty if the embedder was not configured to perform scalar quantization. + */ +@property(nonatomic, readonly, nullable) NSArray *quantizedEmbedding; + +/** The index of the embedder head these entries refer to. This is useful for multi-head models. */ +@property(nonatomic, readonly) NSInteger headIndex; + +/** The optional name of the embedder head, which is the corresponding tensor metadata name. */ +@property(nonatomic, readonly, nullable) NSString *headName; + +/** + * Initializes a new `MPPEmbedding` with the given float embedding, quantized embedding, head index + * and head name. + * + * @param floatEmbedding The optional Floating-point embedding. + * @param quantizedEmbedding The optional Quantized embedding. + * @param headIndex The index of the embedder head. + * @param headName The optional name of the embedder head. + * + * @return An instance of `MPPEmbedding` initialized with the given float embedding, quantized + * embedding, head index and head name. + */ +- (instancetype)initWithFloatEmbedding:(nullable NSArray *)floatEmbedding + quantizedEmbedding:(nullable NSArray *)quantizedEmbedding + headIndex:(NSInteger)headIndex + headName:(nullable NSString *)headName NS_DESIGNATED_INITIALIZER; + +- (instancetype)init NS_UNAVAILABLE; + ++ (instancetype)new NS_UNAVAILABLE; + +@end + +NS_ASSUME_NONNULL_END diff --git a/mediapipe/tasks/ios/components/containers/sources/MPPEmbedding.m b/mediapipe/tasks/ios/components/containers/sources/MPPEmbedding.m new file mode 100644 index 000000000..69ca076b2 --- /dev/null +++ b/mediapipe/tasks/ios/components/containers/sources/MPPEmbedding.m @@ -0,0 +1,33 @@ +// Copyright 2023 The MediaPipe Authors. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#import "mediapipe/tasks/ios/components/containers/sources/MPPEmbedding.h" + +@implementation MPPEmbedding + +- (instancetype)initWithFloatEmbedding:(nullable NSArray *)floatEmbedding + quantizedEmbedding:(nullable NSArray *)quantizedEmbedding + headIndex:(NSInteger)headIndex + headName:(nullable NSString *)headName { + self = [super init]; + if (self) { + _headIndex = headIndex; + _headName = headName; + _floatEmbedding = floatEmbedding; + _quantizedEmbedding = quantizedEmbedding; + } + return self; +} + +@end diff --git a/mediapipe/tasks/ios/components/containers/sources/MPPEmbeddingResult.h b/mediapipe/tasks/ios/components/containers/sources/MPPEmbeddingResult.h new file mode 100644 index 000000000..8fd9b9dff --- /dev/null +++ b/mediapipe/tasks/ios/components/containers/sources/MPPEmbeddingResult.h @@ -0,0 +1,59 @@ +// Copyright 2023 The MediaPipe Authors. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#import +#import "mediapipe/tasks/ios/components/containers/sources/MPPEmbedding.h" + +NS_ASSUME_NONNULL_BEGIN + +/** Represents the embedding results of a model. Typically used as a result for embedding tasks. */ +NS_SWIFT_NAME(EmbeddingResult) +@interface MPPEmbeddingResult : NSObject + +/** + * An Array of `MPPEmbedding` objects containing the embedding results for each head of the model. + */ +@property(nonatomic, readonly) NSArray *embeddings; + +/** + * @brief The optional timestamp (in milliseconds) of the start of the chunk of data corresponding + * to these results. + * This is only used for embedding extraction on time series (e.g. audio embedder). In these use + * cases, the amount of data to process might exceed the maximum size that the model can process. To + * solve this, the input data is split into multiple chunks starting at different timestamps. + */ +@property(nonatomic, readonly) NSInteger timestampMs; + +/** + * Initializes a new `MPPEmbedding` with the given array of embeddings and timestamp (in + * milliseconds). + * + * @param embeddings An Array of `MPPEmbedding` objects containing the embedding results for each + * head of the model. + * @param timestampMs The optional timestamp (in milliseconds) of the start of the chunk of data + * corresponding to these results. Pass `0` if timestamp is absent. + * + * @return An instance of `MPPEmbeddingResult` initialized with the given array of embeddings and + * timestampMs. + */ +- (instancetype)initWithEmbeddings:(NSArray *)embeddings + timestampMs:(NSInteger)timestampMs NS_DESIGNATED_INITIALIZER; + +- (instancetype)init NS_UNAVAILABLE; + ++ (instancetype)new NS_UNAVAILABLE; + +@end + +NS_ASSUME_NONNULL_END diff --git a/mediapipe/tasks/ios/components/containers/sources/MPPEmbeddingResult.m b/mediapipe/tasks/ios/components/containers/sources/MPPEmbeddingResult.m new file mode 100644 index 000000000..56dd30fdd --- /dev/null +++ b/mediapipe/tasks/ios/components/containers/sources/MPPEmbeddingResult.m @@ -0,0 +1,30 @@ +// Copyright 2023 The MediaPipe Authors. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#import "mediapipe/tasks/ios/components/containers/sources/MPPEmbeddingResult.h" + +@implementation MPPEmbeddingResult + +- (instancetype)initWithEmbeddings:(NSArray *)embeddings + timestampMs:(NSInteger)timestampMs { + self = [super init]; + if (self) { + _embeddings = embeddings; + _timestampMs = timestampMs; + } + + return self; +} + +@end diff --git a/mediapipe/tasks/ios/components/containers/utils/BUILD b/mediapipe/tasks/ios/components/containers/utils/BUILD new file mode 100644 index 000000000..64ca29b88 --- /dev/null +++ b/mediapipe/tasks/ios/components/containers/utils/BUILD @@ -0,0 +1,63 @@ +# Copyright 2023 The MediaPipe Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +package(default_visibility = ["//mediapipe/tasks:internal"]) + +licenses(["notice"]) + +objc_library( + name = "MPPCategoryHelpers", + srcs = ["sources/MPPCategory+Helpers.mm"], + hdrs = ["sources/MPPCategory+Helpers.h"], + deps = [ + "//mediapipe/framework/formats:classification_cc_proto", + "//mediapipe/tasks/ios/common/utils:NSStringHelpers", + "//mediapipe/tasks/ios/components/containers:MPPCategory", + ], +) + +objc_library( + name = "MPPClassificationResultHelpers", + srcs = ["sources/MPPClassificationResult+Helpers.mm"], + hdrs = ["sources/MPPClassificationResult+Helpers.h"], + deps = [ + ":MPPCategoryHelpers", + "//mediapipe/tasks/cc/components/containers/proto:classifications_cc_proto", + "//mediapipe/tasks/ios/common/utils:NSStringHelpers", + "//mediapipe/tasks/ios/components/containers:MPPClassificationResult", + ], +) + +objc_library( + name = "MPPEmbeddingHelpers", + srcs = ["sources/MPPEmbedding+Helpers.mm"], + hdrs = ["sources/MPPEmbedding+Helpers.h"], + deps = [ + "//mediapipe/tasks/cc/components/containers/proto:embeddings_cc_proto", + "//mediapipe/tasks/ios/common/utils:NSStringHelpers", + "//mediapipe/tasks/ios/components/containers:MPPEmbedding", + ], +) + +objc_library( + name = "MPPEmbeddingResultHelpers", + srcs = ["sources/MPPEmbeddingResult+Helpers.mm"], + hdrs = ["sources/MPPEmbeddingResult+Helpers.h"], + deps = [ + ":MPPEmbeddingHelpers", + "//mediapipe/tasks/cc/components/containers/proto:embeddings_cc_proto", + "//mediapipe/tasks/ios/common/utils:NSStringHelpers", + "//mediapipe/tasks/ios/components/containers:MPPEmbeddingResult", + ], +) diff --git a/mediapipe/tasks/ios/components/containers/utils/sources/MPPCategory+Helpers.h b/mediapipe/tasks/ios/components/containers/utils/sources/MPPCategory+Helpers.h new file mode 100644 index 000000000..9a11d1e29 --- /dev/null +++ b/mediapipe/tasks/ios/components/containers/utils/sources/MPPCategory+Helpers.h @@ -0,0 +1,26 @@ +// Copyright 2023 The MediaPipe Authors. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "mediapipe/framework/formats/classification.pb.h" +#import "mediapipe/tasks/ios/components/containers/sources/MPPCategory.h" + +NS_ASSUME_NONNULL_BEGIN + +@interface MPPCategory (Helpers) + ++ (MPPCategory *)categoryWithProto:(const ::mediapipe::Classification &)classificationProto; + +@end + +NS_ASSUME_NONNULL_END diff --git a/mediapipe/tasks/ios/components/containers/utils/sources/MPPCategory+Helpers.mm b/mediapipe/tasks/ios/components/containers/utils/sources/MPPCategory+Helpers.mm new file mode 100644 index 000000000..ff0983139 --- /dev/null +++ b/mediapipe/tasks/ios/components/containers/utils/sources/MPPCategory+Helpers.mm @@ -0,0 +1,42 @@ +// Copyright 2023 The MediaPipe Authors. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#import "mediapipe/tasks/ios/common/utils/sources/NSString+Helpers.h" +#import "mediapipe/tasks/ios/components/containers/utils/sources/MPPCategory+Helpers.h" + +namespace { +using ClassificationProto = ::mediapipe::Classification; +} + +@implementation MPPCategory (Helpers) + ++ (MPPCategory *)categoryWithProto:(const ClassificationProto &)clasificationProto { + NSString *categoryName; + NSString *displayName; + + if (clasificationProto.has_label()) { + categoryName = [NSString stringWithCppString:clasificationProto.label()]; + } + + if (clasificationProto.has_display_name()) { + displayName = [NSString stringWithCppString:clasificationProto.display_name()]; + } + + return [[MPPCategory alloc] initWithIndex:clasificationProto.index() + score:clasificationProto.score() + categoryName:categoryName + displayName:displayName]; +} + +@end diff --git a/mediapipe/tasks/ios/components/containers/utils/sources/MPPClassificationResult+Helpers.h b/mediapipe/tasks/ios/components/containers/utils/sources/MPPClassificationResult+Helpers.h new file mode 100644 index 000000000..fde436feb --- /dev/null +++ b/mediapipe/tasks/ios/components/containers/utils/sources/MPPClassificationResult+Helpers.h @@ -0,0 +1,35 @@ +// Copyright 2023 The MediaPipe Authors. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "mediapipe/tasks/cc/components/containers/proto/classifications.pb.h" +#import "mediapipe/tasks/ios/components/containers/sources/MPPClassificationResult.h" + +NS_ASSUME_NONNULL_BEGIN + +@interface MPPClassifications (Helpers) + ++ (MPPClassifications *)classificationsWithProto: + (const mediapipe::tasks::components::containers::proto::Classifications &)classificationsProto; + +@end + +@interface MPPClassificationResult (Helpers) + ++ (MPPClassificationResult *)classificationResultWithProto: + (const mediapipe::tasks::components::containers::proto::ClassificationResult &) + classificationResultProto; + +@end + +NS_ASSUME_NONNULL_END diff --git a/mediapipe/tasks/ios/components/containers/utils/sources/MPPClassificationResult+Helpers.mm b/mediapipe/tasks/ios/components/containers/utils/sources/MPPClassificationResult+Helpers.mm new file mode 100644 index 000000000..b02b032bb --- /dev/null +++ b/mediapipe/tasks/ios/components/containers/utils/sources/MPPClassificationResult+Helpers.mm @@ -0,0 +1,68 @@ +// Copyright 2023 The MediaPipe Authors. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#import "mediapipe/tasks/ios/common/utils/sources/NSString+Helpers.h" +#import "mediapipe/tasks/ios/components/containers/utils/sources/MPPCategory+Helpers.h" +#import "mediapipe/tasks/ios/components/containers/utils/sources/MPPClassificationResult+Helpers.h" + +namespace { +using ClassificationsProto = ::mediapipe::tasks::components::containers::proto::Classifications; +using ClassificationResultProto = + ::mediapipe::tasks::components::containers::proto::ClassificationResult; +} // namespace + +@implementation MPPClassifications (Helpers) + ++ (MPPClassifications *)classificationsWithProto: + (const ClassificationsProto &)classificationsProto { + NSMutableArray *categories = + [NSMutableArray arrayWithCapacity:(NSUInteger)classificationsProto.classification_list() + .classification_size()]; + for (const auto &classification : classificationsProto.classification_list().classification()) { + [categories addObject:[MPPCategory categoryWithProto:classification]]; + } + + NSString *headName; + if (classificationsProto.has_head_name()) { + headName = [NSString stringWithCppString:classificationsProto.head_name()]; + } + + return [[MPPClassifications alloc] initWithHeadIndex:(NSInteger)classificationsProto.head_index() + headName:headName + categories:categories]; +} + +@end + +@implementation MPPClassificationResult (Helpers) + ++ (MPPClassificationResult *)classificationResultWithProto: + (const ClassificationResultProto &)classificationResultProto { + NSMutableArray *classifications = [NSMutableArray + arrayWithCapacity:(NSUInteger)classificationResultProto.classifications_size()]; + for (const auto &classificationsProto : classificationResultProto.classifications()) { + [classifications addObject:[MPPClassifications classificationsWithProto:classificationsProto]]; + } + + NSInteger timestampMs = 0; + if (classificationResultProto.has_timestamp_ms()) { + timestampMs = (NSInteger)classificationResultProto.timestamp_ms(); + } + + return [[MPPClassificationResult alloc] initWithClassifications:classifications + timestampMs:timestampMs]; + ; +} + +@end diff --git a/mediapipe/tasks/ios/components/containers/utils/sources/MPPEmbedding+Helpers.h b/mediapipe/tasks/ios/components/containers/utils/sources/MPPEmbedding+Helpers.h new file mode 100644 index 000000000..33fb3839d --- /dev/null +++ b/mediapipe/tasks/ios/components/containers/utils/sources/MPPEmbedding+Helpers.h @@ -0,0 +1,27 @@ +// Copyright 2023 The MediaPipe Authors. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "mediapipe/tasks/cc/components/containers/proto/embeddings.pb.h" +#import "mediapipe/tasks/ios/components/containers/sources/MPPEmbedding.h" + +NS_ASSUME_NONNULL_BEGIN + +@interface MPPEmbedding (Helpers) + ++ (MPPEmbedding *)embeddingWithProto: + (const ::mediapipe::tasks::components::containers::proto::Embedding &)embeddingProto; + +@end + +NS_ASSUME_NONNULL_END diff --git a/mediapipe/tasks/ios/components/containers/utils/sources/MPPEmbedding+Helpers.mm b/mediapipe/tasks/ios/components/containers/utils/sources/MPPEmbedding+Helpers.mm new file mode 100644 index 000000000..76bb75032 --- /dev/null +++ b/mediapipe/tasks/ios/components/containers/utils/sources/MPPEmbedding+Helpers.mm @@ -0,0 +1,60 @@ +// Copyright 2023 The MediaPipe Authors. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#import "mediapipe/tasks/ios/common/utils/sources/NSString+Helpers.h" +#import "mediapipe/tasks/ios/components/containers/utils/sources/MPPEmbedding+Helpers.h" + +#include + +namespace { +using EmbeddingProto = ::mediapipe::tasks::components::containers::proto::Embedding; +} + +@implementation MPPEmbedding (Helpers) + ++ (MPPEmbedding *)embeddingWithProto:(const EmbeddingProto &)embeddingProto { + + NSMutableArray *floatEmbedding; + NSMutableArray *quantizedEmbedding; + NSString *headName; + + if (embeddingProto.has_float_embedding()) { + floatEmbedding = + [NSMutableArray arrayWithCapacity:embeddingProto.float_embedding().values_size()]; + + for (const auto value : embeddingProto.float_embedding().values()) { + [floatEmbedding addObject:[NSNumber numberWithFloat:value]]; + } + } + + if (embeddingProto.has_quantized_embedding()) { + const std::string &cppQuantizedEmbedding = embeddingProto.quantized_embedding().values(); + quantizedEmbedding = [NSMutableArray arrayWithCapacity:cppQuantizedEmbedding.length()]; + + for (char ch : cppQuantizedEmbedding) { + [quantizedEmbedding addObject:[NSNumber numberWithChar:ch]]; + } + } + + if (embeddingProto.has_head_name()) { + headName = [NSString stringWithCppString:embeddingProto.head_name()]; + } + + return [[MPPEmbedding alloc] initWithFloatEmbedding:floatEmbedding + quantizedEmbedding:quantizedEmbedding + headIndex:embeddingProto.head_index() + headName:headName]; +} + +@end diff --git a/mediapipe/tasks/ios/components/containers/utils/sources/MPPEmbeddingResult+Helpers.h b/mediapipe/tasks/ios/components/containers/utils/sources/MPPEmbeddingResult+Helpers.h new file mode 100644 index 000000000..cc53c3e25 --- /dev/null +++ b/mediapipe/tasks/ios/components/containers/utils/sources/MPPEmbeddingResult+Helpers.h @@ -0,0 +1,28 @@ +// Copyright 2023 The MediaPipe Authors. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "mediapipe/tasks/cc/components/containers/proto/embeddings.pb.h" +#import "mediapipe/tasks/ios/components/containers/sources/MPPEmbeddingResult.h" + +NS_ASSUME_NONNULL_BEGIN + +@interface MPPEmbeddingResult (Helpers) + ++ (MPPEmbeddingResult *)embeddingResultWithProto: + (const ::mediapipe::tasks::components::containers::proto::EmbeddingResult &) + embeddingResultProto; + +@end + +NS_ASSUME_NONNULL_END diff --git a/mediapipe/tasks/ios/components/containers/utils/sources/MPPEmbeddingResult+Helpers.mm b/mediapipe/tasks/ios/components/containers/utils/sources/MPPEmbeddingResult+Helpers.mm new file mode 100644 index 000000000..f9863e9ca --- /dev/null +++ b/mediapipe/tasks/ios/components/containers/utils/sources/MPPEmbeddingResult+Helpers.mm @@ -0,0 +1,42 @@ +// Copyright 2023 The MediaPipe Authors. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#import "mediapipe/tasks/ios/components/containers/utils/sources/MPPEmbeddingResult+Helpers.h" + +#import "mediapipe/tasks/ios/common/utils/sources/NSString+Helpers.h" +#import "mediapipe/tasks/ios/components/containers/utils/sources/MPPEmbedding+Helpers.h" + +namespace { +using EmbeddingResultProto = ::mediapipe::tasks::components::containers::proto::EmbeddingResult; +} + +@implementation MPPEmbeddingResult (Helpers) + ++ (MPPEmbeddingResult *)embeddingResultWithProto: + (const EmbeddingResultProto &)embeddingResultProto { + NSMutableArray *embeddings = + [NSMutableArray arrayWithCapacity:(NSUInteger)embeddingResultProto.embeddings_size()]; + for (const auto &embeddingProto : embeddingResultProto.embeddings()) { + [embeddings addObject:[MPPEmbedding embeddingWithProto:embeddingProto]]; + } + + NSInteger timestampMs = 0; + if (embeddingResultProto.has_timestamp_ms()) { + timestampMs = (NSInteger)embeddingResultProto.timestamp_ms(); + } + + return [[MPPEmbeddingResult alloc] initWithEmbeddings:embeddings timestampMs:timestampMs]; +} + +@end diff --git a/mediapipe/tasks/ios/core/BUILD b/mediapipe/tasks/ios/core/BUILD new file mode 100644 index 000000000..f9489acfe --- /dev/null +++ b/mediapipe/tasks/ios/core/BUILD @@ -0,0 +1,92 @@ +# Copyright 2022 The MediaPipe Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +package(default_visibility = ["//mediapipe/tasks:internal"]) + +licenses(["notice"]) + +objc_library( + name = "MPPBaseOptions", + srcs = ["sources/MPPBaseOptions.m"], + hdrs = ["sources/MPPBaseOptions.h"], +) + +objc_library( + name = "MPPTaskOptions", + srcs = ["sources/MPPTaskOptions.m"], + hdrs = ["sources/MPPTaskOptions.h"], + deps = [":MPPBaseOptions"], +) + +objc_library( + name = "MPPTaskResult", + srcs = ["sources/MPPTaskResult.m"], + hdrs = ["sources/MPPTaskResult.h"], +) + +objc_library( + name = "MPPTaskOptionsProtocol", + hdrs = ["sources/MPPTaskOptionsProtocol.h"], + deps = ["//mediapipe/framework:calculator_options_cc_proto"], +) + +objc_library( + name = "MPPTaskInfo", + srcs = ["sources/MPPTaskInfo.mm"], + hdrs = ["sources/MPPTaskInfo.h"], + copts = [ + "-ObjC++", + "-std=c++17", + ], + deps = [ + ":MPPTaskOptions", + ":MPPTaskOptionsProtocol", + "//mediapipe/calculators/core:flow_limiter_calculator_cc_proto", + "//mediapipe/framework:calculator_cc_proto", + "//mediapipe/framework:calculator_options_cc_proto", + "//mediapipe/tasks/ios/common:MPPCommon", + "//mediapipe/tasks/ios/common/utils:MPPCommonUtils", + "//mediapipe/tasks/ios/common/utils:NSStringHelpers", + ], +) + +objc_library( + name = "MPPTextPacketCreator", + srcs = ["sources/MPPTextPacketCreator.mm"], + hdrs = ["sources/MPPTextPacketCreator.h"], + copts = [ + "-ObjC++", + "-std=c++17", + ], + deps = [ + "//mediapipe/framework:packet", + "//mediapipe/tasks/ios/common/utils:NSStringHelpers", + ], +) + +objc_library( + name = "MPPTaskRunner", + srcs = ["sources/MPPTaskRunner.mm"], + hdrs = ["sources/MPPTaskRunner.h"], + copts = [ + "-ObjC++", + "-std=c++17", + ], + deps = [ + "//mediapipe/framework:calculator_cc_proto", + "//mediapipe/tasks/cc/core:mediapipe_builtin_op_resolver", + "//mediapipe/tasks/cc/core:task_runner", + "//mediapipe/tasks/ios/common/utils:MPPCommonUtils", + ], +) diff --git a/mediapipe/tasks/ios/core/sources/MPPBaseOptions.h b/mediapipe/tasks/ios/core/sources/MPPBaseOptions.h new file mode 100644 index 000000000..088d2d5da --- /dev/null +++ b/mediapipe/tasks/ios/core/sources/MPPBaseOptions.h @@ -0,0 +1,48 @@ +// Copyright 2022 The MediaPipe Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#import + +NS_ASSUME_NONNULL_BEGIN + +/** + * MediaPipe Tasks delegate. + */ +typedef NS_ENUM(NSUInteger, MPPDelegate) { + /** CPU. */ + MPPDelegateCPU, + + /** GPU. */ + MPPDelegateGPU +} NS_SWIFT_NAME(Delegate); + +/** + * Holds the base options that is used for creation of any type of task. It has fields with + * important information acceleration configuration, TFLite model source etc. + */ +NS_SWIFT_NAME(BaseOptions) +@interface MPPBaseOptions : NSObject + +/** The path to the model asset to open and mmap in memory. */ +@property(nonatomic, copy) NSString *modelAssetPath; + +/** + * Device delegate to run the MediaPipe pipeline. If the delegate is not set, the default + * delegate CPU is used. + */ +@property(nonatomic) MPPDelegate delegate; + +@end + +NS_ASSUME_NONNULL_END diff --git a/mediapipe/tasks/ios/core/sources/MPPBaseOptions.m b/mediapipe/tasks/ios/core/sources/MPPBaseOptions.m new file mode 100644 index 000000000..eaf2aa895 --- /dev/null +++ b/mediapipe/tasks/ios/core/sources/MPPBaseOptions.m @@ -0,0 +1,36 @@ +// Copyright 2022 The MediaPipe Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#import "mediapipe/tasks/ios/core/sources/MPPBaseOptions.h" + +@implementation MPPBaseOptions + +- (instancetype)init { + self = [super init]; + if (self) { + self.modelAssetPath = [[NSString alloc] init]; + } + return self; +} + +- (id)copyWithZone:(NSZone *)zone { + MPPBaseOptions *baseOptions = [[MPPBaseOptions alloc] init]; + + baseOptions.modelAssetPath = self.modelAssetPath; + baseOptions.delegate = self.delegate; + + return baseOptions; +} + +@end diff --git a/mediapipe/tasks/ios/core/sources/MPPTaskInfo.h b/mediapipe/tasks/ios/core/sources/MPPTaskInfo.h new file mode 100644 index 000000000..b94e704d1 --- /dev/null +++ b/mediapipe/tasks/ios/core/sources/MPPTaskInfo.h @@ -0,0 +1,70 @@ +// Copyright 2022 The MediaPipe Authors. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#import + +#include "mediapipe/framework/calculator.pb.h" + +#import "mediapipe/tasks/ios/core/sources/MPPTaskOptions.h" +#import "mediapipe/tasks/ios/core/sources/MPPTaskOptionsProtocol.h" + +NS_ASSUME_NONNULL_BEGIN + +/** + * Holds all needed informaton to initialize a MediaPipe Task. + */ +@interface MPPTaskInfo : NSObject + +@property(nonatomic, copy, nonnull) NSString *taskGraphName; + +/** + * A task-specific options that is derived from MPPTaskOptions and confirms to + * MPPTaskOptionsProtocol. + */ +@property(nonatomic, copy) id taskOptions; + +/** + * List of task graph input stream info strings in the form TAG:name. + */ +@property(nonatomic, copy) NSArray *inputStreams; + +/** + * List of task graph output stream info in the form TAG:name. + */ +@property(nonatomic, copy) NSArray *outputStreams; + +/** + * If the task requires a flow limiter. + */ +@property(nonatomic) BOOL enableFlowLimiting; + ++ (instancetype)new NS_UNAVAILABLE; + +- (instancetype)initWithTaskGraphName:(NSString *)taskGraphName + inputStreams:(NSArray *)inputStreams + outputStreams:(NSArray *)outputStreams + taskOptions:(id)taskOptions + enableFlowLimiting:(BOOL)enableFlowLimiting + error:(NSError **)error NS_DESIGNATED_INITIALIZER; + +/** + * Creates a MediaPipe Task protobuf message from the MPPTaskInfo instance. + */ +- (::mediapipe::CalculatorGraphConfig)generateGraphConfig; + +- (instancetype)init NS_UNAVAILABLE; + +@end + +NS_ASSUME_NONNULL_END diff --git a/mediapipe/tasks/ios/core/sources/MPPTaskInfo.mm b/mediapipe/tasks/ios/core/sources/MPPTaskInfo.mm new file mode 100644 index 000000000..80ff594a2 --- /dev/null +++ b/mediapipe/tasks/ios/core/sources/MPPTaskInfo.mm @@ -0,0 +1,136 @@ +// Copyright 2022 The MediaPipe Authors. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#import "mediapipe/tasks/ios/core/sources/MPPTaskInfo.h" +#import "mediapipe/tasks/ios/common/sources/MPPCommon.h" +#import "mediapipe/tasks/ios/common/utils/sources/MPPCommonUtils.h" +#import "mediapipe/tasks/ios/common/utils/sources/NSString+Helpers.h" + +#include "mediapipe/calculators/core/flow_limiter_calculator.pb.h" +#include "mediapipe/framework/calculator.pb.h" +#include "mediapipe/framework/calculator_options.pb.h" + +namespace { +using CalculatorGraphConfig = ::mediapipe::CalculatorGraphConfig; +using Node = ::mediapipe::CalculatorGraphConfig::Node; +using ::mediapipe::FlowLimiterCalculatorOptions; +using ::mediapipe::InputStreamInfo; +} // namespace + +@implementation MPPTaskInfo + +- (instancetype)initWithTaskGraphName:(NSString *)taskGraphName + inputStreams:(NSArray *)inputStreams + outputStreams:(NSArray *)outputStreams + taskOptions:(id)taskOptions + enableFlowLimiting:(BOOL)enableFlowLimiting + error:(NSError **)error { + if (!taskGraphName || !inputStreams.count || !outputStreams.count) { + [MPPCommonUtils + createCustomError:error + withCode:MPPTasksErrorCodeInvalidArgumentError + description: + @"Task graph's name, input streams, and output streams should be non-empty."]; + } + + self = [super init]; + + if (self) { + _taskGraphName = taskGraphName; + _inputStreams = inputStreams; + _outputStreams = outputStreams; + _taskOptions = taskOptions; + _enableFlowLimiting = enableFlowLimiting; + } + return self; +} + +- (id)copyWithZone:(NSZone *)zone { + MPPTaskInfo *taskInfo = [[MPPTaskInfo alloc] init]; + + taskInfo.taskGraphName = self.taskGraphName; + taskInfo.inputStreams = self.inputStreams; + taskInfo.outputStreams = self.outputStreams; + taskInfo.taskOptions = self.taskOptions; + taskInfo.enableFlowLimiting = self.enableFlowLimiting; + + return taskInfo; +} + +- (CalculatorGraphConfig)generateGraphConfig { + CalculatorGraphConfig graphConfig; + + Node *taskSubgraphNode = graphConfig.add_node(); + taskSubgraphNode->set_calculator(self.taskGraphName.cppString); + [self.taskOptions copyToProto:taskSubgraphNode->mutable_options()]; + + for (NSString *outputStream in self.outputStreams) { + auto cppOutputStream = std::string(outputStream.cppString); + taskSubgraphNode->add_output_stream(cppOutputStream); + graphConfig.add_output_stream(cppOutputStream); + } + + if (!self.enableFlowLimiting) { + for (NSString *inputStream in self.inputStreams) { + auto cppInputStream = inputStream.cppString; + taskSubgraphNode->add_input_stream(cppInputStream); + graphConfig.add_input_stream(cppInputStream); + } + return graphConfig; + } + + Node *flowLimitCalculatorNode = graphConfig.add_node(); + + flowLimitCalculatorNode->set_calculator("FlowLimiterCalculator"); + + InputStreamInfo *inputStreamInfo = flowLimitCalculatorNode->add_input_stream_info(); + inputStreamInfo->set_tag_index("FINISHED"); + inputStreamInfo->set_back_edge(true); + + FlowLimiterCalculatorOptions *flowLimitCalculatorOptions = + flowLimitCalculatorNode->mutable_options()->MutableExtension( + FlowLimiterCalculatorOptions::ext); + flowLimitCalculatorOptions->set_max_in_flight(1); + flowLimitCalculatorOptions->set_max_in_queue(1); + + for (NSString *inputStream in self.inputStreams) { + graphConfig.add_input_stream(inputStream.cppString); + + NSString *strippedInputStream = [MPPTaskInfo stripTagIndex:inputStream]; + flowLimitCalculatorNode->add_input_stream(strippedInputStream.cppString); + + NSString *taskInputStream = [MPPTaskInfo addStreamNamePrefix:inputStream]; + taskSubgraphNode->add_input_stream(taskInputStream.cppString); + + NSString *strippedTaskInputStream = [MPPTaskInfo stripTagIndex:taskInputStream]; + flowLimitCalculatorNode->add_output_stream(strippedTaskInputStream.cppString); + } + + NSString *firstOutputStream = self.outputStreams[0]; + auto finishedOutputStream = "FINISHED:" + firstOutputStream.cppString; + flowLimitCalculatorNode->add_input_stream(finishedOutputStream); + + return graphConfig; +} + ++ (NSString *)stripTagIndex:(NSString *)tagIndexName { + return [tagIndexName componentsSeparatedByString:@":"][1]; +} + ++ (NSString *)addStreamNamePrefix:(NSString *)tagIndexName { + NSArray *splits = [tagIndexName componentsSeparatedByString:@":"]; + return [NSString stringWithFormat:@"%@:throttled_%@", splits[0], splits[1]]; +} + +@end diff --git a/mediapipe/tasks/ios/core/sources/MPPTaskOptions.h b/mediapipe/tasks/ios/core/sources/MPPTaskOptions.h new file mode 100644 index 000000000..e10678348 --- /dev/null +++ b/mediapipe/tasks/ios/core/sources/MPPTaskOptions.h @@ -0,0 +1,34 @@ +// Copyright 2022 The MediaPipe Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#import +#import "mediapipe/tasks/ios/core/sources/MPPBaseOptions.h" + +NS_ASSUME_NONNULL_BEGIN + +/** + * MediaPipe Tasks options base class. Any MediaPipe task-specific options class should extend + * this class. + */ +NS_SWIFT_NAME(TaskOptions) + +@interface MPPTaskOptions : NSObject +/** + * Base options for configuring the MediaPipe task. + */ +@property(nonatomic, copy) MPPBaseOptions *baseOptions; + +@end + +NS_ASSUME_NONNULL_END diff --git a/mediapipe/tasks/ios/core/sources/MPPTaskOptions.m b/mediapipe/tasks/ios/core/sources/MPPTaskOptions.m new file mode 100644 index 000000000..fe74517c3 --- /dev/null +++ b/mediapipe/tasks/ios/core/sources/MPPTaskOptions.m @@ -0,0 +1,36 @@ +// Copyright 2022 The MediaPipe Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#import "mediapipe/tasks/ios/core/sources/MPPTaskOptions.h" +#import "mediapipe/tasks/ios/core/sources/MPPBaseOptions.h" + +@implementation MPPTaskOptions + +- (instancetype)init { + self = [super init]; + if (self) { + _baseOptions = [[MPPBaseOptions alloc] init]; + } + return self; +} + +- (id)copyWithZone:(NSZone *)zone { + MPPTaskOptions *taskOptions = [[MPPTaskOptions alloc] init]; + + taskOptions.baseOptions = self.baseOptions; + + return taskOptions; +} + +@end diff --git a/mediapipe/tasks/ios/core/sources/MPPTaskOptionsProtocol.h b/mediapipe/tasks/ios/core/sources/MPPTaskOptionsProtocol.h new file mode 100644 index 000000000..c03165c1d --- /dev/null +++ b/mediapipe/tasks/ios/core/sources/MPPTaskOptionsProtocol.h @@ -0,0 +1,33 @@ +// Copyright 2022 The MediaPipe Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#import + +#include "mediapipe/framework/calculator_options.pb.h" + +NS_ASSUME_NONNULL_BEGIN + +/** + * Any MediaPipe task options should confirm to this protocol. + */ +@protocol MPPTaskOptionsProtocol + +/** + * Copies the iOS MediaPipe task options to an object of mediapipe::CalculatorOptions proto. + */ +- (void)copyToProto:(::mediapipe::CalculatorOptions *)optionsProto; + +@end + +NS_ASSUME_NONNULL_END diff --git a/mediapipe/tasks/ios/core/sources/MPPTaskResult.h b/mediapipe/tasks/ios/core/sources/MPPTaskResult.h new file mode 100644 index 000000000..4ee7b2fc6 --- /dev/null +++ b/mediapipe/tasks/ios/core/sources/MPPTaskResult.h @@ -0,0 +1,37 @@ +// Copyright 2022 The MediaPipe Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#import + +NS_ASSUME_NONNULL_BEGIN + +/** + * MediaPipe Tasks result base class. Any MediaPipe task result class should extend + * this class. + */ +NS_SWIFT_NAME(TaskResult) + +@interface MPPTaskResult : NSObject +/** + * Timestamp that is associated with the task result object. + */ +@property(nonatomic, assign, readonly) NSInteger timestampMs; + +- (instancetype)init NS_UNAVAILABLE; + +- (instancetype)initWithTimestampMs:(NSInteger)timestampMs NS_DESIGNATED_INITIALIZER; + +@end + +NS_ASSUME_NONNULL_END diff --git a/mediapipe/tasks/ios/core/sources/MPPTaskResult.m b/mediapipe/tasks/ios/core/sources/MPPTaskResult.m new file mode 100644 index 000000000..6c08014ff --- /dev/null +++ b/mediapipe/tasks/ios/core/sources/MPPTaskResult.m @@ -0,0 +1,31 @@ +// Copyright 2022 The MediaPipe Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#import "mediapipe/tasks/ios/core/sources/MPPTaskResult.h" + +@implementation MPPTaskResult + +- (instancetype)initWithTimestampMs:(NSInteger)timestampMs { + self = [super init]; + if (self) { + _timestampMs = timestampMs; + } + return self; +} + +- (id)copyWithZone:(NSZone *)zone { + return [[MPPTaskResult alloc] initWithTimestampMs:self.timestampMs]; +} + +@end diff --git a/mediapipe/tasks/ios/core/sources/MPPTaskRunner.h b/mediapipe/tasks/ios/core/sources/MPPTaskRunner.h new file mode 100644 index 000000000..704fc453f --- /dev/null +++ b/mediapipe/tasks/ios/core/sources/MPPTaskRunner.h @@ -0,0 +1,90 @@ +// Copyright 2022 The MediaPipe Authors. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#import + +#include "mediapipe/framework/calculator.pb.h" +#include "mediapipe/tasks/cc/core/task_runner.h" + +NS_ASSUME_NONNULL_BEGIN + +/** + * This class is used to create and call appropriate methods on the C++ Task Runner to initialize, + * execute and terminate any MediaPipe task. + * + * An instance of the newly created C++ task runner will be stored until this class is destroyed. + * When methods are called for processing (performing inference), closing etc., on this class, + * internally the appropriate methods will be called on the C++ task runner instance to execute the + * appropriate actions. For each type of task, a subclass of this class must be defined to add any + * additional functionality. For eg:, vision tasks must create an `MPPVisionTaskRunner` and provide + * additional functionality. An instance of `MPPVisionTaskRunner` can in turn be used by the each + * vision task for creation and execution of the task. Please see the documentation for the C++ Task + * Runner for more details on how the taks runner operates. + */ +@interface MPPTaskRunner : NSObject + +/** + * Initializes a new `MPPTaskRunner` with the MediaPipe calculator configuration proto and an + * optional C++ packets callback. + * + * You can pass `nullptr` for `packetsCallback` in case the mode of operation requested by the user + * is synchronous. + * + * If the task is operating in asynchronous mode, any iOS MediaPipe task that uses the + * `MPPTaskRunner` must define a C++ callback function to obtain the results of inference + * asynchronously and deliver the results to the user. To accomplish this, the callback function + * should in turn invoke the block provided by the user in the task options supplied to create the + * task. Please see the documentation of the C++ Task Runner for more information on the synchronous + * and asynchronous modes of operation. + * + * @param graphConfig A mediapipe task graph config proto. + * @param packetsCallback An optional C++ callback function that takes a list of output packets as + * the input argument. If provided, the callback must in turn call the block provided by the user in + * the appropriate task options. + * + * @return An instance of `MPPTaskRunner` initialized to the given graph config proto and optional + * packetsCallback. + */ +- (instancetype)initWithCalculatorGraphConfig:(mediapipe::CalculatorGraphConfig)graphConfig + packetsCallback: + (mediapipe::tasks::core::PacketsCallback)packetsCallback + error:(NSError **)error NS_DESIGNATED_INITIALIZER; + +/** + * A synchronous method for processing batch data or offline streaming data. This method is designed + * for processing either batch data such as unrelated images and texts or offline streaming data + * such as the decoded frames from a video file or audio file. The call blocks the current + * thread until a failure status or a successful result is returned. If the input packets have no + * timestamp, an internal timestamp will be assigend per invocation. Otherwise, when the timestamp + * is set in the input packets, the caller must ensure that the input packet timestamps are greater + * than the timestamps of the previous invocation. This method is thread-unsafe and it is the + * caller's responsibility to synchronize access to this method across multiple threads and to + * ensure that the input packet timestamps are in order. + */ +- (absl::StatusOr)process: + (const mediapipe::tasks::core::PacketMap &)packetMap; + +/** + * Shuts down the C++ task runner. After the runner is closed, any calls that send input data to the + * runner are illegal and will receive errors. + */ +- (absl::Status)close; + +- (instancetype)init NS_UNAVAILABLE; + ++ (instancetype)new NS_UNAVAILABLE; + +@end + +NS_ASSUME_NONNULL_END diff --git a/mediapipe/tasks/ios/core/sources/MPPTaskRunner.mm b/mediapipe/tasks/ios/core/sources/MPPTaskRunner.mm new file mode 100644 index 000000000..eb777679a --- /dev/null +++ b/mediapipe/tasks/ios/core/sources/MPPTaskRunner.mm @@ -0,0 +1,61 @@ +// Copyright 2022 The MediaPipe Authors. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#import "mediapipe/tasks/ios/core/sources/MPPTaskRunner.h" +#import "mediapipe/tasks/ios/common/utils/sources/MPPCommonUtils.h" + +#include "mediapipe/tasks/cc/core/mediapipe_builtin_op_resolver.h" + +namespace { +using ::mediapipe::CalculatorGraphConfig; +using ::mediapipe::tasks::core::MediaPipeBuiltinOpResolver; +using ::mediapipe::tasks::core::PacketMap; +using ::mediapipe::tasks::core::PacketsCallback; +using TaskRunnerCpp = ::mediapipe::tasks::core::TaskRunner; +} // namespace + +@interface MPPTaskRunner () { + // Cpp Task Runner + std::unique_ptr _cppTaskRunner; +} +@end + +@implementation MPPTaskRunner + +- (instancetype)initWithCalculatorGraphConfig:(CalculatorGraphConfig)graphConfig + packetsCallback:(PacketsCallback)packetsCallback + error:(NSError **)error { + self = [super init]; + if (self) { + auto taskRunnerResult = TaskRunnerCpp::Create(std::move(graphConfig), + absl::make_unique(), + std::move(packetsCallback)); + + if (![MPPCommonUtils checkCppError:taskRunnerResult.status() toError:error]) { + return nil; + } + _cppTaskRunner = std::move(taskRunnerResult.value()); + } + return self; +} + +- (absl::StatusOr)process:(const PacketMap &)packetMap { + return _cppTaskRunner->Process(packetMap); +} + +- (absl::Status)close { + return _cppTaskRunner->Close(); +} + +@end diff --git a/mediapipe/tasks/ios/core/sources/MPPTextPacketCreator.h b/mediapipe/tasks/ios/core/sources/MPPTextPacketCreator.h new file mode 100644 index 000000000..03f946dd0 --- /dev/null +++ b/mediapipe/tasks/ios/core/sources/MPPTextPacketCreator.h @@ -0,0 +1,26 @@ +// Copyright 2022 The MediaPipe Authors. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#import + +#include "mediapipe/framework/packet.h" + +/* This class is an Objective-C wrapper around a MediaPipe graph object, and + * helps interface it with iOS technologies such as AVFoundation. + */ +@interface MPPTextPacketCreator : NSObject + ++ (mediapipe::Packet)createWithText:(NSString *)text; + +@end diff --git a/mediapipe/tasks/ios/core/sources/MPPTextPacketCreator.mm b/mediapipe/tasks/ios/core/sources/MPPTextPacketCreator.mm new file mode 100644 index 000000000..ca86e7a0b --- /dev/null +++ b/mediapipe/tasks/ios/core/sources/MPPTextPacketCreator.mm @@ -0,0 +1,29 @@ +// Copyright 2019 The MediaPipe Authors. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#import "mediapipe/tasks/ios/core/sources/MPPTextPacketCreator.h" +#import "mediapipe/tasks/ios/common/utils/sources/NSString+Helpers.h" + +namespace { +using ::mediapipe::MakePacket; +using ::mediapipe::Packet; +} // namespace + +@implementation MPPTextPacketCreator + ++ (Packet)createWithText:(NSString *)text { + return MakePacket(text.cppString); +} + +@end diff --git a/mediapipe/tasks/ios/core/utils/BUILD b/mediapipe/tasks/ios/core/utils/BUILD new file mode 100644 index 000000000..6577b03b2 --- /dev/null +++ b/mediapipe/tasks/ios/core/utils/BUILD @@ -0,0 +1,29 @@ +# Copyright 2023 The MediaPipe Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +package(default_visibility = ["//mediapipe/tasks:internal"]) + +licenses(["notice"]) + +objc_library( + name = "MPPBaseOptionsHelpers", + srcs = ["sources/MPPBaseOptions+Helpers.mm"], + hdrs = ["sources/MPPBaseOptions+Helpers.h"], + deps = [ + "//mediapipe/tasks/cc/core/proto:acceleration_cc_proto", + "//mediapipe/tasks/cc/core/proto:base_options_cc_proto", + "//mediapipe/tasks/cc/core/proto:external_file_cc_proto", + "//mediapipe/tasks/ios/core:MPPBaseOptions", + ], +) diff --git a/mediapipe/tasks/ios/core/utils/sources/MPPBaseOptions+Helpers.h b/mediapipe/tasks/ios/core/utils/sources/MPPBaseOptions+Helpers.h new file mode 100644 index 000000000..d52df2ae4 --- /dev/null +++ b/mediapipe/tasks/ios/core/utils/sources/MPPBaseOptions+Helpers.h @@ -0,0 +1,26 @@ +// Copyright 2023 The MediaPipe Authors. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "mediapipe/tasks/cc/core/proto/base_options.pb.h" +#import "mediapipe/tasks/ios/core/sources/MPPBaseOptions.h" + +NS_ASSUME_NONNULL_BEGIN + +@interface MPPBaseOptions (Helpers) + +- (void)copyToProto:(mediapipe::tasks::core::proto::BaseOptions *)baseOptionsProto; + +@end + +NS_ASSUME_NONNULL_END diff --git a/mediapipe/tasks/ios/core/utils/sources/MPPBaseOptions+Helpers.mm b/mediapipe/tasks/ios/core/utils/sources/MPPBaseOptions+Helpers.mm new file mode 100644 index 000000000..42cafe610 --- /dev/null +++ b/mediapipe/tasks/ios/core/utils/sources/MPPBaseOptions+Helpers.mm @@ -0,0 +1,47 @@ +// Copyright 2023 The MediaPipe Authors. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "mediapipe/tasks/cc/core/proto/acceleration.pb.h" +#include "mediapipe/tasks/cc/core/proto/external_file.pb.h" +#import "mediapipe/tasks/ios/core/utils/sources/MPPBaseOptions+Helpers.h" + +namespace { +using BaseOptionsProto = ::mediapipe::tasks::core::proto::BaseOptions; +} + +@implementation MPPBaseOptions (Helpers) + +- (void)copyToProto:(BaseOptionsProto *)baseOptionsProto { + baseOptionsProto->Clear(); + + if (self.modelAssetPath) { + baseOptionsProto->mutable_model_asset()->set_file_name(self.modelAssetPath.UTF8String); + } + + switch (self.delegate) { + case MPPDelegateCPU: { + baseOptionsProto->mutable_acceleration()->mutable_tflite(); + break; + } + case MPPDelegateGPU: { + // TODO: Provide an implementation for GPU Delegate. + [NSException raise:@"Invalid value for delegate" format:@"GPU Delegate is not implemented."]; + break; + } + default: + break; + } +} + +@end diff --git a/mediapipe/tasks/ios/ios.bzl b/mediapipe/tasks/ios/ios.bzl new file mode 100644 index 000000000..8fe2a24a1 --- /dev/null +++ b/mediapipe/tasks/ios/ios.bzl @@ -0,0 +1,3 @@ +"""MediaPipe Task Library Helper Rules for iOS""" + +MPP_TASK_MINIMUM_OS_VERSION = "11.0" diff --git a/mediapipe/tasks/ios/test/text/text_classifier/BUILD b/mediapipe/tasks/ios/test/text/text_classifier/BUILD new file mode 100644 index 000000000..3b533646e --- /dev/null +++ b/mediapipe/tasks/ios/test/text/text_classifier/BUILD @@ -0,0 +1,81 @@ +load( + "@build_bazel_rules_apple//apple:ios.bzl", + "ios_unit_test", +) +load( + "@build_bazel_rules_swift//swift:swift.bzl", + "swift_library", +) +load( + "//mediapipe/tasks:ios/ios.bzl", + "MPP_TASK_MINIMUM_OS_VERSION", +) +load( + "@org_tensorflow//tensorflow/lite:special_rules.bzl", + "tflite_ios_lab_runner", +) + +package(default_visibility = ["//mediapipe/tasks:internal"]) + +licenses(["notice"]) + +# Default tags for filtering iOS targets. Targets are restricted to Apple platforms. +TFL_DEFAULT_TAGS = [ + "apple", +] + +# Following sanitizer tests are not supported by iOS test targets. +TFL_DISABLED_SANITIZER_TAGS = [ + "noasan", + "nomsan", + "notsan", +] + +objc_library( + name = "MPPTextClassifierObjcTestLibrary", + testonly = 1, + srcs = ["MPPTextClassifierTests.m"], + data = [ + "//mediapipe/tasks/testdata/text:bert_text_classifier_models", + "//mediapipe/tasks/testdata/text:text_classifier_models", + ], + deps = [ + "//mediapipe/tasks/ios/common:MPPCommon", + "//mediapipe/tasks/ios/text/text_classifier:MPPTextClassifier", + ], +) + +ios_unit_test( + name = "MPPTextClassifierObjcTest", + minimum_os_version = MPP_TASK_MINIMUM_OS_VERSION, + runner = tflite_ios_lab_runner("IOS_LATEST"), + tags = TFL_DEFAULT_TAGS + TFL_DISABLED_SANITIZER_TAGS, + deps = [ + ":MPPTextClassifierObjcTestLibrary", + ], +) + +swift_library( + name = "MPPTextClassifierSwiftTestLibrary", + testonly = 1, + srcs = ["TextClassifierTests.swift"], + data = [ + "//mediapipe/tasks/testdata/text:bert_text_classifier_models", + "//mediapipe/tasks/testdata/text:text_classifier_models", + ], + tags = TFL_DEFAULT_TAGS, + deps = [ + "//mediapipe/tasks/ios/common:MPPCommon", + "//mediapipe/tasks/ios/text/text_classifier:MPPTextClassifier", + ], +) + +ios_unit_test( + name = "MPPTextClassifierSwiftTest", + minimum_os_version = MPP_TASK_MINIMUM_OS_VERSION, + runner = tflite_ios_lab_runner("IOS_LATEST"), + tags = TFL_DEFAULT_TAGS + TFL_DISABLED_SANITIZER_TAGS, + deps = [ + ":MPPTextClassifierSwiftTestLibrary", + ], +) diff --git a/mediapipe/tasks/ios/test/text/text_classifier/MPPTextClassifierTests.m b/mediapipe/tasks/ios/test/text/text_classifier/MPPTextClassifierTests.m new file mode 100644 index 000000000..a80fb8824 --- /dev/null +++ b/mediapipe/tasks/ios/test/text/text_classifier/MPPTextClassifierTests.m @@ -0,0 +1,275 @@ +// Copyright 2023 The MediaPipe Authors. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#import + +#import "mediapipe/tasks/ios/common/sources/MPPCommon.h" +#import "mediapipe/tasks/ios/text/text_classifier/sources/MPPTextClassifier.h" + +static NSString *const kBertTextClassifierModelName = @"bert_text_classifier"; +static NSString *const kRegexTextClassifierModelName = + @"test_model_text_classifier_with_regex_tokenizer"; +static NSString *const kNegativeText = @"unflinchingly bleak and desperate"; +static NSString *const kPositiveText = @"it's a charming and often affecting journey"; +static NSString *const kExpectedErrorDomain = @"com.google.mediapipe.tasks"; + +#define AssertEqualErrors(error, expectedError) \ + XCTAssertNotNil(error); \ + XCTAssertEqualObjects(error.domain, expectedError.domain); \ + XCTAssertEqual(error.code, expectedError.code); \ + XCTAssertNotEqual( \ + [error.localizedDescription rangeOfString:expectedError.localizedDescription].location, \ + NSNotFound) + +#define AssertEqualCategoryArrays(categories, expectedCategories) \ + XCTAssertEqual(categories.count, expectedCategories.count); \ + for (int i = 0; i < categories.count; i++) { \ + XCTAssertEqual(categories[i].index, expectedCategories[i].index, @"index i = %d", i); \ + XCTAssertEqualWithAccuracy(categories[i].score, expectedCategories[i].score, 1e-3, \ + @"index i = %d", i); \ + XCTAssertEqualObjects(categories[i].categoryName, expectedCategories[i].categoryName, \ + @"index i = %d", i); \ + XCTAssertEqualObjects(categories[i].displayName, expectedCategories[i].displayName, \ + @"index i = %d", i); \ + } + +#define AssertTextClassifierResultHasOneHead(textClassifierResult) \ + XCTAssertNotNil(textClassifierResult); \ + XCTAssertNotNil(textClassifierResult.classificationResult); \ + XCTAssertEqual(textClassifierResult.classificationResult.classifications.count, 1); \ + XCTAssertEqual(textClassifierResult.classificationResult.classifications[0].headIndex, 0); + +@interface MPPTextClassifierTests : XCTestCase +@end + +@implementation MPPTextClassifierTests + ++ (NSArray *)expectedBertResultCategoriesForNegativeText { + return @[ + [[MPPCategory alloc] initWithIndex:0 score:0.956187f categoryName:@"negative" displayName:nil], + [[MPPCategory alloc] initWithIndex:1 score:0.043812f categoryName:@"positive" displayName:nil] + ]; +} + ++ (NSArray *)expectedBertResultCategoriesForPositiveText { + return @[ + [[MPPCategory alloc] initWithIndex:1 score:0.999945f categoryName:@"positive" displayName:nil], + [[MPPCategory alloc] initWithIndex:0 score:0.000055f categoryName:@"negative" displayName:nil] + ]; +} + ++ (NSArray *)expectedRegexResultCategoriesForNegativeText { + return @[ + [[MPPCategory alloc] initWithIndex:0 score:0.6647746f categoryName:@"Negative" displayName:nil], + [[MPPCategory alloc] initWithIndex:1 score:0.33522537 categoryName:@"Positive" displayName:nil] + ]; +} + ++ (NSArray *)expectedRegexResultCategoriesForPositiveText { + return @[ + [[MPPCategory alloc] initWithIndex:0 score:0.5120041f categoryName:@"Negative" displayName:nil], + [[MPPCategory alloc] initWithIndex:1 score:0.48799595 categoryName:@"Positive" displayName:nil] + ]; +} + ++ (NSArray *)expectedBertResultCategoriesForEdgeCaseTests { + return @[ [[MPPCategory alloc] initWithIndex:0 + score:0.956187f + categoryName:@"negative" + displayName:nil] ]; +} + +- (NSString *)filePathWithName:(NSString *)fileName extension:(NSString *)extension { + NSString *filePath = [[NSBundle bundleForClass:self.class] pathForResource:fileName + ofType:extension]; + return filePath; +} + +- (MPPTextClassifierOptions *)textClassifierOptionsWithModelName:(NSString *)modelName { + NSString *modelPath = [self filePathWithName:modelName extension:@"tflite"]; + MPPTextClassifierOptions *textClassifierOptions = [[MPPTextClassifierOptions alloc] init]; + textClassifierOptions.baseOptions.modelAssetPath = modelPath; + + return textClassifierOptions; +} + +- (MPPTextClassifier *)textClassifierFromModelFileWithName:(NSString *)modelName { + NSString *modelPath = [self filePathWithName:modelName extension:@"tflite"]; + MPPTextClassifier *textClassifier = [[MPPTextClassifier alloc] initWithModelPath:modelPath + error:nil]; + XCTAssertNotNil(textClassifier); + + return textClassifier; +} + +- (void)assertCreateTextClassifierWithOptions:(MPPTextClassifierOptions *)textClassifierOptions + failsWithExpectedError:(NSError *)expectedError { + NSError *error = nil; + MPPTextClassifier *textClassifier = + [[MPPTextClassifier alloc] initWithOptions:textClassifierOptions error:&error]; + XCTAssertNil(textClassifier); + AssertEqualErrors(error, expectedError); +} + +- (void)assertResultsOfClassifyText:(NSString *)text + usingTextClassifier:(MPPTextClassifier *)textClassifier + equalsCategories:(NSArray *)expectedCategories { + MPPTextClassifierResult *negativeResult = [textClassifier classifyText:text error:nil]; + AssertTextClassifierResultHasOneHead(negativeResult); + AssertEqualCategoryArrays(negativeResult.classificationResult.classifications[0].categories, + expectedCategories); +} + +- (void)testCreateTextClassifierFailsWithMissingModelPath { + NSString *modelPath = [self filePathWithName:@"" extension:@""]; + + NSError *error = nil; + MPPTextClassifier *textClassifier = [[MPPTextClassifier alloc] initWithModelPath:modelPath + error:&error]; + XCTAssertNil(textClassifier); + + NSError *expectedError = [NSError + errorWithDomain:kExpectedErrorDomain + code:MPPTasksErrorCodeInvalidArgumentError + userInfo:@{ + NSLocalizedDescriptionKey : + @"INVALID_ARGUMENT: ExternalFile must specify at least one of 'file_content', " + @"'file_name', 'file_pointer_meta' or 'file_descriptor_meta'." + }]; + AssertEqualErrors(error, expectedError); +} + +- (void)testCreateTextClassifierFailsWithBothAllowlistAndDenylist { + MPPTextClassifierOptions *options = + [self textClassifierOptionsWithModelName:kBertTextClassifierModelName]; + options.categoryAllowlist = @[ @"positive" ]; + options.categoryDenylist = @[ @"negative" ]; + + [self assertCreateTextClassifierWithOptions:options + failsWithExpectedError: + [NSError + errorWithDomain:kExpectedErrorDomain + code:MPPTasksErrorCodeInvalidArgumentError + userInfo:@{ + NSLocalizedDescriptionKey : + @"INVALID_ARGUMENT: `category_allowlist` and " + @"`category_denylist` are mutually exclusive options." + }]]; +} + +- (void)testCreateTextClassifierFailsWithInvalidMaxResults { + MPPTextClassifierOptions *options = + [self textClassifierOptionsWithModelName:kBertTextClassifierModelName]; + options.maxResults = 0; + + [self assertCreateTextClassifierWithOptions:options + failsWithExpectedError: + [NSError errorWithDomain:kExpectedErrorDomain + code:MPPTasksErrorCodeInvalidArgumentError + userInfo:@{ + NSLocalizedDescriptionKey : + @"INVALID_ARGUMENT: Invalid `max_results` option: " + @"value must be != 0." + }]]; +} + +- (void)testClassifyWithBertSucceeds { + MPPTextClassifier *textClassifier = + [self textClassifierFromModelFileWithName:kBertTextClassifierModelName]; + + [self assertResultsOfClassifyText:kNegativeText + usingTextClassifier:textClassifier + equalsCategories:[MPPTextClassifierTests + expectedBertResultCategoriesForNegativeText]]; + + [self assertResultsOfClassifyText:kPositiveText + usingTextClassifier:textClassifier + equalsCategories:[MPPTextClassifierTests + expectedBertResultCategoriesForPositiveText]]; +} + +- (void)testClassifyWithRegexSucceeds { + MPPTextClassifier *textClassifier = + [self textClassifierFromModelFileWithName:kRegexTextClassifierModelName]; + + [self assertResultsOfClassifyText:kNegativeText + usingTextClassifier:textClassifier + equalsCategories:[MPPTextClassifierTests + expectedRegexResultCategoriesForNegativeText]]; + [self assertResultsOfClassifyText:kPositiveText + usingTextClassifier:textClassifier + equalsCategories:[MPPTextClassifierTests + expectedRegexResultCategoriesForPositiveText]]; +} + +- (void)testClassifyWithMaxResultsSucceeds { + MPPTextClassifierOptions *options = + [self textClassifierOptionsWithModelName:kBertTextClassifierModelName]; + options.maxResults = 1; + + MPPTextClassifier *textClassifier = [[MPPTextClassifier alloc] initWithOptions:options error:nil]; + XCTAssertNotNil(textClassifier); + + [self assertResultsOfClassifyText:kNegativeText + usingTextClassifier:textClassifier + equalsCategories:[MPPTextClassifierTests + expectedBertResultCategoriesForEdgeCaseTests]]; +} + +- (void)testClassifyWithCategoryAllowlistSucceeds { + MPPTextClassifierOptions *options = + [self textClassifierOptionsWithModelName:kBertTextClassifierModelName]; + options.categoryAllowlist = @[ @"negative" ]; + + NSError *error = nil; + MPPTextClassifier *textClassifier = [[MPPTextClassifier alloc] initWithOptions:options + error:&error]; + XCTAssertNotNil(textClassifier); + XCTAssertNil(error); + + [self assertResultsOfClassifyText:kNegativeText + usingTextClassifier:textClassifier + equalsCategories:[MPPTextClassifierTests + expectedBertResultCategoriesForEdgeCaseTests]]; +} + +- (void)testClassifyWithCategoryDenylistSucceeds { + MPPTextClassifierOptions *options = + [self textClassifierOptionsWithModelName:kBertTextClassifierModelName]; + options.categoryDenylist = @[ @"positive" ]; + + MPPTextClassifier *textClassifier = [[MPPTextClassifier alloc] initWithOptions:options error:nil]; + XCTAssertNotNil(textClassifier); + + [self assertResultsOfClassifyText:kNegativeText + usingTextClassifier:textClassifier + equalsCategories:[MPPTextClassifierTests + expectedBertResultCategoriesForEdgeCaseTests]]; +} + +- (void)testClassifyWithScoreThresholdSucceeds { + MPPTextClassifierOptions *options = + [self textClassifierOptionsWithModelName:kBertTextClassifierModelName]; + options.scoreThreshold = 0.5f; + + MPPTextClassifier *textClassifier = [[MPPTextClassifier alloc] initWithOptions:options error:nil]; + XCTAssertNotNil(textClassifier); + + [self assertResultsOfClassifyText:kNegativeText + usingTextClassifier:textClassifier + equalsCategories:[MPPTextClassifierTests + expectedBertResultCategoriesForEdgeCaseTests]]; +} + +@end diff --git a/mediapipe/tasks/ios/test/text/text_classifier/TextClassifierTests.swift b/mediapipe/tasks/ios/test/text/text_classifier/TextClassifierTests.swift new file mode 100644 index 000000000..a103283eb --- /dev/null +++ b/mediapipe/tasks/ios/test/text/text_classifier/TextClassifierTests.swift @@ -0,0 +1,264 @@ +// Copyright 2023 The MediaPipe Authors. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +import MPPCommon +import XCTest + +@testable import MPPTextClassifier + +class TextClassifierTests: XCTestCase { + + static let bundle = Bundle(for: TextClassifierTests.self) + + static let bertModelPath = bundle.path( + forResource: "bert_text_classifier", + ofType: "tflite") + + static let positiveText = "it's a charming and often affecting journey" + + static let negativeText = "unflinchingly bleak and desperate" + + static let bertNegativeTextResults = [ + ResultCategory( + index: 0, + score: 0.956187, + categoryName: "negative", + displayName: nil), + ResultCategory( + index: 1, + score: 0.043812, + categoryName: "positive", + displayName: nil), + ] + + static let bertNegativeTextResultsForEdgeTestCases = [ + ResultCategory( + index: 0, + score: 0.956187, + categoryName: "negative", + displayName: nil) + ] + + func assertEqualErrorDescriptions( + _ error: Error, expectedLocalizedDescription: String + ) { + XCTAssertEqual( + error.localizedDescription, + expectedLocalizedDescription) + } + + func assertCategoriesAreEqual( + category: ResultCategory, + expectedCategory: ResultCategory, + indexInCategoryList: Int + ) { + XCTAssertEqual( + category.index, + expectedCategory.index, + String( + format: """ + category[%d].index and expectedCategory[%d].index are not equal. + """, indexInCategoryList)) + XCTAssertEqual( + category.score, + expectedCategory.score, + accuracy: 1e-3, + String( + format: """ + category[%d].score and expectedCategory[%d].score are not equal. + """, indexInCategoryList)) + XCTAssertEqual( + category.categoryName, + expectedCategory.categoryName, + String( + format: """ + category[%d].categoryName and expectedCategory[%d].categoryName are \ + not equal. + """, indexInCategoryList)) + XCTAssertEqual( + category.displayName, + expectedCategory.displayName, + String( + format: """ + category[%d].displayName and expectedCategory[%d].displayName are \ + not equal. + """, indexInCategoryList)) + } + + func assertEqualCategoryArrays( + categoryArray: [ResultCategory], + expectedCategoryArray: [ResultCategory] + ) { + XCTAssertEqual( + categoryArray.count, + expectedCategoryArray.count) + + for (index, (category, expectedCategory)) in zip(categoryArray, expectedCategoryArray) + .enumerated() + { + assertCategoriesAreEqual( + category: category, + expectedCategory: expectedCategory, + indexInCategoryList: index) + } + } + + func assertTextClassifierResultHasOneHead( + _ textClassifierResult: TextClassifierResult + ) { + XCTAssertEqual(textClassifierResult.classificationResult.classifications.count, 1) + XCTAssertEqual(textClassifierResult.classificationResult.classifications[0].headIndex, 0) + } + + func textClassifierOptionsWithModelPath( + _ modelPath: String? + ) throws -> TextClassifierOptions { + let modelPath = try XCTUnwrap(modelPath) + + let textClassifierOptions = TextClassifierOptions() + textClassifierOptions.baseOptions.modelAssetPath = modelPath + + return textClassifierOptions + } + + func assertCreateTextClassifierThrowsError( + textClassifierOptions: TextClassifierOptions, + expectedErrorDescription: String + ) { + do { + let textClassifier = try TextClassifier(options: textClassifierOptions) + XCTAssertNil(textClassifier) + } catch { + assertEqualErrorDescriptions( + error, + expectedLocalizedDescription: expectedErrorDescription) + } + } + + func assertResultsForClassify( + text: String, + using textClassifier: TextClassifier, + equals expectedCategories: [ResultCategory] + ) throws { + let textClassifierResult = + try XCTUnwrap( + textClassifier.classify(text: text)) + assertTextClassifierResultHasOneHead(textClassifierResult) + assertEqualCategoryArrays( + categoryArray: + textClassifierResult.classificationResult.classifications[0].categories, + expectedCategoryArray: expectedCategories) + } + + func testCreateTextClassifierWithInvalidMaxResultsFails() throws { + let textClassifierOptions = + try XCTUnwrap( + textClassifierOptionsWithModelPath(TextClassifierTests.bertModelPath)) + textClassifierOptions.maxResults = 0 + + assertCreateTextClassifierThrowsError( + textClassifierOptions: textClassifierOptions, + expectedErrorDescription: """ + INVALID_ARGUMENT: Invalid `max_results` option: value must be != 0. + """) + } + + func testCreateTextClassifierWithCategoryAllowlistAndDenylistFails() throws { + + let textClassifierOptions = + try XCTUnwrap( + textClassifierOptionsWithModelPath(TextClassifierTests.bertModelPath)) + textClassifierOptions.categoryAllowlist = ["positive"] + textClassifierOptions.categoryDenylist = ["positive"] + + assertCreateTextClassifierThrowsError( + textClassifierOptions: textClassifierOptions, + expectedErrorDescription: """ + INVALID_ARGUMENT: `category_allowlist` and `category_denylist` are \ + mutually exclusive options. + """) + } + + func testClassifyWithBertSucceeds() throws { + + let modelPath = try XCTUnwrap(TextClassifierTests.bertModelPath) + let textClassifier = try XCTUnwrap(TextClassifier(modelPath: modelPath)) + + try assertResultsForClassify( + text: TextClassifierTests.negativeText, + using: textClassifier, + equals: TextClassifierTests.bertNegativeTextResults) + } + + func testClassifyWithMaxResultsSucceeds() throws { + let textClassifierOptions = + try XCTUnwrap( + textClassifierOptionsWithModelPath(TextClassifierTests.bertModelPath)) + textClassifierOptions.maxResults = 1 + + let textClassifier = + try XCTUnwrap(TextClassifier(options: textClassifierOptions)) + + try assertResultsForClassify( + text: TextClassifierTests.negativeText, + using: textClassifier, + equals: TextClassifierTests.bertNegativeTextResultsForEdgeTestCases) + } + + func testClassifyWithCategoryAllowlistSucceeds() throws { + let textClassifierOptions = + try XCTUnwrap( + textClassifierOptionsWithModelPath(TextClassifierTests.bertModelPath)) + textClassifierOptions.categoryAllowlist = ["negative"] + + let textClassifier = + try XCTUnwrap(TextClassifier(options: textClassifierOptions)) + + try assertResultsForClassify( + text: TextClassifierTests.negativeText, + using: textClassifier, + equals: TextClassifierTests.bertNegativeTextResultsForEdgeTestCases) + } + + func testClassifyWithCategoryDenylistSucceeds() throws { + let textClassifierOptions = + try XCTUnwrap( + textClassifierOptionsWithModelPath(TextClassifierTests.bertModelPath)) + textClassifierOptions.categoryDenylist = ["positive"] + + let textClassifier = + try XCTUnwrap(TextClassifier(options: textClassifierOptions)) + + try assertResultsForClassify( + text: TextClassifierTests.negativeText, + using: textClassifier, + equals: TextClassifierTests.bertNegativeTextResultsForEdgeTestCases) + } + + func testClassifyWithScoreThresholdSucceeds() throws { + let textClassifierOptions = + try XCTUnwrap( + textClassifierOptionsWithModelPath(TextClassifierTests.bertModelPath)) + textClassifierOptions.scoreThreshold = 0.5 + + let textClassifier = + try XCTUnwrap(TextClassifier(options: textClassifierOptions)) + + try assertResultsForClassify( + text: TextClassifierTests.negativeText, + using: textClassifier, + equals: TextClassifierTests.bertNegativeTextResultsForEdgeTestCases) + } + +} diff --git a/mediapipe/tasks/ios/text/core/BUILD b/mediapipe/tasks/ios/text/core/BUILD new file mode 100644 index 000000000..e07d92979 --- /dev/null +++ b/mediapipe/tasks/ios/text/core/BUILD @@ -0,0 +1,28 @@ +# Copyright 2023 The MediaPipe Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +package(default_visibility = ["//mediapipe/tasks:internal"]) + +licenses(["notice"]) + +objc_library( + name = "MPPTextTaskRunner", + srcs = ["sources/MPPTextTaskRunner.mm"], + hdrs = ["sources/MPPTextTaskRunner.h"], + copts = [ + "-ObjC++", + "-std=c++17", + ], + deps = ["//mediapipe/tasks/ios/core:MPPTaskRunner"], +) diff --git a/mediapipe/tasks/ios/text/core/sources/MPPTextTaskRunner.h b/mediapipe/tasks/ios/text/core/sources/MPPTextTaskRunner.h new file mode 100644 index 000000000..ba83a1141 --- /dev/null +++ b/mediapipe/tasks/ios/text/core/sources/MPPTextTaskRunner.h @@ -0,0 +1,43 @@ +// Copyright 2023 The MediaPipe Authors. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#import +#import "mediapipe/tasks/ios/core/sources/MPPTaskRunner.h" + +NS_ASSUME_NONNULL_BEGIN + +/** + * This class is used to create and call appropriate methods on the C++ Task Runner to initialize, + * execute and terminate any MediaPipe text task. + */ +@interface MPPTextTaskRunner : MPPTaskRunner + +/** + * Initializes a new `MPPTextTaskRunner` with the MediaPipe calculator config proto. + * + * @param graphConfig A MediaPipe calculator config proto. + * + * @return An instance of `MPPTextTaskRunner` initialized to the given MediaPipe calculator config + * proto. + */ +- (instancetype)initWithCalculatorGraphConfig:(mediapipe::CalculatorGraphConfig)graphConfig + error:(NSError **)error; + +- (instancetype)init NS_UNAVAILABLE; + ++ (instancetype)new NS_UNAVAILABLE; + +@end + +NS_ASSUME_NONNULL_END diff --git a/mediapipe/tasks/ios/text/core/sources/MPPTextTaskRunner.mm b/mediapipe/tasks/ios/text/core/sources/MPPTextTaskRunner.mm new file mode 100644 index 000000000..b539eb133 --- /dev/null +++ b/mediapipe/tasks/ios/text/core/sources/MPPTextTaskRunner.mm @@ -0,0 +1,38 @@ +// Copyright 2023 The MediaPipe Authors. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#import "mediapipe/tasks/ios/text/core/sources/MPPTextTaskRunner.h" + +namespace { +using ::mediapipe::CalculatorGraphConfig; +} // namespace + +@implementation MPPTextTaskRunner + +- (instancetype)initWithCalculatorGraphConfig:(mediapipe::CalculatorGraphConfig)graphConfig + packetsCallback: + (mediapipe::tasks::core::PacketsCallback)packetsCallback + error:(NSError **)error { + self = [super initWithCalculatorGraphConfig:graphConfig + packetsCallback:packetsCallback + error:error]; + return self; +} + +- (instancetype)initWithCalculatorGraphConfig:(CalculatorGraphConfig)graphConfig + error:(NSError **)error { + return [self initWithCalculatorGraphConfig:graphConfig packetsCallback:nullptr error:error]; +} + +@end diff --git a/mediapipe/tasks/ios/text/text_classifier/BUILD b/mediapipe/tasks/ios/text/text_classifier/BUILD new file mode 100644 index 000000000..855a86496 --- /dev/null +++ b/mediapipe/tasks/ios/text/text_classifier/BUILD @@ -0,0 +1,61 @@ +# Copyright 2023 The MediaPipe Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +package(default_visibility = ["//mediapipe/tasks:internal"]) + +licenses(["notice"]) + +objc_library( + name = "MPPTextClassifierOptions", + srcs = ["sources/MPPTextClassifierOptions.m"], + hdrs = ["sources/MPPTextClassifierOptions.h"], + deps = ["//mediapipe/tasks/ios/core:MPPTaskOptions"], +) + +objc_library( + name = "MPPTextClassifierResult", + srcs = ["sources/MPPTextClassifierResult.m"], + hdrs = ["sources/MPPTextClassifierResult.h"], + deps = [ + "//mediapipe/tasks/ios/components/containers:MPPClassificationResult", + "//mediapipe/tasks/ios/core:MPPTaskResult", + ], +) + +objc_library( + name = "MPPTextClassifier", + srcs = ["sources/MPPTextClassifier.mm"], + hdrs = ["sources/MPPTextClassifier.h"], + copts = [ + "-ObjC++", + "-std=c++17", + "-x objective-c++", + ], + module_name = "MPPTextClassifier", + deps = [ + ":MPPTextClassifierOptions", + ":MPPTextClassifierResult", + "//mediapipe/tasks/cc/components/containers/proto:classifications_cc_proto", + "//mediapipe/tasks/cc/text/text_classifier:text_classifier_graph", + "//mediapipe/tasks/ios/common/utils:MPPCommonUtils", + "//mediapipe/tasks/ios/common/utils:NSStringHelpers", + "//mediapipe/tasks/ios/core:MPPTaskInfo", + "//mediapipe/tasks/ios/core:MPPTaskOptions", + "//mediapipe/tasks/ios/core:MPPTextPacketCreator", + "//mediapipe/tasks/ios/text/core:MPPTextTaskRunner", + "//mediapipe/tasks/ios/text/text_classifier/utils:MPPTextClassifierOptionsHelpers", + "//mediapipe/tasks/ios/text/text_classifier/utils:MPPTextClassifierResultHelpers", + "@com_google_absl//absl/status:statusor", + ], +) diff --git a/mediapipe/tasks/ios/text/text_classifier/sources/MPPTextClassifier.h b/mediapipe/tasks/ios/text/text_classifier/sources/MPPTextClassifier.h new file mode 100644 index 000000000..33d3c8970 --- /dev/null +++ b/mediapipe/tasks/ios/text/text_classifier/sources/MPPTextClassifier.h @@ -0,0 +1,102 @@ +// Copyright 2023 The MediaPipe Authors. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#import + +#import "mediapipe/tasks/ios/core/sources/MPPTaskOptions.h" +#import "mediapipe/tasks/ios/text/text_classifier/sources/MPPTextClassifierOptions.h" +#import "mediapipe/tasks/ios/text/text_classifier/sources/MPPTextClassifierResult.h" + +NS_ASSUME_NONNULL_BEGIN + +/** + * @brief Performs classification on text. + * + * This API expects a TFLite model with (optional) [TFLite Model + * Metadata](https://www.tensorflow.org/lite/convert/metadata")that contains the mandatory + * (described below) input tensors, output tensor, and the optional (but recommended) label + * items as AssociatedFiles with type TENSOR_AXIS_LABELS per output classification tensor. + * + * Metadata is required for models with int32 input tensors because it contains the input + * process unit for the model's Tokenizer. No metadata is required for models with string + * input tensors. + * + * Input tensors + * - Three input tensors `kTfLiteInt32` of shape `[batch_size xbert_max_seq_len]` + * representing the input ids, mask ids, and segment ids. This input signature requires + * a Bert Tokenizer process unit in the model metadata. + * - Or one input tensor `kTfLiteInt32` of shape `[batch_size xmax_seq_len]` representing + * the input ids. This input signature requires a Regex Tokenizer process unit in the + * model metadata. + * - Or one input tensor (`kTfLiteString`) that is shapeless or has shape `[1]` containing + * the input string. + * + * At least one output tensor (`kTfLiteFloat32/kBool`) with: + * - `N` classes and shape `[1 x N]` + * - optional (but recommended) label map(s) as AssociatedFiles with type TENSOR_AXIS_LABELS, + * containing one label per line. The first such AssociatedFile (if any) is used to fill the + * `categoryName` field of the results. The `displayName` field is filled from the + * AssociatedFile (if any) whose locale matches the `displayNamesLocale` field of the + * `MPPTextClassifierOptions` used at creation time ("en" by default, i.e. English). If none of + * these are available, only the `index` field of the results will be filled. + */ +NS_SWIFT_NAME(TextClassifier) +@interface MPPTextClassifier : NSObject + +/** + * Creates a new instance of `MPPTextClassifier` from an absolute path to a TensorFlow Lite + * model file stored locally on the device and the default `MPPTextClassifierOptions`. + * + * @param modelPath An absolute path to a TensorFlow Lite model file stored locally on the device. + * @param error An optional error parameter populated when there is an error in initializing the + * text classifier. + * + * @return A new instance of `MPPTextClassifier` with the given model path. `nil` if there is an + * error in initializing the text classifier. + */ +- (nullable instancetype)initWithModelPath:(NSString *)modelPath error:(NSError **)error; + +/** + * Creates a new instance of `MPPTextClassifier` from the given `MPPTextClassifierOptions`. + * + * @param options The options of type `MPPTextClassifierOptions` to use for configuring the + * `MPPTextClassifier`. + * @param error An optional error parameter populated when there is an error in initializing the + * text classifier. + * + * @return A new instance of `MPPTextClassifier` with the given options. `nil` if there is an + * error in initializing the text classifier. + */ +- (nullable instancetype)initWithOptions:(MPPTextClassifierOptions *)options + error:(NSError **)error NS_DESIGNATED_INITIALIZER; + +/** + * Performs classification on the input text. + * + * @param text The `NSString` on which classification is to be performed. + * @param error An optional error parameter populated when there is an error in performing + * classification on the input text. + * + * @return A `MPPTextClassifierResult` object that contains a list of text classifications. + */ +- (nullable MPPTextClassifierResult *)classifyText:(NSString *)text + error:(NSError **)error NS_SWIFT_NAME(classify(text:)); + +- (instancetype)init NS_UNAVAILABLE; + ++ (instancetype)new NS_UNAVAILABLE; + +@end + +NS_ASSUME_NONNULL_END diff --git a/mediapipe/tasks/ios/text/text_classifier/sources/MPPTextClassifier.mm b/mediapipe/tasks/ios/text/text_classifier/sources/MPPTextClassifier.mm new file mode 100644 index 000000000..52e4d92ac --- /dev/null +++ b/mediapipe/tasks/ios/text/text_classifier/sources/MPPTextClassifier.mm @@ -0,0 +1,97 @@ +// Copyright 2023 The MediaPipe Authors. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#import "mediapipe/tasks/ios/text/text_classifier/sources/MPPTextClassifier.h" + +#import "mediapipe/tasks/ios/common/utils/sources/MPPCommonUtils.h" +#import "mediapipe/tasks/ios/common/utils/sources/NSString+Helpers.h" +#import "mediapipe/tasks/ios/core/sources/MPPTaskInfo.h" +#import "mediapipe/tasks/ios/core/sources/MPPTextPacketCreator.h" +#import "mediapipe/tasks/ios/text/core/sources/MPPTextTaskRunner.h" +#import "mediapipe/tasks/ios/text/text_classifier/utils/sources/MPPTextClassifierOptions+Helpers.h" +#import "mediapipe/tasks/ios/text/text_classifier/utils/sources/MPPTextClassifierResult+Helpers.h" + +#include "absl/status/statusor.h" +#include "mediapipe/tasks/cc/components/containers/proto/classifications.pb.h" + +namespace { +using ::mediapipe::Packet; +using ::mediapipe::tasks::core::PacketMap; +} // namespace + +static NSString *const kClassificationsStreamName = @"classifications_out"; +static NSString *const kClassificationsTag = @"CLASSIFICATIONS"; +static NSString *const kTextInStreamName = @"text_in"; +static NSString *const kTextTag = @"TEXT"; +static NSString *const kTaskGraphName = @"mediapipe.tasks.text.text_classifier.TextClassifierGraph"; + +@interface MPPTextClassifier () { + /** iOS Text Task Runner */ + MPPTextTaskRunner *_textTaskRunner; +} +@end + +@implementation MPPTextClassifier + +- (instancetype)initWithOptions:(MPPTextClassifierOptions *)options error:(NSError **)error { + self = [super init]; + if (self) { + MPPTaskInfo *taskInfo = [[MPPTaskInfo alloc] + initWithTaskGraphName:kTaskGraphName + inputStreams:@[ [NSString stringWithFormat:@"%@:%@", kTextTag, kTextInStreamName] ] + outputStreams:@[ [NSString stringWithFormat:@"%@:%@", kClassificationsTag, + kClassificationsStreamName] ] + taskOptions:options + enableFlowLimiting:NO + error:error]; + + if (!taskInfo) { + return nil; + } + + _textTaskRunner = + [[MPPTextTaskRunner alloc] initWithCalculatorGraphConfig:[taskInfo generateGraphConfig] + error:error]; + + if (!_textTaskRunner) { + return nil; + } + } + return self; +} + +- (instancetype)initWithModelPath:(NSString *)modelPath error:(NSError **)error { + MPPTextClassifierOptions *options = [[MPPTextClassifierOptions alloc] init]; + + options.baseOptions.modelAssetPath = modelPath; + + return [self initWithOptions:options error:error]; +} + +- (nullable MPPTextClassifierResult *)classifyText:(NSString *)text error:(NSError **)error { + Packet packet = [MPPTextPacketCreator createWithText:text]; + + std::map packetMap = {{kTextInStreamName.cppString, packet}}; + absl::StatusOr statusOrOutputPacketMap = [_textTaskRunner process:packetMap]; + + if (![MPPCommonUtils checkCppError:statusOrOutputPacketMap.status() toError:error]) { + return nil; + } + + return [MPPTextClassifierResult + textClassifierResultWithClassificationsPacket:statusOrOutputPacketMap.value() + [kClassificationsStreamName.cppString]]; +} + +@end diff --git a/mediapipe/tasks/ios/text/text_classifier/sources/MPPTextClassifierOptions.h b/mediapipe/tasks/ios/text/text_classifier/sources/MPPTextClassifierOptions.h new file mode 100644 index 000000000..55ab020f7 --- /dev/null +++ b/mediapipe/tasks/ios/text/text_classifier/sources/MPPTextClassifierOptions.h @@ -0,0 +1,61 @@ +// Copyright 2023 The MediaPipe Authors. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#import + +#import "mediapipe/tasks/ios/core/sources/MPPTaskOptions.h" + +NS_ASSUME_NONNULL_BEGIN + +/** + * Options for setting up a `MPPTextClassifier`. + */ +NS_SWIFT_NAME(TextClassifierOptions) +@interface MPPTextClassifierOptions : MPPTaskOptions + +/** + * The locale to use for display names specified through the TFLite Model Metadata, if any. Defaults + * to English. + */ +@property(nonatomic, copy) NSString *displayNamesLocale; + +/** + * The maximum number of top-scored classification results to return. If < 0, all available results + * will be returned. If 0, an invalid argument error is returned. + */ +@property(nonatomic) NSInteger maxResults; + +/** + * Score threshold to override the one provided in the model metadata (if any). Results below this + * value are rejected. + */ +@property(nonatomic) float scoreThreshold; + +/** + * The allowlist of category names. If non-empty, detection results whose category name is not in + * this set will be filtered out. Duplicate or unknown category names are ignored. Mutually + * exclusive with categoryDenylist. + */ +@property(nonatomic, copy) NSArray *categoryAllowlist; + +/** + * The denylist of category names. If non-empty, detection results whose category name is in this + * set will be filtered out. Duplicate or unknown category names are ignored. Mutually exclusive + * with categoryAllowlist. + */ +@property(nonatomic, copy) NSArray *categoryDenylist; + +@end + +NS_ASSUME_NONNULL_END diff --git a/mediapipe/tasks/ios/text/text_classifier/sources/MPPTextClassifierOptions.m b/mediapipe/tasks/ios/text/text_classifier/sources/MPPTextClassifierOptions.m new file mode 100644 index 000000000..2d5c17cda --- /dev/null +++ b/mediapipe/tasks/ios/text/text_classifier/sources/MPPTextClassifierOptions.m @@ -0,0 +1,40 @@ +// Copyright 2023 The MediaPipe Authors. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#import "mediapipe/tasks/ios/text/text_classifier/sources/MPPTextClassifierOptions.h" + +@implementation MPPTextClassifierOptions + +- (instancetype)init { + self = [super init]; + if (self) { + _maxResults = -1; + _scoreThreshold = 0; + } + return self; +} + +- (id)copyWithZone:(NSZone *)zone { + MPPTextClassifierOptions *textClassifierOptions = [super copyWithZone:zone]; + + textClassifierOptions.scoreThreshold = self.scoreThreshold; + textClassifierOptions.maxResults = self.maxResults; + textClassifierOptions.categoryDenylist = self.categoryDenylist; + textClassifierOptions.categoryAllowlist = self.categoryAllowlist; + textClassifierOptions.displayNamesLocale = self.displayNamesLocale; + + return textClassifierOptions; +} + +@end diff --git a/mediapipe/tasks/ios/text/text_classifier/sources/MPPTextClassifierResult.h b/mediapipe/tasks/ios/text/text_classifier/sources/MPPTextClassifierResult.h new file mode 100644 index 000000000..6744a8e16 --- /dev/null +++ b/mediapipe/tasks/ios/text/text_classifier/sources/MPPTextClassifierResult.h @@ -0,0 +1,44 @@ +// Copyright 2023 The MediaPipe Authors. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#import +#import "mediapipe/tasks/ios/components/containers/sources/MPPClassificationResult.h" +#import "mediapipe/tasks/ios/core/sources/MPPTaskResult.h" + +NS_ASSUME_NONNULL_BEGIN + +/** Represents the classification results generated by `MPPTextClassifier`. **/ +NS_SWIFT_NAME(TextClassifierResult) +@interface MPPTextClassifierResult : MPPTaskResult + +/** The `MPPClassificationResult` instance containing one set of results per classifier head. **/ +@property(nonatomic, readonly) MPPClassificationResult *classificationResult; + +/** + * Initializes a new `MPPTextClassifierResult` with the given `MPPClassificationResult` and + * timestamp (in milliseconds). + * + * @param classificationResult The `MPPClassificationResult` instance containing one set of results + * per classifier head. + * @param timestampMs The timestamp for this result. + * + * @return An instance of `MPPTextClassifierResult` initialized with the given + * `MPPClassificationResult` and timestamp (in milliseconds). + */ +- (instancetype)initWithClassificationResult:(MPPClassificationResult *)classificationResult + timestampMs:(NSInteger)timestampMs; + +@end + +NS_ASSUME_NONNULL_END diff --git a/mediapipe/tasks/ios/text/text_classifier/sources/MPPTextClassifierResult.m b/mediapipe/tasks/ios/text/text_classifier/sources/MPPTextClassifierResult.m new file mode 100644 index 000000000..4d5c1104a --- /dev/null +++ b/mediapipe/tasks/ios/text/text_classifier/sources/MPPTextClassifierResult.m @@ -0,0 +1,28 @@ +// Copyright 2023 The MediaPipe Authors. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#import "mediapipe/tasks/ios/text/text_classifier/sources/MPPTextClassifierResult.h" + +@implementation MPPTextClassifierResult + +- (instancetype)initWithClassificationResult:(MPPClassificationResult *)classificationResult + timestampMs:(NSInteger)timestampMs { + self = [super initWithTimestampMs:timestampMs]; + if (self) { + _classificationResult = classificationResult; + } + return self; +} + +@end diff --git a/mediapipe/tasks/ios/text/text_classifier/utils/BUILD b/mediapipe/tasks/ios/text/text_classifier/utils/BUILD new file mode 100644 index 000000000..6795194fb --- /dev/null +++ b/mediapipe/tasks/ios/text/text_classifier/utils/BUILD @@ -0,0 +1,44 @@ +# Copyright 2023 The MediaPipe Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +package(default_visibility = ["//mediapipe/tasks:internal"]) + +licenses(["notice"]) + +objc_library( + name = "MPPTextClassifierOptionsHelpers", + srcs = ["sources/MPPTextClassifierOptions+Helpers.mm"], + hdrs = ["sources/MPPTextClassifierOptions+Helpers.h"], + deps = [ + "//mediapipe/framework:calculator_options_cc_proto", + "//mediapipe/tasks/cc/components/processors/proto:classifier_options_cc_proto", + "//mediapipe/tasks/cc/text/text_classifier/proto:text_classifier_graph_options_cc_proto", + "//mediapipe/tasks/ios/common/utils:NSStringHelpers", + "//mediapipe/tasks/ios/core:MPPTaskOptionsProtocol", + "//mediapipe/tasks/ios/core/utils:MPPBaseOptionsHelpers", + "//mediapipe/tasks/ios/text/text_classifier:MPPTextClassifierOptions", + ], +) + +objc_library( + name = "MPPTextClassifierResultHelpers", + srcs = ["sources/MPPTextClassifierResult+Helpers.mm"], + hdrs = ["sources/MPPTextClassifierResult+Helpers.h"], + deps = [ + "//mediapipe/framework:packet", + "//mediapipe/tasks/cc/components/containers/proto:classifications_cc_proto", + "//mediapipe/tasks/ios/components/containers/utils:MPPClassificationResultHelpers", + "//mediapipe/tasks/ios/text/text_classifier:MPPTextClassifierResult", + ], +) diff --git a/mediapipe/tasks/ios/text/text_classifier/utils/sources/MPPTextClassifierOptions+Helpers.h b/mediapipe/tasks/ios/text/text_classifier/utils/sources/MPPTextClassifierOptions+Helpers.h new file mode 100644 index 000000000..0ca393333 --- /dev/null +++ b/mediapipe/tasks/ios/text/text_classifier/utils/sources/MPPTextClassifierOptions+Helpers.h @@ -0,0 +1,27 @@ +// Copyright 2023 The MediaPipe Authors. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "mediapipe/framework/calculator_options.pb.h" +#import "mediapipe/tasks/ios/core/sources/MPPTaskOptionsProtocol.h" +#import "mediapipe/tasks/ios/text/text_classifier/sources/MPPTextClassifierOptions.h" + +NS_ASSUME_NONNULL_BEGIN + +@interface MPPTextClassifierOptions (Helpers) + +- (void)copyToProto:(::mediapipe::CalculatorOptions *)optionsProto; + +@end + +NS_ASSUME_NONNULL_END diff --git a/mediapipe/tasks/ios/text/text_classifier/utils/sources/MPPTextClassifierOptions+Helpers.mm b/mediapipe/tasks/ios/text/text_classifier/utils/sources/MPPTextClassifierOptions+Helpers.mm new file mode 100644 index 000000000..2c9fcc07f --- /dev/null +++ b/mediapipe/tasks/ios/text/text_classifier/utils/sources/MPPTextClassifierOptions+Helpers.mm @@ -0,0 +1,56 @@ +// Copyright 2023 The MediaPipe Authors. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#import "mediapipe/tasks/ios/text/text_classifier/utils/sources/MPPTextClassifierOptions+Helpers.h" + +#import "mediapipe/tasks/ios/common/utils/sources/NSString+Helpers.h" +#import "mediapipe/tasks/ios/core/utils/sources/MPPBaseOptions+Helpers.h" + +#include "mediapipe/tasks/cc/text/text_classifier/proto/text_classifier_graph_options.pb.h" +#include "mediapipe/tasks/cc/components/processors/proto/classifier_options.pb.h" + +namespace { +using CalculatorOptionsProto = ::mediapipe::CalculatorOptions; +using TextClassifierGraphOptionsProto = + ::mediapipe::tasks::text::text_classifier::proto::TextClassifierGraphOptions; +using ClassifierOptionsProto = ::mediapipe::tasks::components::processors::proto::ClassifierOptions; +} // namespace + +@implementation MPPTextClassifierOptions (Helpers) + +- (void)copyToProto:(CalculatorOptionsProto *)optionsProto { + TextClassifierGraphOptionsProto *graphOptions = + optionsProto->MutableExtension(TextClassifierGraphOptionsProto::ext); + [self.baseOptions copyToProto:graphOptions->mutable_base_options()]; + + ClassifierOptionsProto *classifierOptionsProto = graphOptions->mutable_classifier_options(); + classifierOptionsProto->Clear(); + + if (self.displayNamesLocale) { + classifierOptionsProto->set_display_names_locale(self.displayNamesLocale.cppString); + } + + classifierOptionsProto->set_max_results((int)self.maxResults); + classifierOptionsProto->set_score_threshold(self.scoreThreshold); + + for (NSString *category in self.categoryAllowlist) { + classifierOptionsProto->add_category_allowlist(category.cppString); + } + + for (NSString *category in self.categoryDenylist) { + classifierOptionsProto->add_category_denylist(category.cppString); + } +} + +@end diff --git a/mediapipe/tasks/ios/text/text_classifier/utils/sources/MPPTextClassifierResult+Helpers.h b/mediapipe/tasks/ios/text/text_classifier/utils/sources/MPPTextClassifierResult+Helpers.h new file mode 100644 index 000000000..f1b728b0a --- /dev/null +++ b/mediapipe/tasks/ios/text/text_classifier/utils/sources/MPPTextClassifierResult+Helpers.h @@ -0,0 +1,28 @@ +// Copyright 2023 The MediaPipe Authors. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#import "mediapipe/tasks/ios/text/text_classifier/sources/MPPTextClassifierResult.h" + +#include "mediapipe/framework/packet.h" + +NS_ASSUME_NONNULL_BEGIN + +@interface MPPTextClassifierResult (Helpers) + ++ (MPPTextClassifierResult *)textClassifierResultWithClassificationsPacket: + (const mediapipe::Packet &)packet; + +@end + +NS_ASSUME_NONNULL_END diff --git a/mediapipe/tasks/ios/text/text_classifier/utils/sources/MPPTextClassifierResult+Helpers.mm b/mediapipe/tasks/ios/text/text_classifier/utils/sources/MPPTextClassifierResult+Helpers.mm new file mode 100644 index 000000000..f5d6aa1d3 --- /dev/null +++ b/mediapipe/tasks/ios/text/text_classifier/utils/sources/MPPTextClassifierResult+Helpers.mm @@ -0,0 +1,42 @@ +// Copyright 2023 The MediaPipe Authors. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#import "mediapipe/tasks/ios/components/containers/utils/sources/MPPClassificationResult+Helpers.h" +#import "mediapipe/tasks/ios/text/text_classifier/utils/sources/MPPTextClassifierResult+Helpers.h" + +#include "mediapipe/tasks/cc/components/containers/proto/classifications.pb.h" + +static const int kMicroSecondsPerMilliSecond = 1000; + +namespace { +using ClassificationResultProto = + ::mediapipe::tasks::components::containers::proto::ClassificationResult; +using ::mediapipe::Packet; +} // namespace + +#define int kMicroSecondsPerMilliSecond = 1000; + +@implementation MPPTextClassifierResult (Helpers) + ++ (MPPTextClassifierResult *)textClassifierResultWithClassificationsPacket:(const Packet &)packet { + MPPClassificationResult *classificationResult = [MPPClassificationResult + classificationResultWithProto:packet.Get()]; + + return [[MPPTextClassifierResult alloc] + initWithClassificationResult:classificationResult + timestampMs:(NSInteger)(packet.Timestamp().Value() / + kMicroSecondsPerMilliSecond)]; +} + +@end diff --git a/mediapipe/tasks/ios/text/text_embedder/BUILD b/mediapipe/tasks/ios/text/text_embedder/BUILD new file mode 100644 index 000000000..21226b012 --- /dev/null +++ b/mediapipe/tasks/ios/text/text_embedder/BUILD @@ -0,0 +1,60 @@ +# Copyright 2023 The MediaPipe Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +package(default_visibility = ["//mediapipe/tasks:internal"]) + +licenses(["notice"]) + +objc_library( + name = "MPPTextEmbedderOptions", + srcs = ["sources/MPPTextEmbedderOptions.m"], + hdrs = ["sources/MPPTextEmbedderOptions.h"], + deps = ["//mediapipe/tasks/ios/core:MPPTaskOptions"], +) + +objc_library( + name = "MPPTextEmbedderResult", + srcs = ["sources/MPPTextEmbedderResult.m"], + hdrs = ["sources/MPPTextEmbedderResult.h"], + deps = [ + "//mediapipe/tasks/ios/components/containers:MPPEmbeddingResult", + "//mediapipe/tasks/ios/core:MPPTaskResult", + ], +) + +objc_library( + name = "MPPTextEmbedder", + srcs = ["sources/MPPTextEmbedder.mm"], + hdrs = ["sources/MPPTextEmbedder.h"], + copts = [ + "-ObjC++", + "-std=c++17", + "-x objective-c++", + ], + module_name = "MPPTextEmbedder", + deps = [ + ":MPPTextEmbedderOptions", + ":MPPTextEmbedderResult", + "//mediapipe/tasks/cc/text/text_embedder:text_embedder_graph", + "//mediapipe/tasks/ios/common/utils:MPPCommonUtils", + "//mediapipe/tasks/ios/common/utils:NSStringHelpers", + "//mediapipe/tasks/ios/core:MPPTaskInfo", + "//mediapipe/tasks/ios/core:MPPTaskOptions", + "//mediapipe/tasks/ios/core:MPPTextPacketCreator", + "//mediapipe/tasks/ios/text/core:MPPTextTaskRunner", + "//mediapipe/tasks/ios/text/text_embedder/utils:MPPTextEmbedderOptionsHelpers", + "//mediapipe/tasks/ios/text/text_embedder/utils:MPPTextEmbedderResultHelpers", + "@com_google_absl//absl/status:statusor", + ], +) diff --git a/mediapipe/tasks/ios/text/text_embedder/sources/MPPTextEmbedder.h b/mediapipe/tasks/ios/text/text_embedder/sources/MPPTextEmbedder.h new file mode 100644 index 000000000..b3f2ebe92 --- /dev/null +++ b/mediapipe/tasks/ios/text/text_embedder/sources/MPPTextEmbedder.h @@ -0,0 +1,93 @@ +// Copyright 2023 The MediaPipe Authors. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#import + +#import "mediapipe/tasks/ios/core/sources/MPPTaskOptions.h" +#import "mediapipe/tasks/ios/text/text_embedder/sources/MPPTextEmbedderOptions.h" +#import "mediapipe/tasks/ios/text/text_embedder/sources/MPPTextEmbedderResult.h" + +NS_ASSUME_NONNULL_BEGIN + +/** + * @brief Performs embedding extraction on text. + * + * This API expects a TFLite model with (optional) [TFLite Model + * Metadata](https://www.tensorflow.org/lite/convert/metadata"). + * + * Metadata is required for models with int32 input tensors because it contains the input process + * unit for the model's Tokenizer. No metadata is required for models with string input tensors. + * + * Input tensors: + * - Three input tensors `kTfLiteInt32` of shape `[batch_size x bert_max_seq_len]` + * representing the input ids, mask ids, and segment ids. This input signature requires + * a Bert Tokenizer process unit in the model metadata. + * - Or one input tensor `kTfLiteInt32` of shape `[batch_size x max_seq_len]` representing + * the input ids. This input signature requires a Regex Tokenizer process unit in the + * model metadata. + * - Or one input tensor (`kTfLiteString`) that is shapeless or has shape `[1]` containing + * the input string. + * + * At least one output tensor (`kTfLiteFloat32`/`kTfLiteUint8`) with shape `[1 x N]` where `N` is + * the number of dimensions in the produced embeddings. + */ +NS_SWIFT_NAME(TextEmbedder) +@interface MPPTextEmbedder : NSObject + +/** + * Creates a new instance of `MPPTextEmbedder` from an absolute path to a TensorFlow Lite + * model file stored locally on the device and the default `MPPTextEmbedderOptions`. + * + * @param modelPath An absolute path to a TensorFlow Lite model file stored locally on the device. + * @param error An optional error parameter populated when there is an error in initializing the + * text embedder. + * + * @return A new instance of `MPPTextEmbedder` with the given model path. `nil` if there is an + * error in initializing the text embedder. + */ +- (nullable instancetype)initWithModelPath:(NSString *)modelPath error:(NSError **)error; + +/** + * Creates a new instance of `MPPTextEmbedder` from the given `MPPTextEmbedderOptions`. + * + * @param options The options of type `MPPTextEmbedderOptions` to use for configuring the + * `MPPTextEmbedder`. + * @param error An optional error parameter populated when there is an error in initializing the + * text embedder. + * + * @return A new instance of `MPPTextEmbedder` with the given options. `nil` if there is an + * error in initializing the text embedder. + */ +- (nullable instancetype)initWithOptions:(MPPTextEmbedderOptions *)options + error:(NSError **)error NS_DESIGNATED_INITIALIZER; + +/** + * Performs embedding extraction on the input text. + * + * @param text The `NSString` on which embedding extraction is to be performed. + * @param error An optional error parameter populated when there is an error in performing + * embedding extraction on the input text. + * + * @return A `MPPTextEmbedderResult` object that contains a list of embeddings. + */ +- (nullable MPPTextEmbedderResult *)embedText:(NSString *)text + error:(NSError **)error NS_SWIFT_NAME(embed(text:)); + +- (instancetype)init NS_UNAVAILABLE; + ++ (instancetype)new NS_UNAVAILABLE; + +@end + +NS_ASSUME_NONNULL_END diff --git a/mediapipe/tasks/ios/text/text_embedder/sources/MPPTextEmbedder.mm b/mediapipe/tasks/ios/text/text_embedder/sources/MPPTextEmbedder.mm new file mode 100644 index 000000000..a9c811cdb --- /dev/null +++ b/mediapipe/tasks/ios/text/text_embedder/sources/MPPTextEmbedder.mm @@ -0,0 +1,96 @@ +// Copyright 2023 The MediaPipe Authors. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#import "mediapipe/tasks/ios/text/text_embedder/sources/MPPTextEmbedder.h" + +#import "mediapipe/tasks/ios/common/utils/sources/MPPCommonUtils.h" +#import "mediapipe/tasks/ios/common/utils/sources/NSString+Helpers.h" +#import "mediapipe/tasks/ios/core/sources/MPPTaskInfo.h" +#import "mediapipe/tasks/ios/core/sources/MPPTextPacketCreator.h" +#import "mediapipe/tasks/ios/text/core/sources/MPPTextTaskRunner.h" +#import "mediapipe/tasks/ios/text/text_embedder/utils/sources/MPPTextEmbedderOptions+Helpers.h" +#import "mediapipe/tasks/ios/text/text_embedder/utils/sources/MPPTextEmbedderResult+Helpers.h" + +#include "absl/status/statusor.h" + +namespace { +using ::mediapipe::Packet; +using ::mediapipe::tasks::core::PacketMap; +} // namespace + +static NSString *const kEmbeddingsOutStreamName = @"embeddings_out"; +static NSString *const kEmbeddingsTag = @"EMBEDDINGS"; +static NSString *const kTextInStreamName = @"text_in"; +static NSString *const kTextTag = @"TEXT"; +static NSString *const kTaskGraphName = @"mediapipe.tasks.text.text_embedder.TextEmbedderGraph"; + +@interface MPPTextEmbedder () { + /** iOS Text Task Runner */ + MPPTextTaskRunner *_textTaskRunner; +} +@end + +@implementation MPPTextEmbedder + +- (instancetype)initWithOptions:(MPPTextEmbedderOptions *)options error:(NSError **)error { + self = [super init]; + if (self) { + MPPTaskInfo *taskInfo = [[MPPTaskInfo alloc] + initWithTaskGraphName:kTaskGraphName + inputStreams:@[ [NSString stringWithFormat:@"%@:%@", kTextTag, kTextInStreamName] ] + outputStreams:@[ [NSString stringWithFormat:@"%@:%@", kEmbeddingsTag, + kEmbeddingsOutStreamName] ] + taskOptions:options + enableFlowLimiting:NO + error:error]; + + if (!taskInfo) { + return nil; + } + + _textTaskRunner = + [[MPPTextTaskRunner alloc] initWithCalculatorGraphConfig:[taskInfo generateGraphConfig] + error:error]; + + if (!_textTaskRunner) { + return nil; + } + } + return self; +} + +- (instancetype)initWithModelPath:(NSString *)modelPath error:(NSError **)error { + MPPTextEmbedderOptions *options = [[MPPTextEmbedderOptions alloc] init]; + + options.baseOptions.modelAssetPath = modelPath; + + return [self initWithOptions:options error:error]; +} + +- (nullable MPPTextEmbedderResult *)embedText:(NSString *)text error:(NSError **)error { + Packet packet = [MPPTextPacketCreator createWithText:text]; + + std::map packetMap = {{kTextInStreamName.cppString, packet}}; + absl::StatusOr statusOrOutputPacketMap = [_textTaskRunner process:packetMap]; + + if (![MPPCommonUtils checkCppError:statusOrOutputPacketMap.status() toError:error]) { + return nil; + } + + return [MPPTextEmbedderResult + textEmbedderResultWithOutputPacket:statusOrOutputPacketMap + .value()[kEmbeddingsOutStreamName.cppString]]; +} + +@end diff --git a/mediapipe/tasks/ios/text/text_embedder/sources/MPPTextEmbedderOptions.h b/mediapipe/tasks/ios/text/text_embedder/sources/MPPTextEmbedderOptions.h new file mode 100644 index 000000000..cd059a297 --- /dev/null +++ b/mediapipe/tasks/ios/text/text_embedder/sources/MPPTextEmbedderOptions.h @@ -0,0 +1,47 @@ +// Copyright 2023 The MediaPipe Authors. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#import + +#import "mediapipe/tasks/ios/core/sources/MPPTaskOptions.h" + +NS_ASSUME_NONNULL_BEGIN + +/** + * Options for setting up a `MPPTextEmbedder`. + */ +NS_SWIFT_NAME(TextEmbedderOptions) +@interface MPPTextEmbedderOptions : MPPTaskOptions + +/** + * @brief Sets whether L2 normalization should be performed on the returned embeddings. + * Use this option only if the model does not already contain a native L2_NORMALIZATION TF Lite Op. + * In most cases, this is already the case and L2 norm is thus achieved through TF Lite inference. + * + * `NO` by default. + */ +@property(nonatomic) BOOL l2Normalize; + +/** + * @brief Sets whether the returned embedding should be quantized to bytes via scalar quantization. + * Embeddings are implicitly assumed to be unit-norm and therefore any dimensions is guaranteed to + * have value in [-1.0, 1.0]. Use the `l2Normalize` property if this is not the case. + * + * `NO` by default. + */ +@property(nonatomic) BOOL quantize; + +@end + +NS_ASSUME_NONNULL_END diff --git a/mediapipe/tasks/ios/text/text_embedder/sources/MPPTextEmbedderOptions.m b/mediapipe/tasks/ios/text/text_embedder/sources/MPPTextEmbedderOptions.m new file mode 100644 index 000000000..6da3659f7 --- /dev/null +++ b/mediapipe/tasks/ios/text/text_embedder/sources/MPPTextEmbedderOptions.m @@ -0,0 +1,28 @@ +// Copyright 2023 The MediaPipe Authors. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#import "mediapipe/tasks/ios/text/text_embedder/sources/MPPTextEmbedderOptions.h" + +@implementation MPPTextEmbedderOptions + +- (id)copyWithZone:(NSZone *)zone { + MPPTextEmbedderOptions *textEmbedderOptions = [super copyWithZone:zone]; + + textEmbedderOptions.l2Normalize = self.l2Normalize; + textEmbedderOptions.quantize = self.quantize; + + return textEmbedderOptions; +} + +@end diff --git a/mediapipe/tasks/ios/text/text_embedder/sources/MPPTextEmbedderResult.h b/mediapipe/tasks/ios/text/text_embedder/sources/MPPTextEmbedderResult.h new file mode 100644 index 000000000..e4697dcef --- /dev/null +++ b/mediapipe/tasks/ios/text/text_embedder/sources/MPPTextEmbedderResult.h @@ -0,0 +1,48 @@ +// Copyright 2023 The MediaPipe Authors. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#import +#import "mediapipe/tasks/ios/components/containers/sources/MPPEmbeddingResult.h" +#import "mediapipe/tasks/ios/core/sources/MPPTaskResult.h" + +NS_ASSUME_NONNULL_BEGIN + +/** Represents the embedding results generated by `MPPTextEmbedder`. **/ +NS_SWIFT_NAME(TextEmbedderResult) +@interface MPPTextEmbedderResult : MPPTaskResult + +/** The `MPPEmbedderResult` instance containing one embedding per embedder head. **/ +@property(nonatomic, readonly) MPPEmbeddingResult *embeddingResult; + +/** + * Initializes a new `MPPTextEmbedderResult` with the given `MPPEmbeddingResult` and + * timestamp (in milliseconds). + * + * @param embeddingResult The `MPPEmbeddingResult` instance containing one set of results + * per classifier head. + * @param timestampMs The timestamp for this result. + * + * @return An instance of `MPPTextEmbedderResult` initialized with the given + * `MPPEmbeddingResult` and timestamp (in milliseconds). + */ +- (instancetype)initWithEmbeddingResult:(MPPEmbeddingResult *)embeddingResult + timestampMs:(NSInteger)timestampMs; + +- (instancetype)init NS_UNAVAILABLE; + ++ (instancetype)new NS_UNAVAILABLE; + +@end + +NS_ASSUME_NONNULL_END diff --git a/mediapipe/tasks/ios/text/text_embedder/sources/MPPTextEmbedderResult.m b/mediapipe/tasks/ios/text/text_embedder/sources/MPPTextEmbedderResult.m new file mode 100644 index 000000000..5483e3c3f --- /dev/null +++ b/mediapipe/tasks/ios/text/text_embedder/sources/MPPTextEmbedderResult.m @@ -0,0 +1,28 @@ +// Copyright 2023 The MediaPipe Authors. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#import "mediapipe/tasks/ios/text/text_embedder/sources/MPPTextEmbedderResult.h" + +@implementation MPPTextEmbedderResult + +- (instancetype)initWithEmbeddingResult:(MPPEmbeddingResult *)embeddingResult + timestampMs:(NSInteger)timestampMs { + self = [super initWithTimestampMs:timestampMs]; + if (self) { + _embeddingResult = embeddingResult; + } + return self; +} + +@end diff --git a/mediapipe/tasks/ios/text/text_embedder/utils/BUILD b/mediapipe/tasks/ios/text/text_embedder/utils/BUILD new file mode 100644 index 000000000..eeb4981fb --- /dev/null +++ b/mediapipe/tasks/ios/text/text_embedder/utils/BUILD @@ -0,0 +1,44 @@ +# Copyright 2023 The MediaPipe Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +package(default_visibility = ["//mediapipe/tasks:internal"]) + +licenses(["notice"]) + +objc_library( + name = "MPPTextEmbedderOptionsHelpers", + srcs = ["sources/MPPTextEmbedderOptions+Helpers.mm"], + hdrs = ["sources/MPPTextEmbedderOptions+Helpers.h"], + deps = [ + "//mediapipe/framework:calculator_options_cc_proto", + "//mediapipe/tasks/cc/components/processors/proto:embedder_options_cc_proto", + "//mediapipe/tasks/cc/text/text_embedder/proto:text_embedder_graph_options_cc_proto", + "//mediapipe/tasks/ios/common/utils:NSStringHelpers", + "//mediapipe/tasks/ios/core:MPPTaskOptionsProtocol", + "//mediapipe/tasks/ios/core/utils:MPPBaseOptionsHelpers", + "//mediapipe/tasks/ios/text/text_embedder:MPPTextEmbedderOptions", + ], +) + +objc_library( + name = "MPPTextEmbedderResultHelpers", + srcs = ["sources/MPPTextEmbedderResult+Helpers.mm"], + hdrs = ["sources/MPPTextEmbedderResult+Helpers.h"], + deps = [ + "//mediapipe/framework:packet", + "//mediapipe/tasks/cc/components/containers/proto:embeddings_cc_proto", + "//mediapipe/tasks/ios/components/containers/utils:MPPEmbeddingResultHelpers", + "//mediapipe/tasks/ios/text/text_embedder:MPPTextEmbedderResult", + ], +) diff --git a/mediapipe/tasks/ios/text/text_embedder/utils/sources/MPPTextEmbedderOptions+Helpers.h b/mediapipe/tasks/ios/text/text_embedder/utils/sources/MPPTextEmbedderOptions+Helpers.h new file mode 100644 index 000000000..7f3d1c958 --- /dev/null +++ b/mediapipe/tasks/ios/text/text_embedder/utils/sources/MPPTextEmbedderOptions+Helpers.h @@ -0,0 +1,27 @@ +// Copyright 2023 The MediaPipe Authors. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "mediapipe/framework/calculator_options.pb.h" +#import "mediapipe/tasks/ios/core/sources/MPPTaskOptionsProtocol.h" +#import "mediapipe/tasks/ios/text/text_embedder/sources/MPPTextEmbedderOptions.h" + +NS_ASSUME_NONNULL_BEGIN + +@interface MPPTextEmbedderOptions (Helpers) + +- (void)copyToProto:(::mediapipe::CalculatorOptions *)optionsProto; + +@end + +NS_ASSUME_NONNULL_END diff --git a/mediapipe/tasks/ios/text/text_embedder/utils/sources/MPPTextEmbedderOptions+Helpers.mm b/mediapipe/tasks/ios/text/text_embedder/utils/sources/MPPTextEmbedderOptions+Helpers.mm new file mode 100644 index 000000000..e17b6e8da --- /dev/null +++ b/mediapipe/tasks/ios/text/text_embedder/utils/sources/MPPTextEmbedderOptions+Helpers.mm @@ -0,0 +1,44 @@ +// Copyright 2023 The MediaPipe Authors. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#import "mediapipe/tasks/ios/text/text_embedder/utils/sources/MPPTextEmbedderOptions+Helpers.h" + +#import "mediapipe/tasks/ios/common/utils/sources/NSString+Helpers.h" +#import "mediapipe/tasks/ios/core/utils/sources/MPPBaseOptions+Helpers.h" + +#include "mediapipe/tasks/cc/components/processors/proto/embedder_options.pb.h" +#include "mediapipe/tasks/cc/text/text_embedder/proto/text_embedder_graph_options.pb.h" + +namespace { +using CalculatorOptionsProto = ::mediapipe::CalculatorOptions; +using TextEmbedderGraphOptionsProto = + ::mediapipe::tasks::text::text_embedder::proto::TextEmbedderGraphOptions; +using EmbedderOptionsProto = ::mediapipe::tasks::components::processors::proto::EmbedderOptions; +} // namespace + +@implementation MPPTextEmbedderOptions (Helpers) + +- (void)copyToProto:(CalculatorOptionsProto *)optionsProto { + TextEmbedderGraphOptionsProto *graphOptions = + optionsProto->MutableExtension(TextEmbedderGraphOptionsProto::ext); + [self.baseOptions copyToProto:graphOptions->mutable_base_options()]; + + EmbedderOptionsProto *embedderOptionsProto = graphOptions->mutable_embedder_options(); + embedderOptionsProto->Clear(); + + embedderOptionsProto->set_l2_normalize(self.l2Normalize ? true : false); + embedderOptionsProto->set_quantize(self.quantize ? true : false); +} + +@end diff --git a/mediapipe/tasks/ios/text/text_embedder/utils/sources/MPPTextEmbedderResult+Helpers.h b/mediapipe/tasks/ios/text/text_embedder/utils/sources/MPPTextEmbedderResult+Helpers.h new file mode 100644 index 000000000..0a808a54b --- /dev/null +++ b/mediapipe/tasks/ios/text/text_embedder/utils/sources/MPPTextEmbedderResult+Helpers.h @@ -0,0 +1,27 @@ +// Copyright 2023 The MediaPipe Authors. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#import "mediapipe/tasks/ios/text/text_embedder/sources/MPPTextEmbedderResult.h" + +#include "mediapipe/framework/packet.h" + +NS_ASSUME_NONNULL_BEGIN + +@interface MPPTextEmbedderResult (Helpers) + ++ (MPPTextEmbedderResult *)textEmbedderResultWithOutputPacket:(const mediapipe::Packet &)packet; + +@end + +NS_ASSUME_NONNULL_END diff --git a/mediapipe/tasks/ios/text/text_embedder/utils/sources/MPPTextEmbedderResult+Helpers.mm b/mediapipe/tasks/ios/text/text_embedder/utils/sources/MPPTextEmbedderResult+Helpers.mm new file mode 100644 index 000000000..b769292ce --- /dev/null +++ b/mediapipe/tasks/ios/text/text_embedder/utils/sources/MPPTextEmbedderResult+Helpers.mm @@ -0,0 +1,41 @@ +// Copyright 2023 The MediaPipe Authors. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#import "mediapipe/tasks/ios/components/containers/utils/sources/MPPEmbeddingResult+Helpers.h" +#import "mediapipe/tasks/ios/text/text_embedder/utils/sources/MPPTextEmbedderResult+Helpers.h" + +#include "mediapipe/tasks/cc/components/containers/proto/embeddings.pb.h" + +static const int kMicroSecondsPerMilliSecond = 1000; + +namespace { +using EmbeddingResultProto = ::mediapipe::tasks::components::containers::proto::EmbeddingResult; +using ::mediapipe::Packet; +} // namespace + +#define int kMicroSecondsPerMilliSecond = 1000; + +@implementation MPPTextEmbedderResult (Helpers) + ++ (MPPTextEmbedderResult *)textEmbedderResultWithOutputPacket:(const Packet &)packet { + MPPEmbeddingResult *embeddingResult = + [MPPEmbeddingResult embeddingResultWithProto:packet.Get()]; + + return [[MPPTextEmbedderResult alloc] + initWithEmbeddingResult:embeddingResult + timestampMs:(NSInteger)(packet.Timestamp().Value() / + kMicroSecondsPerMilliSecond)]; +} + +@end diff --git a/mediapipe/tasks/java/com/google/mediapipe/tasks/audio/BUILD b/mediapipe/tasks/java/com/google/mediapipe/tasks/audio/BUILD index b162d7dac..e5d472e8a 100644 --- a/mediapipe/tasks/java/com/google/mediapipe/tasks/audio/BUILD +++ b/mediapipe/tasks/java/com/google/mediapipe/tasks/audio/BUILD @@ -14,7 +14,7 @@ licenses(["notice"]) -package(default_visibility = ["//mediapipe/tasks:internal"]) +package(default_visibility = ["//visibility:public"]) android_library( name = "core", @@ -39,6 +39,7 @@ cc_binary( deps = [ "//mediapipe/java/com/google/mediapipe/framework/jni:mediapipe_framework_jni", "//mediapipe/tasks/cc/audio/audio_classifier:audio_classifier_graph", + "//mediapipe/tasks/cc/audio/audio_embedder:audio_embedder_graph", "//mediapipe/tasks/java/com/google/mediapipe/tasks/core/jni:model_resources_cache_jni", ], ) @@ -65,10 +66,39 @@ android_library( "//mediapipe/java/com/google/mediapipe/framework:android_framework", "//mediapipe/tasks/cc/audio/audio_classifier/proto:audio_classifier_graph_options_java_proto_lite", "//mediapipe/tasks/cc/components/containers/proto:classifications_java_proto_lite", + "//mediapipe/tasks/cc/components/processors/proto:classifier_options_java_proto_lite", "//mediapipe/tasks/cc/core/proto:base_options_java_proto_lite", "//mediapipe/tasks/java/com/google/mediapipe/tasks/components/containers:audiodata", "//mediapipe/tasks/java/com/google/mediapipe/tasks/components/containers:classificationresult", - "//mediapipe/tasks/java/com/google/mediapipe/tasks/components/processors:classifieroptions", + "//mediapipe/tasks/java/com/google/mediapipe/tasks/core", + "//third_party:autovalue", + "@maven//:com_google_guava_guava", + ], +) + +android_library( + name = "audioembedder", + srcs = [ + "audioembedder/AudioEmbedder.java", + "audioembedder/AudioEmbedderResult.java", + ], + javacopts = [ + "-Xep:AndroidJdkLibsChecker:OFF", + ], + manifest = "audioembedder/AndroidManifest.xml", + deps = [ + ":core", + "//mediapipe/framework:calculator_options_java_proto_lite", + "//mediapipe/java/com/google/mediapipe/framework:android_framework", + "//mediapipe/tasks/cc/audio/audio_embedder/proto:audio_embedder_graph_options_java_proto_lite", + "//mediapipe/tasks/cc/components/containers/proto:embeddings_java_proto_lite", + "//mediapipe/tasks/cc/components/processors/proto:embedder_options_java_proto_lite", + "//mediapipe/tasks/cc/core/proto:base_options_java_proto_lite", + "//mediapipe/tasks/java/com/google/mediapipe/tasks/audio:libmediapipe_tasks_audio_jni_lib", + "//mediapipe/tasks/java/com/google/mediapipe/tasks/components/containers:audiodata", + "//mediapipe/tasks/java/com/google/mediapipe/tasks/components/containers:embedding", + "//mediapipe/tasks/java/com/google/mediapipe/tasks/components/containers:embeddingresult", + "//mediapipe/tasks/java/com/google/mediapipe/tasks/components/utils:cosinesimilarity", "//mediapipe/tasks/java/com/google/mediapipe/tasks/core", "//third_party:autovalue", "@maven//:com_google_guava_guava", diff --git a/mediapipe/tasks/java/com/google/mediapipe/tasks/audio/audioclassifier/AudioClassifier.java b/mediapipe/tasks/java/com/google/mediapipe/tasks/audio/audioclassifier/AudioClassifier.java index 5a82eecaa..4e5cd7655 100644 --- a/mediapipe/tasks/java/com/google/mediapipe/tasks/audio/audioclassifier/AudioClassifier.java +++ b/mediapipe/tasks/java/com/google/mediapipe/tasks/audio/audioclassifier/AudioClassifier.java @@ -27,7 +27,7 @@ import com.google.mediapipe.tasks.audio.core.BaseAudioTaskApi; import com.google.mediapipe.tasks.audio.core.RunningMode; import com.google.mediapipe.tasks.components.containers.AudioData; import com.google.mediapipe.tasks.components.containers.proto.ClassificationsProto; -import com.google.mediapipe.tasks.components.processors.ClassifierOptions; +import com.google.mediapipe.tasks.components.processors.proto.ClassifierOptionsProto; import com.google.mediapipe.tasks.core.BaseOptions; import com.google.mediapipe.tasks.core.ErrorListener; import com.google.mediapipe.tasks.core.OutputHandler; @@ -203,6 +203,8 @@ public final class AudioClassifier extends BaseAudioTaskApi { TaskRunner.create( context, TaskInfo.builder() + .setTaskName(AudioClassifier.class.getSimpleName()) + .setTaskRunningModeName(options.runningMode().name()) .setTaskGraphName(TASK_GRAPH_NAME) .setInputStreams(INPUT_STREAMS) .setOutputStreams(OUTPUT_STREAMS) @@ -265,8 +267,10 @@ public final class AudioClassifier extends BaseAudioTaskApi { } /* - * Sends audio data (a block in a continuous audio stream) to perform audio classification. Only - * use this method when the AudioClassifier is created with the audio stream mode. + * Sends audio data (a block in a continuous audio stream) to perform audio classification, and + * the results will be available via the {@link ResultListener} provided in the + * {@link AudioClassifierOptions}. Only use this method when the AudioClassifier is created with + * the audio stream mode. * *

The audio block is represented as a MediaPipe {@link AudioData} object. The audio data will * be resampled, accumulated, and framed to the proper size for the underlying model to consume. @@ -318,10 +322,42 @@ public final class AudioClassifier extends BaseAudioTaskApi { public abstract Builder setRunningMode(RunningMode runningMode); /** - * Sets the optional {@link ClassifierOptions} controling classification behavior, such as - * score threshold, number of results, etc. + * Sets the optional locale to use for display names specified through the TFLite Model + * Metadata, if any. */ - public abstract Builder setClassifierOptions(ClassifierOptions classifierOptions); + public abstract Builder setDisplayNamesLocale(String locale); + + /** + * Sets the optional maximum number of top-scored classification results to return. + * + *

If not set, all available results are returned. If set, must be > 0. + */ + public abstract Builder setMaxResults(Integer maxResults); + + /** + * Sets the optional score threshold. Results with score below this value are rejected. + * + *

Overrides the score threshold specified in the TFLite Model Metadata, if any. + */ + public abstract Builder setScoreThreshold(Float scoreThreshold); + + /** + * Sets the optional allowlist of category names. + * + *

If non-empty, detection results whose category name is not in this set will be filtered + * out. Duplicate or unknown category names are ignored. Mutually exclusive with {@code + * categoryDenylist}. + */ + public abstract Builder setCategoryAllowlist(List categoryAllowlist); + + /** + * Sets the optional denylist of category names. + * + *

If non-empty, detection results whose category name is in this set will be filtered out. + * Duplicate or unknown category names are ignored. Mutually exclusive with {@code + * categoryAllowlist}. + */ + public abstract Builder setCategoryDenylist(List categoryDenylist); /** * Sets the {@link ResultListener} to receive the classification results asynchronously when @@ -338,9 +374,7 @@ public final class AudioClassifier extends BaseAudioTaskApi { /** * Validates and builds the {@link AudioClassifierOptions} instance. * - * @throws IllegalArgumentException if the result listener and the running mode are not - * properly configured. The result listener should only be set when the audio classifier - * is in the audio stream mode. + * @throws IllegalArgumentException if any of the set options are invalid. */ public final AudioClassifierOptions build() { AudioClassifierOptions options = autoBuild(); @@ -355,6 +389,13 @@ public final class AudioClassifier extends BaseAudioTaskApi { "The audio classifier is in the audio clips mode, a user-defined result listener" + " shouldn't be provided in AudioClassifierOptions."); } + if (options.maxResults().isPresent() && options.maxResults().get() <= 0) { + throw new IllegalArgumentException("If specified, maxResults must be > 0."); + } + if (!options.categoryAllowlist().isEmpty() && !options.categoryDenylist().isEmpty()) { + throw new IllegalArgumentException( + "Category allowlist and denylist are mutually exclusive."); + } return options; } } @@ -363,7 +404,15 @@ public final class AudioClassifier extends BaseAudioTaskApi { abstract RunningMode runningMode(); - abstract Optional classifierOptions(); + abstract Optional displayNamesLocale(); + + abstract Optional maxResults(); + + abstract Optional scoreThreshold(); + + abstract List categoryAllowlist(); + + abstract List categoryDenylist(); abstract Optional> resultListener(); @@ -371,7 +420,9 @@ public final class AudioClassifier extends BaseAudioTaskApi { public static Builder builder() { return new AutoValue_AudioClassifier_AudioClassifierOptions.Builder() - .setRunningMode(RunningMode.AUDIO_CLIPS); + .setRunningMode(RunningMode.AUDIO_CLIPS) + .setCategoryAllowlist(Collections.emptyList()) + .setCategoryDenylist(Collections.emptyList()); } /** @@ -383,12 +434,21 @@ public final class AudioClassifier extends BaseAudioTaskApi { BaseOptionsProto.BaseOptions.newBuilder(); baseOptionsBuilder.setUseStreamMode(runningMode() == RunningMode.AUDIO_STREAM); baseOptionsBuilder.mergeFrom(convertBaseOptionsToProto(baseOptions())); + ClassifierOptionsProto.ClassifierOptions.Builder classifierOptionsBuilder = + ClassifierOptionsProto.ClassifierOptions.newBuilder(); + displayNamesLocale().ifPresent(classifierOptionsBuilder::setDisplayNamesLocale); + maxResults().ifPresent(classifierOptionsBuilder::setMaxResults); + scoreThreshold().ifPresent(classifierOptionsBuilder::setScoreThreshold); + if (!categoryAllowlist().isEmpty()) { + classifierOptionsBuilder.addAllCategoryAllowlist(categoryAllowlist()); + } + if (!categoryDenylist().isEmpty()) { + classifierOptionsBuilder.addAllCategoryDenylist(categoryDenylist()); + } AudioClassifierGraphOptionsProto.AudioClassifierGraphOptions.Builder taskOptionsBuilder = AudioClassifierGraphOptionsProto.AudioClassifierGraphOptions.newBuilder() - .setBaseOptions(baseOptionsBuilder); - if (classifierOptions().isPresent()) { - taskOptionsBuilder.setClassifierOptions(classifierOptions().get().convertToProto()); - } + .setBaseOptions(baseOptionsBuilder) + .setClassifierOptions(classifierOptionsBuilder); return CalculatorOptions.newBuilder() .setExtension( AudioClassifierGraphOptionsProto.AudioClassifierGraphOptions.ext, diff --git a/mediapipe/tasks/java/com/google/mediapipe/tasks/audio/audioclassifier/AudioClassifierResult.java b/mediapipe/tasks/java/com/google/mediapipe/tasks/audio/audioclassifier/AudioClassifierResult.java index fcc3c6e22..258e5725b 100644 --- a/mediapipe/tasks/java/com/google/mediapipe/tasks/audio/audioclassifier/AudioClassifierResult.java +++ b/mediapipe/tasks/java/com/google/mediapipe/tasks/audio/audioclassifier/AudioClassifierResult.java @@ -20,7 +20,6 @@ import com.google.mediapipe.tasks.components.containers.proto.ClassificationsPro import com.google.mediapipe.tasks.core.TaskResult; import java.util.ArrayList; import java.util.List; -import java.util.Optional; /** Represents the classification results generated by {@link AudioClassifier}. */ @AutoValue @@ -40,8 +39,7 @@ public abstract class AudioClassifierResult implements TaskResult { for (ClassificationsProto.ClassificationResult proto : protoList) { classificationResultList.add(ClassificationResult.createFromProto(proto)); } - return new AutoValue_AudioClassifierResult( - Optional.of(classificationResultList), Optional.empty(), timestampMs); + return new AutoValue_AudioClassifierResult(classificationResultList, timestampMs); } /** @@ -53,23 +51,22 @@ public abstract class AudioClassifierResult implements TaskResult { */ static AudioClassifierResult createFromProto( ClassificationsProto.ClassificationResult proto, long timestampMs) { - return new AutoValue_AudioClassifierResult( - Optional.empty(), Optional.of(ClassificationResult.createFromProto(proto)), timestampMs); + List classificationResultList = new ArrayList<>(); + classificationResultList.add(ClassificationResult.createFromProto(proto)); + return new AutoValue_AudioClassifierResult(classificationResultList, timestampMs); } /** - * A list of of timpstamed {@link ClassificationResult} objects, each contains one set of results - * per classifier head. The list represents the audio classification result of an audio clip, and - * is only available when running with the audio clips mode. + * A list of of timestamped {@link ClassificationResult} objects, each contains one set of results + * per classifier head. + * + *

In the "audio stream" mode, the list only contains one element, representing the + * classification result of the audio block that starts at {@link + * ClassificationResult.timestampMs} in the audio stream. Otherwise, in the "audio clips" mode, + * the list may include multiple {@link ClassificationResult} objects, each classifying an + * interval of the entire audio clip that starts at {@link ClassificationResult.timestampMs}. */ - public abstract Optional> classificationResultList(); - - /** - * Contains one set of results per classifier head. A {@link ClassificationResult} usually - * represents one audio classification result in an audio stream, and s only available when - * running with the audio stream mode. - */ - public abstract Optional classificationResult(); + public abstract List classificationResults(); @Override public abstract long timestampMs(); diff --git a/mediapipe/tasks/java/com/google/mediapipe/tasks/audio/audioembedder/AndroidManifest.xml b/mediapipe/tasks/java/com/google/mediapipe/tasks/audio/audioembedder/AndroidManifest.xml new file mode 100644 index 000000000..4cd033db8 --- /dev/null +++ b/mediapipe/tasks/java/com/google/mediapipe/tasks/audio/audioembedder/AndroidManifest.xml @@ -0,0 +1,8 @@ + + + + + + diff --git a/mediapipe/tasks/java/com/google/mediapipe/tasks/audio/audioembedder/AudioEmbedder.java b/mediapipe/tasks/java/com/google/mediapipe/tasks/audio/audioembedder/AudioEmbedder.java new file mode 100644 index 000000000..077f28ca2 --- /dev/null +++ b/mediapipe/tasks/java/com/google/mediapipe/tasks/audio/audioembedder/AudioEmbedder.java @@ -0,0 +1,410 @@ +// Copyright 2022 The MediaPipe Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package com.google.mediapipe.tasks.audio.audioembedder; + +import android.content.Context; +import android.os.ParcelFileDescriptor; +import com.google.auto.value.AutoValue; +import com.google.mediapipe.proto.CalculatorOptionsProto.CalculatorOptions; +import com.google.mediapipe.framework.MediaPipeException; +import com.google.mediapipe.framework.Packet; +import com.google.mediapipe.framework.PacketGetter; +import com.google.mediapipe.framework.ProtoUtil; +import com.google.mediapipe.tasks.audio.audioembedder.proto.AudioEmbedderGraphOptionsProto; +import com.google.mediapipe.tasks.audio.core.BaseAudioTaskApi; +import com.google.mediapipe.tasks.audio.core.RunningMode; +import com.google.mediapipe.tasks.components.containers.AudioData; +import com.google.mediapipe.tasks.components.containers.Embedding; +import com.google.mediapipe.tasks.components.containers.proto.EmbeddingsProto; +import com.google.mediapipe.tasks.components.processors.proto.EmbedderOptionsProto; +import com.google.mediapipe.tasks.components.utils.CosineSimilarity; +import com.google.mediapipe.tasks.core.BaseOptions; +import com.google.mediapipe.tasks.core.ErrorListener; +import com.google.mediapipe.tasks.core.OutputHandler; +import com.google.mediapipe.tasks.core.OutputHandler.PureResultListener; +import com.google.mediapipe.tasks.core.OutputHandler.ResultListener; +import com.google.mediapipe.tasks.core.TaskInfo; +import com.google.mediapipe.tasks.core.TaskOptions; +import com.google.mediapipe.tasks.core.TaskRunner; +import com.google.mediapipe.tasks.core.proto.BaseOptionsProto; +import java.io.File; +import java.io.IOException; +import java.nio.ByteBuffer; +import java.util.Arrays; +import java.util.Collections; +import java.util.List; +import java.util.Optional; + +/** + * Performs audio embedding extraction on audio clips or audio stream. + * + *

This API expects a TFLite model with mandatory TFLite Model Metadata that contains the + * mandatory AudioProperties of the solo input audio tensor and the optional (but recommended) label + * items as AssociatedFiles with type TENSOR_AXIS_LABELS per output classification tensor. + * + *

Input tensor: (kTfLiteFloat32) + * + *

    + *
  • input audio buffer of size `[batch * samples]`. + *
  • batch inference is not supported (`batch` is required to be 1). + *
  • for multi-channel models, the channels need be interleaved. + *
+ * + *

At least one output tensor with: (kTfLiteFloat32) + * + *

    + *
  • `N` components corresponding to the `N` dimensions of the returned feature vector for this + * output layer. + *
  • Either 2 or 4 dimensions, i.e. `[1 x N]` or `[1 x 1 x 1 x N]`. + *
+ */ +public final class AudioEmbedder extends BaseAudioTaskApi { + private static final String TAG = AudioEmbedder.class.getSimpleName(); + private static final String AUDIO_IN_STREAM_NAME = "audio_in"; + private static final String SAMPLE_RATE_IN_STREAM_NAME = "sample_rate_in"; + private static final List INPUT_STREAMS = + Collections.unmodifiableList( + Arrays.asList( + "AUDIO:" + AUDIO_IN_STREAM_NAME, "SAMPLE_RATE:" + SAMPLE_RATE_IN_STREAM_NAME)); + private static final List OUTPUT_STREAMS = + Collections.unmodifiableList( + Arrays.asList( + "EMBEDDINGS:embeddings_out", "TIMESTAMPED_EMBEDDINGS:timestamped_embeddings_out")); + private static final int EMBEDDINGS_OUT_STREAM_INDEX = 0; + private static final int TIMESTAMPED_EMBEDDINGS_OUT_STREAM_INDEX = 1; + private static final String TASK_GRAPH_NAME = + "mediapipe.tasks.audio.audio_embedder.AudioEmbedderGraph"; + private static final long MICROSECONDS_PER_MILLISECOND = 1000; + + static { + ProtoUtil.registerTypeName( + EmbeddingsProto.EmbeddingResult.class, + "mediapipe.tasks.components.containers.proto.EmbeddingResult"); + } + + /** + * Creates an {@link AudioEmbedder} instance from a model file and default {@link + * AudioEmbedderOptions}. + * + * @param context an Android {@link Context}. + * @param modelPath path to the embedding model in the assets. + * @throws MediaPipeException if there is an error during {@link AudioEmbedder} creation. + */ + public static AudioEmbedder createFromFile(Context context, String modelPath) { + BaseOptions baseOptions = BaseOptions.builder().setModelAssetPath(modelPath).build(); + return createFromOptions( + context, AudioEmbedderOptions.builder().setBaseOptions(baseOptions).build()); + } + + /** + * Creates an {@link AudioEmbedder} instance from a model file and default {@link + * AudioEmbedderOptions}. + * + * @param context an Android {@link Context}. + * @param modelFile the embedding model {@link File} instance. + * @throws IOException if an I/O error occurs when opening the tflite model file. + * @throws MediaPipeException if there is an error during {@link AudioEmbedder} creation. + */ + public static AudioEmbedder createFromFile(Context context, File modelFile) throws IOException { + try (ParcelFileDescriptor descriptor = + ParcelFileDescriptor.open(modelFile, ParcelFileDescriptor.MODE_READ_ONLY)) { + BaseOptions baseOptions = + BaseOptions.builder().setModelAssetFileDescriptor(descriptor.getFd()).build(); + return createFromOptions( + context, AudioEmbedderOptions.builder().setBaseOptions(baseOptions).build()); + } + } + + /** + * Creates an {@link AudioEmbedder} instance from a model buffer and default {@link + * AudioEmbedderOptions}. + * + * @param context an Android {@link Context}. + * @param modelBuffer a direct {@link ByteBuffer} or a {@link MappedByteBuffer} of the embedding + * model. + * @throws MediaPipeException if there is an error during {@link AudioEmbedder} creation. + */ + public static AudioEmbedder createFromBuffer(Context context, final ByteBuffer modelBuffer) { + BaseOptions baseOptions = BaseOptions.builder().setModelAssetBuffer(modelBuffer).build(); + return createFromOptions( + context, AudioEmbedderOptions.builder().setBaseOptions(baseOptions).build()); + } + + /** + * Creates an {@link AudioEmbedder} instance from an {@link AudioEmbedderOptions} instance. + * + * @param context an Android {@link Context}. + * @param options an {@link AudioEmbedderOptions} instance. + * @throws MediaPipeException if there is an error during {@link AudioEmbedder} creation. + */ + public static AudioEmbedder createFromOptions(Context context, AudioEmbedderOptions options) { + OutputHandler handler = new OutputHandler<>(); + handler.setOutputPacketConverter( + new OutputHandler.OutputPacketConverter() { + @Override + public AudioEmbedderResult convertToTaskResult(List packets) { + try { + if (!packets.get(EMBEDDINGS_OUT_STREAM_INDEX).isEmpty()) { + // For audio stream mode. + return AudioEmbedderResult.createFromProto( + PacketGetter.getProto( + packets.get(EMBEDDINGS_OUT_STREAM_INDEX), + EmbeddingsProto.EmbeddingResult.getDefaultInstance()), + packets.get(EMBEDDINGS_OUT_STREAM_INDEX).getTimestamp() + / MICROSECONDS_PER_MILLISECOND); + } else { + // For audio clips mode. + return AudioEmbedderResult.createFromProtoList( + PacketGetter.getProtoVector( + packets.get(TIMESTAMPED_EMBEDDINGS_OUT_STREAM_INDEX), + EmbeddingsProto.EmbeddingResult.parser()), + -1); + } + } catch (IOException e) { + throw new MediaPipeException( + MediaPipeException.StatusCode.INTERNAL.ordinal(), e.getMessage()); + } + } + + @Override + public Void convertToTaskInput(List packets) { + return null; + } + }); + if (options.resultListener().isPresent()) { + ResultListener resultListener = + new ResultListener() { + @Override + public void run(AudioEmbedderResult audioEmbedderResult, Void input) { + options.resultListener().get().run(audioEmbedderResult); + } + }; + handler.setResultListener(resultListener); + } + options.errorListener().ifPresent(handler::setErrorListener); + // Audio tasks should not drop input audio due to flow limiting, which may cause data + // inconsistency. + TaskRunner runner = + TaskRunner.create( + context, + TaskInfo.builder() + .setTaskName(AudioEmbedder.class.getSimpleName()) + .setTaskRunningModeName(options.runningMode().name()) + .setTaskGraphName(TASK_GRAPH_NAME) + .setInputStreams(INPUT_STREAMS) + .setOutputStreams(OUTPUT_STREAMS) + .setTaskOptions(options) + .setEnableFlowLimiting(false) + .build(), + handler); + return new AudioEmbedder(runner, options.runningMode()); + } + + /** + * Constructor to initialize an {@link AudioEmbedder} from a {@link TaskRunner} and {@link + * RunningMode}. + * + * @param taskRunner a {@link TaskRunner}. + * @param runningMode a mediapipe audio task {@link RunningMode}. + */ + private AudioEmbedder(TaskRunner taskRunner, RunningMode runningMode) { + super(taskRunner, runningMode, AUDIO_IN_STREAM_NAME, SAMPLE_RATE_IN_STREAM_NAME); + } + + /* + * Performs embedding extraction on the provided audio clips. Only use this method when the + * AudioEmbedder is created with the audio clips mode. + * + *

The audio clip is represented as a MediaPipe {@link AudioData} object The method accepts + * audio clips with various length and audio sample rate. It's required to provide the + * corresponding audio sample rate within the {@link AudioData} object. + * + *

The input audio clip may be longer than what the model is able to process in a single + * inference. When this occurs, the input audio clip is split into multiple chunks starting at + * different timestamps. For this reason, this function returns a vector of EmbeddingResult + * objects, each associated with a timestamp corresponding to the start (in milliseconds) of the + * chunk data that was extracted. + * + * @param audioClip a MediaPipe {@link AudioData} object for processing. + * @throws MediaPipeException if there is an internal error. + */ + public AudioEmbedderResult embed(AudioData audioClip) { + return (AudioEmbedderResult) processAudioClip(audioClip); + } + + /* + * Sends audio data (a block in a continuous audio stream) to perform audio embedding, and + * the results will be available via the {@link ResultListener} provided in the + * {@link AudioClassifierOptions}. Only use this method when the AudioEmbedder is created with + * the audio stream mode. + * + *

The audio block is represented as a MediaPipe {@link AudioData} object. The audio data will + * be resampled, accumulated, and framed to the proper size for the underlying model to consume. + * It's required to provide the corresponding audio sample rate within {@link AudioData} object as + * well as a timestamp (in milliseconds) to indicate the start time of the input audio block. The + * timestamps must be monotonically increasing. This method will return immediately after + * the input audio data is accepted. The results will be available in the `resultListener` + * provided in the `AudioEmbedderOptions`. The `embedAsync` method is designed to process + * auido stream data such as microphone input. + * + *

The input audio block may be longer than what the model is able to process in a single + * inference. When this occurs, the input audio block is split into multiple chunks. For this + * reason, the callback may be called multiple times (once per chunk) for each call to this + * function. + * + * @param audioBlock a MediaPipe {@link AudioData} object for processing. + * @param timestampMs the input timestamp (in milliseconds). + * @throws MediaPipeException if there is an internal error. + */ + public void embedAsync(AudioData audioBlock, long timestampMs) { + checkOrSetSampleRate(audioBlock.getFormat().getSampleRate()); + sendAudioStreamData(audioBlock, timestampMs); + } + + /** + * Utility function to compute cosine + * similarity between two {@link Embedding} objects. + * + * @throws IllegalArgumentException if the embeddings are of different types (float vs. + * quantized), have different sizes, or have an L2-norm of 0. + */ + public static double cosineSimilarity(Embedding u, Embedding v) { + return CosineSimilarity.compute(u, v); + } + + /** Options for setting up and {@link AudioEmbedder}. */ + @AutoValue + public abstract static class AudioEmbedderOptions extends TaskOptions { + + /** Builder for {@link AudioEmbedderOptions}. */ + @AutoValue.Builder + public abstract static class Builder { + /** Sets the {@link BaseOptions} for the audio embedder task. */ + public abstract Builder setBaseOptions(BaseOptions baseOptions); + + /** + * Sets the {@link RunningMode} for the audio embedder task. Default to the audio clips mode. + * Image embedder has two modes: + * + *

    + *
  • AUDIO_CLIPS: The mode for running audio embedding on audio clips. Users feed audio + * clips to the `embed` method, and will receive the embedding results as the return + * value. + *
  • AUDIO_STREAM: The mode for running audio embedding on the audio stream, such as from + * microphone. Users call `embedAsync` to push the audio data into the AudioEmbedder, + * the embedding results will be available in the result callback when the audio + * embedder finishes the work. + *
+ */ + public abstract Builder setRunningMode(RunningMode runningMode); + + /** + * Sets whether L2 normalization should be performed on the returned embeddings. Use this + * option only if the model does not already contain a native L2_NORMALIZATION TF + * Lite Op. In most cases, this is already the case and L2 norm is thus achieved through TF + * Lite inference. + * + *

False by default. + */ + public abstract Builder setL2Normalize(boolean l2Normalize); + + /** + * Sets whether the returned embedding should be quantized to bytes via scalar quantization. + * Embeddings are implicitly assumed to be unit-norm and therefore any dimensions is + * guaranteed to have value in [-1.0, 1.0]. Use {@link #setL2Normalize(boolean)} + * if this is not the case. + * + *

False by default. + */ + public abstract Builder setQuantize(boolean quantize); + + /** + * Sets the {@link ResultListener} to receive the embedding results asynchronously when the + * audio embedder is in the audio stream mode. + */ + public abstract Builder setResultListener( + PureResultListener resultListener); + + /** Sets an optional {@link ErrorListener}. */ + public abstract Builder setErrorListener(ErrorListener errorListener); + + abstract AudioEmbedderOptions autoBuild(); + + /** + * Validates and builds the {@link AudioEmbedderOptions} instance. + * + * @throws IllegalArgumentException if the result listener and the running mode are not + * properly configured. The result listener should only be set when the audio embedder is + * in the audio stream mode. + */ + public final AudioEmbedderOptions build() { + AudioEmbedderOptions options = autoBuild(); + if (options.runningMode() == RunningMode.AUDIO_STREAM) { + if (!options.resultListener().isPresent()) { + throw new IllegalArgumentException( + "The audio embedder is in the audio stream mode, a user-defined result listener" + + " must be provided in the AudioEmbedderOptions."); + } + } else if (options.resultListener().isPresent()) { + throw new IllegalArgumentException( + "The audio embedder is in the audio clips mode, a user-defined result listener" + + " shouldn't be provided in AudioEmbedderOptions."); + } + return options; + } + } + + abstract BaseOptions baseOptions(); + + abstract RunningMode runningMode(); + + abstract boolean l2Normalize(); + + abstract boolean quantize(); + + abstract Optional> resultListener(); + + abstract Optional errorListener(); + + public static Builder builder() { + return new AutoValue_AudioEmbedder_AudioEmbedderOptions.Builder() + .setRunningMode(RunningMode.AUDIO_CLIPS) + .setL2Normalize(false) + .setQuantize(false); + } + + /** Converts a {@link AudioEmbedderOptions} to a {@link CalculatorOptions} protobuf message. */ + @Override + public CalculatorOptions convertToCalculatorOptionsProto() { + BaseOptionsProto.BaseOptions.Builder baseOptionsBuilder = + BaseOptionsProto.BaseOptions.newBuilder(); + baseOptionsBuilder.setUseStreamMode(runningMode() == RunningMode.AUDIO_STREAM); + baseOptionsBuilder.mergeFrom(convertBaseOptionsToProto(baseOptions())); + EmbedderOptionsProto.EmbedderOptions.Builder embedderOptionsBuilder = + EmbedderOptionsProto.EmbedderOptions.newBuilder(); + embedderOptionsBuilder.setL2Normalize(l2Normalize()); + embedderOptionsBuilder.setQuantize(quantize()); + AudioEmbedderGraphOptionsProto.AudioEmbedderGraphOptions.Builder taskOptionsBuilder = + AudioEmbedderGraphOptionsProto.AudioEmbedderGraphOptions.newBuilder() + .setBaseOptions(baseOptionsBuilder) + .setEmbedderOptions(embedderOptionsBuilder); + return CalculatorOptions.newBuilder() + .setExtension( + AudioEmbedderGraphOptionsProto.AudioEmbedderGraphOptions.ext, + taskOptionsBuilder.build()) + .build(); + } + } +} diff --git a/mediapipe/tasks/java/com/google/mediapipe/tasks/audio/audioembedder/AudioEmbedderResult.java b/mediapipe/tasks/java/com/google/mediapipe/tasks/audio/audioembedder/AudioEmbedderResult.java new file mode 100644 index 000000000..0cfd2297c --- /dev/null +++ b/mediapipe/tasks/java/com/google/mediapipe/tasks/audio/audioembedder/AudioEmbedderResult.java @@ -0,0 +1,72 @@ +// Copyright 2022 The MediaPipe Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package com.google.mediapipe.tasks.audio.audioembedder; + +import com.google.auto.value.AutoValue; +import com.google.mediapipe.tasks.components.containers.EmbeddingResult; +import com.google.mediapipe.tasks.components.containers.proto.EmbeddingsProto; +import com.google.mediapipe.tasks.core.TaskResult; +import java.util.ArrayList; +import java.util.List; + +/** Represents the embedding results generated by {@link AudioEmbedder}. */ +@AutoValue +public abstract class AudioEmbedderResult implements TaskResult { + + /** + * Creates an {@link AudioEmbedderResult} instance from a list of {@link + * EmbeddingsProto.EmbeddingResult} protobuf messages. + * + * @param protoList a list of {@link EmbeddingsProto.EmbeddingResult} protobuf message to convert. + * @param timestampMs a timestamp for this result. + */ + static AudioEmbedderResult createFromProtoList( + List protoList, long timestampMs) { + List embeddingResultList = new ArrayList<>(); + for (EmbeddingsProto.EmbeddingResult proto : protoList) { + embeddingResultList.add(EmbeddingResult.createFromProto(proto)); + } + return new AutoValue_AudioEmbedderResult(embeddingResultList, timestampMs); + } + + /** + * Creates an {@link AudioEmbedderResult} instance from a {@link EmbeddingsProto.EmbeddingResult} + * protobuf message. + * + * @param proto the {@link EmbeddingsProto.EmbeddingResult} protobuf message to convert. + * @param timestampMs a timestamp for this result. + */ + static AudioEmbedderResult createFromProto( + EmbeddingsProto.EmbeddingResult proto, long timestampMs) { + List embeddingResultList = new ArrayList<>(); + embeddingResultList.add(EmbeddingResult.createFromProto(proto)); + return new AutoValue_AudioEmbedderResult(embeddingResultList, timestampMs); + } + + /** + * A list of of timpstamped {@link EmbeddingResult} objects, each contains one set of results per + * embedder head. + * + *

In the "audio stream" mode, the list only contains one element, representing the embedding + * result of the audio block that starts at {@link EmbeddingResult.timestampMs} in the audio + * stream. Otherwise, in the "audio clips" mode, the list may include multiple {@link + * EmbeddingResult} objects, each contains the embedding of an interval of the entire audio clip + * that starts at {@link EmbeddingResult.timestampMs}. + */ + public abstract List embeddingResults(); + + @Override + public abstract long timestampMs(); +} diff --git a/mediapipe/tasks/java/com/google/mediapipe/tasks/audio/core/BaseAudioTaskApi.java b/mediapipe/tasks/java/com/google/mediapipe/tasks/audio/core/BaseAudioTaskApi.java index affe43559..7abde72d5 100644 --- a/mediapipe/tasks/java/com/google/mediapipe/tasks/audio/core/BaseAudioTaskApi.java +++ b/mediapipe/tasks/java/com/google/mediapipe/tasks/audio/core/BaseAudioTaskApi.java @@ -14,6 +14,9 @@ package com.google.mediapipe.tasks.audio.core; +import android.media.AudioFormat; +import android.media.AudioRecord; +import android.media.MediaRecorder; import com.google.mediapipe.framework.MediaPipeException; import com.google.mediapipe.framework.Packet; import com.google.mediapipe.tasks.components.containers.AudioData; @@ -91,7 +94,7 @@ public class BaseAudioTaskApi implements AutoCloseable { * * @param sampleRate the audio sample rate. * @throws MediaPipeException if the task is not in the audio stream mode or the provided sample - * rate is inconsisent with the previously recevied. + * rate is inconsistent with the previously received. */ protected void checkOrSetSampleRate(double sampleRate) { if (runningMode != RunningMode.AUDIO_STREAM) { @@ -116,6 +119,7 @@ public class BaseAudioTaskApi implements AutoCloseable { defaultSampleRate = sampleRate; } } + /** * An asynchronous method to send audio stream data to the {@link TaskRunner}. The results will be * available in the user-defined result listener. @@ -148,4 +152,71 @@ public class BaseAudioTaskApi implements AutoCloseable { public void close() { runner.close(); } + + /** + * Creates an {@link android.media.AudioRecord} instance to record audio stream. The returned + * AudioRecord instance is initialized and client needs to call {@link + * android.media.AudioRecord#startRecording} method to start recording. + * + *

Note that MediaPipe Audio tasks will up/down sample automatically to fit the sample rate + * required by the model. The default sample rate of the MediaPipe pretrained audio model, Yamnet, + * is 16kHz. + * + * @param numChannels the number of audio channels. + * @param sampleRate the audio sample rate. + * @return an {@link android.media.AudioRecord} instance in {@link + * android.media.AudioRecord#STATE_INITIALIZED} + * @throws IllegalArgumentException if the model required channel count is unsupported + * @throws IllegalStateException if AudioRecord instance failed to initialize + */ + public static AudioRecord createAudioRecord(int numChannels, int sampleRate) { + int channelConfig = 0; + switch (numChannels) { + case 1: + channelConfig = AudioFormat.CHANNEL_IN_MONO; + break; + case 2: + channelConfig = AudioFormat.CHANNEL_IN_STEREO; + break; + default: + throw new IllegalArgumentException( + "getAudioRecord method only supports 1 or 2 audio channels."); + } + + int bufferSizeInBytes = + AudioRecord.getMinBufferSize(sampleRate, channelConfig, AudioFormat.ENCODING_PCM_FLOAT); + if (bufferSizeInBytes == AudioRecord.ERROR + || bufferSizeInBytes == AudioRecord.ERROR_BAD_VALUE) { + throw new IllegalStateException( + String.format("AudioRecord.getMinBufferSize failed. Returned: %d", bufferSizeInBytes)); + } + AudioRecord audioRecord = + new AudioRecord( + // including MIC, UNPROCESSED, and CAMCORDER. + MediaRecorder.AudioSource.VOICE_RECOGNITION, + sampleRate, + channelConfig, + AudioFormat.ENCODING_PCM_FLOAT, + bufferSizeInBytes); + if (audioRecord.getState() != AudioRecord.STATE_INITIALIZED) { + throw new IllegalStateException(String.format("AudioRecordfailed to initialize")); + } + return audioRecord; + } + + /** + * Creates an {@link android.media.AudioRecord} instance to record audio stream that has mono + * channel at sample rate at sample rate 16kHz, the sample rate required for models like Yamnet. + * The returned AudioRecord instance is initialized and client needs to call {@link + * android.media.AudioRecord#startRecording} method to start recording. + * + * @return an {@link android.media.AudioRecord} instance in {@link + * android.media.AudioRecord#STATE_INITIALIZED} + * @throws IllegalArgumentException if the model required channel count is unsupported + * @throws IllegalStateException if AudioRecord instance failed to initialize + */ + public static AudioRecord createAudioRecord() { + // TODO: Support creating AudioRecord based on the model specifications. + return createAudioRecord(1, 16000); + } } diff --git a/mediapipe/tasks/java/com/google/mediapipe/tasks/audio/core/RunningMode.java b/mediapipe/tasks/java/com/google/mediapipe/tasks/audio/core/RunningMode.java index f0a123810..a778eae46 100644 --- a/mediapipe/tasks/java/com/google/mediapipe/tasks/audio/core/RunningMode.java +++ b/mediapipe/tasks/java/com/google/mediapipe/tasks/audio/core/RunningMode.java @@ -20,7 +20,7 @@ package com.google.mediapipe.tasks.audio.core; *

    *
  • AUDIO_CLIPS: The mode for running a mediapipe audio task on independent audio clips. *
  • AUDIO_STREAM: The mode for running a mediapipe audio task on an audio stream, such as from - * microphone. + * a microphone. *
*/ public enum RunningMode { diff --git a/mediapipe/tasks/java/com/google/mediapipe/tasks/components/containers/BUILD b/mediapipe/tasks/java/com/google/mediapipe/tasks/components/containers/BUILD index d6e6ac740..4d302b950 100644 --- a/mediapipe/tasks/java/com/google/mediapipe/tasks/components/containers/BUILD +++ b/mediapipe/tasks/java/com/google/mediapipe/tasks/components/containers/BUILD @@ -12,7 +12,7 @@ # See the License for the specific language governing permissions and # limitations under the License. -package(default_visibility = ["//mediapipe/tasks:internal"]) +package(default_visibility = ["//visibility:public"]) licenses(["notice"]) @@ -83,6 +83,15 @@ android_library( ], ) +android_library( + name = "normalized_landmark", + srcs = ["NormalizedLandmark.java"], + deps = [ + "//third_party:autovalue", + "@maven//:com_google_guava_guava", + ], +) + # Expose the java source files for building mediapipe tasks core AAR. filegroup( name = "java_src", diff --git a/mediapipe/tasks/java/com/google/mediapipe/tasks/components/containers/Category.java b/mediapipe/tasks/java/com/google/mediapipe/tasks/components/containers/Category.java index e955605e4..ab3fd0bd8 100644 --- a/mediapipe/tasks/java/com/google/mediapipe/tasks/components/containers/Category.java +++ b/mediapipe/tasks/java/com/google/mediapipe/tasks/components/containers/Category.java @@ -19,9 +19,9 @@ import com.google.mediapipe.formats.proto.ClassificationProto; import java.util.Objects; /** - * Category is a util class, contains a category name, its display name, a float value as score, and - * the index of the label in the corresponding label file. Typically it's used as result of - * classification or detection tasks. + * Category is a util class, that contains a category name, its display name, a float value as + * score, and the index of the label in the corresponding label file. Typically it's used as result + * of classification or detection tasks. */ @AutoValue public abstract class Category { diff --git a/mediapipe/tasks/java/com/google/mediapipe/tasks/components/containers/Landmark.java b/mediapipe/tasks/java/com/google/mediapipe/tasks/components/containers/Landmark.java index e45866190..7fb1b99d0 100644 --- a/mediapipe/tasks/java/com/google/mediapipe/tasks/components/containers/Landmark.java +++ b/mediapipe/tasks/java/com/google/mediapipe/tasks/components/containers/Landmark.java @@ -18,16 +18,16 @@ import com.google.auto.value.AutoValue; import java.util.Objects; /** - * Landmark represents a point in 3D space with x, y, z coordinates. If normalized is true, the - * landmark coordinates is normalized respect to the dimension of image, and the coordinates values - * are in the range of [0,1]. Otherwise, it represenet a point in world coordinates. + * Landmark represents a point in 3D space with x, y, z coordinates. The landmark coordinates are in + * meters. z represents the landmark depth, and the smaller the value the closer the world landmark + * is to the camera. */ @AutoValue public abstract class Landmark { private static final float TOLERANCE = 1e-6f; - public static Landmark create(float x, float y, float z, boolean normalized) { - return new AutoValue_Landmark(x, y, z, normalized); + public static Landmark create(float x, float y, float z) { + return new AutoValue_Landmark(x, y, z); } // The x coordinates of the landmark. @@ -39,28 +39,24 @@ public abstract class Landmark { // The z coordinates of the landmark. public abstract float z(); - // Whether this landmark is normalized with respect to the image size. - public abstract boolean normalized(); - @Override public final boolean equals(Object o) { if (!(o instanceof Landmark)) { return false; } Landmark other = (Landmark) o; - return other.normalized() == this.normalized() - && Math.abs(other.x() - this.x()) < TOLERANCE + return Math.abs(other.x() - this.x()) < TOLERANCE && Math.abs(other.x() - this.y()) < TOLERANCE && Math.abs(other.x() - this.z()) < TOLERANCE; } @Override public final int hashCode() { - return Objects.hash(x(), y(), z(), normalized()); + return Objects.hash(x(), y(), z()); } @Override public final String toString() { - return ""; + return ""; } } diff --git a/mediapipe/tasks/java/com/google/mediapipe/tasks/components/containers/NormalizedLandmark.java b/mediapipe/tasks/java/com/google/mediapipe/tasks/components/containers/NormalizedLandmark.java new file mode 100644 index 000000000..e77f3c3d4 --- /dev/null +++ b/mediapipe/tasks/java/com/google/mediapipe/tasks/components/containers/NormalizedLandmark.java @@ -0,0 +1,63 @@ +// Copyright 2022 The MediaPipe Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package com.google.mediapipe.tasks.components.containers; + +import com.google.auto.value.AutoValue; +import java.util.Objects; + +/** + * Normalized Landmark represents a point in 3D space with x, y, z coordinates. x and y are + * normalized to [0.0, 1.0] by the image width and height respectively. z represents the landmark + * depth, and the smaller the value the closer the landmark is to the camera. The magnitude of z + * uses roughly the same scale as x. + */ +@AutoValue +public abstract class NormalizedLandmark { + private static final float TOLERANCE = 1e-6f; + + public static NormalizedLandmark create(float x, float y, float z) { + return new AutoValue_NormalizedLandmark(x, y, z); + } + + // The x coordinates of the normalized landmark. + public abstract float x(); + + // The y coordinates of the normalized landmark. + public abstract float y(); + + // The z coordinates of the normalized landmark. + public abstract float z(); + + @Override + public final boolean equals(Object o) { + if (!(o instanceof NormalizedLandmark)) { + return false; + } + NormalizedLandmark other = (NormalizedLandmark) o; + return Math.abs(other.x() - this.x()) < TOLERANCE + && Math.abs(other.x() - this.y()) < TOLERANCE + && Math.abs(other.x() - this.z()) < TOLERANCE; + } + + @Override + public final int hashCode() { + return Objects.hash(x(), y(), z()); + } + + @Override + public final String toString() { + return ""; + } +} diff --git a/mediapipe/tasks/java/com/google/mediapipe/tasks/components/processors/BUILD b/mediapipe/tasks/java/com/google/mediapipe/tasks/components/processors/BUILD index e61e59390..b4d453935 100644 --- a/mediapipe/tasks/java/com/google/mediapipe/tasks/components/processors/BUILD +++ b/mediapipe/tasks/java/com/google/mediapipe/tasks/components/processors/BUILD @@ -12,7 +12,7 @@ # See the License for the specific language governing permissions and # limitations under the License. -package(default_visibility = ["//mediapipe/tasks:internal"]) +package(default_visibility = ["//visibility:public"]) licenses(["notice"]) @@ -29,19 +29,6 @@ android_library( ], ) -android_library( - name = "embedderoptions", - srcs = ["EmbedderOptions.java"], - javacopts = [ - "-Xep:AndroidJdkLibsChecker:OFF", - ], - deps = [ - "//mediapipe/tasks/cc/components/processors/proto:embedder_options_java_proto_lite", - "//third_party:autovalue", - "@maven//:com_google_guava_guava", - ], -) - # Expose the java source files for building mediapipe tasks core AAR. filegroup( name = "java_src", diff --git a/mediapipe/tasks/java/com/google/mediapipe/tasks/components/processors/EmbedderOptions.java b/mediapipe/tasks/java/com/google/mediapipe/tasks/components/processors/EmbedderOptions.java deleted file mode 100644 index 3cd197234..000000000 --- a/mediapipe/tasks/java/com/google/mediapipe/tasks/components/processors/EmbedderOptions.java +++ /dev/null @@ -1,68 +0,0 @@ -// Copyright 2022 The MediaPipe Authors. All Rights Reserved. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -package com.google.mediapipe.tasks.components.processors; - -import com.google.auto.value.AutoValue; -import com.google.mediapipe.tasks.components.processors.proto.EmbedderOptionsProto; - -/** Embedder options shared across MediaPipe Java embedding tasks. */ -@AutoValue -public abstract class EmbedderOptions { - - /** Builder for {@link EmbedderOptions} */ - @AutoValue.Builder - public abstract static class Builder { - /** - * Sets whether L2 normalization should be performed on the returned embeddings. Use this option - * only if the model does not already contain a native L2_NORMALIZATION TF Lite Op. - * In most cases, this is already the case and L2 norm is thus achieved through TF Lite - * inference. - * - *

False by default. - */ - public abstract Builder setL2Normalize(boolean l2Normalize); - - /** - * Sets whether the returned embedding should be quantized to bytes via scalar quantization. - * Embeddings are implicitly assumed to be unit-norm and therefore any dimensions is guaranteed - * to have value in [-1.0, 1.0]. Use {@link #setL2Normalize(boolean)} if this is - * not the case. - * - *

False by default. - */ - public abstract Builder setQuantize(boolean quantize); - - public abstract EmbedderOptions build(); - } - - public abstract boolean l2Normalize(); - - public abstract boolean quantize(); - - public static Builder builder() { - return new AutoValue_EmbedderOptions.Builder().setL2Normalize(false).setQuantize(false); - } - - /** - * Converts an {@link EmbedderOptions} object to an {@link EmbedderOptionsProto.EmbedderOptions} - * protobuf message. - */ - public EmbedderOptionsProto.EmbedderOptions convertToProto() { - return EmbedderOptionsProto.EmbedderOptions.newBuilder() - .setL2Normalize(l2Normalize()) - .setQuantize(quantize()) - .build(); - } -} diff --git a/mediapipe/tasks/java/com/google/mediapipe/tasks/components/utils/BUILD b/mediapipe/tasks/java/com/google/mediapipe/tasks/components/utils/BUILD index b2d27bfa7..6c724106f 100644 --- a/mediapipe/tasks/java/com/google/mediapipe/tasks/components/utils/BUILD +++ b/mediapipe/tasks/java/com/google/mediapipe/tasks/components/utils/BUILD @@ -12,7 +12,7 @@ # See the License for the specific language governing permissions and # limitations under the License. -package(default_visibility = ["//mediapipe/tasks:internal"]) +package(default_visibility = ["//visibility:public"]) licenses(["notice"]) diff --git a/mediapipe/tasks/java/com/google/mediapipe/tasks/core/BUILD b/mediapipe/tasks/java/com/google/mediapipe/tasks/core/BUILD index 01b1f653a..5f7101776 100644 --- a/mediapipe/tasks/java/com/google/mediapipe/tasks/core/BUILD +++ b/mediapipe/tasks/java/com/google/mediapipe/tasks/core/BUILD @@ -12,7 +12,7 @@ # See the License for the specific language governing permissions and # limitations under the License. -package(default_visibility = ["//mediapipe/tasks:internal"]) +package(default_visibility = ["//visibility:public"]) android_library( name = "core", @@ -22,6 +22,7 @@ android_library( ], manifest = "AndroidManifest.xml", deps = [ + ":logging", "//mediapipe/calculators/core:flow_limiter_calculator_java_proto_lite", "//mediapipe/calculators/tensor:inference_calculator_java_proto_lite", "//mediapipe/framework:calculator_java_proto_lite", @@ -37,11 +38,22 @@ android_library( ], ) +android_library( + name = "logging", + srcs = glob( + ["logging/*.java"], + ), + deps = [ + "//third_party:autovalue", + "@maven//:com_google_guava_guava", + ], +) + load("//mediapipe/tasks/java/com/google/mediapipe/tasks:mediapipe_tasks_aar.bzl", "mediapipe_tasks_core_aar") mediapipe_tasks_core_aar( name = "tasks_core", - srcs = glob(["*.java"]) + [ + srcs = glob(["**/*.java"]) + [ "//mediapipe/tasks/java/com/google/mediapipe/tasks/components/containers:java_src", "//mediapipe/tasks/java/com/google/mediapipe/tasks/components/processors:java_src", "//mediapipe/tasks/java/com/google/mediapipe/tasks/components/utils:java_src", diff --git a/mediapipe/tasks/java/com/google/mediapipe/tasks/core/TaskInfo.java b/mediapipe/tasks/java/com/google/mediapipe/tasks/core/TaskInfo.java index 12f8be8ba..310f5739c 100644 --- a/mediapipe/tasks/java/com/google/mediapipe/tasks/core/TaskInfo.java +++ b/mediapipe/tasks/java/com/google/mediapipe/tasks/core/TaskInfo.java @@ -32,6 +32,12 @@ public abstract class TaskInfo { /** Builder for {@link TaskInfo}. */ @AutoValue.Builder public abstract static class Builder { + /** Sets the MediaPipe task name. */ + public abstract Builder setTaskName(String value); + + /** Sets the MediaPipe task running mode name. */ + public abstract Builder setTaskRunningModeName(String value); + /** Sets the MediaPipe task graph name. */ public abstract Builder setTaskGraphName(String value); @@ -71,6 +77,10 @@ public abstract class TaskInfo { } } + abstract String taskName(); + + abstract String taskRunningModeName(); + abstract String taskGraphName(); abstract T taskOptions(); @@ -82,7 +92,7 @@ public abstract class TaskInfo { abstract Boolean enableFlowLimiting(); public static Builder builder() { - return new AutoValue_TaskInfo.Builder(); + return new AutoValue_TaskInfo.Builder().setTaskName("").setTaskRunningModeName(""); } /* Returns a list of the output stream names without the stream tags. */ diff --git a/mediapipe/tasks/java/com/google/mediapipe/tasks/core/TaskOptions.java b/mediapipe/tasks/java/com/google/mediapipe/tasks/core/TaskOptions.java index 9bf600360..0fc48742e 100644 --- a/mediapipe/tasks/java/com/google/mediapipe/tasks/core/TaskOptions.java +++ b/mediapipe/tasks/java/com/google/mediapipe/tasks/core/TaskOptions.java @@ -58,8 +58,8 @@ public abstract class TaskOptions { AccelerationProto.Acceleration.newBuilder(); switch (options.delegate()) { case CPU: - accelerationBuilder.setXnnpack( - InferenceCalculatorProto.InferenceCalculatorOptions.Delegate.Xnnpack + accelerationBuilder.setTflite( + InferenceCalculatorProto.InferenceCalculatorOptions.Delegate.TfLite .getDefaultInstance()); break; case GPU: diff --git a/mediapipe/tasks/java/com/google/mediapipe/tasks/core/TaskRunner.java b/mediapipe/tasks/java/com/google/mediapipe/tasks/core/TaskRunner.java index e6fc91cf6..1a128c538 100644 --- a/mediapipe/tasks/java/com/google/mediapipe/tasks/core/TaskRunner.java +++ b/mediapipe/tasks/java/com/google/mediapipe/tasks/core/TaskRunner.java @@ -21,6 +21,8 @@ import com.google.mediapipe.framework.AndroidPacketCreator; import com.google.mediapipe.framework.Graph; import com.google.mediapipe.framework.MediaPipeException; import com.google.mediapipe.framework.Packet; +import com.google.mediapipe.tasks.core.logging.TasksStatsLogger; +import com.google.mediapipe.tasks.core.logging.TasksStatsDummyLogger; import java.util.Map; import java.util.concurrent.atomic.AtomicBoolean; @@ -34,6 +36,7 @@ public class TaskRunner implements AutoCloseable { private final Graph graph; private final ModelResourcesCache modelResourcesCache; private final AndroidPacketCreator packetCreator; + private final TasksStatsLogger statsLogger; private long lastSeenTimestamp = Long.MIN_VALUE; private ErrorListener errorListener; @@ -51,6 +54,8 @@ public class TaskRunner implements AutoCloseable { Context context, TaskInfo taskInfo, OutputHandler outputHandler) { + TasksStatsLogger statsLogger = + TasksStatsDummyLogger.create(context, taskInfo.taskName(), taskInfo.taskRunningModeName()); AndroidAssetUtil.initializeNativeAssetManager(context); Graph mediapipeGraph = new Graph(); mediapipeGraph.loadBinaryGraph(taskInfo.generateGraphConfig()); @@ -58,12 +63,15 @@ public class TaskRunner implements AutoCloseable { mediapipeGraph.setServiceObject(new ModelResourcesCacheService(), graphModelResourcesCache); mediapipeGraph.addMultiStreamCallback( taskInfo.outputStreamNames(), - outputHandler::run, - /*observeTimestampBounds=*/ outputHandler.handleTimestampBoundChanges()); + packets -> { + outputHandler.run(packets); + statsLogger.recordInvocationEnd(packets.get(0).getTimestamp()); + }, + /* observeTimestampBounds= */ outputHandler.handleTimestampBoundChanges()); mediapipeGraph.startRunningGraph(); // Waits until all calculators are opened and the graph is fully started. mediapipeGraph.waitUntilGraphIdle(); - return new TaskRunner(mediapipeGraph, graphModelResourcesCache, outputHandler); + return new TaskRunner(mediapipeGraph, graphModelResourcesCache, outputHandler, statsLogger); } /** @@ -91,7 +99,10 @@ public class TaskRunner implements AutoCloseable { * @param inputs a map contains (input stream {@link String}, data {@link Packet}) pairs. */ public synchronized TaskResult process(Map inputs) { - addPackets(inputs, generateSyntheticTimestamp()); + long syntheticInputTimestamp = generateSyntheticTimestamp(); + // TODO: Support recording GPU input arrival. + statsLogger.recordCpuInputArrival(syntheticInputTimestamp); + addPackets(inputs, syntheticInputTimestamp); graph.waitUntilGraphIdle(); lastSeenTimestamp = outputHandler.getLatestOutputTimestamp(); return outputHandler.retrieveCachedTaskResult(); @@ -112,6 +123,7 @@ public class TaskRunner implements AutoCloseable { */ public synchronized TaskResult process(Map inputs, long inputTimestamp) { validateInputTimstamp(inputTimestamp); + statsLogger.recordCpuInputArrival(inputTimestamp); addPackets(inputs, inputTimestamp); graph.waitUntilGraphIdle(); return outputHandler.retrieveCachedTaskResult(); @@ -132,6 +144,7 @@ public class TaskRunner implements AutoCloseable { */ public synchronized void send(Map inputs, long inputTimestamp) { validateInputTimstamp(inputTimestamp); + statsLogger.recordCpuInputArrival(inputTimestamp); addPackets(inputs, inputTimestamp); } @@ -145,6 +158,7 @@ public class TaskRunner implements AutoCloseable { graphStarted.set(false); graph.closeAllPacketSources(); graph.waitUntilGraphDone(); + statsLogger.logSessionEnd(); } catch (MediaPipeException e) { reportError(e); } @@ -154,6 +168,7 @@ public class TaskRunner implements AutoCloseable { // Waits until all calculators are opened and the graph is fully restarted. graph.waitUntilGraphIdle(); graphStarted.set(true); + statsLogger.logSessionStart(); } catch (MediaPipeException e) { reportError(e); } @@ -169,6 +184,7 @@ public class TaskRunner implements AutoCloseable { graphStarted.set(false); graph.closeAllPacketSources(); graph.waitUntilGraphDone(); + statsLogger.logSessionEnd(); if (modelResourcesCache != null) { modelResourcesCache.release(); } @@ -247,12 +263,15 @@ public class TaskRunner implements AutoCloseable { private TaskRunner( Graph graph, ModelResourcesCache modelResourcesCache, - OutputHandler outputHandler) { + OutputHandler outputHandler, + TasksStatsLogger statsLogger) { this.outputHandler = outputHandler; this.graph = graph; this.modelResourcesCache = modelResourcesCache; this.packetCreator = new AndroidPacketCreator(graph); + this.statsLogger = statsLogger; graphStarted.set(true); + this.statsLogger.logSessionStart(); } /** Reports error. */ diff --git a/mediapipe/tasks/java/com/google/mediapipe/tasks/core/logging/TasksStatsDummyLogger.java b/mediapipe/tasks/java/com/google/mediapipe/tasks/core/logging/TasksStatsDummyLogger.java new file mode 100644 index 000000000..c10b5d224 --- /dev/null +++ b/mediapipe/tasks/java/com/google/mediapipe/tasks/core/logging/TasksStatsDummyLogger.java @@ -0,0 +1,78 @@ +// Copyright 2022 The MediaPipe Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package com.google.mediapipe.tasks.core.logging; + +import android.content.Context; + +/** A dummy MediaPipe Tasks stats logger that has all methods as no-ops. */ +public class TasksStatsDummyLogger implements TasksStatsLogger { + + /** + * Creates the MediaPipe Tasks stats dummy logger. + * + * @param context a {@link Context}. + * @param taskNameStr the task api name. + * @param taskRunningModeStr the task running mode string representation. + */ + public static TasksStatsDummyLogger create( + Context context, String taskNameStr, String taskRunningModeStr) { + return new TasksStatsDummyLogger(); + } + + private TasksStatsDummyLogger() {} + + /** Logs the start of a MediaPipe Tasks API session. */ + @Override + public void logSessionStart() {} + + /** + * Records MediaPipe Tasks API receiving CPU input data. + * + * @param packetTimestamp the input packet timestamp that acts as the identifier of the api + * invocation. + */ + @Override + public void recordCpuInputArrival(long packetTimestamp) {} + + /** + * Records MediaPipe Tasks API receiving GPU input data. + * + * @param packetTimestamp the input packet timestamp that acts as the identifier of the api + * invocation. + */ + @Override + public void recordGpuInputArrival(long packetTimestamp) {} + + /** + * Records the end of a Mediapipe Tasks API invocation. + * + * @param packetTimestamp the output packet timestamp that acts as the identifier of the api + * invocation. + */ + @Override + public void recordInvocationEnd(long packetTimestamp) {} + + /** Logs the MediaPipe Tasks API periodic invocation report. */ + @Override + public void logInvocationReport(StatsSnapshot stats) {} + + /** Logs the Tasks API session end event. */ + @Override + public void logSessionEnd() {} + + /** Logs the MediaPipe Tasks API initialization error. */ + @Override + public void logInitError() {} +} diff --git a/mediapipe/tasks/java/com/google/mediapipe/tasks/core/logging/TasksStatsLogger.java b/mediapipe/tasks/java/com/google/mediapipe/tasks/core/logging/TasksStatsLogger.java new file mode 100644 index 000000000..c726e7d0d --- /dev/null +++ b/mediapipe/tasks/java/com/google/mediapipe/tasks/core/logging/TasksStatsLogger.java @@ -0,0 +1,98 @@ +// Copyright 2022 The MediaPipe Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package com.google.mediapipe.tasks.core.logging; + +import com.google.auto.value.AutoValue; + +/** The stats logger interface that defines what MediaPipe Tasks events to log. */ +public interface TasksStatsLogger { + /** Task stats snapshot. */ + @AutoValue + abstract static class StatsSnapshot { + static StatsSnapshot create( + int cpuInputCount, + int gpuInputCount, + int finishedCount, + int droppedCount, + long totalLatencyMs, + long peakLatencyMs, + long elapsedTimeMs) { + return new AutoValue_TasksStatsLogger_StatsSnapshot( + cpuInputCount, + gpuInputCount, + finishedCount, + droppedCount, + totalLatencyMs, + peakLatencyMs, + elapsedTimeMs); + } + + static StatsSnapshot createDefault() { + return new AutoValue_TasksStatsLogger_StatsSnapshot(0, 0, 0, 0, 0, 0, 0); + } + + abstract int cpuInputCount(); + + abstract int gpuInputCount(); + + abstract int finishedCount(); + + abstract int droppedCount(); + + abstract long totalLatencyMs(); + + abstract long peakLatencyMs(); + + abstract long elapsedTimeMs(); + } + + /** Logs the start of a MediaPipe Tasks API session. */ + public void logSessionStart(); + + /** + * Records MediaPipe Tasks API receiving CPU input data. + * + * @param packetTimestamp the input packet timestamp that acts as the identifier of the api + * invocation. + */ + public void recordCpuInputArrival(long packetTimestamp); + + /** + * Records MediaPipe Tasks API receiving GPU input data. + * + * @param packetTimestamp the input packet timestamp that acts as the identifier of the api + * invocation. + */ + public void recordGpuInputArrival(long packetTimestamp); + + /** + * Records the end of a Mediapipe Tasks API invocation. + * + * @param packetTimestamp the output packet timestamp that acts as the identifier of the api + * invocation. + */ + public void recordInvocationEnd(long packetTimestamp); + + /** Logs the MediaPipe Tasks API periodic invocation report. */ + public void logInvocationReport(StatsSnapshot stats); + + /** Logs the Tasks API session end event. */ + public void logSessionEnd(); + + /** Logs the MediaPipe Tasks API initialization error. */ + public void logInitError(); + + // TODO: Logs more error types. +} diff --git a/mediapipe/tasks/java/com/google/mediapipe/tasks/mediapipe_tasks_aar.bzl b/mediapipe/tasks/java/com/google/mediapipe/tasks/mediapipe_tasks_aar.bzl index f0c9f81c6..727d020a6 100644 --- a/mediapipe/tasks/java/com/google/mediapipe/tasks/mediapipe_tasks_aar.bzl +++ b/mediapipe/tasks/java/com/google/mediapipe/tasks/mediapipe_tasks_aar.bzl @@ -18,11 +18,9 @@ load("//mediapipe/java/com/google/mediapipe:mediapipe_aar.bzl", "mediapipe_build load("@build_bazel_rules_android//android:rules.bzl", "android_library") _CORE_TASKS_JAVA_PROTO_LITE_TARGETS = [ - "//mediapipe/tasks/cc/components/containers/proto:category_java_proto_lite", "//mediapipe/tasks/cc/components/containers/proto:classifications_java_proto_lite", "//mediapipe/tasks/cc/components/containers/proto:embeddings_java_proto_lite", "//mediapipe/tasks/cc/components/containers/proto:landmarks_detection_result_java_proto_lite", - "//mediapipe/tasks/cc/components/proto:segmenter_options_java_proto_lite", "//mediapipe/tasks/cc/components/processors/proto:classifier_options_java_proto_lite", "//mediapipe/tasks/cc/components/processors/proto:embedder_options_java_proto_lite", "//mediapipe/tasks/cc/core/proto:acceleration_java_proto_lite", @@ -32,6 +30,7 @@ _CORE_TASKS_JAVA_PROTO_LITE_TARGETS = [ _AUDIO_TASKS_JAVA_PROTO_LITE_TARGETS = [ "//mediapipe/tasks/cc/audio/audio_classifier/proto:audio_classifier_graph_options_java_proto_lite", + "//mediapipe/tasks/cc/audio/audio_embedder/proto:audio_embedder_graph_options_java_proto_lite", ] _VISION_TASKS_JAVA_PROTO_LITE_TARGETS = [ @@ -41,7 +40,9 @@ _VISION_TASKS_JAVA_PROTO_LITE_TARGETS = [ "//mediapipe/tasks/cc/vision/gesture_recognizer/proto:gesture_embedder_graph_options_java_proto_lite", "//mediapipe/tasks/cc/vision/gesture_recognizer/proto:gesture_recognizer_graph_options_java_proto_lite", "//mediapipe/tasks/cc/vision/gesture_recognizer/proto:hand_gesture_recognizer_graph_options_java_proto_lite", + "//mediapipe/tasks/cc/vision/image_embedder/proto:image_embedder_graph_options_java_proto_lite", "//mediapipe/tasks/cc/vision/image_segmenter/proto:image_segmenter_graph_options_java_proto_lite", + "//mediapipe/tasks/cc/vision/image_segmenter/proto:segmenter_options_java_proto_lite", "//mediapipe/tasks/cc/vision/hand_detector/proto:hand_detector_graph_options_java_proto_lite", "//mediapipe/tasks/cc/vision/hand_landmarker/proto:hand_landmarker_graph_options_java_proto_lite", "//mediapipe/tasks/cc/vision/hand_landmarker/proto:hand_landmarks_detector_graph_options_java_proto_lite", @@ -49,6 +50,7 @@ _VISION_TASKS_JAVA_PROTO_LITE_TARGETS = [ _TEXT_TASKS_JAVA_PROTO_LITE_TARGETS = [ "//mediapipe/tasks/cc/text/text_classifier/proto:text_classifier_graph_options_java_proto_lite", + "//mediapipe/tasks/cc/text/text_embedder/proto:text_embedder_graph_options_java_proto_lite", ] def mediapipe_tasks_core_aar(name, srcs, manifest): @@ -280,9 +282,14 @@ def _mediapipe_tasks_aar(name, srcs, manifest, java_proto_lite_targets, native_l "//mediapipe/tasks/java/com/google/mediapipe/tasks/components/containers:category", "//mediapipe/tasks/java/com/google/mediapipe/tasks/components/containers:classificationresult", "//mediapipe/tasks/java/com/google/mediapipe/tasks/components/containers:classifications", + "//mediapipe/tasks/java/com/google/mediapipe/tasks/components/containers:embedding", + "//mediapipe/tasks/java/com/google/mediapipe/tasks/components/containers:embeddingresult", "//mediapipe/tasks/java/com/google/mediapipe/tasks/components/containers:landmark", + "//mediapipe/tasks/java/com/google/mediapipe/tasks/components/containers:normalized_landmark", "//mediapipe/tasks/java/com/google/mediapipe/tasks/components/processors:classifieroptions", + "//mediapipe/tasks/java/com/google/mediapipe/tasks/components/utils:cosinesimilarity", "//mediapipe/tasks/java/com/google/mediapipe/tasks/core", + "//third_party:androidx_annotation", "//third_party:autovalue", "@maven//:com_google_guava_guava", ] + select({ diff --git a/mediapipe/tasks/java/com/google/mediapipe/tasks/text/BUILD b/mediapipe/tasks/java/com/google/mediapipe/tasks/text/BUILD index b49169529..31cd2c89a 100644 --- a/mediapipe/tasks/java/com/google/mediapipe/tasks/text/BUILD +++ b/mediapipe/tasks/java/com/google/mediapipe/tasks/text/BUILD @@ -14,7 +14,7 @@ licenses(["notice"]) -package(default_visibility = ["//mediapipe/tasks:internal"]) +package(default_visibility = ["//visibility:public"]) # The native library of all MediaPipe text tasks. cc_binary( @@ -24,6 +24,7 @@ cc_binary( deps = [ "//mediapipe/java/com/google/mediapipe/framework/jni:mediapipe_framework_jni", "//mediapipe/tasks/cc/text/text_classifier:text_classifier_graph", + "//mediapipe/tasks/cc/text/text_embedder:text_embedder_graph", "//mediapipe/tasks/java/com/google/mediapipe/tasks/core/jni:model_resources_cache_jni", ], ) @@ -47,12 +48,38 @@ android_library( deps = [ "//mediapipe/framework:calculator_options_java_proto_lite", "//mediapipe/java/com/google/mediapipe/framework:android_framework", - "//mediapipe/tasks/cc/components/containers/proto:category_java_proto_lite", "//mediapipe/tasks/cc/components/containers/proto:classifications_java_proto_lite", + "//mediapipe/tasks/cc/components/processors/proto:classifier_options_java_proto_lite", "//mediapipe/tasks/cc/core/proto:base_options_java_proto_lite", "//mediapipe/tasks/cc/text/text_classifier/proto:text_classifier_graph_options_java_proto_lite", "//mediapipe/tasks/java/com/google/mediapipe/tasks/components/containers:classificationresult", - "//mediapipe/tasks/java/com/google/mediapipe/tasks/components/processors:classifieroptions", + "//mediapipe/tasks/java/com/google/mediapipe/tasks/core", + "//mediapipe/tasks/java/com/google/mediapipe/tasks/text:libmediapipe_tasks_text_jni_lib", + "//third_party:autovalue", + "@maven//:com_google_guava_guava", + ], +) + +android_library( + name = "textembedder", + srcs = [ + "textembedder/TextEmbedder.java", + "textembedder/TextEmbedderResult.java", + ], + javacopts = [ + "-Xep:AndroidJdkLibsChecker:OFF", + ], + manifest = "textembedder/AndroidManifest.xml", + deps = [ + "//mediapipe/framework:calculator_options_java_proto_lite", + "//mediapipe/java/com/google/mediapipe/framework:android_framework", + "//mediapipe/tasks/cc/components/containers/proto:embeddings_java_proto_lite", + "//mediapipe/tasks/cc/components/processors/proto:embedder_options_java_proto_lite", + "//mediapipe/tasks/cc/core/proto:base_options_java_proto_lite", + "//mediapipe/tasks/cc/text/text_embedder/proto:text_embedder_graph_options_java_proto_lite", + "//mediapipe/tasks/java/com/google/mediapipe/tasks/components/containers:embedding", + "//mediapipe/tasks/java/com/google/mediapipe/tasks/components/containers:embeddingresult", + "//mediapipe/tasks/java/com/google/mediapipe/tasks/components/utils:cosinesimilarity", "//mediapipe/tasks/java/com/google/mediapipe/tasks/core", "//mediapipe/tasks/java/com/google/mediapipe/tasks/text:libmediapipe_tasks_text_jni_lib", "//third_party:autovalue", diff --git a/mediapipe/tasks/java/com/google/mediapipe/tasks/text/textclassifier/TextClassifier.java b/mediapipe/tasks/java/com/google/mediapipe/tasks/text/textclassifier/TextClassifier.java index 341d6bf91..edb78a191 100644 --- a/mediapipe/tasks/java/com/google/mediapipe/tasks/text/textclassifier/TextClassifier.java +++ b/mediapipe/tasks/java/com/google/mediapipe/tasks/text/textclassifier/TextClassifier.java @@ -24,7 +24,7 @@ import com.google.mediapipe.framework.PacketGetter; import com.google.mediapipe.framework.ProtoUtil; import com.google.mediapipe.tasks.components.containers.ClassificationResult; import com.google.mediapipe.tasks.components.containers.proto.ClassificationsProto; -import com.google.mediapipe.tasks.components.processors.ClassifierOptions; +import com.google.mediapipe.tasks.components.processors.proto.ClassifierOptionsProto; import com.google.mediapipe.tasks.core.BaseOptions; import com.google.mediapipe.tasks.core.OutputHandler; import com.google.mediapipe.tasks.core.TaskInfo; @@ -169,6 +169,7 @@ public final class TextClassifier implements AutoCloseable { TaskRunner.create( context, TaskInfo.builder() + .setTaskName(TextClassifier.class.getSimpleName()) .setTaskGraphName(TASK_GRAPH_NAME) .setInputStreams(INPUT_STREAMS) .setOutputStreams(OUTPUT_STREAMS) @@ -216,20 +217,79 @@ public final class TextClassifier implements AutoCloseable { public abstract Builder setBaseOptions(BaseOptions value); /** - * Sets the optional {@link ClassifierOptions} controling classification behavior, such as - * score threshold, number of results, etc. + * Sets the optional locale to use for display names specified through the TFLite Model + * Metadata, if any. */ - public abstract Builder setClassifierOptions(ClassifierOptions classifierOptions); + public abstract Builder setDisplayNamesLocale(String locale); - public abstract TextClassifierOptions build(); + /** + * Sets the optional maximum number of top-scored classification results to return. + * + *

If not set, all available results are returned. If set, must be > 0. + */ + public abstract Builder setMaxResults(Integer maxResults); + + /** + * Sets the optional score threshold. Results with score below this value are rejected. + * + *

Overrides the score threshold specified in the TFLite Model Metadata, if any. + */ + public abstract Builder setScoreThreshold(Float scoreThreshold); + + /** + * Sets the optional allowlist of category names. + * + *

If non-empty, detection results whose category name is not in this set will be filtered + * out. Duplicate or unknown category names are ignored. Mutually exclusive with {@code + * categoryDenylist}. + */ + public abstract Builder setCategoryAllowlist(List categoryAllowlist); + + /** + * Sets the optional denylist of category names. + * + *

If non-empty, detection results whose category name is in this set will be filtered out. + * Duplicate or unknown category names are ignored. Mutually exclusive with {@code + * categoryAllowlist}. + */ + public abstract Builder setCategoryDenylist(List categoryDenylist); + + abstract TextClassifierOptions autoBuild(); + + /** + * Validates and builds the {@link TextClassifierOptions} instance. + * + * @throws IllegalArgumentException if any of the set options are invalid. + */ + public final TextClassifierOptions build() { + TextClassifierOptions options = autoBuild(); + if (options.maxResults().isPresent() && options.maxResults().get() <= 0) { + throw new IllegalArgumentException("If specified, maxResults must be > 0."); + } + if (!options.categoryAllowlist().isEmpty() && !options.categoryDenylist().isEmpty()) { + throw new IllegalArgumentException( + "Category allowlist and denylist are mutually exclusive."); + } + return options; + } } abstract BaseOptions baseOptions(); - abstract Optional classifierOptions(); + abstract Optional displayNamesLocale(); + + abstract Optional maxResults(); + + abstract Optional scoreThreshold(); + + abstract List categoryAllowlist(); + + abstract List categoryDenylist(); public static Builder builder() { - return new AutoValue_TextClassifier_TextClassifierOptions.Builder(); + return new AutoValue_TextClassifier_TextClassifierOptions.Builder() + .setCategoryAllowlist(Collections.emptyList()) + .setCategoryDenylist(Collections.emptyList()); } /** Converts a {@link TextClassifierOptions} to a {@link CalculatorOptions} protobuf message. */ @@ -238,12 +298,21 @@ public final class TextClassifier implements AutoCloseable { BaseOptionsProto.BaseOptions.Builder baseOptionsBuilder = BaseOptionsProto.BaseOptions.newBuilder(); baseOptionsBuilder.mergeFrom(convertBaseOptionsToProto(baseOptions())); + ClassifierOptionsProto.ClassifierOptions.Builder classifierOptionsBuilder = + ClassifierOptionsProto.ClassifierOptions.newBuilder(); + displayNamesLocale().ifPresent(classifierOptionsBuilder::setDisplayNamesLocale); + maxResults().ifPresent(classifierOptionsBuilder::setMaxResults); + scoreThreshold().ifPresent(classifierOptionsBuilder::setScoreThreshold); + if (!categoryAllowlist().isEmpty()) { + classifierOptionsBuilder.addAllCategoryAllowlist(categoryAllowlist()); + } + if (!categoryDenylist().isEmpty()) { + classifierOptionsBuilder.addAllCategoryDenylist(categoryDenylist()); + } TextClassifierGraphOptionsProto.TextClassifierGraphOptions.Builder taskOptionsBuilder = TextClassifierGraphOptionsProto.TextClassifierGraphOptions.newBuilder() - .setBaseOptions(baseOptionsBuilder); - if (classifierOptions().isPresent()) { - taskOptionsBuilder.setClassifierOptions(classifierOptions().get().convertToProto()); - } + .setBaseOptions(baseOptionsBuilder) + .setClassifierOptions(classifierOptionsBuilder); return CalculatorOptions.newBuilder() .setExtension( TextClassifierGraphOptionsProto.TextClassifierGraphOptions.ext, diff --git a/mediapipe/tasks/java/com/google/mediapipe/tasks/text/textembedder/AndroidManifest.xml b/mediapipe/tasks/java/com/google/mediapipe/tasks/text/textembedder/AndroidManifest.xml new file mode 100644 index 000000000..d9c885d16 --- /dev/null +++ b/mediapipe/tasks/java/com/google/mediapipe/tasks/text/textembedder/AndroidManifest.xml @@ -0,0 +1,8 @@ + + + + + + diff --git a/mediapipe/tasks/java/com/google/mediapipe/tasks/text/textembedder/TextEmbedder.java b/mediapipe/tasks/java/com/google/mediapipe/tasks/text/textembedder/TextEmbedder.java new file mode 100644 index 000000000..28f351d4b --- /dev/null +++ b/mediapipe/tasks/java/com/google/mediapipe/tasks/text/textembedder/TextEmbedder.java @@ -0,0 +1,276 @@ +// Copyright 2022 The MediaPipe Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package com.google.mediapipe.tasks.text.textembedder; + +import android.content.Context; +import android.os.ParcelFileDescriptor; +import com.google.auto.value.AutoValue; +import com.google.mediapipe.proto.CalculatorOptionsProto.CalculatorOptions; +import com.google.mediapipe.framework.MediaPipeException; +import com.google.mediapipe.framework.Packet; +import com.google.mediapipe.framework.PacketGetter; +import com.google.mediapipe.framework.ProtoUtil; +import com.google.mediapipe.tasks.components.containers.Embedding; +import com.google.mediapipe.tasks.components.containers.EmbeddingResult; +import com.google.mediapipe.tasks.components.containers.proto.EmbeddingsProto; +import com.google.mediapipe.tasks.components.processors.proto.EmbedderOptionsProto; +import com.google.mediapipe.tasks.components.utils.CosineSimilarity; +import com.google.mediapipe.tasks.core.BaseOptions; +import com.google.mediapipe.tasks.core.OutputHandler; +import com.google.mediapipe.tasks.core.TaskInfo; +import com.google.mediapipe.tasks.core.TaskOptions; +import com.google.mediapipe.tasks.core.TaskRunner; +import com.google.mediapipe.tasks.core.proto.BaseOptionsProto; +import com.google.mediapipe.tasks.text.textembedder.proto.TextEmbedderGraphOptionsProto; +import java.io.File; +import java.io.IOException; +import java.util.Arrays; +import java.util.Collections; +import java.util.HashMap; +import java.util.List; +import java.util.Map; + +/** + * Performs embedding extraction on text. + * + *

This API expects a TFLite model with (optional) TFLite Model Metadata. + * + *

Metadata is required for models with int32 input tensors because it contains the input process + * unit for the model's Tokenizer. No metadata is required for models with string input tensors. + * + *

    + *
  • Input tensors + *
      + *
    • Three input tensors ({@code kTfLiteInt32}) of shape {@code [batch_size x + * bert_max_seq_len]} representing the input ids, mask ids, and segment ids. This input + * signature requires a Bert Tokenizer process unit in the model metadata. + *
    • Or one input tensor ({@code kTfLiteInt32}) of shape {@code [batch_size x + * max_seq_len]} representing the input ids. This input signature requires a Regex + * Tokenizer process unit in the model metadata. + *
    • Or one input tensor ({@code kTfLiteString}) that is shapeless or has shape {@code + * [1]} containing the input string. + *
    + *
  • At least one output tensor ({@code kTfLiteFloat32}/{@code kTfLiteUint8}) with shape {@code + * [1 x N]} where N is the number of dimensions in the produced embeddings. + *
+ */ +public final class TextEmbedder implements AutoCloseable { + private static final String TAG = TextEmbedder.class.getSimpleName(); + private static final String TEXT_IN_STREAM_NAME = "text_in"; + + @SuppressWarnings("ConstantCaseForConstants") + private static final List INPUT_STREAMS = + Collections.unmodifiableList(Arrays.asList("TEXT:" + TEXT_IN_STREAM_NAME)); + + @SuppressWarnings("ConstantCaseForConstants") + private static final List OUTPUT_STREAMS = + Collections.unmodifiableList(Arrays.asList("EMBEDDINGS:embeddings_out")); + + private static final int EMBEDDINGS_OUT_STREAM_INDEX = 0; + private static final String TASK_GRAPH_NAME = + "mediapipe.tasks.text.text_embedder.TextEmbedderGraph"; + private final TaskRunner runner; + + static { + System.loadLibrary("mediapipe_tasks_text_jni"); + ProtoUtil.registerTypeName( + EmbeddingsProto.EmbeddingResult.class, + "mediapipe.tasks.components.containers.proto.EmbeddingResult"); + } + + /** + * Creates a {@link TextEmbedder} instance from a model file and the default {@link + * TextEmbedderOptions}. + * + * @param context an Android {@link Context}. + * @param modelPath path to the text model with metadata in the assets. + * @throws MediaPipeException if there is is an error during {@link TextEmbedder} creation. + */ + public static TextEmbedder createFromFile(Context context, String modelPath) { + BaseOptions baseOptions = BaseOptions.builder().setModelAssetPath(modelPath).build(); + return createFromOptions( + context, TextEmbedderOptions.builder().setBaseOptions(baseOptions).build()); + } + + /** + * Creates a {@link TextEmbedder} instance from a model file and the default {@link + * TextEmbedderOptions}. + * + * @param context an Android {@link Context}. + * @param modelFile the text model {@link File} instance. + * @throws IOException if an I/O error occurs when opening the tflite model file. + * @throws MediaPipeException if there is an error during {@link TextEmbedder} creation. + */ + public static TextEmbedder createFromFile(Context context, File modelFile) throws IOException { + try (ParcelFileDescriptor descriptor = + ParcelFileDescriptor.open(modelFile, ParcelFileDescriptor.MODE_READ_ONLY)) { + BaseOptions baseOptions = + BaseOptions.builder().setModelAssetFileDescriptor(descriptor.getFd()).build(); + return createFromOptions( + context, TextEmbedderOptions.builder().setBaseOptions(baseOptions).build()); + } + } + + /** + * Creates a {@link TextEmbedder} instance from {@link TextEmbedderOptions}. + * + * @param context an Android {@link Context}. + * @param options a {@link TextEmbedderOptions} instance. + * @throws MediaPipeException if there is an error during {@link TextEmbedder} creation. + */ + public static TextEmbedder createFromOptions(Context context, TextEmbedderOptions options) { + OutputHandler handler = new OutputHandler<>(); + handler.setOutputPacketConverter( + new OutputHandler.OutputPacketConverter() { + @Override + public TextEmbedderResult convertToTaskResult(List packets) { + try { + return TextEmbedderResult.create( + EmbeddingResult.createFromProto( + PacketGetter.getProto( + packets.get(EMBEDDINGS_OUT_STREAM_INDEX), + EmbeddingsProto.EmbeddingResult.getDefaultInstance())), + packets.get(EMBEDDINGS_OUT_STREAM_INDEX).getTimestamp()); + } catch (IOException e) { + throw new MediaPipeException( + MediaPipeException.StatusCode.INTERNAL.ordinal(), e.getMessage()); + } + } + + @Override + public Void convertToTaskInput(List packets) { + return null; + } + }); + TaskRunner runner = + TaskRunner.create( + context, + TaskInfo.builder() + .setTaskName(TextEmbedder.class.getSimpleName()) + .setTaskGraphName(TASK_GRAPH_NAME) + .setInputStreams(INPUT_STREAMS) + .setOutputStreams(OUTPUT_STREAMS) + .setTaskOptions(options) + .setEnableFlowLimiting(false) + .build(), + handler); + return new TextEmbedder(runner); + } + + /** + * Constructor to initialize a {@link TextEmbedder} from a {@link TaskRunner}. + * + * @param runner a {@link TaskRunner}. + */ + private TextEmbedder(TaskRunner runner) { + this.runner = runner; + } + + /** + * Performs embedding extraction on the input text. + * + * @param inputText a {@link String} for processing. + */ + public TextEmbedderResult embed(String inputText) { + Map inputPackets = new HashMap<>(); + inputPackets.put(TEXT_IN_STREAM_NAME, runner.getPacketCreator().createString(inputText)); + return (TextEmbedderResult) runner.process(inputPackets); + } + + /** Closes and cleans up the {@link TextEmbedder}. */ + @Override + public void close() { + runner.close(); + } + + /** + * Utility function to compute cosine + * similarity between two {@link Embedding} objects. + * + * @throws IllegalArgumentException if the embeddings are of different types (float vs. + * quantized), have different sizes, or have an L2-norm of 0. + */ + public static double cosineSimilarity(Embedding u, Embedding v) { + return CosineSimilarity.compute(u, v); + } + + /** Options for setting up a {@link TextEmbedder}. */ + @AutoValue + public abstract static class TextEmbedderOptions extends TaskOptions { + + /** Builder for {@link TextEmbedderOptions}. */ + @AutoValue.Builder + public abstract static class Builder { + /** Sets the base options for the text embedder task. */ + public abstract Builder setBaseOptions(BaseOptions value); + + /** + * Sets whether L2 normalization should be performed on the returned embeddings. Use this + * option only if the model does not already contain a native L2_NORMALIZATION TF + * Lite Op. In most cases, this is already the case and L2 norm is thus achieved through TF + * Lite inference. + * + *

False by default. + */ + public abstract Builder setL2Normalize(boolean l2Normalize); + + /** + * Sets whether the returned embedding should be quantized to bytes via scalar quantization. + * Embeddings are implicitly assumed to be unit-norm and therefore any dimensions is + * guaranteed to have value in [-1.0, 1.0]. Use {@link #setL2Normalize(boolean)} + * if this is not the case. + * + *

False by default. + */ + public abstract Builder setQuantize(boolean quantize); + + public abstract TextEmbedderOptions build(); + } + + abstract BaseOptions baseOptions(); + + abstract boolean l2Normalize(); + + abstract boolean quantize(); + + public static Builder builder() { + return new AutoValue_TextEmbedder_TextEmbedderOptions.Builder() + .setL2Normalize(false) + .setQuantize(false); + } + + /** Converts a {@link TextEmbedderOptions} to a {@link CalculatorOptions} protobuf message. */ + @Override + public CalculatorOptions convertToCalculatorOptionsProto() { + BaseOptionsProto.BaseOptions.Builder baseOptionsBuilder = + BaseOptionsProto.BaseOptions.newBuilder(); + baseOptionsBuilder.mergeFrom(convertBaseOptionsToProto(baseOptions())); + EmbedderOptionsProto.EmbedderOptions.Builder embedderOptionsBuilder = + EmbedderOptionsProto.EmbedderOptions.newBuilder(); + embedderOptionsBuilder.setL2Normalize(l2Normalize()); + embedderOptionsBuilder.setQuantize(quantize()); + TextEmbedderGraphOptionsProto.TextEmbedderGraphOptions.Builder taskOptionsBuilder = + TextEmbedderGraphOptionsProto.TextEmbedderGraphOptions.newBuilder() + .setBaseOptions(baseOptionsBuilder) + .setEmbedderOptions(embedderOptionsBuilder); + return CalculatorOptions.newBuilder() + .setExtension( + TextEmbedderGraphOptionsProto.TextEmbedderGraphOptions.ext, + taskOptionsBuilder.build()) + .build(); + } + } +} diff --git a/mediapipe/tasks/java/com/google/mediapipe/tasks/text/textembedder/TextEmbedderResult.java b/mediapipe/tasks/java/com/google/mediapipe/tasks/text/textembedder/TextEmbedderResult.java new file mode 100644 index 000000000..9d8e108ec --- /dev/null +++ b/mediapipe/tasks/java/com/google/mediapipe/tasks/text/textembedder/TextEmbedderResult.java @@ -0,0 +1,54 @@ +// Copyright 2022 The MediaPipe Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package com.google.mediapipe.tasks.text.textembedder; + +import com.google.auto.value.AutoValue; +import com.google.mediapipe.tasks.components.containers.EmbeddingResult; +import com.google.mediapipe.tasks.components.containers.proto.EmbeddingsProto; +import com.google.mediapipe.tasks.core.TaskResult; + +/** Represents the embedding results generated by {@link TextEmbedder}. */ +@AutoValue +public abstract class TextEmbedderResult implements TaskResult { + + /** + * Creates an {@link TextEmbedderResult} instance. + * + * @param embeddingResult the {@link EmbeddingResult} object containing one embedding per embedder + * head. + * @param timestampMs a timestamp for this result. + */ + static TextEmbedderResult create(EmbeddingResult embeddingResult, long timestampMs) { + return new AutoValue_TextEmbedderResult(embeddingResult, timestampMs); + } + + /** + * Creates an {@link TextEmbedderResult} instance from a {@link EmbeddingsProto.EmbeddingResult} + * protobuf message. + * + * @param proto the {@link EmbeddingsProto.EmbeddingResult} protobuf message to convert. + * @param timestampMs a timestamp for this result. + */ + static TextEmbedderResult createFromProto( + EmbeddingsProto.EmbeddingResult proto, long timestampMs) { + return create(EmbeddingResult.createFromProto(proto), timestampMs); + } + + /** Contains one embedding per embedder head. */ + public abstract EmbeddingResult embeddingResult(); + + @Override + public abstract long timestampMs(); +} diff --git a/mediapipe/tasks/java/com/google/mediapipe/tasks/vision/BUILD b/mediapipe/tasks/java/com/google/mediapipe/tasks/vision/BUILD index 4dc4a547e..0c30d7646 100644 --- a/mediapipe/tasks/java/com/google/mediapipe/tasks/vision/BUILD +++ b/mediapipe/tasks/java/com/google/mediapipe/tasks/vision/BUILD @@ -14,7 +14,7 @@ licenses(["notice"]) -package(default_visibility = ["//mediapipe/tasks:internal"]) +package(default_visibility = ["//visibility:public"]) android_library( name = "core", @@ -43,6 +43,8 @@ cc_binary( "//mediapipe/java/com/google/mediapipe/framework/jni:mediapipe_framework_jni", "//mediapipe/tasks/cc/vision/gesture_recognizer:gesture_recognizer_graph", "//mediapipe/tasks/cc/vision/image_classifier:image_classifier_graph", + "//mediapipe/tasks/cc/vision/image_embedder:image_embedder_graph", + "//mediapipe/tasks/cc/vision/image_segmenter:image_segmenter_graph", "//mediapipe/tasks/cc/vision/object_detector:object_detector_graph", "//mediapipe/tasks/java/com/google/mediapipe/tasks/core/jni:model_resources_cache_jni", ], @@ -96,12 +98,11 @@ android_library( "//mediapipe/framework:calculator_options_java_proto_lite", "//mediapipe/java/com/google/mediapipe/framework:android_framework", "//mediapipe/java/com/google/mediapipe/framework/image", - "//mediapipe/tasks/cc/components/containers/proto:category_java_proto_lite", "//mediapipe/tasks/cc/components/containers/proto:classifications_java_proto_lite", + "//mediapipe/tasks/cc/components/processors/proto:classifier_options_java_proto_lite", "//mediapipe/tasks/cc/core/proto:base_options_java_proto_lite", "//mediapipe/tasks/cc/vision/image_classifier/proto:image_classifier_graph_options_java_proto_lite", "//mediapipe/tasks/java/com/google/mediapipe/tasks/components/containers:classificationresult", - "//mediapipe/tasks/java/com/google/mediapipe/tasks/components/processors:classifieroptions", "//mediapipe/tasks/java/com/google/mediapipe/tasks/core", "//third_party:autovalue", "@maven//:com_google_guava_guava", @@ -135,6 +136,7 @@ android_library( "//mediapipe/tasks/cc/vision/hand_landmarker/proto:hand_landmarks_detector_graph_options_java_proto_lite", "//mediapipe/tasks/java/com/google/mediapipe/tasks/components/containers:category", "//mediapipe/tasks/java/com/google/mediapipe/tasks/components/containers:landmark", + "//mediapipe/tasks/java/com/google/mediapipe/tasks/components/containers:normalized_landmark", "//mediapipe/tasks/java/com/google/mediapipe/tasks/components/processors:classifieroptions", "//mediapipe/tasks/java/com/google/mediapipe/tasks/core", "//third_party:autovalue", @@ -145,6 +147,7 @@ android_library( android_library( name = "handlandmarker", srcs = [ + "handlandmarker/HandLandmark.java", "handlandmarker/HandLandmarker.java", "handlandmarker/HandLandmarkerResult.java", ], @@ -166,6 +169,60 @@ android_library( "//mediapipe/tasks/cc/vision/hand_landmarker/proto:hand_landmarks_detector_graph_options_java_proto_lite", "//mediapipe/tasks/java/com/google/mediapipe/tasks/components/containers:category", "//mediapipe/tasks/java/com/google/mediapipe/tasks/components/containers:landmark", + "//mediapipe/tasks/java/com/google/mediapipe/tasks/components/containers:normalized_landmark", + "//mediapipe/tasks/java/com/google/mediapipe/tasks/core", + "//third_party:autovalue", + "@maven//:androidx_annotation_annotation", + "@maven//:com_google_guava_guava", + ], +) + +android_library( + name = "imagesegmenter", + srcs = [ + "imagesegmenter/ImageSegmenter.java", + "imagesegmenter/ImageSegmenterResult.java", + ], + javacopts = [ + "-Xep:AndroidJdkLibsChecker:OFF", + ], + manifest = "imagesegmenter/AndroidManifest.xml", + deps = [ + ":core", + "//mediapipe/framework:calculator_options_java_proto_lite", + "//mediapipe/java/com/google/mediapipe/framework:android_framework", + "//mediapipe/java/com/google/mediapipe/framework/image", + "//mediapipe/tasks/cc/core/proto:base_options_java_proto_lite", + "//mediapipe/tasks/cc/vision/image_segmenter/proto:image_segmenter_graph_options_java_proto_lite", + "//mediapipe/tasks/cc/vision/image_segmenter/proto:segmenter_options_java_proto_lite", + "//mediapipe/tasks/java/com/google/mediapipe/tasks/core", + "//third_party:autovalue", + "@maven//:com_google_guava_guava", + ], +) + +android_library( + name = "imageembedder", + srcs = [ + "imageembedder/ImageEmbedder.java", + "imageembedder/ImageEmbedderResult.java", + ], + javacopts = [ + "-Xep:AndroidJdkLibsChecker:OFF", + ], + manifest = "imageembedder/AndroidManifest.xml", + deps = [ + ":core", + "//mediapipe/framework:calculator_options_java_proto_lite", + "//mediapipe/java/com/google/mediapipe/framework:android_framework", + "//mediapipe/java/com/google/mediapipe/framework/image", + "//mediapipe/tasks/cc/components/containers/proto:embeddings_java_proto_lite", + "//mediapipe/tasks/cc/components/processors/proto:embedder_options_java_proto_lite", + "//mediapipe/tasks/cc/core/proto:base_options_java_proto_lite", + "//mediapipe/tasks/cc/vision/image_embedder/proto:image_embedder_graph_options_java_proto_lite", + "//mediapipe/tasks/java/com/google/mediapipe/tasks/components/containers:embedding", + "//mediapipe/tasks/java/com/google/mediapipe/tasks/components/containers:embeddingresult", + "//mediapipe/tasks/java/com/google/mediapipe/tasks/components/utils:cosinesimilarity", "//mediapipe/tasks/java/com/google/mediapipe/tasks/core", "//third_party:autovalue", "@maven//:com_google_guava_guava", diff --git a/mediapipe/tasks/java/com/google/mediapipe/tasks/vision/gesturerecognizer/GestureRecognizer.java b/mediapipe/tasks/java/com/google/mediapipe/tasks/vision/gesturerecognizer/GestureRecognizer.java index 7cbedb32e..a933d2f65 100644 --- a/mediapipe/tasks/java/com/google/mediapipe/tasks/vision/gesturerecognizer/GestureRecognizer.java +++ b/mediapipe/tasks/java/com/google/mediapipe/tasks/vision/gesturerecognizer/GestureRecognizer.java @@ -164,7 +164,9 @@ public final class GestureRecognizer extends BaseVisionTaskApi { new ArrayList<>(), new ArrayList<>(), new ArrayList<>(), - packets.get(HAND_GESTURES_OUT_STREAM_INDEX).getTimestamp()); + BaseVisionTaskApi.generateResultTimestampMs( + recognizerOptions.runningMode(), + packets.get(HAND_GESTURES_OUT_STREAM_INDEX))); } return GestureRecognizerResult.create( PacketGetter.getProtoVector( @@ -192,6 +194,8 @@ public final class GestureRecognizer extends BaseVisionTaskApi { TaskRunner.create( context, TaskInfo.builder() + .setTaskName(GestureRecognizer.class.getSimpleName()) + .setTaskRunningModeName(recognizerOptions.runningMode().name()) .setTaskGraphName(TASK_GRAPH_NAME) .setInputStreams(INPUT_STREAMS) .setOutputStreams(OUTPUT_STREAMS) diff --git a/mediapipe/tasks/java/com/google/mediapipe/tasks/vision/gesturerecognizer/GestureRecognizerResult.java b/mediapipe/tasks/java/com/google/mediapipe/tasks/vision/gesturerecognizer/GestureRecognizerResult.java index ef76bf226..90b92175d 100644 --- a/mediapipe/tasks/java/com/google/mediapipe/tasks/vision/gesturerecognizer/GestureRecognizerResult.java +++ b/mediapipe/tasks/java/com/google/mediapipe/tasks/vision/gesturerecognizer/GestureRecognizerResult.java @@ -15,13 +15,12 @@ package com.google.mediapipe.tasks.vision.gesturerecognizer; import com.google.auto.value.AutoValue; -import com.google.mediapipe.formats.proto.LandmarkProto.Landmark; -import com.google.mediapipe.formats.proto.LandmarkProto.LandmarkList; -import com.google.mediapipe.formats.proto.LandmarkProto.NormalizedLandmark; -import com.google.mediapipe.formats.proto.LandmarkProto.NormalizedLandmarkList; +import com.google.mediapipe.formats.proto.LandmarkProto; import com.google.mediapipe.formats.proto.ClassificationProto.Classification; import com.google.mediapipe.formats.proto.ClassificationProto.ClassificationList; import com.google.mediapipe.tasks.components.containers.Category; +import com.google.mediapipe.tasks.components.containers.Landmark; +import com.google.mediapipe.tasks.components.containers.NormalizedLandmark; import com.google.mediapipe.tasks.core.TaskResult; import java.util.ArrayList; import java.util.Collections; @@ -43,41 +42,36 @@ public abstract class GestureRecognizerResult implements TaskResult { * @param gesturesProto a List of {@link ClassificationList} */ static GestureRecognizerResult create( - List landmarksProto, - List worldLandmarksProto, + List landmarksProto, + List worldLandmarksProto, List handednessesProto, List gesturesProto, long timestampMs) { - List> multiHandLandmarks = - new ArrayList<>(); - List> multiHandWorldLandmarks = - new ArrayList<>(); + List> multiHandLandmarks = new ArrayList<>(); + List> multiHandWorldLandmarks = new ArrayList<>(); List> multiHandHandednesses = new ArrayList<>(); List> multiHandGestures = new ArrayList<>(); - for (NormalizedLandmarkList handLandmarksProto : landmarksProto) { - List handLandmarks = - new ArrayList<>(); + for (LandmarkProto.NormalizedLandmarkList handLandmarksProto : landmarksProto) { + List handLandmarks = new ArrayList<>(); multiHandLandmarks.add(handLandmarks); - for (NormalizedLandmark handLandmarkProto : handLandmarksProto.getLandmarkList()) { + for (LandmarkProto.NormalizedLandmark handLandmarkProto : + handLandmarksProto.getLandmarkList()) { handLandmarks.add( - com.google.mediapipe.tasks.components.containers.Landmark.create( - handLandmarkProto.getX(), - handLandmarkProto.getY(), - handLandmarkProto.getZ(), - true)); + com.google.mediapipe.tasks.components.containers.NormalizedLandmark.create( + handLandmarkProto.getX(), handLandmarkProto.getY(), handLandmarkProto.getZ())); } } - for (LandmarkList handWorldLandmarksProto : worldLandmarksProto) { + for (LandmarkProto.LandmarkList handWorldLandmarksProto : worldLandmarksProto) { List handWorldLandmarks = new ArrayList<>(); multiHandWorldLandmarks.add(handWorldLandmarks); - for (Landmark handWorldLandmarkProto : handWorldLandmarksProto.getLandmarkList()) { + for (LandmarkProto.Landmark handWorldLandmarkProto : + handWorldLandmarksProto.getLandmarkList()) { handWorldLandmarks.add( com.google.mediapipe.tasks.components.containers.Landmark.create( handWorldLandmarkProto.getX(), handWorldLandmarkProto.getY(), - handWorldLandmarkProto.getZ(), - false)); + handWorldLandmarkProto.getZ())); } } for (ClassificationList handednessProto : handednessesProto) { @@ -118,11 +112,10 @@ public abstract class GestureRecognizerResult implements TaskResult { public abstract long timestampMs(); /** Hand landmarks of detected hands. */ - public abstract List> landmarks(); + public abstract List> landmarks(); /** Hand landmarks in world coordniates of detected hands. */ - public abstract List> - worldLandmarks(); + public abstract List> worldLandmarks(); /** Handedness of detected hands. */ public abstract List> handednesses(); diff --git a/mediapipe/tasks/java/com/google/mediapipe/tasks/vision/handlandmarker/HandLandmark.java b/mediapipe/tasks/java/com/google/mediapipe/tasks/vision/handlandmarker/HandLandmark.java new file mode 100644 index 000000000..7b21ebddf --- /dev/null +++ b/mediapipe/tasks/java/com/google/mediapipe/tasks/vision/handlandmarker/HandLandmark.java @@ -0,0 +1,72 @@ +// Copyright 2022 The MediaPipe Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package com.google.mediapipe.tasks.vision.handlandmarker; + +import androidx.annotation.IntDef; + +/** The 21 hand landmarks. */ +public final class HandLandmark { + public static final int NUM_LANDMARKS = 21; + + public static final int WRIST = 0; + public static final int THUMB_CMC = 1; + public static final int THUMB_MCP = 2; + public static final int THUMB_IP = 3; + public static final int THUMB_TIP = 4; + public static final int INDEX_FINGER_MCP = 5; + public static final int INDEX_FINGER_PIP = 6; + public static final int INDEX_FINGER_DIP = 7; + public static final int INDEX_FINGER_TIP = 8; + public static final int MIDDLE_FINGER_MCP = 9; + public static final int MIDDLE_FINGER_PIP = 10; + public static final int MIDDLE_FINGER_DIP = 11; + public static final int MIDDLE_FINGER_TIP = 12; + public static final int RING_FINGER_MCP = 13; + public static final int RING_FINGER_PIP = 14; + public static final int RING_FINGER_DIP = 15; + public static final int RING_FINGER_TIP = 16; + public static final int PINKY_MCP = 17; + public static final int PINKY_PIP = 18; + public static final int PINKY_DIP = 19; + public static final int PINKY_TIP = 20; + + /** Represents a hand landmark type. */ + @IntDef({ + WRIST, + THUMB_CMC, + THUMB_MCP, + THUMB_IP, + THUMB_TIP, + INDEX_FINGER_MCP, + INDEX_FINGER_PIP, + INDEX_FINGER_DIP, + INDEX_FINGER_TIP, + MIDDLE_FINGER_MCP, + MIDDLE_FINGER_PIP, + MIDDLE_FINGER_DIP, + MIDDLE_FINGER_TIP, + RING_FINGER_MCP, + RING_FINGER_PIP, + RING_FINGER_DIP, + RING_FINGER_TIP, + PINKY_MCP, + PINKY_PIP, + PINKY_DIP, + PINKY_TIP, + }) + public @interface HandLandmarkType {} + + private HandLandmark() {} +} diff --git a/mediapipe/tasks/java/com/google/mediapipe/tasks/vision/handlandmarker/HandLandmarker.java b/mediapipe/tasks/java/com/google/mediapipe/tasks/vision/handlandmarker/HandLandmarker.java index 9be489bbe..1d08ab928 100644 --- a/mediapipe/tasks/java/com/google/mediapipe/tasks/vision/handlandmarker/HandLandmarker.java +++ b/mediapipe/tasks/java/com/google/mediapipe/tasks/vision/handlandmarker/HandLandmarker.java @@ -156,7 +156,8 @@ public final class HandLandmarker extends BaseVisionTaskApi { new ArrayList<>(), new ArrayList<>(), new ArrayList<>(), - packets.get(LANDMARKS_OUT_STREAM_INDEX).getTimestamp()); + BaseVisionTaskApi.generateResultTimestampMs( + landmarkerOptions.runningMode(), packets.get(LANDMARKS_OUT_STREAM_INDEX))); } return HandLandmarkerResult.create( PacketGetter.getProtoVector( @@ -182,6 +183,8 @@ public final class HandLandmarker extends BaseVisionTaskApi { TaskRunner.create( context, TaskInfo.builder() + .setTaskName(HandLandmarker.class.getSimpleName()) + .setTaskRunningModeName(landmarkerOptions.runningMode().name()) .setTaskGraphName(TASK_GRAPH_NAME) .setInputStreams(INPUT_STREAMS) .setOutputStreams(OUTPUT_STREAMS) diff --git a/mediapipe/tasks/java/com/google/mediapipe/tasks/vision/handlandmarker/HandLandmarkerResult.java b/mediapipe/tasks/java/com/google/mediapipe/tasks/vision/handlandmarker/HandLandmarkerResult.java index 2889b0e0b..9092c0a2d 100644 --- a/mediapipe/tasks/java/com/google/mediapipe/tasks/vision/handlandmarker/HandLandmarkerResult.java +++ b/mediapipe/tasks/java/com/google/mediapipe/tasks/vision/handlandmarker/HandLandmarkerResult.java @@ -15,13 +15,12 @@ package com.google.mediapipe.tasks.vision.handlandmarker; import com.google.auto.value.AutoValue; -import com.google.mediapipe.formats.proto.LandmarkProto.Landmark; -import com.google.mediapipe.formats.proto.LandmarkProto.LandmarkList; -import com.google.mediapipe.formats.proto.LandmarkProto.NormalizedLandmark; -import com.google.mediapipe.formats.proto.LandmarkProto.NormalizedLandmarkList; +import com.google.mediapipe.formats.proto.LandmarkProto; import com.google.mediapipe.formats.proto.ClassificationProto.Classification; import com.google.mediapipe.formats.proto.ClassificationProto.ClassificationList; import com.google.mediapipe.tasks.components.containers.Category; +import com.google.mediapipe.tasks.components.containers.Landmark; +import com.google.mediapipe.tasks.components.containers.NormalizedLandmark; import com.google.mediapipe.tasks.core.TaskResult; import java.util.ArrayList; import java.util.Collections; @@ -32,47 +31,41 @@ import java.util.List; public abstract class HandLandmarkerResult implements TaskResult { /** - * Creates a {@link HandLandmarkerResult} instance from the lists of landmarks and - * handedness protobuf messages. + * Creates a {@link HandLandmarkerResult} instance from the lists of landmarks and handedness + * protobuf messages. * * @param landmarksProto a List of {@link NormalizedLandmarkList} * @param worldLandmarksProto a List of {@link LandmarkList} * @param handednessesProto a List of {@link ClassificationList} */ static HandLandmarkerResult create( - List landmarksProto, - List worldLandmarksProto, + List landmarksProto, + List worldLandmarksProto, List handednessesProto, long timestampMs) { - List> multiHandLandmarks = - new ArrayList<>(); - List> multiHandWorldLandmarks = - new ArrayList<>(); + List> multiHandLandmarks = new ArrayList<>(); + List> multiHandWorldLandmarks = new ArrayList<>(); List> multiHandHandednesses = new ArrayList<>(); - for (NormalizedLandmarkList handLandmarksProto : landmarksProto) { - List handLandmarks = - new ArrayList<>(); + for (LandmarkProto.NormalizedLandmarkList handLandmarksProto : landmarksProto) { + List handLandmarks = new ArrayList<>(); multiHandLandmarks.add(handLandmarks); - for (NormalizedLandmark handLandmarkProto : handLandmarksProto.getLandmarkList()) { + for (LandmarkProto.NormalizedLandmark handLandmarkProto : + handLandmarksProto.getLandmarkList()) { handLandmarks.add( - com.google.mediapipe.tasks.components.containers.Landmark.create( - handLandmarkProto.getX(), - handLandmarkProto.getY(), - handLandmarkProto.getZ(), - true)); + NormalizedLandmark.create( + handLandmarkProto.getX(), handLandmarkProto.getY(), handLandmarkProto.getZ())); } } - for (LandmarkList handWorldLandmarksProto : worldLandmarksProto) { - List handWorldLandmarks = - new ArrayList<>(); + for (LandmarkProto.LandmarkList handWorldLandmarksProto : worldLandmarksProto) { + List handWorldLandmarks = new ArrayList<>(); multiHandWorldLandmarks.add(handWorldLandmarks); - for (Landmark handWorldLandmarkProto : handWorldLandmarksProto.getLandmarkList()) { + for (LandmarkProto.Landmark handWorldLandmarkProto : + handWorldLandmarksProto.getLandmarkList()) { handWorldLandmarks.add( com.google.mediapipe.tasks.components.containers.Landmark.create( handWorldLandmarkProto.getX(), handWorldLandmarkProto.getY(), - handWorldLandmarkProto.getZ(), - false)); + handWorldLandmarkProto.getZ())); } } for (ClassificationList handednessProto : handednessesProto) { @@ -98,11 +91,10 @@ public abstract class HandLandmarkerResult implements TaskResult { public abstract long timestampMs(); /** Hand landmarks of detected hands. */ - public abstract List> landmarks(); + public abstract List> landmarks(); /** Hand landmarks in world coordniates of detected hands. */ - public abstract List> - worldLandmarks(); + public abstract List> worldLandmarks(); /** Handedness of detected hands. */ public abstract List> handednesses(); diff --git a/mediapipe/tasks/java/com/google/mediapipe/tasks/vision/imageclassifier/ImageClassifier.java b/mediapipe/tasks/java/com/google/mediapipe/tasks/vision/imageclassifier/ImageClassifier.java index 5e278804b..38482797c 100644 --- a/mediapipe/tasks/java/com/google/mediapipe/tasks/vision/imageclassifier/ImageClassifier.java +++ b/mediapipe/tasks/java/com/google/mediapipe/tasks/vision/imageclassifier/ImageClassifier.java @@ -27,7 +27,7 @@ import com.google.mediapipe.framework.image.BitmapImageBuilder; import com.google.mediapipe.framework.image.MPImage; import com.google.mediapipe.tasks.components.containers.ClassificationResult; import com.google.mediapipe.tasks.components.containers.proto.ClassificationsProto; -import com.google.mediapipe.tasks.components.processors.ClassifierOptions; +import com.google.mediapipe.tasks.components.processors.proto.ClassifierOptionsProto; import com.google.mediapipe.tasks.core.BaseOptions; import com.google.mediapipe.tasks.core.ErrorListener; import com.google.mediapipe.tasks.core.OutputHandler; @@ -197,6 +197,8 @@ public final class ImageClassifier extends BaseVisionTaskApi { TaskRunner.create( context, TaskInfo.builder() + .setTaskName(ImageClassifier.class.getSimpleName()) + .setTaskRunningModeName(options.runningMode().name()) .setTaskGraphName(TASK_GRAPH_NAME) .setInputStreams(INPUT_STREAMS) .setOutputStreams(OUTPUT_STREAMS) @@ -376,10 +378,42 @@ public final class ImageClassifier extends BaseVisionTaskApi { public abstract Builder setRunningMode(RunningMode runningMode); /** - * Sets the optional {@link ClassifierOptions} controling classification behavior, such as - * score threshold, number of results, etc. + * Sets the optional locale to use for display names specified through the TFLite Model + * Metadata, if any. */ - public abstract Builder setClassifierOptions(ClassifierOptions classifierOptions); + public abstract Builder setDisplayNamesLocale(String locale); + + /** + * Sets the optional maximum number of top-scored classification results to return. + * + *

If not set, all available results are returned. If set, must be > 0. + */ + public abstract Builder setMaxResults(Integer maxResults); + + /** + * Sets the optional score threshold. Results with score below this value are rejected. + * + *

Overrides the score threshold specified in the TFLite Model Metadata, if any. + */ + public abstract Builder setScoreThreshold(Float scoreThreshold); + + /** + * Sets the optional allowlist of category names. + * + *

If non-empty, detection results whose category name is not in this set will be filtered + * out. Duplicate or unknown category names are ignored. Mutually exclusive with {@code + * categoryDenylist}. + */ + public abstract Builder setCategoryAllowlist(List categoryAllowlist); + + /** + * Sets the optional denylist of category names. + * + *

If non-empty, detection results whose category name is in this set will be filtered out. + * Duplicate or unknown category names are ignored. Mutually exclusive with {@code + * categoryAllowlist}. + */ + public abstract Builder setCategoryDenylist(List categoryDenylist); /** * Sets the {@link ResultListener} to receive the classification results asynchronously when @@ -396,9 +430,7 @@ public final class ImageClassifier extends BaseVisionTaskApi { /** * Validates and builds the {@link ImageClassifierOptions} instance. * * - * @throws IllegalArgumentException if the result listener and the running mode are not - * properly configured. The result listener should only be set when the image classifier - * is in the live stream mode. + * @throws IllegalArgumentException if any of the set options are invalid. */ public final ImageClassifierOptions build() { ImageClassifierOptions options = autoBuild(); @@ -413,6 +445,13 @@ public final class ImageClassifier extends BaseVisionTaskApi { "The image classifier is in the image or video mode, a user-defined result listener" + " shouldn't be provided in ImageClassifierOptions."); } + if (options.maxResults().isPresent() && options.maxResults().get() <= 0) { + throw new IllegalArgumentException("If specified, maxResults must be > 0."); + } + if (!options.categoryAllowlist().isEmpty() && !options.categoryDenylist().isEmpty()) { + throw new IllegalArgumentException( + "Category allowlist and denylist are mutually exclusive."); + } return options; } } @@ -421,7 +460,15 @@ public final class ImageClassifier extends BaseVisionTaskApi { abstract RunningMode runningMode(); - abstract Optional classifierOptions(); + abstract Optional displayNamesLocale(); + + abstract Optional maxResults(); + + abstract Optional scoreThreshold(); + + abstract List categoryAllowlist(); + + abstract List categoryDenylist(); abstract Optional> resultListener(); @@ -429,7 +476,9 @@ public final class ImageClassifier extends BaseVisionTaskApi { public static Builder builder() { return new AutoValue_ImageClassifier_ImageClassifierOptions.Builder() - .setRunningMode(RunningMode.IMAGE); + .setRunningMode(RunningMode.IMAGE) + .setCategoryAllowlist(Collections.emptyList()) + .setCategoryDenylist(Collections.emptyList()); } /** @@ -441,12 +490,21 @@ public final class ImageClassifier extends BaseVisionTaskApi { BaseOptionsProto.BaseOptions.newBuilder(); baseOptionsBuilder.setUseStreamMode(runningMode() != RunningMode.IMAGE); baseOptionsBuilder.mergeFrom(convertBaseOptionsToProto(baseOptions())); + ClassifierOptionsProto.ClassifierOptions.Builder classifierOptionsBuilder = + ClassifierOptionsProto.ClassifierOptions.newBuilder(); + displayNamesLocale().ifPresent(classifierOptionsBuilder::setDisplayNamesLocale); + maxResults().ifPresent(classifierOptionsBuilder::setMaxResults); + scoreThreshold().ifPresent(classifierOptionsBuilder::setScoreThreshold); + if (!categoryAllowlist().isEmpty()) { + classifierOptionsBuilder.addAllCategoryAllowlist(categoryAllowlist()); + } + if (!categoryDenylist().isEmpty()) { + classifierOptionsBuilder.addAllCategoryDenylist(categoryDenylist()); + } ImageClassifierGraphOptionsProto.ImageClassifierGraphOptions.Builder taskOptionsBuilder = ImageClassifierGraphOptionsProto.ImageClassifierGraphOptions.newBuilder() - .setBaseOptions(baseOptionsBuilder); - if (classifierOptions().isPresent()) { - taskOptionsBuilder.setClassifierOptions(classifierOptions().get().convertToProto()); - } + .setBaseOptions(baseOptionsBuilder) + .setClassifierOptions(classifierOptionsBuilder); return CalculatorOptions.newBuilder() .setExtension( ImageClassifierGraphOptionsProto.ImageClassifierGraphOptions.ext, diff --git a/mediapipe/tasks/java/com/google/mediapipe/tasks/vision/imageembedder/AndroidManifest.xml b/mediapipe/tasks/java/com/google/mediapipe/tasks/vision/imageembedder/AndroidManifest.xml new file mode 100644 index 000000000..ebdb037d6 --- /dev/null +++ b/mediapipe/tasks/java/com/google/mediapipe/tasks/vision/imageembedder/AndroidManifest.xml @@ -0,0 +1,8 @@ + + + + + + diff --git a/mediapipe/tasks/java/com/google/mediapipe/tasks/vision/imageembedder/ImageEmbedder.java b/mediapipe/tasks/java/com/google/mediapipe/tasks/vision/imageembedder/ImageEmbedder.java new file mode 100644 index 000000000..488927257 --- /dev/null +++ b/mediapipe/tasks/java/com/google/mediapipe/tasks/vision/imageembedder/ImageEmbedder.java @@ -0,0 +1,470 @@ +// Copyright 2022 The MediaPipe Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package com.google.mediapipe.tasks.vision.imageembedder; + +import android.content.Context; +import android.os.ParcelFileDescriptor; +import com.google.auto.value.AutoValue; +import com.google.mediapipe.proto.CalculatorOptionsProto.CalculatorOptions; +import com.google.mediapipe.framework.AndroidPacketGetter; +import com.google.mediapipe.framework.MediaPipeException; +import com.google.mediapipe.framework.Packet; +import com.google.mediapipe.framework.PacketGetter; +import com.google.mediapipe.framework.ProtoUtil; +import com.google.mediapipe.framework.image.BitmapImageBuilder; +import com.google.mediapipe.framework.image.MPImage; +import com.google.mediapipe.tasks.components.containers.Embedding; +import com.google.mediapipe.tasks.components.containers.EmbeddingResult; +import com.google.mediapipe.tasks.components.containers.proto.EmbeddingsProto; +import com.google.mediapipe.tasks.components.processors.proto.EmbedderOptionsProto; +import com.google.mediapipe.tasks.components.utils.CosineSimilarity; +import com.google.mediapipe.tasks.core.BaseOptions; +import com.google.mediapipe.tasks.core.ErrorListener; +import com.google.mediapipe.tasks.core.OutputHandler; +import com.google.mediapipe.tasks.core.OutputHandler.ResultListener; +import com.google.mediapipe.tasks.core.TaskInfo; +import com.google.mediapipe.tasks.core.TaskOptions; +import com.google.mediapipe.tasks.core.TaskRunner; +import com.google.mediapipe.tasks.core.proto.BaseOptionsProto; +import com.google.mediapipe.tasks.vision.core.BaseVisionTaskApi; +import com.google.mediapipe.tasks.vision.core.ImageProcessingOptions; +import com.google.mediapipe.tasks.vision.core.RunningMode; +import com.google.mediapipe.tasks.vision.imageembedder.proto.ImageEmbedderGraphOptionsProto; +import java.io.File; +import java.io.IOException; +import java.nio.ByteBuffer; +import java.util.Arrays; +import java.util.Collections; +import java.util.List; +import java.util.Optional; + +/** + * Performs embedding extraction on images. + * + *

The API expects a TFLite model with optional, but strongly recommended, TFLite Model Metadata.. + * + *

The API supports models with one image input tensor and one or more output tensors. To be more + * specific, here are the requirements. + * + *

    + *
  • Input image tensor ({@code kTfLiteUInt8}/{@code kTfLiteFloat32}) + *
      + *
    • image input of size {@code [batch x height x width x channels]}. + *
    • batch inference is not supported ({@code batch} is required to be 1). + *
    • only RGB inputs are supported ({@code channels} is required to be 3). + *
    • if type is kTfLiteFloat32, NormalizationOptions are required to be attached to the + * metadata for input normalization. + *
    + *
  • At least one output tensor ({@code kTfLiteUInt8}/{@code kTfLiteFloat32}) with shape {@code + * [1 x N]} where N is the number of dimensions in the produced embeddings. + *
+ */ +public final class ImageEmbedder extends BaseVisionTaskApi { + private static final String TAG = ImageEmbedder.class.getSimpleName(); + private static final String IMAGE_IN_STREAM_NAME = "image_in"; + private static final String NORM_RECT_IN_STREAM_NAME = "norm_rect_in"; + private static final List INPUT_STREAMS = + Collections.unmodifiableList( + Arrays.asList("IMAGE:" + IMAGE_IN_STREAM_NAME, "NORM_RECT:" + NORM_RECT_IN_STREAM_NAME)); + private static final List OUTPUT_STREAMS = + Collections.unmodifiableList(Arrays.asList("EMBEDDINGS:embeddings_out", "IMAGE:image_out")); + private static final int EMBEDDINGS_OUT_STREAM_INDEX = 0; + private static final int IMAGE_OUT_STREAM_INDEX = 1; + private static final String TASK_GRAPH_NAME = + "mediapipe.tasks.vision.image_embedder.ImageEmbedderGraph"; + + static { + ProtoUtil.registerTypeName( + EmbeddingsProto.EmbeddingResult.class, + "mediapipe.tasks.components.containers.proto.EmbeddingResult"); + } + + /** + * Creates an {@link ImageEmbedder} instance from a model file and default {@link + * ImageEmbedderOptions}. + * + * @param context an Android {@link Context}. + * @param modelPath path to the embedding model in the assets. + * @throws MediaPipeException if there is an error during {@link ImageEmbedder} creation. + */ + public static ImageEmbedder createFromFile(Context context, String modelPath) { + BaseOptions baseOptions = BaseOptions.builder().setModelAssetPath(modelPath).build(); + return createFromOptions( + context, ImageEmbedderOptions.builder().setBaseOptions(baseOptions).build()); + } + + /** + * Creates an {@link ImageEmbedder} instance from a model file and default {@link + * ImageEmbedderOptions}. + * + * @param context an Android {@link Context}. + * @param modelFile the embedding model {@link File} instance. + * @throws IOException if an I/O error occurs when opening the tflite model file. + * @throws MediaPipeException if there is an error during {@link ImageEmbedder} creation. + */ + public static ImageEmbedder createFromFile(Context context, File modelFile) throws IOException { + try (ParcelFileDescriptor descriptor = + ParcelFileDescriptor.open(modelFile, ParcelFileDescriptor.MODE_READ_ONLY)) { + BaseOptions baseOptions = + BaseOptions.builder().setModelAssetFileDescriptor(descriptor.getFd()).build(); + return createFromOptions( + context, ImageEmbedderOptions.builder().setBaseOptions(baseOptions).build()); + } + } + + /** + * Creates an {@link ImageEmbedder} instance from a model buffer and default {@link + * ImageEmbedderOptions}. + * + * @param context an Android {@link Context}. + * @param modelBuffer a direct {@link ByteBuffer} or a {@link MappedByteBuffer} of the embedding + * model. + * @throws MediaPipeException if there is an error during {@link ImageEmbedder} creation. + */ + public static ImageEmbedder createFromBuffer(Context context, final ByteBuffer modelBuffer) { + BaseOptions baseOptions = BaseOptions.builder().setModelAssetBuffer(modelBuffer).build(); + return createFromOptions( + context, ImageEmbedderOptions.builder().setBaseOptions(baseOptions).build()); + } + + /** + * Creates an {@link ImageEmbedder} instance from an {@link ImageEmbedderOptions} instance. + * + * @param context an Android {@link Context}. + * @param options an {@link ImageEmbedderOptions} instance. + * @throws MediaPipeException if there is an error during {@link ImageEmbedder} creation. + */ + public static ImageEmbedder createFromOptions(Context context, ImageEmbedderOptions options) { + OutputHandler handler = new OutputHandler<>(); + handler.setOutputPacketConverter( + new OutputHandler.OutputPacketConverter() { + @Override + public ImageEmbedderResult convertToTaskResult(List packets) { + try { + return ImageEmbedderResult.create( + EmbeddingResult.createFromProto( + PacketGetter.getProto( + packets.get(EMBEDDINGS_OUT_STREAM_INDEX), + EmbeddingsProto.EmbeddingResult.getDefaultInstance())), + BaseVisionTaskApi.generateResultTimestampMs( + options.runningMode(), packets.get(EMBEDDINGS_OUT_STREAM_INDEX))); + } catch (IOException e) { + throw new MediaPipeException( + MediaPipeException.StatusCode.INTERNAL.ordinal(), e.getMessage()); + } + } + + @Override + public MPImage convertToTaskInput(List packets) { + return new BitmapImageBuilder( + AndroidPacketGetter.getBitmapFromRgb(packets.get(IMAGE_OUT_STREAM_INDEX))) + .build(); + } + }); + options.resultListener().ifPresent(handler::setResultListener); + options.errorListener().ifPresent(handler::setErrorListener); + TaskRunner runner = + TaskRunner.create( + context, + TaskInfo.builder() + .setTaskName(ImageEmbedder.class.getSimpleName()) + .setTaskRunningModeName(options.runningMode().name()) + .setTaskGraphName(TASK_GRAPH_NAME) + .setInputStreams(INPUT_STREAMS) + .setOutputStreams(OUTPUT_STREAMS) + .setTaskOptions(options) + .setEnableFlowLimiting(options.runningMode() == RunningMode.LIVE_STREAM) + .build(), + handler); + return new ImageEmbedder(runner, options.runningMode()); + } + + /** + * Constructor to initialize an {@link ImageEmbedder} from a {@link TaskRunner} and {@link + * RunningMode}. + * + * @param taskRunner a {@link TaskRunner}. + * @param runningMode a mediapipe vision task {@link RunningMode}. + */ + private ImageEmbedder(TaskRunner taskRunner, RunningMode runningMode) { + super(taskRunner, runningMode, IMAGE_IN_STREAM_NAME, NORM_RECT_IN_STREAM_NAME); + } + + /** + * Performs embedding extraction on the provided single image with default image processing + * options, i.e. using the whole image as region-of-interest and without any rotation applied. + * Only use this method when the {@link ImageEmbedder} is created with {@link RunningMode.IMAGE}. + * + *

{@link ImageEmbedder} supports the following color space types: + * + *

    + *
  • {@link Bitmap.Config.ARGB_8888} + *
+ * + * @param image a MediaPipe {@link MPImage} object for processing. + * @throws MediaPipeException if there is an internal error. + */ + public ImageEmbedderResult embed(MPImage image) { + return embed(image, ImageProcessingOptions.builder().build()); + } + + /** + * Performs embedding extraction on the provided single image. Only use this method when the + * {@link ImageEmbedder} is created with {@link RunningMode.IMAGE}. + * + *

{@link ImageEmbedder} supports the following color space types: + * + *

    + *
  • {@link Bitmap.Config.ARGB_8888} + *
+ * + * @param image a MediaPipe {@link MPImage} object for processing. + * @param imageProcessingOptions the {@link ImageProcessingOptions} specifying how to process the + * input image before running inference. + * @throws MediaPipeException if there is an internal error. + */ + public ImageEmbedderResult embed(MPImage image, ImageProcessingOptions imageProcessingOptions) { + return (ImageEmbedderResult) processImageData(image, imageProcessingOptions); + } + + /** + * Performs embedding extraction on the provided video frame with default image processing + * options, i.e. using the whole image as region-of-interest and without any rotation applied. + * Only use this method when the {@link ImageEmbedder} is created with {@link RunningMode.VIDEO}. + * + *

It's required to provide the video frame's timestamp (in milliseconds). The input timestamps + * must be monotonically increasing. + * + *

{@link ImageEmbedder} supports the following color space types: + * + *

    + *
  • {@link Bitmap.Config.ARGB_8888} + *
+ * + * @param image a MediaPipe {@link MPImage} object for processing. + * @param timestampMs the input timestamp (in milliseconds). + * @throws MediaPipeException if there is an internal error. + */ + public ImageEmbedderResult embedForVideo(MPImage image, long timestampMs) { + return embedForVideo(image, ImageProcessingOptions.builder().build(), timestampMs); + } + + /** + * Performs embedding extraction on the provided video frame. Only use this method when the {@link + * ImageEmbedder} is created with {@link RunningMode.VIDEO}. + * + *

It's required to provide the video frame's timestamp (in milliseconds). The input timestamps + * must be monotonically increasing. + * + *

{@link ImageEmbedder} supports the following color space types: + * + *

    + *
  • {@link Bitmap.Config.ARGB_8888} + *
+ * + * @param image a MediaPipe {@link MPImage} object for processing. + * @param imageProcessingOptions the {@link ImageProcessingOptions} specifying how to process the + * input image before running inference. + * @param timestampMs the input timestamp (in milliseconds). + * @throws MediaPipeException if there is an internal error. + */ + public ImageEmbedderResult embedForVideo( + MPImage image, ImageProcessingOptions imageProcessingOptions, long timestampMs) { + return (ImageEmbedderResult) processVideoData(image, imageProcessingOptions, timestampMs); + } + + /** + * Sends live image data to perform embedding extraction with default image processing options, + * i.e. using the whole image as region-of-interest and without any rotation applied, and the + * results will be available via the {@link ResultListener} provided in the {@link + * ImageEmbedderOptions}. Only use this method when the {@link ImageEmbedder} is created with + * {@link RunningMode.LIVE_STREAM}. + * + *

It's required to provide a timestamp (in milliseconds) to indicate when the input image is + * sent to the object detector. The input timestamps must be monotonically increasing. + * + *

{@link ImageEmbedder} supports the following color space types: + * + *

    + *
  • {@link Bitmap.Config.ARGB_8888} + *
+ * + * @param image a MediaPipe {@link MPImage} object for processing. + * @param timestampMs the input timestamp (in milliseconds). + * @throws MediaPipeException if there is an internal error. + */ + public void embedAsync(MPImage image, long timestampMs) { + embedAsync(image, ImageProcessingOptions.builder().build(), timestampMs); + } + + /** + * Sends live image data to perform embedding extraction, and the results will be available via + * the {@link ResultListener} provided in the {@link ImageEmbedderOptions}. Only use this method + * when the {@link ImageEmbedder} is created with {@link RunningMode.LIVE_STREAM}. + * + *

It's required to provide a timestamp (in milliseconds) to indicate when the input image is + * sent to the object detector. The input timestamps must be monotonically increasing. + * + *

{@link ImageEmbedder} supports the following color space types: + * + *

    + *
  • {@link Bitmap.Config.ARGB_8888} + *
+ * + * @param image a MediaPipe {@link MPImage} object for processing. + * @param imageProcessingOptions the {@link ImageProcessingOptions} specifying how to process the + * input image before running inference. + * @param timestampMs the input timestamp (in milliseconds). + * @throws MediaPipeException if there is an internal error. + */ + public void embedAsync( + MPImage image, ImageProcessingOptions imageProcessingOptions, long timestampMs) { + sendLiveStreamData(image, imageProcessingOptions, timestampMs); + } + + /** + * Utility function to compute cosine + * similarity between two {@link Embedding} objects. + * + * @throws IllegalArgumentException if the embeddings are of different types (float vs. + * quantized), have different sizes, or have an L2-norm of 0. + */ + public static double cosineSimilarity(Embedding u, Embedding v) { + return CosineSimilarity.compute(u, v); + } + + /** Options for setting up and {@link ImageEmbedder}. */ + @AutoValue + public abstract static class ImageEmbedderOptions extends TaskOptions { + + /** Builder for {@link ImageEmbedderOptions}. */ + @AutoValue.Builder + public abstract static class Builder { + /** Sets the {@link BaseOptions} for the image embedder task. */ + public abstract Builder setBaseOptions(BaseOptions baseOptions); + + /** + * Sets the {@link RunningMode} for the image embedder task. Default to the image mode. Image + * embedder has three modes: + * + *
    + *
  • IMAGE: The mode for performing embedding extraction on single image inputs. + *
  • VIDEO: The mode for performing embedding extraction on the decoded frames of a video. + *
  • LIVE_STREAM: The mode for for performing embedding extraction on a live stream of + * input data, such as from camera. In this mode, {@code setResultListener} must be + * called to set up a listener to receive the embedding results asynchronously. + *
+ */ + public abstract Builder setRunningMode(RunningMode runningMode); + + /** + * Sets whether L2 normalization should be performed on the returned embeddings. Use this + * option only if the model does not already contain a native L2_NORMALIZATION TF + * Lite Op. In most cases, this is already the case and L2 norm is thus achieved through TF + * Lite inference. + * + *

False by default. + */ + public abstract Builder setL2Normalize(boolean l2Normalize); + + /** + * Sets whether the returned embedding should be quantized to bytes via scalar quantization. + * Embeddings are implicitly assumed to be unit-norm and therefore any dimensions is + * guaranteed to have value in [-1.0, 1.0]. Use {@link #setL2Normalize(boolean)} + * if this is not the case. + * + *

False by default. + */ + public abstract Builder setQuantize(boolean quantize); + + /** + * Sets the {@link ResultListener} to receive the embedding results asynchronously when the + * image embedder is in the live stream mode. + */ + public abstract Builder setResultListener( + ResultListener resultListener); + + /** Sets an optional {@link ErrorListener}. */ + public abstract Builder setErrorListener(ErrorListener errorListener); + + abstract ImageEmbedderOptions autoBuild(); + + /** + * Validates and builds the {@link ImageEmbedderOptions} instance. * + * + * @throws IllegalArgumentException if the result listener and the running mode are not + * properly configured. The result listener should only be set when the image embedder is + * in the live stream mode. + */ + public final ImageEmbedderOptions build() { + ImageEmbedderOptions options = autoBuild(); + if (options.runningMode() == RunningMode.LIVE_STREAM) { + if (!options.resultListener().isPresent()) { + throw new IllegalArgumentException( + "The image embedder is in the live stream mode, a user-defined result listener" + + " must be provided in the ImageEmbedderOptions."); + } + } else if (options.resultListener().isPresent()) { + throw new IllegalArgumentException( + "The image embedder is in the image or video mode, a user-defined result listener" + + " shouldn't be provided in ImageEmbedderOptions."); + } + return options; + } + } + + abstract BaseOptions baseOptions(); + + abstract RunningMode runningMode(); + + abstract boolean l2Normalize(); + + abstract boolean quantize(); + + abstract Optional> resultListener(); + + abstract Optional errorListener(); + + public static Builder builder() { + return new AutoValue_ImageEmbedder_ImageEmbedderOptions.Builder() + .setRunningMode(RunningMode.IMAGE) + .setL2Normalize(false) + .setQuantize(false); + } + + /** Converts a {@link ImageEmbedderOptions} to a {@link CalculatorOptions} protobuf message. */ + @Override + public CalculatorOptions convertToCalculatorOptionsProto() { + BaseOptionsProto.BaseOptions.Builder baseOptionsBuilder = + BaseOptionsProto.BaseOptions.newBuilder(); + baseOptionsBuilder.setUseStreamMode(runningMode() != RunningMode.IMAGE); + baseOptionsBuilder.mergeFrom(convertBaseOptionsToProto(baseOptions())); + EmbedderOptionsProto.EmbedderOptions.Builder embedderOptionsBuilder = + EmbedderOptionsProto.EmbedderOptions.newBuilder(); + embedderOptionsBuilder.setL2Normalize(l2Normalize()); + embedderOptionsBuilder.setQuantize(quantize()); + ImageEmbedderGraphOptionsProto.ImageEmbedderGraphOptions.Builder taskOptionsBuilder = + ImageEmbedderGraphOptionsProto.ImageEmbedderGraphOptions.newBuilder() + .setBaseOptions(baseOptionsBuilder) + .setEmbedderOptions(embedderOptionsBuilder); + return CalculatorOptions.newBuilder() + .setExtension( + ImageEmbedderGraphOptionsProto.ImageEmbedderGraphOptions.ext, + taskOptionsBuilder.build()) + .build(); + } + } +} diff --git a/mediapipe/tasks/java/com/google/mediapipe/tasks/vision/imageembedder/ImageEmbedderResult.java b/mediapipe/tasks/java/com/google/mediapipe/tasks/vision/imageembedder/ImageEmbedderResult.java new file mode 100644 index 000000000..ee3f4abc9 --- /dev/null +++ b/mediapipe/tasks/java/com/google/mediapipe/tasks/vision/imageembedder/ImageEmbedderResult.java @@ -0,0 +1,54 @@ +// Copyright 2022 The MediaPipe Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package com.google.mediapipe.tasks.vision.imageembedder; + +import com.google.auto.value.AutoValue; +import com.google.mediapipe.tasks.components.containers.EmbeddingResult; +import com.google.mediapipe.tasks.components.containers.proto.EmbeddingsProto; +import com.google.mediapipe.tasks.core.TaskResult; + +/** Represents the embedding results generated by {@link ImageEmbedder}. */ +@AutoValue +public abstract class ImageEmbedderResult implements TaskResult { + + /** + * Creates an {@link ImageEmbedderResult} instance. + * + * @param embeddingResult the {@link EmbeddingResult} object containing one embedding per embedder + * head. + * @param timestampMs a timestamp for this result. + */ + static ImageEmbedderResult create(EmbeddingResult embeddingResult, long timestampMs) { + return new AutoValue_ImageEmbedderResult(embeddingResult, timestampMs); + } + + /** + * Creates an {@link ImageEmbedderResult} instance from a {@link EmbeddingsProto.EmbeddingResult} + * protobuf message. + * + * @param proto the {@link EmbeddingsProto.EmbeddingResult} protobuf message to convert. + * @param timestampMs a timestamp for this result. + */ + static ImageEmbedderResult createFromProto( + EmbeddingsProto.EmbeddingResult proto, long timestampMs) { + return create(EmbeddingResult.createFromProto(proto), timestampMs); + } + + /** Contains one embedding per embedder head. */ + public abstract EmbeddingResult embeddingResult(); + + @Override + public abstract long timestampMs(); +} diff --git a/mediapipe/tasks/java/com/google/mediapipe/tasks/vision/imagesegmenter/AndroidManifest.xml b/mediapipe/tasks/java/com/google/mediapipe/tasks/vision/imagesegmenter/AndroidManifest.xml new file mode 100644 index 000000000..6c8070364 --- /dev/null +++ b/mediapipe/tasks/java/com/google/mediapipe/tasks/vision/imagesegmenter/AndroidManifest.xml @@ -0,0 +1,8 @@ + + + + + + diff --git a/mediapipe/tasks/java/com/google/mediapipe/tasks/vision/imagesegmenter/ImageSegmenter.java b/mediapipe/tasks/java/com/google/mediapipe/tasks/vision/imagesegmenter/ImageSegmenter.java new file mode 100644 index 000000000..8d07b7c68 --- /dev/null +++ b/mediapipe/tasks/java/com/google/mediapipe/tasks/vision/imagesegmenter/ImageSegmenter.java @@ -0,0 +1,462 @@ +// Copyright 2023 The MediaPipe Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package com.google.mediapipe.tasks.vision.imagesegmenter; + +import android.content.Context; +import com.google.auto.value.AutoValue; +import com.google.mediapipe.proto.CalculatorOptionsProto.CalculatorOptions; +import com.google.mediapipe.framework.AndroidPacketGetter; +import com.google.mediapipe.framework.MediaPipeException; +import com.google.mediapipe.framework.Packet; +import com.google.mediapipe.framework.PacketGetter; +import com.google.mediapipe.framework.image.BitmapImageBuilder; +import com.google.mediapipe.framework.image.ByteBufferImageBuilder; +import com.google.mediapipe.framework.image.MPImage; +import com.google.mediapipe.tasks.core.BaseOptions; +import com.google.mediapipe.tasks.core.ErrorListener; +import com.google.mediapipe.tasks.core.OutputHandler; +import com.google.mediapipe.tasks.core.OutputHandler.ResultListener; +import com.google.mediapipe.tasks.core.TaskInfo; +import com.google.mediapipe.tasks.core.TaskOptions; +import com.google.mediapipe.tasks.core.TaskRunner; +import com.google.mediapipe.tasks.core.proto.BaseOptionsProto; +import com.google.mediapipe.tasks.vision.core.BaseVisionTaskApi; +import com.google.mediapipe.tasks.vision.core.ImageProcessingOptions; +import com.google.mediapipe.tasks.vision.core.RunningMode; +import com.google.mediapipe.tasks.vision.imagesegmenter.proto.ImageSegmenterGraphOptionsProto; +import com.google.mediapipe.tasks.vision.imagesegmenter.proto.SegmenterOptionsProto; +import java.nio.ByteBuffer; +import java.util.ArrayList; +import java.util.Arrays; +import java.util.Collections; +import java.util.List; +import java.util.Optional; + +/** + * Performs image segmentation on images. + * + *

Note that, unlike other vision tasks, the output of ImageSegmenter is provided through a + * user-defined callback function even for the synchronous API. This makes it possible for + * ImageSegmenter to return the output masks without any copy. {@link ResultListener} must be set in + * the {@link ImageSegmenterOptions} for all {@link RunningMode}. + * + *

The API expects a TFLite model with,TFLite Model Metadata.. + * + *

    + *
  • Input image {@link MPImage} + *
      + *
    • The image that image segmenter runs on. + *
    + *
  • Output ImageSegmenterResult {@link ImageSgmenterResult} + *
      + *
    • An ImageSegmenterResult containing segmented masks. + *
    + *
+ */ +public final class ImageSegmenter extends BaseVisionTaskApi { + private static final String TAG = ImageSegmenter.class.getSimpleName(); + private static final String IMAGE_IN_STREAM_NAME = "image_in"; + private static final String NORM_RECT_IN_STREAM_NAME = "norm_rect_in"; + private static final List INPUT_STREAMS = + Collections.unmodifiableList( + Arrays.asList("IMAGE:" + IMAGE_IN_STREAM_NAME, "NORM_RECT:" + NORM_RECT_IN_STREAM_NAME)); + private static final List OUTPUT_STREAMS = + Collections.unmodifiableList( + Arrays.asList( + "GROUPED_SEGMENTATION:segmented_mask_out", + "IMAGE:image_out", + "SEGMENTATION:0:segmentation")); + private static final int GROUPED_SEGMENTATION_OUT_STREAM_INDEX = 0; + private static final int IMAGE_OUT_STREAM_INDEX = 1; + private static final int SEGMENTATION_OUT_STREAM_INDEX = 2; + private static final String TASK_GRAPH_NAME = + "mediapipe.tasks.vision.image_segmenter.ImageSegmenterGraph"; + + /** + * Creates an {@link ImageSegmenter} instance from an {@link ImageSegmenterOptions}. + * + * @param context an Android {@link Context}. + * @param segmenterOptions an {@link ImageSegmenterOptions} instance. + * @throws MediaPipeException if there is an error during {@link ImageSegmenter} creation. + */ + public static ImageSegmenter createFromOptions( + Context context, ImageSegmenterOptions segmenterOptions) { + // TODO: Consolidate OutputHandler and TaskRunner. + OutputHandler handler = new OutputHandler<>(); + handler.setOutputPacketConverter( + new OutputHandler.OutputPacketConverter() { + @Override + public ImageSegmenterResult convertToTaskResult(List packets) + throws MediaPipeException { + if (packets.get(GROUPED_SEGMENTATION_OUT_STREAM_INDEX).isEmpty()) { + return ImageSegmenterResult.create( + new ArrayList<>(), + packets.get(GROUPED_SEGMENTATION_OUT_STREAM_INDEX).getTimestamp()); + } + List segmentedMasks = new ArrayList<>(); + int width = PacketGetter.getImageWidth(packets.get(SEGMENTATION_OUT_STREAM_INDEX)); + int height = PacketGetter.getImageHeight(packets.get(SEGMENTATION_OUT_STREAM_INDEX)); + int imageFormat = + segmenterOptions.outputType() == ImageSegmenterOptions.OutputType.CONFIDENCE_MASK + ? MPImage.IMAGE_FORMAT_VEC32F1 + : MPImage.IMAGE_FORMAT_ALPHA; + int imageListSize = + PacketGetter.getImageListSize(packets.get(GROUPED_SEGMENTATION_OUT_STREAM_INDEX)); + ByteBuffer[] buffersArray = new ByteBuffer[imageListSize]; + if (!PacketGetter.getImageList( + packets.get(GROUPED_SEGMENTATION_OUT_STREAM_INDEX), buffersArray, false)) { + throw new MediaPipeException( + MediaPipeException.StatusCode.INTERNAL.ordinal(), + "There is an error getting segmented masks. It usually results from incorrect" + + " options of unsupported OutputType of given model."); + } + for (ByteBuffer buffer : buffersArray) { + ByteBufferImageBuilder builder = + new ByteBufferImageBuilder(buffer, width, height, imageFormat); + segmentedMasks.add(builder.build()); + } + + return ImageSegmenterResult.create( + segmentedMasks, + BaseVisionTaskApi.generateResultTimestampMs( + segmenterOptions.runningMode(), + packets.get(GROUPED_SEGMENTATION_OUT_STREAM_INDEX))); + } + + @Override + public MPImage convertToTaskInput(List packets) { + return new BitmapImageBuilder( + AndroidPacketGetter.getBitmapFromRgb(packets.get(IMAGE_OUT_STREAM_INDEX))) + .build(); + } + }); + handler.setResultListener(segmenterOptions.resultListener()); + segmenterOptions.errorListener().ifPresent(handler::setErrorListener); + TaskRunner runner = + TaskRunner.create( + context, + TaskInfo.builder() + .setTaskName(ImageSegmenter.class.getSimpleName()) + .setTaskRunningModeName(segmenterOptions.runningMode().name()) + .setTaskGraphName(TASK_GRAPH_NAME) + .setInputStreams(INPUT_STREAMS) + .setOutputStreams(OUTPUT_STREAMS) + .setTaskOptions(segmenterOptions) + .setEnableFlowLimiting(segmenterOptions.runningMode() == RunningMode.LIVE_STREAM) + .build(), + handler); + return new ImageSegmenter(runner, segmenterOptions.runningMode()); + } + + /** + * Constructor to initialize an {@link ImageSegmenter} from a {@link TaskRunner} and a {@link + * RunningMode}. + * + * @param taskRunner a {@link TaskRunner}. + * @param runningMode a mediapipe vision task {@link RunningMode}. + */ + private ImageSegmenter(TaskRunner taskRunner, RunningMode runningMode) { + super(taskRunner, runningMode, IMAGE_IN_STREAM_NAME, NORM_RECT_IN_STREAM_NAME); + } + + /** + * Performs image segmentation on the provided single image with default image processing options, + * i.e. without any rotation applied, and the results will be available via the {@link + * ResultListener} provided in the {@link ImageSegmenterOptions}. Only use this method when the + * {@link ImageSegmenter} is created with {@link RunningMode.IMAGE}. TODO update java + * doc for input image format. + * + *

{@link ImageSegmenter} supports the following color space types: + * + *

    + *
  • {@link Bitmap.Config.ARGB_8888} + *
+ * + * @param image a MediaPipe {@link MPImage} object for processing. + * @throws MediaPipeException if there is an internal error. + */ + public void segment(MPImage image) { + segment(image, ImageProcessingOptions.builder().build()); + } + + /** + * Performs image segmentation on the provided single image, and the results will be available via + * the {@link ResultListener} provided in the {@link ImageSegmenterOptions}. Only use this method + * when the {@link ImageSegmenter} is created with {@link RunningMode.IMAGE}. TODO + * update java doc for input image format. + * + *

{@link HandLandmarker} supports the following color space types: + * + *

    + *
  • {@link Bitmap.Config.ARGB_8888} + *
+ * + * @param image a MediaPipe {@link MPImage} object for processing. + * @param imageProcessingOptions the {@link ImageProcessingOptions} specifying how to process the + * input image before running inference. Note that region-of-interest is not supported + * by this task: specifying {@link ImageProcessingOptions#regionOfInterest()} will result in + * this method throwing an IllegalArgumentException. + * @throws IllegalArgumentException if the {@link ImageProcessingOptions} specify a + * region-of-interest. + * @throws MediaPipeException if there is an internal error. + */ + public void segment(MPImage image, ImageProcessingOptions imageProcessingOptions) { + validateImageProcessingOptions(imageProcessingOptions); + ImageSegmenterResult unused = + (ImageSegmenterResult) processImageData(image, imageProcessingOptions); + } + + /** + * Performs image segmentation on the provided video frame with default image processing options, + * i.e. without any rotation applied, and the results will be available via the {@link + * ResultListener} provided in the {@link ImageSegmenterOptions}. Only use this method when the + * {@link HandLandmarker} is created with {@link RunningMode.VIDEO}. + * + *

It's required to provide the video frame's timestamp (in milliseconds). The input timestamps + * must be monotonically increasing. + * + *

{@link ImageSegmenter} supports the following color space types: + * + *

    + *
  • {@link Bitmap.Config.ARGB_8888} + *
+ * + * @param image a MediaPipe {@link MPImage} object for processing. + * @param timestampMs the input timestamp (in milliseconds). + * @throws MediaPipeException if there is an internal error. + */ + public void segmentForVideo(MPImage image, long timestampMs) { + segmentForVideo(image, ImageProcessingOptions.builder().build(), timestampMs); + } + + /** + * Performs image segmentation on the provided video frame, and the results will be available via + * the {@link ResultListener} provided in the {@link ImageSegmenterOptions}. Only use this method + * when the {@link ImageSegmenter} is created with {@link RunningMode.VIDEO}. + * + *

It's required to provide the video frame's timestamp (in milliseconds). The input timestamps + * must be monotonically increasing. + * + *

{@link HandLandmarker} supports the following color space types: + * + *

    + *
  • {@link Bitmap.Config.ARGB_8888} + *
+ * + * @param image a MediaPipe {@link MPImage} object for processing. + * @param imageProcessingOptions the {@link ImageProcessingOptions} specifying how to process the + * input image before running inference. Note that region-of-interest is not supported + * by this task: specifying {@link ImageProcessingOptions#regionOfInterest()} will result in + * this method throwing an IllegalArgumentException. + * @param timestampMs the input timestamp (in milliseconds). + * @throws IllegalArgumentException if the {@link ImageProcessingOptions} specify a + * region-of-interest. + * @throws MediaPipeException if there is an internal error. + */ + public void segmentForVideo( + MPImage image, ImageProcessingOptions imageProcessingOptions, long timestampMs) { + validateImageProcessingOptions(imageProcessingOptions); + ImageSegmenterResult unused = + (ImageSegmenterResult) processVideoData(image, imageProcessingOptions, timestampMs); + } + + /** + * Sends live image data to perform hand landmarks detection with default image processing + * options, i.e. without any rotation applied, and the results will be available via the {@link + * ResultListener} provided in the {@link ImageSegmenterOptions}. Only use this method when the + * {@link ImageSegmenter } is created with {@link RunningMode.LIVE_STREAM}. + * + *

It's required to provide a timestamp (in milliseconds) to indicate when the input image is + * sent to the image segmenter. The input timestamps must be monotonically increasing. + * + *

{@link ImageSegmenter} supports the following color space types: + * + *

    + *
  • {@link Bitmap.Config.ARGB_8888} + *
+ * + * @param image a MediaPipe {@link MPImage} object for processing. + * @param timestampMs the input timestamp (in milliseconds). + * @throws MediaPipeException if there is an internal error. + */ + public void segmentAsync(MPImage image, long timestampMs) { + segmentAsync(image, ImageProcessingOptions.builder().build(), timestampMs); + } + + /** + * Sends live image data to perform image segmentation, and the results will be available via the + * {@link ResultListener} provided in the {@link ImageSegmenterOptions}. Only use this method when + * the {@link ImageSegmenter} is created with {@link RunningMode.LIVE_STREAM}. + * + *

It's required to provide a timestamp (in milliseconds) to indicate when the input image is + * sent to the image segmenter. The input timestamps must be monotonically increasing. + * + *

{@link ImageSegmenter} supports the following color space types: + * + *

    + *
  • {@link Bitmap.Config.ARGB_8888} + *
+ * + * @param image a MediaPipe {@link MPImage} object for processing. + * @param imageProcessingOptions the {@link ImageProcessingOptions} specifying how to process the + * input image before running inference. Note that region-of-interest is not supported + * by this task: specifying {@link ImageProcessingOptions#regionOfInterest()} will result in + * this method throwing an IllegalArgumentException. + * @param timestampMs the input timestamp (in milliseconds). + * @throws IllegalArgumentException if the {@link ImageProcessingOptions} specify a + * region-of-interest. + * @throws MediaPipeException if there is an internal error. + */ + public void segmentAsync( + MPImage image, ImageProcessingOptions imageProcessingOptions, long timestampMs) { + validateImageProcessingOptions(imageProcessingOptions); + sendLiveStreamData(image, imageProcessingOptions, timestampMs); + } + + /** Options for setting up an {@link ImageSegmenter}. */ + @AutoValue + public abstract static class ImageSegmenterOptions extends TaskOptions { + + /** Builder for {@link ImageSegmenterOptions}. */ + @AutoValue.Builder + public abstract static class Builder { + /** Sets the base options for the image segmenter task. */ + public abstract Builder setBaseOptions(BaseOptions value); + + /** + * Sets the running mode for the image segmenter task. Default to the image mode. Image + * segmenter has three modes: + * + *
    + *
  • IMAGE: The mode for segmenting image on single image inputs. + *
  • VIDEO: The mode for segmenting image on the decoded frames of a video. + *
  • LIVE_STREAM: The mode for for segmenting image on a live stream of input data, such + * as from camera. In this mode, {@code setResultListener} must be called to set up a + * listener to receive the recognition results asynchronously. + *
+ */ + public abstract Builder setRunningMode(RunningMode value); + + /** + * The locale to use for display names specified through the TFLite Model Metadata, if any. + * Defaults to English. + */ + public abstract Builder setDisplayNamesLocale(String value); + + /** The output type from image segmenter. */ + public abstract Builder setOutputType(OutputType value); + + /** + * Sets the {@link ResultListener} to receive the segmentation results when the graph pipeline + * is done processing an image. + */ + public abstract Builder setResultListener( + ResultListener value); + + /** Sets an optional {@link ErrorListener}}. */ + public abstract Builder setErrorListener(ErrorListener value); + + abstract ImageSegmenterOptions autoBuild(); + + /** + * Validates and builds the {@link ImageSegmenterOptions} instance. + * + * @throws IllegalArgumentException if the result listener and the running mode are not + * properly configured. The result listener should only be set when the image segmenter is + * in the live stream mode. + */ + public final ImageSegmenterOptions build() { + ImageSegmenterOptions options = autoBuild(); + return options; + } + } + + abstract BaseOptions baseOptions(); + + abstract RunningMode runningMode(); + + abstract String displayNamesLocale(); + + abstract OutputType outputType(); + + abstract ResultListener resultListener(); + + abstract Optional errorListener(); + + /** The output type of segmentation results. */ + public enum OutputType { + // Gives a single output mask where each pixel represents the class which + // the pixel in the original image was predicted to belong to. + CATEGORY_MASK, + // Gives a list of output masks where, for each mask, each pixel represents + // the prediction confidence, usually in the [0, 1] range. + CONFIDENCE_MASK + } + + public static Builder builder() { + return new AutoValue_ImageSegmenter_ImageSegmenterOptions.Builder() + .setRunningMode(RunningMode.IMAGE) + .setDisplayNamesLocale("en") + .setOutputType(OutputType.CATEGORY_MASK) + .setResultListener((result, image) -> {}); + } + + /** + * Converts an {@link ImageSegmenterOptions} to a {@link CalculatorOptions} protobuf message. + */ + @Override + public CalculatorOptions convertToCalculatorOptionsProto() { + ImageSegmenterGraphOptionsProto.ImageSegmenterGraphOptions.Builder taskOptionsBuilder = + ImageSegmenterGraphOptionsProto.ImageSegmenterGraphOptions.newBuilder() + .setBaseOptions( + BaseOptionsProto.BaseOptions.newBuilder() + .setUseStreamMode(runningMode() != RunningMode.IMAGE) + .mergeFrom(convertBaseOptionsToProto(baseOptions())) + .build()) + .setDisplayNamesLocale(displayNamesLocale()); + + SegmenterOptionsProto.SegmenterOptions.Builder segmenterOptionsBuilder = + SegmenterOptionsProto.SegmenterOptions.newBuilder(); + if (outputType() == OutputType.CONFIDENCE_MASK) { + segmenterOptionsBuilder.setOutputType( + SegmenterOptionsProto.SegmenterOptions.OutputType.CONFIDENCE_MASK); + } else if (outputType() == OutputType.CATEGORY_MASK) { + segmenterOptionsBuilder.setOutputType( + SegmenterOptionsProto.SegmenterOptions.OutputType.CATEGORY_MASK); + } + // TODO: remove this once activation is handled in metadata and grpah level. + segmenterOptionsBuilder.setActivation( + SegmenterOptionsProto.SegmenterOptions.Activation.SOFTMAX); + taskOptionsBuilder.setSegmenterOptions(segmenterOptionsBuilder); + return CalculatorOptions.newBuilder() + .setExtension( + ImageSegmenterGraphOptionsProto.ImageSegmenterGraphOptions.ext, + taskOptionsBuilder.build()) + .build(); + } + } + + /** + * Validates that the provided {@link ImageProcessingOptions} doesn't contain a + * region-of-interest. + */ + private static void validateImageProcessingOptions( + ImageProcessingOptions imageProcessingOptions) { + if (imageProcessingOptions.regionOfInterest().isPresent()) { + throw new IllegalArgumentException("ImageSegmenter doesn't support region-of-interest."); + } + } +} diff --git a/mediapipe/tasks/java/com/google/mediapipe/tasks/vision/imagesegmenter/ImageSegmenterResult.java b/mediapipe/tasks/java/com/google/mediapipe/tasks/vision/imagesegmenter/ImageSegmenterResult.java new file mode 100644 index 000000000..40fb93dd1 --- /dev/null +++ b/mediapipe/tasks/java/com/google/mediapipe/tasks/vision/imagesegmenter/ImageSegmenterResult.java @@ -0,0 +1,45 @@ +// Copyright 2023 The MediaPipe Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package com.google.mediapipe.tasks.vision.imagesegmenter; + +import com.google.auto.value.AutoValue; +import com.google.mediapipe.framework.image.MPImage; +import com.google.mediapipe.tasks.core.TaskResult; +import java.util.Collections; +import java.util.List; + +/** Represents the segmentation results generated by {@link ImageSegmenter}. */ +@AutoValue +public abstract class ImageSegmenterResult implements TaskResult { + + /** + * Creates an {@link ImageSegmenterResult} instance from a list of segmentation MPImage. + * + * @param segmentations a {@link List} of MPImage representing the segmented masks. If OutputType + * is CATEGORY_MASK, the masks will be in IMAGE_FORMAT_ALPHA format. If OutputType is + * CONFIDENCE_MASK, the masks will be in IMAGE_FORMAT_ALPHA format. + * @param timestampMs a timestamp for this result. + */ + // TODO: consolidate output formats across platforms. + static ImageSegmenterResult create(List segmentations, long timestampMs) { + return new AutoValue_ImageSegmenterResult( + Collections.unmodifiableList(segmentations), timestampMs); + } + + public abstract List segmentations(); + + @Override + public abstract long timestampMs(); +} diff --git a/mediapipe/tasks/java/com/google/mediapipe/tasks/vision/objectdetector/ObjectDetector.java b/mediapipe/tasks/java/com/google/mediapipe/tasks/vision/objectdetector/ObjectDetector.java index 769b9137f..d706189ee 100644 --- a/mediapipe/tasks/java/com/google/mediapipe/tasks/vision/objectdetector/ObjectDetector.java +++ b/mediapipe/tasks/java/com/google/mediapipe/tasks/vision/objectdetector/ObjectDetector.java @@ -190,6 +190,8 @@ public final class ObjectDetector extends BaseVisionTaskApi { TaskRunner.create( context, TaskInfo.builder() + .setTaskName(ObjectDetector.class.getSimpleName()) + .setTaskRunningModeName(detectorOptions.runningMode().name()) .setTaskGraphName(TASK_GRAPH_NAME) .setInputStreams(INPUT_STREAMS) .setOutputStreams(OUTPUT_STREAMS) diff --git a/mediapipe/tasks/javatests/com/google/mediapipe/tasks/text/textclassifier/TextClassifierTest.java b/mediapipe/tasks/javatests/com/google/mediapipe/tasks/text/textclassifier/TextClassifierTest.java index d3f0e90f3..5ed413f6a 100644 --- a/mediapipe/tasks/javatests/com/google/mediapipe/tasks/text/textclassifier/TextClassifierTest.java +++ b/mediapipe/tasks/javatests/com/google/mediapipe/tasks/text/textclassifier/TextClassifierTest.java @@ -40,6 +40,37 @@ public class TextClassifierTest { private static final String NEGATIVE_TEXT = "unflinchingly bleak and desperate"; private static final String POSITIVE_TEXT = "it's a charming and often affecting journey"; + @Test + public void options_failsWithNegativeMaxResults() throws Exception { + IllegalArgumentException exception = + assertThrows( + IllegalArgumentException.class, + () -> + TextClassifierOptions.builder() + .setBaseOptions( + BaseOptions.builder().setModelAssetPath(BERT_MODEL_FILE).build()) + .setMaxResults(-1) + .build()); + assertThat(exception).hasMessageThat().contains("If specified, maxResults must be > 0"); + } + + @Test + public void options_failsWithBothAllowlistAndDenylist() throws Exception { + IllegalArgumentException exception = + assertThrows( + IllegalArgumentException.class, + () -> + TextClassifierOptions.builder() + .setBaseOptions( + BaseOptions.builder().setModelAssetPath(BERT_MODEL_FILE).build()) + .setCategoryAllowlist(Arrays.asList("foo")) + .setCategoryDenylist(Arrays.asList("bar")) + .build()); + assertThat(exception) + .hasMessageThat() + .contains("Category allowlist and denylist are mutually exclusive"); + } + @Test public void create_failsWithMissingModel() throws Exception { String nonExistentFile = "/path/to/non/existent/file"; @@ -67,9 +98,7 @@ public class TextClassifierTest { ApplicationProvider.getApplicationContext(), options)); // TODO: Make MediaPipe InferenceCalculator report the detailed. // interpreter errors (e.g., "Encountered unresolved custom op"). - assertThat(exception) - .hasMessageThat() - .contains("interpreter_builder(&interpreter) == kTfLiteOk"); + assertThat(exception).hasMessageThat().contains("== kTfLiteOk"); } @Test diff --git a/mediapipe/tasks/javatests/com/google/mediapipe/tasks/text/textembedder/AndroidManifest.xml b/mediapipe/tasks/javatests/com/google/mediapipe/tasks/text/textembedder/AndroidManifest.xml new file mode 100644 index 000000000..5d55d7cfe --- /dev/null +++ b/mediapipe/tasks/javatests/com/google/mediapipe/tasks/text/textembedder/AndroidManifest.xml @@ -0,0 +1,24 @@ + + + + + + + + + + + + + + + diff --git a/mediapipe/tasks/javatests/com/google/mediapipe/tasks/text/textembedder/BUILD b/mediapipe/tasks/javatests/com/google/mediapipe/tasks/text/textembedder/BUILD new file mode 100644 index 000000000..a7f804c64 --- /dev/null +++ b/mediapipe/tasks/javatests/com/google/mediapipe/tasks/text/textembedder/BUILD @@ -0,0 +1,19 @@ +# Copyright 2022 The MediaPipe Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +package(default_visibility = ["//mediapipe/tasks:internal"]) + +licenses(["notice"]) + +# TODO: Enable this in OSS diff --git a/mediapipe/tasks/javatests/com/google/mediapipe/tasks/text/textembedder/TextEmbedderTest.java b/mediapipe/tasks/javatests/com/google/mediapipe/tasks/text/textembedder/TextEmbedderTest.java new file mode 100644 index 000000000..48f214770 --- /dev/null +++ b/mediapipe/tasks/javatests/com/google/mediapipe/tasks/text/textembedder/TextEmbedderTest.java @@ -0,0 +1,118 @@ +// Copyright 2022 The MediaPipe Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package com.google.mediapipe.tasks.text.textembedder; + +import static com.google.common.truth.Truth.assertThat; +import static org.junit.Assert.assertThrows; + +import androidx.test.core.app.ApplicationProvider; +import androidx.test.ext.junit.runners.AndroidJUnit4; +import com.google.mediapipe.framework.MediaPipeException; +import org.junit.Test; +import org.junit.runner.RunWith; + +/** Test for {@link TextEmbedder}/ */ +@RunWith(AndroidJUnit4.class) +public class TextEmbedderTest { + private static final String BERT_MODEL_FILE = "mobilebert_embedding_with_metadata.tflite"; + private static final String REGEX_MODEL_FILE = "regex_one_embedding_with_metadata.tflite"; + + private static final double DOUBLE_DIFF_TOLERANCE = 1e-4; + private static final float FLOAT_DIFF_TOLERANCE = 1e-4f; + + @Test + public void create_failsWithMissingModel() throws Exception { + String nonExistentFile = "/path/to/non/existent/file"; + MediaPipeException exception = + assertThrows( + MediaPipeException.class, + () -> + TextEmbedder.createFromFile( + ApplicationProvider.getApplicationContext(), nonExistentFile)); + assertThat(exception).hasMessageThat().contains(nonExistentFile); + } + + @Test + public void embed_succeedsWithBert() throws Exception { + TextEmbedder textEmbedder = + TextEmbedder.createFromFile(ApplicationProvider.getApplicationContext(), BERT_MODEL_FILE); + + TextEmbedderResult result0 = textEmbedder.embed("it's a charming and often affecting journey"); + assertThat(result0.embeddingResult().embeddings().size()).isEqualTo(1); + assertThat(result0.embeddingResult().embeddings().get(0).floatEmbedding()).hasLength(512); + assertThat(result0.embeddingResult().embeddings().get(0).floatEmbedding()[0]) + .isWithin(FLOAT_DIFF_TOLERANCE) + .of(20.59746f); + TextEmbedderResult result1 = textEmbedder.embed("what a great and fantastic trip"); + assertThat(result1.embeddingResult().embeddings().size()).isEqualTo(1); + assertThat(result1.embeddingResult().embeddings().get(0).floatEmbedding()).hasLength(512); + assertThat(result1.embeddingResult().embeddings().get(0).floatEmbedding()[0]) + .isWithin(FLOAT_DIFF_TOLERANCE) + .of(21.774776f); + + // Check cosine similarity. + double similarity = + TextEmbedder.cosineSimilarity( + result0.embeddingResult().embeddings().get(0), + result1.embeddingResult().embeddings().get(0)); + assertThat(similarity).isWithin(DOUBLE_DIFF_TOLERANCE).of(0.968879); + } + + @Test + public void embed_succeedsWithRegex() throws Exception { + TextEmbedder textEmbedder = + TextEmbedder.createFromFile(ApplicationProvider.getApplicationContext(), REGEX_MODEL_FILE); + + TextEmbedderResult result0 = textEmbedder.embed("it's a charming and often affecting journey"); + assertThat(result0.embeddingResult().embeddings().size()).isEqualTo(1); + assertThat(result0.embeddingResult().embeddings().get(0).floatEmbedding()).hasLength(16); + assertThat(result0.embeddingResult().embeddings().get(0).floatEmbedding()[0]) + .isWithin(FLOAT_DIFF_TOLERANCE) + .of(0.030935612f); + TextEmbedderResult result1 = textEmbedder.embed("what a great and fantastic trip"); + assertThat(result1.embeddingResult().embeddings().size()).isEqualTo(1); + assertThat(result1.embeddingResult().embeddings().get(0).floatEmbedding()).hasLength(16); + assertThat(result1.embeddingResult().embeddings().get(0).floatEmbedding()[0]) + .isWithin(FLOAT_DIFF_TOLERANCE) + .of(0.0312863f); + + // Check cosine similarity. + double similarity = + TextEmbedder.cosineSimilarity( + result0.embeddingResult().embeddings().get(0), + result1.embeddingResult().embeddings().get(0)); + assertThat(similarity).isWithin(DOUBLE_DIFF_TOLERANCE).of(0.999937); + } + + @Test + public void classify_succeedsWithBertAndDifferentThemes() throws Exception { + TextEmbedder textEmbedder = + TextEmbedder.createFromFile(ApplicationProvider.getApplicationContext(), BERT_MODEL_FILE); + + TextEmbedderResult result0 = + textEmbedder.embed( + "When you go to this restaurant, they hold the pancake upside-down before they hand " + + "it to you. It's a great gimmick."); + TextEmbedderResult result1 = + textEmbedder.embed("Let\'s make a plan to steal the declaration of independence.'"); + + // Check cosine similarity. + double similarity = + TextEmbedder.cosineSimilarity( + result0.embeddingResult().embeddings().get(0), + result1.embeddingResult().embeddings().get(0)); + assertThat(similarity).isWithin(DOUBLE_DIFF_TOLERANCE).of(0.3477488707202946); + } +} diff --git a/mediapipe/tasks/javatests/com/google/mediapipe/tasks/vision/gesturerecognizer/GestureRecognizerTest.java b/mediapipe/tasks/javatests/com/google/mediapipe/tasks/vision/gesturerecognizer/GestureRecognizerTest.java index c0be4cffe..5821b36cc 100644 --- a/mediapipe/tasks/javatests/com/google/mediapipe/tasks/vision/gesturerecognizer/GestureRecognizerTest.java +++ b/mediapipe/tasks/javatests/com/google/mediapipe/tasks/vision/gesturerecognizer/GestureRecognizerTest.java @@ -28,7 +28,7 @@ import com.google.mediapipe.framework.MediaPipeException; import com.google.mediapipe.framework.image.BitmapImageBuilder; import com.google.mediapipe.framework.image.MPImage; import com.google.mediapipe.tasks.components.containers.Category; -import com.google.mediapipe.tasks.components.containers.Landmark; +import com.google.mediapipe.tasks.components.containers.NormalizedLandmark; import com.google.mediapipe.tasks.components.containers.proto.LandmarksDetectionResultProto.LandmarksDetectionResult; import com.google.mediapipe.tasks.components.processors.ClassifierOptions; import com.google.mediapipe.tasks.core.BaseOptions; @@ -603,7 +603,7 @@ public class GestureRecognizerTest { assertThat(actualResult.landmarks().get(0)) .comparingElementsUsing( Correspondence.from( - (Correspondence.BinaryPredicate) + (Correspondence.BinaryPredicate) (actual, expected) -> { return Correspondence.tolerance(LANDMARKS_ERROR_TOLERANCE) .compare(actual.x(), expected.x()) diff --git a/mediapipe/tasks/javatests/com/google/mediapipe/tasks/vision/handlandmarker/HandLandmarkerTest.java b/mediapipe/tasks/javatests/com/google/mediapipe/tasks/vision/handlandmarker/HandLandmarkerTest.java index 9e12d210f..c313d385d 100644 --- a/mediapipe/tasks/javatests/com/google/mediapipe/tasks/vision/handlandmarker/HandLandmarkerTest.java +++ b/mediapipe/tasks/javatests/com/google/mediapipe/tasks/vision/handlandmarker/HandLandmarkerTest.java @@ -27,7 +27,7 @@ import com.google.mediapipe.framework.MediaPipeException; import com.google.mediapipe.framework.image.BitmapImageBuilder; import com.google.mediapipe.framework.image.MPImage; import com.google.mediapipe.tasks.components.containers.Category; -import com.google.mediapipe.tasks.components.containers.Landmark; +import com.google.mediapipe.tasks.components.containers.NormalizedLandmark; import com.google.mediapipe.tasks.components.containers.proto.LandmarksDetectionResultProto.LandmarksDetectionResult; import com.google.mediapipe.tasks.core.BaseOptions; import com.google.mediapipe.tasks.vision.core.ImageProcessingOptions; @@ -399,7 +399,7 @@ public class HandLandmarkerTest { assertThat(actualResult.landmarks().get(0)) .comparingElementsUsing( Correspondence.from( - (Correspondence.BinaryPredicate) + (Correspondence.BinaryPredicate) (actual, expected) -> { return Correspondence.tolerance(LANDMARKS_ERROR_TOLERANCE) .compare(actual.x(), expected.x()) diff --git a/mediapipe/tasks/javatests/com/google/mediapipe/tasks/vision/imageclassifier/ImageClassifierTest.java b/mediapipe/tasks/javatests/com/google/mediapipe/tasks/vision/imageclassifier/ImageClassifierTest.java index 69820ce2d..dac11bf02 100644 --- a/mediapipe/tasks/javatests/com/google/mediapipe/tasks/vision/imageclassifier/ImageClassifierTest.java +++ b/mediapipe/tasks/javatests/com/google/mediapipe/tasks/vision/imageclassifier/ImageClassifierTest.java @@ -26,7 +26,6 @@ import com.google.mediapipe.framework.MediaPipeException; import com.google.mediapipe.framework.image.BitmapImageBuilder; import com.google.mediapipe.framework.image.MPImage; import com.google.mediapipe.tasks.components.containers.Category; -import com.google.mediapipe.tasks.components.processors.ClassifierOptions; import com.google.mediapipe.tasks.core.BaseOptions; import com.google.mediapipe.tasks.core.TestUtils; import com.google.mediapipe.tasks.vision.core.ImageProcessingOptions; @@ -55,6 +54,37 @@ public class ImageClassifierTest { @RunWith(AndroidJUnit4.class) public static final class General extends ImageClassifierTest { + @Test + public void options_failsWithNegativeMaxResults() throws Exception { + IllegalArgumentException exception = + assertThrows( + IllegalArgumentException.class, + () -> + ImageClassifierOptions.builder() + .setBaseOptions( + BaseOptions.builder().setModelAssetPath(FLOAT_MODEL_FILE).build()) + .setMaxResults(-1) + .build()); + assertThat(exception).hasMessageThat().contains("If specified, maxResults must be > 0"); + } + + @Test + public void options_failsWithBothAllowlistAndDenylist() throws Exception { + IllegalArgumentException exception = + assertThrows( + IllegalArgumentException.class, + () -> + ImageClassifierOptions.builder() + .setBaseOptions( + BaseOptions.builder().setModelAssetPath(FLOAT_MODEL_FILE).build()) + .setCategoryAllowlist(Arrays.asList("foo")) + .setCategoryDenylist(Arrays.asList("bar")) + .build()); + assertThat(exception) + .hasMessageThat() + .contains("Category allowlist and denylist are mutually exclusive"); + } + @Test public void create_failsWithMissingModel() throws Exception { String nonExistentFile = "/path/to/non/existent/file"; @@ -105,7 +135,7 @@ public class ImageClassifierTest { ImageClassifierOptions options = ImageClassifierOptions.builder() .setBaseOptions(BaseOptions.builder().setModelAssetPath(FLOAT_MODEL_FILE).build()) - .setClassifierOptions(ClassifierOptions.builder().setMaxResults(3).build()) + .setMaxResults(3) .build(); ImageClassifier imageClassifier = ImageClassifier.createFromOptions(ApplicationProvider.getApplicationContext(), options); @@ -125,7 +155,7 @@ public class ImageClassifierTest { ImageClassifierOptions options = ImageClassifierOptions.builder() .setBaseOptions(BaseOptions.builder().setModelAssetPath(QUANTIZED_MODEL_FILE).build()) - .setClassifierOptions(ClassifierOptions.builder().setMaxResults(1).build()) + .setMaxResults(1) .build(); ImageClassifier imageClassifier = ImageClassifier.createFromOptions(ApplicationProvider.getApplicationContext(), options); @@ -141,7 +171,7 @@ public class ImageClassifierTest { ImageClassifierOptions options = ImageClassifierOptions.builder() .setBaseOptions(BaseOptions.builder().setModelAssetPath(FLOAT_MODEL_FILE).build()) - .setClassifierOptions(ClassifierOptions.builder().setScoreThreshold(0.02f).build()) + .setScoreThreshold(0.02f) .build(); ImageClassifier imageClassifier = ImageClassifier.createFromOptions(ApplicationProvider.getApplicationContext(), options); @@ -160,10 +190,7 @@ public class ImageClassifierTest { ImageClassifierOptions options = ImageClassifierOptions.builder() .setBaseOptions(BaseOptions.builder().setModelAssetPath(FLOAT_MODEL_FILE).build()) - .setClassifierOptions( - ClassifierOptions.builder() - .setCategoryAllowlist(Arrays.asList("cheeseburger", "guacamole", "meat loaf")) - .build()) + .setCategoryAllowlist(Arrays.asList("cheeseburger", "guacamole", "meat loaf")) .build(); ImageClassifier imageClassifier = ImageClassifier.createFromOptions(ApplicationProvider.getApplicationContext(), options); @@ -183,11 +210,8 @@ public class ImageClassifierTest { ImageClassifierOptions options = ImageClassifierOptions.builder() .setBaseOptions(BaseOptions.builder().setModelAssetPath(FLOAT_MODEL_FILE).build()) - .setClassifierOptions( - ClassifierOptions.builder() - .setMaxResults(3) - .setCategoryDenylist(Arrays.asList("bagel")) - .build()) + .setMaxResults(3) + .setCategoryDenylist(Arrays.asList("bagel")) .build(); ImageClassifier imageClassifier = ImageClassifier.createFromOptions(ApplicationProvider.getApplicationContext(), options); @@ -207,7 +231,7 @@ public class ImageClassifierTest { ImageClassifierOptions options = ImageClassifierOptions.builder() .setBaseOptions(BaseOptions.builder().setModelAssetPath(FLOAT_MODEL_FILE).build()) - .setClassifierOptions(ClassifierOptions.builder().setMaxResults(1).build()) + .setMaxResults(1) .build(); ImageClassifier imageClassifier = ImageClassifier.createFromOptions(ApplicationProvider.getApplicationContext(), options); @@ -228,7 +252,7 @@ public class ImageClassifierTest { ImageClassifierOptions options = ImageClassifierOptions.builder() .setBaseOptions(BaseOptions.builder().setModelAssetPath(FLOAT_MODEL_FILE).build()) - .setClassifierOptions(ClassifierOptions.builder().setMaxResults(3).build()) + .setMaxResults(3) .build(); ImageClassifier imageClassifier = ImageClassifier.createFromOptions(ApplicationProvider.getApplicationContext(), options); @@ -251,7 +275,7 @@ public class ImageClassifierTest { ImageClassifierOptions options = ImageClassifierOptions.builder() .setBaseOptions(BaseOptions.builder().setModelAssetPath(FLOAT_MODEL_FILE).build()) - .setClassifierOptions(ClassifierOptions.builder().setMaxResults(1).build()) + .setMaxResults(1) .build(); ImageClassifier imageClassifier = ImageClassifier.createFromOptions(ApplicationProvider.getApplicationContext(), options); @@ -322,14 +346,14 @@ public class ImageClassifierTest { MediaPipeException.class, () -> imageClassifier.classifyForVideo( - getImageFromAsset(BURGER_IMAGE), /*timestampMs=*/ 0)); + getImageFromAsset(BURGER_IMAGE), /* timestampMs= */ 0)); assertThat(exception).hasMessageThat().contains("not initialized with the video mode"); exception = assertThrows( MediaPipeException.class, () -> imageClassifier.classifyAsync( - getImageFromAsset(BURGER_IMAGE), /*timestampMs=*/ 0)); + getImageFromAsset(BURGER_IMAGE), /* timestampMs= */ 0)); assertThat(exception).hasMessageThat().contains("not initialized with the live stream mode"); } @@ -353,7 +377,7 @@ public class ImageClassifierTest { MediaPipeException.class, () -> imageClassifier.classifyAsync( - getImageFromAsset(BURGER_IMAGE), /*timestampMs=*/ 0)); + getImageFromAsset(BURGER_IMAGE), /* timestampMs= */ 0)); assertThat(exception).hasMessageThat().contains("not initialized with the live stream mode"); } @@ -379,7 +403,7 @@ public class ImageClassifierTest { MediaPipeException.class, () -> imageClassifier.classifyForVideo( - getImageFromAsset(BURGER_IMAGE), /*timestampMs=*/ 0)); + getImageFromAsset(BURGER_IMAGE), /* timestampMs= */ 0)); assertThat(exception).hasMessageThat().contains("not initialized with the video mode"); } @@ -388,7 +412,7 @@ public class ImageClassifierTest { ImageClassifierOptions options = ImageClassifierOptions.builder() .setBaseOptions(BaseOptions.builder().setModelAssetPath(FLOAT_MODEL_FILE).build()) - .setClassifierOptions(ClassifierOptions.builder().setMaxResults(1).build()) + .setMaxResults(1) .build(); ImageClassifier imageClassifier = ImageClassifier.createFromOptions(ApplicationProvider.getApplicationContext(), options); @@ -405,13 +429,14 @@ public class ImageClassifierTest { ImageClassifierOptions options = ImageClassifierOptions.builder() .setBaseOptions(BaseOptions.builder().setModelAssetPath(FLOAT_MODEL_FILE).build()) - .setClassifierOptions(ClassifierOptions.builder().setMaxResults(1).build()) + .setMaxResults(1) .setRunningMode(RunningMode.VIDEO) .build(); ImageClassifier imageClassifier = ImageClassifier.createFromOptions(ApplicationProvider.getApplicationContext(), options); for (int i = 0; i < 3; i++) { - ImageClassifierResult results = imageClassifier.classifyForVideo(image, /*timestampMs=*/ i); + ImageClassifierResult results = + imageClassifier.classifyForVideo(image, /* timestampMs= */ i); assertHasOneHead(results); assertCategoriesAre( results, Arrays.asList(Category.create(0.7952058f, 934, "cheeseburger", ""))); @@ -424,7 +449,7 @@ public class ImageClassifierTest { ImageClassifierOptions options = ImageClassifierOptions.builder() .setBaseOptions(BaseOptions.builder().setModelAssetPath(FLOAT_MODEL_FILE).build()) - .setClassifierOptions(ClassifierOptions.builder().setMaxResults(1).build()) + .setMaxResults(1) .setRunningMode(RunningMode.LIVE_STREAM) .setResultListener( (imageClassificationResult, inputImage) -> { @@ -436,11 +461,11 @@ public class ImageClassifierTest { .build(); try (ImageClassifier imageClassifier = ImageClassifier.createFromOptions(ApplicationProvider.getApplicationContext(), options)) { - imageClassifier.classifyAsync(getImageFromAsset(BURGER_IMAGE), /*timestampMs=*/ 1); + imageClassifier.classifyAsync(getImageFromAsset(BURGER_IMAGE), /* timestampMs= */ 1); MediaPipeException exception = assertThrows( MediaPipeException.class, - () -> imageClassifier.classifyAsync(image, /*timestampMs=*/ 0)); + () -> imageClassifier.classifyAsync(image, /* timestampMs= */ 0)); assertThat(exception) .hasMessageThat() .contains("having a smaller timestamp than the processed timestamp"); @@ -453,7 +478,7 @@ public class ImageClassifierTest { ImageClassifierOptions options = ImageClassifierOptions.builder() .setBaseOptions(BaseOptions.builder().setModelAssetPath(FLOAT_MODEL_FILE).build()) - .setClassifierOptions(ClassifierOptions.builder().setMaxResults(1).build()) + .setMaxResults(1) .setRunningMode(RunningMode.LIVE_STREAM) .setResultListener( (imageClassificationResult, inputImage) -> { @@ -466,7 +491,7 @@ public class ImageClassifierTest { try (ImageClassifier imageClassifier = ImageClassifier.createFromOptions(ApplicationProvider.getApplicationContext(), options)) { for (int i = 0; i < 3; ++i) { - imageClassifier.classifyAsync(image, /*timestampMs=*/ i); + imageClassifier.classifyAsync(image, /* timestampMs= */ i); } } } diff --git a/mediapipe/tasks/javatests/com/google/mediapipe/tasks/vision/imageembedder/AndroidManifest.xml b/mediapipe/tasks/javatests/com/google/mediapipe/tasks/vision/imageembedder/AndroidManifest.xml new file mode 100644 index 000000000..db303a439 --- /dev/null +++ b/mediapipe/tasks/javatests/com/google/mediapipe/tasks/vision/imageembedder/AndroidManifest.xml @@ -0,0 +1,24 @@ + + + + + + + + + + + + + + + diff --git a/mediapipe/tasks/javatests/com/google/mediapipe/tasks/vision/imageembedder/BUILD b/mediapipe/tasks/javatests/com/google/mediapipe/tasks/vision/imageembedder/BUILD new file mode 100644 index 000000000..a7f804c64 --- /dev/null +++ b/mediapipe/tasks/javatests/com/google/mediapipe/tasks/vision/imageembedder/BUILD @@ -0,0 +1,19 @@ +# Copyright 2022 The MediaPipe Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +package(default_visibility = ["//mediapipe/tasks:internal"]) + +licenses(["notice"]) + +# TODO: Enable this in OSS diff --git a/mediapipe/tasks/javatests/com/google/mediapipe/tasks/vision/imageembedder/ImageEmbedderTest.java b/mediapipe/tasks/javatests/com/google/mediapipe/tasks/vision/imageembedder/ImageEmbedderTest.java new file mode 100644 index 000000000..8dec6f80b --- /dev/null +++ b/mediapipe/tasks/javatests/com/google/mediapipe/tasks/vision/imageembedder/ImageEmbedderTest.java @@ -0,0 +1,441 @@ +// Copyright 2022 The MediaPipe Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package com.google.mediapipe.tasks.vision.imageembedder; + +import static com.google.common.truth.Truth.assertThat; +import static org.junit.Assert.assertThrows; + +import android.content.res.AssetManager; +import android.graphics.BitmapFactory; +import android.graphics.RectF; +import androidx.test.core.app.ApplicationProvider; +import androidx.test.ext.junit.runners.AndroidJUnit4; +import com.google.mediapipe.framework.MediaPipeException; +import com.google.mediapipe.framework.image.BitmapImageBuilder; +import com.google.mediapipe.framework.image.MPImage; +import com.google.mediapipe.tasks.core.BaseOptions; +import com.google.mediapipe.tasks.core.TestUtils; +import com.google.mediapipe.tasks.vision.core.ImageProcessingOptions; +import com.google.mediapipe.tasks.vision.core.RunningMode; +import com.google.mediapipe.tasks.vision.imageembedder.ImageEmbedder.ImageEmbedderOptions; +import java.io.InputStream; +import java.nio.ByteBuffer; +import org.junit.Test; +import org.junit.runner.RunWith; +import org.junit.runners.Suite; +import org.junit.runners.Suite.SuiteClasses; + +/** Test for {@link ImageEmbedder}/ */ +@RunWith(Suite.class) +@SuiteClasses({ImageEmbedderTest.General.class, ImageEmbedderTest.RunningModeTest.class}) +public class ImageEmbedderTest { + private static final String MOBILENET_EMBEDDER = "mobilenet_v3_small_100_224_embedder.tflite"; + private static final String BURGER_IMAGE = "burger.jpg"; + private static final String BURGER_CROP_IMAGE = "burger_crop.jpg"; + private static final String BURGER_ROTATED_IMAGE = "burger_rotated.jpg"; + + private static final double DOUBLE_DIFF_TOLERANCE = 1e-4; + + @RunWith(AndroidJUnit4.class) + public static final class General extends ImageEmbedderTest { + + @Test + public void create_failsWithMissingModel() throws Exception { + String nonExistentFile = "/path/to/non/existent/file"; + MediaPipeException exception = + assertThrows( + MediaPipeException.class, + () -> + ImageEmbedder.createFromFile( + ApplicationProvider.getApplicationContext(), nonExistentFile)); + assertThat(exception).hasMessageThat().contains(nonExistentFile); + } + + @Test + public void create_failsWithInvalidModelBuffer() throws Exception { + // Create a non-direct model ByteBuffer. + ByteBuffer modelBuffer = + TestUtils.loadToNonDirectByteBuffer( + ApplicationProvider.getApplicationContext(), MOBILENET_EMBEDDER); + + IllegalArgumentException exception = + assertThrows( + IllegalArgumentException.class, + () -> + ImageEmbedder.createFromBuffer( + ApplicationProvider.getApplicationContext(), modelBuffer)); + + assertThat(exception) + .hasMessageThat() + .contains("The model buffer should be either a direct ByteBuffer or a MappedByteBuffer."); + } + + @Test + public void embed_succeedsWithNoOptions() throws Exception { + ImageEmbedder imageEmbedder = + ImageEmbedder.createFromFile( + ApplicationProvider.getApplicationContext(), MOBILENET_EMBEDDER); + ImageEmbedderResult result = imageEmbedder.embed(getImageFromAsset(BURGER_IMAGE)); + ImageEmbedderResult resultCrop = imageEmbedder.embed(getImageFromAsset(BURGER_CROP_IMAGE)); + + // Check results. + assertHasOneHeadAndCorrectDimension(result, /* quantized= */ false); + assertHasOneHeadAndCorrectDimension(resultCrop, /* quantized= */ false); + // Check similarity. + double similarity = + ImageEmbedder.cosineSimilarity( + result.embeddingResult().embeddings().get(0), + resultCrop.embeddingResult().embeddings().get(0)); + assertThat(similarity).isWithin(DOUBLE_DIFF_TOLERANCE).of(0.925272); + } + + @Test + public void embed_succeedsWithL2Normalization() throws Exception { + BaseOptions baseOptions = BaseOptions.builder().setModelAssetPath(MOBILENET_EMBEDDER).build(); + ImageEmbedderOptions options = + ImageEmbedderOptions.builder().setBaseOptions(baseOptions).setL2Normalize(true).build(); + + ImageEmbedder imageEmbedder = + ImageEmbedder.createFromOptions(ApplicationProvider.getApplicationContext(), options); + ImageEmbedderResult result = imageEmbedder.embed(getImageFromAsset(BURGER_IMAGE)); + ImageEmbedderResult resultCrop = imageEmbedder.embed(getImageFromAsset(BURGER_CROP_IMAGE)); + + // Check results. + assertHasOneHeadAndCorrectDimension(result, /* quantized= */ false); + assertHasOneHeadAndCorrectDimension(resultCrop, /* quantized= */ false); + // Check similarity. + double similarity = + ImageEmbedder.cosineSimilarity( + result.embeddingResult().embeddings().get(0), + resultCrop.embeddingResult().embeddings().get(0)); + assertThat(similarity).isWithin(DOUBLE_DIFF_TOLERANCE).of(0.925272); + } + + @Test + public void embed_succeedsWithQuantization() throws Exception { + BaseOptions baseOptions = BaseOptions.builder().setModelAssetPath(MOBILENET_EMBEDDER).build(); + ImageEmbedderOptions options = + ImageEmbedderOptions.builder().setBaseOptions(baseOptions).setQuantize(true).build(); + + ImageEmbedder imageEmbedder = + ImageEmbedder.createFromOptions(ApplicationProvider.getApplicationContext(), options); + ImageEmbedderResult result = imageEmbedder.embed(getImageFromAsset(BURGER_IMAGE)); + ImageEmbedderResult resultCrop = imageEmbedder.embed(getImageFromAsset(BURGER_CROP_IMAGE)); + + // Check results. + assertHasOneHeadAndCorrectDimension(result, /* quantized= */ true); + assertHasOneHeadAndCorrectDimension(resultCrop, /* quantized= */ true); + // Check similarity. + double similarity = + ImageEmbedder.cosineSimilarity( + result.embeddingResult().embeddings().get(0), + resultCrop.embeddingResult().embeddings().get(0)); + assertThat(similarity).isWithin(DOUBLE_DIFF_TOLERANCE).of(0.926776); + } + + @Test + public void embed_succeedsWithRegionOfInterest() throws Exception { + ImageEmbedder imageEmbedder = + ImageEmbedder.createFromFile( + ApplicationProvider.getApplicationContext(), MOBILENET_EMBEDDER); + // RectF around the region in "burger.jpg" corresponding to "burger_crop.jpg". + RectF roi = new RectF(0.0f, 0.0f, 0.833333f, 1.0f); + ImageProcessingOptions imageProcessingOptions = + ImageProcessingOptions.builder().setRegionOfInterest(roi).build(); + ImageEmbedderResult resultRoi = + imageEmbedder.embed(getImageFromAsset(BURGER_IMAGE), imageProcessingOptions); + ImageEmbedderResult resultCrop = imageEmbedder.embed(getImageFromAsset(BURGER_CROP_IMAGE)); + + // Check results. + assertHasOneHeadAndCorrectDimension(resultRoi, /* quantized= */ false); + assertHasOneHeadAndCorrectDimension(resultCrop, /* quantized= */ false); + // Check similarity. + double similarity = + ImageEmbedder.cosineSimilarity( + resultRoi.embeddingResult().embeddings().get(0), + resultCrop.embeddingResult().embeddings().get(0)); + assertThat(similarity).isWithin(DOUBLE_DIFF_TOLERANCE).of(0.999931f); + } + + @Test + public void embed_succeedsWithRotation() throws Exception { + ImageEmbedder imageEmbedder = + ImageEmbedder.createFromFile( + ApplicationProvider.getApplicationContext(), MOBILENET_EMBEDDER); + ImageEmbedderResult result = imageEmbedder.embed(getImageFromAsset(BURGER_IMAGE)); + ImageProcessingOptions imageProcessingOptions = + ImageProcessingOptions.builder().setRotationDegrees(-90).build(); + ImageEmbedderResult resultRotated = + imageEmbedder.embed(getImageFromAsset(BURGER_ROTATED_IMAGE), imageProcessingOptions); + + // Check results. + assertHasOneHeadAndCorrectDimension(result, /* quantized= */ false); + assertHasOneHeadAndCorrectDimension(resultRotated, /* quantized= */ false); + // Check similarity. + double similarity = + ImageEmbedder.cosineSimilarity( + result.embeddingResult().embeddings().get(0), + resultRotated.embeddingResult().embeddings().get(0)); + assertThat(similarity).isWithin(DOUBLE_DIFF_TOLERANCE).of(0.571648426f); + } + + @Test + public void embed_succeedsWithRegionOfInterestAndRotation() throws Exception { + ImageEmbedder imageEmbedder = + ImageEmbedder.createFromFile( + ApplicationProvider.getApplicationContext(), MOBILENET_EMBEDDER); + // RectF around the region in "burger_rotated.jpg" corresponding to "burger_crop.jpg". + RectF roi = new RectF(0.0f, 0.0f, 1.0f, 0.833333f); + ImageProcessingOptions imageProcessingOptions = + ImageProcessingOptions.builder().setRegionOfInterest(roi).setRotationDegrees(-90).build(); + ImageEmbedderResult resultRoiRotated = + imageEmbedder.embed(getImageFromAsset(BURGER_ROTATED_IMAGE), imageProcessingOptions); + ImageEmbedderResult resultCrop = imageEmbedder.embed(getImageFromAsset(BURGER_CROP_IMAGE)); + + // Check results. + assertHasOneHeadAndCorrectDimension(resultRoiRotated, /* quantized= */ false); + assertHasOneHeadAndCorrectDimension(resultCrop, /* quantized= */ false); + // Check similarity. + double similarity = + ImageEmbedder.cosineSimilarity( + resultRoiRotated.embeddingResult().embeddings().get(0), + resultCrop.embeddingResult().embeddings().get(0)); + assertThat(similarity).isWithin(DOUBLE_DIFF_TOLERANCE).of(0.62780395f); + } + } + + @RunWith(AndroidJUnit4.class) + public static final class RunningModeTest extends ImageEmbedderTest { + + @Test + public void create_failsWithIllegalResultListenerInNonLiveStreamMode() throws Exception { + for (RunningMode mode : new RunningMode[] {RunningMode.IMAGE, RunningMode.VIDEO}) { + IllegalArgumentException exception = + assertThrows( + IllegalArgumentException.class, + () -> + ImageEmbedderOptions.builder() + .setBaseOptions( + BaseOptions.builder().setModelAssetPath(MOBILENET_EMBEDDER).build()) + .setRunningMode(mode) + .setResultListener((result, inputImage) -> {}) + .build()); + assertThat(exception) + .hasMessageThat() + .contains("a user-defined result listener shouldn't be provided"); + } + } + + @Test + public void create_failsWithMissingResultListenerInLiveSteamMode() throws Exception { + IllegalArgumentException exception = + assertThrows( + IllegalArgumentException.class, + () -> + ImageEmbedderOptions.builder() + .setBaseOptions( + BaseOptions.builder().setModelAssetPath(MOBILENET_EMBEDDER).build()) + .setRunningMode(RunningMode.LIVE_STREAM) + .build()); + assertThat(exception) + .hasMessageThat() + .contains("a user-defined result listener must be provided"); + } + + @Test + public void embed_failsWithCallingWrongApiInImageMode() throws Exception { + ImageEmbedderOptions options = + ImageEmbedderOptions.builder() + .setBaseOptions(BaseOptions.builder().setModelAssetPath(MOBILENET_EMBEDDER).build()) + .setRunningMode(RunningMode.IMAGE) + .build(); + + ImageEmbedder imageEmbedder = + ImageEmbedder.createFromOptions(ApplicationProvider.getApplicationContext(), options); + MediaPipeException exception = + assertThrows( + MediaPipeException.class, + () -> + imageEmbedder.embedForVideo( + getImageFromAsset(BURGER_IMAGE), /* timestampMs= */ 0)); + assertThat(exception).hasMessageThat().contains("not initialized with the video mode"); + exception = + assertThrows( + MediaPipeException.class, + () -> + imageEmbedder.embedAsync(getImageFromAsset(BURGER_IMAGE), /* timestampMs= */ 0)); + assertThat(exception).hasMessageThat().contains("not initialized with the live stream mode"); + } + + @Test + public void embed_failsWithCallingWrongApiInVideoMode() throws Exception { + ImageEmbedderOptions options = + ImageEmbedderOptions.builder() + .setBaseOptions(BaseOptions.builder().setModelAssetPath(MOBILENET_EMBEDDER).build()) + .setRunningMode(RunningMode.VIDEO) + .build(); + + ImageEmbedder imageEmbedder = + ImageEmbedder.createFromOptions(ApplicationProvider.getApplicationContext(), options); + MediaPipeException exception = + assertThrows( + MediaPipeException.class, () -> imageEmbedder.embed(getImageFromAsset(BURGER_IMAGE))); + assertThat(exception).hasMessageThat().contains("not initialized with the image mode"); + exception = + assertThrows( + MediaPipeException.class, + () -> + imageEmbedder.embedAsync(getImageFromAsset(BURGER_IMAGE), /* timestampMs= */ 0)); + assertThat(exception).hasMessageThat().contains("not initialized with the live stream mode"); + } + + @Test + public void embed_failsWithCallingWrongApiInLiveSteamMode() throws Exception { + ImageEmbedderOptions options = + ImageEmbedderOptions.builder() + .setBaseOptions(BaseOptions.builder().setModelAssetPath(MOBILENET_EMBEDDER).build()) + .setRunningMode(RunningMode.LIVE_STREAM) + .setResultListener((imageClassificationResult, inputImage) -> {}) + .build(); + + ImageEmbedder imageEmbedder = + ImageEmbedder.createFromOptions(ApplicationProvider.getApplicationContext(), options); + + MediaPipeException exception = + assertThrows( + MediaPipeException.class, () -> imageEmbedder.embed(getImageFromAsset(BURGER_IMAGE))); + assertThat(exception).hasMessageThat().contains("not initialized with the image mode"); + exception = + assertThrows( + MediaPipeException.class, + () -> + imageEmbedder.embedForVideo( + getImageFromAsset(BURGER_IMAGE), /* timestampMs= */ 0)); + assertThat(exception).hasMessageThat().contains("not initialized with the video mode"); + } + + @Test + public void embed_succeedsWithImageMode() throws Exception { + ImageEmbedder imageEmbedder = + ImageEmbedder.createFromFile( + ApplicationProvider.getApplicationContext(), MOBILENET_EMBEDDER); + ImageEmbedderResult result = imageEmbedder.embed(getImageFromAsset(BURGER_IMAGE)); + ImageEmbedderResult resultCrop = imageEmbedder.embed(getImageFromAsset(BURGER_CROP_IMAGE)); + + // Check results. + assertHasOneHeadAndCorrectDimension(result, /* quantized= */ false); + assertHasOneHeadAndCorrectDimension(resultCrop, /* quantized= */ false); + // Check similarity. + double similarity = + ImageEmbedder.cosineSimilarity( + result.embeddingResult().embeddings().get(0), + resultCrop.embeddingResult().embeddings().get(0)); + assertThat(similarity).isWithin(DOUBLE_DIFF_TOLERANCE).of(0.925272); + } + + @Test + public void embed_succeedsWithVideoMode() throws Exception { + BaseOptions baseOptions = BaseOptions.builder().setModelAssetPath(MOBILENET_EMBEDDER).build(); + ImageEmbedderOptions options = + ImageEmbedderOptions.builder() + .setBaseOptions(baseOptions) + .setRunningMode(RunningMode.VIDEO) + .build(); + ImageEmbedder imageEmbedder = + ImageEmbedder.createFromOptions(ApplicationProvider.getApplicationContext(), options); + + for (int i = 0; i < 3; ++i) { + ImageEmbedderResult result = + imageEmbedder.embedForVideo(getImageFromAsset(BURGER_IMAGE), /* timestampMs= */ i); + assertHasOneHeadAndCorrectDimension(result, /* quantized= */ false); + } + } + + @Test + public void embed_failsWithOutOfOrderInputTimestamps() throws Exception { + MPImage image = getImageFromAsset(BURGER_IMAGE); + BaseOptions baseOptions = BaseOptions.builder().setModelAssetPath(MOBILENET_EMBEDDER).build(); + ImageEmbedderOptions options = + ImageEmbedderOptions.builder() + .setBaseOptions(baseOptions) + .setRunningMode(RunningMode.LIVE_STREAM) + .setResultListener( + (imageEmbedderResult, inputImage) -> { + assertHasOneHeadAndCorrectDimension( + imageEmbedderResult, /* quantized= */ false); + assertImageSizeIsExpected(inputImage); + }) + .build(); + try (ImageEmbedder imageEmbedder = + ImageEmbedder.createFromOptions(ApplicationProvider.getApplicationContext(), options)) { + imageEmbedder.embedAsync(getImageFromAsset(BURGER_IMAGE), /* timestampMs= */ 1); + MediaPipeException exception = + assertThrows( + MediaPipeException.class, + () -> imageEmbedder.embedAsync(image, /* timestampMs= */ 0)); + assertThat(exception) + .hasMessageThat() + .contains("having a smaller timestamp than the processed timestamp"); + } + } + + @Test + public void embed_succeedsWithLiveStreamMode() throws Exception { + MPImage image = getImageFromAsset(BURGER_IMAGE); + BaseOptions baseOptions = BaseOptions.builder().setModelAssetPath(MOBILENET_EMBEDDER).build(); + ImageEmbedderOptions options = + ImageEmbedderOptions.builder() + .setBaseOptions(baseOptions) + .setRunningMode(RunningMode.LIVE_STREAM) + .setResultListener( + (imageEmbedderResult, inputImage) -> { + assertHasOneHeadAndCorrectDimension( + imageEmbedderResult, /* quantized= */ false); + assertImageSizeIsExpected(inputImage); + }) + .build(); + try (ImageEmbedder imageEmbedder = + ImageEmbedder.createFromOptions(ApplicationProvider.getApplicationContext(), options)) { + for (int i = 0; i < 3; ++i) { + imageEmbedder.embedAsync(image, /* timestampMs= */ i); + } + } + } + } + + private static MPImage getImageFromAsset(String filePath) throws Exception { + AssetManager assetManager = ApplicationProvider.getApplicationContext().getAssets(); + InputStream istr = assetManager.open(filePath); + return new BitmapImageBuilder(BitmapFactory.decodeStream(istr)).build(); + } + + private static void assertHasOneHeadAndCorrectDimension( + ImageEmbedderResult result, boolean quantized) { + assertThat(result.embeddingResult().embeddings()).hasSize(1); + assertThat(result.embeddingResult().embeddings().get(0).headIndex()).isEqualTo(0); + assertThat(result.embeddingResult().embeddings().get(0).headName().get()).isEqualTo("feature"); + if (quantized) { + assertThat(result.embeddingResult().embeddings().get(0).quantizedEmbedding()).hasLength(1024); + } else { + assertThat(result.embeddingResult().embeddings().get(0).floatEmbedding()).hasLength(1024); + } + } + + private static void assertImageSizeIsExpected(MPImage inputImage) { + assertThat(inputImage).isNotNull(); + assertThat(inputImage.getWidth()).isEqualTo(480); + assertThat(inputImage.getHeight()).isEqualTo(325); + } +} diff --git a/mediapipe/tasks/javatests/com/google/mediapipe/tasks/vision/imagesegmenter/AndroidManifest.xml b/mediapipe/tasks/javatests/com/google/mediapipe/tasks/vision/imagesegmenter/AndroidManifest.xml new file mode 100644 index 000000000..c641d446f --- /dev/null +++ b/mediapipe/tasks/javatests/com/google/mediapipe/tasks/vision/imagesegmenter/AndroidManifest.xml @@ -0,0 +1,24 @@ + + + + + + + + + + + + + + + diff --git a/mediapipe/tasks/javatests/com/google/mediapipe/tasks/vision/imagesegmenter/BUILD b/mediapipe/tasks/javatests/com/google/mediapipe/tasks/vision/imagesegmenter/BUILD new file mode 100644 index 000000000..c14486766 --- /dev/null +++ b/mediapipe/tasks/javatests/com/google/mediapipe/tasks/vision/imagesegmenter/BUILD @@ -0,0 +1,19 @@ +# Copyright 2023 The MediaPipe Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +package(default_visibility = ["//mediapipe/tasks:internal"]) + +licenses(["notice"]) + +# TODO: Enable this in OSS diff --git a/mediapipe/tasks/javatests/com/google/mediapipe/tasks/vision/imagesegmenter/ImageSegmenterTest.java b/mediapipe/tasks/javatests/com/google/mediapipe/tasks/vision/imagesegmenter/ImageSegmenterTest.java new file mode 100644 index 000000000..c11bb1f31 --- /dev/null +++ b/mediapipe/tasks/javatests/com/google/mediapipe/tasks/vision/imagesegmenter/ImageSegmenterTest.java @@ -0,0 +1,427 @@ +// Copyright 2023 The MediaPipe Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package com.google.mediapipe.tasks.vision.imagesegmenter; + +import static com.google.common.truth.Truth.assertThat; +import static org.junit.Assert.assertThrows; + +import android.content.res.AssetManager; +import android.graphics.Bitmap; +import android.graphics.BitmapFactory; +import android.graphics.Color; +import androidx.test.core.app.ApplicationProvider; +import androidx.test.ext.junit.runners.AndroidJUnit4; +import com.google.mediapipe.framework.MediaPipeException; +import com.google.mediapipe.framework.image.BitmapExtractor; +import com.google.mediapipe.framework.image.BitmapImageBuilder; +import com.google.mediapipe.framework.image.ByteBufferExtractor; +import com.google.mediapipe.framework.image.MPImage; +import com.google.mediapipe.tasks.core.BaseOptions; +import com.google.mediapipe.tasks.vision.core.RunningMode; +import com.google.mediapipe.tasks.vision.imagesegmenter.ImageSegmenter.ImageSegmenterOptions; +import java.io.InputStream; +import java.nio.ByteBuffer; +import java.nio.FloatBuffer; +import java.util.List; +import org.junit.Test; +import org.junit.runner.RunWith; +import org.junit.runners.Suite; +import org.junit.runners.Suite.SuiteClasses; + +/** Test for {@link ImageSegmenter}. */ +@RunWith(Suite.class) +@SuiteClasses({ImageSegmenterTest.General.class, ImageSegmenterTest.RunningModeTest.class}) +public class ImageSegmenterTest { + private static final String DEEPLAB_MODEL_FILE = "deeplabv3.tflite"; + private static final String SELFIE_128x128_MODEL_FILE = "selfie_segm_128_128_3.tflite"; + private static final String SELFIE_144x256_MODEL_FILE = "selfie_segm_144_256_3.tflite"; + private static final String CAT_IMAGE = "cat.jpg"; + private static final float GOLDEN_MASK_SIMILARITY = 0.96f; + private static final int MAGNIFICATION_FACTOR = 10; + + @RunWith(AndroidJUnit4.class) + public static final class General extends ImageSegmenterTest { + + @Test + public void segment_successWithCategoryMask() throws Exception { + final String inputImageName = "segmentation_input_rotation0.jpg"; + final String goldenImageName = "segmentation_golden_rotation0.png"; + MPImage expectedMaskBuffer = getImageFromAsset(goldenImageName); + ImageSegmenterOptions options = + ImageSegmenterOptions.builder() + .setBaseOptions(BaseOptions.builder().setModelAssetPath(DEEPLAB_MODEL_FILE).build()) + .setOutputType(ImageSegmenterOptions.OutputType.CATEGORY_MASK) + .setResultListener( + (actualResult, inputImage) -> { + List segmentations = actualResult.segmentations(); + assertThat(segmentations.size()).isEqualTo(1); + MPImage actualMaskBuffer = actualResult.segmentations().get(0); + verifyCategoryMask( + actualMaskBuffer, + expectedMaskBuffer, + GOLDEN_MASK_SIMILARITY, + MAGNIFICATION_FACTOR); + }) + .build(); + ImageSegmenter imageSegmenter = + ImageSegmenter.createFromOptions(ApplicationProvider.getApplicationContext(), options); + imageSegmenter.segment(getImageFromAsset(inputImageName)); + } + + @Test + public void segment_successWithConfidenceMask() throws Exception { + final String inputImageName = "cat.jpg"; + final String goldenImageName = "cat_mask.jpg"; + MPImage expectedMaskBuffer = getImageFromAsset(goldenImageName); + ImageSegmenterOptions options = + ImageSegmenterOptions.builder() + .setBaseOptions(BaseOptions.builder().setModelAssetPath(DEEPLAB_MODEL_FILE).build()) + .setOutputType(ImageSegmenterOptions.OutputType.CONFIDENCE_MASK) + .setResultListener( + (actualResult, inputImage) -> { + List segmentations = actualResult.segmentations(); + assertThat(segmentations.size()).isEqualTo(21); + // Cat category index 8. + MPImage actualMaskBuffer = actualResult.segmentations().get(8); + verifyConfidenceMask( + actualMaskBuffer, expectedMaskBuffer, GOLDEN_MASK_SIMILARITY); + }) + .build(); + ImageSegmenter imageSegmenter = + ImageSegmenter.createFromOptions(ApplicationProvider.getApplicationContext(), options); + imageSegmenter.segment(getImageFromAsset(inputImageName)); + } + + @Test + public void segment_successWith128x128Segmentation() throws Exception { + final String inputImageName = "mozart_square.jpg"; + final String goldenImageName = "selfie_segm_128_128_3_expected_mask.jpg"; + MPImage expectedMaskBuffer = getImageFromAsset(goldenImageName); + ImageSegmenterOptions options = + ImageSegmenterOptions.builder() + .setBaseOptions( + BaseOptions.builder().setModelAssetPath(SELFIE_128x128_MODEL_FILE).build()) + .setOutputType(ImageSegmenterOptions.OutputType.CONFIDENCE_MASK) + .setResultListener( + (actualResult, inputImage) -> { + List segmentations = actualResult.segmentations(); + assertThat(segmentations.size()).isEqualTo(2); + // Selfie category index 1. + MPImage actualMaskBuffer = actualResult.segmentations().get(1); + verifyConfidenceMask( + actualMaskBuffer, expectedMaskBuffer, GOLDEN_MASK_SIMILARITY); + }) + .build(); + ImageSegmenter imageSegmenter = + ImageSegmenter.createFromOptions(ApplicationProvider.getApplicationContext(), options); + imageSegmenter.segment(getImageFromAsset(inputImageName)); + } + + // TODO: enable this unit test once activation option is supported in metadata. + // @Test + // public void segment_successWith144x256Segmentation() throws Exception { + // final String inputImageName = "mozart_square.jpg"; + // final String goldenImageName = "selfie_segm_144_256_3_expected_mask.jpg"; + // MPImage expectedMaskBuffer = getImageFromAsset(goldenImageName); + // ImageSegmenterOptions options = + // ImageSegmenterOptions.builder() + // .setBaseOptions( + // BaseOptions.builder().setModelAssetPath(SELFIE_144x256_MODEL_FILE).build()) + // .setOutputType(ImageSegmenterOptions.OutputType.CONFIDENCE_MASK) + // .setActivation(ImageSegmenterOptions.Activation.NONE) + // .setResultListener( + // (actualResult, inputImage) -> { + // List segmentations = actualResult.segmentations(); + // assertThat(segmentations.size()).isEqualTo(1); + // MPImage actualMaskBuffer = actualResult.segmentations().get(0); + // verifyConfidenceMask( + // actualMaskBuffer, expectedMaskBuffer, GOLDEN_MASK_SIMILARITY); + // }) + // .build(); + // ImageSegmenter imageSegmenter = + // ImageSegmenter.createFromOptions(ApplicationProvider.getApplicationContext(), + // options); + // imageSegmenter.segment(getImageFromAsset(inputImageName)); + // } + } + + @RunWith(AndroidJUnit4.class) + public static final class RunningModeTest extends ImageSegmenterTest { + @Test + public void segment_failsWithCallingWrongApiInImageMode() throws Exception { + ImageSegmenterOptions options = + ImageSegmenterOptions.builder() + .setBaseOptions(BaseOptions.builder().setModelAssetPath(DEEPLAB_MODEL_FILE).build()) + .setRunningMode(RunningMode.IMAGE) + .build(); + + ImageSegmenter imageSegmenter = + ImageSegmenter.createFromOptions(ApplicationProvider.getApplicationContext(), options); + MediaPipeException exception = + assertThrows( + MediaPipeException.class, + () -> + imageSegmenter.segmentForVideo( + getImageFromAsset(CAT_IMAGE), /* timestampsMs= */ 0)); + assertThat(exception).hasMessageThat().contains("not initialized with the video mode"); + exception = + assertThrows( + MediaPipeException.class, + () -> + imageSegmenter.segmentAsync(getImageFromAsset(CAT_IMAGE), /* timestampsMs= */ 0)); + assertThat(exception).hasMessageThat().contains("not initialized with the live stream mode"); + } + + @Test + public void segment_failsWithCallingWrongApiInVideoMode() throws Exception { + ImageSegmenterOptions options = + ImageSegmenterOptions.builder() + .setBaseOptions(BaseOptions.builder().setModelAssetPath(DEEPLAB_MODEL_FILE).build()) + .setRunningMode(RunningMode.VIDEO) + .build(); + + ImageSegmenter imageSegmenter = + ImageSegmenter.createFromOptions(ApplicationProvider.getApplicationContext(), options); + MediaPipeException exception = + assertThrows( + MediaPipeException.class, () -> imageSegmenter.segment(getImageFromAsset(CAT_IMAGE))); + assertThat(exception).hasMessageThat().contains("not initialized with the image mode"); + exception = + assertThrows( + MediaPipeException.class, + () -> + imageSegmenter.segmentAsync(getImageFromAsset(CAT_IMAGE), /* timestampsMs= */ 0)); + assertThat(exception).hasMessageThat().contains("not initialized with the live stream mode"); + } + + @Test + public void segment_failsWithCallingWrongApiInLiveSteamMode() throws Exception { + ImageSegmenterOptions options = + ImageSegmenterOptions.builder() + .setBaseOptions(BaseOptions.builder().setModelAssetPath(DEEPLAB_MODEL_FILE).build()) + .setRunningMode(RunningMode.LIVE_STREAM) + .setResultListener((result, inputImage) -> {}) + .build(); + + ImageSegmenter imageSegmenter = + ImageSegmenter.createFromOptions(ApplicationProvider.getApplicationContext(), options); + MediaPipeException exception = + assertThrows( + MediaPipeException.class, () -> imageSegmenter.segment(getImageFromAsset(CAT_IMAGE))); + assertThat(exception).hasMessageThat().contains("not initialized with the image mode"); + exception = + assertThrows( + MediaPipeException.class, + () -> + imageSegmenter.segmentForVideo( + getImageFromAsset(CAT_IMAGE), /* timestampsMs= */ 0)); + assertThat(exception).hasMessageThat().contains("not initialized with the video mode"); + } + + @Test + public void segment_successWithImageMode() throws Exception { + final String inputImageName = "cat.jpg"; + final String goldenImageName = "cat_mask.jpg"; + MPImage expectedMaskBuffer = getImageFromAsset(goldenImageName); + ImageSegmenterOptions options = + ImageSegmenterOptions.builder() + .setBaseOptions(BaseOptions.builder().setModelAssetPath(DEEPLAB_MODEL_FILE).build()) + .setOutputType(ImageSegmenterOptions.OutputType.CONFIDENCE_MASK) + .setRunningMode(RunningMode.IMAGE) + .setResultListener( + (actualResult, inputImage) -> { + List segmentations = actualResult.segmentations(); + assertThat(segmentations.size()).isEqualTo(21); + // Cat category index 8. + MPImage actualMaskBuffer = actualResult.segmentations().get(8); + verifyConfidenceMask( + actualMaskBuffer, expectedMaskBuffer, GOLDEN_MASK_SIMILARITY); + }) + .build(); + ImageSegmenter imageSegmenter = + ImageSegmenter.createFromOptions(ApplicationProvider.getApplicationContext(), options); + imageSegmenter.segment(getImageFromAsset(inputImageName)); + } + + @Test + public void segment_successWithVideoMode() throws Exception { + final String inputImageName = "cat.jpg"; + final String goldenImageName = "cat_mask.jpg"; + MPImage expectedMaskBuffer = getImageFromAsset(goldenImageName); + ImageSegmenterOptions options = + ImageSegmenterOptions.builder() + .setBaseOptions(BaseOptions.builder().setModelAssetPath(DEEPLAB_MODEL_FILE).build()) + .setOutputType(ImageSegmenterOptions.OutputType.CONFIDENCE_MASK) + .setRunningMode(RunningMode.VIDEO) + .setResultListener( + (actualResult, inputImage) -> { + List segmentations = actualResult.segmentations(); + assertThat(segmentations.size()).isEqualTo(21); + // Cat category index 8. + MPImage actualMaskBuffer = actualResult.segmentations().get(8); + verifyConfidenceMask( + actualMaskBuffer, expectedMaskBuffer, GOLDEN_MASK_SIMILARITY); + }) + .build(); + ImageSegmenter imageSegmenter = + ImageSegmenter.createFromOptions(ApplicationProvider.getApplicationContext(), options); + for (int i = 0; i < 3; i++) { + imageSegmenter.segmentForVideo(getImageFromAsset(inputImageName), /* timestampsMs= */ i); + } + } + + @Test + public void segment_successWithLiveStreamMode() throws Exception { + final String inputImageName = "cat.jpg"; + final String goldenImageName = "cat_mask.jpg"; + MPImage image = getImageFromAsset(inputImageName); + MPImage expectedResult = getImageFromAsset(goldenImageName); + ImageSegmenterOptions options = + ImageSegmenterOptions.builder() + .setBaseOptions(BaseOptions.builder().setModelAssetPath(DEEPLAB_MODEL_FILE).build()) + .setOutputType(ImageSegmenterOptions.OutputType.CONFIDENCE_MASK) + .setRunningMode(RunningMode.LIVE_STREAM) + .setResultListener( + (segmenterResult, inputImage) -> { + verifyConfidenceMask( + segmenterResult.segmentations().get(8), + expectedResult, + GOLDEN_MASK_SIMILARITY); + }) + .build(); + try (ImageSegmenter imageSegmenter = + ImageSegmenter.createFromOptions(ApplicationProvider.getApplicationContext(), options)) { + for (int i = 0; i < 3; i++) { + imageSegmenter.segmentAsync(image, /* timestampsMs= */ i); + } + } + } + + @Test + public void segment_failsWithOutOfOrderInputTimestamps() throws Exception { + final String inputImageName = "cat.jpg"; + final String goldenImageName = "cat_mask.jpg"; + MPImage image = getImageFromAsset(inputImageName); + MPImage expectedResult = getImageFromAsset(goldenImageName); + ImageSegmenterOptions options = + ImageSegmenterOptions.builder() + .setBaseOptions(BaseOptions.builder().setModelAssetPath(DEEPLAB_MODEL_FILE).build()) + .setOutputType(ImageSegmenterOptions.OutputType.CONFIDENCE_MASK) + .setRunningMode(RunningMode.LIVE_STREAM) + .setResultListener( + (segmenterResult, inputImage) -> { + verifyConfidenceMask( + segmenterResult.segmentations().get(8), + expectedResult, + GOLDEN_MASK_SIMILARITY); + }) + .build(); + try (ImageSegmenter imageSegmenter = + ImageSegmenter.createFromOptions(ApplicationProvider.getApplicationContext(), options)) { + imageSegmenter.segmentAsync(image, /* timestampsMs= */ 1); + MediaPipeException exception = + assertThrows( + MediaPipeException.class, + () -> imageSegmenter.segmentAsync(image, /* timestampsMs= */ 0)); + assertThat(exception) + .hasMessageThat() + .contains("having a smaller timestamp than the processed timestamp"); + } + } + } + + private static void verifyCategoryMask( + MPImage actualMask, MPImage goldenMask, float similarityThreshold, int magnificationFactor) { + assertThat(actualMask.getWidth()).isEqualTo(goldenMask.getWidth()); + assertThat(actualMask.getHeight()).isEqualTo(goldenMask.getHeight()); + ByteBuffer actualMaskBuffer = ByteBufferExtractor.extract(actualMask); + Bitmap goldenMaskBitmap = BitmapExtractor.extract(goldenMask); + int consistentPixels = 0; + final int numPixels = actualMask.getWidth() * actualMask.getHeight(); + actualMaskBuffer.rewind(); + for (int y = 0; y < actualMask.getHeight(); y++) { + for (int x = 0; x < actualMask.getWidth(); x++) { + // RGB values are the same in the golden mask image. + consistentPixels += + actualMaskBuffer.get() * magnificationFactor + == Color.red(goldenMaskBitmap.getPixel(x, y)) + ? 1 + : 0; + } + } + assertThat((float) consistentPixels / numPixels).isGreaterThan(similarityThreshold); + } + + private static void verifyConfidenceMask( + MPImage actualMask, MPImage goldenMask, float similarityThreshold) { + assertThat(actualMask.getWidth()).isEqualTo(goldenMask.getWidth()); + assertThat(actualMask.getHeight()).isEqualTo(goldenMask.getHeight()); + FloatBuffer actualMaskBuffer = ByteBufferExtractor.extract(actualMask).asFloatBuffer(); + Bitmap goldenMaskBitmap = BitmapExtractor.extract(goldenMask); + FloatBuffer goldenMaskBuffer = getByteBufferFromBitmap(goldenMaskBitmap).asFloatBuffer(); + assertThat( + calculateSoftIOU( + actualMaskBuffer, goldenMaskBuffer, actualMask.getWidth() * actualMask.getHeight())) + .isGreaterThan((double) similarityThreshold); + } + + private static MPImage getImageFromAsset(String filePath) throws Exception { + AssetManager assetManager = ApplicationProvider.getApplicationContext().getAssets(); + InputStream istr = assetManager.open(filePath); + return new BitmapImageBuilder(BitmapFactory.decodeStream(istr)).build(); + } + + private static ByteBuffer getByteBufferFromBitmap(Bitmap bitmap) { + ByteBuffer byteBuffer = ByteBuffer.allocateDirect(bitmap.getWidth() * bitmap.getHeight() * 4); + for (int y = 0; y < bitmap.getHeight(); y++) { + for (int x = 0; x < bitmap.getWidth(); x++) { + byteBuffer.putFloat((float) Color.red(bitmap.getPixel(x, y)) / 255.f); + } + } + byteBuffer.rewind(); + return byteBuffer; + } + + private static double calculateSum(FloatBuffer m) { + m.rewind(); + double sum = 0; + while (m.hasRemaining()) { + sum += m.get(); + } + m.rewind(); + return sum; + } + + private static FloatBuffer multiply(FloatBuffer m1, FloatBuffer m2, int bufferSize) { + m1.rewind(); + m2.rewind(); + FloatBuffer buffer = FloatBuffer.allocate(bufferSize); + while (m1.hasRemaining()) { + buffer.put(m1.get() * m2.get()); + } + m1.rewind(); + m2.rewind(); + buffer.rewind(); + return buffer; + } + + private static double calculateSoftIOU(FloatBuffer m1, FloatBuffer m2, int bufferSize) { + double intersectionSum = calculateSum(multiply(m1, m2, bufferSize)); + double m1m1 = calculateSum(multiply(m1, m1.duplicate(), bufferSize)); + double m2m2 = calculateSum(multiply(m2, m2.duplicate(), bufferSize)); + double unionSum = m1m1 + m2m2 - intersectionSum; + return unionSum > 0.0 ? intersectionSum / unionSum : 0.0; + } +} diff --git a/mediapipe/tasks/python/audio/BUILD b/mediapipe/tasks/python/audio/BUILD index dd8719151..6dda7a53c 100644 --- a/mediapipe/tasks/python/audio/BUILD +++ b/mediapipe/tasks/python/audio/BUILD @@ -14,7 +14,7 @@ # Placeholder for internal Python strict library and test compatibility macro. -package(default_visibility = ["//mediapipe/tasks:internal"]) +package(default_visibility = ["//visibility:public"]) licenses(["notice"]) @@ -29,11 +29,34 @@ py_library( "//mediapipe/python:packet_getter", "//mediapipe/tasks/cc/audio/audio_classifier/proto:audio_classifier_graph_options_py_pb2", "//mediapipe/tasks/cc/components/containers/proto:classifications_py_pb2", + "//mediapipe/tasks/cc/components/processors/proto:classifier_options_py_pb2", "//mediapipe/tasks/python/audio/core:audio_task_running_mode", "//mediapipe/tasks/python/audio/core:base_audio_task_api", "//mediapipe/tasks/python/components/containers:audio_data", "//mediapipe/tasks/python/components/containers:classification_result", - "//mediapipe/tasks/python/components/processors:classifier_options", + "//mediapipe/tasks/python/core:base_options", + "//mediapipe/tasks/python/core:optional_dependencies", + "//mediapipe/tasks/python/core:task_info", + ], +) + +py_library( + name = "audio_embedder", + srcs = [ + "audio_embedder.py", + ], + deps = [ + "//mediapipe/python:_framework_bindings", + "//mediapipe/python:packet_creator", + "//mediapipe/python:packet_getter", + "//mediapipe/tasks/cc/audio/audio_embedder/proto:audio_embedder_graph_options_py_pb2", + "//mediapipe/tasks/cc/components/containers/proto:embeddings_py_pb2", + "//mediapipe/tasks/cc/components/processors/proto:embedder_options_py_pb2", + "//mediapipe/tasks/python/audio/core:audio_task_running_mode", + "//mediapipe/tasks/python/audio/core:base_audio_task_api", + "//mediapipe/tasks/python/components/containers:audio_data", + "//mediapipe/tasks/python/components/containers:embedding_result", + "//mediapipe/tasks/python/components/utils:cosine_similarity", "//mediapipe/tasks/python/core:base_options", "//mediapipe/tasks/python/core:optional_dependencies", "//mediapipe/tasks/python/core:task_info", diff --git a/mediapipe/tasks/python/audio/__init__.py b/mediapipe/tasks/python/audio/__init__.py index 947f95d9d..e129800a3 100644 --- a/mediapipe/tasks/python/audio/__init__.py +++ b/mediapipe/tasks/python/audio/__init__.py @@ -16,12 +16,18 @@ import mediapipe.tasks.python.audio.core import mediapipe.tasks.python.audio.audio_classifier +import mediapipe.tasks.python.audio.audio_embedder AudioClassifier = audio_classifier.AudioClassifier AudioClassifierOptions = audio_classifier.AudioClassifierOptions +AudioClassifierResult = audio_classifier.AudioClassifierResult +AudioEmbedder = audio_embedder.AudioEmbedder +AudioEmbedderOptions = audio_embedder.AudioEmbedderOptions +AudioEmbedderResult = audio_embedder.AudioEmbedderResult RunningMode = core.audio_task_running_mode.AudioTaskRunningMode # Remove unnecessary modules to avoid duplication in API docs. del audio_classifier +del audio_embedder del core del mediapipe diff --git a/mediapipe/tasks/python/audio/audio_classifier.py b/mediapipe/tasks/python/audio/audio_classifier.py index e04e778b5..cc87d6221 100644 --- a/mediapipe/tasks/python/audio/audio_classifier.py +++ b/mediapipe/tasks/python/audio/audio_classifier.py @@ -21,11 +21,11 @@ from mediapipe.python import packet_getter from mediapipe.python._framework_bindings import packet from mediapipe.tasks.cc.audio.audio_classifier.proto import audio_classifier_graph_options_pb2 from mediapipe.tasks.cc.components.containers.proto import classifications_pb2 +from mediapipe.tasks.cc.components.processors.proto import classifier_options_pb2 from mediapipe.tasks.python.audio.core import audio_task_running_mode as running_mode_module from mediapipe.tasks.python.audio.core import base_audio_task_api from mediapipe.tasks.python.components.containers import audio_data as audio_data_module from mediapipe.tasks.python.components.containers import classification_result as classification_result_module -from mediapipe.tasks.python.components.processors import classifier_options as classifier_options_module from mediapipe.tasks.python.core import base_options as base_options_module from mediapipe.tasks.python.core import task_info as task_info_module from mediapipe.tasks.python.core.optional_dependencies import doc_controls @@ -34,7 +34,7 @@ AudioClassifierResult = classification_result_module.ClassificationResult _AudioClassifierGraphOptionsProto = audio_classifier_graph_options_pb2.AudioClassifierGraphOptions _AudioData = audio_data_module.AudioData _BaseOptions = base_options_module.BaseOptions -_ClassifierOptions = classifier_options_module.ClassifierOptions +_ClassifierOptionsProto = classifier_options_pb2.ClassifierOptions _RunningMode = running_mode_module.AudioTaskRunningMode _TaskInfo = task_info_module.TaskInfo @@ -62,15 +62,31 @@ class AudioClassifierOptions: mode for running classification on the audio stream, such as from microphone. In this mode, the "result_callback" below must be specified to receive the classification results asynchronously. - classifier_options: Options for configuring the classifier behavior, such as - score threshold, number of results, etc. + display_names_locale: The locale to use for display names specified through + the TFLite Model Metadata. + max_results: The maximum number of top-scored classification results to + return. + score_threshold: Overrides the ones provided in the model metadata. Results + below this value are rejected. + category_allowlist: Allowlist of category names. If non-empty, + classification results whose category name is not in this set will be + filtered out. Duplicate or unknown category names are ignored. Mutually + exclusive with `category_denylist`. + category_denylist: Denylist of category names. If non-empty, classification + results whose category name is in this set will be filtered out. Duplicate + or unknown category names are ignored. Mutually exclusive with + `category_allowlist`. result_callback: The user-defined result callback for processing audio stream data. The result callback should only be specified when the running mode is set to the audio stream mode. """ base_options: _BaseOptions running_mode: _RunningMode = _RunningMode.AUDIO_CLIPS - classifier_options: _ClassifierOptions = _ClassifierOptions() + display_names_locale: Optional[str] = None + max_results: Optional[int] = None + score_threshold: Optional[float] = None + category_allowlist: Optional[List[str]] = None + category_denylist: Optional[List[str]] = None result_callback: Optional[Callable[[AudioClassifierResult, int], None]] = None @doc_controls.do_not_generate_docs @@ -78,7 +94,12 @@ class AudioClassifierOptions: """Generates an AudioClassifierOptions protobuf object.""" base_options_proto = self.base_options.to_pb2() base_options_proto.use_stream_mode = False if self.running_mode == _RunningMode.AUDIO_CLIPS else True - classifier_options_proto = self.classifier_options.to_pb2() + classifier_options_proto = _ClassifierOptionsProto( + score_threshold=self.score_threshold, + category_allowlist=self.category_allowlist, + category_denylist=self.category_denylist, + display_names_locale=self.display_names_locale, + max_results=self.max_results) return _AudioClassifierGraphOptionsProto( base_options=base_options_proto, @@ -86,7 +107,30 @@ class AudioClassifierOptions: class AudioClassifier(base_audio_task_api.BaseAudioTaskApi): - """Class that performs audio classification on audio data.""" + """Class that performs audio classification on audio data. + + This API expects a TFLite model with mandatory TFLite Model Metadata that + contains the mandatory AudioProperties of the solo input audio tensor and the + optional (but recommended) category labels as AssociatedFiles with type + TENSOR_AXIS_LABELS per output classification tensor. + + Input tensor: + (kTfLiteFloat32) + - input audio buffer of size `[batch * samples]`. + - batch inference is not supported (`batch` is required to be 1). + - for multi-channel models, the channels must be interleaved. + At least one output tensor with: + (kTfLiteFloat32) + - `[1 x N]` array with `N` represents the number of categories. + - optional (but recommended) category labels as AssociatedFiles with type + TENSOR_AXIS_LABELS, containing one label per line. The first such + AssociatedFile (if any) is used to fill the `category_name` field of the + results. The `display_name` field is filled from the AssociatedFile (if + any) whose locale matches the `display_names_locale` field of the + `AudioClassifierOptions` used at creation time ("en" by default, i.e. + English). If none of these are available, only the `index` field of the + results will be filled. + """ @classmethod def create_from_model_path(cls, model_path: str) -> 'AudioClassifier': @@ -257,7 +301,7 @@ class AudioClassifier(base_audio_task_api.BaseAudioTaskApi): Raises: ValueError: If any of the followings: 1) The sample rate is not provided in the `AudioData` object or the - provided sample rate is inconsisent with the previously recevied. + provided sample rate is inconsistent with the previously received. 2) The current input timestamp is smaller than what the audio classifier has already processed. """ @@ -270,7 +314,7 @@ class AudioClassifier(base_audio_task_api.BaseAudioTaskApi): elif audio_block.audio_format.sample_rate != self._default_sample_rate: raise ValueError( f'The audio sample rate provided in audio data: ' - f'{audio_block.audio_format.sample_rate} is inconsisent with ' + f'{audio_block.audio_format.sample_rate} is inconsistent with ' f'the previously received: {self._default_sample_rate}.') self._send_audio_stream_data({ diff --git a/mediapipe/tasks/python/audio/audio_embedder.py b/mediapipe/tasks/python/audio/audio_embedder.py new file mode 100644 index 000000000..4c37783e9 --- /dev/null +++ b/mediapipe/tasks/python/audio/audio_embedder.py @@ -0,0 +1,309 @@ +# Copyright 2022 The MediaPipe Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""MediaPipe audio embedder task.""" + +import dataclasses +from typing import Callable, Mapping, List, Optional + +from mediapipe.python import packet_creator +from mediapipe.python import packet_getter +from mediapipe.python._framework_bindings import packet +from mediapipe.tasks.cc.audio.audio_embedder.proto import audio_embedder_graph_options_pb2 +from mediapipe.tasks.cc.components.containers.proto import embeddings_pb2 +from mediapipe.tasks.cc.components.processors.proto import embedder_options_pb2 +from mediapipe.tasks.python.audio.core import audio_task_running_mode as running_mode_module +from mediapipe.tasks.python.audio.core import base_audio_task_api +from mediapipe.tasks.python.components.containers import audio_data as audio_data_module +from mediapipe.tasks.python.components.containers import embedding_result as embedding_result_module +from mediapipe.tasks.python.components.utils import cosine_similarity +from mediapipe.tasks.python.core import base_options as base_options_module +from mediapipe.tasks.python.core import task_info as task_info_module +from mediapipe.tasks.python.core.optional_dependencies import doc_controls + +AudioEmbedderResult = embedding_result_module.EmbeddingResult +_AudioEmbedderGraphOptionsProto = audio_embedder_graph_options_pb2.AudioEmbedderGraphOptions +_AudioData = audio_data_module.AudioData +_BaseOptions = base_options_module.BaseOptions +_EmbedderOptionsProto = embedder_options_pb2.EmbedderOptions +_RunningMode = running_mode_module.AudioTaskRunningMode +_TaskInfo = task_info_module.TaskInfo + +_AUDIO_IN_STREAM_NAME = 'audio_in' +_AUDIO_TAG = 'AUDIO' +_EMBEDDINGS_STREAM_NAME = 'embeddings_out' +_EMBEDDINGS_TAG = 'EMBEDDINGS' +_SAMPLE_RATE_IN_STREAM_NAME = 'sample_rate_in' +_SAMPLE_RATE_TAG = 'SAMPLE_RATE' +_TASK_GRAPH_NAME = 'mediapipe.tasks.audio.audio_embedder.AudioEmbedderGraph' +_TIMESTAMPTED_EMBEDDINGS_STREAM_NAME = 'timestamped_embeddings_out' +_TIMESTAMPTED_EMBEDDINGS_TAG = 'TIMESTAMPED_EMBEDDINGS' +_MICRO_SECONDS_PER_MILLISECOND = 1000 + + +@dataclasses.dataclass +class AudioEmbedderOptions: + """Options for the audio embedder task. + + Attributes: + base_options: Base options for the audio embedder task. + running_mode: The running mode of the task. Default to the audio clips mode. + Audio embedder task has two running modes: 1) The audio clips mode for + running embedding extraction on independent audio clips. 2) The audio + stream mode for running embedding extraction on the audio stream, such as + from microphone. In this mode, the "result_callback" below must be + specified to receive the embedding results asynchronously. + l2_normalize: Whether to normalize the returned feature vector with L2 norm. + Use this option only if the model does not already contain a native + L2_NORMALIZATION TF Lite Op. In most cases, this is already the case and + L2 norm is thus achieved through TF Lite inference. + quantize: Whether the returned embedding should be quantized to bytes via + scalar quantization. Embeddings are implicitly assumed to be unit-norm and + therefore any dimension is guaranteed to have a value in [-1.0, 1.0]. Use + the l2_normalize option if this is not the case. + result_callback: The user-defined result callback for processing audio + stream data. The result callback should only be specified when the running + mode is set to the audio stream mode. + """ + base_options: _BaseOptions + running_mode: _RunningMode = _RunningMode.AUDIO_CLIPS + l2_normalize: Optional[bool] = None + quantize: Optional[bool] = None + result_callback: Optional[Callable[[AudioEmbedderResult, int], None]] = None + + @doc_controls.do_not_generate_docs + def to_pb2(self) -> _AudioEmbedderGraphOptionsProto: + """Generates an AudioEmbedderOptions protobuf object.""" + base_options_proto = self.base_options.to_pb2() + base_options_proto.use_stream_mode = False if self.running_mode == _RunningMode.AUDIO_CLIPS else True + embedder_options_proto = _EmbedderOptionsProto( + l2_normalize=self.l2_normalize, quantize=self.quantize) + + return _AudioEmbedderGraphOptionsProto( + base_options=base_options_proto, + embedder_options=embedder_options_proto) + + +class AudioEmbedder(base_audio_task_api.BaseAudioTaskApi): + """Class that performs embedding extraction on audio clips or audio stream. + + This API expects a TFLite model with mandatory TFLite Model Metadata that + contains the mandatory AudioProperties of the solo input audio tensor and the + optional (but recommended) label items as AssociatedFiles with type + TENSOR_AXIS_LABELS per output embedding tensor. + + Input tensor: + (kTfLiteFloat32) + - input audio buffer of size `[batch * samples]`. + - batch inference is not supported (`batch` is required to be 1). + - for multi-channel models, the channels must be interleaved. + At least one output tensor with: + (kTfLiteUInt8/kTfLiteFloat32) + - `N` components corresponding to the `N` dimensions of the returned + feature vector for this output layer. + - Either 2 or 4 dimensions, i.e. `[1 x N]` or `[1 x 1 x 1 x N]`. + """ + + @classmethod + def create_from_model_path(cls, model_path: str) -> 'AudioEmbedder': + """Creates an `AudioEmbedder` object from a TensorFlow Lite model and the default `AudioEmbedderOptions`. + + Note that the created `AudioEmbedder` instance is in audio clips mode, for + embedding extraction on the independent audio clips. + + Args: + model_path: Path to the model. + + Returns: + `AudioEmbedder` object that's created from the model file and the + default `AudioEmbedderOptions`. + + Raises: + ValueError: If failed to create `AudioEmbedder` object from the provided + file such as invalid file path. + RuntimeError: If other types of error occurred. + """ + base_options = _BaseOptions(model_asset_path=model_path) + options = AudioEmbedderOptions( + base_options=base_options, running_mode=_RunningMode.AUDIO_CLIPS) + return cls.create_from_options(options) + + @classmethod + def create_from_options(cls, + options: AudioEmbedderOptions) -> 'AudioEmbedder': + """Creates the `AudioEmbedder` object from audio embedder options. + + Args: + options: Options for the audio embedder task. + + Returns: + `AudioEmbedder` object that's created from `options`. + + Raises: + ValueError: If failed to create `AudioEmbedder` object from + `AudioEmbedderOptions` such as missing the model. + RuntimeError: If other types of error occurred. + """ + + def packets_callback(output_packets: Mapping[str, packet.Packet]): + timestamp_ms = output_packets[ + _EMBEDDINGS_STREAM_NAME].timestamp.value // _MICRO_SECONDS_PER_MILLISECOND + if output_packets[_EMBEDDINGS_STREAM_NAME].is_empty(): + options.result_callback( + AudioEmbedderResult(embeddings=[]), timestamp_ms) + return + embedding_result_proto = embeddings_pb2.EmbeddingResult() + embedding_result_proto.CopyFrom( + packet_getter.get_proto(output_packets[_EMBEDDINGS_STREAM_NAME])) + options.result_callback( + AudioEmbedderResult.create_from_pb2(embedding_result_proto), + timestamp_ms) + + task_info = _TaskInfo( + task_graph=_TASK_GRAPH_NAME, + input_streams=[ + ':'.join([_AUDIO_TAG, _AUDIO_IN_STREAM_NAME]), + ':'.join([_SAMPLE_RATE_TAG, _SAMPLE_RATE_IN_STREAM_NAME]) + ], + output_streams=[ + ':'.join([_EMBEDDINGS_TAG, _EMBEDDINGS_STREAM_NAME]), ':'.join([ + _TIMESTAMPTED_EMBEDDINGS_TAG, + _TIMESTAMPTED_EMBEDDINGS_STREAM_NAME + ]) + ], + task_options=options) + return cls( + # Audio tasks should not drop input audio due to flow limiting, which + # may cause data inconsistency. + task_info.generate_graph_config(enable_flow_limiting=False), + options.running_mode, + packets_callback if options.result_callback else None) + + def embed(self, audio_clip: _AudioData) -> List[AudioEmbedderResult]: + """Performs embedding extraction on the provided audio clips. + + The audio clip is represented as a MediaPipe AudioData. The method accepts + audio clips with various length and audio sample rate. It's required to + provide the corresponding audio sample rate within the `AudioData` object. + + The input audio clip may be longer than what the model is able to process + in a single inference. When this occurs, the input audio clip is split into + multiple chunks starting at different timestamps. For this reason, this + function returns a vector of EmbeddingResult objects, each associated + ith a timestamp corresponding to the start (in milliseconds) of the chunk + data on which embedding extraction was carried out. + + Args: + audio_clip: MediaPipe AudioData. + + Returns: + An `AudioEmbedderResult` object that contains a list of embedding result + objects, each associated with a timestamp corresponding to the start + (in milliseconds) of the chunk data on which embedding extraction was + carried out. + + Raises: + ValueError: If any of the input arguments is invalid, such as the sample + rate is not provided in the `AudioData` object. + RuntimeError: If audio embedding extraction failed to run. + """ + if not audio_clip.audio_format.sample_rate: + raise ValueError('Must provide the audio sample rate in audio data.') + output_packets = self._process_audio_clip({ + _AUDIO_IN_STREAM_NAME: + packet_creator.create_matrix(audio_clip.buffer, transpose=True), + _SAMPLE_RATE_IN_STREAM_NAME: + packet_creator.create_double(audio_clip.audio_format.sample_rate) + }) + output_list = [] + embeddings_proto_list = packet_getter.get_proto_list( + output_packets[_TIMESTAMPTED_EMBEDDINGS_STREAM_NAME]) + for proto in embeddings_proto_list: + embedding_result_proto = embeddings_pb2.EmbeddingResult() + embedding_result_proto.CopyFrom(proto) + output_list.append( + AudioEmbedderResult.create_from_pb2(embedding_result_proto)) + return output_list + + def embed_async(self, audio_block: _AudioData, timestamp_ms: int) -> None: + """Sends audio data (a block in a continuous audio stream) to perform audio embedding extraction. + + Only use this method when the AudioEmbedder is created with the audio + stream running mode. The input timestamps should be monotonically increasing + for adjacent calls of this method. This method will return immediately after + the input audio data is accepted. The results will be available via the + `result_callback` provided in the `AudioEmbedderOptions`. The + `embed_async` method is designed to process auido stream data such as + microphone input. + + The input audio data may be longer than what the model is able to process + in a single inference. When this occurs, the input audio block is split + into multiple chunks. For this reason, the callback may be called multiple + times (once per chunk) for each call to this function. + + The `result_callback` provides: + - An `AudioEmbedderResult` object that contains a list of + embeddings. + - The input timestamp in milliseconds. + + Args: + audio_block: MediaPipe AudioData. + timestamp_ms: The timestamp of the input audio data in milliseconds. + + Raises: + ValueError: If any of the followings: + 1) The sample rate is not provided in the `AudioData` object or the + provided sample rate is inconsistent with the previously received. + 2) The current input timestamp is smaller than what the audio + embedder has already processed. + """ + if not audio_block.audio_format.sample_rate: + raise ValueError('Must provide the audio sample rate in audio data.') + if not self._default_sample_rate: + self._default_sample_rate = audio_block.audio_format.sample_rate + self._set_sample_rate(_SAMPLE_RATE_IN_STREAM_NAME, + self._default_sample_rate) + elif audio_block.audio_format.sample_rate != self._default_sample_rate: + raise ValueError( + f'The audio sample rate provided in audio data: ' + f'{audio_block.audio_format.sample_rate} is inconsistent with ' + f'the previously received: {self._default_sample_rate}.') + + self._send_audio_stream_data({ + _AUDIO_IN_STREAM_NAME: + packet_creator.create_matrix(audio_block.buffer, transpose=True).at( + timestamp_ms * _MICRO_SECONDS_PER_MILLISECOND) + }) + + @classmethod + def cosine_similarity(cls, u: embedding_result_module.Embedding, + v: embedding_result_module.Embedding) -> float: + """Utility function to compute cosine similarity between two embedding entries. + + May return an InvalidArgumentError if e.g. the feature vectors are + of different types (quantized vs. float), have different sizes, or have a + an L2-norm of 0. + + Args: + u: An embedding entry. + v: An embedding entry. + + Returns: + The cosine similarity for the two embeddings. + + Raises: + ValueError: May return an error if e.g. the feature vectors are of + different types (quantized vs. float), have different sizes, or have + an L2-norm of 0. + """ + return cosine_similarity.cosine_similarity(u, v) diff --git a/mediapipe/tasks/python/audio/core/BUILD b/mediapipe/tasks/python/audio/core/BUILD index 3cb9cb8e8..5b4203d7b 100644 --- a/mediapipe/tasks/python/audio/core/BUILD +++ b/mediapipe/tasks/python/audio/core/BUILD @@ -14,7 +14,7 @@ # Placeholder for internal Python strict library and test compatibility macro. -package(default_visibility = ["//mediapipe/tasks:internal"]) +package(default_visibility = ["//visibility:public"]) licenses(["notice"]) diff --git a/mediapipe/tasks/python/audio/core/base_audio_task_api.py b/mediapipe/tasks/python/audio/core/base_audio_task_api.py index b2197c142..5b08a2b76 100644 --- a/mediapipe/tasks/python/audio/core/base_audio_task_api.py +++ b/mediapipe/tasks/python/audio/core/base_audio_task_api.py @@ -29,6 +29,7 @@ _RunningMode = running_mode_module.AudioTaskRunningMode _Timestamp = timestamp_module.Timestamp +@doc_controls.do_not_generate_docs class BaseAudioTaskApi(object): """The base class of the user-facing mediapipe audio task api classes.""" @@ -133,12 +134,10 @@ class BaseAudioTaskApi(object): """ self._runner.close() - @doc_controls.do_not_generate_docs def __enter__(self): """Return `self` upon entering the runtime context.""" return self - @doc_controls.do_not_generate_docs def __exit__(self, unused_exc_type, unused_exc_value, unused_traceback): """Shuts down the mediapipe audio task instance on exit of the context manager. diff --git a/mediapipe/tasks/python/components/containers/BUILD b/mediapipe/tasks/python/components/containers/BUILD index d931c26c7..7108617ff 100644 --- a/mediapipe/tasks/python/components/containers/BUILD +++ b/mediapipe/tasks/python/components/containers/BUILD @@ -14,7 +14,7 @@ # Placeholder for internal Python strict library and test compatibility macro. -package(default_visibility = ["//mediapipe/tasks:internal"]) +package(default_visibility = ["//visibility:public"]) licenses(["notice"]) @@ -68,7 +68,7 @@ py_library( name = "category", srcs = ["category.py"], deps = [ - "//mediapipe/tasks/cc/components/containers/proto:category_py_pb2", + "//mediapipe/framework/formats:classification_py_pb2", "//mediapipe/tasks/python/core:optional_dependencies", ], ) diff --git a/mediapipe/tasks/python/components/containers/__init__.py b/mediapipe/tasks/python/components/containers/__init__.py index 65c1214af..17464db36 100644 --- a/mediapipe/tasks/python/components/containers/__init__.py +++ b/mediapipe/tasks/python/components/containers/__init__.py @@ -11,3 +11,43 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. + +"""MediaPipe Tasks Components Containers API.""" + +import mediapipe.tasks.python.components.containers.audio_data +import mediapipe.tasks.python.components.containers.bounding_box +import mediapipe.tasks.python.components.containers.category +import mediapipe.tasks.python.components.containers.classification_result +import mediapipe.tasks.python.components.containers.detections +import mediapipe.tasks.python.components.containers.embedding_result +import mediapipe.tasks.python.components.containers.landmark +import mediapipe.tasks.python.components.containers.landmark_detection_result +import mediapipe.tasks.python.components.containers.rect + +AudioDataFormat = audio_data.AudioDataFormat +AudioData = audio_data.AudioData +BoundingBox = bounding_box.BoundingBox +Category = category.Category +Classifications = classification_result.Classifications +ClassificationResult = classification_result.ClassificationResult +Detection = detections.Detection +DetectionResult = detections.DetectionResult +Embedding = embedding_result.Embedding +EmbeddingResult = embedding_result.EmbeddingResult +Landmark = landmark.Landmark +NormalizedLandmark = landmark.NormalizedLandmark +LandmarksDetectionResult = landmark_detection_result.LandmarksDetectionResult +Rect = rect.Rect +NormalizedRect = rect.NormalizedRect + +# Remove unnecessary modules to avoid duplication in API docs. +del audio_data +del bounding_box +del category +del classification_result +del detections +del embedding_result +del landmark +del landmark_detection_result +del rect +del mediapipe diff --git a/mediapipe/tasks/python/components/containers/category.py b/mediapipe/tasks/python/components/containers/category.py index cfdb83740..9b5419883 100644 --- a/mediapipe/tasks/python/components/containers/category.py +++ b/mediapipe/tasks/python/components/containers/category.py @@ -16,10 +16,10 @@ import dataclasses from typing import Any, Optional -from mediapipe.tasks.cc.components.containers.proto import category_pb2 +from mediapipe.framework.formats import classification_pb2 from mediapipe.tasks.python.core.optional_dependencies import doc_controls -_CategoryProto = category_pb2.Category +_ClassificationProto = classification_pb2.Classification @dataclasses.dataclass @@ -45,23 +45,23 @@ class Category: category_name: Optional[str] = None @doc_controls.do_not_generate_docs - def to_pb2(self) -> _CategoryProto: + def to_pb2(self) -> _ClassificationProto: """Generates a Category protobuf object.""" - return _CategoryProto( + return _ClassificationProto( index=self.index, score=self.score, - display_name=self.display_name, - category_name=self.category_name) + label=self.category_name, + display_name=self.display_name) @classmethod @doc_controls.do_not_generate_docs - def create_from_pb2(cls, pb2_obj: _CategoryProto) -> 'Category': + def create_from_pb2(cls, pb2_obj: _ClassificationProto) -> 'Category': """Creates a `Category` object from the given protobuf object.""" return Category( index=pb2_obj.index, score=pb2_obj.score, display_name=pb2_obj.display_name, - category_name=pb2_obj.category_name) + category_name=pb2_obj.label) def __eq__(self, other: Any) -> bool: """Checks if this object is equal to the given object. diff --git a/mediapipe/tasks/python/components/containers/classification_result.py b/mediapipe/tasks/python/components/containers/classification_result.py index 6ffdabe51..000468041 100644 --- a/mediapipe/tasks/python/components/containers/classification_result.py +++ b/mediapipe/tasks/python/components/containers/classification_result.py @@ -49,11 +49,7 @@ class Classifications: """Generates a Classifications protobuf object.""" classification_list_proto = _ClassificationListProto() for category in self.categories: - classification_proto = _ClassificationProto( - index=category.index, - score=category.score, - label=category.category_name, - display_name=category.display_name) + classification_proto = category.to_pb2() classification_list_proto.classification.append(classification_proto) return _ClassificationsProto( classification_list=classification_list_proto, @@ -65,14 +61,9 @@ class Classifications: def create_from_pb2(cls, pb2_obj: _ClassificationsProto) -> 'Classifications': """Creates a `Classifications` object from the given protobuf object.""" categories = [] - for entry in pb2_obj.classification_list.classification: + for classification in pb2_obj.classification_list.classification: categories.append( - category_module.Category( - index=entry.index, - score=entry.score, - display_name=entry.display_name, - category_name=entry.label)) - + category_module.Category.create_from_pb2(classification)) return Classifications( categories=categories, head_index=pb2_obj.head_index, diff --git a/mediapipe/tasks/python/components/processors/BUILD b/mediapipe/tasks/python/components/processors/BUILD index eef368db0..695f6df91 100644 --- a/mediapipe/tasks/python/components/processors/BUILD +++ b/mediapipe/tasks/python/components/processors/BUILD @@ -16,7 +16,7 @@ # Placeholder for internal Python strict library and test compatibility macro. -package(default_visibility = ["//mediapipe/tasks:internal"]) +package(default_visibility = ["//visibility:public"]) licenses(["notice"]) @@ -28,12 +28,3 @@ py_library( "//mediapipe/tasks/python/core:optional_dependencies", ], ) - -py_library( - name = "embedder_options", - srcs = ["embedder_options.py"], - deps = [ - "//mediapipe/tasks/cc/components/processors/proto:embedder_options_py_pb2", - "//mediapipe/tasks/python/core:optional_dependencies", - ], -) diff --git a/mediapipe/tasks/python/components/processors/__init__.py b/mediapipe/tasks/python/components/processors/__init__.py index 65c1214af..0eb73abe0 100644 --- a/mediapipe/tasks/python/components/processors/__init__.py +++ b/mediapipe/tasks/python/components/processors/__init__.py @@ -11,3 +11,13 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. + +"""MediaPipe Tasks Components Processors API.""" + +import mediapipe.tasks.python.components.processors.classifier_options + +ClassifierOptions = classifier_options.ClassifierOptions + +# Remove unnecessary modules to avoid duplication in API docs. +del classifier_options +del mediapipe diff --git a/mediapipe/tasks/python/components/processors/embedder_options.py b/mediapipe/tasks/python/components/processors/embedder_options.py deleted file mode 100644 index c86a91105..000000000 --- a/mediapipe/tasks/python/components/processors/embedder_options.py +++ /dev/null @@ -1,68 +0,0 @@ -# Copyright 2022 The MediaPipe Authors. All Rights Reserved. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -"""Embedder options data class.""" - -import dataclasses -from typing import Any, Optional - -from mediapipe.tasks.cc.components.processors.proto import embedder_options_pb2 -from mediapipe.tasks.python.core.optional_dependencies import doc_controls - -_EmbedderOptionsProto = embedder_options_pb2.EmbedderOptions - - -@dataclasses.dataclass -class EmbedderOptions: - """Shared options used by all embedding extraction tasks. - - Attributes: - l2_normalize: Whether to normalize the returned feature vector with L2 norm. - Use this option only if the model does not already contain a native - L2_NORMALIZATION TF Lite Op. In most cases, this is already the case and - L2 norm is thus achieved through TF Lite inference. - quantize: Whether the returned embedding should be quantized to bytes via - scalar quantization. Embeddings are implicitly assumed to be unit-norm and - therefore any dimension is guaranteed to have a value in [-1.0, 1.0]. Use - the l2_normalize option if this is not the case. - """ - - l2_normalize: Optional[bool] = None - quantize: Optional[bool] = None - - @doc_controls.do_not_generate_docs - def to_pb2(self) -> _EmbedderOptionsProto: - """Generates a EmbedderOptions protobuf object.""" - return _EmbedderOptionsProto( - l2_normalize=self.l2_normalize, quantize=self.quantize) - - @classmethod - @doc_controls.do_not_generate_docs - def create_from_pb2(cls, pb2_obj: _EmbedderOptionsProto) -> 'EmbedderOptions': - """Creates a `EmbedderOptions` object from the given protobuf object.""" - return EmbedderOptions( - l2_normalize=pb2_obj.l2_normalize, quantize=pb2_obj.quantize) - - def __eq__(self, other: Any) -> bool: - """Checks if this object is equal to the given object. - - Args: - other: The object to be compared with. - - Returns: - True if the objects are equal. - """ - if not isinstance(other, EmbedderOptions): - return False - - return self.to_pb2().__eq__(other.to_pb2()) diff --git a/mediapipe/tasks/python/components/utils/BUILD b/mediapipe/tasks/python/components/utils/BUILD index b64d04c72..1a18531c6 100644 --- a/mediapipe/tasks/python/components/utils/BUILD +++ b/mediapipe/tasks/python/components/utils/BUILD @@ -16,15 +16,12 @@ # Placeholder for internal Python strict library and test compatibility macro. -package(default_visibility = ["//mediapipe/tasks:internal"]) +package(default_visibility = ["//visibility:public"]) licenses(["notice"]) py_library( name = "cosine_similarity", srcs = ["cosine_similarity.py"], - deps = [ - "//mediapipe/tasks/python/components/containers:embedding_result", - "//mediapipe/tasks/python/components/processors:embedder_options", - ], + deps = ["//mediapipe/tasks/python/components/containers:embedding_result"], ) diff --git a/mediapipe/tasks/python/components/utils/cosine_similarity.py b/mediapipe/tasks/python/components/utils/cosine_similarity.py index 486c02ece..ff8979458 100644 --- a/mediapipe/tasks/python/components/utils/cosine_similarity.py +++ b/mediapipe/tasks/python/components/utils/cosine_similarity.py @@ -16,10 +16,8 @@ import numpy as np from mediapipe.tasks.python.components.containers import embedding_result -from mediapipe.tasks.python.components.processors import embedder_options _Embedding = embedding_result.Embedding -_EmbedderOptions = embedder_options.EmbedderOptions def _compute_cosine_similarity(u, v): diff --git a/mediapipe/tasks/python/core/BUILD b/mediapipe/tasks/python/core/BUILD index 76e2f4f4a..6098fb5f5 100644 --- a/mediapipe/tasks/python/core/BUILD +++ b/mediapipe/tasks/python/core/BUILD @@ -14,7 +14,7 @@ # Placeholder for internal Python strict library and test compatibility macro. -package(default_visibility = ["//mediapipe/tasks:internal"]) +package(default_visibility = ["//visibility:public"]) licenses(["notice"]) @@ -23,14 +23,15 @@ py_library( srcs = [ "optional_dependencies.py", ], - deps = [ - "@org_tensorflow//tensorflow/tools/docs:doc_controls", - ], ) py_library( name = "base_options", srcs = ["base_options.py"], + visibility = [ + "//mediapipe/model_maker/python/vision/gesture_recognizer:__subpackages__", + "//mediapipe/tasks:users", + ], deps = [ ":optional_dependencies", "//mediapipe/tasks/cc/core/proto:base_options_py_pb2", @@ -42,6 +43,7 @@ py_library( name = "task_info", srcs = ["task_info.py"], deps = [ + ":optional_dependencies", "//mediapipe/calculators/core:flow_limiter_calculator_py_pb2", "//mediapipe/framework:calculator_options_py_pb2", "//mediapipe/framework:calculator_py_pb2", diff --git a/mediapipe/tasks/python/core/base_options.py b/mediapipe/tasks/python/core/base_options.py index 122dc620f..b48fa2ccc 100644 --- a/mediapipe/tasks/python/core/base_options.py +++ b/mediapipe/tasks/python/core/base_options.py @@ -14,6 +14,7 @@ """Base options for MediaPipe Task APIs.""" import dataclasses +import os from typing import Any, Optional from mediapipe.tasks.cc.core.proto import base_options_pb2 @@ -49,10 +50,14 @@ class BaseOptions: @doc_controls.do_not_generate_docs def to_pb2(self) -> _BaseOptionsProto: """Generates a BaseOptions protobuf object.""" + if self.model_asset_path is not None: + full_path = os.path.abspath(self.model_asset_path) + else: + full_path = None + return _BaseOptionsProto( model_asset=_ExternalFileProto( - file_name=self.model_asset_path, - file_content=self.model_asset_buffer)) + file_name=full_path, file_content=self.model_asset_buffer)) @classmethod @doc_controls.do_not_generate_docs diff --git a/mediapipe/tasks/python/core/optional_dependencies.py b/mediapipe/tasks/python/core/optional_dependencies.py index d4f6a6abc..b1a0ed538 100644 --- a/mediapipe/tasks/python/core/optional_dependencies.py +++ b/mediapipe/tasks/python/core/optional_dependencies.py @@ -13,6 +13,13 @@ # limitations under the License. """MediaPipe Tasks' common but optional dependencies.""" -doc_controls = lambda: None -no_op = lambda x: x -setattr(doc_controls, 'do_not_generate_docs', no_op) +# TensorFlow isn't a dependency of mediapipe pip package. It's only +# required in the API docgen pipeline so we'll ignore it if tensorflow is not +# installed. +try: + from tensorflow.tools.docs import doc_controls +except ModuleNotFoundError: + # Replace the real doc_controls.do_not_generate_docs with an no-op + doc_controls = lambda: None + no_op = lambda x: x + setattr(doc_controls, 'do_not_generate_docs', no_op) diff --git a/mediapipe/tasks/python/core/task_info.py b/mediapipe/tasks/python/core/task_info.py index 31605480f..6ea2cee7b 100644 --- a/mediapipe/tasks/python/core/task_info.py +++ b/mediapipe/tasks/python/core/task_info.py @@ -20,8 +20,10 @@ from typing import Any, List from mediapipe.calculators.core import flow_limiter_calculator_pb2 from mediapipe.framework import calculator_options_pb2 from mediapipe.framework import calculator_pb2 +from mediapipe.tasks.python.core.optional_dependencies import doc_controls +@doc_controls.do_not_generate_docs @dataclasses.dataclass class TaskInfo: """Specifications of a MediaPipe task graph. diff --git a/mediapipe/tasks/python/metadata/metadata.py b/mediapipe/tasks/python/metadata/metadata.py index 10a0b9b66..6afb5a3fa 100644 --- a/mediapipe/tasks/python/metadata/metadata.py +++ b/mediapipe/tasks/python/metadata/metadata.py @@ -106,7 +106,7 @@ class MetadataPopulator(object): The metadata file (or buffer) should be generated based on the metadata schema: - third_party/tensorflow/lite/schema/metadata_schema.fbs + mediapipe/tasks/metadata/metadata_schema.fbs Example usage: Populate matadata and label file into an image classifier model. @@ -860,6 +860,8 @@ def get_metadata_buffer(model_buf): if meta.Name().decode("utf-8") == MetadataPopulator.METADATA_FIELD_NAME: buffer_index = meta.Buffer() metadata = tflite_model.Buffers(buffer_index) + if metadata.DataLength() == 0: + continue return metadata.DataAsNumpy().tobytes() return None diff --git a/mediapipe/tasks/python/test/audio/BUILD b/mediapipe/tasks/python/test/audio/BUILD index 863449126..43f1d417c 100644 --- a/mediapipe/tasks/python/test/audio/BUILD +++ b/mediapipe/tasks/python/test/audio/BUILD @@ -30,7 +30,23 @@ py_test( "//mediapipe/tasks/python/audio/core:audio_task_running_mode", "//mediapipe/tasks/python/components/containers:audio_data", "//mediapipe/tasks/python/components/containers:classification_result", - "//mediapipe/tasks/python/components/processors:classifier_options", + "//mediapipe/tasks/python/core:base_options", + "//mediapipe/tasks/python/test:test_utils", + ], +) + +py_test( + name = "audio_embedder_test", + srcs = ["audio_embedder_test.py"], + data = [ + "//mediapipe/tasks/testdata/audio:test_audio_clips", + "//mediapipe/tasks/testdata/audio:test_models", + ], + deps = [ + "//mediapipe/tasks/python/audio:audio_embedder", + "//mediapipe/tasks/python/audio/core:audio_task_running_mode", + "//mediapipe/tasks/python/components/containers:audio_data", + "//mediapipe/tasks/python/components/containers:embedding_result", "//mediapipe/tasks/python/core:base_options", "//mediapipe/tasks/python/test:test_utils", ], diff --git a/mediapipe/tasks/python/test/audio/audio_classifier_test.py b/mediapipe/tasks/python/test/audio/audio_classifier_test.py index 0d067e587..75146547c 100644 --- a/mediapipe/tasks/python/test/audio/audio_classifier_test.py +++ b/mediapipe/tasks/python/test/audio/audio_classifier_test.py @@ -27,7 +27,6 @@ from mediapipe.tasks.python.audio import audio_classifier from mediapipe.tasks.python.audio.core import audio_task_running_mode from mediapipe.tasks.python.components.containers import audio_data as audio_data_module from mediapipe.tasks.python.components.containers import classification_result as classification_result_module -from mediapipe.tasks.python.components.processors import classifier_options from mediapipe.tasks.python.core import base_options as base_options_module from mediapipe.tasks.python.test import test_utils @@ -36,7 +35,6 @@ _AudioClassifierOptions = audio_classifier.AudioClassifierOptions _AudioClassifierResult = classification_result_module.ClassificationResult _AudioData = audio_data_module.AudioData _BaseOptions = base_options_module.BaseOptions -_ClassifierOptions = classifier_options.ClassifierOptions _RUNNING_MODE = audio_task_running_mode.AudioTaskRunningMode _YAMNET_MODEL_FILE = 'yamnet_audio_classifier_with_metadata.tflite' @@ -210,8 +208,7 @@ class AudioClassifierTest(parameterized.TestCase): with _AudioClassifier.create_from_options( _AudioClassifierOptions( base_options=_BaseOptions(model_asset_path=self.yamnet_model_path), - classifier_options=_ClassifierOptions( - max_results=1))) as classifier: + max_results=1)) as classifier: for audio_file in [_SPEECH_WAV_16K_MONO, _SPEECH_WAV_16K_MONO]: classification_result_list = classifier.classify( self._read_wav_file(audio_file)) @@ -222,8 +219,7 @@ class AudioClassifierTest(parameterized.TestCase): with _AudioClassifier.create_from_options( _AudioClassifierOptions( base_options=_BaseOptions(model_asset_path=self.yamnet_model_path), - classifier_options=_ClassifierOptions( - score_threshold=0.9))) as classifier: + score_threshold=0.9)) as classifier: for audio_file in [_SPEECH_WAV_16K_MONO, _SPEECH_WAV_16K_MONO]: classification_result_list = classifier.classify( self._read_wav_file(audio_file)) @@ -234,8 +230,7 @@ class AudioClassifierTest(parameterized.TestCase): with _AudioClassifier.create_from_options( _AudioClassifierOptions( base_options=_BaseOptions(model_asset_path=self.yamnet_model_path), - classifier_options=_ClassifierOptions( - category_allowlist=['Speech']))) as classifier: + category_allowlist=['Speech'])) as classifier: for audio_file in [_SPEECH_WAV_16K_MONO, _SPEECH_WAV_16K_MONO]: classification_result_list = classifier.classify( self._read_wav_file(audio_file)) @@ -250,8 +245,8 @@ class AudioClassifierTest(parameterized.TestCase): r'exclusive options.'): options = _AudioClassifierOptions( base_options=_BaseOptions(model_asset_path=self.yamnet_model_path), - classifier_options=_ClassifierOptions( - category_allowlist=['foo'], category_denylist=['bar'])) + category_allowlist=['foo'], + category_denylist=['bar']) with _AudioClassifier.create_from_options(options) as unused_classifier: pass @@ -278,8 +273,7 @@ class AudioClassifierTest(parameterized.TestCase): _AudioClassifierOptions( base_options=_BaseOptions( model_asset_path=self.two_heads_model_path), - classifier_options=_ClassifierOptions( - max_results=1))) as classifier: + max_results=1)) as classifier: for audio_file in [_TWO_HEADS_WAV_16K_MONO, _TWO_HEADS_WAV_44K_MONO]: classification_result_list = classifier.classify( self._read_wav_file(audio_file)) @@ -364,7 +358,7 @@ class AudioClassifierTest(parameterized.TestCase): options = _AudioClassifierOptions( base_options=_BaseOptions(model_asset_path=self.yamnet_model_path), running_mode=_RUNNING_MODE.AUDIO_STREAM, - classifier_options=_ClassifierOptions(max_results=1), + max_results=1, result_callback=save_result) classifier = _AudioClassifier.create_from_options(options) audio_data_list = self._read_wav_file_as_stream(audio_file) diff --git a/mediapipe/tasks/python/test/audio/audio_embedder_test.py b/mediapipe/tasks/python/test/audio/audio_embedder_test.py new file mode 100644 index 000000000..f280235d7 --- /dev/null +++ b/mediapipe/tasks/python/test/audio/audio_embedder_test.py @@ -0,0 +1,314 @@ +# Copyright 2022 The MediaPipe Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""Tests for audio embedder.""" +import enum +import os +from typing import List, Tuple +from unittest import mock + +from absl.testing import absltest +from absl.testing import parameterized + +import numpy as np +from scipy.io import wavfile + +from mediapipe.tasks.python.audio import audio_embedder +from mediapipe.tasks.python.audio.core import audio_task_running_mode +from mediapipe.tasks.python.components.containers import audio_data as audio_data_module +from mediapipe.tasks.python.core import base_options as base_options_module +from mediapipe.tasks.python.test import test_utils + +_AudioEmbedder = audio_embedder.AudioEmbedder +_AudioEmbedderOptions = audio_embedder.AudioEmbedderOptions +_AudioEmbedderResult = audio_embedder.AudioEmbedderResult +_AudioData = audio_data_module.AudioData +_BaseOptions = base_options_module.BaseOptions +_RUNNING_MODE = audio_task_running_mode.AudioTaskRunningMode + +_YAMNET_MODEL_FILE = 'yamnet_embedding_metadata.tflite' +_YAMNET_MODEL_SAMPLE_RATE = 16000 +_SPEECH_WAV_16K_MONO = 'speech_16000_hz_mono.wav' +_SPEECH_WAV_48K_MONO = 'speech_48000_hz_mono.wav' +_TWO_HEADS_WAV_16K_MONO = 'two_heads_16000_hz_mono.wav' +_TEST_DATA_DIR = 'mediapipe/tasks/testdata/audio' +_SPEECH_SIMILARITIES = [0.985359, 0.994349, 0.993227, 0.996658, 0.996384] +_YAMNET_NUM_OF_SAMPLES = 15600 +_MILLSECONDS_PER_SECOND = 1000 +# Tolerance for embedding vector coordinate values. +_EPSILON = 3e-6 +# Tolerance for cosine similarity evaluation. +_SIMILARITY_TOLERANCE = 1e-6 + + +class ModelFileType(enum.Enum): + FILE_CONTENT = 1 + FILE_NAME = 2 + + +class AudioEmbedderTest(parameterized.TestCase): + + def setUp(self): + super().setUp() + self.yamnet_model_path = test_utils.get_test_data_path( + os.path.join(_TEST_DATA_DIR, _YAMNET_MODEL_FILE)) + + def _read_wav_file(self, file_name) -> _AudioData: + sample_rate, buffer = wavfile.read( + test_utils.get_test_data_path(os.path.join(_TEST_DATA_DIR, file_name))) + return _AudioData.create_from_array( + buffer.astype(float) / np.iinfo(np.int16).max, sample_rate) + + def _read_wav_file_as_stream(self, file_name) -> List[Tuple[_AudioData, int]]: + sample_rate, buffer = wavfile.read( + test_utils.get_test_data_path(os.path.join(_TEST_DATA_DIR, file_name))) + audio_data_list = [] + start = 0 + step_size = _YAMNET_NUM_OF_SAMPLES * sample_rate / _YAMNET_MODEL_SAMPLE_RATE + while start < len(buffer): + end = min(start + (int)(step_size), len(buffer)) + audio_data_list.append((_AudioData.create_from_array( + buffer[start:end].astype(float) / np.iinfo(np.int16).max, + sample_rate), (int)(start / sample_rate * _MILLSECONDS_PER_SECOND))) + start = end + return audio_data_list + + def _check_embedding_value(self, result, expected_first_value): + # Check embedding first value. + self.assertAlmostEqual( + result.embeddings[0].embedding[0], expected_first_value, delta=_EPSILON) + + def _check_embedding_size(self, result, quantize, expected_embedding_size): + # Check embedding size. + self.assertLen(result.embeddings, 1) + embedding_result = result.embeddings[0] + self.assertLen(embedding_result.embedding, expected_embedding_size) + if quantize: + self.assertEqual(embedding_result.embedding.dtype, np.uint8) + else: + self.assertEqual(embedding_result.embedding.dtype, float) + + def _check_cosine_similarity(self, result0, result1, expected_similarity): + # Checks cosine similarity. + similarity = _AudioEmbedder.cosine_similarity(result0.embeddings[0], + result1.embeddings[0]) + self.assertAlmostEqual( + similarity, expected_similarity, delta=_SIMILARITY_TOLERANCE) + + def _check_yamnet_result(self, + embedding_result0_list: List[_AudioEmbedderResult], + embedding_result1_list: List[_AudioEmbedderResult], + expected_similarities: List[float]): + expected_size = len(expected_similarities) + self.assertLen(embedding_result0_list, expected_size) + self.assertLen(embedding_result1_list, expected_size) + + for idx in range(expected_size): + embedding_result0 = embedding_result0_list[idx] + embedding_result1 = embedding_result1_list[idx] + self._check_cosine_similarity(embedding_result0, embedding_result1, + expected_similarities[idx]) + + def test_create_from_file_succeeds_with_valid_model_path(self): + # Creates with default option and valid model file successfully. + with _AudioEmbedder.create_from_model_path( + self.yamnet_model_path) as embedder: + self.assertIsInstance(embedder, _AudioEmbedder) + + def test_create_from_options_succeeds_with_valid_model_path(self): + # Creates with options containing model file successfully. + with _AudioEmbedder.create_from_options( + _AudioEmbedderOptions( + base_options=_BaseOptions( + model_asset_path=self.yamnet_model_path))) as embedder: + self.assertIsInstance(embedder, _AudioEmbedder) + + def test_create_from_options_fails_with_invalid_model_path(self): + with self.assertRaisesRegex( + RuntimeError, 'Unable to open file at /path/to/invalid/model.tflite'): + base_options = _BaseOptions( + model_asset_path='/path/to/invalid/model.tflite') + options = _AudioEmbedderOptions(base_options=base_options) + _AudioEmbedder.create_from_options(options) + + def test_create_from_options_succeeds_with_valid_model_content(self): + # Creates with options containing model content successfully. + with open(self.yamnet_model_path, 'rb') as f: + base_options = _BaseOptions(model_asset_buffer=f.read()) + options = _AudioEmbedderOptions(base_options=base_options) + embedder = _AudioEmbedder.create_from_options(options) + self.assertIsInstance(embedder, _AudioEmbedder) + + @parameterized.parameters( + # Same audio inputs but different sample rates. + (False, False, ModelFileType.FILE_NAME, _SPEECH_WAV_16K_MONO, + _SPEECH_WAV_48K_MONO, 1024, (0, 0)), + (False, False, ModelFileType.FILE_CONTENT, _SPEECH_WAV_16K_MONO, + _SPEECH_WAV_48K_MONO, 1024, (0, 0))) + def test_embed_with_yamnet_model(self, l2_normalize, quantize, + model_file_type, audio_file0, audio_file1, + expected_size, expected_first_values): + # Creates embedder. + if model_file_type is ModelFileType.FILE_NAME: + base_options = _BaseOptions(model_asset_path=self.yamnet_model_path) + elif model_file_type is ModelFileType.FILE_CONTENT: + with open(self.yamnet_model_path, 'rb') as f: + model_content = f.read() + base_options = _BaseOptions(model_asset_buffer=model_content) + else: + # Should never happen + raise ValueError('model_file_type is invalid.') + + options = _AudioEmbedderOptions( + base_options=base_options, l2_normalize=l2_normalize, quantize=quantize) + + with _AudioEmbedder.create_from_options(options) as embedder: + embedding_result0_list = embedder.embed(self._read_wav_file(audio_file0)) + embedding_result1_list = embedder.embed(self._read_wav_file(audio_file1)) + + # Checks embeddings and cosine similarity. + expected_result0_value, expected_result1_value = expected_first_values + self._check_embedding_size(embedding_result0_list[0], quantize, + expected_size) + self._check_embedding_size(embedding_result1_list[0], quantize, + expected_size) + self._check_embedding_value(embedding_result0_list[0], + expected_result0_value) + self._check_embedding_value(embedding_result1_list[0], + expected_result1_value) + self._check_yamnet_result( + embedding_result0_list, + embedding_result1_list, + expected_similarities=_SPEECH_SIMILARITIES) + + def test_embed_with_yamnet_model_and_different_inputs(self): + with _AudioEmbedder.create_from_model_path( + self.yamnet_model_path) as embedder: + embedding_result0_list = embedder.embed( + self._read_wav_file(_SPEECH_WAV_16K_MONO)) + embedding_result1_list = embedder.embed( + self._read_wav_file(_TWO_HEADS_WAV_16K_MONO)) + self.assertLen(embedding_result0_list, 5) + self.assertLen(embedding_result1_list, 1) + self._check_cosine_similarity( + embedding_result0_list[0], + embedding_result1_list[0], + expected_similarity=0.09017) + + def test_missing_sample_rate_in_audio_clips_mode(self): + options = _AudioEmbedderOptions( + base_options=_BaseOptions(model_asset_path=self.yamnet_model_path), + running_mode=_RUNNING_MODE.AUDIO_CLIPS) + with self.assertRaisesRegex(ValueError, + r'Must provide the audio sample rate'): + with _AudioEmbedder.create_from_options(options) as embedder: + embedder.embed(_AudioData(buffer_length=100)) + + def test_missing_sample_rate_in_audio_stream_mode(self): + options = _AudioEmbedderOptions( + base_options=_BaseOptions(model_asset_path=self.yamnet_model_path), + running_mode=_RUNNING_MODE.AUDIO_STREAM, + result_callback=mock.MagicMock()) + with self.assertRaisesRegex(ValueError, + r'provide the audio sample rate in audio data'): + with _AudioEmbedder.create_from_options(options) as embedder: + embedder.embed(_AudioData(buffer_length=100)) + + def test_missing_result_callback(self): + options = _AudioEmbedderOptions( + base_options=_BaseOptions(model_asset_path=self.yamnet_model_path), + running_mode=_RUNNING_MODE.AUDIO_STREAM) + with self.assertRaisesRegex(ValueError, + r'result callback must be provided'): + with _AudioEmbedder.create_from_options(options) as unused_embedder: + pass + + def test_illegal_result_callback(self): + options = _AudioEmbedderOptions( + base_options=_BaseOptions(model_asset_path=self.yamnet_model_path), + running_mode=_RUNNING_MODE.AUDIO_CLIPS, + result_callback=mock.MagicMock()) + with self.assertRaisesRegex(ValueError, + r'result callback should not be provided'): + with _AudioEmbedder.create_from_options(options) as unused_embedder: + pass + + def test_calling_embed_in_audio_stream_mode(self): + options = _AudioEmbedderOptions( + base_options=_BaseOptions(model_asset_path=self.yamnet_model_path), + running_mode=_RUNNING_MODE.AUDIO_STREAM, + result_callback=mock.MagicMock()) + with _AudioEmbedder.create_from_options(options) as embedder: + with self.assertRaisesRegex(ValueError, + r'not initialized with the audio clips mode'): + embedder.embed(self._read_wav_file(_SPEECH_WAV_16K_MONO)) + + def test_calling_embed_async_in_audio_clips_mode(self): + options = _AudioEmbedderOptions( + base_options=_BaseOptions(model_asset_path=self.yamnet_model_path), + running_mode=_RUNNING_MODE.AUDIO_CLIPS) + with _AudioEmbedder.create_from_options(options) as embedder: + with self.assertRaisesRegex( + ValueError, r'not initialized with the audio stream mode'): + embedder.embed_async(self._read_wav_file(_SPEECH_WAV_16K_MONO), 0) + + def test_embed_async_calls_with_illegal_timestamp(self): + options = _AudioEmbedderOptions( + base_options=_BaseOptions(model_asset_path=self.yamnet_model_path), + running_mode=_RUNNING_MODE.AUDIO_STREAM, + result_callback=mock.MagicMock()) + with _AudioEmbedder.create_from_options(options) as embedder: + embedder.embed_async(self._read_wav_file(_SPEECH_WAV_16K_MONO), 100) + with self.assertRaisesRegex( + ValueError, r'Input timestamp must be monotonically increasing'): + embedder.embed_async(self._read_wav_file(_SPEECH_WAV_16K_MONO), 0) + + @parameterized.parameters( + # Same audio inputs but different sample rates. + (False, False, _SPEECH_WAV_16K_MONO, _SPEECH_WAV_48K_MONO)) + def test_embed_async(self, l2_normalize, quantize, audio_file0, audio_file1): + embedding_result_list = [] + embedding_result_list_copy = embedding_result_list.copy() + + def save_result(result: _AudioEmbedderResult, timestamp_ms: int): + result.timestamp_ms = timestamp_ms + embedding_result_list.append(result) + + options = _AudioEmbedderOptions( + base_options=_BaseOptions(model_asset_path=self.yamnet_model_path), + running_mode=_RUNNING_MODE.AUDIO_STREAM, + l2_normalize=l2_normalize, + quantize=quantize, + result_callback=save_result) + + with _AudioEmbedder.create_from_options(options) as embedder: + audio_data0_list = self._read_wav_file_as_stream(audio_file0) + for audio_data, timestamp_ms in audio_data0_list: + embedder.embed_async(audio_data, timestamp_ms) + embedding_result0_list = embedding_result_list + + with _AudioEmbedder.create_from_options(options) as embedder: + audio_data1_list = self._read_wav_file_as_stream(audio_file1) + embedding_result_list = embedding_result_list_copy + for audio_data, timestamp_ms in audio_data1_list: + embedder.embed_async(audio_data, timestamp_ms) + embedding_result1_list = embedding_result_list + + self._check_yamnet_result( + embedding_result0_list, + embedding_result1_list, + expected_similarities=_SPEECH_SIMILARITIES) + + +if __name__ == '__main__': + absltest.main() diff --git a/mediapipe/tasks/python/test/metadata/metadata_test.py b/mediapipe/tasks/python/test/metadata/metadata_test.py index bed9c2833..d892f1b61 100644 --- a/mediapipe/tasks/python/test/metadata/metadata_test.py +++ b/mediapipe/tasks/python/test/metadata/metadata_test.py @@ -550,7 +550,7 @@ class MetadataPopulatorTest(MetadataTest): ("The number of output tensors (1) should match the number of " "output tensor metadata (0)"), str(error.exception)) - def testLoadMetadataAndAssociatedFilesShouldSucceeds(self): + def testLoadMetadataAndAssociatedFilesShouldSucceed(self): # Create a src model with metadata and two associated files. src_model_buf = self._create_model_buf() populator_src = _metadata.MetadataPopulator.with_model_buffer(src_model_buf) @@ -566,7 +566,7 @@ class MetadataPopulatorTest(MetadataTest): populator_src.get_model_buffer()) populator_dst.populate() - # Tests if the metadata and associated files are populated correctly. + # Test if the metadata and associated files are populated correctly. dst_model_file = self.create_tempfile().full_path with open(dst_model_file, "wb") as f: f.write(populator_dst.get_model_buffer()) @@ -575,6 +575,28 @@ class MetadataPopulatorTest(MetadataTest): recorded_files = populator_dst.get_recorded_associated_file_list() self.assertEqual(set(recorded_files), set(self.expected_recorded_files)) + def testLoadMetadataAndAssociatedFilesShouldSucceedOnEmptyMetadata(self): + # When the user hasn't specified the metadata, but only the associated + # files, an empty metadata buffer is created. Previously, it caused an + # exception when reading. + + # Create a source model with two associated files but no metadata. + src_model_buf = self._create_model_buf() + populator_src = _metadata.MetadataPopulator.with_model_buffer(src_model_buf) + populator_src.load_associated_files([self._file1, self._file2]) + populator_src.populate() + + # Create a model to be populated with the files from `src_model_buf`. + dst_model_buf = self._create_model_buf() + populator_dst = _metadata.MetadataPopulator.with_model_buffer(dst_model_buf) + populator_dst.load_metadata_and_associated_files( + populator_src.get_model_buffer()) + populator_dst.populate() + + # Test if the metadata and associated files are populated correctly. + packed_files = populator_dst.get_packed_associated_file_list() + self.assertEqual(set(packed_files), set(self.expected_recorded_files)) + @parameterized.named_parameters( { "testcase_name": "InputTensorWithBert", diff --git a/mediapipe/tasks/python/test/text/BUILD b/mediapipe/tasks/python/test/text/BUILD index 38e56bdb2..0e2b06012 100644 --- a/mediapipe/tasks/python/test/text/BUILD +++ b/mediapipe/tasks/python/test/text/BUILD @@ -28,7 +28,6 @@ py_test( deps = [ "//mediapipe/tasks/python/components/containers:category", "//mediapipe/tasks/python/components/containers:classification_result", - "//mediapipe/tasks/python/components/processors:classifier_options", "//mediapipe/tasks/python/core:base_options", "//mediapipe/tasks/python/test:test_utils", "//mediapipe/tasks/python/text:text_classifier", @@ -44,7 +43,6 @@ py_test( ], deps = [ "//mediapipe/tasks/python/components/containers:embedding_result", - "//mediapipe/tasks/python/components/processors:embedder_options", "//mediapipe/tasks/python/core:base_options", "//mediapipe/tasks/python/test:test_utils", "//mediapipe/tasks/python/text:text_embedder", diff --git a/mediapipe/tasks/python/test/text/text_classifier_test.py b/mediapipe/tasks/python/test/text/text_classifier_test.py index 8678d2194..8df7dce86 100644 --- a/mediapipe/tasks/python/test/text/text_classifier_test.py +++ b/mediapipe/tasks/python/test/text/text_classifier_test.py @@ -21,14 +21,12 @@ from absl.testing import parameterized from mediapipe.tasks.python.components.containers import category from mediapipe.tasks.python.components.containers import classification_result as classification_result_module -from mediapipe.tasks.python.components.processors import classifier_options from mediapipe.tasks.python.core import base_options as base_options_module from mediapipe.tasks.python.test import test_utils from mediapipe.tasks.python.text import text_classifier TextClassifierResult = classification_result_module.ClassificationResult _BaseOptions = base_options_module.BaseOptions -_ClassifierOptions = classifier_options.ClassifierOptions _Category = category.Category _Classifications = classification_result_module.Classifications _TextClassifier = text_classifier.TextClassifier diff --git a/mediapipe/tasks/python/test/text/text_embedder_test.py b/mediapipe/tasks/python/test/text/text_embedder_test.py index c9090026c..455deba03 100644 --- a/mediapipe/tasks/python/test/text/text_embedder_test.py +++ b/mediapipe/tasks/python/test/text/text_embedder_test.py @@ -21,13 +21,11 @@ from absl.testing import parameterized import numpy as np from mediapipe.tasks.python.components.containers import embedding_result as embedding_result_module -from mediapipe.tasks.python.components.processors import embedder_options as embedder_options_module from mediapipe.tasks.python.core import base_options as base_options_module from mediapipe.tasks.python.test import test_utils from mediapipe.tasks.python.text import text_embedder _BaseOptions = base_options_module.BaseOptions -_EmbedderOptions = embedder_options_module.EmbedderOptions _Embedding = embedding_result_module.Embedding _TextEmbedder = text_embedder.TextEmbedder _TextEmbedderOptions = text_embedder.TextEmbedderOptions @@ -128,10 +126,8 @@ class TextEmbedderTest(parameterized.TestCase): # Should never happen raise ValueError('model_file_type is invalid.') - embedder_options = _EmbedderOptions( - l2_normalize=l2_normalize, quantize=quantize) options = _TextEmbedderOptions( - base_options=base_options, embedder_options=embedder_options) + base_options=base_options, l2_normalize=l2_normalize, quantize=quantize) embedder = _TextEmbedder.create_from_options(options) # Extracts both embeddings. @@ -178,10 +174,8 @@ class TextEmbedderTest(parameterized.TestCase): # Should never happen raise ValueError('model_file_type is invalid.') - embedder_options = _EmbedderOptions( - l2_normalize=l2_normalize, quantize=quantize) options = _TextEmbedderOptions( - base_options=base_options, embedder_options=embedder_options) + base_options=base_options, l2_normalize=l2_normalize, quantize=quantize) with _TextEmbedder.create_from_options(options) as embedder: # Extracts both embeddings. positive_text0 = "it's a charming and often affecting journey" @@ -198,6 +192,36 @@ class TextEmbedderTest(parameterized.TestCase): self._check_embedding_value(result1, expected_result1_value) self._check_cosine_similarity(result0, result1, expected_similarity) + def test_embed_with_mobile_bert_and_different_themes(self): + # Creates embedder. + model_path = test_utils.get_test_data_path( + os.path.join(_TEST_DATA_DIR, _BERT_MODEL_FILE) + ) + base_options = _BaseOptions(model_asset_path=model_path) + options = _TextEmbedderOptions(base_options=base_options) + embedder = _TextEmbedder.create_from_options(options) + + # Extracts both embeddings. + text0 = ( + 'When you go to this restaurant, they hold the pancake upside-down ' + "before they hand it to you. It's a great gimmick." + ) + result0 = embedder.embed(text0) + + text1 = "Let's make a plan to steal the declaration of independence." + result1 = embedder.embed(text1) + + similarity = _TextEmbedder.cosine_similarity( + result0.embeddings[0], result1.embeddings[0] + ) + + # TODO: The similarity should likely be lower + self.assertAlmostEqual(similarity, 0.980880, delta=_SIMILARITY_TOLERANCE) + + # Closes the embedder explicitly when the embedder is not used in + # a context. + embedder.close() + if __name__ == '__main__': absltest.main() diff --git a/mediapipe/tasks/python/test/vision/BUILD b/mediapipe/tasks/python/test/vision/BUILD index 066107421..48ecc30b3 100644 --- a/mediapipe/tasks/python/test/vision/BUILD +++ b/mediapipe/tasks/python/test/vision/BUILD @@ -49,7 +49,6 @@ py_test( "//mediapipe/tasks/python/components/containers:category", "//mediapipe/tasks/python/components/containers:classification_result", "//mediapipe/tasks/python/components/containers:rect", - "//mediapipe/tasks/python/components/processors:classifier_options", "//mediapipe/tasks/python/core:base_options", "//mediapipe/tasks/python/test:test_utils", "//mediapipe/tasks/python/vision:image_classifier", @@ -69,7 +68,6 @@ py_test( "//mediapipe/python:_framework_bindings", "//mediapipe/tasks/python/components/containers:embedding_result", "//mediapipe/tasks/python/components/containers:rect", - "//mediapipe/tasks/python/components/processors:embedder_options", "//mediapipe/tasks/python/core:base_options", "//mediapipe/tasks/python/test:test_utils", "//mediapipe/tasks/python/vision:image_embedder", diff --git a/mediapipe/tasks/python/test/vision/image_classifier_test.py b/mediapipe/tasks/python/test/vision/image_classifier_test.py index 77f16278f..b47efb32b 100644 --- a/mediapipe/tasks/python/test/vision/image_classifier_test.py +++ b/mediapipe/tasks/python/test/vision/image_classifier_test.py @@ -26,7 +26,6 @@ from mediapipe.python._framework_bindings import image from mediapipe.tasks.python.components.containers import category as category_module from mediapipe.tasks.python.components.containers import classification_result as classification_result_module from mediapipe.tasks.python.components.containers import rect -from mediapipe.tasks.python.components.processors import classifier_options from mediapipe.tasks.python.core import base_options as base_options_module from mediapipe.tasks.python.test import test_utils from mediapipe.tasks.python.vision import image_classifier @@ -36,7 +35,6 @@ from mediapipe.tasks.python.vision.core import vision_task_running_mode ImageClassifierResult = classification_result_module.ClassificationResult _Rect = rect.Rect _BaseOptions = base_options_module.BaseOptions -_ClassifierOptions = classifier_options.ClassifierOptions _Category = category_module.Category _Classifications = classification_result_module.Classifications _Image = image.Image @@ -63,7 +61,7 @@ def _generate_empty_results() -> ImageClassifierResult: timestamp_ms=0) -def _generate_burger_results() -> ImageClassifierResult: +def _generate_burger_results(timestamp_ms=0) -> ImageClassifierResult: return ImageClassifierResult( classifications=[ _Classifications( @@ -72,30 +70,36 @@ def _generate_burger_results() -> ImageClassifierResult: index=934, score=0.793959, display_name='', - category_name='cheeseburger'), + category_name='cheeseburger', + ), _Category( index=932, score=0.0273929, display_name='', - category_name='bagel'), + category_name='bagel', + ), _Category( index=925, score=0.0193408, display_name='', - category_name='guacamole'), + category_name='guacamole', + ), _Category( index=963, score=0.00632786, display_name='', - category_name='meat loaf') + category_name='meat loaf', + ), ], head_index=0, - head_name='probability') + head_name='probability', + ) ], - timestamp_ms=0) + timestamp_ms=timestamp_ms, + ) -def _generate_soccer_ball_results() -> ImageClassifierResult: +def _generate_soccer_ball_results(timestamp_ms=0) -> ImageClassifierResult: return ImageClassifierResult( classifications=[ _Classifications( @@ -104,12 +108,15 @@ def _generate_soccer_ball_results() -> ImageClassifierResult: index=806, score=0.996527, display_name='', - category_name='soccer ball') + category_name='soccer ball', + ) ], head_index=0, - head_name='probability') + head_name='probability', + ) ], - timestamp_ms=0) + timestamp_ms=timestamp_ms, + ) class ModelFileType(enum.Enum): @@ -171,9 +178,8 @@ class ImageClassifierTest(parameterized.TestCase): # Should never happen raise ValueError('model_file_type is invalid.') - custom_classifier_options = _ClassifierOptions(max_results=max_results) options = _ImageClassifierOptions( - base_options=base_options, classifier_options=custom_classifier_options) + base_options=base_options, max_results=max_results) classifier = _ImageClassifier.create_from_options(options) # Performs image classification on the input. @@ -200,9 +206,8 @@ class ImageClassifierTest(parameterized.TestCase): # Should never happen raise ValueError('model_file_type is invalid.') - custom_classifier_options = _ClassifierOptions(max_results=max_results) options = _ImageClassifierOptions( - base_options=base_options, classifier_options=custom_classifier_options) + base_options=base_options, max_results=max_results) with _ImageClassifier.create_from_options(options) as classifier: # Performs image classification on the input. image_result = classifier.classify(self.test_image) @@ -212,9 +217,7 @@ class ImageClassifierTest(parameterized.TestCase): def test_classify_succeeds_with_region_of_interest(self): base_options = _BaseOptions(model_asset_path=self.model_path) - custom_classifier_options = _ClassifierOptions(max_results=1) - options = _ImageClassifierOptions( - base_options=base_options, classifier_options=custom_classifier_options) + options = _ImageClassifierOptions(base_options=base_options, max_results=1) with _ImageClassifier.create_from_options(options) as classifier: # Load the test image. test_image = _Image.create_from_file( @@ -230,11 +233,9 @@ class ImageClassifierTest(parameterized.TestCase): _generate_soccer_ball_results().to_pb2()) def test_score_threshold_option(self): - custom_classifier_options = _ClassifierOptions( - score_threshold=_SCORE_THRESHOLD) options = _ImageClassifierOptions( base_options=_BaseOptions(model_asset_path=self.model_path), - classifier_options=custom_classifier_options) + score_threshold=_SCORE_THRESHOLD) with _ImageClassifier.create_from_options(options) as classifier: # Performs image classification on the input. image_result = classifier.classify(self.test_image) @@ -249,11 +250,9 @@ class ImageClassifierTest(parameterized.TestCase): f'{classification}') def test_max_results_option(self): - custom_classifier_options = _ClassifierOptions( - score_threshold=_SCORE_THRESHOLD) options = _ImageClassifierOptions( base_options=_BaseOptions(model_asset_path=self.model_path), - classifier_options=custom_classifier_options) + score_threshold=_SCORE_THRESHOLD) with _ImageClassifier.create_from_options(options) as classifier: # Performs image classification on the input. image_result = classifier.classify(self.test_image) @@ -263,11 +262,9 @@ class ImageClassifierTest(parameterized.TestCase): len(categories), _MAX_RESULTS, 'Too many results returned.') def test_allow_list_option(self): - custom_classifier_options = _ClassifierOptions( - category_allowlist=_ALLOW_LIST) options = _ImageClassifierOptions( base_options=_BaseOptions(model_asset_path=self.model_path), - classifier_options=custom_classifier_options) + category_allowlist=_ALLOW_LIST) with _ImageClassifier.create_from_options(options) as classifier: # Performs image classification on the input. image_result = classifier.classify(self.test_image) @@ -280,10 +277,9 @@ class ImageClassifierTest(parameterized.TestCase): f'Label {label} found but not in label allow list') def test_deny_list_option(self): - custom_classifier_options = _ClassifierOptions(category_denylist=_DENY_LIST) options = _ImageClassifierOptions( base_options=_BaseOptions(model_asset_path=self.model_path), - classifier_options=custom_classifier_options) + category_denylist=_DENY_LIST) with _ImageClassifier.create_from_options(options) as classifier: # Performs image classification on the input. image_result = classifier.classify(self.test_image) @@ -301,19 +297,17 @@ class ImageClassifierTest(parameterized.TestCase): ValueError, r'`category_allowlist` and `category_denylist` are mutually ' r'exclusive options.'): - custom_classifier_options = _ClassifierOptions( - category_allowlist=['foo'], category_denylist=['bar']) options = _ImageClassifierOptions( base_options=_BaseOptions(model_asset_path=self.model_path), - classifier_options=custom_classifier_options) + category_allowlist=['foo'], + category_denylist=['bar']) with _ImageClassifier.create_from_options(options) as unused_classifier: pass def test_empty_classification_outputs(self): - custom_classifier_options = _ClassifierOptions(score_threshold=1) options = _ImageClassifierOptions( base_options=_BaseOptions(model_asset_path=self.model_path), - classifier_options=custom_classifier_options) + score_threshold=1) with _ImageClassifier.create_from_options(options) as classifier: # Performs image classification on the input. image_result = classifier.classify(self.test_image) @@ -386,24 +380,25 @@ class ImageClassifierTest(parameterized.TestCase): classifier.classify_for_video(self.test_image, 0) def test_classify_for_video(self): - custom_classifier_options = _ClassifierOptions(max_results=4) options = _ImageClassifierOptions( base_options=_BaseOptions(model_asset_path=self.model_path), running_mode=_RUNNING_MODE.VIDEO, - classifier_options=custom_classifier_options) + max_results=4) with _ImageClassifier.create_from_options(options) as classifier: for timestamp in range(0, 300, 30): classification_result = classifier.classify_for_video( self.test_image, timestamp) - test_utils.assert_proto_equals(self, classification_result.to_pb2(), - _generate_burger_results().to_pb2()) + test_utils.assert_proto_equals( + self, + classification_result.to_pb2(), + _generate_burger_results(timestamp).to_pb2(), + ) def test_classify_for_video_succeeds_with_region_of_interest(self): - custom_classifier_options = _ClassifierOptions(max_results=1) options = _ImageClassifierOptions( base_options=_BaseOptions(model_asset_path=self.model_path), running_mode=_RUNNING_MODE.VIDEO, - classifier_options=custom_classifier_options) + max_results=1) with _ImageClassifier.create_from_options(options) as classifier: # Load the test image. test_image = _Image.create_from_file( @@ -415,8 +410,11 @@ class ImageClassifierTest(parameterized.TestCase): for timestamp in range(0, 300, 30): classification_result = classifier.classify_for_video( test_image, timestamp, image_processing_options) - test_utils.assert_proto_equals(self, classification_result.to_pb2(), - _generate_soccer_ball_results().to_pb2()) + test_utils.assert_proto_equals( + self, + classification_result.to_pb2(), + _generate_soccer_ball_results(timestamp).to_pb2(), + ) def test_calling_classify_in_live_stream_mode(self): options = _ImageClassifierOptions( @@ -439,11 +437,10 @@ class ImageClassifierTest(parameterized.TestCase): classifier.classify_for_video(self.test_image, 0) def test_classify_async_calls_with_illegal_timestamp(self): - custom_classifier_options = _ClassifierOptions(max_results=4) options = _ImageClassifierOptions( base_options=_BaseOptions(model_asset_path=self.model_path), running_mode=_RUNNING_MODE.LIVE_STREAM, - classifier_options=custom_classifier_options, + max_results=4, result_callback=mock.MagicMock()) with _ImageClassifier.create_from_options(options) as classifier: classifier.classify_async(self.test_image, 100) @@ -466,16 +463,14 @@ class ImageClassifierTest(parameterized.TestCase): self.assertLess(observed_timestamp_ms, timestamp_ms) self.observed_timestamp_ms = timestamp_ms - custom_classifier_options = _ClassifierOptions( - max_results=4, score_threshold=threshold) options = _ImageClassifierOptions( base_options=_BaseOptions(model_asset_path=self.model_path), running_mode=_RUNNING_MODE.LIVE_STREAM, - classifier_options=custom_classifier_options, + max_results=4, + score_threshold=threshold, result_callback=check_result) with _ImageClassifier.create_from_options(options) as classifier: - for timestamp in range(0, 300, 30): - classifier.classify_async(self.test_image, timestamp) + classifier.classify_async(self.test_image, 0) def test_classify_async_succeeds_with_region_of_interest(self): # Load the test image. @@ -489,23 +484,21 @@ class ImageClassifierTest(parameterized.TestCase): def check_result(result: ImageClassifierResult, output_image: _Image, timestamp_ms: int): - test_utils.assert_proto_equals(self, result.to_pb2(), - _generate_soccer_ball_results().to_pb2()) + test_utils.assert_proto_equals( + self, result.to_pb2(), _generate_soccer_ball_results(100).to_pb2() + ) self.assertEqual(output_image.width, test_image.width) self.assertEqual(output_image.height, test_image.height) self.assertLess(observed_timestamp_ms, timestamp_ms) self.observed_timestamp_ms = timestamp_ms - custom_classifier_options = _ClassifierOptions(max_results=1) options = _ImageClassifierOptions( base_options=_BaseOptions(model_asset_path=self.model_path), running_mode=_RUNNING_MODE.LIVE_STREAM, - classifier_options=custom_classifier_options, + max_results=1, result_callback=check_result) with _ImageClassifier.create_from_options(options) as classifier: - for timestamp in range(0, 300, 30): - classifier.classify_async(test_image, timestamp, - image_processing_options) + classifier.classify_async(test_image, 100, image_processing_options) if __name__ == '__main__': diff --git a/mediapipe/tasks/python/test/vision/image_embedder_test.py b/mediapipe/tasks/python/test/vision/image_embedder_test.py index 4bb96bad6..11c0cf002 100644 --- a/mediapipe/tasks/python/test/vision/image_embedder_test.py +++ b/mediapipe/tasks/python/test/vision/image_embedder_test.py @@ -24,7 +24,6 @@ import numpy as np from mediapipe.python._framework_bindings import image as image_module from mediapipe.tasks.python.components.containers import embedding_result as embedding_result_module from mediapipe.tasks.python.components.containers import rect -from mediapipe.tasks.python.components.processors import embedder_options as embedder_options_module from mediapipe.tasks.python.core import base_options as base_options_module from mediapipe.tasks.python.test import test_utils from mediapipe.tasks.python.vision import image_embedder @@ -33,7 +32,6 @@ from mediapipe.tasks.python.vision.core import vision_task_running_mode as runni _Rect = rect.Rect _BaseOptions = base_options_module.BaseOptions -_EmbedderOptions = embedder_options_module.EmbedderOptions _Embedding = embedding_result_module.Embedding _Image = image_module.Image _ImageEmbedder = image_embedder.ImageEmbedder @@ -142,10 +140,8 @@ class ImageEmbedderTest(parameterized.TestCase): # Should never happen raise ValueError('model_file_type is invalid.') - embedder_options = _EmbedderOptions( - l2_normalize=l2_normalize, quantize=quantize) options = _ImageEmbedderOptions( - base_options=base_options, embedder_options=embedder_options) + base_options=base_options, l2_normalize=l2_normalize, quantize=quantize) embedder = _ImageEmbedder.create_from_options(options) image_processing_options = None @@ -186,10 +182,8 @@ class ImageEmbedderTest(parameterized.TestCase): # Should never happen raise ValueError('model_file_type is invalid.') - embedder_options = _EmbedderOptions( - l2_normalize=l2_normalize, quantize=quantize) options = _ImageEmbedderOptions( - base_options=base_options, embedder_options=embedder_options) + base_options=base_options, l2_normalize=l2_normalize, quantize=quantize) with _ImageEmbedder.create_from_options(options) as embedder: # Extracts both embeddings. diff --git a/mediapipe/tasks/python/text/BUILD b/mediapipe/tasks/python/text/BUILD index bb42da912..9d5d23261 100644 --- a/mediapipe/tasks/python/text/BUILD +++ b/mediapipe/tasks/python/text/BUILD @@ -14,7 +14,7 @@ # Placeholder for internal Python strict library and test compatibility macro. -package(default_visibility = ["//mediapipe/tasks:internal"]) +package(default_visibility = ["//visibility:public"]) licenses(["notice"]) @@ -23,13 +23,14 @@ py_library( srcs = [ "text_classifier.py", ], + visibility = ["//mediapipe/tasks:users"], deps = [ "//mediapipe/python:packet_creator", "//mediapipe/python:packet_getter", "//mediapipe/tasks/cc/components/containers/proto:classifications_py_pb2", + "//mediapipe/tasks/cc/components/processors/proto:classifier_options_py_pb2", "//mediapipe/tasks/cc/text/text_classifier/proto:text_classifier_graph_options_py_pb2", "//mediapipe/tasks/python/components/containers:classification_result", - "//mediapipe/tasks/python/components/processors:classifier_options", "//mediapipe/tasks/python/core:base_options", "//mediapipe/tasks/python/core:optional_dependencies", "//mediapipe/tasks/python/core:task_info", @@ -46,9 +47,9 @@ py_library( "//mediapipe/python:packet_creator", "//mediapipe/python:packet_getter", "//mediapipe/tasks/cc/components/containers/proto:embeddings_py_pb2", + "//mediapipe/tasks/cc/components/processors/proto:embedder_options_py_pb2", "//mediapipe/tasks/cc/text/text_embedder/proto:text_embedder_graph_options_py_pb2", "//mediapipe/tasks/python/components/containers:embedding_result", - "//mediapipe/tasks/python/components/processors:embedder_options", "//mediapipe/tasks/python/components/utils:cosine_similarity", "//mediapipe/tasks/python/core:base_options", "//mediapipe/tasks/python/core:optional_dependencies", diff --git a/mediapipe/tasks/python/text/__init__.py b/mediapipe/tasks/python/text/__init__.py index e2473f56b..ecf3a0ad2 100644 --- a/mediapipe/tasks/python/text/__init__.py +++ b/mediapipe/tasks/python/text/__init__.py @@ -15,10 +15,16 @@ """MediaPipe Tasks Text API.""" import mediapipe.tasks.python.text.text_classifier +import mediapipe.tasks.python.text.text_embedder TextClassifier = text_classifier.TextClassifier TextClassifierOptions = text_classifier.TextClassifierOptions +TextClassifierResult = text_classifier.TextClassifierResult +TextEmbedder = text_embedder.TextEmbedder +TextEmbedderOptions = text_embedder.TextEmbedderOptions +TextEmbedderResult = text_embedder.TextEmbedderResult # Remove unnecessary modules to avoid duplication in API docs. del mediapipe del text_classifier +del text_embedder diff --git a/mediapipe/tasks/python/text/core/BUILD b/mediapipe/tasks/python/text/core/BUILD index 072a0c7d8..e76bd4b6d 100644 --- a/mediapipe/tasks/python/text/core/BUILD +++ b/mediapipe/tasks/python/text/core/BUILD @@ -14,7 +14,7 @@ # Placeholder for internal Python strict library and test compatibility macro. -package(default_visibility = ["//mediapipe/tasks:internal"]) +package(default_visibility = ["//visibility:public"]) licenses(["notice"]) diff --git a/mediapipe/tasks/python/text/core/base_text_task_api.py b/mediapipe/tasks/python/text/core/base_text_task_api.py index b22bfff00..1d6311561 100644 --- a/mediapipe/tasks/python/text/core/base_text_task_api.py +++ b/mediapipe/tasks/python/text/core/base_text_task_api.py @@ -20,6 +20,7 @@ from mediapipe.tasks.python.core.optional_dependencies import doc_controls _TaskRunner = task_runner.TaskRunner +@doc_controls.do_not_generate_docs class BaseTextTaskApi(object): """The base class of the user-facing mediapipe text task api classes.""" @@ -40,12 +41,10 @@ class BaseTextTaskApi(object): """ self._runner.close() - @doc_controls.do_not_generate_docs def __enter__(self): """Returns `self` upon entering the runtime context.""" return self - @doc_controls.do_not_generate_docs def __exit__(self, unused_exc_type, unused_exc_value, unused_traceback): """Shuts down the mediapipe text task instance on exit of the context manager. diff --git a/mediapipe/tasks/python/text/text_classifier.py b/mediapipe/tasks/python/text/text_classifier.py index 00f35fada..fdb20f0ef 100644 --- a/mediapipe/tasks/python/text/text_classifier.py +++ b/mediapipe/tasks/python/text/text_classifier.py @@ -14,13 +14,14 @@ """MediaPipe text classifier task.""" import dataclasses +from typing import Optional, List from mediapipe.python import packet_creator from mediapipe.python import packet_getter from mediapipe.tasks.cc.components.containers.proto import classifications_pb2 +from mediapipe.tasks.cc.components.processors.proto import classifier_options_pb2 from mediapipe.tasks.cc.text.text_classifier.proto import text_classifier_graph_options_pb2 from mediapipe.tasks.python.components.containers import classification_result as classification_result_module -from mediapipe.tasks.python.components.processors import classifier_options from mediapipe.tasks.python.core import base_options as base_options_module from mediapipe.tasks.python.core import task_info as task_info_module from mediapipe.tasks.python.core.optional_dependencies import doc_controls @@ -29,7 +30,7 @@ from mediapipe.tasks.python.text.core import base_text_task_api TextClassifierResult = classification_result_module.ClassificationResult _BaseOptions = base_options_module.BaseOptions _TextClassifierGraphOptionsProto = text_classifier_graph_options_pb2.TextClassifierGraphOptions -_ClassifierOptions = classifier_options.ClassifierOptions +_ClassifierOptionsProto = classifier_options_pb2.ClassifierOptions _TaskInfo = task_info_module.TaskInfo _CLASSIFICATIONS_STREAM_NAME = 'classifications_out' @@ -45,16 +46,38 @@ class TextClassifierOptions: Attributes: base_options: Base options for the text classifier task. - classifier_options: Options for the text classification task. + display_names_locale: The locale to use for display names specified through + the TFLite Model Metadata. + max_results: The maximum number of top-scored classification results to + return. + score_threshold: Overrides the ones provided in the model metadata. Results + below this value are rejected. + category_allowlist: Allowlist of category names. If non-empty, + classification results whose category name is not in this set will be + filtered out. Duplicate or unknown category names are ignored. Mutually + exclusive with `category_denylist`. + category_denylist: Denylist of category names. If non-empty, classification + results whose category name is in this set will be filtered out. Duplicate + or unknown category names are ignored. Mutually exclusive with + `category_allowlist`. """ base_options: _BaseOptions - classifier_options: _ClassifierOptions = _ClassifierOptions() + display_names_locale: Optional[str] = None + max_results: Optional[int] = None + score_threshold: Optional[float] = None + category_allowlist: Optional[List[str]] = None + category_denylist: Optional[List[str]] = None @doc_controls.do_not_generate_docs def to_pb2(self) -> _TextClassifierGraphOptionsProto: """Generates an TextClassifierOptions protobuf object.""" base_options_proto = self.base_options.to_pb2() - classifier_options_proto = self.classifier_options.to_pb2() + classifier_options_proto = _ClassifierOptionsProto( + score_threshold=self.score_threshold, + category_allowlist=self.category_allowlist, + category_denylist=self.category_denylist, + display_names_locale=self.display_names_locale, + max_results=self.max_results) return _TextClassifierGraphOptionsProto( base_options=base_options_proto, @@ -62,7 +85,38 @@ class TextClassifierOptions: class TextClassifier(base_text_task_api.BaseTextTaskApi): - """Class that performs classification on text.""" + """Class that performs classification on text. + + This API expects a TFLite model with (optional) TFLite Model Metadata that + contains the mandatory (described below) input tensors, output tensor, + and the optional (but recommended) category labels as AssociatedFiles with + type + TENSOR_AXIS_LABELS per output classification tensor. Metadata is required for + models with int32 input tensors because it contains the input process unit + for the model's Tokenizer. No metadata is required for models with string + input tensors. + + Input tensors: + (kTfLiteInt32) + - 3 input tensors of size `[batch_size x bert_max_seq_len]` representing + the input ids, segment ids, and mask ids + - or 1 input tensor of size `[batch_size x max_seq_len]` representing the + input ids + or (kTfLiteString) + - 1 input tensor that is shapeless or has shape [1] containing the input + string + At least one output tensor with: + (kTfLiteFloat32/kBool) + - `[1 x N]` array with `N` represents the number of categories. + - optional (but recommended) category labels as AssociatedFiles with type + TENSOR_AXIS_LABELS, containing one label per line. The first such + AssociatedFile (if any) is used to fill the `category_name` field of the + results. The `display_name` field is filled from the AssociatedFile (if + any) whose locale matches the `display_names_locale` field of the + `TextClassifierOptions` used at creation time ("en" by default, i.e. + English). If none of these are available, only the `index` field of the + results will be filled. + """ @classmethod def create_from_model_path(cls, model_path: str) -> 'TextClassifier': diff --git a/mediapipe/tasks/python/text/text_embedder.py b/mediapipe/tasks/python/text/text_embedder.py index 2395f6d6b..be899636d 100644 --- a/mediapipe/tasks/python/text/text_embedder.py +++ b/mediapipe/tasks/python/text/text_embedder.py @@ -14,13 +14,14 @@ """MediaPipe text embedder task.""" import dataclasses +from typing import Optional from mediapipe.python import packet_creator from mediapipe.python import packet_getter from mediapipe.tasks.cc.components.containers.proto import embeddings_pb2 +from mediapipe.tasks.cc.components.processors.proto import embedder_options_pb2 from mediapipe.tasks.cc.text.text_embedder.proto import text_embedder_graph_options_pb2 from mediapipe.tasks.python.components.containers import embedding_result as embedding_result_module -from mediapipe.tasks.python.components.processors import embedder_options from mediapipe.tasks.python.components.utils import cosine_similarity from mediapipe.tasks.python.core import base_options as base_options_module from mediapipe.tasks.python.core import task_info as task_info_module @@ -30,7 +31,7 @@ from mediapipe.tasks.python.text.core import base_text_task_api TextEmbedderResult = embedding_result_module.EmbeddingResult _BaseOptions = base_options_module.BaseOptions _TextEmbedderGraphOptionsProto = text_embedder_graph_options_pb2.TextEmbedderGraphOptions -_EmbedderOptions = embedder_options.EmbedderOptions +_EmbedderOptionsProto = embedder_options_pb2.EmbedderOptions _TaskInfo = task_info_module.TaskInfo _EMBEDDINGS_OUT_STREAM_NAME = 'embeddings_out' @@ -46,16 +47,25 @@ class TextEmbedderOptions: Attributes: base_options: Base options for the text embedder task. - embedder_options: Options for the text embedder task. + l2_normalize: Whether to normalize the returned feature vector with L2 norm. + Use this option only if the model does not already contain a native + L2_NORMALIZATION TF Lite Op. In most cases, this is already the case and + L2 norm is thus achieved through TF Lite inference. + quantize: Whether the returned embedding should be quantized to bytes via + scalar quantization. Embeddings are implicitly assumed to be unit-norm and + therefore any dimension is guaranteed to have a value in [-1.0, 1.0]. Use + the l2_normalize option if this is not the case. """ base_options: _BaseOptions - embedder_options: _EmbedderOptions = _EmbedderOptions() + l2_normalize: Optional[bool] = None + quantize: Optional[bool] = None @doc_controls.do_not_generate_docs def to_pb2(self) -> _TextEmbedderGraphOptionsProto: """Generates an TextEmbedderOptions protobuf object.""" base_options_proto = self.base_options.to_pb2() - embedder_options_proto = self.embedder_options.to_pb2() + embedder_options_proto = _EmbedderOptionsProto( + l2_normalize=self.l2_normalize, quantize=self.quantize) return _TextEmbedderGraphOptionsProto( base_options=base_options_proto, @@ -63,7 +73,27 @@ class TextEmbedderOptions: class TextEmbedder(base_text_task_api.BaseTextTaskApi): - """Class that performs embedding extraction on text.""" + """Class that performs embedding extraction on text. + + This API expects a TFLite model with TFLite Model Metadata that contains the + mandatory (described below) input tensors and output tensors. Metadata should + contain the input process unit for the model's Tokenizer as well as input / + output tensor metadata. + + Input tensors: + (kTfLiteInt32) + - 3 input tensors of size `[batch_size x bert_max_seq_len]` with names + "ids", "mask", and "segment_ids" representing the input ids, mask ids, and + segment ids respectively. + - or 1 input tensor of size `[batch_size x max_seq_len]` representing the + input ids. + + At least one output tensor with: + (kTfLiteFloat32) + - `N` components corresponding to the `N` dimensions of the returned + feature vector for this output layer. + - Either 2 or 4 dimensions, i.e. `[1 x N]` or `[1 x 1 x 1 x N]`. + """ @classmethod def create_from_model_path(cls, model_path: str) -> 'TextEmbedder': diff --git a/mediapipe/tasks/python/vision/BUILD b/mediapipe/tasks/python/vision/BUILD index e94507eed..eda8e290d 100644 --- a/mediapipe/tasks/python/vision/BUILD +++ b/mediapipe/tasks/python/vision/BUILD @@ -14,7 +14,7 @@ # Placeholder for internal Python strict library and test compatibility macro. -package(default_visibility = ["//mediapipe/tasks:internal"]) +package(default_visibility = ["//visibility:public"]) licenses(["notice"]) @@ -47,10 +47,10 @@ py_library( "//mediapipe/python:packet_creator", "//mediapipe/python:packet_getter", "//mediapipe/tasks/cc/components/containers/proto:classifications_py_pb2", + "//mediapipe/tasks/cc/components/processors/proto:classifier_options_py_pb2", "//mediapipe/tasks/cc/vision/image_classifier/proto:image_classifier_graph_options_py_pb2", "//mediapipe/tasks/python/components/containers:classification_result", "//mediapipe/tasks/python/components/containers:rect", - "//mediapipe/tasks/python/components/processors:classifier_options", "//mediapipe/tasks/python/core:base_options", "//mediapipe/tasks/python/core:optional_dependencies", "//mediapipe/tasks/python/core:task_info", @@ -69,8 +69,8 @@ py_library( "//mediapipe/python:_framework_bindings", "//mediapipe/python:packet_creator", "//mediapipe/python:packet_getter", - "//mediapipe/tasks/cc/components/proto:segmenter_options_py_pb2", "//mediapipe/tasks/cc/vision/image_segmenter/proto:image_segmenter_graph_options_py_pb2", + "//mediapipe/tasks/cc/vision/image_segmenter/proto:segmenter_options_py_pb2", "//mediapipe/tasks/python/core:base_options", "//mediapipe/tasks/python/core:optional_dependencies", "//mediapipe/tasks/python/core:task_info", @@ -89,9 +89,9 @@ py_library( "//mediapipe/python:packet_creator", "//mediapipe/python:packet_getter", "//mediapipe/tasks/cc/components/containers/proto:embeddings_py_pb2", + "//mediapipe/tasks/cc/components/processors/proto:embedder_options_py_pb2", "//mediapipe/tasks/cc/vision/image_embedder/proto:image_embedder_graph_options_py_pb2", "//mediapipe/tasks/python/components/containers:embedding_result", - "//mediapipe/tasks/python/components/processors:embedder_options", "//mediapipe/tasks/python/components/utils:cosine_similarity", "//mediapipe/tasks/python/core:base_options", "//mediapipe/tasks/python/core:optional_dependencies", @@ -131,6 +131,10 @@ py_library( srcs = [ "hand_landmarker.py", ], + visibility = [ + "//mediapipe/model_maker/python/vision/gesture_recognizer:__subpackages__", + "//mediapipe/tasks:internal", + ], deps = [ "//mediapipe/framework/formats:classification_py_pb2", "//mediapipe/framework/formats:landmark_py_pb2", diff --git a/mediapipe/tasks/python/vision/core/BUILD b/mediapipe/tasks/python/vision/core/BUILD index e2b2b3dec..18df690a0 100644 --- a/mediapipe/tasks/python/vision/core/BUILD +++ b/mediapipe/tasks/python/vision/core/BUILD @@ -14,7 +14,7 @@ # Placeholder for internal Python strict library and test compatibility macro. -package(default_visibility = ["//mediapipe/tasks:internal"]) +package(default_visibility = ["//visibility:public"]) licenses(["notice"]) diff --git a/mediapipe/tasks/python/vision/core/base_vision_task_api.py b/mediapipe/tasks/python/vision/core/base_vision_task_api.py index 016170398..0c8262d4b 100644 --- a/mediapipe/tasks/python/vision/core/base_vision_task_api.py +++ b/mediapipe/tasks/python/vision/core/base_vision_task_api.py @@ -31,6 +31,7 @@ _RunningMode = running_mode_module.VisionTaskRunningMode _ImageProcessingOptions = image_processing_options_module.ImageProcessingOptions +@doc_controls.do_not_generate_docs class BaseVisionTaskApi(object): """The base class of the user-facing mediapipe vision task api classes.""" @@ -178,12 +179,10 @@ class BaseVisionTaskApi(object): """ self._runner.close() - @doc_controls.do_not_generate_docs def __enter__(self): """Return `self` upon entering the runtime context.""" return self - @doc_controls.do_not_generate_docs def __exit__(self, unused_exc_type, unused_exc_value, unused_traceback): """Shuts down the mediapipe vision task instance on exit of the context manager. diff --git a/mediapipe/tasks/python/vision/gesture_recognizer.py b/mediapipe/tasks/python/vision/gesture_recognizer.py index 9b6fd8cab..227203a0d 100644 --- a/mediapipe/tasks/python/vision/gesture_recognizer.py +++ b/mediapipe/tasks/python/vision/gesture_recognizer.py @@ -181,9 +181,11 @@ class GestureRecognizerOptions: min_hand_presence_confidence: Optional[float] = 0.5 min_tracking_confidence: Optional[float] = 0.5 canned_gesture_classifier_options: Optional[ - _ClassifierOptions] = _ClassifierOptions() + _ClassifierOptions] = dataclasses.field( + default_factory=_ClassifierOptions) custom_gesture_classifier_options: Optional[ - _ClassifierOptions] = _ClassifierOptions() + _ClassifierOptions] = dataclasses.field( + default_factory=_ClassifierOptions) result_callback: Optional[Callable[ [GestureRecognizerResult, image_module.Image, int], None]] = None diff --git a/mediapipe/tasks/python/vision/hand_landmarker.py b/mediapipe/tasks/python/vision/hand_landmarker.py index 3367f1da7..a0cd99a83 100644 --- a/mediapipe/tasks/python/vision/hand_landmarker.py +++ b/mediapipe/tasks/python/vision/hand_landmarker.py @@ -14,6 +14,7 @@ """MediaPipe hand landmarker task.""" import dataclasses +import enum from typing import Callable, Mapping, Optional, List from mediapipe.framework.formats import classification_pb2 @@ -53,6 +54,31 @@ _TASK_GRAPH_NAME = 'mediapipe.tasks.vision.hand_landmarker.HandLandmarkerGraph' _MICRO_SECONDS_PER_MILLISECOND = 1000 +class HandLandmark(enum.IntEnum): + """The 21 hand landmarks.""" + WRIST = 0 + THUMB_CMC = 1 + THUMB_MCP = 2 + THUMB_IP = 3 + THUMB_TIP = 4 + INDEX_FINGER_MCP = 5 + INDEX_FINGER_PIP = 6 + INDEX_FINGER_DIP = 7 + INDEX_FINGER_TIP = 8 + MIDDLE_FINGER_MCP = 9 + MIDDLE_FINGER_PIP = 10 + MIDDLE_FINGER_DIP = 11 + MIDDLE_FINGER_TIP = 12 + RING_FINGER_MCP = 13 + RING_FINGER_PIP = 14 + RING_FINGER_DIP = 15 + RING_FINGER_TIP = 16 + PINKY_MCP = 17 + PINKY_PIP = 18 + PINKY_DIP = 19 + PINKY_TIP = 20 + + @dataclasses.dataclass class HandLandmarkerResult: """The hand landmarks result from HandLandmarker, where each vector element represents a single hand detected in the image. diff --git a/mediapipe/tasks/python/vision/image_classifier.py b/mediapipe/tasks/python/vision/image_classifier.py index 0537e7dbb..b60d18e31 100644 --- a/mediapipe/tasks/python/vision/image_classifier.py +++ b/mediapipe/tasks/python/vision/image_classifier.py @@ -14,17 +14,17 @@ """MediaPipe image classifier task.""" import dataclasses -from typing import Callable, Mapping, Optional +from typing import Callable, Mapping, Optional, List from mediapipe.python import packet_creator from mediapipe.python import packet_getter from mediapipe.python._framework_bindings import image as image_module from mediapipe.python._framework_bindings import packet from mediapipe.tasks.cc.components.containers.proto import classifications_pb2 +from mediapipe.tasks.cc.components.processors.proto import classifier_options_pb2 from mediapipe.tasks.cc.vision.image_classifier.proto import image_classifier_graph_options_pb2 from mediapipe.tasks.python.components.containers import classification_result as classification_result_module from mediapipe.tasks.python.components.containers import rect -from mediapipe.tasks.python.components.processors import classifier_options from mediapipe.tasks.python.core import base_options as base_options_module from mediapipe.tasks.python.core import task_info as task_info_module from mediapipe.tasks.python.core.optional_dependencies import doc_controls @@ -36,7 +36,7 @@ ImageClassifierResult = classification_result_module.ClassificationResult _NormalizedRect = rect.NormalizedRect _BaseOptions = base_options_module.BaseOptions _ImageClassifierGraphOptionsProto = image_classifier_graph_options_pb2.ImageClassifierGraphOptions -_ClassifierOptions = classifier_options.ClassifierOptions +_ClassifierOptionsProto = classifier_options_pb2.ClassifierOptions _RunningMode = vision_task_running_mode.VisionTaskRunningMode _ImageProcessingOptions = image_processing_options_module.ImageProcessingOptions _TaskInfo = task_info_module.TaskInfo @@ -63,14 +63,31 @@ class ImageClassifierOptions: objects on single image inputs. 2) The video mode for classifying objects on the decoded frames of a video. 3) The live stream mode for classifying objects on a live stream of input data, such as from camera. - classifier_options: Options for the image classification task. + display_names_locale: The locale to use for display names specified through + the TFLite Model Metadata. + max_results: The maximum number of top-scored classification results to + return. + score_threshold: Overrides the ones provided in the model metadata. Results + below this value are rejected. + category_allowlist: Allowlist of category names. If non-empty, + classification results whose category name is not in this set will be + filtered out. Duplicate or unknown category names are ignored. Mutually + exclusive with `category_denylist`. + category_denylist: Denylist of category names. If non-empty, classification + results whose category name is in this set will be filtered out. Duplicate + or unknown category names are ignored. Mutually exclusive with + `category_allowlist`. result_callback: The user-defined result callback for processing live stream data. The result callback should only be specified when the running mode is set to the live stream mode. """ base_options: _BaseOptions running_mode: _RunningMode = _RunningMode.IMAGE - classifier_options: _ClassifierOptions = _ClassifierOptions() + display_names_locale: Optional[str] = None + max_results: Optional[int] = None + score_threshold: Optional[float] = None + category_allowlist: Optional[List[str]] = None + category_denylist: Optional[List[str]] = None result_callback: Optional[Callable[ [ImageClassifierResult, image_module.Image, int], None]] = None @@ -79,7 +96,12 @@ class ImageClassifierOptions: """Generates an ImageClassifierOptions protobuf object.""" base_options_proto = self.base_options.to_pb2() base_options_proto.use_stream_mode = False if self.running_mode == _RunningMode.IMAGE else True - classifier_options_proto = self.classifier_options.to_pb2() + classifier_options_proto = _ClassifierOptionsProto( + score_threshold=self.score_threshold, + category_allowlist=self.category_allowlist, + category_denylist=self.category_denylist, + display_names_locale=self.display_names_locale, + max_results=self.max_results) return _ImageClassifierGraphOptionsProto( base_options=base_options_proto, @@ -87,7 +109,40 @@ class ImageClassifierOptions: class ImageClassifier(base_vision_task_api.BaseVisionTaskApi): - """Class that performs image classification on images.""" + """Class that performs image classification on images. + + The API expects a TFLite model with optional, but strongly recommended, + TFLite Model Metadata. + + Input tensor: + (kTfLiteUInt8/kTfLiteFloat32) + - image input of size `[batch x height x width x channels]`. + - batch inference is not supported (`batch` is required to be 1). + - only RGB inputs are supported (`channels` is required to be 3). + - if type is kTfLiteFloat32, NormalizationOptions are required to be + attached to the metadata for input normalization. + At least one output tensor with: + (kTfLiteUInt8/kTfLiteFloat32) + - `N `classes and either 2 or 4 dimensions, i.e. `[1 x N]` or + `[1 x 1 x 1 x N]` + - optional (but recommended) label map(s) as AssociatedFiles with type + TENSOR_AXIS_LABELS, containing one label per line. The first such + AssociatedFile (if any) is used to fill the `class_name` field of the + results. The `display_name` field is filled from the AssociatedFile (if + any) whose locale matches the `display_names_locale` field of the + `ImageClassifierOptions` used at creation time ("en" by default, i.e. + English). If none of these are available, only the `index` field of the + results will be filled. + - optional score calibration can be attached using ScoreCalibrationOptions + and an AssociatedFile with type TENSOR_AXIS_SCORE_CALIBRATION. See + metadata_schema.fbs [1] for more details. + + An example of such model can be found at: + https://tfhub.dev/bohemian-visual-recognition-alliance/lite-model/models/mushroom-identification_v1/1 + + [1]: + https://github.com/google/mediapipe/blob/6cdc6443b6a7ed662744e2a2ce2d58d9c83e6d6f/mediapipe/tasks/metadata/metadata_schema.fbs#L456 + """ @classmethod def create_from_model_path(cls, model_path: str) -> 'ImageClassifier': diff --git a/mediapipe/tasks/python/vision/image_embedder.py b/mediapipe/tasks/python/vision/image_embedder.py index 922040397..0bae21bda 100644 --- a/mediapipe/tasks/python/vision/image_embedder.py +++ b/mediapipe/tasks/python/vision/image_embedder.py @@ -21,9 +21,9 @@ from mediapipe.python import packet_getter from mediapipe.python._framework_bindings import image as image_module from mediapipe.python._framework_bindings import packet as packet_module from mediapipe.tasks.cc.components.containers.proto import embeddings_pb2 +from mediapipe.tasks.cc.components.processors.proto import embedder_options_pb2 from mediapipe.tasks.cc.vision.image_embedder.proto import image_embedder_graph_options_pb2 from mediapipe.tasks.python.components.containers import embedding_result as embedding_result_module -from mediapipe.tasks.python.components.processors import embedder_options from mediapipe.tasks.python.components.utils import cosine_similarity from mediapipe.tasks.python.core import base_options as base_options_module from mediapipe.tasks.python.core import task_info as task_info_module @@ -35,7 +35,7 @@ from mediapipe.tasks.python.vision.core import vision_task_running_mode as runni ImageEmbedderResult = embedding_result_module.EmbeddingResult _BaseOptions = base_options_module.BaseOptions _ImageEmbedderGraphOptionsProto = image_embedder_graph_options_pb2.ImageEmbedderGraphOptions -_EmbedderOptions = embedder_options.EmbedderOptions +_EmbedderOptionsProto = embedder_options_pb2.EmbedderOptions _RunningMode = running_mode_module.VisionTaskRunningMode _TaskInfo = task_info_module.TaskInfo _ImageProcessingOptions = image_processing_options_module.ImageProcessingOptions @@ -62,14 +62,22 @@ class ImageEmbedderOptions: image on single image inputs. 2) The video mode for embedding image on the decoded frames of a video. 3) The live stream mode for embedding image on a live stream of input data, such as from camera. - embedder_options: Options for the image embedder task. + l2_normalize: Whether to normalize the returned feature vector with L2 norm. + Use this option only if the model does not already contain a native + L2_NORMALIZATION TF Lite Op. In most cases, this is already the case and + L2 norm is thus achieved through TF Lite inference. + quantize: Whether the returned embedding should be quantized to bytes via + scalar quantization. Embeddings are implicitly assumed to be unit-norm and + therefore any dimension is guaranteed to have a value in [-1.0, 1.0]. Use + the l2_normalize option if this is not the case. result_callback: The user-defined result callback for processing live stream data. The result callback should only be specified when the running mode is set to the live stream mode. """ base_options: _BaseOptions running_mode: _RunningMode = _RunningMode.IMAGE - embedder_options: _EmbedderOptions = _EmbedderOptions() + l2_normalize: Optional[bool] = None + quantize: Optional[bool] = None result_callback: Optional[Callable[ [ImageEmbedderResult, image_module.Image, int], None]] = None @@ -78,7 +86,8 @@ class ImageEmbedderOptions: """Generates an ImageEmbedderOptions protobuf object.""" base_options_proto = self.base_options.to_pb2() base_options_proto.use_stream_mode = False if self.running_mode == _RunningMode.IMAGE else True - embedder_options_proto = self.embedder_options.to_pb2() + embedder_options_proto = _EmbedderOptionsProto( + l2_normalize=self.l2_normalize, quantize=self.quantize) return _ImageEmbedderGraphOptionsProto( base_options=base_options_proto, @@ -86,7 +95,24 @@ class ImageEmbedderOptions: class ImageEmbedder(base_vision_task_api.BaseVisionTaskApi): - """Class that performs embedding extraction on images.""" + """Class that performs embedding extraction on images. + + The API expects a TFLite model with optional, but strongly recommended, + TFLite Model Metadata. + + Input tensor: + (kTfLiteUInt8/kTfLiteFloat32) + - image input of size `[batch x height x width x channels]`. + - batch inference is not supported (`batch` is required to be 1). + - only RGB inputs are supported (`channels` is required to be 3). + - if type is kTfLiteFloat32, NormalizationOptions are required to be + attached to the metadata for input normalization. + At least one output tensor with: + (kTfLiteUInt8/kTfLiteFloat32) + - `N` components corresponding to the `N` dimensions of the returned + feature vector for this output layer. + - Either 2 or 4 dimensions, i.e. `[1 x N]` or `[1 x 1 x 1 x N]`. + """ @classmethod def create_from_model_path(cls, model_path: str) -> 'ImageEmbedder': diff --git a/mediapipe/tasks/python/vision/image_segmenter.py b/mediapipe/tasks/python/vision/image_segmenter.py index 1740d41ef..22a37cb3e 100644 --- a/mediapipe/tasks/python/vision/image_segmenter.py +++ b/mediapipe/tasks/python/vision/image_segmenter.py @@ -21,8 +21,8 @@ from mediapipe.python import packet_creator from mediapipe.python import packet_getter from mediapipe.python._framework_bindings import image as image_module from mediapipe.python._framework_bindings import packet -from mediapipe.tasks.cc.components.proto import segmenter_options_pb2 from mediapipe.tasks.cc.vision.image_segmenter.proto import image_segmenter_graph_options_pb2 +from mediapipe.tasks.cc.vision.image_segmenter.proto import segmenter_options_pb2 from mediapipe.tasks.python.core import base_options as base_options_module from mediapipe.tasks.python.core import task_info as task_info_module from mediapipe.tasks.python.core.optional_dependencies import doc_controls @@ -93,7 +93,29 @@ class ImageSegmenterOptions: class ImageSegmenter(base_vision_task_api.BaseVisionTaskApi): - """Class that performs image segmentation on images.""" + """Class that performs image segmentation on images. + + The API expects a TFLite model with mandatory TFLite Model Metadata. + + Input tensor: + (kTfLiteUInt8/kTfLiteFloat32) + - image input of size `[batch x height x width x channels]`. + - batch inference is not supported (`batch` is required to be 1). + - RGB and greyscale inputs are supported (`channels` is required to be + 1 or 3). + - if type is kTfLiteFloat32, NormalizationOptions are required to be + attached to the metadata for input normalization. + Output tensors: + (kTfLiteUInt8/kTfLiteFloat32) + - list of segmented masks. + - if `output_type` is CATEGORY_MASK, uint8 Image, Image vector of size 1. + - if `output_type` is CONFIDENCE_MASK, float32 Image list of size + `channels`. + - batch is always 1 + + An example of such model can be found at: + https://tfhub.dev/tensorflow/lite-model/deeplabv3/1/metadata/2 + """ @classmethod def create_from_model_path(cls, model_path: str) -> 'ImageSegmenter': diff --git a/mediapipe/tasks/python/vision/object_detector.py b/mediapipe/tasks/python/vision/object_detector.py index f6177cda2..7c9993d62 100644 --- a/mediapipe/tasks/python/vision/object_detector.py +++ b/mediapipe/tasks/python/vision/object_detector.py @@ -98,7 +98,49 @@ class ObjectDetectorOptions: class ObjectDetector(base_vision_task_api.BaseVisionTaskApi): - """Class that performs object detection on images.""" + """Class that performs object detection on images. + + The API expects a TFLite model with mandatory TFLite Model Metadata. + + Input tensor: + (kTfLiteUInt8/kTfLiteFloat32) + - image input of size `[batch x height x width x channels]`. + - batch inference is not supported (`batch` is required to be 1). + - only RGB inputs are supported (`channels` is required to be 3). + - if type is kTfLiteFloat32, NormalizationOptions are required to be + attached to the metadata for input normalization. + Output tensors must be the 4 outputs of a `DetectionPostProcess` op, i.e: + (kTfLiteFloat32) + - locations tensor of size `[num_results x 4]`, the inner array + representing bounding boxes in the form [top, left, right, bottom]. + - BoundingBoxProperties are required to be attached to the metadata + and must specify type=BOUNDARIES and coordinate_type=RATIO. + (kTfLiteFloat32) + - classes tensor of size `[num_results]`, each value representing the + integer index of a class. + - optional (but recommended) label map(s) can be attached as + AssociatedFile-s with type TENSOR_VALUE_LABELS, containing one label per + line. The first such AssociatedFile (if any) is used to fill the + `class_name` field of the results. The `display_name` field is filled + from the AssociatedFile (if any) whose locale matches the + `display_names_locale` field of the `ObjectDetectorOptions` used at + creation time ("en" by default, i.e. English). If none of these are + available, only the `index` field of the results will be filled. + (kTfLiteFloat32) + - scores tensor of size `[num_results]`, each value representing the score + of the detected object. + - optional score calibration can be attached using ScoreCalibrationOptions + and an AssociatedFile with type TENSOR_AXIS_SCORE_CALIBRATION. See + metadata_schema.fbs [1] for more details. + (kTfLiteFloat32) + - integer num_results as a tensor of size `[1]` + + An example of such model can be found at: + https://tfhub.dev/google/lite-model/object_detection/mobile_object_localizer_v1/1/metadata/1 + + [1]: + https://github.com/google/mediapipe/blob/6cdc6443b6a7ed662744e2a2ce2d58d9c83e6d6f/mediapipe/tasks/metadata/metadata_schema.fbs#L456 + """ @classmethod def create_from_model_path(cls, model_path: str) -> 'ObjectDetector': diff --git a/mediapipe/tasks/testdata/text/BUILD b/mediapipe/tasks/testdata/text/BUILD index 081e63c2c..a0131c056 100644 --- a/mediapipe/tasks/testdata/text/BUILD +++ b/mediapipe/tasks/testdata/text/BUILD @@ -18,7 +18,10 @@ load( ) package( - default_visibility = ["//mediapipe/framework:mediapipe_internal"], + default_visibility = [ + "//mediapipe/calculators/tensor:__subpackages__", + "//mediapipe/tasks:__subpackages__", + ], licenses = ["notice"], # Apache 2.0 ) diff --git a/mediapipe/tasks/testdata/vision/BUILD b/mediapipe/tasks/testdata/vision/BUILD index ad8072b87..09f830aba 100644 --- a/mediapipe/tasks/testdata/vision/BUILD +++ b/mediapipe/tasks/testdata/vision/BUILD @@ -37,9 +37,13 @@ mediapipe_files(srcs = [ "coco_ssd_mobilenet_v1_1.0_quant_2018_06_29.tflite", "coco_ssd_mobilenet_v1_1.0_quant_2018_06_29_with_dummy_score_calibration.tflite", "deeplabv3.tflite", + "face_detection_full_range.tflite", + "face_detection_full_range_sparse.tflite", "fist.jpg", + "fist.png", "hand_landmark_full.tflite", "hand_landmark_lite.tflite", + "hand_landmarker.task", "left_hands.jpg", "left_hands_rotated.jpg", "mobilenet_v1_0.25_192_quantized_1_default_1.tflite", @@ -56,6 +60,7 @@ mediapipe_files(srcs = [ "palm_detection_full.tflite", "pointing_up.jpg", "pointing_up_rotated.jpg", + "portrait.jpg", "right_hands.jpg", "right_hands_rotated.jpg", "segmentation_golden_rotation0.png", @@ -77,6 +82,7 @@ exports_files( "expected_right_down_hand_landmarks.prototxt", "expected_right_up_hand_landmarks.prototxt", "gesture_recognizer.task", + "portrait_expected_detection.pbtxt", ], ) @@ -94,6 +100,7 @@ filegroup( "cats_and_dogs_no_resizing.jpg", "cats_and_dogs_rotated.jpg", "fist.jpg", + "fist.png", "hand_landmark_full.tflite", "hand_landmark_lite.tflite", "left_hands.jpg", @@ -103,6 +110,7 @@ filegroup( "multi_objects_rotated.jpg", "pointing_up.jpg", "pointing_up_rotated.jpg", + "portrait.jpg", "right_hands.jpg", "right_hands_rotated.jpg", "segmentation_golden_rotation0.png", @@ -126,6 +134,8 @@ filegroup( "coco_ssd_mobilenet_v1_1.0_quant_2018_06_29.tflite", "coco_ssd_mobilenet_v1_1.0_quant_2018_06_29_with_dummy_score_calibration.tflite", "deeplabv3.tflite", + "face_detection_full_range.tflite", + "face_detection_full_range_sparse.tflite", "hand_landmark_full.tflite", "hand_landmark_lite.tflite", "hand_landmarker.task", @@ -158,6 +168,7 @@ filegroup( "hand_detector_result_two_hands.pbtxt", "pointing_up_landmarks.pbtxt", "pointing_up_rotated_landmarks.pbtxt", + "portrait_expected_detection.pbtxt", "thumb_up_landmarks.pbtxt", "thumb_up_rotated_landmarks.pbtxt", "victory_landmarks.pbtxt", diff --git a/mediapipe/tasks/testdata/vision/portrait_expected_detection.pbtxt b/mediapipe/tasks/testdata/vision/portrait_expected_detection.pbtxt new file mode 100644 index 000000000..775f4479b --- /dev/null +++ b/mediapipe/tasks/testdata/vision/portrait_expected_detection.pbtxt @@ -0,0 +1,35 @@ +# proto-file: mediapipe/framework/formats/detection.proto +# proto-message: Detection +location_data { + format: RELATIVE_BOUNDING_BOX + relative_bounding_box { + xmin: 0.35494408 + ymin: 0.1059662 + width: 0.28768203 + height: 0.23037356 + } + relative_keypoints { + x: 0.44416338 + y: 0.17643969 + } + relative_keypoints { + x: 0.55514044 + y: 0.17731678 + } + relative_keypoints { + x: 0.5046702 + y: 0.2265771 + } + relative_keypoints { + x: 0.50227845 + y: 0.2719954 + } + relative_keypoints { + x: 0.37245658 + y: 0.20143759 + } + relative_keypoints { + x: 0.6084143 + y: 0.20409837 + } +} diff --git a/mediapipe/tasks/web/BUILD b/mediapipe/tasks/web/BUILD index 2c0ea57ef..ff947ef54 100644 --- a/mediapipe/tasks/web/BUILD +++ b/mediapipe/tasks/web/BUILD @@ -1,131 +1,5 @@ -# This contains the MediaPipe Tasks NPM package definitions. - -load("//mediapipe/framework/port:build_config.bzl", "mediapipe_ts_library") -load("@build_bazel_rules_nodejs//:index.bzl", "pkg_npm") -load("@npm//@bazel/rollup:index.bzl", "rollup_bundle") -load( - "//mediapipe/framework/tool:mediapipe_files.bzl", - "mediapipe_files", -) - -package(default_visibility = ["//mediapipe/tasks:internal"]) - -mediapipe_files(srcs = [ - "wasm/audio_wasm_internal.js", - "wasm/audio_wasm_internal.wasm", - "wasm/text_wasm_internal.js", - "wasm/text_wasm_internal.wasm", - "wasm/vision_wasm_internal.js", - "wasm/vision_wasm_internal.wasm", +exports_files([ + "karma.conf.ts", + "package.json", + "rollup.config.mjs", ]) - -# Audio - -mediapipe_ts_library( - name = "audio_lib", - srcs = ["audio.ts"], - deps = ["//mediapipe/tasks/web/audio:audio_lib"], -) - -rollup_bundle( - name = "audio_bundle", - config_file = "rollup.config.mjs", - entry_point = "audio.ts", - format = "cjs", - output_dir = False, - deps = [ - ":audio_lib", - "@npm//@rollup/plugin-commonjs", - "@npm//@rollup/plugin-node-resolve", - ], -) - -pkg_npm( - name = "audio_pkg", - package_name = "@mediapipe/tasks-__NAME__", - srcs = ["package.json"], - substitutions = { - "__NAME__": "audio", - "__DESCRIPTION__": "MediaPipe Audio Tasks", - }, - tgz = "audio.tgz", - deps = [ - "wasm/audio_wasm_internal.js", - "wasm/audio_wasm_internal.wasm", - ":audio_bundle", - ], -) - -# Text - -mediapipe_ts_library( - name = "text_lib", - srcs = ["text.ts"], - deps = ["//mediapipe/tasks/web/text:text_lib"], -) - -rollup_bundle( - name = "text_bundle", - config_file = "rollup.config.mjs", - entry_point = "text.ts", - format = "cjs", - output_dir = False, - deps = [ - ":text_lib", - "@npm//@rollup/plugin-commonjs", - "@npm//@rollup/plugin-node-resolve", - ], -) - -pkg_npm( - name = "text_pkg", - package_name = "@mediapipe/tasks-__NAME__", - srcs = ["package.json"], - substitutions = { - "__NAME__": "text", - "__DESCRIPTION__": "MediaPipe Text Tasks", - }, - tgz = "text.tgz", - deps = [ - "wasm/text_wasm_internal.js", - "wasm/text_wasm_internal.wasm", - ":text_bundle", - ], -) - -# Vision - -mediapipe_ts_library( - name = "vision_lib", - srcs = ["vision.ts"], - deps = ["//mediapipe/tasks/web/vision:vision_lib"], -) - -rollup_bundle( - name = "vision_bundle", - config_file = "rollup.config.mjs", - entry_point = "vision.ts", - format = "cjs", - output_dir = False, - deps = [ - ":vision_lib", - "@npm//@rollup/plugin-commonjs", - "@npm//@rollup/plugin-node-resolve", - ], -) - -pkg_npm( - name = "vision_pkg", - package_name = "@mediapipe/tasks-__NAME__", - srcs = ["package.json"], - substitutions = { - "__NAME__": "vision", - "__DESCRIPTION__": "MediaPipe Vision Tasks", - }, - tgz = "vision_pkg.tgz", - deps = [ - "wasm/vision_wasm_internal.js", - "wasm/vision_wasm_internal.wasm", - ":vision_bundle", - ], -) diff --git a/mediapipe/tasks/web/audio/BUILD b/mediapipe/tasks/web/audio/BUILD index 4f6e48b28..409836800 100644 --- a/mediapipe/tasks/web/audio/BUILD +++ b/mediapipe/tasks/web/audio/BUILD @@ -1,13 +1,81 @@ # This contains the MediaPipe Audio Tasks. load("//mediapipe/framework/port:build_config.bzl", "mediapipe_ts_library") +load("@build_bazel_rules_nodejs//:index.bzl", "pkg_npm") +load("@npm//@bazel/rollup:index.bzl", "rollup_bundle") +load( + "//mediapipe/framework/tool:mediapipe_files.bzl", + "mediapipe_files", +) package(default_visibility = ["//mediapipe/tasks:internal"]) +AUDIO_LIBS = [ + "//mediapipe/tasks/web/audio/audio_classifier", + "//mediapipe/tasks/web/audio/audio_embedder", + "//mediapipe/tasks/web/core:fileset_resolver", +] + mediapipe_ts_library( name = "audio_lib", srcs = ["index.ts"], + visibility = ["//visibility:public"], + deps = AUDIO_LIBS, +) + +mediapipe_ts_library( + name = "audio_types", + srcs = ["types.ts"], + visibility = ["//visibility:public"], + deps = AUDIO_LIBS, +) + +mediapipe_files(srcs = [ + "wasm/audio_wasm_internal.js", + "wasm/audio_wasm_internal.wasm", + "wasm/audio_wasm_nosimd_internal.js", + "wasm/audio_wasm_nosimd_internal.wasm", +]) + +rollup_bundle( + name = "audio_bundle", + config_file = "//mediapipe/tasks/web:rollup.config.mjs", + entry_point = "index.ts", + format = "esm", + output_dir = False, + sourcemap = "false", deps = [ - "//mediapipe/tasks/web/audio/audio_classifier", + ":audio_lib", + "@npm//@rollup/plugin-commonjs", + "@npm//@rollup/plugin-node-resolve", + "@npm//@rollup/plugin-terser", + "@npm//google-protobuf", + ], +) + +genrule( + name = "package_json", + srcs = ["//mediapipe/tasks/web:package.json"], + outs = ["package.json"], + cmd = "cp $< $@", +) + +pkg_npm( + name = "audio_pkg", + package_name = "@mediapipe/tasks-__NAME__", + srcs = ["README.md"], + substitutions = { + "__NAME__": "audio", + "__DESCRIPTION__": "MediaPipe Audio Tasks", + "__TYPES__": "audio.d.ts", + }, + tgz = "audio.tgz", + deps = [ + "wasm/audio_wasm_internal.js", + "wasm/audio_wasm_internal.wasm", + "wasm/audio_wasm_nosimd_internal.js", + "wasm/audio_wasm_nosimd_internal.wasm", + ":audio_bundle", + ":package_json", ], ) diff --git a/mediapipe/tasks/web/audio/README.md b/mediapipe/tasks/web/audio/README.md new file mode 100644 index 000000000..834785709 --- /dev/null +++ b/mediapipe/tasks/web/audio/README.md @@ -0,0 +1,31 @@ +# MediaPipe Tasks Vision Package + +This package contains the audio tasks for MediaPipe. + +## Audio Classification + +The MediaPipe Audio Classification task performs classification on audio data. + +``` +const audio = await FilesetResolver.forAudioTasks( + "https://cdn.jsdelivr.net/npm/@mediapipe/tasks-audio@latest/wasm" +); +const audioClassifier = await AudioClassifier.createFromModelPath(audio, + "https://storage.googleapis.com/mediapipe-tasks/audio_classifier/yamnet_audio_classifier_with_metadata.tflite" +); +const classifications = audioClassifier.classifiy(audioData); +``` + +## Audio Embedding + +The MediaPipe Audio Embedding task extracts embeddings from audio data. + +``` +const audio = await FilesetResolver.forAudioTasks( + "https://cdn.jsdelivr.net/npm/@mediapipe/tasks-audio@latest/wasm" +); +const audioEmbedder = await AudioEmbedder.createFromModelPath(audio, + "model.tflite" +); +const embeddings = audioEmbedder.embed(audioData); +``` diff --git a/mediapipe/tasks/web/audio/audio_classifier/BUILD b/mediapipe/tasks/web/audio/audio_classifier/BUILD index 1bc4af309..a94b4931d 100644 --- a/mediapipe/tasks/web/audio/audio_classifier/BUILD +++ b/mediapipe/tasks/web/audio/audio_classifier/BUILD @@ -2,7 +2,8 @@ # # This task takes audio data and outputs the classification result. -load("//mediapipe/framework/port:build_config.bzl", "mediapipe_ts_library") +load("@npm//@bazel/jasmine:index.bzl", "jasmine_node_test") +load("//mediapipe/framework/port:build_config.bzl", "mediapipe_ts_declaration", "mediapipe_ts_library") package(default_visibility = ["//mediapipe/tasks:internal"]) @@ -10,24 +11,58 @@ licenses(["notice"]) mediapipe_ts_library( name = "audio_classifier", - srcs = [ - "audio_classifier.ts", - "audio_classifier_options.ts", - "audio_classifier_result.ts", - ], + srcs = ["audio_classifier.ts"], + visibility = ["//visibility:public"], deps = [ + ":audio_classifier_types", "//mediapipe/framework:calculator_jspb_proto", "//mediapipe/framework:calculator_options_jspb_proto", "//mediapipe/tasks/cc/audio/audio_classifier/proto:audio_classifier_graph_options_jspb_proto", "//mediapipe/tasks/cc/components/containers/proto:classifications_jspb_proto", + "//mediapipe/tasks/cc/core/proto:base_options_jspb_proto", + "//mediapipe/tasks/web/audio/core:audio_task_runner", "//mediapipe/tasks/web/components/containers:category", "//mediapipe/tasks/web/components/containers:classification_result", - "//mediapipe/tasks/web/components/processors:base_options", "//mediapipe/tasks/web/components/processors:classifier_options", "//mediapipe/tasks/web/components/processors:classifier_result", "//mediapipe/tasks/web/core", "//mediapipe/tasks/web/core:classifier_options", "//mediapipe/tasks/web/core:task_runner", - "//mediapipe/web/graph_runner:wasm_mediapipe_lib_ts", + "//mediapipe/web/graph_runner:graph_runner_ts", ], ) + +mediapipe_ts_declaration( + name = "audio_classifier_types", + srcs = [ + "audio_classifier_options.d.ts", + "audio_classifier_result.d.ts", + ], + visibility = ["//visibility:public"], + deps = [ + "//mediapipe/tasks/web/components/containers:category", + "//mediapipe/tasks/web/components/containers:classification_result", + "//mediapipe/tasks/web/core", + "//mediapipe/tasks/web/core:classifier_options", + ], +) + +mediapipe_ts_library( + name = "audio_classifier_test_lib", + testonly = True, + srcs = [ + "audio_classifier_test.ts", + ], + deps = [ + ":audio_classifier", + "//mediapipe/framework:calculator_jspb_proto", + "//mediapipe/framework/formats:classification_jspb_proto", + "//mediapipe/tasks/cc/components/containers/proto:classifications_jspb_proto", + "//mediapipe/tasks/web/core:task_runner_test_utils", + ], +) + +jasmine_node_test( + name = "audio_classifier_test", + deps = [":audio_classifier_test_lib"], +) diff --git a/mediapipe/tasks/web/audio/audio_classifier/audio_classifier.ts b/mediapipe/tasks/web/audio/audio_classifier/audio_classifier.ts index e3700cd7a..e26ead6a9 100644 --- a/mediapipe/tasks/web/audio/audio_classifier/audio_classifier.ts +++ b/mediapipe/tasks/web/audio/audio_classifier/audio_classifier.ts @@ -18,25 +18,25 @@ import {CalculatorGraphConfig} from '../../../../framework/calculator_pb'; import {CalculatorOptions} from '../../../../framework/calculator_options_pb'; import {AudioClassifierGraphOptions} from '../../../../tasks/cc/audio/audio_classifier/proto/audio_classifier_graph_options_pb'; import {ClassificationResult} from '../../../../tasks/cc/components/containers/proto/classifications_pb'; -import {convertBaseOptionsToProto} from '../../../../tasks/web/components/processors/base_options'; +import {BaseOptions as BaseOptionsProto} from '../../../../tasks/cc/core/proto/base_options_pb'; +import {AudioTaskRunner} from '../../../../tasks/web/audio/core/audio_task_runner'; import {convertClassifierOptionsToProto} from '../../../../tasks/web/components/processors/classifier_options'; import {convertFromClassificationResultProto} from '../../../../tasks/web/components/processors/classifier_result'; -import {TaskRunner} from '../../../../tasks/web/core/task_runner'; -import {WasmLoaderOptions} from '../../../../tasks/web/core/wasm_loader_options'; -import {createMediaPipeLib, FileLocator} from '../../../../web/graph_runner/wasm_mediapipe_lib'; +import {CachedGraphRunner} from '../../../../tasks/web/core/task_runner'; +import {WasmFileset} from '../../../../tasks/web/core/wasm_fileset'; +import {WasmModule} from '../../../../web/graph_runner/graph_runner'; // Placeholder for internal dependency on trusted resource url import {AudioClassifierOptions} from './audio_classifier_options'; import {AudioClassifierResult} from './audio_classifier_result'; +export * from './audio_classifier_options'; +export * from './audio_classifier_result'; + const MEDIAPIPE_GRAPH = 'mediapipe.tasks.audio.audio_classifier.AudioClassifierGraph'; -// Note: `input_audio` is hardcoded in 'gl_graph_runner_internal_audio' and -// cannot be changed -// TODO: Change this to `audio_in` to match the name in the CC -// implementation -const AUDIO_STREAM = 'input_audio'; +const AUDIO_STREAM = 'audio_in'; const SAMPLE_RATE_STREAM = 'sample_rate'; const TIMESTAMPED_CLASSIFICATIONS_STREAM = 'timestamped_classifications'; @@ -44,68 +44,71 @@ const TIMESTAMPED_CLASSIFICATIONS_STREAM = 'timestamped_classifications'; // tslint:disable:jspb-use-builder-pattern /** Performs audio classification. */ -export class AudioClassifier extends TaskRunner { +export class AudioClassifier extends AudioTaskRunner { private classificationResults: AudioClassifierResult[] = []; - private defaultSampleRate = 48000; private readonly options = new AudioClassifierGraphOptions(); /** * Initializes the Wasm runtime and creates a new audio classifier from the * provided options. - * @param wasmLoaderOptions A configuration object that provides the location - * of the Wasm binary and its loader. + * @param wasmFileset A configuration object that provides the location of the + * Wasm binary and its loader. * @param audioClassifierOptions The options for the audio classifier. Note * that either a path to the model asset or a model buffer needs to be * provided (via `baseOptions`). */ - static async createFromOptions( - wasmLoaderOptions: WasmLoaderOptions, - audioClassifierOptions: AudioClassifierOptions): + static createFromOptions( + wasmFileset: WasmFileset, audioClassifierOptions: AudioClassifierOptions): Promise { - // Create a file locator based on the loader options - const fileLocator: FileLocator = { - locateFile() { - // The only file loaded with this mechanism is the Wasm binary - return wasmLoaderOptions.wasmBinaryPath.toString(); - } - }; - - const classifier = await createMediaPipeLib( - AudioClassifier, wasmLoaderOptions.wasmLoaderPath, - /* assetLoaderScript= */ undefined, - /* glCanvas= */ undefined, fileLocator); - await classifier.setOptions(audioClassifierOptions); - return classifier; + return AudioTaskRunner.createInstance( + AudioClassifier, /* initializeCanvas= */ false, wasmFileset, + audioClassifierOptions); } /** * Initializes the Wasm runtime and creates a new audio classifier based on * the provided model asset buffer. - * @param wasmLoaderOptions A configuration object that provides the location - * of the Wasm binary and its loader. + * @param wasmFileset A configuration object that provides the location of the + * Wasm binary and its loader. * @param modelAssetBuffer A binary representation of the model. */ static createFromModelBuffer( - wasmLoaderOptions: WasmLoaderOptions, + wasmFileset: WasmFileset, modelAssetBuffer: Uint8Array): Promise { - return AudioClassifier.createFromOptions( - wasmLoaderOptions, {baseOptions: {modelAssetBuffer}}); + return AudioTaskRunner.createInstance( + AudioClassifier, /* initializeCanvas= */ false, wasmFileset, + {baseOptions: {modelAssetBuffer}}); } /** * Initializes the Wasm runtime and creates a new audio classifier based on * the path to the model asset. - * @param wasmLoaderOptions A configuration object that provides the location - * of the Wasm binary and its loader. + * @param wasmFileset A configuration object that provides the location of the + * Wasm binary and its loader. * @param modelAssetPath The path to the model asset. */ - static async createFromModelPath( - wasmLoaderOptions: WasmLoaderOptions, + static createFromModelPath( + wasmFileset: WasmFileset, modelAssetPath: string): Promise { - const response = await fetch(modelAssetPath.toString()); - const graphData = await response.arrayBuffer(); - return AudioClassifier.createFromModelBuffer( - wasmLoaderOptions, new Uint8Array(graphData)); + return AudioTaskRunner.createInstance( + AudioClassifier, /* initializeCanvas= */ false, wasmFileset, + {baseOptions: {modelAssetPath}}); + } + + /** @hideconstructor */ + constructor( + wasmModule: WasmModule, + glCanvas?: HTMLCanvasElement|OffscreenCanvas|null) { + super(new CachedGraphRunner(wasmModule, glCanvas)); + this.options.setBaseOptions(new BaseOptionsProto()); + } + + protected override get baseOptions(): BaseOptionsProto { + return this.options.getBaseOptions()!; + } + + protected override set baseOptions(proto: BaseOptionsProto) { + this.options.setBaseOptions(proto); } /** @@ -117,34 +120,20 @@ export class AudioClassifier extends TaskRunner { * * @param options The options for the audio classifier. */ - async setOptions(options: AudioClassifierOptions): Promise { - if (options.baseOptions) { - const baseOptionsProto = await convertBaseOptionsToProto( - options.baseOptions, this.options.getBaseOptions()); - this.options.setBaseOptions(baseOptionsProto); - } - + override setOptions(options: AudioClassifierOptions): Promise { this.options.setClassifierOptions(convertClassifierOptionsToProto( options, this.options.getClassifierOptions())); - this.refreshGraph(); + return this.applyOptions(options); } - /** - * Sets the sample rate for all calls to `classify()` that omit an explicit - * sample rate. `48000` is used as a default if this method is not called. - * - * @param sampleRate A sample rate (e.g. `44100`). - */ - setDefaultSampleRate(sampleRate: number) { - this.defaultSampleRate = sampleRate; - } + // TODO: Add a classifyStream() that takes a timestamp /** - * Performs audio classification on the provided audio data and waits + * Performs audio classification on the provided audio clip and waits * synchronously for the response. * - * @param audioData An array of raw audio capture data, like - * from a call to getChannelData on an AudioBuffer. + * @param audioData An array of raw audio capture data, like from a call to + * `getChannelData()` on an AudioBuffer. * @param sampleRate The sample rate in Hz of the provided audio data. If not * set, defaults to the sample rate set via `setDefaultSampleRate()` or * `48000` if no custom default was set. @@ -152,18 +141,18 @@ export class AudioClassifier extends TaskRunner { */ classify(audioData: Float32Array, sampleRate?: number): AudioClassifierResult[] { - sampleRate = sampleRate ?? this.defaultSampleRate; + return this.processAudioClip(audioData, sampleRate); + } - // Configures the number of samples in the WASM layer. We re-configure the - // number of samples and the sample rate for every frame, but ignore other - // side effects of this function (such as sending the input side packet and - // the input stream header). - this.configureAudio( - /* numChannels= */ 1, /* numSamples= */ audioData.length, sampleRate); - - const timestamp = performance.now(); - this.addDoubleToStream(sampleRate, SAMPLE_RATE_STREAM, timestamp); - this.addAudioToStream(audioData, timestamp); + /** Sends an audio package to the graph and returns the classifications. */ + protected override process( + audioData: Float32Array, sampleRate: number, + timestampMs: number): AudioClassifierResult[] { + this.graphRunner.addDoubleToStream( + sampleRate, SAMPLE_RATE_STREAM, timestampMs); + this.graphRunner.addAudioToStreamWithShape( + audioData, /* numChannels= */ 1, /* numSamples= */ audioData.length, + AUDIO_STREAM, timestampMs); this.classificationResults = []; this.finishProcessing(); @@ -184,7 +173,7 @@ export class AudioClassifier extends TaskRunner { } /** Updates the MediaPipe graph configuration. */ - private refreshGraph(): void { + protected override refreshGraph(): void { const graphConfig = new CalculatorGraphConfig(); graphConfig.addInputStream(AUDIO_STREAM); graphConfig.addInputStream(SAMPLE_RATE_STREAM); @@ -206,9 +195,10 @@ export class AudioClassifier extends TaskRunner { graphConfig.addNode(classifierNode); - this.attachProtoVectorListener( - TIMESTAMPED_CLASSIFICATIONS_STREAM, binaryProtos => { + this.graphRunner.attachProtoVectorListener( + TIMESTAMPED_CLASSIFICATIONS_STREAM, (binaryProtos, timestamp) => { this.addJsAudioClassificationResults(binaryProtos); + this.setLatestOutputTimestamp(timestamp); }); const binaryGraph = graphConfig.serializeBinary(); diff --git a/mediapipe/tasks/web/audio/audio_classifier/audio_classifier_options.d.ts b/mediapipe/tasks/web/audio/audio_classifier/audio_classifier_options.d.ts new file mode 100644 index 000000000..dc3c494bf --- /dev/null +++ b/mediapipe/tasks/web/audio/audio_classifier/audio_classifier_options.d.ts @@ -0,0 +1,22 @@ +/** + * Copyright 2022 The MediaPipe Authors. All Rights Reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +import {ClassifierOptions} from '../../../../tasks/web/core/classifier_options'; +import {TaskRunnerOptions} from '../../../../tasks/web/core/task_runner_options'; + +/** Options to configure the MediaPipe Audio Classifier Task */ +export declare interface AudioClassifierOptions extends ClassifierOptions, + TaskRunnerOptions {} diff --git a/mediapipe/tasks/web/audio/audio_classifier/audio_classifier_result.ts b/mediapipe/tasks/web/audio/audio_classifier/audio_classifier_result.d.ts similarity index 100% rename from mediapipe/tasks/web/audio/audio_classifier/audio_classifier_result.ts rename to mediapipe/tasks/web/audio/audio_classifier/audio_classifier_result.d.ts diff --git a/mediapipe/tasks/web/audio/audio_classifier/audio_classifier_test.ts b/mediapipe/tasks/web/audio/audio_classifier/audio_classifier_test.ts new file mode 100644 index 000000000..b7bb158de --- /dev/null +++ b/mediapipe/tasks/web/audio/audio_classifier/audio_classifier_test.ts @@ -0,0 +1,212 @@ +/** + * Copyright 2022 The MediaPipe Authors. All Rights Reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +import 'jasmine'; + +import {CalculatorGraphConfig} from '../../../../framework/calculator_pb'; +import {Classification, ClassificationList} from '../../../../framework/formats/classification_pb'; +import {ClassificationResult, Classifications} from '../../../../tasks/cc/components/containers/proto/classifications_pb'; +import {addJasmineCustomFloatEqualityTester, createSpyWasmModule, MediapipeTasksFake, verifyGraph, verifyListenersRegistered} from '../../../../tasks/web/core/task_runner_test_utils'; + +import {AudioClassifier} from './audio_classifier'; + +// The OSS JS API does not support the builder pattern. +// tslint:disable:jspb-use-builder-pattern + +class AudioClassifierFake extends AudioClassifier implements + MediapipeTasksFake { + lastSampleRate: number|undefined; + calculatorName = + 'mediapipe.tasks.audio.audio_classifier.AudioClassifierGraph'; + attachListenerSpies: jasmine.Spy[] = []; + graph: CalculatorGraphConfig|undefined; + + private protoVectorListener: + ((binaryProtos: Uint8Array[], timestamp: number) => void)|undefined; + private resultProtoVector: ClassificationResult[] = []; + + constructor() { + super(createSpyWasmModule(), /* glCanvas= */ null); + + this.attachListenerSpies[0] = + spyOn(this.graphRunner, 'attachProtoVectorListener') + .and.callFake((stream, listener) => { + expect(stream).toEqual('timestamped_classifications'); + this.protoVectorListener = listener; + }); + spyOn(this.graphRunner, 'addDoubleToStream') + .and.callFake((sampleRate, streamName, timestamp) => { + if (streamName === 'sample_rate') { + this.lastSampleRate = sampleRate; + } + }); + spyOn(this.graphRunner, 'addAudioToStreamWithShape') + .and.callFake( + (audioData, numChannels, numSamples, streamName, timestamp) => { + expect(numChannels).toBe(1); + }); + spyOn(this.graphRunner, 'finishProcessing').and.callFake(() => { + if (!this.protoVectorListener) return; + this.protoVectorListener( + this.resultProtoVector.map( + classificationResult => classificationResult.serializeBinary()), + 1337); + }); + spyOn(this.graphRunner, 'setGraph').and.callFake(binaryGraph => { + this.graph = CalculatorGraphConfig.deserializeBinary(binaryGraph); + }); + } + + /** Sets the Protobuf that will be send to the API. */ + setResults(results: ClassificationResult[]): void { + this.resultProtoVector = results; + } +} + +describe('AudioClassifier', () => { + let audioClassifier: AudioClassifierFake; + + beforeEach(async () => { + addJasmineCustomFloatEqualityTester(); + audioClassifier = new AudioClassifierFake(); + await audioClassifier.setOptions( + {baseOptions: {modelAssetBuffer: new Uint8Array([])}}); + }); + + it('initializes graph', async () => { + verifyGraph(audioClassifier); + verifyListenersRegistered(audioClassifier); + }); + + it('reloads graph when settings are changed', async () => { + await audioClassifier.setOptions({maxResults: 1}); + verifyGraph(audioClassifier, [['classifierOptions', 'maxResults'], 1]); + verifyListenersRegistered(audioClassifier); + + await audioClassifier.setOptions({maxResults: 5}); + verifyGraph(audioClassifier, [['classifierOptions', 'maxResults'], 5]); + verifyListenersRegistered(audioClassifier); + }); + + it('merges options', async () => { + await audioClassifier.setOptions({maxResults: 1}); + await audioClassifier.setOptions({displayNamesLocale: 'en'}); + verifyGraph(audioClassifier, [ + 'classifierOptions', { + maxResults: 1, + displayNamesLocale: 'en', + scoreThreshold: undefined, + categoryAllowlistList: [], + categoryDenylistList: [] + } + ]); + }); + + it('uses a sample rate of 48000 by default', async () => { + audioClassifier.classify(new Float32Array([])); + expect(audioClassifier.lastSampleRate).toEqual(48000); + }); + + it('uses default sample rate if none provided', async () => { + audioClassifier.setDefaultSampleRate(16000); + audioClassifier.classify(new Float32Array([])); + expect(audioClassifier.lastSampleRate).toEqual(16000); + }); + + it('uses custom sample rate if provided', async () => { + audioClassifier.setDefaultSampleRate(16000); + audioClassifier.classify(new Float32Array([]), 44100); + expect(audioClassifier.lastSampleRate).toEqual(44100); + }); + + it('transforms results', async () => { + const resultProtoVector: ClassificationResult[] = []; + + let classificationResult = new ClassificationResult(); + classificationResult.setTimestampMs(0); + let classifcations = new Classifications(); + classifcations.setHeadIndex(1); + classifcations.setHeadName('headName'); + let classificationList = new ClassificationList(); + let classification = new Classification(); + classification.setIndex(1); + classification.setScore(0.2); + classification.setDisplayName('displayName'); + classification.setLabel('categoryName'); + classificationList.addClassification(classification); + classifcations.setClassificationList(classificationList); + classificationResult.addClassifications(classifcations); + resultProtoVector.push(classificationResult); + + classificationResult = new ClassificationResult(); + classificationResult.setTimestampMs(1); + classifcations = new Classifications(); + classificationList = new ClassificationList(); + classification = new Classification(); + classification.setIndex(2); + classification.setScore(0.3); + classificationList.addClassification(classification); + classifcations.setClassificationList(classificationList); + classificationResult.addClassifications(classifcations); + resultProtoVector.push(classificationResult); + + // Invoke the audio classifier + audioClassifier.setResults(resultProtoVector); + const results = audioClassifier.classify(new Float32Array([])); + expect(results.length).toEqual(2); + expect(results[0]).toEqual({ + classifications: [{ + categories: [{ + index: 1, + score: 0.2, + displayName: 'displayName', + categoryName: 'categoryName' + }], + headIndex: 1, + headName: 'headName' + }], + timestampMs: 0 + }); + expect(results[1]).toEqual({ + classifications: [{ + categories: [{index: 2, score: 0.3, displayName: '', categoryName: ''}], + headIndex: 0, + headName: '' + }], + timestampMs: 1 + }); + }); + + it('clears results between invocations', async () => { + const classificationResult = new ClassificationResult(); + const classifcations = new Classifications(); + const classificationList = new ClassificationList(); + const classification = new Classification(); + classificationList.addClassification(classification); + classifcations.setClassificationList(classificationList); + classificationResult.addClassifications(classifcations); + + audioClassifier.setResults([classificationResult]); + + // Invoke the gesture recognizer twice + const classifications1 = audioClassifier.classify(new Float32Array([])); + const classifications2 = audioClassifier.classify(new Float32Array([])); + + // Verify that gestures2 is not a concatenation of all previously returned + // gestures. + expect(classifications1).toEqual(classifications2); + }); +}); diff --git a/mediapipe/tasks/web/audio/audio_embedder/BUILD b/mediapipe/tasks/web/audio/audio_embedder/BUILD new file mode 100644 index 000000000..68a7f7bd5 --- /dev/null +++ b/mediapipe/tasks/web/audio/audio_embedder/BUILD @@ -0,0 +1,67 @@ +# This contains the MediaPipe Audio Embedder Task. +# +# This task takes audio input and performs embedding. + +load("//mediapipe/framework/port:build_config.bzl", "mediapipe_ts_declaration", "mediapipe_ts_library") +load("@npm//@bazel/jasmine:index.bzl", "jasmine_node_test") + +package(default_visibility = ["//mediapipe/tasks:internal"]) + +licenses(["notice"]) + +mediapipe_ts_library( + name = "audio_embedder", + srcs = ["audio_embedder.ts"], + visibility = ["//visibility:public"], + deps = [ + ":audio_embedder_types", + "//mediapipe/framework:calculator_jspb_proto", + "//mediapipe/framework:calculator_options_jspb_proto", + "//mediapipe/tasks/cc/audio/audio_embedder/proto:audio_embedder_graph_options_jspb_proto", + "//mediapipe/tasks/cc/components/containers/proto:embeddings_jspb_proto", + "//mediapipe/tasks/cc/core/proto:base_options_jspb_proto", + "//mediapipe/tasks/web/audio/core:audio_task_runner", + "//mediapipe/tasks/web/components/containers:embedding_result", + "//mediapipe/tasks/web/components/processors:embedder_options", + "//mediapipe/tasks/web/components/processors:embedder_result", + "//mediapipe/tasks/web/components/utils:cosine_similarity", + "//mediapipe/tasks/web/core", + "//mediapipe/tasks/web/core:embedder_options", + "//mediapipe/tasks/web/core:task_runner", + "//mediapipe/web/graph_runner:graph_runner_ts", + ], +) + +mediapipe_ts_declaration( + name = "audio_embedder_types", + srcs = [ + "audio_embedder_options.d.ts", + "audio_embedder_result.d.ts", + ], + visibility = ["//visibility:public"], + deps = [ + "//mediapipe/tasks/web/components/containers:embedding_result", + "//mediapipe/tasks/web/core", + "//mediapipe/tasks/web/core:embedder_options", + ], +) + +mediapipe_ts_library( + name = "audio_embedder_test_lib", + testonly = True, + srcs = [ + "audio_embedder_test.ts", + ], + deps = [ + ":audio_embedder", + "//mediapipe/framework:calculator_jspb_proto", + "//mediapipe/tasks/cc/components/containers/proto:embeddings_jspb_proto", + "//mediapipe/tasks/web/core", + "//mediapipe/tasks/web/core:task_runner_test_utils", + ], +) + +jasmine_node_test( + name = "audio_embedder_test", + deps = [":audio_embedder_test_lib"], +) diff --git a/mediapipe/tasks/web/audio/audio_embedder/audio_embedder.ts b/mediapipe/tasks/web/audio/audio_embedder/audio_embedder.ts new file mode 100644 index 000000000..7411f95ef --- /dev/null +++ b/mediapipe/tasks/web/audio/audio_embedder/audio_embedder.ts @@ -0,0 +1,224 @@ +/** + * Copyright 2022 The MediaPipe Authors. All Rights Reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +import {CalculatorGraphConfig} from '../../../../framework/calculator_pb'; +import {CalculatorOptions} from '../../../../framework/calculator_options_pb'; +import {AudioEmbedderGraphOptions as AudioEmbedderGraphOptionsProto} from '../../../../tasks/cc/audio/audio_embedder/proto/audio_embedder_graph_options_pb'; +import {EmbeddingResult} from '../../../../tasks/cc/components/containers/proto/embeddings_pb'; +import {BaseOptions as BaseOptionsProto} from '../../../../tasks/cc/core/proto/base_options_pb'; +import {AudioTaskRunner} from '../../../../tasks/web/audio/core/audio_task_runner'; +import {Embedding} from '../../../../tasks/web/components/containers/embedding_result'; +import {convertEmbedderOptionsToProto} from '../../../../tasks/web/components/processors/embedder_options'; +import {convertFromEmbeddingResultProto} from '../../../../tasks/web/components/processors/embedder_result'; +import {computeCosineSimilarity} from '../../../../tasks/web/components/utils/cosine_similarity'; +import {CachedGraphRunner} from '../../../../tasks/web/core/task_runner'; +import {WasmFileset} from '../../../../tasks/web/core/wasm_fileset'; +import {WasmModule} from '../../../../web/graph_runner/graph_runner'; +// Placeholder for internal dependency on trusted resource url + +import {AudioEmbedderOptions} from './audio_embedder_options'; +import {AudioEmbedderResult} from './audio_embedder_result'; + +export * from './audio_embedder_options'; +export * from './audio_embedder_result'; + +// The OSS JS API does not support the builder pattern. +// tslint:disable:jspb-use-builder-pattern + +const AUDIO_STREAM = 'audio_in'; +const SAMPLE_RATE_STREAM = 'sample_rate'; +const EMBEDDINGS_STREAM = 'embeddings_out'; +const TIMESTAMPED_EMBEDDINGS_STREAM = 'timestamped_embeddings_out'; +const AUDIO_EMBEDDER_CALCULATOR = + 'mediapipe.tasks.audio.audio_embedder.AudioEmbedderGraph'; + +/** Performs embedding extraction on audio. */ +export class AudioEmbedder extends AudioTaskRunner { + private embeddingResults: AudioEmbedderResult[] = []; + private readonly options = new AudioEmbedderGraphOptionsProto(); + + /** + * Initializes the Wasm runtime and creates a new audio embedder from the + * provided options. + * @param wasmFileset A configuration object that provides the location of the + * Wasm binary and its loader. + * @param audioEmbedderOptions The options for the audio embedder. Note that + * either a path to the TFLite model or the model itself needs to be + * provided (via `baseOptions`). + */ + static createFromOptions( + wasmFileset: WasmFileset, + audioEmbedderOptions: AudioEmbedderOptions): Promise { + return AudioTaskRunner.createInstance( + AudioEmbedder, /* initializeCanvas= */ false, wasmFileset, + audioEmbedderOptions); + } + + /** + * Initializes the Wasm runtime and creates a new audio embedder based on the + * provided model asset buffer. + * @param wasmFileset A configuration object that provides the location of the + * Wasm binary and its loader. + * @param modelAssetBuffer A binary representation of the TFLite model. + */ + static createFromModelBuffer( + wasmFileset: WasmFileset, + modelAssetBuffer: Uint8Array): Promise { + return AudioTaskRunner.createInstance( + AudioEmbedder, /* initializeCanvas= */ false, wasmFileset, + {baseOptions: {modelAssetBuffer}}); + } + + /** + * Initializes the Wasm runtime and creates a new audio embedder based on the + * path to the model asset. + * @param wasmFileset A configuration object that provides the location of the + * Wasm binary and its loader. + * @param modelAssetPath The path to the TFLite model. + */ + static createFromModelPath( + wasmFileset: WasmFileset, + modelAssetPath: string): Promise { + return AudioTaskRunner.createInstance( + AudioEmbedder, /* initializeCanvas= */ false, wasmFileset, + {baseOptions: {modelAssetPath}}); + } + + /** @hideconstructor */ + constructor( + wasmModule: WasmModule, + glCanvas?: HTMLCanvasElement|OffscreenCanvas|null) { + super(new CachedGraphRunner(wasmModule, glCanvas)); + this.options.setBaseOptions(new BaseOptionsProto()); + } + + protected override get baseOptions(): BaseOptionsProto { + return this.options.getBaseOptions()!; + } + + protected override set baseOptions(proto: BaseOptionsProto) { + this.options.setBaseOptions(proto); + } + + /** + * Sets new options for the audio embedder. + * + * Calling `setOptions()` with a subset of options only affects those options. + * You can reset an option back to its default value by explicitly setting it + * to `undefined`. + * + * @param options The options for the audio embedder. + */ + override setOptions(options: AudioEmbedderOptions): Promise { + this.options.setEmbedderOptions(convertEmbedderOptionsToProto( + options, this.options.getEmbedderOptions())); + return this.applyOptions(options); + } + + // TODO: Add a classifyStream() that takes a timestamp + + /** + * Performs embeding extraction on the provided audio clip and waits + * synchronously for the response. + * + * @param audioData An array of raw audio capture data, like from a call to + * `getChannelData()` on an AudioBuffer. + * @param sampleRate The sample rate in Hz of the provided audio data. If not + * set, defaults to the sample rate set via `setDefaultSampleRate()` or + * `48000` if no custom default was set. + * @return The embedding resuls of the audio + */ + embed(audioData: Float32Array, sampleRate?: number): AudioEmbedderResult[] { + return this.processAudioClip(audioData, sampleRate); + } + + /** + * Utility function to compute cosine similarity[1] between two `Embedding` + * objects. + * + * [1]: https://en.wikipedia.org/wiki/Cosine_similarity + * + * @throws if the embeddings are of different types(float vs. quantized), have + * different sizes, or have an L2-norm of 0. + */ + static cosineSimilarity(u: Embedding, v: Embedding): number { + return computeCosineSimilarity(u, v); + } + + protected override process( + audioData: Float32Array, sampleRate: number, + timestampMs: number): AudioEmbedderResult[] { + this.graphRunner.addDoubleToStream( + sampleRate, SAMPLE_RATE_STREAM, timestampMs); + this.graphRunner.addAudioToStreamWithShape( + audioData, /* numChannels= */ 1, /* numSamples= */ audioData.length, + AUDIO_STREAM, timestampMs); + + this.embeddingResults = []; + this.finishProcessing(); + return this.embeddingResults; + } + + /** Updates the MediaPipe graph configuration. */ + protected override refreshGraph(): void { + const graphConfig = new CalculatorGraphConfig(); + graphConfig.addInputStream(AUDIO_STREAM); + graphConfig.addInputStream(SAMPLE_RATE_STREAM); + graphConfig.addOutputStream(EMBEDDINGS_STREAM); + graphConfig.addOutputStream(TIMESTAMPED_EMBEDDINGS_STREAM); + + const calculatorOptions = new CalculatorOptions(); + calculatorOptions.setExtension( + AudioEmbedderGraphOptionsProto.ext, this.options); + + const embedderNode = new CalculatorGraphConfig.Node(); + embedderNode.setCalculator(AUDIO_EMBEDDER_CALCULATOR); + embedderNode.addInputStream('AUDIO:' + AUDIO_STREAM); + embedderNode.addInputStream('SAMPLE_RATE:' + SAMPLE_RATE_STREAM); + embedderNode.addOutputStream('EMBEDDINGS:' + EMBEDDINGS_STREAM); + embedderNode.addOutputStream( + 'TIMESTAMPED_EMBEDDINGS:' + TIMESTAMPED_EMBEDDINGS_STREAM); + embedderNode.setOptions(calculatorOptions); + + graphConfig.addNode(embedderNode); + + this.graphRunner.attachProtoListener( + EMBEDDINGS_STREAM, (binaryProto, timestamp) => { + const embeddingResult = + EmbeddingResult.deserializeBinary(binaryProto); + this.embeddingResults.push( + convertFromEmbeddingResultProto(embeddingResult)); + this.setLatestOutputTimestamp(timestamp); + }); + + this.graphRunner.attachProtoVectorListener( + TIMESTAMPED_EMBEDDINGS_STREAM, (data, timestamp) => { + for (const binaryProto of data) { + const embeddingResult = + EmbeddingResult.deserializeBinary(binaryProto); + this.embeddingResults.push( + convertFromEmbeddingResultProto(embeddingResult)); + } + this.setLatestOutputTimestamp(timestamp); + }); + + const binaryGraph = graphConfig.serializeBinary(); + this.setGraph(new Uint8Array(binaryGraph), /* isBinary= */ true); + } +} + + + diff --git a/mediapipe/tasks/web/audio/audio_embedder/audio_embedder_options.d.ts b/mediapipe/tasks/web/audio/audio_embedder/audio_embedder_options.d.ts new file mode 100644 index 000000000..ac22728ab --- /dev/null +++ b/mediapipe/tasks/web/audio/audio_embedder/audio_embedder_options.d.ts @@ -0,0 +1,22 @@ +/** + * Copyright 2022 The MediaPipe Authors. All Rights Reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +import {EmbedderOptions} from '../../../../tasks/web/core/embedder_options'; +import {TaskRunnerOptions} from '../../../../tasks/web/core/task_runner_options'; + +/** Options to configure the MediaPipe Audio Embedder Task */ +export declare interface AudioEmbedderOptions extends EmbedderOptions, + TaskRunnerOptions {} diff --git a/mediapipe/tasks/web/text/text_classifier/text_classifier_options.ts b/mediapipe/tasks/web/audio/audio_embedder/audio_embedder_result.d.ts similarity index 83% rename from mediapipe/tasks/web/text/text_classifier/text_classifier_options.ts rename to mediapipe/tasks/web/audio/audio_embedder/audio_embedder_result.d.ts index 51b2b3947..13abc28d9 100644 --- a/mediapipe/tasks/web/text/text_classifier/text_classifier_options.ts +++ b/mediapipe/tasks/web/audio/audio_embedder/audio_embedder_result.d.ts @@ -14,4 +14,4 @@ * limitations under the License. */ -export {ClassifierOptions as TextClassifierOptions} from '../../../../tasks/web/core/classifier_options'; +export {Embedding, EmbeddingResult as AudioEmbedderResult} from '../../../../tasks/web/components/containers/embedding_result'; diff --git a/mediapipe/tasks/web/audio/audio_embedder/audio_embedder_test.ts b/mediapipe/tasks/web/audio/audio_embedder/audio_embedder_test.ts new file mode 100644 index 000000000..a8a2b232b --- /dev/null +++ b/mediapipe/tasks/web/audio/audio_embedder/audio_embedder_test.ts @@ -0,0 +1,189 @@ +/** + * Copyright 2022 The MediaPipe Authors. All Rights Reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +import 'jasmine'; + +// Placeholder for internal dependency on encodeByteArray +import {CalculatorGraphConfig} from '../../../../framework/calculator_pb'; +import {Embedding, EmbeddingResult as EmbeddingResultProto, FloatEmbedding} from '../../../../tasks/cc/components/containers/proto/embeddings_pb'; +import {addJasmineCustomFloatEqualityTester, createSpyWasmModule, MediapipeTasksFake, SpyWasmModule, verifyGraph, verifyListenersRegistered} from '../../../../tasks/web/core/task_runner_test_utils'; + +import {AudioEmbedder, AudioEmbedderResult} from './audio_embedder'; + + +// The OSS JS API does not support the builder pattern. +// tslint:disable:jspb-use-builder-pattern + +class AudioEmbedderFake extends AudioEmbedder implements MediapipeTasksFake { + lastSampleRate: number|undefined; + calculatorName = 'mediapipe.tasks.audio.audio_embedder.AudioEmbedderGraph'; + graph: CalculatorGraphConfig|undefined; + attachListenerSpies: jasmine.Spy[] = []; + fakeWasmModule: SpyWasmModule; + + protoListener: + ((binaryProto: Uint8Array, timestamp: number) => void)|undefined; + protoVectorListener: + ((binaryProtos: Uint8Array[], timestamp: number) => void)|undefined; + + constructor() { + super(createSpyWasmModule(), /* glCanvas= */ null); + this.fakeWasmModule = + this.graphRunner.wasmModule as unknown as SpyWasmModule; + + this.attachListenerSpies[0] = + spyOn(this.graphRunner, 'attachProtoListener') + .and.callFake((stream, listener) => { + expect(stream).toEqual('embeddings_out'); + this.protoListener = listener; + }); + this.attachListenerSpies[1] = + spyOn(this.graphRunner, 'attachProtoVectorListener') + .and.callFake((stream, listener) => { + expect(stream).toEqual('timestamped_embeddings_out'); + this.protoVectorListener = listener; + }); + spyOn(this.graphRunner, 'setGraph').and.callFake(binaryGraph => { + this.graph = CalculatorGraphConfig.deserializeBinary(binaryGraph); + }); + spyOn(this.graphRunner, 'addDoubleToStream').and.callFake(sampleRate => { + this.lastSampleRate = sampleRate; + }); + spyOn(this.graphRunner, 'addAudioToStreamWithShape'); + } +} + +describe('AudioEmbedder', () => { + let audioEmbedder: AudioEmbedderFake; + + beforeEach(async () => { + addJasmineCustomFloatEqualityTester(); + audioEmbedder = new AudioEmbedderFake(); + await audioEmbedder.setOptions( + {baseOptions: {modelAssetBuffer: new Uint8Array([])}}); + }); + + it('initializes graph', () => { + verifyGraph(audioEmbedder); + verifyListenersRegistered(audioEmbedder); + }); + + it('reloads graph when settings are changed', async () => { + await audioEmbedder.setOptions({quantize: true}); + verifyGraph(audioEmbedder, [['embedderOptions', 'quantize'], true]); + verifyListenersRegistered(audioEmbedder); + + await audioEmbedder.setOptions({quantize: undefined}); + verifyGraph(audioEmbedder, [['embedderOptions', 'quantize'], undefined]); + verifyListenersRegistered(audioEmbedder); + }); + + it('can use custom models', async () => { + const newModel = new Uint8Array([0, 1, 2, 3, 4]); + const newModelBase64 = Buffer.from(newModel).toString('base64'); + await audioEmbedder.setOptions({ + baseOptions: { + modelAssetBuffer: newModel, + } + }); + + verifyGraph( + audioEmbedder, + /* expectedCalculatorOptions= */ undefined, + /* expectedBaseOptions= */[ + 'modelAsset', { + fileContent: newModelBase64, + fileName: undefined, + fileDescriptorMeta: undefined, + filePointerMeta: undefined + } + ]); + }); + + it('combines options', async () => { + await audioEmbedder.setOptions({quantize: true}); + await audioEmbedder.setOptions({l2Normalize: true}); + verifyGraph( + audioEmbedder, + ['embedderOptions', {'quantize': true, 'l2Normalize': true}]); + }); + + it('uses a sample rate of 48000 by default', async () => { + audioEmbedder.embed(new Float32Array([])); + expect(audioEmbedder.lastSampleRate).toEqual(48000); + }); + + it('uses default sample rate if none provided', async () => { + audioEmbedder.setDefaultSampleRate(16000); + audioEmbedder.embed(new Float32Array([])); + expect(audioEmbedder.lastSampleRate).toEqual(16000); + }); + + it('uses custom sample rate if provided', async () => { + audioEmbedder.setDefaultSampleRate(16000); + audioEmbedder.embed(new Float32Array([]), 44100); + expect(audioEmbedder.lastSampleRate).toEqual(44100); + }); + + describe('transforms results', () => { + const embedding = new Embedding(); + embedding.setHeadIndex(1); + embedding.setHeadName('headName'); + + const floatEmbedding = new FloatEmbedding(); + floatEmbedding.setValuesList([0.1, 0.9]); + + embedding.setFloatEmbedding(floatEmbedding); + const resultProto = new EmbeddingResultProto(); + resultProto.addEmbeddings(embedding); + + function validateEmbeddingResult( + expectedEmbeddignResult: AudioEmbedderResult[]) { + expect(expectedEmbeddignResult.length).toEqual(1); + + const [embeddingResult] = expectedEmbeddignResult; + expect(embeddingResult.embeddings.length).toEqual(1); + expect(embeddingResult.embeddings[0]) + .toEqual( + {floatEmbedding: [0.1, 0.9], headIndex: 1, headName: 'headName'}); + } + + it('from embeddings strem', async () => { + audioEmbedder.fakeWasmModule._waitUntilIdle.and.callFake(() => { + verifyListenersRegistered(audioEmbedder); + // Pass the test data to our listener + audioEmbedder.protoListener!(resultProto.serializeBinary(), 1337); + }); + + // Invoke the audio embedder + const embeddingResults = audioEmbedder.embed(new Float32Array([])); + validateEmbeddingResult(embeddingResults); + }); + + it('from timestamped embeddgins stream', async () => { + audioEmbedder.fakeWasmModule._waitUntilIdle.and.callFake(() => { + verifyListenersRegistered(audioEmbedder); + // Pass the test data to our listener + audioEmbedder.protoVectorListener! + ([resultProto.serializeBinary()], 1337); + }); + + // Invoke the audio embedder + const embeddingResults = audioEmbedder.embed(new Float32Array([]), 42); + validateEmbeddingResult(embeddingResults); + }); + }); +}); diff --git a/mediapipe/tasks/web/audio/core/BUILD b/mediapipe/tasks/web/audio/core/BUILD new file mode 100644 index 000000000..cea689838 --- /dev/null +++ b/mediapipe/tasks/web/audio/core/BUILD @@ -0,0 +1,11 @@ +# This package contains options shared by all MediaPipe Audio Tasks for Web. + +load("//mediapipe/framework/port:build_config.bzl", "mediapipe_ts_library") + +package(default_visibility = ["//mediapipe/tasks:internal"]) + +mediapipe_ts_library( + name = "audio_task_runner", + srcs = ["audio_task_runner.ts"], + deps = ["//mediapipe/tasks/web/core:task_runner"], +) diff --git a/mediapipe/tasks/web/audio/core/audio_task_runner.ts b/mediapipe/tasks/web/audio/core/audio_task_runner.ts new file mode 100644 index 000000000..ff39185f2 --- /dev/null +++ b/mediapipe/tasks/web/audio/core/audio_task_runner.ts @@ -0,0 +1,47 @@ +/** + * Copyright 2022 The MediaPipe Authors. All Rights Reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +import {TaskRunner} from '../../../../tasks/web/core/task_runner'; + +/** Base class for all MediaPipe Audio Tasks. */ +export abstract class AudioTaskRunner extends TaskRunner { + private defaultSampleRate = 48000; + + /** + * Sets the sample rate for API calls that omit an explicit sample rate. + * `48000` is used as a default if this method is not called. + * + * @param sampleRate A sample rate (e.g. `44100`). + */ + setDefaultSampleRate(sampleRate: number) { + this.defaultSampleRate = sampleRate; + } + + /** Sends an audio packet to the graph and awaits results. */ + protected abstract process( + audioData: Float32Array, sampleRate: number, timestampMs: number): T; + + /** Sends a single audio clip to the graph and awaits results. */ + protected processAudioClip(audioData: Float32Array, sampleRate?: number): T { + // Increment the timestamp by 1 millisecond to guarantee that we send + // monotonically increasing timestamps to the graph. + const syntheticTimestamp = this.getLatestOutputTimestamp() + 1; + return this.process( + audioData, sampleRate ?? this.defaultSampleRate, syntheticTimestamp); + } +} + + diff --git a/mediapipe/tasks/web/audio/index.ts b/mediapipe/tasks/web/audio/index.ts index 114a8ceca..e7465878b 100644 --- a/mediapipe/tasks/web/audio/index.ts +++ b/mediapipe/tasks/web/audio/index.ts @@ -14,7 +14,14 @@ * limitations under the License. */ -// Audio Classifier -export * from '../../../tasks/web/audio/audio_classifier/audio_classifier_options'; -export * from '../../../tasks/web/audio/audio_classifier/audio_classifier_result'; -export * from '../../../tasks/web/audio/audio_classifier/audio_classifier'; +import {AudioClassifier as AudioClassifierImpl} from '../../../tasks/web/audio/audio_classifier/audio_classifier'; +import {AudioEmbedder as AudioEmbedderImpl} from '../../../tasks/web/audio/audio_embedder/audio_embedder'; +import {FilesetResolver as FilesetResolverImpl} from '../../../tasks/web/core/fileset_resolver'; + +// Declare the variables locally so that Rollup in OSS includes them explicitly +// as exports. +const AudioClassifier = AudioClassifierImpl; +const AudioEmbedder = AudioEmbedderImpl; +const FilesetResolver = FilesetResolverImpl; + +export {AudioClassifier, AudioEmbedder, FilesetResolver}; diff --git a/mediapipe/tasks/web/audio/types.ts b/mediapipe/tasks/web/audio/types.ts new file mode 100644 index 000000000..19073b708 --- /dev/null +++ b/mediapipe/tasks/web/audio/types.ts @@ -0,0 +1,19 @@ +/** + * Copyright 2023 The MediaPipe Authors. All Rights Reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +export * from '../../../tasks/web/audio/audio_classifier/audio_classifier'; +export * from '../../../tasks/web/audio/audio_embedder/audio_embedder'; +export * from '../../../tasks/web/core/fileset_resolver'; diff --git a/mediapipe/tasks/web/components/containers/BUILD b/mediapipe/tasks/web/components/containers/BUILD index 1b0e403ff..a0db59d0b 100644 --- a/mediapipe/tasks/web/components/containers/BUILD +++ b/mediapipe/tasks/web/components/containers/BUILD @@ -1,21 +1,31 @@ # This package contains options shared by all MediaPipe Tasks for Web. -load("//mediapipe/framework/port:build_config.bzl", "mediapipe_ts_library") +load("//mediapipe/framework/port:build_config.bzl", "mediapipe_ts_declaration") package(default_visibility = ["//mediapipe/tasks:internal"]) -mediapipe_ts_library( +mediapipe_ts_declaration( name = "category", srcs = ["category.d.ts"], ) -mediapipe_ts_library( +mediapipe_ts_declaration( name = "classification_result", srcs = ["classification_result.d.ts"], deps = [":category"], ) -mediapipe_ts_library( +mediapipe_ts_declaration( name = "landmark", srcs = ["landmark.d.ts"], ) + +mediapipe_ts_declaration( + name = "embedding_result", + srcs = ["embedding_result.d.ts"], +) + +mediapipe_ts_declaration( + name = "rect", + srcs = ["rect.d.ts"], +) diff --git a/mediapipe/tasks/web/components/containers/embedding_result.d.ts b/mediapipe/tasks/web/components/containers/embedding_result.d.ts new file mode 100644 index 000000000..43d14d30e --- /dev/null +++ b/mediapipe/tasks/web/components/containers/embedding_result.d.ts @@ -0,0 +1,67 @@ +/** + * Copyright 2022 The MediaPipe Authors. All Rights Reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +/** + * List of embeddings with an optional timestamp. + * + * One and only one of the two 'floatEmbedding' and 'quantizedEmbedding' will + * contain data, based on whether or not the embedder was configured to perform + * scalar quantization. + */ +export declare interface Embedding { + /** + * Floating-point embedding. Empty if the embedder was configured to perform + * scalar-quantization. + */ + floatEmbedding?: number[]; + + /** + * Scalar-quantized embedding. Empty if the embedder was not configured to + * perform scalar quantization. + */ + quantizedEmbedding?: Uint8Array; + + /** + * The index of the classifier head these categories refer to. This is + * useful for multi-head models. + */ + headIndex: number; + + /** + * The name of the classifier head, which is the corresponding tensor + * metadata name. + */ + headName: string; +} + +/** Embedding results for a given embedder model. */ +export interface EmbeddingResult { + /** + * The embedding results for each model head, i.e. one for each output tensor. + */ + embeddings: Embedding[]; + + /** + * The optional timestamp (in milliseconds) of the start of the chunk of + * data corresponding to these results. + * + * This is only used for embedding extraction on time series (e.g. audio + * embedding). In these use cases, the amount of data to process might + * exceed the maximum size that the model can process: to solve this, the + * input data is split into multiple chunks starting at different timestamps. + */ + timestampMs?: number; +} diff --git a/mediapipe/tasks/web/components/containers/landmark.d.ts b/mediapipe/tasks/web/components/containers/landmark.d.ts index f790d8a0b..0f916bf88 100644 --- a/mediapipe/tasks/web/components/containers/landmark.d.ts +++ b/mediapipe/tasks/web/components/containers/landmark.d.ts @@ -15,12 +15,29 @@ */ /** - * Landmark represents a point in 3D space with x, y, z coordinates. If - * normalized is true, the landmark coordinates is normalized respect to the - * dimension of image, and the coordinates values are in the range of [0,1]. - * Otherwise, it represenet a point in world coordinates. + * Normalized Landmark represents a point in 3D space with x, y, z coordinates. + * x and y are normalized to [0.0, 1.0] by the image width and height + * respectively. z represents the landmark depth, and the smaller the value the + * closer the landmark is to the camera. The magnitude of z uses roughly the + * same scale as x. */ -export declare class Landmark { +export declare interface NormalizedLandmark { + /** The x coordinates of the normalized landmark. */ + x: number; + + /** The y coordinates of the normalized landmark. */ + y: number; + + /** The z coordinates of the normalized landmark. */ + z: number; +} + +/** + * Landmark represents a point in 3D space with x, y, z coordinates. The + * landmark coordinates are in meters. z represents the landmark depth, + * and the smaller the value the closer the world landmark is to the camera. + */ +export declare interface Landmark { /** The x coordinates of the landmark. */ x: number; @@ -29,7 +46,4 @@ export declare class Landmark { /** The z coordinates of the landmark. */ z: number; - - /** Whether this landmark is normalized with respect to the image size. */ - normalized: boolean; } diff --git a/mediapipe/tasks/web/components/containers/rect.d.ts b/mediapipe/tasks/web/components/containers/rect.d.ts new file mode 100644 index 000000000..9afece9ca --- /dev/null +++ b/mediapipe/tasks/web/components/containers/rect.d.ts @@ -0,0 +1,41 @@ +/** + * Copyright 2023 The MediaPipe Authors. All Rights Reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +/** + * Defines a rectangle, used e.g. as part of detection results or as input + * region-of-interest. + */ +export declare interface Rect { + left: number; + top: number; + right: number; + bottom: number; +} + +/** + * Defines a rectangle, used e.g. as part of detection results or as input + * region-of-interest. + * + * The coordinates are normalized with respect to the image dimensions, i.e. + * generally in [0,1] but they may exceed these bounds if describing a region + * overlapping the image. The origin is on the top-left corner of the image. + */ +export declare interface RectF { + left: number; + top: number; + right: number; + bottom: number; +} diff --git a/mediapipe/tasks/web/components/processors/BUILD b/mediapipe/tasks/web/components/processors/BUILD index e0d84b632..cab24293d 100644 --- a/mediapipe/tasks/web/components/processors/BUILD +++ b/mediapipe/tasks/web/components/processors/BUILD @@ -1,5 +1,6 @@ # This package contains options shared by all MediaPipe Tasks for Web. +load("@npm//@bazel/jasmine:index.bzl", "jasmine_node_test") load("//mediapipe/framework/port:build_config.bzl", "mediapipe_ts_library") package(default_visibility = ["//mediapipe/tasks:internal"]) @@ -13,24 +14,92 @@ mediapipe_ts_library( ], ) +mediapipe_ts_library( + name = "classifier_options_test_lib", + testonly = True, + srcs = ["classifier_options.test.ts"], + deps = [ + ":classifier_options", + "//mediapipe/tasks/cc/components/processors/proto:classifier_options_jspb_proto", + "//mediapipe/tasks/web/core:classifier_options", + ], +) + +jasmine_node_test( + name = "classifier_options_test", + deps = [":classifier_options_test_lib"], +) + mediapipe_ts_library( name = "classifier_result", srcs = ["classifier_result.ts"], deps = [ - "//mediapipe/framework/formats:classification_jspb_proto", "//mediapipe/tasks/cc/components/containers/proto:classifications_jspb_proto", "//mediapipe/tasks/web/components/containers:classification_result", ], ) mediapipe_ts_library( - name = "base_options", - srcs = ["base_options.ts"], + name = "classifier_result_test_lib", + testonly = True, + srcs = ["classifier_result.test.ts"], deps = [ - "//mediapipe/calculators/tensor:inference_calculator_jspb_proto", - "//mediapipe/tasks/cc/core/proto:acceleration_jspb_proto", - "//mediapipe/tasks/cc/core/proto:base_options_jspb_proto", - "//mediapipe/tasks/cc/core/proto:external_file_jspb_proto", - "//mediapipe/tasks/web/core", + ":classifier_result", + "//mediapipe/framework/formats:classification_jspb_proto", + "//mediapipe/tasks/cc/components/containers/proto:classifications_jspb_proto", ], ) + +jasmine_node_test( + name = "classifier_result_test", + deps = [":classifier_result_test_lib"], +) + +mediapipe_ts_library( + name = "embedder_result", + srcs = ["embedder_result.ts"], + deps = [ + "//mediapipe/tasks/cc/components/containers/proto:embeddings_jspb_proto", + "//mediapipe/tasks/web/components/containers:embedding_result", + ], +) + +mediapipe_ts_library( + name = "embedder_result_test_lib", + testonly = True, + srcs = ["embedder_result.test.ts"], + deps = [ + ":embedder_result", + "//mediapipe/tasks/cc/components/containers/proto:embeddings_jspb_proto", + ], +) + +jasmine_node_test( + name = "embedder_result_test", + deps = [":embedder_result_test_lib"], +) + +mediapipe_ts_library( + name = "embedder_options", + srcs = ["embedder_options.ts"], + deps = [ + "//mediapipe/tasks/cc/components/processors/proto:embedder_options_jspb_proto", + "//mediapipe/tasks/web/core:embedder_options", + ], +) + +mediapipe_ts_library( + name = "embedder_options_test_lib", + testonly = True, + srcs = ["embedder_options.test.ts"], + deps = [ + ":embedder_options", + "//mediapipe/tasks/cc/components/processors/proto:embedder_options_jspb_proto", + "//mediapipe/tasks/web/core:embedder_options", + ], +) + +jasmine_node_test( + name = "embedder_options_test", + deps = [":embedder_options_test_lib"], +) diff --git a/mediapipe/tasks/web/components/processors/base_options.ts b/mediapipe/tasks/web/components/processors/base_options.ts deleted file mode 100644 index ac24a8db6..000000000 --- a/mediapipe/tasks/web/components/processors/base_options.ts +++ /dev/null @@ -1,80 +0,0 @@ -/** - * Copyright 2022 The MediaPipe Authors. All Rights Reserved. - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -import {InferenceCalculatorOptions} from '../../../../calculators/tensor/inference_calculator_pb'; -import {Acceleration} from '../../../../tasks/cc/core/proto/acceleration_pb'; -import {BaseOptions as BaseOptionsProto} from '../../../../tasks/cc/core/proto/base_options_pb'; -import {ExternalFile} from '../../../../tasks/cc/core/proto/external_file_pb'; -import {BaseOptions} from '../../../../tasks/web/core/base_options'; - -// The OSS JS API does not support the builder pattern. -// tslint:disable:jspb-use-builder-pattern - -/** - * Converts a BaseOptions API object to its Protobuf representation. - * @throws If neither a model assset path or buffer is provided - */ -export async function convertBaseOptionsToProto( - updatedOptions: BaseOptions, - currentOptions?: BaseOptionsProto): Promise { - const result = - currentOptions ? currentOptions.clone() : new BaseOptionsProto(); - - await configureExternalFile(updatedOptions, result); - configureAcceleration(updatedOptions, result); - - return result; -} - -/** - * Configues the `externalFile` option and validates that a single model is - * provided. - */ -async function configureExternalFile( - options: BaseOptions, proto: BaseOptionsProto) { - const externalFile = proto.getModelAsset() || new ExternalFile(); - proto.setModelAsset(externalFile); - - if (options.modelAssetPath || options.modelAssetBuffer) { - if (options.modelAssetPath && options.modelAssetBuffer) { - throw new Error( - 'Cannot set both baseOptions.modelAssetPath and baseOptions.modelAssetBuffer'); - } - - let modelAssetBuffer = options.modelAssetBuffer; - if (!modelAssetBuffer) { - const response = await fetch(options.modelAssetPath!.toString()); - modelAssetBuffer = new Uint8Array(await response.arrayBuffer()); - } - externalFile.setFileContent(modelAssetBuffer); - } - - if (!externalFile.hasFileContent()) { - throw new Error( - 'Either baseOptions.modelAssetPath or baseOptions.modelAssetBuffer must be set'); - } -} - -/** Configues the `acceleration` option. */ -function configureAcceleration(options: BaseOptions, proto: BaseOptionsProto) { - const acceleration = proto.getAcceleration() ?? new Acceleration(); - if (options.delegate === 'gpu') { - acceleration.setGpu(new InferenceCalculatorOptions.Delegate.Gpu()); - } else { - acceleration.setTflite(new InferenceCalculatorOptions.Delegate.TfLite()); - } - proto.setAcceleration(acceleration); -} diff --git a/mediapipe/tasks/web/components/processors/classifier_options.test.ts b/mediapipe/tasks/web/components/processors/classifier_options.test.ts new file mode 100644 index 000000000..928bda426 --- /dev/null +++ b/mediapipe/tasks/web/components/processors/classifier_options.test.ts @@ -0,0 +1,114 @@ +/** + * Copyright 2022 The MediaPipe Authors. All Rights Reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +import 'jasmine'; + +import {ClassifierOptions as ClassifierOptionsProto} from '../../../../tasks/cc/components/processors/proto/classifier_options_pb'; +import {ClassifierOptions} from '../../../../tasks/web/core/classifier_options'; + +import {convertClassifierOptionsToProto} from './classifier_options'; + +interface TestCase { + optionName: keyof ClassifierOptions; + protoName: string; + customValue: unknown; + defaultValue: unknown; +} + +describe('convertClassifierOptionsToProto()', () => { + function verifyOption( + actualClassifierOptions: ClassifierOptionsProto, + expectedClassifierOptions: Record = {}): void { + expect(actualClassifierOptions.toObject()) + .toEqual(jasmine.objectContaining(expectedClassifierOptions)); + } + + const testCases: TestCase[] = [ + { + optionName: 'maxResults', + protoName: 'maxResults', + customValue: 5, + defaultValue: -1 + }, + { + optionName: 'displayNamesLocale', + protoName: 'displayNamesLocale', + customValue: 'en', + defaultValue: 'en' + }, + { + optionName: 'scoreThreshold', + protoName: 'scoreThreshold', + customValue: 0.1, + defaultValue: undefined + }, + { + optionName: 'categoryAllowlist', + protoName: 'categoryAllowlistList', + customValue: ['foo'], + defaultValue: [] + }, + { + optionName: 'categoryDenylist', + protoName: 'categoryDenylistList', + customValue: ['bar'], + defaultValue: [] + }, + ]; + + for (const testCase of testCases) { + it(`can set ${testCase.optionName}`, () => { + const classifierOptionsProto = convertClassifierOptionsToProto( + {[testCase.optionName]: testCase.customValue}); + verifyOption( + classifierOptionsProto, {[testCase.protoName]: testCase.customValue}); + }); + + it(`can clear ${testCase.optionName}`, () => { + let classifierOptionsProto = convertClassifierOptionsToProto( + {[testCase.optionName]: testCase.customValue}); + verifyOption( + classifierOptionsProto, {[testCase.protoName]: testCase.customValue}); + + classifierOptionsProto = + convertClassifierOptionsToProto({[testCase.optionName]: undefined}); + verifyOption( + classifierOptionsProto, + {[testCase.protoName]: testCase.defaultValue}); + }); + } + + it('overwrites options', () => { + let classifierOptionsProto = + convertClassifierOptionsToProto({maxResults: 1}); + verifyOption(classifierOptionsProto, {'maxResults': 1}); + + classifierOptionsProto = convertClassifierOptionsToProto( + {maxResults: 2}, classifierOptionsProto); + verifyOption(classifierOptionsProto, {'maxResults': 2}); + }); + + it('merges options', () => { + let classifierOptionsProto = + convertClassifierOptionsToProto({maxResults: 1}); + verifyOption(classifierOptionsProto, {'maxResults': 1}); + + classifierOptionsProto = convertClassifierOptionsToProto( + {displayNamesLocale: 'en'}, classifierOptionsProto); + verifyOption( + classifierOptionsProto, {'maxResults': 1, 'displayNamesLocale': 'en'}); + }); +}); diff --git a/mediapipe/tasks/web/components/processors/classifier_result.test.ts b/mediapipe/tasks/web/components/processors/classifier_result.test.ts new file mode 100644 index 000000000..4b93d0a76 --- /dev/null +++ b/mediapipe/tasks/web/components/processors/classifier_result.test.ts @@ -0,0 +1,80 @@ +/** + * Copyright 2022 The MediaPipe Authors. All Rights Reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +import 'jasmine'; + +import {Classification, ClassificationList} from '../../../../framework/formats/classification_pb'; +import {ClassificationResult, Classifications} from '../../../../tasks/cc/components/containers/proto/classifications_pb'; + +import {convertFromClassificationResultProto} from './classifier_result'; + +// The OSS JS API does not support the builder pattern. +// tslint:disable:jspb-use-builder-pattern + +describe('convertFromClassificationResultProto()', () => { + it('transforms custom values', () => { + const classificationResult = new ClassificationResult(); + classificationResult.setTimestampMs(1); + const classifcations = new Classifications(); + classifcations.setHeadIndex(1); + classifcations.setHeadName('headName'); + const classificationList = new ClassificationList(); + const clasification = new Classification(); + clasification.setIndex(2); + clasification.setScore(0.3); + clasification.setDisplayName('displayName'); + clasification.setLabel('categoryName'); + classificationList.addClassification(clasification); + classifcations.setClassificationList(classificationList); + classificationResult.addClassifications(classifcations); + + const result = convertFromClassificationResultProto(classificationResult); + + expect(result).toEqual({ + classifications: [{ + categories: [{ + index: 2, + score: 0.3, + displayName: 'displayName', + categoryName: 'categoryName' + }], + headIndex: 1, + headName: 'headName' + }], + timestampMs: 1 + }); + }); + + it('transforms default values', () => { + const classificationResult = new ClassificationResult(); + const classifcations = new Classifications(); + const classificationList = new ClassificationList(); + const clasification = new Classification(); + classificationList.addClassification(clasification); + classifcations.setClassificationList(classificationList); + classificationResult.addClassifications(classifcations); + + const result = convertFromClassificationResultProto(classificationResult); + + expect(result).toEqual({ + classifications: [{ + categories: [{index: 0, score: 0, displayName: '', categoryName: ''}], + headIndex: 0, + headName: '' + }], + }); + }); +}); diff --git a/mediapipe/tasks/web/components/processors/embedder_options.test.ts b/mediapipe/tasks/web/components/processors/embedder_options.test.ts new file mode 100644 index 000000000..b879a6b29 --- /dev/null +++ b/mediapipe/tasks/web/components/processors/embedder_options.test.ts @@ -0,0 +1,93 @@ +/** + * Copyright 2022 The MediaPipe Authors. All Rights Reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +import 'jasmine'; + +import {EmbedderOptions as EmbedderOptionsProto} from '../../../../tasks/cc/components/processors/proto/embedder_options_pb'; +import {EmbedderOptions} from '../../../../tasks/web/core/embedder_options'; + +import {convertEmbedderOptionsToProto} from './embedder_options'; + +interface TestCase { + optionName: keyof EmbedderOptions; + protoName: string; + customValue: unknown; + defaultValue: unknown; +} + +describe('convertEmbedderOptionsToProto()', () => { + function verifyOption( + actualEmbedderOptions: EmbedderOptionsProto, + expectedEmbedderOptions: Record = {}): void { + expect(actualEmbedderOptions.toObject()) + .toEqual(jasmine.objectContaining(expectedEmbedderOptions)); + } + + const testCases: TestCase[] = [ + { + optionName: 'l2Normalize', + protoName: 'l2Normalize', + customValue: true, + defaultValue: undefined + }, + { + optionName: 'quantize', + protoName: 'quantize', + customValue: true, + defaultValue: undefined + }, + ]; + + for (const testCase of testCases) { + it(`can set ${testCase.optionName}`, () => { + const embedderOptionsProto = convertEmbedderOptionsToProto( + {[testCase.optionName]: testCase.customValue}); + verifyOption( + embedderOptionsProto, {[testCase.protoName]: testCase.customValue}); + }); + + it(`can clear ${testCase.optionName}`, () => { + let embedderOptionsProto = convertEmbedderOptionsToProto( + {[testCase.optionName]: testCase.customValue}); + verifyOption( + embedderOptionsProto, {[testCase.protoName]: testCase.customValue}); + + embedderOptionsProto = + convertEmbedderOptionsToProto({[testCase.optionName]: undefined}); + verifyOption( + embedderOptionsProto, {[testCase.protoName]: testCase.defaultValue}); + }); + } + + it('overwrites options', () => { + let embedderOptionsProto = + convertEmbedderOptionsToProto({l2Normalize: true}); + verifyOption(embedderOptionsProto, {'l2Normalize': true}); + + embedderOptionsProto = convertEmbedderOptionsToProto( + {l2Normalize: false}, embedderOptionsProto); + verifyOption(embedderOptionsProto, {'l2Normalize': false}); + }); + + it('replaces options', () => { + let embedderOptionsProto = convertEmbedderOptionsToProto({quantize: true}); + verifyOption(embedderOptionsProto, {'quantize': true}); + + embedderOptionsProto = convertEmbedderOptionsToProto( + {l2Normalize: true}, embedderOptionsProto); + verifyOption(embedderOptionsProto, {'l2Normalize': true, 'quantize': true}); + }); +}); diff --git a/mediapipe/tasks/web/components/processors/embedder_options.ts b/mediapipe/tasks/web/components/processors/embedder_options.ts new file mode 100644 index 000000000..f000dbd64 --- /dev/null +++ b/mediapipe/tasks/web/components/processors/embedder_options.ts @@ -0,0 +1,46 @@ +/** + * Copyright 2022 The MediaPipe Authors. All Rights Reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +import {EmbedderOptions as EmbedderOptionsProto} from '../../../../tasks/cc/components/processors/proto/embedder_options_pb'; +import {EmbedderOptions} from '../../../../tasks/web/core/embedder_options'; + +/** + * Converts a EmbedderOptions object to its Proto representation, optionally + * based on existing definition. + * @param options The options object to convert to a Proto. Only options that + * are expliclty provided are set. + * @param baseOptions A base object that options can be merged into. + */ +export function convertEmbedderOptionsToProto( + options: EmbedderOptions, + baseOptions?: EmbedderOptionsProto): EmbedderOptionsProto { + const embedderOptions = + baseOptions ? baseOptions.clone() : new EmbedderOptionsProto(); + + if (options.l2Normalize !== undefined) { + embedderOptions.setL2Normalize(options.l2Normalize); + } else if ('l2Normalize' in options) { // Check for undefined + embedderOptions.clearL2Normalize(); + } + + if (options.quantize !== undefined) { + embedderOptions.setQuantize(options.quantize); + } else if ('quantize' in options) { // Check for undefined + embedderOptions.clearQuantize(); + } + + return embedderOptions; +} diff --git a/mediapipe/tasks/web/components/processors/embedder_result.test.ts b/mediapipe/tasks/web/components/processors/embedder_result.test.ts new file mode 100644 index 000000000..97ba935c8 --- /dev/null +++ b/mediapipe/tasks/web/components/processors/embedder_result.test.ts @@ -0,0 +1,75 @@ +/** + * Copyright 2022 The MediaPipe Authors. All Rights Reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +import 'jasmine'; + +import {Embedding, EmbeddingResult, FloatEmbedding, QuantizedEmbedding} from '../../../../tasks/cc/components/containers/proto/embeddings_pb'; + +import {convertFromEmbeddingResultProto} from './embedder_result'; + +// The OSS JS API does not support the builder pattern. +// tslint:disable:jspb-use-builder-pattern + +describe('convertFromEmbeddingResultProto()', () => { + it('transforms custom values', () => { + const embedding = new Embedding(); + embedding.setHeadIndex(1); + embedding.setHeadName('headName'); + + const floatEmbedding = new FloatEmbedding(); + floatEmbedding.setValuesList([0.1, 0.9]); + + embedding.setFloatEmbedding(floatEmbedding); + const resultProto = new EmbeddingResult(); + resultProto.addEmbeddings(embedding); + resultProto.setTimestampMs(1); + + const embedderResult = convertFromEmbeddingResultProto(resultProto); + const embeddings = embedderResult.embeddings; + const timestampMs = embedderResult.timestampMs; + expect(embeddings.length).toEqual(1); + expect(embeddings[0]) + .toEqual( + {floatEmbedding: [0.1, 0.9], headIndex: 1, headName: 'headName'}); + expect(timestampMs).toEqual(1); + }); + + it('transforms custom quantized values', () => { + const embedding = new Embedding(); + embedding.setHeadIndex(1); + embedding.setHeadName('headName'); + + const quantizedEmbedding = new QuantizedEmbedding(); + const quantizedValues = new Uint8Array([1, 2, 3]); + quantizedEmbedding.setValues(quantizedValues); + + embedding.setQuantizedEmbedding(quantizedEmbedding); + const resultProto = new EmbeddingResult(); + resultProto.addEmbeddings(embedding); + resultProto.setTimestampMs(1); + + const embedderResult = convertFromEmbeddingResultProto(resultProto); + const embeddings = embedderResult.embeddings; + const timestampMs = embedderResult.timestampMs; + expect(embeddings.length).toEqual(1); + expect(embeddings[0]).toEqual({ + quantizedEmbedding: new Uint8Array([1, 2, 3]), + headIndex: 1, + headName: 'headName' + }); + expect(timestampMs).toEqual(1); + }); +}); diff --git a/mediapipe/tasks/web/components/processors/embedder_result.ts b/mediapipe/tasks/web/components/processors/embedder_result.ts new file mode 100644 index 000000000..285afe68a --- /dev/null +++ b/mediapipe/tasks/web/components/processors/embedder_result.ts @@ -0,0 +1,53 @@ +/** + * Copyright 2022 The MediaPipe Authors. All Rights Reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +import {Embedding as EmbeddingProto, EmbeddingResult as EmbeddingResultProto} from '../../../../tasks/cc/components/containers/proto/embeddings_pb'; +import {Embedding, EmbeddingResult} from '../../../../tasks/web/components/containers/embedding_result'; + +const DEFAULT_INDEX = -1; + +/** + * Converts an Embedding proto to the Embedding object. + */ +function convertFromEmbeddingsProto(source: EmbeddingProto): Embedding { + const embedding: Embedding = { + headIndex: source.getHeadIndex() ?? DEFAULT_INDEX, + headName: source.getHeadName() ?? '', + }; + + if (source.hasFloatEmbedding()) { + embedding.floatEmbedding = source.getFloatEmbedding()!.getValuesList(); + } else { + const encodedValue = source.getQuantizedEmbedding()?.getValues() ?? ''; + embedding.quantizedEmbedding = typeof encodedValue == 'string' ? + Uint8Array.from(atob(encodedValue), c => c.charCodeAt(0)) : encodedValue; + } + + return embedding; +} + +/** + * Converts an EmbedderResult proto to an EmbeddingResult object. + */ +export function convertFromEmbeddingResultProto( + embeddingResult: EmbeddingResultProto): EmbeddingResult { + const result: EmbeddingResult = { + embeddings: embeddingResult.getEmbeddingsList().map( + e => convertFromEmbeddingsProto(e)), + timestampMs: embeddingResult.getTimestampMs(), + }; + return result; +} diff --git a/mediapipe/tasks/web/components/utils/BUILD b/mediapipe/tasks/web/components/utils/BUILD new file mode 100644 index 000000000..f4a215e48 --- /dev/null +++ b/mediapipe/tasks/web/components/utils/BUILD @@ -0,0 +1,27 @@ +load("//mediapipe/framework/port:build_config.bzl", "mediapipe_ts_library") +load("@npm//@bazel/jasmine:index.bzl", "jasmine_node_test") + +package(default_visibility = ["//mediapipe/tasks:internal"]) + +mediapipe_ts_library( + name = "cosine_similarity", + srcs = ["cosine_similarity.ts"], + deps = [ + "//mediapipe/tasks/web/components/containers:embedding_result", + ], +) + +mediapipe_ts_library( + name = "cosine_similarity_test_lib", + testonly = True, + srcs = ["cosine_similarity.test.ts"], + deps = [ + ":cosine_similarity", + "//mediapipe/tasks/web/components/containers:embedding_result", + ], +) + +jasmine_node_test( + name = "cosine_similarity_test", + deps = [":cosine_similarity_test_lib"], +) diff --git a/mediapipe/tasks/web/components/utils/cosine_similarity.test.ts b/mediapipe/tasks/web/components/utils/cosine_similarity.test.ts new file mode 100644 index 000000000..2a82f388d --- /dev/null +++ b/mediapipe/tasks/web/components/utils/cosine_similarity.test.ts @@ -0,0 +1,85 @@ +/** + * Copyright 2022 The MediaPipe Authors. All Rights Reserved. + * + *

Licensed under the Apache License, Version 2.0 (the "License"); you may + * not use this file except in compliance with the License. You may obtain a + * copy of the License at + * + *

http://www.apache.org/licenses/LICENSE-2.0 + * + *

Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT + * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the + * License for the specific language governing permissions and limitations under + * the License. + */ + +import {Embedding} from '../../../../tasks/web/components/containers/embedding_result'; + +import {computeCosineSimilarity} from './cosine_similarity'; + +describe('computeCosineSimilarity', () => { + it('fails with quantized and float embeddings', () => { + const u: Embedding = {floatEmbedding: [1.0], headIndex: 0, headName: ''}; + const v: Embedding = { + quantizedEmbedding: new Uint8Array([1.0]), + headIndex: 0, + headName: '' + }; + + expect(() => computeCosineSimilarity(u, v)) + .toThrowError( + /Cannot compute cosine similarity between quantized and float embeddings/); + }); + + it('fails with zero norm', () => { + const u = {floatEmbedding: [0.0], headIndex: 0, headName: ''}; + expect(() => computeCosineSimilarity(u, u)) + .toThrowError( + /Cannot compute cosine similarity on embedding with 0 norm/); + }); + + it('fails with different sizes', () => { + const u: + Embedding = {floatEmbedding: [1.0, 2.0], headIndex: 0, headName: ''}; + const v: Embedding = { + floatEmbedding: [1.0, 2.0, 3.0], + headIndex: 0, + headName: '' + }; + + expect(() => computeCosineSimilarity(u, v)) + .toThrowError( + /Cannot compute cosine similarity between embeddings of different sizes/); + }); + + it('succeeds with float embeddings', () => { + const u: Embedding = { + floatEmbedding: [1.0, 0.0, 0.0, 0.0], + headIndex: 0, + headName: '' + }; + const v: Embedding = { + floatEmbedding: [0.5, 0.5, 0.5, 0.5], + headIndex: 0, + headName: '' + }; + + expect(computeCosineSimilarity(u, v)).toEqual(0.5); + }); + + it('succeeds with quantized embeddings', () => { + const u: Embedding = { + quantizedEmbedding: new Uint8Array([127, 0, 0, 0]), + headIndex: 0, + headName: '' + }; + const v: Embedding = { + quantizedEmbedding: new Uint8Array([128, 0, 0, 0]), + headIndex: 0, + headName: '' + }; + + expect(computeCosineSimilarity(u, v)).toEqual(-1.0); + }); +}); diff --git a/mediapipe/tasks/web/components/utils/cosine_similarity.ts b/mediapipe/tasks/web/components/utils/cosine_similarity.ts new file mode 100644 index 000000000..b512478f4 --- /dev/null +++ b/mediapipe/tasks/web/components/utils/cosine_similarity.ts @@ -0,0 +1,63 @@ +/** + * Copyright 2022 The MediaPipe Authors. All Rights Reserved. + * + *

Licensed under the Apache License, Version 2.0 (the "License"); you may + * not use this file except in compliance with the License. You may obtain a + * copy of the License at + * + *

http://www.apache.org/licenses/LICENSE-2.0 + * + *

Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT + * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the + * License for the specific language governing permissions and limitations under + * the License. + */ + +import {Embedding} from '../../../../tasks/web/components/containers/embedding_result'; + +/** + * Computes cosine similarity[1] between two `Embedding` objects. + * + * [1]: https://en.wikipedia.org/wiki/Cosine_similarity + * + * @throws if the embeddings are of different types (float vs. quantized), + * have different sizes, or have an L2-norm of 0. + */ +export function computeCosineSimilarity(u: Embedding, v: Embedding): number { + if (u.floatEmbedding && v.floatEmbedding) { + return compute(u.floatEmbedding, v.floatEmbedding); + } + if (u.quantizedEmbedding && v.quantizedEmbedding) { + return compute( + convertToBytes(u.quantizedEmbedding), + convertToBytes(v.quantizedEmbedding)); + } + throw new Error( + 'Cannot compute cosine similarity between quantized and float embeddings.'); +} + +function convertToBytes(data: Uint8Array): number[] { + return Array.from(data, v => v > 127 ? v - 256 : v); +} + +function compute(u: number[], v: number[]) { + if (u.length !== v.length) { + throw new Error( + `Cannot compute cosine similarity between embeddings of different sizes (${ + u.length} vs. ${v.length}).`); + } + let dotProduct = 0.0; + let normU = 0.0; + let normV = 0.0; + for (let i = 0; i < u.length; i++) { + dotProduct += u[i] * v[i]; + normU += u[i] * u[i]; + normV += v[i] * v[i]; + } + if (normU <= 0 || normV <= 0) { + throw new Error( + 'Cannot compute cosine similarity on embedding with 0 norm.'); + } + return dotProduct / Math.sqrt(normU * normV); +} diff --git a/mediapipe/tasks/web/core/BUILD b/mediapipe/tasks/web/core/BUILD index 4fb57d6c3..371c75da0 100644 --- a/mediapipe/tasks/web/core/BUILD +++ b/mediapipe/tasks/web/core/BUILD @@ -1,33 +1,79 @@ # This package contains options shared by all MediaPipe Tasks for Web. -load("//mediapipe/framework/port:build_config.bzl", "mediapipe_ts_library") +load("//mediapipe/framework/port:build_config.bzl", "mediapipe_ts_declaration", "mediapipe_ts_library") +load("@npm//@bazel/jasmine:index.bzl", "jasmine_node_test") package(default_visibility = ["//mediapipe/tasks:internal"]) -mediapipe_ts_library( +mediapipe_ts_declaration( name = "core", srcs = [ - "base_options.d.ts", - "wasm_loader_options.d.ts", + "task_runner_options.d.ts", + "wasm_fileset.d.ts", ], ) mediapipe_ts_library( name = "task_runner", - srcs = [ - "task_runner.ts", - ], + srcs = ["task_runner.ts"], deps = [ + ":core", + "//mediapipe/calculators/tensor:inference_calculator_jspb_proto", + "//mediapipe/tasks/cc/core/proto:acceleration_jspb_proto", + "//mediapipe/tasks/cc/core/proto:base_options_jspb_proto", + "//mediapipe/tasks/cc/core/proto:external_file_jspb_proto", + "//mediapipe/web/graph_runner:graph_runner_ts", "//mediapipe/web/graph_runner:register_model_resources_graph_service_ts", - "//mediapipe/web/graph_runner:wasm_mediapipe_image_lib_ts", - "//mediapipe/web/graph_runner:wasm_mediapipe_lib_ts", ], ) mediapipe_ts_library( - name = "classifier_options", - srcs = [ - "classifier_options.d.ts", - ], + name = "fileset_resolver", + srcs = ["fileset_resolver.ts"], + visibility = ["//visibility:public"], + deps = [":core"], +) + +mediapipe_ts_library( + name = "task_runner_test_utils", + testonly = True, + srcs = [ + "task_runner_test_utils.ts", + ], + deps = [ + "//mediapipe/framework:calculator_jspb_proto", + "//mediapipe/web/graph_runner:graph_runner_ts", + "//mediapipe/web/graph_runner:register_model_resources_graph_service_ts", + ], +) + +mediapipe_ts_library( + name = "task_runner_test_lib", + testonly = True, + srcs = [ + "task_runner_test.ts", + ], + deps = [ + ":core", + ":task_runner", + "//mediapipe/tasks/cc/core/proto:base_options_jspb_proto", + "//mediapipe/web/graph_runner:graph_runner_ts", + ], +) + +jasmine_node_test( + name = "task_runner_test", + deps = [":task_runner_test_lib"], +) + +mediapipe_ts_declaration( + name = "classifier_options", + srcs = ["classifier_options.d.ts"], + deps = [":core"], +) + +mediapipe_ts_declaration( + name = "embedder_options", + srcs = ["embedder_options.d.ts"], deps = [":core"], ) diff --git a/mediapipe/tasks/web/core/classifier_options.d.ts b/mediapipe/tasks/web/core/classifier_options.d.ts index 3dec8d27e..08e7a7664 100644 --- a/mediapipe/tasks/web/core/classifier_options.d.ts +++ b/mediapipe/tasks/web/core/classifier_options.d.ts @@ -14,13 +14,8 @@ * limitations under the License. */ -import {BaseOptions} from '../../../tasks/web/core/base_options'; - -/** Options to configure the Mediapipe Classifier Task. */ +/** Options to configure a MediaPipe Classifier Task. */ export declare interface ClassifierOptions { - /** Options to configure the loading of the model assets. */ - baseOptions?: BaseOptions; - /** * The locale to use for display names specified through the TFLite Model * Metadata, if any. Defaults to English. diff --git a/mediapipe/tasks/web/core/embedder_options.d.ts b/mediapipe/tasks/web/core/embedder_options.d.ts new file mode 100644 index 000000000..8669acfcb --- /dev/null +++ b/mediapipe/tasks/web/core/embedder_options.d.ts @@ -0,0 +1,34 @@ +/** + * Copyright 2022 The MediaPipe Authors. All Rights Reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +/** Options to configure a MediaPipe Embedder Task */ +export declare interface EmbedderOptions { + /** + * Whether to normalize the returned feature vector with L2 norm. Use this + * option only if the model does not already contain a native L2_NORMALIZATION + * TF Lite Op. In most cases, this is already the case and L2 norm is thus + * achieved through TF Lite inference. + */ + l2Normalize?: boolean|undefined; + + /** + * Whether the returned embedding should be quantized to bytes via scalar + * quantization. Embeddings are implicitly assumed to be unit-norm and + * therefore any dimension is guaranteed to have a value in [-1.0, 1.0]. Use + * the l2_normalize option if this is not the case. + */ + quantize?: boolean|undefined; +} diff --git a/mediapipe/tasks/web/core/fileset_resolver.ts b/mediapipe/tasks/web/core/fileset_resolver.ts new file mode 100644 index 000000000..9917035a4 --- /dev/null +++ b/mediapipe/tasks/web/core/fileset_resolver.ts @@ -0,0 +1,122 @@ +/** + * Copyright 2022 The MediaPipe Authors. All Rights Reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +// Placeholder for internal dependency on trusted resource URL builder + +import {WasmFileset} from './wasm_fileset'; + +let supportsSimd: boolean|undefined; + +/** + * Simple WASM program to test compatibility with the M91 instruction set. + * Compiled from + * https://github.com/GoogleChromeLabs/wasm-feature-detect/blob/main/src/detectors/simd/module.wat + */ +const WASM_SIMD_CHECK = new Uint8Array([ + 0, 97, 115, 109, 1, 0, 0, 0, 1, 5, 1, 96, 0, 1, 123, 3, + 2, 1, 0, 10, 10, 1, 8, 0, 65, 0, 253, 15, 253, 98, 11 +]); + +async function isSimdSupported(): Promise { + if (supportsSimd === undefined) { + try { + await WebAssembly.instantiate(WASM_SIMD_CHECK); + supportsSimd = true; + } catch { + supportsSimd = false; + } + } + + return supportsSimd; +} + +async function createFileset( + taskName: string, basePath: string = ''): Promise { + const suffix = + await isSimdSupported() ? 'wasm_internal' : 'wasm_nosimd_internal'; + + return { + wasmLoaderPath: `${basePath}/${taskName}_${suffix}.js`, + wasmBinaryPath: `${basePath}/${taskName}_${suffix}.wasm`, + }; +} + +// tslint:disable:class-as-namespace + +/** + * Resolves the files required for the MediaPipe Task APIs. + * + * This class verifies whether SIMD is supported in the current environment and + * loads the SIMD files only if support is detected. The returned filesets + * require that the Wasm files are published without renaming. If this is not + * possible, you can invoke the MediaPipe Tasks APIs using a manually created + * `WasmFileset`. + */ +export class FilesetResolver { + /** + * Returns whether SIMD is supported in the current environment. + * + * If your environment requires custom locations for the MediaPipe Wasm files, + * you can use `isSimdSupported()` to decide whether to load the SIMD-based + * assets. + * + * @return Whether SIMD support was detected in the current environment. + */ + static isSimdSupported(): Promise { + return isSimdSupported(); + } + + /** + * Creates a fileset for the MediaPipe Audio tasks. + * + * @param basePath An optional base path to specify the directory the Wasm + * files should be loaded from. If not specified, the Wasm files are + * loaded from the host's root directory. + * @return A `WasmFileset` that can be used to initialize MediaPipe Audio + * tasks. + */ + static forAudioTasks(basePath?: string): Promise { + return createFileset('audio', basePath); + } + + /** + * Creates a fileset for the MediaPipe Text tasks. + * + * @param basePath An optional base path to specify the directory the Wasm + * files should be loaded from. If not specified, the Wasm files are + * loaded from the host's root directory. + * @return A `WasmFileset` that can be used to initialize MediaPipe Text + * tasks. + */ + static forTextTasks(basePath?: string): Promise { + return createFileset('text', basePath); + } + + /** + * Creates a fileset for the MediaPipe Vision tasks. + * + * @param basePath An optional base path to specify the directory the Wasm + * files should be loaded from. If not specified, the Wasm files are + * loaded from the host's root directory. + * @return A `WasmFileset` that can be used to initialize MediaPipe Vision + * tasks. + */ + static forVisionTasks(basePath?: string): Promise { + return createFileset('vision', basePath); + } +} + + diff --git a/mediapipe/tasks/web/core/task_runner.ts b/mediapipe/tasks/web/core/task_runner.ts index c948930fc..8d483d9ff 100644 --- a/mediapipe/tasks/web/core/task_runner.ts +++ b/mediapipe/tasks/web/core/task_runner.ts @@ -14,29 +14,127 @@ * limitations under the License. */ +import {InferenceCalculatorOptions} from '../../../calculators/tensor/inference_calculator_pb'; +import {Acceleration} from '../../../tasks/cc/core/proto/acceleration_pb'; +import {BaseOptions as BaseOptionsProto} from '../../../tasks/cc/core/proto/base_options_pb'; +import {ExternalFile} from '../../../tasks/cc/core/proto/external_file_pb'; +import {BaseOptions, TaskRunnerOptions} from '../../../tasks/web/core/task_runner_options'; +import {createMediaPipeLib, FileLocator, GraphRunner, WasmMediaPipeConstructor} from '../../../web/graph_runner/graph_runner'; import {SupportModelResourcesGraphService} from '../../../web/graph_runner/register_model_resources_graph_service'; -import {SupportImage} from '../../../web/graph_runner/wasm_mediapipe_image_lib'; -import {WasmMediaPipeLib, WasmModule} from '../../../web/graph_runner/wasm_mediapipe_lib'; + +import {WasmFileset} from './wasm_fileset'; + +// None of the MP Tasks ship bundle assets. +const NO_ASSETS = undefined; // tslint:disable-next-line:enforce-name-casing -const WasmMediaPipeImageLib = - SupportModelResourcesGraphService(SupportImage(WasmMediaPipeLib)); +const CachedGraphRunnerType = SupportModelResourcesGraphService(GraphRunner); +/** + * An implementation of the GraphRunner that exposes the resource graph + * service. + */ +export class CachedGraphRunner extends CachedGraphRunnerType {} + +/** + * Creates a new instance of a Mediapipe Task. Determines if SIMD is + * supported and loads the relevant WASM binary. + * @return A fully instantiated instance of `T`. + */ +export async function createTaskRunner( + type: WasmMediaPipeConstructor, initializeCanvas: boolean, + fileset: WasmFileset, options: TaskRunnerOptions): Promise { + const fileLocator: FileLocator = { + locateFile() { + // The only file loaded with this mechanism is the Wasm binary + return fileset.wasmBinaryPath.toString(); + } + }; + + // Initialize a canvas if requested. If OffscreenCanvas is available, we + // let the graph runner initialize it by passing `undefined`. + const canvas = initializeCanvas ? (typeof OffscreenCanvas === 'undefined' ? + document.createElement('canvas') : + undefined) : + null; + const instance = await createMediaPipeLib( + type, fileset.wasmLoaderPath, NO_ASSETS, canvas, fileLocator); + await instance.setOptions(options); + return instance; +} /** Base class for all MediaPipe Tasks. */ -export abstract class TaskRunner extends WasmMediaPipeImageLib { +export abstract class TaskRunner { + protected abstract baseOptions: BaseOptionsProto; private processingErrors: Error[] = []; + private latestOutputTimestamp = 0; - constructor(wasmModule: WasmModule) { - super(wasmModule); + /** + * Creates a new instance of a Mediapipe Task. Determines if SIMD is + * supported and loads the relevant WASM binary. + * @return A fully instantiated instance of `T`. + */ + protected static async createInstance( + type: WasmMediaPipeConstructor, initializeCanvas: boolean, + fileset: WasmFileset, options: TaskRunnerOptions): Promise { + return createTaskRunner(type, initializeCanvas, fileset, options); + } + /** @hideconstructor protected */ + constructor(protected readonly graphRunner: CachedGraphRunner) { // Disables the automatic render-to-screen code, which allows for pure // CPU processing. - this.setAutoRenderToScreen(false); + this.graphRunner.setAutoRenderToScreen(false); // Enables use of our model resource caching graph service. - this.registerModelResourcesGraphService(); + this.graphRunner.registerModelResourcesGraphService(); } + /** Configures the task with custom options. */ + abstract setOptions(options: TaskRunnerOptions): Promise; + + /** + * Applies the current set of options, including any base options that have + * not been processed by the task implementation. The options are applied + * synchronously unless a `modelAssetPath` is provided. This ensures that + * for most use cases options are applied directly and immediately affect + * the next inference. + */ + protected applyOptions(options: TaskRunnerOptions): Promise { + const baseOptions: BaseOptions = options.baseOptions || {}; + + // Validate that exactly one model is configured + if (options.baseOptions?.modelAssetBuffer && + options.baseOptions?.modelAssetPath) { + throw new Error( + 'Cannot set both baseOptions.modelAssetPath and baseOptions.modelAssetBuffer'); + } else if (!(this.baseOptions.getModelAsset()?.hasFileContent() || + options.baseOptions?.modelAssetBuffer || + options.baseOptions?.modelAssetPath)) { + throw new Error( + 'Either baseOptions.modelAssetPath or baseOptions.modelAssetBuffer must be set'); + } + + this.setAcceleration(baseOptions); + if (baseOptions.modelAssetPath) { + // We don't use `await` here since we want to apply most settings + // synchronously. + return fetch(baseOptions.modelAssetPath.toString()) + .then(response => response.arrayBuffer()) + .then(buffer => { + this.setExternalFile(new Uint8Array(buffer)); + this.refreshGraph(); + }); + } else { + // Apply the setting synchronously. + this.setExternalFile(baseOptions.modelAssetBuffer); + this.refreshGraph(); + return Promise.resolve(); + } + } + + /** Appliest the current options to the MediaPipe graph. */ + protected abstract refreshGraph(): void; + /** * Takes the raw data from a MediaPipe graph, and passes it to C++ to be run * over the video stream. Will replace the previously running MediaPipe graph, @@ -47,11 +145,11 @@ export abstract class TaskRunner extends WasmMediaPipeImageLib { * @param isBinary This should be set to true if the graph is in * binary format, and false if it is in human-readable text format. */ - override setGraph(graphData: Uint8Array, isBinary: boolean): void { - this.attachErrorListener((code, message) => { + protected setGraph(graphData: Uint8Array, isBinary: boolean): void { + this.graphRunner.attachErrorListener((code, message) => { this.processingErrors.push(new Error(message)); }); - super.setGraph(graphData, isBinary); + this.graphRunner.setGraph(graphData, isBinary); this.handleErrors(); } @@ -60,23 +158,62 @@ export abstract class TaskRunner extends WasmMediaPipeImageLib { * far as possible, performing all processing until no more processing can be * done. */ - override finishProcessing(): void { - super.finishProcessing(); + protected finishProcessing(): void { + this.graphRunner.finishProcessing(); this.handleErrors(); } + /* + * Sets the latest output timestamp received from the graph (in ms). + * Timestamps that are smaller than the currently latest output timestamp are + * ignored. + */ + protected setLatestOutputTimestamp(timestamp: number): void { + this.latestOutputTimestamp = + Math.max(this.latestOutputTimestamp, timestamp); + } + + /** Returns the latest output timestamp. */ + protected getLatestOutputTimestamp() { + return this.latestOutputTimestamp; + } + /** Throws the error from the error listener if an error was raised. */ private handleErrors() { - const errorCount = this.processingErrors.length; - if (errorCount === 1) { - // Re-throw error to get a more meaningful stacktrace - throw new Error(this.processingErrors[0].message); - } else if (errorCount > 1) { - throw new Error( - 'Encountered multiple errors: ' + - this.processingErrors.map(e => e.message).join(', ')); + try { + const errorCount = this.processingErrors.length; + if (errorCount === 1) { + // Re-throw error to get a more meaningful stacktrace + throw new Error(this.processingErrors[0].message); + } else if (errorCount > 1) { + throw new Error( + 'Encountered multiple errors: ' + + this.processingErrors.map(e => e.message).join(', ')); + } + } finally { + this.processingErrors = []; } - this.processingErrors = []; + } + + /** Configures the `externalFile` option */ + private setExternalFile(modelAssetBuffer?: Uint8Array): void { + const externalFile = this.baseOptions.getModelAsset() || new ExternalFile(); + if (modelAssetBuffer) { + externalFile.setFileContent(modelAssetBuffer); + } + this.baseOptions.setModelAsset(externalFile); + } + + /** Configures the `acceleration` option. */ + private setAcceleration(options: BaseOptions) { + const acceleration = + this.baseOptions.getAcceleration() ?? new Acceleration(); + if (options.delegate === 'GPU') { + acceleration.setGpu(new InferenceCalculatorOptions.Delegate.Gpu()); + } else { + acceleration.setTflite(new InferenceCalculatorOptions.Delegate.TfLite()); + } + this.baseOptions.setAcceleration(acceleration); } } diff --git a/mediapipe/tasks/web/core/base_options.d.ts b/mediapipe/tasks/web/core/task_runner_options.d.ts similarity index 82% rename from mediapipe/tasks/web/core/base_options.d.ts rename to mediapipe/tasks/web/core/task_runner_options.d.ts index 86635b8c7..5f23cd4bf 100644 --- a/mediapipe/tasks/web/core/base_options.d.ts +++ b/mediapipe/tasks/web/core/task_runner_options.d.ts @@ -16,7 +16,7 @@ // Placeholder for internal dependency on trusted resource url -/** Options to configure MediaPipe Tasks in general. */ +/** Options to configure MediaPipe model loading and processing. */ export declare interface BaseOptions { /** * The model path to the model asset file. Only one of `modelAssetPath` or @@ -31,5 +31,11 @@ export declare interface BaseOptions { modelAssetBuffer?: Uint8Array|undefined; /** Overrides the default backend to use for the provided model. */ - delegate?: 'cpu'|'gpu'|undefined; + delegate?: 'CPU'|'GPU'|undefined; +} + +/** Options to configure MediaPipe Tasks in general. */ +export declare interface TaskRunnerOptions { + /** Options to configure the loading of the model assets. */ + baseOptions?: BaseOptions; } diff --git a/mediapipe/tasks/web/core/task_runner_test.ts b/mediapipe/tasks/web/core/task_runner_test.ts new file mode 100644 index 000000000..9a8aa32eb --- /dev/null +++ b/mediapipe/tasks/web/core/task_runner_test.ts @@ -0,0 +1,255 @@ +/** + * Copyright 2022 The MediaPipe Authors. All Rights Reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +import 'jasmine'; + +// Placeholder for internal dependency on encodeByteArray +import {BaseOptions as BaseOptionsProto} from '../../../tasks/cc/core/proto/base_options_pb'; +import {TaskRunner} from '../../../tasks/web/core/task_runner'; +import {ErrorListener} from '../../../web/graph_runner/graph_runner'; +// Placeholder for internal dependency on trusted resource URL builder + +import {CachedGraphRunner} from './task_runner'; +import {TaskRunnerOptions} from './task_runner_options.d'; + +class TaskRunnerFake extends TaskRunner { + private errorListener: ErrorListener|undefined; + private errors: string[] = []; + + baseOptions = new BaseOptionsProto(); + + static createFake(): TaskRunnerFake { + return new TaskRunnerFake(); + } + + constructor() { + super(jasmine.createSpyObj([ + 'setAutoRenderToScreen', 'setGraph', 'finishProcessing', + 'registerModelResourcesGraphService', 'attachErrorListener' + ])); + const graphRunner = this.graphRunner as jasmine.SpyObj; + expect(graphRunner.registerModelResourcesGraphService).toHaveBeenCalled(); + expect(graphRunner.setAutoRenderToScreen).toHaveBeenCalled(); + graphRunner.attachErrorListener.and.callFake(listener => { + this.errorListener = listener; + }); + graphRunner.setGraph.and.callFake(() => { + this.throwErrors(); + }); + graphRunner.finishProcessing.and.callFake(() => { + this.throwErrors(); + }); + } + + enqueueError(message: string): void { + this.errors.push(message); + } + + override finishProcessing(): void { + super.finishProcessing(); + } + + override refreshGraph(): void {} + + override setGraph(graphData: Uint8Array, isBinary: boolean): void { + super.setGraph(graphData, isBinary); + } + + setOptions(options: TaskRunnerOptions): Promise { + return this.applyOptions(options); + } + + private throwErrors(): void { + expect(this.errorListener).toBeDefined(); + for (const error of this.errors) { + this.errorListener!(/* errorCode= */ -1, error); + } + this.errors = []; + } +} + +describe('TaskRunner', () => { + const mockBytes = new Uint8Array([0, 1, 2, 3]); + const mockBytesResult = { + modelAsset: { + fileContent: Buffer.from(mockBytes).toString('base64'), + fileName: undefined, + fileDescriptorMeta: undefined, + filePointerMeta: undefined, + }, + useStreamMode: false, + acceleration: { + xnnpack: undefined, + gpu: undefined, + tflite: {}, + }, + }; + + let fetchSpy: jasmine.Spy; + let taskRunner: TaskRunnerFake; + + beforeEach(() => { + fetchSpy = jasmine.createSpy().and.callFake(async url => { + expect(url).toEqual('foo'); + return { + arrayBuffer: () => mockBytes.buffer, + } as unknown as Response; + }); + global.fetch = fetchSpy; + + taskRunner = TaskRunnerFake.createFake(); + }); + + it('handles errors during graph update', () => { + taskRunner.enqueueError('Test error'); + + expect(() => { + taskRunner.setGraph(new Uint8Array(0), /* isBinary= */ true); + }).toThrowError('Test error'); + }); + + it('handles errors during graph execution', () => { + taskRunner.setGraph(new Uint8Array(0), /* isBinary= */ true); + + taskRunner.enqueueError('Test error'); + + expect(() => { + taskRunner.finishProcessing(); + }).toThrowError('Test error'); + }); + + it('can handle multiple errors', () => { + taskRunner.enqueueError('Test error 1'); + taskRunner.enqueueError('Test error 2'); + + expect(() => { + taskRunner.setGraph(new Uint8Array(0), /* isBinary= */ true); + }).toThrowError(/Test error 1, Test error 2/); + }); + + it('clears errors once thrown', () => { + taskRunner.enqueueError('Test error'); + + expect(() => { + taskRunner.setGraph(new Uint8Array(0), /* isBinary= */ true); + }).toThrowError(/Test error/); + + expect(() => { + taskRunner.setGraph(new Uint8Array(0), /* isBinary= */ true); + }).not.toThrow(); + }); + + it('verifies that at least one model asset option is provided', () => { + expect(() => { + taskRunner.setOptions({}); + }) + .toThrowError( + /Either baseOptions.modelAssetPath or baseOptions.modelAssetBuffer must be set/); + }); + + it('verifies that no more than one model asset option is provided', () => { + expect(() => { + taskRunner.setOptions({ + baseOptions: { + modelAssetPath: `foo`, + modelAssetBuffer: new Uint8Array([]) + } + }); + }) + .toThrowError( + /Cannot set both baseOptions.modelAssetPath and baseOptions.modelAssetBuffer/); + }); + + it('doesn\'t require model once it is configured', async () => { + await taskRunner.setOptions( + {baseOptions: {modelAssetBuffer: new Uint8Array(mockBytes)}}); + expect(() => { + taskRunner.setOptions({}); + }).not.toThrowError(); + }); + + it('downloads model', async () => { + await taskRunner.setOptions( + {baseOptions: {modelAssetPath: `foo`}}); + + expect(fetchSpy).toHaveBeenCalled(); + expect(taskRunner.baseOptions.toObject()).toEqual(mockBytesResult); + }); + + it('does not download model when bytes are provided', async () => { + await taskRunner.setOptions( + {baseOptions: {modelAssetBuffer: new Uint8Array(mockBytes)}}); + + expect(fetchSpy).not.toHaveBeenCalled(); + expect(taskRunner.baseOptions.toObject()).toEqual(mockBytesResult); + }); + + it('changes model synchronously when bytes are provided', () => { + const resolvedPromise = taskRunner.setOptions( + {baseOptions: {modelAssetBuffer: new Uint8Array(mockBytes)}}); + + // Check that the change has been applied even though we do not await the + // above Promise + expect(taskRunner.baseOptions.toObject()).toEqual(mockBytesResult); + return resolvedPromise; + }); + + it('can enable CPU delegate', async () => { + await taskRunner.setOptions({ + baseOptions: { + modelAssetBuffer: new Uint8Array(mockBytes), + delegate: 'CPU', + } + }); + expect(taskRunner.baseOptions.toObject()).toEqual(mockBytesResult); + }); + + it('can enable GPU delegate', async () => { + await taskRunner.setOptions({ + baseOptions: { + modelAssetBuffer: new Uint8Array(mockBytes), + delegate: 'GPU', + } + }); + expect(taskRunner.baseOptions.toObject()).toEqual({ + ...mockBytesResult, + acceleration: { + xnnpack: undefined, + gpu: { + useAdvancedGpuApi: false, + api: 0, + allowPrecisionLoss: true, + cachedKernelPath: undefined, + serializedModelDir: undefined, + modelToken: undefined, + usage: 2, + }, + tflite: undefined, + }, + }); + }); + + it('can reset delegate', async () => { + await taskRunner.setOptions({ + baseOptions: { + modelAssetBuffer: new Uint8Array(mockBytes), + delegate: 'GPU', + } + }); + // Clear backend + await taskRunner.setOptions({baseOptions: {delegate: undefined}}); + expect(taskRunner.baseOptions.toObject()).toEqual(mockBytesResult); + }); +}); diff --git a/mediapipe/tasks/web/core/task_runner_test_utils.ts b/mediapipe/tasks/web/core/task_runner_test_utils.ts new file mode 100644 index 000000000..62dd0463a --- /dev/null +++ b/mediapipe/tasks/web/core/task_runner_test_utils.ts @@ -0,0 +1,115 @@ +/** + * Copyright 2022 The MediaPipe Authors. All Rights Reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +import 'jasmine'; + +import {CalculatorGraphConfig} from '../../../framework/calculator_pb'; +import {WasmModule} from '../../../web/graph_runner/graph_runner'; +import {WasmModuleRegisterModelResources} from '../../../web/graph_runner/register_model_resources_graph_service'; + +type SpyWasmModuleInternal = WasmModule&WasmModuleRegisterModelResources; + +/** + * Convenience type for our fake WasmModule for Jasmine testing. + */ +export declare type SpyWasmModule = jasmine.SpyObj; + +/** + * Factory function for creating a fake WasmModule for our Jasmine tests, + * allowing our APIs to no longer rely on the Wasm layer so they can run tests + * in pure JS/TS (and optionally spy on the calls). + */ +export function createSpyWasmModule(): SpyWasmModule { + const spyWasmModule = jasmine.createSpyObj([ + '_setAutoRenderToScreen', 'stringToNewUTF8', '_attachProtoListener', + '_attachProtoVectorListener', '_free', '_waitUntilIdle', + '_addStringToInputStream', '_registerModelResourcesGraphService', + '_configureAudio', '_malloc', '_addProtoToInputStream' + ]); + spyWasmModule.HEAPU8 = jasmine.createSpyObj(['set']); + return spyWasmModule; +} + +/** + * Sets up our equality testing to use a custom float equality checking function + * to avoid incorrect test results due to minor floating point inaccuracies. + */ +export function addJasmineCustomFloatEqualityTester(tolerance = 5e-8) { + jasmine.addCustomEqualityTester((a, b) => { // Custom float equality + if (a === +a && b === +b && (a !== (a | 0) || b !== (b | 0))) { + return Math.abs(a - b) < tolerance; + } + return; + }); +} + +/** The minimum interface provided by a test fake. */ +export interface MediapipeTasksFake { + graph: CalculatorGraphConfig|undefined; + calculatorName: string; + attachListenerSpies: jasmine.Spy[]; +} + +/** An map of field paths to values */ +export type FieldPathToValue = [string[] | string, unknown]; + +/** + * Verifies that the graph has been initialized and that it contains the + * provided options. + */ +export function verifyGraph( + tasksFake: MediapipeTasksFake, + expectedCalculatorOptions?: FieldPathToValue, + expectedBaseOptions?: FieldPathToValue, + ): void { + expect(tasksFake.graph).toBeDefined(); + expect(tasksFake.graph!.getNodeList().length).toBe(1); + const node = tasksFake.graph!.getNodeList()[0].toObject(); + expect(node).toEqual( + jasmine.objectContaining({calculator: tasksFake.calculatorName})); + + if (expectedBaseOptions) { + const [fieldPath, value] = expectedBaseOptions; + let proto = (node.options as {ext: {baseOptions: unknown}}).ext.baseOptions; + for (const fieldName of ( + Array.isArray(fieldPath) ? fieldPath : [fieldPath])) { + proto = ((proto ?? {}) as Record)[fieldName]; + } + expect(proto).toEqual(value); + } + + if (expectedCalculatorOptions) { + const [fieldPath, value] = expectedCalculatorOptions; + let proto = (node.options as {ext: unknown}).ext; + for (const fieldName of ( + Array.isArray(fieldPath) ? fieldPath : [fieldPath])) { + proto = ((proto ?? {}) as Record)[fieldName]; + } + expect(proto).toEqual(value); + } +} + +/** + * Verifies all listeners (as exposed by `.attachListenerSpies`) have been + * attached at least once since the last call to `verifyListenersRegistered()`. + * This helps us to ensure that listeners are re-registered with every graph + * update. + */ +export function verifyListenersRegistered(tasksFake: MediapipeTasksFake): void { + for (const spy of tasksFake.attachListenerSpies) { + expect(spy.calls.count()).toBeGreaterThanOrEqual(1); + spy.calls.reset(); + } +} diff --git a/mediapipe/tasks/web/core/wasm_loader_options.d.ts b/mediapipe/tasks/web/core/wasm_fileset.d.ts similarity index 88% rename from mediapipe/tasks/web/core/wasm_loader_options.d.ts rename to mediapipe/tasks/web/core/wasm_fileset.d.ts index 74436583d..18227eab9 100644 --- a/mediapipe/tasks/web/core/wasm_loader_options.d.ts +++ b/mediapipe/tasks/web/core/wasm_fileset.d.ts @@ -16,8 +16,8 @@ // Placeholder for internal dependency on trusted resource url -/** An object containing the locations of all Wasm assets */ -export declare interface WasmLoaderOptions { +/** An object containing the locations of the Wasm assets */ +export declare interface WasmFileset { /** The path to the Wasm loader script. */ wasmLoaderPath: string; /** The path to the Wasm binary. */ diff --git a/mediapipe/tasks/web/package.json b/mediapipe/tasks/web/package.json index d7d484ca4..89c9a599e 100644 --- a/mediapipe/tasks/web/package.json +++ b/mediapipe/tasks/web/package.json @@ -3,17 +3,9 @@ "version": "__VERSION__", "description": "__DESCRIPTION__", "main": "__NAME___bundle.js", - "module": "__NAME___bundle.js", - "exports": { - ".": "./__NAME___bundle.js", - "./loader": "./wasm/__NAME___wasm_internal.js", - "./wasm": "./wasm/__NAME___wasm_internal.wasm" - }, "author": "mediapipe@google.com", "license": "Apache-2.0", - "dependencies": { - "google-protobuf": "^3.21.2" - }, + "types": "__TYPES__", "homepage": "http://mediapipe.dev", "keywords": [ "AR", "ML", "Augmented", "MediaPipe", "MediaPipe Tasks" ] } diff --git a/mediapipe/tasks/web/rollup.config.mjs b/mediapipe/tasks/web/rollup.config.mjs index 392b235fc..3b5119530 100644 --- a/mediapipe/tasks/web/rollup.config.mjs +++ b/mediapipe/tasks/web/rollup.config.mjs @@ -1,9 +1,11 @@ import resolve from '@rollup/plugin-node-resolve'; import commonjs from '@rollup/plugin-commonjs'; +import terser from '@rollup/plugin-terser'; export default { plugins: [ resolve(), - commonjs() + commonjs(), + terser() ] } diff --git a/mediapipe/tasks/web/text/BUILD b/mediapipe/tasks/web/text/BUILD index a369d0af0..ebe3403b2 100644 --- a/mediapipe/tasks/web/text/BUILD +++ b/mediapipe/tasks/web/text/BUILD @@ -1,13 +1,81 @@ # This contains the MediaPipe Text Tasks. load("//mediapipe/framework/port:build_config.bzl", "mediapipe_ts_library") +load("@build_bazel_rules_nodejs//:index.bzl", "pkg_npm") +load("@npm//@bazel/rollup:index.bzl", "rollup_bundle") +load( + "//mediapipe/framework/tool:mediapipe_files.bzl", + "mediapipe_files", +) package(default_visibility = ["//mediapipe/tasks:internal"]) +mediapipe_files(srcs = [ + "wasm/text_wasm_internal.js", + "wasm/text_wasm_internal.wasm", + "wasm/text_wasm_nosimd_internal.js", + "wasm/text_wasm_nosimd_internal.wasm", +]) + +TEXT_LIBS = [ + "//mediapipe/tasks/web/core:fileset_resolver", + "//mediapipe/tasks/web/text/text_classifier", + "//mediapipe/tasks/web/text/text_embedder", +] + mediapipe_ts_library( name = "text_lib", srcs = ["index.ts"], + visibility = ["//visibility:public"], + deps = TEXT_LIBS, +) + +mediapipe_ts_library( + name = "text_types", + srcs = ["types.ts"], + visibility = ["//visibility:public"], + deps = TEXT_LIBS, +) + +rollup_bundle( + name = "text_bundle", + config_file = "//mediapipe/tasks/web:rollup.config.mjs", + entry_point = "index.ts", + format = "esm", + output_dir = False, + sourcemap = "false", deps = [ - "//mediapipe/tasks/web/text/text_classifier", + ":text_lib", + "@npm//@rollup/plugin-commonjs", + "@npm//@rollup/plugin-node-resolve", + "@npm//@rollup/plugin-terser", + "@npm//google-protobuf", + ], +) + +genrule( + name = "package_json", + srcs = ["//mediapipe/tasks/web:package.json"], + outs = ["package.json"], + cmd = "cp $< $@", +) + +pkg_npm( + name = "text_pkg", + package_name = "@mediapipe/tasks-__NAME__", + srcs = ["README.md"], + substitutions = { + "__NAME__": "text", + "__DESCRIPTION__": "MediaPipe Text Tasks", + "__TYPES__": "text.d.ts", + }, + tgz = "text.tgz", + deps = [ + "wasm/text_wasm_internal.js", + "wasm/text_wasm_internal.wasm", + "wasm/text_wasm_nosimd_internal.js", + "wasm/text_wasm_nosimd_internal.wasm", + ":package_json", + ":text_bundle", ], ) diff --git a/mediapipe/tasks/web/text/README.md b/mediapipe/tasks/web/text/README.md new file mode 100644 index 000000000..247dc6d30 --- /dev/null +++ b/mediapipe/tasks/web/text/README.md @@ -0,0 +1,34 @@ +# MediaPipe Tasks Text Package + +This package contains the text tasks for MediaPipe. + +## Text Classification + +MediaPipe Text Classifier task lets you classify text into a set of defined +categories, such as positive or negative sentiment. + +``` +const text = await FilesetResolver.forTextTasks( + "https://cdn.jsdelivr.net/npm/@mediapipe/tasks-text@latest/wasm" +); +const textClassifier = await TextClassifier.createFromModelPath(text, + "https://storage.googleapis.com/mediapipe-tasks/text_classifier/bert_text_classifier.tflite" +); +const classifications = textClassifier.classifiy(textData); +``` + +For more information, refer to the [Text Classification](https://developers.google.com/mediapipe/solutions/text/text_classifier/web_js) documentation. + +## Text Embedding + +The MediaPipe Text Embedding task extracts embeddings from text data. + +``` +const text = await FilesetResolver.forTextTasks( + "https://cdn.jsdelivr.net/npm/@mediapipe/tasks-text@latest/wasm" +); +const textEmbedder = await TextEmbedder.createFromModelPath(text, + "https://storage.googleapis.com/mediapipe-tasks/text_embedder/mobilebert_embedding_with_metadata.tflite" +); +const embeddings = textEmbedder.embed(textData); +``` diff --git a/mediapipe/tasks/web/text/index.ts b/mediapipe/tasks/web/text/index.ts index dc511a426..cfa990e58 100644 --- a/mediapipe/tasks/web/text/index.ts +++ b/mediapipe/tasks/web/text/index.ts @@ -14,7 +14,14 @@ * limitations under the License. */ -// Text Classifier -export * from '../../../tasks/web/text/text_classifier/text_classifier_options'; -export * from '../../../tasks/web/text/text_classifier/text_classifier_result'; -export * from '../../../tasks/web/text/text_classifier/text_classifier'; +import {FilesetResolver as FilesetResolverImpl} from '../../../tasks/web/core/fileset_resolver'; +import {TextClassifier as TextClassifierImpl} from '../../../tasks/web/text/text_classifier/text_classifier'; +import {TextEmbedder as TextEmbedderImpl} from '../../../tasks/web/text/text_embedder/text_embedder'; + +// Declare the variables locally so that Rollup in OSS includes them explicitly +// as exports. +const FilesetResolver = FilesetResolverImpl; +const TextClassifier = TextClassifierImpl; +const TextEmbedder = TextEmbedderImpl; + +export {FilesetResolver, TextClassifier, TextEmbedder}; diff --git a/mediapipe/tasks/web/text/text_classifier/BUILD b/mediapipe/tasks/web/text/text_classifier/BUILD index 4ebdce18a..fd97c3db4 100644 --- a/mediapipe/tasks/web/text/text_classifier/BUILD +++ b/mediapipe/tasks/web/text/text_classifier/BUILD @@ -3,7 +3,8 @@ # This task takes text input performs Natural Language classification (including # BERT-based text classification). -load("//mediapipe/framework/port:build_config.bzl", "mediapipe_ts_library") +load("//mediapipe/framework/port:build_config.bzl", "mediapipe_ts_declaration", "mediapipe_ts_library") +load("@npm//@bazel/jasmine:index.bzl", "jasmine_node_test") package(default_visibility = ["//mediapipe/tasks:internal"]) @@ -11,24 +12,58 @@ licenses(["notice"]) mediapipe_ts_library( name = "text_classifier", - srcs = [ - "text_classifier.ts", - "text_classifier_options.ts", - "text_classifier_result.ts", - ], + srcs = ["text_classifier.ts"], + visibility = ["//visibility:public"], deps = [ + ":text_classifier_types", "//mediapipe/framework:calculator_jspb_proto", "//mediapipe/framework:calculator_options_jspb_proto", "//mediapipe/tasks/cc/components/containers/proto:classifications_jspb_proto", + "//mediapipe/tasks/cc/core/proto:base_options_jspb_proto", "//mediapipe/tasks/cc/text/text_classifier/proto:text_classifier_graph_options_jspb_proto", "//mediapipe/tasks/web/components/containers:category", "//mediapipe/tasks/web/components/containers:classification_result", - "//mediapipe/tasks/web/components/processors:base_options", "//mediapipe/tasks/web/components/processors:classifier_options", "//mediapipe/tasks/web/components/processors:classifier_result", "//mediapipe/tasks/web/core", "//mediapipe/tasks/web/core:classifier_options", "//mediapipe/tasks/web/core:task_runner", - "//mediapipe/web/graph_runner:wasm_mediapipe_lib_ts", + "//mediapipe/web/graph_runner:graph_runner_ts", ], ) + +mediapipe_ts_declaration( + name = "text_classifier_types", + srcs = [ + "text_classifier_options.d.ts", + "text_classifier_result.d.ts", + ], + visibility = ["//visibility:public"], + deps = [ + "//mediapipe/tasks/web/components/containers:category", + "//mediapipe/tasks/web/components/containers:classification_result", + "//mediapipe/tasks/web/core", + "//mediapipe/tasks/web/core:classifier_options", + ], +) + +mediapipe_ts_library( + name = "text_classifier_test_lib", + testonly = True, + srcs = [ + "text_classifier_test.ts", + ], + deps = [ + ":text_classifier", + "//mediapipe/framework:calculator_jspb_proto", + "//mediapipe/framework/formats:classification_jspb_proto", + "//mediapipe/tasks/cc/components/containers/proto:classifications_jspb_proto", + "//mediapipe/tasks/web/core", + "//mediapipe/tasks/web/core:task_runner_test_utils", + ], +) + +jasmine_node_test( + name = "text_classifier_test", + deps = [":text_classifier_test_lib"], +) diff --git a/mediapipe/tasks/web/text/text_classifier/text_classifier.ts b/mediapipe/tasks/web/text/text_classifier/text_classifier.ts index e1d0c9601..ff314cfc3 100644 --- a/mediapipe/tasks/web/text/text_classifier/text_classifier.ts +++ b/mediapipe/tasks/web/text/text_classifier/text_classifier.ts @@ -17,18 +17,21 @@ import {CalculatorGraphConfig} from '../../../../framework/calculator_pb'; import {CalculatorOptions} from '../../../../framework/calculator_options_pb'; import {ClassificationResult} from '../../../../tasks/cc/components/containers/proto/classifications_pb'; +import {BaseOptions as BaseOptionsProto} from '../../../../tasks/cc/core/proto/base_options_pb'; import {TextClassifierGraphOptions} from '../../../../tasks/cc/text/text_classifier/proto/text_classifier_graph_options_pb'; -import {convertBaseOptionsToProto} from '../../../../tasks/web/components/processors/base_options'; import {convertClassifierOptionsToProto} from '../../../../tasks/web/components/processors/classifier_options'; import {convertFromClassificationResultProto} from '../../../../tasks/web/components/processors/classifier_result'; -import {TaskRunner} from '../../../../tasks/web/core/task_runner'; -import {WasmLoaderOptions} from '../../../../tasks/web/core/wasm_loader_options'; -import {createMediaPipeLib, FileLocator} from '../../../../web/graph_runner/wasm_mediapipe_lib'; +import {CachedGraphRunner, TaskRunner} from '../../../../tasks/web/core/task_runner'; +import {WasmFileset} from '../../../../tasks/web/core/wasm_fileset'; +import {WasmModule} from '../../../../web/graph_runner/graph_runner'; // Placeholder for internal dependency on trusted resource url import {TextClassifierOptions} from './text_classifier_options'; import {TextClassifierResult} from './text_classifier_result'; +export * from './text_classifier_options'; +export * from './text_classifier_result'; + const INPUT_STREAM = 'text_in'; const CLASSIFICATIONS_STREAM = 'classifications_out'; const TEXT_CLASSIFIER_GRAPH = @@ -45,59 +48,56 @@ export class TextClassifier extends TaskRunner { /** * Initializes the Wasm runtime and creates a new text classifier from the * provided options. - * @param wasmLoaderOptions A configuration object that provides the location - * of the Wasm binary and its loader. + * @param wasmFileset A configuration object that provides the location of the + * Wasm binary and its loader. * @param textClassifierOptions The options for the text classifier. Note that * either a path to the TFLite model or the model itself needs to be * provided (via `baseOptions`). */ - static async createFromOptions( - wasmLoaderOptions: WasmLoaderOptions, + static createFromOptions( + wasmFileset: WasmFileset, textClassifierOptions: TextClassifierOptions): Promise { - // Create a file locator based on the loader options - const fileLocator: FileLocator = { - locateFile() { - // The only file we load is the Wasm binary - return wasmLoaderOptions.wasmBinaryPath.toString(); - } - }; - - const classifier = await createMediaPipeLib( - TextClassifier, wasmLoaderOptions.wasmLoaderPath, - /* assetLoaderScript= */ undefined, - /* glCanvas= */ undefined, fileLocator); - await classifier.setOptions(textClassifierOptions); - return classifier; + return TaskRunner.createInstance( + TextClassifier, /* initializeCanvas= */ false, wasmFileset, + textClassifierOptions); } /** * Initializes the Wasm runtime and creates a new text classifier based on the * provided model asset buffer. - * @param wasmLoaderOptions A configuration object that provides the location - * of the Wasm binary and its loader. + * @param wasmFileset A configuration object that provides the location of the + * Wasm binary and its loader. * @param modelAssetBuffer A binary representation of the model. */ static createFromModelBuffer( - wasmLoaderOptions: WasmLoaderOptions, + wasmFileset: WasmFileset, modelAssetBuffer: Uint8Array): Promise { - return TextClassifier.createFromOptions( - wasmLoaderOptions, {baseOptions: {modelAssetBuffer}}); + return TaskRunner.createInstance( + TextClassifier, /* initializeCanvas= */ false, wasmFileset, + {baseOptions: {modelAssetBuffer}}); } /** * Initializes the Wasm runtime and creates a new text classifier based on the * path to the model asset. - * @param wasmLoaderOptions A configuration object that provides the location - * of the Wasm binary and its loader. + * @param wasmFileset A configuration object that provides the location of the + * Wasm binary and its loader. * @param modelAssetPath The path to the model asset. */ - static async createFromModelPath( - wasmLoaderOptions: WasmLoaderOptions, + static createFromModelPath( + wasmFileset: WasmFileset, modelAssetPath: string): Promise { - const response = await fetch(modelAssetPath.toString()); - const graphData = await response.arrayBuffer(); - return TextClassifier.createFromModelBuffer( - wasmLoaderOptions, new Uint8Array(graphData)); + return TaskRunner.createInstance( + TextClassifier, /* initializeCanvas= */ false, wasmFileset, + {baseOptions: {modelAssetPath}}); + } + + /** @hideconstructor */ + constructor( + wasmModule: WasmModule, + glCanvas?: HTMLCanvasElement|OffscreenCanvas|null) { + super(new CachedGraphRunner(wasmModule, glCanvas)); + this.options.setBaseOptions(new BaseOptionsProto()); } /** @@ -109,18 +109,19 @@ export class TextClassifier extends TaskRunner { * * @param options The options for the text classifier. */ - async setOptions(options: TextClassifierOptions): Promise { - if (options.baseOptions) { - const baseOptionsProto = await convertBaseOptionsToProto( - options.baseOptions, this.options.getBaseOptions()); - this.options.setBaseOptions(baseOptionsProto); - } - + override setOptions(options: TextClassifierOptions): Promise { this.options.setClassifierOptions(convertClassifierOptionsToProto( options, this.options.getClassifierOptions())); - this.refreshGraph(); + return this.applyOptions(options); } + protected override get baseOptions(): BaseOptionsProto { + return this.options.getBaseOptions()!; + } + + protected override set baseOptions(proto: BaseOptionsProto) { + this.options.setBaseOptions(proto); + } /** * Performs Natural Language classification on the provided text and waits @@ -130,16 +131,17 @@ export class TextClassifier extends TaskRunner { * @return The classification result of the text */ classify(text: string): TextClassifierResult { - // Get classification result by running our MediaPipe graph. + // Increment the timestamp by 1 millisecond to guarantee that we send + // monotonically increasing timestamps to the graph. + const syntheticTimestamp = this.getLatestOutputTimestamp() + 1; this.classificationResult = {classifications: []}; - this.addStringToStream( - text, INPUT_STREAM, /* timestamp= */ performance.now()); + this.graphRunner.addStringToStream(text, INPUT_STREAM, syntheticTimestamp); this.finishProcessing(); return this.classificationResult; } /** Updates the MediaPipe graph configuration. */ - private refreshGraph(): void { + protected override refreshGraph(): void { const graphConfig = new CalculatorGraphConfig(); graphConfig.addInputStream(INPUT_STREAM); graphConfig.addOutputStream(CLASSIFICATIONS_STREAM); @@ -156,10 +158,12 @@ export class TextClassifier extends TaskRunner { graphConfig.addNode(classifierNode); - this.attachProtoListener(CLASSIFICATIONS_STREAM, binaryProto => { - this.classificationResult = convertFromClassificationResultProto( - ClassificationResult.deserializeBinary(binaryProto)); - }); + this.graphRunner.attachProtoListener( + CLASSIFICATIONS_STREAM, (binaryProto, timestamp) => { + this.classificationResult = convertFromClassificationResultProto( + ClassificationResult.deserializeBinary(binaryProto)); + this.setLatestOutputTimestamp(timestamp); + }); const binaryGraph = graphConfig.serializeBinary(); this.setGraph(new Uint8Array(binaryGraph), /* isBinary= */ true); diff --git a/mediapipe/tasks/web/text/text_classifier/text_classifier_options.d.ts b/mediapipe/tasks/web/text/text_classifier/text_classifier_options.d.ts new file mode 100644 index 000000000..25592deb5 --- /dev/null +++ b/mediapipe/tasks/web/text/text_classifier/text_classifier_options.d.ts @@ -0,0 +1,22 @@ +/** + * Copyright 2022 The MediaPipe Authors. All Rights Reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +import {ClassifierOptions} from '../../../../tasks/web/core/classifier_options'; +import {TaskRunnerOptions} from '../../../../tasks/web/core/task_runner_options'; + +/** Options to configure the MediaPipe Text Classifier Task */ +export declare interface TextClassifierOptions extends ClassifierOptions, + TaskRunnerOptions {} diff --git a/mediapipe/tasks/web/text/text_classifier/text_classifier_result.ts b/mediapipe/tasks/web/text/text_classifier/text_classifier_result.d.ts similarity index 100% rename from mediapipe/tasks/web/text/text_classifier/text_classifier_result.ts rename to mediapipe/tasks/web/text/text_classifier/text_classifier_result.d.ts diff --git a/mediapipe/tasks/web/text/text_classifier/text_classifier_test.ts b/mediapipe/tasks/web/text/text_classifier/text_classifier_test.ts new file mode 100644 index 000000000..d9eb14865 --- /dev/null +++ b/mediapipe/tasks/web/text/text_classifier/text_classifier_test.ts @@ -0,0 +1,155 @@ +/** + * Copyright 2022 The MediaPipe Authors. All Rights Reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +import 'jasmine'; + +// Placeholder for internal dependency on encodeByteArray +import {CalculatorGraphConfig} from '../../../../framework/calculator_pb'; +import {Classification, ClassificationList} from '../../../../framework/formats/classification_pb'; +import {ClassificationResult, Classifications} from '../../../../tasks/cc/components/containers/proto/classifications_pb'; +import {addJasmineCustomFloatEqualityTester, createSpyWasmModule, MediapipeTasksFake, SpyWasmModule, verifyGraph, verifyListenersRegistered} from '../../../../tasks/web/core/task_runner_test_utils'; + +import {TextClassifier} from './text_classifier'; + +// The OSS JS API does not support the builder pattern. +// tslint:disable:jspb-use-builder-pattern + +class TextClassifierFake extends TextClassifier implements MediapipeTasksFake { + calculatorName = 'mediapipe.tasks.text.text_classifier.TextClassifierGraph'; + attachListenerSpies: jasmine.Spy[] = []; + graph: CalculatorGraphConfig|undefined; + fakeWasmModule: SpyWasmModule; + protoListener: + ((binaryProto: Uint8Array, timestamp: number) => void)|undefined; + + constructor() { + super(createSpyWasmModule(), /* glCanvas= */ null); + this.fakeWasmModule = + this.graphRunner.wasmModule as unknown as SpyWasmModule; + this.attachListenerSpies[0] = + spyOn(this.graphRunner, 'attachProtoListener') + .and.callFake((stream, listener) => { + expect(stream).toEqual('classifications_out'); + this.protoListener = listener; + }); + spyOn(this.graphRunner, 'setGraph').and.callFake(binaryGraph => { + this.graph = CalculatorGraphConfig.deserializeBinary(binaryGraph); + }); + } +} + +describe('TextClassifier', () => { + let textClassifier: TextClassifierFake; + + beforeEach(async () => { + addJasmineCustomFloatEqualityTester(); + textClassifier = new TextClassifierFake(); + await textClassifier.setOptions( + {baseOptions: {modelAssetBuffer: new Uint8Array([])}}); + }); + + it('initializes graph', async () => { + verifyGraph(textClassifier); + verifyListenersRegistered(textClassifier); + }); + + it('reloads graph when settings are changed', async () => { + await textClassifier.setOptions({maxResults: 1}); + verifyGraph(textClassifier, [['classifierOptions', 'maxResults'], 1]); + verifyListenersRegistered(textClassifier); + + await textClassifier.setOptions({maxResults: 5}); + verifyGraph(textClassifier, [['classifierOptions', 'maxResults'], 5]); + verifyListenersRegistered(textClassifier); + }); + + it('can use custom models', async () => { + const newModel = new Uint8Array([0, 1, 2, 3, 4]); + const newModelBase64 = Buffer.from(newModel).toString('base64'); + await textClassifier.setOptions({ + baseOptions: { + modelAssetBuffer: newModel, + } + }); + + verifyGraph( + textClassifier, + /* expectedCalculatorOptions= */ undefined, + /* expectedBaseOptions= */ + [ + 'modelAsset', { + fileContent: newModelBase64, + fileName: undefined, + fileDescriptorMeta: undefined, + filePointerMeta: undefined + } + ]); + }); + + it('merges options', async () => { + await textClassifier.setOptions({maxResults: 1}); + await textClassifier.setOptions({displayNamesLocale: 'en'}); + verifyGraph(textClassifier, [ + 'classifierOptions', { + maxResults: 1, + displayNamesLocale: 'en', + scoreThreshold: undefined, + categoryAllowlistList: [], + categoryDenylistList: [] + } + ]); + }); + + it('transforms results', async () => { + const classificationResult = new ClassificationResult(); + const classifcations = new Classifications(); + classifcations.setHeadIndex(1); + classifcations.setHeadName('headName'); + const classificationList = new ClassificationList(); + const classification = new Classification(); + classification.setIndex(1); + classification.setScore(0.2); + classification.setDisplayName('displayName'); + classification.setLabel('categoryName'); + classificationList.addClassification(classification); + classifcations.setClassificationList(classificationList); + classificationResult.addClassifications(classifcations); + + // Pass the test data to our listener + textClassifier.fakeWasmModule._waitUntilIdle.and.callFake(() => { + verifyListenersRegistered(textClassifier); + textClassifier.protoListener! + (classificationResult.serializeBinary(), 1337); + }); + + // Invoke the text classifier + const result = textClassifier.classify('foo'); + + expect(textClassifier.fakeWasmModule._waitUntilIdle).toHaveBeenCalled(); + expect(result).toEqual({ + classifications: [{ + categories: [{ + index: 1, + score: 0.2, + displayName: 'displayName', + categoryName: 'categoryName' + }], + headIndex: 1, + headName: 'headName' + }] + }); + }); +}); diff --git a/mediapipe/tasks/web/text/text_embedder/BUILD b/mediapipe/tasks/web/text/text_embedder/BUILD new file mode 100644 index 000000000..1514944bf --- /dev/null +++ b/mediapipe/tasks/web/text/text_embedder/BUILD @@ -0,0 +1,67 @@ +# This contains the MediaPipe Text Embedder Task. +# +# This task takes text input and performs embedding +# + +load("//mediapipe/framework/port:build_config.bzl", "mediapipe_ts_declaration", "mediapipe_ts_library") +load("@npm//@bazel/jasmine:index.bzl", "jasmine_node_test") + +package(default_visibility = ["//mediapipe/tasks:internal"]) + +licenses(["notice"]) + +mediapipe_ts_library( + name = "text_embedder", + srcs = ["text_embedder.ts"], + visibility = ["//visibility:public"], + deps = [ + ":text_embedder_types", + "//mediapipe/framework:calculator_jspb_proto", + "//mediapipe/framework:calculator_options_jspb_proto", + "//mediapipe/tasks/cc/components/containers/proto:embeddings_jspb_proto", + "//mediapipe/tasks/cc/core/proto:base_options_jspb_proto", + "//mediapipe/tasks/cc/text/text_embedder/proto:text_embedder_graph_options_jspb_proto", + "//mediapipe/tasks/web/components/containers:embedding_result", + "//mediapipe/tasks/web/components/processors:embedder_options", + "//mediapipe/tasks/web/components/processors:embedder_result", + "//mediapipe/tasks/web/components/utils:cosine_similarity", + "//mediapipe/tasks/web/core", + "//mediapipe/tasks/web/core:embedder_options", + "//mediapipe/tasks/web/core:task_runner", + "//mediapipe/web/graph_runner:graph_runner_ts", + ], +) + +mediapipe_ts_declaration( + name = "text_embedder_types", + srcs = [ + "text_embedder_options.d.ts", + "text_embedder_result.d.ts", + ], + visibility = ["//visibility:public"], + deps = [ + "//mediapipe/tasks/web/components/containers:embedding_result", + "//mediapipe/tasks/web/core", + "//mediapipe/tasks/web/core:embedder_options", + ], +) + +mediapipe_ts_library( + name = "text_embedder_test_lib", + testonly = True, + srcs = [ + "text_embedder_test.ts", + ], + deps = [ + ":text_embedder", + "//mediapipe/framework:calculator_jspb_proto", + "//mediapipe/tasks/cc/components/containers/proto:embeddings_jspb_proto", + "//mediapipe/tasks/web/core", + "//mediapipe/tasks/web/core:task_runner_test_utils", + ], +) + +jasmine_node_test( + name = "text_embedder_test", + deps = [":text_embedder_test_lib"], +) diff --git a/mediapipe/tasks/web/text/text_embedder/text_embedder.ts b/mediapipe/tasks/web/text/text_embedder/text_embedder.ts new file mode 100644 index 000000000..daa1d24ed --- /dev/null +++ b/mediapipe/tasks/web/text/text_embedder/text_embedder.ts @@ -0,0 +1,192 @@ +/** + * Copyright 2022 The MediaPipe Authors. All Rights Reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +import {CalculatorGraphConfig} from '../../../../framework/calculator_pb'; +import {CalculatorOptions} from '../../../../framework/calculator_options_pb'; +import {EmbeddingResult} from '../../../../tasks/cc/components/containers/proto/embeddings_pb'; +import {BaseOptions as BaseOptionsProto} from '../../../../tasks/cc/core/proto/base_options_pb'; +import {TextEmbedderGraphOptions as TextEmbedderGraphOptionsProto} from '../../../../tasks/cc/text/text_embedder/proto/text_embedder_graph_options_pb'; +import {Embedding} from '../../../../tasks/web/components/containers/embedding_result'; +import {convertEmbedderOptionsToProto} from '../../../../tasks/web/components/processors/embedder_options'; +import {convertFromEmbeddingResultProto} from '../../../../tasks/web/components/processors/embedder_result'; +import {computeCosineSimilarity} from '../../../../tasks/web/components/utils/cosine_similarity'; +import {CachedGraphRunner, TaskRunner} from '../../../../tasks/web/core/task_runner'; +import {WasmFileset} from '../../../../tasks/web/core/wasm_fileset'; +import {WasmModule} from '../../../../web/graph_runner/graph_runner'; +// Placeholder for internal dependency on trusted resource url + +import {TextEmbedderOptions} from './text_embedder_options'; +import {TextEmbedderResult} from './text_embedder_result'; + +export * from './text_embedder_options'; +export * from './text_embedder_result'; + +// The OSS JS API does not support the builder pattern. +// tslint:disable:jspb-use-builder-pattern + +const INPUT_STREAM = 'text_in'; +const EMBEDDINGS_STREAM = 'embeddings_out'; +const TEXT_EMBEDDER_CALCULATOR = + 'mediapipe.tasks.text.text_embedder.TextEmbedderGraph'; + +/** + * Performs embedding extraction on text. + */ +export class TextEmbedder extends TaskRunner { + private embeddingResult: TextEmbedderResult = {embeddings: []}; + private readonly options = new TextEmbedderGraphOptionsProto(); + + /** + * Initializes the Wasm runtime and creates a new text embedder from the + * provided options. + * @param wasmFileset A configuration object that provides the location of the + * Wasm binary and its loader. + * @param textEmbedderOptions The options for the text embedder. Note that + * either a path to the TFLite model or the model itself needs to be + * provided (via `baseOptions`). + */ + static createFromOptions( + wasmFileset: WasmFileset, + textEmbedderOptions: TextEmbedderOptions): Promise { + return TaskRunner.createInstance( + TextEmbedder, /* initializeCanvas= */ false, wasmFileset, + textEmbedderOptions); + } + + /** + * Initializes the Wasm runtime and creates a new text embedder based on the + * provided model asset buffer. + * @param wasmFileset A configuration object that provides the location of the + * Wasm binary and its loader. + * @param modelAssetBuffer A binary representation of the TFLite model. + */ + static createFromModelBuffer( + wasmFileset: WasmFileset, + modelAssetBuffer: Uint8Array): Promise { + return TaskRunner.createInstance( + TextEmbedder, /* initializeCanvas= */ false, wasmFileset, + {baseOptions: {modelAssetBuffer}}); + } + + /** + * Initializes the Wasm runtime and creates a new text embedder based on the + * path to the model asset. + * @param wasmFileset A configuration object that provides the location of the + * Wasm binary and its loader. + * @param modelAssetPath The path to the TFLite model. + */ + static createFromModelPath( + wasmFileset: WasmFileset, + modelAssetPath: string): Promise { + return TaskRunner.createInstance( + TextEmbedder, /* initializeCanvas= */ false, wasmFileset, + {baseOptions: {modelAssetPath}}); + } + + /** @hideconstructor */ + constructor( + wasmModule: WasmModule, + glCanvas?: HTMLCanvasElement|OffscreenCanvas|null) { + super(new CachedGraphRunner(wasmModule, glCanvas)); + this.options.setBaseOptions(new BaseOptionsProto()); + } + + /** + * Sets new options for the text embedder. + * + * Calling `setOptions()` with a subset of options only affects those options. + * You can reset an option back to its default value by explicitly setting it + * to `undefined`. + * + * @param options The options for the text embedder. + */ + override setOptions(options: TextEmbedderOptions): Promise { + this.options.setEmbedderOptions(convertEmbedderOptionsToProto( + options, this.options.getEmbedderOptions())); + return this.applyOptions(options); + } + + protected override get baseOptions(): BaseOptionsProto { + return this.options.getBaseOptions()!; + } + + protected override set baseOptions(proto: BaseOptionsProto) { + this.options.setBaseOptions(proto); + } + + /** + * Performs embeding extraction on the provided text and waits synchronously + * for the response. + * + * @param text The text to process. + * @return The embedding resuls of the text + */ + embed(text: string): TextEmbedderResult { + // Increment the timestamp by 1 millisecond to guarantee that we send + // monotonically increasing timestamps to the graph. + const syntheticTimestamp = this.getLatestOutputTimestamp() + 1; + this.graphRunner.addStringToStream(text, INPUT_STREAM, syntheticTimestamp); + this.finishProcessing(); + return this.embeddingResult; + } + + /** + * Utility function to compute cosine similarity[1] between two `Embedding` + * objects. + * + * [1]: https://en.wikipedia.org/wiki/Cosine_similarity + * + * @throws if the embeddings are of different types(float vs. quantized), have + * different sizes, or have an L2-norm of 0. + */ + static cosineSimilarity(u: Embedding, v: Embedding): number { + return computeCosineSimilarity(u, v); + } + + /** Updates the MediaPipe graph configuration. */ + protected override refreshGraph(): void { + const graphConfig = new CalculatorGraphConfig(); + graphConfig.addInputStream(INPUT_STREAM); + graphConfig.addOutputStream(EMBEDDINGS_STREAM); + + const calculatorOptions = new CalculatorOptions(); + calculatorOptions.setExtension( + TextEmbedderGraphOptionsProto.ext, this.options); + + const embedderNode = new CalculatorGraphConfig.Node(); + embedderNode.setCalculator(TEXT_EMBEDDER_CALCULATOR); + embedderNode.addInputStream('TEXT:' + INPUT_STREAM); + embedderNode.addOutputStream('EMBEDDINGS:' + EMBEDDINGS_STREAM); + embedderNode.setOptions(calculatorOptions); + + graphConfig.addNode(embedderNode); + + this.graphRunner.attachProtoListener( + EMBEDDINGS_STREAM, (binaryProto, timestamp) => { + const embeddingResult = + EmbeddingResult.deserializeBinary(binaryProto); + this.embeddingResult = + convertFromEmbeddingResultProto(embeddingResult); + this.setLatestOutputTimestamp(timestamp); + }); + + const binaryGraph = graphConfig.serializeBinary(); + this.setGraph(new Uint8Array(binaryGraph), /* isBinary= */ true); + } +} + + + diff --git a/mediapipe/tasks/web/text/text_embedder/text_embedder_options.d.ts b/mediapipe/tasks/web/text/text_embedder/text_embedder_options.d.ts new file mode 100644 index 000000000..7689ee0c1 --- /dev/null +++ b/mediapipe/tasks/web/text/text_embedder/text_embedder_options.d.ts @@ -0,0 +1,22 @@ +/** + * Copyright 2022 The MediaPipe Authors. All Rights Reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +import {EmbedderOptions} from '../../../../tasks/web/core/embedder_options'; +import {TaskRunnerOptions} from '../../../../tasks/web/core/task_runner_options'; + +/** Options to configure the MediaPipe Text Embedder Task */ +export declare interface TextEmbedderOptions extends EmbedderOptions, + TaskRunnerOptions {} diff --git a/mediapipe/tasks/web/audio/audio_classifier/audio_classifier_options.ts b/mediapipe/tasks/web/text/text_embedder/text_embedder_result.d.ts similarity index 83% rename from mediapipe/tasks/web/audio/audio_classifier/audio_classifier_options.ts rename to mediapipe/tasks/web/text/text_embedder/text_embedder_result.d.ts index 93bd9927e..65640b507 100644 --- a/mediapipe/tasks/web/audio/audio_classifier/audio_classifier_options.ts +++ b/mediapipe/tasks/web/text/text_embedder/text_embedder_result.d.ts @@ -14,4 +14,4 @@ * limitations under the License. */ -export {ClassifierOptions as AudioClassifierOptions} from '../../../../tasks/web/core/classifier_options'; +export {Embedding, EmbeddingResult as TextEmbedderResult} from '../../../../tasks/web/components/containers/embedding_result'; diff --git a/mediapipe/tasks/web/text/text_embedder/text_embedder_test.ts b/mediapipe/tasks/web/text/text_embedder/text_embedder_test.ts new file mode 100644 index 000000000..e26b85bf4 --- /dev/null +++ b/mediapipe/tasks/web/text/text_embedder/text_embedder_test.ts @@ -0,0 +1,167 @@ +/** + * Copyright 2022 The MediaPipe Authors. All Rights Reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +import 'jasmine'; + +// Placeholder for internal dependency on encodeByteArray +import {CalculatorGraphConfig} from '../../../../framework/calculator_pb'; +import {Embedding, EmbeddingResult, FloatEmbedding, QuantizedEmbedding} from '../../../../tasks/cc/components/containers/proto/embeddings_pb'; +import {addJasmineCustomFloatEqualityTester, createSpyWasmModule, MediapipeTasksFake, SpyWasmModule, verifyGraph, verifyListenersRegistered} from '../../../../tasks/web/core/task_runner_test_utils'; + +import {TextEmbedder} from './text_embedder'; + +// The OSS JS API does not support the builder pattern. +// tslint:disable:jspb-use-builder-pattern + +class TextEmbedderFake extends TextEmbedder implements MediapipeTasksFake { + calculatorName = 'mediapipe.tasks.text.text_embedder.TextEmbedderGraph'; + graph: CalculatorGraphConfig|undefined; + attachListenerSpies: jasmine.Spy[] = []; + fakeWasmModule: SpyWasmModule; + protoListener: + ((binaryProtos: Uint8Array, timestamp: number) => void)|undefined; + + constructor() { + super(createSpyWasmModule(), /* glCanvas= */ null); + this.fakeWasmModule = + this.graphRunner.wasmModule as unknown as SpyWasmModule; + + this.attachListenerSpies[0] = + spyOn(this.graphRunner, 'attachProtoListener') + .and.callFake((stream, listener) => { + expect(stream).toEqual('embeddings_out'); + this.protoListener = listener; + }); + spyOn(this.graphRunner, 'setGraph').and.callFake(binaryGraph => { + this.graph = CalculatorGraphConfig.deserializeBinary(binaryGraph); + }); + } +} + +describe('TextEmbedder', () => { + let textEmbedder: TextEmbedderFake; + + beforeEach(async () => { + addJasmineCustomFloatEqualityTester(); + textEmbedder = new TextEmbedderFake(); + await textEmbedder.setOptions( + {baseOptions: {modelAssetBuffer: new Uint8Array([])}}); + }); + + it('initializes graph', async () => { + verifyGraph(textEmbedder); + verifyListenersRegistered(textEmbedder); + }); + + it('reloads graph when settings are changed', async () => { + await textEmbedder.setOptions({quantize: true}); + verifyGraph(textEmbedder, [['embedderOptions', 'quantize'], true]); + verifyListenersRegistered(textEmbedder); + + await textEmbedder.setOptions({quantize: undefined}); + verifyGraph(textEmbedder, [['embedderOptions', 'quantize'], undefined]); + verifyListenersRegistered(textEmbedder); + }); + + it('can use custom models', async () => { + const newModel = new Uint8Array([0, 1, 2, 3, 4]); + const newModelBase64 = Buffer.from(newModel).toString('base64'); + await textEmbedder.setOptions({ + baseOptions: { + modelAssetBuffer: newModel, + } + }); + + verifyGraph( + textEmbedder, + /* expectedCalculatorOptions= */ undefined, + /* expectedBaseOptions= */[ + 'modelAsset', { + fileContent: newModelBase64, + fileName: undefined, + fileDescriptorMeta: undefined, + filePointerMeta: undefined + } + ]); + }); + + it('combines options', async () => { + await textEmbedder.setOptions({quantize: true}); + await textEmbedder.setOptions({l2Normalize: true}); + verifyGraph( + textEmbedder, + ['embedderOptions', {'quantize': true, 'l2Normalize': true}]); + }); + + it('transforms results', async () => { + const embedding = new Embedding(); + embedding.setHeadIndex(1); + embedding.setHeadName('headName'); + + const floatEmbedding = new FloatEmbedding(); + floatEmbedding.setValuesList([0.1, 0.9]); + + embedding.setFloatEmbedding(floatEmbedding); + const resultProto = new EmbeddingResult(); + resultProto.addEmbeddings(embedding); + + // Pass the test data to our listener + textEmbedder.fakeWasmModule._waitUntilIdle.and.callFake(() => { + verifyListenersRegistered(textEmbedder); + textEmbedder.protoListener!(resultProto.serializeBinary(), 1337); + }); + + // Invoke the text embedder + const embeddingResult = textEmbedder.embed('foo'); + + expect(textEmbedder.fakeWasmModule._waitUntilIdle).toHaveBeenCalled(); + expect(embeddingResult.embeddings.length).toEqual(1); + expect(embeddingResult.embeddings[0]) + .toEqual( + {floatEmbedding: [0.1, 0.9], headIndex: 1, headName: 'headName'}); + }); + + it('transforms custom quantized values', async () => { + const embedding = new Embedding(); + embedding.setHeadIndex(1); + embedding.setHeadName('headName'); + + const quantizedEmbedding = new QuantizedEmbedding(); + const quantizedValues = new Uint8Array([1, 2, 3]); + quantizedEmbedding.setValues(quantizedValues); + + embedding.setQuantizedEmbedding(quantizedEmbedding); + const resultProto = new EmbeddingResult(); + resultProto.addEmbeddings(embedding); + + // Pass the test data to our listener + textEmbedder.fakeWasmModule._waitUntilIdle.and.callFake(() => { + verifyListenersRegistered(textEmbedder); + textEmbedder.protoListener!(resultProto.serializeBinary(), 1337); + }); + + // Invoke the text embedder + const embeddingsResult = textEmbedder.embed('foo'); + + expect(textEmbedder.fakeWasmModule._waitUntilIdle).toHaveBeenCalled(); + expect(embeddingsResult.embeddings.length).toEqual(1); + expect(embeddingsResult.embeddings[0]).toEqual({ + quantizedEmbedding: new Uint8Array([1, 2, 3]), + headIndex: 1, + headName: 'headName' + }); + }); +}); diff --git a/mediapipe/tasks/web/text.ts b/mediapipe/tasks/web/text/types.ts similarity index 68% rename from mediapipe/tasks/web/text.ts rename to mediapipe/tasks/web/text/types.ts index f8a0b6457..bd01b1c6f 100644 --- a/mediapipe/tasks/web/text.ts +++ b/mediapipe/tasks/web/text/types.ts @@ -1,5 +1,5 @@ /** - * Copyright 2022 The MediaPipe Authors. All Rights Reserved. + * Copyright 2023 The MediaPipe Authors. All Rights Reserved. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. @@ -14,4 +14,6 @@ * limitations under the License. */ -export * from '../../tasks/web/text/index'; +export * from '../../../tasks/web/core/fileset_resolver'; +export * from '../../../tasks/web/text/text_classifier/text_classifier'; +export * from '../../../tasks/web/text/text_embedder/text_embedder'; diff --git a/mediapipe/tasks/web/vision.ts b/mediapipe/tasks/web/vision.ts deleted file mode 100644 index 6ff8f725b..000000000 --- a/mediapipe/tasks/web/vision.ts +++ /dev/null @@ -1,17 +0,0 @@ -/** - * Copyright 2022 The MediaPipe Authors. All Rights Reserved. - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -export * from '../../tasks/web/vision/index'; diff --git a/mediapipe/tasks/web/vision/BUILD b/mediapipe/tasks/web/vision/BUILD index abdbc54ea..a229cbd2a 100644 --- a/mediapipe/tasks/web/vision/BUILD +++ b/mediapipe/tasks/web/vision/BUILD @@ -1,15 +1,85 @@ # This contains the MediaPipe Vision Tasks. load("//mediapipe/framework/port:build_config.bzl", "mediapipe_ts_library") +load("@build_bazel_rules_nodejs//:index.bzl", "pkg_npm") +load("@npm//@bazel/rollup:index.bzl", "rollup_bundle") +load( + "//mediapipe/framework/tool:mediapipe_files.bzl", + "mediapipe_files", +) package(default_visibility = ["//mediapipe/tasks:internal"]) +mediapipe_files(srcs = [ + "wasm/vision_wasm_internal.js", + "wasm/vision_wasm_internal.wasm", + "wasm/vision_wasm_nosimd_internal.js", + "wasm/vision_wasm_nosimd_internal.wasm", +]) + +VISION_LIBS = [ + "//mediapipe/tasks/web/core:fileset_resolver", + "//mediapipe/tasks/web/vision/gesture_recognizer", + "//mediapipe/tasks/web/vision/hand_landmarker", + "//mediapipe/tasks/web/vision/image_classifier", + "//mediapipe/tasks/web/vision/image_embedder", + "//mediapipe/tasks/web/vision/image_segmenter", + "//mediapipe/tasks/web/vision/object_detector", +] + mediapipe_ts_library( name = "vision_lib", srcs = ["index.ts"], + visibility = ["//visibility:public"], + deps = VISION_LIBS, +) + +mediapipe_ts_library( + name = "vision_types", + srcs = ["types.ts"], + visibility = ["//visibility:public"], + deps = VISION_LIBS, +) + +rollup_bundle( + name = "vision_bundle", + config_file = "//mediapipe/tasks/web:rollup.config.mjs", + entry_point = "index.ts", + format = "esm", + output_dir = False, + sourcemap = "false", deps = [ - "//mediapipe/tasks/web/vision/gesture_recognizer", - "//mediapipe/tasks/web/vision/image_classifier", - "//mediapipe/tasks/web/vision/object_detector", + ":vision_lib", + "@npm//@rollup/plugin-commonjs", + "@npm//@rollup/plugin-node-resolve", + "@npm//@rollup/plugin-terser", + "@npm//google-protobuf", + ], +) + +genrule( + name = "package_json", + srcs = ["//mediapipe/tasks/web:package.json"], + outs = ["package.json"], + cmd = "cp $< $@", +) + +pkg_npm( + name = "vision_pkg", + package_name = "@mediapipe/tasks-__NAME__", + srcs = ["README.md"], + substitutions = { + "__NAME__": "vision", + "__DESCRIPTION__": "MediaPipe Vision Tasks", + "__TYPES__": "vision.d.ts", + }, + tgz = "vision_pkg.tgz", + deps = [ + "wasm/vision_wasm_internal.js", + "wasm/vision_wasm_internal.wasm", + "wasm/vision_wasm_nosimd_internal.js", + "wasm/vision_wasm_nosimd_internal.wasm", + ":package_json", + ":vision_bundle", ], ) diff --git a/mediapipe/tasks/web/vision/README.md b/mediapipe/tasks/web/vision/README.md new file mode 100644 index 000000000..9e86eafd3 --- /dev/null +++ b/mediapipe/tasks/web/vision/README.md @@ -0,0 +1,95 @@ +# MediaPipe Tasks Vision Package + +This package contains the vision tasks for MediaPipe. + +## Object Detection + +The MediaPipe Object Detector task lets you detect the presence and location of +multiple classes of objects within images or videos. + +``` +const vision = await FilesetResolver.forVisionTasks( + "https://cdn.jsdelivr.net/npm/@mediapipe/tasks-vision@latest/wasm" +); +const objectDetector = await ObjectDetector.createFromModelPath(vision, + "https://storage.googleapis.com/mediapipe-tasks/object_detector/efficientdet_lite0_uint8.tflite" +); +const image = document.getElementById("image") as HTMLImageElement; +const detections = objectDetector.detect(image); +``` + +For more information, refer to the [Object Detector](https://developers.google.com/mediapipe/solutions/vision/object_detector/web_js) documentation. + +## Image Classification + +The MediaPipe Image Classifier task lets you perform classification on images. +You can use this task to identify what an image represents among a set of +categories defined at training time. + +``` +const vision = await FilesetResolver.forVisionTasks( + "https://cdn.jsdelivr.net/npm/@mediapipe/tasks-vision@latest/wasm" +); +const imageClassifier = await ImageClassifier.createFromModelPath(vision, + "https://storage.googleapis.com/mediapipe-tasks/image_classifier/efficientnet_lite0_uint8.tflite" +); +const image = document.getElementById("image") as HTMLImageElement; +const classifications = imageClassifier.classify(image); +``` + +For more information, refer to the [Image Classification](https://developers.google.com/mediapipe/solutions/vision/image_classifier/web_js) documentation. + +## Image Segmentation + +The MediaPipe Image Segmenter lets you segment an image into categories. + +``` +const vision = await FilesetResolver.forVisionTasks( + "https://cdn.jsdelivr.net/npm/@mediapipe/tasks-vision@latest/wasm" +); +const imageSegmenter = await ImageSegmenter.createFromModelPath(vision, + "model.tflite" +); +const image = document.getElementById("image") as HTMLImageElement; +imageSegmenter.segment(image, (masks, width, height) => { + ... +}); +``` + +## Gesture Recognition + +The MediaPipe Gesture Recognizer task lets you recognize hand gestures in real +time, and provides the recognized hand gesture results along with the landmarks +of the detected hands. You can use this task to recognize specific hand gestures +from a user, and invoke application features that correspond to those gestures. + +``` +const vision = await FilesetResolver.forVisionTasks( + "https://cdn.jsdelivr.net/npm/@mediapipe/tasks-vision@latest/wasm" +); +const gestureRecognizer = await GestureRecognizer.createFromModelPath(vision, + "https://storage.googleapis.com/mediapipe-tasks/gesture_recognizer/gesture_recognizer.task" +); +const image = document.getElementById("image") as HTMLImageElement; +const recognitions = gestureRecognizer.recognize(image); +``` + +## Handlandmark Detection + +The MediaPipe Hand Landmarker task lets you detect the landmarks of the hands in +an image. You can use this Task to localize key points of the hands and render +visual effects over the hands. + +``` +const vision = await FilesetResolver.forVisionTasks( + "https://cdn.jsdelivr.net/npm/@mediapipe/tasks-vision@latest/wasm" +); +const handLandmarker = await HandLandmarker.createFromModelPath(vision, + "https://storage.googleapis.com/mediapipe-tasks/hand_landmarker/hand_landmarker.task" +); +const image = document.getElementById("image") as HTMLImageElement; +const landmarks = handLandmarker.detect(image); +``` + +For more information, refer to the [Handlandmark Detection](https://developers.google.com/mediapipe/solutions/vision/hand_landmarker/web_js) documentation. + diff --git a/mediapipe/tasks/web/vision/core/BUILD b/mediapipe/tasks/web/vision/core/BUILD new file mode 100644 index 000000000..a0a008122 --- /dev/null +++ b/mediapipe/tasks/web/vision/core/BUILD @@ -0,0 +1,57 @@ +# This package contains options shared by all MediaPipe Vision Tasks for Web. + +load("@npm//@bazel/jasmine:index.bzl", "jasmine_node_test") +load("//mediapipe/framework/port:build_config.bzl", "mediapipe_ts_declaration", "mediapipe_ts_library") + +package(default_visibility = ["//mediapipe/tasks:internal"]) + +mediapipe_ts_declaration( + name = "image_processing_options", + srcs = ["image_processing_options.d.ts"], + deps = [ + "//mediapipe/tasks/web/components/containers:rect", + ], +) + +mediapipe_ts_declaration( + name = "vision_task_options", + srcs = ["vision_task_options.d.ts"], + deps = [ + "//mediapipe/tasks/web/core", + ], +) + +mediapipe_ts_library( + name = "vision_task_runner", + srcs = ["vision_task_runner.ts"], + deps = [ + ":image_processing_options", + ":vision_task_options", + "//mediapipe/framework/formats:rect_jspb_proto", + "//mediapipe/tasks/web/core", + "//mediapipe/tasks/web/core:task_runner", + "//mediapipe/web/graph_runner:graph_runner_image_lib_ts", + "//mediapipe/web/graph_runner:graph_runner_ts", + "//mediapipe/web/graph_runner:register_model_resources_graph_service_ts", + ], +) + +mediapipe_ts_library( + name = "vision_task_runner_test_lib", + testonly = True, + srcs = ["vision_task_runner.test.ts"], + deps = [ + ":image_processing_options", + ":vision_task_options", + ":vision_task_runner", + "//mediapipe/framework/formats:rect_jspb_proto", + "//mediapipe/tasks/cc/core/proto:base_options_jspb_proto", + "//mediapipe/tasks/web/core:task_runner_test_utils", + "//mediapipe/web/graph_runner:graph_runner_ts", + ], +) + +jasmine_node_test( + name = "vision_task_runner_test", + deps = [":vision_task_runner_test_lib"], +) diff --git a/mediapipe/tasks/web/vision/core/image_processing_options.d.ts b/mediapipe/tasks/web/vision/core/image_processing_options.d.ts new file mode 100644 index 000000000..b76731546 --- /dev/null +++ b/mediapipe/tasks/web/vision/core/image_processing_options.d.ts @@ -0,0 +1,42 @@ +/** + * Copyright 2023 The MediaPipe Authors. All Rights Reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +import {RectF} from '../../../../tasks/web/components/containers/rect'; + +/** + * Options for image processing. + * + * If both region-or-interest and rotation are specified, the crop around the + * region-of-interest is extracted first, then the specified rotation is applied + * to the crop. + */ +export declare interface ImageProcessingOptions { + /** + * The optional region-of-interest to crop from the image. If not specified, + * the full image is used. + * + * Coordinates must be in [0,1] with 'left' < 'right' and 'top' < bottom. + */ + regionOfInterest?: RectF; + + /** + * The rotation to apply to the image (or cropped region-of-interest), in + * degrees clockwise. + * + * The rotation must be a multiple (positive or negative) of 90°. + */ + rotationDegrees?: number; +} diff --git a/mediapipe/tasks/web/vision/core/vision_task_options.d.ts b/mediapipe/tasks/web/vision/core/vision_task_options.d.ts new file mode 100644 index 000000000..44b1660ff --- /dev/null +++ b/mediapipe/tasks/web/vision/core/vision_task_options.d.ts @@ -0,0 +1,35 @@ +/** + * Copyright 2022 The MediaPipe Authors. All Rights Reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +import {TaskRunnerOptions} from '../../../../tasks/web/core/task_runner_options'; + +/** + * The two running modes of a vision task. + * 1) The image mode for processing single image inputs. + * 2) The video mode for processing decoded frames of a video. + */ +export type RunningMode = 'IMAGE'|'VIDEO'; + +/** The options for configuring a MediaPipe vision task. */ +export declare interface VisionTaskOptions extends TaskRunnerOptions { + /** + * The running mode of the task. Default to the image mode. + * Vision tasks have two running modes: + * 1) The image mode for processing single image inputs. + * 2) The video mode for processing decoded frames of a video. + */ + runningMode?: RunningMode; +} diff --git a/mediapipe/tasks/web/vision/core/vision_task_runner.test.ts b/mediapipe/tasks/web/vision/core/vision_task_runner.test.ts new file mode 100644 index 000000000..4eb51afdb --- /dev/null +++ b/mediapipe/tasks/web/vision/core/vision_task_runner.test.ts @@ -0,0 +1,256 @@ +/** + * Copyright 2022 The MediaPipe Authors. All Rights Reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +import 'jasmine'; + +import {NormalizedRect} from '../../../../framework/formats/rect_pb'; +import {BaseOptions as BaseOptionsProto} from '../../../../tasks/cc/core/proto/base_options_pb'; +import {addJasmineCustomFloatEqualityTester} from '../../../../tasks/web/core/task_runner_test_utils'; +import {ImageProcessingOptions} from '../../../../tasks/web/vision/core/image_processing_options'; +import {ImageSource} from '../../../../web/graph_runner/graph_runner'; + +import {VisionTaskOptions} from './vision_task_options'; +import {VisionGraphRunner, VisionTaskRunner} from './vision_task_runner'; + + +// The OSS JS API does not support the builder pattern. +// tslint:disable:jspb-use-builder-pattern + +const IMAGE_STREAM = 'image_in'; +const NORM_RECT_STREAM = 'norm_rect'; + +const IMAGE = {} as unknown as HTMLImageElement; +const TIMESTAMP = 42; + +class VisionTaskRunnerFake extends VisionTaskRunner { + baseOptions = new BaseOptionsProto(); + fakeGraphRunner: jasmine.SpyObj; + expectedImageSource?: ImageSource; + expectedNormalizedRect?: NormalizedRect; + + constructor(roiAllowed = true) { + super( + jasmine.createSpyObj([ + 'addProtoToStream', 'addGpuBufferAsImageToStream', + 'setAutoRenderToScreen', 'registerModelResourcesGraphService', + 'finishProcessing' + ]), + IMAGE_STREAM, NORM_RECT_STREAM, roiAllowed); + + this.fakeGraphRunner = + this.graphRunner as unknown as jasmine.SpyObj; + + (this.graphRunner.addProtoToStream as jasmine.Spy) + .and.callFake((serializedData, type, streamName, timestamp) => { + expect(type).toBe('mediapipe.NormalizedRect'); + expect(streamName).toBe(NORM_RECT_STREAM); + expect(timestamp).toBe(TIMESTAMP); + + const actualNormalizedRect = + NormalizedRect.deserializeBinary(serializedData); + expect(actualNormalizedRect.toObject()) + .toEqual(this.expectedNormalizedRect!.toObject()); + }); + + (this.graphRunner.addGpuBufferAsImageToStream as jasmine.Spy) + .and.callFake((imageSource, streamName, timestamp) => { + expect(streamName).toBe(IMAGE_STREAM); + expect(timestamp).toBe(TIMESTAMP); + expect(imageSource).toBe(this.expectedImageSource!); + }); + + // SetOptions with a modelAssetBuffer runs synchonously + void this.setOptions({baseOptions: {modelAssetBuffer: new Uint8Array([])}}); + } + + protected override refreshGraph(): void {} + + override setOptions(options: VisionTaskOptions): Promise { + return this.applyOptions(options); + } + + override processImageData( + image: ImageSource, + imageProcessingOptions: ImageProcessingOptions|undefined): void { + super.processImageData(image, imageProcessingOptions); + } + + override processVideoData( + imageFrame: ImageSource, + imageProcessingOptions: ImageProcessingOptions|undefined, + timestamp: number): void { + super.processVideoData(imageFrame, imageProcessingOptions, timestamp); + } + + expectNormalizedRect( + xCenter: number, yCenter: number, width: number, height: number): void { + const rect = new NormalizedRect(); + rect.setXCenter(xCenter); + rect.setYCenter(yCenter); + rect.setWidth(width); + rect.setHeight(height); + this.expectedNormalizedRect = rect; + } + + expectImage(imageSource: ImageSource): void { + this.expectedImageSource = imageSource; + } +} + +describe('VisionTaskRunner', () => { + beforeEach(() => { + addJasmineCustomFloatEqualityTester(); + }); + + it('can enable image mode', async () => { + const visionTaskRunner = new VisionTaskRunnerFake(); + await visionTaskRunner.setOptions({runningMode: 'IMAGE'}); + expect(visionTaskRunner.baseOptions.toObject()) + .toEqual(jasmine.objectContaining({useStreamMode: false})); + }); + + it('can enable video mode', async () => { + const visionTaskRunner = new VisionTaskRunnerFake(); + await visionTaskRunner.setOptions({runningMode: 'VIDEO'}); + expect(visionTaskRunner.baseOptions.toObject()) + .toEqual(jasmine.objectContaining({useStreamMode: true})); + }); + + it('can clear running mode', async () => { + const visionTaskRunner = new VisionTaskRunnerFake(); + await visionTaskRunner.setOptions({runningMode: 'VIDEO'}); + + // Clear running mode + await visionTaskRunner.setOptions( + {runningMode: /* imageProcessingOptions= */ undefined}); + expect(visionTaskRunner.baseOptions.toObject()) + .toEqual(jasmine.objectContaining({useStreamMode: false})); + }); + + it('cannot process images with video mode', async () => { + const visionTaskRunner = new VisionTaskRunnerFake(); + await visionTaskRunner.setOptions({runningMode: 'VIDEO'}); + expect(() => { + visionTaskRunner.processImageData( + IMAGE, /* imageProcessingOptions= */ undefined); + }).toThrowError(/Task is not initialized with image mode./); + }); + + it('cannot process video with image mode', async () => { + const visionTaskRunner = new VisionTaskRunnerFake(); + // Use default for `useStreamMode` + expect(() => { + visionTaskRunner.processVideoData( + IMAGE, /* imageProcessingOptions= */ undefined, TIMESTAMP); + }).toThrowError(/Task is not initialized with video mode./); + + // Explicitly set to image mode + await visionTaskRunner.setOptions({runningMode: 'IMAGE'}); + expect(() => { + visionTaskRunner.processVideoData( + IMAGE, /* imageProcessingOptions= */ undefined, TIMESTAMP); + }).toThrowError(/Task is not initialized with video mode./); + }); + + it('sends packets to graph', async () => { + const visionTaskRunner = new VisionTaskRunnerFake(); + await visionTaskRunner.setOptions({runningMode: 'VIDEO'}); + + visionTaskRunner.expectImage(IMAGE); + visionTaskRunner.expectNormalizedRect(0.5, 0.5, 1, 1); + visionTaskRunner.processVideoData( + IMAGE, /* imageProcessingOptions= */ undefined, TIMESTAMP); + }); + + it('sends packets to graph with image processing options', async () => { + const visionTaskRunner = new VisionTaskRunnerFake(); + await visionTaskRunner.setOptions({runningMode: 'VIDEO'}); + + visionTaskRunner.expectImage(IMAGE); + visionTaskRunner.expectNormalizedRect(0.3, 0.6, 0.2, 0.4); + visionTaskRunner.processVideoData( + IMAGE, + {regionOfInterest: {left: 0.2, right: 0.4, top: 0.4, bottom: 0.8}}, + TIMESTAMP); + }); + + describe('validates processing options', () => { + it('with left > right', () => { + const visionTaskRunner = new VisionTaskRunnerFake(); + expect(() => { + visionTaskRunner.processImageData(IMAGE, { + regionOfInterest: { + left: 0.2, + right: 0.1, + top: 0.1, + bottom: 0.2, + } + }); + }).toThrowError('Expected RectF with left < right and top < bottom.'); + }); + + it('with top > bottom', () => { + const visionTaskRunner = new VisionTaskRunnerFake(); + expect(() => { + visionTaskRunner.processImageData(IMAGE, { + regionOfInterest: { + left: 0.1, + right: 0.2, + top: 0.2, + bottom: 0.1, + } + }); + }).toThrowError('Expected RectF with left < right and top < bottom.'); + }); + + it('with out of range values', () => { + const visionTaskRunner = new VisionTaskRunnerFake(); + expect(() => { + visionTaskRunner.processImageData(IMAGE, { + regionOfInterest: { + left: 0.1, + right: 1.1, + top: 0.1, + bottom: 0.2, + } + }); + }).toThrowError('Expected RectF values to be in [0,1].'); + }); + + + it('without region of interest support', () => { + const visionTaskRunner = + new VisionTaskRunnerFake(/* roiAllowed= */ false); + expect(() => { + visionTaskRunner.processImageData(IMAGE, { + regionOfInterest: { + left: 0.1, + right: 0.2, + top: 0.1, + bottom: 0.2, + } + }); + }).toThrowError('This task doesn\'t support region-of-interest.'); + }); + + it('with non-90 degree rotation', () => { + const visionTaskRunner = new VisionTaskRunnerFake(); + expect(() => { + visionTaskRunner.processImageData(IMAGE, {rotationDegrees: 42}); + }).toThrowError('Expected rotation to be a multiple of 90°.'); + }); + }); +}); diff --git a/mediapipe/tasks/web/vision/core/vision_task_runner.ts b/mediapipe/tasks/web/vision/core/vision_task_runner.ts new file mode 100644 index 000000000..b3e8ed4db --- /dev/null +++ b/mediapipe/tasks/web/vision/core/vision_task_runner.ts @@ -0,0 +1,157 @@ +/** + * Copyright 2022 The MediaPipe Authors. All Rights Reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +import {NormalizedRect} from '../../../../framework/formats/rect_pb'; +import {TaskRunner} from '../../../../tasks/web/core/task_runner'; +import {ImageProcessingOptions} from '../../../../tasks/web/vision/core/image_processing_options'; +import {GraphRunner, ImageSource} from '../../../../web/graph_runner/graph_runner'; +import {SupportImage} from '../../../../web/graph_runner/graph_runner_image_lib'; +import {SupportModelResourcesGraphService} from '../../../../web/graph_runner/register_model_resources_graph_service'; + +import {VisionTaskOptions} from './vision_task_options'; + +// tslint:disable-next-line:enforce-name-casing +const GraphRunnerVisionType = + SupportModelResourcesGraphService(SupportImage(GraphRunner)); +/** An implementation of the GraphRunner that supports image operations */ +export class VisionGraphRunner extends GraphRunnerVisionType {} + +// The OSS JS API does not support the builder pattern. +// tslint:disable:jspb-use-builder-pattern + +/** Base class for all MediaPipe Vision Tasks. */ +export abstract class VisionTaskRunner extends TaskRunner { + /** + * Constructor to initialize a `VisionTaskRunner`. + * + * @param graphRunner the graph runner for this task. + * @param imageStreamName the name of the input image stream. + * @param normRectStreamName the name of the input normalized rect image + * stream used to provide (mandatory) rotation and (optional) + * region-of-interest. + * @param roiAllowed Whether this task supports Region-Of-Interest + * pre-processing + * + * @hideconstructor protected + */ + constructor( + protected override readonly graphRunner: VisionGraphRunner, + private readonly imageStreamName: string, + private readonly normRectStreamName: string, + private readonly roiAllowed: boolean) { + super(graphRunner); + } + + /** Configures the shared options of a vision task. */ + override applyOptions(options: VisionTaskOptions): Promise { + if ('runningMode' in options) { + const useStreamMode = + !!options.runningMode && options.runningMode !== 'IMAGE'; + this.baseOptions.setUseStreamMode(useStreamMode); + } + return super.applyOptions(options); + } + + /** Sends a single image to the graph and awaits results. */ + protected processImageData( + image: ImageSource, + imageProcessingOptions: ImageProcessingOptions|undefined): void { + if (!!this.baseOptions?.getUseStreamMode()) { + throw new Error( + 'Task is not initialized with image mode. ' + + '\'runningMode\' must be set to \'IMAGE\'.'); + } + + // Increment the timestamp by 1 millisecond to guarantee that we send + // monotonically increasing timestamps to the graph. + const syntheticTimestamp = this.getLatestOutputTimestamp() + 1; + this.process(image, imageProcessingOptions, syntheticTimestamp); + } + + /** Sends a single video frame to the graph and awaits results. */ + protected processVideoData( + imageFrame: ImageSource, + imageProcessingOptions: ImageProcessingOptions|undefined, + timestamp: number): void { + if (!this.baseOptions?.getUseStreamMode()) { + throw new Error( + 'Task is not initialized with video mode. ' + + '\'runningMode\' must be set to \'VIDEO\'.'); + } + this.process(imageFrame, imageProcessingOptions, timestamp); + } + + private convertToNormalizedRect(imageProcessingOptions?: + ImageProcessingOptions): NormalizedRect { + const normalizedRect = new NormalizedRect(); + + if (imageProcessingOptions?.regionOfInterest) { + if (!this.roiAllowed) { + throw new Error('This task doesn\'t support region-of-interest.'); + } + + const roi = imageProcessingOptions.regionOfInterest; + + if (roi.left >= roi.right || roi.top >= roi.bottom) { + throw new Error('Expected RectF with left < right and top < bottom.'); + } + if (roi.left < 0 || roi.top < 0 || roi.right > 1 || roi.bottom > 1) { + throw new Error('Expected RectF values to be in [0,1].'); + } + + normalizedRect.setXCenter((roi.left + roi.right) / 2.0); + normalizedRect.setYCenter((roi.top + roi.bottom) / 2.0); + normalizedRect.setWidth(roi.right - roi.left); + normalizedRect.setHeight(roi.bottom - roi.top); + return normalizedRect; + } else { + normalizedRect.setXCenter(0.5); + normalizedRect.setYCenter(0.5); + normalizedRect.setWidth(1); + normalizedRect.setHeight(1); + } + + if (imageProcessingOptions?.rotationDegrees) { + if (imageProcessingOptions?.rotationDegrees % 90 !== 0) { + throw new Error( + 'Expected rotation to be a multiple of 90°.', + ); + } + + // Convert to radians anti-clockwise. + normalizedRect.setRotation( + -Math.PI * imageProcessingOptions.rotationDegrees / 180.0); + } + + return normalizedRect; + } + + /** Runs the graph and blocks on the response. */ + private process( + imageSource: ImageSource, + imageProcessingOptions: ImageProcessingOptions|undefined, + timestamp: number): void { + const normalizedRect = this.convertToNormalizedRect(imageProcessingOptions); + this.graphRunner.addProtoToStream( + normalizedRect.serializeBinary(), 'mediapipe.NormalizedRect', + this.normRectStreamName, timestamp); + this.graphRunner.addGpuBufferAsImageToStream( + imageSource, this.imageStreamName, timestamp ?? performance.now()); + this.finishProcessing(); + } +} + + diff --git a/mediapipe/tasks/web/vision/gesture_recognizer/BUILD b/mediapipe/tasks/web/vision/gesture_recognizer/BUILD index 6b99f6ce4..9156e89b7 100644 --- a/mediapipe/tasks/web/vision/gesture_recognizer/BUILD +++ b/mediapipe/tasks/web/vision/gesture_recognizer/BUILD @@ -3,7 +3,8 @@ # This task takes video frames and outputs synchronized frames along with # the detection results for one or more gesture categories, using Gesture Recognizer. -load("//mediapipe/framework/port:build_config.bzl", "mediapipe_ts_library") +load("//mediapipe/framework/port:build_config.bzl", "mediapipe_ts_declaration", "mediapipe_ts_library") +load("@npm//@bazel/jasmine:index.bzl", "jasmine_node_test") package(default_visibility = ["//mediapipe/tasks:internal"]) @@ -11,17 +12,15 @@ licenses(["notice"]) mediapipe_ts_library( name = "gesture_recognizer", - srcs = [ - "gesture_recognizer.ts", - "gesture_recognizer_options.ts", - "gesture_recognizer_result.ts", - ], + srcs = ["gesture_recognizer.ts"], + visibility = ["//visibility:public"], deps = [ + ":gesture_recognizer_types", "//mediapipe/framework:calculator_jspb_proto", "//mediapipe/framework:calculator_options_jspb_proto", "//mediapipe/framework/formats:classification_jspb_proto", "//mediapipe/framework/formats:landmark_jspb_proto", - "//mediapipe/framework/formats:rect_jspb_proto", + "//mediapipe/tasks/cc/core/proto:base_options_jspb_proto", "//mediapipe/tasks/cc/vision/gesture_recognizer/proto:gesture_classifier_graph_options_jspb_proto", "//mediapipe/tasks/cc/vision/gesture_recognizer/proto:gesture_recognizer_graph_options_jspb_proto", "//mediapipe/tasks/cc/vision/gesture_recognizer/proto:hand_gesture_recognizer_graph_options_jspb_proto", @@ -30,11 +29,51 @@ mediapipe_ts_library( "//mediapipe/tasks/cc/vision/hand_landmarker/proto:hand_landmarks_detector_graph_options_jspb_proto", "//mediapipe/tasks/web/components/containers:category", "//mediapipe/tasks/web/components/containers:landmark", - "//mediapipe/tasks/web/components/processors:base_options", "//mediapipe/tasks/web/components/processors:classifier_options", "//mediapipe/tasks/web/core", "//mediapipe/tasks/web/core:classifier_options", - "//mediapipe/tasks/web/core:task_runner", - "//mediapipe/web/graph_runner:wasm_mediapipe_lib_ts", + "//mediapipe/tasks/web/vision/core:image_processing_options", + "//mediapipe/tasks/web/vision/core:vision_task_runner", + "//mediapipe/web/graph_runner:graph_runner_ts", ], ) + +mediapipe_ts_declaration( + name = "gesture_recognizer_types", + srcs = [ + "gesture_recognizer_options.d.ts", + "gesture_recognizer_result.d.ts", + ], + visibility = ["//visibility:public"], + deps = [ + "//mediapipe/tasks/web/components/containers:category", + "//mediapipe/tasks/web/components/containers:landmark", + "//mediapipe/tasks/web/core", + "//mediapipe/tasks/web/core:classifier_options", + "//mediapipe/tasks/web/vision/core:vision_task_options", + ], +) + +mediapipe_ts_library( + name = "gesture_recognizer_test_lib", + testonly = True, + srcs = [ + "gesture_recognizer_test.ts", + ], + deps = [ + ":gesture_recognizer", + ":gesture_recognizer_types", + "//mediapipe/framework:calculator_jspb_proto", + "//mediapipe/framework/formats:classification_jspb_proto", + "//mediapipe/framework/formats:landmark_jspb_proto", + "//mediapipe/tasks/web/core", + "//mediapipe/tasks/web/core:task_runner_test_utils", + "//mediapipe/tasks/web/vision/core:vision_task_runner", + ], +) + +jasmine_node_test( + name = "gesture_recognizer_test", + tags = ["nomsan"], + deps = [":gesture_recognizer_test_lib"], +) diff --git a/mediapipe/tasks/web/vision/gesture_recognizer/gesture_recognizer.ts b/mediapipe/tasks/web/vision/gesture_recognizer/gesture_recognizer.ts index c24d1a7b3..beea263ce 100644 --- a/mediapipe/tasks/web/vision/gesture_recognizer/gesture_recognizer.ts +++ b/mediapipe/tasks/web/vision/gesture_recognizer/gesture_recognizer.ts @@ -18,7 +18,7 @@ import {CalculatorGraphConfig} from '../../../../framework/calculator_pb'; import {CalculatorOptions} from '../../../../framework/calculator_options_pb'; import {ClassificationList} from '../../../../framework/formats/classification_pb'; import {LandmarkList, NormalizedLandmarkList} from '../../../../framework/formats/landmark_pb'; -import {NormalizedRect} from '../../../../framework/formats/rect_pb'; +import {BaseOptions as BaseOptionsProto} from '../../../../tasks/cc/core/proto/base_options_pb'; import {GestureClassifierGraphOptions} from '../../../../tasks/cc/vision/gesture_recognizer/proto/gesture_classifier_graph_options_pb'; import {GestureRecognizerGraphOptions} from '../../../../tasks/cc/vision/gesture_recognizer/proto/gesture_recognizer_graph_options_pb'; import {HandGestureRecognizerGraphOptions} from '../../../../tasks/cc/vision/gesture_recognizer/proto/hand_gesture_recognizer_graph_options_pb'; @@ -26,17 +26,19 @@ import {HandDetectorGraphOptions} from '../../../../tasks/cc/vision/hand_detecto import {HandLandmarkerGraphOptions} from '../../../../tasks/cc/vision/hand_landmarker/proto/hand_landmarker_graph_options_pb'; import {HandLandmarksDetectorGraphOptions} from '../../../../tasks/cc/vision/hand_landmarker/proto/hand_landmarks_detector_graph_options_pb'; import {Category} from '../../../../tasks/web/components/containers/category'; -import {Landmark} from '../../../../tasks/web/components/containers/landmark'; -import {convertBaseOptionsToProto} from '../../../../tasks/web/components/processors/base_options'; +import {Landmark, NormalizedLandmark} from '../../../../tasks/web/components/containers/landmark'; import {convertClassifierOptionsToProto} from '../../../../tasks/web/components/processors/classifier_options'; -import {TaskRunner} from '../../../../tasks/web/core/task_runner'; -import {WasmLoaderOptions} from '../../../../tasks/web/core/wasm_loader_options'; -import {createMediaPipeLib, FileLocator, ImageSource, WasmModule} from '../../../../web/graph_runner/wasm_mediapipe_lib'; +import {WasmFileset} from '../../../../tasks/web/core/wasm_fileset'; +import {ImageProcessingOptions} from '../../../../tasks/web/vision/core/image_processing_options'; +import {VisionGraphRunner, VisionTaskRunner} from '../../../../tasks/web/vision/core/vision_task_runner'; +import {ImageSource, WasmModule} from '../../../../web/graph_runner/graph_runner'; // Placeholder for internal dependency on trusted resource url import {GestureRecognizerOptions} from './gesture_recognizer_options'; import {GestureRecognizerResult} from './gesture_recognizer_result'; +export * from './gesture_recognizer_options'; +export * from './gesture_recognizer_result'; export {ImageSource}; // The OSS JS API does not support the builder pattern. @@ -52,19 +54,13 @@ const GESTURE_RECOGNIZER_GRAPH = 'mediapipe.tasks.vision.gesture_recognizer.GestureRecognizerGraph'; const DEFAULT_NUM_HANDS = 1; -const DEFAULT_SCORE_THRESHOLD = 0.5; +const DEFAULT_CONFIDENCE = 0.5; const DEFAULT_CATEGORY_INDEX = -1; -const FULL_IMAGE_RECT = new NormalizedRect(); -FULL_IMAGE_RECT.setXCenter(0.5); -FULL_IMAGE_RECT.setYCenter(0.5); -FULL_IMAGE_RECT.setWidth(1); -FULL_IMAGE_RECT.setHeight(1); - /** Performs hand gesture recognition on images. */ -export class GestureRecognizer extends TaskRunner { +export class GestureRecognizer extends VisionTaskRunner { private gestures: Category[][] = []; - private landmarks: Landmark[][] = []; + private landmarks: NormalizedLandmark[][] = []; private worldLandmarks: Landmark[][] = []; private handednesses: Category[][] = []; @@ -79,66 +75,61 @@ export class GestureRecognizer extends TaskRunner { /** * Initializes the Wasm runtime and creates a new gesture recognizer from the * provided options. - * @param wasmLoaderOptions A configuration object that provides the location - * of the Wasm binary and its loader. + * @param wasmFileset A configuration object that provides the location of the + * Wasm binary and its loader. * @param gestureRecognizerOptions The options for the gesture recognizer. * Note that either a path to the model asset or a model buffer needs to * be provided (via `baseOptions`). */ - static async createFromOptions( - wasmLoaderOptions: WasmLoaderOptions, + static createFromOptions( + wasmFileset: WasmFileset, gestureRecognizerOptions: GestureRecognizerOptions): Promise { - // Create a file locator based on the loader options - const fileLocator: FileLocator = { - locateFile() { - // The only file we load via this mechanism is the Wasm binary - return wasmLoaderOptions.wasmBinaryPath.toString(); - } - }; - - const recognizer = await createMediaPipeLib( - GestureRecognizer, wasmLoaderOptions.wasmLoaderPath, - /* assetLoaderScript= */ undefined, - /* glCanvas= */ undefined, fileLocator); - await recognizer.setOptions(gestureRecognizerOptions); - return recognizer; + return VisionTaskRunner.createInstance( + GestureRecognizer, /* initializeCanvas= */ true, wasmFileset, + gestureRecognizerOptions); } /** * Initializes the Wasm runtime and creates a new gesture recognizer based on * the provided model asset buffer. - * @param wasmLoaderOptions A configuration object that provides the location - * of the Wasm binary and its loader. + * @param wasmFileset A configuration object that provides the location of the + * Wasm binary and its loader. * @param modelAssetBuffer A binary representation of the model. */ static createFromModelBuffer( - wasmLoaderOptions: WasmLoaderOptions, + wasmFileset: WasmFileset, modelAssetBuffer: Uint8Array): Promise { - return GestureRecognizer.createFromOptions( - wasmLoaderOptions, {baseOptions: {modelAssetBuffer}}); + return VisionTaskRunner.createInstance( + GestureRecognizer, /* initializeCanvas= */ true, wasmFileset, + {baseOptions: {modelAssetBuffer}}); } /** * Initializes the Wasm runtime and creates a new gesture recognizer based on * the path to the model asset. - * @param wasmLoaderOptions A configuration object that provides the location - * of the Wasm binary and its loader. + * @param wasmFileset A configuration object that provides the location of the + * Wasm binary and its loader. * @param modelAssetPath The path to the model asset. */ - static async createFromModelPath( - wasmLoaderOptions: WasmLoaderOptions, + static createFromModelPath( + wasmFileset: WasmFileset, modelAssetPath: string): Promise { - const response = await fetch(modelAssetPath.toString()); - const graphData = await response.arrayBuffer(); - return GestureRecognizer.createFromModelBuffer( - wasmLoaderOptions, new Uint8Array(graphData)); + return VisionTaskRunner.createInstance( + GestureRecognizer, /* initializeCanvas= */ true, wasmFileset, + {baseOptions: {modelAssetPath}}); } - constructor(wasmModule: WasmModule) { - super(wasmModule); + /** @hideconstructor */ + constructor( + wasmModule: WasmModule, + glCanvas?: HTMLCanvasElement|OffscreenCanvas|null) { + super( + new VisionGraphRunner(wasmModule, glCanvas), IMAGE_STREAM, + NORM_RECT_STREAM, /* roiAllowed= */ false); this.options = new GestureRecognizerGraphOptions(); + this.options.setBaseOptions(new BaseOptionsProto()); this.handLandmarkerGraphOptions = new HandLandmarkerGraphOptions(); this.options.setHandLandmarkerGraphOptions(this.handLandmarkerGraphOptions); this.handLandmarksDetectorGraphOptions = @@ -152,12 +143,14 @@ export class GestureRecognizer extends TaskRunner { new HandGestureRecognizerGraphOptions(); this.options.setHandGestureRecognizerGraphOptions( this.handGestureRecognizerGraphOptions); + } - this.initDefaults(); + protected override get baseOptions(): BaseOptionsProto { + return this.options.getBaseOptions()!; + } - // Disables the automatic render-to-screen code, which allows for pure - // CPU processing. - this.setAutoRenderToScreen(false); + protected override set baseOptions(proto: BaseOptionsProto) { + this.options.setBaseOptions(proto); } /** @@ -169,29 +162,15 @@ export class GestureRecognizer extends TaskRunner { * * @param options The options for the gesture recognizer. */ - async setOptions(options: GestureRecognizerOptions): Promise { - if (options.baseOptions) { - const baseOptionsProto = await convertBaseOptionsToProto( - options.baseOptions, this.options.getBaseOptions()); - this.options.setBaseOptions(baseOptionsProto); - } - - if ('numHands' in options) { - this.handDetectorGraphOptions.setNumHands( - options.numHands ?? DEFAULT_NUM_HANDS); - } - if ('minHandDetectionConfidence' in options) { - this.handDetectorGraphOptions.setMinDetectionConfidence( - options.minHandDetectionConfidence ?? DEFAULT_SCORE_THRESHOLD); - } - if ('minHandPresenceConfidence' in options) { - this.handLandmarksDetectorGraphOptions.setMinDetectionConfidence( - options.minHandPresenceConfidence ?? DEFAULT_SCORE_THRESHOLD); - } - if ('minTrackingConfidence' in options) { - this.handLandmarkerGraphOptions.setMinTrackingConfidence( - options.minTrackingConfidence ?? DEFAULT_SCORE_THRESHOLD); - } + override setOptions(options: GestureRecognizerOptions): Promise { + this.handDetectorGraphOptions.setNumHands( + options.numHands ?? DEFAULT_NUM_HANDS); + this.handDetectorGraphOptions.setMinDetectionConfidence( + options.minHandDetectionConfidence ?? DEFAULT_CONFIDENCE); + this.handLandmarkerGraphOptions.setMinTrackingConfidence( + options.minTrackingConfidence ?? DEFAULT_CONFIDENCE); + this.handLandmarksDetectorGraphOptions.setMinDetectionConfidence( + options.minHandPresenceConfidence ?? DEFAULT_CONFIDENCE); if (options.cannedGesturesClassifierOptions) { // Note that we have to support both JSPB and ProtobufJS and cannot @@ -225,59 +204,87 @@ export class GestureRecognizer extends TaskRunner { ?.clearClassifierOptions(); } - this.refreshGraph(); + return this.applyOptions(options); } /** * Performs gesture recognition on the provided single image and waits - * synchronously for the response. - * @param imageSource An image source to process. - * @param timestamp The timestamp of the current frame, in ms. If not - * provided, defaults to `performance.now()`. + * synchronously for the response. Only use this method when the + * GestureRecognizer is created with running mode `image`. + * + * @param image A single image to process. + * @param imageProcessingOptions the `ImageProcessingOptions` specifying how + * to process the input image before running inference. * @return The detected gestures. */ - recognize(imageSource: ImageSource, timestamp: number = performance.now()): + recognize( + image: ImageSource, imageProcessingOptions?: ImageProcessingOptions): GestureRecognizerResult { + this.resetResults(); + this.processImageData(image, imageProcessingOptions); + return this.processResults(); + } + + /** + * Performs gesture recognition on the provided video frame and waits + * synchronously for the response. Only use this method when the + * GestureRecognizer is created with running mode `video`. + * + * @param videoFrame A video frame to process. + * @param timestamp The timestamp of the current frame, in ms. + * @param imageProcessingOptions the `ImageProcessingOptions` specifying how + * to process the input image before running inference. + * @return The detected gestures. + */ + recognizeForVideo( + videoFrame: ImageSource, timestamp: number, + imageProcessingOptions?: ImageProcessingOptions): + GestureRecognizerResult { + this.resetResults(); + this.processVideoData(videoFrame, imageProcessingOptions, timestamp); + return this.processResults(); + } + + private resetResults(): void { this.gestures = []; this.landmarks = []; this.worldLandmarks = []; this.handednesses = []; - - this.addGpuBufferAsImageToStream(imageSource, IMAGE_STREAM, timestamp); - this.addProtoToStream( - FULL_IMAGE_RECT.serializeBinary(), 'mediapipe.NormalizedRect', - NORM_RECT_STREAM, timestamp); - this.finishProcessing(); - - return { - gestures: this.gestures, - landmarks: this.landmarks, - worldLandmarks: this.worldLandmarks, - handednesses: this.handednesses - }; } - /** Sets the default values for the graph. */ - private initDefaults(): void { - this.handDetectorGraphOptions.setNumHands(DEFAULT_NUM_HANDS); - this.handDetectorGraphOptions.setMinDetectionConfidence( - DEFAULT_SCORE_THRESHOLD); - this.handLandmarksDetectorGraphOptions.setMinDetectionConfidence( - DEFAULT_SCORE_THRESHOLD); - this.handLandmarkerGraphOptions.setMinTrackingConfidence( - DEFAULT_SCORE_THRESHOLD); + private processResults(): GestureRecognizerResult { + if (this.gestures.length === 0) { + // If no gestures are detected in the image, just return an empty list + return { + gestures: [], + landmarks: [], + worldLandmarks: [], + handednesses: [], + }; + } else { + return { + gestures: this.gestures, + landmarks: this.landmarks, + worldLandmarks: this.worldLandmarks, + handednesses: this.handednesses + }; + } } /** Converts the proto data to a Category[][] structure. */ - private toJsCategories(data: Uint8Array[]): Category[][] { + private toJsCategories(data: Uint8Array[], populateIndex = true): + Category[][] { const result: Category[][] = []; for (const binaryProto of data) { const inputList = ClassificationList.deserializeBinary(binaryProto); const outputList: Category[] = []; for (const classification of inputList.getClassificationList()) { + const index = populateIndex && classification.hasIndex() ? + classification.getIndex()! : + DEFAULT_CATEGORY_INDEX; outputList.push({ score: classification.getScore() ?? 0, - index: classification.getIndex() ?? DEFAULT_CATEGORY_INDEX, + index, categoryName: classification.getLabel() ?? '', displayName: classification.getDisplayName() ?? '', }); @@ -292,13 +299,12 @@ export class GestureRecognizer extends TaskRunner { for (const binaryProto of data) { const handLandmarksProto = NormalizedLandmarkList.deserializeBinary(binaryProto); - const landmarks: Landmark[] = []; + const landmarks: NormalizedLandmark[] = []; for (const handLandmarkProto of handLandmarksProto.getLandmarkList()) { landmarks.push({ x: handLandmarkProto.getX() ?? 0, y: handLandmarkProto.getY() ?? 0, - z: handLandmarkProto.getZ() ?? 0, - normalized: true + z: handLandmarkProto.getZ() ?? 0 }); } this.landmarks.push(landmarks); @@ -319,8 +325,7 @@ export class GestureRecognizer extends TaskRunner { worldLandmarks.push({ x: handWorldLandmarkProto.getX() ?? 0, y: handWorldLandmarkProto.getY() ?? 0, - z: handWorldLandmarkProto.getZ() ?? 0, - normalized: false + z: handWorldLandmarkProto.getZ() ?? 0 }); } this.worldLandmarks.push(worldLandmarks); @@ -328,7 +333,7 @@ export class GestureRecognizer extends TaskRunner { } /** Updates the MediaPipe graph configuration. */ - private refreshGraph(): void { + protected override refreshGraph(): void { const graphConfig = new CalculatorGraphConfig(); graphConfig.addInputStream(IMAGE_STREAM); graphConfig.addInputStream(NORM_RECT_STREAM); @@ -353,18 +358,29 @@ export class GestureRecognizer extends TaskRunner { graphConfig.addNode(recognizerNode); - this.attachProtoVectorListener(LANDMARKS_STREAM, binaryProto => { - this.addJsLandmarks(binaryProto); - }); - this.attachProtoVectorListener(WORLD_LANDMARKS_STREAM, binaryProto => { - this.adddJsWorldLandmarks(binaryProto); - }); - this.attachProtoVectorListener(HAND_GESTURES_STREAM, binaryProto => { - this.gestures.push(...this.toJsCategories(binaryProto)); - }); - this.attachProtoVectorListener(HANDEDNESS_STREAM, binaryProto => { - this.handednesses.push(...this.toJsCategories(binaryProto)); - }); + this.graphRunner.attachProtoVectorListener( + LANDMARKS_STREAM, (binaryProto, timestamp) => { + this.addJsLandmarks(binaryProto); + this.setLatestOutputTimestamp(timestamp); + }); + this.graphRunner.attachProtoVectorListener( + WORLD_LANDMARKS_STREAM, (binaryProto, timestamp) => { + this.adddJsWorldLandmarks(binaryProto); + this.setLatestOutputTimestamp(timestamp); + }); + this.graphRunner.attachProtoVectorListener( + HAND_GESTURES_STREAM, (binaryProto, timestamp) => { + // Gesture index is not used, because the final gesture result comes + // from multiple classifiers. + this.gestures.push( + ...this.toJsCategories(binaryProto, /* populateIndex= */ false)); + this.setLatestOutputTimestamp(timestamp); + }); + this.graphRunner.attachProtoVectorListener( + HANDEDNESS_STREAM, (binaryProto, timestamp) => { + this.handednesses.push(...this.toJsCategories(binaryProto)); + this.setLatestOutputTimestamp(timestamp); + }); const binaryGraph = graphConfig.serializeBinary(); this.setGraph(new Uint8Array(binaryGraph), /* isBinary= */ true); diff --git a/mediapipe/tasks/web/vision/gesture_recognizer/gesture_recognizer_options.ts b/mediapipe/tasks/web/vision/gesture_recognizer/gesture_recognizer_options.d.ts similarity index 90% rename from mediapipe/tasks/web/vision/gesture_recognizer/gesture_recognizer_options.ts rename to mediapipe/tasks/web/vision/gesture_recognizer/gesture_recognizer_options.d.ts index 45601a74c..dd8fc9548 100644 --- a/mediapipe/tasks/web/vision/gesture_recognizer/gesture_recognizer_options.ts +++ b/mediapipe/tasks/web/vision/gesture_recognizer/gesture_recognizer_options.d.ts @@ -14,14 +14,11 @@ * limitations under the License. */ -import {BaseOptions} from '../../../../tasks/web/core/base_options'; import {ClassifierOptions} from '../../../../tasks/web/core/classifier_options'; +import {VisionTaskOptions} from '../../../../tasks/web/vision/core/vision_task_options'; /** Options to configure the MediaPipe Gesture Recognizer Task */ -export declare interface GestureRecognizerOptions { - /** Options to configure the loading of the model assets. */ - baseOptions?: BaseOptions; - +export declare interface GestureRecognizerOptions extends VisionTaskOptions { /** * The maximum number of hands can be detected by the GestureRecognizer. * Defaults to 1. diff --git a/mediapipe/tasks/web/vision/gesture_recognizer/gesture_recognizer_result.d.ts b/mediapipe/tasks/web/vision/gesture_recognizer/gesture_recognizer_result.d.ts new file mode 100644 index 000000000..323290008 --- /dev/null +++ b/mediapipe/tasks/web/vision/gesture_recognizer/gesture_recognizer_result.d.ts @@ -0,0 +1,41 @@ +/** + * Copyright 2022 The MediaPipe Authors. All Rights Reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +import {Category} from '../../../../tasks/web/components/containers/category'; +import {Landmark, NormalizedLandmark} from '../../../../tasks/web/components/containers/landmark'; + +export {Category, Landmark, NormalizedLandmark}; + +/** + * Represents the gesture recognition results generated by `GestureRecognizer`. + */ +export declare interface GestureRecognizerResult { + /** Hand landmarks of detected hands. */ + landmarks: NormalizedLandmark[][]; + + /** Hand landmarks in world coordniates of detected hands. */ + worldLandmarks: Landmark[][]; + + /** Handedness of detected hands. */ + handednesses: Category[][]; + + /** + * Recognized hand gestures of detected hands. Note that the index of the + * gesture is always -1, because the raw indices from multiple gesture + * classifiers cannot consolidate to a meaningful index. + */ + gestures: Category[][]; +} diff --git a/mediapipe/tasks/web/vision/gesture_recognizer/gesture_recognizer_test.ts b/mediapipe/tasks/web/vision/gesture_recognizer/gesture_recognizer_test.ts new file mode 100644 index 000000000..b2a2c0d72 --- /dev/null +++ b/mediapipe/tasks/web/vision/gesture_recognizer/gesture_recognizer_test.ts @@ -0,0 +1,343 @@ +/** + * Copyright 2022 The MediaPipe Authors. All Rights Reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +import 'jasmine'; + +import {CalculatorGraphConfig} from '../../../../framework/calculator_pb'; +import {Classification, ClassificationList} from '../../../../framework/formats/classification_pb'; +import {Landmark, LandmarkList, NormalizedLandmark, NormalizedLandmarkList} from '../../../../framework/formats/landmark_pb'; +import {addJasmineCustomFloatEqualityTester, createSpyWasmModule, MediapipeTasksFake, SpyWasmModule, verifyGraph, verifyListenersRegistered} from '../../../../tasks/web/core/task_runner_test_utils'; +import {VisionGraphRunner} from '../../../../tasks/web/vision/core/vision_task_runner'; + +import {GestureRecognizer, GestureRecognizerOptions} from './gesture_recognizer'; + +// The OSS JS API does not support the builder pattern. +// tslint:disable:jspb-use-builder-pattern + +type ProtoListener = ((binaryProtos: Uint8Array[], timestamp: number) => void); + +function createHandednesses(): Uint8Array[] { + const handsProto = new ClassificationList(); + const classification = new Classification(); + classification.setScore(0.1); + classification.setIndex(1); + classification.setLabel('handedness_label'); + classification.setDisplayName('handedness_display_name'); + handsProto.addClassification(classification); + return [handsProto.serializeBinary()]; +} + +function createGestures(): Uint8Array[] { + const gesturesProto = new ClassificationList(); + const classification = new Classification(); + classification.setScore(0.2); + classification.setIndex(2); + classification.setLabel('gesture_label'); + classification.setDisplayName('gesture_display_name'); + gesturesProto.addClassification(classification); + return [gesturesProto.serializeBinary()]; +} + +function createLandmarks(): Uint8Array[] { + const handLandmarksProto = new NormalizedLandmarkList(); + const landmark = new NormalizedLandmark(); + landmark.setX(0.3); + landmark.setY(0.4); + landmark.setZ(0.5); + handLandmarksProto.addLandmark(landmark); + return [handLandmarksProto.serializeBinary()]; +} + +function createWorldLandmarks(): Uint8Array[] { + const handLandmarksProto = new LandmarkList(); + const landmark = new Landmark(); + landmark.setX(21); + landmark.setY(22); + landmark.setZ(23); + handLandmarksProto.addLandmark(landmark); + return [handLandmarksProto.serializeBinary()]; +} + +class GestureRecognizerFake extends GestureRecognizer implements + MediapipeTasksFake { + calculatorName = + 'mediapipe.tasks.vision.gesture_recognizer.GestureRecognizerGraph'; + attachListenerSpies: jasmine.Spy[] = []; + graph: CalculatorGraphConfig|undefined; + fakeWasmModule: SpyWasmModule; + listeners = new Map(); + + constructor() { + super(createSpyWasmModule(), /* glCanvas= */ null); + this.fakeWasmModule = + this.graphRunner.wasmModule as unknown as SpyWasmModule; + this.attachListenerSpies[0] = + spyOn(this.graphRunner, 'attachProtoVectorListener') + .and.callFake((stream, listener) => { + expect(stream).toMatch( + /(hand_landmarks|world_hand_landmarks|handedness|hand_gestures)/); + this.listeners.set(stream, listener); + }); + + spyOn(this.graphRunner, 'setGraph').and.callFake(binaryGraph => { + this.graph = CalculatorGraphConfig.deserializeBinary(binaryGraph); + }); + spyOn(this.graphRunner, 'addGpuBufferAsImageToStream'); + spyOn(this.graphRunner, 'addProtoToStream'); + } + + getGraphRunner(): VisionGraphRunner { + return this.graphRunner; + } +} + +describe('GestureRecognizer', () => { + let gestureRecognizer: GestureRecognizerFake; + + beforeEach(async () => { + addJasmineCustomFloatEqualityTester(); + gestureRecognizer = new GestureRecognizerFake(); + await gestureRecognizer.setOptions( + {baseOptions: {modelAssetBuffer: new Uint8Array([])}}); + }); + + it('initializes graph', async () => { + verifyGraph(gestureRecognizer); + verifyListenersRegistered(gestureRecognizer); + }); + + it('reloads graph when settings are changed', async () => { + await gestureRecognizer.setOptions({numHands: 1}); + verifyGraph(gestureRecognizer, [ + ['handLandmarkerGraphOptions', 'handDetectorGraphOptions', 'numHands'], 1 + ]); + verifyListenersRegistered(gestureRecognizer); + + await gestureRecognizer.setOptions({numHands: 5}); + verifyGraph(gestureRecognizer, [ + ['handLandmarkerGraphOptions', 'handDetectorGraphOptions', 'numHands'], 5 + ]); + verifyListenersRegistered(gestureRecognizer); + }); + + it('merges options', async () => { + await gestureRecognizer.setOptions({numHands: 1}); + await gestureRecognizer.setOptions({minHandDetectionConfidence: 0.5}); + verifyGraph(gestureRecognizer, [ + ['handLandmarkerGraphOptions', 'handDetectorGraphOptions', 'numHands'], 1 + ]); + verifyGraph(gestureRecognizer, [ + [ + 'handLandmarkerGraphOptions', 'handDetectorGraphOptions', + 'minDetectionConfidence' + ], + 0.5 + ]); + }); + + describe('setOptions()', () => { + interface TestCase { + optionPath: [keyof GestureRecognizerOptions, ...string[]]; + fieldPath: string[]; + customValue: unknown; + defaultValue: unknown; + } + + const testCases: TestCase[] = [ + { + optionPath: ['numHands'], + fieldPath: [ + 'handLandmarkerGraphOptions', 'handDetectorGraphOptions', 'numHands' + ], + customValue: 5, + defaultValue: 1 + }, + { + optionPath: ['minHandDetectionConfidence'], + fieldPath: [ + 'handLandmarkerGraphOptions', 'handDetectorGraphOptions', + 'minDetectionConfidence' + ], + customValue: 0.1, + defaultValue: 0.5 + }, + { + optionPath: ['minHandPresenceConfidence'], + fieldPath: [ + 'handLandmarkerGraphOptions', 'handLandmarksDetectorGraphOptions', + 'minDetectionConfidence' + ], + customValue: 0.2, + defaultValue: 0.5 + }, + { + optionPath: ['minTrackingConfidence'], + fieldPath: ['handLandmarkerGraphOptions', 'minTrackingConfidence'], + customValue: 0.3, + defaultValue: 0.5 + }, + { + optionPath: ['cannedGesturesClassifierOptions', 'scoreThreshold'], + fieldPath: [ + 'handGestureRecognizerGraphOptions', + 'cannedGestureClassifierGraphOptions', 'classifierOptions', + 'scoreThreshold' + ], + customValue: 0.4, + defaultValue: undefined + }, + { + optionPath: ['customGesturesClassifierOptions', 'scoreThreshold'], + fieldPath: [ + 'handGestureRecognizerGraphOptions', + 'customGestureClassifierGraphOptions', 'classifierOptions', + 'scoreThreshold' + ], + customValue: 0.5, + defaultValue: undefined, + }, + ]; + + /** Creates an options object that can be passed to setOptions() */ + function createOptions( + path: string[], value: unknown): GestureRecognizerOptions { + const options: Record = {}; + let currentLevel = options; + for (const element of path.slice(0, -1)) { + currentLevel[element] = {}; + currentLevel = currentLevel[element] as Record; + } + currentLevel[path[path.length - 1]] = value; + return options; + } + + for (const testCase of testCases) { + it(`uses default value for ${testCase.optionPath[0]}`, async () => { + verifyGraph( + gestureRecognizer, [testCase.fieldPath, testCase.defaultValue]); + }); + + it(`can set ${testCase.optionPath[0]}`, async () => { + await gestureRecognizer.setOptions( + createOptions(testCase.optionPath, testCase.customValue)); + verifyGraph( + gestureRecognizer, [testCase.fieldPath, testCase.customValue]); + }); + + it(`can clear ${testCase.optionPath[0]}`, async () => { + await gestureRecognizer.setOptions( + createOptions(testCase.optionPath, testCase.customValue)); + verifyGraph( + gestureRecognizer, [testCase.fieldPath, testCase.customValue]); + + await gestureRecognizer.setOptions( + createOptions(testCase.optionPath, undefined)); + verifyGraph( + gestureRecognizer, [testCase.fieldPath, testCase.defaultValue]); + }); + } + }); + + it('doesn\'t support region of interest', () => { + expect(() => { + gestureRecognizer.recognize( + {} as HTMLImageElement, + {regionOfInterest: {left: 0, right: 0, top: 0, bottom: 0}}); + }).toThrowError('This task doesn\'t support region-of-interest.'); + }); + + it('transforms results', async () => { + // Pass the test data to our listener + gestureRecognizer.fakeWasmModule._waitUntilIdle.and.callFake(() => { + verifyListenersRegistered(gestureRecognizer); + gestureRecognizer.listeners.get('hand_landmarks')! + (createLandmarks(), 1337); + gestureRecognizer.listeners.get('world_hand_landmarks')! + (createWorldLandmarks(), 1337); + gestureRecognizer.listeners.get('handedness')! + (createHandednesses(), 1337); + gestureRecognizer.listeners.get('hand_gestures')!(createGestures(), 1337); + }); + + // Invoke the gesture recognizer + const gestures = gestureRecognizer.recognize({} as HTMLImageElement); + expect(gestureRecognizer.getGraphRunner().addProtoToStream) + .toHaveBeenCalledTimes(1); + expect(gestureRecognizer.getGraphRunner().addGpuBufferAsImageToStream) + .toHaveBeenCalledTimes(1); + expect(gestureRecognizer.fakeWasmModule._waitUntilIdle).toHaveBeenCalled(); + + expect(gestures).toEqual({ + 'gestures': [[{ + 'score': 0.2, + 'index': -1, + 'categoryName': 'gesture_label', + 'displayName': 'gesture_display_name' + }]], + 'landmarks': [[{'x': 0.3, 'y': 0.4, 'z': 0.5}]], + 'worldLandmarks': [[{'x': 21, 'y': 22, 'z': 23}]], + 'handednesses': [[{ + 'score': 0.1, + 'index': 1, + 'categoryName': 'handedness_label', + 'displayName': 'handedness_display_name' + }]] + }); + }); + + it('clears results between invoations', async () => { + // Pass the test data to our listener + gestureRecognizer.fakeWasmModule._waitUntilIdle.and.callFake(() => { + gestureRecognizer.listeners.get('hand_landmarks')! + (createLandmarks(), 1337); + gestureRecognizer.listeners.get('world_hand_landmarks')! + (createWorldLandmarks(), 1337); + gestureRecognizer.listeners.get('handedness')! + (createHandednesses(), 1337); + gestureRecognizer.listeners.get('hand_gestures')!(createGestures(), 1337); + }); + + // Invoke the gesture recognizer twice + const gestures1 = gestureRecognizer.recognize({} as HTMLImageElement); + const gestures2 = gestureRecognizer.recognize({} as HTMLImageElement); + + // Verify that gestures2 is not a concatenation of all previously returned + // gestures. + expect(gestures2).toEqual(gestures1); + }); + + it('returns empty results when no gestures are detected', async () => { + // Pass the test data to our listener + gestureRecognizer.fakeWasmModule._waitUntilIdle.and.callFake(() => { + verifyListenersRegistered(gestureRecognizer); + gestureRecognizer.listeners.get('hand_landmarks')! + (createLandmarks(), 1337); + gestureRecognizer.listeners.get('world_hand_landmarks')! + (createWorldLandmarks(), 1337); + gestureRecognizer.listeners.get('handedness')! + (createHandednesses(), 1337); + gestureRecognizer.listeners.get('hand_gestures')!([], 1337); + }); + + // Invoke the gesture recognizer + const gestures = gestureRecognizer.recognize({} as HTMLImageElement); + expect(gestures).toEqual({ + 'gestures': [], + 'landmarks': [], + 'worldLandmarks': [], + 'handednesses': [] + }); + }); +}); diff --git a/mediapipe/tasks/web/vision/hand_landmarker/BUILD b/mediapipe/tasks/web/vision/hand_landmarker/BUILD new file mode 100644 index 000000000..c5687ee2f --- /dev/null +++ b/mediapipe/tasks/web/vision/hand_landmarker/BUILD @@ -0,0 +1,74 @@ +# This contains the MediaPipe Hand Landmarker Task. +# +# This task takes video frames and outputs synchronized frames along with +# the detection results for one or more hand categories, using Hand Landmarker. + +load("//mediapipe/framework/port:build_config.bzl", "mediapipe_ts_declaration", "mediapipe_ts_library") +load("@npm//@bazel/jasmine:index.bzl", "jasmine_node_test") + +package(default_visibility = ["//mediapipe/tasks:internal"]) + +licenses(["notice"]) + +mediapipe_ts_library( + name = "hand_landmarker", + srcs = ["hand_landmarker.ts"], + visibility = ["//visibility:public"], + deps = [ + ":hand_landmarker_types", + "//mediapipe/framework:calculator_jspb_proto", + "//mediapipe/framework:calculator_options_jspb_proto", + "//mediapipe/framework/formats:classification_jspb_proto", + "//mediapipe/framework/formats:landmark_jspb_proto", + "//mediapipe/tasks/cc/core/proto:base_options_jspb_proto", + "//mediapipe/tasks/cc/vision/hand_detector/proto:hand_detector_graph_options_jspb_proto", + "//mediapipe/tasks/cc/vision/hand_landmarker/proto:hand_landmarker_graph_options_jspb_proto", + "//mediapipe/tasks/cc/vision/hand_landmarker/proto:hand_landmarks_detector_graph_options_jspb_proto", + "//mediapipe/tasks/web/components/containers:category", + "//mediapipe/tasks/web/components/containers:landmark", + "//mediapipe/tasks/web/core", + "//mediapipe/tasks/web/vision/core:image_processing_options", + "//mediapipe/tasks/web/vision/core:vision_task_runner", + "//mediapipe/web/graph_runner:graph_runner_ts", + ], +) + +mediapipe_ts_declaration( + name = "hand_landmarker_types", + srcs = [ + "hand_landmark.d.ts", + "hand_landmarker_options.d.ts", + "hand_landmarker_result.d.ts", + ], + visibility = ["//visibility:public"], + deps = [ + "//mediapipe/tasks/web/components/containers:category", + "//mediapipe/tasks/web/components/containers:landmark", + "//mediapipe/tasks/web/core", + "//mediapipe/tasks/web/vision/core:vision_task_options", + ], +) + +mediapipe_ts_library( + name = "hand_landmarker_test_lib", + testonly = True, + srcs = [ + "hand_landmarker_test.ts", + ], + deps = [ + ":hand_landmarker", + ":hand_landmarker_types", + "//mediapipe/framework:calculator_jspb_proto", + "//mediapipe/framework/formats:classification_jspb_proto", + "//mediapipe/framework/formats:landmark_jspb_proto", + "//mediapipe/tasks/web/core", + "//mediapipe/tasks/web/core:task_runner_test_utils", + "//mediapipe/tasks/web/vision/core:vision_task_runner", + ], +) + +jasmine_node_test( + name = "hand_landmarker_test", + tags = ["nomsan"], + deps = [":hand_landmarker_test_lib"], +) diff --git a/mediapipe/tasks/web/vision/hand_landmarker/hand_landmark.d.ts b/mediapipe/tasks/web/vision/hand_landmarker/hand_landmark.d.ts new file mode 100644 index 000000000..ca2543f78 --- /dev/null +++ b/mediapipe/tasks/web/vision/hand_landmarker/hand_landmark.d.ts @@ -0,0 +1,41 @@ +/** + * Copyright 2022 The MediaPipe Authors. All Rights Reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + + +/** The 21 hand landmarks. */ +export const enum HandLandmark { + WRIST = 0, + THUMB_CMC = 1, + THUMB_MCP = 2, + THUMB_IP = 3, + THUMB_TIP = 4, + INDEX_FINGER_MCP = 5, + INDEX_FINGER_PIP = 6, + INDEX_FINGER_DIP = 7, + INDEX_FINGER_TIP = 8, + MIDDLE_FINGER_MCP = 9, + MIDDLE_FINGER_PIP = 10, + MIDDLE_FINGER_DIP = 11, + MIDDLE_FINGER_TIP = 12, + RING_FINGER_MCP = 13, + RING_FINGER_PIP = 14, + RING_FINGER_DIP = 15, + RING_FINGER_TIP = 16, + PINKY_MCP = 17, + PINKY_PIP = 18, + PINKY_DIP = 19, + PINKY_TIP = 20 +} diff --git a/mediapipe/tasks/web/vision/hand_landmarker/hand_landmarker.ts b/mediapipe/tasks/web/vision/hand_landmarker/hand_landmarker.ts new file mode 100644 index 000000000..cd0459372 --- /dev/null +++ b/mediapipe/tasks/web/vision/hand_landmarker/hand_landmarker.ts @@ -0,0 +1,336 @@ +/** + * Copyright 2022 The MediaPipe Authors. All Rights Reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +import {CalculatorGraphConfig} from '../../../../framework/calculator_pb'; +import {CalculatorOptions} from '../../../../framework/calculator_options_pb'; +import {ClassificationList} from '../../../../framework/formats/classification_pb'; +import {LandmarkList, NormalizedLandmarkList} from '../../../../framework/formats/landmark_pb'; +import {BaseOptions as BaseOptionsProto} from '../../../../tasks/cc/core/proto/base_options_pb'; +import {HandDetectorGraphOptions} from '../../../../tasks/cc/vision/hand_detector/proto/hand_detector_graph_options_pb'; +import {HandLandmarkerGraphOptions} from '../../../../tasks/cc/vision/hand_landmarker/proto/hand_landmarker_graph_options_pb'; +import {HandLandmarksDetectorGraphOptions} from '../../../../tasks/cc/vision/hand_landmarker/proto/hand_landmarks_detector_graph_options_pb'; +import {Category} from '../../../../tasks/web/components/containers/category'; +import {Landmark, NormalizedLandmark} from '../../../../tasks/web/components/containers/landmark'; +import {WasmFileset} from '../../../../tasks/web/core/wasm_fileset'; +import {ImageProcessingOptions} from '../../../../tasks/web/vision/core/image_processing_options'; +import {VisionGraphRunner, VisionTaskRunner} from '../../../../tasks/web/vision/core/vision_task_runner'; +import {ImageSource, WasmModule} from '../../../../web/graph_runner/graph_runner'; +// Placeholder for internal dependency on trusted resource url + +import {HandLandmarkerOptions} from './hand_landmarker_options'; +import {HandLandmarkerResult} from './hand_landmarker_result'; + +export * from './hand_landmarker_options'; +export * from './hand_landmarker_result'; +export {ImageSource}; + +// The OSS JS API does not support the builder pattern. +// tslint:disable:jspb-use-builder-pattern + +const IMAGE_STREAM = 'image_in'; +const NORM_RECT_STREAM = 'norm_rect'; +const LANDMARKS_STREAM = 'hand_landmarks'; +const WORLD_LANDMARKS_STREAM = 'world_hand_landmarks'; +const HANDEDNESS_STREAM = 'handedness'; +const HAND_LANDMARKER_GRAPH = + 'mediapipe.tasks.vision.hand_landmarker.HandLandmarkerGraph'; + +const DEFAULT_NUM_HANDS = 1; +const DEFAULT_SCORE_THRESHOLD = 0.5; +const DEFAULT_CATEGORY_INDEX = -1; + +/** Performs hand landmarks detection on images. */ +export class HandLandmarker extends VisionTaskRunner { + private landmarks: NormalizedLandmark[][] = []; + private worldLandmarks: Landmark[][] = []; + private handednesses: Category[][] = []; + + private readonly options: HandLandmarkerGraphOptions; + private readonly handLandmarksDetectorGraphOptions: + HandLandmarksDetectorGraphOptions; + private readonly handDetectorGraphOptions: HandDetectorGraphOptions; + + /** + * Initializes the Wasm runtime and creates a new `HandLandmarker` from the + * provided options. + * @param wasmFileset A configuration object that provides the location of the + * Wasm binary and its loader. + * @param handLandmarkerOptions The options for the HandLandmarker. + * Note that either a path to the model asset or a model buffer needs to + * be provided (via `baseOptions`). + */ + static createFromOptions( + wasmFileset: WasmFileset, + handLandmarkerOptions: HandLandmarkerOptions): Promise { + return VisionTaskRunner.createInstance( + HandLandmarker, /* initializeCanvas= */ true, wasmFileset, + handLandmarkerOptions); + } + + /** + * Initializes the Wasm runtime and creates a new `HandLandmarker` based on + * the provided model asset buffer. + * @param wasmFileset A configuration object that provides the location of the + * Wasm binary and its loader. + * @param modelAssetBuffer A binary representation of the model. + */ + static createFromModelBuffer( + wasmFileset: WasmFileset, + modelAssetBuffer: Uint8Array): Promise { + return VisionTaskRunner.createInstance( + HandLandmarker, /* initializeCanvas= */ true, wasmFileset, + {baseOptions: {modelAssetBuffer}}); + } + + /** + * Initializes the Wasm runtime and creates a new `HandLandmarker` based on + * the path to the model asset. + * @param wasmFileset A configuration object that provides the location of the + * Wasm binary and its loader. + * @param modelAssetPath The path to the model asset. + */ + static createFromModelPath( + wasmFileset: WasmFileset, + modelAssetPath: string): Promise { + return VisionTaskRunner.createInstance( + HandLandmarker, /* initializeCanvas= */ true, wasmFileset, + {baseOptions: {modelAssetPath}}); + } + + /** @hideconstructor */ + constructor( + wasmModule: WasmModule, + glCanvas?: HTMLCanvasElement|OffscreenCanvas|null) { + super( + new VisionGraphRunner(wasmModule, glCanvas), IMAGE_STREAM, + NORM_RECT_STREAM, /* roiAllowed= */ false); + + this.options = new HandLandmarkerGraphOptions(); + this.options.setBaseOptions(new BaseOptionsProto()); + this.handLandmarksDetectorGraphOptions = + new HandLandmarksDetectorGraphOptions(); + this.options.setHandLandmarksDetectorGraphOptions( + this.handLandmarksDetectorGraphOptions); + this.handDetectorGraphOptions = new HandDetectorGraphOptions(); + this.options.setHandDetectorGraphOptions(this.handDetectorGraphOptions); + + this.initDefaults(); + } + + protected override get baseOptions(): BaseOptionsProto { + return this.options.getBaseOptions()!; + } + + protected override set baseOptions(proto: BaseOptionsProto) { + this.options.setBaseOptions(proto); + } + + /** + * Sets new options for this `HandLandmarker`. + * + * Calling `setOptions()` with a subset of options only affects those options. + * You can reset an option back to its default value by explicitly setting it + * to `undefined`. + * + * @param options The options for the hand landmarker. + */ + override setOptions(options: HandLandmarkerOptions): Promise { + // Configure hand detector options. + if ('numHands' in options) { + this.handDetectorGraphOptions.setNumHands( + options.numHands ?? DEFAULT_NUM_HANDS); + } + if ('minHandDetectionConfidence' in options) { + this.handDetectorGraphOptions.setMinDetectionConfidence( + options.minHandDetectionConfidence ?? DEFAULT_SCORE_THRESHOLD); + } + + // Configure hand landmark detector options. + if ('minTrackingConfidence' in options) { + this.options.setMinTrackingConfidence( + options.minTrackingConfidence ?? DEFAULT_SCORE_THRESHOLD); + } + if ('minHandPresenceConfidence' in options) { + this.handLandmarksDetectorGraphOptions.setMinDetectionConfidence( + options.minHandPresenceConfidence ?? DEFAULT_SCORE_THRESHOLD); + } + + return this.applyOptions(options); + } + + /** + * Performs hand landmarks detection on the provided single image and waits + * synchronously for the response. Only use this method when the + * HandLandmarker is created with running mode `image`. + * + * @param image An image to process. + * @param imageProcessingOptions the `ImageProcessingOptions` specifying how + * to process the input image before running inference. + * @return The detected hand landmarks. + */ + detect(image: ImageSource, imageProcessingOptions?: ImageProcessingOptions): + HandLandmarkerResult { + this.resetResults(); + this.processImageData(image, imageProcessingOptions); + return this.processResults(); + } + + /** + * Performs hand landmarks detection on the provided video frame and waits + * synchronously for the response. Only use this method when the + * HandLandmarker is created with running mode `video`. + * + * @param videoFrame A video frame to process. + * @param timestamp The timestamp of the current frame, in ms. + * @param imageProcessingOptions the `ImageProcessingOptions` specifying how + * to process the input image before running inference. + * @return The detected hand landmarks. + */ + detectForVideo( + videoFrame: ImageSource, timestamp: number, + imageProcessingOptions?: ImageProcessingOptions): HandLandmarkerResult { + this.resetResults(); + this.processVideoData(videoFrame, imageProcessingOptions, timestamp); + return this.processResults(); + } + + private resetResults(): void { + this.landmarks = []; + this.worldLandmarks = []; + this.handednesses = []; + } + + private processResults(): HandLandmarkerResult { + return { + landmarks: this.landmarks, + worldLandmarks: this.worldLandmarks, + handednesses: this.handednesses + }; + } + + /** Sets the default values for the graph. */ + private initDefaults(): void { + this.handDetectorGraphOptions.setNumHands(DEFAULT_NUM_HANDS); + this.handDetectorGraphOptions.setMinDetectionConfidence( + DEFAULT_SCORE_THRESHOLD); + this.handLandmarksDetectorGraphOptions.setMinDetectionConfidence( + DEFAULT_SCORE_THRESHOLD); + this.options.setMinTrackingConfidence(DEFAULT_SCORE_THRESHOLD); + } + + /** Converts the proto data to a Category[][] structure. */ + private toJsCategories(data: Uint8Array[]): Category[][] { + const result: Category[][] = []; + for (const binaryProto of data) { + const inputList = ClassificationList.deserializeBinary(binaryProto); + const outputList: Category[] = []; + for (const classification of inputList.getClassificationList()) { + outputList.push({ + score: classification.getScore() ?? 0, + index: classification.getIndex() ?? DEFAULT_CATEGORY_INDEX, + categoryName: classification.getLabel() ?? '', + displayName: classification.getDisplayName() ?? '', + }); + } + result.push(outputList); + } + return result; + } + + /** Converts raw data into a landmark, and adds it to our landmarks list. */ + private addJsLandmarks(data: Uint8Array[]): void { + for (const binaryProto of data) { + const handLandmarksProto = + NormalizedLandmarkList.deserializeBinary(binaryProto); + const landmarks: NormalizedLandmark[] = []; + for (const handLandmarkProto of handLandmarksProto.getLandmarkList()) { + landmarks.push({ + x: handLandmarkProto.getX() ?? 0, + y: handLandmarkProto.getY() ?? 0, + z: handLandmarkProto.getZ() ?? 0, + }); + } + this.landmarks.push(landmarks); + } + } + + /** + * Converts raw data into a world landmark, and adds it to our worldLandmarks + * list. + */ + private adddJsWorldLandmarks(data: Uint8Array[]): void { + for (const binaryProto of data) { + const handWorldLandmarksProto = + LandmarkList.deserializeBinary(binaryProto); + const worldLandmarks: Landmark[] = []; + for (const handWorldLandmarkProto of + handWorldLandmarksProto.getLandmarkList()) { + worldLandmarks.push({ + x: handWorldLandmarkProto.getX() ?? 0, + y: handWorldLandmarkProto.getY() ?? 0, + z: handWorldLandmarkProto.getZ() ?? 0, + }); + } + this.worldLandmarks.push(worldLandmarks); + } + } + + /** Updates the MediaPipe graph configuration. */ + protected override refreshGraph(): void { + const graphConfig = new CalculatorGraphConfig(); + graphConfig.addInputStream(IMAGE_STREAM); + graphConfig.addInputStream(NORM_RECT_STREAM); + graphConfig.addOutputStream(LANDMARKS_STREAM); + graphConfig.addOutputStream(WORLD_LANDMARKS_STREAM); + graphConfig.addOutputStream(HANDEDNESS_STREAM); + + const calculatorOptions = new CalculatorOptions(); + calculatorOptions.setExtension( + HandLandmarkerGraphOptions.ext, this.options); + + const landmarkerNode = new CalculatorGraphConfig.Node(); + landmarkerNode.setCalculator(HAND_LANDMARKER_GRAPH); + landmarkerNode.addInputStream('IMAGE:' + IMAGE_STREAM); + landmarkerNode.addInputStream('NORM_RECT:' + NORM_RECT_STREAM); + landmarkerNode.addOutputStream('LANDMARKS:' + LANDMARKS_STREAM); + landmarkerNode.addOutputStream('WORLD_LANDMARKS:' + WORLD_LANDMARKS_STREAM); + landmarkerNode.addOutputStream('HANDEDNESS:' + HANDEDNESS_STREAM); + landmarkerNode.setOptions(calculatorOptions); + + graphConfig.addNode(landmarkerNode); + + this.graphRunner.attachProtoVectorListener( + LANDMARKS_STREAM, (binaryProto, timestamp) => { + this.addJsLandmarks(binaryProto); + this.setLatestOutputTimestamp(timestamp); + }); + this.graphRunner.attachProtoVectorListener( + WORLD_LANDMARKS_STREAM, (binaryProto, timestamp) => { + this.adddJsWorldLandmarks(binaryProto); + this.setLatestOutputTimestamp(timestamp); + }); + this.graphRunner.attachProtoVectorListener( + HANDEDNESS_STREAM, (binaryProto, timestamp) => { + this.handednesses.push(...this.toJsCategories(binaryProto)); + this.setLatestOutputTimestamp(timestamp); + }); + + const binaryGraph = graphConfig.serializeBinary(); + this.setGraph(new Uint8Array(binaryGraph), /* isBinary= */ true); + } +} + + diff --git a/mediapipe/tasks/web/vision/hand_landmarker/hand_landmarker_options.d.ts b/mediapipe/tasks/web/vision/hand_landmarker/hand_landmarker_options.d.ts new file mode 100644 index 000000000..fe79b7089 --- /dev/null +++ b/mediapipe/tasks/web/vision/hand_landmarker/hand_landmarker_options.d.ts @@ -0,0 +1,44 @@ +/** + * Copyright 2022 The MediaPipe Authors. All Rights Reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +import {VisionTaskOptions} from '../../../../tasks/web/vision/core/vision_task_options'; + +/** Options to configure the MediaPipe HandLandmarker Task */ +export declare interface HandLandmarkerOptions extends VisionTaskOptions { + /** + * The maximum number of hands can be detected by the HandLandmarker. + * Defaults to 1. + */ + numHands?: number|undefined; + + /** + * The minimum confidence score for the hand detection to be considered + * successful. Defaults to 0.5. + */ + minHandDetectionConfidence?: number|undefined; + + /** + * The minimum confidence score of hand presence score in the hand landmark + * detection. Defaults to 0.5. + */ + minHandPresenceConfidence?: number|undefined; + + /** + * The minimum confidence score for the hand tracking to be considered + * successful. Defaults to 0.5. + */ + minTrackingConfidence?: number|undefined; +} diff --git a/mediapipe/tasks/web/vision/gesture_recognizer/gesture_recognizer_result.ts b/mediapipe/tasks/web/vision/hand_landmarker/hand_landmarker_result.d.ts similarity index 74% rename from mediapipe/tasks/web/vision/gesture_recognizer/gesture_recognizer_result.ts rename to mediapipe/tasks/web/vision/hand_landmarker/hand_landmarker_result.d.ts index 7c295c9e9..8a6d9bfa6 100644 --- a/mediapipe/tasks/web/vision/gesture_recognizer/gesture_recognizer_result.ts +++ b/mediapipe/tasks/web/vision/hand_landmarker/hand_landmarker_result.d.ts @@ -15,21 +15,20 @@ */ import {Category} from '../../../../tasks/web/components/containers/category'; -import {Landmark} from '../../../../tasks/web/components/containers/landmark'; +import {Landmark, NormalizedLandmark} from '../../../../tasks/web/components/containers/landmark'; + +export {Landmark, NormalizedLandmark, Category}; /** - * Represents the gesture recognition results generated by `GestureRecognizer`. + * Represents the hand landmarks deection results generated by `HandLandmarker`. */ -export declare interface GestureRecognizerResult { +export declare interface HandLandmarkerResult { /** Hand landmarks of detected hands. */ - landmarks: Landmark[][]; + landmarks: NormalizedLandmark[][]; /** Hand landmarks in world coordniates of detected hands. */ worldLandmarks: Landmark[][]; /** Handedness of detected hands. */ handednesses: Category[][]; - - /** Recognized hand gestures of detected hands */ - gestures: Category[][]; } diff --git a/mediapipe/tasks/web/vision/hand_landmarker/hand_landmarker_test.ts b/mediapipe/tasks/web/vision/hand_landmarker/hand_landmarker_test.ts new file mode 100644 index 000000000..5fd493424 --- /dev/null +++ b/mediapipe/tasks/web/vision/hand_landmarker/hand_landmarker_test.ts @@ -0,0 +1,261 @@ +/** + * Copyright 2022 The MediaPipe Authors. All Rights Reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +import 'jasmine'; + +import {CalculatorGraphConfig} from '../../../../framework/calculator_pb'; +import {Classification, ClassificationList} from '../../../../framework/formats/classification_pb'; +import {Landmark, LandmarkList, NormalizedLandmark, NormalizedLandmarkList} from '../../../../framework/formats/landmark_pb'; +import {addJasmineCustomFloatEqualityTester, createSpyWasmModule, MediapipeTasksFake, SpyWasmModule, verifyGraph, verifyListenersRegistered} from '../../../../tasks/web/core/task_runner_test_utils'; +import {VisionGraphRunner} from '../../../../tasks/web/vision/core/vision_task_runner'; + +import {HandLandmarker} from './hand_landmarker'; +import {HandLandmarkerOptions} from './hand_landmarker_options'; + + +// The OSS JS API does not support the builder pattern. +// tslint:disable:jspb-use-builder-pattern + +type ProtoListener = ((binaryProtos: Uint8Array[], timestamp: number) => void); + +function createHandednesses(): Uint8Array[] { + const handsProto = new ClassificationList(); + const classification = new Classification(); + classification.setScore(0.1); + classification.setIndex(1); + classification.setLabel('handedness_label'); + classification.setDisplayName('handedness_display_name'); + handsProto.addClassification(classification); + return [handsProto.serializeBinary()]; +} + +function createLandmarks(): Uint8Array[] { + const handLandmarksProto = new NormalizedLandmarkList(); + const landmark = new NormalizedLandmark(); + landmark.setX(0.3); + landmark.setY(0.4); + landmark.setZ(0.5); + handLandmarksProto.addLandmark(landmark); + return [handLandmarksProto.serializeBinary()]; +} + +function createWorldLandmarks(): Uint8Array[] { + const handLandmarksProto = new LandmarkList(); + const landmark = new Landmark(); + landmark.setX(21); + landmark.setY(22); + landmark.setZ(23); + handLandmarksProto.addLandmark(landmark); + return [handLandmarksProto.serializeBinary()]; +} + +class HandLandmarkerFake extends HandLandmarker implements MediapipeTasksFake { + calculatorName = 'mediapipe.tasks.vision.hand_landmarker.HandLandmarkerGraph'; + attachListenerSpies: jasmine.Spy[] = []; + graph: CalculatorGraphConfig|undefined; + fakeWasmModule: SpyWasmModule; + listeners = new Map(); + + constructor() { + super(createSpyWasmModule(), /* glCanvas= */ null); + this.fakeWasmModule = + this.graphRunner.wasmModule as unknown as SpyWasmModule; + + this.attachListenerSpies[0] = + spyOn(this.graphRunner, 'attachProtoVectorListener') + .and.callFake((stream, listener) => { + expect(stream).toMatch( + /(hand_landmarks|world_hand_landmarks|handedness|hand_hands)/); + this.listeners.set(stream, listener); + }); + + spyOn(this.graphRunner, 'setGraph').and.callFake(binaryGraph => { + this.graph = CalculatorGraphConfig.deserializeBinary(binaryGraph); + }); + spyOn(this.graphRunner, 'addGpuBufferAsImageToStream'); + spyOn(this.graphRunner, 'addProtoToStream'); + } + + getGraphRunner(): VisionGraphRunner { + return this.graphRunner; + } +} + +describe('HandLandmarker', () => { + let handLandmarker: HandLandmarkerFake; + + beforeEach(async () => { + addJasmineCustomFloatEqualityTester(); + handLandmarker = new HandLandmarkerFake(); + await handLandmarker.setOptions( + {baseOptions: {modelAssetBuffer: new Uint8Array([])}}); + }); + + it('initializes graph', async () => { + verifyGraph(handLandmarker); + verifyListenersRegistered(handLandmarker); + }); + + it('reloads graph when settings are changed', async () => { + verifyListenersRegistered(handLandmarker); + + await handLandmarker.setOptions({numHands: 1}); + verifyGraph(handLandmarker, [['handDetectorGraphOptions', 'numHands'], 1]); + verifyListenersRegistered(handLandmarker); + + await handLandmarker.setOptions({numHands: 5}); + verifyGraph(handLandmarker, [['handDetectorGraphOptions', 'numHands'], 5]); + verifyListenersRegistered(handLandmarker); + }); + + it('merges options', async () => { + await handLandmarker.setOptions({numHands: 1}); + await handLandmarker.setOptions({minHandDetectionConfidence: 0.5}); + verifyGraph(handLandmarker, [ + 'handDetectorGraphOptions', + {numHands: 1, baseOptions: undefined, minDetectionConfidence: 0.5} + ]); + }); + + describe('setOptions()', () => { + interface TestCase { + optionPath: [keyof HandLandmarkerOptions, ...string[]]; + fieldPath: string[]; + customValue: unknown; + defaultValue: unknown; + } + + const testCases: TestCase[] = [ + { + optionPath: ['numHands'], + fieldPath: ['handDetectorGraphOptions', 'numHands'], + customValue: 5, + defaultValue: 1 + }, + { + optionPath: ['minHandDetectionConfidence'], + fieldPath: ['handDetectorGraphOptions', 'minDetectionConfidence'], + customValue: 0.1, + defaultValue: 0.5 + }, + { + optionPath: ['minHandPresenceConfidence'], + fieldPath: + ['handLandmarksDetectorGraphOptions', 'minDetectionConfidence'], + customValue: 0.2, + defaultValue: 0.5 + }, + { + optionPath: ['minTrackingConfidence'], + fieldPath: ['minTrackingConfidence'], + customValue: 0.3, + defaultValue: 0.5 + }, + ]; + + /** Creates an options object that can be passed to setOptions() */ + function createOptions( + path: string[], value: unknown): HandLandmarkerOptions { + const options: Record = {}; + let currentLevel = options; + for (const element of path.slice(0, -1)) { + currentLevel[element] = {}; + currentLevel = currentLevel[element] as Record; + } + currentLevel[path[path.length - 1]] = value; + return options; + } + + for (const testCase of testCases) { + it(`uses default value for ${testCase.optionPath[0]}`, async () => { + verifyGraph( + handLandmarker, [testCase.fieldPath, testCase.defaultValue]); + }); + + it(`can set ${testCase.optionPath[0]}`, async () => { + await handLandmarker.setOptions( + createOptions(testCase.optionPath, testCase.customValue)); + verifyGraph(handLandmarker, [testCase.fieldPath, testCase.customValue]); + }); + + it(`can clear ${testCase.optionPath[0]}`, async () => { + await handLandmarker.setOptions( + createOptions(testCase.optionPath, testCase.customValue)); + verifyGraph(handLandmarker, [testCase.fieldPath, testCase.customValue]); + + await handLandmarker.setOptions( + createOptions(testCase.optionPath, undefined)); + verifyGraph( + handLandmarker, [testCase.fieldPath, testCase.defaultValue]); + }); + } + }); + + it('doesn\'t support region of interest', () => { + expect(() => { + handLandmarker.detect( + {} as HTMLImageElement, + {regionOfInterest: {left: 0, right: 0, top: 0, bottom: 0}}); + }).toThrowError('This task doesn\'t support region-of-interest.'); + }); + + it('transforms results', async () => { + // Pass the test data to our listener + handLandmarker.fakeWasmModule._waitUntilIdle.and.callFake(() => { + verifyListenersRegistered(handLandmarker); + handLandmarker.listeners.get('hand_landmarks')!(createLandmarks(), 1337); + handLandmarker.listeners.get('world_hand_landmarks')! + (createWorldLandmarks(), 1337); + handLandmarker.listeners.get('handedness')!(createHandednesses(), 1337); + }); + + // Invoke the hand landmarker + const landmarks = handLandmarker.detect({} as HTMLImageElement); + expect(handLandmarker.getGraphRunner().addProtoToStream) + .toHaveBeenCalledTimes(1); + expect(handLandmarker.getGraphRunner().addGpuBufferAsImageToStream) + .toHaveBeenCalledTimes(1); + expect(handLandmarker.fakeWasmModule._waitUntilIdle).toHaveBeenCalled(); + + expect(landmarks).toEqual({ + 'landmarks': [[{'x': 0.3, 'y': 0.4, 'z': 0.5}]], + 'worldLandmarks': [[{'x': 21, 'y': 22, 'z': 23}]], + 'handednesses': [[{ + 'score': 0.1, + 'index': 1, + 'categoryName': 'handedness_label', + 'displayName': 'handedness_display_name' + }]] + }); + }); + + it('clears results between invoations', async () => { + // Pass the test data to our listener + handLandmarker.fakeWasmModule._waitUntilIdle.and.callFake(() => { + handLandmarker.listeners.get('hand_landmarks')!(createLandmarks(), 1337); + handLandmarker.listeners.get('world_hand_landmarks')! + (createWorldLandmarks(), 1337); + handLandmarker.listeners.get('handedness')!(createHandednesses(), 1337); + }); + + // Invoke the hand landmarker twice + const landmarks1 = handLandmarker.detect({} as HTMLImageElement); + const landmarks2 = handLandmarker.detect({} as HTMLImageElement); + + // Verify that hands2 is not a concatenation of all previously returned + // hands. + expect(landmarks1).toEqual(landmarks2); + }); +}); diff --git a/mediapipe/tasks/web/vision/image_classifier/BUILD b/mediapipe/tasks/web/vision/image_classifier/BUILD index e96d6a8e3..86c7d8457 100644 --- a/mediapipe/tasks/web/vision/image_classifier/BUILD +++ b/mediapipe/tasks/web/vision/image_classifier/BUILD @@ -2,7 +2,8 @@ # # This task takes video or image frames and outputs the classification result. -load("//mediapipe/framework/port:build_config.bzl", "mediapipe_ts_library") +load("//mediapipe/framework/port:build_config.bzl", "mediapipe_ts_declaration", "mediapipe_ts_library") +load("@npm//@bazel/jasmine:index.bzl", "jasmine_node_test") package(default_visibility = ["//mediapipe/tasks:internal"]) @@ -10,24 +11,61 @@ licenses(["notice"]) mediapipe_ts_library( name = "image_classifier", - srcs = [ - "image_classifier.ts", - "image_classifier_options.ts", - "image_classifier_result.ts", - ], + srcs = ["image_classifier.ts"], + visibility = ["//visibility:public"], deps = [ + ":image_classifier_types", "//mediapipe/framework:calculator_jspb_proto", "//mediapipe/framework:calculator_options_jspb_proto", "//mediapipe/tasks/cc/components/containers/proto:classifications_jspb_proto", + "//mediapipe/tasks/cc/core/proto:base_options_jspb_proto", "//mediapipe/tasks/cc/vision/image_classifier/proto:image_classifier_graph_options_jspb_proto", "//mediapipe/tasks/web/components/containers:category", "//mediapipe/tasks/web/components/containers:classification_result", - "//mediapipe/tasks/web/components/processors:base_options", "//mediapipe/tasks/web/components/processors:classifier_options", "//mediapipe/tasks/web/components/processors:classifier_result", "//mediapipe/tasks/web/core", "//mediapipe/tasks/web/core:classifier_options", - "//mediapipe/tasks/web/core:task_runner", - "//mediapipe/web/graph_runner:wasm_mediapipe_lib_ts", + "//mediapipe/tasks/web/vision/core:image_processing_options", + "//mediapipe/tasks/web/vision/core:vision_task_runner", + "//mediapipe/web/graph_runner:graph_runner_ts", ], ) + +mediapipe_ts_declaration( + name = "image_classifier_types", + srcs = [ + "image_classifier_options.d.ts", + "image_classifier_result.d.ts", + ], + visibility = ["//visibility:public"], + deps = [ + "//mediapipe/tasks/web/components/containers:category", + "//mediapipe/tasks/web/components/containers:classification_result", + "//mediapipe/tasks/web/core:classifier_options", + "//mediapipe/tasks/web/vision/core:vision_task_options", + ], +) + +mediapipe_ts_library( + name = "image_classifier_test_lib", + testonly = True, + srcs = [ + "image_classifier_test.ts", + ], + deps = [ + ":image_classifier", + ":image_classifier_types", + "//mediapipe/framework:calculator_jspb_proto", + "//mediapipe/framework/formats:classification_jspb_proto", + "//mediapipe/tasks/cc/components/containers/proto:classifications_jspb_proto", + "//mediapipe/tasks/web/core", + "//mediapipe/tasks/web/core:task_runner_test_utils", + ], +) + +jasmine_node_test( + name = "image_classifier_test", + tags = ["nomsan"], + deps = [":image_classifier_test_lib"], +) diff --git a/mediapipe/tasks/web/vision/image_classifier/image_classifier.ts b/mediapipe/tasks/web/vision/image_classifier/image_classifier.ts index ba4b6c907..071513b19 100644 --- a/mediapipe/tasks/web/vision/image_classifier/image_classifier.ts +++ b/mediapipe/tasks/web/vision/image_classifier/image_classifier.ts @@ -17,13 +17,14 @@ import {CalculatorGraphConfig} from '../../../../framework/calculator_pb'; import {CalculatorOptions} from '../../../../framework/calculator_options_pb'; import {ClassificationResult} from '../../../../tasks/cc/components/containers/proto/classifications_pb'; +import {BaseOptions as BaseOptionsProto} from '../../../../tasks/cc/core/proto/base_options_pb'; import {ImageClassifierGraphOptions} from '../../../../tasks/cc/vision/image_classifier/proto/image_classifier_graph_options_pb'; -import {convertBaseOptionsToProto} from '../../../../tasks/web/components/processors/base_options'; import {convertClassifierOptionsToProto} from '../../../../tasks/web/components/processors/classifier_options'; import {convertFromClassificationResultProto} from '../../../../tasks/web/components/processors/classifier_result'; -import {TaskRunner} from '../../../../tasks/web/core/task_runner'; -import {WasmLoaderOptions} from '../../../../tasks/web/core/wasm_loader_options'; -import {createMediaPipeLib, FileLocator, ImageSource} from '../../../../web/graph_runner/wasm_mediapipe_lib'; +import {WasmFileset} from '../../../../tasks/web/core/wasm_fileset'; +import {ImageProcessingOptions} from '../../../../tasks/web/vision/core/image_processing_options'; +import {VisionGraphRunner, VisionTaskRunner} from '../../../../tasks/web/vision/core/vision_task_runner'; +import {ImageSource, WasmModule} from '../../../../web/graph_runner/graph_runner'; // Placeholder for internal dependency on trusted resource url import {ImageClassifierOptions} from './image_classifier_options'; @@ -31,76 +32,85 @@ import {ImageClassifierResult} from './image_classifier_result'; const IMAGE_CLASSIFIER_GRAPH = 'mediapipe.tasks.vision.image_classifier.ImageClassifierGraph'; -const INPUT_STREAM = 'input_image'; +const IMAGE_STREAM = 'input_image'; +const NORM_RECT_STREAM = 'norm_rect'; const CLASSIFICATIONS_STREAM = 'classifications'; +export * from './image_classifier_options'; +export * from './image_classifier_result'; export {ImageSource}; // Used in the public API // The OSS JS API does not support the builder pattern. // tslint:disable:jspb-use-builder-pattern /** Performs classification on images. */ -export class ImageClassifier extends TaskRunner { +export class ImageClassifier extends VisionTaskRunner { private classificationResult: ImageClassifierResult = {classifications: []}; private readonly options = new ImageClassifierGraphOptions(); /** * Initializes the Wasm runtime and creates a new image classifier from the * provided options. - * @param wasmLoaderOptions A configuration object that provides the location - * of the Wasm binary and its loader. + * @param wasmFileset A configuration object that provides the location + * Wasm binary and its loader. * @param imageClassifierOptions The options for the image classifier. Note * that either a path to the model asset or a model buffer needs to be * provided (via `baseOptions`). */ - static async createFromOptions( - wasmLoaderOptions: WasmLoaderOptions, - imageClassifierOptions: ImageClassifierOptions): + static createFromOptions( + wasmFileset: WasmFileset, imageClassifierOptions: ImageClassifierOptions): Promise { - // Create a file locator based on the loader options - const fileLocator: FileLocator = { - locateFile() { - // The only file we load is the Wasm binary - return wasmLoaderOptions.wasmBinaryPath.toString(); - } - }; - - const classifier = await createMediaPipeLib( - ImageClassifier, wasmLoaderOptions.wasmLoaderPath, - /* assetLoaderScript= */ undefined, - /* glCanvas= */ undefined, fileLocator); - await classifier.setOptions(imageClassifierOptions); - return classifier; + return VisionTaskRunner.createInstance( + ImageClassifier, /* initializeCanvas= */ true, wasmFileset, + imageClassifierOptions); } /** * Initializes the Wasm runtime and creates a new image classifier based on * the provided model asset buffer. - * @param wasmLoaderOptions A configuration object that provides the location - * of the Wasm binary and its loader. + * @param wasmFileset A configuration object that provides the location of the + * Wasm binary and its loader. * @param modelAssetBuffer A binary representation of the model. */ static createFromModelBuffer( - wasmLoaderOptions: WasmLoaderOptions, + wasmFileset: WasmFileset, modelAssetBuffer: Uint8Array): Promise { - return ImageClassifier.createFromOptions( - wasmLoaderOptions, {baseOptions: {modelAssetBuffer}}); + return VisionTaskRunner.createInstance( + ImageClassifier, /* initializeCanvas= */ true, wasmFileset, + {baseOptions: {modelAssetBuffer}}); } /** * Initializes the Wasm runtime and creates a new image classifier based on * the path to the model asset. - * @param wasmLoaderOptions A configuration object that provides the location - * of the Wasm binary and its loader. + * @param wasmFileset A configuration object that provides the location of the + * Wasm binary and its loader. * @param modelAssetPath The path to the model asset. */ - static async createFromModelPath( - wasmLoaderOptions: WasmLoaderOptions, + static createFromModelPath( + wasmFileset: WasmFileset, modelAssetPath: string): Promise { - const response = await fetch(modelAssetPath.toString()); - const graphData = await response.arrayBuffer(); - return ImageClassifier.createFromModelBuffer( - wasmLoaderOptions, new Uint8Array(graphData)); + return VisionTaskRunner.createInstance( + ImageClassifier, /* initializeCanvas= */ true, wasmFileset, + {baseOptions: {modelAssetPath}}); + } + + /** @hideconstructor */ + constructor( + wasmModule: WasmModule, + glCanvas?: HTMLCanvasElement|OffscreenCanvas|null) { + super( + new VisionGraphRunner(wasmModule, glCanvas), IMAGE_STREAM, + NORM_RECT_STREAM, /* roiAllowed= */ true); + this.options.setBaseOptions(new BaseOptionsProto()); + } + + protected override get baseOptions(): BaseOptionsProto { + return this.options.getBaseOptions()!; + } + + protected override set baseOptions(proto: BaseOptionsProto) { + this.options.setBaseOptions(proto); } /** @@ -112,41 +122,53 @@ export class ImageClassifier extends TaskRunner { * * @param options The options for the image classifier. */ - async setOptions(options: ImageClassifierOptions): Promise { - if (options.baseOptions) { - const baseOptionsProto = await convertBaseOptionsToProto( - options.baseOptions, this.options.getBaseOptions()); - this.options.setBaseOptions(baseOptionsProto); - } - + override setOptions(options: ImageClassifierOptions): Promise { this.options.setClassifierOptions(convertClassifierOptionsToProto( options, this.options.getClassifierOptions())); - this.refreshGraph(); + return this.applyOptions(options); } /** - * Performs image classification on the provided image and waits synchronously - * for the response. + * Performs image classification on the provided single image and waits + * synchronously for the response. Only use this method when the + * ImageClassifier is created with running mode `image`. * - * @param imageSource An image source to process. - * @param timestamp The timestamp of the current frame, in ms. If not - * provided, defaults to `performance.now()`. + * @param image An image to process. + * @param imageProcessingOptions the `ImageProcessingOptions` specifying how + * to process the input image before running inference. * @return The classification result of the image */ - classify(imageSource: ImageSource, timestamp?: number): + classify(image: ImageSource, imageProcessingOptions?: ImageProcessingOptions): ImageClassifierResult { - // Get classification result by running our MediaPipe graph. this.classificationResult = {classifications: []}; - this.addGpuBufferAsImageToStream( - imageSource, INPUT_STREAM, timestamp ?? performance.now()); - this.finishProcessing(); + this.processImageData(image, imageProcessingOptions); + return this.classificationResult; + } + + /** + * Performs image classification on the provided video frame and waits + * synchronously for the response. Only use this method when the + * ImageClassifier is created with running mode `video`. + * + * @param videoFrame A video frame to process. + * @param timestamp The timestamp of the current frame, in ms. + * @param imageProcessingOptions the `ImageProcessingOptions` specifying how + * to process the input image before running inference. + * @return The classification result of the image + */ + classifyForVideo( + videoFrame: ImageSource, timestamp: number, + imageProcessingOptions?: ImageProcessingOptions): ImageClassifierResult { + this.classificationResult = {classifications: []}; + this.processVideoData(videoFrame, imageProcessingOptions, timestamp); return this.classificationResult; } /** Updates the MediaPipe graph configuration. */ - private refreshGraph(): void { + protected override refreshGraph(): void { const graphConfig = new CalculatorGraphConfig(); - graphConfig.addInputStream(INPUT_STREAM); + graphConfig.addInputStream(IMAGE_STREAM); + graphConfig.addInputStream(NORM_RECT_STREAM); graphConfig.addOutputStream(CLASSIFICATIONS_STREAM); const calculatorOptions = new CalculatorOptions(); @@ -157,16 +179,19 @@ export class ImageClassifier extends TaskRunner { // are built-in. const classifierNode = new CalculatorGraphConfig.Node(); classifierNode.setCalculator(IMAGE_CLASSIFIER_GRAPH); - classifierNode.addInputStream('IMAGE:' + INPUT_STREAM); + classifierNode.addInputStream('IMAGE:' + IMAGE_STREAM); + classifierNode.addInputStream('NORM_RECT:' + NORM_RECT_STREAM); classifierNode.addOutputStream('CLASSIFICATIONS:' + CLASSIFICATIONS_STREAM); classifierNode.setOptions(calculatorOptions); graphConfig.addNode(classifierNode); - this.attachProtoListener(CLASSIFICATIONS_STREAM, binaryProto => { - this.classificationResult = convertFromClassificationResultProto( - ClassificationResult.deserializeBinary(binaryProto)); - }); + this.graphRunner.attachProtoListener( + CLASSIFICATIONS_STREAM, (binaryProto, timestamp) => { + this.classificationResult = convertFromClassificationResultProto( + ClassificationResult.deserializeBinary(binaryProto)); + this.setLatestOutputTimestamp(timestamp); + }); const binaryGraph = graphConfig.serializeBinary(); this.setGraph(new Uint8Array(binaryGraph), /* isBinary= */ true); diff --git a/mediapipe/tasks/web/vision/image_classifier/image_classifier_options.d.ts b/mediapipe/tasks/web/vision/image_classifier/image_classifier_options.d.ts new file mode 100644 index 000000000..e99dd2b69 --- /dev/null +++ b/mediapipe/tasks/web/vision/image_classifier/image_classifier_options.d.ts @@ -0,0 +1,22 @@ +/** + * Copyright 2022 The MediaPipe Authors. All Rights Reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +import {ClassifierOptions} from '../../../../tasks/web/core/classifier_options'; +import {VisionTaskOptions} from '../../../../tasks/web/vision/core/vision_task_options'; + +/** Options to configure the MediaPipe Image Classifier Task. */ +export declare interface ImageClassifierOptions extends ClassifierOptions, + VisionTaskOptions {} diff --git a/mediapipe/tasks/web/vision/image_classifier/image_classifier_options.ts b/mediapipe/tasks/web/vision/image_classifier/image_classifier_options.ts deleted file mode 100644 index a5f5c2386..000000000 --- a/mediapipe/tasks/web/vision/image_classifier/image_classifier_options.ts +++ /dev/null @@ -1,17 +0,0 @@ -/** - * Copyright 2022 The MediaPipe Authors. All Rights Reserved. - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -export {ClassifierOptions as ImageClassifierOptions} from '../../../../tasks/web/core/classifier_options'; diff --git a/mediapipe/tasks/web/vision/image_classifier/image_classifier_result.ts b/mediapipe/tasks/web/vision/image_classifier/image_classifier_result.d.ts similarity index 100% rename from mediapipe/tasks/web/vision/image_classifier/image_classifier_result.ts rename to mediapipe/tasks/web/vision/image_classifier/image_classifier_result.d.ts diff --git a/mediapipe/tasks/web/vision/image_classifier/image_classifier_test.ts b/mediapipe/tasks/web/vision/image_classifier/image_classifier_test.ts new file mode 100644 index 000000000..60595310e --- /dev/null +++ b/mediapipe/tasks/web/vision/image_classifier/image_classifier_test.ts @@ -0,0 +1,153 @@ +/** + * Copyright 2022 The MediaPipe Authors. All Rights Reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +import 'jasmine'; + +// Placeholder for internal dependency on encodeByteArray +import {CalculatorGraphConfig} from '../../../../framework/calculator_pb'; +import {Classification, ClassificationList} from '../../../../framework/formats/classification_pb'; +import {ClassificationResult, Classifications} from '../../../../tasks/cc/components/containers/proto/classifications_pb'; +import {addJasmineCustomFloatEqualityTester, createSpyWasmModule, MediapipeTasksFake, SpyWasmModule, verifyGraph, verifyListenersRegistered} from '../../../../tasks/web/core/task_runner_test_utils'; + +import {ImageClassifier} from './image_classifier'; + +// The OSS JS API does not support the builder pattern. +// tslint:disable:jspb-use-builder-pattern + +class ImageClassifierFake extends ImageClassifier implements + MediapipeTasksFake { + calculatorName = + 'mediapipe.tasks.vision.image_classifier.ImageClassifierGraph'; + attachListenerSpies: jasmine.Spy[] = []; + graph: CalculatorGraphConfig|undefined; + + fakeWasmModule: SpyWasmModule; + protoListener: + ((binaryProto: Uint8Array, timestamp: number) => void)|undefined; + + constructor() { + super(createSpyWasmModule(), /* glCanvas= */ null); + this.fakeWasmModule = + this.graphRunner.wasmModule as unknown as SpyWasmModule; + + this.attachListenerSpies[0] = + spyOn(this.graphRunner, 'attachProtoListener') + .and.callFake((stream, listener) => { + expect(stream).toEqual('classifications'); + this.protoListener = listener; + }); + spyOn(this.graphRunner, 'setGraph').and.callFake(binaryGraph => { + this.graph = CalculatorGraphConfig.deserializeBinary(binaryGraph); + }); + spyOn(this.graphRunner, 'addGpuBufferAsImageToStream'); + } +} + +describe('ImageClassifier', () => { + let imageClassifier: ImageClassifierFake; + + beforeEach(async () => { + addJasmineCustomFloatEqualityTester(); + imageClassifier = new ImageClassifierFake(); + await imageClassifier.setOptions( + {baseOptions: {modelAssetBuffer: new Uint8Array([])}}); + }); + + it('initializes graph', async () => { + verifyGraph(imageClassifier); + verifyListenersRegistered(imageClassifier); + }); + + it('reloads graph when settings are changed', async () => { + await imageClassifier.setOptions({maxResults: 1}); + verifyGraph(imageClassifier, [['classifierOptions', 'maxResults'], 1]); + verifyListenersRegistered(imageClassifier); + + await imageClassifier.setOptions({maxResults: 5}); + verifyGraph(imageClassifier, [['classifierOptions', 'maxResults'], 5]); + verifyListenersRegistered(imageClassifier); + }); + + it('can use custom models', async () => { + const newModel = new Uint8Array([0, 1, 2, 3, 4]); + const newModelBase64 = Buffer.from(newModel).toString('base64'); + await imageClassifier.setOptions({ + baseOptions: { + modelAssetBuffer: newModel, + } + }); + + verifyGraph( + imageClassifier, + /* expectedCalculatorOptions= */ undefined, + /* expectedBaseOptions= */[ + 'modelAsset', { + fileContent: newModelBase64, + fileName: undefined, + fileDescriptorMeta: undefined, + filePointerMeta: undefined + } + ]); + }); + + it('merges options', async () => { + await imageClassifier.setOptions({maxResults: 1}); + await imageClassifier.setOptions({displayNamesLocale: 'en'}); + verifyGraph(imageClassifier, [['classifierOptions', 'maxResults'], 1]); + verifyGraph( + imageClassifier, [['classifierOptions', 'displayNamesLocale'], 'en']); + }); + + it('transforms results', async () => { + const classificationResult = new ClassificationResult(); + const classifcations = new Classifications(); + classifcations.setHeadIndex(1); + classifcations.setHeadName('headName'); + const classificationList = new ClassificationList(); + const clasification = new Classification(); + clasification.setIndex(1); + clasification.setScore(0.2); + clasification.setDisplayName('displayName'); + clasification.setLabel('categoryName'); + classificationList.addClassification(clasification); + classifcations.setClassificationList(classificationList); + classificationResult.addClassifications(classifcations); + + // Pass the test data to our listener + imageClassifier.fakeWasmModule._waitUntilIdle.and.callFake(() => { + verifyListenersRegistered(imageClassifier); + imageClassifier.protoListener! + (classificationResult.serializeBinary(), 1337); + }); + + // Invoke the image classifier + const result = imageClassifier.classify({} as HTMLImageElement); + + expect(imageClassifier.fakeWasmModule._waitUntilIdle).toHaveBeenCalled(); + expect(result).toEqual({ + classifications: [{ + categories: [{ + index: 1, + score: 0.2, + displayName: 'displayName', + categoryName: 'categoryName' + }], + headIndex: 1, + headName: 'headName' + }] + }); + }); +}); diff --git a/mediapipe/tasks/web/vision/image_embedder/BUILD b/mediapipe/tasks/web/vision/image_embedder/BUILD new file mode 100644 index 000000000..449cee9bb --- /dev/null +++ b/mediapipe/tasks/web/vision/image_embedder/BUILD @@ -0,0 +1,69 @@ +# This contains the MediaPipe Image Embedder Task. +# +# This task performs embedding extraction on images. + +load("//mediapipe/framework/port:build_config.bzl", "mediapipe_ts_declaration", "mediapipe_ts_library") +load("@npm//@bazel/jasmine:index.bzl", "jasmine_node_test") + +package(default_visibility = ["//mediapipe/tasks:internal"]) + +licenses(["notice"]) + +mediapipe_ts_library( + name = "image_embedder", + srcs = ["image_embedder.ts"], + visibility = ["//visibility:public"], + deps = [ + ":image_embedder_types", + "//mediapipe/framework:calculator_jspb_proto", + "//mediapipe/framework:calculator_options_jspb_proto", + "//mediapipe/tasks/cc/components/containers/proto:embeddings_jspb_proto", + "//mediapipe/tasks/cc/core/proto:base_options_jspb_proto", + "//mediapipe/tasks/cc/vision/image_embedder/proto:image_embedder_graph_options_jspb_proto", + "//mediapipe/tasks/web/components/containers:embedding_result", + "//mediapipe/tasks/web/components/processors:embedder_options", + "//mediapipe/tasks/web/components/processors:embedder_result", + "//mediapipe/tasks/web/components/utils:cosine_similarity", + "//mediapipe/tasks/web/core", + "//mediapipe/tasks/web/core:embedder_options", + "//mediapipe/tasks/web/vision/core:image_processing_options", + "//mediapipe/tasks/web/vision/core:vision_task_options", + "//mediapipe/tasks/web/vision/core:vision_task_runner", + "//mediapipe/web/graph_runner:graph_runner_ts", + ], +) + +mediapipe_ts_declaration( + name = "image_embedder_types", + srcs = [ + "image_embedder_options.d.ts", + "image_embedder_result.d.ts", + ], + visibility = ["//visibility:public"], + deps = [ + "//mediapipe/tasks/web/components/containers:embedding_result", + "//mediapipe/tasks/web/core", + "//mediapipe/tasks/web/core:embedder_options", + "//mediapipe/tasks/web/vision/core:vision_task_options", + ], +) + +mediapipe_ts_library( + name = "image_embedder_test_lib", + testonly = True, + srcs = [ + "image_embedder_test.ts", + ], + deps = [ + ":image_embedder", + "//mediapipe/framework:calculator_jspb_proto", + "//mediapipe/tasks/cc/components/containers/proto:embeddings_jspb_proto", + "//mediapipe/tasks/web/core", + "//mediapipe/tasks/web/core:task_runner_test_utils", + ], +) + +jasmine_node_test( + name = "image_embedder_test", + deps = [":image_embedder_test_lib"], +) diff --git a/mediapipe/tasks/web/vision/image_embedder/image_embedder.ts b/mediapipe/tasks/web/vision/image_embedder/image_embedder.ts new file mode 100644 index 000000000..fdeb92f3f --- /dev/null +++ b/mediapipe/tasks/web/vision/image_embedder/image_embedder.ts @@ -0,0 +1,220 @@ +/** + * Copyright 2022 The MediaPipe Authors. All Rights Reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +import {CalculatorGraphConfig} from '../../../../framework/calculator_pb'; +import {CalculatorOptions} from '../../../../framework/calculator_options_pb'; +import {EmbeddingResult} from '../../../../tasks/cc/components/containers/proto/embeddings_pb'; +import {BaseOptions as BaseOptionsProto} from '../../../../tasks/cc/core/proto/base_options_pb'; +import {ImageEmbedderGraphOptions} from '../../../../tasks/cc/vision/image_embedder/proto/image_embedder_graph_options_pb'; +import {Embedding} from '../../../../tasks/web/components/containers/embedding_result'; +import {convertEmbedderOptionsToProto} from '../../../../tasks/web/components/processors/embedder_options'; +import {convertFromEmbeddingResultProto} from '../../../../tasks/web/components/processors/embedder_result'; +import {computeCosineSimilarity} from '../../../../tasks/web/components/utils/cosine_similarity'; +import {WasmFileset} from '../../../../tasks/web/core/wasm_fileset'; +import {ImageProcessingOptions} from '../../../../tasks/web/vision/core/image_processing_options'; +import {VisionGraphRunner, VisionTaskRunner} from '../../../../tasks/web/vision/core/vision_task_runner'; +import {ImageSource, WasmModule} from '../../../../web/graph_runner/graph_runner'; +// Placeholder for internal dependency on trusted resource url + +import {ImageEmbedderOptions} from './image_embedder_options'; +import {ImageEmbedderResult} from './image_embedder_result'; + + +// The OSS JS API does not support the builder pattern. +// tslint:disable:jspb-use-builder-pattern + +const IMAGE_STREAM = 'image_in'; +const NORM_RECT_STREAM = 'norm_rect'; +const EMBEDDINGS_STREAM = 'embeddings_out'; +const TEXT_EMBEDDER_CALCULATOR = + 'mediapipe.tasks.vision.image_embedder.ImageEmbedderGraph'; + +export * from './image_embedder_options'; +export * from './image_embedder_result'; +export {ImageSource}; // Used in the public API + +/** Performs embedding extraction on images. */ +export class ImageEmbedder extends VisionTaskRunner { + private readonly options = new ImageEmbedderGraphOptions(); + private embeddings: ImageEmbedderResult = {embeddings: []}; + + /** + * Initializes the Wasm runtime and creates a new image embedder from the + * provided options. + * @param wasmFileset A configuration object that provides the location of the + * Wasm binary and its loader. + * @param imageEmbedderOptions The options for the image embedder. Note that + * either a path to the TFLite model or the model itself needs to be + * provided (via `baseOptions`). + */ + static createFromOptions( + wasmFileset: WasmFileset, + imageEmbedderOptions: ImageEmbedderOptions): Promise { + return VisionTaskRunner.createInstance( + ImageEmbedder, /* initializeCanvas= */ true, wasmFileset, + imageEmbedderOptions); + } + + /** + * Initializes the Wasm runtime and creates a new image embedder based on the + * provided model asset buffer. + * @param wasmFileset A configuration object that provides the location of the + * Wasm binary and its loader. + * @param modelAssetBuffer A binary representation of the TFLite model. + */ + static createFromModelBuffer( + wasmFileset: WasmFileset, + modelAssetBuffer: Uint8Array): Promise { + return VisionTaskRunner.createInstance( + ImageEmbedder, /* initializeCanvas= */ true, wasmFileset, + {baseOptions: {modelAssetBuffer}}); + } + + /** + * Initializes the Wasm runtime and creates a new image embedder based on the + * path to the model asset. + * @param wasmFileset A configuration object that provides the location of the + * Wasm binary and its loader. + * @param modelAssetPath The path to the TFLite model. + */ + static createFromModelPath( + wasmFileset: WasmFileset, + modelAssetPath: string): Promise { + return VisionTaskRunner.createInstance( + ImageEmbedder, /* initializeCanvas= */ true, wasmFileset, + {baseOptions: {modelAssetPath}}); + } + + /** @hideconstructor */ + constructor( + wasmModule: WasmModule, + glCanvas?: HTMLCanvasElement|OffscreenCanvas|null) { + super( + new VisionGraphRunner(wasmModule, glCanvas), IMAGE_STREAM, + NORM_RECT_STREAM, /* roiAllowed= */ true); + this.options.setBaseOptions(new BaseOptionsProto()); + } + + protected override get baseOptions(): BaseOptionsProto { + return this.options.getBaseOptions()!; + } + + protected override set baseOptions(proto: BaseOptionsProto) { + this.options.setBaseOptions(proto); + } + + /** + * Sets new options for the image embedder. + * + * Calling `setOptions()` with a subset of options only affects those options. + * You can reset an option back to its default value by explicitly setting it + * to `undefined`. + * + * @param options The options for the image embedder. + */ + override setOptions(options: ImageEmbedderOptions): Promise { + this.options.setEmbedderOptions(convertEmbedderOptionsToProto( + options, this.options.getEmbedderOptions())); + return this.applyOptions(options); + } + + /** + * Performs embedding extraction on the provided single image and waits + * synchronously for the response. Only use this method when the + * ImageEmbedder is created with running mode `image`. + * + * @param image The image to process. + * @param imageProcessingOptions the `ImageProcessingOptions` specifying how + * to process the input image before running inference. + * @return The classification result of the image + */ + embed(image: ImageSource, imageProcessingOptions?: ImageProcessingOptions): + ImageEmbedderResult { + this.processImageData(image, imageProcessingOptions); + return this.embeddings; + } + + /** + * Performs embedding extraction on the provided video frame and waits + * synchronously for the response. Only use this method when the + * ImageEmbedder is created with running mode `video`. + * + * @param imageFrame The image frame to process. + * @param timestamp The timestamp of the current frame, in ms. + * @param imageProcessingOptions the `ImageProcessingOptions` specifying how + * to process the input image before running inference. + * @return The classification result of the image + */ + embedForVideo( + imageFrame: ImageSource, timestamp: number, + imageProcessingOptions?: ImageProcessingOptions): ImageEmbedderResult { + this.processVideoData(imageFrame, imageProcessingOptions, timestamp); + return this.embeddings; + } + + /** + * Utility function to compute cosine similarity[1] between two `Embedding` + * objects. + * + * [1]: https://en.wikipedia.org/wiki/Cosine_similarity + * + * @throws if the embeddings are of different types(float vs. quantized), have + * different sizes, or have an L2-norm of 0. + */ + static cosineSimilarity(u: Embedding, v: Embedding): number { + return computeCosineSimilarity(u, v); + } + + /** + * Internal function for converting raw data into an embedding, and setting it + * as our embeddings result. + */ + private addJsImageEmdedding(binaryProto: Uint8Array): void { + const embeddingResult = EmbeddingResult.deserializeBinary(binaryProto); + this.embeddings = convertFromEmbeddingResultProto(embeddingResult); + } + + /** Updates the MediaPipe graph configuration. */ + protected override refreshGraph(): void { + const graphConfig = new CalculatorGraphConfig(); + graphConfig.addInputStream(IMAGE_STREAM); + graphConfig.addInputStream(NORM_RECT_STREAM); + graphConfig.addOutputStream(EMBEDDINGS_STREAM); + + const calculatorOptions = new CalculatorOptions(); + calculatorOptions.setExtension(ImageEmbedderGraphOptions.ext, this.options); + + const embedderNode = new CalculatorGraphConfig.Node(); + embedderNode.setCalculator(TEXT_EMBEDDER_CALCULATOR); + embedderNode.addInputStream('IMAGE:' + IMAGE_STREAM); + embedderNode.addInputStream('NORM_RECT:' + NORM_RECT_STREAM); + embedderNode.addOutputStream('EMBEDDINGS:' + EMBEDDINGS_STREAM); + embedderNode.setOptions(calculatorOptions); + + graphConfig.addNode(embedderNode); + + this.graphRunner.attachProtoListener( + EMBEDDINGS_STREAM, (binaryProto, timestamp) => { + this.addJsImageEmdedding(binaryProto); + this.setLatestOutputTimestamp(timestamp); + }); + + const binaryGraph = graphConfig.serializeBinary(); + this.setGraph(new Uint8Array(binaryGraph), /* isBinary= */ true); + } +} + + diff --git a/mediapipe/tasks/web/vision/image_embedder/image_embedder_options.d.ts b/mediapipe/tasks/web/vision/image_embedder/image_embedder_options.d.ts new file mode 100644 index 000000000..8a04be5e1 --- /dev/null +++ b/mediapipe/tasks/web/vision/image_embedder/image_embedder_options.d.ts @@ -0,0 +1,22 @@ +/** + * Copyright 2022 The MediaPipe Authors. All Rights Reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +import {EmbedderOptions} from '../../../../tasks/web/core/embedder_options'; +import {VisionTaskOptions} from '../../../../tasks/web/vision/core/vision_task_options'; + +/** Options for configuring a MediaPipe Image Embedder task. */ +export declare interface ImageEmbedderOptions extends EmbedderOptions, + VisionTaskOptions {} diff --git a/mediapipe/tasks/web/audio.ts b/mediapipe/tasks/web/vision/image_embedder/image_embedder_result.d.ts similarity index 83% rename from mediapipe/tasks/web/audio.ts rename to mediapipe/tasks/web/vision/image_embedder/image_embedder_result.d.ts index 4a3b80594..156636505 100644 --- a/mediapipe/tasks/web/audio.ts +++ b/mediapipe/tasks/web/vision/image_embedder/image_embedder_result.d.ts @@ -14,4 +14,4 @@ * limitations under the License. */ -export * from '../../tasks/web/audio/index'; +export {Embedding, EmbeddingResult as ImageEmbedderResult} from '../../../../tasks/web/components/containers/embedding_result'; diff --git a/mediapipe/tasks/web/vision/image_embedder/image_embedder_test.ts b/mediapipe/tasks/web/vision/image_embedder/image_embedder_test.ts new file mode 100644 index 000000000..5a8293c44 --- /dev/null +++ b/mediapipe/tasks/web/vision/image_embedder/image_embedder_test.ts @@ -0,0 +1,160 @@ +/** + * Copyright 2022 The MediaPipe Authors. All Rights Reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +import 'jasmine'; + +// Placeholder for internal dependency on encodeByteArray +import {CalculatorGraphConfig} from '../../../../framework/calculator_pb'; +import {Embedding, EmbeddingResult, FloatEmbedding} from '../../../../tasks/cc/components/containers/proto/embeddings_pb'; +import {addJasmineCustomFloatEqualityTester, createSpyWasmModule, MediapipeTasksFake, SpyWasmModule, verifyGraph, verifyListenersRegistered} from '../../../../tasks/web/core/task_runner_test_utils'; + +import {ImageEmbedder} from './image_embedder'; + +// The OSS JS API does not support the builder pattern. +// tslint:disable:jspb-use-builder-pattern + +class ImageEmbedderFake extends ImageEmbedder implements MediapipeTasksFake { + calculatorName = 'mediapipe.tasks.vision.image_embedder.ImageEmbedderGraph'; + graph: CalculatorGraphConfig|undefined; + attachListenerSpies: jasmine.Spy[] = []; + fakeWasmModule: SpyWasmModule; + protoListener: + ((binaryProtos: Uint8Array, timestamp: number) => void)|undefined; + + constructor() { + super(createSpyWasmModule(), /* glCanvas= */ null); + this.fakeWasmModule = + this.graphRunner.wasmModule as unknown as SpyWasmModule; + + this.attachListenerSpies[0] = + spyOn(this.graphRunner, 'attachProtoListener') + .and.callFake((stream, listener) => { + expect(stream).toEqual('embeddings_out'); + this.protoListener = listener; + }); + spyOn(this.graphRunner, 'setGraph').and.callFake(binaryGraph => { + this.graph = CalculatorGraphConfig.deserializeBinary(binaryGraph); + }); + spyOn(this.graphRunner, 'addGpuBufferAsImageToStream'); + } +} + +describe('ImageEmbedder', () => { + let imageEmbedder: ImageEmbedderFake; + + beforeEach(async () => { + addJasmineCustomFloatEqualityTester(); + imageEmbedder = new ImageEmbedderFake(); + await imageEmbedder.setOptions( + {baseOptions: {modelAssetBuffer: new Uint8Array([])}}); + }); + + it('initializes graph', async () => { + verifyGraph(imageEmbedder); + verifyListenersRegistered(imageEmbedder); + }); + + it('reloads graph when settings are changed', async () => { + verifyListenersRegistered(imageEmbedder); + + await imageEmbedder.setOptions({quantize: true}); + verifyGraph(imageEmbedder, [['embedderOptions', 'quantize'], true]); + verifyListenersRegistered(imageEmbedder); + + await imageEmbedder.setOptions({quantize: undefined}); + verifyGraph(imageEmbedder, [['embedderOptions', 'quantize'], undefined]); + verifyListenersRegistered(imageEmbedder); + }); + + it('can use custom models', async () => { + const newModel = new Uint8Array([0, 1, 2, 3, 4]); + const newModelBase64 = Buffer.from(newModel).toString('base64'); + await imageEmbedder.setOptions({ + baseOptions: { + modelAssetBuffer: newModel, + } + }); + + verifyGraph( + imageEmbedder, + /* expectedCalculatorOptions= */ undefined, + /* expectedBaseOptions= */[ + 'modelAsset', { + fileContent: newModelBase64, + fileName: undefined, + fileDescriptorMeta: undefined, + filePointerMeta: undefined + } + ]); + }); + + it('overrides options', async () => { + await imageEmbedder.setOptions({quantize: true}); + await imageEmbedder.setOptions({l2Normalize: true}); + verifyGraph( + imageEmbedder, + ['embedderOptions', {'quantize': true, 'l2Normalize': true}]); + }); + + describe('transforms result', () => { + beforeEach(() => { + const floatEmbedding = new FloatEmbedding(); + floatEmbedding.setValuesList([0.1, 0.9]); + + const embedding = new Embedding(); + embedding.setHeadIndex(1); + embedding.setHeadName('headName'); + embedding.setFloatEmbedding(floatEmbedding); + + const resultProto = new EmbeddingResult(); + resultProto.addEmbeddings(embedding); + resultProto.setTimestampMs(42); + + // Pass the test data to our listener + imageEmbedder.fakeWasmModule._waitUntilIdle.and.callFake(() => { + verifyListenersRegistered(imageEmbedder); + imageEmbedder.protoListener!(resultProto.serializeBinary(), 1337); + }); + }); + + it('for image mode', async () => { + // Invoke the image embedder + const embeddingResult = imageEmbedder.embed({} as HTMLImageElement); + + expect(imageEmbedder.fakeWasmModule._waitUntilIdle).toHaveBeenCalled(); + expect(embeddingResult).toEqual({ + embeddings: + [{headIndex: 1, headName: 'headName', floatEmbedding: [0.1, 0.9]}], + timestampMs: 42 + }); + }); + + it('for video mode', async () => { + await imageEmbedder.setOptions({runningMode: 'VIDEO'}); + + // Invoke the video embedder + const embeddingResult = + imageEmbedder.embedForVideo({} as HTMLImageElement, 42); + + expect(imageEmbedder.fakeWasmModule._waitUntilIdle).toHaveBeenCalled(); + expect(embeddingResult).toEqual({ + embeddings: + [{headIndex: 1, headName: 'headName', floatEmbedding: [0.1, 0.9]}], + timestampMs: 42 + }); + }); + }); +}); diff --git a/mediapipe/tasks/web/vision/image_segmenter/BUILD b/mediapipe/tasks/web/vision/image_segmenter/BUILD new file mode 100644 index 000000000..d15fe63f1 --- /dev/null +++ b/mediapipe/tasks/web/vision/image_segmenter/BUILD @@ -0,0 +1,58 @@ +# This contains the MediaPipe Image Segmenter Task. + +load("//mediapipe/framework/port:build_config.bzl", "mediapipe_ts_declaration", "mediapipe_ts_library") +load("@npm//@bazel/jasmine:index.bzl", "jasmine_node_test") + +package(default_visibility = ["//mediapipe/tasks:internal"]) + +licenses(["notice"]) + +mediapipe_ts_library( + name = "image_segmenter", + srcs = ["image_segmenter.ts"], + deps = [ + ":image_segmenter_types", + "//mediapipe/framework:calculator_jspb_proto", + "//mediapipe/framework:calculator_options_jspb_proto", + "//mediapipe/tasks/cc/core/proto:base_options_jspb_proto", + "//mediapipe/tasks/cc/vision/image_segmenter/proto:image_segmenter_graph_options_jspb_proto", + "//mediapipe/tasks/cc/vision/image_segmenter/proto:segmenter_options_jspb_proto", + "//mediapipe/tasks/web/core", + "//mediapipe/tasks/web/vision/core:image_processing_options", + "//mediapipe/tasks/web/vision/core:vision_task_runner", + "//mediapipe/web/graph_runner:graph_runner_image_lib_ts", + "//mediapipe/web/graph_runner:graph_runner_ts", + ], +) + +mediapipe_ts_declaration( + name = "image_segmenter_types", + srcs = ["image_segmenter_options.d.ts"], + deps = [ + "//mediapipe/tasks/web/core", + "//mediapipe/tasks/web/core:classifier_options", + "//mediapipe/tasks/web/vision/core:vision_task_options", + ], +) + +mediapipe_ts_library( + name = "image_segmenter_test_lib", + testonly = True, + srcs = [ + "image_segmenter_test.ts", + ], + deps = [ + ":image_segmenter", + ":image_segmenter_types", + "//mediapipe/framework:calculator_jspb_proto", + "//mediapipe/tasks/web/core", + "//mediapipe/tasks/web/core:task_runner_test_utils", + "//mediapipe/web/graph_runner:graph_runner_image_lib_ts", + ], +) + +jasmine_node_test( + name = "image_segmenter_test", + tags = ["nomsan"], + deps = [":image_segmenter_test_lib"], +) diff --git a/mediapipe/tasks/web/vision/image_segmenter/image_segmenter.ts b/mediapipe/tasks/web/vision/image_segmenter/image_segmenter.ts new file mode 100644 index 000000000..4f81977eb --- /dev/null +++ b/mediapipe/tasks/web/vision/image_segmenter/image_segmenter.ts @@ -0,0 +1,300 @@ +/** + * Copyright 2022 The MediaPipe Authors. All Rights Reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +import {CalculatorGraphConfig} from '../../../../framework/calculator_pb'; +import {CalculatorOptions} from '../../../../framework/calculator_options_pb'; +import {BaseOptions as BaseOptionsProto} from '../../../../tasks/cc/core/proto/base_options_pb'; +import {ImageSegmenterGraphOptions as ImageSegmenterGraphOptionsProto} from '../../../../tasks/cc/vision/image_segmenter/proto/image_segmenter_graph_options_pb'; +import {SegmenterOptions as SegmenterOptionsProto} from '../../../../tasks/cc/vision/image_segmenter/proto/segmenter_options_pb'; +import {WasmFileset} from '../../../../tasks/web/core/wasm_fileset'; +import {ImageProcessingOptions} from '../../../../tasks/web/vision/core/image_processing_options'; +import {VisionGraphRunner, VisionTaskRunner} from '../../../../tasks/web/vision/core/vision_task_runner'; +import {ImageSource, WasmModule} from '../../../../web/graph_runner/graph_runner'; +// Placeholder for internal dependency on trusted resource url + +import {ImageSegmenterOptions} from './image_segmenter_options'; + +export * from './image_segmenter_options'; +export {ImageSource}; // Used in the public API + +/** + * The ImageSegmenter returns the segmentation result as a Uint8Array (when + * the default mode of `CATEGORY_MASK` is used) or as a Float32Array (for + * output type `CONFIDENCE_MASK`). The `WebGLTexture` output type is reserved + * for future usage. + */ +export type SegmentationMask = Uint8Array|Float32Array|WebGLTexture; + +/** + * A callback that receives the computed masks from the image segmenter. The + * callback either receives a single element array with a category mask (as a + * `[Uint8Array]`) or multiple confidence masks (as a `Float32Array[]`). + * The returned data is only valid for the duration of the callback. If + * asynchronous processing is needed, all data needs to be copied before the + * callback returns. + */ +export type SegmentationMaskCallback = + (masks: SegmentationMask[], width: number, height: number) => void; + +const IMAGE_STREAM = 'image_in'; +const NORM_RECT_STREAM = 'norm_rect'; +const GROUPED_SEGMENTATIONS_STREAM = 'segmented_masks'; +const IMAGEA_SEGMENTER_GRAPH = + 'mediapipe.tasks.vision.image_segmenter.ImageSegmenterGraph'; + +// The OSS JS API does not support the builder pattern. +// tslint:disable:jspb-use-builder-pattern + +/** Performs image segmentation on images. */ +export class ImageSegmenter extends VisionTaskRunner { + private userCallback: SegmentationMaskCallback = () => {}; + private readonly options: ImageSegmenterGraphOptionsProto; + private readonly segmenterOptions: SegmenterOptionsProto; + + /** + * Initializes the Wasm runtime and creates a new image segmenter from the + * provided options. + * @param wasmFileset A configuration object that provides the location of + * the Wasm binary and its loader. + * @param imageSegmenterOptions The options for the Image Segmenter. Note + * that either a path to the model asset or a model buffer needs to be + * provided (via `baseOptions`). + */ + static createFromOptions( + wasmFileset: WasmFileset, + imageSegmenterOptions: ImageSegmenterOptions): Promise { + return VisionTaskRunner.createInstance( + ImageSegmenter, /* initializeCanvas= */ true, wasmFileset, + imageSegmenterOptions); + } + + /** + * Initializes the Wasm runtime and creates a new image segmenter based on + * the provided model asset buffer. + * @param wasmFileset A configuration object that provides the location of + * the Wasm binary and its loader. + * @param modelAssetBuffer A binary representation of the model. + */ + static createFromModelBuffer( + wasmFileset: WasmFileset, + modelAssetBuffer: Uint8Array): Promise { + return VisionTaskRunner.createInstance( + ImageSegmenter, /* initializeCanvas= */ true, wasmFileset, + {baseOptions: {modelAssetBuffer}}); + } + + /** + * Initializes the Wasm runtime and creates a new image segmenter based on + * the path to the model asset. + * @param wasmFileset A configuration object that provides the location of + * the Wasm binary and its loader. + * @param modelAssetPath The path to the model asset. + */ + static createFromModelPath( + wasmFileset: WasmFileset, + modelAssetPath: string): Promise { + return VisionTaskRunner.createInstance( + ImageSegmenter, /* initializeCanvas= */ true, wasmFileset, + {baseOptions: {modelAssetPath}}); + } + + /** @hideconstructor */ + constructor( + wasmModule: WasmModule, + glCanvas?: HTMLCanvasElement|OffscreenCanvas|null) { + super( + new VisionGraphRunner(wasmModule, glCanvas), IMAGE_STREAM, + NORM_RECT_STREAM, /* roiAllowed= */ false); + this.options = new ImageSegmenterGraphOptionsProto(); + this.segmenterOptions = new SegmenterOptionsProto(); + this.options.setSegmenterOptions(this.segmenterOptions); + this.options.setBaseOptions(new BaseOptionsProto()); + } + + + protected override get baseOptions(): BaseOptionsProto { + return this.options.getBaseOptions()!; + } + + protected override set baseOptions(proto: BaseOptionsProto) { + this.options.setBaseOptions(proto); + } + + /** + * Sets new options for the image segmenter. + * + * Calling `setOptions()` with a subset of options only affects those + * options. You can reset an option back to its default value by + * explicitly setting it to `undefined`. + * + * @param options The options for the image segmenter. + */ + override setOptions(options: ImageSegmenterOptions): Promise { + // Note that we have to support both JSPB and ProtobufJS, hence we + // have to expliclity clear the values instead of setting them to + // `undefined`. + if (options.displayNamesLocale !== undefined) { + this.options.setDisplayNamesLocale(options.displayNamesLocale); + } else if ('displayNamesLocale' in options) { // Check for undefined + this.options.clearDisplayNamesLocale(); + } + + if (options.outputType === 'CONFIDENCE_MASK') { + this.segmenterOptions.setOutputType( + SegmenterOptionsProto.OutputType.CONFIDENCE_MASK); + } else { + this.segmenterOptions.setOutputType( + SegmenterOptionsProto.OutputType.CATEGORY_MASK); + } + + return super.applyOptions(options); + } + + /** + * Performs image segmentation on the provided single image and invokes the + * callback with the response. The method returns synchronously once the + * callback returns. Only use this method when the ImageSegmenter is + * created with running mode `image`. + * + * @param image An image to process. + * @param callback The callback that is invoked with the segmented masks. The + * lifetime of the returned data is only guaranteed for the duration of the + * callback. + */ + segment(image: ImageSource, callback: SegmentationMaskCallback): void; + /** + * Performs image segmentation on the provided single image and invokes the + * callback with the response. The method returns synchronously once the + * callback returns. Only use this method when the ImageSegmenter is + * created with running mode `image`. + * + * @param image An image to process. + * @param imageProcessingOptions the `ImageProcessingOptions` specifying how + * to process the input image before running inference. + * @param callback The callback that is invoked with the segmented masks. The + * lifetime of the returned data is only guaranteed for the duration of the + * callback. + */ + segment( + image: ImageSource, imageProcessingOptions: ImageProcessingOptions, + callback: SegmentationMaskCallback): void; + segment( + image: ImageSource, + imageProcessingOptionsOrCallback: ImageProcessingOptions| + SegmentationMaskCallback, + callback?: SegmentationMaskCallback): void { + const imageProcessingOptions = + typeof imageProcessingOptionsOrCallback !== 'function' ? + imageProcessingOptionsOrCallback : + {}; + + this.userCallback = typeof imageProcessingOptionsOrCallback === 'function' ? + imageProcessingOptionsOrCallback : + callback!; + this.processImageData(image, imageProcessingOptions); + this.userCallback = () => {}; + } + + /** + * Performs image segmentation on the provided video frame and invokes the + * callback with the response. The method returns synchronously once the + * callback returns. Only use this method when the ImageSegmenter is + * created with running mode `video`. + * + * @param videoFrame A video frame to process. + * @param timestamp The timestamp of the current frame, in ms. + * @param callback The callback that is invoked with the segmented masks. The + * lifetime of the returned data is only guaranteed for the duration of the + * callback. + */ + segmentForVideo( + videoFrame: ImageSource, timestamp: number, + callback: SegmentationMaskCallback): void; + /** + * Performs image segmentation on the provided video frame and invokes the + * callback with the response. The method returns synchronously once the + * callback returns. Only use this method when the ImageSegmenter is + * created with running mode `video`. + * + * @param videoFrame A video frame to process. + * @param imageProcessingOptions the `ImageProcessingOptions` specifying how + * to process the input image before running inference. + * @param timestamp The timestamp of the current frame, in ms. + * @param callback The callback that is invoked with the segmented masks. The + * lifetime of the returned data is only guaranteed for the duration of the + * callback. + */ + segmentForVideo( + videoFrame: ImageSource, imageProcessingOptions: ImageProcessingOptions, + timestamp: number, callback: SegmentationMaskCallback): void; + segmentForVideo( + videoFrame: ImageSource, + timestampOrImageProcessingOptions: number|ImageProcessingOptions, + timestampOrCallback: number|SegmentationMaskCallback, + callback?: SegmentationMaskCallback): void { + const imageProcessingOptions = + typeof timestampOrImageProcessingOptions !== 'number' ? + timestampOrImageProcessingOptions : + {}; + const timestamp = typeof timestampOrImageProcessingOptions === 'number' ? + timestampOrImageProcessingOptions : + timestampOrCallback as number; + + this.userCallback = typeof timestampOrCallback === 'function' ? + timestampOrCallback : + callback!; + this.processVideoData(videoFrame, imageProcessingOptions, timestamp); + this.userCallback = () => {}; + } + + /** Updates the MediaPipe graph configuration. */ + protected override refreshGraph(): void { + const graphConfig = new CalculatorGraphConfig(); + graphConfig.addInputStream(IMAGE_STREAM); + graphConfig.addInputStream(NORM_RECT_STREAM); + graphConfig.addOutputStream(GROUPED_SEGMENTATIONS_STREAM); + + const calculatorOptions = new CalculatorOptions(); + calculatorOptions.setExtension( + ImageSegmenterGraphOptionsProto.ext, this.options); + + const segmenterNode = new CalculatorGraphConfig.Node(); + segmenterNode.setCalculator(IMAGEA_SEGMENTER_GRAPH); + segmenterNode.addInputStream('IMAGE:' + IMAGE_STREAM); + segmenterNode.addInputStream('NORM_RECT:' + NORM_RECT_STREAM); + segmenterNode.addOutputStream( + 'GROUPED_SEGMENTATION:' + GROUPED_SEGMENTATIONS_STREAM); + segmenterNode.setOptions(calculatorOptions); + + graphConfig.addNode(segmenterNode); + + this.graphRunner.attachImageVectorListener( + GROUPED_SEGMENTATIONS_STREAM, (masks, timestamp) => { + if (masks.length === 0) { + this.userCallback([], 0, 0); + } else { + this.userCallback( + masks.map(m => m.data), masks[0].width, masks[0].height); + } + this.setLatestOutputTimestamp(timestamp); + }); + + const binaryGraph = graphConfig.serializeBinary(); + this.setGraph(new Uint8Array(binaryGraph), /* isBinary= */ true); + } +} + + diff --git a/mediapipe/tasks/web/vision/image_segmenter/image_segmenter_options.d.ts b/mediapipe/tasks/web/vision/image_segmenter/image_segmenter_options.d.ts new file mode 100644 index 000000000..c17e7e421 --- /dev/null +++ b/mediapipe/tasks/web/vision/image_segmenter/image_segmenter_options.d.ts @@ -0,0 +1,41 @@ +/** + * Copyright 2022 The MediaPipe Authors. All Rights Reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +import {VisionTaskOptions} from '../../../../tasks/web/vision/core/vision_task_options'; + +/** Options to configure the MediaPipe Image Segmenter Task */ +export interface ImageSegmenterOptions extends VisionTaskOptions { + /** + * The locale to use for display names specified through the TFLite Model + * Metadata, if any. Defaults to English. + */ + displayNamesLocale?: string|undefined; + + /** + * The output type of segmentation results. + * + * The two supported modes are: + * - Category Mask: Gives a single output mask where each pixel represents + * the class which the pixel in the original image was + * predicted to belong to. + * - Confidence Mask: Gives a list of output masks (one for each class). For + * each mask, the pixel represents the prediction + * confidence, usually in the [0.0, 0.1] range. + * + * Defaults to `CATEGORY_MASK`. + */ + outputType?: 'CATEGORY_MASK'|'CONFIDENCE_MASK'|undefined; +} diff --git a/mediapipe/tasks/web/vision/image_segmenter/image_segmenter_test.ts b/mediapipe/tasks/web/vision/image_segmenter/image_segmenter_test.ts new file mode 100644 index 000000000..aa81be025 --- /dev/null +++ b/mediapipe/tasks/web/vision/image_segmenter/image_segmenter_test.ts @@ -0,0 +1,215 @@ +/** + * Copyright 2022 The MediaPipe Authors. All Rights Reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +import 'jasmine'; + +// Placeholder for internal dependency on encodeByteArray +import {CalculatorGraphConfig} from '../../../../framework/calculator_pb'; +import {addJasmineCustomFloatEqualityTester, createSpyWasmModule, MediapipeTasksFake, SpyWasmModule, verifyGraph, verifyListenersRegistered} from '../../../../tasks/web/core/task_runner_test_utils'; +import {WasmImage} from '../../../../web/graph_runner/graph_runner_image_lib'; + +import {ImageSegmenter} from './image_segmenter'; +import {ImageSegmenterOptions} from './image_segmenter_options'; + +class ImageSegmenterFake extends ImageSegmenter implements MediapipeTasksFake { + calculatorName = 'mediapipe.tasks.vision.image_segmenter.ImageSegmenterGraph'; + attachListenerSpies: jasmine.Spy[] = []; + graph: CalculatorGraphConfig|undefined; + + fakeWasmModule: SpyWasmModule; + imageVectorListener: + ((images: WasmImage[], timestamp: number) => void)|undefined; + + constructor() { + super(createSpyWasmModule(), /* glCanvas= */ null); + this.fakeWasmModule = + this.graphRunner.wasmModule as unknown as SpyWasmModule; + + this.attachListenerSpies[0] = + spyOn(this.graphRunner, 'attachImageVectorListener') + .and.callFake((stream, listener) => { + expect(stream).toEqual('segmented_masks'); + this.imageVectorListener = listener; + }); + spyOn(this.graphRunner, 'setGraph').and.callFake(binaryGraph => { + this.graph = CalculatorGraphConfig.deserializeBinary(binaryGraph); + }); + spyOn(this.graphRunner, 'addGpuBufferAsImageToStream'); + } +} + +describe('ImageSegmenter', () => { + let imageSegmenter: ImageSegmenterFake; + + beforeEach(async () => { + addJasmineCustomFloatEqualityTester(); + imageSegmenter = new ImageSegmenterFake(); + await imageSegmenter.setOptions( + {baseOptions: {modelAssetBuffer: new Uint8Array([])}}); + }); + + it('initializes graph', async () => { + verifyGraph(imageSegmenter); + verifyListenersRegistered(imageSegmenter); + }); + + it('reloads graph when settings are changed', async () => { + await imageSegmenter.setOptions({displayNamesLocale: 'en'}); + verifyGraph(imageSegmenter, ['displayNamesLocale', 'en']); + verifyListenersRegistered(imageSegmenter); + + await imageSegmenter.setOptions({displayNamesLocale: 'de'}); + verifyGraph(imageSegmenter, ['displayNamesLocale', 'de']); + verifyListenersRegistered(imageSegmenter); + }); + + it('can use custom models', async () => { + const newModel = new Uint8Array([0, 1, 2, 3, 4]); + const newModelBase64 = Buffer.from(newModel).toString('base64'); + await imageSegmenter.setOptions({ + baseOptions: { + modelAssetBuffer: newModel, + } + }); + + verifyGraph( + imageSegmenter, + /* expectedCalculatorOptions= */ undefined, + /* expectedBaseOptions= */ + [ + 'modelAsset', { + fileContent: newModelBase64, + fileName: undefined, + fileDescriptorMeta: undefined, + filePointerMeta: undefined + } + ]); + }); + + it('merges options', async () => { + await imageSegmenter.setOptions({outputType: 'CATEGORY_MASK'}); + await imageSegmenter.setOptions({displayNamesLocale: 'en'}); + verifyGraph(imageSegmenter, [['segmenterOptions', 'outputType'], 1]); + verifyGraph(imageSegmenter, ['displayNamesLocale', 'en']); + }); + + describe('setOptions()', () => { + interface TestCase { + optionName: keyof ImageSegmenterOptions; + fieldPath: string[]; + userValue: unknown; + graphValue: unknown; + defaultValue: unknown; + } + + const testCases: TestCase[] = [ + { + optionName: 'displayNamesLocale', + fieldPath: ['displayNamesLocale'], + userValue: 'en', + graphValue: 'en', + defaultValue: 'en' + }, + { + optionName: 'outputType', + fieldPath: ['segmenterOptions', 'outputType'], + userValue: 'CONFIDENCE_MASK', + graphValue: 2, + defaultValue: 1 + }, + ]; + + for (const testCase of testCases) { + it(`can set ${testCase.optionName}`, async () => { + await imageSegmenter.setOptions( + {[testCase.optionName]: testCase.userValue}); + verifyGraph(imageSegmenter, [testCase.fieldPath, testCase.graphValue]); + }); + + it(`can clear ${testCase.optionName}`, async () => { + await imageSegmenter.setOptions( + {[testCase.optionName]: testCase.userValue}); + verifyGraph(imageSegmenter, [testCase.fieldPath, testCase.graphValue]); + await imageSegmenter.setOptions({[testCase.optionName]: undefined}); + verifyGraph( + imageSegmenter, [testCase.fieldPath, testCase.defaultValue]); + }); + } + }); + + it('doesn\'t support region of interest', () => { + expect(() => { + imageSegmenter.segment( + {} as HTMLImageElement, + {regionOfInterest: {left: 0, right: 0, top: 0, bottom: 0}}, () => {}); + }).toThrowError('This task doesn\'t support region-of-interest.'); + }); + + it('supports category masks', (done) => { + const mask = new Uint8Array([1, 2, 3, 4]); + + // Pass the test data to our listener + imageSegmenter.fakeWasmModule._waitUntilIdle.and.callFake(() => { + verifyListenersRegistered(imageSegmenter); + imageSegmenter.imageVectorListener!( + [ + {data: mask, width: 2, height: 2}, + ], + /* timestamp= */ 1337); + }); + + // Invoke the image segmenter + imageSegmenter.segment({} as HTMLImageElement, (masks, width, height) => { + expect(imageSegmenter.fakeWasmModule._waitUntilIdle).toHaveBeenCalled(); + expect(masks).toHaveSize(1); + expect(masks[0]).toEqual(mask); + expect(width).toEqual(2); + expect(height).toEqual(2); + done(); + }); + }); + + it('supports confidence masks', async () => { + const mask1 = new Float32Array([0.1, 0.2, 0.3, 0.4]); + const mask2 = new Float32Array([0.5, 0.6, 0.7, 0.8]); + + await imageSegmenter.setOptions({outputType: 'CONFIDENCE_MASK'}); + + // Pass the test data to our listener + imageSegmenter.fakeWasmModule._waitUntilIdle.and.callFake(() => { + verifyListenersRegistered(imageSegmenter); + imageSegmenter.imageVectorListener!( + [ + {data: mask1, width: 2, height: 2}, + {data: mask2, width: 2, height: 2}, + ], + 1337); + }); + + return new Promise(resolve => { + // Invoke the image segmenter + imageSegmenter.segment({} as HTMLImageElement, (masks, width, height) => { + expect(imageSegmenter.fakeWasmModule._waitUntilIdle).toHaveBeenCalled(); + expect(masks).toHaveSize(2); + expect(masks[0]).toEqual(mask1); + expect(masks[1]).toEqual(mask2); + expect(width).toEqual(2); + expect(height).toEqual(2); + resolve(); + }); + }); + }); +}); diff --git a/mediapipe/tasks/web/vision/index.ts b/mediapipe/tasks/web/vision/index.ts index 7cc915f25..5a87c7a82 100644 --- a/mediapipe/tasks/web/vision/index.ts +++ b/mediapipe/tasks/web/vision/index.ts @@ -14,17 +14,30 @@ * limitations under the License. */ -// Image Classifier -export * from '../../../tasks/web/vision/image_classifier/image_classifier_options'; -export * from '../../../tasks/web/vision/image_classifier/image_classifier_result'; -export * from '../../../tasks/web/vision/image_classifier/image_classifier'; +import {FilesetResolver as FilesetResolverImpl} from '../../../tasks/web/core/fileset_resolver'; +import {GestureRecognizer as GestureRecognizerImpl} from '../../../tasks/web/vision/gesture_recognizer/gesture_recognizer'; +import {HandLandmarker as HandLandmarkerImpl} from '../../../tasks/web/vision/hand_landmarker/hand_landmarker'; +import {ImageClassifier as ImageClassifierImpl} from '../../../tasks/web/vision/image_classifier/image_classifier'; +import {ImageEmbedder as ImageEmbedderImpl} from '../../../tasks/web/vision/image_embedder/image_embedder'; +import {ImageSegmenter as ImageSegementerImpl} from '../../../tasks/web/vision/image_segmenter/image_segmenter'; +import {ObjectDetector as ObjectDetectorImpl} from '../../../tasks/web/vision/object_detector/object_detector'; -// Gesture Recognizer -export * from '../../../tasks/web/vision/gesture_recognizer/gesture_recognizer_options'; -export * from '../../../tasks/web/vision/gesture_recognizer/gesture_recognizer_result'; -export * from '../../../tasks/web/vision/gesture_recognizer/gesture_recognizer'; +// Declare the variables locally so that Rollup in OSS includes them explicitly +// as exports. +const FilesetResolver = FilesetResolverImpl; +const GestureRecognizer = GestureRecognizerImpl; +const HandLandmarker = HandLandmarkerImpl; +const ImageClassifier = ImageClassifierImpl; +const ImageEmbedder = ImageEmbedderImpl; +const ImageSegmenter = ImageSegementerImpl; +const ObjectDetector = ObjectDetectorImpl; -// Object Detector -export * from '../../../tasks/web/vision/object_detector/object_detector_options'; -export * from '../../../tasks/web/vision/object_detector/object_detector_result'; -export * from '../../../tasks/web/vision/object_detector/object_detector'; +export { + FilesetResolver, + GestureRecognizer, + HandLandmarker, + ImageClassifier, + ImageEmbedder, + ImageSegmenter, + ObjectDetector +}; diff --git a/mediapipe/tasks/web/vision/object_detector/BUILD b/mediapipe/tasks/web/vision/object_detector/BUILD index 095a84b52..76fa589c8 100644 --- a/mediapipe/tasks/web/vision/object_detector/BUILD +++ b/mediapipe/tasks/web/vision/object_detector/BUILD @@ -3,7 +3,8 @@ # This task takes video frames and outputs synchronized frames along with # the detection results for one or more object categories, using Object Detector. -load("//mediapipe/framework/port:build_config.bzl", "mediapipe_ts_library") +load("//mediapipe/framework/port:build_config.bzl", "mediapipe_ts_declaration", "mediapipe_ts_library") +load("@npm//@bazel/jasmine:index.bzl", "jasmine_node_test") package(default_visibility = ["//mediapipe/tasks:internal"]) @@ -11,20 +12,57 @@ licenses(["notice"]) mediapipe_ts_library( name = "object_detector", - srcs = [ - "object_detector.ts", - "object_detector_options.ts", - "object_detector_result.ts", - ], + srcs = ["object_detector.ts"], + visibility = ["//visibility:public"], deps = [ + ":object_detector_types", "//mediapipe/framework:calculator_jspb_proto", "//mediapipe/framework:calculator_options_jspb_proto", "//mediapipe/framework/formats:detection_jspb_proto", + "//mediapipe/tasks/cc/core/proto:base_options_jspb_proto", "//mediapipe/tasks/cc/vision/object_detector/proto:object_detector_options_jspb_proto", "//mediapipe/tasks/web/components/containers:category", - "//mediapipe/tasks/web/components/processors:base_options", "//mediapipe/tasks/web/core", - "//mediapipe/tasks/web/core:task_runner", - "//mediapipe/web/graph_runner:wasm_mediapipe_lib_ts", + "//mediapipe/tasks/web/vision/core:image_processing_options", + "//mediapipe/tasks/web/vision/core:vision_task_runner", + "//mediapipe/web/graph_runner:graph_runner_ts", ], ) + +mediapipe_ts_declaration( + name = "object_detector_types", + srcs = [ + "object_detector_options.d.ts", + "object_detector_result.d.ts", + ], + visibility = ["//visibility:public"], + deps = [ + "//mediapipe/tasks/web/components/containers:category", + "//mediapipe/tasks/web/core", + "//mediapipe/tasks/web/core:classifier_options", + "//mediapipe/tasks/web/vision/core:vision_task_options", + ], +) + +mediapipe_ts_library( + name = "object_detector_test_lib", + testonly = True, + srcs = [ + "object_detector_test.ts", + ], + deps = [ + ":object_detector", + ":object_detector_types", + "//mediapipe/framework:calculator_jspb_proto", + "//mediapipe/framework/formats:detection_jspb_proto", + "//mediapipe/framework/formats:location_data_jspb_proto", + "//mediapipe/tasks/web/core", + "//mediapipe/tasks/web/core:task_runner_test_utils", + ], +) + +jasmine_node_test( + name = "object_detector_test", + tags = ["nomsan"], + deps = [":object_detector_test_lib"], +) diff --git a/mediapipe/tasks/web/vision/object_detector/object_detector.ts b/mediapipe/tasks/web/vision/object_detector/object_detector.ts index 022bf6301..5b581432d 100644 --- a/mediapipe/tasks/web/vision/object_detector/object_detector.ts +++ b/mediapipe/tasks/web/vision/object_detector/object_detector.ts @@ -17,88 +17,99 @@ import {CalculatorGraphConfig} from '../../../../framework/calculator_pb'; import {CalculatorOptions} from '../../../../framework/calculator_options_pb'; import {Detection as DetectionProto} from '../../../../framework/formats/detection_pb'; +import {BaseOptions as BaseOptionsProto} from '../../../../tasks/cc/core/proto/base_options_pb'; import {ObjectDetectorOptions as ObjectDetectorOptionsProto} from '../../../../tasks/cc/vision/object_detector/proto/object_detector_options_pb'; -import {convertBaseOptionsToProto} from '../../../../tasks/web/components/processors/base_options'; -import {TaskRunner} from '../../../../tasks/web/core/task_runner'; -import {WasmLoaderOptions} from '../../../../tasks/web/core/wasm_loader_options'; -import {createMediaPipeLib, FileLocator, ImageSource} from '../../../../web/graph_runner/wasm_mediapipe_lib'; +import {WasmFileset} from '../../../../tasks/web/core/wasm_fileset'; +import {ImageProcessingOptions} from '../../../../tasks/web/vision/core/image_processing_options'; +import {VisionGraphRunner, VisionTaskRunner} from '../../../../tasks/web/vision/core/vision_task_runner'; +import {ImageSource, WasmModule} from '../../../../web/graph_runner/graph_runner'; // Placeholder for internal dependency on trusted resource url import {ObjectDetectorOptions} from './object_detector_options'; import {Detection} from './object_detector_result'; -const INPUT_STREAM = 'input_frame_gpu'; +const IMAGE_STREAM = 'input_frame_gpu'; +const NORM_RECT_STREAM = 'norm_rect'; const DETECTIONS_STREAM = 'detections'; const OBJECT_DETECTOR_GRAPH = 'mediapipe.tasks.vision.ObjectDetectorGraph'; const DEFAULT_CATEGORY_INDEX = -1; +export * from './object_detector_options'; +export * from './object_detector_result'; export {ImageSource}; // Used in the public API // The OSS JS API does not support the builder pattern. // tslint:disable:jspb-use-builder-pattern /** Performs object detection on images. */ -export class ObjectDetector extends TaskRunner { +export class ObjectDetector extends VisionTaskRunner { private detections: Detection[] = []; private readonly options = new ObjectDetectorOptionsProto(); /** * Initializes the Wasm runtime and creates a new object detector from the * provided options. - * @param wasmLoaderOptions A configuration object that provides the location - * of the Wasm binary and its loader. + * @param wasmFileset A configuration object that provides the location of the + * Wasm binary and its loader. * @param objectDetectorOptions The options for the Object Detector. Note that * either a path to the model asset or a model buffer needs to be * provided (via `baseOptions`). */ - static async createFromOptions( - wasmLoaderOptions: WasmLoaderOptions, + static createFromOptions( + wasmFileset: WasmFileset, objectDetectorOptions: ObjectDetectorOptions): Promise { - // Create a file locator based on the loader options - const fileLocator: FileLocator = { - locateFile() { - // The only file we load is the Wasm binary - return wasmLoaderOptions.wasmBinaryPath.toString(); - } - }; - - const detector = await createMediaPipeLib( - ObjectDetector, wasmLoaderOptions.wasmLoaderPath, - /* assetLoaderScript= */ undefined, - /* glCanvas= */ undefined, fileLocator); - await detector.setOptions(objectDetectorOptions); - return detector; + return VisionTaskRunner.createInstance( + ObjectDetector, /* initializeCanvas= */ true, wasmFileset, + objectDetectorOptions); } /** * Initializes the Wasm runtime and creates a new object detector based on the * provided model asset buffer. - * @param wasmLoaderOptions A configuration object that provides the location - * of the Wasm binary and its loader. + * @param wasmFileset A configuration object that provides the location of the + * Wasm binary and its loader. * @param modelAssetBuffer A binary representation of the model. */ static createFromModelBuffer( - wasmLoaderOptions: WasmLoaderOptions, + wasmFileset: WasmFileset, modelAssetBuffer: Uint8Array): Promise { - return ObjectDetector.createFromOptions( - wasmLoaderOptions, {baseOptions: {modelAssetBuffer}}); + return VisionTaskRunner.createInstance( + ObjectDetector, /* initializeCanvas= */ true, wasmFileset, + {baseOptions: {modelAssetBuffer}}); } /** * Initializes the Wasm runtime and creates a new object detector based on the * path to the model asset. - * @param wasmLoaderOptions A configuration object that provides the location - * of the Wasm binary and its loader. + * @param wasmFileset A configuration object that provides the location of the + * Wasm binary and its loader. * @param modelAssetPath The path to the model asset. */ static async createFromModelPath( - wasmLoaderOptions: WasmLoaderOptions, + wasmFileset: WasmFileset, modelAssetPath: string): Promise { - const response = await fetch(modelAssetPath.toString()); - const graphData = await response.arrayBuffer(); - return ObjectDetector.createFromModelBuffer( - wasmLoaderOptions, new Uint8Array(graphData)); + return VisionTaskRunner.createInstance( + ObjectDetector, /* initializeCanvas= */ true, wasmFileset, + {baseOptions: {modelAssetPath}}); + } + + /** @hideconstructor */ + constructor( + wasmModule: WasmModule, + glCanvas?: HTMLCanvasElement|OffscreenCanvas|null) { + super( + new VisionGraphRunner(wasmModule, glCanvas), IMAGE_STREAM, + NORM_RECT_STREAM, /* roiAllowed= */ false); + this.options.setBaseOptions(new BaseOptionsProto()); + } + + protected override get baseOptions(): BaseOptionsProto { + return this.options.getBaseOptions()!; + } + + protected override set baseOptions(proto: BaseOptionsProto) { + this.options.setBaseOptions(proto); } /** @@ -110,13 +121,7 @@ export class ObjectDetector extends TaskRunner { * * @param options The options for the object detector. */ - async setOptions(options: ObjectDetectorOptions): Promise { - if (options.baseOptions) { - const baseOptionsProto = await convertBaseOptionsToProto( - options.baseOptions, this.options.getBaseOptions()); - this.options.setBaseOptions(baseOptionsProto); - } - + override setOptions(options: ObjectDetectorOptions): Promise { // Note that we have to support both JSPB and ProtobufJS, hence we // have to expliclity clear the values instead of setting them to // `undefined`. @@ -150,23 +155,42 @@ export class ObjectDetector extends TaskRunner { this.options.clearCategoryDenylistList(); } - this.refreshGraph(); + return this.applyOptions(options); } /** * Performs object detection on the provided single image and waits - * synchronously for the response. - * @param imageSource An image source to process. - * @param timestamp The timestamp of the current frame, in ms. If not - * provided, defaults to `performance.now()`. + * synchronously for the response. Only use this method when the + * ObjectDetector is created with running mode `image`. + * + * @param image An image to process. + * @param imageProcessingOptions the `ImageProcessingOptions` specifying how + * to process the input image before running inference. * @return The list of detected objects */ - detect(imageSource: ImageSource, timestamp?: number): Detection[] { - // Get detections by running our MediaPipe graph. + detect(image: ImageSource, imageProcessingOptions?: ImageProcessingOptions): + Detection[] { this.detections = []; - this.addGpuBufferAsImageToStream( - imageSource, INPUT_STREAM, timestamp ?? performance.now()); - this.finishProcessing(); + this.processImageData(image, imageProcessingOptions); + return [...this.detections]; + } + + /** + * Performs object detection on the provided video frame and waits + * synchronously for the response. Only use this method when the + * ObjectDetector is created with running mode `video`. + * + * @param videoFrame A video frame to process. + * @param timestamp The timestamp of the current frame, in ms. + * @param imageProcessingOptions the `ImageProcessingOptions` specifying how + * to process the input image before running inference. + * @return The list of detected objects + */ + detectForVideo( + videoFrame: ImageSource, timestamp: number, + imageProcessingOptions?: ImageProcessingOptions): Detection[] { + this.detections = []; + this.processVideoData(videoFrame, imageProcessingOptions, timestamp); return [...this.detections]; } @@ -204,9 +228,10 @@ export class ObjectDetector extends TaskRunner { } /** Updates the MediaPipe graph configuration. */ - private refreshGraph(): void { + protected override refreshGraph(): void { const graphConfig = new CalculatorGraphConfig(); - graphConfig.addInputStream(INPUT_STREAM); + graphConfig.addInputStream(IMAGE_STREAM); + graphConfig.addInputStream(NORM_RECT_STREAM); graphConfig.addOutputStream(DETECTIONS_STREAM); const calculatorOptions = new CalculatorOptions(); @@ -215,15 +240,18 @@ export class ObjectDetector extends TaskRunner { const detectorNode = new CalculatorGraphConfig.Node(); detectorNode.setCalculator(OBJECT_DETECTOR_GRAPH); - detectorNode.addInputStream('IMAGE:' + INPUT_STREAM); + detectorNode.addInputStream('IMAGE:' + IMAGE_STREAM); + detectorNode.addInputStream('NORM_RECT:' + NORM_RECT_STREAM); detectorNode.addOutputStream('DETECTIONS:' + DETECTIONS_STREAM); detectorNode.setOptions(calculatorOptions); graphConfig.addNode(detectorNode); - this.attachProtoVectorListener(DETECTIONS_STREAM, binaryProto => { - this.addJsObjectDetections(binaryProto); - }); + this.graphRunner.attachProtoVectorListener( + DETECTIONS_STREAM, (binaryProto, timestamp) => { + this.addJsObjectDetections(binaryProto); + this.setLatestOutputTimestamp(timestamp); + }); const binaryGraph = graphConfig.serializeBinary(); this.setGraph(new Uint8Array(binaryGraph), /* isBinary= */ true); diff --git a/mediapipe/tasks/web/vision/object_detector/object_detector_options.d.ts b/mediapipe/tasks/web/vision/object_detector/object_detector_options.d.ts new file mode 100644 index 000000000..7564e7760 --- /dev/null +++ b/mediapipe/tasks/web/vision/object_detector/object_detector_options.d.ts @@ -0,0 +1,22 @@ +/** + * Copyright 2022 The MediaPipe Authors. All Rights Reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +import {ClassifierOptions} from '../../../../tasks/web/core/classifier_options'; +import {VisionTaskOptions} from '../../../../tasks/web/vision/core/vision_task_options'; + +/** Options to configure the MediaPipe Object Detector Task */ +export interface ObjectDetectorOptions extends VisionTaskOptions, + ClassifierOptions {} diff --git a/mediapipe/tasks/web/vision/object_detector/object_detector_options.ts b/mediapipe/tasks/web/vision/object_detector/object_detector_options.ts deleted file mode 100644 index eec12cf17..000000000 --- a/mediapipe/tasks/web/vision/object_detector/object_detector_options.ts +++ /dev/null @@ -1,52 +0,0 @@ -/** - * Copyright 2022 The MediaPipe Authors. All Rights Reserved. - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -import {BaseOptions} from '../../../../tasks/web/core/base_options'; - -/** Options to configure the MediaPipe Object Detector Task */ -export interface ObjectDetectorOptions { - /** Options to configure the loading of the model assets. */ - baseOptions?: BaseOptions; - - /** - * The locale to use for display names specified through the TFLite Model - * Metadata, if any. Defaults to English. - */ - displayNamesLocale?: string|undefined; - - /** The maximum number of top-scored detection results to return. */ - maxResults?: number|undefined; - - /** - * Overrides the value provided in the model metadata. Results below this - * value are rejected. - */ - scoreThreshold?: number|undefined; - - /** - * Allowlist of category names. If non-empty, detection results whose category - * name is not in this set will be filtered out. Duplicate or unknown category - * names are ignored. Mutually exclusive with `categoryDenylist`. - */ - categoryAllowlist?: string[]|undefined; - - /** - * Denylist of category names. If non-empty, detection results whose category - * name is in this set will be filtered out. Duplicate or unknown category - * names are ignored. Mutually exclusive with `categoryAllowlist`. - */ - categoryDenylist?: string[]|undefined; -} diff --git a/mediapipe/tasks/web/vision/object_detector/object_detector_result.ts b/mediapipe/tasks/web/vision/object_detector/object_detector_result.d.ts similarity index 98% rename from mediapipe/tasks/web/vision/object_detector/object_detector_result.ts rename to mediapipe/tasks/web/vision/object_detector/object_detector_result.d.ts index e9e3843bc..c9c87a1bf 100644 --- a/mediapipe/tasks/web/vision/object_detector/object_detector_result.ts +++ b/mediapipe/tasks/web/vision/object_detector/object_detector_result.d.ts @@ -16,6 +16,8 @@ import {Category} from '../../../../tasks/web/components/containers/category'; +export {Category}; + /** An integer bounding box, axis aligned. */ export declare interface BoundingBox { /** The X coordinate of the top-left corner, in pixels. */ diff --git a/mediapipe/tasks/web/vision/object_detector/object_detector_test.ts b/mediapipe/tasks/web/vision/object_detector/object_detector_test.ts new file mode 100644 index 000000000..9dd64c0b6 --- /dev/null +++ b/mediapipe/tasks/web/vision/object_detector/object_detector_test.ts @@ -0,0 +1,239 @@ +/** + * Copyright 2022 The MediaPipe Authors. All Rights Reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +import 'jasmine'; + +// Placeholder for internal dependency on encodeByteArray +import {CalculatorGraphConfig} from '../../../../framework/calculator_pb'; +import {Detection as DetectionProto} from '../../../../framework/formats/detection_pb'; +import {LocationData} from '../../../../framework/formats/location_data_pb'; +import {addJasmineCustomFloatEqualityTester, createSpyWasmModule, MediapipeTasksFake, SpyWasmModule, verifyGraph, verifyListenersRegistered} from '../../../../tasks/web/core/task_runner_test_utils'; + +import {ObjectDetector} from './object_detector'; +import {ObjectDetectorOptions} from './object_detector_options'; + +// The OSS JS API does not support the builder pattern. +// tslint:disable:jspb-use-builder-pattern + +class ObjectDetectorFake extends ObjectDetector implements MediapipeTasksFake { + lastSampleRate: number|undefined; + calculatorName = 'mediapipe.tasks.vision.ObjectDetectorGraph'; + attachListenerSpies: jasmine.Spy[] = []; + graph: CalculatorGraphConfig|undefined; + + fakeWasmModule: SpyWasmModule; + protoListener: + ((binaryProtos: Uint8Array[], timestamp: number) => void)|undefined; + + constructor() { + super(createSpyWasmModule(), /* glCanvas= */ null); + this.fakeWasmModule = + this.graphRunner.wasmModule as unknown as SpyWasmModule; + + this.attachListenerSpies[0] = + spyOn(this.graphRunner, 'attachProtoVectorListener') + .and.callFake((stream, listener) => { + expect(stream).toEqual('detections'); + this.protoListener = listener; + }); + spyOn(this.graphRunner, 'setGraph').and.callFake(binaryGraph => { + this.graph = CalculatorGraphConfig.deserializeBinary(binaryGraph); + }); + spyOn(this.graphRunner, 'addGpuBufferAsImageToStream'); + } +} + +describe('ObjectDetector', () => { + let objectDetector: ObjectDetectorFake; + + beforeEach(async () => { + addJasmineCustomFloatEqualityTester(); + objectDetector = new ObjectDetectorFake(); + await objectDetector.setOptions( + {baseOptions: {modelAssetBuffer: new Uint8Array([])}}); + }); + + it('initializes graph', async () => { + verifyGraph(objectDetector); + verifyListenersRegistered(objectDetector); + }); + + it('reloads graph when settings are changed', async () => { + await objectDetector.setOptions({maxResults: 1}); + verifyGraph(objectDetector, ['maxResults', 1]); + verifyListenersRegistered(objectDetector); + + await objectDetector.setOptions({maxResults: 5}); + verifyGraph(objectDetector, ['maxResults', 5]); + verifyListenersRegistered(objectDetector); + }); + + it('can use custom models', async () => { + const newModel = new Uint8Array([0, 1, 2, 3, 4]); + const newModelBase64 = Buffer.from(newModel).toString('base64'); + await objectDetector.setOptions({ + baseOptions: { + modelAssetBuffer: newModel, + } + }); + + verifyGraph( + objectDetector, + /* expectedCalculatorOptions= */ undefined, + /* expectedBaseOptions= */ + [ + 'modelAsset', { + fileContent: newModelBase64, + fileName: undefined, + fileDescriptorMeta: undefined, + filePointerMeta: undefined + } + ]); + }); + + it('merges options', async () => { + await objectDetector.setOptions({maxResults: 1}); + await objectDetector.setOptions({displayNamesLocale: 'en'}); + verifyGraph(objectDetector, ['maxResults', 1]); + verifyGraph(objectDetector, ['displayNamesLocale', 'en']); + }); + + describe('setOptions()', () => { + interface TestCase { + optionName: keyof ObjectDetectorOptions; + protoName: string; + customValue: unknown; + defaultValue: unknown; + } + + const testCases: TestCase[] = [ + { + optionName: 'maxResults', + protoName: 'maxResults', + customValue: 5, + defaultValue: -1 + }, + { + optionName: 'displayNamesLocale', + protoName: 'displayNamesLocale', + customValue: 'en', + defaultValue: 'en' + }, + { + optionName: 'scoreThreshold', + protoName: 'scoreThreshold', + customValue: 0.1, + defaultValue: undefined + }, + { + optionName: 'categoryAllowlist', + protoName: 'categoryAllowlistList', + customValue: ['foo'], + defaultValue: [] + }, + { + optionName: 'categoryDenylist', + protoName: 'categoryDenylistList', + customValue: ['bar'], + defaultValue: [] + }, + ]; + + for (const testCase of testCases) { + it(`can set ${testCase.optionName}`, async () => { + await objectDetector.setOptions( + {[testCase.optionName]: testCase.customValue}); + verifyGraph(objectDetector, [testCase.protoName, testCase.customValue]); + }); + + it(`can clear ${testCase.optionName}`, async () => { + await objectDetector.setOptions( + {[testCase.optionName]: testCase.customValue}); + verifyGraph(objectDetector, [testCase.protoName, testCase.customValue]); + await objectDetector.setOptions({[testCase.optionName]: undefined}); + verifyGraph( + objectDetector, [testCase.protoName, testCase.defaultValue]); + }); + } + }); + + it('doesn\'t support region of interest', () => { + expect(() => { + objectDetector.detect( + {} as HTMLImageElement, + {regionOfInterest: {left: 0, right: 0, top: 0, bottom: 0}}); + }).toThrowError('This task doesn\'t support region-of-interest.'); + }); + + it('transforms results', async () => { + const detectionProtos: Uint8Array[] = []; + + // Add a detection with all optional properties + let detection = new DetectionProto(); + detection.addScore(0.1); + detection.addLabelId(1); + detection.addLabel('foo'); + detection.addDisplayName('bar'); + let locationData = new LocationData(); + let boundingBox = new LocationData.BoundingBox(); + boundingBox.setXmin(1); + boundingBox.setYmin(2); + boundingBox.setWidth(3); + boundingBox.setHeight(4); + locationData.setBoundingBox(boundingBox); + detection.setLocationData(locationData); + detectionProtos.push(detection.serializeBinary()); + + // Add a detection without optional properties + detection = new DetectionProto(); + detection.addScore(0.2); + locationData = new LocationData(); + boundingBox = new LocationData.BoundingBox(); + locationData.setBoundingBox(boundingBox); + detection.setLocationData(locationData); + detectionProtos.push(detection.serializeBinary()); + + // Pass the test data to our listener + objectDetector.fakeWasmModule._waitUntilIdle.and.callFake(() => { + verifyListenersRegistered(objectDetector); + objectDetector.protoListener!(detectionProtos, 1337); + }); + + // Invoke the object detector + const detections = objectDetector.detect({} as HTMLImageElement); + + expect(objectDetector.fakeWasmModule._waitUntilIdle).toHaveBeenCalled(); + expect(detections.length).toEqual(2); + expect(detections[0]).toEqual({ + categories: [{ + score: 0.1, + index: 1, + categoryName: 'foo', + displayName: 'bar', + }], + boundingBox: {originX: 1, originY: 2, width: 3, height: 4} + }); + expect(detections[1]).toEqual({ + categories: [{ + score: 0.2, + index: -1, + categoryName: '', + displayName: '', + }], + boundingBox: {originX: 0, originY: 0, width: 0, height: 0} + }); + }); +}); diff --git a/mediapipe/tasks/web/vision/types.ts b/mediapipe/tasks/web/vision/types.ts new file mode 100644 index 000000000..b9d951f60 --- /dev/null +++ b/mediapipe/tasks/web/vision/types.ts @@ -0,0 +1,23 @@ +/** + * Copyright 2023 The MediaPipe Authors. All Rights Reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +export * from '../../../tasks/web/core/fileset_resolver'; +export * from '../../../tasks/web/vision/gesture_recognizer/gesture_recognizer'; +export * from '../../../tasks/web/vision/hand_landmarker/hand_landmarker'; +export * from '../../../tasks/web/vision/image_classifier/image_classifier'; +export * from '../../../tasks/web/vision/image_embedder/image_embedder'; +export * from '../../../tasks/web/vision/image_segmenter/image_segmenter'; +export * from '../../../tasks/web/vision/object_detector/object_detector'; diff --git a/mediapipe/util/BUILD b/mediapipe/util/BUILD index ab3390e0a..555569552 100644 --- a/mediapipe/util/BUILD +++ b/mediapipe/util/BUILD @@ -186,6 +186,7 @@ cc_library( hdrs = [ "resource_util.h", ], + # We use Objective-C++ on iOS. copts = select({ "//conditions:default": [], @@ -228,6 +229,7 @@ cc_library( visibility = ["//visibility:public"], deps = [ "//mediapipe/framework/port:logging", + "@com_google_absl//absl/container:flat_hash_map", "@com_google_absl//absl/functional:function_ref", ], ) @@ -367,3 +369,21 @@ cc_test( "//mediapipe/framework/port:gtest_main", ], ) + +cc_library( + name = "image_test_utils", + testonly = 1, + srcs = ["image_test_utils.cc"], + hdrs = ["image_test_utils.h"], + visibility = ["//visibility:public"], + deps = [ + "//mediapipe/framework:packet", + "//mediapipe/framework:timestamp", + "//mediapipe/framework/formats:image", + "//mediapipe/framework/formats:image_frame", + "//mediapipe/framework/formats:image_frame_opencv", + "//mediapipe/framework/port:opencv_core", + "//mediapipe/framework/port:opencv_imgcodecs", + "//mediapipe/framework/port:opencv_imgproc", + ], +) diff --git a/mediapipe/util/image_test_utils.cc b/mediapipe/util/image_test_utils.cc new file mode 100644 index 000000000..815666985 --- /dev/null +++ b/mediapipe/util/image_test_utils.cc @@ -0,0 +1,57 @@ +#include "mediapipe/util/image_test_utils.h" + +#include "mediapipe/framework/formats/image_frame.h" +#include "mediapipe/framework/formats/image_frame_opencv.h" +#include "mediapipe/framework/port/opencv_core_inc.h" +#include "mediapipe/framework/port/opencv_imgcodecs_inc.h" +#include "mediapipe/framework/port/opencv_imgproc_inc.h" +#include "mediapipe/framework/timestamp.h" + +namespace mediapipe { + +cv::Mat GetRgb(const std::string& path) { + cv::Mat bgr = cv::imread(path); + cv::Mat rgb; + cv::cvtColor(bgr, rgb, cv::COLOR_BGR2RGB); + return rgb; +} + +cv::Mat GetRgba(const std::string& path) { + cv::Mat bgr = cv::imread(path); + cv::Mat rgb; + cv::cvtColor(bgr, rgb, cv::COLOR_BGR2RGBA); + return rgb; +} + +cv::Mat GetGray(const std::string& path) { + cv::Mat bgr = cv::imread(path); + cv::Mat gray; + cv::cvtColor(bgr, gray, cv::COLOR_BGR2GRAY); + return gray; +} + +mediapipe::ImageFormat::Format GetImageFormat(int image_channels) { + if (image_channels == 4) { + return ImageFormat::SRGBA; + } else if (image_channels == 3) { + return ImageFormat::SRGB; + } else if (image_channels == 1) { + return ImageFormat::GRAY8; + } + LOG(FATAL) << "Unsupported input image channles: " << image_channels; +} + +Packet MakeImageFramePacket(cv::Mat input, int timestamp) { + ImageFrame input_image(GetImageFormat(input.channels()), input.cols, + input.rows, input.step, input.data, [](uint8*) {}); + return MakePacket(std::move(input_image)).At(Timestamp(0)); +} + +Packet MakeImagePacket(cv::Mat input, int timestamp) { + mediapipe::Image input_image(std::make_shared( + GetImageFormat(input.channels()), input.cols, input.rows, input.step, + input.data, [](uint8*) {})); + return MakePacket(std::move(input_image)).At(Timestamp(0)); +} + +} // namespace mediapipe diff --git a/mediapipe/util/image_test_utils.h b/mediapipe/util/image_test_utils.h new file mode 100644 index 000000000..6df9644d2 --- /dev/null +++ b/mediapipe/util/image_test_utils.h @@ -0,0 +1,32 @@ +#ifndef MEDIAPIPE_UTIL_IMAGE_TEST_UTILS_H_ +#define MEDIAPIPE_UTIL_IMAGE_TEST_UTILS_H_ + +#include + +#include "mediapipe/framework/formats/image.h" +#include "mediapipe/framework/packet.h" +#include "mediapipe/framework/port/opencv_core_inc.h" + +namespace mediapipe { + +// Reads the image file into cv::Mat with RGB channels. +cv::Mat GetRgb(const std::string& path); + +// Reads the image file into cv::Mat with RGBA channels. +cv::Mat GetRgba(const std::string& path); + +// Reads the image file into cv::Mat with Gray channel. +cv::Mat GetGray(const std::string& path); + +// Converts the image channels into corresponding ImageFormat. +mediapipe::ImageFormat::Format GetImageFormat(int image_channels); + +// Converts the cv::Mat into ImageFrame packet. +Packet MakeImageFramePacket(cv::Mat input, int timestamp = 0); + +// Converts the cv::Mat into Image packet. +Packet MakeImagePacket(cv::Mat input, int timestamp = 0); + +} // namespace mediapipe + +#endif // MEDIAPIPE_UTIL_IMAGE_TEST_UTILS_H_ diff --git a/mediapipe/util/log_fatal_to_breakpad.cc b/mediapipe/util/log_fatal_to_breakpad.cc new file mode 100644 index 000000000..45087f2e3 --- /dev/null +++ b/mediapipe/util/log_fatal_to_breakpad.cc @@ -0,0 +1,50 @@ +#include "mediapipe/util/log_fatal_to_breakpad.h" + +#import + +#include "absl/log/log.h" +#include "absl/log/log_sink.h" +#include "absl/log/log_sink_registry.h" +#import "googlemac/iPhone/Shared/GoogleIOSBreakpad/Classes/GoogleBreakpadController.h" + +namespace mediapipe { +namespace { +NSString* MakeNSString(absl::string_view str) { + return [[NSString alloc] initWithBytes:str.data() + length:str.length() + encoding:NSUTF8StringEncoding]; +} +} // namespace + +static NSString* const kFatalLogMessageKey = @"fatal_log_message"; + +class BreakpadFatalLogSink : public absl::LogSink { + public: + BreakpadFatalLogSink() + : breakpad_controller_([GoogleBreakpadController sharedInstance]) {} + void Send(const absl::LogEntry& entry) override { + if (entry.log_severity() != absl::LogSeverity::kFatal) return; + __block NSString* message = MakeNSString(entry.text_message_with_prefix()); + [breakpad_controller_ withBreakpadRef:^(BreakpadRef breakpad) { + // NOTE: This block runs on Breakpad's background queue. + if (!breakpad) return; + BreakpadAddUploadParameter(breakpad, kFatalLogMessageKey, message); + }]; + } + + private: + GoogleBreakpadController* breakpad_controller_; +}; + +absl::LogSink* GetBreakpadFatalLogSink() { + static BreakpadFatalLogSink sink; + return &sink; +} + +// This log sink is automatically enabled when including this library. +static const auto kRegisterLogSink = [] { + absl::AddLogSink(GetBreakpadFatalLogSink()); + return true; +}(); + +} // namespace mediapipe diff --git a/mediapipe/util/log_fatal_to_breakpad.h b/mediapipe/util/log_fatal_to_breakpad.h new file mode 100644 index 000000000..1712a9af8 --- /dev/null +++ b/mediapipe/util/log_fatal_to_breakpad.h @@ -0,0 +1,15 @@ +#ifndef MEDIAPIPE_UTIL_LOG_FATAL_TO_BREAKPAD_H_ +#define MEDIAPIPE_UTIL_LOG_FATAL_TO_BREAKPAD_H_ + +#include "absl/log/log_sink.h" + +namespace mediapipe { + +// Returns a singleton instance of a log sink that sends FATAL log messages to +// Breakpad. This log sink is enabled by default when this library is included +// in your binary. +absl::LogSink* GetBreakpadFatalLogSink(); + +} // namespace mediapipe + +#endif // MEDIAPIPE_UTIL_LOG_FATAL_TO_BREAKPAD_H_ diff --git a/mediapipe/util/packet_test_util.h b/mediapipe/util/packet_test_util.h index 106d7f8d4..61e9322e1 100644 --- a/mediapipe/util/packet_test_util.h +++ b/mediapipe/util/packet_test_util.h @@ -32,30 +32,29 @@ namespace mediapipe { namespace internal { template -class PacketMatcher : public ::testing::MatcherInterface { +class PacketMatcher : public testing::MatcherInterface { public: template explicit PacketMatcher(InnerMatcher inner_matcher) : inner_matcher_( - ::testing::SafeMatcherCast(inner_matcher)) {} + testing::SafeMatcherCast(inner_matcher)) {} // Returns true iff the packet contains value of PayloadType satisfying // the inner matcher. - bool MatchAndExplain( - const Packet& packet, - ::testing::MatchResultListener* listener) const override { + bool MatchAndExplain(const Packet& packet, + testing::MatchResultListener* listener) const override { if (!packet.ValidateAsType().ok()) { *listener << packet.DebugString() << " does not contain expected type " << ExpectedTypeName(); return false; } - ::testing::StringMatchResultListener match_listener; + testing::StringMatchResultListener match_listener; const PayloadType& payload = packet.Get(); const bool matches = inner_matcher_.MatchAndExplain(payload, &match_listener); const std::string explanation = match_listener.str(); *listener << packet.DebugString() << " containing value " - << ::testing::PrintToString(payload); + << testing::PrintToString(payload); if (!explanation.empty()) { *listener << ", which " << explanation; } @@ -78,9 +77,28 @@ class PacketMatcher : public ::testing::MatcherInterface { return ::mediapipe::Demangle(typeid(PayloadType).name()); } - const ::testing::Matcher inner_matcher_; + const testing::Matcher inner_matcher_; }; +inline std::string SourceString(Timestamp t) { + return (t.IsSpecialValue()) + ? t.DebugString() + : absl::StrCat("Timestamp(", t.DebugString(), ")"); +} + +template +std::string SourceString(Packet packet) { + std::ostringstream oss; + if (packet.IsEmpty()) { + oss << "Packet()"; + } else { + oss << "MakePacket<" << MediaPipeTypeStringOrDemangled() << ">(" + << packet.Get() << ")"; + } + oss << ".At(" << SourceString(packet.Timestamp()) << ")"; + return oss.str(); +} + } // namespace internal // Creates matcher validating that the packet contains value of expected type @@ -91,9 +109,9 @@ class PacketMatcher : public ::testing::MatcherInterface { // // EXPECT_THAT(MakePacket(42), PacketContains(Eq(42))) template -inline ::testing::Matcher PacketContains( +inline testing::Matcher PacketContains( InnerMatcher inner_matcher) { - return ::testing::MakeMatcher( + return testing::MakeMatcher( new internal::PacketMatcher(inner_matcher)); } @@ -110,7 +128,7 @@ inline ::testing::Matcher PacketContains( // Eq(42))) template -inline ::testing::Matcher PacketContainsTimestampAndPayload( +inline testing::Matcher PacketContainsTimestampAndPayload( TimestampMatcher timestamp_matcher, ContentMatcher content_matcher) { return testing::AllOf( testing::Property("Packet::Timestamp", &Packet::Timestamp, @@ -118,6 +136,46 @@ inline ::testing::Matcher PacketContainsTimestampAndPayload( PacketContains(content_matcher)); } +template +class PacketEqMatcher : public testing::MatcherInterface { + public: + PacketEqMatcher(Packet packet) : packet_(packet) {} + void DescribeTo(::std::ostream* os) const override { + *os << "The expected packet: " << internal::SourceString(packet_); + } + bool MatchAndExplain(Packet value, + testing::MatchResultListener* listener) const override { + bool unequal = (value.Timestamp() != packet_.Timestamp() || + value.IsEmpty() != packet_.IsEmpty() || + (!value.IsEmpty() && value.Get() != packet_.Get())); + if (unequal && listener->IsInterested()) { + *listener << "The actual packet: " << internal::SourceString(value); + } + return !unequal; + } + const Packet packet_; +}; + +template +testing::Matcher PacketEq(Packet packet) { + return MakeMatcher(new PacketEqMatcher(packet)); +} + +template +std::vector> PacketMatchers( + std::vector packets) { + std::vector> result; + for (const auto& packet : packets) { + result.push_back(PacketEq(packet)); + } + return result; +} + +} // namespace mediapipe + +namespace mediapipe { +using mediapipe::PacketContains; +using mediapipe::PacketContainsTimestampAndPayload; } // namespace mediapipe #endif // MEDIAPIPE_UTIL_PACKET_TEST_UTIL_H_ diff --git a/mediapipe/util/rectangle_util_test.cc b/mediapipe/util/rectangle_util_test.cc index cd1946d45..3bc323f9f 100644 --- a/mediapipe/util/rectangle_util_test.cc +++ b/mediapipe/util/rectangle_util_test.cc @@ -20,6 +20,7 @@ namespace mediapipe { namespace { +using ::mediapipe::NormalizedRect; using ::testing::FloatNear; class RectangleUtilTest : public testing::Test { diff --git a/mediapipe/util/resource_cache.h b/mediapipe/util/resource_cache.h index 4cd869f6a..2b3ccbc7d 100644 --- a/mediapipe/util/resource_cache.h +++ b/mediapipe/util/resource_cache.h @@ -17,6 +17,7 @@ #include +#include "absl/container/flat_hash_map.h" #include "absl/functional/function_ref.h" #include "mediapipe/framework/port/logging.h" @@ -26,7 +27,8 @@ namespace mediapipe { // resource (e.g., image dimension for an image pool) is described bye the `Key` // type. The `Value` type must include an unset value, with implicit conversion // to bool reflecting set/unset state. -template +template ::hasher> class ResourceCache { public: Value Lookup( @@ -36,15 +38,14 @@ class ResourceCache { Entry* entry; if (map_it == map_.end()) { std::tie(map_it, std::ignore) = - map_.emplace(std::piecewise_construct, std::forward_as_tuple(key), - std::forward_as_tuple(key)); - entry = &map_it->second; + map_.try_emplace(key, std::make_unique(key)); + entry = map_it->second.get(); CHECK_EQ(entry->request_count, 0); entry->request_count = 1; entry_list_.Append(entry); if (entry->prev != nullptr) CHECK_GE(entry->prev->request_count, 1); } else { - entry = &map_it->second; + entry = map_it->second.get(); ++entry->request_count; Entry* larger = entry->prev; while (larger != nullptr && @@ -171,7 +172,7 @@ class ResourceCache { size_t size_ = 0; }; - std::unordered_map map_; + absl::flat_hash_map, KeyHash> map_; EntryList entry_list_; int total_request_count_ = 0; }; diff --git a/mediapipe/util/resource_util.cc b/mediapipe/util/resource_util.cc index 8f40154a0..38636f32e 100644 --- a/mediapipe/util/resource_util.cc +++ b/mediapipe/util/resource_util.cc @@ -37,6 +37,8 @@ absl::Status GetResourceContents(const std::string& path, std::string* output, return internal::DefaultGetResourceContents(path, output, read_as_binary); } +bool HasCustomGlobalResourceProvider() { return resource_provider_ != nullptr; } + void SetCustomGlobalResourceProvider(ResourceProviderFn fn) { resource_provider_ = std::move(fn); } diff --git a/mediapipe/util/resource_util_custom.h b/mediapipe/util/resource_util_custom.h index 6bc1513c6..e74af8b2e 100644 --- a/mediapipe/util/resource_util_custom.h +++ b/mediapipe/util/resource_util_custom.h @@ -10,6 +10,9 @@ namespace mediapipe { typedef std::function ResourceProviderFn; +// Returns true if files are provided via a custom resource provider. +bool HasCustomGlobalResourceProvider(); + // Overrides the behavior of GetResourceContents. void SetCustomGlobalResourceProvider(ResourceProviderFn fn); diff --git a/mediapipe/util/sequence/media_sequence_test.cc b/mediapipe/util/sequence/media_sequence_test.cc index 40a474599..42b0e3889 100644 --- a/mediapipe/util/sequence/media_sequence_test.cc +++ b/mediapipe/util/sequence/media_sequence_test.cc @@ -802,7 +802,7 @@ TEST(MediaSequenceTest, ReconcileMetadataImages) { tensorflow::SequenceExample sequence; cv::Mat image(2, 3, CV_8UC3, cv::Scalar(0, 0, 255)); std::vector bytes; - ASSERT_TRUE(cv::imencode(".jpg", image, bytes, {80})); + ASSERT_TRUE(cv::imencode(".jpg", image, bytes, {})); std::string encoded_image(bytes.begin(), bytes.end()); AddImageEncoded(encoded_image, &sequence); AddImageEncoded(encoded_image, &sequence); @@ -843,7 +843,7 @@ TEST(MediaSequenceTest, ReconcileMetadataFlowEncoded) { tensorflow::SequenceExample sequence; cv::Mat image(2, 3, CV_8UC3, cv::Scalar(0, 0, 255)); std::vector bytes; - ASSERT_TRUE(cv::imencode(".jpg", image, bytes, {80})); + ASSERT_TRUE(cv::imencode(".jpg", image, bytes, {})); std::string encoded_flow(bytes.begin(), bytes.end()); AddForwardFlowEncoded(encoded_flow, &sequence); diff --git a/mediapipe/util/tflite/tflite_gpu_runner.h b/mediapipe/util/tflite/tflite_gpu_runner.h index dfbc8d659..5eeaa230f 100644 --- a/mediapipe/util/tflite/tflite_gpu_runner.h +++ b/mediapipe/util/tflite/tflite_gpu_runner.h @@ -21,6 +21,7 @@ #include "absl/status/status.h" #include "mediapipe/framework/port.h" +#include "mediapipe/framework/port/ret_check.h" #include "mediapipe/framework/port/status.h" #include "mediapipe/framework/port/statusor.h" #include "tensorflow/lite/core/api/op_resolver.h" @@ -89,7 +90,8 @@ class TFLiteGPURunner { serialized_binary_cache_ = std::move(cache); } - std::vector GetSerializedBinaryCache() { + absl::StatusOr> GetSerializedBinaryCache() { + RET_CHECK(cl_environment_) << "CL environment is not initialized."; return cl_environment_->GetSerializedBinaryCache(); } diff --git a/mediapipe/util/tracking/BUILD b/mediapipe/util/tracking/BUILD index 319e99d5b..816af2533 100644 --- a/mediapipe/util/tracking/BUILD +++ b/mediapipe/util/tracking/BUILD @@ -134,7 +134,6 @@ proto_library( mediapipe_cc_proto_library( name = "tone_models_cc_proto", srcs = ["tone_models.proto"], - visibility = ["//visibility:public"], deps = [":tone_models_proto"], ) @@ -142,7 +141,6 @@ mediapipe_cc_proto_library( name = "tone_estimation_cc_proto", srcs = ["tone_estimation.proto"], cc_deps = [":tone_models_cc_proto"], - visibility = ["//visibility:public"], deps = [":tone_estimation_proto"], ) @@ -153,21 +151,18 @@ mediapipe_cc_proto_library( ":tone_estimation_cc_proto", ":tone_models_cc_proto", ], - visibility = ["//visibility:public"], deps = [":region_flow_computation_proto"], ) mediapipe_cc_proto_library( name = "motion_saliency_cc_proto", srcs = ["motion_saliency.proto"], - visibility = ["//visibility:public"], deps = [":motion_saliency_proto"], ) mediapipe_cc_proto_library( name = "motion_estimation_cc_proto", srcs = ["motion_estimation.proto"], - visibility = ["//visibility:public"], deps = [":motion_estimation_proto"], ) @@ -179,7 +174,6 @@ mediapipe_cc_proto_library( ":motion_saliency_cc_proto", ":region_flow_computation_cc_proto", ], - visibility = ["//visibility:public"], deps = [":motion_analysis_proto"], ) @@ -187,14 +181,12 @@ mediapipe_cc_proto_library( name = "region_flow_cc_proto", srcs = ["region_flow.proto"], cc_deps = [":motion_models_cc_proto"], - visibility = ["//visibility:public"], deps = [":region_flow_proto"], ) mediapipe_cc_proto_library( name = "motion_models_cc_proto", srcs = ["motion_models.proto"], - visibility = ["//visibility:public"], deps = [":motion_models_proto"], ) @@ -202,21 +194,18 @@ mediapipe_cc_proto_library( name = "camera_motion_cc_proto", srcs = ["camera_motion.proto"], cc_deps = [":motion_models_cc_proto"], - visibility = ["//visibility:public"], deps = [":camera_motion_proto"], ) mediapipe_cc_proto_library( name = "push_pull_filtering_cc_proto", srcs = ["push_pull_filtering.proto"], - visibility = ["//visibility:public"], deps = [":push_pull_filtering_proto"], ) mediapipe_cc_proto_library( name = "frame_selection_solution_evaluator_cc_proto", srcs = ["frame_selection_solution_evaluator.proto"], - visibility = ["//visibility:public"], deps = [":frame_selection_solution_evaluator_proto"], ) @@ -228,7 +217,6 @@ mediapipe_cc_proto_library( ":frame_selection_solution_evaluator_cc_proto", ":region_flow_cc_proto", ], - visibility = ["//visibility:public"], deps = [":frame_selection_proto"], ) @@ -239,7 +227,6 @@ mediapipe_cc_proto_library( ":motion_models_cc_proto", ":region_flow_cc_proto", ], - visibility = ["//visibility:public"], deps = [":flow_packager_proto"], ) @@ -247,7 +234,6 @@ mediapipe_cc_proto_library( name = "tracking_cc_proto", srcs = ["tracking.proto"], cc_deps = [":motion_models_cc_proto"], - visibility = ["//visibility:public"], deps = [":tracking_proto"], ) @@ -255,14 +241,12 @@ mediapipe_cc_proto_library( name = "box_tracker_cc_proto", srcs = ["box_tracker.proto"], cc_deps = [":tracking_cc_proto"], - visibility = ["//visibility:public"], deps = [":box_tracker_proto"], ) mediapipe_cc_proto_library( name = "tracked_detection_manager_config_cc_proto", srcs = ["tracked_detection_manager_config.proto"], - visibility = ["//visibility:public"], deps = [":tracked_detection_manager_config_proto"], ) @@ -273,7 +257,6 @@ mediapipe_cc_proto_library( ":box_tracker_cc_proto", ":region_flow_cc_proto", ], - visibility = ["//visibility:public"], deps = [":box_detector_proto"], ) @@ -299,7 +282,6 @@ cc_library( srcs = ["motion_models_cv.cc"], hdrs = ["motion_models_cv.h"], deps = [ - ":camera_motion_cc_proto", ":motion_models", ":motion_models_cc_proto", "//mediapipe/framework/port:opencv_core", @@ -458,7 +440,6 @@ cc_library( "//mediapipe/framework/port:integral_types", "//mediapipe/framework/port:logging", "//mediapipe/framework/port:opencv_core", - "//mediapipe/framework/port:opencv_highgui", ], ) @@ -739,7 +720,7 @@ cc_test( "//mediapipe/framework/port:gtest_main", "//mediapipe/framework/port:logging", "//mediapipe/framework/port:opencv_core", - "//mediapipe/framework/port:opencv_highgui", + "//mediapipe/framework/port:opencv_imgcodecs", "//mediapipe/framework/port:opencv_imgproc", "//mediapipe/framework/port:status", "//mediapipe/framework/port:vector", diff --git a/mediapipe/util/tracking/motion_analysis.cc b/mediapipe/util/tracking/motion_analysis.cc index 0b7678889..5b6a970cf 100644 --- a/mediapipe/util/tracking/motion_analysis.cc +++ b/mediapipe/util/tracking/motion_analysis.cc @@ -791,7 +791,7 @@ void MotionAnalysis::VisualizeBlurAnalysisRegions(cv::Mat* input_view) { region_flow_computation_->ComputeBlurMask(*input_view, &corner_values, &mask); cv::Mat mask_3c; - cv::cvtColor(mask, mask_3c, CV_GRAY2RGB); + cv::cvtColor(mask, mask_3c, cv::COLOR_GRAY2RGB); cv::addWeighted(*input_view, 0.5, mask_3c, 0.5, -128, *input_view); } diff --git a/mediapipe/util/tracking/region_flow_computation.cc b/mediapipe/util/tracking/region_flow_computation.cc index cfd5c23c2..708c868b5 100644 --- a/mediapipe/util/tracking/region_flow_computation.cc +++ b/mediapipe/util/tracking/region_flow_computation.cc @@ -30,6 +30,7 @@ #include "absl/container/node_hash_set.h" #include "absl/memory/memory.h" #include "mediapipe/framework/port/logging.h" +#include "mediapipe/framework/port/opencv_core_inc.h" #include "mediapipe/framework/port/opencv_features2d_inc.h" #include "mediapipe/framework/port/opencv_imgproc_inc.h" #include "mediapipe/framework/port/opencv_video_inc.h" @@ -935,12 +936,13 @@ bool RegionFlowComputation::InitFrame(const cv::Mat& source, // Area based method best for downsampling. // For color images to temporary buffer. cv::Mat& resized = source.channels() == 1 ? dest_frame : *curr_color_image_; - cv::resize(source, resized, resized.size(), 0, 0, CV_INTER_AREA); + cv::resize(source, resized, resized.size(), 0, 0, cv::INTER_AREA); source_ptr = &resized; // Resize feature extraction mask if needed. if (!source_mask.empty()) { dest_mask.create(resized.rows, resized.cols, CV_8UC1); - cv::resize(source_mask, dest_mask, dest_mask.size(), 0, 0, CV_INTER_NN); + cv::resize(source_mask, dest_mask, dest_mask.size(), 0, 0, + cv::INTER_NEAREST); } } else if (!source_mask.empty()) { source_mask.copyTo(dest_mask); @@ -954,7 +956,7 @@ bool RegionFlowComputation::InitFrame(const cv::Mat& source, const int dimension = visual_options.tiny_image_dimension(); data->tiny_image.create(dimension, dimension, type); cv::resize(*source_ptr, data->tiny_image, data->tiny_image.size(), 0, 0, - CV_INTER_AREA); + cv::INTER_AREA); } if (source_ptr->channels() == 1 && @@ -2286,7 +2288,7 @@ void RegionFlowComputation::ExtractFeatures( // Initialize mask from frame's feature extraction mask, by downsampling and // negating the latter mask. if (!data->mask.empty()) { - cv::resize(data->mask, mask, mask.size(), 0, 0, CV_INTER_NN); + cv::resize(data->mask, mask, mask.size(), 0, 0, cv::INTER_NEAREST); for (int y = 0; y < mask.rows; ++y) { uint8* mask_ptr = mask.ptr(y); for (int x = 0; x < mask.cols; ++x) { @@ -2590,12 +2592,6 @@ void RegionFlowComputation::TrackFeatures(FrameTrackingData* from_data_ptr, cv::_InputArray input_frame2(data2.pyramid); #endif - // Using old c-interface for OpenCV's 2.2 tracker. - CvTermCriteria criteria; - criteria.type = CV_TERMCRIT_EPS | CV_TERMCRIT_ITER; - criteria.max_iter = options_.tracking_options().tracking_iterations(); - criteria.epsilon = 0.02f; - feature_track_error_.resize(num_features); feature_status_.resize(num_features); if (use_cv_tracking_) { diff --git a/mediapipe/util/tracking/region_flow_computation_test.cc b/mediapipe/util/tracking/region_flow_computation_test.cc index 0ac6dc2a5..435a8e200 100644 --- a/mediapipe/util/tracking/region_flow_computation_test.cc +++ b/mediapipe/util/tracking/region_flow_computation_test.cc @@ -28,7 +28,7 @@ #include "mediapipe/framework/port/gtest.h" #include "mediapipe/framework/port/logging.h" #include "mediapipe/framework/port/opencv_core_inc.h" -#include "mediapipe/framework/port/opencv_highgui_inc.h" +#include "mediapipe/framework/port/opencv_imgcodecs_inc.h" #include "mediapipe/framework/port/opencv_imgproc_inc.h" #include "mediapipe/framework/port/status.h" #include "mediapipe/framework/port/vector.h" diff --git a/mediapipe/util/tracking/tracked_detection.cc b/mediapipe/util/tracking/tracked_detection.cc index 130a87640..80a6981a8 100644 --- a/mediapipe/util/tracking/tracked_detection.cc +++ b/mediapipe/util/tracking/tracked_detection.cc @@ -20,6 +20,8 @@ namespace mediapipe { namespace { +using ::mediapipe::NormalizedRect; + // Struct for carrying boundary information. struct NormalizedRectBounds { float left, right, top, bottom; diff --git a/mediapipe/util/tracking/tracked_detection_manager.cc b/mediapipe/util/tracking/tracked_detection_manager.cc index 597827f3c..a9e348ceb 100644 --- a/mediapipe/util/tracking/tracked_detection_manager.cc +++ b/mediapipe/util/tracking/tracked_detection_manager.cc @@ -21,6 +21,7 @@ namespace { +using ::mediapipe::NormalizedRect; using mediapipe::TrackedDetection; // Checks if a point is out of view. diff --git a/mediapipe/util/tracking/tracked_detection_test.cc b/mediapipe/util/tracking/tracked_detection_test.cc index 60b9df1b1..13efaab92 100644 --- a/mediapipe/util/tracking/tracked_detection_test.cc +++ b/mediapipe/util/tracking/tracked_detection_test.cc @@ -18,6 +18,8 @@ namespace mediapipe { +using ::mediapipe::NormalizedRect; + const float kErrorMargin = 1e-4f; TEST(TrackedDetectionTest, ConstructorWithoutBox) { diff --git a/mediapipe/web/graph_runner/BUILD b/mediapipe/web/graph_runner/BUILD index dab6be50f..5c12947af 100644 --- a/mediapipe/web/graph_runner/BUILD +++ b/mediapipe/web/graph_runner/BUILD @@ -3,32 +3,24 @@ load("//mediapipe/framework/port:build_config.bzl", "mediapipe_ts_library") package(default_visibility = [ - ":internal", "//mediapipe/tasks:internal", ]) -package_group( - name = "internal", - packages = [ - "//mediapipe/app/pursuit/wasm/web_ml_cpu/typescript/...", - ], -) - mediapipe_ts_library( - name = "wasm_mediapipe_lib_ts", + name = "graph_runner_ts", srcs = [ - ":wasm_mediapipe_lib.ts", + ":graph_runner.ts", ], allow_unoptimized_namespaces = True, ) mediapipe_ts_library( - name = "wasm_mediapipe_image_lib_ts", + name = "graph_runner_image_lib_ts", srcs = [ - ":wasm_mediapipe_image_lib.ts", + ":graph_runner_image_lib.ts", ], allow_unoptimized_namespaces = True, - deps = [":wasm_mediapipe_lib_ts"], + deps = [":graph_runner_ts"], ) mediapipe_ts_library( @@ -37,5 +29,5 @@ mediapipe_ts_library( ":register_model_resources_graph_service.ts", ], allow_unoptimized_namespaces = True, - deps = [":wasm_mediapipe_lib_ts"], + deps = [":graph_runner_ts"], ) diff --git a/mediapipe/web/graph_runner/wasm_mediapipe_lib.ts b/mediapipe/web/graph_runner/graph_runner.ts similarity index 82% rename from mediapipe/web/graph_runner/wasm_mediapipe_lib.ts rename to mediapipe/web/graph_runner/graph_runner.ts index 82a3a3f16..644d74918 100644 --- a/mediapipe/web/graph_runner/wasm_mediapipe_lib.ts +++ b/mediapipe/web/graph_runner/graph_runner.ts @@ -15,9 +15,6 @@ export declare interface FileLocator { locateFile: (filename: string) => string; } -/** Listener to be passed in by user for handling output audio data. */ -export type AudioOutputListener = (output: Float32Array) => void; - /** * Declarations for Emscripten's WebAssembly Module behavior, so TS compiler * doesn't break our JS/C++ bridge. @@ -32,19 +29,14 @@ export declare interface WasmModule { _bindTextureToCanvas: () => boolean; _changeBinaryGraph: (size: number, dataPtr: number) => void; _changeTextGraph: (size: number, dataPtr: number) => void; - _configureAudio: - (channels: number, samples: number, sampleRate: number) => void; _free: (ptr: number) => void; _malloc: (size: number) => number; - _processAudio: (dataPtr: number, timestamp: number) => void; _processFrame: (width: number, height: number, timestamp: number) => void; _setAutoRenderToScreen: (enabled: boolean) => void; _waitUntilIdle: () => void; // Exposed so that clients of this lib can access this field dataFileDownloads?: {[url: string]: {loaded: number, total: number}}; - // Wasm module will call us back at this function when given audio data. - onAudioOutput?: AudioOutputListener; // Wasm Module multistream entrypoints. Require // gl_graph_runner_internal_multi_input as a build dependency. @@ -81,10 +73,11 @@ export declare interface WasmModule { // Wasm Module output listener entrypoints. Also built as part of // gl_graph_runner_internal_multi_input. - simpleListeners?: {[outputStreamName: string]: (data: unknown) => void}; + simpleListeners?: + {[outputStreamName: string]: (data: unknown, timestamp: number) => void}; vectorListeners?: { [outputStreamName: string]: ( - data: unknown, index: number, length: number) => void + data: unknown, index: number, length: number, timestamp: number) => void }; _attachBoolListener: (streamNamePtr: number) => void; _attachBoolVectorListener: (streamNamePtr: number) => void; @@ -100,11 +93,14 @@ export declare interface WasmModule { _attachProtoVectorListener: (streamNamePtr: number, makeDeepCopy?: boolean) => void; - // Requires dependency ":gl_graph_runner_audio_out", and will register an - // audio output listening function which can be tapped into dynamically during - // graph running via onAudioOutput. This call must be made before graph is - // initialized, but after wasmModule is instantiated. - _attachAudioOutputListener: () => void; + // Require dependency ":gl_graph_runner_audio_out" + _attachAudioListener: (streamNamePtr: number, makeDeepCopy?: boolean) => void; + + // Require dependency ":gl_graph_runner_audio" + _addAudioToInputStream: (dataPtr: number, numChannels: number, + numSamples: number, streamNamePtr: number, timestamp: number) => void; + _configureAudio: (channels: number, samples: number, sampleRate: number, + streamNamePtr: number, headerNamePtr: number) => void; // TODO: Refactor to just use a few numbers (perhaps refactor away // from gl_graph_runner_internal.cc entirely to use something a little more @@ -129,7 +125,7 @@ declare global { declare function importScripts(...urls: Array): void; /** - * Valid types of image sources which we can run our WasmMediaPipeLib over. + * Valid types of image sources which we can run our GraphRunner over. */ export type ImageSource = HTMLCanvasElement|HTMLVideoElement|HTMLImageElement|ImageData|ImageBitmap; @@ -138,9 +134,11 @@ export type ImageSource = /** A listener that will be invoked with an absl::StatusCode and message. */ export type ErrorListener = (code: number, message: string) => void; -// Internal type of constructors used for initializing WasmMediaPipeLib and -// subclasses. -type WasmMediaPipeConstructor = +/** + * Internal type of constructors used for initializing GraphRunner and + * subclasses. + */ +export type WasmMediaPipeConstructor = (new ( module: WasmModule, canvas?: HTMLCanvasElement|OffscreenCanvas|null) => LibType); @@ -151,7 +149,7 @@ type WasmMediaPipeConstructor = * into canvas, or else return the output WebGLTexture. Takes a WebAssembly * Module (must be instantiated to self.Module). */ -export class WasmMediaPipeLib { +export class GraphRunner { // TODO: These should be protected/private, but are left exposed for // now so that we can use proper TS mixins with this class as a base. This // should be somewhat fixed when we create our .d.ts files. @@ -181,10 +179,14 @@ export class WasmMediaPipeLib { if (glCanvas !== undefined) { this.wasmModule.canvas = glCanvas; - } else { + } else if (typeof OffscreenCanvas !== 'undefined') { // If no canvas is provided, assume Chrome/Firefox and just make an // OffscreenCanvas for GPU processing. this.wasmModule.canvas = new OffscreenCanvas(1, 1); + } else { + console.warn('OffscreenCanvas not detected and GraphRunner constructor ' + + 'glCanvas parameter is undefined. Creating backup canvas.'); + this.wasmModule.canvas = document.createElement('canvas'); } } @@ -235,19 +237,38 @@ export class WasmMediaPipeLib { } /** - * Configures the current graph to handle audio in a certain way. Must be - * called before the graph is set/started in order to use processAudio. + * Configures the current graph to handle audio processing in a certain way + * for all its audio input streams. Additionally can configure audio headers + * (both input side packets as well as input stream headers), but these + * configurations only take effect if called before the graph is set/started. * @param numChannels The number of channels of audio input. Only 1 * is supported for now. * @param numSamples The number of samples that are taken in each * audio capture. * @param sampleRate The rate, in Hz, of the sampling. + * @param streamName The optional name of the input stream to additionally + * configure with audio information. This configuration only occurs before + * the graph is set/started. If unset, a default stream name will be used. + * @param headerName The optional name of the header input side packet to + * additionally configure with audio information. This configuration only + * occurs before the graph is set/started. If unset, a default header name + * will be used. */ - configureAudio(numChannels: number, numSamples: number, sampleRate: number) { - this.wasmModule._configureAudio(numChannels, numSamples, sampleRate); - if (this.wasmModule._attachAudioOutputListener) { - this.wasmModule._attachAudioOutputListener(); + configureAudio(numChannels: number, numSamples: number, sampleRate: number, + streamName?: string, headerName?: string) { + if (!this.wasmModule._configureAudio) { + console.warn( + 'Attempting to use configureAudio without support for input audio. ' + + 'Is build dep ":gl_graph_runner_audio" missing?'); } + streamName = streamName || 'input_audio'; + this.wrapStringPtr(streamName, (streamNamePtr: number) => { + headerName = headerName || 'audio_header'; + this.wrapStringPtr(headerName, (headerNamePtr: number) => { + this.wasmModule._configureAudio(streamNamePtr, headerNamePtr, + numChannels, numSamples, sampleRate); + }); + }); } /** @@ -305,6 +326,10 @@ export class WasmMediaPipeLib { if ((imageSource as HTMLVideoElement).videoWidth) { width = (imageSource as HTMLVideoElement).videoWidth; height = (imageSource as HTMLVideoElement).videoHeight; + } else if ((imageSource as HTMLImageElement).naturalWidth) { + // TODO: Ensure this works with SVG images + width = (imageSource as HTMLImageElement).naturalWidth; + height = (imageSource as HTMLImageElement).naturalHeight; } else { width = imageSource.width; height = imageSource.height; @@ -394,10 +419,12 @@ export class WasmMediaPipeLib { * Ensures existence of the simple listeners table and registers the callback. * Intended for internal usage. */ - setListener(outputStreamName: string, callbackFcn: (data: T) => void) { + setListener( + outputStreamName: string, + callbackFcn: (data: T, timestamp: number) => void) { this.wasmModule.simpleListeners = this.wasmModule.simpleListeners || {}; this.wasmModule.simpleListeners[outputStreamName] = - callbackFcn as (data: unknown) => void; + callbackFcn as (data: unknown, timestamp: number) => void; } /** @@ -405,11 +432,12 @@ export class WasmMediaPipeLib { * Intended for internal usage. */ setVectorListener( - outputStreamName: string, callbackFcn: (data: T[]) => void) { - const buffer: T[] = []; + outputStreamName: string, + callbackFcn: (data: T[], timestamp: number) => void) { + let buffer: T[] = []; this.wasmModule.vectorListeners = this.wasmModule.vectorListeners || {}; this.wasmModule.vectorListeners[outputStreamName] = - (data: unknown, index: number, length: number) => { + (data: unknown, index: number, length: number, timestamp: number) => { // The Wasm listener gets invoked once for each element. Once we // receive all elements, we invoke the registered callback with the // full array. @@ -418,7 +446,8 @@ export class WasmMediaPipeLib { // Invoke the user callback directly, as the Wasm layer may clean up // the underlying data elements once we leave the scope of the // listener. - callbackFcn(buffer); + callbackFcn(buffer, timestamp); + buffer = []; } }; } @@ -436,9 +465,36 @@ export class WasmMediaPipeLib { * processed. * @param audioData An array of raw audio capture data, like * from a call to getChannelData on an AudioBuffer. + * @param streamName The name of the MediaPipe graph stream to add the audio + * data to. * @param timestamp The timestamp of the current frame, in ms. */ - addAudioToStream(audioData: Float32Array, timestamp: number) { + addAudioToStream( + audioData: Float32Array, streamName: string, timestamp: number) { + // numChannels and numSamples being 0 will cause defaults to be used, + // which will reflect values from last call to configureAudio. + this.addAudioToStreamWithShape(audioData, 0, 0, streamName, timestamp); + } + + /** + * Takes the raw data from a JS audio capture array, and sends it to C++ to be + * processed, shaping the audioData array into an audio matrix according to + * the numChannels and numSamples parameters. + * @param audioData An array of raw audio capture data, like + * from a call to getChannelData on an AudioBuffer. + * @param numChannels The number of audio channels this data represents. If 0 + * is passed, then the value will be taken from the last call to + * configureAudio. + * @param numSamples The number of audio samples captured in this data packet. + * If 0 is passed, then the value will be taken from the last call to + * configureAudio. + * @param streamName The name of the MediaPipe graph stream to add the audio + * data to. + * @param timestamp The timestamp of the current frame, in ms. + */ + addAudioToStreamWithShape( + audioData: Float32Array, numChannels: number, numSamples: number, + streamName: string, timestamp: number) { // 4 bytes for each F32 const size = audioData.length * 4; if (this.audioSize !== size) { @@ -449,7 +505,11 @@ export class WasmMediaPipeLib { this.audioSize = size; } this.wasmModule.HEAPF32.set(audioData, this.audioPtr! / 4); - this.wasmModule._processAudio(this.audioPtr!, timestamp); + + this.wrapStringPtr(streamName, (streamNamePtr: number) => { + this.wasmModule._addAudioToInputStream( + this.audioPtr!, numChannels, numSamples, streamNamePtr, timestamp); + }); } /** @@ -684,7 +744,8 @@ export class WasmMediaPipeLib { * should not perform overly complicated (or any async) behavior. */ attachBoolListener( - outputStreamName: string, callbackFcn: (data: boolean) => void): void { + outputStreamName: string, + callbackFcn: (data: boolean, timestamp: number) => void): void { // Set up our TS listener to receive any packets for this stream. this.setListener(outputStreamName, callbackFcn); @@ -704,7 +765,8 @@ export class WasmMediaPipeLib { * should not perform overly complicated (or any async) behavior. */ attachBoolVectorListener( - outputStreamName: string, callbackFcn: (data: boolean[]) => void): void { + outputStreamName: string, + callbackFcn: (data: boolean[], timestamp: number) => void): void { // Set up our TS listener to receive any packets for this stream. this.setVectorListener(outputStreamName, callbackFcn); @@ -724,7 +786,8 @@ export class WasmMediaPipeLib { * should not perform overly complicated (or any async) behavior. */ attachIntListener( - outputStreamName: string, callbackFcn: (data: number) => void): void { + outputStreamName: string, + callbackFcn: (data: number, timestamp: number) => void): void { // Set up our TS listener to receive any packets for this stream. this.setListener(outputStreamName, callbackFcn); @@ -744,7 +807,8 @@ export class WasmMediaPipeLib { * should not perform overly complicated (or any async) behavior. */ attachIntVectorListener( - outputStreamName: string, callbackFcn: (data: number[]) => void): void { + outputStreamName: string, + callbackFcn: (data: number[], timestamp: number) => void): void { // Set up our TS listener to receive any packets for this stream. this.setVectorListener(outputStreamName, callbackFcn); @@ -764,7 +828,8 @@ export class WasmMediaPipeLib { * should not perform overly complicated (or any async) behavior. */ attachDoubleListener( - outputStreamName: string, callbackFcn: (data: number) => void): void { + outputStreamName: string, + callbackFcn: (data: number, timestamp: number) => void): void { // Set up our TS listener to receive any packets for this stream. this.setListener(outputStreamName, callbackFcn); @@ -784,7 +849,8 @@ export class WasmMediaPipeLib { * should not perform overly complicated (or any async) behavior. */ attachDoubleVectorListener( - outputStreamName: string, callbackFcn: (data: number[]) => void): void { + outputStreamName: string, + callbackFcn: (data: number[], timestamp: number) => void): void { // Set up our TS listener to receive any packets for this stream. this.setVectorListener(outputStreamName, callbackFcn); @@ -804,7 +870,8 @@ export class WasmMediaPipeLib { * should not perform overly complicated (or any async) behavior. */ attachFloatListener( - outputStreamName: string, callbackFcn: (data: number) => void): void { + outputStreamName: string, + callbackFcn: (data: number, timestamp: number) => void): void { // Set up our TS listener to receive any packets for this stream. this.setListener(outputStreamName, callbackFcn); @@ -824,7 +891,8 @@ export class WasmMediaPipeLib { * should not perform overly complicated (or any async) behavior. */ attachFloatVectorListener( - outputStreamName: string, callbackFcn: (data: number[]) => void): void { + outputStreamName: string, + callbackFcn: (data: number[], timestamp: number) => void): void { // Set up our TS listener to receive any packets for this stream. this.setVectorListener(outputStreamName, callbackFcn); @@ -844,7 +912,8 @@ export class WasmMediaPipeLib { * should not perform overly complicated (or any async) behavior. */ attachStringListener( - outputStreamName: string, callbackFcn: (data: string) => void): void { + outputStreamName: string, + callbackFcn: (data: string, timestamp: number) => void): void { // Set up our TS listener to receive any packets for this stream. this.setListener(outputStreamName, callbackFcn); @@ -864,7 +933,8 @@ export class WasmMediaPipeLib { * should not perform overly complicated (or any async) behavior. */ attachStringVectorListener( - outputStreamName: string, callbackFcn: (data: string[]) => void): void { + outputStreamName: string, + callbackFcn: (data: string[], timestamp: number) => void): void { // Set up our TS listener to receive any packets for this stream. this.setVectorListener(outputStreamName, callbackFcn); @@ -894,7 +964,8 @@ export class WasmMediaPipeLib { * with it). */ attachProtoListener( - outputStreamName: string, callbackFcn: (data: Uint8Array) => void, + outputStreamName: string, + callbackFcn: (data: Uint8Array, timestamp: number) => void, makeDeepCopy?: boolean): void { // Set up our TS listener to receive any packets for this stream. this.setListener(outputStreamName, callbackFcn); @@ -928,7 +999,8 @@ export class WasmMediaPipeLib { * with it). */ attachProtoVectorListener( - outputStreamName: string, callbackFcn: (data: Uint8Array[]) => void, + outputStreamName: string, + callbackFcn: (data: Uint8Array[], timestamp: number) => void, makeDeepCopy?: boolean): void { // Set up our TS listener to receive any packets for this stream. this.setVectorListener(outputStreamName, callbackFcn); @@ -942,17 +1014,50 @@ export class WasmMediaPipeLib { } /** - * Sets a listener to be called back with audio output packet data, as a - * Float32Array, when graph has finished processing it. - * @param audioOutputListener The caller's listener function. + * Attaches an audio packet listener to the specified output_stream, to be + * given a Float32Array as output. + * @param outputStreamName The name of the graph output stream to grab audio + * data from. + * @param callbackFcn The function that will be called back with the data, as + * it is received. Note that the data is only guaranteed to exist for the + * duration of the callback, and the callback will be called inline, so it + * should not perform overly complicated (or any async) behavior. If the + * audio data needs to be able to outlive the call, you may set the + * optional makeDeepCopy parameter to true, or can manually deep-copy the + * data yourself. + * @param makeDeepCopy Optional convenience parameter which, if set to true, + * will override the default memory management behavior and make a deep + * copy of the underlying data, rather than just returning a view into the + * C++-managed memory. At the cost of a data copy, this allows the + * returned data to outlive the callback lifetime (and it will be cleaned + * up automatically by JS garbage collection whenever the user is finished + * with it). */ - setOnAudioOutput(audioOutputListener: AudioOutputListener) { - this.wasmModule.onAudioOutput = audioOutputListener; - if (!this.wasmModule._attachAudioOutputListener) { + attachAudioListener( + outputStreamName: string, + callbackFcn: (data: Float32Array, timestamp: number) => void, + makeDeepCopy?: boolean): void { + if (!this.wasmModule._attachAudioListener) { console.warn( - 'Attempting to use AudioOutputListener without support for ' + + 'Attempting to use attachAudioListener without support for ' + 'output audio. Is build dep ":gl_graph_runner_audio_out" missing?'); } + + // Set up our TS listener to receive any packets for this stream, and + // additionally reformat our Uint8Array into a Float32Array for the user. + this.setListener( + outputStreamName, (data: Uint8Array, timestamp: number) => { + // Should be very fast + const floatArray = + new Float32Array(data.buffer, data.byteOffset, data.length / 4); + callbackFcn(floatArray, timestamp); + }); + + // Tell our graph to listen for string packets on this stream. + this.wrapStringPtr(outputStreamName, (outputStreamNamePtr: number) => { + this.wasmModule._attachAudioListener( + outputStreamNamePtr, makeDeepCopy || false); + }); } /** @@ -971,7 +1076,7 @@ async function runScript(scriptUrl: string) { importScripts(scriptUrl.toString()); } else { const script = document.createElement('script'); - script.setAttribute('url', scriptUrl); + script.setAttribute('src', scriptUrl); script.setAttribute('crossorigin', 'anonymous'); return new Promise((resolve) => { script.addEventListener('load', () => { @@ -988,7 +1093,7 @@ async function runScript(scriptUrl: string) { /** * Global function to initialize Wasm blob and load runtime assets for a * specialized MediaPipe library. This allows us to create a requested - * subclass inheriting from WasmMediaPipeLib. + * subclass inheriting from GraphRunner. * @param constructorFcn The name of the class to instantiate via "new". * @param wasmLoaderScript Url for the wasm-runner script; produced by the build * process. @@ -1001,8 +1106,8 @@ async function runScript(scriptUrl: string) { */ export async function createMediaPipeLib( constructorFcn: WasmMediaPipeConstructor, - wasmLoaderScript?: string, - assetLoaderScript?: string, + wasmLoaderScript?: string|null, + assetLoaderScript?: string|null, glCanvas?: HTMLCanvasElement|OffscreenCanvas|null, fileLocator?: FileLocator): Promise { const scripts = []; @@ -1042,12 +1147,12 @@ export async function createMediaPipeLib( * @return promise A promise which will resolve when initialization has * completed successfully. */ -export async function createWasmMediaPipeLib( +export async function createGraphRunner( wasmLoaderScript?: string, assetLoaderScript?: string, glCanvas?: HTMLCanvasElement|OffscreenCanvas|null, - fileLocator?: FileLocator): Promise { + fileLocator?: FileLocator): Promise { return createMediaPipeLib( - WasmMediaPipeLib, wasmLoaderScript, assetLoaderScript, glCanvas, + GraphRunner, wasmLoaderScript, assetLoaderScript, glCanvas, fileLocator); } diff --git a/mediapipe/web/graph_runner/graph_runner_image_lib.ts b/mediapipe/web/graph_runner/graph_runner_image_lib.ts new file mode 100644 index 000000000..9608ebcc7 --- /dev/null +++ b/mediapipe/web/graph_runner/graph_runner_image_lib.ts @@ -0,0 +1,113 @@ +import {GraphRunner, ImageSource} from './graph_runner'; + + + +/** + * We extend from a GraphRunner constructor. This ensures our mixin has + * access to the wasmModule, among other things. The `any` type is required for + * mixin constructors. + */ +// tslint:disable-next-line:no-any +type LibConstructor = new (...args: any[]) => GraphRunner; + +/** An image returned from a MediaPipe graph. */ +export interface WasmImage { + data: Uint8Array|Float32Array; + width: number; + height: number; +} +/** + * Declarations for Emscripten's WebAssembly Module behavior, so TS compiler + * doesn't break our JS/C++ bridge. + */ +export declare interface WasmImageModule { + _addBoundTextureAsImageToStream: + (streamNamePtr: number, width: number, height: number, + timestamp: number) => void; + _attachImageListener: (streamNamePtr: number) => void; + _attachImageVectorListener: (streamNamePtr: number) => void; +} + +/** + * An implementation of GraphRunner that supports binding GPU image data as + * `mediapipe::Image` instances. We implement as a proper TS mixin, to allow + * for effective multiple inheritance. Example usage: `const GraphRunnerImageLib + * = SupportImage(GraphRunner);` + */ +// tslint:disable-next-line:enforce-name-casing +export function SupportImage(Base: TBase) { + return class extends Base { + get wasmImageModule(): WasmImageModule { + return this.wasmModule as unknown as WasmImageModule; + } + + /** + * Takes the relevant information from the HTML video or image element, + * and passes it into the WebGL-based graph for processing on the given + * stream at the given timestamp as a MediaPipe image. Processing will not + * occur until a blocking call (like processVideoGl or finishProcessing) + * is made. + * @param imageSource Reference to the video frame we wish to add into our + * graph. + * @param streamName The name of the MediaPipe graph stream to add the + * frame to. + * @param timestamp The timestamp of the input frame, in ms. + */ + addGpuBufferAsImageToStream( + imageSource: ImageSource, streamName: string, timestamp: number): void { + this.wrapStringPtr(streamName, (streamNamePtr: number) => { + const [width, height] = + this.bindTextureToStream(imageSource, streamNamePtr); + this.wasmImageModule._addBoundTextureAsImageToStream( + streamNamePtr, width, height, timestamp); + }); + } + + /** + * Attaches a mediapipe:Image packet listener to the specified output + * stream. + * @param outputStreamName The name of the graph output stream to grab + * mediapipe::Image data from. + * @param callbackFcn The function that will be called back with the data, + * as it is received. Note that the data is only guaranteed to exist + * for the duration of the callback, and the callback will be called + * inline, so it should not perform overly complicated (or any async) + * behavior. + */ + attachImageListener( + outputStreamName: string, + callbackFcn: (data: WasmImage, timestamp: number) => void): void { + // Set up our TS listener to receive any packets for this stream. + this.setListener(outputStreamName, callbackFcn); + + // Tell our graph to listen for mediapipe::Image packets on this stream. + this.wrapStringPtr(outputStreamName, (outputStreamNamePtr: number) => { + this.wasmImageModule._attachImageListener(outputStreamNamePtr); + }); + } + + /** + * Attaches a mediapipe:Image[] packet listener to the specified + * output_stream. + * @param outputStreamName The name of the graph output stream to grab + * std::vector data from. + * @param callbackFcn The function that will be called back with the data, + * as it is received. Note that the data is only guaranteed to exist + * for the duration of the callback, and the callback will be called + * inline, so it should not perform overly complicated (or any async) + * behavior. + */ + attachImageVectorListener( + outputStreamName: string, + callbackFcn: (data: WasmImage[], timestamp: number) => void): void { + // Set up our TS listener to receive any packets for this stream. + this.setVectorListener(outputStreamName, callbackFcn); + + // Tell our graph to listen for std::vector packets on + // this stream. + this.wrapStringPtr(outputStreamName, (outputStreamNamePtr: number) => { + this.wasmImageModule._attachImageVectorListener(outputStreamNamePtr); + }); + } + }; +} diff --git a/mediapipe/web/graph_runner/register_model_resources_graph_service.ts b/mediapipe/web/graph_runner/register_model_resources_graph_service.ts index e85d63b06..9f2791d80 100644 --- a/mediapipe/web/graph_runner/register_model_resources_graph_service.ts +++ b/mediapipe/web/graph_runner/register_model_resources_graph_service.ts @@ -1,12 +1,12 @@ -import {WasmMediaPipeLib} from './wasm_mediapipe_lib'; +import {GraphRunner} from './graph_runner'; /** - * We extend from a WasmMediaPipeLib constructor. This ensures our mixin has + * We extend from a GraphRunner constructor. This ensures our mixin has * access to the wasmModule, among other things. The `any` type is required for * mixin constructors. */ // tslint:disable-next-line:no-any -type LibConstructor = new (...args: any[]) => WasmMediaPipeLib; +type LibConstructor = new (...args: any[]) => GraphRunner; /** * Declarations for Emscripten's WebAssembly Module behavior, so TS compiler @@ -17,11 +17,11 @@ export declare interface WasmModuleRegisterModelResources { } /** - * An implementation of WasmMediaPipeLib that supports registering model + * An implementation of GraphRunner that supports registering model * resources to a cache, in the form of a GraphService C++-side. We implement as * a proper TS mixin, to allow for effective multiple inheritance. Sample usage: - * `const WasmMediaPipeImageLib = SupportModelResourcesGraphService( - * WasmMediaPipeLib);` + * `const GraphRunnerWithModelResourcesLib = + * SupportModelResourcesGraphService(GraphRunner);` */ // tslint:disable:enforce-name-casing export function SupportModelResourcesGraphService( diff --git a/mediapipe/web/graph_runner/wasm_mediapipe_image_lib.ts b/mediapipe/web/graph_runner/wasm_mediapipe_image_lib.ts deleted file mode 100644 index 3b45e8230..000000000 --- a/mediapipe/web/graph_runner/wasm_mediapipe_image_lib.ts +++ /dev/null @@ -1,52 +0,0 @@ -import {ImageSource, WasmMediaPipeLib} from './wasm_mediapipe_lib'; - -/** - * We extend from a WasmMediaPipeLib constructor. This ensures our mixin has - * access to the wasmModule, among other things. The `any` type is required for - * mixin constructors. - */ -// tslint:disable-next-line:no-any -type LibConstructor = new (...args: any[]) => WasmMediaPipeLib; - -/** - * Declarations for Emscripten's WebAssembly Module behavior, so TS compiler - * doesn't break our JS/C++ bridge. - */ -export declare interface WasmImageModule { - _addBoundTextureAsImageToStream: - (streamNamePtr: number, width: number, height: number, - timestamp: number) => void; -} - -/** - * An implementation of WasmMediaPipeLib that supports binding GPU image data as - * `mediapipe::Image` instances. We implement as a proper TS mixin, to allow for - * effective multiple inheritance. Example usage: - * `const WasmMediaPipeImageLib = SupportImage(WasmMediaPipeLib);` - */ -// tslint:disable-next-line:enforce-name-casing -export function SupportImage(Base: TBase) { - return class extends Base { - /** - * Takes the relevant information from the HTML video or image element, and - * passes it into the WebGL-based graph for processing on the given stream - * at the given timestamp as a MediaPipe image. Processing will not occur - * until a blocking call (like processVideoGl or finishProcessing) is made. - * @param imageSource Reference to the video frame we wish to add into our - * graph. - * @param streamName The name of the MediaPipe graph stream to add the frame - * to. - * @param timestamp The timestamp of the input frame, in ms. - */ - addGpuBufferAsImageToStream( - imageSource: ImageSource, streamName: string, timestamp: number): void { - this.wrapStringPtr(streamName, (streamNamePtr: number) => { - const [width, height] = - this.bindTextureToStream(imageSource, streamNamePtr); - (this.wasmModule as unknown as WasmImageModule) - ._addBoundTextureAsImageToStream( - streamNamePtr, width, height, timestamp); - }); - } - }; -} diff --git a/package.json b/package.json index 298157cbc..89b62bc83 100644 --- a/package.json +++ b/package.json @@ -3,13 +3,19 @@ "version": "0.0.0-alphga", "description": "MediaPipe GitHub repo", "devDependencies": { + "@bazel/jasmine": "^5.7.2", "@bazel/rollup": "^5.7.1", "@bazel/typescript": "^5.7.1", "@rollup/plugin-commonjs": "^23.0.2", "@rollup/plugin-node-resolve": "^15.0.1", + "@rollup/plugin-terser": "^0.1.0", "@types/google-protobuf": "^3.15.6", + "@types/jasmine": "^4.3.1", + "@types/node": "^18.11.11", "@types/offscreencanvas": "^2019.7.0", "google-protobuf": "^3.21.2", + "jasmine": "^4.5.0", + "jasmine-core": "^4.5.0", "protobufjs": "^7.1.2", "protobufjs-cli": "^1.0.2", "rollup": "^2.3.0", diff --git a/setup.py b/setup.py index b072a850e..992430cf1 100644 --- a/setup.py +++ b/setup.py @@ -490,10 +490,10 @@ setuptools.setup( 'Operating System :: MacOS :: MacOS X', 'Operating System :: Microsoft :: Windows', 'Operating System :: POSIX :: Linux', - 'Programming Language :: Python :: 3.7', 'Programming Language :: Python :: 3.8', 'Programming Language :: Python :: 3.9', 'Programming Language :: Python :: 3.10', + 'Programming Language :: Python :: 3.11', 'Programming Language :: Python :: 3 :: Only', 'Topic :: Scientific/Engineering', 'Topic :: Scientific/Engineering :: Artificial Intelligence', diff --git a/third_party/apple_frameworks/BUILD b/third_party/apple_frameworks/BUILD new file mode 100644 index 000000000..05f830e81 --- /dev/null +++ b/third_party/apple_frameworks/BUILD @@ -0,0 +1,73 @@ +# Build rules to inject Apple Frameworks + +package(default_visibility = ["//visibility:public"]) + +cc_library( + name = "CoreGraphics", + linkopts = ["-framework CoreGraphics"], +) + +cc_library( + name = "CoreMedia", + linkopts = ["-framework CoreMedia"], +) + +cc_library( + name = "UIKit", + linkopts = ["-framework UIKit"], +) + +cc_library( + name = "Accelerate", + linkopts = ["-framework Accelerate"], +) + +cc_library( + name = "CoreVideo", + linkopts = ["-framework CoreVideo"], +) + +cc_library( + name = "Metal", + linkopts = ["-framework Metal"], +) + +cc_library( + name = "MetalPerformanceShaders", + linkopts = ["-framework MetalPerformanceShaders"], +) + +cc_library( + name = "AVFoundation", + linkopts = ["-framework AVFoundation"], +) + +cc_library( + name = "Foundation", + linkopts = ["-framework Foundation"], +) + +cc_library( + name = "CoreImage", + linkopts = ["-framework CoreImage"], +) + +cc_library( + name = "XCTest", + linkopts = ["-framework XCTest"], +) + +cc_library( + name = "GLKit", + linkopts = ["-framework GLKit"], +) + +cc_library( + name = "OpenGLES", + linkopts = ["-framework OpenGLES"], +) + +cc_library( + name = "QuartzCore", + linkopts = ["-framework QuartzCore"], +) diff --git a/third_party/com_google_absl_f863b622fe13612433fdf43f76547d5edda0c93001.diff b/third_party/com_google_absl_f863b622fe13612433fdf43f76547d5edda0c93001.diff deleted file mode 100644 index 0cd2dffa4..000000000 --- a/third_party/com_google_absl_f863b622fe13612433fdf43f76547d5edda0c93001.diff +++ /dev/null @@ -1,14 +0,0 @@ -diff --git a/absl/time/internal/cctz/BUILD.bazel b/absl/time/internal/cctz/BUILD.bazel -index 9fceffe..e7f9d01 100644 ---- a/absl/time/internal/cctz/BUILD.bazel -+++ b/absl/time/internal/cctz/BUILD.bazel -@@ -69,8 +69,5 @@ cc_library( - "include/cctz/zone_info_source.h", - ], - linkopts = select({ -- ":osx": [ -- "-framework Foundation", -- ], - ":ios": [ - "-framework Foundation", - ], \ No newline at end of file diff --git a/third_party/com_google_absl_windows_patch.diff b/third_party/com_google_absl_windows_patch.diff new file mode 100644 index 000000000..a4b5b96bb --- /dev/null +++ b/third_party/com_google_absl_windows_patch.diff @@ -0,0 +1,13 @@ +diff --git a/absl/types/compare.h b/absl/types/compare.h +index 19b076e..0201004 100644 +--- a/absl/types/compare.h ++++ b/absl/types/compare.h +@@ -84,7 +84,7 @@ enum class ncmp : value_type { unordered = -127 }; + // based on whether the feature is supported. Note: we can't use + // ABSL_INTERNAL_INLINE_CONSTEXPR here because the variables here are of + // incomplete types so they need to be defined after the types are complete. +-#ifdef __cpp_inline_variables ++#if defined(__cpp_inline_variables) && !(defined(_MSC_VER) && _MSC_VER <= 1916) + + // A no-op expansion that can be followed by a semicolon at class level. + #define ABSL_COMPARE_INLINE_BASECLASS_DECL(name) static_assert(true, "") \ No newline at end of file diff --git a/third_party/external_files.bzl b/third_party/external_files.bzl index 1f0b00289..1d9239c83 100644 --- a/third_party/external_files.bzl +++ b/third_party/external_files.bzl @@ -90,8 +90,8 @@ def external_files(): http_file( name = "com_google_mediapipe_canned_gesture_classifier_tflite", - sha256 = "2fc7e279966a7a9e15fc869223793e390791fc61fdc0062f9bc7d0eef6be98a2", - urls = ["https://storage.googleapis.com/mediapipe-assets/canned_gesture_classifier.tflite?generation=1668124189331326"], + sha256 = "ee121d85979de1b86126faabb0a0f4d2e4039c3e33e2cd687db50571001b24d0", + urls = ["https://storage.googleapis.com/mediapipe-assets/canned_gesture_classifier.tflite?generation=1668550473107417"], ) http_file( @@ -240,14 +240,14 @@ def external_files(): http_file( name = "com_google_mediapipe_face_detection_full_range_sparse_tflite", - sha256 = "671dd2f9ed11a78436fc21cc42357a803dfc6f73e9fb86541be942d5716c2dce", - urls = ["https://storage.googleapis.com/mediapipe-assets/face_detection_full_range_sparse.tflite?generation=1661875739104017"], + sha256 = "2c3728e6da56f21e21a320433396fb06d40d9088f2247c05e5635a688d45dfe1", + urls = ["https://storage.googleapis.com/mediapipe-assets/face_detection_full_range_sparse.tflite?generation=1674261618323821"], ) http_file( name = "com_google_mediapipe_face_detection_full_range_tflite", - sha256 = "99bf9494d84f50acc6617d89873f71bf6635a841ea699c17cb3377f9507cfec3", - urls = ["https://storage.googleapis.com/mediapipe-assets/face_detection_full_range.tflite?generation=1661875742733283"], + sha256 = "3698b18f063835bc609069ef052228fbe86d9c9a6dc8dcb7c7c2d69aed2b181b", + urls = ["https://storage.googleapis.com/mediapipe-assets/face_detection_full_range.tflite?generation=1674261620964007"], ) http_file( @@ -286,6 +286,12 @@ def external_files(): urls = ["https://storage.googleapis.com/mediapipe-assets/fist_landmarks.pbtxt?generation=1666999360561864"], ) + http_file( + name = "com_google_mediapipe_fist_png", + sha256 = "4397b3d3f590c88a8de7d21c08d73a0df4a97fd93f92cbd086eef37fd246daaa", + urls = ["https://storage.googleapis.com/mediapipe-assets/fist.png?generation=1672952068696274"], + ) + http_file( name = "com_google_mediapipe_general_meta_json", sha256 = "b95363e4bae89b9c2af484498312aaad4efc7ff57c7eadcc4e5e7adca641445f", @@ -294,8 +300,8 @@ def external_files(): http_file( name = "com_google_mediapipe_gesture_embedder_tflite", - sha256 = "54abe78de1d1cd5e3cdaa0dab01db18e3ec7e09a76e7c3b5fa278572f7a60977", - urls = ["https://storage.googleapis.com/mediapipe-assets/gesture_embedder.tflite?generation=1668124192126494"], + sha256 = "927e4f6cbe6451da6b4fd1485e2576a6f8dbd95062666661cbd9dea893c41d01", + urls = ["https://storage.googleapis.com/mediapipe-assets/gesture_embedder.tflite?generation=1668550476472972"], ) http_file( @@ -706,6 +712,18 @@ def external_files(): urls = ["https://storage.googleapis.com/mediapipe-assets/pointing_up_rotated_landmarks.pbtxt?generation=1666629486774022"], ) + http_file( + name = "com_google_mediapipe_portrait_expected_detection_pbtxt", + sha256 = "bb54e08e87844ef14bb185d5cb808908eb6011bfa6db48bd22d9650f6fda338b", + urls = ["https://storage.googleapis.com/mediapipe-assets/portrait_expected_detection.pbtxt?generation=1674261627835475"], + ) + + http_file( + name = "com_google_mediapipe_portrait_jpg", + sha256 = "a6f11efaa834706db23f275b6115058fa87fc7f14362681e6abe14e82749de3e", + urls = ["https://storage.googleapis.com/mediapipe-assets/portrait.jpg?generation=1674261630039907"], + ) + http_file( name = "com_google_mediapipe_pose_detection_tflite", sha256 = "a63c614bef30d35947f13be361820b1e4e3bec9cfeebf4d11216a18373108e85", @@ -990,14 +1008,26 @@ def external_files(): http_file( name = "com_google_mediapipe_gesture_embedder_keras_metadata_pb", - sha256 = "24268b69429be4e307f9ab099ba20d1de7c40e4191a53f6a92dcbbd97a7047d3", - urls = ["https://storage.googleapis.com/mediapipe-assets/gesture_embedder/keras_metadata.pb?generation=1668124196996131"], + sha256 = "c76b856101e2284293a5e5963b7c445e407a0b3e56ec63eb78f64d883e51e3aa", + urls = ["https://storage.googleapis.com/mediapipe-assets/gesture_embedder/keras_metadata.pb?generation=1668550482128410"], ) http_file( name = "com_google_mediapipe_gesture_embedder_saved_model_pb", - sha256 = "f3a2870ba3ef537a4f6a5889ffc5b7061ad98f9fd96ec431a62116892f100659", - urls = ["https://storage.googleapis.com/mediapipe-assets/gesture_embedder/saved_model.pb?generation=1668124199460071"], + sha256 = "0082d37c5b85487fbf553e00a63f640945faf3da2d561a5f5a24c3194fecda6a", + urls = ["https://storage.googleapis.com/mediapipe-assets/gesture_embedder/saved_model.pb?generation=1668550484904822"], + ) + + http_file( + name = "com_google_mediapipe_mobilebert_tiny_keras_metadata_pb", + sha256 = "cef8131a414c602b9d4742ac57f4f90bc5d8a42baec36b65deece884e2d0cf0f", + urls = ["https://storage.googleapis.com/mediapipe-assets/mobilebert_tiny/keras_metadata.pb?generation=1673297965144159"], + ) + + http_file( + name = "com_google_mediapipe_mobilebert_tiny_saved_model_pb", + sha256 = "323c997cd3e17df1b2e3bdebe3cfe2b17c5ffd9488a26a4afb59ee819196837a", + urls = ["https://storage.googleapis.com/mediapipe-assets/mobilebert_tiny/saved_model.pb?generation=1673297968138825"], ) http_file( @@ -1038,12 +1068,30 @@ def external_files(): http_file( name = "com_google_mediapipe_gesture_embedder_variables_variables_data-00000-of-00001", - sha256 = "9fdb750c4bac67afb9c0f61916510930b496cc47e7f89449aee2bec6b6ed0af8", - urls = ["https://storage.googleapis.com/mediapipe-assets/gesture_embedder/variables/variables.data-00000-of-00001?generation=1668124201918980"], + sha256 = "c156c9654c9ffb1091bb9f06c71080bd1e428586276d3f39c33fbab27fe0522d", + urls = ["https://storage.googleapis.com/mediapipe-assets/gesture_embedder/variables/variables.data-00000-of-00001?generation=1668550487965052"], ) http_file( name = "com_google_mediapipe_gesture_embedder_variables_variables_index", - sha256 = "3ccbcee9488fec4627d496abd9837997276b32b839a4d0ae434bd806fe380b86", - urls = ["https://storage.googleapis.com/mediapipe-assets/gesture_embedder/variables/variables.index?generation=1668124204353848"], + sha256 = "76ea482b8da6bdb3d65d3b2ea989c1699c9fa0d6df0cb6d80863d1dc6fe7c4bd", + urls = ["https://storage.googleapis.com/mediapipe-assets/gesture_embedder/variables/variables.index?generation=1668550490691823"], + ) + + http_file( + name = "com_google_mediapipe_mobilebert_tiny_assets_vocab_txt", + sha256 = "07eced375cec144d27c900241f3e339478dec958f92fddbc551f295c992038a3", + urls = ["https://storage.googleapis.com/mediapipe-assets/mobilebert_tiny/assets/vocab.txt?generation=1673297970948751"], + ) + + http_file( + name = "com_google_mediapipe_mobilebert_tiny_variables_variables_data-00000-of-00001", + sha256 = "c3857370046cd3a2f345657cf1bb259a4e7e09185d7f0808e57803e9d41ebba4", + urls = ["https://storage.googleapis.com/mediapipe-assets/mobilebert_tiny/variables/variables.data-00000-of-00001?generation=1673297975132568"], + ) + + http_file( + name = "com_google_mediapipe_mobilebert_tiny_variables_variables_index", + sha256 = "4df4d7c0fefe99903ab6ebf44b7478196ce613082d2ca692a5a37a7f24e562ed", + urls = ["https://storage.googleapis.com/mediapipe-assets/mobilebert_tiny/variables/variables.index?generation=1673297977586840"], ) diff --git a/third_party/wasm_files.bzl b/third_party/wasm_files.bzl index 6bfde21ba..017d84466 100644 --- a/third_party/wasm_files.bzl +++ b/third_party/wasm_files.bzl @@ -12,36 +12,72 @@ def wasm_files(): http_file( name = "com_google_mediapipe_wasm_audio_wasm_internal_js", - sha256 = "9419766229f24790388805d891af907cf11fe8e2cdacabcf016feb054b720c82", - urls = ["https://storage.googleapis.com/mediapipe-assets/wasm/audio_wasm_internal.js?generation=1667934266184984"], - ) - - http_file( - name = "com_google_mediapipe_wasm_text_wasm_internal_js", - sha256 = "39d9445ab3b90f625a3332251fe82e59b40cd0501a5657475f3b115b7c6122c8", - urls = ["https://storage.googleapis.com/mediapipe-assets/wasm/text_wasm_internal.js?generation=1667934268229056"], - ) - - http_file( - name = "com_google_mediapipe_wasm_vision_wasm_internal_js", - sha256 = "b43c7078fe5da72990394af4fefd798bd844b4ac47849a49067bd68c3c910a3d", - urls = ["https://storage.googleapis.com/mediapipe-assets/wasm/vision_wasm_internal.js?generation=1667934270239845"], + sha256 = "d4d205d08e3e1b09662a9a358d0107e8a8023827ba9b6982a3777bb6c040f936", + urls = ["https://storage.googleapis.com/mediapipe-assets/wasm/audio_wasm_internal.js?generation=1673996821002628"], ) http_file( name = "com_google_mediapipe_wasm_audio_wasm_internal_wasm", - sha256 = "9f2abe2a51d1ebc854859f620759cec1cc643773f3748d0d19e0868578c3d746", - urls = ["https://storage.googleapis.com/mediapipe-assets/wasm/audio_wasm_internal.wasm?generation=1667934272818542"], + sha256 = "1b2ffe82b0a25d20188237a724a7cad68d068818a7738f91c69c782314f55965", + urls = ["https://storage.googleapis.com/mediapipe-assets/wasm/audio_wasm_internal.wasm?generation=1673996823772372"], + ) + + http_file( + name = "com_google_mediapipe_wasm_audio_wasm_nosimd_internal_js", + sha256 = "1f367c2d667628b178251aec7fd464327351570edac4549450b11fb82f5f0fd4", + urls = ["https://storage.googleapis.com/mediapipe-assets/wasm/audio_wasm_nosimd_internal.js?generation=1673996826132845"], + ) + + http_file( + name = "com_google_mediapipe_wasm_audio_wasm_nosimd_internal_wasm", + sha256 = "35c6ad888c06025dba1f9c8edb70e6c7be7e94e45dc2c0236a2fcfe61991dc44", + urls = ["https://storage.googleapis.com/mediapipe-assets/wasm/audio_wasm_nosimd_internal.wasm?generation=1673996828935550"], + ) + + http_file( + name = "com_google_mediapipe_wasm_text_wasm_internal_js", + sha256 = "68c0134e0b3cb986c3526cd645f74cc5a1f6ab19292276ca7d3558b89801e205", + urls = ["https://storage.googleapis.com/mediapipe-assets/wasm/text_wasm_internal.js?generation=1673996831356232"], ) http_file( name = "com_google_mediapipe_wasm_text_wasm_internal_wasm", - sha256 = "8334caec5fb10cd1f936f6ee41f8853771c7bf3a421f5c15c39ee41aa503ca54", - urls = ["https://storage.googleapis.com/mediapipe-assets/wasm/text_wasm_internal.wasm?generation=1667934275451198"], + sha256 = "df82bb192ea852dc1bcc8f9f28fbd8c3d6b219dc4fec2b2a92451678d98ee1f0", + urls = ["https://storage.googleapis.com/mediapipe-assets/wasm/text_wasm_internal.wasm?generation=1673996834657078"], + ) + + http_file( + name = "com_google_mediapipe_wasm_text_wasm_nosimd_internal_js", + sha256 = "de1a4aabefb2e42ae4fee68b7e762e328623a163257a7ddc72365fc2502bd090", + urls = ["https://storage.googleapis.com/mediapipe-assets/wasm/text_wasm_nosimd_internal.js?generation=1673996837104551"], + ) + + http_file( + name = "com_google_mediapipe_wasm_text_wasm_nosimd_internal_wasm", + sha256 = "828dd1e73fa9478a97a62539117f92b813833ab35d37a986c466df15a8cfdc7b", + urls = ["https://storage.googleapis.com/mediapipe-assets/wasm/text_wasm_nosimd_internal.wasm?generation=1673996840120504"], + ) + + http_file( + name = "com_google_mediapipe_wasm_vision_wasm_internal_js", + sha256 = "c146b68523c256d41132230e811fc224dafb6a0bce6fc318c29dad37dfac06de", + urls = ["https://storage.googleapis.com/mediapipe-assets/wasm/vision_wasm_internal.js?generation=1673996842448396"], ) http_file( name = "com_google_mediapipe_wasm_vision_wasm_internal_wasm", - sha256 = "b996eaa324da151359ad8e16edad27d9768505f1fd073625bc50dbb0f252e098", - urls = ["https://storage.googleapis.com/mediapipe-assets/wasm/vision_wasm_internal.wasm?generation=1667934277855507"], + sha256 = "8dbccaaf944ef1251cf78190450ab7074abea233e18ebb37d2c2ce0f18d14a0c", + urls = ["https://storage.googleapis.com/mediapipe-assets/wasm/vision_wasm_internal.wasm?generation=1673996845499070"], + ) + + http_file( + name = "com_google_mediapipe_wasm_vision_wasm_nosimd_internal_js", + sha256 = "705f9e3c2c62d12903ea2cadc22d2c328bc890f96fffc47b51f989471196ecea", + urls = ["https://storage.googleapis.com/mediapipe-assets/wasm/vision_wasm_nosimd_internal.js?generation=1673996847915731"], + ) + + http_file( + name = "com_google_mediapipe_wasm_vision_wasm_nosimd_internal_wasm", + sha256 = "c7ff6a7d8dc22380e2e8457a15a51b6bc1e70c6262fecca25825f54ecc593d1f", + urls = ["https://storage.googleapis.com/mediapipe-assets/wasm/vision_wasm_nosimd_internal.wasm?generation=1673996850980344"], ) diff --git a/tsconfig.json b/tsconfig.json index c17b1902e..970246dbb 100644 --- a/tsconfig.json +++ b/tsconfig.json @@ -10,7 +10,7 @@ "inlineSourceMap": true, "inlineSources": true, "strict": true, - "types": ["@types/offscreencanvas"], + "types": ["@types/offscreencanvas", "@types/jasmine", "node"], "rootDirs": [ ".", "./bazel-out/host/bin", diff --git a/yarn.lock b/yarn.lock index a5ec6fb13..9c4d91d30 100644 --- a/yarn.lock +++ b/yarn.lock @@ -3,34 +3,92 @@ "@babel/parser@^7.9.4": - version "7.20.3" - resolved "https://registry.yarnpkg.com/@babel/parser/-/parser-7.20.3.tgz#5358cf62e380cf69efcb87a7bb922ff88bfac6e2" - integrity sha512-OP/s5a94frIPXwjzEcv5S/tpQfc6XhxYUnmWpgdqMWGgYCuErA3SzozaRAMQgSZWKeTJxht9aWAkUY+0UzvOFg== + version "7.20.5" + resolved "https://registry.yarnpkg.com/@babel/parser/-/parser-7.20.5.tgz#7f3c7335fe417665d929f34ae5dceae4c04015e8" + integrity sha512-r27t/cy/m9uKLXQNWWebeCUHgnAZq0CpG1OwKRxzJMP1vpSU4bSIK2hq+/cp0bQxetkXx38n09rNu8jVkcK/zA== + +"@bazel/jasmine@^5.7.2": + version "5.7.2" + resolved "https://registry.yarnpkg.com/@bazel/jasmine/-/jasmine-5.7.2.tgz#438f272e66e939106cbdd58db709cd6aa008131b" + integrity sha512-RJruOB6S9e0efTNIE2JVdaslguUXh5KcmLUCq/xLCt0zENP44ssp9OooDIrZ8H+Sp4mLDNBX7CMMA9WTsbsxTQ== + dependencies: + c8 "~7.5.0" + jasmine-reporters "~2.5.0" "@bazel/rollup@^5.7.1": - version "5.7.1" - resolved "https://registry.yarnpkg.com/@bazel/rollup/-/rollup-5.7.1.tgz#6f644c2d493a5bd9cd3724a6f239e609585c6e37" - integrity sha512-LLNogoK2Qx9GIJVywQ+V/czjud8236mnaRX//g7qbOyXoWZDQvAEgsxRHq+lS/XX9USbh+zJJlfb+Dfp/PXx4A== + version "5.7.2" + resolved "https://registry.yarnpkg.com/@bazel/rollup/-/rollup-5.7.2.tgz#9953b06e3de52794791cee4f89540c263b035fcf" + integrity sha512-yGWLheSKdMnJ/Y3/qg+zCDx/qkD04FBFp+BjRS8xP4yvlz9G4rW3zc45VzHHz3oOywgQaY1vhfKuZMCcjTGEyA== dependencies: - "@bazel/worker" "5.7.1" + "@bazel/worker" "5.7.2" "@bazel/typescript@^5.7.1": - version "5.7.1" - resolved "https://registry.yarnpkg.com/@bazel/typescript/-/typescript-5.7.1.tgz#e585bcdc54a4ccb23d99c3e1206abf4853cf0682" - integrity sha512-MAnAtFxA2znadm81+rbYXcyWX1DEF/urzZ1F4LBq+w27EQ4PGyqIqCM5om7JcoSZJwjjMoBJc3SflRsMrZZ6+g== + version "5.7.2" + resolved "https://registry.yarnpkg.com/@bazel/typescript/-/typescript-5.7.2.tgz#a341215dc93ce28794e8430b311756816140bd78" + integrity sha512-tarBJBEIirnq/YaeYu18vXcDxjzlq4xhCXvXUxA0lhHX5oArjEcAEn4tmO0jF+t/7cbkAdMT7daG6vIHSz0QAA== dependencies: - "@bazel/worker" "5.7.1" + "@bazel/worker" "5.7.2" semver "5.6.0" source-map-support "0.5.9" tsutils "3.21.0" -"@bazel/worker@5.7.1": - version "5.7.1" - resolved "https://registry.yarnpkg.com/@bazel/worker/-/worker-5.7.1.tgz#2c4a9bd0e0ef75e496aec9599ff64a87307e7dad" - integrity sha512-UndmQVRqK0t0NMNl8I1P5XmxzdPvMA0X6jufszpfwy5gyzjOxeiOIzmC0ALCOx78CuJqOB/8WOI1pwTRmhd0tg== +"@bazel/worker@5.7.2": + version "5.7.2" + resolved "https://registry.yarnpkg.com/@bazel/worker/-/worker-5.7.2.tgz#43d800dc1b5a3707340a4eb0102da81c53fc6f63" + integrity sha512-H+auDA0QKF4mtZxKkZ2OKJvD7hGXVsVKtvcf4lbb93ur0ldpb5k810PcDxngmIGBcIX5kmyxniNTIiGFNobWTg== dependencies: google-protobuf "^3.6.1" +"@bcoe/v8-coverage@^0.2.3": + version "0.2.3" + resolved "https://registry.yarnpkg.com/@bcoe/v8-coverage/-/v8-coverage-0.2.3.tgz#75a2e8b51cb758a7553d6804a5932d7aace75c39" + integrity sha512-0hYQ8SB4Db5zvZB4axdMHGwEaQjkZzFjQiN9LVYvIFB2nSUHW9tYpxWriPrWDASIxiaXax83REcLxuSdnGPZtw== + +"@istanbuljs/schema@^0.1.2": + version "0.1.3" + resolved "https://registry.yarnpkg.com/@istanbuljs/schema/-/schema-0.1.3.tgz#e45e384e4b8ec16bce2fd903af78450f6bf7ec98" + integrity sha512-ZXRY4jNvVgSVQ8DL3LTcakaAtXwTVUxE81hslsyD2AtoXW/wVob10HkOJ1X/pAlcI7D+2YoZKg5do8G/w6RYgA== + +"@jridgewell/gen-mapping@^0.3.0": + version "0.3.2" + resolved "https://registry.yarnpkg.com/@jridgewell/gen-mapping/-/gen-mapping-0.3.2.tgz#c1aedc61e853f2bb9f5dfe6d4442d3b565b253b9" + integrity sha512-mh65xKQAzI6iBcFzwv28KVWSmCkdRBWoOh+bYQGW3+6OZvbbN3TqMGo5hqYxQniRcH9F2VZIoJCm4pa3BPDK/A== + dependencies: + "@jridgewell/set-array" "^1.0.1" + "@jridgewell/sourcemap-codec" "^1.4.10" + "@jridgewell/trace-mapping" "^0.3.9" + +"@jridgewell/resolve-uri@3.1.0": + version "3.1.0" + resolved "https://registry.yarnpkg.com/@jridgewell/resolve-uri/-/resolve-uri-3.1.0.tgz#2203b118c157721addfe69d47b70465463066d78" + integrity sha512-F2msla3tad+Mfht5cJq7LSXcdudKTWCVYUgw6pLFOOHSTtZlj6SWNYAp+AhuqLmWdBO2X5hPrLcu8cVP8fy28w== + +"@jridgewell/set-array@^1.0.1": + version "1.1.2" + resolved "https://registry.yarnpkg.com/@jridgewell/set-array/-/set-array-1.1.2.tgz#7c6cf998d6d20b914c0a55a91ae928ff25965e72" + integrity sha512-xnkseuNADM0gt2bs+BvhO0p78Mk762YnZdsuzFV018NoG1Sj1SCQvpSqa7XUaTam5vAGasABV9qXASMKnFMwMw== + +"@jridgewell/source-map@^0.3.2": + version "0.3.2" + resolved "https://registry.yarnpkg.com/@jridgewell/source-map/-/source-map-0.3.2.tgz#f45351aaed4527a298512ec72f81040c998580fb" + integrity sha512-m7O9o2uR8k2ObDysZYzdfhb08VuEml5oWGiosa1VdaPZ/A6QyPkAJuwN0Q1lhULOf6B7MtQmHENS743hWtCrgw== + dependencies: + "@jridgewell/gen-mapping" "^0.3.0" + "@jridgewell/trace-mapping" "^0.3.9" + +"@jridgewell/sourcemap-codec@1.4.14", "@jridgewell/sourcemap-codec@^1.4.10": + version "1.4.14" + resolved "https://registry.yarnpkg.com/@jridgewell/sourcemap-codec/-/sourcemap-codec-1.4.14.tgz#add4c98d341472a289190b424efbdb096991bb24" + integrity sha512-XPSJHWmi394fuUuzDnGz1wiKqWfo1yXecHQMRf2l6hztTO+nPru658AyDngaBe7isIxEkRsPR3FZh+s7iVa4Uw== + +"@jridgewell/trace-mapping@^0.3.9": + version "0.3.17" + resolved "https://registry.yarnpkg.com/@jridgewell/trace-mapping/-/trace-mapping-0.3.17.tgz#793041277af9073b0951a7fe0f0d8c4c98c36985" + integrity sha512-MCNzAp77qzKca9+W/+I0+sEpaUnZoeasnghNeVc41VZCEKaCH73Vq3BZZ/SzWIgrqE4H4ceI+p+b6C0mHf9T4g== + dependencies: + "@jridgewell/resolve-uri" "3.1.0" + "@jridgewell/sourcemap-codec" "1.4.14" + "@protobufjs/aspromise@^1.1.1", "@protobufjs/aspromise@^1.1.2": version "1.1.2" resolved "https://registry.yarnpkg.com/@protobufjs/aspromise/-/aspromise-1.1.2.tgz#9b8b0cc663d669a7d8f6f5d0893a14d348f30fbf" @@ -85,9 +143,9 @@ integrity sha512-Vvn3zZrhQZkkBE8LSuW3em98c0FwgO4nxzv6OdSxPKJIEKY2bGbHn+mhGIPerzI4twdxaP8/0+06HBpwf345Lw== "@rollup/plugin-commonjs@^23.0.2": - version "23.0.2" - resolved "https://registry.yarnpkg.com/@rollup/plugin-commonjs/-/plugin-commonjs-23.0.2.tgz#3a3a5b7b1b1cb29037eb4992edcaae997d7ebd92" - integrity sha512-e9ThuiRf93YlVxc4qNIurvv+Hp9dnD+4PjOqQs5vAYfcZ3+AXSrcdzXnVjWxcGQOa6KGJFcRZyUI3ktWLavFjg== + version "23.0.3" + resolved "https://registry.yarnpkg.com/@rollup/plugin-commonjs/-/plugin-commonjs-23.0.3.tgz#442cd8ccca1b7563a503da86fc84a1a7112b54bb" + integrity sha512-31HxrT5emGfTyIfAs1lDQHj6EfYxTXcwtX5pIIhq+B/xZBNIqQ179d/CkYxlpYmFCxT78AeU4M8aL8Iv/IBxFA== dependencies: "@rollup/pluginutils" "^5.0.1" commondir "^1.0.1" @@ -108,6 +166,13 @@ is-module "^1.0.0" resolve "^1.22.1" +"@rollup/plugin-terser@^0.1.0": + version "0.1.0" + resolved "https://registry.yarnpkg.com/@rollup/plugin-terser/-/plugin-terser-0.1.0.tgz#7530c0f11667637419d71820461646c418526041" + integrity sha512-N2KK+qUfHX2hBzVzM41UWGLrEmcjVC37spC8R3c9mt3oEDFKh3N2e12/lLp9aVSt86veR0TQiCNQXrm8C6aiUQ== + dependencies: + terser "^5.15.1" + "@rollup/pluginutils@^5.0.1": version "5.0.2" resolved "https://registry.yarnpkg.com/@rollup/pluginutils/-/pluginutils-5.0.2.tgz#012b8f53c71e4f6f9cb317e311df1404f56e7a33" @@ -127,6 +192,21 @@ resolved "https://registry.yarnpkg.com/@types/google-protobuf/-/google-protobuf-3.15.6.tgz#674a69493ef2c849b95eafe69167ea59079eb504" integrity sha512-pYVNNJ+winC4aek+lZp93sIKxnXt5qMkuKmaqS3WGuTq0Bw1ZDYNBgzG5kkdtwcv+GmYJGo3yEg6z2cKKAiEdw== +"@types/is-windows@^1.0.0": + version "1.0.0" + resolved "https://registry.yarnpkg.com/@types/is-windows/-/is-windows-1.0.0.tgz#1011fa129d87091e2f6faf9042d6704cdf2e7be0" + integrity sha512-tJ1rq04tGKuIJoWIH0Gyuwv4RQ3+tIu7wQrC0MV47raQ44kIzXSSFKfrxFUOWVRvesoF7mrTqigXmqoZJsXwTg== + +"@types/istanbul-lib-coverage@^2.0.1": + version "2.0.4" + resolved "https://registry.yarnpkg.com/@types/istanbul-lib-coverage/-/istanbul-lib-coverage-2.0.4.tgz#8467d4b3c087805d63580480890791277ce35c44" + integrity sha512-z/QT1XN4K4KYuslS23k62yDIDLwLFkzxOuMplDtObz0+y7VqJCaO2o+SPwHCvLFZh7xazvvoor2tA/hPz9ee7g== + +"@types/jasmine@^4.3.1": + version "4.3.1" + resolved "https://registry.yarnpkg.com/@types/jasmine/-/jasmine-4.3.1.tgz#2d8ab5601c2fe7d9673dcb157e03f128ab5c5fff" + integrity sha512-Vu8l+UGcshYmV1VWwULgnV/2RDbBaO6i2Ptx7nd//oJPIZGhoI1YLST4VKagD2Pq/Bc2/7zvtvhM7F3p4SN7kQ== + "@types/linkify-it@*": version "3.0.2" resolved "https://registry.yarnpkg.com/@types/linkify-it/-/linkify-it-3.0.2.tgz#fd2cd2edbaa7eaac7e7f3c1748b52a19143846c9" @@ -145,10 +225,10 @@ resolved "https://registry.yarnpkg.com/@types/mdurl/-/mdurl-1.0.2.tgz#e2ce9d83a613bacf284c7be7d491945e39e1f8e9" integrity sha512-eC4U9MlIcu2q0KQmXszyn5Akca/0jrQmwDRgpAMJai7qBWq4amIQhZyNau4VYGtCeALvW1/NtjzJJ567aZxfKA== -"@types/node@>=13.7.0": - version "18.11.9" - resolved "https://registry.yarnpkg.com/@types/node/-/node-18.11.9.tgz#02d013de7058cea16d36168ef2fc653464cfbad4" - integrity sha512-CRpX21/kGdzjOpFsZSkcrXMGIBWMGNIHXXBVFSH+ggkftxg+XYP20TESbh+zFvFj3EQOl5byk0HTRn1IL6hbqg== +"@types/node@>=13.7.0", "@types/node@^18.11.11": + version "18.11.11" + resolved "https://registry.yarnpkg.com/@types/node/-/node-18.11.11.tgz#1d455ac0211549a8409d3cdb371cd55cc971e8dc" + integrity sha512-KJ021B1nlQUBLopzZmPBVuGU9un7WJd/W4ya7Ih02B4Uwky5Nja0yGYav2EfYIk0RR2Q9oVhf60S2XR1BCWJ2g== "@types/offscreencanvas@^2019.7.0": version "2019.7.0" @@ -160,17 +240,27 @@ resolved "https://registry.yarnpkg.com/@types/resolve/-/resolve-1.20.2.tgz#97d26e00cd4a0423b4af620abecf3e6f442b7975" integrity sha512-60BCwRFOZCQhDncwQdxxeOEEkbc5dIMccYLwbxsS4TUNeVECQ/pBJ0j09mrHOl/JJvpRPGwO9SvE4nR2Nb/a4Q== +"@xmldom/xmldom@^0.8.5": + version "0.8.6" + resolved "https://registry.yarnpkg.com/@xmldom/xmldom/-/xmldom-0.8.6.tgz#8a1524eb5bd5e965c1e3735476f0262469f71440" + integrity sha512-uRjjusqpoqfmRkTaNuLJ2VohVr67Q5YwDATW3VU7PfzTj6IRaihGrYI7zckGZjxQPBIp63nfvJbM+Yu5ICh0Bg== + acorn-jsx@^5.3.2: version "5.3.2" resolved "https://registry.yarnpkg.com/acorn-jsx/-/acorn-jsx-5.3.2.tgz#7ed5bb55908b3b2f1bc55c6af1653bada7f07937" integrity sha512-rq9s+JNhf0IChjtDXxllJ7g41oZk5SlXtp0LHwyA5cejwn7vKmKp4pPri6YEePv2PU65sAsegbXtIinmDFDXgQ== -acorn@^8.8.0: +acorn@^8.5.0, acorn@^8.8.0: version "8.8.1" resolved "https://registry.yarnpkg.com/acorn/-/acorn-8.8.1.tgz#0a3f9cbecc4ec3bea6f0a80b66ae8dd2da250b73" integrity sha512-7zFpHzhnqYKrkYdUjF1HI1bzd0VygEGX8lFk4k5zVMqHEoES+P+7TKI+EvLO9WVMJ8eekdO0aDEK044xTXwPPA== -ansi-styles@^4.1.0: +ansi-regex@^5.0.1: + version "5.0.1" + resolved "https://registry.yarnpkg.com/ansi-regex/-/ansi-regex-5.0.1.tgz#082cb2c89c9fe8659a311a53bd6a4dc5301db304" + integrity sha512-quJQXlTSUGL2LH9SUXo8VwsY4soanhgo6LNSm84E1LBcE8s3O0wpdiRzyR9z/ZZJMlMWv37qOOb9pdJlMUEKFQ== + +ansi-styles@^4.0.0, ansi-styles@^4.1.0: version "4.3.0" resolved "https://registry.yarnpkg.com/ansi-styles/-/ansi-styles-4.3.0.tgz#edd803628ae71c04c85ae7a0906edad34b648937" integrity sha512-zbB9rCJAT1rbjiVDb2hqKFHNYLxgtk8NURxZ3IZwD3F6NtxbXZQCnnSi1Lkx+IDohdPlFp222wVALIheZJQSEg== @@ -217,6 +307,25 @@ builtin-modules@^3.3.0: resolved "https://registry.yarnpkg.com/builtin-modules/-/builtin-modules-3.3.0.tgz#cae62812b89801e9656336e46223e030386be7b6" integrity sha512-zhaCDicdLuWN5UbN5IMnFqNMhNfo919sH85y2/ea+5Yg9TsTkeZxpL+JLbp6cgYFS4sRLp3YV4S6yDuqVWHYOw== +c8@~7.5.0: + version "7.5.0" + resolved "https://registry.yarnpkg.com/c8/-/c8-7.5.0.tgz#a69439ab82848f344a74bb25dc5dd4e867764481" + integrity sha512-GSkLsbvDr+FIwjNSJ8OwzWAyuznEYGTAd1pzb/Kr0FMLuV4vqYJTyjboDTwmlUNAG6jAU3PFWzqIdKrOt1D8tw== + dependencies: + "@bcoe/v8-coverage" "^0.2.3" + "@istanbuljs/schema" "^0.1.2" + find-up "^5.0.0" + foreground-child "^2.0.0" + furi "^2.0.0" + istanbul-lib-coverage "^3.0.0" + istanbul-lib-report "^3.0.0" + istanbul-reports "^3.0.2" + rimraf "^3.0.0" + test-exclude "^6.0.0" + v8-to-istanbul "^7.1.0" + yargs "^16.0.0" + yargs-parser "^20.0.0" + catharsis@^0.9.0: version "0.9.0" resolved "https://registry.yarnpkg.com/catharsis/-/catharsis-0.9.0.tgz#40382a168be0e6da308c277d3a2b3eb40c7d2121" @@ -232,6 +341,15 @@ chalk@^4.0.0: ansi-styles "^4.1.0" supports-color "^7.1.0" +cliui@^7.0.2: + version "7.0.4" + resolved "https://registry.yarnpkg.com/cliui/-/cliui-7.0.4.tgz#a0265ee655476fc807aea9df3df8df7783808b4f" + integrity sha512-OcRE68cOsVMXp1Yvonl/fzkQOyjLSu/8bhPDfQt0e0/Eb283TKP20Fs2MqoPsr9SwA595rRCA+QMzYc9nBP+JQ== + dependencies: + string-width "^4.2.0" + strip-ansi "^6.0.0" + wrap-ansi "^7.0.0" + color-convert@^2.0.1: version "2.0.1" resolved "https://registry.yarnpkg.com/color-convert/-/color-convert-2.0.1.tgz#72d3a68d598c9bdb3af2ad1e84f21d896abd4de3" @@ -244,6 +362,11 @@ color-name@~1.1.4: resolved "https://registry.yarnpkg.com/color-name/-/color-name-1.1.4.tgz#c2a09a87acbde69543de6f63fa3995c826c536a2" integrity sha512-dOy+3AuW3a2wNbZHIuMZpTcgjGuLU/uBL/ubcZF9OXbDo8ff4O8yVp5Bf0efS8uEoYo5q4Fx7dY9OgQGXgAsQA== +commander@^2.20.0: + version "2.20.3" + resolved "https://registry.yarnpkg.com/commander/-/commander-2.20.3.tgz#fd485e84c03eb4881c20722ba48035e8531aeb33" + integrity sha512-GpVkmM8vF2vQUkj2LvZmD35JxeJOLCwJ9cUkugyk2nuhbv3+mJvpLYYt+0+USMxE+oj+ey/lJEnhZw75x/OMcQ== + commondir@^1.0.1: version "1.0.1" resolved "https://registry.yarnpkg.com/commondir/-/commondir-1.0.1.tgz#ddd800da0c66127393cca5950ea968a3aaf1253b" @@ -254,6 +377,20 @@ concat-map@0.0.1: resolved "https://registry.yarnpkg.com/concat-map/-/concat-map-0.0.1.tgz#d8a96bd77fd68df7793a73036a3ba0d5405d477b" integrity sha512-/Srv4dswyQNBfohGpz9o6Yb3Gz3SrUDqBH5rTuhGR7ahtlbYKnVxw2bCFMRljaA7EXHaXZ8wsHdodFvbkhKmqg== +convert-source-map@^1.6.0: + version "1.9.0" + resolved "https://registry.yarnpkg.com/convert-source-map/-/convert-source-map-1.9.0.tgz#7faae62353fb4213366d0ca98358d22e8368b05f" + integrity sha512-ASFBup0Mz1uyiIjANan1jzLQami9z1PoYSZCiiYW2FczPbenXc45FZdBZLzOT+r6+iciuEModtmCti+hjaAk0A== + +cross-spawn@^7.0.0: + version "7.0.3" + resolved "https://registry.yarnpkg.com/cross-spawn/-/cross-spawn-7.0.3.tgz#f73a85b9d5d41d045551c177e2882d4ac85728a6" + integrity sha512-iRDPJKUPVEND7dHPO8rkbOnPpyDygcDFtWjpeWNCgy8WP2rXcxXL8TskReQl6OrB2G7+UJrags1q15Fudc7G6w== + dependencies: + path-key "^3.1.0" + shebang-command "^2.0.0" + which "^2.0.1" + deep-is@~0.1.3: version "0.1.4" resolved "https://registry.yarnpkg.com/deep-is/-/deep-is-0.1.4.tgz#a6f2dce612fadd2ef1f519b73551f17e85199831" @@ -264,11 +401,21 @@ deepmerge@^4.2.2: resolved "https://registry.yarnpkg.com/deepmerge/-/deepmerge-4.2.2.tgz#44d2ea3679b8f4d4ffba33f03d865fc1e7bf4955" integrity sha512-FJ3UgI4gIl+PHZm53knsuSFpE+nESMr7M4v9QcgB7S63Kj/6WqMiFQJpBBYz1Pt+66bZpP3Q7Lye0Oo9MPKEdg== +emoji-regex@^8.0.0: + version "8.0.0" + resolved "https://registry.yarnpkg.com/emoji-regex/-/emoji-regex-8.0.0.tgz#e818fd69ce5ccfcb404594f842963bf53164cc37" + integrity sha512-MSjYzcWNOA0ewAHpz0MxpYFvwg6yjy1NG3xteoqz644VCo/RPgnr1/GGt+ic3iJTzQ8Eu3TdM14SawnVUmGE6A== + entities@~2.1.0: version "2.1.0" resolved "https://registry.yarnpkg.com/entities/-/entities-2.1.0.tgz#992d3129cf7df6870b96c57858c249a120f8b8b5" integrity sha512-hCx1oky9PFrJ611mf0ifBLBRW8lUUVRlFolb5gWRfIELabBlbp9xZvrqZLZAs+NxFnbfQoeGd8wDkygjg7U85w== +escalade@^3.1.1: + version "3.1.1" + resolved "https://registry.yarnpkg.com/escalade/-/escalade-3.1.1.tgz#d8cfdc7000965c5a0174b4a82eaa5c0552742e40" + integrity sha512-k0er2gUkLf8O0zKJiAhmkTnJlTvINGv7ygDNPbeIsX/TJjGJZHuh9B2UxbsaEkmlEo9MfhrSzmhIlhRlI2GXnw== + escape-string-regexp@^2.0.0: version "2.0.0" resolved "https://registry.yarnpkg.com/escape-string-regexp/-/escape-string-regexp-2.0.0.tgz#a30304e99daa32e23b2fd20f51babd07cffca344" @@ -330,6 +477,22 @@ fast-levenshtein@~2.0.6: resolved "https://registry.yarnpkg.com/fast-levenshtein/-/fast-levenshtein-2.0.6.tgz#3d8a5c66883a16a30ca8643e851f19baa7797917" integrity sha512-DCXu6Ifhqcks7TZKY3Hxp3y6qphY5SJZmrWMDrKcERSOXWQdMhU9Ig/PYrzyw/ul9jOIyh0N4M0tbC5hodg8dw== +find-up@^5.0.0: + version "5.0.0" + resolved "https://registry.yarnpkg.com/find-up/-/find-up-5.0.0.tgz#4c92819ecb7083561e4f4a240a86be5198f536fc" + integrity sha512-78/PXT1wlLLDgTzDs7sjq9hzz0vXD+zn+7wypEe4fXQxCmdmqfGsEPQxmiCSQI3ajFV91bVSsvNtrJRiW6nGng== + dependencies: + locate-path "^6.0.0" + path-exists "^4.0.0" + +foreground-child@^2.0.0: + version "2.0.0" + resolved "https://registry.yarnpkg.com/foreground-child/-/foreground-child-2.0.0.tgz#71b32800c9f15aa8f2f83f4a6bd9bff35d861a53" + integrity sha512-dCIq9FpEcyQyXKCkyzmlPTFNgrCzPudOe+mhvJU5zAtlBnGVy2yKxtfsxK2tQBThwq225jcvBjpw1Gr40uzZCA== + dependencies: + cross-spawn "^7.0.0" + signal-exit "^3.0.2" + fs.realpath@^1.0.0: version "1.0.0" resolved "https://registry.yarnpkg.com/fs.realpath/-/fs.realpath-1.0.0.tgz#1504ad2523158caa40db4a2787cb01411994ea4f" @@ -345,7 +508,20 @@ function-bind@^1.1.1: resolved "https://registry.yarnpkg.com/function-bind/-/function-bind-1.1.1.tgz#a56899d3ea3c9bab874bb9773b7c5ede92f4895d" integrity sha512-yIovAzMX49sF8Yl58fSCWJ5svSLuaibPxXQJFLmBObTuCr0Mf1KiPopGM9NiFjiYBCbfaa2Fh6breQ6ANVTI0A== -glob@^7.1.3: +furi@^2.0.0: + version "2.0.0" + resolved "https://registry.yarnpkg.com/furi/-/furi-2.0.0.tgz#13d85826a1af21acc691da6254b3888fc39f0b4a" + integrity sha512-uKuNsaU0WVaK/vmvj23wW1bicOFfyqSsAIH71bRZx8kA4Xj+YCHin7CJKJJjkIsmxYaPFLk9ljmjEyB7xF7WvQ== + dependencies: + "@types/is-windows" "^1.0.0" + is-windows "^1.0.2" + +get-caller-file@^2.0.5: + version "2.0.5" + resolved "https://registry.yarnpkg.com/get-caller-file/-/get-caller-file-2.0.5.tgz#4f94412a82db32f36e3b0b9741f8a97feb031f7e" + integrity sha512-DyFP3BM/3YHTQOCUL/w0OZHR0lpKeGrxotcHWcqNEdnltqFwXVfhEBQ94eIo34AfQpo0rGki4cyIiftY06h2Fg== + +glob@^7.1.3, glob@^7.1.4, glob@^7.1.6: version "7.2.3" resolved "https://registry.yarnpkg.com/glob/-/glob-7.2.3.tgz#b8df0fb802bbfa8e89bd1d938b4e16578ed44f2b" integrity sha512-nFR0zLpU2YCaRxwoCJvL6UvCH2JFyFVIvwTLsIf21AuHlMskA1hhTdk+LlYJtOlYt9v6dvszD2BGRqBL+iQK9Q== @@ -390,6 +566,11 @@ has@^1.0.3: dependencies: function-bind "^1.1.1" +html-escaper@^2.0.0: + version "2.0.2" + resolved "https://registry.yarnpkg.com/html-escaper/-/html-escaper-2.0.2.tgz#dfd60027da36a36dfcbe236262c00a5822681453" + integrity sha512-H2iMtd0I4Mt5eYiapRdIDjp+XzelXQ0tFE4JS7YFwFevXXMmOp9myNrUvCg0D6ws8iqkRPBfKHgbwig1SmlLfg== + inflight@^1.0.4: version "1.0.6" resolved "https://registry.yarnpkg.com/inflight/-/inflight-1.0.6.tgz#49bd6331d7d02d0c09bc910a1075ba8165b56df9" @@ -417,6 +598,11 @@ is-core-module@^2.9.0: dependencies: has "^1.0.3" +is-fullwidth-code-point@^3.0.0: + version "3.0.0" + resolved "https://registry.yarnpkg.com/is-fullwidth-code-point/-/is-fullwidth-code-point-3.0.0.tgz#f116f8064fe90b3f7844a38997c0b75051269f1d" + integrity sha512-zymm5+u+sCsSWyD9qNaejV3DFvhCKclKdizYaJUuHA83RLjb7nSuGnddCHGv0hk+KY7BMAlsWeK4Ueg6EV6XQg== + is-module@^1.0.0: version "1.0.0" resolved "https://registry.yarnpkg.com/is-module/-/is-module-1.0.0.tgz#3258fb69f78c14d5b815d664336b4cffb6441591" @@ -429,6 +615,59 @@ is-reference@1.2.1: dependencies: "@types/estree" "*" +is-windows@^1.0.2: + version "1.0.2" + resolved "https://registry.yarnpkg.com/is-windows/-/is-windows-1.0.2.tgz#d1850eb9791ecd18e6182ce12a30f396634bb19d" + integrity sha512-eXK1UInq2bPmjyX6e3VHIzMLobc4J94i4AWn+Hpq3OU5KkrRC96OAcR3PRJ/pGu6m8TRnBHP9dkXQVsT/COVIA== + +isexe@^2.0.0: + version "2.0.0" + resolved "https://registry.yarnpkg.com/isexe/-/isexe-2.0.0.tgz#e8fbf374dc556ff8947a10dcb0572d633f2cfa10" + integrity sha512-RHxMLp9lnKHGHRng9QFhRCMbYAcVpn69smSGcq3f36xjgVVWThj4qqLbTLlq7Ssj8B+fIQ1EuCEGI2lKsyQeIw== + +istanbul-lib-coverage@^3.0.0: + version "3.2.0" + resolved "https://registry.yarnpkg.com/istanbul-lib-coverage/-/istanbul-lib-coverage-3.2.0.tgz#189e7909d0a39fa5a3dfad5b03f71947770191d3" + integrity sha512-eOeJ5BHCmHYvQK7xt9GkdHuzuCGS1Y6g9Gvnx3Ym33fz/HpLRYxiS0wHNr+m/MBC8B647Xt608vCDEvhl9c6Mw== + +istanbul-lib-report@^3.0.0: + version "3.0.0" + resolved "https://registry.yarnpkg.com/istanbul-lib-report/-/istanbul-lib-report-3.0.0.tgz#7518fe52ea44de372f460a76b5ecda9ffb73d8a6" + integrity sha512-wcdi+uAKzfiGT2abPpKZ0hSU1rGQjUQnLvtY5MpQ7QCTahD3VODhcu4wcfY1YtkGaDD5yuydOLINXsfbus9ROw== + dependencies: + istanbul-lib-coverage "^3.0.0" + make-dir "^3.0.0" + supports-color "^7.1.0" + +istanbul-reports@^3.0.2: + version "3.1.5" + resolved "https://registry.yarnpkg.com/istanbul-reports/-/istanbul-reports-3.1.5.tgz#cc9a6ab25cb25659810e4785ed9d9fb742578bae" + integrity sha512-nUsEMa9pBt/NOHqbcbeJEgqIlY/K7rVWUX6Lql2orY5e9roQOthbR3vtY4zzf2orPELg80fnxxk9zUyPlgwD1w== + dependencies: + html-escaper "^2.0.0" + istanbul-lib-report "^3.0.0" + +jasmine-core@^4.5.0: + version "4.5.0" + resolved "https://registry.yarnpkg.com/jasmine-core/-/jasmine-core-4.5.0.tgz#1a6bd0bde3f60996164311c88a0995d67ceda7c3" + integrity sha512-9PMzyvhtocxb3aXJVOPqBDswdgyAeSB81QnLop4npOpbqnheaTEwPc9ZloQeVswugPManznQBjD8kWDTjlnHuw== + +jasmine-reporters@~2.5.0: + version "2.5.2" + resolved "https://registry.yarnpkg.com/jasmine-reporters/-/jasmine-reporters-2.5.2.tgz#b5dfa1d9c40b8020c5225e0e1e2b9953d66a4d69" + integrity sha512-qdewRUuFOSiWhiyWZX8Yx3YNQ9JG51ntBEO4ekLQRpktxFTwUHy24a86zD/Oi2BRTKksEdfWQZcQFqzjqIkPig== + dependencies: + "@xmldom/xmldom" "^0.8.5" + mkdirp "^1.0.4" + +jasmine@^4.5.0: + version "4.5.0" + resolved "https://registry.yarnpkg.com/jasmine/-/jasmine-4.5.0.tgz#8d3c0d0a33a61e4d05c9f9747ee5dfaf6f7b5d3d" + integrity sha512-9olGRvNZyADIwYL9XBNBst5BTU/YaePzuddK+YRslc7rI9MdTIE4r3xaBKbv2GEmzYYUfMOdTR8/i6JfLZaxSQ== + dependencies: + glob "^7.1.6" + jasmine-core "^4.5.0" + js2xmlparser@^4.0.2: version "4.0.2" resolved "https://registry.yarnpkg.com/js2xmlparser/-/js2xmlparser-4.0.2.tgz#2a1fdf01e90585ef2ae872a01bc169c6a8d5e60a" @@ -479,7 +718,14 @@ linkify-it@^3.0.1: dependencies: uc.micro "^1.0.1" -lodash@^4.17.14, lodash@^4.17.15: +locate-path@^6.0.0: + version "6.0.0" + resolved "https://registry.yarnpkg.com/locate-path/-/locate-path-6.0.0.tgz#55321eb309febbc59c4801d931a72452a681d286" + integrity sha512-iPZK6eYjbxRu3uB4/WZ3EsEIMJFMqAoopl3R+zuq0UjcAm/MO6KCweDgPfP3elTztoKP3KtnVHxTn2NHBSDVUw== + dependencies: + p-locate "^5.0.0" + +lodash@^4.17.15, lodash@^4.17.21: version "4.17.21" resolved "https://registry.yarnpkg.com/lodash/-/lodash-4.17.21.tgz#679591c564c3bffaae8454cf0b3df370c3d6911c" integrity sha512-v2kDEe57lecTulaDIuNTPy3Ry4gLGJ6Z1O3vE1krgXZNrsQ+LFTGHVxVjcXPs17LhbZVGedAJv8XZ1tvj5FvSg== @@ -503,6 +749,13 @@ magic-string@^0.26.4: dependencies: sourcemap-codec "^1.4.8" +make-dir@^3.0.0: + version "3.1.0" + resolved "https://registry.yarnpkg.com/make-dir/-/make-dir-3.1.0.tgz#415e967046b3a7f1d185277d84aa58203726a13f" + integrity sha512-g3FeP20LNwhALb/6Cz6Dd4F2ngze0jz7tbzrD2wAV+o9FeNHe4rL+yK2md0J/fiSf1sa1ADhXqi5+oVwOM/eGw== + dependencies: + semver "^6.0.0" + markdown-it-anchor@^8.4.1: version "8.6.5" resolved "https://registry.yarnpkg.com/markdown-it-anchor/-/markdown-it-anchor-8.6.5.tgz#30c4bc5bbff327f15ce3c429010ec7ba75e7b5f8" @@ -520,16 +773,16 @@ markdown-it@^12.3.2: uc.micro "^1.0.5" marked@^4.0.10: - version "4.2.2" - resolved "https://registry.yarnpkg.com/marked/-/marked-4.2.2.tgz#1d2075ad6cdfe42e651ac221c32d949a26c0672a" - integrity sha512-JjBTFTAvuTgANXx82a5vzK9JLSMoV6V3LBVn4Uhdso6t7vXrGx7g1Cd2r6NYSsxrYbQGFCMqBDhFHyK5q2UvcQ== + version "4.2.3" + resolved "https://registry.yarnpkg.com/marked/-/marked-4.2.3.tgz#bd76a5eb510ff1d8421bc6c3b2f0b93488c15bea" + integrity sha512-slWRdJkbTZ+PjkyJnE30Uid64eHwbwa1Q25INCAYfZlK4o6ylagBy/Le9eWntqJFoFT93ikUKMv47GZ4gTwHkw== mdurl@^1.0.1: version "1.0.1" resolved "https://registry.yarnpkg.com/mdurl/-/mdurl-1.0.1.tgz#fe85b2ec75a59037f2adfec100fd6c601761152e" integrity sha512-/sKlQJCBYVY9Ers9hqzKou4H6V5UWc/M59TH2dvkt+84itfnq7uFOMLpOiOS4ujvHP4etln18fmIxA5R5fll0g== -minimatch@^3.1.1: +minimatch@^3.0.4, minimatch@^3.1.1: version "3.1.2" resolved "https://registry.yarnpkg.com/minimatch/-/minimatch-3.1.2.tgz#19cd194bfd3e428f049a70817c038d89ab4be35b" integrity sha512-J7p63hRiAjw1NDEww1W7i37+ByIrOWO5XQQAzZ3VOcL0PNybwpfmV/N05zFAzwQ9USyEcX6t3UO+K5aqBQOIHw== @@ -537,9 +790,9 @@ minimatch@^3.1.1: brace-expansion "^1.1.7" minimatch@^5.0.1: - version "5.1.0" - resolved "https://registry.yarnpkg.com/minimatch/-/minimatch-5.1.0.tgz#1717b464f4971b144f6aabe8f2d0b8e4511e09c7" - integrity sha512-9TPBGGak4nHfGZsPBohm9AWg6NoT7QTCehS3BIJABslyZbzxfV78QM2Y6+i741OPZIafFAaiiEMh5OyIrJPgtg== + version "5.1.1" + resolved "https://registry.yarnpkg.com/minimatch/-/minimatch-5.1.1.tgz#6c9dffcf9927ff2a31e74b5af11adf8b9604b022" + integrity sha512-362NP+zlprccbEt/SkxKfRMHnNY85V74mVnpUpNyr3F35covl09Kec7/sEFLt3RA4oXmewtoaanoIf67SE5Y5g== dependencies: brace-expansion "^2.0.1" @@ -572,11 +825,35 @@ optionator@^0.8.1: type-check "~0.3.2" word-wrap "~1.2.3" +p-limit@^3.0.2: + version "3.1.0" + resolved "https://registry.yarnpkg.com/p-limit/-/p-limit-3.1.0.tgz#e1daccbe78d0d1388ca18c64fea38e3e57e3706b" + integrity sha512-TYOanM3wGwNGsZN2cVTYPArw454xnXj5qmWF1bEoAc4+cU/ol7GVh7odevjp1FNHduHc3KZMcFduxU5Xc6uJRQ== + dependencies: + yocto-queue "^0.1.0" + +p-locate@^5.0.0: + version "5.0.0" + resolved "https://registry.yarnpkg.com/p-locate/-/p-locate-5.0.0.tgz#83c8315c6785005e3bd021839411c9e110e6d834" + integrity sha512-LaNjtRWUBY++zB5nE/NwcaoMylSPk+S+ZHNB1TzdbMJMny6dynpAGt7X/tl/QYq3TIeE6nxHppbo2LGymrG5Pw== + dependencies: + p-limit "^3.0.2" + +path-exists@^4.0.0: + version "4.0.0" + resolved "https://registry.yarnpkg.com/path-exists/-/path-exists-4.0.0.tgz#513bdbe2d3b95d7762e8c1137efa195c6c61b5b3" + integrity sha512-ak9Qy5Q7jYb2Wwcey5Fpvg2KoAc/ZIhLSLOSBmRmygPsGwkVVt0fZa0qrtMz+m6tJTAHfZQ8FnmB4MG4LWy7/w== + path-is-absolute@^1.0.0: version "1.0.1" resolved "https://registry.yarnpkg.com/path-is-absolute/-/path-is-absolute-1.0.1.tgz#174b9268735534ffbc7ace6bf53a5a9e1b5c5f5f" integrity sha512-AVbw3UJ2e9bq64vSaS9Am0fje1Pa8pbGqTTsmXfaIiMpnr5DlDhfJOuLj9Sf95ZPVDAUerDfEk88MPmPe7UCQg== +path-key@^3.1.0: + version "3.1.1" + resolved "https://registry.yarnpkg.com/path-key/-/path-key-3.1.1.tgz#581f6ade658cbba65a0d3380de7753295054f375" + integrity sha512-ojmeN0qd+y0jszEtoY48r0Peq5dwMEkIlCOu6Q5f41lfkswXuKtYrhgoTpLnyIcHm24Uhqx+5Tqm2InSwLhE6Q== + path-parse@^1.0.7: version "1.0.7" resolved "https://registry.yarnpkg.com/path-parse/-/path-parse-1.0.7.tgz#fbc114b60ca42b30d9daf5858e4bd68bbedb6735" @@ -626,12 +903,17 @@ protobufjs@^7.1.2: "@types/node" ">=13.7.0" long "^5.0.0" +require-directory@^2.1.1: + version "2.1.1" + resolved "https://registry.yarnpkg.com/require-directory/-/require-directory-2.1.1.tgz#8c64ad5fd30dab1c976e2344ffe7f792a6a6df42" + integrity sha512-fGxEI7+wsG9xrvdjsrlmL22OMTTiHRwAMroiEeMgq8gzoLC/PQr7RsRDSTLUg/bZAZtF+TVIkHc6/4RIKrui+Q== + requizzle@^0.2.3: - version "0.2.3" - resolved "https://registry.yarnpkg.com/requizzle/-/requizzle-0.2.3.tgz#4675c90aacafb2c036bd39ba2daa4a1cb777fded" - integrity sha512-YanoyJjykPxGHii0fZP0uUPEXpvqfBDxWV7s6GKAiiOsiqhX6vHNyW3Qzdmqp/iq/ExbhaGbVrjB4ruEVSM4GQ== + version "0.2.4" + resolved "https://registry.yarnpkg.com/requizzle/-/requizzle-0.2.4.tgz#319eb658b28c370f0c20f968fa8ceab98c13d27c" + integrity sha512-JRrFk1D4OQ4SqovXOgdav+K8EAhSB/LJZqCz8tbX0KObcdeM15Ss59ozWMBWmmINMagCwmqn4ZNryUGpBsl6Jw== dependencies: - lodash "^4.17.14" + lodash "^4.17.21" resolve@^1.22.1: version "1.22.1" @@ -661,6 +943,11 @@ semver@5.6.0: resolved "https://registry.yarnpkg.com/semver/-/semver-5.6.0.tgz#7e74256fbaa49c75aa7c7a205cc22799cac80004" integrity sha512-RS9R6R35NYgQn++fkDWaOmqGoj4Ek9gGs+DPxNUZKuwE183xjJroKvyo1IzVFeXvUrvmALy6FWD5xrdJT25gMg== +semver@^6.0.0: + version "6.3.0" + resolved "https://registry.yarnpkg.com/semver/-/semver-6.3.0.tgz#ee0a64c8af5e8ceea67687b133761e1becbd1d3d" + integrity sha512-b39TBaTSfV6yBrapU89p5fKekE2m/NwnDocOVruQFS1/veMgdzuPcnOM34M6CwxW8jH/lxEa5rBoDeUwu5HHTw== + semver@^7.1.2: version "7.3.8" resolved "https://registry.yarnpkg.com/semver/-/semver-7.3.8.tgz#07a78feafb3f7b32347d725e33de7e2a2df67798" @@ -668,6 +955,23 @@ semver@^7.1.2: dependencies: lru-cache "^6.0.0" +shebang-command@^2.0.0: + version "2.0.0" + resolved "https://registry.yarnpkg.com/shebang-command/-/shebang-command-2.0.0.tgz#ccd0af4f8835fbdc265b82461aaf0c36663f34ea" + integrity sha512-kHxr2zZpYtdmrN1qDjrrX/Z1rR1kG8Dx+gkpK1G4eXmvXswmcE1hTWBWYUzlraYw1/yZp6YuDY77YtvbN0dmDA== + dependencies: + shebang-regex "^3.0.0" + +shebang-regex@^3.0.0: + version "3.0.0" + resolved "https://registry.yarnpkg.com/shebang-regex/-/shebang-regex-3.0.0.tgz#ae16f1644d873ecad843b0307b143362d4c42172" + integrity sha512-7++dFhtcx3353uBaq8DDR4NuxBetBzC7ZQOhmTQInHEd6bSrXdiEyzCvG07Z44UYdLShWUyXt5M/yhz8ekcb1A== + +signal-exit@^3.0.2: + version "3.0.7" + resolved "https://registry.yarnpkg.com/signal-exit/-/signal-exit-3.0.7.tgz#a9a1767f8af84155114eaabd73f99273c8f59ad9" + integrity sha512-wnD2ZE+l+SPC/uoS0vXeE9L1+0wuaMqKlfz9AMUo38JsyLSBWSFcHR1Rri62LZc12vLr1gb3jl7iwQhgwpAbGQ== + source-map-support@0.5.9: version "0.5.9" resolved "https://registry.yarnpkg.com/source-map-support/-/source-map-support-0.5.9.tgz#41bc953b2534267ea2d605bccfa7bfa3111ced5f" @@ -676,16 +980,45 @@ source-map-support@0.5.9: buffer-from "^1.0.0" source-map "^0.6.0" +source-map-support@~0.5.20: + version "0.5.21" + resolved "https://registry.yarnpkg.com/source-map-support/-/source-map-support-0.5.21.tgz#04fe7c7f9e1ed2d662233c28cb2b35b9f63f6e4f" + integrity sha512-uBHU3L3czsIyYXKX88fdrGovxdSCoTGDRZ6SYXtSRxLZUzHg5P/66Ht6uoUlHu9EZod+inXhKo3qQgwXUT/y1w== + dependencies: + buffer-from "^1.0.0" + source-map "^0.6.0" + source-map@^0.6.0, source-map@~0.6.1: version "0.6.1" resolved "https://registry.yarnpkg.com/source-map/-/source-map-0.6.1.tgz#74722af32e9614e9c287a8d0bbde48b5e2f1a263" integrity sha512-UjgapumWlbMhkBgzT7Ykc5YXUT46F0iKu8SGXq0bcwP5dz/h0Plj6enJqjz1Zbq2l5WaqYnrVbwWOWMyF3F47g== +source-map@^0.7.3: + version "0.7.4" + resolved "https://registry.yarnpkg.com/source-map/-/source-map-0.7.4.tgz#a9bbe705c9d8846f4e08ff6765acf0f1b0898656" + integrity sha512-l3BikUxvPOcn5E74dZiq5BGsTb5yEwhaTSzccU6t4sDOH8NWJCstKO5QT2CvtFoK6F0saL7p9xHAqHOlCPJygA== + sourcemap-codec@^1.4.8: version "1.4.8" resolved "https://registry.yarnpkg.com/sourcemap-codec/-/sourcemap-codec-1.4.8.tgz#ea804bd94857402e6992d05a38ef1ae35a9ab4c4" integrity sha512-9NykojV5Uih4lgo5So5dtw+f0JgJX30KCNI8gwhz2J9A15wD0Ml6tjHKwf6fTSa6fAdVBdZeNOs9eJ71qCk8vA== +string-width@^4.1.0, string-width@^4.2.0: + version "4.2.3" + resolved "https://registry.yarnpkg.com/string-width/-/string-width-4.2.3.tgz#269c7117d27b05ad2e536830a8ec895ef9c6d010" + integrity sha512-wKyQRQpjJ0sIp62ErSZdGsjMJWsap5oRNihHhu6G7JVO/9jIB6UyevL+tXuOqrng8j/cxKTWyWUwvSTriiZz/g== + dependencies: + emoji-regex "^8.0.0" + is-fullwidth-code-point "^3.0.0" + strip-ansi "^6.0.1" + +strip-ansi@^6.0.0, strip-ansi@^6.0.1: + version "6.0.1" + resolved "https://registry.yarnpkg.com/strip-ansi/-/strip-ansi-6.0.1.tgz#9e26c63d30f53443e9489495b2105d37b67a85d9" + integrity sha512-Y38VPSHcqkFrCpFnQ9vuSXmquuv5oXOKpGeT6aGrr3o3Gc9AlVa6JBfUSOCnbxGGZF+/0ooI7KrPuUSztUdU5A== + dependencies: + ansi-regex "^5.0.1" + strip-json-comments@^3.1.0: version "3.1.1" resolved "https://registry.yarnpkg.com/strip-json-comments/-/strip-json-comments-3.1.1.tgz#31f1281b3832630434831c310c01cccda8cbe006" @@ -708,6 +1041,25 @@ taffydb@2.6.2: resolved "https://registry.yarnpkg.com/taffydb/-/taffydb-2.6.2.tgz#7cbcb64b5a141b6a2efc2c5d2c67b4e150b2a268" integrity sha512-y3JaeRSplks6NYQuCOj3ZFMO3j60rTwbuKCvZxsAraGYH2epusatvZ0baZYA01WsGqJBq/Dl6vOrMUJqyMj8kA== +terser@^5.15.1: + version "5.16.1" + resolved "https://registry.yarnpkg.com/terser/-/terser-5.16.1.tgz#5af3bc3d0f24241c7fb2024199d5c461a1075880" + integrity sha512-xvQfyfA1ayT0qdK47zskQgRZeWLoOQ8JQ6mIgRGVNwZKdQMU+5FkCBjmv4QjcrTzyZquRw2FVtlJSRUmMKQslw== + dependencies: + "@jridgewell/source-map" "^0.3.2" + acorn "^8.5.0" + commander "^2.20.0" + source-map-support "~0.5.20" + +test-exclude@^6.0.0: + version "6.0.0" + resolved "https://registry.yarnpkg.com/test-exclude/-/test-exclude-6.0.0.tgz#04a8698661d805ea6fa293b6cb9e63ac044ef15e" + integrity sha512-cAGWPIyOHU6zlmg88jwm7VRyXnMN7iV68OGAbYDk/Mh/xC/pzVPlQtY6ngoIH/5/tciuhGfvESU8GrHrcxD56w== + dependencies: + "@istanbuljs/schema" "^0.1.2" + glob "^7.1.4" + minimatch "^3.0.4" + tmp@^0.2.1: version "0.2.1" resolved "https://registry.yarnpkg.com/tmp/-/tmp-0.2.1.tgz#8457fc3037dcf4719c251367a1af6500ee1ccf14" @@ -742,9 +1094,9 @@ type-check@~0.3.2: prelude-ls "~1.1.2" typescript@^4.8.4: - version "4.8.4" - resolved "https://registry.yarnpkg.com/typescript/-/typescript-4.8.4.tgz#c464abca159669597be5f96b8943500b238e60e6" - integrity sha512-QCh+85mCy+h0IGff8r5XWzOVSbBO+KfeYrMQh7NJ58QujwcE22u+NUSmUxqF+un70P9GXKxa2HCNiTTMJknyjQ== + version "4.9.3" + resolved "https://registry.yarnpkg.com/typescript/-/typescript-4.9.3.tgz#3aea307c1746b8c384435d8ac36b8a2e580d85db" + integrity sha512-CIfGzTelbKNEnLpLdGFgdyKhG23CKdKgQPOBc+OUNrkJ2vr+KSzsSV5kq5iWhEQbok+quxgGzrAtGWCyU7tHnA== uc.micro@^1.0.1, uc.micro@^1.0.5: version "1.0.6" @@ -761,11 +1113,36 @@ underscore@~1.13.2: resolved "https://registry.yarnpkg.com/underscore/-/underscore-1.13.6.tgz#04786a1f589dc6c09f761fc5f45b89e935136441" integrity sha512-+A5Sja4HP1M08MaXya7p5LvjuM7K6q/2EaC0+iovj/wOcMsTzMvDFbasi/oSapiwOlt252IqsKqPjCl7huKS0A== +v8-to-istanbul@^7.1.0: + version "7.1.2" + resolved "https://registry.yarnpkg.com/v8-to-istanbul/-/v8-to-istanbul-7.1.2.tgz#30898d1a7fa0c84d225a2c1434fb958f290883c1" + integrity sha512-TxNb7YEUwkLXCQYeudi6lgQ/SZrzNO4kMdlqVxaZPUIUjCv6iSSypUQX70kNBSERpQ8fk48+d61FXk+tgqcWow== + dependencies: + "@types/istanbul-lib-coverage" "^2.0.1" + convert-source-map "^1.6.0" + source-map "^0.7.3" + +which@^2.0.1: + version "2.0.2" + resolved "https://registry.yarnpkg.com/which/-/which-2.0.2.tgz#7c6a8dd0a636a0327e10b59c9286eee93f3f51b1" + integrity sha512-BLI3Tl1TW3Pvl70l3yq3Y64i+awpwXqsGBYWkkqMtnbXgrMD+yj7rhW0kuEDxzJaYXGjEW5ogapKNMEKNMjibA== + dependencies: + isexe "^2.0.0" + word-wrap@~1.2.3: version "1.2.3" resolved "https://registry.yarnpkg.com/word-wrap/-/word-wrap-1.2.3.tgz#610636f6b1f703891bd34771ccb17fb93b47079c" integrity sha512-Hz/mrNwitNRh/HUAtM/VT/5VH+ygD6DV7mYKZAtHOrbs8U7lvPS6xf7EJKMF0uW1KJCl0H701g3ZGus+muE5vQ== +wrap-ansi@^7.0.0: + version "7.0.0" + resolved "https://registry.yarnpkg.com/wrap-ansi/-/wrap-ansi-7.0.0.tgz#67e145cff510a6a6984bdf1152911d69d2eb9e43" + integrity sha512-YVGIj2kamLSTxw6NsZjoBxfSwsn0ycdesmc4p+Q21c5zPuZ1pl+NfxVdxPtdHvmNVOQ6XSYG4AUtyt/Fi7D16Q== + dependencies: + ansi-styles "^4.0.0" + string-width "^4.1.0" + strip-ansi "^6.0.0" + wrappy@1: version "1.0.2" resolved "https://registry.yarnpkg.com/wrappy/-/wrappy-1.0.2.tgz#b5243d8f3ec1aa35f1364605bc0d1036e30ab69f" @@ -776,7 +1153,35 @@ xmlcreate@^2.0.4: resolved "https://registry.yarnpkg.com/xmlcreate/-/xmlcreate-2.0.4.tgz#0c5ab0f99cdd02a81065fa9cd8f8ae87624889be" integrity sha512-nquOebG4sngPmGPICTS5EnxqhKbCmz5Ox5hsszI2T6U5qdrJizBc+0ilYSEjTSzU0yZcmvppztXe/5Al5fUwdg== +y18n@^5.0.5: + version "5.0.8" + resolved "https://registry.yarnpkg.com/y18n/-/y18n-5.0.8.tgz#7f4934d0f7ca8c56f95314939ddcd2dd91ce1d55" + integrity sha512-0pfFzegeDWJHJIAmTLRP2DwHjdF5s7jo9tuztdQxAhINCdvS+3nGINqPd00AphqJR/0LhANUS6/+7SCb98YOfA== + yallist@^4.0.0: version "4.0.0" resolved "https://registry.yarnpkg.com/yallist/-/yallist-4.0.0.tgz#9bb92790d9c0effec63be73519e11a35019a3a72" integrity sha512-3wdGidZyq5PB084XLES5TpOSRA3wjXAlIWMhum2kRcv/41Sn2emQ0dycQW4uZXLejwKvg6EsvbdlVL+FYEct7A== + +yargs-parser@^20.0.0, yargs-parser@^20.2.2: + version "20.2.9" + resolved "https://registry.yarnpkg.com/yargs-parser/-/yargs-parser-20.2.9.tgz#2eb7dc3b0289718fc295f362753845c41a0c94ee" + integrity sha512-y11nGElTIV+CT3Zv9t7VKl+Q3hTQoT9a1Qzezhhl6Rp21gJ/IVTW7Z3y9EWXhuUBC2Shnf+DX0antecpAwSP8w== + +yargs@^16.0.0: + version "16.2.0" + resolved "https://registry.yarnpkg.com/yargs/-/yargs-16.2.0.tgz#1c82bf0f6b6a66eafce7ef30e376f49a12477f66" + integrity sha512-D1mvvtDG0L5ft/jGWkLpG1+m0eQxOfaBvTNELraWj22wSVUMWxZUvYgJYcKh6jGGIkJFhH4IZPQhR4TKpc8mBw== + dependencies: + cliui "^7.0.2" + escalade "^3.1.1" + get-caller-file "^2.0.5" + require-directory "^2.1.1" + string-width "^4.2.0" + y18n "^5.0.5" + yargs-parser "^20.2.2" + +yocto-queue@^0.1.0: + version "0.1.0" + resolved "https://registry.yarnpkg.com/yocto-queue/-/yocto-queue-0.1.0.tgz#0294eb3dee05028d31ee1a5fa2c556a6aaf10a1b" + integrity sha512-rVksvsnNCdJ/ohGc6xgPwyN8eheCxsiLM8mxuE/t/mOVqJewPuO1miLpTHQiRgTKCLexL4MeAFVagts7HmNZ2Q==