diff --git a/.github/ISSUE_TEMPLATE/11-tasks-issue.md b/.github/ISSUE_TEMPLATE/11-tasks-issue.md new file mode 100644 index 000000000..4e9ae721d --- /dev/null +++ b/.github/ISSUE_TEMPLATE/11-tasks-issue.md @@ -0,0 +1,25 @@ +--- +name: "Tasks Issue" +about: Use this template for assistance with using MediaPipe Tasks (developers.google.com/mediapipe/solutions) to deploy on-device ML solutions (e.g. gesture recognition etc.) on supported platforms. +labels: type:support + +--- +Please make sure that this is a [Tasks](https://developers.google.com/mediapipe/solutions) issue. + +**System information** (Please provide as much relevant information as possible) +- Have I written custom code (as opposed to using a stock example script provided in MediaPipe): +- OS Platform and Distribution (e.g., Linux Ubuntu 16.04, Android 11, iOS 14.4): +- MediaPipe Tasks SDK version: +- Task name (e.g. Object detection, Gesture recognition etc.): +- Programming Language and version ( e.g. C++, Python, Java): + +**Describe the expected behavior:** + +**Standalone code you may have used to try to get what you need :** + +If there is a problem, provide a reproducible test case that is the bare minimum necessary to generate the problem. If possible, please share a link to Colab, GitHub repo link or anything that we can use to reproduce the problem: + +**Other info / Complete Logs :** +Include any logs or source code that would be helpful to +diagnose the problem. If including tracebacks, please include the full +traceback. Large logs and files should be attached: diff --git a/.github/ISSUE_TEMPLATE/12-model-maker-issue.md b/.github/ISSUE_TEMPLATE/12-model-maker-issue.md new file mode 100644 index 000000000..31e8d7f1b --- /dev/null +++ b/.github/ISSUE_TEMPLATE/12-model-maker-issue.md @@ -0,0 +1,25 @@ +--- +name: "Model Maker Issue" +about: Use this template for assistance with using MediaPipe Model Maker (developers.google.com/mediapipe/solutions) to create custom on-device ML solutions. +labels: type:support + +--- +Please make sure that this is a [Model Maker](https://developers.google.com/mediapipe/solutions) issue. + +**System information** (Please provide as much relevant information as possible) +- Have I written custom code (as opposed to using a stock example script provided in MediaPipe): +- OS Platform and Distribution (e.g., Linux Ubuntu 16.04): +- Python version (e.g. 3.8): +- [MediaPipe Model Maker version](https://pypi.org/project/mediapipe-model-maker/): +- Task name (e.g. Image classification, Gesture recognition etc.): + +**Describe the expected behavior:** + +**Standalone code you may have used to try to get what you need :** + +If there is a problem, provide a reproducible test case that is the bare minimum necessary to generate the problem. If possible, please share a link to Colab, GitHub repo link or anything that we can use to reproduce the problem: + +**Other info / Complete Logs :** +Include any logs or source code that would be helpful to +diagnose the problem. If including tracebacks, please include the full +traceback. Large logs and files should be attached: diff --git a/.github/ISSUE_TEMPLATE/10-solution-issue.md b/.github/ISSUE_TEMPLATE/13-solution-issue.md similarity index 81% rename from .github/ISSUE_TEMPLATE/10-solution-issue.md rename to .github/ISSUE_TEMPLATE/13-solution-issue.md index a5332cb36..9297edf6b 100644 --- a/.github/ISSUE_TEMPLATE/10-solution-issue.md +++ b/.github/ISSUE_TEMPLATE/13-solution-issue.md @@ -1,6 +1,6 @@ --- -name: "Solution Issue" -about: Use this template for assistance with a specific mediapipe solution, such as "Pose" or "Iris", including inference model usage/training, solution-specific calculators, etc. +name: "Solution (legacy) Issue" +about: Use this template for assistance with a specific Mediapipe solution (google.github.io/mediapipe/solutions), such as "Pose" or "Iris", including inference model usage/training, solution-specific calculators, etc. labels: type:support --- diff --git a/.github/ISSUE_TEMPLATE/14-studio-issue.md b/.github/ISSUE_TEMPLATE/14-studio-issue.md new file mode 100644 index 000000000..5942b1eb1 --- /dev/null +++ b/.github/ISSUE_TEMPLATE/14-studio-issue.md @@ -0,0 +1,19 @@ +--- +name: "Studio Issue" +about: Use this template for assistance with the MediaPipe Studio application. +labels: type:support + +--- +Please make sure that this is a MediaPipe Studio issue. + +**System information** (Please provide as much relevant information as possible) +- OS Platform and Distribution (e.g., Linux Ubuntu 16.04, Android 11, iOS 14.4): +- Browser and Version +- Any microphone or camera hardware +- URL that shows the problem + +**Describe the expected behavior:** + +**Other info / Complete Logs :** +Include any js console logs that would be helpful to diagnose the problem. +Large logs and files should be attached: diff --git a/WORKSPACE b/WORKSPACE index d43394883..bf5e4236b 100644 --- a/WORKSPACE +++ b/WORKSPACE @@ -320,12 +320,30 @@ http_archive( ], ) -# iOS basic build deps. +# Load Zlib before initializing TensorFlow and the iOS build rules to guarantee +# that the target @zlib//:mini_zlib is available +http_archive( + name = "zlib", + build_file = "//third_party:zlib.BUILD", + sha256 = "c3e5e9fdd5004dcb542feda5ee4f0ff0744628baf8ed2dd5d66f8ca1197cb1a1", + strip_prefix = "zlib-1.2.11", + urls = [ + "http://mirror.bazel.build/zlib.net/fossils/zlib-1.2.11.tar.gz", + "http://zlib.net/fossils/zlib-1.2.11.tar.gz", # 2017-01-15 + ], + patches = [ + "@//third_party:zlib.diff", + ], + patch_args = [ + "-p1", + ], +) +# iOS basic build deps. http_archive( name = "build_bazel_rules_apple", - sha256 = "77e8bf6fda706f420a55874ae6ee4df0c9d95da6c7838228b26910fc82eea5a2", - url = "https://github.com/bazelbuild/rules_apple/releases/download/0.32.0/rules_apple.0.32.0.tar.gz", + sha256 = "f94e6dddf74739ef5cb30f000e13a2a613f6ebfa5e63588305a71fce8a8a9911", + url = "https://github.com/bazelbuild/rules_apple/releases/download/1.1.3/rules_apple.1.1.3.tar.gz", patches = [ # Bypass checking ios unit test runner when building MP ios applications. "@//third_party:build_bazel_rules_apple_bypass_test_runner_check.diff" @@ -339,29 +357,24 @@ load( "@build_bazel_rules_apple//apple:repositories.bzl", "apple_rules_dependencies", ) - apple_rules_dependencies() load( "@build_bazel_rules_swift//swift:repositories.bzl", "swift_rules_dependencies", ) - swift_rules_dependencies() -http_archive( - name = "build_bazel_apple_support", - sha256 = "741366f79d900c11e11d8efd6cc6c66a31bfb2451178b58e0b5edc6f1db17b35", - urls = [ - "https://github.com/bazelbuild/apple_support/releases/download/0.10.0/apple_support.0.10.0.tar.gz" - ], +load( + "@build_bazel_rules_swift//swift:extras.bzl", + "swift_rules_extra_dependencies", ) +swift_rules_extra_dependencies() load( "@build_bazel_apple_support//lib:repositories.bzl", "apple_support_dependencies", ) - apple_support_dependencies() # More iOS deps. @@ -442,25 +455,6 @@ http_archive( ], ) -# Load Zlib before initializing TensorFlow to guarantee that the target -# @zlib//:mini_zlib is available -http_archive( - name = "zlib", - build_file = "//third_party:zlib.BUILD", - sha256 = "c3e5e9fdd5004dcb542feda5ee4f0ff0744628baf8ed2dd5d66f8ca1197cb1a1", - strip_prefix = "zlib-1.2.11", - urls = [ - "http://mirror.bazel.build/zlib.net/fossils/zlib-1.2.11.tar.gz", - "http://zlib.net/fossils/zlib-1.2.11.tar.gz", # 2017-01-15 - ], - patches = [ - "@//third_party:zlib.diff", - ], - patch_args = [ - "-p1", - ], -) - # TensorFlow repo should always go after the other external dependencies. # TF on 2022-08-10. _TENSORFLOW_GIT_COMMIT = "af1d5bc4fbb66d9e6cc1cf89503014a99233583b" diff --git a/docs/build_model_maker_api_docs.py b/docs/build_model_maker_api_docs.py new file mode 100644 index 000000000..7732b7d56 --- /dev/null +++ b/docs/build_model_maker_api_docs.py @@ -0,0 +1,81 @@ +# Copyright 2022 The MediaPipe Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +r"""MediaPipe Model Maker reference docs generation script. + +This script generates API reference docs for the `mediapipe` PIP package. + +$> pip install -U git+https://github.com/tensorflow/docs mediapipe-model-maker +$> python build_model_maker_api_docs.py +""" + +import os + +from absl import app +from absl import flags + +from tensorflow_docs.api_generator import generate_lib + +try: + # mediapipe has not been set up to work with bazel yet, so catch & report. + import mediapipe_model_maker # pytype: disable=import-error +except ImportError as e: + raise ImportError('Please `pip install mediapipe-model-maker`.') from e + + +PROJECT_SHORT_NAME = 'mediapipe_model_maker' +PROJECT_FULL_NAME = 'MediaPipe Model Maker' + +_OUTPUT_DIR = flags.DEFINE_string( + 'output_dir', + default='/tmp/generated_docs', + help='Where to write the resulting docs.') + +_URL_PREFIX = flags.DEFINE_string( + 'code_url_prefix', + 'https://github.com/google/mediapipe/tree/master/mediapipe/model_maker', + 'The url prefix for links to code.') + +_SEARCH_HINTS = flags.DEFINE_bool( + 'search_hints', True, + 'Include metadata search hints in the generated files') + +_SITE_PATH = flags.DEFINE_string('site_path', '/mediapipe/api_docs/python', + 'Path prefix in the _toc.yaml') + + +def gen_api_docs(): + """Generates API docs for the mediapipe-model-maker package.""" + + doc_generator = generate_lib.DocGenerator( + root_title=PROJECT_FULL_NAME, + py_modules=[(PROJECT_SHORT_NAME, mediapipe_model_maker)], + base_dir=os.path.dirname(mediapipe_model_maker.__file__), + code_url_prefix=_URL_PREFIX.value, + search_hints=_SEARCH_HINTS.value, + site_path=_SITE_PATH.value, + callbacks=[], + ) + + doc_generator.build(_OUTPUT_DIR.value) + + print('Docs output to:', _OUTPUT_DIR.value) + + +def main(_): + gen_api_docs() + + +if __name__ == '__main__': + app.run(main) diff --git a/docs/build_py_api_docs.py b/docs/build_py_api_docs.py index fe706acd3..02eb04074 100644 --- a/docs/build_py_api_docs.py +++ b/docs/build_py_api_docs.py @@ -26,7 +26,6 @@ from absl import app from absl import flags from tensorflow_docs.api_generator import generate_lib -from tensorflow_docs.api_generator import public_api try: # mediapipe has not been set up to work with bazel yet, so catch & report. @@ -45,14 +44,14 @@ _OUTPUT_DIR = flags.DEFINE_string( _URL_PREFIX = flags.DEFINE_string( 'code_url_prefix', - 'https://github.com/google/mediapipe/tree/master/mediapipe', + 'https://github.com/google/mediapipe/blob/master/mediapipe', 'The url prefix for links to code.') _SEARCH_HINTS = flags.DEFINE_bool( 'search_hints', True, 'Include metadata search hints in the generated files') -_SITE_PATH = flags.DEFINE_string('site_path', '/mediapipe/api_docs/python', +_SITE_PATH = flags.DEFINE_string('site_path', '/mediapipe/api/solutions/python', 'Path prefix in the _toc.yaml') @@ -68,10 +67,7 @@ def gen_api_docs(): code_url_prefix=_URL_PREFIX.value, search_hints=_SEARCH_HINTS.value, site_path=_SITE_PATH.value, - # This callback ensures that docs are only generated for objects that - # are explicitly imported in your __init__.py files. There are other - # options but this is a good starting point. - callbacks=[public_api.explicit_package_contents_filter], + callbacks=[], ) doc_generator.build(_OUTPUT_DIR.value) diff --git a/mediapipe/calculators/core/BUILD b/mediapipe/calculators/core/BUILD index 3b658eb5b..2c143a609 100644 --- a/mediapipe/calculators/core/BUILD +++ b/mediapipe/calculators/core/BUILD @@ -17,12 +17,11 @@ load("//mediapipe/framework/port:build_config.bzl", "mediapipe_proto_library") licenses(["notice"]) -package(default_visibility = ["//visibility:private"]) +package(default_visibility = ["//visibility:public"]) mediapipe_proto_library( name = "concatenate_vector_calculator_proto", srcs = ["concatenate_vector_calculator.proto"], - visibility = ["//visibility:public"], deps = [ "//mediapipe/framework:calculator_options_proto", "//mediapipe/framework:calculator_proto", @@ -32,7 +31,6 @@ mediapipe_proto_library( mediapipe_proto_library( name = "dequantize_byte_array_calculator_proto", srcs = ["dequantize_byte_array_calculator.proto"], - visibility = ["//visibility:public"], deps = [ "//mediapipe/framework:calculator_options_proto", "//mediapipe/framework:calculator_proto", @@ -42,7 +40,6 @@ mediapipe_proto_library( mediapipe_proto_library( name = "packet_cloner_calculator_proto", srcs = ["packet_cloner_calculator.proto"], - visibility = ["//visibility:public"], deps = [ "//mediapipe/framework:calculator_options_proto", "//mediapipe/framework:calculator_proto", @@ -52,7 +49,6 @@ mediapipe_proto_library( mediapipe_proto_library( name = "packet_resampler_calculator_proto", srcs = ["packet_resampler_calculator.proto"], - visibility = ["//visibility:public"], deps = [ "//mediapipe/framework:calculator_options_proto", "//mediapipe/framework:calculator_proto", @@ -62,7 +58,6 @@ mediapipe_proto_library( mediapipe_proto_library( name = "packet_thinner_calculator_proto", srcs = ["packet_thinner_calculator.proto"], - visibility = ["//visibility:public"], deps = [ "//mediapipe/framework:calculator_options_proto", "//mediapipe/framework:calculator_proto", @@ -72,7 +67,6 @@ mediapipe_proto_library( mediapipe_proto_library( name = "split_vector_calculator_proto", srcs = ["split_vector_calculator.proto"], - visibility = ["//visibility:public"], deps = [ "//mediapipe/framework:calculator_options_proto", "//mediapipe/framework:calculator_proto", @@ -82,7 +76,6 @@ mediapipe_proto_library( mediapipe_proto_library( name = "quantize_float_vector_calculator_proto", srcs = ["quantize_float_vector_calculator.proto"], - visibility = ["//visibility:public"], deps = [ "//mediapipe/framework:calculator_options_proto", "//mediapipe/framework:calculator_proto", @@ -92,7 +85,6 @@ mediapipe_proto_library( mediapipe_proto_library( name = "sequence_shift_calculator_proto", srcs = ["sequence_shift_calculator.proto"], - visibility = ["//visibility:public"], deps = [ "//mediapipe/framework:calculator_options_proto", "//mediapipe/framework:calculator_proto", @@ -102,7 +94,6 @@ mediapipe_proto_library( mediapipe_proto_library( name = "gate_calculator_proto", srcs = ["gate_calculator.proto"], - visibility = ["//visibility:public"], deps = [ "//mediapipe/framework:calculator_options_proto", "//mediapipe/framework:calculator_proto", @@ -112,7 +103,6 @@ mediapipe_proto_library( mediapipe_proto_library( name = "constant_side_packet_calculator_proto", srcs = ["constant_side_packet_calculator.proto"], - visibility = ["//visibility:public"], deps = [ "//mediapipe/framework:calculator_options_proto", "//mediapipe/framework:calculator_proto", @@ -124,7 +114,6 @@ mediapipe_proto_library( mediapipe_proto_library( name = "clip_vector_size_calculator_proto", srcs = ["clip_vector_size_calculator.proto"], - visibility = ["//visibility:public"], deps = [ "//mediapipe/framework:calculator_options_proto", "//mediapipe/framework:calculator_proto", @@ -134,7 +123,6 @@ mediapipe_proto_library( mediapipe_proto_library( name = "flow_limiter_calculator_proto", srcs = ["flow_limiter_calculator.proto"], - visibility = ["//visibility:public"], deps = [ "//mediapipe/framework:calculator_options_proto", "//mediapipe/framework:calculator_proto", @@ -144,7 +132,6 @@ mediapipe_proto_library( mediapipe_proto_library( name = "graph_profile_calculator_proto", srcs = ["graph_profile_calculator.proto"], - visibility = ["//visibility:public"], deps = [ "//mediapipe/framework:calculator_options_proto", "//mediapipe/framework:calculator_proto", @@ -154,7 +141,6 @@ mediapipe_proto_library( mediapipe_proto_library( name = "get_vector_item_calculator_proto", srcs = ["get_vector_item_calculator.proto"], - visibility = ["//visibility:public"], deps = [ "//mediapipe/framework:calculator_options_proto", "//mediapipe/framework:calculator_proto", @@ -164,7 +150,6 @@ mediapipe_proto_library( cc_library( name = "add_header_calculator", srcs = ["add_header_calculator.cc"], - visibility = ["//visibility:public"], deps = [ "//mediapipe/framework:calculator_framework", "//mediapipe/framework/api2:node", @@ -193,7 +178,6 @@ cc_library( name = "begin_loop_calculator", srcs = ["begin_loop_calculator.cc"], hdrs = ["begin_loop_calculator.h"], - visibility = ["//visibility:public"], deps = [ "//mediapipe/framework:calculator_context", "//mediapipe/framework:calculator_contract", @@ -216,7 +200,6 @@ cc_library( name = "end_loop_calculator", srcs = ["end_loop_calculator.cc"], hdrs = ["end_loop_calculator.h"], - visibility = ["//visibility:public"], deps = [ "//mediapipe/framework:calculator_context", "//mediapipe/framework:calculator_contract", @@ -258,7 +241,6 @@ cc_test( cc_library( name = "concatenate_vector_calculator_hdr", hdrs = ["concatenate_vector_calculator.h"], - visibility = ["//visibility:public"], deps = [ ":concatenate_vector_calculator_cc_proto", "//mediapipe/framework:calculator_framework", @@ -284,7 +266,6 @@ cc_library( ], "//conditions:default": [], }), - visibility = ["//visibility:public"], deps = [ ":concatenate_vector_calculator_cc_proto", "//mediapipe/framework/api2:node", @@ -311,7 +292,6 @@ cc_library( cc_library( name = "concatenate_detection_vector_calculator", srcs = ["concatenate_detection_vector_calculator.cc"], - visibility = ["//visibility:public"], deps = [ ":concatenate_vector_calculator", "//mediapipe/framework:calculator_framework", @@ -323,7 +303,6 @@ cc_library( cc_library( name = "concatenate_proto_list_calculator", srcs = ["concatenate_proto_list_calculator.cc"], - visibility = ["//visibility:public"], deps = [ ":concatenate_vector_calculator_cc_proto", "//mediapipe/framework:calculator_framework", @@ -372,7 +351,6 @@ cc_library( name = "clip_vector_size_calculator", srcs = ["clip_vector_size_calculator.cc"], hdrs = ["clip_vector_size_calculator.h"], - visibility = ["//visibility:public"], deps = [ ":clip_vector_size_calculator_cc_proto", "//mediapipe/framework:calculator_framework", @@ -388,7 +366,6 @@ cc_library( cc_library( name = "clip_detection_vector_size_calculator", srcs = ["clip_detection_vector_size_calculator.cc"], - visibility = ["//visibility:public"], deps = [ ":clip_vector_size_calculator", "//mediapipe/framework:calculator_framework", @@ -415,9 +392,6 @@ cc_test( cc_library( name = "counting_source_calculator", srcs = ["counting_source_calculator.cc"], - visibility = [ - "//visibility:public", - ], deps = [ "//mediapipe/framework:calculator_framework", "//mediapipe/framework/port:ret_check", @@ -430,9 +404,6 @@ cc_library( cc_library( name = "make_pair_calculator", srcs = ["make_pair_calculator.cc"], - visibility = [ - "//visibility:public", - ], deps = [ "//mediapipe/framework:calculator_framework", "//mediapipe/framework/api2:node", @@ -461,9 +432,6 @@ cc_test( cc_library( name = "matrix_multiply_calculator", srcs = ["matrix_multiply_calculator.cc"], - visibility = [ - "//visibility:public", - ], deps = [ "//mediapipe/framework:calculator_framework", "//mediapipe/framework/api2:node", @@ -477,9 +445,6 @@ cc_library( cc_library( name = "matrix_subtract_calculator", srcs = ["matrix_subtract_calculator.cc"], - visibility = [ - "//visibility:public", - ], deps = [ "//mediapipe/framework:calculator_framework", "//mediapipe/framework/api2:node", @@ -493,9 +458,6 @@ cc_library( cc_library( name = "mux_calculator", srcs = ["mux_calculator.cc"], - visibility = [ - "//visibility:public", - ], deps = [ "//mediapipe/framework:calculator_framework", "//mediapipe/framework/api2:node", @@ -508,9 +470,6 @@ cc_library( cc_library( name = "non_zero_calculator", srcs = ["non_zero_calculator.cc"], - visibility = [ - "//visibility:public", - ], deps = [ "//mediapipe/framework:calculator_framework", "//mediapipe/framework/api2:node", @@ -556,9 +515,6 @@ cc_test( cc_library( name = "packet_cloner_calculator", srcs = ["packet_cloner_calculator.cc"], - visibility = [ - "//visibility:public", - ], deps = [ ":packet_cloner_calculator_cc_proto", "//mediapipe/framework:calculator_framework", @@ -587,7 +543,6 @@ cc_test( cc_library( name = "packet_inner_join_calculator", srcs = ["packet_inner_join_calculator.cc"], - visibility = ["//visibility:public"], deps = [ "//mediapipe/framework:calculator_framework", "//mediapipe/framework/port:ret_check", @@ -611,7 +566,6 @@ cc_test( cc_library( name = "packet_thinner_calculator", srcs = ["packet_thinner_calculator.cc"], - visibility = ["//visibility:public"], deps = [ "//mediapipe/calculators/core:packet_thinner_calculator_cc_proto", "//mediapipe/framework:calculator_context", @@ -643,9 +597,6 @@ cc_test( cc_library( name = "pass_through_calculator", srcs = ["pass_through_calculator.cc"], - visibility = [ - "//visibility:public", - ], deps = [ "//mediapipe/framework:calculator_framework", "//mediapipe/framework/port:status", @@ -656,9 +607,6 @@ cc_library( cc_library( name = "round_robin_demux_calculator", srcs = ["round_robin_demux_calculator.cc"], - visibility = [ - "//visibility:public", - ], deps = [ "//mediapipe/framework:calculator_framework", "//mediapipe/framework/api2:node", @@ -670,9 +618,6 @@ cc_library( cc_library( name = "immediate_mux_calculator", srcs = ["immediate_mux_calculator.cc"], - visibility = [ - "//visibility:public", - ], deps = [ "//mediapipe/framework:calculator_framework", "//mediapipe/framework/port:ret_check", @@ -684,7 +629,6 @@ cc_library( cc_library( name = "packet_presence_calculator", srcs = ["packet_presence_calculator.cc"], - visibility = ["//visibility:public"], deps = [ "//mediapipe/framework:calculator_framework", "//mediapipe/framework:packet", @@ -713,7 +657,6 @@ cc_test( cc_library( name = "previous_loopback_calculator", srcs = ["previous_loopback_calculator.cc"], - visibility = ["//visibility:public"], deps = [ "//mediapipe/framework:calculator_framework", "//mediapipe/framework:packet", @@ -729,7 +672,6 @@ cc_library( cc_library( name = "flow_limiter_calculator", srcs = ["flow_limiter_calculator.cc"], - visibility = ["//visibility:public"], deps = [ ":flow_limiter_calculator_cc_proto", "//mediapipe/framework:calculator_framework", @@ -746,7 +688,6 @@ cc_library( cc_library( name = "string_to_int_calculator", srcs = ["string_to_int_calculator.cc"], - visibility = ["//visibility:public"], deps = [ "//mediapipe/framework:calculator_framework", "//mediapipe/framework/port:integral_types", @@ -759,7 +700,6 @@ cc_library( cc_library( name = "default_side_packet_calculator", srcs = ["default_side_packet_calculator.cc"], - visibility = ["//visibility:public"], deps = [ "//mediapipe/framework:calculator_framework", "//mediapipe/framework/port:ret_check", @@ -771,7 +711,6 @@ cc_library( cc_library( name = "side_packet_to_stream_calculator", srcs = ["side_packet_to_stream_calculator.cc"], - visibility = ["//visibility:public"], deps = [ "//mediapipe/framework:calculator_framework", "//mediapipe/framework/port:logging", @@ -822,9 +761,6 @@ cc_library( name = "packet_resampler_calculator", srcs = ["packet_resampler_calculator.cc"], hdrs = ["packet_resampler_calculator.h"], - visibility = [ - "//visibility:public", - ], deps = [ "//mediapipe/calculators/core:packet_resampler_calculator_cc_proto", "//mediapipe/framework:calculator_framework", @@ -884,7 +820,6 @@ cc_test( cc_test( name = "matrix_multiply_calculator_test", srcs = ["matrix_multiply_calculator_test.cc"], - visibility = ["//visibility:private"], deps = [ ":matrix_multiply_calculator", "//mediapipe/framework:calculator_framework", @@ -900,7 +835,6 @@ cc_test( cc_test( name = "matrix_subtract_calculator_test", srcs = ["matrix_subtract_calculator_test.cc"], - visibility = ["//visibility:private"], deps = [ ":matrix_subtract_calculator", "//mediapipe/framework:calculator_framework", @@ -950,7 +884,6 @@ cc_library( ], "//conditions:default": [], }), - visibility = ["//visibility:public"], deps = [ ":split_vector_calculator_cc_proto", "//mediapipe/framework/formats:detection_cc_proto", @@ -996,7 +929,6 @@ cc_test( cc_library( name = "split_proto_list_calculator", srcs = ["split_proto_list_calculator.cc"], - visibility = ["//visibility:public"], deps = [ ":split_vector_calculator_cc_proto", "//mediapipe/framework:calculator_framework", @@ -1028,7 +960,6 @@ cc_test( cc_library( name = "dequantize_byte_array_calculator", srcs = ["dequantize_byte_array_calculator.cc"], - visibility = ["//visibility:public"], deps = [ ":dequantize_byte_array_calculator_cc_proto", "//mediapipe/framework:calculator_context", @@ -1054,7 +985,6 @@ cc_test( cc_library( name = "quantize_float_vector_calculator", srcs = ["quantize_float_vector_calculator.cc"], - visibility = ["//visibility:public"], deps = [ ":quantize_float_vector_calculator_cc_proto", "//mediapipe/framework:calculator_context", @@ -1080,7 +1010,6 @@ cc_test( cc_library( name = "sequence_shift_calculator", srcs = ["sequence_shift_calculator.cc"], - visibility = ["//visibility:public"], deps = [ ":sequence_shift_calculator_cc_proto", "//mediapipe/framework:calculator_framework", @@ -1105,7 +1034,6 @@ cc_test( cc_library( name = "gate_calculator", srcs = ["gate_calculator.cc"], - visibility = ["//visibility:public"], deps = [ ":gate_calculator_cc_proto", "//mediapipe/framework:calculator_framework", @@ -1131,7 +1059,6 @@ cc_test( cc_library( name = "matrix_to_vector_calculator", srcs = ["matrix_to_vector_calculator.cc"], - visibility = ["//visibility:public"], deps = [ "//mediapipe/framework:calculator_framework", "//mediapipe/framework/api2:node", @@ -1167,7 +1094,6 @@ cc_test( cc_library( name = "merge_calculator", srcs = ["merge_calculator.cc"], - visibility = ["//visibility:public"], deps = [ "//mediapipe/framework:calculator_framework", "//mediapipe/framework/api2:node", @@ -1193,7 +1119,6 @@ cc_test( cc_library( name = "stream_to_side_packet_calculator", srcs = ["stream_to_side_packet_calculator.cc"], - visibility = ["//visibility:public"], deps = [ "//mediapipe/framework:calculator_framework", "//mediapipe/framework:timestamp", @@ -1219,7 +1144,6 @@ cc_test( cc_library( name = "constant_side_packet_calculator", srcs = ["constant_side_packet_calculator.cc"], - visibility = ["//visibility:public"], deps = [ ":constant_side_packet_calculator_cc_proto", "//mediapipe/framework:calculator_framework", @@ -1249,7 +1173,6 @@ cc_test( cc_library( name = "graph_profile_calculator", srcs = ["graph_profile_calculator.cc"], - visibility = ["//visibility:public"], deps = [ ":graph_profile_calculator_cc_proto", "//mediapipe/framework:calculator_framework", @@ -1291,7 +1214,6 @@ cc_library( name = "get_vector_item_calculator", srcs = ["get_vector_item_calculator.cc"], hdrs = ["get_vector_item_calculator.h"], - visibility = ["//visibility:public"], deps = [ ":get_vector_item_calculator_cc_proto", "//mediapipe/framework:calculator_framework", @@ -1325,7 +1247,6 @@ cc_library( name = "vector_indices_calculator", srcs = ["vector_indices_calculator.cc"], hdrs = ["vector_indices_calculator.h"], - visibility = ["//visibility:public"], deps = [ "//mediapipe/framework:calculator_framework", "//mediapipe/framework/api2:node", @@ -1351,7 +1272,6 @@ cc_library( name = "vector_size_calculator", srcs = ["vector_size_calculator.cc"], hdrs = ["vector_size_calculator.h"], - visibility = ["//visibility:public"], deps = [ "//mediapipe/framework:calculator_framework", "//mediapipe/framework/api2:node", @@ -1365,9 +1285,6 @@ cc_library( cc_library( name = "packet_sequencer_calculator", srcs = ["packet_sequencer_calculator.cc"], - visibility = [ - "//visibility:public", - ], deps = [ "//mediapipe/framework:calculator_framework", "//mediapipe/framework/api2:contract", @@ -1402,11 +1319,11 @@ cc_library( name = "merge_to_vector_calculator", srcs = ["merge_to_vector_calculator.cc"], hdrs = ["merge_to_vector_calculator.h"], - visibility = ["//visibility:public"], deps = [ "//mediapipe/framework:calculator_framework", "//mediapipe/framework/api2:node", "//mediapipe/framework/api2:port", + "//mediapipe/framework/formats:detection_cc_proto", "//mediapipe/framework/formats:image", "@com_google_absl//absl/status", ], @@ -1416,7 +1333,6 @@ cc_library( mediapipe_proto_library( name = "bypass_calculator_proto", srcs = ["bypass_calculator.proto"], - visibility = ["//visibility:public"], deps = [ "//mediapipe/framework:calculator_options_proto", "//mediapipe/framework:calculator_proto", @@ -1426,7 +1342,6 @@ mediapipe_proto_library( cc_library( name = "bypass_calculator", srcs = ["bypass_calculator.cc"], - visibility = ["//visibility:public"], deps = [ ":bypass_calculator_cc_proto", "//mediapipe/framework:calculator_framework", diff --git a/mediapipe/calculators/core/get_vector_item_calculator.h b/mediapipe/calculators/core/get_vector_item_calculator.h index dc98ccfe7..25d90bfe6 100644 --- a/mediapipe/calculators/core/get_vector_item_calculator.h +++ b/mediapipe/calculators/core/get_vector_item_calculator.h @@ -65,6 +65,7 @@ class GetVectorItemCalculator : public Node { MEDIAPIPE_NODE_CONTRACT(kIn, kIdx, kOut); absl::Status Open(CalculatorContext* cc) final { + cc->SetOffset(mediapipe::TimestampDiff(0)); auto& options = cc->Options(); RET_CHECK(kIdx(cc).IsConnected() || options.has_item_index()); return absl::OkStatus(); @@ -90,8 +91,12 @@ class GetVectorItemCalculator : public Node { return absl::OkStatus(); } - RET_CHECK(idx >= 0 && idx < items.size()); - kOut(cc).Send(items[idx]); + RET_CHECK(idx >= 0); + RET_CHECK(options.output_empty_on_oob() || idx < items.size()); + + if (idx < items.size()) { + kOut(cc).Send(items[idx]); + } return absl::OkStatus(); } diff --git a/mediapipe/calculators/core/get_vector_item_calculator.proto b/mediapipe/calculators/core/get_vector_item_calculator.proto index c406283e4..9cfb579e4 100644 --- a/mediapipe/calculators/core/get_vector_item_calculator.proto +++ b/mediapipe/calculators/core/get_vector_item_calculator.proto @@ -26,4 +26,7 @@ message GetVectorItemCalculatorOptions { // Index of vector item to get. INDEX input stream can be used instead, or to // override. optional int32 item_index = 1; + + // Set to true to output an empty packet when the index is out of bounds. + optional bool output_empty_on_oob = 2; } diff --git a/mediapipe/calculators/core/get_vector_item_calculator_test.cc b/mediapipe/calculators/core/get_vector_item_calculator_test.cc index c148aa9d1..c2974e20a 100644 --- a/mediapipe/calculators/core/get_vector_item_calculator_test.cc +++ b/mediapipe/calculators/core/get_vector_item_calculator_test.cc @@ -32,18 +32,21 @@ CalculatorRunner MakeRunnerWithStream() { )"); } -CalculatorRunner MakeRunnerWithOptions(int set_index) { - return CalculatorRunner(absl::StrFormat(R"( +CalculatorRunner MakeRunnerWithOptions(int set_index, + bool output_empty_on_oob = false) { + return CalculatorRunner( + absl::StrFormat(R"( calculator: "TestGetIntVectorItemCalculator" input_stream: "VECTOR:vector_stream" output_stream: "ITEM:item_stream" options { [mediapipe.GetVectorItemCalculatorOptions.ext] { item_index: %d + output_empty_on_oob: %s } } )", - set_index)); + set_index, output_empty_on_oob ? "true" : "false")); } void AddInputVector(CalculatorRunner& runner, const std::vector& inputs, @@ -140,8 +143,7 @@ TEST(TestGetIntVectorItemCalculatorTest, StreamIndexBoundsCheckFail1) { absl::Status status = runner.Run(); ASSERT_FALSE(status.ok()); - EXPECT_THAT(status.message(), - testing::HasSubstr("idx >= 0 && idx < items.size()")); + EXPECT_THAT(status.message(), testing::HasSubstr("idx >= 0")); } TEST(TestGetIntVectorItemCalculatorTest, StreamIndexBoundsCheckFail2) { @@ -155,7 +157,8 @@ TEST(TestGetIntVectorItemCalculatorTest, StreamIndexBoundsCheckFail2) { absl::Status status = runner.Run(); ASSERT_FALSE(status.ok()); EXPECT_THAT(status.message(), - testing::HasSubstr("idx >= 0 && idx < items.size()")); + testing::HasSubstr( + "options.output_empty_on_oob() || idx < items.size()")); } TEST(TestGetIntVectorItemCalculatorTest, OptionsIndexBoundsCheckFail1) { @@ -167,8 +170,7 @@ TEST(TestGetIntVectorItemCalculatorTest, OptionsIndexBoundsCheckFail1) { absl::Status status = runner.Run(); ASSERT_FALSE(status.ok()); - EXPECT_THAT(status.message(), - testing::HasSubstr("idx >= 0 && idx < items.size()")); + EXPECT_THAT(status.message(), testing::HasSubstr("idx >= 0")); } TEST(TestGetIntVectorItemCalculatorTest, OptionsIndexBoundsCheckFail2) { @@ -181,7 +183,21 @@ TEST(TestGetIntVectorItemCalculatorTest, OptionsIndexBoundsCheckFail2) { absl::Status status = runner.Run(); ASSERT_FALSE(status.ok()); EXPECT_THAT(status.message(), - testing::HasSubstr("idx >= 0 && idx < items.size()")); + testing::HasSubstr( + "options.output_empty_on_oob() || idx < items.size()")); +} + +TEST(TestGetIntVectorItemCalculatorTest, OptionsIndexBoundsCheckFail3) { + const int try_index = 3; + CalculatorRunner runner = MakeRunnerWithOptions(try_index, true); + const std::vector inputs = {1, 2, 3}; + + AddInputVector(runner, inputs, 1); + + MP_ASSERT_OK(runner.Run()); + + const std::vector& outputs = runner.Outputs().Tag("ITEM").packets; + EXPECT_THAT(outputs, testing::ElementsAre()); } TEST(TestGetIntVectorItemCalculatorTest, IndexStreamTwoTimestamps) { diff --git a/mediapipe/calculators/core/merge_to_vector_calculator.cc b/mediapipe/calculators/core/merge_to_vector_calculator.cc index cca64bc9a..fd053ed2b 100644 --- a/mediapipe/calculators/core/merge_to_vector_calculator.cc +++ b/mediapipe/calculators/core/merge_to_vector_calculator.cc @@ -15,6 +15,7 @@ limitations under the License. #include "mediapipe/calculators/core/merge_to_vector_calculator.h" +#include "mediapipe/framework/formats/detection.pb.h" #include "mediapipe/framework/formats/image.h" namespace mediapipe { @@ -23,5 +24,13 @@ namespace api2 { typedef MergeToVectorCalculator MergeImagesToVectorCalculator; MEDIAPIPE_REGISTER_NODE(MergeImagesToVectorCalculator); +typedef MergeToVectorCalculator + MergeGpuBuffersToVectorCalculator; +MEDIAPIPE_REGISTER_NODE(MergeGpuBuffersToVectorCalculator); + +typedef MergeToVectorCalculator + MergeDetectionsToVectorCalculator; +MEDIAPIPE_REGISTER_NODE(MergeDetectionsToVectorCalculator); + } // namespace api2 } // namespace mediapipe diff --git a/mediapipe/calculators/core/merge_to_vector_calculator.h b/mediapipe/calculators/core/merge_to_vector_calculator.h index bed616695..f63d86ee4 100644 --- a/mediapipe/calculators/core/merge_to_vector_calculator.h +++ b/mediapipe/calculators/core/merge_to_vector_calculator.h @@ -42,11 +42,20 @@ class MergeToVectorCalculator : public Node { return absl::OkStatus(); } + absl::Status Open(::mediapipe::CalculatorContext* cc) { + cc->SetOffset(::mediapipe::TimestampDiff(0)); + return absl::OkStatus(); + } + absl::Status Process(CalculatorContext* cc) { const int input_num = kIn(cc).Count(); - std::vector output_vector(input_num); - std::transform(kIn(cc).begin(), kIn(cc).end(), output_vector.begin(), - [](const auto& elem) -> T { return elem.Get(); }); + std::vector output_vector; + for (auto it = kIn(cc).begin(); it != kIn(cc).end(); it++) { + const auto& elem = *it; + if (!elem.IsEmpty()) { + output_vector.push_back(elem.Get()); + } + } kOut(cc).Send(output_vector); return absl::OkStatus(); } diff --git a/mediapipe/calculators/image/image_cropping_calculator.cc b/mediapipe/calculators/image/image_cropping_calculator.cc index 8c9305ffb..1a2b2e5b0 100644 --- a/mediapipe/calculators/image/image_cropping_calculator.cc +++ b/mediapipe/calculators/image/image_cropping_calculator.cc @@ -37,7 +37,8 @@ enum { ATTRIB_VERTEX, ATTRIB_TEXTURE_POSITION, NUM_ATTRIBUTES }; namespace mediapipe { namespace { - +using ::mediapipe::NormalizedRect; +using ::mediapipe::Rect; #if !MEDIAPIPE_DISABLE_GPU #endif // !MEDIAPIPE_DISABLE_GPU diff --git a/mediapipe/calculators/image/image_cropping_calculator_test.cc b/mediapipe/calculators/image/image_cropping_calculator_test.cc index b3f692889..3c565282b 100644 --- a/mediapipe/calculators/image/image_cropping_calculator_test.cc +++ b/mediapipe/calculators/image/image_cropping_calculator_test.cc @@ -195,11 +195,11 @@ TEST(ImageCroppingCalculatorTest, RedundantSpecWithInputStream) { auto cc = absl::make_unique( calculator_state.get(), inputTags, tool::CreateTagMap({}).value()); auto& inputs = cc->Inputs(); - mediapipe::Rect rect = ParseTextProtoOrDie( + Rect rect = ParseTextProtoOrDie( R"pb( width: 1 height: 1 x_center: 40 y_center: 40 rotation: 0.5 )pb"); - inputs.Tag(kRectTag).Value() = MakePacket(rect); + inputs.Tag(kRectTag).Value() = MakePacket(rect); RectSpec expectRect = { .width = 1, .height = 1, diff --git a/mediapipe/calculators/internal/BUILD b/mediapipe/calculators/internal/BUILD index 54b6c20f1..caade2dc3 100644 --- a/mediapipe/calculators/internal/BUILD +++ b/mediapipe/calculators/internal/BUILD @@ -21,7 +21,7 @@ package(default_visibility = ["//visibility:private"]) proto_library( name = "callback_packet_calculator_proto", srcs = ["callback_packet_calculator.proto"], - visibility = ["//visibility:public"], + visibility = ["//mediapipe/framework:__subpackages__"], deps = ["//mediapipe/framework:calculator_proto"], ) @@ -29,14 +29,14 @@ mediapipe_cc_proto_library( name = "callback_packet_calculator_cc_proto", srcs = ["callback_packet_calculator.proto"], cc_deps = ["//mediapipe/framework:calculator_cc_proto"], - visibility = ["//visibility:public"], + visibility = ["//mediapipe/framework:__subpackages__"], deps = [":callback_packet_calculator_proto"], ) cc_library( name = "callback_packet_calculator", srcs = ["callback_packet_calculator.cc"], - visibility = ["//visibility:public"], + visibility = ["//mediapipe/framework:__subpackages__"], deps = [ ":callback_packet_calculator_cc_proto", "//mediapipe/framework:calculator_base", diff --git a/mediapipe/calculators/tensor/BUILD b/mediapipe/calculators/tensor/BUILD index 645189a07..dec68deac 100644 --- a/mediapipe/calculators/tensor/BUILD +++ b/mediapipe/calculators/tensor/BUILD @@ -24,7 +24,7 @@ load("//mediapipe/framework:encode_binary_proto.bzl", "encode_binary_proto") licenses(["notice"]) -package(default_visibility = ["//visibility:private"]) +package(default_visibility = ["//visibility:public"]) exports_files( glob(["testdata/image_to_tensor/*"]), @@ -44,9 +44,6 @@ selects.config_setting_group( mediapipe_proto_library( name = "audio_to_tensor_calculator_proto", srcs = ["audio_to_tensor_calculator.proto"], - visibility = [ - "//mediapipe/framework:mediapipe_internal", - ], deps = [ "//mediapipe/framework:calculator_options_proto", "//mediapipe/framework:calculator_proto", @@ -64,9 +61,6 @@ cc_library( ], "//conditions:default": [], }), - visibility = [ - "//mediapipe/framework:mediapipe_internal", - ], deps = [ ":audio_to_tensor_calculator_cc_proto", "//mediapipe/framework:calculator_framework", @@ -113,9 +107,6 @@ cc_test( mediapipe_proto_library( name = "tensors_to_audio_calculator_proto", srcs = ["tensors_to_audio_calculator.proto"], - visibility = [ - "//mediapipe/framework:mediapipe_internal", - ], deps = [ "//mediapipe/framework:calculator_options_proto", "//mediapipe/framework:calculator_proto", @@ -125,9 +116,6 @@ mediapipe_proto_library( cc_library( name = "tensors_to_audio_calculator", srcs = ["tensors_to_audio_calculator.cc"], - visibility = [ - "//mediapipe/framework:mediapipe_internal", - ], deps = [ ":tensors_to_audio_calculator_cc_proto", "//mediapipe/framework:calculator_framework", @@ -164,9 +152,6 @@ cc_test( mediapipe_proto_library( name = "feedback_tensors_calculator_proto", srcs = ["feedback_tensors_calculator.proto"], - visibility = [ - "//mediapipe/framework:mediapipe_internal", - ], deps = [ "//mediapipe/framework:calculator_options_proto", "//mediapipe/framework:calculator_proto", @@ -184,9 +169,6 @@ cc_library( ], "//conditions:default": [], }), - visibility = [ - "//mediapipe/framework:mediapipe_internal", - ], deps = [ ":feedback_tensors_calculator_cc_proto", "//mediapipe/framework:calculator_framework", @@ -216,9 +198,6 @@ cc_test( mediapipe_proto_library( name = "bert_preprocessor_calculator_proto", srcs = ["bert_preprocessor_calculator.proto"], - visibility = [ - "//mediapipe/framework:mediapipe_internal", - ], deps = [ "//mediapipe/framework:calculator_options_proto", "//mediapipe/framework:calculator_proto", @@ -228,9 +207,6 @@ mediapipe_proto_library( cc_library( name = "bert_preprocessor_calculator", srcs = ["bert_preprocessor_calculator.cc"], - visibility = [ - "//mediapipe/framework:mediapipe_internal", - ], deps = [ ":bert_preprocessor_calculator_cc_proto", "//mediapipe/framework:calculator_framework", @@ -274,9 +250,6 @@ cc_test( mediapipe_proto_library( name = "regex_preprocessor_calculator_proto", srcs = ["regex_preprocessor_calculator.proto"], - visibility = [ - "//mediapipe/framework:mediapipe_internal", - ], deps = [ "//mediapipe/framework:calculator_options_proto", "//mediapipe/framework:calculator_proto", @@ -286,9 +259,6 @@ mediapipe_proto_library( cc_library( name = "regex_preprocessor_calculator", srcs = ["regex_preprocessor_calculator.cc"], - visibility = [ - "//mediapipe/framework:mediapipe_internal", - ], deps = [ ":regex_preprocessor_calculator_cc_proto", "//mediapipe/framework:calculator_framework", @@ -330,9 +300,6 @@ cc_test( cc_library( name = "text_to_tensor_calculator", srcs = ["text_to_tensor_calculator.cc"], - visibility = [ - "//mediapipe/framework:mediapipe_internal", - ], deps = [ "//mediapipe/framework:calculator_context", "//mediapipe/framework:calculator_framework", @@ -405,7 +372,6 @@ cc_test( mediapipe_proto_library( name = "inference_calculator_proto", srcs = ["inference_calculator.proto"], - visibility = ["//visibility:public"], deps = [ "//mediapipe/framework:calculator_options_proto", "//mediapipe/framework:calculator_proto", @@ -432,7 +398,6 @@ cc_library( ], "//conditions:default": [], }), - visibility = ["//visibility:public"], deps = [ ":inference_calculator_cc_proto", ":inference_calculator_options_lib", @@ -457,7 +422,6 @@ cc_library( name = "inference_calculator_gl", srcs = ["inference_calculator_gl.cc"], tags = ["nomac"], # config problem with cpuinfo via TF - visibility = ["//visibility:public"], deps = [ ":inference_calculator_cc_proto", ":inference_calculator_interface", @@ -475,7 +439,6 @@ cc_library( name = "inference_calculator_gl_advanced", srcs = ["inference_calculator_gl_advanced.cc"], tags = ["nomac"], - visibility = ["//visibility:public"], deps = [ ":inference_calculator_interface", "@com_google_absl//absl/memory", @@ -506,7 +469,6 @@ cc_library( "-framework MetalKit", ], tags = ["ios"], - visibility = ["//visibility:public"], deps = [ "inference_calculator_interface", "//mediapipe/gpu:MPPMetalHelper", @@ -535,7 +497,6 @@ cc_library( ], "//conditions:default": [], }), - visibility = ["//visibility:public"], deps = [ "//mediapipe/framework:calculator_context", "//mediapipe/framework/formats:tensor", @@ -555,7 +516,6 @@ cc_library( ], "//conditions:default": [], }), - visibility = ["//visibility:public"], deps = [ ":inference_runner", "//mediapipe/framework:mediapipe_profiling", @@ -585,7 +545,6 @@ cc_library( ], "//conditions:default": [], }), - visibility = ["//visibility:public"], deps = [ ":inference_calculator_interface", ":inference_calculator_utils", @@ -632,7 +591,6 @@ cc_library( ], "//conditions:default": [], }), - visibility = ["//visibility:public"], deps = [ ":inference_calculator_interface", ":inference_calculator_utils", @@ -648,7 +606,6 @@ cc_library( cc_library( name = "inference_calculator_gl_if_compute_shader_available", - visibility = ["//visibility:public"], deps = selects.with_or({ ":compute_shader_unavailable": [], "//conditions:default": [ @@ -664,7 +621,6 @@ cc_library( # inference_calculator_interface. cc_library( name = "inference_calculator", - visibility = ["//visibility:public"], deps = [ ":inference_calculator_interface", ":inference_calculator_cpu", @@ -678,7 +634,6 @@ cc_library( mediapipe_proto_library( name = "tensor_converter_calculator_proto", srcs = ["tensor_converter_calculator.proto"], - visibility = ["//visibility:public"], deps = [ "//mediapipe/framework:calculator_options_proto", "//mediapipe/framework:calculator_proto", @@ -703,7 +658,6 @@ cc_library( ], "//conditions:default": [], }), - visibility = ["//visibility:public"], deps = [ ":tensor_converter_calculator_cc_proto", "//mediapipe/framework:calculator_framework", @@ -722,6 +676,7 @@ cc_library( cc_library( name = "tensor_converter_calculator_gpu_deps", + visibility = ["//visibility:private"], deps = select({ "//mediapipe:android": [ "//mediapipe/gpu:gl_calculator_helper", @@ -766,7 +721,6 @@ cc_test( mediapipe_proto_library( name = "tensors_to_detections_calculator_proto", srcs = ["tensors_to_detections_calculator.proto"], - visibility = ["//visibility:public"], deps = [ "//mediapipe/framework:calculator_options_proto", "//mediapipe/framework:calculator_proto", @@ -791,7 +745,6 @@ cc_library( ], "//conditions:default": [], }), - visibility = ["//visibility:public"], deps = [ ":tensors_to_detections_calculator_cc_proto", "//mediapipe/framework/formats:detection_cc_proto", @@ -814,6 +767,7 @@ cc_library( cc_library( name = "tensors_to_detections_calculator_gpu_deps", + visibility = ["//visibility:private"], deps = select({ "//mediapipe:ios": [ "//mediapipe/gpu:MPPMetalUtil", @@ -829,7 +783,6 @@ cc_library( mediapipe_proto_library( name = "tensors_to_landmarks_calculator_proto", srcs = ["tensors_to_landmarks_calculator.proto"], - visibility = ["//visibility:public"], deps = [ "//mediapipe/framework:calculator_options_proto", "//mediapipe/framework:calculator_proto", @@ -846,7 +799,6 @@ cc_library( ], "//conditions:default": [], }), - visibility = ["//visibility:public"], deps = [ ":tensors_to_landmarks_calculator_cc_proto", "//mediapipe/framework:calculator_framework", @@ -861,7 +813,6 @@ cc_library( mediapipe_proto_library( name = "landmarks_to_tensor_calculator_proto", srcs = ["landmarks_to_tensor_calculator.proto"], - visibility = ["//visibility:public"], deps = [ "//mediapipe/framework:calculator_options_proto", "//mediapipe/framework:calculator_proto", @@ -879,7 +830,6 @@ cc_library( ], "//conditions:default": [], }), - visibility = ["//visibility:public"], deps = [ ":landmarks_to_tensor_calculator_cc_proto", "//mediapipe/framework:calculator_framework", @@ -912,7 +862,6 @@ cc_test( mediapipe_proto_library( name = "tensors_to_floats_calculator_proto", srcs = ["tensors_to_floats_calculator.proto"], - visibility = ["//visibility:public"], deps = [ "//mediapipe/framework:calculator_options_proto", "//mediapipe/framework:calculator_proto", @@ -929,7 +878,6 @@ cc_library( ], "//conditions:default": [], }), - visibility = ["//visibility:public"], deps = [ ":tensors_to_floats_calculator_cc_proto", "//mediapipe/framework:calculator_framework", @@ -967,7 +915,6 @@ cc_library( ], "//conditions:default": [], }), - visibility = ["//visibility:public"], deps = [ ":tensors_to_classification_calculator_cc_proto", "@com_google_absl//absl/container:node_hash_map", @@ -998,7 +945,6 @@ cc_library( mediapipe_proto_library( name = "tensors_to_classification_calculator_proto", srcs = ["tensors_to_classification_calculator.proto"], - visibility = ["//visibility:public"], deps = [ "//mediapipe/framework:calculator_options_proto", "//mediapipe/framework:calculator_proto", @@ -1036,7 +982,6 @@ cc_library( "//conditions:default": [], }), features = ["-layering_check"], # allow depending on image_to_tensor_calculator_gpu_deps - visibility = ["//visibility:public"], deps = [ ":image_to_tensor_calculator_cc_proto", ":image_to_tensor_converter", @@ -1065,6 +1010,7 @@ cc_library( cc_library( name = "image_to_tensor_calculator_gpu_deps", + visibility = ["//visibility:private"], deps = selects.with_or({ "//mediapipe:android": [ ":image_to_tensor_converter_gl_buffer", @@ -1088,7 +1034,6 @@ cc_library( mediapipe_proto_library( name = "image_to_tensor_calculator_proto", srcs = ["image_to_tensor_calculator.proto"], - visibility = ["//visibility:public"], deps = [ "//mediapipe/framework:calculator_options_proto", "//mediapipe/framework:calculator_proto", @@ -1151,7 +1096,6 @@ cc_library( ], "//conditions:default": [], }), - visibility = ["//visibility:public"], deps = [ ":image_to_tensor_utils", "//mediapipe/framework/formats:image", @@ -1171,7 +1115,6 @@ cc_library( ], "//conditions:default": [], }), - visibility = ["//visibility:public"], deps = [ ":image_to_tensor_converter", ":image_to_tensor_utils", @@ -1191,6 +1134,7 @@ cc_library( name = "image_to_tensor_converter_gl_buffer", srcs = ["image_to_tensor_converter_gl_buffer.cc"], hdrs = ["image_to_tensor_converter_gl_buffer.h"], + visibility = ["//visibility:private"], deps = ["//mediapipe/framework:port"] + selects.with_or({ "//mediapipe:apple": [], "//conditions:default": [ @@ -1224,6 +1168,7 @@ cc_library( name = "image_to_tensor_converter_gl_texture", srcs = ["image_to_tensor_converter_gl_texture.cc"], hdrs = ["image_to_tensor_converter_gl_texture.h"], + visibility = ["//visibility:private"], deps = ["//mediapipe/framework:port"] + select({ "//mediapipe/gpu:disable_gpu": [], "//conditions:default": [ @@ -1248,6 +1193,7 @@ cc_library( name = "image_to_tensor_converter_gl_utils", srcs = ["image_to_tensor_converter_gl_utils.cc"], hdrs = ["image_to_tensor_converter_gl_utils.h"], + visibility = ["//visibility:private"], deps = ["//mediapipe/framework:port"] + select({ "//mediapipe/gpu:disable_gpu": [], "//conditions:default": [ @@ -1277,6 +1223,7 @@ cc_library( ], "//conditions:default": [], }), + visibility = ["//visibility:private"], deps = ["//mediapipe/framework:port"] + select({ "//mediapipe:apple": [ ":image_to_tensor_converter", @@ -1308,7 +1255,6 @@ cc_library( ], "//conditions:default": [], }), - visibility = ["//visibility:public"], deps = [ ":image_to_tensor_calculator_cc_proto", "@com_google_absl//absl/status", @@ -1351,7 +1297,6 @@ selects.config_setting_group( mediapipe_proto_library( name = "tensors_to_segmentation_calculator_proto", srcs = ["tensors_to_segmentation_calculator.proto"], - visibility = ["//visibility:public"], deps = [ "//mediapipe/framework:calculator_options_proto", "//mediapipe/framework:calculator_proto", @@ -1369,7 +1314,6 @@ cc_library( ], "//conditions:default": [], }), - visibility = ["//visibility:public"], deps = [ ":tensors_to_segmentation_calculator_cc_proto", "@com_google_absl//absl/strings:str_format", @@ -1427,7 +1371,6 @@ cc_library( ], "//conditions:default": [], }), - visibility = ["//visibility:public"], deps = [ "//mediapipe/framework:calculator_context", "//mediapipe/framework:calculator_framework", diff --git a/mediapipe/calculators/tensor/inference_calculator.proto b/mediapipe/calculators/tensor/inference_calculator.proto index 46552803b..78a0039bc 100644 --- a/mediapipe/calculators/tensor/inference_calculator.proto +++ b/mediapipe/calculators/tensor/inference_calculator.proto @@ -17,6 +17,7 @@ syntax = "proto2"; package mediapipe; import "mediapipe/framework/calculator.proto"; +import "mediapipe/framework/calculator_options.proto"; option java_package = "com.google.mediapipe.calculator.proto"; option java_outer_classname = "InferenceCalculatorProto"; diff --git a/mediapipe/calculators/util/BUILD b/mediapipe/calculators/util/BUILD index 43eadd53b..1529ead8a 100644 --- a/mediapipe/calculators/util/BUILD +++ b/mediapipe/calculators/util/BUILD @@ -456,6 +456,23 @@ cc_library( alwayslink = 1, ) +cc_library( + name = "detections_deduplicate_calculator", + srcs = [ + "detections_deduplicate_calculator.cc", + ], + deps = [ + "//mediapipe/framework:calculator_framework", + "//mediapipe/framework/api2:node", + "//mediapipe/framework/api2:port", + "//mediapipe/framework/formats:detection_cc_proto", + "//mediapipe/framework/formats:location_data_cc_proto", + "@com_google_absl//absl/container:flat_hash_map", + "@com_google_absl//absl/status", + ], + alwayslink = 1, +) + cc_library( name = "rect_transformation_calculator", srcs = ["rect_transformation_calculator.cc"], diff --git a/mediapipe/calculators/util/detections_deduplicate_calculator.cc b/mediapipe/calculators/util/detections_deduplicate_calculator.cc new file mode 100644 index 000000000..2dfa09028 --- /dev/null +++ b/mediapipe/calculators/util/detections_deduplicate_calculator.cc @@ -0,0 +1,114 @@ +/* Copyright 2022 The MediaPipe Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include +#include +#include +#include + +#include "absl/container/flat_hash_map.h" +#include "absl/status/status.h" +#include "mediapipe/framework/api2/node.h" +#include "mediapipe/framework/api2/port.h" +#include "mediapipe/framework/calculator_framework.h" +#include "mediapipe/framework/formats/detection.pb.h" +#include "mediapipe/framework/formats/location_data.pb.h" + +namespace mediapipe { +namespace api2 { +namespace { + +struct BoundingBoxHash { + size_t operator()(const LocationData::BoundingBox& bbox) const { + return std::hash{}(bbox.xmin()) ^ std::hash{}(bbox.ymin()) ^ + std::hash{}(bbox.width()) ^ std::hash{}(bbox.height()); + } +}; + +struct BoundingBoxEq { + bool operator()(const LocationData::BoundingBox& lhs, + const LocationData::BoundingBox& rhs) const { + return lhs.xmin() == rhs.xmin() && lhs.ymin() == rhs.ymin() && + lhs.width() == rhs.width() && lhs.height() == rhs.height(); + } +}; + +} // namespace + +// This Calculator deduplicates the bunding boxes with exactly the same +// coordinates, and folds the labels into a single Detection proto. Note +// non-maximum-suppression remove the overlapping bounding boxes within a class, +// while the deduplication operation merges bounding boxes from different +// classes. + +// Example config: +// node { +// calculator: "DetectionsDeduplicateCalculator" +// input_stream: "detections" +// output_stream: "deduplicated_detections" +// } +class DetectionsDeduplicateCalculator : public Node { + public: + static constexpr Input> kIn{""}; + static constexpr Output> kOut{""}; + + MEDIAPIPE_NODE_CONTRACT(kIn, kOut); + + absl::Status Open(mediapipe::CalculatorContext* cc) { + cc->SetOffset(::mediapipe::TimestampDiff(0)); + return absl::OkStatus(); + } + + absl::Status Process(mediapipe::CalculatorContext* cc) { + const std::vector& raw_detections = kIn(cc).Get(); + absl::flat_hash_map + bbox_to_detections; + std::vector deduplicated_detections; + for (const auto& detection : raw_detections) { + if (!detection.has_location_data() || + !detection.location_data().has_bounding_box()) { + return absl::InvalidArgumentError( + "The location data of Detections must be BoundingBox."); + } + if (bbox_to_detections.contains( + detection.location_data().bounding_box())) { + // The bbox location already exists. Merge the detection labels into + // the existing detection proto. + Detection& deduplicated_detection = + *bbox_to_detections[detection.location_data().bounding_box()]; + deduplicated_detection.mutable_score()->MergeFrom(detection.score()); + deduplicated_detection.mutable_label()->MergeFrom(detection.label()); + deduplicated_detection.mutable_label_id()->MergeFrom( + detection.label_id()); + deduplicated_detection.mutable_display_name()->MergeFrom( + detection.display_name()); + } else { + // The bbox location appears first time. Add the detection to output + // detection vector. + deduplicated_detections.push_back(detection); + bbox_to_detections[detection.location_data().bounding_box()] = + &deduplicated_detections.back(); + } + } + kOut(cc).Send(std::move(deduplicated_detections)); + return absl::OkStatus(); + } +}; + +MEDIAPIPE_REGISTER_NODE(DetectionsDeduplicateCalculator); + +} // namespace api2 +} // namespace mediapipe diff --git a/mediapipe/calculators/util/detections_to_rects_calculator.cc b/mediapipe/calculators/util/detections_to_rects_calculator.cc index 73a67d322..3e566836c 100644 --- a/mediapipe/calculators/util/detections_to_rects_calculator.cc +++ b/mediapipe/calculators/util/detections_to_rects_calculator.cc @@ -37,6 +37,9 @@ constexpr char kNormRectTag[] = "NORM_RECT"; constexpr char kRectsTag[] = "RECTS"; constexpr char kNormRectsTag[] = "NORM_RECTS"; +using ::mediapipe::NormalizedRect; +using ::mediapipe::Rect; + constexpr float kMinFloat = std::numeric_limits::lowest(); constexpr float kMaxFloat = std::numeric_limits::max(); diff --git a/mediapipe/calculators/util/detections_to_rects_calculator_test.cc b/mediapipe/calculators/util/detections_to_rects_calculator_test.cc index 6caf792a7..63de60a60 100644 --- a/mediapipe/calculators/util/detections_to_rects_calculator_test.cc +++ b/mediapipe/calculators/util/detections_to_rects_calculator_test.cc @@ -39,6 +39,9 @@ constexpr char kImageSizeTag[] = "IMAGE_SIZE"; constexpr char kRectTag[] = "RECT"; constexpr char kDetectionTag[] = "DETECTION"; +using ::mediapipe::NormalizedRect; +using ::mediapipe::Rect; + MATCHER_P4(RectEq, x_center, y_center, width, height, "") { return testing::Value(arg.x_center(), testing::Eq(x_center)) && testing::Value(arg.y_center(), testing::Eq(y_center)) && diff --git a/mediapipe/calculators/util/landmark_projection_calculator.cc b/mediapipe/calculators/util/landmark_projection_calculator.cc index e27edea66..9f276da56 100644 --- a/mediapipe/calculators/util/landmark_projection_calculator.cc +++ b/mediapipe/calculators/util/landmark_projection_calculator.cc @@ -24,6 +24,8 @@ namespace mediapipe { +using ::mediapipe::NormalizedRect; + namespace { constexpr char kLandmarksTag[] = "NORM_LANDMARKS"; diff --git a/mediapipe/calculators/util/landmarks_smoothing_calculator.cc b/mediapipe/calculators/util/landmarks_smoothing_calculator.cc index 6673816e7..7a92cfb7e 100644 --- a/mediapipe/calculators/util/landmarks_smoothing_calculator.cc +++ b/mediapipe/calculators/util/landmarks_smoothing_calculator.cc @@ -35,7 +35,9 @@ constexpr char kObjectScaleRoiTag[] = "OBJECT_SCALE_ROI"; constexpr char kNormalizedFilteredLandmarksTag[] = "NORM_FILTERED_LANDMARKS"; constexpr char kFilteredLandmarksTag[] = "FILTERED_LANDMARKS"; +using ::mediapipe::NormalizedRect; using mediapipe::OneEuroFilter; +using ::mediapipe::Rect; using mediapipe::RelativeVelocityFilter; void NormalizedLandmarksToLandmarks( diff --git a/mediapipe/calculators/util/rect_projection_calculator.cc b/mediapipe/calculators/util/rect_projection_calculator.cc index dcc6e7391..69b28af87 100644 --- a/mediapipe/calculators/util/rect_projection_calculator.cc +++ b/mediapipe/calculators/util/rect_projection_calculator.cc @@ -23,6 +23,8 @@ namespace { constexpr char kNormRectTag[] = "NORM_RECT"; constexpr char kNormReferenceRectTag[] = "NORM_REFERENCE_RECT"; +using ::mediapipe::NormalizedRect; + } // namespace // Projects rectangle from reference coordinate system (defined by reference diff --git a/mediapipe/calculators/util/rect_to_render_data_calculator.cc b/mediapipe/calculators/util/rect_to_render_data_calculator.cc index 400be277d..bbc08255e 100644 --- a/mediapipe/calculators/util/rect_to_render_data_calculator.cc +++ b/mediapipe/calculators/util/rect_to_render_data_calculator.cc @@ -29,6 +29,9 @@ constexpr char kNormRectsTag[] = "NORM_RECTS"; constexpr char kRectsTag[] = "RECTS"; constexpr char kRenderDataTag[] = "RENDER_DATA"; +using ::mediapipe::NormalizedRect; +using ::mediapipe::Rect; + RenderAnnotation::Rectangle* NewRect( const RectToRenderDataCalculatorOptions& options, RenderData* render_data) { auto* annotation = render_data->add_render_annotations(); diff --git a/mediapipe/calculators/util/rect_to_render_scale_calculator.cc b/mediapipe/calculators/util/rect_to_render_scale_calculator.cc index d94615228..6ff6b3d51 100644 --- a/mediapipe/calculators/util/rect_to_render_scale_calculator.cc +++ b/mediapipe/calculators/util/rect_to_render_scale_calculator.cc @@ -24,6 +24,8 @@ constexpr char kNormRectTag[] = "NORM_RECT"; constexpr char kImageSizeTag[] = "IMAGE_SIZE"; constexpr char kRenderScaleTag[] = "RENDER_SCALE"; +using ::mediapipe::NormalizedRect; + } // namespace // A calculator to get scale for RenderData primitives. diff --git a/mediapipe/calculators/util/rect_transformation_calculator.cc b/mediapipe/calculators/util/rect_transformation_calculator.cc index 15bb26826..4783cb919 100644 --- a/mediapipe/calculators/util/rect_transformation_calculator.cc +++ b/mediapipe/calculators/util/rect_transformation_calculator.cc @@ -28,6 +28,9 @@ constexpr char kRectTag[] = "RECT"; constexpr char kRectsTag[] = "RECTS"; constexpr char kImageSizeTag[] = "IMAGE_SIZE"; +using ::mediapipe::NormalizedRect; +using ::mediapipe::Rect; + // Wraps around an angle in radians to within -M_PI and M_PI. inline float NormalizeRadians(float angle) { return angle - 2 * M_PI * std::floor((angle - (-M_PI)) / (2 * M_PI)); diff --git a/mediapipe/calculators/util/world_landmark_projection_calculator.cc b/mediapipe/calculators/util/world_landmark_projection_calculator.cc index bcd7352a2..e843d63bf 100644 --- a/mediapipe/calculators/util/world_landmark_projection_calculator.cc +++ b/mediapipe/calculators/util/world_landmark_projection_calculator.cc @@ -22,6 +22,8 @@ namespace mediapipe { +using ::mediapipe::NormalizedRect; + namespace { constexpr char kLandmarksTag[] = "LANDMARKS"; diff --git a/mediapipe/calculators/video/tracked_detection_manager_calculator.cc b/mediapipe/calculators/video/tracked_detection_manager_calculator.cc index c416fa9b0..48664fead 100644 --- a/mediapipe/calculators/video/tracked_detection_manager_calculator.cc +++ b/mediapipe/calculators/video/tracked_detection_manager_calculator.cc @@ -32,6 +32,8 @@ namespace mediapipe { namespace { +using ::mediapipe::NormalizedRect; + constexpr int kDetectionUpdateTimeOutMS = 5000; constexpr char kDetectionsTag[] = "DETECTIONS"; constexpr char kDetectionBoxesTag[] = "DETECTION_BOXES"; diff --git a/mediapipe/examples/android/src/java/com/google/mediapipe/apps/instantmotiontracking/GIFEditText.java b/mediapipe/examples/android/src/java/com/google/mediapipe/apps/instantmotiontracking/GIFEditText.java index 10e6422ba..1b733ed82 100644 --- a/mediapipe/examples/android/src/java/com/google/mediapipe/apps/instantmotiontracking/GIFEditText.java +++ b/mediapipe/examples/android/src/java/com/google/mediapipe/apps/instantmotiontracking/GIFEditText.java @@ -18,7 +18,7 @@ import android.content.ClipDescription; import android.content.Context; import android.net.Uri; import android.os.Bundle; -import android.support.v7.widget.AppCompatEditText; +import androidx.appcompat.widget.AppCompatEditText; import android.util.AttributeSet; import android.util.Log; import android.view.inputmethod.EditorInfo; diff --git a/mediapipe/examples/ios/common/BUILD b/mediapipe/examples/ios/common/BUILD index 9b8f8a968..bfa770cec 100644 --- a/mediapipe/examples/ios/common/BUILD +++ b/mediapipe/examples/ios/common/BUILD @@ -29,12 +29,6 @@ objc_library( "Base.lproj/LaunchScreen.storyboard", "Base.lproj/Main.storyboard", ], - sdk_frameworks = [ - "AVFoundation", - "CoreGraphics", - "CoreMedia", - "UIKit", - ], visibility = [ "//mediapipe:__subpackages__", ], @@ -42,6 +36,10 @@ objc_library( "//mediapipe/objc:mediapipe_framework_ios", "//mediapipe/objc:mediapipe_input_sources_ios", "//mediapipe/objc:mediapipe_layer_renderer", + "//third_party/apple_frameworks:AVFoundation", + "//third_party/apple_frameworks:CoreGraphics", + "//third_party/apple_frameworks:CoreMedia", + "//third_party/apple_frameworks:UIKit", ], ) diff --git a/mediapipe/examples/ios/faceeffect/BUILD b/mediapipe/examples/ios/faceeffect/BUILD index 50a6f68bd..7d3a75cc6 100644 --- a/mediapipe/examples/ios/faceeffect/BUILD +++ b/mediapipe/examples/ios/faceeffect/BUILD @@ -73,13 +73,13 @@ objc_library( "//mediapipe/modules/face_landmark:face_landmark.tflite", ], features = ["-layering_check"], - sdk_frameworks = [ - "AVFoundation", - "CoreGraphics", - "CoreMedia", - "UIKit", - ], deps = [ + "//mediapipe/framework/formats:matrix_data_cc_proto", + "//third_party/apple_frameworks:AVFoundation", + "//third_party/apple_frameworks:CoreGraphics", + "//third_party/apple_frameworks:CoreMedia", + "//third_party/apple_frameworks:UIKit", + "//mediapipe/modules/face_geometry/protos:face_geometry_cc_proto", "//mediapipe/objc:mediapipe_framework_ios", "//mediapipe/objc:mediapipe_input_sources_ios", "//mediapipe/objc:mediapipe_layer_renderer", @@ -87,9 +87,7 @@ objc_library( "//mediapipe:ios_i386": [], "//mediapipe:ios_x86_64": [], "//conditions:default": [ - "//mediapipe/framework/formats:matrix_data_cc_proto", "//mediapipe/graphs/face_effect:face_effect_gpu_deps", - "//mediapipe/modules/face_geometry/protos:face_geometry_cc_proto", ], }), ) diff --git a/mediapipe/examples/ios/facemeshgpu/BUILD b/mediapipe/examples/ios/facemeshgpu/BUILD index 02103ce2f..6caf8c09c 100644 --- a/mediapipe/examples/ios/facemeshgpu/BUILD +++ b/mediapipe/examples/ios/facemeshgpu/BUILD @@ -67,12 +67,12 @@ objc_library( ], deps = [ "//mediapipe/examples/ios/common:CommonMediaPipeAppLibrary", + "//mediapipe/framework/formats:landmark_cc_proto", ] + select({ "//mediapipe:ios_i386": [], "//mediapipe:ios_x86_64": [], "//conditions:default": [ "//mediapipe/graphs/face_mesh:mobile_calculators", - "//mediapipe/framework/formats:landmark_cc_proto", ], }), ) diff --git a/mediapipe/examples/ios/handtrackinggpu/BUILD b/mediapipe/examples/ios/handtrackinggpu/BUILD index 647b7670a..c5b8e7b58 100644 --- a/mediapipe/examples/ios/handtrackinggpu/BUILD +++ b/mediapipe/examples/ios/handtrackinggpu/BUILD @@ -68,12 +68,12 @@ objc_library( ], deps = [ "//mediapipe/examples/ios/common:CommonMediaPipeAppLibrary", + "//mediapipe/framework/formats:landmark_cc_proto", ] + select({ "//mediapipe:ios_i386": [], "//mediapipe:ios_x86_64": [], "//conditions:default": [ "//mediapipe/graphs/hand_tracking:mobile_calculators", - "//mediapipe/framework/formats:landmark_cc_proto", ], }), ) diff --git a/mediapipe/examples/ios/iristrackinggpu/BUILD b/mediapipe/examples/ios/iristrackinggpu/BUILD index 056447d63..646d2e5a2 100644 --- a/mediapipe/examples/ios/iristrackinggpu/BUILD +++ b/mediapipe/examples/ios/iristrackinggpu/BUILD @@ -68,12 +68,12 @@ objc_library( ], deps = [ "//mediapipe/examples/ios/common:CommonMediaPipeAppLibrary", + "//mediapipe/framework/formats:landmark_cc_proto", ] + select({ "//mediapipe:ios_i386": [], "//mediapipe:ios_x86_64": [], "//conditions:default": [ "//mediapipe/graphs/iris_tracking:iris_tracking_gpu_deps", - "//mediapipe/framework/formats:landmark_cc_proto", ], }), ) diff --git a/mediapipe/examples/ios/posetrackinggpu/BUILD b/mediapipe/examples/ios/posetrackinggpu/BUILD index 86b41ed36..4fbc2280c 100644 --- a/mediapipe/examples/ios/posetrackinggpu/BUILD +++ b/mediapipe/examples/ios/posetrackinggpu/BUILD @@ -67,12 +67,12 @@ objc_library( ], deps = [ "//mediapipe/examples/ios/common:CommonMediaPipeAppLibrary", + "//mediapipe/framework/formats:landmark_cc_proto", ] + select({ "//mediapipe:ios_i386": [], "//mediapipe:ios_x86_64": [], "//conditions:default": [ "//mediapipe/graphs/pose_tracking:pose_tracking_gpu_deps", - "//mediapipe/framework/formats:landmark_cc_proto", ], }), ) diff --git a/mediapipe/framework/BUILD b/mediapipe/framework/BUILD index 3cc72b4f1..872944acd 100644 --- a/mediapipe/framework/BUILD +++ b/mediapipe/framework/BUILD @@ -21,6 +21,7 @@ licenses(["notice"]) package(default_visibility = ["//visibility:private"]) +# The MediaPipe internal package group. No mediapipe users should be added to this group. package_group( name = "mediapipe_internal", packages = [ @@ -56,12 +57,12 @@ mediapipe_proto_library( srcs = ["calculator.proto"], visibility = ["//visibility:public"], deps = [ - "//mediapipe/framework:calculator_options_proto", - "//mediapipe/framework:mediapipe_options_proto", - "//mediapipe/framework:packet_factory_proto", - "//mediapipe/framework:packet_generator_proto", - "//mediapipe/framework:status_handler_proto", - "//mediapipe/framework:stream_handler_proto", + ":calculator_options_proto", + ":mediapipe_options_proto", + ":packet_factory_proto", + ":packet_generator_proto", + ":status_handler_proto", + ":stream_handler_proto", "@com_google_protobuf//:any_proto", ], ) @@ -78,8 +79,8 @@ mediapipe_proto_library( srcs = ["calculator_contract_test.proto"], visibility = ["//mediapipe/framework:__subpackages__"], deps = [ - "//mediapipe/framework:calculator_options_proto", - "//mediapipe/framework:calculator_proto", + ":calculator_options_proto", + ":calculator_proto", ], ) @@ -88,8 +89,8 @@ mediapipe_proto_library( srcs = ["calculator_profile.proto"], visibility = ["//visibility:public"], deps = [ - "//mediapipe/framework:calculator_options_proto", - "//mediapipe/framework:calculator_proto", + ":calculator_options_proto", + ":calculator_proto", ], ) @@ -125,14 +126,14 @@ mediapipe_proto_library( name = "status_handler_proto", srcs = ["status_handler.proto"], visibility = [":mediapipe_internal"], - deps = ["//mediapipe/framework:mediapipe_options_proto"], + deps = [":mediapipe_options_proto"], ) mediapipe_proto_library( name = "stream_handler_proto", srcs = ["stream_handler.proto"], visibility = [":mediapipe_internal"], - deps = ["//mediapipe/framework:mediapipe_options_proto"], + deps = [":mediapipe_options_proto"], ) mediapipe_proto_library( @@ -141,8 +142,8 @@ mediapipe_proto_library( srcs = ["test_calculators.proto"], visibility = [":mediapipe_internal"], deps = [ - "//mediapipe/framework:calculator_options_proto", - "//mediapipe/framework:calculator_proto", + ":calculator_options_proto", + ":calculator_proto", ], ) @@ -150,7 +151,7 @@ mediapipe_proto_library( name = "thread_pool_executor_proto", srcs = ["thread_pool_executor.proto"], visibility = [":mediapipe_internal"], - deps = ["//mediapipe/framework:mediapipe_options_proto"], + deps = [":mediapipe_options_proto"], ) # It is for pure-native Android builds where the library can't have any dependency on libandroid.so diff --git a/mediapipe/framework/deps/BUILD b/mediapipe/framework/deps/BUILD index 95ab21707..27bc105c8 100644 --- a/mediapipe/framework/deps/BUILD +++ b/mediapipe/framework/deps/BUILD @@ -20,7 +20,9 @@ load("@bazel_skylib//:bzl_library.bzl", "bzl_library") licenses(["notice"]) -package(default_visibility = ["//visibility:private"]) +package(default_visibility = [ + "//mediapipe:__subpackages__", +]) bzl_library( name = "expand_template_bzl", @@ -50,13 +52,11 @@ mediapipe_proto_library( cc_library( name = "aligned_malloc_and_free", hdrs = ["aligned_malloc_and_free.h"], - visibility = ["//visibility:public"], ) cc_library( name = "cleanup", hdrs = ["cleanup.h"], - visibility = ["//visibility:public"], deps = ["@com_google_absl//absl/base:core_headers"], ) @@ -86,7 +86,6 @@ cc_library( # Use this library through "mediapipe/framework/port:gtest_main". visibility = [ "//mediapipe/framework/port:__pkg__", - "//third_party/visionai/algorithms/tracking:__pkg__", ], deps = [ "//mediapipe/framework/port:core_proto", @@ -108,7 +107,6 @@ cc_library( name = "file_helpers", srcs = ["file_helpers.cc"], hdrs = ["file_helpers.h"], - visibility = ["//visibility:public"], deps = [ ":file_path", "//mediapipe/framework/port:status", @@ -134,7 +132,6 @@ cc_library( cc_library( name = "image_resizer", hdrs = ["image_resizer.h"], - visibility = ["//visibility:public"], deps = [ "//mediapipe/framework/port:opencv_imgproc", ], @@ -151,7 +148,9 @@ cc_library( cc_library( name = "mathutil", hdrs = ["mathutil.h"], - visibility = ["//visibility:public"], + visibility = [ + "//mediapipe:__subpackages__", + ], deps = [ "//mediapipe/framework/port:integral_types", "//mediapipe/framework/port:logging", @@ -171,7 +170,6 @@ cc_library( cc_library( name = "no_destructor", hdrs = ["no_destructor.h"], - visibility = ["//visibility:public"], ) cc_library( @@ -190,7 +188,6 @@ cc_library( cc_library( name = "random", hdrs = ["random_base.h"], - visibility = ["//visibility:public"], deps = ["//mediapipe/framework/port:integral_types"], ) @@ -211,14 +208,12 @@ cc_library( name = "registration_token", srcs = ["registration_token.cc"], hdrs = ["registration_token.h"], - visibility = ["//visibility:public"], ) cc_library( name = "registration", srcs = ["registration.cc"], hdrs = ["registration.h"], - visibility = ["//visibility:public"], deps = [ ":registration_token", "//mediapipe/framework/port:logging", @@ -279,7 +274,6 @@ cc_library( hdrs = [ "re2.h", ], - visibility = ["//visibility:public"], ) cc_library( @@ -310,7 +304,6 @@ cc_library( cc_library( name = "thread_options", hdrs = ["thread_options.h"], - visibility = ["//visibility:public"], ) cc_library( @@ -356,7 +349,6 @@ cc_library( cc_test( name = "mathutil_unittest", srcs = ["mathutil_unittest.cc"], - visibility = ["//visibility:public"], deps = [ ":mathutil", "//mediapipe/framework/port:benchmark", @@ -368,7 +360,6 @@ cc_test( name = "registration_token_test", srcs = ["registration_token_test.cc"], linkstatic = 1, - visibility = ["//visibility:public"], deps = [ ":registration_token", "//mediapipe/framework/port:gtest_main", @@ -381,7 +372,6 @@ cc_test( timeout = "long", srcs = ["safe_int_test.cc"], linkstatic = 1, - visibility = ["//visibility:public"], deps = [ ":intops", "//mediapipe/framework/port:gtest_main", @@ -393,7 +383,6 @@ cc_test( name = "monotonic_clock_test", srcs = ["monotonic_clock_test.cc"], linkstatic = 1, - visibility = ["//visibility:public"], deps = [ ":clock", "//mediapipe/framework/port:gtest_main", diff --git a/mediapipe/framework/formats/tensor.cc b/mediapipe/framework/formats/tensor.cc index c31eba350..fdafbff5c 100644 --- a/mediapipe/framework/formats/tensor.cc +++ b/mediapipe/framework/formats/tensor.cc @@ -361,7 +361,7 @@ void Tensor::AllocateOpenGlBuffer() const { LOG_IF(FATAL, !gl_context_) << "GlContext is not bound to the thread."; glGenBuffers(1, &opengl_buffer_); glBindBuffer(GL_SHADER_STORAGE_BUFFER, opengl_buffer_); - if (!AllocateAhwbMapToSsbo()) { + if (!use_ahwb_ || !AllocateAhwbMapToSsbo()) { glBufferData(GL_SHADER_STORAGE_BUFFER, bytes(), NULL, GL_STREAM_COPY); } glBindBuffer(GL_SHADER_STORAGE_BUFFER, 0); @@ -551,7 +551,7 @@ Tensor::CpuReadView Tensor::GetCpuReadView() const { }); } else #endif // MEDIAPIPE_OPENGL_ES_VERSION >= MEDIAPIPE_OPENGL_ES_31 - + { // Transfer data from texture if not transferred from SSBO/MTLBuffer // yet. if (valid_ & kValidOpenGlTexture2d) { @@ -582,6 +582,7 @@ Tensor::CpuReadView Tensor::GetCpuReadView() const { } }); } + } #endif // MEDIAPIPE_OPENGL_ES_VERSION >= MEDIAPIPE_OPENGL_ES_30 valid_ |= kValidCpu; } @@ -609,7 +610,7 @@ Tensor::CpuWriteView Tensor::GetCpuWriteView() const { void Tensor::AllocateCpuBuffer() const { if (!cpu_buffer_) { #ifdef MEDIAPIPE_TENSOR_USE_AHWB - if (AllocateAHardwareBuffer()) return; + if (use_ahwb_ && AllocateAHardwareBuffer()) return; #endif // MEDIAPIPE_TENSOR_USE_AHWB #if MEDIAPIPE_METAL_ENABLED cpu_buffer_ = AllocateVirtualMemory(bytes()); diff --git a/mediapipe/framework/formats/tensor.h b/mediapipe/framework/formats/tensor.h index ecd63c8c6..9d3e90b6a 100644 --- a/mediapipe/framework/formats/tensor.h +++ b/mediapipe/framework/formats/tensor.h @@ -39,10 +39,9 @@ #endif // MEDIAPIPE_NO_JNI #ifdef MEDIAPIPE_TENSOR_USE_AHWB +#include +#include #include - -#include "third_party/GL/gl/include/EGL/egl.h" -#include "third_party/GL/gl/include/EGL/eglext.h" #endif // MEDIAPIPE_TENSOR_USE_AHWB #if MEDIAPIPE_OPENGL_ES_VERSION >= MEDIAPIPE_OPENGL_ES_30 #include "mediapipe/gpu/gl_base.h" @@ -410,8 +409,8 @@ class Tensor { bool AllocateAHardwareBuffer(int size_alignment = 0) const; void CreateEglSyncAndFd() const; // Use Ahwb for other views: OpenGL / CPU buffer. - static inline bool use_ahwb_ = false; #endif // MEDIAPIPE_TENSOR_USE_AHWB + static inline bool use_ahwb_ = false; // Expects the target SSBO to be already bound. bool AllocateAhwbMapToSsbo() const; bool InsertAhwbToSsboFence() const; @@ -419,6 +418,7 @@ class Tensor { void ReleaseAhwbStuff(); void* MapAhwbToCpuRead() const; void* MapAhwbToCpuWrite() const; + void MoveCpuOrSsboToAhwb() const; #if MEDIAPIPE_OPENGL_ES_VERSION >= MEDIAPIPE_OPENGL_ES_30 mutable std::shared_ptr gl_context_; diff --git a/mediapipe/framework/formats/tensor_ahwb.cc b/mediapipe/framework/formats/tensor_ahwb.cc index b11f6b55b..3c3ec8b17 100644 --- a/mediapipe/framework/formats/tensor_ahwb.cc +++ b/mediapipe/framework/formats/tensor_ahwb.cc @@ -4,12 +4,13 @@ #include "mediapipe/framework/formats/tensor.h" #ifdef MEDIAPIPE_TENSOR_USE_AHWB +#include +#include + #include "absl/synchronization/mutex.h" #include "mediapipe/framework/port.h" #include "mediapipe/framework/port/logging.h" #include "mediapipe/gpu/gl_base.h" -#include "third_party/GL/gl/include/EGL/egl.h" -#include "third_party/GL/gl/include/EGL/eglext.h" #endif // MEDIAPIPE_TENSOR_USE_AHWB namespace mediapipe { @@ -213,11 +214,16 @@ Tensor::AHardwareBufferView Tensor::GetAHardwareBufferReadView() const { "supported."; CHECK(ahwb_ || !(valid_ & kValidOpenGlBuffer)) << "Interoperability bettween OpenGL buffer and AHardwareBuffer is not " - "supported on targe system."; + "supported on target system."; + bool transfer = !ahwb_; CHECK(AllocateAHardwareBuffer()) << "AHardwareBuffer is not supported on the target system."; valid_ |= kValidAHardwareBuffer; - if (valid_ & kValidOpenGlBuffer) CreateEglSyncAndFd(); + if (transfer) { + MoveCpuOrSsboToAhwb(); + } else { + if (valid_ & kValidOpenGlBuffer) CreateEglSyncAndFd(); + } return {ahwb_, ssbo_written_, &fence_fd_, // The FD is created for SSBO -> AHWB synchronization. @@ -262,7 +268,6 @@ Tensor::AHardwareBufferView Tensor::GetAHardwareBufferWriteView( } bool Tensor::AllocateAHardwareBuffer(int size_alignment) const { - if (!use_ahwb_) return false; if (__builtin_available(android 26, *)) { if (ahwb_ == nullptr) { AHardwareBuffer_Desc desc = {}; @@ -302,6 +307,39 @@ bool Tensor::AllocateAhwbMapToSsbo() const { return false; } +// Moves Cpu/Ssbo resource under the Ahwb backed memory. +void Tensor::MoveCpuOrSsboToAhwb() const { + void* dest = nullptr; + if (__builtin_available(android 26, *)) { + auto error = AHardwareBuffer_lock( + ahwb_, AHARDWAREBUFFER_USAGE_CPU_WRITE_RARELY, -1, nullptr, &dest); + CHECK(error == 0) << "AHardwareBuffer_lock " << error; + } + if (valid_ & kValidOpenGlBuffer) { + gl_context_->Run([this, dest]() { + glBindBuffer(GL_SHADER_STORAGE_BUFFER, opengl_buffer_); + const void* src = glMapBufferRange(GL_SHADER_STORAGE_BUFFER, 0, bytes(), + GL_MAP_READ_BIT); + std::memcpy(dest, src, bytes()); + glUnmapBuffer(GL_SHADER_STORAGE_BUFFER); + glDeleteBuffers(1, &opengl_buffer_); + }); + opengl_buffer_ = GL_INVALID_INDEX; + gl_context_ = nullptr; + } else if (valid_ & kValidCpu) { + std::memcpy(dest, cpu_buffer_, bytes()); + // Free CPU memory because next time AHWB is mapped instead. + free(cpu_buffer_); + cpu_buffer_ = nullptr; + } else { + LOG(FATAL) << "Can't convert tensor with mask " << valid_ << " into AHWB."; + } + if (__builtin_available(android 26, *)) { + auto error = AHardwareBuffer_unlock(ahwb_, nullptr); + CHECK(error == 0) << "AHardwareBuffer_unlock " << error; + } +} + // SSBO is created on top of AHWB. A fence is inserted into the GPU queue before // the GPU task that is going to read from the SSBO. When the writing into AHWB // is finished then the GPU reads from the SSBO. diff --git a/mediapipe/framework/formats/tensor_ahwb_gpu_test.cc b/mediapipe/framework/formats/tensor_ahwb_gpu_test.cc new file mode 100644 index 000000000..7ccd9c7f5 --- /dev/null +++ b/mediapipe/framework/formats/tensor_ahwb_gpu_test.cc @@ -0,0 +1,171 @@ + +#if !defined(MEDIAPIPE_NO_JNI) && \ + (__ANDROID_API__ >= 26 || \ + defined(__ANDROID_UNAVAILABLE_SYMBOLS_ARE_WEAK__)) +#include + +#include + +#include "mediapipe/framework/formats/tensor.h" +#include "mediapipe/framework/formats/tensor_data_types.h" +#include "mediapipe/gpu/gpu_test_base.h" +#include "mediapipe/gpu/shader_util.h" +#include "tensorflow/lite/delegates/gpu/gl/gl_call.h" +#include "testing/base/public/gunit.h" + +// The test creates OpenGL ES buffer, fills the buffer with incrementing values +// 0.0, 0.1, 0.2 etc. with the compute shader on GPU. +// Then the test requests the CPU view and compares the values. +// Float32 and Float16 tests are there. + +namespace { + +using mediapipe::Float16; +using mediapipe::Tensor; + +MATCHER_P(NearWithPrecision, precision, "") { + return std::abs(std::get<0>(arg) - std::get<1>(arg)) < precision; +} + +#if MEDIAPIPE_OPENGL_ES_VERSION >= MEDIAPIPE_OPENGL_ES_31 + +// Utility function to fill the GPU buffer. +void FillGpuBuffer(GLuint name, std::size_t size, + const Tensor::ElementType fmt) { + std::string shader_source; + if (fmt == Tensor::ElementType::kFloat32) { + shader_source = R"( #version 310 es + precision highp float; + layout(local_size_x = 1, local_size_y = 1) in; + layout(std430, binding = 0) buffer Output {float elements[];} output_data; + void main() { + uint v = gl_GlobalInvocationID.x * 2u; + output_data.elements[v] = float(v) / 10.0; + output_data.elements[v + 1u] = float(v + 1u) / 10.0; + })"; + } else { + shader_source = R"( #version 310 es + precision highp float; + layout(local_size_x = 1, local_size_y = 1) in; + layout(std430, binding = 0) buffer Output {float elements[];} output_data; + void main() { + uint v = gl_GlobalInvocationID.x; + uint tmp = packHalf2x16(vec2((float(v)* 2.0 + 0.0) / 10.0, + (float(v) * 2.0 + 1.0) / 10.0)); + output_data.elements[v] = uintBitsToFloat(tmp); + })"; + } + GLuint shader; + MP_ASSERT_OK(TFLITE_GPU_CALL_GL(glCreateShader, &shader, GL_COMPUTE_SHADER)); + const GLchar* sources[] = {shader_source.c_str()}; + MP_ASSERT_OK(TFLITE_GPU_CALL_GL(glShaderSource, shader, 1, sources, nullptr)); + MP_ASSERT_OK(TFLITE_GPU_CALL_GL(glCompileShader, shader)); + GLint is_compiled = 0; + MP_ASSERT_OK(TFLITE_GPU_CALL_GL(glGetShaderiv, shader, GL_COMPILE_STATUS, + &is_compiled)); + if (is_compiled == GL_FALSE) { + GLint max_length = 0; + MP_ASSERT_OK(TFLITE_GPU_CALL_GL(glGetShaderiv, shader, GL_INFO_LOG_LENGTH, + &max_length)); + std::vector error_log(max_length); + glGetShaderInfoLog(shader, max_length, &max_length, error_log.data()); + glDeleteShader(shader); + FAIL() << error_log.data(); + return; + } + GLuint to_buffer_program; + MP_ASSERT_OK(TFLITE_GPU_CALL_GL(glCreateProgram, &to_buffer_program)); + MP_ASSERT_OK(TFLITE_GPU_CALL_GL(glAttachShader, to_buffer_program, shader)); + MP_ASSERT_OK(TFLITE_GPU_CALL_GL(glDeleteShader, shader)); + MP_ASSERT_OK(TFLITE_GPU_CALL_GL(glLinkProgram, to_buffer_program)); + + MP_ASSERT_OK( + TFLITE_GPU_CALL_GL(glBindBufferBase, GL_SHADER_STORAGE_BUFFER, 0, name)); + MP_ASSERT_OK(TFLITE_GPU_CALL_GL(glUseProgram, to_buffer_program)); + MP_ASSERT_OK(TFLITE_GPU_CALL_GL(glDispatchCompute, size / 2, 1, 1)); + MP_ASSERT_OK(TFLITE_GPU_CALL_GL(glBindBuffer, GL_SHADER_STORAGE_BUFFER, 0)); + MP_ASSERT_OK(TFLITE_GPU_CALL_GL(glDeleteProgram, to_buffer_program)); +} + +class TensorAhwbGpuTest : public mediapipe::GpuTestBase { + public: +}; + +TEST_F(TensorAhwbGpuTest, TestGpuToCpuFloat32) { + Tensor::SetPreferredStorageType(Tensor::StorageType::kAhwb); + constexpr size_t num_elements = 20; + Tensor tensor{Tensor::ElementType::kFloat32, Tensor::Shape({num_elements})}; + RunInGlContext([&tensor] { + auto ssbo_view = tensor.GetOpenGlBufferWriteView(); + auto ssbo_name = ssbo_view.name(); + EXPECT_GT(ssbo_name, 0); + FillGpuBuffer(ssbo_name, tensor.shape().num_elements(), + tensor.element_type()); + }); + auto ptr = tensor.GetCpuReadView().buffer(); + EXPECT_NE(ptr, nullptr); + std::vector reference; + reference.resize(num_elements); + for (int i = 0; i < num_elements; i++) { + reference[i] = static_cast(i) / 10.0f; + } + EXPECT_THAT(absl::Span(ptr, num_elements), + testing::Pointwise(testing::FloatEq(), reference)); +} + +TEST_F(TensorAhwbGpuTest, TestGpuToCpuFloat16) { + Tensor::SetPreferredStorageType(Tensor::StorageType::kAhwb); + constexpr size_t num_elements = 20; + Tensor tensor{Tensor::ElementType::kFloat16, Tensor::Shape({num_elements})}; + RunInGlContext([&tensor] { + auto ssbo_view = tensor.GetOpenGlBufferWriteView(); + auto ssbo_name = ssbo_view.name(); + EXPECT_GT(ssbo_name, 0); + FillGpuBuffer(ssbo_name, tensor.shape().num_elements(), + tensor.element_type()); + }); + auto ptr = tensor.GetCpuReadView().buffer(); + EXPECT_NE(ptr, nullptr); + std::vector reference; + reference.resize(num_elements); + for (int i = 0; i < num_elements; i++) { + reference[i] = static_cast(i) / 10.0f; + } + // Precision is set to a reasonable value for Float16. + EXPECT_THAT(absl::Span(ptr, num_elements), + testing::Pointwise(NearWithPrecision(0.001), reference)); +} + +TEST_F(TensorAhwbGpuTest, TestReplacingCpuByAhwb) { + // Request the CPU view to get the memory to be allocated. + // Request Ahwb view then to transform the storage into Ahwb. + Tensor::SetPreferredStorageType(Tensor::StorageType::kDefault); + constexpr size_t num_elements = 20; + Tensor tensor{Tensor::ElementType::kFloat32, Tensor::Shape({num_elements})}; + { + auto ptr = tensor.GetCpuWriteView().buffer(); + EXPECT_NE(ptr, nullptr); + for (int i = 0; i < num_elements; i++) { + ptr[i] = static_cast(i) / 10.0f; + } + } + { + auto view = tensor.GetAHardwareBufferReadView(); + EXPECT_NE(view.handle(), nullptr); + } + auto ptr = tensor.GetCpuReadView().buffer(); + EXPECT_NE(ptr, nullptr); + std::vector reference; + reference.resize(num_elements); + for (int i = 0; i < num_elements; i++) { + reference[i] = static_cast(i) / 10.0f; + } + EXPECT_THAT(absl::Span(ptr, num_elements), + testing::Pointwise(testing::FloatEq(), reference)); +} + +#endif // MEDIAPIPE_OPENGL_ES_VERSION >= MEDIAPIPE_OPENGL_ES_31 +} // namespace + +#endif // !defined(MEDIAPIPE_NO_JNI) && (__ANDROID_API__ >= 26 || + // defined(__ANDROID_UNAVAILABLE_SYMBOLS_ARE_WEAK__)) diff --git a/mediapipe/framework/formats/tensor_hardware_buffer.h b/mediapipe/framework/formats/tensor_hardware_buffer.h new file mode 100644 index 000000000..fa0241bde --- /dev/null +++ b/mediapipe/framework/formats/tensor_hardware_buffer.h @@ -0,0 +1,71 @@ +#ifndef MEDIAPIPE_FRAMEWORK_FORMATS_TENSOR_HARDWARE_BUFFER_H_ +#define MEDIAPIPE_FRAMEWORK_FORMATS_TENSOR_HARDWARE_BUFFER_H_ + +#if !defined(MEDIAPIPE_NO_JNI) && \ + (__ANDROID_API__ >= 26 || \ + defined(__ANDROID_UNAVAILABLE_SYMBOLS_ARE_WEAK__)) + +#include + +#include + +#include "mediapipe/framework/formats/tensor_buffer.h" +#include "mediapipe/framework/formats/tensor_internal.h" +#include "mediapipe/framework/formats/tensor_v2.h" + +namespace mediapipe { + +// Supports: +// - float 16 and 32 bits +// - signed / unsigned integers 8,16,32 bits +class TensorHardwareBufferView; +struct TensorHardwareBufferViewDescriptor : public Tensor::ViewDescriptor { + using ViewT = TensorHardwareBufferView; + TensorBufferDescriptor buffer; +}; + +class TensorHardwareBufferView : public Tensor::View { + public: + TENSOR_UNIQUE_VIEW_TYPE_ID(); + ~TensorHardwareBufferView() = default; + + const TensorHardwareBufferViewDescriptor& descriptor() const override { + return descriptor_; + } + AHardwareBuffer* handle() const { return ahwb_handle_; } + + protected: + TensorHardwareBufferView(int access_capability, Tensor::View::Access access, + Tensor::View::State state, + const TensorHardwareBufferViewDescriptor& desc, + AHardwareBuffer* ahwb_handle) + : Tensor::View(kId, access_capability, access, state), + descriptor_(desc), + ahwb_handle_(ahwb_handle) {} + + private: + bool MatchDescriptor( + uint64_t view_type_id, + const Tensor::ViewDescriptor& base_descriptor) const override { + if (!Tensor::View::MatchDescriptor(view_type_id, base_descriptor)) + return false; + auto descriptor = + static_cast(base_descriptor); + return descriptor.buffer.format == descriptor_.buffer.format && + descriptor.buffer.size_alignment <= + descriptor_.buffer.size_alignment && + descriptor_.buffer.size_alignment % + descriptor.buffer.size_alignment == + 0; + } + const TensorHardwareBufferViewDescriptor& descriptor_; + AHardwareBuffer* ahwb_handle_ = nullptr; +}; + +} // namespace mediapipe + +#endif // !defined(MEDIAPIPE_NO_JNI) && \ + (__ANDROID_API__ >= 26 || \ + defined(__ANDROID_UNAVAILABLE_SYMBOLS_ARE_WEAK__)) + +#endif // MEDIAPIPE_FRAMEWORK_FORMATS_TENSOR_HARDWARE_BUFFER_H_ diff --git a/mediapipe/framework/formats/tensor_hardware_buffer_cpu_storage.cc b/mediapipe/framework/formats/tensor_hardware_buffer_cpu_storage.cc new file mode 100644 index 000000000..9c223ce2c --- /dev/null +++ b/mediapipe/framework/formats/tensor_hardware_buffer_cpu_storage.cc @@ -0,0 +1,216 @@ +#if !defined(MEDIAPIPE_NO_JNI) && \ + (__ANDROID_API__ >= 26 || \ + defined(__ANDROID_UNAVAILABLE_SYMBOLS_ARE_WEAK__)) + +#include + +#include "absl/memory/memory.h" +#include "absl/status/status.h" +#include "mediapipe/framework/formats/tensor_backend.h" +#include "mediapipe/framework/formats/tensor_cpu_buffer.h" +#include "mediapipe/framework/formats/tensor_hardware_buffer.h" +#include "mediapipe/framework/formats/tensor_v2.h" +#include "util/task/status_macros.h" + +namespace mediapipe { +namespace { + +class TensorCpuViewImpl : public TensorCpuView { + public: + TensorCpuViewImpl(int access_capabilities, Tensor::View::Access access, + Tensor::View::State state, + const TensorCpuViewDescriptor& descriptor, void* pointer, + AHardwareBuffer* ahwb_handle) + : TensorCpuView(access_capabilities, access, state, descriptor, pointer), + ahwb_handle_(ahwb_handle) {} + ~TensorCpuViewImpl() { + // If handle_ is null then this view is constructed in GetViews with no + // access. + if (ahwb_handle_) { + if (__builtin_available(android 26, *)) { + AHardwareBuffer_unlock(ahwb_handle_, nullptr); + } + } + } + + private: + AHardwareBuffer* ahwb_handle_; +}; + +class TensorHardwareBufferViewImpl : public TensorHardwareBufferView { + public: + TensorHardwareBufferViewImpl( + int access_capability, Tensor::View::Access access, + Tensor::View::State state, + const TensorHardwareBufferViewDescriptor& descriptor, + AHardwareBuffer* handle) + : TensorHardwareBufferView(access_capability, access, state, descriptor, + handle) {} + ~TensorHardwareBufferViewImpl() = default; +}; + +class HardwareBufferCpuStorage : public TensorStorage { + public: + ~HardwareBufferCpuStorage() { + if (!ahwb_handle_) return; + if (__builtin_available(android 26, *)) { + AHardwareBuffer_release(ahwb_handle_); + } + } + + static absl::Status CanProvide( + int access_capability, const Tensor::Shape& shape, uint64_t view_type_id, + const Tensor::ViewDescriptor& base_descriptor) { + // TODO: use AHardwareBuffer_isSupported for API >= 29. + static const bool is_ahwb_supported = [] { + if (__builtin_available(android 26, *)) { + AHardwareBuffer_Desc desc = {}; + // Aligned to the largest possible virtual memory page size. + constexpr uint32_t kPageSize = 16384; + desc.width = kPageSize; + desc.height = 1; + desc.layers = 1; + desc.format = AHARDWAREBUFFER_FORMAT_BLOB; + desc.usage = AHARDWAREBUFFER_USAGE_CPU_READ_OFTEN | + AHARDWAREBUFFER_USAGE_CPU_WRITE_OFTEN; + AHardwareBuffer* handle; + if (AHardwareBuffer_allocate(&desc, &handle) != 0) return false; + AHardwareBuffer_release(handle); + return true; + } + return false; + }(); + if (!is_ahwb_supported) { + return absl::UnavailableError( + "AHardwareBuffer is not supported on the platform."); + } + + if (view_type_id != TensorCpuView::kId && + view_type_id != TensorHardwareBufferView::kId) { + return absl::InvalidArgumentError( + "A view type is not supported by this storage."); + } + return absl::OkStatus(); + } + + std::vector> GetViews(uint64_t latest_version) { + std::vector> result; + auto update_state = latest_version == version_ + ? Tensor::View::State::kUpToDate + : Tensor::View::State::kOutdated; + if (ahwb_handle_) { + result.push_back( + std::unique_ptr(new TensorHardwareBufferViewImpl( + kAccessCapability, Tensor::View::Access::kNoAccess, update_state, + hw_descriptor_, ahwb_handle_))); + + result.push_back(std::unique_ptr(new TensorCpuViewImpl( + kAccessCapability, Tensor::View::Access::kNoAccess, update_state, + cpu_descriptor_, nullptr, nullptr))); + } + return result; + } + + absl::StatusOr> GetView( + Tensor::View::Access access, const Tensor::Shape& shape, + uint64_t latest_version, uint64_t view_type_id, + const Tensor::ViewDescriptor& base_descriptor, int access_capability) { + MP_RETURN_IF_ERROR( + CanProvide(access_capability, shape, view_type_id, base_descriptor)); + const auto& buffer_descriptor = + view_type_id == TensorHardwareBufferView::kId + ? static_cast( + base_descriptor) + .buffer + : static_cast(base_descriptor) + .buffer; + if (!ahwb_handle_) { + if (__builtin_available(android 26, *)) { + AHardwareBuffer_Desc desc = {}; + desc.width = TensorBufferSize(buffer_descriptor, shape); + desc.height = 1; + desc.layers = 1; + desc.format = AHARDWAREBUFFER_FORMAT_BLOB; + // TODO: Use access capabilities to set hints. + desc.usage = AHARDWAREBUFFER_USAGE_CPU_READ_OFTEN | + AHARDWAREBUFFER_USAGE_CPU_WRITE_OFTEN; + auto error = AHardwareBuffer_allocate(&desc, &ahwb_handle_); + if (error != 0) { + return absl::UnknownError( + absl::StrCat("Error allocating hardware buffer: ", error)); + } + // Fill all possible views to provide it as proto views. + hw_descriptor_.buffer = buffer_descriptor; + cpu_descriptor_.buffer = buffer_descriptor; + } + } + if (buffer_descriptor.format != hw_descriptor_.buffer.format || + buffer_descriptor.size_alignment > + hw_descriptor_.buffer.size_alignment || + hw_descriptor_.buffer.size_alignment % + buffer_descriptor.size_alignment > + 0) { + return absl::AlreadyExistsError( + "A view with different params is already allocated with this " + "storage"); + } + + absl::StatusOr> result; + if (view_type_id == TensorHardwareBufferView::kId) { + result = GetAhwbView(access, shape, base_descriptor); + } else { + result = GetCpuView(access, shape, base_descriptor); + } + if (result.ok()) version_ = latest_version; + return result; + } + + private: + absl::StatusOr> GetAhwbView( + Tensor::View::Access access, const Tensor::Shape& shape, + const Tensor::ViewDescriptor& base_descriptor) { + return std::unique_ptr(new TensorHardwareBufferViewImpl( + kAccessCapability, access, Tensor::View::State::kUpToDate, + hw_descriptor_, ahwb_handle_)); + } + + absl::StatusOr> GetCpuView( + Tensor::View::Access access, const Tensor::Shape& shape, + const Tensor::ViewDescriptor& base_descriptor) { + void* pointer = nullptr; + if (__builtin_available(android 26, *)) { + int error = + AHardwareBuffer_lock(ahwb_handle_, + access == Tensor::View::Access::kWriteOnly + ? AHARDWAREBUFFER_USAGE_CPU_WRITE_OFTEN + : AHARDWAREBUFFER_USAGE_CPU_READ_OFTEN, + -1, nullptr, &pointer); + if (error != 0) { + return absl::UnknownError( + absl::StrCat("Error locking hardware buffer: ", error)); + } + } + return std::unique_ptr( + new TensorCpuViewImpl(access == Tensor::View::Access::kWriteOnly + ? Tensor::View::AccessCapability::kWrite + : Tensor::View::AccessCapability::kRead, + access, Tensor::View::State::kUpToDate, + cpu_descriptor_, pointer, ahwb_handle_)); + } + + static constexpr int kAccessCapability = + Tensor::View::AccessCapability::kRead | + Tensor::View::AccessCapability::kWrite; + TensorHardwareBufferViewDescriptor hw_descriptor_; + AHardwareBuffer* ahwb_handle_ = nullptr; + + TensorCpuViewDescriptor cpu_descriptor_; + uint64_t version_ = 0; +}; +TENSOR_REGISTER_STORAGE(HardwareBufferCpuStorage); + +} // namespace +} // namespace mediapipe + +#endif // !defined(MEDIAPIPE_NO_JNI) && (__ANDROID_API__ >= 26 || + // defined(__ANDROID_UNAVAILABLE_SYMBOLS_ARE_WEAK__)) diff --git a/mediapipe/framework/formats/tensor_hardware_buffer_cpu_storage_test.cc b/mediapipe/framework/formats/tensor_hardware_buffer_cpu_storage_test.cc new file mode 100644 index 000000000..0afa9899f --- /dev/null +++ b/mediapipe/framework/formats/tensor_hardware_buffer_cpu_storage_test.cc @@ -0,0 +1,76 @@ + +#if !defined(MEDIAPIPE_NO_JNI) && \ + (__ANDROID_API__ >= 26 || \ + defined(__ANDROID_UNAVAILABLE_SYMBOLS_ARE_WEAK__)) +#include + +#include + +#include "mediapipe/framework/formats/tensor_cpu_buffer.h" +#include "mediapipe/framework/formats/tensor_hardware_buffer.h" +#include "mediapipe/framework/formats/tensor_v2.h" +#include "testing/base/public/gmock.h" +#include "testing/base/public/gunit.h" + +namespace mediapipe { + +namespace { + +class TensorHardwareBufferTest : public ::testing::Test { + public: + TensorHardwareBufferTest() {} + ~TensorHardwareBufferTest() override {} +}; + +TEST_F(TensorHardwareBufferTest, TestFloat32) { + Tensor tensor{Tensor::Shape({1})}; + { + MP_ASSERT_OK_AND_ASSIGN( + auto view, + tensor.GetView( + TensorHardwareBufferViewDescriptor{ + .buffer = {.format = + TensorBufferDescriptor::Format::kFloat32}})); + EXPECT_NE(view->handle(), nullptr); + } + { + const auto& const_tensor = tensor; + MP_ASSERT_OK_AND_ASSIGN( + auto view, + const_tensor.GetView( + TensorCpuViewDescriptor{ + .buffer = {.format = + TensorBufferDescriptor::Format::kFloat32}})); + EXPECT_NE(view->data(), nullptr); + } +} + +TEST_F(TensorHardwareBufferTest, TestInt8Padding) { + Tensor tensor{Tensor::Shape({1})}; + + { + MP_ASSERT_OK_AND_ASSIGN( + auto view, + tensor.GetView( + TensorHardwareBufferViewDescriptor{ + .buffer = {.format = TensorBufferDescriptor::Format::kInt8, + .size_alignment = 4}})); + EXPECT_NE(view->handle(), nullptr); + } + { + const auto& const_tensor = tensor; + MP_ASSERT_OK_AND_ASSIGN( + auto view, + const_tensor.GetView( + TensorCpuViewDescriptor{ + .buffer = {.format = TensorBufferDescriptor::Format::kInt8}})); + EXPECT_NE(view->data(), nullptr); + } +} + +} // namespace + +} // namespace mediapipe + +#endif // !defined(MEDIAPIPE_NO_JNI) && (__ANDROID_API__ >= 26 || + // defined(__ANDROID_UNAVAILABLE_SYMBOLS_ARE_WEAK__)) diff --git a/mediapipe/framework/output_stream_shard.h b/mediapipe/framework/output_stream_shard.h index fdc5fe077..718174c45 100644 --- a/mediapipe/framework/output_stream_shard.h +++ b/mediapipe/framework/output_stream_shard.h @@ -127,6 +127,8 @@ class OutputStreamShard : public OutputStream { friend class GraphProfiler; // Accesses OutputStreamShard for profiling. friend class GraphTracer; + // Accesses OutputStreamShard for profiling. + friend class PerfettoTraceScope; // Accesses OutputStreamShard for post processing. friend class OutputStreamManager; }; diff --git a/mediapipe/framework/port/BUILD b/mediapipe/framework/port/BUILD index e499ca3a6..1039dc1c6 100644 --- a/mediapipe/framework/port/BUILD +++ b/mediapipe/framework/port/BUILD @@ -18,7 +18,7 @@ licenses(["notice"]) package( - default_visibility = ["//visibility:private"], + default_visibility = ["//visibility:public"], features = ["-parse_headers"], ) @@ -28,7 +28,6 @@ config_setting( define_values = { "USE_MEDIAPIPE_THREADPOOL": "1", }, - visibility = ["//visibility:public"], ) #TODO : remove from OSS. @@ -37,13 +36,11 @@ config_setting( define_values = { "USE_MEDIAPIPE_THREADPOOL": "0", }, - visibility = ["//visibility:public"], ) cc_library( name = "aligned_malloc_and_free", hdrs = ["aligned_malloc_and_free.h"], - visibility = ["//visibility:public"], deps = [ "//mediapipe/framework/deps:aligned_malloc_and_free", "@com_google_absl//absl/base:core_headers", @@ -57,7 +54,6 @@ cc_library( "advanced_proto_inc.h", "proto_ns.h", ], - visibility = ["//visibility:public"], deps = [ ":advanced_proto_lite", ":core_proto", @@ -72,7 +68,6 @@ cc_library( "advanced_proto_lite_inc.h", "proto_ns.h", ], - visibility = ["//visibility:public"], deps = [ ":core_proto", "//mediapipe/framework:port", @@ -83,7 +78,6 @@ cc_library( cc_library( name = "any_proto", hdrs = ["any_proto.h"], - visibility = ["//visibility:public"], deps = [ ":core_proto", ], @@ -94,7 +88,6 @@ cc_library( hdrs = [ "commandlineflags.h", ], - visibility = ["//visibility:public"], deps = [ "//third_party:glog", "@com_google_absl//absl/flags:flag", @@ -107,7 +100,6 @@ cc_library( "core_proto_inc.h", "proto_ns.h", ], - visibility = ["//visibility:public"], deps = [ "//mediapipe/framework:port", "@com_google_protobuf//:protobuf", @@ -117,7 +109,6 @@ cc_library( cc_library( name = "file_helpers", hdrs = ["file_helpers.h"], - visibility = ["//visibility:public"], deps = [ ":status", "//mediapipe/framework/deps:file_helpers", @@ -128,7 +119,6 @@ cc_library( cc_library( name = "image_resizer", hdrs = ["image_resizer.h"], - visibility = ["//visibility:public"], deps = select({ "//conditions:default": [ "//mediapipe/framework/deps:image_resizer", @@ -140,14 +130,12 @@ cc_library( cc_library( name = "integral_types", hdrs = ["integral_types.h"], - visibility = ["//visibility:public"], ) cc_library( name = "benchmark", testonly = 1, hdrs = ["benchmark.h"], - visibility = ["//visibility:public"], deps = [ "@com_google_benchmark//:benchmark", ], @@ -158,7 +146,6 @@ cc_library( hdrs = [ "re2.h", ], - visibility = ["//visibility:public"], deps = [ "//mediapipe/framework/deps:re2", ], @@ -173,7 +160,6 @@ cc_library( "gtest-spi.h", "status_matchers.h", ], - visibility = ["//visibility:public"], deps = [ ":status_matchers", "//mediapipe/framework/deps:message_matchers", @@ -190,7 +176,6 @@ cc_library( "gtest-spi.h", "status_matchers.h", ], - visibility = ["//visibility:public"], deps = [ ":status_matchers", "//mediapipe/framework/deps:message_matchers", @@ -204,7 +189,6 @@ cc_library( hdrs = [ "logging.h", ], - visibility = ["//visibility:public"], deps = [ "//mediapipe/framework:port", "//third_party:glog", @@ -217,7 +201,6 @@ cc_library( hdrs = [ "map_util.h", ], - visibility = ["//visibility:public"], deps = [ "//mediapipe/framework:port", "//mediapipe/framework/deps:map_util", @@ -227,7 +210,6 @@ cc_library( cc_library( name = "numbers", hdrs = ["numbers.h"], - visibility = ["//visibility:public"], deps = ["//mediapipe/framework/deps:numbers"], ) @@ -238,13 +220,11 @@ config_setting( define_values = { "MEDIAPIPE_DISABLE_OPENCV": "1", }, - visibility = ["//visibility:public"], ) cc_library( name = "opencv_core", hdrs = ["opencv_core_inc.h"], - visibility = ["//visibility:public"], deps = [ "//third_party:opencv", ], @@ -253,7 +233,6 @@ cc_library( cc_library( name = "opencv_imgproc", hdrs = ["opencv_imgproc_inc.h"], - visibility = ["//visibility:public"], deps = [ ":opencv_core", "//third_party:opencv", @@ -263,7 +242,6 @@ cc_library( cc_library( name = "opencv_imgcodecs", hdrs = ["opencv_imgcodecs_inc.h"], - visibility = ["//visibility:public"], deps = [ ":opencv_core", "//third_party:opencv", @@ -273,7 +251,6 @@ cc_library( cc_library( name = "opencv_highgui", hdrs = ["opencv_highgui_inc.h"], - visibility = ["//visibility:public"], deps = [ ":opencv_core", "//third_party:opencv", @@ -283,7 +260,6 @@ cc_library( cc_library( name = "opencv_video", hdrs = ["opencv_video_inc.h"], - visibility = ["//visibility:public"], deps = [ ":opencv_core", "//mediapipe/framework:port", @@ -294,7 +270,6 @@ cc_library( cc_library( name = "opencv_features2d", hdrs = ["opencv_features2d_inc.h"], - visibility = ["//visibility:public"], deps = [ ":opencv_core", "//third_party:opencv", @@ -304,7 +279,6 @@ cc_library( cc_library( name = "opencv_calib3d", hdrs = ["opencv_calib3d_inc.h"], - visibility = ["//visibility:public"], deps = [ ":opencv_core", "//third_party:opencv", @@ -314,7 +288,6 @@ cc_library( cc_library( name = "opencv_videoio", hdrs = ["opencv_videoio_inc.h"], - visibility = ["//visibility:public"], deps = [ ":opencv_core", "//mediapipe/framework:port", @@ -328,7 +301,6 @@ cc_library( "parse_text_proto.h", "proto_ns.h", ], - visibility = ["//visibility:public"], deps = [ ":core_proto", ":logging", @@ -339,14 +311,12 @@ cc_library( cc_library( name = "point", hdrs = ["point2.h"], - visibility = ["//visibility:public"], deps = ["//mediapipe/framework/deps:point"], ) cc_library( name = "port", hdrs = ["port.h"], - visibility = ["//visibility:public"], deps = [ "//mediapipe/framework:port", "@com_google_absl//absl/base:core_headers", @@ -356,14 +326,12 @@ cc_library( cc_library( name = "rectangle", hdrs = ["rectangle.h"], - visibility = ["//visibility:public"], deps = ["//mediapipe/framework/deps:rectangle"], ) cc_library( name = "ret_check", hdrs = ["ret_check.h"], - visibility = ["//visibility:public"], deps = [ "//mediapipe/framework:port", "//mediapipe/framework/deps:ret_check", @@ -373,7 +341,6 @@ cc_library( cc_library( name = "singleton", hdrs = ["singleton.h"], - visibility = ["//visibility:public"], deps = ["//mediapipe/framework/deps:singleton"], ) @@ -382,7 +349,6 @@ cc_library( hdrs = [ "source_location.h", ], - visibility = ["//visibility:public"], deps = [ "//mediapipe/framework:port", "//mediapipe/framework/deps:source_location", @@ -397,7 +363,6 @@ cc_library( "status_builder.h", "status_macros.h", ], - visibility = ["//visibility:public"], deps = [ ":source_location", "//mediapipe/framework:port", @@ -412,7 +377,6 @@ cc_library( hdrs = [ "statusor.h", ], - visibility = ["//visibility:public"], deps = [ "//mediapipe/framework:port", "@com_google_absl//absl/status:statusor", @@ -423,7 +387,6 @@ cc_library( name = "status_matchers", testonly = 1, hdrs = ["status_matchers.h"], - visibility = ["//visibility:private"], deps = [ ":status", "@com_google_googletest//:gtest", @@ -433,7 +396,6 @@ cc_library( cc_library( name = "threadpool", hdrs = ["threadpool.h"], - visibility = ["//visibility:public"], deps = select({ "//conditions:default": [":threadpool_impl_default_to_google"], "//mediapipe:android": [":threadpool_impl_default_to_mediapipe"], @@ -460,7 +422,6 @@ alias( cc_library( name = "topologicalsorter", hdrs = ["topologicalsorter.h"], - visibility = ["//visibility:public"], deps = [ "//mediapipe/framework:port", "//mediapipe/framework/deps:topologicalsorter", @@ -470,6 +431,5 @@ cc_library( cc_library( name = "vector", hdrs = ["vector.h"], - visibility = ["//visibility:public"], deps = ["//mediapipe/framework/deps:vector"], ) diff --git a/mediapipe/framework/port/build_config.bzl b/mediapipe/framework/port/build_config.bzl index eaabda856..94a4a5646 100644 --- a/mediapipe/framework/port/build_config.bzl +++ b/mediapipe/framework/port/build_config.bzl @@ -228,6 +228,8 @@ def mediapipe_ts_library( srcs = srcs, visibility = visibility, deps = deps + [ + "@npm//@types/jasmine", + "@npm//@types/node", "@npm//@types/offscreencanvas", "@npm//@types/google-protobuf", ], diff --git a/mediapipe/framework/profiler/BUILD b/mediapipe/framework/profiler/BUILD index b53a1ac39..3b6976fc8 100644 --- a/mediapipe/framework/profiler/BUILD +++ b/mediapipe/framework/profiler/BUILD @@ -140,7 +140,7 @@ cc_library( name = "circular_buffer", hdrs = ["circular_buffer.h"], visibility = [ - "//visibility:public", + "//mediapipe:__subpackages__", ], deps = [ "//mediapipe/framework/port:integral_types", @@ -151,7 +151,6 @@ cc_test( name = "circular_buffer_test", size = "small", srcs = ["circular_buffer_test.cc"], - visibility = ["//visibility:public"], deps = [ ":circular_buffer", "//mediapipe/framework/port:gtest_main", @@ -164,7 +163,7 @@ cc_library( name = "trace_buffer", srcs = ["trace_buffer.h"], hdrs = ["trace_buffer.h"], - visibility = ["//visibility:public"], + visibility = ["//mediapipe/framework/profiler:__subpackages__"], deps = [ ":circular_buffer", "//mediapipe/framework:calculator_profile_cc_proto", @@ -292,9 +291,7 @@ cc_library( "-ObjC++", ], }), - visibility = [ - "//mediapipe/framework:mediapipe_internal", - ], + visibility = ["//visibility:private"], deps = [ "@com_google_absl//absl/flags:flag", "//mediapipe/framework/port:logging", diff --git a/mediapipe/framework/profiler/graph_profiler.h b/mediapipe/framework/profiler/graph_profiler.h index 29969af2e..23caed4ec 100644 --- a/mediapipe/framework/profiler/graph_profiler.h +++ b/mediapipe/framework/profiler/graph_profiler.h @@ -232,6 +232,11 @@ class GraphProfiler : public std::enable_shared_from_this { const ProfilerConfig& profiler_config() { return profiler_config_; } + // Helper method to expose the config to other profilers. + const ValidatedGraphConfig* GetValidatedGraphConfig() { + return validated_graph_; + } + private: // This can be used to add packet info for the input streams to the graph. // It treats the stream defined by |stream_name| as a stream produced by a diff --git a/mediapipe/framework/scheduler.cc b/mediapipe/framework/scheduler.cc index afef4f383..854c10fd5 100644 --- a/mediapipe/framework/scheduler.cc +++ b/mediapipe/framework/scheduler.cc @@ -117,7 +117,7 @@ void Scheduler::SubmitWaitingTasksOnQueues() { // Note: state_mutex_ is held when this function is entered or // exited. void Scheduler::HandleIdle() { - if (handling_idle_) { + if (++handling_idle_ > 1) { // Someone is already inside this method. // Note: This can happen in the sections below where we unlock the mutex // and make more nodes runnable: the nodes can run and become idle again @@ -127,7 +127,6 @@ void Scheduler::HandleIdle() { VLOG(2) << "HandleIdle: already in progress"; return; } - handling_idle_ = true; while (IsIdle() && (state_ == STATE_RUNNING || state_ == STATE_CANCELLING)) { // Remove active sources that are closed. @@ -165,11 +164,17 @@ void Scheduler::HandleIdle() { } } + // If HandleIdle has been called again, then continue scheduling. + if (handling_idle_ > 1) { + handling_idle_ = 1; + continue; + } + // Nothing left to do. break; } - handling_idle_ = false; + handling_idle_ = 0; } // Note: state_mutex_ is held when this function is entered or exited. diff --git a/mediapipe/framework/scheduler.h b/mediapipe/framework/scheduler.h index dd1572d99..b59467b9f 100644 --- a/mediapipe/framework/scheduler.h +++ b/mediapipe/framework/scheduler.h @@ -302,7 +302,7 @@ class Scheduler { // - We need it to be reentrant, which Mutex does not support. // - We want simultaneous calls to return immediately instead of waiting, // and Mutex's TryLock is not guaranteed to work. - bool handling_idle_ ABSL_GUARDED_BY(state_mutex_) = false; + int handling_idle_ ABSL_GUARDED_BY(state_mutex_) = 0; // Mutex for the scheduler state and related things. // Note: state_ is declared as atomic so that its getter methods don't need diff --git a/mediapipe/framework/tool/BUILD b/mediapipe/framework/tool/BUILD index 52d04b4b1..89cb802da 100644 --- a/mediapipe/framework/tool/BUILD +++ b/mediapipe/framework/tool/BUILD @@ -90,7 +90,7 @@ mediapipe_proto_library( name = "packet_generator_wrapper_calculator_proto", srcs = ["packet_generator_wrapper_calculator.proto"], def_py_proto = False, - visibility = ["//mediapipe/framework:mediapipe_internal"], + visibility = ["//mediapipe/framework:__subpackages__"], deps = [ "//mediapipe/framework:calculator_options_proto", "//mediapipe/framework:packet_generator_proto", @@ -120,13 +120,13 @@ cc_library( name = "fill_packet_set", srcs = ["fill_packet_set.cc"], hdrs = ["fill_packet_set.h"], - visibility = ["//mediapipe/framework:mediapipe_internal"], + visibility = ["//mediapipe/framework:__subpackages__"], deps = [ + ":status_util", "//mediapipe/framework:packet_set", "//mediapipe/framework:packet_type", "//mediapipe/framework/port:status", "//mediapipe/framework/port:statusor", - "//mediapipe/framework/tool:status_util", "@com_google_absl//absl/memory", ], ) @@ -162,7 +162,6 @@ cc_library( cc_test( name = "executor_util_test", srcs = ["executor_util_test.cc"], - visibility = ["//mediapipe/framework:mediapipe_internal"], deps = [ ":executor_util", "//mediapipe/framework/port:gtest_main", @@ -173,7 +172,7 @@ cc_test( cc_library( name = "options_map", hdrs = ["options_map.h"], - visibility = ["//mediapipe/framework:mediapipe_internal"], + visibility = ["//mediapipe:__subpackages__"], deps = [ ":type_util", "//mediapipe/framework:calculator_cc_proto", @@ -193,7 +192,7 @@ cc_library( name = "options_field_util", srcs = ["options_field_util.cc"], hdrs = ["options_field_util.h"], - visibility = ["//mediapipe/framework:mediapipe_internal"], + visibility = ["//visibility:private"], deps = [ ":field_data_cc_proto", ":name_util", @@ -216,7 +215,7 @@ cc_library( name = "options_syntax_util", srcs = ["options_syntax_util.cc"], hdrs = ["options_syntax_util.h"], - visibility = ["//mediapipe/framework:mediapipe_internal"], + visibility = ["//visibility:private"], deps = [ ":name_util", ":options_field_util", @@ -235,8 +234,9 @@ cc_library( name = "options_util", srcs = ["options_util.cc"], hdrs = ["options_util.h"], - visibility = ["//mediapipe/framework:mediapipe_internal"], + visibility = ["//visibility:public"], deps = [ + ":name_util", ":options_field_util", ":options_map", ":options_registry", @@ -254,7 +254,6 @@ cc_library( "//mediapipe/framework/port:advanced_proto", "//mediapipe/framework/port:any_proto", "//mediapipe/framework/port:status", - "//mediapipe/framework/tool:name_util", "@com_google_absl//absl/strings", ], ) @@ -323,7 +322,7 @@ mediapipe_cc_test( cc_library( name = "packet_generator_wrapper_calculator", srcs = ["packet_generator_wrapper_calculator.cc"], - visibility = ["//mediapipe/framework:mediapipe_internal"], + visibility = ["//mediapipe/framework:__subpackages__"], deps = [ ":packet_generator_wrapper_calculator_cc_proto", "//mediapipe/framework:calculator_base", @@ -347,6 +346,7 @@ cc_library( "//mediapipe/framework/port:logging", "//mediapipe/framework/port:ret_check", "//mediapipe/framework/port:status", + "//mediapipe/framework/port:statusor", "@com_google_absl//absl/strings", ], ) @@ -507,6 +507,7 @@ cc_library( visibility = ["//visibility:public"], deps = [ ":proto_util_lite", + "//mediapipe/framework:calculator_cc_proto", "//mediapipe/framework/deps:proto_descriptor_cc_proto", "//mediapipe/framework/port:advanced_proto", "//mediapipe/framework/port:integral_types", diff --git a/mediapipe/framework/tool/calculator_graph_template.proto b/mediapipe/framework/tool/calculator_graph_template.proto index 27153f3f7..31c233812 100644 --- a/mediapipe/framework/tool/calculator_graph_template.proto +++ b/mediapipe/framework/tool/calculator_graph_template.proto @@ -27,6 +27,9 @@ message TemplateExpression { // The FieldDescriptor::Type of the modified field. optional mediapipe.FieldDescriptorProto.Type field_type = 5; + // The FieldDescriptor::Type of each map key in the path. + repeated mediapipe.FieldDescriptorProto.Type key_type = 6; + // Alternative value for the modified field, in protobuf binary format. optional string field_value = 7; } diff --git a/mediapipe/framework/tool/proto_util_lite.cc b/mediapipe/framework/tool/proto_util_lite.cc index 4628815ea..a810ce129 100644 --- a/mediapipe/framework/tool/proto_util_lite.cc +++ b/mediapipe/framework/tool/proto_util_lite.cc @@ -22,6 +22,7 @@ #include "mediapipe/framework/port/canonical_errors.h" #include "mediapipe/framework/port/logging.h" #include "mediapipe/framework/port/ret_check.h" +#include "mediapipe/framework/port/statusor.h" #include "mediapipe/framework/tool/field_data.pb.h" #include "mediapipe/framework/type_map.h" @@ -87,12 +88,13 @@ absl::Status ReadPackedValues(WireFormatLite::WireType wire_type, // Extracts the data value(s) for one field from a serialized message. // The message with these field values removed is written to |out|. -absl::Status GetFieldValues(uint32 field_id, WireFormatLite::WireType wire_type, - CodedInputStream* in, CodedOutputStream* out, +absl::Status GetFieldValues(uint32 field_id, CodedInputStream* in, + CodedOutputStream* out, std::vector* field_values) { uint32 tag; while ((tag = in->ReadTag()) != 0) { int field_number = WireFormatLite::GetTagFieldNumber(tag); + WireFormatLite::WireType wire_type = WireFormatLite::GetTagWireType(tag); if (field_number == field_id) { if (!IsLengthDelimited(wire_type) && IsLengthDelimited(WireFormatLite::GetTagWireType(tag))) { @@ -131,9 +133,7 @@ absl::Status FieldAccess::SetMessage(const std::string& message) { CodedInputStream in(&ais); StringOutputStream sos(&message_); CodedOutputStream out(&sos); - WireFormatLite::WireType wire_type = - WireFormatLite::WireTypeForFieldType(field_type_); - return GetFieldValues(field_id_, wire_type, &in, &out, &field_values_); + return GetFieldValues(field_id_, &in, &out, &field_values_); } void FieldAccess::GetMessage(std::string* result) { @@ -149,18 +149,56 @@ std::vector* FieldAccess::mutable_field_values() { return &field_values_; } +namespace { +using ProtoPathEntry = ProtoUtilLite::ProtoPathEntry; + +// Returns the FieldAccess and index for a field-id or a map-id. +// Returns access to the field-id if the field index is found, +// to the map-id if the map entry is found, and to the field-id otherwise. +absl::StatusOr> AccessField( + const ProtoPathEntry& entry, FieldType field_type, + const FieldValue& message) { + FieldAccess result(entry.field_id, field_type); + if (entry.field_id >= 0) { + MP_RETURN_IF_ERROR(result.SetMessage(message)); + if (entry.index < result.mutable_field_values()->size()) { + return std::pair(result, entry.index); + } + } + if (entry.map_id >= 0) { + FieldAccess access(entry.map_id, field_type); + MP_RETURN_IF_ERROR(access.SetMessage(message)); + auto& field_values = *access.mutable_field_values(); + for (int index = 0; index < field_values.size(); ++index) { + FieldAccess key(entry.key_id, entry.key_type); + MP_RETURN_IF_ERROR(key.SetMessage(field_values[index])); + if (key.mutable_field_values()->at(0) == entry.key_value) { + return std::pair(std::move(access), index); + } + } + } + if (entry.field_id >= 0) { + return std::pair(result, entry.index); + } + return absl::InvalidArgumentError(absl::StrCat( + "ProtoPath field missing, field-id: ", entry.field_id, ", map-id: ", + entry.map_id, ", key: ", entry.key_value, " key_type: ", entry.key_type)); +} + +} // namespace + // Replaces a range of field values for one field nested within a protobuf. absl::Status ProtoUtilLite::ReplaceFieldRange( FieldValue* message, ProtoPath proto_path, int length, FieldType field_type, const std::vector& field_values) { - int field_id, index; - std::tie(field_id, index) = proto_path.front(); + ProtoPathEntry entry = proto_path.front(); proto_path.erase(proto_path.begin()); - FieldAccess access(field_id, !proto_path.empty() - ? WireFormatLite::TYPE_MESSAGE - : field_type); - MP_RETURN_IF_ERROR(access.SetMessage(*message)); - std::vector& v = *access.mutable_field_values(); + FieldType type = + !proto_path.empty() ? WireFormatLite::TYPE_MESSAGE : field_type; + ASSIGN_OR_RETURN(auto r, AccessField(entry, type, *message)); + FieldAccess& access = r.first; + int index = r.second; + std::vector& v = *access.mutable_field_values(); if (!proto_path.empty()) { RET_CHECK_NO_LOG(index >= 0 && index < v.size()); MP_RETURN_IF_ERROR(ReplaceFieldRange(&v[index], proto_path, length, @@ -180,19 +218,22 @@ absl::Status ProtoUtilLite::ReplaceFieldRange( absl::Status ProtoUtilLite::GetFieldRange( const FieldValue& message, ProtoPath proto_path, int length, FieldType field_type, std::vector* field_values) { - int field_id, index; - std::tie(field_id, index) = proto_path.front(); + ProtoPathEntry entry = proto_path.front(); proto_path.erase(proto_path.begin()); - FieldAccess access(field_id, !proto_path.empty() - ? WireFormatLite::TYPE_MESSAGE - : field_type); - MP_RETURN_IF_ERROR(access.SetMessage(message)); - std::vector& v = *access.mutable_field_values(); + FieldType type = + !proto_path.empty() ? WireFormatLite::TYPE_MESSAGE : field_type; + ASSIGN_OR_RETURN(auto r, AccessField(entry, type, message)); + FieldAccess& access = r.first; + int index = r.second; + std::vector& v = *access.mutable_field_values(); if (!proto_path.empty()) { RET_CHECK_NO_LOG(index >= 0 && index < v.size()); MP_RETURN_IF_ERROR( GetFieldRange(v[index], proto_path, length, field_type, field_values)); } else { + if (length == -1) { + length = v.size() - index; + } RET_CHECK_NO_LOG(index >= 0 && index <= v.size()); RET_CHECK_NO_LOG(index + length >= 0 && index + length <= v.size()); field_values->insert(field_values->begin(), v.begin() + index, @@ -206,19 +247,21 @@ absl::Status ProtoUtilLite::GetFieldCount(const FieldValue& message, ProtoPath proto_path, FieldType field_type, int* field_count) { - int field_id, index; - std::tie(field_id, index) = proto_path.back(); - proto_path.pop_back(); - std::vector parent; - if (proto_path.empty()) { - parent.push_back(std::string(message)); + ProtoPathEntry entry = proto_path.front(); + proto_path.erase(proto_path.begin()); + FieldType type = + !proto_path.empty() ? WireFormatLite::TYPE_MESSAGE : field_type; + ASSIGN_OR_RETURN(auto r, AccessField(entry, type, message)); + FieldAccess& access = r.first; + int index = r.second; + std::vector& v = *access.mutable_field_values(); + if (!proto_path.empty()) { + RET_CHECK_NO_LOG(index >= 0 && index < v.size()); + MP_RETURN_IF_ERROR( + GetFieldCount(v[index], proto_path, field_type, field_count)); } else { - MP_RETURN_IF_ERROR(ProtoUtilLite::GetFieldRange( - message, proto_path, 1, WireFormatLite::TYPE_MESSAGE, &parent)); + *field_count = v.size(); } - FieldAccess access(field_id, field_type); - MP_RETURN_IF_ERROR(access.SetMessage(parent[0])); - *field_count = access.mutable_field_values()->size(); return absl::OkStatus(); } diff --git a/mediapipe/framework/tool/proto_util_lite.h b/mediapipe/framework/tool/proto_util_lite.h index 7d3a263f3..15e321eeb 100644 --- a/mediapipe/framework/tool/proto_util_lite.h +++ b/mediapipe/framework/tool/proto_util_lite.h @@ -34,15 +34,36 @@ class ProtoUtilLite { // Defines field types and tag formats. using WireFormatLite = proto_ns::internal::WireFormatLite; - // Defines a sequence of nested field-number field-index pairs. - using ProtoPath = std::vector>; - // The serialized value for a protobuf field. using FieldValue = std::string; // The serialized data type for a protobuf field. using FieldType = WireFormatLite::FieldType; + // A field-id and index, or a map-id and key, or both. + struct ProtoPathEntry { + ProtoPathEntry(int id, int index) : field_id(id), index(index) {} + ProtoPathEntry(int id, int key_id, FieldType key_type, FieldValue key_value) + : map_id(id), + key_id(key_id), + key_type(key_type), + key_value(std::move(key_value)) {} + bool operator==(const ProtoPathEntry& o) const { + return field_id == o.field_id && index == o.index && map_id == o.map_id && + key_id == o.key_id && key_type == o.key_type && + key_value == o.key_value; + } + int field_id = -1; + int index = -1; + int map_id = -1; + int key_id = -1; + FieldType key_type = FieldType::MAX_FIELD_TYPE; + FieldValue key_value; + }; + + // Defines a sequence of nested field-number field-index pairs. + using ProtoPath = std::vector; + class FieldAccess { public: // Provides access to a certain protobuf field. @@ -57,9 +78,11 @@ class ProtoUtilLite { // Returns the serialized values of the protobuf field. std::vector* mutable_field_values(); + uint32 field_id() const { return field_id_; } + private: - const uint32 field_id_; - const FieldType field_type_; + uint32 field_id_; + FieldType field_type_; std::string message_; std::vector field_values_; }; diff --git a/mediapipe/framework/tool/template_expander.cc b/mediapipe/framework/tool/template_expander.cc index 034e1a026..a91ea5adc 100644 --- a/mediapipe/framework/tool/template_expander.cc +++ b/mediapipe/framework/tool/template_expander.cc @@ -22,6 +22,7 @@ #include #include "absl/strings/ascii.h" +#include "absl/strings/match.h" #include "absl/strings/numbers.h" #include "absl/strings/str_join.h" #include "absl/strings/str_split.h" @@ -44,6 +45,7 @@ using WireFormatLite = ProtoUtilLite::WireFormatLite; using FieldValue = ProtoUtilLite::FieldValue; using FieldType = ProtoUtilLite::FieldType; using ProtoPath = ProtoUtilLite::ProtoPath; +using ProtoPathEntry = ProtoUtilLite::ProtoPathEntry; namespace { @@ -84,26 +86,87 @@ std::unique_ptr CloneMessage(const MessageLite& message) { return result; } -// Returns the (tag, index) pairs in a field path. -// For example, returns {{1, 1}, {2, 1}, {3, 1}} for path "/1[1]/2[1]/3[1]". -absl::Status ProtoPathSplit(const std::string& path, ProtoPath* result) { - absl::Status status; - std::vector ids = absl::StrSplit(path, '/'); - for (const std::string& id : ids) { - if (id.length() > 0) { - std::pair id_pair = - absl::StrSplit(id, absl::ByAnyChar("[]")); - int tag = 0; - int index = 0; - bool ok = absl::SimpleAtoi(id_pair.first, &tag) && - absl::SimpleAtoi(id_pair.second, &index); - if (!ok) { - status.Update(absl::InvalidArgumentError(path)); - } - result->push_back(std::make_pair(tag, index)); +// Parses one ProtoPathEntry. +// The parsed entry is appended to `result` and removed from `path`. +// ProtoPathEntry::key_value stores map key text. Use SetMapKeyTypes +// to serialize the key text to protobuf wire format. +absl::Status ParseEntry(absl::string_view& path, ProtoPath* result) { + bool ok = true; + int sb = path.find('['); + int eb = path.find(']'); + int field_id = -1; + ok &= absl::SimpleAtoi(path.substr(0, sb), &field_id); + auto selector = path.substr(sb + 1, eb - 1 - sb); + if (absl::StartsWith(selector, "@")) { + int eq = selector.find('='); + int key_id = -1; + ok &= absl::SimpleAtoi(selector.substr(1, eq - 1), &key_id); + auto key_text = selector.substr(eq + 1); + FieldType key_type = FieldType::TYPE_STRING; + result->push_back({field_id, key_id, key_type, std::string(key_text)}); + } else { + int index = 0; + ok &= absl::SimpleAtoi(selector, &index); + result->push_back({field_id, index}); + } + int end = path.find('/', eb); + if (end == std::string::npos) { + path = ""; + } else { + path = path.substr(end + 1); + } + return ok ? absl::OkStatus() + : absl::InvalidArgumentError( + absl::StrCat("Failed to parse ProtoPath entry: ", path)); +} + +// Specifies the FieldTypes for protobuf map keys in a ProtoPath. +// Each ProtoPathEntry::key_value is converted from text to the protobuf +// wire format for its key type. +absl::Status SetMapKeyTypes(const std::vector& key_types, + ProtoPath* result) { + int i = 0; + for (ProtoPathEntry& entry : *result) { + if (entry.map_id >= 0) { + FieldType key_type = key_types[i++]; + std::vector key_value; + MP_RETURN_IF_ERROR( + ProtoUtilLite::Serialize({entry.key_value}, key_type, &key_value)); + entry.key_type = key_type; + entry.key_value = key_value.front(); } } - return status; + return absl::OkStatus(); +} + +// Returns the (tag, index) pairs in a field path. +// For example, returns {{1, 1}, {2, 1}, {3, 1}} for "/1[1]/2[1]/3[1]", +// returns {{1, 1}, {2, 1, "INPUT_FRAMES"}} for "/1[1]/2[@1=INPUT_FRAMES]". +absl::Status ProtoPathSplit(const std::string& path, ProtoPath* result) { + result->clear(); + absl::string_view rest = path; + if (absl::StartsWith(rest, "/")) { + rest = rest.substr(1); + } + while (!rest.empty()) { + MP_RETURN_IF_ERROR(ParseEntry(rest, result)); + } + return absl::OkStatus(); +} + +// Parse the TemplateExpression.path field into a ProtoPath struct. +absl::Status ParseProtoPath(const TemplateExpression& rule, + std::string base_path, ProtoPath* result) { + ProtoPath base_entries; + MP_RETURN_IF_ERROR(ProtoPathSplit(base_path, &base_entries)); + MP_RETURN_IF_ERROR(ProtoPathSplit(rule.path(), result)); + std::vector key_types; + for (int type : rule.key_type()) { + key_types.push_back(static_cast(type)); + } + MP_RETURN_IF_ERROR(SetMapKeyTypes(key_types, result)); + result->erase(result->begin(), result->begin() + base_entries.size()); + return absl::OkStatus(); } // Returns true if one proto path is prefix by another. @@ -111,13 +174,6 @@ bool ProtoPathStartsWith(const std::string& path, const std::string& prefix) { return absl::StartsWith(path, prefix); } -// Returns the part of one proto path after a prefix proto path. -std::string ProtoPathRelative(const std::string& field_path, - const std::string& base_path) { - CHECK(ProtoPathStartsWith(field_path, base_path)); - return field_path.substr(base_path.length()); -} - // Returns the target ProtoUtilLite::FieldType of a rule. FieldType GetFieldType(const TemplateExpression& rule) { return static_cast(rule.field_type()); @@ -126,19 +182,10 @@ FieldType GetFieldType(const TemplateExpression& rule) { // Returns the count of field values at a ProtoPath. int FieldCount(const FieldValue& base, ProtoPath field_path, FieldType field_type) { - int field_id, index; - std::tie(field_id, index) = field_path.back(); - field_path.pop_back(); - std::vector parent; - if (field_path.empty()) { - parent.push_back(base); - } else { - MEDIAPIPE_CHECK_OK(ProtoUtilLite::GetFieldRange( - base, field_path, 1, WireFormatLite::TYPE_MESSAGE, &parent)); - } - ProtoUtilLite::FieldAccess access(field_id, field_type); - MEDIAPIPE_CHECK_OK(access.SetMessage(parent[0])); - return access.mutable_field_values()->size(); + int result = 0; + CHECK( + ProtoUtilLite::GetFieldCount(base, field_path, field_type, &result).ok()); + return result; } } // namespace @@ -229,9 +276,7 @@ class TemplateExpanderImpl { return absl::OkStatus(); } ProtoPath field_path; - absl::Status status = - ProtoPathSplit(ProtoPathRelative(rule.path(), base_path), &field_path); - if (!status.ok()) return status; + MP_RETURN_IF_ERROR(ParseProtoPath(rule, base_path, &field_path)); return ProtoUtilLite::GetFieldRange(output, field_path, 1, GetFieldType(rule), base); } @@ -242,12 +287,13 @@ class TemplateExpanderImpl { const std::vector& field_values, FieldValue* output) { if (!rule.has_path()) { - *output = field_values[0]; + if (!field_values.empty()) { + *output = field_values[0]; + } return absl::OkStatus(); } ProtoPath field_path; - RET_CHECK_OK( - ProtoPathSplit(ProtoPathRelative(rule.path(), base_path), &field_path)); + MP_RETURN_IF_ERROR(ParseProtoPath(rule, base_path, &field_path)); int field_count = 1; if (rule.has_field_value()) { // For a non-repeated field, only one value can be specified. @@ -257,7 +303,7 @@ class TemplateExpanderImpl { "Multiple values specified for non-repeated field: ", rule.path())); } // For a non-repeated field, the field value is stored only in the rule. - field_path[field_path.size() - 1].second = 0; + field_path[field_path.size() - 1].index = 0; field_count = 0; } return ProtoUtilLite::ReplaceFieldRange(output, field_path, field_count, diff --git a/mediapipe/framework/tool/template_parser.cc b/mediapipe/framework/tool/template_parser.cc index 1d81e7a78..cf23f3443 100644 --- a/mediapipe/framework/tool/template_parser.cc +++ b/mediapipe/framework/tool/template_parser.cc @@ -26,6 +26,7 @@ #include "absl/strings/str_cat.h" #include "absl/strings/str_join.h" #include "absl/strings/str_split.h" +#include "mediapipe/framework/calculator.pb.h" #include "mediapipe/framework/deps/proto_descriptor.pb.h" #include "mediapipe/framework/port/canonical_errors.h" #include "mediapipe/framework/port/integral_types.h" @@ -45,6 +46,9 @@ using mediapipe::proto_ns::Message; using mediapipe::proto_ns::OneofDescriptor; using mediapipe::proto_ns::Reflection; using mediapipe::proto_ns::TextFormat; +using ProtoPath = mediapipe::tool::ProtoUtilLite::ProtoPath; +using FieldType = mediapipe::tool::ProtoUtilLite::FieldType; +using FieldValue = mediapipe::tool::ProtoUtilLite::FieldValue; namespace mediapipe { @@ -1357,32 +1361,138 @@ absl::Status ProtoPathSplit(const std::string& path, if (!ok) { status.Update(absl::InvalidArgumentError(path)); } - result->push_back(std::make_pair(tag, index)); + result->push_back({tag, index}); } } return status; } +// Returns a message serialized deterministically. +bool DeterministicallySerialize(const Message& proto, std::string* result) { + proto_ns::io::StringOutputStream stream(result); + proto_ns::io::CodedOutputStream output(&stream); + output.SetSerializationDeterministic(true); + return proto.SerializeToCodedStream(&output); +} + // Serialize one field of a message. void SerializeField(const Message* message, const FieldDescriptor* field, std::vector* result) { ProtoUtilLite::FieldValue message_bytes; - CHECK(message->SerializePartialToString(&message_bytes)); + CHECK(DeterministicallySerialize(*message, &message_bytes)); ProtoUtilLite::FieldAccess access( field->number(), static_cast(field->type())); MEDIAPIPE_CHECK_OK(access.SetMessage(message_bytes)); *result = *access.mutable_field_values(); } +// Serialize a ProtoPath as a readable string. +// For example, {{1, 1}, {2, 1}, {3, 1}} returns "/1[1]/2[1]/3[1]", +// and {{1, 1}, {2, 1, "INPUT_FRAMES"}} returns "/1[1]/2[@1=INPUT_FRAMES]". +std::string ProtoPathJoin(ProtoPath path) { + std::string result; + for (ProtoUtilLite::ProtoPathEntry& e : path) { + if (e.field_id >= 0) { + absl::StrAppend(&result, "/", e.field_id, "[", e.index, "]"); + } else if (e.map_id >= 0) { + absl::StrAppend(&result, "/", e.map_id, "[@", e.key_id, "=", e.key_value, + "]"); + } + } + return result; +} + +// Returns the message value from a field at an index. +const Message* GetFieldMessage(const Message& message, + const FieldDescriptor* field, int index) { + if (field->type() != FieldDescriptor::TYPE_MESSAGE) { + return nullptr; + } + if (!field->is_repeated()) { + return &message.GetReflection()->GetMessage(message, field); + } + if (index < message.GetReflection()->FieldSize(message, field)) { + return &message.GetReflection()->GetRepeatedMessage(message, field, index); + } + return nullptr; +} + +// Returns all FieldDescriptors including extensions. +std::vector GetFields(const Message* src) { + std::vector result; + src->GetDescriptor()->file()->pool()->FindAllExtensions(src->GetDescriptor(), + &result); + for (int i = 0; i < src->GetDescriptor()->field_count(); ++i) { + result.push_back(src->GetDescriptor()->field(i)); + } + return result; +} + +// Orders map entries in dst to match src. +void OrderMapEntries(const Message* src, Message* dst, + std::set* seen = nullptr) { + std::unique_ptr> seen_owner; + if (!seen) { + seen_owner = std::make_unique>(); + seen = seen_owner.get(); + } + if (seen->count(src) > 0) { + return; + } else { + seen->insert(src); + } + for (auto field : GetFields(src)) { + if (field->is_map()) { + dst->GetReflection()->ClearField(dst, field); + for (int j = 0; j < src->GetReflection()->FieldSize(*src, field); ++j) { + const Message& entry = + src->GetReflection()->GetRepeatedMessage(*src, field, j); + dst->GetReflection()->AddMessage(dst, field)->CopyFrom(entry); + } + } + if (field->type() == FieldDescriptor::TYPE_MESSAGE) { + if (field->is_repeated()) { + for (int j = 0; j < src->GetReflection()->FieldSize(*src, field); ++j) { + OrderMapEntries( + &src->GetReflection()->GetRepeatedMessage(*src, field, j), + dst->GetReflection()->MutableRepeatedMessage(dst, field, j), + seen); + } + } else { + OrderMapEntries(&src->GetReflection()->GetMessage(*src, field), + dst->GetReflection()->MutableMessage(dst, field), seen); + } + } + } +} + +// Copies a Message, keeping map entries in order. +std::unique_ptr CloneMessage(const Message* message) { + std::unique_ptr result(message->New()); + result->CopyFrom(*message); + OrderMapEntries(message, result.get()); + return result; +} + +using MessageMap = std::map>; + // For a non-repeated field, move the most recently parsed field value // into the most recently parsed template expression. -void StowFieldValue(Message* message, TemplateExpression* expression) { +void StowFieldValue(Message* message, TemplateExpression* expression, + MessageMap* stowed_messages) { const Reflection* reflection = message->GetReflection(); const Descriptor* descriptor = message->GetDescriptor(); ProtoUtilLite::ProtoPath path; MEDIAPIPE_CHECK_OK(ProtoPathSplit(expression->path(), &path)); - int field_number = path[path.size() - 1].first; + int field_number = path[path.size() - 1].field_id; const FieldDescriptor* field = descriptor->FindFieldByNumber(field_number); + + // Save each stowed message unserialized preserving map entry order. + if (!field->is_repeated() && field->type() == FieldDescriptor::TYPE_MESSAGE) { + (*stowed_messages)[ProtoPathJoin(path)] = + CloneMessage(GetFieldMessage(*message, field, 0)); + } + if (!field->is_repeated()) { std::vector field_values; SerializeField(message, field, &field_values); @@ -1402,6 +1512,112 @@ static void StripQuotes(std::string* str) { } } +// Returns the field or extension for field number. +const FieldDescriptor* FindFieldByNumber(const Message* message, + int field_num) { + const FieldDescriptor* result = + message->GetDescriptor()->FindFieldByNumber(field_num); + if (result == nullptr) { + result = message->GetReflection()->FindKnownExtensionByNumber(field_num); + } + return result; +} + +// Returns the protobuf map key types from a ProtoPath. +std::vector ProtoPathKeyTypes(ProtoPath path) { + std::vector result; + for (auto& entry : path) { + if (entry.map_id >= 0) { + result.push_back(entry.key_type); + } + } + return result; +} + +// Returns the text value for a string or numeric protobuf map key. +std::string GetMapKey(const Message& map_entry) { + auto key_field = map_entry.GetDescriptor()->FindFieldByName("key"); + auto reflection = map_entry.GetReflection(); + if (key_field->type() == FieldDescriptor::TYPE_STRING) { + return reflection->GetString(map_entry, key_field); + } else if (key_field->type() == FieldDescriptor::TYPE_INT32) { + return absl::StrCat(reflection->GetInt32(map_entry, key_field)); + } else if (key_field->type() == FieldDescriptor::TYPE_INT64) { + return absl::StrCat(reflection->GetInt64(map_entry, key_field)); + } + return ""; +} + +// Returns a Message store in CalculatorGraphTemplate::field_value. +Message* FindStowedMessage(MessageMap* stowed_messages, ProtoPath proto_path) { + auto it = stowed_messages->find(ProtoPathJoin(proto_path)); + return (it != stowed_messages->end()) ? it->second.get() : nullptr; +} + +const Message* GetNestedMessage(const Message& message, + const FieldDescriptor* field, + ProtoPath proto_path, + MessageMap* stowed_messages) { + if (field->type() != FieldDescriptor::TYPE_MESSAGE) { + return nullptr; + } + const Message* result = FindStowedMessage(stowed_messages, proto_path); + if (!result) { + result = GetFieldMessage(message, field, proto_path.back().index); + } + return result; +} + +// Adjusts map-entries from indexes to keys. +// Protobuf map-entry order is intentionally not preserved. +absl::Status KeyProtoMapEntries(Message* source, MessageMap* stowed_messages) { + // Copy the rules from the source CalculatorGraphTemplate. + mediapipe::CalculatorGraphTemplate rules; + rules.ParsePartialFromString(source->SerializePartialAsString()); + // Only the "source" Message knows all extension types. + Message* config_0 = source->GetReflection()->MutableMessage( + source, source->GetDescriptor()->FindFieldByName("config"), nullptr); + for (int i = 0; i < rules.rule().size(); ++i) { + TemplateExpression* rule = rules.mutable_rule()->Mutable(i); + const Message* message = config_0; + ProtoPath path; + MP_RETURN_IF_ERROR(ProtoPathSplit(rule->path(), &path)); + for (int j = 0; j < path.size(); ++j) { + int field_id = path[j].field_id; + const FieldDescriptor* field = FindFieldByNumber(message, field_id); + ProtoPath prefix = {path.begin(), path.begin() + j + 1}; + message = GetNestedMessage(*message, field, prefix, stowed_messages); + if (!message) { + break; + } + if (field->is_map()) { + const Message* map_entry = message; + int key_id = + map_entry->GetDescriptor()->FindFieldByName("key")->number(); + FieldType key_type = static_cast( + map_entry->GetDescriptor()->FindFieldByName("key")->type()); + std::string key_value = GetMapKey(*map_entry); + path[j] = {field_id, key_id, key_type, key_value}; + } + } + if (!rule->path().empty()) { + *rule->mutable_path() = ProtoPathJoin(path); + for (FieldType key_type : ProtoPathKeyTypes(path)) { + *rule->mutable_key_type()->Add() = key_type; + } + } + } + // Copy the rules back into the source CalculatorGraphTemplate. + auto source_rules = + source->GetReflection()->GetMutableRepeatedFieldRef( + source, source->GetDescriptor()->FindFieldByName("rule")); + source_rules.Clear(); + for (auto& rule : rules.rule()) { + source_rules.Add(rule); + } + return absl::OkStatus(); +} + } // namespace class TemplateParser::Parser::MediaPipeParserImpl @@ -1416,6 +1632,8 @@ class TemplateParser::Parser::MediaPipeParserImpl // Copy the template rules into the output template "rule" field. success &= MergeFields(template_rules_, output).ok(); + // Replace map-entry indexes with map keys. + success &= KeyProtoMapEntries(output, &stowed_messages_).ok(); return success; } @@ -1441,7 +1659,7 @@ class TemplateParser::Parser::MediaPipeParserImpl DO(ConsumeFieldTemplate(message)); } else { DO(ConsumeField(message)); - StowFieldValue(message, expression); + StowFieldValue(message, expression, &stowed_messages_); } DO(ConsumeEndTemplate()); return true; @@ -1652,6 +1870,7 @@ class TemplateParser::Parser::MediaPipeParserImpl } mediapipe::CalculatorGraphTemplate template_rules_; + std::map> stowed_messages_; }; #undef DO diff --git a/mediapipe/framework/tool/testdata/BUILD b/mediapipe/framework/tool/testdata/BUILD index 906688520..8300181b5 100644 --- a/mediapipe/framework/tool/testdata/BUILD +++ b/mediapipe/framework/tool/testdata/BUILD @@ -17,10 +17,13 @@ load( "//mediapipe/framework/tool:mediapipe_graph.bzl", "mediapipe_simple_subgraph", ) +load("//mediapipe/framework/port:build_config.bzl", "mediapipe_proto_library") licenses(["notice"]) -package(default_visibility = ["//mediapipe:__subpackages__"]) +package(default_visibility = [ + "//mediapipe:__subpackages__", +]) filegroup( name = "test_graph", @@ -40,7 +43,6 @@ mediapipe_simple_subgraph( testonly = 1, graph = "dub_quad_test_subgraph.pbtxt", register_as = "DubQuadTestSubgraph", - visibility = ["//visibility:public"], deps = [ "//mediapipe/framework:test_calculators", ], @@ -51,9 +53,18 @@ mediapipe_simple_subgraph( testonly = 1, graph = "nested_test_subgraph.pbtxt", register_as = "NestedTestSubgraph", - visibility = ["//visibility:public"], + visibility = ["//mediapipe/framework:__subpackages__"], deps = [ ":dub_quad_test_subgraph", "//mediapipe/framework:test_calculators", ], ) + +mediapipe_proto_library( + name = "frozen_generator_proto", + srcs = ["frozen_generator.proto"], + visibility = ["//mediapipe/framework:__subpackages__"], + deps = [ + "//mediapipe/framework:packet_generator_proto", + ], +) diff --git a/mediapipe/framework/tool/testdata/frozen_generator.proto b/mediapipe/framework/tool/testdata/frozen_generator.proto new file mode 100644 index 000000000..5f133f461 --- /dev/null +++ b/mediapipe/framework/tool/testdata/frozen_generator.proto @@ -0,0 +1,20 @@ +syntax = "proto2"; + +package mediapipe; + +import "mediapipe/framework/packet_generator.proto"; + +message FrozenGeneratorOptions { + extend mediapipe.PacketGeneratorOptions { + optional FrozenGeneratorOptions ext = 225748738; + } + + // Path to file containing serialized proto of type tensorflow::GraphDef. + optional string graph_proto_path = 1; + + // This map defines the which streams are fed to which tensors in the model. + map tag_to_tensor_names = 2; + + // Graph nodes to run to initialize the model. + repeated string initialization_op_names = 4; +} diff --git a/mediapipe/framework/validated_graph_config.cc b/mediapipe/framework/validated_graph_config.cc index 01e3da83e..15eac3209 100644 --- a/mediapipe/framework/validated_graph_config.cc +++ b/mediapipe/framework/validated_graph_config.cc @@ -369,6 +369,7 @@ absl::Status ValidatedGraphConfig::Initialize( input_side_packets_.clear(); output_side_packets_.clear(); stream_to_producer_.clear(); + output_streams_to_consumer_nodes_.clear(); input_streams_.clear(); output_streams_.clear(); owned_packet_types_.clear(); @@ -719,6 +720,15 @@ absl::Status ValidatedGraphConfig::AddInputStreamsForNode( << " does not have a corresponding output stream."; } } + // Add this node as a consumer of this edge's output stream. + if (edge_info.upstream > -1) { + auto parent_node = output_streams_[edge_info.upstream].parent_node; + if (parent_node.type == NodeTypeInfo::NodeType::CALCULATOR) { + int this_idx = node_type_info->Node().index; + output_streams_to_consumer_nodes_[edge_info.upstream].push_back( + this_idx); + } + } edge_info.parent_node = node_type_info->Node(); edge_info.name = name; diff --git a/mediapipe/framework/validated_graph_config.h b/mediapipe/framework/validated_graph_config.h index 11f9553cd..95ecccbb4 100644 --- a/mediapipe/framework/validated_graph_config.h +++ b/mediapipe/framework/validated_graph_config.h @@ -282,6 +282,14 @@ class ValidatedGraphConfig { return output_streams_[iter->second].parent_node.index; } + std::vector OutputStreamToConsumers(int idx) const { + auto iter = output_streams_to_consumer_nodes_.find(idx); + if (iter == output_streams_to_consumer_nodes_.end()) { + return {}; + } + return iter->second; + } + // Returns the registered type name of the specified side packet if // it can be determined, otherwise an appropriate error is returned. absl::StatusOr RegisteredSidePacketTypeName( @@ -418,6 +426,10 @@ class ValidatedGraphConfig { // Mapping from stream name to the output_streams_ index which produces it. std::map stream_to_producer_; + + // Mapping from output streams to consumer node ids. Used for profiling. + std::map> output_streams_to_consumer_nodes_; + // Mapping from side packet name to the output_side_packets_ index // which produces it. std::map side_packet_to_producer_; diff --git a/mediapipe/gpu/BUILD b/mediapipe/gpu/BUILD index 7a8aa6557..009eb3f9e 100644 --- a/mediapipe/gpu/BUILD +++ b/mediapipe/gpu/BUILD @@ -289,7 +289,9 @@ cc_library( deps = [ ":gpu_buffer_format", ":gpu_buffer_storage", + "@com_google_absl//absl/functional:bind_front", "@com_google_absl//absl/strings", + "@com_google_absl//absl/synchronization", "//mediapipe/framework/formats:image_frame", "//mediapipe/framework/port:logging", ":gpu_buffer_storage_image_frame", @@ -472,13 +474,13 @@ objc_library( copts = [ "-Wno-shorten-64-to-32", ], - sdk_frameworks = [ - "Accelerate", - "CoreGraphics", - "CoreVideo", - ], visibility = ["//visibility:public"], - deps = ["//mediapipe/objc:util"], + deps = [ + "//mediapipe/objc:util", + "//third_party/apple_frameworks:Accelerate", + "//third_party/apple_frameworks:CoreGraphics", + "//third_party/apple_frameworks:CoreVideo", + ], ) objc_library( @@ -510,13 +512,11 @@ objc_library( "-x objective-c++", "-Wno-shorten-64-to-32", ], - sdk_frameworks = [ - "CoreVideo", - "Metal", - ], visibility = ["//visibility:public"], deps = [ "//mediapipe/objc:mediapipe_framework_ios", + "//third_party/apple_frameworks:CoreVideo", + "//third_party/apple_frameworks:Metal", "@com_google_absl//absl/time", "@google_toolbox_for_mac//:GTM_Defines", ], @@ -808,15 +808,13 @@ objc_library( "-Wno-shorten-64-to-32", ], features = ["-layering_check"], - sdk_frameworks = [ - "CoreVideo", - "Metal", - ], visibility = ["//visibility:public"], deps = [ ":gpu_shared_data_internal", ":graph_support", "//mediapipe/objc:mediapipe_framework_ios", + "//third_party/apple_frameworks:CoreVideo", + "//third_party/apple_frameworks:Metal", "@google_toolbox_for_mac//:GTM_Defines", ], ) @@ -1020,16 +1018,14 @@ objc_library( name = "metal_copy_calculator", srcs = ["MetalCopyCalculator.mm"], features = ["-layering_check"], - sdk_frameworks = [ - "CoreVideo", - "Metal", - ], visibility = ["//visibility:public"], deps = [ ":MPPMetalHelper", ":simple_shaders_mtl", "//mediapipe/gpu:copy_calculator_cc_proto", "//mediapipe/objc:mediapipe_framework_ios", + "//third_party/apple_frameworks:CoreVideo", + "//third_party/apple_frameworks:Metal", ], alwayslink = 1, ) @@ -1038,15 +1034,13 @@ objc_library( name = "metal_rgb_weight_calculator", srcs = ["MetalRgbWeightCalculator.mm"], features = ["-layering_check"], - sdk_frameworks = [ - "CoreVideo", - "Metal", - ], visibility = ["//visibility:public"], deps = [ ":MPPMetalHelper", ":simple_shaders_mtl", "//mediapipe/objc:mediapipe_framework_ios", + "//third_party/apple_frameworks:CoreVideo", + "//third_party/apple_frameworks:Metal", ], alwayslink = 1, ) @@ -1055,15 +1049,13 @@ objc_library( name = "metal_sobel_calculator", srcs = ["MetalSobelCalculator.mm"], features = ["-layering_check"], - sdk_frameworks = [ - "CoreVideo", - "Metal", - ], visibility = ["//visibility:public"], deps = [ ":MPPMetalHelper", ":simple_shaders_mtl", "//mediapipe/objc:mediapipe_framework_ios", + "//third_party/apple_frameworks:CoreVideo", + "//third_party/apple_frameworks:Metal", ], alwayslink = 1, ) @@ -1072,15 +1064,13 @@ objc_library( name = "metal_sobel_compute_calculator", srcs = ["MetalSobelComputeCalculator.mm"], features = ["-layering_check"], - sdk_frameworks = [ - "CoreVideo", - "Metal", - ], visibility = ["//visibility:public"], deps = [ ":MPPMetalHelper", ":simple_shaders_mtl", "//mediapipe/objc:mediapipe_framework_ios", + "//third_party/apple_frameworks:CoreVideo", + "//third_party/apple_frameworks:Metal", ], alwayslink = 1, ) @@ -1090,15 +1080,13 @@ objc_library( srcs = ["MPSSobelCalculator.mm"], copts = ["-std=c++17"], features = ["-layering_check"], - sdk_frameworks = [ - "CoreVideo", - "Metal", - "MetalPerformanceShaders", - ], visibility = ["//visibility:public"], deps = [ ":MPPMetalHelper", "//mediapipe/objc:mediapipe_framework_ios", + "//third_party/apple_frameworks:CoreVideo", + "//third_party/apple_frameworks:Metal", + "//third_party/apple_frameworks:MetalPerformanceShaders", ], alwayslink = 1, ) @@ -1106,15 +1094,13 @@ objc_library( objc_library( name = "mps_threshold_calculator", srcs = ["MPSThresholdCalculator.mm"], - sdk_frameworks = [ - "CoreVideo", - "Metal", - "MetalPerformanceShaders", - ], visibility = ["//visibility:public"], deps = [ ":MPPMetalHelper", "//mediapipe/objc:mediapipe_framework_ios", + "//third_party/apple_frameworks:CoreVideo", + "//third_party/apple_frameworks:Metal", + "//third_party/apple_frameworks:MetalPerformanceShaders", ], alwayslink = 1, ) diff --git a/mediapipe/gpu/gpu_buffer.cc b/mediapipe/gpu/gpu_buffer.cc index 388960b11..628e86099 100644 --- a/mediapipe/gpu/gpu_buffer.cc +++ b/mediapipe/gpu/gpu_buffer.cc @@ -3,6 +3,7 @@ #include #include +#include "absl/functional/bind_front.h" #include "absl/strings/str_cat.h" #include "absl/strings/str_join.h" #include "mediapipe/framework/port/logging.h" @@ -25,57 +26,101 @@ struct StorageTypeFormatter { } // namespace std::string GpuBuffer::DebugString() const { - return absl::StrCat("GpuBuffer[", - absl::StrJoin(storages_, ", ", StorageTypeFormatter()), - "]"); + return holder_ ? absl::StrCat("GpuBuffer[", width(), "x", height(), " ", + format(), " as ", holder_->DebugString(), "]") + : "GpuBuffer[invalid]"; } -internal::GpuBufferStorage* GpuBuffer::GetStorageForView( +std::string GpuBuffer::StorageHolder::DebugString() const { + absl::MutexLock lock(&mutex_); + return absl::StrJoin(storages_, ", ", StorageTypeFormatter()); +} + +internal::GpuBufferStorage* GpuBuffer::StorageHolder::GetStorageForView( TypeId view_provider_type, bool for_writing) const { - const std::shared_ptr* chosen_storage = nullptr; + std::shared_ptr chosen_storage; + std::function()> conversion; - // First see if any current storage supports the view. - for (const auto& s : storages_) { - if (s->can_down_cast_to(view_provider_type)) { - chosen_storage = &s; - break; - } - } - - // Then try to convert existing storages to one that does. - // TODO: choose best conversion. - if (!chosen_storage) { + { + absl::MutexLock lock(&mutex_); + // First see if any current storage supports the view. for (const auto& s : storages_) { - if (auto converter = internal::GpuBufferStorageRegistry::Get() - .StorageConverterForViewProvider( - view_provider_type, s->storage_type())) { - if (auto new_storage = converter(s)) { - storages_.push_back(new_storage); - chosen_storage = &storages_.back(); + if (s->can_down_cast_to(view_provider_type)) { + chosen_storage = s; + break; + } + } + + // Then try to convert existing storages to one that does. + // TODO: choose best conversion. + if (!chosen_storage) { + for (const auto& s : storages_) { + if (auto converter = internal::GpuBufferStorageRegistry::Get() + .StorageConverterForViewProvider( + view_provider_type, s->storage_type())) { + conversion = absl::bind_front(converter, s); break; } } } } + // Avoid invoking a converter or factory while holding the mutex. + // Two reasons: + // 1. Readers that don't need a conversion will not be blocked. + // 2. We use mutexes to make sure GL contexts are not used simultaneously on + // different threads, and we also rely on Mutex's deadlock detection + // heuristic, which enforces a consistent mutex acquisition order. + // This function is likely to be called within a GL context, and the + // conversion function may in turn use a GL context, and this may cause a + // false positive in the deadlock detector. + // TODO: we could use Mutex::ForgetDeadlockInfo instead. + if (conversion) { + auto new_storage = conversion(); + absl::MutexLock lock(&mutex_); + // Another reader might have already completed and inserted the same + // conversion. TODO: prevent this? + for (const auto& s : storages_) { + if (s->can_down_cast_to(view_provider_type)) { + chosen_storage = s; + break; + } + } + if (!chosen_storage) { + storages_.push_back(std::move(new_storage)); + chosen_storage = storages_.back(); + } + } + if (for_writing) { + // This will temporarily hold storages to be released, and do so while the + // lock is not held (see above). + decltype(storages_) old_storages; + using std::swap; if (chosen_storage) { // Discard all other storages. - storages_ = {*chosen_storage}; - chosen_storage = &storages_.back(); + absl::MutexLock lock(&mutex_); + swap(old_storages, storages_); + storages_ = {chosen_storage}; } else { // Allocate a new storage supporting the requested view. if (auto factory = internal::GpuBufferStorageRegistry::Get() .StorageFactoryForViewProvider(view_provider_type)) { - if (auto new_storage = factory(width(), height(), format())) { + if (auto new_storage = factory(width_, height_, format_)) { + absl::MutexLock lock(&mutex_); + swap(old_storages, storages_); storages_ = {std::move(new_storage)}; - chosen_storage = &storages_.back(); + chosen_storage = storages_.back(); } } } } - return chosen_storage ? chosen_storage->get() : nullptr; + + // It is ok to return a non-owning storage pointer here because this object + // ensures the storage's lifetime. Overwriting a GpuBuffer while readers are + // active would violate this, but it's not allowed in MediaPipe. + return chosen_storage ? chosen_storage.get() : nullptr; } internal::GpuBufferStorage& GpuBuffer::GetStorageForViewOrDie( @@ -84,8 +129,7 @@ internal::GpuBufferStorage& GpuBuffer::GetStorageForViewOrDie( GpuBuffer::GetStorageForView(view_provider_type, for_writing); CHECK(chosen_storage) << "no view provider found for requested view " << view_provider_type.name() << "; storages available: " - << absl::StrJoin(storages_, ", ", - StorageTypeFormatter()); + << (holder_ ? holder_->DebugString() : "invalid"); DCHECK(chosen_storage->can_down_cast_to(view_provider_type)); return *chosen_storage; } diff --git a/mediapipe/gpu/gpu_buffer.h b/mediapipe/gpu/gpu_buffer.h index 56507d92f..b9a88aa53 100644 --- a/mediapipe/gpu/gpu_buffer.h +++ b/mediapipe/gpu/gpu_buffer.h @@ -15,9 +15,12 @@ #ifndef MEDIAPIPE_GPU_GPU_BUFFER_H_ #define MEDIAPIPE_GPU_GPU_BUFFER_H_ +#include +#include #include #include +#include "absl/synchronization/mutex.h" #include "mediapipe/framework/formats/image_frame.h" #include "mediapipe/gpu/gpu_buffer_format.h" #include "mediapipe/gpu/gpu_buffer_storage.h" @@ -56,8 +59,7 @@ class GpuBuffer { // Creates an empty buffer of a given size and format. It will be allocated // when a view is requested. GpuBuffer(int width, int height, Format format) - : GpuBuffer(std::make_shared(width, height, - format)) {} + : holder_(std::make_shared(width, height, format)) {} // Copy and move constructors and assignment operators are supported. GpuBuffer(const GpuBuffer& other) = default; @@ -70,9 +72,8 @@ class GpuBuffer { // are not portable. Applications and calculators should normally obtain // GpuBuffers in a portable way from the framework, e.g. using // GpuBufferMultiPool. - explicit GpuBuffer(std::shared_ptr storage) { - storages_.push_back(std::move(storage)); - } + explicit GpuBuffer(std::shared_ptr storage) + : holder_(std::make_shared(std::move(storage))) {} #if !MEDIAPIPE_DISABLE_GPU && MEDIAPIPE_GPU_BUFFER_USE_CV_PIXEL_BUFFER // This is used to support backward-compatible construction of GpuBuffer from @@ -84,9 +85,11 @@ class GpuBuffer { : GpuBuffer(internal::AsGpuBufferStorage(storage_convertible)) {} #endif // !MEDIAPIPE_DISABLE_GPU && MEDIAPIPE_GPU_BUFFER_USE_CV_PIXEL_BUFFER - int width() const { return current_storage().width(); } - int height() const { return current_storage().height(); } - GpuBufferFormat format() const { return current_storage().format(); } + int width() const { return holder_ ? holder_->width() : 0; } + int height() const { return holder_ ? holder_->height() : 0; } + GpuBufferFormat format() const { + return holder_ ? holder_->format() : GpuBufferFormat::kUnknown; + } // Converts to true iff valid. explicit operator bool() const { return operator!=(nullptr); } @@ -122,31 +125,17 @@ class GpuBuffer { // using views. template std::shared_ptr internal_storage() const { - for (const auto& s : storages_) - if (s->down_cast()) return std::static_pointer_cast(s); - return nullptr; + return holder_ ? holder_->internal_storage() : nullptr; } std::string DebugString() const; private: - class PlaceholderGpuBufferStorage - : public internal::GpuBufferStorageImpl { - public: - PlaceholderGpuBufferStorage(int width, int height, Format format) - : width_(width), height_(height), format_(format) {} - int width() const override { return width_; } - int height() const override { return height_; } - GpuBufferFormat format() const override { return format_; } - - private: - int width_ = 0; - int height_ = 0; - GpuBufferFormat format_ = GpuBufferFormat::kUnknown; - }; - internal::GpuBufferStorage* GetStorageForView(TypeId view_provider_type, - bool for_writing) const; + bool for_writing) const { + return holder_ ? holder_->GetStorageForView(view_provider_type, for_writing) + : nullptr; + } internal::GpuBufferStorage& GetStorageForViewOrDie(TypeId view_provider_type, bool for_writing) const; @@ -158,25 +147,49 @@ class GpuBuffer { .template down_cast(); } - std::shared_ptr& no_storage() const { - static auto placeholder = - std::static_pointer_cast( - std::make_shared( - 0, 0, GpuBufferFormat::kUnknown)); - return placeholder; - } + // This class manages a set of alternative storages for the contents of a + // GpuBuffer. GpuBuffer was originally designed as a reference-type object, + // where a copy represents another reference to the same contents, so multiple + // GpuBuffer instances can share the same StorageHolder. + class StorageHolder { + public: + explicit StorageHolder(std::shared_ptr storage) + : StorageHolder(storage->width(), storage->height(), + storage->format()) { + storages_.push_back(std::move(storage)); + } + explicit StorageHolder(int width, int height, Format format) + : width_(width), height_(height), format_(format) {} - const internal::GpuBufferStorage& current_storage() const { - return storages_.empty() ? *no_storage() : *storages_[0]; - } + int width() const { return width_; } + int height() const { return height_; } + GpuBufferFormat format() const { return format_; } - internal::GpuBufferStorage& current_storage() { - return storages_.empty() ? *no_storage() : *storages_[0]; - } + internal::GpuBufferStorage* GetStorageForView(TypeId view_provider_type, + bool for_writing) const; - // This is mutable because view methods that do not change the contents may - // still need to allocate new storages. - mutable std::vector> storages_; + template + std::shared_ptr internal_storage() const { + absl::MutexLock lock(&mutex_); + for (const auto& s : storages_) + if (s->down_cast()) return std::static_pointer_cast(s); + return nullptr; + } + + std::string DebugString() const; + + private: + int width_ = 0; + int height_ = 0; + GpuBufferFormat format_ = GpuBufferFormat::kUnknown; + // This is mutable because view methods that do not change the contents may + // still need to allocate new storages. + mutable absl::Mutex mutex_; + mutable std::vector> storages_ + ABSL_GUARDED_BY(mutex_); + }; + + std::shared_ptr holder_; #if MEDIAPIPE_GPU_BUFFER_USE_CV_PIXEL_BUFFER friend CVPixelBufferRef GetCVPixelBufferRef(const GpuBuffer& buffer); @@ -184,15 +197,15 @@ class GpuBuffer { }; inline bool GpuBuffer::operator==(std::nullptr_t other) const { - return storages_.empty(); + return holder_ == other; } inline bool GpuBuffer::operator==(const GpuBuffer& other) const { - return storages_ == other.storages_; + return holder_ == other.holder_; } inline GpuBuffer& GpuBuffer::operator=(std::nullptr_t other) { - storages_.clear(); + holder_ = other; return *this; } diff --git a/mediapipe/gpu/gpu_buffer_test.cc b/mediapipe/gpu/gpu_buffer_test.cc index 145b71806..e4be617db 100644 --- a/mediapipe/gpu/gpu_buffer_test.cc +++ b/mediapipe/gpu/gpu_buffer_test.cc @@ -20,6 +20,7 @@ #include "mediapipe/framework/port/gmock.h" #include "mediapipe/framework/port/gtest.h" #include "mediapipe/framework/tool/test_util.h" +#include "mediapipe/gpu/gl_texture_buffer.h" #include "mediapipe/gpu/gl_texture_util.h" #include "mediapipe/gpu/gpu_buffer_storage_ahwb.h" #include "mediapipe/gpu/gpu_buffer_storage_image_frame.h" @@ -228,5 +229,26 @@ TEST_F(GpuBufferTest, GlTextureViewRetainsWhatItNeeds) { EXPECT_TRUE(true); } +TEST_F(GpuBufferTest, CopiesShareConversions) { + GpuBuffer buffer(300, 200, GpuBufferFormat::kBGRA32); + { + std::shared_ptr view = buffer.GetWriteView(); + FillImageFrameRGBA(*view, 255, 0, 0, 255); + } + + GpuBuffer other_handle = buffer; + RunInGlContext([&buffer] { + TempGlFramebuffer fb; + auto view = buffer.GetReadView(0); + }); + + // Check that other_handle also sees the same GlTextureBuffer as buffer. + // Note that this is deliberately written so that it still passes on platforms + // where we use another storage for GL textures (they will both be null). + // TODO: expose more accessors for testing? + EXPECT_EQ(other_handle.internal_storage(), + buffer.internal_storage()); +} + } // anonymous namespace } // namespace mediapipe diff --git a/mediapipe/gpu/metal_shared_resources.h b/mediapipe/gpu/metal_shared_resources.h new file mode 100644 index 000000000..341860a2d --- /dev/null +++ b/mediapipe/gpu/metal_shared_resources.h @@ -0,0 +1,40 @@ +#ifndef MEDIAPIPE_GPU_METAL_SHARED_RESOURCES_H_ +#define MEDIAPIPE_GPU_METAL_SHARED_RESOURCES_H_ + +#import +#import +#import +#import + +#ifndef __OBJC__ +#error This class must be built as Objective-C++. +#endif // !__OBJC__ + +@interface MPPMetalSharedResources : NSObject { +} + +- (instancetype)init NS_DESIGNATED_INITIALIZER; + +@property(readonly) id mtlDevice; +@property(readonly) id mtlCommandQueue; +#if COREVIDEO_SUPPORTS_METAL +@property(readonly) CVMetalTextureCacheRef mtlTextureCache; +#endif + +@end + +namespace mediapipe { + +class MetalSharedResources { + public: + MetalSharedResources(); + ~MetalSharedResources(); + MPPMetalSharedResources* resources() { return resources_; } + + private: + MPPMetalSharedResources* resources_; +}; + +} // namespace mediapipe + +#endif // MEDIAPIPE_GPU_METAL_SHARED_RESOURCES_H_ diff --git a/mediapipe/gpu/metal_shared_resources.mm b/mediapipe/gpu/metal_shared_resources.mm new file mode 100644 index 000000000..80d755a01 --- /dev/null +++ b/mediapipe/gpu/metal_shared_resources.mm @@ -0,0 +1,73 @@ +#import "mediapipe/gpu/metal_shared_resources.h" + +@interface MPPMetalSharedResources () +@end + +@implementation MPPMetalSharedResources { +} + +@synthesize mtlDevice = _mtlDevice; +@synthesize mtlCommandQueue = _mtlCommandQueue; +#if COREVIDEO_SUPPORTS_METAL +@synthesize mtlTextureCache = _mtlTextureCache; +#endif + +- (instancetype)init { + self = [super init]; + if (self) { + } + return self; +} + +- (void)dealloc { +#if COREVIDEO_SUPPORTS_METAL + if (_mtlTextureCache) { + CFRelease(_mtlTextureCache); + _mtlTextureCache = NULL; + } +#endif +} + +- (id)mtlDevice { + @synchronized(self) { + if (!_mtlDevice) { + _mtlDevice = MTLCreateSystemDefaultDevice(); + } + } + return _mtlDevice; +} + +- (id)mtlCommandQueue { + @synchronized(self) { + if (!_mtlCommandQueue) { + _mtlCommandQueue = [self.mtlDevice newCommandQueue]; + } + } + return _mtlCommandQueue; +} + +#if COREVIDEO_SUPPORTS_METAL +- (CVMetalTextureCacheRef)mtlTextureCache { + @synchronized(self) { + if (!_mtlTextureCache) { + CVReturn __unused err = + CVMetalTextureCacheCreate(NULL, NULL, self.mtlDevice, NULL, &_mtlTextureCache); + NSAssert(err == kCVReturnSuccess, @"Error at CVMetalTextureCacheCreate %d ; device %@", err, + self.mtlDevice); + // TODO: register and flush metal caches too. + } + } + return _mtlTextureCache; +} +#endif + +@end + +namespace mediapipe { + +MetalSharedResources::MetalSharedResources() { + resources_ = [[MPPMetalSharedResources alloc] init]; +} +MetalSharedResources::~MetalSharedResources() {} + +} // namespace mediapipe diff --git a/mediapipe/gpu/metal_shared_resources_test.mm b/mediapipe/gpu/metal_shared_resources_test.mm new file mode 100644 index 000000000..9eb53a9b7 --- /dev/null +++ b/mediapipe/gpu/metal_shared_resources_test.mm @@ -0,0 +1,49 @@ +// Copyright 2019 The MediaPipe Authors. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#import +#import + +#include + +#include "absl/memory/memory.h" +#include "mediapipe/framework/port/threadpool.h" + +#import "mediapipe/gpu/gpu_shared_data_internal.h" +#import "mediapipe/gpu/metal_shared_resources.h" + +@interface MPPMetalSharedResourcesTests : XCTestCase { +} +@end + +@implementation MPPMetalSharedResourcesTests + +// This test verifies that the internal Objective-C object is correctly +// released when the C++ wrapper is released. +- (void)testCorrectlyReleased { + __weak id metalRes = nil; + std::weak_ptr weakGpuRes; + @autoreleasepool { + auto maybeGpuRes = mediapipe::GpuResources::Create(); + XCTAssertTrue(maybeGpuRes.ok()); + weakGpuRes = *maybeGpuRes; + metalRes = (**maybeGpuRes).metal_shared().resources(); + XCTAssertNotEqual(weakGpuRes.lock(), nullptr); + XCTAssertNotNil(metalRes); + } + XCTAssertEqual(weakGpuRes.lock(), nullptr); + XCTAssertNil(metalRes); +} + +@end diff --git a/mediapipe/graphs/hair_segmentation/BUILD b/mediapipe/graphs/hair_segmentation/BUILD index b177726bf..945f02c62 100644 --- a/mediapipe/graphs/hair_segmentation/BUILD +++ b/mediapipe/graphs/hair_segmentation/BUILD @@ -43,6 +43,7 @@ cc_library( deps = [ "//mediapipe/calculators/core:flow_limiter_calculator", "//mediapipe/calculators/core:previous_loopback_calculator", + "//mediapipe/calculators/image:color_convert_calculator", "//mediapipe/calculators/image:image_transformation_calculator", "//mediapipe/calculators/image:recolor_calculator", "//mediapipe/calculators/image:set_alpha_calculator", diff --git a/mediapipe/graphs/hair_segmentation/hair_segmentation_desktop_live.pbtxt b/mediapipe/graphs/hair_segmentation/hair_segmentation_desktop_live.pbtxt index 36c6970e1..f48b26be0 100644 --- a/mediapipe/graphs/hair_segmentation/hair_segmentation_desktop_live.pbtxt +++ b/mediapipe/graphs/hair_segmentation/hair_segmentation_desktop_live.pbtxt @@ -60,7 +60,14 @@ node { tag_index: "LOOP" back_edge: true } - output_stream: "PREV_LOOP:previous_hair_mask" + output_stream: "PREV_LOOP:previous_hair_mask_rgb" +} + +# Converts the 4 channel hair mask to a single channel mask +node { + calculator: "ColorConvertCalculator" + input_stream: "RGB_IN:previous_hair_mask_rgb" + output_stream: "GRAY_OUT:previous_hair_mask" } # Embeds the hair mask generated from the previous round of hair segmentation diff --git a/mediapipe/graphs/iris_tracking/calculators/BUILD b/mediapipe/graphs/iris_tracking/calculators/BUILD index 3a3d57a0f..f5124b464 100644 --- a/mediapipe/graphs/iris_tracking/calculators/BUILD +++ b/mediapipe/graphs/iris_tracking/calculators/BUILD @@ -97,7 +97,6 @@ cc_library( visibility = ["//visibility:public"], deps = [ "//mediapipe/framework:calculator_framework", - "//mediapipe/framework/formats:image_file_properties_cc_proto", "//mediapipe/framework/formats:landmark_cc_proto", "//mediapipe/framework/port:ret_check", "//mediapipe/framework/port:status", diff --git a/mediapipe/graphs/object_detection_3d/calculators/BUILD b/mediapipe/graphs/object_detection_3d/calculators/BUILD index 783fff187..d4c5c496b 100644 --- a/mediapipe/graphs/object_detection_3d/calculators/BUILD +++ b/mediapipe/graphs/object_detection_3d/calculators/BUILD @@ -22,6 +22,7 @@ package(default_visibility = ["//visibility:public"]) mediapipe_proto_library( name = "gl_animation_overlay_calculator_proto", srcs = ["gl_animation_overlay_calculator.proto"], + def_options_lib = False, visibility = ["//visibility:public"], exports = [ "//mediapipe/gpu:gl_animation_overlay_calculator_proto", diff --git a/mediapipe/java/com/google/mediapipe/framework/jni/BUILD b/mediapipe/java/com/google/mediapipe/framework/jni/BUILD index 4926e2f3c..4540f63a6 100644 --- a/mediapipe/java/com/google/mediapipe/framework/jni/BUILD +++ b/mediapipe/java/com/google/mediapipe/framework/jni/BUILD @@ -84,12 +84,11 @@ cc_library( deps = [ ":class_registry", ":jni_util", - "//mediapipe/framework:calculator_framework", - "//mediapipe/framework:calculator_profile_cc_proto", - "//mediapipe/framework/tool:calculator_graph_template_cc_proto", "//mediapipe/framework/formats:image_format_cc_proto", - "//mediapipe/framework/formats:matrix_data_cc_proto", "//mediapipe/framework/formats:time_series_header_cc_proto", + "//mediapipe/framework:calculator_cc_proto", + "//mediapipe/framework:calculator_profile_cc_proto", + "//mediapipe/framework:calculator_framework", "@com_google_absl//absl/strings", "@com_google_absl//absl/strings:str_format", "@com_google_absl//absl/synchronization", diff --git a/mediapipe/model_maker/MANIFEST.in b/mediapipe/model_maker/MANIFEST.in new file mode 100644 index 000000000..54ce01aff --- /dev/null +++ b/mediapipe/model_maker/MANIFEST.in @@ -0,0 +1 @@ +recursive-include pip_src/mediapipe_model_maker/models * diff --git a/mediapipe/model_maker/__init__.py b/mediapipe/model_maker/__init__.py index 7ca2f9216..9899a145b 100644 --- a/mediapipe/model_maker/__init__.py +++ b/mediapipe/model_maker/__init__.py @@ -11,3 +11,9 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. + + +from mediapipe.model_maker.python.core.utils import quantization +from mediapipe.model_maker.python.vision import image_classifier +from mediapipe.model_maker.python.vision import gesture_recognizer +from mediapipe.model_maker.python.text import text_classifier diff --git a/mediapipe/model_maker/python/core/utils/file_util.py b/mediapipe/model_maker/python/core/utils/file_util.py index bccf928e2..66addad54 100644 --- a/mediapipe/model_maker/python/core/utils/file_util.py +++ b/mediapipe/model_maker/python/core/utils/file_util.py @@ -19,7 +19,7 @@ import os def get_absolute_path(file_path: str) -> str: - """Gets the absolute path of a file. + """Gets the absolute path of a file in the model_maker directory. Args: file_path: The path to a file relative to the `mediapipe` dir @@ -27,10 +27,17 @@ def get_absolute_path(file_path: str) -> str: Returns: The full path of the file """ - # Extract the file path before mediapipe/ as the `base_dir`. By joining it - # with the `path` which defines the relative path under mediapipe/, it - # yields to the absolute path of the model files directory. + # Extract the file path before and including 'model_maker' as the + # `mm_base_dir`. By joining it with the `path` after 'model_maker/', it + # yields to the absolute path of the model files directory. We must join + # on 'model_maker' because in the pypi package, the 'model_maker' directory + # is renamed to 'mediapipe_model_maker'. So we have to join on model_maker + # to ensure that the `mm_base_dir` path includes the renamed + # 'mediapipe_model_maker' directory. cwd = os.path.dirname(__file__) - base_dir = cwd[:cwd.rfind('mediapipe')] - absolute_path = os.path.join(base_dir, file_path) + cwd_stop_idx = cwd.rfind('model_maker') + len('model_maker') + mm_base_dir = cwd[:cwd_stop_idx] + file_path_start_idx = file_path.find('model_maker') + len('model_maker') + 1 + mm_relative_path = file_path[file_path_start_idx:] + absolute_path = os.path.join(mm_base_dir, mm_relative_path) return absolute_path diff --git a/mediapipe/tasks/cc/components/proto/BUILD b/mediapipe/model_maker/python/text/__init__.py similarity index 63% rename from mediapipe/tasks/cc/components/proto/BUILD rename to mediapipe/model_maker/python/text/__init__.py index 569023753..7ca2f9216 100644 --- a/mediapipe/tasks/cc/components/proto/BUILD +++ b/mediapipe/model_maker/python/text/__init__.py @@ -4,21 +4,10 @@ # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # -# http://www.apache.org/licenses/LICENSE-2.0 +# http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. - -load("//mediapipe/framework/port:build_config.bzl", "mediapipe_proto_library") - -package(default_visibility = ["//mediapipe/tasks:internal"]) - -licenses(["notice"]) - -mediapipe_proto_library( - name = "segmenter_options_proto", - srcs = ["segmenter_options.proto"], -) diff --git a/mediapipe/model_maker/python/vision/gesture_recognizer/BUILD b/mediapipe/model_maker/python/vision/gesture_recognizer/BUILD index 256447a8d..9123e36b0 100644 --- a/mediapipe/model_maker/python/vision/gesture_recognizer/BUILD +++ b/mediapipe/model_maker/python/vision/gesture_recognizer/BUILD @@ -35,20 +35,21 @@ py_library( srcs = ["constants.py"], ) -# TODO: Change to py_library after migrating the MediaPipe hand solution -# library to MediaPipe hand task library. py_library( name = "dataset", srcs = ["dataset.py"], deps = [ ":constants", + ":metadata_writer", "//mediapipe/model_maker/python/core/data:classification_dataset", - "//mediapipe/model_maker/python/core/data:data_util", "//mediapipe/model_maker/python/core/utils:model_util", - "//mediapipe/python/solutions:hands", + "//mediapipe/python:_framework_bindings", + "//mediapipe/tasks/python/core:base_options", + "//mediapipe/tasks/python/vision:hand_landmarker", ], ) +# TODO: Remove notsan tag once tasks no longer has race condition issue py_test( name = "dataset_test", srcs = ["dataset_test.py"], @@ -56,10 +57,11 @@ py_test( ":testdata", "//mediapipe/model_maker/models/gesture_recognizer:models", ], + tags = ["notsan"], deps = [ ":dataset", - "//mediapipe/python/solutions:hands", "//mediapipe/tasks/python/test:test_utils", + "//mediapipe/tasks/python/vision:hand_landmarker", ], ) @@ -131,6 +133,7 @@ py_library( ], ) +# TODO: Remove notsan tag once tasks no longer has race condition issue py_test( name = "gesture_recognizer_test", size = "large", @@ -140,6 +143,7 @@ py_test( "//mediapipe/model_maker/models/gesture_recognizer:models", ], shard_count = 2, + tags = ["notsan"], deps = [ ":gesture_recognizer_import", "//mediapipe/model_maker/python/core/utils:test_util", diff --git a/mediapipe/model_maker/python/vision/gesture_recognizer/dataset.py b/mediapipe/model_maker/python/vision/gesture_recognizer/dataset.py index 256f26fd6..6a2c878c0 100644 --- a/mediapipe/model_maker/python/vision/gesture_recognizer/dataset.py +++ b/mediapipe/model_maker/python/vision/gesture_recognizer/dataset.py @@ -16,16 +16,22 @@ import dataclasses import os import random -from typing import List, NamedTuple, Optional +from typing import List, Optional -import cv2 import tensorflow as tf from mediapipe.model_maker.python.core.data import classification_dataset -from mediapipe.model_maker.python.core.data import data_util from mediapipe.model_maker.python.core.utils import model_util from mediapipe.model_maker.python.vision.gesture_recognizer import constants -from mediapipe.python.solutions import hands as mp_hands +from mediapipe.model_maker.python.vision.gesture_recognizer import metadata_writer +from mediapipe.python._framework_bindings import image as image_module +from mediapipe.tasks.python.core import base_options as base_options_module +from mediapipe.tasks.python.vision import hand_landmarker as hand_landmarker_module + +_Image = image_module.Image +_HandLandmarker = hand_landmarker_module.HandLandmarker +_HandLandmarkerOptions = hand_landmarker_module.HandLandmarkerOptions +_HandLandmarkerResult = hand_landmarker_module.HandLandmarkerResult @dataclasses.dataclass @@ -59,7 +65,7 @@ class HandData: handedness: List[float] -def _validate_data_sample(data: NamedTuple) -> bool: +def _validate_data_sample(data: _HandLandmarkerResult) -> bool: """Validates the input hand data sample. Args: @@ -70,19 +76,17 @@ def _validate_data_sample(data: NamedTuple) -> bool: 'multi_hand_landmarks' or 'multi_hand_world_landmarks' or 'multi_handedness' or any of these attributes' values are none. Otherwise, True. """ - if (not hasattr(data, 'multi_hand_landmarks') or - data.multi_hand_landmarks is None): + if data.hand_landmarks is None or not data.hand_landmarks: return False - if (not hasattr(data, 'multi_hand_world_landmarks') or - data.multi_hand_world_landmarks is None): + if data.hand_world_landmarks is None or not data.hand_world_landmarks: return False - if not hasattr(data, 'multi_handedness') or data.multi_handedness is None: + if data.handedness is None or not data.handedness: return False return True def _get_hand_data(all_image_paths: List[str], - min_detection_confidence: float) -> Optional[HandData]: + min_detection_confidence: float) -> List[Optional[HandData]]: """Computes hand data (landmarks and handedness) in the input image. Args: @@ -93,28 +97,36 @@ def _get_hand_data(all_image_paths: List[str], A HandData object. Returns None if no hand is detected. """ hand_data_result = [] - with mp_hands.Hands( - static_image_mode=True, - max_num_hands=1, - min_detection_confidence=min_detection_confidence) as hands: + hand_detector_model_buffer = model_util.load_tflite_model_buffer( + constants.HAND_DETECTOR_TFLITE_MODEL_FILE) + hand_landmarks_detector_model_buffer = model_util.load_tflite_model_buffer( + constants.HAND_LANDMARKS_DETECTOR_TFLITE_MODEL_FILE) + hand_landmarker_writer = metadata_writer.HandLandmarkerMetadataWriter( + hand_detector_model_buffer, hand_landmarks_detector_model_buffer) + hand_landmarker_options = _HandLandmarkerOptions( + base_options=base_options_module.BaseOptions( + model_asset_buffer=hand_landmarker_writer.populate()), + num_hands=1, + min_hand_detection_confidence=min_detection_confidence, + min_hand_presence_confidence=0.5, + min_tracking_confidence=1, + ) + with _HandLandmarker.create_from_options( + hand_landmarker_options) as hand_landmarker: for path in all_image_paths: tf.compat.v1.logging.info('Loading image %s', path) - image = data_util.load_image(path) - # Flip image around y-axis for correct handedness output - image = cv2.flip(image, 1) - data = hands.process(image) + image = _Image.create_from_file(path) + data = hand_landmarker.detect(image) if not _validate_data_sample(data): hand_data_result.append(None) continue - hand_landmarks = [[ - hand_landmark.x, hand_landmark.y, hand_landmark.z - ] for hand_landmark in data.multi_hand_landmarks[0].landmark] + hand_landmarks = [[hand_landmark.x, hand_landmark.y, hand_landmark.z] + for hand_landmark in data.hand_landmarks[0]] hand_world_landmarks = [[ hand_landmark.x, hand_landmark.y, hand_landmark.z - ] for hand_landmark in data.multi_hand_world_landmarks[0].landmark] + ] for hand_landmark in data.hand_world_landmarks[0]] handedness_scores = [ - handedness.score - for handedness in data.multi_handedness[0].classification + handedness.score for handedness in data.handedness[0] ] hand_data_result.append( HandData( diff --git a/mediapipe/model_maker/python/vision/gesture_recognizer/dataset_test.py b/mediapipe/model_maker/python/vision/gesture_recognizer/dataset_test.py index 76e70a58d..528d02edd 100644 --- a/mediapipe/model_maker/python/vision/gesture_recognizer/dataset_test.py +++ b/mediapipe/model_maker/python/vision/gesture_recognizer/dataset_test.py @@ -12,21 +12,17 @@ # See the License for the specific language governing permissions and # limitations under the License. -import collections import os import shutil from typing import NamedTuple import unittest -from absl import flags from absl.testing import parameterized import tensorflow as tf from mediapipe.model_maker.python.vision.gesture_recognizer import dataset -from mediapipe.python.solutions import hands as mp_hands from mediapipe.tasks.python.test import test_utils - -FLAGS = flags.FLAGS +from mediapipe.tasks.python.vision import hand_landmarker _TEST_DATA_DIRNAME = 'raw_data' @@ -39,14 +35,14 @@ class DatasetTest(tf.test.TestCase, parameterized.TestCase): dirname=input_data_dir, hparams=dataset.HandDataPreprocessingParams()) train_data, test_data = data.split(0.5) - self.assertLen(train_data, 17) + self.assertLen(train_data, 16) for _, elem in enumerate(train_data.gen_tf_dataset(is_training=True)): self.assertEqual(elem[0].shape, (1, 128)) self.assertEqual(elem[1].shape, ([1, 4])) self.assertEqual(train_data.num_classes, 4) self.assertEqual(train_data.label_names, ['none', 'call', 'four', 'rock']) - self.assertLen(test_data, 18) + self.assertLen(test_data, 16) for _, elem in enumerate(test_data.gen_tf_dataset(is_training=True)): self.assertEqual(elem[0].shape, (1, 128)) self.assertEqual(elem[1].shape, ([1, 4])) @@ -60,7 +56,7 @@ class DatasetTest(tf.test.TestCase, parameterized.TestCase): for _, elem in enumerate(data.gen_tf_dataset(is_training=True)): self.assertEqual(elem[0].shape, (1, 128)) self.assertEqual(elem[1].shape, ([1, 4])) - self.assertLen(data, 35) + self.assertLen(data, 32) self.assertEqual(data.num_classes, 4) self.assertEqual(data.label_names, ['none', 'call', 'four', 'rock']) @@ -105,51 +101,42 @@ class DatasetTest(tf.test.TestCase, parameterized.TestCase): for _, elem in enumerate(data.gen_tf_dataset(is_training=True)): self.assertEqual(elem[0].shape, (1, 128)) self.assertEqual(elem[1].shape, ([1, 4])) - self.assertLen(data, 35) + self.assertLen(data, 32) self.assertEqual(data.num_classes, 4) self.assertEqual(data.label_names, ['NONE', 'CALL', 'FOUR', 'ROCK']) @parameterized.named_parameters( dict( - testcase_name='invalid_field_name_multi_hand_landmark', - hand=collections.namedtuple('Hand', [ - 'multi_hand_landmark', 'multi_hand_world_landmarks', - 'multi_handedness' - ])(1, 2, 3)), + testcase_name='none_handedness', + hand=hand_landmarker.HandLandmarkerResult( + handedness=None, hand_landmarks=[[2]], + hand_world_landmarks=[[3]])), dict( - testcase_name='invalid_field_name_multi_hand_world_landmarks', - hand=collections.namedtuple('Hand', [ - 'multi_hand_landmarks', 'multi_hand_world_landmark', - 'multi_handedness' - ])(1, 2, 3)), + testcase_name='none_hand_landmarks', + hand=hand_landmarker.HandLandmarkerResult( + handedness=[[1]], hand_landmarks=None, + hand_world_landmarks=[[3]])), dict( - testcase_name='invalid_field_name_multi_handed', - hand=collections.namedtuple('Hand', [ - 'multi_hand_landmarks', 'multi_hand_world_landmarks', - 'multi_handed' - ])(1, 2, 3)), + testcase_name='none_hand_world_landmarks', + hand=hand_landmarker.HandLandmarkerResult( + handedness=[[1]], hand_landmarks=[[2]], + hand_world_landmarks=None)), dict( - testcase_name='multi_hand_landmarks_is_none', - hand=collections.namedtuple('Hand', [ - 'multi_hand_landmarks', 'multi_hand_world_landmarks', - 'multi_handedness' - ])(None, 2, 3)), + testcase_name='empty_handedness', + hand=hand_landmarker.HandLandmarkerResult( + handedness=[], hand_landmarks=[[2]], hand_world_landmarks=[[3]])), dict( - testcase_name='multi_hand_world_landmarks_is_none', - hand=collections.namedtuple('Hand', [ - 'multi_hand_landmarks', 'multi_hand_world_landmarks', - 'multi_handedness' - ])(1, None, 3)), + testcase_name='empty_hand_landmarks', + hand=hand_landmarker.HandLandmarkerResult( + handedness=[[1]], hand_landmarks=[], hand_world_landmarks=[[3]])), dict( - testcase_name='multi_handedness_is_none', - hand=collections.namedtuple('Hand', [ - 'multi_hand_landmarks', 'multi_hand_world_landmarks', - 'multi_handedness' - ])(1, 2, None)), + testcase_name='empty_hand_world_landmarks', + hand=hand_landmarker.HandLandmarkerResult( + handedness=[[1]], hand_landmarks=[[2]], hand_world_landmarks=[])), ) def test_create_dataset_from_invalid_hand_data(self, hand: NamedTuple): with unittest.mock.patch.object( - mp_hands.Hands, 'process', return_value=hand): + hand_landmarker.HandLandmarker, 'detect', return_value=hand): input_data_dir = test_utils.get_test_data_path(_TEST_DATA_DIRNAME) with self.assertRaisesRegex(ValueError, 'No valid hand is detected'): dataset.Dataset.from_folder( diff --git a/mediapipe/model_maker/python/vision/gesture_recognizer/metadata_writer.py b/mediapipe/model_maker/python/vision/gesture_recognizer/metadata_writer.py index 58b67e072..b2e851afe 100644 --- a/mediapipe/model_maker/python/vision/gesture_recognizer/metadata_writer.py +++ b/mediapipe/model_maker/python/vision/gesture_recognizer/metadata_writer.py @@ -62,6 +62,50 @@ def read_file(file_path: str, mode: str = "rb") -> Union[str, bytes]: return f.read() +class HandLandmarkerMetadataWriter: + """MetadataWriter to write the model asset bundle for HandLandmarker.""" + + def __init__( + self, + hand_detector_model_buffer: bytearray, + hand_landmarks_detector_model_buffer: bytearray, + ) -> None: + """Initializes HandLandmarkerMetadataWriter to write model asset bundle. + + Args: + hand_detector_model_buffer: A valid flatbuffer *with* metadata loaded from + the TFLite hand detector model file. + hand_landmarks_detector_model_buffer: A valid flatbuffer *with* metadata + loaded from the TFLite hand landmarks detector model file. + """ + self._hand_detector_model_buffer = hand_detector_model_buffer + self._hand_landmarks_detector_model_buffer = hand_landmarks_detector_model_buffer + self._temp_folder = tempfile.TemporaryDirectory() + + def __del__(self): + if os.path.exists(self._temp_folder.name): + self._temp_folder.cleanup() + + def populate(self): + """Creates the model asset bundle for hand landmarker task. + + Returns: + Model asset bundle in bytes + """ + landmark_models = { + _HAND_DETECTOR_TFLITE_NAME: + self._hand_detector_model_buffer, + _HAND_LANDMARKS_DETECTOR_TFLITE_NAME: + self._hand_landmarks_detector_model_buffer + } + output_hand_landmarker_path = os.path.join(self._temp_folder.name, + _HAND_LANDMARKER_BUNDLE_NAME) + writer_utils.create_model_asset_bundle(landmark_models, + output_hand_landmarker_path) + hand_landmarker_model_buffer = read_file(output_hand_landmarker_path) + return hand_landmarker_model_buffer + + class MetadataWriter: """MetadataWriter to write the metadata and the model asset bundle.""" @@ -86,8 +130,8 @@ class MetadataWriter: custom_gesture_classifier_metadata_writer: Metadata writer to write custom gesture classifier metadata into the TFLite file. """ - self._hand_detector_model_buffer = hand_detector_model_buffer - self._hand_landmarks_detector_model_buffer = hand_landmarks_detector_model_buffer + self._hand_landmarker_metadata_writer = HandLandmarkerMetadataWriter( + hand_detector_model_buffer, hand_landmarks_detector_model_buffer) self._gesture_embedder_model_buffer = gesture_embedder_model_buffer self._canned_gesture_classifier_model_buffer = canned_gesture_classifier_model_buffer self._custom_gesture_classifier_metadata_writer = custom_gesture_classifier_metadata_writer @@ -147,16 +191,8 @@ class MetadataWriter: A tuple of (model_asset_bundle_in_bytes, metadata_json_content) """ # Creates the model asset bundle for hand landmarker task. - landmark_models = { - _HAND_DETECTOR_TFLITE_NAME: - self._hand_detector_model_buffer, - _HAND_LANDMARKS_DETECTOR_TFLITE_NAME: - self._hand_landmarks_detector_model_buffer - } - output_hand_landmarker_path = os.path.join(self._temp_folder.name, - _HAND_LANDMARKER_BUNDLE_NAME) - writer_utils.create_model_asset_bundle(landmark_models, - output_hand_landmarker_path) + hand_landmarker_model_buffer = self._hand_landmarker_metadata_writer.populate( + ) # Write metadata into custom gesture classifier model. self._custom_gesture_classifier_model_buffer, custom_gesture_classifier_metadata_json = self._custom_gesture_classifier_metadata_writer.populate( @@ -179,7 +215,7 @@ class MetadataWriter: # graph. gesture_recognizer_models = { _HAND_LANDMARKER_BUNDLE_NAME: - read_file(output_hand_landmarker_path), + hand_landmarker_model_buffer, _HAND_GESTURE_RECOGNIZER_BUNDLE_NAME: read_file(output_hand_gesture_recognizer_path), } diff --git a/mediapipe/model_maker/python/vision/gesture_recognizer/metadata_writer_test.py b/mediapipe/model_maker/python/vision/gesture_recognizer/metadata_writer_test.py index 83998141d..fd26b274d 100644 --- a/mediapipe/model_maker/python/vision/gesture_recognizer/metadata_writer_test.py +++ b/mediapipe/model_maker/python/vision/gesture_recognizer/metadata_writer_test.py @@ -33,6 +33,23 @@ _CUSTOM_GESTURE_CLASSIFIER_PATH = test_utils.get_test_data_path( class MetadataWriterTest(tf.test.TestCase): + def test_hand_landmarker_metadata_writer(self): + # Use dummy model buffer for unit test only. + hand_detector_model_buffer = b"\x11\x12" + hand_landmarks_detector_model_buffer = b"\x22" + writer = metadata_writer.HandLandmarkerMetadataWriter( + hand_detector_model_buffer, hand_landmarks_detector_model_buffer) + model_bundle_content = writer.populate() + model_bundle_filepath = os.path.join(self.get_temp_dir(), + "hand_landmarker.task") + with open(model_bundle_filepath, "wb") as f: + f.write(model_bundle_content) + + with zipfile.ZipFile(model_bundle_filepath) as zf: + self.assertEqual( + set(zf.namelist()), + set(["hand_landmarks_detector.tflite", "hand_detector.tflite"])) + def test_write_metadata_and_create_model_asset_bundle_successful(self): # Use dummy model buffer for unit test only. hand_detector_model_buffer = b"\x11\x12" diff --git a/mediapipe/model_maker/requirements.txt b/mediapipe/model_maker/requirements.txt index 389ee484a..9b3c9f906 100644 --- a/mediapipe/model_maker/requirements.txt +++ b/mediapipe/model_maker/requirements.txt @@ -1,6 +1,8 @@ absl-py +mediapipe==0.9.1 numpy -opencv-contrib-python -tensorflow +opencv-python +tensorflow>=2.10 tensorflow-datasets tensorflow-hub +tf-models-official>=2.10.1 diff --git a/mediapipe/model_maker/setup.py b/mediapipe/model_maker/setup.py new file mode 100644 index 000000000..ea193db94 --- /dev/null +++ b/mediapipe/model_maker/setup.py @@ -0,0 +1,147 @@ +"""Copyright 2020-2022 The MediaPipe Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. + +Setup for Mediapipe-Model-Maker package with setuptools. +""" + +import glob +import os +import shutil +import subprocess +import sys +import setuptools + + +__version__ = 'dev' +MM_ROOT_PATH = os.path.dirname(os.path.abspath(__file__)) +# Build dir to copy all necessary files and build package +SRC_NAME = 'pip_src' +BUILD_DIR = os.path.join(MM_ROOT_PATH, SRC_NAME) +BUILD_MM_DIR = os.path.join(BUILD_DIR, 'mediapipe_model_maker') + + +def _parse_requirements(path): + with open(os.path.join(MM_ROOT_PATH, path)) as f: + return [ + line.rstrip() + for line in f + if not (line.isspace() or line.startswith('#')) + ] + + +def _copy_to_pip_src_dir(file): + """Copy a file from bazel-bin to the pip_src dir.""" + dst = file + dst_dir = os.path.dirname(dst) + if not os.path.exists(dst_dir): + os.makedirs(dst_dir) + src_file = os.path.join('../../bazel-bin/mediapipe/model_maker', file) + shutil.copyfile(src_file, file) + + +def _setup_build_dir(): + """Setup the BUILD_DIR directory to build the mediapipe_model_maker package. + + We need to create a new BUILD_DIR directory because any references to the path + `mediapipe/model_maker` needs to be renamed to `mediapipe_model_maker` to + avoid conflicting with the mediapipe package name. + This setup function performs the following actions: + 1. Copy python source code into BUILD_DIR and rename imports to + mediapipe_model_maker + 2. Download models from GCS into BUILD_DIR + """ + # Copy python source code into BUILD_DIR + if os.path.exists(BUILD_DIR): + shutil.rmtree(BUILD_DIR) + python_files = glob.glob('python/**/*.py', recursive=True) + python_files.append('__init__.py') + for python_file in python_files: + # Exclude test files from pip package + if '_test.py' in python_file: + continue + build_target_file = os.path.join(BUILD_MM_DIR, python_file) + with open(python_file, 'r') as file: + filedata = file.read() + # Rename all mediapipe.model_maker imports to mediapipe_model_maker + filedata = filedata.replace('from mediapipe.model_maker', + 'from mediapipe_model_maker') + os.makedirs(os.path.dirname(build_target_file), exist_ok=True) + with open(build_target_file, 'w') as file: + file.write(filedata) + + # Use bazel to download GCS model files + model_build_files = ['models/gesture_recognizer/BUILD'] + for model_build_file in model_build_files: + build_target_file = os.path.join(BUILD_MM_DIR, model_build_file) + os.makedirs(os.path.dirname(build_target_file), exist_ok=True) + shutil.copy(model_build_file, build_target_file) + external_files = [ + 'models/gesture_recognizer/canned_gesture_classifier.tflite', + 'models/gesture_recognizer/gesture_embedder.tflite', + 'models/gesture_recognizer/hand_landmark_full.tflite', + 'models/gesture_recognizer/palm_detection_full.tflite', + 'models/gesture_recognizer/gesture_embedder/keras_metadata.pb', + 'models/gesture_recognizer/gesture_embedder/saved_model.pb', + 'models/gesture_recognizer/gesture_embedder/variables/variables.data-00000-of-00001', + 'models/gesture_recognizer/gesture_embedder/variables/variables.index', + ] + for elem in external_files: + external_file = os.path.join(f'{SRC_NAME}/mediapipe_model_maker', elem) + sys.stderr.write('downloading file: %s\n' % external_file) + fetch_model_command = [ + 'bazel', + 'build', + external_file, + ] + if subprocess.call(fetch_model_command) != 0: + sys.exit(-1) + _copy_to_pip_src_dir(external_file) + +_setup_build_dir() + +setuptools.setup( + name='mediapipe-model-maker', + version=__version__, + url='https://github.com/google/mediapipe/tree/master/mediapipe/model_maker', + description='MediaPipe Model Maker is a simple, low-code solution for customizing on-device ML models', + author='The MediaPipe Authors', + author_email='mediapipe@google.com', + long_description='', + long_description_content_type='text/markdown', + packages=setuptools.find_packages(where=SRC_NAME), + package_dir={'': SRC_NAME}, + install_requires=_parse_requirements('requirements.txt'), + include_package_data=True, + classifiers=[ + 'Development Status :: 3 - Alpha', + 'Intended Audience :: Developers', + 'Intended Audience :: Education', + 'Intended Audience :: Science/Research', + 'License :: OSI Approved :: Apache Software License', + 'Operating System :: MacOS :: MacOS X', + 'Operating System :: Microsoft :: Windows', + 'Operating System :: POSIX :: Linux', + 'Programming Language :: Python :: 3.7', + 'Programming Language :: Python :: 3.8', + 'Programming Language :: Python :: 3.9', + 'Programming Language :: Python :: 3 :: Only', + 'Topic :: Scientific/Engineering', + 'Topic :: Scientific/Engineering :: Artificial Intelligence', + 'Topic :: Software Development', + 'Topic :: Software Development :: Libraries', + 'Topic :: Software Development :: Libraries :: Python Modules', + ], + license='Apache 2.0', + keywords=['mediapipe', 'model', 'maker'], +) diff --git a/mediapipe/modules/hand_landmark/calculators/BUILD b/mediapipe/modules/hand_landmark/calculators/BUILD index b2a8efe37..b42ec94de 100644 --- a/mediapipe/modules/hand_landmark/calculators/BUILD +++ b/mediapipe/modules/hand_landmark/calculators/BUILD @@ -24,7 +24,6 @@ cc_library( "//mediapipe/framework:calculator_framework", "//mediapipe/framework:calculator_options_cc_proto", "//mediapipe/framework/formats:landmark_cc_proto", - "//mediapipe/framework/formats:location_data_cc_proto", "//mediapipe/framework/formats:rect_cc_proto", "//mediapipe/framework/port:ret_check", "//mediapipe/framework/port:status", diff --git a/mediapipe/modules/hand_landmark/calculators/hand_landmarks_to_rect_calculator.cc b/mediapipe/modules/hand_landmark/calculators/hand_landmarks_to_rect_calculator.cc index 6f2c49d64..638678ff5 100644 --- a/mediapipe/modules/hand_landmark/calculators/hand_landmarks_to_rect_calculator.cc +++ b/mediapipe/modules/hand_landmark/calculators/hand_landmarks_to_rect_calculator.cc @@ -22,6 +22,8 @@ namespace mediapipe { +using ::mediapipe::NormalizedRect; + namespace { // NORM_LANDMARKS is either the full set of landmarks for the hand, or diff --git a/mediapipe/modules/holistic_landmark/calculators/roi_tracking_calculator.cc b/mediapipe/modules/holistic_landmark/calculators/roi_tracking_calculator.cc index 0da6cd7f7..49c7b93fb 100644 --- a/mediapipe/modules/holistic_landmark/calculators/roi_tracking_calculator.cc +++ b/mediapipe/modules/holistic_landmark/calculators/roi_tracking_calculator.cc @@ -34,6 +34,8 @@ constexpr char kRecropRectTag[] = "RECROP_RECT"; constexpr char kImageSizeTag[] = "IMAGE_SIZE"; constexpr char kTrackingRectTag[] = "TRACKING_RECT"; +using ::mediapipe::NormalizedRect; + // TODO: Use rect rotation. // Verifies that Intersection over Union of previous frame rect and current // frame re-crop rect is less than threshold. diff --git a/mediapipe/modules/objectron/calculators/BUILD b/mediapipe/modules/objectron/calculators/BUILD index eeeaee5f4..14cea526f 100644 --- a/mediapipe/modules/objectron/calculators/BUILD +++ b/mediapipe/modules/objectron/calculators/BUILD @@ -275,7 +275,6 @@ cc_library( ":tflite_tensors_to_objects_calculator_cc_proto", "//mediapipe/framework:calculator_framework", "//mediapipe/framework/deps:file_path", - "//mediapipe/framework/formats:detection_cc_proto", "//mediapipe/framework/port:opencv_core", "//mediapipe/framework/port:ret_check", "@com_google_absl//absl/memory", @@ -299,7 +298,6 @@ cc_library( ":tensors_to_objects_calculator_cc_proto", "//mediapipe/framework:calculator_framework", "//mediapipe/framework/deps:file_path", - "//mediapipe/framework/formats:detection_cc_proto", "//mediapipe/framework/port:opencv_core", "//mediapipe/framework/port:ret_check", "@com_google_absl//absl/memory", @@ -316,13 +314,11 @@ cc_library( visibility = ["//visibility:public"], deps = [ ":annotation_cc_proto", - ":belief_decoder_config_cc_proto", ":decoder", ":lift_2d_frame_annotation_to_3d_calculator_cc_proto", ":tensor_util", "//mediapipe/framework:calculator_framework", "//mediapipe/framework/deps:file_path", - "//mediapipe/framework/formats:detection_cc_proto", "//mediapipe/framework/port:opencv_core", "//mediapipe/framework/port:ret_check", "@com_google_absl//absl/memory", diff --git a/mediapipe/modules/objectron/calculators/frame_annotation_to_rect_calculator.cc b/mediapipe/modules/objectron/calculators/frame_annotation_to_rect_calculator.cc index 476f8cb54..1fe919c54 100644 --- a/mediapipe/modules/objectron/calculators/frame_annotation_to_rect_calculator.cc +++ b/mediapipe/modules/objectron/calculators/frame_annotation_to_rect_calculator.cc @@ -34,6 +34,8 @@ namespace { constexpr char kInputFrameAnnotationTag[] = "FRAME_ANNOTATION"; constexpr char kOutputNormRectsTag[] = "NORM_RECTS"; +using ::mediapipe::NormalizedRect; + } // namespace // A calculator that converts FrameAnnotation proto to NormalizedRect. diff --git a/mediapipe/objc/BUILD b/mediapipe/objc/BUILD index fafdfee8a..c71c02b6d 100644 --- a/mediapipe/objc/BUILD +++ b/mediapipe/objc/BUILD @@ -68,7 +68,6 @@ objc_library( copts = [ "-Wno-shorten-64-to-32", ], - sdk_frameworks = ["Accelerate"], # This build rule is public to allow external customers to build their own iOS apps. visibility = ["//visibility:public"], deps = [ @@ -90,6 +89,7 @@ objc_library( "//mediapipe/gpu:metal_shared_resources", "//mediapipe/gpu:pixel_buffer_pool_util", "//mediapipe/util:cpu_util", + "//third_party/apple_frameworks:Accelerate", "@com_google_absl//absl/base:core_headers", "@com_google_absl//absl/memory", "@com_google_absl//absl/strings", @@ -120,13 +120,13 @@ objc_library( ], "//conditions:default": [], }), - sdk_frameworks = [ - "AVFoundation", - "CoreVideo", - "Foundation", - ], # This build rule is public to allow external customers to build their own iOS apps. visibility = ["//visibility:public"], + deps = [ + "//third_party/apple_frameworks:AVFoundation", + "//third_party/apple_frameworks:CoreVideo", + "//third_party/apple_frameworks:Foundation", + ], ) objc_library( @@ -140,16 +140,14 @@ objc_library( copts = [ "-Wno-shorten-64-to-32", ], - sdk_frameworks = [ - "Foundation", - "GLKit", - ], visibility = ["//mediapipe/framework:mediapipe_internal"], deps = [ ":mediapipe_framework_ios", "//mediapipe/gpu:gl_calculator_helper", "//mediapipe/gpu:gl_quad_renderer", "//mediapipe/gpu:gl_simple_shaders", + "//third_party/apple_frameworks:Foundation", + "//third_party/apple_frameworks:GLKit", ], ) @@ -164,16 +162,14 @@ objc_library( copts = [ "-Wno-shorten-64-to-32", ], - sdk_frameworks = [ - "Foundation", - "GLKit", - ], # This build rule is public to allow external customers to build their own iOS apps. visibility = ["//visibility:public"], deps = [ ":mediapipe_framework_ios", ":mediapipe_gl_view_renderer", "//mediapipe/gpu:gl_calculator_helper", + "//third_party/apple_frameworks:Foundation", + "//third_party/apple_frameworks:GLKit", ], ) @@ -188,13 +184,11 @@ objc_library( copts = [ "-Wno-shorten-64-to-32", ], - sdk_frameworks = [ - "CoreVideo", - "Foundation", - ], visibility = ["//mediapipe/framework:mediapipe_internal"], deps = [ ":mediapipe_framework_ios", + "//third_party/apple_frameworks:CoreVideo", + "//third_party/apple_frameworks:Foundation", "@com_google_absl//absl/strings", ], ) @@ -211,23 +205,21 @@ objc_library( copts = [ "-Wno-shorten-64-to-32", ], - sdk_frameworks = [ - "AVFoundation", - "Accelerate", - "CoreGraphics", - "CoreMedia", - "CoreVideo", - "GLKit", - "OpenGLES", - "QuartzCore", - "UIKit", - ], visibility = ["//mediapipe/framework:mediapipe_internal"], deps = [ ":CGImageRefUtils", ":Weakify", ":mediapipe_framework_ios", "//mediapipe/framework:calculator_framework", + "//third_party/apple_frameworks:AVFoundation", + "//third_party/apple_frameworks:Accelerate", + "//third_party/apple_frameworks:CoreGraphics", + "//third_party/apple_frameworks:CoreMedia", + "//third_party/apple_frameworks:CoreVideo", + "//third_party/apple_frameworks:GLKit", + "//third_party/apple_frameworks:OpenGLES", + "//third_party/apple_frameworks:QuartzCore", + "//third_party/apple_frameworks:UIKit", ], ) @@ -245,16 +237,6 @@ objc_library( data = [ "testdata/googlelogo_color_272x92dp.png", ], - sdk_frameworks = [ - "AVFoundation", - "Accelerate", - "CoreGraphics", - "CoreMedia", - "CoreVideo", - "GLKit", - "QuartzCore", - "UIKit", - ], visibility = ["//mediapipe/framework:mediapipe_internal"], deps = [ ":CGImageRefUtils", @@ -263,6 +245,14 @@ objc_library( ":mediapipe_framework_ios", ":mediapipe_input_sources_ios", "//mediapipe/calculators/core:pass_through_calculator", + "//third_party/apple_frameworks:AVFoundation", + "//third_party/apple_frameworks:Accelerate", + "//third_party/apple_frameworks:CoreGraphics", + "//third_party/apple_frameworks:CoreMedia", + "//third_party/apple_frameworks:CoreVideo", + "//third_party/apple_frameworks:GLKit", + "//third_party/apple_frameworks:QuartzCore", + "//third_party/apple_frameworks:UIKit", ], ) diff --git a/mediapipe/tasks/cc/audio/audio_classifier/BUILD b/mediapipe/tasks/cc/audio/audio_classifier/BUILD index f61472413..c575caabe 100644 --- a/mediapipe/tasks/cc/audio/audio_classifier/BUILD +++ b/mediapipe/tasks/cc/audio/audio_classifier/BUILD @@ -22,9 +22,7 @@ cc_library( name = "audio_classifier", srcs = ["audio_classifier.cc"], hdrs = ["audio_classifier.h"], - visibility = [ - "//mediapipe/tasks:users", - ], + visibility = ["//visibility:public"], deps = [ ":audio_classifier_graph", "//mediapipe/framework/api2:builder", diff --git a/mediapipe/tasks/cc/audio/audio_classifier/proto/audio_classifier_graph_options.proto b/mediapipe/tasks/cc/audio/audio_classifier/proto/audio_classifier_graph_options.proto index 5d4ba3296..cc26b3070 100644 --- a/mediapipe/tasks/cc/audio/audio_classifier/proto/audio_classifier_graph_options.proto +++ b/mediapipe/tasks/cc/audio/audio_classifier/proto/audio_classifier_graph_options.proto @@ -18,6 +18,7 @@ syntax = "proto2"; package mediapipe.tasks.audio.audio_classifier.proto; import "mediapipe/framework/calculator.proto"; +import "mediapipe/framework/calculator_options.proto"; import "mediapipe/tasks/cc/components/processors/proto/classifier_options.proto"; import "mediapipe/tasks/cc/core/proto/base_options.proto"; diff --git a/mediapipe/tasks/cc/audio/audio_embedder/BUILD b/mediapipe/tasks/cc/audio/audio_embedder/BUILD index 6a0f627b2..1dfdd6f1b 100644 --- a/mediapipe/tasks/cc/audio/audio_embedder/BUILD +++ b/mediapipe/tasks/cc/audio/audio_embedder/BUILD @@ -22,9 +22,7 @@ cc_library( name = "audio_embedder", srcs = ["audio_embedder.cc"], hdrs = ["audio_embedder.h"], - visibility = [ - "//mediapipe/tasks:users", - ], + visibility = ["//visibility:public"], deps = [ ":audio_embedder_graph", "//mediapipe/framework/api2:builder", diff --git a/mediapipe/tasks/cc/audio/audio_embedder/proto/audio_embedder_graph_options.proto b/mediapipe/tasks/cc/audio/audio_embedder/proto/audio_embedder_graph_options.proto index 25c5d5474..367a1bf26 100644 --- a/mediapipe/tasks/cc/audio/audio_embedder/proto/audio_embedder_graph_options.proto +++ b/mediapipe/tasks/cc/audio/audio_embedder/proto/audio_embedder_graph_options.proto @@ -18,6 +18,7 @@ syntax = "proto2"; package mediapipe.tasks.audio.audio_embedder.proto; import "mediapipe/framework/calculator.proto"; +import "mediapipe/framework/calculator_options.proto"; import "mediapipe/tasks/cc/components/processors/proto/embedder_options.proto"; import "mediapipe/tasks/cc/core/proto/base_options.proto"; diff --git a/mediapipe/tasks/cc/components/containers/BUILD b/mediapipe/tasks/cc/components/containers/BUILD index 35d3f4785..0750a1482 100644 --- a/mediapipe/tasks/cc/components/containers/BUILD +++ b/mediapipe/tasks/cc/components/containers/BUILD @@ -18,6 +18,7 @@ licenses(["notice"]) cc_library( name = "rect", + srcs = ["rect.cc"], hdrs = ["rect.h"], ) @@ -41,6 +42,18 @@ cc_library( ], ) +cc_library( + name = "detection_result", + srcs = ["detection_result.cc"], + hdrs = ["detection_result.h"], + deps = [ + ":category", + ":rect", + "//mediapipe/framework/formats:detection_cc_proto", + "//mediapipe/framework/formats:location_data_cc_proto", + ], +) + cc_library( name = "embedding_result", srcs = ["embedding_result.cc"], diff --git a/mediapipe/tasks/cc/components/containers/detection_result.cc b/mediapipe/tasks/cc/components/containers/detection_result.cc new file mode 100644 index 000000000..43c8ca0f5 --- /dev/null +++ b/mediapipe/tasks/cc/components/containers/detection_result.cc @@ -0,0 +1,73 @@ +/* Copyright 2022 The MediaPipe Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "mediapipe/tasks/cc/components/containers/detection_result.h" + +#include + +#include +#include +#include + +#include "mediapipe/framework/formats/detection.pb.h" +#include "mediapipe/framework/formats/location_data.pb.h" +#include "mediapipe/tasks/cc/components/containers/category.h" +#include "mediapipe/tasks/cc/components/containers/rect.h" + +namespace mediapipe::tasks::components::containers { + +constexpr int kDefaultCategoryIndex = -1; + +Detection ConvertToDetectionResult( + const mediapipe::Detection& detection_proto) { + Detection detection; + for (int idx = 0; idx < detection_proto.score_size(); ++idx) { + detection.categories.push_back( + {/* index= */ detection_proto.label_id_size() > idx + ? detection_proto.label_id(idx) + : kDefaultCategoryIndex, + /* score= */ detection_proto.score(idx), + /* category_name */ detection_proto.label_size() > idx + ? detection_proto.label(idx) + : "", + /* display_name */ detection_proto.display_name_size() > idx + ? detection_proto.display_name(idx) + : ""}); + } + Rect bounding_box; + if (detection_proto.location_data().has_bounding_box()) { + mediapipe::LocationData::BoundingBox bounding_box_proto = + detection_proto.location_data().bounding_box(); + bounding_box.left = bounding_box_proto.xmin(); + bounding_box.top = bounding_box_proto.ymin(); + bounding_box.right = bounding_box_proto.xmin() + bounding_box_proto.width(); + bounding_box.bottom = + bounding_box_proto.ymin() + bounding_box_proto.height(); + } + detection.bounding_box = bounding_box; + return detection; +} + +DetectionResult ConvertToDetectionResult( + std::vector detections_proto) { + DetectionResult detection_result; + detection_result.detections.reserve(detections_proto.size()); + for (const auto& detection_proto : detections_proto) { + detection_result.detections.push_back( + ConvertToDetectionResult(detection_proto)); + } + return detection_result; +} +} // namespace mediapipe::tasks::components::containers diff --git a/mediapipe/tasks/cc/components/containers/detection_result.h b/mediapipe/tasks/cc/components/containers/detection_result.h new file mode 100644 index 000000000..546f324d6 --- /dev/null +++ b/mediapipe/tasks/cc/components/containers/detection_result.h @@ -0,0 +1,52 @@ +/* Copyright 2022 The MediaPipe Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef MEDIAPIPE_TASKS_CC_COMPONENTS_CONTAINERS_DETECTION_RESULT_H_ +#define MEDIAPIPE_TASKS_CC_COMPONENTS_CONTAINERS_DETECTION_RESULT_H_ + +#include +#include +#include + +#include "mediapipe/framework/formats/detection.pb.h" +#include "mediapipe/tasks/cc/components/containers/category.h" +#include "mediapipe/tasks/cc/components/containers/rect.h" + +namespace mediapipe::tasks::components::containers { + +// Detection for a single bounding box. +struct Detection { + // A vector of detected categories. + std::vector categories; + // The bounding box location. + Rect bounding_box; +}; + +// Detection results of a model. +struct DetectionResult { + // A vector of Detections. + std::vector detections; +}; + +// Utility function to convert from Detection proto to Detection struct. +Detection ConvertToDetection(const mediapipe::Detection& detection_proto); + +// Utility function to convert from list of Detection proto to DetectionResult +// struct. +DetectionResult ConvertToDetectionResult( + std::vector detections_proto); + +} // namespace mediapipe::tasks::components::containers +#endif // MEDIAPIPE_TASKS_CC_COMPONENTS_CONTAINERS_DETECTION_RESULT_H_ diff --git a/mediapipe/tasks/cc/components/containers/rect.cc b/mediapipe/tasks/cc/components/containers/rect.cc new file mode 100644 index 000000000..4a94832a6 --- /dev/null +++ b/mediapipe/tasks/cc/components/containers/rect.cc @@ -0,0 +1,34 @@ +/* Copyright 2022 The MediaPipe Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "mediapipe/tasks/cc/components/containers/rect.h" + +namespace mediapipe::tasks::components::containers { + +RectF ToRectF(const Rect& rect, int image_height, int image_width) { + return RectF{static_cast(rect.left) / image_width, + static_cast(rect.top) / image_height, + static_cast(rect.right) / image_width, + static_cast(rect.bottom) / image_height}; +} + +Rect ToRect(const RectF& rect, int image_height, int image_width) { + return Rect{static_cast(rect.left * image_width), + static_cast(rect.top * image_height), + static_cast(rect.right * image_width), + static_cast(rect.bottom * image_height)}; +} + +} // namespace mediapipe::tasks::components::containers diff --git a/mediapipe/tasks/cc/components/containers/rect.h b/mediapipe/tasks/cc/components/containers/rect.h index 3f5432cf2..551d91588 100644 --- a/mediapipe/tasks/cc/components/containers/rect.h +++ b/mediapipe/tasks/cc/components/containers/rect.h @@ -16,20 +16,47 @@ limitations under the License. #ifndef MEDIAPIPE_TASKS_CC_COMPONENTS_CONTAINERS_RECT_H_ #define MEDIAPIPE_TASKS_CC_COMPONENTS_CONTAINERS_RECT_H_ +#include + namespace mediapipe::tasks::components::containers { +constexpr float kRectFTolerance = 1e-4; + // Defines a rectangle, used e.g. as part of detection results or as input // region-of-interest. // +struct Rect { + int left; + int top; + int right; + int bottom; +}; + +inline bool operator==(const Rect& lhs, const Rect& rhs) { + return lhs.left == rhs.left && lhs.top == rhs.top && lhs.right == rhs.right && + lhs.bottom == rhs.bottom; +} + // The coordinates are normalized wrt the image dimensions, i.e. generally in // [0,1] but they may exceed these bounds if describing a region overlapping the // image. The origin is on the top-left corner of the image. -struct Rect { +struct RectF { float left; float top; float right; float bottom; }; +inline bool operator==(const RectF& lhs, const RectF& rhs) { + return abs(lhs.left - rhs.left) < kRectFTolerance && + abs(lhs.top - rhs.top) < kRectFTolerance && + abs(lhs.right - rhs.right) < kRectFTolerance && + abs(lhs.bottom - rhs.bottom) < kRectFTolerance; +} + +RectF ToRectF(const Rect& rect, int image_height, int image_width); + +Rect ToRect(const RectF& rect, int image_height, int image_width); + } // namespace mediapipe::tasks::components::containers #endif // MEDIAPIPE_TASKS_CC_COMPONENTS_CONTAINERS_RECT_H_ diff --git a/mediapipe/tasks/cc/components/processors/BUILD b/mediapipe/tasks/cc/components/processors/BUILD index 185bf231b..cec44a9e3 100644 --- a/mediapipe/tasks/cc/components/processors/BUILD +++ b/mediapipe/tasks/cc/components/processors/BUILD @@ -150,9 +150,12 @@ cc_library( "//mediapipe/framework/api2:builder", "//mediapipe/framework/api2:port", "//mediapipe/framework/formats:tensor", + "//mediapipe/tasks/cc:common", + "//mediapipe/tasks/cc/components/processors/proto:text_model_type_cc_proto", "//mediapipe/tasks/cc/components/processors/proto:text_preprocessing_graph_options_cc_proto", "//mediapipe/tasks/cc/core:model_resources", "//mediapipe/tasks/cc/metadata:metadata_extractor", + "//mediapipe/tasks/cc/text/utils:text_model_utils", "@com_google_absl//absl/status", "@com_google_absl//absl/status:statusor", "@com_google_absl//absl/strings", diff --git a/mediapipe/tasks/cc/components/processors/image_preprocessing_graph.cc b/mediapipe/tasks/cc/components/processors/image_preprocessing_graph.cc index b24b7f0cb..fefc1ec52 100644 --- a/mediapipe/tasks/cc/components/processors/image_preprocessing_graph.cc +++ b/mediapipe/tasks/cc/components/processors/image_preprocessing_graph.cc @@ -45,6 +45,7 @@ namespace components { namespace processors { namespace { +using ::mediapipe::NormalizedRect; using ::mediapipe::Tensor; using ::mediapipe::api2::Input; using ::mediapipe::api2::Output; diff --git a/mediapipe/tasks/cc/components/processors/proto/BUILD b/mediapipe/tasks/cc/components/processors/proto/BUILD index f48c4bad8..816ba47e3 100644 --- a/mediapipe/tasks/cc/components/processors/proto/BUILD +++ b/mediapipe/tasks/cc/components/processors/proto/BUILD @@ -60,10 +60,16 @@ mediapipe_proto_library( ], ) +mediapipe_proto_library( + name = "text_model_type_proto", + srcs = ["text_model_type.proto"], +) + mediapipe_proto_library( name = "text_preprocessing_graph_options_proto", srcs = ["text_preprocessing_graph_options.proto"], deps = [ + ":text_model_type_proto", "//mediapipe/framework:calculator_options_proto", "//mediapipe/framework:calculator_proto", ], diff --git a/mediapipe/tasks/cc/components/processors/proto/text_model_type.proto b/mediapipe/tasks/cc/components/processors/proto/text_model_type.proto new file mode 100644 index 000000000..7ffc0db07 --- /dev/null +++ b/mediapipe/tasks/cc/components/processors/proto/text_model_type.proto @@ -0,0 +1,31 @@ +/* Copyright 2022 The MediaPipe Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +syntax = "proto2"; + +package mediapipe.tasks.components.processors.proto; + +message TextModelType { + // TFLite text models supported by MediaPipe tasks. + enum ModelType { + UNSPECIFIED_MODEL = 0; + // A BERT-based model. + BERT_MODEL = 1; + // A model expecting input passed through a regex-based tokenizer. + REGEX_MODEL = 2; + // A model taking a string tensor input. + STRING_MODEL = 3; + } +} diff --git a/mediapipe/tasks/cc/components/processors/proto/text_preprocessing_graph_options.proto b/mediapipe/tasks/cc/components/processors/proto/text_preprocessing_graph_options.proto index a67cfd8a9..b610f7757 100644 --- a/mediapipe/tasks/cc/components/processors/proto/text_preprocessing_graph_options.proto +++ b/mediapipe/tasks/cc/components/processors/proto/text_preprocessing_graph_options.proto @@ -18,25 +18,16 @@ syntax = "proto2"; package mediapipe.tasks.components.processors.proto; import "mediapipe/framework/calculator.proto"; +import "mediapipe/tasks/cc/components/processors/proto/text_model_type.proto"; message TextPreprocessingGraphOptions { extend mediapipe.CalculatorOptions { optional TextPreprocessingGraphOptions ext = 476978751; } - // The type of text preprocessor required for the TFLite model. - enum PreprocessorType { - UNSPECIFIED_PREPROCESSOR = 0; - // Used for the BertPreprocessorCalculator. - BERT_PREPROCESSOR = 1; - // Used for the RegexPreprocessorCalculator. - REGEX_PREPROCESSOR = 2; - // Used for the TextToTensorCalculator. - STRING_PREPROCESSOR = 3; - } - optional PreprocessorType preprocessor_type = 1; + optional TextModelType.ModelType model_type = 1; // The maximum input sequence length for the TFLite model. Used with - // BERT_PREPROCESSOR and REGEX_PREPROCESSOR. + // BERT_MODEL and REGEX_MODEL. optional int32 max_seq_len = 2; } diff --git a/mediapipe/tasks/cc/components/processors/text_preprocessing_graph.cc b/mediapipe/tasks/cc/components/processors/text_preprocessing_graph.cc index de16375bd..f6c15c441 100644 --- a/mediapipe/tasks/cc/components/processors/text_preprocessing_graph.cc +++ b/mediapipe/tasks/cc/components/processors/text_preprocessing_graph.cc @@ -25,15 +25,14 @@ limitations under the License. #include "mediapipe/framework/api2/port.h" #include "mediapipe/framework/formats/tensor.h" #include "mediapipe/framework/subgraph.h" +#include "mediapipe/tasks/cc/common.h" +#include "mediapipe/tasks/cc/components/processors/proto/text_model_type.pb.h" #include "mediapipe/tasks/cc/components/processors/proto/text_preprocessing_graph_options.pb.h" #include "mediapipe/tasks/cc/core/model_resources.h" #include "mediapipe/tasks/cc/metadata/metadata_extractor.h" +#include "mediapipe/tasks/cc/text/utils/text_model_utils.h" -namespace mediapipe { -namespace tasks { -namespace components { -namespace processors { - +namespace mediapipe::tasks::components::processors { namespace { using ::mediapipe::api2::Input; @@ -42,91 +41,35 @@ using ::mediapipe::api2::SideInput; using ::mediapipe::api2::builder::Graph; using ::mediapipe::api2::builder::SideSource; using ::mediapipe::api2::builder::Source; +using ::mediapipe::tasks::components::processors::proto::TextModelType; using ::mediapipe::tasks::components::processors::proto:: TextPreprocessingGraphOptions; using ::mediapipe::tasks::core::ModelResources; using ::mediapipe::tasks::metadata::ModelMetadataExtractor; +using ::mediapipe::tasks::text::utils::GetModelType; constexpr char kTextTag[] = "TEXT"; constexpr char kMetadataExtractorTag[] = "METADATA_EXTRACTOR"; constexpr char kTensorsTag[] = "TENSORS"; -constexpr int kNumInputTensorsForBert = 3; -constexpr int kNumInputTensorsForRegex = 1; - -// Gets the name of the MediaPipe calculator associated with -// `preprocessor_type`. -absl::StatusOr GetCalculatorNameFromPreprocessorType( - TextPreprocessingGraphOptions::PreprocessorType preprocessor_type) { - switch (preprocessor_type) { - case TextPreprocessingGraphOptions::UNSPECIFIED_PREPROCESSOR: +// Gets the name of the MediaPipe preprocessor calculator associated with +// `model_type`. +absl::StatusOr GetCalculatorNameFromModelType( + TextModelType::ModelType model_type) { + switch (model_type) { + case TextModelType::UNSPECIFIED_MODEL: return CreateStatusWithPayload( - absl::StatusCode::kInvalidArgument, "Unspecified preprocessor type", + absl::StatusCode::kInvalidArgument, "Unspecified model type", MediaPipeTasksStatus::kInvalidArgumentError); - case TextPreprocessingGraphOptions::BERT_PREPROCESSOR: + case TextModelType::BERT_MODEL: return "BertPreprocessorCalculator"; - case TextPreprocessingGraphOptions::REGEX_PREPROCESSOR: + case TextModelType::REGEX_MODEL: return "RegexPreprocessorCalculator"; - case TextPreprocessingGraphOptions::STRING_PREPROCESSOR: + case TextModelType::STRING_MODEL: return "TextToTensorCalculator"; } } -// Determines the PreprocessorType for the model based on its metadata as well -// as its input tensors' type and count. Returns an error if there is no -// compatible preprocessor. -absl::StatusOr -GetPreprocessorType(const ModelResources& model_resources) { - const tflite::SubGraph& model_graph = - *(*model_resources.GetTfLiteModel()->subgraphs())[0]; - bool all_int32_tensors = - absl::c_all_of(*model_graph.inputs(), [&model_graph](int i) { - return (*model_graph.tensors())[i]->type() == tflite::TensorType_INT32; - }); - bool all_string_tensors = - absl::c_all_of(*model_graph.inputs(), [&model_graph](int i) { - return (*model_graph.tensors())[i]->type() == tflite::TensorType_STRING; - }); - if (!all_int32_tensors && !all_string_tensors) { - return CreateStatusWithPayload( - absl::StatusCode::kInvalidArgument, - "All input tensors should have type int32 or all should have type " - "string", - MediaPipeTasksStatus::kInvalidInputTensorTypeError); - } - if (all_string_tensors) { - return TextPreprocessingGraphOptions::STRING_PREPROCESSOR; - } - - // Otherwise, all tensors should have type int32 - const ModelMetadataExtractor* metadata_extractor = - model_resources.GetMetadataExtractor(); - if (metadata_extractor->GetModelMetadata() == nullptr || - metadata_extractor->GetModelMetadata()->subgraph_metadata() == nullptr) { - return CreateStatusWithPayload( - absl::StatusCode::kInvalidArgument, - "Text models with int32 input tensors require TFLite Model " - "Metadata but none was found", - MediaPipeTasksStatus::kMetadataNotFoundError); - } - - if (model_graph.inputs()->size() == kNumInputTensorsForBert) { - return TextPreprocessingGraphOptions::BERT_PREPROCESSOR; - } - - if (model_graph.inputs()->size() == kNumInputTensorsForRegex) { - return TextPreprocessingGraphOptions::REGEX_PREPROCESSOR; - } - - return CreateStatusWithPayload( - absl::StatusCode::kInvalidArgument, - absl::Substitute("Models with int32 input tensors should take exactly $0 " - "or $1 input tensors, but found $2", - kNumInputTensorsForBert, kNumInputTensorsForRegex, - model_graph.inputs()->size()), - MediaPipeTasksStatus::kInvalidNumInputTensorsError); -} - // Returns the maximum input sequence length accepted by the TFLite // model that owns `model graph` or returns an error if the model's input // tensors' shape is invalid for text preprocessing. This util assumes that the @@ -181,17 +124,16 @@ absl::Status ConfigureTextPreprocessingGraph( MediaPipeTasksStatus::kInvalidArgumentError); } - ASSIGN_OR_RETURN( - TextPreprocessingGraphOptions::PreprocessorType preprocessor_type, - GetPreprocessorType(model_resources)); - options.set_preprocessor_type(preprocessor_type); - switch (preprocessor_type) { - case TextPreprocessingGraphOptions::UNSPECIFIED_PREPROCESSOR: - case TextPreprocessingGraphOptions::STRING_PREPROCESSOR: { + ASSIGN_OR_RETURN(TextModelType::ModelType model_type, + GetModelType(model_resources)); + options.set_model_type(model_type); + switch (model_type) { + case TextModelType::UNSPECIFIED_MODEL: + case TextModelType::STRING_MODEL: { break; } - case TextPreprocessingGraphOptions::BERT_PREPROCESSOR: - case TextPreprocessingGraphOptions::REGEX_PREPROCESSOR: { + case TextModelType::BERT_MODEL: + case TextModelType::REGEX_MODEL: { ASSIGN_OR_RETURN( int max_seq_len, GetMaxSeqLen(*(*model_resources.GetTfLiteModel()->subgraphs())[0])); @@ -239,23 +181,22 @@ class TextPreprocessingGraph : public mediapipe::Subgraph { absl::StatusOr>> BuildTextPreprocessing( const TextPreprocessingGraphOptions& options, Source text_in, SideSource metadata_extractor_in, Graph& graph) { - ASSIGN_OR_RETURN( - std::string preprocessor_name, - GetCalculatorNameFromPreprocessorType(options.preprocessor_type())); + ASSIGN_OR_RETURN(std::string preprocessor_name, + GetCalculatorNameFromModelType(options.model_type())); auto& text_preprocessor = graph.AddNode(preprocessor_name); - switch (options.preprocessor_type()) { - case TextPreprocessingGraphOptions::UNSPECIFIED_PREPROCESSOR: - case TextPreprocessingGraphOptions::STRING_PREPROCESSOR: { + switch (options.model_type()) { + case TextModelType::UNSPECIFIED_MODEL: + case TextModelType::STRING_MODEL: { break; } - case TextPreprocessingGraphOptions::BERT_PREPROCESSOR: { + case TextModelType::BERT_MODEL: { text_preprocessor.GetOptions() .set_bert_max_seq_len(options.max_seq_len()); metadata_extractor_in >> text_preprocessor.SideIn(kMetadataExtractorTag); break; } - case TextPreprocessingGraphOptions::REGEX_PREPROCESSOR: { + case TextModelType::REGEX_MODEL: { text_preprocessor.GetOptions() .set_max_seq_len(options.max_seq_len()); metadata_extractor_in >> @@ -270,7 +211,4 @@ class TextPreprocessingGraph : public mediapipe::Subgraph { REGISTER_MEDIAPIPE_GRAPH( ::mediapipe::tasks::components::processors::TextPreprocessingGraph); -} // namespace processors -} // namespace components -} // namespace tasks -} // namespace mediapipe +} // namespace mediapipe::tasks::components::processors diff --git a/mediapipe/tasks/cc/core/BUILD b/mediapipe/tasks/cc/core/BUILD index 202f3ea3c..d440271df 100644 --- a/mediapipe/tasks/cc/core/BUILD +++ b/mediapipe/tasks/cc/core/BUILD @@ -117,6 +117,7 @@ cc_library_with_tflite( "//mediapipe/tasks/cc/core/proto:external_file_cc_proto", "//mediapipe/tasks/cc/metadata:metadata_extractor", "//mediapipe/util:resource_util", + "//mediapipe/util:resource_util_custom", "//mediapipe/util/tflite:error_reporter", "@com_google_absl//absl/memory", "@com_google_absl//absl/status", @@ -308,7 +309,10 @@ cc_library( name = "utils", srcs = ["utils.cc"], hdrs = ["utils.h"], - visibility = ["//mediapipe/framework:mediapipe_internal"], + visibility = [ + "//mediapipe/calculators:__subpackages__", + "//mediapipe/tasks:internal", + ], deps = [ "//mediapipe/calculators/core:flow_limiter_calculator_cc_proto", "//mediapipe/framework:calculator_cc_proto", diff --git a/mediapipe/tasks/cc/core/model_resources.cc b/mediapipe/tasks/cc/core/model_resources.cc index d5c12ee95..7819f6213 100644 --- a/mediapipe/tasks/cc/core/model_resources.cc +++ b/mediapipe/tasks/cc/core/model_resources.cc @@ -33,6 +33,7 @@ limitations under the License. #include "mediapipe/tasks/cc/core/proto/external_file.pb.h" #include "mediapipe/tasks/cc/metadata/metadata_extractor.h" #include "mediapipe/util/resource_util.h" +#include "mediapipe/util/resource_util_custom.h" #include "mediapipe/util/tflite/error_reporter.h" #include "tensorflow/lite/core/api/error_reporter.h" #include "tensorflow/lite/core/api/op_resolver.h" @@ -99,21 +100,20 @@ const tflite::Model* ModelResources::GetTfLiteModel() const { absl::Status ModelResources::BuildModelFromExternalFileProto() { if (model_file_->has_file_name()) { -#ifdef __EMSCRIPTEN__ - // In browsers, the model file may require a custom ResourceProviderFn to - // provide the model content. The open() method may not work in this case. - // Thus, loading the model content from the model file path in advance with - // the help of GetResourceContents. - MP_RETURN_IF_ERROR(mediapipe::GetResourceContents( - model_file_->file_name(), model_file_->mutable_file_content())); - model_file_->clear_file_name(); -#else - // If the model file name is a relative path, searches the file in a - // platform-specific location and returns the absolute path on success. - ASSIGN_OR_RETURN(std::string path_to_resource, - mediapipe::PathToResourceAsFile(model_file_->file_name())); - model_file_->set_file_name(path_to_resource); -#endif // __EMSCRIPTEN__ + if (HasCustomGlobalResourceProvider()) { + // If the model contents are provided via a custom ResourceProviderFn, the + // open() method may not work. Thus, loads the model content from the + // model file path in advance with the help of GetResourceContents. + MP_RETURN_IF_ERROR(GetResourceContents( + model_file_->file_name(), model_file_->mutable_file_content())); + model_file_->clear_file_name(); + } else { + // If the model file name is a relative path, searches the file in a + // platform-specific location and returns the absolute path on success. + ASSIGN_OR_RETURN(std::string path_to_resource, + PathToResourceAsFile(model_file_->file_name())); + model_file_->set_file_name(path_to_resource); + } } ASSIGN_OR_RETURN( model_file_handler_, diff --git a/mediapipe/tasks/cc/core/model_task_graph.cc b/mediapipe/tasks/cc/core/model_task_graph.cc index 66434483b..0cb556ec2 100644 --- a/mediapipe/tasks/cc/core/model_task_graph.cc +++ b/mediapipe/tasks/cc/core/model_task_graph.cc @@ -186,7 +186,7 @@ absl::StatusOr ModelTaskGraph::CreateModelResources( absl::StatusOr ModelTaskGraph::CreateModelAssetBundleResources( SubgraphContext* sc, std::unique_ptr external_file, - const std::string tag_suffix) { + std::string tag_suffix) { auto model_resources_cache_service = sc->Service(kModelResourcesCacheService); bool has_file_pointer_meta = external_file->has_file_pointer_meta(); // if external file is set by file pointer, no need to add the model asset diff --git a/mediapipe/tasks/cc/core/model_task_graph.h b/mediapipe/tasks/cc/core/model_task_graph.h index 50dcc903b..3068b2c46 100644 --- a/mediapipe/tasks/cc/core/model_task_graph.h +++ b/mediapipe/tasks/cc/core/model_task_graph.h @@ -59,14 +59,16 @@ class ModelTaskGraph : public Subgraph { // creates a local model resources object that can only be used in the graph // construction stage. The returned model resources pointer will provide graph // authors with the access to the metadata extractor and the tflite model. + // If more than one model resources are created in a graph, the model + // resources graph service add the tag_suffix to support multiple resources. template absl::StatusOr CreateModelResources( - SubgraphContext* sc) { + SubgraphContext* sc, std::string tag_suffix = "") { auto external_file = std::make_unique(); external_file->Swap(sc->MutableOptions() ->mutable_base_options() ->mutable_model_asset()); - return CreateModelResources(sc, std::move(external_file)); + return CreateModelResources(sc, std::move(external_file), tag_suffix); } // If the model resources graph service is available, creates a model @@ -83,7 +85,7 @@ class ModelTaskGraph : public Subgraph { // resources. absl::StatusOr CreateModelResources( SubgraphContext* sc, std::unique_ptr external_file, - const std::string tag_suffix = ""); + std::string tag_suffix = ""); // If the model resources graph service is available, creates a model asset // bundle resources object from the subgraph context, and caches the created diff --git a/mediapipe/tasks/cc/text/text_classifier/proto/text_classifier_graph_options.proto b/mediapipe/tasks/cc/text/text_classifier/proto/text_classifier_graph_options.proto index 8f4d7eea6..41f87b519 100644 --- a/mediapipe/tasks/cc/text/text_classifier/proto/text_classifier_graph_options.proto +++ b/mediapipe/tasks/cc/text/text_classifier/proto/text_classifier_graph_options.proto @@ -18,6 +18,7 @@ syntax = "proto2"; package mediapipe.tasks.text.text_classifier.proto; import "mediapipe/framework/calculator.proto"; +import "mediapipe/framework/calculator_options.proto"; import "mediapipe/tasks/cc/components/processors/proto/classifier_options.proto"; import "mediapipe/tasks/cc/core/proto/base_options.proto"; diff --git a/mediapipe/tasks/cc/text/text_classifier/text_classifier_test.cc b/mediapipe/tasks/cc/text/text_classifier/text_classifier_test.cc index 8f73914fc..799885eac 100644 --- a/mediapipe/tasks/cc/text/text_classifier/text_classifier_test.cc +++ b/mediapipe/tasks/cc/text/text_classifier/text_classifier_test.cc @@ -38,10 +38,7 @@ limitations under the License. #include "mediapipe/tasks/cc/text/text_classifier/text_classifier_test_utils.h" #include "tensorflow/lite/core/shims/cc/shims_test_util.h" -namespace mediapipe { -namespace tasks { -namespace text { -namespace text_classifier { +namespace mediapipe::tasks::text::text_classifier { namespace { using ::mediapipe::file::JoinPath; @@ -88,6 +85,8 @@ void ExpectApproximatelyEqual(const TextClassifierResult& actual, } } +} // namespace + class TextClassifierTest : public tflite_shims::testing::Test {}; TEST_F(TextClassifierTest, CreateSucceedsWithBertModel) { @@ -217,8 +216,42 @@ TEST_F(TextClassifierTest, TextClassifierWithStringToBool) { MP_ASSERT_OK(classifier->Close()); } -} // namespace -} // namespace text_classifier -} // namespace text -} // namespace tasks -} // namespace mediapipe +TEST_F(TextClassifierTest, BertLongPositive) { + std::stringstream ss_for_positive_review; + ss_for_positive_review + << "it's a charming and often affecting journey and this is a long"; + for (int i = 0; i < kMaxSeqLen; ++i) { + ss_for_positive_review << " long"; + } + ss_for_positive_review << " movie review"; + auto options = std::make_unique(); + options->base_options.model_asset_path = GetFullPath(kTestBertModelPath); + MP_ASSERT_OK_AND_ASSIGN(std::unique_ptr classifier, + TextClassifier::Create(std::move(options))); + MP_ASSERT_OK_AND_ASSIGN(TextClassifierResult result, + classifier->Classify(ss_for_positive_review.str())); + TextClassifierResult expected; + std::vector categories; + +// Predicted scores are slightly different on Mac OS. +#ifdef __APPLE__ + categories.push_back( + {/*index=*/1, /*score=*/0.974181, /*category_name=*/"positive"}); + categories.push_back( + {/*index=*/0, /*score=*/0.025819, /*category_name=*/"negative"}); +#else + categories.push_back( + {/*index=*/1, /*score=*/0.985889, /*category_name=*/"positive"}); + categories.push_back( + {/*index=*/0, /*score=*/0.014112, /*category_name=*/"negative"}); +#endif // __APPLE__ + + expected.classifications.emplace_back( + Classifications{/*categories=*/categories, + /*head_index=*/0, + /*head_name=*/"probability"}); + ExpectApproximatelyEqual(result, expected); + MP_ASSERT_OK(classifier->Close()); +} + +} // namespace mediapipe::tasks::text::text_classifier diff --git a/mediapipe/tasks/cc/text/text_embedder/proto/text_embedder_graph_options.proto b/mediapipe/tasks/cc/text/text_embedder/proto/text_embedder_graph_options.proto index e7e3a63c7..fc8e02858 100644 --- a/mediapipe/tasks/cc/text/text_embedder/proto/text_embedder_graph_options.proto +++ b/mediapipe/tasks/cc/text/text_embedder/proto/text_embedder_graph_options.proto @@ -18,6 +18,7 @@ syntax = "proto2"; package mediapipe.tasks.text.text_embedder.proto; import "mediapipe/framework/calculator.proto"; +import "mediapipe/framework/calculator_options.proto"; import "mediapipe/tasks/cc/components/processors/proto/embedder_options.proto"; import "mediapipe/tasks/cc/core/proto/base_options.proto"; diff --git a/mediapipe/tasks/cc/text/tokenizers/BUILD b/mediapipe/tasks/cc/text/tokenizers/BUILD index 7f1ea2848..92fac8eaa 100644 --- a/mediapipe/tasks/cc/text/tokenizers/BUILD +++ b/mediapipe/tasks/cc/text/tokenizers/BUILD @@ -12,7 +12,7 @@ # See the License for the specific language governing permissions and # limitations under the License. -package(default_visibility = ["//mediapipe/framework:mediapipe_internal"]) +package(default_visibility = ["//mediapipe/calculators/tensor:__subpackages__"]) licenses(["notice"]) diff --git a/mediapipe/tasks/cc/text/utils/BUILD b/mediapipe/tasks/cc/text/utils/BUILD index 710e8a984..092a7d450 100644 --- a/mediapipe/tasks/cc/text/utils/BUILD +++ b/mediapipe/tasks/cc/text/utils/BUILD @@ -43,3 +43,43 @@ cc_test( "@com_google_absl//absl/container:node_hash_map", ], ) + +cc_library( + name = "text_model_utils", + srcs = ["text_model_utils.cc"], + hdrs = ["text_model_utils.h"], + deps = [ + "//mediapipe/tasks/cc:common", + "//mediapipe/tasks/cc/components/processors/proto:text_model_type_cc_proto", + "//mediapipe/tasks/cc/core:model_resources", + "//mediapipe/tasks/cc/metadata:metadata_extractor", + "@com_google_absl//absl/algorithm:container", + "@com_google_absl//absl/status", + "@com_google_absl//absl/status:statusor", + "@com_google_absl//absl/strings", + "@org_tensorflow//tensorflow/lite/schema:schema_fbs", + ], +) + +cc_test( + name = "text_model_utils_test", + srcs = ["text_model_utils_test.cc"], + data = [ + "//mediapipe/tasks/testdata/text:bert_text_classifier_models", + "//mediapipe/tasks/testdata/text:mobilebert_embedding_model", + "//mediapipe/tasks/testdata/text:regex_embedding_with_metadata", + "//mediapipe/tasks/testdata/text:text_classifier_models", + ], + deps = [ + ":text_model_utils", + "//mediapipe/framework/deps:file_path", + "//mediapipe/framework/port:gtest_main", + "//mediapipe/tasks/cc/components/processors/proto:text_model_type_cc_proto", + "//mediapipe/tasks/cc/core:model_resources", + "//mediapipe/tasks/cc/core/proto:external_file_cc_proto", + "@com_google_absl//absl/flags:flag", + "@com_google_absl//absl/status:statusor", + "@com_google_absl//absl/strings", + "@org_tensorflow//tensorflow/lite/core/shims:cc_shims_test_util", + ], +) diff --git a/mediapipe/tasks/cc/text/utils/text_model_utils.cc b/mediapipe/tasks/cc/text/utils/text_model_utils.cc new file mode 100644 index 000000000..9d0005ec1 --- /dev/null +++ b/mediapipe/tasks/cc/text/utils/text_model_utils.cc @@ -0,0 +1,119 @@ +/* Copyright 2022 The MediaPipe Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "mediapipe/tasks/cc/text/utils/text_model_utils.h" + +#include "absl/algorithm/container.h" +#include "absl/status/status.h" +#include "absl/status/statusor.h" +#include "absl/strings/substitute.h" +#include "mediapipe/tasks/cc/common.h" +#include "mediapipe/tasks/cc/components/processors/proto/text_model_type.pb.h" +#include "mediapipe/tasks/cc/core/model_resources.h" +#include "mediapipe/tasks/cc/metadata/metadata_extractor.h" +#include "tensorflow/lite/schema/schema_generated.h" + +namespace mediapipe::tasks::text::utils { +namespace { + +using ::mediapipe::tasks::components::processors::proto::TextModelType; +using ::mediapipe::tasks::core::ModelResources; +using ::mediapipe::tasks::metadata::ModelMetadataExtractor; + +constexpr int kNumInputTensorsForBert = 3; +constexpr int kNumInputTensorsForRegex = 1; +constexpr int kNumInputTensorsForStringPreprocessor = 1; + +// Determines the ModelType for a model with int32 input tensors based +// on the number of input tensors. Returns an error if there is missing metadata +// or an invalid number of input tensors. +absl::StatusOr GetIntTensorModelType( + const ModelResources& model_resources, int num_input_tensors) { + const ModelMetadataExtractor* metadata_extractor = + model_resources.GetMetadataExtractor(); + if (metadata_extractor->GetModelMetadata() == nullptr || + metadata_extractor->GetModelMetadata()->subgraph_metadata() == nullptr) { + return CreateStatusWithPayload( + absl::StatusCode::kInvalidArgument, + "Text models with int32 input tensors require TFLite Model " + "Metadata but none was found", + MediaPipeTasksStatus::kMetadataNotFoundError); + } + + if (num_input_tensors == kNumInputTensorsForBert) { + return TextModelType::BERT_MODEL; + } + + if (num_input_tensors == kNumInputTensorsForRegex) { + return TextModelType::REGEX_MODEL; + } + + return CreateStatusWithPayload( + absl::StatusCode::kInvalidArgument, + absl::Substitute("Models with int32 input tensors should take exactly $0 " + "or $1 input tensors, but found $2", + kNumInputTensorsForBert, kNumInputTensorsForRegex, + num_input_tensors), + MediaPipeTasksStatus::kInvalidNumInputTensorsError); +} + +// Determines the ModelType for a model with string input tensors based +// on the number of input tensors. Returns an error if there is an invalid +// number of input tensors. +absl::StatusOr GetStringTensorModelType( + const ModelResources& model_resources, int num_input_tensors) { + if (num_input_tensors == kNumInputTensorsForStringPreprocessor) { + return TextModelType::STRING_MODEL; + } + + return CreateStatusWithPayload( + absl::StatusCode::kInvalidArgument, + absl::Substitute("Models with string input tensors should take exactly " + "$0 tensors, but found $1", + kNumInputTensorsForStringPreprocessor, + num_input_tensors), + MediaPipeTasksStatus::kInvalidNumInputTensorsError); +} +} // namespace + +absl::StatusOr GetModelType( + const ModelResources& model_resources) { + const tflite::SubGraph& model_graph = + *(*model_resources.GetTfLiteModel()->subgraphs())[0]; + bool all_int32_tensors = + absl::c_all_of(*model_graph.inputs(), [&model_graph](int i) { + return (*model_graph.tensors())[i]->type() == tflite::TensorType_INT32; + }); + bool all_string_tensors = + absl::c_all_of(*model_graph.inputs(), [&model_graph](int i) { + return (*model_graph.tensors())[i]->type() == tflite::TensorType_STRING; + }); + if (!all_int32_tensors && !all_string_tensors) { + return CreateStatusWithPayload( + absl::StatusCode::kInvalidArgument, + "All input tensors should have type int32 or all should have type " + "string", + MediaPipeTasksStatus::kInvalidInputTensorTypeError); + } + if (all_string_tensors) { + return GetStringTensorModelType(model_resources, + model_graph.inputs()->size()); + } + + // Otherwise, all tensors should have type int32 + return GetIntTensorModelType(model_resources, model_graph.inputs()->size()); +} + +} // namespace mediapipe::tasks::text::utils diff --git a/mediapipe/tasks/cc/text/utils/text_model_utils.h b/mediapipe/tasks/cc/text/utils/text_model_utils.h new file mode 100644 index 000000000..da8783d33 --- /dev/null +++ b/mediapipe/tasks/cc/text/utils/text_model_utils.h @@ -0,0 +1,33 @@ +/* Copyright 2022 The MediaPipe Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef MEDIAPIPE_TASKS_CC_TEXT_UTILS_TEXT_MODEL_UTILS_H_ +#define MEDIAPIPE_TASKS_CC_TEXT_UTILS_TEXT_MODEL_UTILS_H_ + +#include "absl/status/statusor.h" +#include "mediapipe/tasks/cc/components/processors/proto/text_model_type.pb.h" +#include "mediapipe/tasks/cc/core/model_resources.h" + +namespace mediapipe::tasks::text::utils { + +// Determines the ModelType for the model based on its metadata as well +// as its input tensors' type and count. Returns an error if there is no +// compatible model type. +absl::StatusOr +GetModelType(const core::ModelResources& model_resources); + +} // namespace mediapipe::tasks::text::utils + +#endif // MEDIAPIPE_TASKS_CC_TEXT_UTILS_TEXT_MODEL_UTILS_H_ diff --git a/mediapipe/tasks/cc/text/utils/text_model_utils_test.cc b/mediapipe/tasks/cc/text/utils/text_model_utils_test.cc new file mode 100644 index 000000000..c02f8eca5 --- /dev/null +++ b/mediapipe/tasks/cc/text/utils/text_model_utils_test.cc @@ -0,0 +1,108 @@ +/* Copyright 2022 The MediaPipe Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "mediapipe/tasks/cc/text/utils/text_model_utils.h" + +#include +#include + +#include "absl/flags/flag.h" +#include "absl/status/statusor.h" +#include "absl/strings/string_view.h" +#include "mediapipe/framework/deps/file_path.h" +#include "mediapipe/framework/port/gmock.h" +#include "mediapipe/framework/port/gtest.h" +#include "mediapipe/framework/port/status_matchers.h" +#include "mediapipe/tasks/cc/components/processors/proto/text_model_type.pb.h" +#include "mediapipe/tasks/cc/core/model_resources.h" +#include "mediapipe/tasks/cc/core/proto/external_file.pb.h" +#include "tensorflow/lite/core/shims/cc/shims_test_util.h" + +namespace mediapipe::tasks::text::utils { + +namespace { + +using ::mediapipe::file::JoinPath; +using ::mediapipe::tasks::components::processors::proto::TextModelType; +using ::mediapipe::tasks::core::ModelResources; +using ::mediapipe::tasks::core::proto::ExternalFile; + +constexpr absl::string_view kTestModelResourcesTag = "test_model_resources"; + +constexpr absl::string_view kTestDataDirectory = + "/mediapipe/tasks/testdata/text/"; +// Classification model with BERT preprocessing. +constexpr absl::string_view kBertClassifierPath = "bert_text_classifier.tflite"; +// Embedding model with BERT preprocessing. +constexpr absl::string_view kMobileBert = + "mobilebert_embedding_with_metadata.tflite"; +// Classification model with regex preprocessing. +constexpr absl::string_view kRegexClassifierPath = + "test_model_text_classifier_with_regex_tokenizer.tflite"; +// Embedding model with regex preprocessing. +constexpr absl::string_view kRegexOneEmbeddingModel = + "regex_one_embedding_with_metadata.tflite"; +// Classification model that takes a string tensor and outputs a bool tensor. +constexpr absl::string_view kStringToBoolModelPath = + "test_model_text_classifier_bool_output.tflite"; + +std::string GetFullPath(absl::string_view file_name) { + return JoinPath("./", kTestDataDirectory, file_name); +} + +absl::StatusOr GetModelTypeFromFile( + absl::string_view file_name) { + auto model_file = std::make_unique(); + model_file->set_file_name(GetFullPath(file_name)); + ASSIGN_OR_RETURN(auto model_resources, + ModelResources::Create(std::string(kTestModelResourcesTag), + std::move(model_file))); + return GetModelType(*model_resources); +} + +} // namespace + +class TextModelUtilsTest : public tflite_shims::testing::Test {}; + +TEST_F(TextModelUtilsTest, BertClassifierModelTest) { + MP_ASSERT_OK_AND_ASSIGN(auto model_type, + GetModelTypeFromFile(kBertClassifierPath)); + ASSERT_EQ(model_type, TextModelType::BERT_MODEL); +} + +TEST_F(TextModelUtilsTest, BertEmbedderModelTest) { + MP_ASSERT_OK_AND_ASSIGN(auto model_type, GetModelTypeFromFile(kMobileBert)); + ASSERT_EQ(model_type, TextModelType::BERT_MODEL); +} + +TEST_F(TextModelUtilsTest, RegexClassifierModelTest) { + MP_ASSERT_OK_AND_ASSIGN(auto model_type, + GetModelTypeFromFile(kRegexClassifierPath)); + ASSERT_EQ(model_type, TextModelType::REGEX_MODEL); +} + +TEST_F(TextModelUtilsTest, RegexEmbedderModelTest) { + MP_ASSERT_OK_AND_ASSIGN(auto model_type, + GetModelTypeFromFile(kRegexOneEmbeddingModel)); + ASSERT_EQ(model_type, TextModelType::REGEX_MODEL); +} + +TEST_F(TextModelUtilsTest, StringInputModelTest) { + MP_ASSERT_OK_AND_ASSIGN(auto model_type, + GetModelTypeFromFile(kStringToBoolModelPath)); + ASSERT_EQ(model_type, TextModelType::STRING_MODEL); +} + +} // namespace mediapipe::tasks::text::utils diff --git a/mediapipe/tasks/cc/vision/core/base_vision_task_api.h b/mediapipe/tasks/cc/vision/core/base_vision_task_api.h index c3c0a0261..a86b2cca8 100644 --- a/mediapipe/tasks/cc/vision/core/base_vision_task_api.h +++ b/mediapipe/tasks/cc/vision/core/base_vision_task_api.h @@ -129,13 +129,13 @@ class BaseVisionTaskApi : public tasks::core::BaseTaskApi { if (roi.left >= roi.right || roi.top >= roi.bottom) { return CreateStatusWithPayload( absl::StatusCode::kInvalidArgument, - "Expected Rect with left < right and top < bottom.", + "Expected RectF with left < right and top < bottom.", MediaPipeTasksStatus::kImageProcessingInvalidArgumentError); } if (roi.left < 0 || roi.top < 0 || roi.right > 1 || roi.bottom > 1) { return CreateStatusWithPayload( absl::StatusCode::kInvalidArgument, - "Expected Rect values to be in [0,1].", + "Expected RectF values to be in [0,1].", MediaPipeTasksStatus::kImageProcessingInvalidArgumentError); } normalized_rect.set_x_center((roi.left + roi.right) / 2.0); diff --git a/mediapipe/tasks/cc/vision/core/image_processing_options.h b/mediapipe/tasks/cc/vision/core/image_processing_options.h index 7e764c1fe..1983272fc 100644 --- a/mediapipe/tasks/cc/vision/core/image_processing_options.h +++ b/mediapipe/tasks/cc/vision/core/image_processing_options.h @@ -35,7 +35,8 @@ struct ImageProcessingOptions { // the full image is used. // // Coordinates must be in [0,1] with 'left' < 'right' and 'top' < bottom. - std::optional region_of_interest = std::nullopt; + std::optional region_of_interest = + std::nullopt; // The rotation to apply to the image (or cropped region-of-interest), in // degrees clockwise. diff --git a/mediapipe/tasks/cc/vision/gesture_recognizer/calculators/landmarks_to_matrix_calculator.cc b/mediapipe/tasks/cc/vision/gesture_recognizer/calculators/landmarks_to_matrix_calculator.cc index 277bb170a..088f97c29 100644 --- a/mediapipe/tasks/cc/vision/gesture_recognizer/calculators/landmarks_to_matrix_calculator.cc +++ b/mediapipe/tasks/cc/vision/gesture_recognizer/calculators/landmarks_to_matrix_calculator.cc @@ -35,6 +35,8 @@ limitations under the License. namespace mediapipe { namespace api2 { +using ::mediapipe::NormalizedRect; + namespace { constexpr char kLandmarksTag[] = "LANDMARKS"; diff --git a/mediapipe/tasks/cc/vision/gesture_recognizer/calculators/landmarks_to_matrix_calculator_test.cc b/mediapipe/tasks/cc/vision/gesture_recognizer/calculators/landmarks_to_matrix_calculator_test.cc index fe6f1162b..a1a44c8d1 100644 --- a/mediapipe/tasks/cc/vision/gesture_recognizer/calculators/landmarks_to_matrix_calculator_test.cc +++ b/mediapipe/tasks/cc/vision/gesture_recognizer/calculators/landmarks_to_matrix_calculator_test.cc @@ -33,6 +33,8 @@ namespace mediapipe { namespace { +using ::mediapipe::NormalizedRect; + constexpr char kLandmarksTag[] = "LANDMARKS"; constexpr char kWorldLandmarksTag[] = "WORLD_LANDMARKS"; constexpr char kImageSizeTag[] = "IMAGE_SIZE"; diff --git a/mediapipe/tasks/cc/vision/gesture_recognizer/gesture_recognizer.cc b/mediapipe/tasks/cc/vision/gesture_recognizer/gesture_recognizer.cc index e7fcf6fd9..01f444742 100644 --- a/mediapipe/tasks/cc/vision/gesture_recognizer/gesture_recognizer.cc +++ b/mediapipe/tasks/cc/vision/gesture_recognizer/gesture_recognizer.cc @@ -57,6 +57,8 @@ namespace { using GestureRecognizerGraphOptionsProto = ::mediapipe::tasks::vision:: gesture_recognizer::proto::GestureRecognizerGraphOptions; +using ::mediapipe::NormalizedRect; + constexpr char kHandGestureSubgraphTypeName[] = "mediapipe.tasks.vision.gesture_recognizer.GestureRecognizerGraph"; diff --git a/mediapipe/tasks/cc/vision/gesture_recognizer/gesture_recognizer_graph.cc b/mediapipe/tasks/cc/vision/gesture_recognizer/gesture_recognizer_graph.cc index 47d95100b..2d949c410 100644 --- a/mediapipe/tasks/cc/vision/gesture_recognizer/gesture_recognizer_graph.cc +++ b/mediapipe/tasks/cc/vision/gesture_recognizer/gesture_recognizer_graph.cc @@ -46,6 +46,7 @@ namespace gesture_recognizer { namespace { +using ::mediapipe::NormalizedRect; using ::mediapipe::api2::Input; using ::mediapipe::api2::Output; using ::mediapipe::api2::builder::Graph; diff --git a/mediapipe/tasks/cc/vision/gesture_recognizer/hand_gesture_recognizer_graph.cc b/mediapipe/tasks/cc/vision/gesture_recognizer/hand_gesture_recognizer_graph.cc index d7e983d81..4db57e85b 100644 --- a/mediapipe/tasks/cc/vision/gesture_recognizer/hand_gesture_recognizer_graph.cc +++ b/mediapipe/tasks/cc/vision/gesture_recognizer/hand_gesture_recognizer_graph.cc @@ -52,6 +52,7 @@ namespace gesture_recognizer { namespace { +using ::mediapipe::NormalizedRect; using ::mediapipe::api2::Input; using ::mediapipe::api2::Output; using ::mediapipe::api2::builder::Graph; diff --git a/mediapipe/tasks/cc/vision/gesture_recognizer/proto/gesture_classifier_graph_options.proto b/mediapipe/tasks/cc/vision/gesture_recognizer/proto/gesture_classifier_graph_options.proto index dcefa075f..edbabc018 100644 --- a/mediapipe/tasks/cc/vision/gesture_recognizer/proto/gesture_classifier_graph_options.proto +++ b/mediapipe/tasks/cc/vision/gesture_recognizer/proto/gesture_classifier_graph_options.proto @@ -18,6 +18,7 @@ syntax = "proto2"; package mediapipe.tasks.vision.gesture_recognizer.proto; import "mediapipe/framework/calculator.proto"; +import "mediapipe/framework/calculator_options.proto"; import "mediapipe/tasks/cc/components/processors/proto/classifier_options.proto"; import "mediapipe/tasks/cc/core/proto/base_options.proto"; diff --git a/mediapipe/tasks/cc/vision/gesture_recognizer/proto/gesture_embedder_graph_options.proto b/mediapipe/tasks/cc/vision/gesture_recognizer/proto/gesture_embedder_graph_options.proto index bff4e0a9c..df909a6db 100644 --- a/mediapipe/tasks/cc/vision/gesture_recognizer/proto/gesture_embedder_graph_options.proto +++ b/mediapipe/tasks/cc/vision/gesture_recognizer/proto/gesture_embedder_graph_options.proto @@ -18,6 +18,7 @@ syntax = "proto2"; package mediapipe.tasks.vision.gesture_recognizer.proto; import "mediapipe/framework/calculator.proto"; +import "mediapipe/framework/calculator_options.proto"; import "mediapipe/tasks/cc/core/proto/base_options.proto"; option java_package = "com.google.mediapipe.tasks.vision.gesturerecognizer.proto"; diff --git a/mediapipe/tasks/cc/vision/gesture_recognizer/proto/gesture_recognizer_graph_options.proto b/mediapipe/tasks/cc/vision/gesture_recognizer/proto/gesture_recognizer_graph_options.proto index 57d8a3746..fef22c07c 100644 --- a/mediapipe/tasks/cc/vision/gesture_recognizer/proto/gesture_recognizer_graph_options.proto +++ b/mediapipe/tasks/cc/vision/gesture_recognizer/proto/gesture_recognizer_graph_options.proto @@ -18,6 +18,7 @@ syntax = "proto2"; package mediapipe.tasks.vision.gesture_recognizer.proto; import "mediapipe/framework/calculator.proto"; +import "mediapipe/framework/calculator_options.proto"; import "mediapipe/tasks/cc/core/proto/base_options.proto"; import "mediapipe/tasks/cc/vision/gesture_recognizer/proto/hand_gesture_recognizer_graph_options.proto"; import "mediapipe/tasks/cc/vision/hand_landmarker/proto/hand_landmarker_graph_options.proto"; diff --git a/mediapipe/tasks/cc/vision/gesture_recognizer/proto/hand_gesture_recognizer_graph_options.proto b/mediapipe/tasks/cc/vision/gesture_recognizer/proto/hand_gesture_recognizer_graph_options.proto index 7df2fed37..ae85509da 100644 --- a/mediapipe/tasks/cc/vision/gesture_recognizer/proto/hand_gesture_recognizer_graph_options.proto +++ b/mediapipe/tasks/cc/vision/gesture_recognizer/proto/hand_gesture_recognizer_graph_options.proto @@ -18,6 +18,7 @@ syntax = "proto2"; package mediapipe.tasks.vision.gesture_recognizer.proto; import "mediapipe/framework/calculator.proto"; +import "mediapipe/framework/calculator_options.proto"; import "mediapipe/tasks/cc/core/proto/base_options.proto"; import "mediapipe/tasks/cc/vision/gesture_recognizer/proto/gesture_classifier_graph_options.proto"; import "mediapipe/tasks/cc/vision/gesture_recognizer/proto/gesture_embedder_graph_options.proto"; diff --git a/mediapipe/tasks/cc/vision/hand_detector/hand_detector_graph.cc b/mediapipe/tasks/cc/vision/hand_detector/hand_detector_graph.cc index c24548c9b..49958e36b 100644 --- a/mediapipe/tasks/cc/vision/hand_detector/hand_detector_graph.cc +++ b/mediapipe/tasks/cc/vision/hand_detector/hand_detector_graph.cc @@ -50,6 +50,7 @@ namespace hand_detector { namespace { +using ::mediapipe::NormalizedRect; using ::mediapipe::api2::Input; using ::mediapipe::api2::Output; using ::mediapipe::api2::builder::Graph; diff --git a/mediapipe/tasks/cc/vision/hand_detector/hand_detector_graph_test.cc b/mediapipe/tasks/cc/vision/hand_detector/hand_detector_graph_test.cc index cbbc0e193..f4e5f8c7d 100644 --- a/mediapipe/tasks/cc/vision/hand_detector/hand_detector_graph_test.cc +++ b/mediapipe/tasks/cc/vision/hand_detector/hand_detector_graph_test.cc @@ -53,6 +53,7 @@ namespace { using ::file::Defaults; using ::file::GetTextProto; +using ::mediapipe::NormalizedRect; using ::mediapipe::api2::Input; using ::mediapipe::api2::Output; using ::mediapipe::api2::builder::Graph; diff --git a/mediapipe/tasks/cc/vision/hand_detector/proto/hand_detector_graph_options.proto b/mediapipe/tasks/cc/vision/hand_detector/proto/hand_detector_graph_options.proto index a009f2365..bede70da5 100644 --- a/mediapipe/tasks/cc/vision/hand_detector/proto/hand_detector_graph_options.proto +++ b/mediapipe/tasks/cc/vision/hand_detector/proto/hand_detector_graph_options.proto @@ -18,6 +18,7 @@ syntax = "proto2"; package mediapipe.tasks.vision.hand_detector.proto; import "mediapipe/framework/calculator.proto"; +import "mediapipe/framework/calculator_options.proto"; import "mediapipe/tasks/cc/core/proto/base_options.proto"; option java_package = "com.google.mediapipe.tasks.vision.handdetector.proto"; diff --git a/mediapipe/tasks/cc/vision/hand_landmarker/calculators/hand_association_calculator.cc b/mediapipe/tasks/cc/vision/hand_landmarker/calculators/hand_association_calculator.cc index b6df80588..dffdbdd38 100644 --- a/mediapipe/tasks/cc/vision/hand_landmarker/calculators/hand_association_calculator.cc +++ b/mediapipe/tasks/cc/vision/hand_landmarker/calculators/hand_association_calculator.cc @@ -27,6 +27,8 @@ limitations under the License. namespace mediapipe::api2 { +using ::mediapipe::NormalizedRect; + // HandAssociationCalculator accepts multiple inputs of vectors of // NormalizedRect. The output is a vector of NormalizedRect that contains // rects from the input vectors that don't overlap with each other. When two diff --git a/mediapipe/tasks/cc/vision/hand_landmarker/calculators/hand_association_calculator_test.cc b/mediapipe/tasks/cc/vision/hand_landmarker/calculators/hand_association_calculator_test.cc index cb3130854..138164209 100644 --- a/mediapipe/tasks/cc/vision/hand_landmarker/calculators/hand_association_calculator_test.cc +++ b/mediapipe/tasks/cc/vision/hand_landmarker/calculators/hand_association_calculator_test.cc @@ -26,6 +26,8 @@ limitations under the License. namespace mediapipe { namespace { +using ::mediapipe::NormalizedRect; + class HandAssociationCalculatorTest : public testing::Test { protected: HandAssociationCalculatorTest() { diff --git a/mediapipe/tasks/cc/vision/hand_landmarker/calculators/hand_landmarks_deduplication_calculator.cc b/mediapipe/tasks/cc/vision/hand_landmarker/calculators/hand_landmarks_deduplication_calculator.cc index 564184c64..d875de98f 100644 --- a/mediapipe/tasks/cc/vision/hand_landmarker/calculators/hand_landmarks_deduplication_calculator.cc +++ b/mediapipe/tasks/cc/vision/hand_landmarker/calculators/hand_landmarks_deduplication_calculator.cc @@ -41,10 +41,11 @@ limitations under the License. namespace mediapipe::api2 { namespace { +using ::mediapipe::NormalizedRect; using ::mediapipe::api2::Input; using ::mediapipe::api2::Output; using ::mediapipe::api2::builder::Source; -using ::mediapipe::tasks::components::containers::Rect; +using ::mediapipe::tasks::components::containers::RectF; using ::mediapipe::tasks::vision::utils::CalculateIOU; using ::mediapipe::tasks::vision::utils::DuplicatesFinder; @@ -126,7 +127,7 @@ absl::StatusOr HandBaselineDistance( return distance; } -Rect CalculateBound(const NormalizedLandmarkList& list) { +RectF CalculateBound(const NormalizedLandmarkList& list) { constexpr float kMinInitialValue = std::numeric_limits::max(); constexpr float kMaxInitialValue = std::numeric_limits::lowest(); @@ -144,10 +145,10 @@ Rect CalculateBound(const NormalizedLandmarkList& list) { } // Populate normalized non rotated face bounding box - return Rect{/*left=*/bounding_box_left, - /*top=*/bounding_box_top, - /*right=*/bounding_box_right, - /*bottom=*/bounding_box_bottom}; + return RectF{/*left=*/bounding_box_left, + /*top=*/bounding_box_top, + /*right=*/bounding_box_right, + /*bottom=*/bounding_box_bottom}; } // Uses IoU and distance of some corresponding hand landmarks to detect @@ -172,7 +173,7 @@ class HandDuplicatesFinder : public DuplicatesFinder { const int num = multi_landmarks.size(); std::vector baseline_distances; baseline_distances.reserve(num); - std::vector bounds; + std::vector bounds; bounds.reserve(num); for (const NormalizedLandmarkList& list : multi_landmarks) { ASSIGN_OR_RETURN(const float baseline_distance, diff --git a/mediapipe/tasks/cc/vision/hand_landmarker/hand_landmarker.cc b/mediapipe/tasks/cc/vision/hand_landmarker/hand_landmarker.cc index 2b818b2e5..3bb1ee8d8 100644 --- a/mediapipe/tasks/cc/vision/hand_landmarker/hand_landmarker.cc +++ b/mediapipe/tasks/cc/vision/hand_landmarker/hand_landmarker.cc @@ -46,6 +46,8 @@ namespace { using HandLandmarkerGraphOptionsProto = ::mediapipe::tasks::vision:: hand_landmarker::proto::HandLandmarkerGraphOptions; +using ::mediapipe::NormalizedRect; + constexpr char kHandLandmarkerGraphTypeName[] = "mediapipe.tasks.vision.hand_landmarker.HandLandmarkerGraph"; diff --git a/mediapipe/tasks/cc/vision/hand_landmarker/hand_landmarker_graph.cc b/mediapipe/tasks/cc/vision/hand_landmarker/hand_landmarker_graph.cc index 2c4133eb1..05ad97efe 100644 --- a/mediapipe/tasks/cc/vision/hand_landmarker/hand_landmarker_graph.cc +++ b/mediapipe/tasks/cc/vision/hand_landmarker/hand_landmarker_graph.cc @@ -49,6 +49,7 @@ namespace hand_landmarker { namespace { +using ::mediapipe::NormalizedRect; using ::mediapipe::api2::Input; using ::mediapipe::api2::Output; using ::mediapipe::api2::builder::Graph; diff --git a/mediapipe/tasks/cc/vision/hand_landmarker/hand_landmarker_graph_test.cc b/mediapipe/tasks/cc/vision/hand_landmarker/hand_landmarker_graph_test.cc index f275486f5..c28df2c05 100644 --- a/mediapipe/tasks/cc/vision/hand_landmarker/hand_landmarker_graph_test.cc +++ b/mediapipe/tasks/cc/vision/hand_landmarker/hand_landmarker_graph_test.cc @@ -54,6 +54,7 @@ namespace { using ::file::Defaults; using ::file::GetTextProto; +using ::mediapipe::NormalizedRect; using ::mediapipe::api2::Input; using ::mediapipe::api2::Output; using ::mediapipe::api2::builder::Graph; diff --git a/mediapipe/tasks/cc/vision/hand_landmarker/hand_landmarker_test.cc b/mediapipe/tasks/cc/vision/hand_landmarker/hand_landmarker_test.cc index fa49a4c1f..94d1b1c12 100644 --- a/mediapipe/tasks/cc/vision/hand_landmarker/hand_landmarker_test.cc +++ b/mediapipe/tasks/cc/vision/hand_landmarker/hand_landmarker_test.cc @@ -50,7 +50,7 @@ namespace { using ::file::Defaults; using ::mediapipe::file::JoinPath; -using ::mediapipe::tasks::components::containers::Rect; +using ::mediapipe::tasks::components::containers::RectF; using ::mediapipe::tasks::containers::proto::LandmarksDetectionResult; using ::mediapipe::tasks::vision::core::ImageProcessingOptions; using ::testing::EqualsProto; @@ -188,7 +188,7 @@ TEST_F(ImageModeTest, FailsWithRegionOfInterest) { options->running_mode = core::RunningMode::IMAGE; MP_ASSERT_OK_AND_ASSIGN(std::unique_ptr hand_landmarker, HandLandmarker::Create(std::move(options))); - Rect roi{/*left=*/0.1, /*top=*/0, /*right=*/0.9, /*bottom=*/1}; + RectF roi{/*left=*/0.1, /*top=*/0, /*right=*/0.9, /*bottom=*/1}; ImageProcessingOptions image_processing_options{roi, /*rotation_degrees=*/0}; auto results = hand_landmarker->Detect(image, image_processing_options); diff --git a/mediapipe/tasks/cc/vision/hand_landmarker/hand_landmarks_detector_graph.cc b/mediapipe/tasks/cc/vision/hand_landmarker/hand_landmarks_detector_graph.cc index 014830ba2..4ea066aab 100644 --- a/mediapipe/tasks/cc/vision/hand_landmarker/hand_landmarks_detector_graph.cc +++ b/mediapipe/tasks/cc/vision/hand_landmarker/hand_landmarks_detector_graph.cc @@ -53,6 +53,7 @@ namespace hand_landmarker { namespace { +using ::mediapipe::NormalizedRect; using ::mediapipe::api2::Input; using ::mediapipe::api2::Output; using ::mediapipe::api2::builder::Graph; diff --git a/mediapipe/tasks/cc/vision/hand_landmarker/hand_landmarks_detector_graph_test.cc b/mediapipe/tasks/cc/vision/hand_landmarker/hand_landmarks_detector_graph_test.cc index d1e928ce7..f28907d2f 100644 --- a/mediapipe/tasks/cc/vision/hand_landmarker/hand_landmarks_detector_graph_test.cc +++ b/mediapipe/tasks/cc/vision/hand_landmarker/hand_landmarks_detector_graph_test.cc @@ -50,6 +50,7 @@ namespace { using ::file::Defaults; using ::file::GetTextProto; +using ::mediapipe::NormalizedRect; using ::mediapipe::api2::Input; using ::mediapipe::api2::Output; using ::mediapipe::api2::builder::Graph; diff --git a/mediapipe/tasks/cc/vision/hand_landmarker/proto/hand_landmarker_graph_options.proto b/mediapipe/tasks/cc/vision/hand_landmarker/proto/hand_landmarker_graph_options.proto index 51e4e129a..d0edf99c0 100644 --- a/mediapipe/tasks/cc/vision/hand_landmarker/proto/hand_landmarker_graph_options.proto +++ b/mediapipe/tasks/cc/vision/hand_landmarker/proto/hand_landmarker_graph_options.proto @@ -18,6 +18,7 @@ syntax = "proto2"; package mediapipe.tasks.vision.hand_landmarker.proto; import "mediapipe/framework/calculator.proto"; +import "mediapipe/framework/calculator_options.proto"; import "mediapipe/tasks/cc/core/proto/base_options.proto"; import "mediapipe/tasks/cc/vision/hand_detector/proto/hand_detector_graph_options.proto"; import "mediapipe/tasks/cc/vision/hand_landmarker/proto/hand_landmarks_detector_graph_options.proto"; diff --git a/mediapipe/tasks/cc/vision/hand_landmarker/proto/hand_landmarks_detector_graph_options.proto b/mediapipe/tasks/cc/vision/hand_landmarker/proto/hand_landmarks_detector_graph_options.proto index 195f6e5cc..a2d520963 100644 --- a/mediapipe/tasks/cc/vision/hand_landmarker/proto/hand_landmarks_detector_graph_options.proto +++ b/mediapipe/tasks/cc/vision/hand_landmarker/proto/hand_landmarks_detector_graph_options.proto @@ -18,6 +18,7 @@ syntax = "proto2"; package mediapipe.tasks.vision.hand_landmarker.proto; import "mediapipe/framework/calculator.proto"; +import "mediapipe/framework/calculator_options.proto"; import "mediapipe/tasks/cc/core/proto/base_options.proto"; option java_package = "com.google.mediapipe.tasks.vision.handlandmarker.proto"; diff --git a/mediapipe/tasks/cc/vision/image_classifier/image_classifier.cc b/mediapipe/tasks/cc/vision/image_classifier/image_classifier.cc index 60f8f7ed4..763e0a320 100644 --- a/mediapipe/tasks/cc/vision/image_classifier/image_classifier.cc +++ b/mediapipe/tasks/cc/vision/image_classifier/image_classifier.cc @@ -58,6 +58,7 @@ constexpr char kSubgraphTypeName[] = "mediapipe.tasks.vision.image_classifier.ImageClassifierGraph"; constexpr int kMicroSecondsPerMilliSecond = 1000; +using ::mediapipe::NormalizedRect; using ::mediapipe::tasks::components::containers::ConvertToClassificationResult; using ::mediapipe::tasks::components::containers::proto::ClassificationResult; using ::mediapipe::tasks::core::PacketMap; diff --git a/mediapipe/tasks/cc/vision/image_classifier/image_classifier_graph.cc b/mediapipe/tasks/cc/vision/image_classifier/image_classifier_graph.cc index 2d0379c66..0adcf842d 100644 --- a/mediapipe/tasks/cc/vision/image_classifier/image_classifier_graph.cc +++ b/mediapipe/tasks/cc/vision/image_classifier/image_classifier_graph.cc @@ -38,6 +38,7 @@ namespace image_classifier { namespace { +using ::mediapipe::NormalizedRect; using ::mediapipe::api2::Input; using ::mediapipe::api2::Output; using ::mediapipe::api2::builder::GenericNode; diff --git a/mediapipe/tasks/cc/vision/image_classifier/image_classifier_test.cc b/mediapipe/tasks/cc/vision/image_classifier/image_classifier_test.cc index 1144e9032..7aa2a148c 100644 --- a/mediapipe/tasks/cc/vision/image_classifier/image_classifier_test.cc +++ b/mediapipe/tasks/cc/vision/image_classifier/image_classifier_test.cc @@ -52,7 +52,7 @@ namespace { using ::mediapipe::file::JoinPath; using ::mediapipe::tasks::components::containers::Category; using ::mediapipe::tasks::components::containers::Classifications; -using ::mediapipe::tasks::components::containers::Rect; +using ::mediapipe::tasks::components::containers::RectF; using ::mediapipe::tasks::vision::core::ImageProcessingOptions; using ::testing::HasSubstr; using ::testing::Optional; @@ -472,7 +472,7 @@ TEST_F(ImageModeTest, SucceedsWithRegionOfInterest) { MP_ASSERT_OK_AND_ASSIGN(std::unique_ptr image_classifier, ImageClassifier::Create(std::move(options))); // Region-of-interest around the soccer ball. - Rect roi{/*left=*/0.45, /*top=*/0.3075, /*right=*/0.614, /*bottom=*/0.7345}; + RectF roi{/*left=*/0.45, /*top=*/0.3075, /*right=*/0.614, /*bottom=*/0.7345}; ImageProcessingOptions image_processing_options{roi, /*rotation_degrees=*/0}; MP_ASSERT_OK_AND_ASSIGN(auto results, image_classifier->Classify( @@ -526,7 +526,8 @@ TEST_F(ImageModeTest, SucceedsWithRegionOfInterestAndRotation) { MP_ASSERT_OK_AND_ASSIGN(std::unique_ptr image_classifier, ImageClassifier::Create(std::move(options))); // Region-of-interest around the chair, with 90° anti-clockwise rotation. - Rect roi{/*left=*/0.006, /*top=*/0.1763, /*right=*/0.5702, /*bottom=*/0.3049}; + RectF roi{/*left=*/0.006, /*top=*/0.1763, /*right=*/0.5702, + /*bottom=*/0.3049}; ImageProcessingOptions image_processing_options{roi, /*rotation_degrees=*/-90}; @@ -554,13 +555,13 @@ TEST_F(ImageModeTest, FailsWithInvalidImageProcessingOptions) { ImageClassifier::Create(std::move(options))); // Invalid: left > right. - Rect roi{/*left=*/0.9, /*top=*/0, /*right=*/0.1, /*bottom=*/1}; + RectF roi{/*left=*/0.9, /*top=*/0, /*right=*/0.1, /*bottom=*/1}; ImageProcessingOptions image_processing_options{roi, /*rotation_degrees=*/0}; auto results = image_classifier->Classify(image, image_processing_options); EXPECT_EQ(results.status().code(), absl::StatusCode::kInvalidArgument); EXPECT_THAT(results.status().message(), - HasSubstr("Expected Rect with left < right and top < bottom")); + HasSubstr("Expected RectF with left < right and top < bottom")); EXPECT_THAT( results.status().GetPayload(kMediaPipeTasksPayload), Optional(absl::Cord(absl::StrCat( @@ -573,7 +574,7 @@ TEST_F(ImageModeTest, FailsWithInvalidImageProcessingOptions) { results = image_classifier->Classify(image, image_processing_options); EXPECT_EQ(results.status().code(), absl::StatusCode::kInvalidArgument); EXPECT_THAT(results.status().message(), - HasSubstr("Expected Rect with left < right and top < bottom")); + HasSubstr("Expected RectF with left < right and top < bottom")); EXPECT_THAT( results.status().GetPayload(kMediaPipeTasksPayload), Optional(absl::Cord(absl::StrCat( @@ -586,7 +587,7 @@ TEST_F(ImageModeTest, FailsWithInvalidImageProcessingOptions) { results = image_classifier->Classify(image, image_processing_options); EXPECT_EQ(results.status().code(), absl::StatusCode::kInvalidArgument); EXPECT_THAT(results.status().message(), - HasSubstr("Expected Rect values to be in [0,1]")); + HasSubstr("Expected RectF values to be in [0,1]")); EXPECT_THAT( results.status().GetPayload(kMediaPipeTasksPayload), Optional(absl::Cord(absl::StrCat( @@ -695,7 +696,7 @@ TEST_F(VideoModeTest, SucceedsWithRegionOfInterest) { ImageClassifier::Create(std::move(options))); // Crop around the soccer ball. // Region-of-interest around the soccer ball. - Rect roi{/*left=*/0.45, /*top=*/0.3075, /*right=*/0.614, /*bottom=*/0.7345}; + RectF roi{/*left=*/0.45, /*top=*/0.3075, /*right=*/0.614, /*bottom=*/0.7345}; ImageProcessingOptions image_processing_options{roi, /*rotation_degrees=*/0}; for (int i = 0; i < iterations; ++i) { @@ -837,7 +838,7 @@ TEST_F(LiveStreamModeTest, SucceedsWithRegionOfInterest) { MP_ASSERT_OK_AND_ASSIGN(std::unique_ptr image_classifier, ImageClassifier::Create(std::move(options))); // Crop around the soccer ball. - Rect roi{/*left=*/0.45, /*top=*/0.3075, /*right=*/0.614, /*bottom=*/0.7345}; + RectF roi{/*left=*/0.45, /*top=*/0.3075, /*right=*/0.614, /*bottom=*/0.7345}; ImageProcessingOptions image_processing_options{roi, /*rotation_degrees=*/0}; for (int i = 0; i < iterations; ++i) { diff --git a/mediapipe/tasks/cc/vision/image_classifier/proto/image_classifier_graph_options.proto b/mediapipe/tasks/cc/vision/image_classifier/proto/image_classifier_graph_options.proto index 76315e230..24b126a35 100644 --- a/mediapipe/tasks/cc/vision/image_classifier/proto/image_classifier_graph_options.proto +++ b/mediapipe/tasks/cc/vision/image_classifier/proto/image_classifier_graph_options.proto @@ -18,6 +18,7 @@ syntax = "proto2"; package mediapipe.tasks.vision.image_classifier.proto; import "mediapipe/framework/calculator.proto"; +import "mediapipe/framework/calculator_options.proto"; import "mediapipe/tasks/cc/components/processors/proto/classifier_options.proto"; import "mediapipe/tasks/cc/core/proto/base_options.proto"; diff --git a/mediapipe/tasks/cc/vision/image_embedder/image_embedder.cc b/mediapipe/tasks/cc/vision/image_embedder/image_embedder.cc index e3198090f..494b075a7 100644 --- a/mediapipe/tasks/cc/vision/image_embedder/image_embedder.cc +++ b/mediapipe/tasks/cc/vision/image_embedder/image_embedder.cc @@ -54,6 +54,7 @@ constexpr char kGraphTypeName[] = "mediapipe.tasks.vision.image_embedder.ImageEmbedderGraph"; constexpr int kMicroSecondsPerMilliSecond = 1000; +using ::mediapipe::NormalizedRect; using ::mediapipe::tasks::components::containers::ConvertToEmbeddingResult; using ::mediapipe::tasks::components::containers::proto::EmbeddingResult; using ::mediapipe::tasks::core::PacketMap; diff --git a/mediapipe/tasks/cc/vision/image_embedder/image_embedder_graph.cc b/mediapipe/tasks/cc/vision/image_embedder/image_embedder_graph.cc index 81ccb5361..95c4ff379 100644 --- a/mediapipe/tasks/cc/vision/image_embedder/image_embedder_graph.cc +++ b/mediapipe/tasks/cc/vision/image_embedder/image_embedder_graph.cc @@ -34,6 +34,7 @@ namespace image_embedder { namespace { +using ::mediapipe::NormalizedRect; using ::mediapipe::api2::Input; using ::mediapipe::api2::Output; using ::mediapipe::api2::builder::GenericNode; diff --git a/mediapipe/tasks/cc/vision/image_embedder/image_embedder_test.cc b/mediapipe/tasks/cc/vision/image_embedder/image_embedder_test.cc index 6098a9a70..dd602bef5 100644 --- a/mediapipe/tasks/cc/vision/image_embedder/image_embedder_test.cc +++ b/mediapipe/tasks/cc/vision/image_embedder/image_embedder_test.cc @@ -41,7 +41,7 @@ namespace image_embedder { namespace { using ::mediapipe::file::JoinPath; -using ::mediapipe::tasks::components::containers::Rect; +using ::mediapipe::tasks::components::containers::RectF; using ::mediapipe::tasks::vision::core::ImageProcessingOptions; using ::testing::HasSubstr; using ::testing::Optional; @@ -320,7 +320,7 @@ TEST_F(ImageModeTest, SucceedsWithRegionOfInterest) { Image crop, DecodeImageFromFile( JoinPath("./", kTestDataDirectory, "burger_crop.jpg"))); // Region-of-interest in "burger.jpg" corresponding to "burger_crop.jpg". - Rect roi{/*left=*/0, /*top=*/0, /*right=*/0.833333, /*bottom=*/1}; + RectF roi{/*left=*/0, /*top=*/0, /*right=*/0.833333, /*bottom=*/1}; ImageProcessingOptions image_processing_options{roi, /*rotation_degrees=*/0}; // Extract both embeddings. @@ -388,7 +388,7 @@ TEST_F(ImageModeTest, SucceedsWithRegionOfInterestAndRotation) { DecodeImageFromFile(JoinPath("./", kTestDataDirectory, "burger_rotated.jpg"))); // Region-of-interest corresponding to burger_crop.jpg. - Rect roi{/*left=*/0, /*top=*/0, /*right=*/1, /*bottom=*/0.8333333}; + RectF roi{/*left=*/0, /*top=*/0, /*right=*/1, /*bottom=*/0.8333333}; ImageProcessingOptions image_processing_options{roi, /*rotation_degrees=*/-90}; diff --git a/mediapipe/tasks/cc/vision/image_embedder/proto/image_embedder_graph_options.proto b/mediapipe/tasks/cc/vision/image_embedder/proto/image_embedder_graph_options.proto index 72b3e7ee3..24ee866f2 100644 --- a/mediapipe/tasks/cc/vision/image_embedder/proto/image_embedder_graph_options.proto +++ b/mediapipe/tasks/cc/vision/image_embedder/proto/image_embedder_graph_options.proto @@ -18,6 +18,7 @@ syntax = "proto2"; package mediapipe.tasks.vision.image_embedder.proto; import "mediapipe/framework/calculator.proto"; +import "mediapipe/framework/calculator_options.proto"; import "mediapipe/tasks/cc/components/processors/proto/embedder_options.proto"; import "mediapipe/tasks/cc/core/proto/base_options.proto"; diff --git a/mediapipe/tasks/cc/vision/image_segmenter/BUILD b/mediapipe/tasks/cc/vision/image_segmenter/BUILD index 2124fe6e0..4c9c6e69c 100644 --- a/mediapipe/tasks/cc/vision/image_segmenter/BUILD +++ b/mediapipe/tasks/cc/vision/image_segmenter/BUILD @@ -28,7 +28,6 @@ cc_library( "//mediapipe/framework/api2:builder", "//mediapipe/framework/formats:image", "//mediapipe/framework/formats:rect_cc_proto", - "//mediapipe/tasks/cc/components/proto:segmenter_options_cc_proto", "//mediapipe/tasks/cc/core:base_options", "//mediapipe/tasks/cc/core:utils", "//mediapipe/tasks/cc/vision/core:base_vision_task_api", @@ -36,6 +35,7 @@ cc_library( "//mediapipe/tasks/cc/vision/core:running_mode", "//mediapipe/tasks/cc/vision/core:vision_task_api_factory", "//mediapipe/tasks/cc/vision/image_segmenter/proto:image_segmenter_graph_options_cc_proto", + "//mediapipe/tasks/cc/vision/image_segmenter/proto:segmenter_options_cc_proto", "@com_google_absl//absl/memory", "@com_google_absl//absl/status:statusor", "@org_tensorflow//tensorflow/lite/kernels:builtin_ops", @@ -56,17 +56,17 @@ cc_library( "//mediapipe/framework/formats:rect_cc_proto", "//mediapipe/framework/port:status", "//mediapipe/tasks/cc:common", - "//mediapipe/tasks/cc/components/calculators/tensor:tensors_to_segmentation_calculator", - "//mediapipe/tasks/cc/components/calculators/tensor:tensors_to_segmentation_calculator_cc_proto", "//mediapipe/tasks/cc/components/processors:image_preprocessing_graph", "//mediapipe/tasks/cc/components/processors/proto:image_preprocessing_graph_options_cc_proto", - "//mediapipe/tasks/cc/components/proto:segmenter_options_cc_proto", "//mediapipe/tasks/cc/core:model_resources", "//mediapipe/tasks/cc/core:model_task_graph", "//mediapipe/tasks/cc/core/proto:acceleration_cc_proto", "//mediapipe/tasks/cc/core/proto:inference_subgraph_cc_proto", "//mediapipe/tasks/cc/metadata:metadata_extractor", + "//mediapipe/tasks/cc/vision/image_segmenter/calculators:tensors_to_segmentation_calculator", + "//mediapipe/tasks/cc/vision/image_segmenter/calculators:tensors_to_segmentation_calculator_cc_proto", "//mediapipe/tasks/cc/vision/image_segmenter/proto:image_segmenter_graph_options_cc_proto", + "//mediapipe/tasks/cc/vision/image_segmenter/proto:segmenter_options_cc_proto", "//mediapipe/tasks/metadata:metadata_schema_cc", "//mediapipe/util:label_map_cc_proto", "//mediapipe/util:label_map_util", diff --git a/mediapipe/tasks/cc/components/calculators/tensor/BUILD b/mediapipe/tasks/cc/vision/image_segmenter/calculators/BUILD similarity index 94% rename from mediapipe/tasks/cc/components/calculators/tensor/BUILD rename to mediapipe/tasks/cc/vision/image_segmenter/calculators/BUILD index 6e4322a8f..dcd7fb407 100644 --- a/mediapipe/tasks/cc/components/calculators/tensor/BUILD +++ b/mediapipe/tasks/cc/vision/image_segmenter/calculators/BUILD @@ -25,7 +25,7 @@ mediapipe_proto_library( "//mediapipe/framework:calculator_options_proto", "//mediapipe/framework:calculator_proto", "//mediapipe/framework/formats:image_format_proto", - "//mediapipe/tasks/cc/components/proto:segmenter_options_proto", + "//mediapipe/tasks/cc/vision/image_segmenter/proto:segmenter_options_proto", "//mediapipe/util:label_map_proto", ], ) @@ -45,7 +45,7 @@ cc_library( "//mediapipe/framework/port:opencv_core", "//mediapipe/framework/port:opencv_imgproc", "//mediapipe/framework/port:status", - "//mediapipe/tasks/cc/components/proto:segmenter_options_cc_proto", + "//mediapipe/tasks/cc/vision/image_segmenter/proto:segmenter_options_cc_proto", "//mediapipe/tasks/cc/vision/utils:image_utils", "//mediapipe/util:label_map_cc_proto", "@com_google_absl//absl/status", diff --git a/mediapipe/tasks/cc/components/calculators/tensor/tensors_to_segmentation_calculator.cc b/mediapipe/tasks/cc/vision/image_segmenter/calculators/tensors_to_segmentation_calculator.cc similarity index 95% rename from mediapipe/tasks/cc/components/calculators/tensor/tensors_to_segmentation_calculator.cc rename to mediapipe/tasks/cc/vision/image_segmenter/calculators/tensors_to_segmentation_calculator.cc index 40585848f..668de0057 100644 --- a/mediapipe/tasks/cc/components/calculators/tensor/tensors_to_segmentation_calculator.cc +++ b/mediapipe/tasks/cc/vision/image_segmenter/calculators/tensors_to_segmentation_calculator.cc @@ -13,7 +13,6 @@ See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ -// TODO consolidate TensorsToSegmentationCalculator. #include #include #include @@ -35,14 +34,14 @@ limitations under the License. #include "mediapipe/framework/port/opencv_core_inc.h" #include "mediapipe/framework/port/opencv_imgproc_inc.h" #include "mediapipe/framework/port/status_macros.h" -#include "mediapipe/tasks/cc/components/calculators/tensor/tensors_to_segmentation_calculator.pb.h" -#include "mediapipe/tasks/cc/components/proto/segmenter_options.pb.h" +#include "mediapipe/tasks/cc/vision/image_segmenter/calculators/tensors_to_segmentation_calculator.pb.h" +#include "mediapipe/tasks/cc/vision/image_segmenter/proto/segmenter_options.pb.h" #include "mediapipe/tasks/cc/vision/utils/image_utils.h" #include "mediapipe/util/label_map.pb.h" +// TODO: consolidate TensorToSegmentationCalculator. namespace mediapipe { namespace tasks { - namespace { using ::mediapipe::Image; @@ -51,9 +50,9 @@ using ::mediapipe::api2::Input; using ::mediapipe::api2::Node; using ::mediapipe::api2::Output; using ::mediapipe::tasks::TensorsToSegmentationCalculatorOptions; -using ::mediapipe::tasks::components::proto::SegmenterOptions; using ::mediapipe::tasks::vision::GetImageLikeTensorShape; using ::mediapipe::tasks::vision::Shape; +using ::mediapipe::tasks::vision::image_segmenter::proto::SegmenterOptions; void StableSoftmax(absl::Span values, absl::Span activated_values) { @@ -90,7 +89,7 @@ void Sigmoid(absl::Span values, // the size to resize masks to. // // Output: -// Segmentation: Segmenation proto. +// Segmentation: Segmentation proto. // // Options: // See tensors_to_segmentation_calculator.proto @@ -132,8 +131,7 @@ class TensorsToSegmentationCalculator : public Node { absl::Status TensorsToSegmentationCalculator::Open( mediapipe::CalculatorContext* cc) { - options_ = - cc->Options(); + options_ = cc->Options(); RET_CHECK_NE(options_.segmenter_options().output_type(), SegmenterOptions::UNSPECIFIED) << "Must specify output_type as one of [CONFIDENCE_MASK|CATEGORY_MASK]."; diff --git a/mediapipe/tasks/cc/components/calculators/tensor/tensors_to_segmentation_calculator.proto b/mediapipe/tasks/cc/vision/image_segmenter/calculators/tensors_to_segmentation_calculator.proto similarity index 82% rename from mediapipe/tasks/cc/components/calculators/tensor/tensors_to_segmentation_calculator.proto rename to mediapipe/tasks/cc/vision/image_segmenter/calculators/tensors_to_segmentation_calculator.proto index c26cf910a..b0fdfdd32 100644 --- a/mediapipe/tasks/cc/components/calculators/tensor/tensors_to_segmentation_calculator.proto +++ b/mediapipe/tasks/cc/vision/image_segmenter/calculators/tensors_to_segmentation_calculator.proto @@ -15,10 +15,11 @@ limitations under the License. syntax = "proto2"; +// TODO: consolidate TensorToSegmentationCalculator. package mediapipe.tasks; import "mediapipe/framework/calculator.proto"; -import "mediapipe/tasks/cc/components/proto/segmenter_options.proto"; +import "mediapipe/tasks/cc/vision/image_segmenter/proto/segmenter_options.proto"; import "mediapipe/util/label_map.proto"; message TensorsToSegmentationCalculatorOptions { @@ -26,7 +27,8 @@ message TensorsToSegmentationCalculatorOptions { optional TensorsToSegmentationCalculatorOptions ext = 458105876; } - optional components.proto.SegmenterOptions segmenter_options = 1; + optional mediapipe.tasks.vision.image_segmenter.proto.SegmenterOptions + segmenter_options = 1; // Identifying information for each classification label. map label_items = 2; diff --git a/mediapipe/tasks/cc/components/calculators/tensor/tensors_to_segmentation_calculator_test.cc b/mediapipe/tasks/cc/vision/image_segmenter/calculators/tensors_to_segmentation_calculator_test.cc similarity index 99% rename from mediapipe/tasks/cc/components/calculators/tensor/tensors_to_segmentation_calculator_test.cc rename to mediapipe/tasks/cc/vision/image_segmenter/calculators/tensors_to_segmentation_calculator_test.cc index 55e46d72b..54fb9b816 100644 --- a/mediapipe/tasks/cc/components/calculators/tensor/tensors_to_segmentation_calculator_test.cc +++ b/mediapipe/tasks/cc/vision/image_segmenter/calculators/tensors_to_segmentation_calculator_test.cc @@ -33,10 +33,9 @@ limitations under the License. #include "mediapipe/framework/port/gtest.h" #include "mediapipe/framework/port/parse_text_proto.h" #include "mediapipe/framework/port/status_matchers.h" -#include "mediapipe/tasks/cc/components/calculators/tensor/tensors_to_segmentation_calculator.pb.h" +#include "mediapipe/tasks/cc/vision/image_segmenter/calculators/tensors_to_segmentation_calculator.pb.h" namespace mediapipe { -namespace api2 { namespace { @@ -374,5 +373,4 @@ TEST(TensorsToSegmentationCalculatorTest, SucceedsCategoryMaskResize) { expected_index, buffer_indices))); } -} // namespace api2 } // namespace mediapipe diff --git a/mediapipe/tasks/cc/vision/image_segmenter/image_segmenter.cc b/mediapipe/tasks/cc/vision/image_segmenter/image_segmenter.cc index 6dce1b4ea..7130c72e2 100644 --- a/mediapipe/tasks/cc/vision/image_segmenter/image_segmenter.cc +++ b/mediapipe/tasks/cc/vision/image_segmenter/image_segmenter.cc @@ -18,12 +18,12 @@ limitations under the License. #include "mediapipe/framework/api2/builder.h" #include "mediapipe/framework/formats/image.h" #include "mediapipe/framework/formats/rect.pb.h" -#include "mediapipe/tasks/cc/components/proto/segmenter_options.pb.h" #include "mediapipe/tasks/cc/core/utils.h" #include "mediapipe/tasks/cc/vision/core/image_processing_options.h" #include "mediapipe/tasks/cc/vision/core/running_mode.h" #include "mediapipe/tasks/cc/vision/core/vision_task_api_factory.h" #include "mediapipe/tasks/cc/vision/image_segmenter/proto/image_segmenter_graph_options.pb.h" +#include "mediapipe/tasks/cc/vision/image_segmenter/proto/segmenter_options.pb.h" namespace mediapipe { namespace tasks { @@ -44,7 +44,8 @@ constexpr int kMicroSecondsPerMilliSecond = 1000; using ::mediapipe::CalculatorGraphConfig; using ::mediapipe::Image; -using ::mediapipe::tasks::components::proto::SegmenterOptions; +using ::mediapipe::NormalizedRect; +using ::mediapipe::tasks::vision::image_segmenter::proto::SegmenterOptions; using ImageSegmenterGraphOptionsProto = ::mediapipe::tasks::vision:: image_segmenter::proto::ImageSegmenterGraphOptions; diff --git a/mediapipe/tasks/cc/vision/image_segmenter/image_segmenter_graph.cc b/mediapipe/tasks/cc/vision/image_segmenter/image_segmenter_graph.cc index d5eb5af0d..923cf2937 100644 --- a/mediapipe/tasks/cc/vision/image_segmenter/image_segmenter_graph.cc +++ b/mediapipe/tasks/cc/vision/image_segmenter/image_segmenter_graph.cc @@ -26,16 +26,16 @@ limitations under the License. #include "mediapipe/framework/formats/rect.pb.h" #include "mediapipe/framework/port/status_macros.h" #include "mediapipe/tasks/cc/common.h" -#include "mediapipe/tasks/cc/components/calculators/tensor/tensors_to_segmentation_calculator.pb.h" #include "mediapipe/tasks/cc/components/processors/image_preprocessing_graph.h" #include "mediapipe/tasks/cc/components/processors/proto/image_preprocessing_graph_options.pb.h" -#include "mediapipe/tasks/cc/components/proto/segmenter_options.pb.h" #include "mediapipe/tasks/cc/core/model_resources.h" #include "mediapipe/tasks/cc/core/model_task_graph.h" #include "mediapipe/tasks/cc/core/proto/acceleration.pb.h" #include "mediapipe/tasks/cc/core/proto/inference_subgraph.pb.h" #include "mediapipe/tasks/cc/metadata/metadata_extractor.h" +#include "mediapipe/tasks/cc/vision/image_segmenter/calculators/tensors_to_segmentation_calculator.pb.h" #include "mediapipe/tasks/cc/vision/image_segmenter/proto/image_segmenter_graph_options.pb.h" +#include "mediapipe/tasks/cc/vision/image_segmenter/proto/segmenter_options.pb.h" #include "mediapipe/tasks/metadata/metadata_schema_generated.h" #include "mediapipe/util/label_map.pb.h" #include "mediapipe/util/label_map_util.h" @@ -49,15 +49,16 @@ namespace image_segmenter { namespace { using ::mediapipe::Image; +using ::mediapipe::NormalizedRect; using ::mediapipe::api2::Input; using ::mediapipe::api2::Output; using ::mediapipe::api2::builder::Graph; using ::mediapipe::api2::builder::MultiSource; using ::mediapipe::api2::builder::Source; -using ::mediapipe::tasks::components::proto::SegmenterOptions; using ::mediapipe::tasks::metadata::ModelMetadataExtractor; using ::mediapipe::tasks::vision::image_segmenter::proto:: ImageSegmenterGraphOptions; +using ::mediapipe::tasks::vision::image_segmenter::proto::SegmenterOptions; using ::tflite::Tensor; using ::tflite::TensorMetadata; using LabelItems = mediapipe::proto_ns::Map; diff --git a/mediapipe/tasks/cc/vision/image_segmenter/image_segmenter_test.cc b/mediapipe/tasks/cc/vision/image_segmenter/image_segmenter_test.cc index 752a116dd..f9618c1b1 100644 --- a/mediapipe/tasks/cc/vision/image_segmenter/image_segmenter_test.cc +++ b/mediapipe/tasks/cc/vision/image_segmenter/image_segmenter_test.cc @@ -28,11 +28,11 @@ limitations under the License. #include "mediapipe/framework/port/opencv_core_inc.h" #include "mediapipe/framework/port/opencv_imgcodecs_inc.h" #include "mediapipe/framework/port/status_matchers.h" -#include "mediapipe/tasks/cc/components/calculators/tensor/tensors_to_segmentation_calculator.pb.h" #include "mediapipe/tasks/cc/components/containers/rect.h" #include "mediapipe/tasks/cc/core/proto/base_options.pb.h" #include "mediapipe/tasks/cc/core/proto/external_file.pb.h" #include "mediapipe/tasks/cc/vision/core/image_processing_options.h" +#include "mediapipe/tasks/cc/vision/image_segmenter/calculators/tensors_to_segmentation_calculator.pb.h" #include "mediapipe/tasks/cc/vision/image_segmenter/proto/image_segmenter_graph_options.pb.h" #include "mediapipe/tasks/cc/vision/utils/image_utils.h" #include "tensorflow/lite/core/shims/cc/shims_test_util.h" @@ -47,7 +47,7 @@ namespace { using ::mediapipe::Image; using ::mediapipe::file::JoinPath; -using ::mediapipe::tasks::components::containers::Rect; +using ::mediapipe::tasks::components::containers::RectF; using ::mediapipe::tasks::vision::core::ImageProcessingOptions; using ::testing::HasSubstr; using ::testing::Optional; @@ -299,7 +299,7 @@ TEST_F(ImageModeTest, FailsWithRegionOfInterest) { MP_ASSERT_OK_AND_ASSIGN(std::unique_ptr segmenter, ImageSegmenter::Create(std::move(options))); - Rect roi{/*left=*/0.1, /*top=*/0, /*right=*/0.9, /*bottom=*/1}; + RectF roi{/*left=*/0.1, /*top=*/0, /*right=*/0.9, /*bottom=*/1}; ImageProcessingOptions image_processing_options{roi, /*rotation_degrees=*/0}; auto results = segmenter->Segment(image, image_processing_options); diff --git a/mediapipe/tasks/cc/vision/image_segmenter/proto/BUILD b/mediapipe/tasks/cc/vision/image_segmenter/proto/BUILD index 3b14060f1..9523dd679 100644 --- a/mediapipe/tasks/cc/vision/image_segmenter/proto/BUILD +++ b/mediapipe/tasks/cc/vision/image_segmenter/proto/BUILD @@ -18,13 +18,18 @@ package(default_visibility = ["//mediapipe/tasks:internal"]) licenses(["notice"]) +mediapipe_proto_library( + name = "segmenter_options_proto", + srcs = ["segmenter_options.proto"], +) + mediapipe_proto_library( name = "image_segmenter_graph_options_proto", srcs = ["image_segmenter_graph_options.proto"], deps = [ + ":segmenter_options_proto", "//mediapipe/framework:calculator_options_proto", "//mediapipe/framework:calculator_proto", - "//mediapipe/tasks/cc/components/proto:segmenter_options_proto", "//mediapipe/tasks/cc/core/proto:base_options_proto", ], ) diff --git a/mediapipe/tasks/cc/vision/image_segmenter/proto/image_segmenter_graph_options.proto b/mediapipe/tasks/cc/vision/image_segmenter/proto/image_segmenter_graph_options.proto index 166e2e8e0..5c7d2ec71 100644 --- a/mediapipe/tasks/cc/vision/image_segmenter/proto/image_segmenter_graph_options.proto +++ b/mediapipe/tasks/cc/vision/image_segmenter/proto/image_segmenter_graph_options.proto @@ -18,8 +18,9 @@ syntax = "proto2"; package mediapipe.tasks.vision.image_segmenter.proto; import "mediapipe/framework/calculator.proto"; -import "mediapipe/tasks/cc/components/proto/segmenter_options.proto"; +import "mediapipe/framework/calculator_options.proto"; import "mediapipe/tasks/cc/core/proto/base_options.proto"; +import "mediapipe/tasks/cc/vision/image_segmenter/proto/segmenter_options.proto"; option java_package = "com.google.mediapipe.tasks.vision.imagesegmenter.proto"; option java_outer_classname = "ImageSegmenterGraphOptionsProto"; @@ -37,5 +38,5 @@ message ImageSegmenterGraphOptions { optional string display_names_locale = 2 [default = "en"]; // Segmentation output options. - optional components.proto.SegmenterOptions segmenter_options = 3; + optional SegmenterOptions segmenter_options = 3; } diff --git a/mediapipe/tasks/cc/components/proto/segmenter_options.proto b/mediapipe/tasks/cc/vision/image_segmenter/proto/segmenter_options.proto similarity index 92% rename from mediapipe/tasks/cc/components/proto/segmenter_options.proto rename to mediapipe/tasks/cc/vision/image_segmenter/proto/segmenter_options.proto index ca9986707..be2b8a589 100644 --- a/mediapipe/tasks/cc/components/proto/segmenter_options.proto +++ b/mediapipe/tasks/cc/vision/image_segmenter/proto/segmenter_options.proto @@ -15,9 +15,9 @@ limitations under the License. syntax = "proto2"; -package mediapipe.tasks.components.proto; +package mediapipe.tasks.vision.image_segmenter.proto; -option java_package = "com.google.mediapipe.tasks.components.proto"; +option java_package = "com.google.mediapipe.tasks.vision.imagesegmenter.proto"; option java_outer_classname = "SegmenterOptionsProto"; // Shared options used by image segmentation tasks. diff --git a/mediapipe/tasks/cc/vision/object_detector/BUILD b/mediapipe/tasks/cc/vision/object_detector/BUILD index c2dd9995d..5269796ae 100644 --- a/mediapipe/tasks/cc/vision/object_detector/BUILD +++ b/mediapipe/tasks/cc/vision/object_detector/BUILD @@ -22,9 +22,7 @@ cc_library( name = "object_detector", srcs = ["object_detector.cc"], hdrs = ["object_detector.h"], - visibility = [ - "//mediapipe/tasks:users", - ], + visibility = ["//visibility:public"], deps = [ ":object_detector_graph", "//mediapipe/calculators/core:concatenate_vector_calculator", @@ -35,6 +33,7 @@ cc_library( "//mediapipe/framework/formats:rect_cc_proto", "//mediapipe/tasks/cc:common", "//mediapipe/tasks/cc/components/calculators:score_calibration_calculator", + "//mediapipe/tasks/cc/components/containers:detection_result", "//mediapipe/tasks/cc/core:base_options", "//mediapipe/tasks/cc/core:utils", "//mediapipe/tasks/cc/core/proto:base_options_cc_proto", @@ -63,6 +62,7 @@ cc_library( "//mediapipe/calculators/util:detection_label_id_to_text_calculator_cc_proto", "//mediapipe/calculators/util:detection_projection_calculator", "//mediapipe/calculators/util:detection_transformation_calculator", + "//mediapipe/calculators/util:detections_deduplicate_calculator", "//mediapipe/framework:calculator_cc_proto", "//mediapipe/framework/api2:builder", "//mediapipe/framework/api2:port", diff --git a/mediapipe/tasks/cc/vision/object_detector/object_detector.cc b/mediapipe/tasks/cc/vision/object_detector/object_detector.cc index dd19237ff..2477f8a44 100644 --- a/mediapipe/tasks/cc/vision/object_detector/object_detector.cc +++ b/mediapipe/tasks/cc/vision/object_detector/object_detector.cc @@ -30,6 +30,7 @@ limitations under the License. #include "mediapipe/framework/formats/image.h" #include "mediapipe/framework/formats/rect.pb.h" #include "mediapipe/tasks/cc/common.h" +#include "mediapipe/tasks/cc/components/containers/detection_result.h" #include "mediapipe/tasks/cc/core/base_options.h" #include "mediapipe/tasks/cc/core/proto/base_options.pb.h" #include "mediapipe/tasks/cc/core/proto/inference_subgraph.pb.h" @@ -56,6 +57,8 @@ constexpr char kSubgraphTypeName[] = "mediapipe.tasks.vision.ObjectDetectorGraph"; constexpr int kMicroSecondsPerMilliSecond = 1000; +using ::mediapipe::NormalizedRect; +using ::mediapipe::tasks::components::containers::ConvertToDetectionResult; using ObjectDetectorOptionsProto = object_detector::proto::ObjectDetectorOptions; @@ -129,7 +132,8 @@ absl::StatusOr> ObjectDetector::Create( Packet detections_packet = status_or_packets.value()[kDetectionsOutStreamName]; Packet image_packet = status_or_packets.value()[kImageOutStreamName]; - result_callback(detections_packet.Get>(), + result_callback(ConvertToDetectionResult( + detections_packet.Get>()), image_packet.Get(), detections_packet.Timestamp().Value() / kMicroSecondsPerMilliSecond); @@ -144,7 +148,7 @@ absl::StatusOr> ObjectDetector::Create( std::move(packets_callback)); } -absl::StatusOr> ObjectDetector::Detect( +absl::StatusOr ObjectDetector::Detect( mediapipe::Image image, std::optional image_processing_options) { if (image.UsesGpu()) { @@ -161,10 +165,11 @@ absl::StatusOr> ObjectDetector::Detect( ProcessImageData( {{kImageInStreamName, MakePacket(std::move(image))}, {kNormRectName, MakePacket(std::move(norm_rect))}})); - return output_packets[kDetectionsOutStreamName].Get>(); + return ConvertToDetectionResult( + output_packets[kDetectionsOutStreamName].Get>()); } -absl::StatusOr> ObjectDetector::DetectForVideo( +absl::StatusOr ObjectDetector::DetectForVideo( mediapipe::Image image, int64 timestamp_ms, std::optional image_processing_options) { if (image.UsesGpu()) { @@ -185,7 +190,8 @@ absl::StatusOr> ObjectDetector::DetectForVideo( {kNormRectName, MakePacket(std::move(norm_rect)) .At(Timestamp(timestamp_ms * kMicroSecondsPerMilliSecond))}})); - return output_packets[kDetectionsOutStreamName].Get>(); + return ConvertToDetectionResult( + output_packets[kDetectionsOutStreamName].Get>()); } absl::Status ObjectDetector::DetectAsync( diff --git a/mediapipe/tasks/cc/vision/object_detector/object_detector.h b/mediapipe/tasks/cc/vision/object_detector/object_detector.h index 44ce68ed9..249a2ebf5 100644 --- a/mediapipe/tasks/cc/vision/object_detector/object_detector.h +++ b/mediapipe/tasks/cc/vision/object_detector/object_detector.h @@ -27,6 +27,7 @@ limitations under the License. #include "absl/status/statusor.h" #include "mediapipe/framework/formats/detection.pb.h" #include "mediapipe/framework/formats/image.h" +#include "mediapipe/tasks/cc/components/containers/detection_result.h" #include "mediapipe/tasks/cc/core/base_options.h" #include "mediapipe/tasks/cc/vision/core/base_vision_task_api.h" #include "mediapipe/tasks/cc/vision/core/image_processing_options.h" @@ -36,6 +37,10 @@ namespace mediapipe { namespace tasks { namespace vision { +// Alias the shared DetectionResult struct as result typo. +using ObjectDetectorResult = + ::mediapipe::tasks::components::containers::DetectionResult; + // The options for configuring a mediapipe object detector task. struct ObjectDetectorOptions { // Base options for configuring MediaPipe Tasks, such as specifying the TfLite @@ -79,8 +84,7 @@ struct ObjectDetectorOptions { // The user-defined result callback for processing live stream data. // The result callback should only be specified when the running mode is set // to RunningMode::LIVE_STREAM. - std::function>, - const Image&, int64)> + std::function, const Image&, int64)> result_callback = nullptr; }; @@ -165,7 +169,7 @@ class ObjectDetector : tasks::vision::core::BaseVisionTaskApi { // underlying image data. // TODO: Describes the output bounding boxes for gpu input // images after enabling the gpu support in MediaPipe Tasks. - absl::StatusOr> Detect( + absl::StatusOr Detect( mediapipe::Image image, std::optional image_processing_options = std::nullopt); @@ -188,7 +192,7 @@ class ObjectDetector : tasks::vision::core::BaseVisionTaskApi { // unrotated input frame of reference coordinates system, i.e. in `[0, // image_width) x [0, image_height)`, which are the dimensions of the // underlying image data. - absl::StatusOr> DetectForVideo( + absl::StatusOr DetectForVideo( mediapipe::Image image, int64 timestamp_ms, std::optional image_processing_options = std::nullopt); diff --git a/mediapipe/tasks/cc/vision/object_detector/object_detector_graph.cc b/mediapipe/tasks/cc/vision/object_detector/object_detector_graph.cc index f5dc7e061..e5af7544d 100644 --- a/mediapipe/tasks/cc/vision/object_detector/object_detector_graph.cc +++ b/mediapipe/tasks/cc/vision/object_detector/object_detector_graph.cc @@ -52,6 +52,7 @@ namespace vision { namespace { +using ::mediapipe::NormalizedRect; using ::mediapipe::api2::Input; using ::mediapipe::api2::Output; using ::mediapipe::api2::builder::Graph; @@ -532,8 +533,7 @@ class ObjectDetectorGraph : public core::ModelTaskGraph { MP_RETURN_IF_ERROR(SanityCheckOptions(task_options)); // Checks that the model has 4 outputs. auto& model = *model_resources.GetTfLiteModel(); - if (model.subgraphs()->size() != 1 || - (*model.subgraphs())[0]->outputs()->size() != 4) { + if (model.subgraphs()->size() != 1) { return CreateStatusWithPayload( absl::StatusCode::kInvalidArgument, absl::StrFormat("Expected a model with a single subgraph, found %d.", @@ -663,11 +663,16 @@ class ObjectDetectorGraph : public core::ModelTaskGraph { detection_transformation.Out(kPixelDetectionsTag) >> detection_label_id_to_text.In(""); + // Deduplicate Detections with same bounding box coordinates. + auto& detections_deduplicate = + graph.AddNode("DetectionsDeduplicateCalculator"); + detection_label_id_to_text.Out("") >> detections_deduplicate.In(""); + // Outputs the labeled detections and the processed image as the subgraph // output streams. return {{ /* detections= */ - detection_label_id_to_text[Output>("")], + detections_deduplicate[Output>("")], /* image= */ preprocessing[Output(kImageTag)], }}; } diff --git a/mediapipe/tasks/cc/vision/object_detector/object_detector_test.cc b/mediapipe/tasks/cc/vision/object_detector/object_detector_test.cc index 1747685dd..798e3f238 100644 --- a/mediapipe/tasks/cc/vision/object_detector/object_detector_test.cc +++ b/mediapipe/tasks/cc/vision/object_detector/object_detector_test.cc @@ -35,6 +35,7 @@ limitations under the License. #include "mediapipe/framework/port/gtest.h" #include "mediapipe/framework/port/parse_text_proto.h" #include "mediapipe/framework/port/status_matchers.h" +#include "mediapipe/tasks/cc/components/containers/detection_result.h" #include "mediapipe/tasks/cc/components/containers/rect.h" #include "mediapipe/tasks/cc/vision/core/image_processing_options.h" #include "mediapipe/tasks/cc/vision/core/running_mode.h" @@ -65,10 +66,14 @@ namespace vision { namespace { using ::mediapipe::file::JoinPath; -using ::mediapipe::tasks::components::containers::Rect; +using ::mediapipe::tasks::components::containers::ConvertToDetectionResult; +using ::mediapipe::tasks::components::containers::Detection; +using ::mediapipe::tasks::components::containers::DetectionResult; +using ::mediapipe::tasks::components::containers::RectF; using ::mediapipe::tasks::vision::core::ImageProcessingOptions; using ::testing::HasSubstr; using ::testing::Optional; +using DetectionProto = mediapipe::Detection; constexpr char kTestDataDirectory[] = "/mediapipe/tasks/testdata/vision/"; constexpr char kMobileSsdWithMetadata[] = @@ -83,47 +88,45 @@ constexpr char kEfficientDetWithMetadata[] = // Checks that the two provided `Detection` proto vectors are equal, with a // tolerancy on floating-point scores to account for numerical instabilities. // If the proto definition changes, please also change this function. -void ExpectApproximatelyEqual(const std::vector& actual, - const std::vector& expected) { +void ExpectApproximatelyEqual(const ObjectDetectorResult& actual, + const ObjectDetectorResult& expected) { const float kPrecision = 1e-6; - EXPECT_EQ(actual.size(), expected.size()); - for (int i = 0; i < actual.size(); ++i) { - const Detection& a = actual[i]; - const Detection& b = expected[i]; - EXPECT_THAT(a.location_data().bounding_box(), - EqualsProto(b.location_data().bounding_box())); - EXPECT_EQ(a.label_size(), 1); - EXPECT_EQ(b.label_size(), 1); - EXPECT_EQ(a.label(0), b.label(0)); - EXPECT_EQ(a.score_size(), 1); - EXPECT_EQ(b.score_size(), 1); - EXPECT_NEAR(a.score(0), b.score(0), kPrecision); + EXPECT_EQ(actual.detections.size(), expected.detections.size()); + for (int i = 0; i < actual.detections.size(); ++i) { + const Detection& a = actual.detections[i]; + const Detection& b = expected.detections[i]; + EXPECT_EQ(a.bounding_box, b.bounding_box); + EXPECT_EQ(a.categories.size(), 1); + EXPECT_EQ(b.categories.size(), 1); + EXPECT_EQ(a.categories[0].category_name, b.categories[0].category_name); + EXPECT_NEAR(a.categories[0].score, b.categories[0].score, kPrecision); } } -std::vector GenerateMobileSsdNoImageResizingFullExpectedResults() { - return {ParseTextProtoOrDie(R"pb( +std::vector +GenerateMobileSsdNoImageResizingFullExpectedResults() { + return {ParseTextProtoOrDie(R"pb( label: "cat" score: 0.6328125 location_data { format: BOUNDING_BOX bounding_box { xmin: 14 ymin: 197 width: 98 height: 99 } })pb"), - ParseTextProtoOrDie(R"pb( + ParseTextProtoOrDie(R"pb( label: "cat" score: 0.59765625 location_data { format: BOUNDING_BOX bounding_box { xmin: 151 ymin: 78 width: 104 height: 223 } })pb"), - ParseTextProtoOrDie(R"pb( + ParseTextProtoOrDie(R"pb( label: "cat" score: 0.5 location_data { format: BOUNDING_BOX bounding_box { xmin: 65 ymin: 199 width: 41 height: 101 } })pb"), - ParseTextProtoOrDie(R"pb( + ParseTextProtoOrDie(R"pb( label: "dog" score: 0.48828125 location_data { @@ -263,8 +266,8 @@ TEST_F(CreateFromOptionsTest, FailsWithIllegalCallbackInImageOrVideoMode) { JoinPath("./", kTestDataDirectory, kMobileSsdWithMetadata); options->running_mode = running_mode; options->result_callback = - [](absl::StatusOr> detections, - const Image& image, int64 timestamp_ms) {}; + [](absl::StatusOr detections, const Image& image, + int64 timestamp_ms) {}; absl::StatusOr> object_detector = ObjectDetector::Create(std::move(options)); EXPECT_EQ(object_detector.status().code(), @@ -340,34 +343,36 @@ TEST_F(ImageModeTest, Succeeds) { MP_ASSERT_OK_AND_ASSIGN(auto results, object_detector->Detect(image)); MP_ASSERT_OK(object_detector->Close()); ExpectApproximatelyEqual( - results, {ParseTextProtoOrDie(R"pb( - label: "cat" - score: 0.69921875 - location_data { - format: BOUNDING_BOX - bounding_box { xmin: 608 ymin: 161 width: 381 height: 439 } - })pb"), - ParseTextProtoOrDie(R"pb( - label: "cat" - score: 0.64453125 - location_data { - format: BOUNDING_BOX - bounding_box { xmin: 60 ymin: 398 width: 386 height: 196 } - })pb"), - ParseTextProtoOrDie(R"pb( - label: "cat" - score: 0.51171875 - location_data { - format: BOUNDING_BOX - bounding_box { xmin: 256 ymin: 395 width: 173 height: 202 } - })pb"), - ParseTextProtoOrDie(R"pb( - label: "cat" - score: 0.48828125 - location_data { - format: BOUNDING_BOX - bounding_box { xmin: 362 ymin: 191 width: 325 height: 419 } - })pb")}); + results, + ConvertToDetectionResult( + {ParseTextProtoOrDie(R"pb( + label: "cat" + score: 0.69921875 + location_data { + format: BOUNDING_BOX + bounding_box { xmin: 608 ymin: 161 width: 381 height: 439 } + })pb"), + ParseTextProtoOrDie(R"pb( + label: "cat" + score: 0.64453125 + location_data { + format: BOUNDING_BOX + bounding_box { xmin: 60 ymin: 398 width: 386 height: 196 } + })pb"), + ParseTextProtoOrDie(R"pb( + label: "cat" + score: 0.51171875 + location_data { + format: BOUNDING_BOX + bounding_box { xmin: 256 ymin: 395 width: 173 height: 202 } + })pb"), + ParseTextProtoOrDie(R"pb( + label: "cat" + score: 0.48828125 + location_data { + format: BOUNDING_BOX + bounding_box { xmin: 362 ymin: 191 width: 325 height: 419 } + })pb")})); } TEST_F(ImageModeTest, SucceedsEfficientDetModel) { @@ -383,34 +388,36 @@ TEST_F(ImageModeTest, SucceedsEfficientDetModel) { MP_ASSERT_OK_AND_ASSIGN(auto results, object_detector->Detect(image)); MP_ASSERT_OK(object_detector->Close()); ExpectApproximatelyEqual( - results, {ParseTextProtoOrDie(R"pb( - label: "cat" - score: 0.7578125 - location_data { - format: BOUNDING_BOX - bounding_box { xmin: 858 ymin: 408 width: 225 height: 187 } - })pb"), - ParseTextProtoOrDie(R"pb( - label: "cat" - score: 0.72265625 - location_data { - format: BOUNDING_BOX - bounding_box { xmin: 67 ymin: 401 width: 399 height: 192 } - })pb"), - ParseTextProtoOrDie(R"pb( - label: "cat" - score: 0.6289063 - location_data { - format: BOUNDING_BOX - bounding_box { xmin: 368 ymin: 210 width: 272 height: 385 } - })pb"), - ParseTextProtoOrDie(R"pb( - label: "cat" - score: 0.5859375 - location_data { - format: BOUNDING_BOX - bounding_box { xmin: 601 ymin: 166 width: 298 height: 437 } - })pb")}); + results, + ConvertToDetectionResult( + {ParseTextProtoOrDie(R"pb( + label: "cat" + score: 0.7578125 + location_data { + format: BOUNDING_BOX + bounding_box { xmin: 858 ymin: 408 width: 225 height: 187 } + })pb"), + ParseTextProtoOrDie(R"pb( + label: "cat" + score: 0.72265625 + location_data { + format: BOUNDING_BOX + bounding_box { xmin: 67 ymin: 401 width: 399 height: 192 } + })pb"), + ParseTextProtoOrDie(R"pb( + label: "cat" + score: 0.6289063 + location_data { + format: BOUNDING_BOX + bounding_box { xmin: 368 ymin: 210 width: 272 height: 385 } + })pb"), + ParseTextProtoOrDie(R"pb( + label: "cat" + score: 0.5859375 + location_data { + format: BOUNDING_BOX + bounding_box { xmin: 601 ymin: 166 width: 298 height: 437 } + })pb")})); } TEST_F(ImageModeTest, SucceedsWithoutImageResizing) { @@ -426,7 +433,8 @@ TEST_F(ImageModeTest, SucceedsWithoutImageResizing) { MP_ASSERT_OK_AND_ASSIGN(auto results, object_detector->Detect(image)); MP_ASSERT_OK(object_detector->Close()); ExpectApproximatelyEqual( - results, GenerateMobileSsdNoImageResizingFullExpectedResults()); + results, ConvertToDetectionResult( + GenerateMobileSsdNoImageResizingFullExpectedResults())); } TEST_F(ImageModeTest, SucceedsWithScoreCalibration) { @@ -442,13 +450,14 @@ TEST_F(ImageModeTest, SucceedsWithScoreCalibration) { MP_ASSERT_OK_AND_ASSIGN(auto results, object_detector->Detect(image)); MP_ASSERT_OK(object_detector->Close()); ExpectApproximatelyEqual( - results, {ParseTextProtoOrDie(R"pb( + results, + ConvertToDetectionResult({ParseTextProtoOrDie(R"pb( label: "cat" score: 0.6531269142 location_data { format: BOUNDING_BOX bounding_box { xmin: 14 ymin: 197 width: 98 height: 99 } - })pb")}); + })pb")})); } TEST_F(ImageModeTest, SucceedsWithScoreThresholdOption) { @@ -463,11 +472,13 @@ TEST_F(ImageModeTest, SucceedsWithScoreThresholdOption) { ObjectDetector::Create(std::move(options))); MP_ASSERT_OK_AND_ASSIGN(auto results, object_detector->Detect(image)); MP_ASSERT_OK(object_detector->Close()); - std::vector full_expected_results = + std::vector full_expected_results = GenerateMobileSsdNoImageResizingFullExpectedResults(); - ExpectApproximatelyEqual(results, - {full_expected_results[0], full_expected_results[1], - full_expected_results[2]}); + + ExpectApproximatelyEqual( + results, ConvertToDetectionResult({full_expected_results[0], + full_expected_results[1], + full_expected_results[2]})); } TEST_F(ImageModeTest, SucceedsWithMaxResultsOption) { @@ -482,10 +493,11 @@ TEST_F(ImageModeTest, SucceedsWithMaxResultsOption) { ObjectDetector::Create(std::move(options))); MP_ASSERT_OK_AND_ASSIGN(auto results, object_detector->Detect(image)); MP_ASSERT_OK(object_detector->Close()); - std::vector full_expected_results = + std::vector full_expected_results = GenerateMobileSsdNoImageResizingFullExpectedResults(); ExpectApproximatelyEqual( - results, {full_expected_results[0], full_expected_results[1]}); + results, ConvertToDetectionResult( + {full_expected_results[0], full_expected_results[1]})); } TEST_F(ImageModeTest, SucceedsWithAllowlistOption) { @@ -501,9 +513,10 @@ TEST_F(ImageModeTest, SucceedsWithAllowlistOption) { ObjectDetector::Create(std::move(options))); MP_ASSERT_OK_AND_ASSIGN(auto results, object_detector->Detect(image)); MP_ASSERT_OK(object_detector->Close()); - std::vector full_expected_results = + std::vector full_expected_results = GenerateMobileSsdNoImageResizingFullExpectedResults(); - ExpectApproximatelyEqual(results, {full_expected_results[3]}); + ExpectApproximatelyEqual( + results, ConvertToDetectionResult({full_expected_results[3]})); } TEST_F(ImageModeTest, SucceedsWithDenylistOption) { @@ -519,9 +532,10 @@ TEST_F(ImageModeTest, SucceedsWithDenylistOption) { ObjectDetector::Create(std::move(options))); MP_ASSERT_OK_AND_ASSIGN(auto results, object_detector->Detect(image)); MP_ASSERT_OK(object_detector->Close()); - std::vector full_expected_results = + std::vector full_expected_results = GenerateMobileSsdNoImageResizingFullExpectedResults(); - ExpectApproximatelyEqual(results, {full_expected_results[3]}); + ExpectApproximatelyEqual( + results, ConvertToDetectionResult({full_expected_results[3]})); } TEST_F(ImageModeTest, SucceedsWithRotation) { @@ -541,13 +555,14 @@ TEST_F(ImageModeTest, SucceedsWithRotation) { auto results, object_detector->Detect(image, image_processing_options)); MP_ASSERT_OK(object_detector->Close()); ExpectApproximatelyEqual( - results, {ParseTextProtoOrDie(R"pb( + results, + ConvertToDetectionResult({ParseTextProtoOrDie(R"pb( label: "cat" score: 0.7109375 location_data { format: BOUNDING_BOX bounding_box { xmin: 0 ymin: 622 width: 436 height: 276 } - })pb")}); + })pb")})); } TEST_F(ImageModeTest, FailsWithRegionOfInterest) { @@ -560,7 +575,7 @@ TEST_F(ImageModeTest, FailsWithRegionOfInterest) { JoinPath("./", kTestDataDirectory, kMobileSsdWithMetadata); MP_ASSERT_OK_AND_ASSIGN(std::unique_ptr object_detector, ObjectDetector::Create(std::move(options))); - Rect roi{/*left=*/0.1, /*top=*/0, /*right=*/0.9, /*bottom=*/1}; + RectF roi{/*left=*/0.1, /*top=*/0, /*right=*/0.9, /*bottom=*/1}; ImageProcessingOptions image_processing_options{roi, /*rotation_degrees=*/0}; auto results = object_detector->Detect(image, image_processing_options); @@ -619,10 +634,11 @@ TEST_F(VideoModeTest, Succeeds) { for (int i = 0; i < iterations; ++i) { MP_ASSERT_OK_AND_ASSIGN(auto results, object_detector->DetectForVideo(image, i)); - std::vector full_expected_results = + std::vector full_expected_results = GenerateMobileSsdNoImageResizingFullExpectedResults(); ExpectApproximatelyEqual( - results, {full_expected_results[0], full_expected_results[1]}); + results, ConvertToDetectionResult( + {full_expected_results[0], full_expected_results[1]})); } MP_ASSERT_OK(object_detector->Close()); } @@ -637,9 +653,8 @@ TEST_F(LiveStreamModeTest, FailsWithCallingWrongMethod) { options->base_options.model_asset_path = JoinPath("./", kTestDataDirectory, kMobileSsdWithMetadata); options->running_mode = core::RunningMode::LIVE_STREAM; - options->result_callback = - [](absl::StatusOr> detections, const Image& image, - int64 timestamp_ms) {}; + options->result_callback = [](absl::StatusOr detections, + const Image& image, int64 timestamp_ms) {}; MP_ASSERT_OK_AND_ASSIGN(std::unique_ptr object_detector, ObjectDetector::Create(std::move(options))); @@ -669,9 +684,8 @@ TEST_F(LiveStreamModeTest, FailsWithOutOfOrderInputTimestamps) { options->running_mode = core::RunningMode::LIVE_STREAM; options->base_options.model_asset_path = JoinPath("./", kTestDataDirectory, kMobileSsdWithMetadata); - options->result_callback = - [](absl::StatusOr> detections, const Image& image, - int64 timestamp_ms) {}; + options->result_callback = [](absl::StatusOr detections, + const Image& image, int64 timestamp_ms) {}; MP_ASSERT_OK_AND_ASSIGN(std::unique_ptr object_detector, ObjectDetector::Create(std::move(options))); MP_ASSERT_OK(object_detector->DetectAsync(image, 1)); @@ -695,14 +709,14 @@ TEST_F(LiveStreamModeTest, Succeeds) { auto options = std::make_unique(); options->max_results = 2; options->running_mode = core::RunningMode::LIVE_STREAM; - std::vector> detection_results; + std::vector detection_results; std::vector> image_sizes; std::vector timestamps; options->base_options.model_asset_path = JoinPath("./", kTestDataDirectory, kMobileSsdWithMetadata); options->result_callback = [&detection_results, &image_sizes, ×tamps]( - absl::StatusOr> detections, const Image& image, + absl::StatusOr detections, const Image& image, int64 timestamp_ms) { MP_ASSERT_OK(detections.status()); detection_results.push_back(std::move(detections).value()); @@ -719,11 +733,12 @@ TEST_F(LiveStreamModeTest, Succeeds) { // number of iterations. ASSERT_LE(detection_results.size(), iterations); ASSERT_GT(detection_results.size(), 0); - std::vector full_expected_results = + std::vector full_expected_results = GenerateMobileSsdNoImageResizingFullExpectedResults(); for (const auto& detection_result : detection_results) { ExpectApproximatelyEqual( - detection_result, {full_expected_results[0], full_expected_results[1]}); + detection_result, ConvertToDetectionResult({full_expected_results[0], + full_expected_results[1]})); } for (const auto& image_size : image_sizes) { EXPECT_EQ(image_size.first, image.width()); diff --git a/mediapipe/tasks/cc/vision/object_detector/proto/object_detector_options.proto b/mediapipe/tasks/cc/vision/object_detector/proto/object_detector_options.proto index cba58ace8..3f6932f8f 100644 --- a/mediapipe/tasks/cc/vision/object_detector/proto/object_detector_options.proto +++ b/mediapipe/tasks/cc/vision/object_detector/proto/object_detector_options.proto @@ -18,6 +18,7 @@ syntax = "proto2"; package mediapipe.tasks.vision.object_detector.proto; import "mediapipe/framework/calculator.proto"; +import "mediapipe/framework/calculator_options.proto"; import "mediapipe/tasks/cc/core/proto/base_options.proto"; option java_package = "com.google.mediapipe.tasks.vision.objectdetector.proto"; diff --git a/mediapipe/tasks/cc/vision/utils/landmarks_utils.cc b/mediapipe/tasks/cc/vision/utils/landmarks_utils.cc index 2ce9e2454..fe4e63824 100644 --- a/mediapipe/tasks/cc/vision/utils/landmarks_utils.cc +++ b/mediapipe/tasks/cc/vision/utils/landmarks_utils.cc @@ -22,13 +22,13 @@ limitations under the License. namespace mediapipe::tasks::vision::utils { -using ::mediapipe::tasks::components::containers::Rect; +using ::mediapipe::tasks::components::containers::RectF; -float CalculateArea(const Rect& rect) { +float CalculateArea(const RectF& rect) { return (rect.right - rect.left) * (rect.bottom - rect.top); } -float CalculateIntersectionArea(const Rect& a, const Rect& b) { +float CalculateIntersectionArea(const RectF& a, const RectF& b) { const float intersection_left = std::max(a.left, b.left); const float intersection_top = std::max(a.top, b.top); const float intersection_right = std::min(a.right, b.right); @@ -38,7 +38,7 @@ float CalculateIntersectionArea(const Rect& a, const Rect& b) { std::max(intersection_right - intersection_left, 0.0); } -float CalculateIOU(const Rect& a, const Rect& b) { +float CalculateIOU(const RectF& a, const RectF& b) { const float area_a = CalculateArea(a); const float area_b = CalculateArea(b); if (area_a <= 0 || area_b <= 0) return 0.0; diff --git a/mediapipe/tasks/cc/vision/utils/landmarks_utils.h b/mediapipe/tasks/cc/vision/utils/landmarks_utils.h index 73114d2ef..4d1fac62f 100644 --- a/mediapipe/tasks/cc/vision/utils/landmarks_utils.h +++ b/mediapipe/tasks/cc/vision/utils/landmarks_utils.h @@ -27,15 +27,15 @@ limitations under the License. namespace mediapipe::tasks::vision::utils { // Calculates intersection over union for two bounds. -float CalculateIOU(const components::containers::Rect& a, - const components::containers::Rect& b); +float CalculateIOU(const components::containers::RectF& a, + const components::containers::RectF& b); // Calculates area for face bound -float CalculateArea(const components::containers::Rect& rect); +float CalculateArea(const components::containers::RectF& rect); // Calucates intersection area of two face bounds -float CalculateIntersectionArea(const components::containers::Rect& a, - const components::containers::Rect& b); +float CalculateIntersectionArea(const components::containers::RectF& a, + const components::containers::RectF& b); } // namespace mediapipe::tasks::vision::utils #endif // MEDIAPIPE_TASKS_CC_VISION_UTILS_LANDMARKS_UTILS_H_ diff --git a/mediapipe/tasks/ios/core/BUILD b/mediapipe/tasks/ios/core/BUILD index c6def1685..e8ce47818 100644 --- a/mediapipe/tasks/ios/core/BUILD +++ b/mediapipe/tasks/ios/core/BUILD @@ -16,22 +16,6 @@ package(default_visibility = ["//mediapipe/tasks:internal"]) licenses(["notice"]) -objc_library( - name = "MPPExternalFile", - srcs = ["sources/MPPExternalFile.m"], - hdrs = ["sources/MPPExternalFile.h"], -) - -objc_library( - name = "MPPBaseOptions", - srcs = ["sources/MPPBaseOptions.m"], - hdrs = ["sources/MPPBaseOptions.h"], - deps = [ - ":MPPExternalFile", - - ], -) - objc_library( name = "MPPTaskOptions", srcs = ["sources/MPPTaskOptions.m"], @@ -85,3 +69,8 @@ objc_library( hdrs = ["sources/MPPTaskResult.h"], ) +objc_library( + name = "MPPBaseOptions", + srcs = ["sources/MPPBaseOptions.m"], + hdrs = ["sources/MPPBaseOptions.h"], +) diff --git a/mediapipe/tasks/ios/core/sources/MPPBaseOptions.h b/mediapipe/tasks/ios/core/sources/MPPBaseOptions.h index 686e50add..9c6595cfc 100644 --- a/mediapipe/tasks/ios/core/sources/MPPBaseOptions.h +++ b/mediapipe/tasks/ios/core/sources/MPPBaseOptions.h @@ -1,16 +1,18 @@ -/* Copyright 2022 The TensorFlow Authors. All Rights Reserved. - Licensed under the Apache License, Version 2.0 (the "License"); - you may not use this file except in compliance with the License. - You may obtain a copy of the License at - http://www.apache.org/licenses/LICENSE-2.0 - Unless required by applicable law or agreed to in writing, software - distributed under the License is distributed on an "AS IS" BASIS, - WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - See the License for the specific language governing permissions and - limitations under the License. - ==============================================================================*/ +// Copyright 2022 The MediaPipe Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + #import -#import "mediapipe/tasks/ios/core/sources/MPPExternalFile.h" NS_ASSUME_NONNULL_BEGIN @@ -18,30 +20,25 @@ NS_ASSUME_NONNULL_BEGIN * MediaPipe Tasks delegate. */ typedef NS_ENUM(NSUInteger, MPPDelegate) { - /** CPU. */ MPPDelegateCPU, - + /** GPU. */ MPPDelegateGPU } NS_SWIFT_NAME(Delegate); /** * Holds the base options that is used for creation of any type of task. It has fields with - * important information acceleration configuration, tflite model source etc. + * important information acceleration configuration, TFLite model source etc. */ NS_SWIFT_NAME(BaseOptions) @interface MPPBaseOptions : NSObject -/** - * The external model file, as a single standalone TFLite file. It could be packed with TFLite Model - * Metadata[1] and associated files if exist. Fail to provide the necessary metadata and associated - * files might result in errors. - */ -@property(nonatomic, copy) MPPExternalFile *modelAssetFile; +/** The path to the model asset to open and mmap in memory. */ +@property(nonatomic, copy) NSString *modelAssetPath; /** - * device delegate to run the MediaPipe pipeline. If the delegate is not set, the default + * Device delegate to run the MediaPipe pipeline. If the delegate is not set, the default * delegate CPU is used. */ @property(nonatomic) MPPDelegate delegate; @@ -49,3 +46,4 @@ NS_SWIFT_NAME(BaseOptions) @end NS_ASSUME_NONNULL_END + diff --git a/mediapipe/tasks/ios/core/sources/MPPBaseOptions.m b/mediapipe/tasks/ios/core/sources/MPPBaseOptions.m index 4c25b80e8..b2b027da7 100644 --- a/mediapipe/tasks/ios/core/sources/MPPBaseOptions.m +++ b/mediapipe/tasks/ios/core/sources/MPPBaseOptions.m @@ -1,17 +1,17 @@ -/* Copyright 2022 The TensorFlow Authors. All Rights Reserved. +// Copyright 2022 The MediaPipe Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. - Licensed under the Apache License, Version 2.0 (the "License"); - you may not use this file except in compliance with the License. - You may obtain a copy of the License at - - http://www.apache.org/licenses/LICENSE-2.0 - - Unless required by applicable law or agreed to in writing, software - distributed under the License is distributed on an "AS IS" BASIS, - WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - See the License for the specific language governing permissions and - limitations under the License. - ==============================================================================*/ #import "mediapipe/tasks/ios/core/sources/MPPBaseOptions.h" @implementation MPPBaseOptions @@ -19,17 +19,17 @@ - (instancetype)init { self = [super init]; if (self) { - self.modelAssetFile = [[MPPExternalFile alloc] init]; + self.modelAssetPath = [[NSString alloc] init]; } return self; } - (id)copyWithZone:(NSZone *)zone { MPPBaseOptions *baseOptions = [[MPPBaseOptions alloc] init]; - - baseOptions.modelAssetFile = self.modelAssetFile; + + baseOptions.modelAssetPath = self.modelAssetPath; baseOptions.delegate = self.delegate; - + return baseOptions; } diff --git a/mediapipe/tasks/java/com/google/mediapipe/tasks/audio/BUILD b/mediapipe/tasks/java/com/google/mediapipe/tasks/audio/BUILD index 2d29ccf23..e5d472e8a 100644 --- a/mediapipe/tasks/java/com/google/mediapipe/tasks/audio/BUILD +++ b/mediapipe/tasks/java/com/google/mediapipe/tasks/audio/BUILD @@ -14,7 +14,7 @@ licenses(["notice"]) -package(default_visibility = ["//mediapipe/tasks:internal"]) +package(default_visibility = ["//visibility:public"]) android_library( name = "core", diff --git a/mediapipe/tasks/java/com/google/mediapipe/tasks/components/containers/BUILD b/mediapipe/tasks/java/com/google/mediapipe/tasks/components/containers/BUILD index ad17d5552..4d302b950 100644 --- a/mediapipe/tasks/java/com/google/mediapipe/tasks/components/containers/BUILD +++ b/mediapipe/tasks/java/com/google/mediapipe/tasks/components/containers/BUILD @@ -12,7 +12,7 @@ # See the License for the specific language governing permissions and # limitations under the License. -package(default_visibility = ["//mediapipe/tasks:internal"]) +package(default_visibility = ["//visibility:public"]) licenses(["notice"]) diff --git a/mediapipe/tasks/java/com/google/mediapipe/tasks/components/processors/BUILD b/mediapipe/tasks/java/com/google/mediapipe/tasks/components/processors/BUILD index 1f99f1612..b4d453935 100644 --- a/mediapipe/tasks/java/com/google/mediapipe/tasks/components/processors/BUILD +++ b/mediapipe/tasks/java/com/google/mediapipe/tasks/components/processors/BUILD @@ -12,7 +12,7 @@ # See the License for the specific language governing permissions and # limitations under the License. -package(default_visibility = ["//mediapipe/tasks:internal"]) +package(default_visibility = ["//visibility:public"]) licenses(["notice"]) diff --git a/mediapipe/tasks/java/com/google/mediapipe/tasks/components/utils/BUILD b/mediapipe/tasks/java/com/google/mediapipe/tasks/components/utils/BUILD index b2d27bfa7..6c724106f 100644 --- a/mediapipe/tasks/java/com/google/mediapipe/tasks/components/utils/BUILD +++ b/mediapipe/tasks/java/com/google/mediapipe/tasks/components/utils/BUILD @@ -12,7 +12,7 @@ # See the License for the specific language governing permissions and # limitations under the License. -package(default_visibility = ["//mediapipe/tasks:internal"]) +package(default_visibility = ["//visibility:public"]) licenses(["notice"]) diff --git a/mediapipe/tasks/java/com/google/mediapipe/tasks/core/BUILD b/mediapipe/tasks/java/com/google/mediapipe/tasks/core/BUILD index 01b1f653a..31f885267 100644 --- a/mediapipe/tasks/java/com/google/mediapipe/tasks/core/BUILD +++ b/mediapipe/tasks/java/com/google/mediapipe/tasks/core/BUILD @@ -12,7 +12,7 @@ # See the License for the specific language governing permissions and # limitations under the License. -package(default_visibility = ["//mediapipe/tasks:internal"]) +package(default_visibility = ["//visibility:public"]) android_library( name = "core", diff --git a/mediapipe/tasks/java/com/google/mediapipe/tasks/mediapipe_tasks_aar.bzl b/mediapipe/tasks/java/com/google/mediapipe/tasks/mediapipe_tasks_aar.bzl index d91c03cc2..727d020a6 100644 --- a/mediapipe/tasks/java/com/google/mediapipe/tasks/mediapipe_tasks_aar.bzl +++ b/mediapipe/tasks/java/com/google/mediapipe/tasks/mediapipe_tasks_aar.bzl @@ -21,7 +21,6 @@ _CORE_TASKS_JAVA_PROTO_LITE_TARGETS = [ "//mediapipe/tasks/cc/components/containers/proto:classifications_java_proto_lite", "//mediapipe/tasks/cc/components/containers/proto:embeddings_java_proto_lite", "//mediapipe/tasks/cc/components/containers/proto:landmarks_detection_result_java_proto_lite", - "//mediapipe/tasks/cc/components/proto:segmenter_options_java_proto_lite", "//mediapipe/tasks/cc/components/processors/proto:classifier_options_java_proto_lite", "//mediapipe/tasks/cc/components/processors/proto:embedder_options_java_proto_lite", "//mediapipe/tasks/cc/core/proto:acceleration_java_proto_lite", @@ -43,6 +42,7 @@ _VISION_TASKS_JAVA_PROTO_LITE_TARGETS = [ "//mediapipe/tasks/cc/vision/gesture_recognizer/proto:hand_gesture_recognizer_graph_options_java_proto_lite", "//mediapipe/tasks/cc/vision/image_embedder/proto:image_embedder_graph_options_java_proto_lite", "//mediapipe/tasks/cc/vision/image_segmenter/proto:image_segmenter_graph_options_java_proto_lite", + "//mediapipe/tasks/cc/vision/image_segmenter/proto:segmenter_options_java_proto_lite", "//mediapipe/tasks/cc/vision/hand_detector/proto:hand_detector_graph_options_java_proto_lite", "//mediapipe/tasks/cc/vision/hand_landmarker/proto:hand_landmarker_graph_options_java_proto_lite", "//mediapipe/tasks/cc/vision/hand_landmarker/proto:hand_landmarks_detector_graph_options_java_proto_lite", @@ -285,6 +285,7 @@ def _mediapipe_tasks_aar(name, srcs, manifest, java_proto_lite_targets, native_l "//mediapipe/tasks/java/com/google/mediapipe/tasks/components/containers:embedding", "//mediapipe/tasks/java/com/google/mediapipe/tasks/components/containers:embeddingresult", "//mediapipe/tasks/java/com/google/mediapipe/tasks/components/containers:landmark", + "//mediapipe/tasks/java/com/google/mediapipe/tasks/components/containers:normalized_landmark", "//mediapipe/tasks/java/com/google/mediapipe/tasks/components/processors:classifieroptions", "//mediapipe/tasks/java/com/google/mediapipe/tasks/components/utils:cosinesimilarity", "//mediapipe/tasks/java/com/google/mediapipe/tasks/core", diff --git a/mediapipe/tasks/java/com/google/mediapipe/tasks/text/BUILD b/mediapipe/tasks/java/com/google/mediapipe/tasks/text/BUILD index 5b10e9aab..31cd2c89a 100644 --- a/mediapipe/tasks/java/com/google/mediapipe/tasks/text/BUILD +++ b/mediapipe/tasks/java/com/google/mediapipe/tasks/text/BUILD @@ -14,7 +14,7 @@ licenses(["notice"]) -package(default_visibility = ["//mediapipe/tasks:internal"]) +package(default_visibility = ["//visibility:public"]) # The native library of all MediaPipe text tasks. cc_binary( diff --git a/mediapipe/tasks/java/com/google/mediapipe/tasks/vision/BUILD b/mediapipe/tasks/java/com/google/mediapipe/tasks/vision/BUILD index 6161fe032..f469aed0c 100644 --- a/mediapipe/tasks/java/com/google/mediapipe/tasks/vision/BUILD +++ b/mediapipe/tasks/java/com/google/mediapipe/tasks/vision/BUILD @@ -14,7 +14,7 @@ licenses(["notice"]) -package(default_visibility = ["//mediapipe/tasks:internal"]) +package(default_visibility = ["//visibility:public"]) android_library( name = "core", diff --git a/mediapipe/tasks/python/audio/BUILD b/mediapipe/tasks/python/audio/BUILD index 2e5815ff0..6dda7a53c 100644 --- a/mediapipe/tasks/python/audio/BUILD +++ b/mediapipe/tasks/python/audio/BUILD @@ -14,7 +14,7 @@ # Placeholder for internal Python strict library and test compatibility macro. -package(default_visibility = ["//mediapipe/tasks:internal"]) +package(default_visibility = ["//visibility:public"]) licenses(["notice"]) @@ -29,11 +29,11 @@ py_library( "//mediapipe/python:packet_getter", "//mediapipe/tasks/cc/audio/audio_classifier/proto:audio_classifier_graph_options_py_pb2", "//mediapipe/tasks/cc/components/containers/proto:classifications_py_pb2", + "//mediapipe/tasks/cc/components/processors/proto:classifier_options_py_pb2", "//mediapipe/tasks/python/audio/core:audio_task_running_mode", "//mediapipe/tasks/python/audio/core:base_audio_task_api", "//mediapipe/tasks/python/components/containers:audio_data", "//mediapipe/tasks/python/components/containers:classification_result", - "//mediapipe/tasks/python/components/processors:classifier_options", "//mediapipe/tasks/python/core:base_options", "//mediapipe/tasks/python/core:optional_dependencies", "//mediapipe/tasks/python/core:task_info", @@ -51,11 +51,11 @@ py_library( "//mediapipe/python:packet_getter", "//mediapipe/tasks/cc/audio/audio_embedder/proto:audio_embedder_graph_options_py_pb2", "//mediapipe/tasks/cc/components/containers/proto:embeddings_py_pb2", + "//mediapipe/tasks/cc/components/processors/proto:embedder_options_py_pb2", "//mediapipe/tasks/python/audio/core:audio_task_running_mode", "//mediapipe/tasks/python/audio/core:base_audio_task_api", "//mediapipe/tasks/python/components/containers:audio_data", "//mediapipe/tasks/python/components/containers:embedding_result", - "//mediapipe/tasks/python/components/processors:embedder_options", "//mediapipe/tasks/python/components/utils:cosine_similarity", "//mediapipe/tasks/python/core:base_options", "//mediapipe/tasks/python/core:optional_dependencies", diff --git a/mediapipe/tasks/python/audio/audio_classifier.py b/mediapipe/tasks/python/audio/audio_classifier.py index d82b6fe27..cc87d6221 100644 --- a/mediapipe/tasks/python/audio/audio_classifier.py +++ b/mediapipe/tasks/python/audio/audio_classifier.py @@ -21,11 +21,11 @@ from mediapipe.python import packet_getter from mediapipe.python._framework_bindings import packet from mediapipe.tasks.cc.audio.audio_classifier.proto import audio_classifier_graph_options_pb2 from mediapipe.tasks.cc.components.containers.proto import classifications_pb2 +from mediapipe.tasks.cc.components.processors.proto import classifier_options_pb2 from mediapipe.tasks.python.audio.core import audio_task_running_mode as running_mode_module from mediapipe.tasks.python.audio.core import base_audio_task_api from mediapipe.tasks.python.components.containers import audio_data as audio_data_module from mediapipe.tasks.python.components.containers import classification_result as classification_result_module -from mediapipe.tasks.python.components.processors import classifier_options as classifier_options_module from mediapipe.tasks.python.core import base_options as base_options_module from mediapipe.tasks.python.core import task_info as task_info_module from mediapipe.tasks.python.core.optional_dependencies import doc_controls @@ -34,7 +34,7 @@ AudioClassifierResult = classification_result_module.ClassificationResult _AudioClassifierGraphOptionsProto = audio_classifier_graph_options_pb2.AudioClassifierGraphOptions _AudioData = audio_data_module.AudioData _BaseOptions = base_options_module.BaseOptions -_ClassifierOptions = classifier_options_module.ClassifierOptions +_ClassifierOptionsProto = classifier_options_pb2.ClassifierOptions _RunningMode = running_mode_module.AudioTaskRunningMode _TaskInfo = task_info_module.TaskInfo @@ -62,16 +62,31 @@ class AudioClassifierOptions: mode for running classification on the audio stream, such as from microphone. In this mode, the "result_callback" below must be specified to receive the classification results asynchronously. - classifier_options: Options for configuring the classifier behavior, such as - score threshold, number of results, etc. + display_names_locale: The locale to use for display names specified through + the TFLite Model Metadata. + max_results: The maximum number of top-scored classification results to + return. + score_threshold: Overrides the ones provided in the model metadata. Results + below this value are rejected. + category_allowlist: Allowlist of category names. If non-empty, + classification results whose category name is not in this set will be + filtered out. Duplicate or unknown category names are ignored. Mutually + exclusive with `category_denylist`. + category_denylist: Denylist of category names. If non-empty, classification + results whose category name is in this set will be filtered out. Duplicate + or unknown category names are ignored. Mutually exclusive with + `category_allowlist`. result_callback: The user-defined result callback for processing audio stream data. The result callback should only be specified when the running mode is set to the audio stream mode. """ base_options: _BaseOptions running_mode: _RunningMode = _RunningMode.AUDIO_CLIPS - classifier_options: Optional[_ClassifierOptions] = dataclasses.field( - default_factory=_ClassifierOptions) + display_names_locale: Optional[str] = None + max_results: Optional[int] = None + score_threshold: Optional[float] = None + category_allowlist: Optional[List[str]] = None + category_denylist: Optional[List[str]] = None result_callback: Optional[Callable[[AudioClassifierResult, int], None]] = None @doc_controls.do_not_generate_docs @@ -79,7 +94,12 @@ class AudioClassifierOptions: """Generates an AudioClassifierOptions protobuf object.""" base_options_proto = self.base_options.to_pb2() base_options_proto.use_stream_mode = False if self.running_mode == _RunningMode.AUDIO_CLIPS else True - classifier_options_proto = self.classifier_options.to_pb2() + classifier_options_proto = _ClassifierOptionsProto( + score_threshold=self.score_threshold, + category_allowlist=self.category_allowlist, + category_denylist=self.category_denylist, + display_names_locale=self.display_names_locale, + max_results=self.max_results) return _AudioClassifierGraphOptionsProto( base_options=base_options_proto, diff --git a/mediapipe/tasks/python/audio/audio_embedder.py b/mediapipe/tasks/python/audio/audio_embedder.py index 629e21882..4c37783e9 100644 --- a/mediapipe/tasks/python/audio/audio_embedder.py +++ b/mediapipe/tasks/python/audio/audio_embedder.py @@ -21,11 +21,11 @@ from mediapipe.python import packet_getter from mediapipe.python._framework_bindings import packet from mediapipe.tasks.cc.audio.audio_embedder.proto import audio_embedder_graph_options_pb2 from mediapipe.tasks.cc.components.containers.proto import embeddings_pb2 +from mediapipe.tasks.cc.components.processors.proto import embedder_options_pb2 from mediapipe.tasks.python.audio.core import audio_task_running_mode as running_mode_module from mediapipe.tasks.python.audio.core import base_audio_task_api from mediapipe.tasks.python.components.containers import audio_data as audio_data_module from mediapipe.tasks.python.components.containers import embedding_result as embedding_result_module -from mediapipe.tasks.python.components.processors import embedder_options as embedder_options_module from mediapipe.tasks.python.components.utils import cosine_similarity from mediapipe.tasks.python.core import base_options as base_options_module from mediapipe.tasks.python.core import task_info as task_info_module @@ -35,7 +35,7 @@ AudioEmbedderResult = embedding_result_module.EmbeddingResult _AudioEmbedderGraphOptionsProto = audio_embedder_graph_options_pb2.AudioEmbedderGraphOptions _AudioData = audio_data_module.AudioData _BaseOptions = base_options_module.BaseOptions -_EmbedderOptions = embedder_options_module.EmbedderOptions +_EmbedderOptionsProto = embedder_options_pb2.EmbedderOptions _RunningMode = running_mode_module.AudioTaskRunningMode _TaskInfo = task_info_module.TaskInfo @@ -63,16 +63,22 @@ class AudioEmbedderOptions: stream mode for running embedding extraction on the audio stream, such as from microphone. In this mode, the "result_callback" below must be specified to receive the embedding results asynchronously. - embedder_options: Options for configuring the embedder behavior, such as - l2_normalize and quantize. + l2_normalize: Whether to normalize the returned feature vector with L2 norm. + Use this option only if the model does not already contain a native + L2_NORMALIZATION TF Lite Op. In most cases, this is already the case and + L2 norm is thus achieved through TF Lite inference. + quantize: Whether the returned embedding should be quantized to bytes via + scalar quantization. Embeddings are implicitly assumed to be unit-norm and + therefore any dimension is guaranteed to have a value in [-1.0, 1.0]. Use + the l2_normalize option if this is not the case. result_callback: The user-defined result callback for processing audio stream data. The result callback should only be specified when the running mode is set to the audio stream mode. """ base_options: _BaseOptions running_mode: _RunningMode = _RunningMode.AUDIO_CLIPS - embedder_options: Optional[_EmbedderOptions] = dataclasses.field( - default_factory=_EmbedderOptions) + l2_normalize: Optional[bool] = None + quantize: Optional[bool] = None result_callback: Optional[Callable[[AudioEmbedderResult, int], None]] = None @doc_controls.do_not_generate_docs @@ -80,7 +86,8 @@ class AudioEmbedderOptions: """Generates an AudioEmbedderOptions protobuf object.""" base_options_proto = self.base_options.to_pb2() base_options_proto.use_stream_mode = False if self.running_mode == _RunningMode.AUDIO_CLIPS else True - embedder_options_proto = self.embedder_options.to_pb2() + embedder_options_proto = _EmbedderOptionsProto( + l2_normalize=self.l2_normalize, quantize=self.quantize) return _AudioEmbedderGraphOptionsProto( base_options=base_options_proto, diff --git a/mediapipe/tasks/python/audio/core/BUILD b/mediapipe/tasks/python/audio/core/BUILD index 3cb9cb8e8..5b4203d7b 100644 --- a/mediapipe/tasks/python/audio/core/BUILD +++ b/mediapipe/tasks/python/audio/core/BUILD @@ -14,7 +14,7 @@ # Placeholder for internal Python strict library and test compatibility macro. -package(default_visibility = ["//mediapipe/tasks:internal"]) +package(default_visibility = ["//visibility:public"]) licenses(["notice"]) diff --git a/mediapipe/tasks/python/components/containers/BUILD b/mediapipe/tasks/python/components/containers/BUILD index 9d275e167..7108617ff 100644 --- a/mediapipe/tasks/python/components/containers/BUILD +++ b/mediapipe/tasks/python/components/containers/BUILD @@ -14,7 +14,7 @@ # Placeholder for internal Python strict library and test compatibility macro. -package(default_visibility = ["//mediapipe/tasks:internal"]) +package(default_visibility = ["//visibility:public"]) licenses(["notice"]) diff --git a/mediapipe/tasks/python/components/processors/BUILD b/mediapipe/tasks/python/components/processors/BUILD index eef368db0..695f6df91 100644 --- a/mediapipe/tasks/python/components/processors/BUILD +++ b/mediapipe/tasks/python/components/processors/BUILD @@ -16,7 +16,7 @@ # Placeholder for internal Python strict library and test compatibility macro. -package(default_visibility = ["//mediapipe/tasks:internal"]) +package(default_visibility = ["//visibility:public"]) licenses(["notice"]) @@ -28,12 +28,3 @@ py_library( "//mediapipe/tasks/python/core:optional_dependencies", ], ) - -py_library( - name = "embedder_options", - srcs = ["embedder_options.py"], - deps = [ - "//mediapipe/tasks/cc/components/processors/proto:embedder_options_py_pb2", - "//mediapipe/tasks/python/core:optional_dependencies", - ], -) diff --git a/mediapipe/tasks/python/components/processors/__init__.py b/mediapipe/tasks/python/components/processors/__init__.py index adcb38757..0eb73abe0 100644 --- a/mediapipe/tasks/python/components/processors/__init__.py +++ b/mediapipe/tasks/python/components/processors/__init__.py @@ -15,12 +15,9 @@ """MediaPipe Tasks Components Processors API.""" import mediapipe.tasks.python.components.processors.classifier_options -import mediapipe.tasks.python.components.processors.embedder_options ClassifierOptions = classifier_options.ClassifierOptions -EmbedderOptions = embedder_options.EmbedderOptions # Remove unnecessary modules to avoid duplication in API docs. del classifier_options -del embedder_options del mediapipe diff --git a/mediapipe/tasks/python/components/processors/embedder_options.py b/mediapipe/tasks/python/components/processors/embedder_options.py deleted file mode 100644 index c86a91105..000000000 --- a/mediapipe/tasks/python/components/processors/embedder_options.py +++ /dev/null @@ -1,68 +0,0 @@ -# Copyright 2022 The MediaPipe Authors. All Rights Reserved. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -"""Embedder options data class.""" - -import dataclasses -from typing import Any, Optional - -from mediapipe.tasks.cc.components.processors.proto import embedder_options_pb2 -from mediapipe.tasks.python.core.optional_dependencies import doc_controls - -_EmbedderOptionsProto = embedder_options_pb2.EmbedderOptions - - -@dataclasses.dataclass -class EmbedderOptions: - """Shared options used by all embedding extraction tasks. - - Attributes: - l2_normalize: Whether to normalize the returned feature vector with L2 norm. - Use this option only if the model does not already contain a native - L2_NORMALIZATION TF Lite Op. In most cases, this is already the case and - L2 norm is thus achieved through TF Lite inference. - quantize: Whether the returned embedding should be quantized to bytes via - scalar quantization. Embeddings are implicitly assumed to be unit-norm and - therefore any dimension is guaranteed to have a value in [-1.0, 1.0]. Use - the l2_normalize option if this is not the case. - """ - - l2_normalize: Optional[bool] = None - quantize: Optional[bool] = None - - @doc_controls.do_not_generate_docs - def to_pb2(self) -> _EmbedderOptionsProto: - """Generates a EmbedderOptions protobuf object.""" - return _EmbedderOptionsProto( - l2_normalize=self.l2_normalize, quantize=self.quantize) - - @classmethod - @doc_controls.do_not_generate_docs - def create_from_pb2(cls, pb2_obj: _EmbedderOptionsProto) -> 'EmbedderOptions': - """Creates a `EmbedderOptions` object from the given protobuf object.""" - return EmbedderOptions( - l2_normalize=pb2_obj.l2_normalize, quantize=pb2_obj.quantize) - - def __eq__(self, other: Any) -> bool: - """Checks if this object is equal to the given object. - - Args: - other: The object to be compared with. - - Returns: - True if the objects are equal. - """ - if not isinstance(other, EmbedderOptions): - return False - - return self.to_pb2().__eq__(other.to_pb2()) diff --git a/mediapipe/tasks/python/components/utils/BUILD b/mediapipe/tasks/python/components/utils/BUILD index b64d04c72..1a18531c6 100644 --- a/mediapipe/tasks/python/components/utils/BUILD +++ b/mediapipe/tasks/python/components/utils/BUILD @@ -16,15 +16,12 @@ # Placeholder for internal Python strict library and test compatibility macro. -package(default_visibility = ["//mediapipe/tasks:internal"]) +package(default_visibility = ["//visibility:public"]) licenses(["notice"]) py_library( name = "cosine_similarity", srcs = ["cosine_similarity.py"], - deps = [ - "//mediapipe/tasks/python/components/containers:embedding_result", - "//mediapipe/tasks/python/components/processors:embedder_options", - ], + deps = ["//mediapipe/tasks/python/components/containers:embedding_result"], ) diff --git a/mediapipe/tasks/python/components/utils/cosine_similarity.py b/mediapipe/tasks/python/components/utils/cosine_similarity.py index 486c02ece..ff8979458 100644 --- a/mediapipe/tasks/python/components/utils/cosine_similarity.py +++ b/mediapipe/tasks/python/components/utils/cosine_similarity.py @@ -16,10 +16,8 @@ import numpy as np from mediapipe.tasks.python.components.containers import embedding_result -from mediapipe.tasks.python.components.processors import embedder_options _Embedding = embedding_result.Embedding -_EmbedderOptions = embedder_options.EmbedderOptions def _compute_cosine_similarity(u, v): diff --git a/mediapipe/tasks/python/core/BUILD b/mediapipe/tasks/python/core/BUILD index fc0018ab1..f14d59b99 100644 --- a/mediapipe/tasks/python/core/BUILD +++ b/mediapipe/tasks/python/core/BUILD @@ -14,7 +14,7 @@ # Placeholder for internal Python strict library and test compatibility macro. -package(default_visibility = ["//mediapipe/tasks:internal"]) +package(default_visibility = ["//visibility:public"]) licenses(["notice"]) @@ -23,15 +23,15 @@ py_library( srcs = [ "optional_dependencies.py", ], - deps = [ - "@org_tensorflow//tensorflow/tools/docs:doc_controls", - ], ) py_library( name = "base_options", srcs = ["base_options.py"], - visibility = ["//mediapipe/tasks:users"], + visibility = [ + "//mediapipe/model_maker/python/vision/gesture_recognizer:__subpackages__", + "//mediapipe/tasks:users", + ], deps = [ ":optional_dependencies", "//mediapipe/tasks/cc/core/proto:base_options_py_pb2", diff --git a/mediapipe/tasks/python/core/optional_dependencies.py b/mediapipe/tasks/python/core/optional_dependencies.py index d4f6a6abc..b1a0ed538 100644 --- a/mediapipe/tasks/python/core/optional_dependencies.py +++ b/mediapipe/tasks/python/core/optional_dependencies.py @@ -13,6 +13,13 @@ # limitations under the License. """MediaPipe Tasks' common but optional dependencies.""" -doc_controls = lambda: None -no_op = lambda x: x -setattr(doc_controls, 'do_not_generate_docs', no_op) +# TensorFlow isn't a dependency of mediapipe pip package. It's only +# required in the API docgen pipeline so we'll ignore it if tensorflow is not +# installed. +try: + from tensorflow.tools.docs import doc_controls +except ModuleNotFoundError: + # Replace the real doc_controls.do_not_generate_docs with an no-op + doc_controls = lambda: None + no_op = lambda x: x + setattr(doc_controls, 'do_not_generate_docs', no_op) diff --git a/mediapipe/tasks/python/test/audio/BUILD b/mediapipe/tasks/python/test/audio/BUILD index 9278cea55..43f1d417c 100644 --- a/mediapipe/tasks/python/test/audio/BUILD +++ b/mediapipe/tasks/python/test/audio/BUILD @@ -30,7 +30,6 @@ py_test( "//mediapipe/tasks/python/audio/core:audio_task_running_mode", "//mediapipe/tasks/python/components/containers:audio_data", "//mediapipe/tasks/python/components/containers:classification_result", - "//mediapipe/tasks/python/components/processors:classifier_options", "//mediapipe/tasks/python/core:base_options", "//mediapipe/tasks/python/test:test_utils", ], @@ -48,7 +47,6 @@ py_test( "//mediapipe/tasks/python/audio/core:audio_task_running_mode", "//mediapipe/tasks/python/components/containers:audio_data", "//mediapipe/tasks/python/components/containers:embedding_result", - "//mediapipe/tasks/python/components/processors:embedder_options", "//mediapipe/tasks/python/core:base_options", "//mediapipe/tasks/python/test:test_utils", ], diff --git a/mediapipe/tasks/python/test/audio/audio_classifier_test.py b/mediapipe/tasks/python/test/audio/audio_classifier_test.py index 0d067e587..75146547c 100644 --- a/mediapipe/tasks/python/test/audio/audio_classifier_test.py +++ b/mediapipe/tasks/python/test/audio/audio_classifier_test.py @@ -27,7 +27,6 @@ from mediapipe.tasks.python.audio import audio_classifier from mediapipe.tasks.python.audio.core import audio_task_running_mode from mediapipe.tasks.python.components.containers import audio_data as audio_data_module from mediapipe.tasks.python.components.containers import classification_result as classification_result_module -from mediapipe.tasks.python.components.processors import classifier_options from mediapipe.tasks.python.core import base_options as base_options_module from mediapipe.tasks.python.test import test_utils @@ -36,7 +35,6 @@ _AudioClassifierOptions = audio_classifier.AudioClassifierOptions _AudioClassifierResult = classification_result_module.ClassificationResult _AudioData = audio_data_module.AudioData _BaseOptions = base_options_module.BaseOptions -_ClassifierOptions = classifier_options.ClassifierOptions _RUNNING_MODE = audio_task_running_mode.AudioTaskRunningMode _YAMNET_MODEL_FILE = 'yamnet_audio_classifier_with_metadata.tflite' @@ -210,8 +208,7 @@ class AudioClassifierTest(parameterized.TestCase): with _AudioClassifier.create_from_options( _AudioClassifierOptions( base_options=_BaseOptions(model_asset_path=self.yamnet_model_path), - classifier_options=_ClassifierOptions( - max_results=1))) as classifier: + max_results=1)) as classifier: for audio_file in [_SPEECH_WAV_16K_MONO, _SPEECH_WAV_16K_MONO]: classification_result_list = classifier.classify( self._read_wav_file(audio_file)) @@ -222,8 +219,7 @@ class AudioClassifierTest(parameterized.TestCase): with _AudioClassifier.create_from_options( _AudioClassifierOptions( base_options=_BaseOptions(model_asset_path=self.yamnet_model_path), - classifier_options=_ClassifierOptions( - score_threshold=0.9))) as classifier: + score_threshold=0.9)) as classifier: for audio_file in [_SPEECH_WAV_16K_MONO, _SPEECH_WAV_16K_MONO]: classification_result_list = classifier.classify( self._read_wav_file(audio_file)) @@ -234,8 +230,7 @@ class AudioClassifierTest(parameterized.TestCase): with _AudioClassifier.create_from_options( _AudioClassifierOptions( base_options=_BaseOptions(model_asset_path=self.yamnet_model_path), - classifier_options=_ClassifierOptions( - category_allowlist=['Speech']))) as classifier: + category_allowlist=['Speech'])) as classifier: for audio_file in [_SPEECH_WAV_16K_MONO, _SPEECH_WAV_16K_MONO]: classification_result_list = classifier.classify( self._read_wav_file(audio_file)) @@ -250,8 +245,8 @@ class AudioClassifierTest(parameterized.TestCase): r'exclusive options.'): options = _AudioClassifierOptions( base_options=_BaseOptions(model_asset_path=self.yamnet_model_path), - classifier_options=_ClassifierOptions( - category_allowlist=['foo'], category_denylist=['bar'])) + category_allowlist=['foo'], + category_denylist=['bar']) with _AudioClassifier.create_from_options(options) as unused_classifier: pass @@ -278,8 +273,7 @@ class AudioClassifierTest(parameterized.TestCase): _AudioClassifierOptions( base_options=_BaseOptions( model_asset_path=self.two_heads_model_path), - classifier_options=_ClassifierOptions( - max_results=1))) as classifier: + max_results=1)) as classifier: for audio_file in [_TWO_HEADS_WAV_16K_MONO, _TWO_HEADS_WAV_44K_MONO]: classification_result_list = classifier.classify( self._read_wav_file(audio_file)) @@ -364,7 +358,7 @@ class AudioClassifierTest(parameterized.TestCase): options = _AudioClassifierOptions( base_options=_BaseOptions(model_asset_path=self.yamnet_model_path), running_mode=_RUNNING_MODE.AUDIO_STREAM, - classifier_options=_ClassifierOptions(max_results=1), + max_results=1, result_callback=save_result) classifier = _AudioClassifier.create_from_options(options) audio_data_list = self._read_wav_file_as_stream(audio_file) diff --git a/mediapipe/tasks/python/test/audio/audio_embedder_test.py b/mediapipe/tasks/python/test/audio/audio_embedder_test.py index 2e38ea2ee..f280235d7 100644 --- a/mediapipe/tasks/python/test/audio/audio_embedder_test.py +++ b/mediapipe/tasks/python/test/audio/audio_embedder_test.py @@ -26,7 +26,6 @@ from scipy.io import wavfile from mediapipe.tasks.python.audio import audio_embedder from mediapipe.tasks.python.audio.core import audio_task_running_mode from mediapipe.tasks.python.components.containers import audio_data as audio_data_module -from mediapipe.tasks.python.components.processors import embedder_options from mediapipe.tasks.python.core import base_options as base_options_module from mediapipe.tasks.python.test import test_utils @@ -35,7 +34,6 @@ _AudioEmbedderOptions = audio_embedder.AudioEmbedderOptions _AudioEmbedderResult = audio_embedder.AudioEmbedderResult _AudioData = audio_data_module.AudioData _BaseOptions = base_options_module.BaseOptions -_EmbedderOptions = embedder_options.EmbedderOptions _RUNNING_MODE = audio_task_running_mode.AudioTaskRunningMode _YAMNET_MODEL_FILE = 'yamnet_embedding_metadata.tflite' @@ -172,9 +170,7 @@ class AudioEmbedderTest(parameterized.TestCase): raise ValueError('model_file_type is invalid.') options = _AudioEmbedderOptions( - base_options=base_options, - embedder_options=_EmbedderOptions( - l2_normalize=l2_normalize, quantize=quantize)) + base_options=base_options, l2_normalize=l2_normalize, quantize=quantize) with _AudioEmbedder.create_from_options(options) as embedder: embedding_result0_list = embedder.embed(self._read_wav_file(audio_file0)) @@ -291,8 +287,8 @@ class AudioEmbedderTest(parameterized.TestCase): options = _AudioEmbedderOptions( base_options=_BaseOptions(model_asset_path=self.yamnet_model_path), running_mode=_RUNNING_MODE.AUDIO_STREAM, - embedder_options=_EmbedderOptions( - l2_normalize=l2_normalize, quantize=quantize), + l2_normalize=l2_normalize, + quantize=quantize, result_callback=save_result) with _AudioEmbedder.create_from_options(options) as embedder: diff --git a/mediapipe/tasks/python/test/text/BUILD b/mediapipe/tasks/python/test/text/BUILD index 38e56bdb2..0e2b06012 100644 --- a/mediapipe/tasks/python/test/text/BUILD +++ b/mediapipe/tasks/python/test/text/BUILD @@ -28,7 +28,6 @@ py_test( deps = [ "//mediapipe/tasks/python/components/containers:category", "//mediapipe/tasks/python/components/containers:classification_result", - "//mediapipe/tasks/python/components/processors:classifier_options", "//mediapipe/tasks/python/core:base_options", "//mediapipe/tasks/python/test:test_utils", "//mediapipe/tasks/python/text:text_classifier", @@ -44,7 +43,6 @@ py_test( ], deps = [ "//mediapipe/tasks/python/components/containers:embedding_result", - "//mediapipe/tasks/python/components/processors:embedder_options", "//mediapipe/tasks/python/core:base_options", "//mediapipe/tasks/python/test:test_utils", "//mediapipe/tasks/python/text:text_embedder", diff --git a/mediapipe/tasks/python/test/text/text_classifier_test.py b/mediapipe/tasks/python/test/text/text_classifier_test.py index 8678d2194..8df7dce86 100644 --- a/mediapipe/tasks/python/test/text/text_classifier_test.py +++ b/mediapipe/tasks/python/test/text/text_classifier_test.py @@ -21,14 +21,12 @@ from absl.testing import parameterized from mediapipe.tasks.python.components.containers import category from mediapipe.tasks.python.components.containers import classification_result as classification_result_module -from mediapipe.tasks.python.components.processors import classifier_options from mediapipe.tasks.python.core import base_options as base_options_module from mediapipe.tasks.python.test import test_utils from mediapipe.tasks.python.text import text_classifier TextClassifierResult = classification_result_module.ClassificationResult _BaseOptions = base_options_module.BaseOptions -_ClassifierOptions = classifier_options.ClassifierOptions _Category = category.Category _Classifications = classification_result_module.Classifications _TextClassifier = text_classifier.TextClassifier diff --git a/mediapipe/tasks/python/test/text/text_embedder_test.py b/mediapipe/tasks/python/test/text/text_embedder_test.py index c9090026c..1346ba373 100644 --- a/mediapipe/tasks/python/test/text/text_embedder_test.py +++ b/mediapipe/tasks/python/test/text/text_embedder_test.py @@ -21,13 +21,11 @@ from absl.testing import parameterized import numpy as np from mediapipe.tasks.python.components.containers import embedding_result as embedding_result_module -from mediapipe.tasks.python.components.processors import embedder_options as embedder_options_module from mediapipe.tasks.python.core import base_options as base_options_module from mediapipe.tasks.python.test import test_utils from mediapipe.tasks.python.text import text_embedder _BaseOptions = base_options_module.BaseOptions -_EmbedderOptions = embedder_options_module.EmbedderOptions _Embedding = embedding_result_module.Embedding _TextEmbedder = text_embedder.TextEmbedder _TextEmbedderOptions = text_embedder.TextEmbedderOptions @@ -128,10 +126,8 @@ class TextEmbedderTest(parameterized.TestCase): # Should never happen raise ValueError('model_file_type is invalid.') - embedder_options = _EmbedderOptions( - l2_normalize=l2_normalize, quantize=quantize) options = _TextEmbedderOptions( - base_options=base_options, embedder_options=embedder_options) + base_options=base_options, l2_normalize=l2_normalize, quantize=quantize) embedder = _TextEmbedder.create_from_options(options) # Extracts both embeddings. @@ -178,10 +174,8 @@ class TextEmbedderTest(parameterized.TestCase): # Should never happen raise ValueError('model_file_type is invalid.') - embedder_options = _EmbedderOptions( - l2_normalize=l2_normalize, quantize=quantize) options = _TextEmbedderOptions( - base_options=base_options, embedder_options=embedder_options) + base_options=base_options, l2_normalize=l2_normalize, quantize=quantize) with _TextEmbedder.create_from_options(options) as embedder: # Extracts both embeddings. positive_text0 = "it's a charming and often affecting journey" diff --git a/mediapipe/tasks/python/test/vision/BUILD b/mediapipe/tasks/python/test/vision/BUILD index 066107421..48ecc30b3 100644 --- a/mediapipe/tasks/python/test/vision/BUILD +++ b/mediapipe/tasks/python/test/vision/BUILD @@ -49,7 +49,6 @@ py_test( "//mediapipe/tasks/python/components/containers:category", "//mediapipe/tasks/python/components/containers:classification_result", "//mediapipe/tasks/python/components/containers:rect", - "//mediapipe/tasks/python/components/processors:classifier_options", "//mediapipe/tasks/python/core:base_options", "//mediapipe/tasks/python/test:test_utils", "//mediapipe/tasks/python/vision:image_classifier", @@ -69,7 +68,6 @@ py_test( "//mediapipe/python:_framework_bindings", "//mediapipe/tasks/python/components/containers:embedding_result", "//mediapipe/tasks/python/components/containers:rect", - "//mediapipe/tasks/python/components/processors:embedder_options", "//mediapipe/tasks/python/core:base_options", "//mediapipe/tasks/python/test:test_utils", "//mediapipe/tasks/python/vision:image_embedder", diff --git a/mediapipe/tasks/python/test/vision/image_classifier_test.py b/mediapipe/tasks/python/test/vision/image_classifier_test.py index 77f16278f..cbeaf36bd 100644 --- a/mediapipe/tasks/python/test/vision/image_classifier_test.py +++ b/mediapipe/tasks/python/test/vision/image_classifier_test.py @@ -26,7 +26,6 @@ from mediapipe.python._framework_bindings import image from mediapipe.tasks.python.components.containers import category as category_module from mediapipe.tasks.python.components.containers import classification_result as classification_result_module from mediapipe.tasks.python.components.containers import rect -from mediapipe.tasks.python.components.processors import classifier_options from mediapipe.tasks.python.core import base_options as base_options_module from mediapipe.tasks.python.test import test_utils from mediapipe.tasks.python.vision import image_classifier @@ -36,7 +35,6 @@ from mediapipe.tasks.python.vision.core import vision_task_running_mode ImageClassifierResult = classification_result_module.ClassificationResult _Rect = rect.Rect _BaseOptions = base_options_module.BaseOptions -_ClassifierOptions = classifier_options.ClassifierOptions _Category = category_module.Category _Classifications = classification_result_module.Classifications _Image = image.Image @@ -171,9 +169,8 @@ class ImageClassifierTest(parameterized.TestCase): # Should never happen raise ValueError('model_file_type is invalid.') - custom_classifier_options = _ClassifierOptions(max_results=max_results) options = _ImageClassifierOptions( - base_options=base_options, classifier_options=custom_classifier_options) + base_options=base_options, max_results=max_results) classifier = _ImageClassifier.create_from_options(options) # Performs image classification on the input. @@ -200,9 +197,8 @@ class ImageClassifierTest(parameterized.TestCase): # Should never happen raise ValueError('model_file_type is invalid.') - custom_classifier_options = _ClassifierOptions(max_results=max_results) options = _ImageClassifierOptions( - base_options=base_options, classifier_options=custom_classifier_options) + base_options=base_options, max_results=max_results) with _ImageClassifier.create_from_options(options) as classifier: # Performs image classification on the input. image_result = classifier.classify(self.test_image) @@ -212,9 +208,7 @@ class ImageClassifierTest(parameterized.TestCase): def test_classify_succeeds_with_region_of_interest(self): base_options = _BaseOptions(model_asset_path=self.model_path) - custom_classifier_options = _ClassifierOptions(max_results=1) - options = _ImageClassifierOptions( - base_options=base_options, classifier_options=custom_classifier_options) + options = _ImageClassifierOptions(base_options=base_options, max_results=1) with _ImageClassifier.create_from_options(options) as classifier: # Load the test image. test_image = _Image.create_from_file( @@ -230,11 +224,9 @@ class ImageClassifierTest(parameterized.TestCase): _generate_soccer_ball_results().to_pb2()) def test_score_threshold_option(self): - custom_classifier_options = _ClassifierOptions( - score_threshold=_SCORE_THRESHOLD) options = _ImageClassifierOptions( base_options=_BaseOptions(model_asset_path=self.model_path), - classifier_options=custom_classifier_options) + score_threshold=_SCORE_THRESHOLD) with _ImageClassifier.create_from_options(options) as classifier: # Performs image classification on the input. image_result = classifier.classify(self.test_image) @@ -249,11 +241,9 @@ class ImageClassifierTest(parameterized.TestCase): f'{classification}') def test_max_results_option(self): - custom_classifier_options = _ClassifierOptions( - score_threshold=_SCORE_THRESHOLD) options = _ImageClassifierOptions( base_options=_BaseOptions(model_asset_path=self.model_path), - classifier_options=custom_classifier_options) + score_threshold=_SCORE_THRESHOLD) with _ImageClassifier.create_from_options(options) as classifier: # Performs image classification on the input. image_result = classifier.classify(self.test_image) @@ -263,11 +253,9 @@ class ImageClassifierTest(parameterized.TestCase): len(categories), _MAX_RESULTS, 'Too many results returned.') def test_allow_list_option(self): - custom_classifier_options = _ClassifierOptions( - category_allowlist=_ALLOW_LIST) options = _ImageClassifierOptions( base_options=_BaseOptions(model_asset_path=self.model_path), - classifier_options=custom_classifier_options) + category_allowlist=_ALLOW_LIST) with _ImageClassifier.create_from_options(options) as classifier: # Performs image classification on the input. image_result = classifier.classify(self.test_image) @@ -280,10 +268,9 @@ class ImageClassifierTest(parameterized.TestCase): f'Label {label} found but not in label allow list') def test_deny_list_option(self): - custom_classifier_options = _ClassifierOptions(category_denylist=_DENY_LIST) options = _ImageClassifierOptions( base_options=_BaseOptions(model_asset_path=self.model_path), - classifier_options=custom_classifier_options) + category_denylist=_DENY_LIST) with _ImageClassifier.create_from_options(options) as classifier: # Performs image classification on the input. image_result = classifier.classify(self.test_image) @@ -301,19 +288,17 @@ class ImageClassifierTest(parameterized.TestCase): ValueError, r'`category_allowlist` and `category_denylist` are mutually ' r'exclusive options.'): - custom_classifier_options = _ClassifierOptions( - category_allowlist=['foo'], category_denylist=['bar']) options = _ImageClassifierOptions( base_options=_BaseOptions(model_asset_path=self.model_path), - classifier_options=custom_classifier_options) + category_allowlist=['foo'], + category_denylist=['bar']) with _ImageClassifier.create_from_options(options) as unused_classifier: pass def test_empty_classification_outputs(self): - custom_classifier_options = _ClassifierOptions(score_threshold=1) options = _ImageClassifierOptions( base_options=_BaseOptions(model_asset_path=self.model_path), - classifier_options=custom_classifier_options) + score_threshold=1) with _ImageClassifier.create_from_options(options) as classifier: # Performs image classification on the input. image_result = classifier.classify(self.test_image) @@ -386,11 +371,10 @@ class ImageClassifierTest(parameterized.TestCase): classifier.classify_for_video(self.test_image, 0) def test_classify_for_video(self): - custom_classifier_options = _ClassifierOptions(max_results=4) options = _ImageClassifierOptions( base_options=_BaseOptions(model_asset_path=self.model_path), running_mode=_RUNNING_MODE.VIDEO, - classifier_options=custom_classifier_options) + max_results=4) with _ImageClassifier.create_from_options(options) as classifier: for timestamp in range(0, 300, 30): classification_result = classifier.classify_for_video( @@ -399,11 +383,10 @@ class ImageClassifierTest(parameterized.TestCase): _generate_burger_results().to_pb2()) def test_classify_for_video_succeeds_with_region_of_interest(self): - custom_classifier_options = _ClassifierOptions(max_results=1) options = _ImageClassifierOptions( base_options=_BaseOptions(model_asset_path=self.model_path), running_mode=_RUNNING_MODE.VIDEO, - classifier_options=custom_classifier_options) + max_results=1) with _ImageClassifier.create_from_options(options) as classifier: # Load the test image. test_image = _Image.create_from_file( @@ -439,11 +422,10 @@ class ImageClassifierTest(parameterized.TestCase): classifier.classify_for_video(self.test_image, 0) def test_classify_async_calls_with_illegal_timestamp(self): - custom_classifier_options = _ClassifierOptions(max_results=4) options = _ImageClassifierOptions( base_options=_BaseOptions(model_asset_path=self.model_path), running_mode=_RUNNING_MODE.LIVE_STREAM, - classifier_options=custom_classifier_options, + max_results=4, result_callback=mock.MagicMock()) with _ImageClassifier.create_from_options(options) as classifier: classifier.classify_async(self.test_image, 100) @@ -466,12 +448,11 @@ class ImageClassifierTest(parameterized.TestCase): self.assertLess(observed_timestamp_ms, timestamp_ms) self.observed_timestamp_ms = timestamp_ms - custom_classifier_options = _ClassifierOptions( - max_results=4, score_threshold=threshold) options = _ImageClassifierOptions( base_options=_BaseOptions(model_asset_path=self.model_path), running_mode=_RUNNING_MODE.LIVE_STREAM, - classifier_options=custom_classifier_options, + max_results=4, + score_threshold=threshold, result_callback=check_result) with _ImageClassifier.create_from_options(options) as classifier: for timestamp in range(0, 300, 30): @@ -496,11 +477,10 @@ class ImageClassifierTest(parameterized.TestCase): self.assertLess(observed_timestamp_ms, timestamp_ms) self.observed_timestamp_ms = timestamp_ms - custom_classifier_options = _ClassifierOptions(max_results=1) options = _ImageClassifierOptions( base_options=_BaseOptions(model_asset_path=self.model_path), running_mode=_RUNNING_MODE.LIVE_STREAM, - classifier_options=custom_classifier_options, + max_results=1, result_callback=check_result) with _ImageClassifier.create_from_options(options) as classifier: for timestamp in range(0, 300, 30): diff --git a/mediapipe/tasks/python/test/vision/image_embedder_test.py b/mediapipe/tasks/python/test/vision/image_embedder_test.py index 4bb96bad6..11c0cf002 100644 --- a/mediapipe/tasks/python/test/vision/image_embedder_test.py +++ b/mediapipe/tasks/python/test/vision/image_embedder_test.py @@ -24,7 +24,6 @@ import numpy as np from mediapipe.python._framework_bindings import image as image_module from mediapipe.tasks.python.components.containers import embedding_result as embedding_result_module from mediapipe.tasks.python.components.containers import rect -from mediapipe.tasks.python.components.processors import embedder_options as embedder_options_module from mediapipe.tasks.python.core import base_options as base_options_module from mediapipe.tasks.python.test import test_utils from mediapipe.tasks.python.vision import image_embedder @@ -33,7 +32,6 @@ from mediapipe.tasks.python.vision.core import vision_task_running_mode as runni _Rect = rect.Rect _BaseOptions = base_options_module.BaseOptions -_EmbedderOptions = embedder_options_module.EmbedderOptions _Embedding = embedding_result_module.Embedding _Image = image_module.Image _ImageEmbedder = image_embedder.ImageEmbedder @@ -142,10 +140,8 @@ class ImageEmbedderTest(parameterized.TestCase): # Should never happen raise ValueError('model_file_type is invalid.') - embedder_options = _EmbedderOptions( - l2_normalize=l2_normalize, quantize=quantize) options = _ImageEmbedderOptions( - base_options=base_options, embedder_options=embedder_options) + base_options=base_options, l2_normalize=l2_normalize, quantize=quantize) embedder = _ImageEmbedder.create_from_options(options) image_processing_options = None @@ -186,10 +182,8 @@ class ImageEmbedderTest(parameterized.TestCase): # Should never happen raise ValueError('model_file_type is invalid.') - embedder_options = _EmbedderOptions( - l2_normalize=l2_normalize, quantize=quantize) options = _ImageEmbedderOptions( - base_options=base_options, embedder_options=embedder_options) + base_options=base_options, l2_normalize=l2_normalize, quantize=quantize) with _ImageEmbedder.create_from_options(options) as embedder: # Extracts both embeddings. diff --git a/mediapipe/tasks/python/text/BUILD b/mediapipe/tasks/python/text/BUILD index 10b4b8a6e..9d5d23261 100644 --- a/mediapipe/tasks/python/text/BUILD +++ b/mediapipe/tasks/python/text/BUILD @@ -14,7 +14,7 @@ # Placeholder for internal Python strict library and test compatibility macro. -package(default_visibility = ["//mediapipe/tasks:internal"]) +package(default_visibility = ["//visibility:public"]) licenses(["notice"]) @@ -28,9 +28,9 @@ py_library( "//mediapipe/python:packet_creator", "//mediapipe/python:packet_getter", "//mediapipe/tasks/cc/components/containers/proto:classifications_py_pb2", + "//mediapipe/tasks/cc/components/processors/proto:classifier_options_py_pb2", "//mediapipe/tasks/cc/text/text_classifier/proto:text_classifier_graph_options_py_pb2", "//mediapipe/tasks/python/components/containers:classification_result", - "//mediapipe/tasks/python/components/processors:classifier_options", "//mediapipe/tasks/python/core:base_options", "//mediapipe/tasks/python/core:optional_dependencies", "//mediapipe/tasks/python/core:task_info", @@ -47,9 +47,9 @@ py_library( "//mediapipe/python:packet_creator", "//mediapipe/python:packet_getter", "//mediapipe/tasks/cc/components/containers/proto:embeddings_py_pb2", + "//mediapipe/tasks/cc/components/processors/proto:embedder_options_py_pb2", "//mediapipe/tasks/cc/text/text_embedder/proto:text_embedder_graph_options_py_pb2", "//mediapipe/tasks/python/components/containers:embedding_result", - "//mediapipe/tasks/python/components/processors:embedder_options", "//mediapipe/tasks/python/components/utils:cosine_similarity", "//mediapipe/tasks/python/core:base_options", "//mediapipe/tasks/python/core:optional_dependencies", diff --git a/mediapipe/tasks/python/text/core/BUILD b/mediapipe/tasks/python/text/core/BUILD index 072a0c7d8..e76bd4b6d 100644 --- a/mediapipe/tasks/python/text/core/BUILD +++ b/mediapipe/tasks/python/text/core/BUILD @@ -14,7 +14,7 @@ # Placeholder for internal Python strict library and test compatibility macro. -package(default_visibility = ["//mediapipe/tasks:internal"]) +package(default_visibility = ["//visibility:public"]) licenses(["notice"]) diff --git a/mediapipe/tasks/python/text/text_classifier.py b/mediapipe/tasks/python/text/text_classifier.py index 9711e8b3a..fdb20f0ef 100644 --- a/mediapipe/tasks/python/text/text_classifier.py +++ b/mediapipe/tasks/python/text/text_classifier.py @@ -14,14 +14,14 @@ """MediaPipe text classifier task.""" import dataclasses -from typing import Optional +from typing import Optional, List from mediapipe.python import packet_creator from mediapipe.python import packet_getter from mediapipe.tasks.cc.components.containers.proto import classifications_pb2 +from mediapipe.tasks.cc.components.processors.proto import classifier_options_pb2 from mediapipe.tasks.cc.text.text_classifier.proto import text_classifier_graph_options_pb2 from mediapipe.tasks.python.components.containers import classification_result as classification_result_module -from mediapipe.tasks.python.components.processors import classifier_options from mediapipe.tasks.python.core import base_options as base_options_module from mediapipe.tasks.python.core import task_info as task_info_module from mediapipe.tasks.python.core.optional_dependencies import doc_controls @@ -30,7 +30,7 @@ from mediapipe.tasks.python.text.core import base_text_task_api TextClassifierResult = classification_result_module.ClassificationResult _BaseOptions = base_options_module.BaseOptions _TextClassifierGraphOptionsProto = text_classifier_graph_options_pb2.TextClassifierGraphOptions -_ClassifierOptions = classifier_options.ClassifierOptions +_ClassifierOptionsProto = classifier_options_pb2.ClassifierOptions _TaskInfo = task_info_module.TaskInfo _CLASSIFICATIONS_STREAM_NAME = 'classifications_out' @@ -46,17 +46,38 @@ class TextClassifierOptions: Attributes: base_options: Base options for the text classifier task. - classifier_options: Options for the text classification task. + display_names_locale: The locale to use for display names specified through + the TFLite Model Metadata. + max_results: The maximum number of top-scored classification results to + return. + score_threshold: Overrides the ones provided in the model metadata. Results + below this value are rejected. + category_allowlist: Allowlist of category names. If non-empty, + classification results whose category name is not in this set will be + filtered out. Duplicate or unknown category names are ignored. Mutually + exclusive with `category_denylist`. + category_denylist: Denylist of category names. If non-empty, classification + results whose category name is in this set will be filtered out. Duplicate + or unknown category names are ignored. Mutually exclusive with + `category_allowlist`. """ base_options: _BaseOptions - classifier_options: Optional[_ClassifierOptions] = dataclasses.field( - default_factory=_ClassifierOptions) + display_names_locale: Optional[str] = None + max_results: Optional[int] = None + score_threshold: Optional[float] = None + category_allowlist: Optional[List[str]] = None + category_denylist: Optional[List[str]] = None @doc_controls.do_not_generate_docs def to_pb2(self) -> _TextClassifierGraphOptionsProto: """Generates an TextClassifierOptions protobuf object.""" base_options_proto = self.base_options.to_pb2() - classifier_options_proto = self.classifier_options.to_pb2() + classifier_options_proto = _ClassifierOptionsProto( + score_threshold=self.score_threshold, + category_allowlist=self.category_allowlist, + category_denylist=self.category_denylist, + display_names_locale=self.display_names_locale, + max_results=self.max_results) return _TextClassifierGraphOptionsProto( base_options=base_options_proto, diff --git a/mediapipe/tasks/python/text/text_embedder.py b/mediapipe/tasks/python/text/text_embedder.py index a9e560ac9..be899636d 100644 --- a/mediapipe/tasks/python/text/text_embedder.py +++ b/mediapipe/tasks/python/text/text_embedder.py @@ -19,9 +19,9 @@ from typing import Optional from mediapipe.python import packet_creator from mediapipe.python import packet_getter from mediapipe.tasks.cc.components.containers.proto import embeddings_pb2 +from mediapipe.tasks.cc.components.processors.proto import embedder_options_pb2 from mediapipe.tasks.cc.text.text_embedder.proto import text_embedder_graph_options_pb2 from mediapipe.tasks.python.components.containers import embedding_result as embedding_result_module -from mediapipe.tasks.python.components.processors import embedder_options from mediapipe.tasks.python.components.utils import cosine_similarity from mediapipe.tasks.python.core import base_options as base_options_module from mediapipe.tasks.python.core import task_info as task_info_module @@ -31,7 +31,7 @@ from mediapipe.tasks.python.text.core import base_text_task_api TextEmbedderResult = embedding_result_module.EmbeddingResult _BaseOptions = base_options_module.BaseOptions _TextEmbedderGraphOptionsProto = text_embedder_graph_options_pb2.TextEmbedderGraphOptions -_EmbedderOptions = embedder_options.EmbedderOptions +_EmbedderOptionsProto = embedder_options_pb2.EmbedderOptions _TaskInfo = task_info_module.TaskInfo _EMBEDDINGS_OUT_STREAM_NAME = 'embeddings_out' @@ -47,17 +47,25 @@ class TextEmbedderOptions: Attributes: base_options: Base options for the text embedder task. - embedder_options: Options for the text embedder task. + l2_normalize: Whether to normalize the returned feature vector with L2 norm. + Use this option only if the model does not already contain a native + L2_NORMALIZATION TF Lite Op. In most cases, this is already the case and + L2 norm is thus achieved through TF Lite inference. + quantize: Whether the returned embedding should be quantized to bytes via + scalar quantization. Embeddings are implicitly assumed to be unit-norm and + therefore any dimension is guaranteed to have a value in [-1.0, 1.0]. Use + the l2_normalize option if this is not the case. """ base_options: _BaseOptions - embedder_options: Optional[_EmbedderOptions] = dataclasses.field( - default_factory=_EmbedderOptions) + l2_normalize: Optional[bool] = None + quantize: Optional[bool] = None @doc_controls.do_not_generate_docs def to_pb2(self) -> _TextEmbedderGraphOptionsProto: """Generates an TextEmbedderOptions protobuf object.""" base_options_proto = self.base_options.to_pb2() - embedder_options_proto = self.embedder_options.to_pb2() + embedder_options_proto = _EmbedderOptionsProto( + l2_normalize=self.l2_normalize, quantize=self.quantize) return _TextEmbedderGraphOptionsProto( base_options=base_options_proto, diff --git a/mediapipe/tasks/python/vision/BUILD b/mediapipe/tasks/python/vision/BUILD index e94507eed..eda8e290d 100644 --- a/mediapipe/tasks/python/vision/BUILD +++ b/mediapipe/tasks/python/vision/BUILD @@ -14,7 +14,7 @@ # Placeholder for internal Python strict library and test compatibility macro. -package(default_visibility = ["//mediapipe/tasks:internal"]) +package(default_visibility = ["//visibility:public"]) licenses(["notice"]) @@ -47,10 +47,10 @@ py_library( "//mediapipe/python:packet_creator", "//mediapipe/python:packet_getter", "//mediapipe/tasks/cc/components/containers/proto:classifications_py_pb2", + "//mediapipe/tasks/cc/components/processors/proto:classifier_options_py_pb2", "//mediapipe/tasks/cc/vision/image_classifier/proto:image_classifier_graph_options_py_pb2", "//mediapipe/tasks/python/components/containers:classification_result", "//mediapipe/tasks/python/components/containers:rect", - "//mediapipe/tasks/python/components/processors:classifier_options", "//mediapipe/tasks/python/core:base_options", "//mediapipe/tasks/python/core:optional_dependencies", "//mediapipe/tasks/python/core:task_info", @@ -69,8 +69,8 @@ py_library( "//mediapipe/python:_framework_bindings", "//mediapipe/python:packet_creator", "//mediapipe/python:packet_getter", - "//mediapipe/tasks/cc/components/proto:segmenter_options_py_pb2", "//mediapipe/tasks/cc/vision/image_segmenter/proto:image_segmenter_graph_options_py_pb2", + "//mediapipe/tasks/cc/vision/image_segmenter/proto:segmenter_options_py_pb2", "//mediapipe/tasks/python/core:base_options", "//mediapipe/tasks/python/core:optional_dependencies", "//mediapipe/tasks/python/core:task_info", @@ -89,9 +89,9 @@ py_library( "//mediapipe/python:packet_creator", "//mediapipe/python:packet_getter", "//mediapipe/tasks/cc/components/containers/proto:embeddings_py_pb2", + "//mediapipe/tasks/cc/components/processors/proto:embedder_options_py_pb2", "//mediapipe/tasks/cc/vision/image_embedder/proto:image_embedder_graph_options_py_pb2", "//mediapipe/tasks/python/components/containers:embedding_result", - "//mediapipe/tasks/python/components/processors:embedder_options", "//mediapipe/tasks/python/components/utils:cosine_similarity", "//mediapipe/tasks/python/core:base_options", "//mediapipe/tasks/python/core:optional_dependencies", @@ -131,6 +131,10 @@ py_library( srcs = [ "hand_landmarker.py", ], + visibility = [ + "//mediapipe/model_maker/python/vision/gesture_recognizer:__subpackages__", + "//mediapipe/tasks:internal", + ], deps = [ "//mediapipe/framework/formats:classification_py_pb2", "//mediapipe/framework/formats:landmark_py_pb2", diff --git a/mediapipe/tasks/python/vision/core/BUILD b/mediapipe/tasks/python/vision/core/BUILD index e2b2b3dec..18df690a0 100644 --- a/mediapipe/tasks/python/vision/core/BUILD +++ b/mediapipe/tasks/python/vision/core/BUILD @@ -14,7 +14,7 @@ # Placeholder for internal Python strict library and test compatibility macro. -package(default_visibility = ["//mediapipe/tasks:internal"]) +package(default_visibility = ["//visibility:public"]) licenses(["notice"]) diff --git a/mediapipe/tasks/python/vision/image_classifier.py b/mediapipe/tasks/python/vision/image_classifier.py index 6cbce7860..b60d18e31 100644 --- a/mediapipe/tasks/python/vision/image_classifier.py +++ b/mediapipe/tasks/python/vision/image_classifier.py @@ -14,17 +14,17 @@ """MediaPipe image classifier task.""" import dataclasses -from typing import Callable, Mapping, Optional +from typing import Callable, Mapping, Optional, List from mediapipe.python import packet_creator from mediapipe.python import packet_getter from mediapipe.python._framework_bindings import image as image_module from mediapipe.python._framework_bindings import packet from mediapipe.tasks.cc.components.containers.proto import classifications_pb2 +from mediapipe.tasks.cc.components.processors.proto import classifier_options_pb2 from mediapipe.tasks.cc.vision.image_classifier.proto import image_classifier_graph_options_pb2 from mediapipe.tasks.python.components.containers import classification_result as classification_result_module from mediapipe.tasks.python.components.containers import rect -from mediapipe.tasks.python.components.processors import classifier_options from mediapipe.tasks.python.core import base_options as base_options_module from mediapipe.tasks.python.core import task_info as task_info_module from mediapipe.tasks.python.core.optional_dependencies import doc_controls @@ -36,7 +36,7 @@ ImageClassifierResult = classification_result_module.ClassificationResult _NormalizedRect = rect.NormalizedRect _BaseOptions = base_options_module.BaseOptions _ImageClassifierGraphOptionsProto = image_classifier_graph_options_pb2.ImageClassifierGraphOptions -_ClassifierOptions = classifier_options.ClassifierOptions +_ClassifierOptionsProto = classifier_options_pb2.ClassifierOptions _RunningMode = vision_task_running_mode.VisionTaskRunningMode _ImageProcessingOptions = image_processing_options_module.ImageProcessingOptions _TaskInfo = task_info_module.TaskInfo @@ -63,15 +63,31 @@ class ImageClassifierOptions: objects on single image inputs. 2) The video mode for classifying objects on the decoded frames of a video. 3) The live stream mode for classifying objects on a live stream of input data, such as from camera. - classifier_options: Options for the image classification task. + display_names_locale: The locale to use for display names specified through + the TFLite Model Metadata. + max_results: The maximum number of top-scored classification results to + return. + score_threshold: Overrides the ones provided in the model metadata. Results + below this value are rejected. + category_allowlist: Allowlist of category names. If non-empty, + classification results whose category name is not in this set will be + filtered out. Duplicate or unknown category names are ignored. Mutually + exclusive with `category_denylist`. + category_denylist: Denylist of category names. If non-empty, classification + results whose category name is in this set will be filtered out. Duplicate + or unknown category names are ignored. Mutually exclusive with + `category_allowlist`. result_callback: The user-defined result callback for processing live stream data. The result callback should only be specified when the running mode is set to the live stream mode. """ base_options: _BaseOptions running_mode: _RunningMode = _RunningMode.IMAGE - classifier_options: Optional[_ClassifierOptions] = dataclasses.field( - default_factory=_ClassifierOptions) + display_names_locale: Optional[str] = None + max_results: Optional[int] = None + score_threshold: Optional[float] = None + category_allowlist: Optional[List[str]] = None + category_denylist: Optional[List[str]] = None result_callback: Optional[Callable[ [ImageClassifierResult, image_module.Image, int], None]] = None @@ -80,7 +96,12 @@ class ImageClassifierOptions: """Generates an ImageClassifierOptions protobuf object.""" base_options_proto = self.base_options.to_pb2() base_options_proto.use_stream_mode = False if self.running_mode == _RunningMode.IMAGE else True - classifier_options_proto = self.classifier_options.to_pb2() + classifier_options_proto = _ClassifierOptionsProto( + score_threshold=self.score_threshold, + category_allowlist=self.category_allowlist, + category_denylist=self.category_denylist, + display_names_locale=self.display_names_locale, + max_results=self.max_results) return _ImageClassifierGraphOptionsProto( base_options=base_options_proto, diff --git a/mediapipe/tasks/python/vision/image_embedder.py b/mediapipe/tasks/python/vision/image_embedder.py index a58dca3ae..0bae21bda 100644 --- a/mediapipe/tasks/python/vision/image_embedder.py +++ b/mediapipe/tasks/python/vision/image_embedder.py @@ -21,9 +21,9 @@ from mediapipe.python import packet_getter from mediapipe.python._framework_bindings import image as image_module from mediapipe.python._framework_bindings import packet as packet_module from mediapipe.tasks.cc.components.containers.proto import embeddings_pb2 +from mediapipe.tasks.cc.components.processors.proto import embedder_options_pb2 from mediapipe.tasks.cc.vision.image_embedder.proto import image_embedder_graph_options_pb2 from mediapipe.tasks.python.components.containers import embedding_result as embedding_result_module -from mediapipe.tasks.python.components.processors import embedder_options from mediapipe.tasks.python.components.utils import cosine_similarity from mediapipe.tasks.python.core import base_options as base_options_module from mediapipe.tasks.python.core import task_info as task_info_module @@ -35,7 +35,7 @@ from mediapipe.tasks.python.vision.core import vision_task_running_mode as runni ImageEmbedderResult = embedding_result_module.EmbeddingResult _BaseOptions = base_options_module.BaseOptions _ImageEmbedderGraphOptionsProto = image_embedder_graph_options_pb2.ImageEmbedderGraphOptions -_EmbedderOptions = embedder_options.EmbedderOptions +_EmbedderOptionsProto = embedder_options_pb2.EmbedderOptions _RunningMode = running_mode_module.VisionTaskRunningMode _TaskInfo = task_info_module.TaskInfo _ImageProcessingOptions = image_processing_options_module.ImageProcessingOptions @@ -62,15 +62,22 @@ class ImageEmbedderOptions: image on single image inputs. 2) The video mode for embedding image on the decoded frames of a video. 3) The live stream mode for embedding image on a live stream of input data, such as from camera. - embedder_options: Options for the image embedder task. + l2_normalize: Whether to normalize the returned feature vector with L2 norm. + Use this option only if the model does not already contain a native + L2_NORMALIZATION TF Lite Op. In most cases, this is already the case and + L2 norm is thus achieved through TF Lite inference. + quantize: Whether the returned embedding should be quantized to bytes via + scalar quantization. Embeddings are implicitly assumed to be unit-norm and + therefore any dimension is guaranteed to have a value in [-1.0, 1.0]. Use + the l2_normalize option if this is not the case. result_callback: The user-defined result callback for processing live stream data. The result callback should only be specified when the running mode is set to the live stream mode. """ base_options: _BaseOptions running_mode: _RunningMode = _RunningMode.IMAGE - embedder_options: Optional[_EmbedderOptions] = dataclasses.field( - default_factory=_EmbedderOptions) + l2_normalize: Optional[bool] = None + quantize: Optional[bool] = None result_callback: Optional[Callable[ [ImageEmbedderResult, image_module.Image, int], None]] = None @@ -79,7 +86,8 @@ class ImageEmbedderOptions: """Generates an ImageEmbedderOptions protobuf object.""" base_options_proto = self.base_options.to_pb2() base_options_proto.use_stream_mode = False if self.running_mode == _RunningMode.IMAGE else True - embedder_options_proto = self.embedder_options.to_pb2() + embedder_options_proto = _EmbedderOptionsProto( + l2_normalize=self.l2_normalize, quantize=self.quantize) return _ImageEmbedderGraphOptionsProto( base_options=base_options_proto, diff --git a/mediapipe/tasks/python/vision/image_segmenter.py b/mediapipe/tasks/python/vision/image_segmenter.py index 62fc8bb7c..22a37cb3e 100644 --- a/mediapipe/tasks/python/vision/image_segmenter.py +++ b/mediapipe/tasks/python/vision/image_segmenter.py @@ -21,8 +21,8 @@ from mediapipe.python import packet_creator from mediapipe.python import packet_getter from mediapipe.python._framework_bindings import image as image_module from mediapipe.python._framework_bindings import packet -from mediapipe.tasks.cc.components.proto import segmenter_options_pb2 from mediapipe.tasks.cc.vision.image_segmenter.proto import image_segmenter_graph_options_pb2 +from mediapipe.tasks.cc.vision.image_segmenter.proto import segmenter_options_pb2 from mediapipe.tasks.python.core import base_options as base_options_module from mediapipe.tasks.python.core import task_info as task_info_module from mediapipe.tasks.python.core.optional_dependencies import doc_controls diff --git a/mediapipe/tasks/testdata/text/BUILD b/mediapipe/tasks/testdata/text/BUILD index 081e63c2c..a0131c056 100644 --- a/mediapipe/tasks/testdata/text/BUILD +++ b/mediapipe/tasks/testdata/text/BUILD @@ -18,7 +18,10 @@ load( ) package( - default_visibility = ["//mediapipe/framework:mediapipe_internal"], + default_visibility = [ + "//mediapipe/calculators/tensor:__subpackages__", + "//mediapipe/tasks:__subpackages__", + ], licenses = ["notice"], # Apache 2.0 ) diff --git a/mediapipe/tasks/web/BUILD b/mediapipe/tasks/web/BUILD index 20e717433..bc9e84147 100644 --- a/mediapipe/tasks/web/BUILD +++ b/mediapipe/tasks/web/BUILD @@ -44,7 +44,6 @@ rollup_bundle( ":audio_lib", "@npm//@rollup/plugin-commonjs", "@npm//@rollup/plugin-node-resolve", - "@npm//@rollup/plugin-replace", "@npm//@rollup/plugin-terser", "@npm//google-protobuf", ], @@ -88,7 +87,6 @@ rollup_bundle( ":text_lib", "@npm//@rollup/plugin-commonjs", "@npm//@rollup/plugin-node-resolve", - "@npm//@rollup/plugin-replace", "@npm//@rollup/plugin-terser", "@npm//google-protobuf", ], @@ -132,7 +130,6 @@ rollup_bundle( ":vision_lib", "@npm//@rollup/plugin-commonjs", "@npm//@rollup/plugin-node-resolve", - "@npm//@rollup/plugin-replace", "@npm//@rollup/plugin-terser", "@npm//google-protobuf", ], diff --git a/mediapipe/tasks/web/audio/BUILD b/mediapipe/tasks/web/audio/BUILD index d08602521..9d26f1118 100644 --- a/mediapipe/tasks/web/audio/BUILD +++ b/mediapipe/tasks/web/audio/BUILD @@ -7,6 +7,7 @@ package(default_visibility = ["//mediapipe/tasks:internal"]) mediapipe_ts_library( name = "audio_lib", srcs = ["index.ts"], + visibility = ["//visibility:public"], deps = [ "//mediapipe/tasks/web/audio/audio_classifier", "//mediapipe/tasks/web/audio/audio_embedder", diff --git a/mediapipe/tasks/web/audio/audio_classifier/BUILD b/mediapipe/tasks/web/audio/audio_classifier/BUILD index 6f785dd0d..24ef31feb 100644 --- a/mediapipe/tasks/web/audio/audio_classifier/BUILD +++ b/mediapipe/tasks/web/audio/audio_classifier/BUILD @@ -2,6 +2,7 @@ # # This task takes audio data and outputs the classification result. +load("@npm//@bazel/jasmine:index.bzl", "jasmine_node_test") load("//mediapipe/framework/port:build_config.bzl", "mediapipe_ts_declaration", "mediapipe_ts_library") package(default_visibility = ["//mediapipe/tasks:internal"]) @@ -11,6 +12,7 @@ licenses(["notice"]) mediapipe_ts_library( name = "audio_classifier", srcs = ["audio_classifier.ts"], + visibility = ["//visibility:public"], deps = [ ":audio_classifier_types", "//mediapipe/framework:calculator_jspb_proto", @@ -35,6 +37,7 @@ mediapipe_ts_declaration( "audio_classifier_options.d.ts", "audio_classifier_result.d.ts", ], + visibility = ["//visibility:public"], deps = [ "//mediapipe/tasks/web/components/containers:category", "//mediapipe/tasks/web/components/containers:classification_result", @@ -42,3 +45,23 @@ mediapipe_ts_declaration( "//mediapipe/tasks/web/core:classifier_options", ], ) + +mediapipe_ts_library( + name = "audio_classifier_test_lib", + testonly = True, + srcs = [ + "audio_classifier_test.ts", + ], + deps = [ + ":audio_classifier", + "//mediapipe/framework:calculator_jspb_proto", + "//mediapipe/framework/formats:classification_jspb_proto", + "//mediapipe/tasks/cc/components/containers/proto:classifications_jspb_proto", + "//mediapipe/tasks/web/core:task_runner_test_utils", + ], +) + +jasmine_node_test( + name = "audio_classifier_test", + deps = [":audio_classifier_test_lib"], +) diff --git a/mediapipe/tasks/web/audio/audio_classifier/audio_classifier.ts b/mediapipe/tasks/web/audio/audio_classifier/audio_classifier.ts index 265ba2b33..7bfca680a 100644 --- a/mediapipe/tasks/web/audio/audio_classifier/audio_classifier.ts +++ b/mediapipe/tasks/web/audio/audio_classifier/audio_classifier.ts @@ -94,6 +94,7 @@ export class AudioClassifier extends AudioTaskRunner { {baseOptions: {modelAssetPath}}); } + /** @hideconstructor */ constructor( wasmModule: WasmModule, glCanvas?: HTMLCanvasElement|OffscreenCanvas|null) { diff --git a/mediapipe/tasks/web/audio/audio_classifier/audio_classifier_test.ts b/mediapipe/tasks/web/audio/audio_classifier/audio_classifier_test.ts new file mode 100644 index 000000000..d5c0a9429 --- /dev/null +++ b/mediapipe/tasks/web/audio/audio_classifier/audio_classifier_test.ts @@ -0,0 +1,208 @@ +/** + * Copyright 2022 The MediaPipe Authors. All Rights Reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +import 'jasmine'; + +import {CalculatorGraphConfig} from '../../../../framework/calculator_pb'; +import {Classification, ClassificationList} from '../../../../framework/formats/classification_pb'; +import {ClassificationResult, Classifications} from '../../../../tasks/cc/components/containers/proto/classifications_pb'; +import {addJasmineCustomFloatEqualityTester, createSpyWasmModule, MediapipeTasksFake, verifyGraph, verifyListenersRegistered} from '../../../../tasks/web/core/task_runner_test_utils'; + +import {AudioClassifier} from './audio_classifier'; + +// The OSS JS API does not support the builder pattern. +// tslint:disable:jspb-use-builder-pattern + +class AudioClassifierFake extends AudioClassifier implements + MediapipeTasksFake { + lastSampleRate: number|undefined; + calculatorName = + 'mediapipe.tasks.audio.audio_classifier.AudioClassifierGraph'; + attachListenerSpies: jasmine.Spy[] = []; + graph: CalculatorGraphConfig|undefined; + + private protoVectorListener: ((binaryProtos: Uint8Array[]) => void)|undefined; + private resultProtoVector: ClassificationResult[] = []; + + constructor() { + super(createSpyWasmModule(), /* glCanvas= */ null); + + this.attachListenerSpies[0] = + spyOn(this.graphRunner, 'attachProtoVectorListener') + .and.callFake((stream, listener) => { + expect(stream).toEqual('timestamped_classifications'); + this.protoVectorListener = listener; + }); + spyOn(this.graphRunner, 'addDoubleToStream') + .and.callFake((sampleRate, streamName, timestamp) => { + if (streamName === 'sample_rate') { + this.lastSampleRate = sampleRate; + } + }); + spyOn(this.graphRunner, 'addAudioToStreamWithShape') + .and.callFake( + (audioData, numChannels, numSamples, streamName, timestamp) => { + expect(numChannels).toBe(1); + }); + spyOn(this.graphRunner, 'finishProcessing').and.callFake(() => { + if (!this.protoVectorListener) return; + this.protoVectorListener(this.resultProtoVector.map( + classificationResult => classificationResult.serializeBinary())); + }); + spyOn(this.graphRunner, 'setGraph').and.callFake(binaryGraph => { + this.graph = CalculatorGraphConfig.deserializeBinary(binaryGraph); + }); + } + + /** Sets the Protobuf that will be send to the API. */ + setResults(results: ClassificationResult[]): void { + this.resultProtoVector = results; + } +} + +describe('AudioClassifier', () => { + let audioClassifier: AudioClassifierFake; + + beforeEach(async () => { + addJasmineCustomFloatEqualityTester(); + audioClassifier = new AudioClassifierFake(); + await audioClassifier.setOptions({}); // Initialize graph + }); + + it('initializes graph', async () => { + verifyGraph(audioClassifier); + verifyListenersRegistered(audioClassifier); + }); + + it('reloads graph when settings are changed', async () => { + await audioClassifier.setOptions({maxResults: 1}); + verifyGraph(audioClassifier, [['classifierOptions', 'maxResults'], 1]); + verifyListenersRegistered(audioClassifier); + + await audioClassifier.setOptions({maxResults: 5}); + verifyGraph(audioClassifier, [['classifierOptions', 'maxResults'], 5]); + verifyListenersRegistered(audioClassifier); + }); + + it('merges options', async () => { + await audioClassifier.setOptions({maxResults: 1}); + await audioClassifier.setOptions({displayNamesLocale: 'en'}); + verifyGraph(audioClassifier, [ + 'classifierOptions', { + maxResults: 1, + displayNamesLocale: 'en', + scoreThreshold: undefined, + categoryAllowlistList: [], + categoryDenylistList: [] + } + ]); + }); + + it('uses a sample rate of 48000 by default', async () => { + audioClassifier.classify(new Float32Array([])); + expect(audioClassifier.lastSampleRate).toEqual(48000); + }); + + it('uses default sample rate if none provided', async () => { + audioClassifier.setDefaultSampleRate(16000); + audioClassifier.classify(new Float32Array([])); + expect(audioClassifier.lastSampleRate).toEqual(16000); + }); + + it('uses custom sample rate if provided', async () => { + audioClassifier.setDefaultSampleRate(16000); + audioClassifier.classify(new Float32Array([]), 44100); + expect(audioClassifier.lastSampleRate).toEqual(44100); + }); + + it('transforms results', async () => { + const resultProtoVector: ClassificationResult[] = []; + + let classificationResult = new ClassificationResult(); + classificationResult.setTimestampMs(0); + let classifcations = new Classifications(); + classifcations.setHeadIndex(1); + classifcations.setHeadName('headName'); + let classificationList = new ClassificationList(); + let clasification = new Classification(); + clasification.setIndex(1); + clasification.setScore(0.2); + clasification.setDisplayName('displayName'); + clasification.setLabel('categoryName'); + classificationList.addClassification(clasification); + classifcations.setClassificationList(classificationList); + classificationResult.addClassifications(classifcations); + resultProtoVector.push(classificationResult); + + classificationResult = new ClassificationResult(); + classificationResult.setTimestampMs(1); + classifcations = new Classifications(); + classificationList = new ClassificationList(); + clasification = new Classification(); + clasification.setIndex(2); + clasification.setScore(0.3); + classificationList.addClassification(clasification); + classifcations.setClassificationList(classificationList); + classificationResult.addClassifications(classifcations); + resultProtoVector.push(classificationResult); + + // Invoke the audio classifier + audioClassifier.setResults(resultProtoVector); + const results = audioClassifier.classify(new Float32Array([])); + expect(results.length).toEqual(2); + expect(results[0]).toEqual({ + classifications: [{ + categories: [{ + index: 1, + score: 0.2, + displayName: 'displayName', + categoryName: 'categoryName' + }], + headIndex: 1, + headName: 'headName' + }], + timestampMs: 0 + }); + expect(results[1]).toEqual({ + classifications: [{ + categories: [{index: 2, score: 0.3, displayName: '', categoryName: ''}], + headIndex: 0, + headName: '' + }], + timestampMs: 1 + }); + }); + + it('clears results between invocations', async () => { + const classificationResult = new ClassificationResult(); + const classifcations = new Classifications(); + const classificationList = new ClassificationList(); + const clasification = new Classification(); + classificationList.addClassification(clasification); + classifcations.setClassificationList(classificationList); + classificationResult.addClassifications(classifcations); + + audioClassifier.setResults([classificationResult]); + + // Invoke the gesture recognizer twice + const classifications1 = audioClassifier.classify(new Float32Array([])); + const classifications2 = audioClassifier.classify(new Float32Array([])); + + // Verify that gestures2 is not a concatenation of all previously returned + // gestures. + expect(classifications1).toEqual(classifications2); + }); +}); diff --git a/mediapipe/tasks/web/audio/audio_embedder/BUILD b/mediapipe/tasks/web/audio/audio_embedder/BUILD index 0555bb639..0817776c5 100644 --- a/mediapipe/tasks/web/audio/audio_embedder/BUILD +++ b/mediapipe/tasks/web/audio/audio_embedder/BUILD @@ -3,6 +3,7 @@ # This task takes audio input and performs embedding. load("//mediapipe/framework/port:build_config.bzl", "mediapipe_ts_declaration", "mediapipe_ts_library") +load("@npm//@bazel/jasmine:index.bzl", "jasmine_node_test") package(default_visibility = ["//mediapipe/tasks:internal"]) @@ -11,6 +12,7 @@ licenses(["notice"]) mediapipe_ts_library( name = "audio_embedder", srcs = ["audio_embedder.ts"], + visibility = ["//visibility:public"], deps = [ ":audio_embedder_types", "//mediapipe/framework:calculator_jspb_proto", @@ -35,9 +37,30 @@ mediapipe_ts_declaration( "audio_embedder_options.d.ts", "audio_embedder_result.d.ts", ], + visibility = ["//visibility:public"], deps = [ "//mediapipe/tasks/web/components/containers:embedding_result", "//mediapipe/tasks/web/core", "//mediapipe/tasks/web/core:embedder_options", ], ) + +mediapipe_ts_library( + name = "audio_embedder_test_lib", + testonly = True, + srcs = [ + "audio_embedder_test.ts", + ], + deps = [ + ":audio_embedder", + "//mediapipe/framework:calculator_jspb_proto", + "//mediapipe/tasks/cc/components/containers/proto:embeddings_jspb_proto", + "//mediapipe/tasks/web/core", + "//mediapipe/tasks/web/core:task_runner_test_utils", + ], +) + +jasmine_node_test( + name = "audio_embedder_test", + deps = [":audio_embedder_test_lib"], +) diff --git a/mediapipe/tasks/web/audio/audio_embedder/audio_embedder.ts b/mediapipe/tasks/web/audio/audio_embedder/audio_embedder.ts index 445dd5172..246cba883 100644 --- a/mediapipe/tasks/web/audio/audio_embedder/audio_embedder.ts +++ b/mediapipe/tasks/web/audio/audio_embedder/audio_embedder.ts @@ -96,6 +96,7 @@ export class AudioEmbedder extends AudioTaskRunner { {baseOptions: {modelAssetPath}}); } + /** @hideconstructor */ constructor( wasmModule: WasmModule, glCanvas?: HTMLCanvasElement|OffscreenCanvas|null) { diff --git a/mediapipe/tasks/web/audio/audio_embedder/audio_embedder_test.ts b/mediapipe/tasks/web/audio/audio_embedder/audio_embedder_test.ts new file mode 100644 index 000000000..2f605ff98 --- /dev/null +++ b/mediapipe/tasks/web/audio/audio_embedder/audio_embedder_test.ts @@ -0,0 +1,185 @@ +/** + * Copyright 2022 The MediaPipe Authors. All Rights Reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +import 'jasmine'; + +// Placeholder for internal dependency on encodeByteArray +import {CalculatorGraphConfig} from '../../../../framework/calculator_pb'; +import {Embedding, EmbeddingResult as EmbeddingResultProto, FloatEmbedding} from '../../../../tasks/cc/components/containers/proto/embeddings_pb'; +import {addJasmineCustomFloatEqualityTester, createSpyWasmModule, MediapipeTasksFake, SpyWasmModule, verifyGraph, verifyListenersRegistered} from '../../../../tasks/web/core/task_runner_test_utils'; + +import {AudioEmbedder, AudioEmbedderResult} from './audio_embedder'; + + +// The OSS JS API does not support the builder pattern. +// tslint:disable:jspb-use-builder-pattern + +class AudioEmbedderFake extends AudioEmbedder implements MediapipeTasksFake { + lastSampleRate: number|undefined; + calculatorName = 'mediapipe.tasks.audio.audio_embedder.AudioEmbedderGraph'; + graph: CalculatorGraphConfig|undefined; + attachListenerSpies: jasmine.Spy[] = []; + fakeWasmModule: SpyWasmModule; + + protoListener: ((binaryProto: Uint8Array) => void)|undefined; + protoVectorListener: ((binaryProtos: Uint8Array[]) => void)|undefined; + + constructor() { + super(createSpyWasmModule(), /* glCanvas= */ null); + this.fakeWasmModule = + this.graphRunner.wasmModule as unknown as SpyWasmModule; + + this.attachListenerSpies[0] = + spyOn(this.graphRunner, 'attachProtoListener') + .and.callFake((stream, listener) => { + expect(stream).toEqual('embeddings_out'); + this.protoListener = listener; + }); + this.attachListenerSpies[1] = + spyOn(this.graphRunner, 'attachProtoVectorListener') + .and.callFake((stream, listener) => { + expect(stream).toEqual('timestamped_embeddings_out'); + this.protoVectorListener = listener; + }); + spyOn(this.graphRunner, 'setGraph').and.callFake(binaryGraph => { + this.graph = CalculatorGraphConfig.deserializeBinary(binaryGraph); + }); + spyOn(this.graphRunner, 'addDoubleToStream').and.callFake(sampleRate => { + this.lastSampleRate = sampleRate; + }); + spyOn(this.graphRunner, 'addAudioToStreamWithShape'); + } +} + +describe('AudioEmbedder', () => { + let audioEmbedder: AudioEmbedderFake; + + beforeEach(async () => { + addJasmineCustomFloatEqualityTester(); + audioEmbedder = new AudioEmbedderFake(); + await audioEmbedder.setOptions({}); // Initialize graph + }); + + it('initializes graph', () => { + verifyGraph(audioEmbedder); + verifyListenersRegistered(audioEmbedder); + }); + + it('reloads graph when settings are changed', async () => { + await audioEmbedder.setOptions({quantize: true}); + verifyGraph(audioEmbedder, [['embedderOptions', 'quantize'], true]); + verifyListenersRegistered(audioEmbedder); + + await audioEmbedder.setOptions({quantize: undefined}); + verifyGraph(audioEmbedder, [['embedderOptions', 'quantize'], undefined]); + verifyListenersRegistered(audioEmbedder); + }); + + it('can use custom models', async () => { + const newModel = new Uint8Array([0, 1, 2, 3, 4]); + const newModelBase64 = Buffer.from(newModel).toString('base64'); + await audioEmbedder.setOptions({ + baseOptions: { + modelAssetBuffer: newModel, + } + }); + + verifyGraph( + audioEmbedder, + /* expectedCalculatorOptions= */ undefined, + /* expectedBaseOptions= */[ + 'modelAsset', { + fileContent: newModelBase64, + fileName: undefined, + fileDescriptorMeta: undefined, + filePointerMeta: undefined + } + ]); + }); + + it('combines options', async () => { + await audioEmbedder.setOptions({quantize: true}); + await audioEmbedder.setOptions({l2Normalize: true}); + verifyGraph( + audioEmbedder, + ['embedderOptions', {'quantize': true, 'l2Normalize': true}]); + }); + + it('uses a sample rate of 48000 by default', async () => { + audioEmbedder.embed(new Float32Array([])); + expect(audioEmbedder.lastSampleRate).toEqual(48000); + }); + + it('uses default sample rate if none provided', async () => { + audioEmbedder.setDefaultSampleRate(16000); + audioEmbedder.embed(new Float32Array([])); + expect(audioEmbedder.lastSampleRate).toEqual(16000); + }); + + it('uses custom sample rate if provided', async () => { + audioEmbedder.setDefaultSampleRate(16000); + audioEmbedder.embed(new Float32Array([]), 44100); + expect(audioEmbedder.lastSampleRate).toEqual(44100); + }); + + describe('transforms results', () => { + const embedding = new Embedding(); + embedding.setHeadIndex(1); + embedding.setHeadName('headName'); + + const floatEmbedding = new FloatEmbedding(); + floatEmbedding.setValuesList([0.1, 0.9]); + + embedding.setFloatEmbedding(floatEmbedding); + const resultProto = new EmbeddingResultProto(); + resultProto.addEmbeddings(embedding); + + function validateEmbeddingResult( + expectedEmbeddignResult: AudioEmbedderResult[]) { + expect(expectedEmbeddignResult.length).toEqual(1); + + const [embeddingResult] = expectedEmbeddignResult; + expect(embeddingResult.embeddings.length).toEqual(1); + expect(embeddingResult.embeddings[0]) + .toEqual( + {floatEmbedding: [0.1, 0.9], headIndex: 1, headName: 'headName'}); + } + + it('from embeddings strem', async () => { + audioEmbedder.fakeWasmModule._waitUntilIdle.and.callFake(() => { + verifyListenersRegistered(audioEmbedder); + // Pass the test data to our listener + audioEmbedder.protoListener!(resultProto.serializeBinary()); + }); + + // Invoke the audio embedder + const embeddingResults = audioEmbedder.embed(new Float32Array([])); + validateEmbeddingResult(embeddingResults); + }); + + it('from timestamped embeddgins stream', async () => { + audioEmbedder.fakeWasmModule._waitUntilIdle.and.callFake(() => { + verifyListenersRegistered(audioEmbedder); + // Pass the test data to our listener + audioEmbedder.protoVectorListener!([resultProto.serializeBinary()]); + }); + + // Invoke the audio embedder + const embeddingResults = audioEmbedder.embed(new Float32Array([]), 42); + validateEmbeddingResult(embeddingResults); + }); + }); +}); diff --git a/mediapipe/tasks/web/audio/core/BUILD b/mediapipe/tasks/web/audio/core/BUILD index 9ab6c7bee..cea689838 100644 --- a/mediapipe/tasks/web/audio/core/BUILD +++ b/mediapipe/tasks/web/audio/core/BUILD @@ -7,8 +7,5 @@ package(default_visibility = ["//mediapipe/tasks:internal"]) mediapipe_ts_library( name = "audio_task_runner", srcs = ["audio_task_runner.ts"], - deps = [ - "//mediapipe/tasks/web/core", - "//mediapipe/tasks/web/core:task_runner", - ], + deps = ["//mediapipe/tasks/web/core:task_runner"], ) diff --git a/mediapipe/tasks/web/audio/core/audio_task_runner.ts b/mediapipe/tasks/web/audio/core/audio_task_runner.ts index 00cfe0253..24d78378d 100644 --- a/mediapipe/tasks/web/audio/core/audio_task_runner.ts +++ b/mediapipe/tasks/web/audio/core/audio_task_runner.ts @@ -15,10 +15,9 @@ */ import {TaskRunner} from '../../../../tasks/web/core/task_runner'; -import {TaskRunnerOptions} from '../../../../tasks/web/core/task_runner_options'; /** Base class for all MediaPipe Audio Tasks. */ -export abstract class AudioTaskRunner extends TaskRunner { +export abstract class AudioTaskRunner extends TaskRunner { private defaultSampleRate = 48000; /** diff --git a/mediapipe/tasks/web/components/processors/BUILD b/mediapipe/tasks/web/components/processors/BUILD index 86e743928..148a08238 100644 --- a/mediapipe/tasks/web/components/processors/BUILD +++ b/mediapipe/tasks/web/components/processors/BUILD @@ -1,5 +1,6 @@ # This package contains options shared by all MediaPipe Tasks for Web. +load("@npm//@bazel/jasmine:index.bzl", "jasmine_node_test") load("//mediapipe/framework/port:build_config.bzl", "mediapipe_ts_library") package(default_visibility = ["//mediapipe/tasks:internal"]) @@ -13,6 +14,22 @@ mediapipe_ts_library( ], ) +mediapipe_ts_library( + name = "classifier_options_test_lib", + testonly = True, + srcs = ["classifier_options.test.ts"], + deps = [ + ":classifier_options", + "//mediapipe/tasks/cc/components/processors/proto:classifier_options_jspb_proto", + "//mediapipe/tasks/web/core:classifier_options", + ], +) + +jasmine_node_test( + name = "classifier_options_test", + deps = [":classifier_options_test_lib"], +) + mediapipe_ts_library( name = "classifier_result", srcs = ["classifier_result.ts"], @@ -22,6 +39,22 @@ mediapipe_ts_library( ], ) +mediapipe_ts_library( + name = "classifier_result_test_lib", + testonly = True, + srcs = ["classifier_result.test.ts"], + deps = [ + ":classifier_result", + "//mediapipe/framework/formats:classification_jspb_proto", + "//mediapipe/tasks/cc/components/containers/proto:classifications_jspb_proto", + ], +) + +jasmine_node_test( + name = "classifier_result_test", + deps = [":classifier_result_test_lib"], +) + mediapipe_ts_library( name = "embedder_result", srcs = ["embedder_result.ts"], @@ -31,6 +64,21 @@ mediapipe_ts_library( ], ) +mediapipe_ts_library( + name = "embedder_result_test_lib", + testonly = True, + srcs = ["embedder_result.test.ts"], + deps = [ + ":embedder_result", + "//mediapipe/tasks/cc/components/containers/proto:embeddings_jspb_proto", + ], +) + +jasmine_node_test( + name = "embedder_result_test", + deps = [":embedder_result_test_lib"], +) + mediapipe_ts_library( name = "embedder_options", srcs = ["embedder_options.ts"], @@ -40,6 +88,22 @@ mediapipe_ts_library( ], ) +mediapipe_ts_library( + name = "embedder_options_test_lib", + testonly = True, + srcs = ["embedder_options.test.ts"], + deps = [ + ":embedder_options", + "//mediapipe/tasks/cc/components/processors/proto:embedder_options_jspb_proto", + "//mediapipe/tasks/web/core:embedder_options", + ], +) + +jasmine_node_test( + name = "embedder_options_test", + deps = [":embedder_options_test_lib"], +) + mediapipe_ts_library( name = "base_options", srcs = [ @@ -53,3 +117,15 @@ mediapipe_ts_library( "//mediapipe/tasks/web/core", ], ) + +mediapipe_ts_library( + name = "base_options_test_lib", + testonly = True, + srcs = ["base_options.test.ts"], + deps = [":base_options"], +) + +jasmine_node_test( + name = "base_options_test", + deps = [":base_options_test_lib"], +) diff --git a/mediapipe/tasks/web/components/processors/base_options.test.ts b/mediapipe/tasks/web/components/processors/base_options.test.ts new file mode 100644 index 000000000..46c2277e9 --- /dev/null +++ b/mediapipe/tasks/web/components/processors/base_options.test.ts @@ -0,0 +1,127 @@ +/** + * Copyright 2022 The MediaPipe Authors. All Rights Reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +import 'jasmine'; + +// Placeholder for internal dependency on encodeByteArray +// Placeholder for internal dependency on trusted resource URL builder + +import {convertBaseOptionsToProto} from './base_options'; + +describe('convertBaseOptionsToProto()', () => { + const mockBytes = new Uint8Array([0, 1, 2, 3]); + const mockBytesResult = { + modelAsset: { + fileContent: Buffer.from(mockBytes).toString('base64'), + fileName: undefined, + fileDescriptorMeta: undefined, + filePointerMeta: undefined, + }, + useStreamMode: false, + acceleration: { + xnnpack: undefined, + gpu: undefined, + tflite: {}, + }, + }; + + let fetchSpy: jasmine.Spy; + + beforeEach(() => { + fetchSpy = jasmine.createSpy().and.callFake(async url => { + expect(url).toEqual('foo'); + return { + arrayBuffer: () => mockBytes.buffer, + } as unknown as Response; + }); + global.fetch = fetchSpy; + }); + + it('verifies that at least one model asset option is provided', async () => { + await expectAsync(convertBaseOptionsToProto({})) + .toBeRejectedWithError( + /Either baseOptions.modelAssetPath or baseOptions.modelAssetBuffer must be set/); + }); + + it('verifies that no more than one model asset option is provided', async () => { + await expectAsync(convertBaseOptionsToProto({ + modelAssetPath: `foo`, + modelAssetBuffer: new Uint8Array([]) + })) + .toBeRejectedWithError( + /Cannot set both baseOptions.modelAssetPath and baseOptions.modelAssetBuffer/); + }); + + it('downloads model', async () => { + const baseOptionsProto = await convertBaseOptionsToProto({ + modelAssetPath: `foo`, + }); + + expect(fetchSpy).toHaveBeenCalled(); + expect(baseOptionsProto.toObject()).toEqual(mockBytesResult); + }); + + it('does not download model when bytes are provided', async () => { + const baseOptionsProto = await convertBaseOptionsToProto({ + modelAssetBuffer: new Uint8Array(mockBytes), + }); + + expect(fetchSpy).not.toHaveBeenCalled(); + expect(baseOptionsProto.toObject()).toEqual(mockBytesResult); + }); + + it('can enable CPU delegate', async () => { + const baseOptionsProto = await convertBaseOptionsToProto({ + modelAssetBuffer: new Uint8Array(mockBytes), + delegate: 'cpu', + }); + expect(baseOptionsProto.toObject()).toEqual(mockBytesResult); + }); + + it('can enable GPU delegate', async () => { + const baseOptionsProto = await convertBaseOptionsToProto({ + modelAssetBuffer: new Uint8Array(mockBytes), + delegate: 'gpu', + }); + expect(baseOptionsProto.toObject()).toEqual({ + ...mockBytesResult, + acceleration: { + xnnpack: undefined, + gpu: { + useAdvancedGpuApi: false, + api: 0, + allowPrecisionLoss: true, + cachedKernelPath: undefined, + serializedModelDir: undefined, + modelToken: undefined, + usage: 2, + }, + tflite: undefined, + }, + }); + }); + + it('can reset delegate', async () => { + let baseOptionsProto = await convertBaseOptionsToProto({ + modelAssetBuffer: new Uint8Array(mockBytes), + delegate: 'gpu', + }); + // Clear backend + baseOptionsProto = + await convertBaseOptionsToProto({delegate: undefined}, baseOptionsProto); + expect(baseOptionsProto.toObject()).toEqual(mockBytesResult); + }); +}); diff --git a/mediapipe/tasks/web/components/processors/classifier_options.test.ts b/mediapipe/tasks/web/components/processors/classifier_options.test.ts new file mode 100644 index 000000000..928bda426 --- /dev/null +++ b/mediapipe/tasks/web/components/processors/classifier_options.test.ts @@ -0,0 +1,114 @@ +/** + * Copyright 2022 The MediaPipe Authors. All Rights Reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +import 'jasmine'; + +import {ClassifierOptions as ClassifierOptionsProto} from '../../../../tasks/cc/components/processors/proto/classifier_options_pb'; +import {ClassifierOptions} from '../../../../tasks/web/core/classifier_options'; + +import {convertClassifierOptionsToProto} from './classifier_options'; + +interface TestCase { + optionName: keyof ClassifierOptions; + protoName: string; + customValue: unknown; + defaultValue: unknown; +} + +describe('convertClassifierOptionsToProto()', () => { + function verifyOption( + actualClassifierOptions: ClassifierOptionsProto, + expectedClassifierOptions: Record = {}): void { + expect(actualClassifierOptions.toObject()) + .toEqual(jasmine.objectContaining(expectedClassifierOptions)); + } + + const testCases: TestCase[] = [ + { + optionName: 'maxResults', + protoName: 'maxResults', + customValue: 5, + defaultValue: -1 + }, + { + optionName: 'displayNamesLocale', + protoName: 'displayNamesLocale', + customValue: 'en', + defaultValue: 'en' + }, + { + optionName: 'scoreThreshold', + protoName: 'scoreThreshold', + customValue: 0.1, + defaultValue: undefined + }, + { + optionName: 'categoryAllowlist', + protoName: 'categoryAllowlistList', + customValue: ['foo'], + defaultValue: [] + }, + { + optionName: 'categoryDenylist', + protoName: 'categoryDenylistList', + customValue: ['bar'], + defaultValue: [] + }, + ]; + + for (const testCase of testCases) { + it(`can set ${testCase.optionName}`, () => { + const classifierOptionsProto = convertClassifierOptionsToProto( + {[testCase.optionName]: testCase.customValue}); + verifyOption( + classifierOptionsProto, {[testCase.protoName]: testCase.customValue}); + }); + + it(`can clear ${testCase.optionName}`, () => { + let classifierOptionsProto = convertClassifierOptionsToProto( + {[testCase.optionName]: testCase.customValue}); + verifyOption( + classifierOptionsProto, {[testCase.protoName]: testCase.customValue}); + + classifierOptionsProto = + convertClassifierOptionsToProto({[testCase.optionName]: undefined}); + verifyOption( + classifierOptionsProto, + {[testCase.protoName]: testCase.defaultValue}); + }); + } + + it('overwrites options', () => { + let classifierOptionsProto = + convertClassifierOptionsToProto({maxResults: 1}); + verifyOption(classifierOptionsProto, {'maxResults': 1}); + + classifierOptionsProto = convertClassifierOptionsToProto( + {maxResults: 2}, classifierOptionsProto); + verifyOption(classifierOptionsProto, {'maxResults': 2}); + }); + + it('merges options', () => { + let classifierOptionsProto = + convertClassifierOptionsToProto({maxResults: 1}); + verifyOption(classifierOptionsProto, {'maxResults': 1}); + + classifierOptionsProto = convertClassifierOptionsToProto( + {displayNamesLocale: 'en'}, classifierOptionsProto); + verifyOption( + classifierOptionsProto, {'maxResults': 1, 'displayNamesLocale': 'en'}); + }); +}); diff --git a/mediapipe/tasks/web/components/processors/classifier_result.test.ts b/mediapipe/tasks/web/components/processors/classifier_result.test.ts new file mode 100644 index 000000000..4b93d0a76 --- /dev/null +++ b/mediapipe/tasks/web/components/processors/classifier_result.test.ts @@ -0,0 +1,80 @@ +/** + * Copyright 2022 The MediaPipe Authors. All Rights Reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +import 'jasmine'; + +import {Classification, ClassificationList} from '../../../../framework/formats/classification_pb'; +import {ClassificationResult, Classifications} from '../../../../tasks/cc/components/containers/proto/classifications_pb'; + +import {convertFromClassificationResultProto} from './classifier_result'; + +// The OSS JS API does not support the builder pattern. +// tslint:disable:jspb-use-builder-pattern + +describe('convertFromClassificationResultProto()', () => { + it('transforms custom values', () => { + const classificationResult = new ClassificationResult(); + classificationResult.setTimestampMs(1); + const classifcations = new Classifications(); + classifcations.setHeadIndex(1); + classifcations.setHeadName('headName'); + const classificationList = new ClassificationList(); + const clasification = new Classification(); + clasification.setIndex(2); + clasification.setScore(0.3); + clasification.setDisplayName('displayName'); + clasification.setLabel('categoryName'); + classificationList.addClassification(clasification); + classifcations.setClassificationList(classificationList); + classificationResult.addClassifications(classifcations); + + const result = convertFromClassificationResultProto(classificationResult); + + expect(result).toEqual({ + classifications: [{ + categories: [{ + index: 2, + score: 0.3, + displayName: 'displayName', + categoryName: 'categoryName' + }], + headIndex: 1, + headName: 'headName' + }], + timestampMs: 1 + }); + }); + + it('transforms default values', () => { + const classificationResult = new ClassificationResult(); + const classifcations = new Classifications(); + const classificationList = new ClassificationList(); + const clasification = new Classification(); + classificationList.addClassification(clasification); + classifcations.setClassificationList(classificationList); + classificationResult.addClassifications(classifcations); + + const result = convertFromClassificationResultProto(classificationResult); + + expect(result).toEqual({ + classifications: [{ + categories: [{index: 0, score: 0, displayName: '', categoryName: ''}], + headIndex: 0, + headName: '' + }], + }); + }); +}); diff --git a/mediapipe/tasks/web/components/processors/embedder_options.test.ts b/mediapipe/tasks/web/components/processors/embedder_options.test.ts new file mode 100644 index 000000000..b879a6b29 --- /dev/null +++ b/mediapipe/tasks/web/components/processors/embedder_options.test.ts @@ -0,0 +1,93 @@ +/** + * Copyright 2022 The MediaPipe Authors. All Rights Reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +import 'jasmine'; + +import {EmbedderOptions as EmbedderOptionsProto} from '../../../../tasks/cc/components/processors/proto/embedder_options_pb'; +import {EmbedderOptions} from '../../../../tasks/web/core/embedder_options'; + +import {convertEmbedderOptionsToProto} from './embedder_options'; + +interface TestCase { + optionName: keyof EmbedderOptions; + protoName: string; + customValue: unknown; + defaultValue: unknown; +} + +describe('convertEmbedderOptionsToProto()', () => { + function verifyOption( + actualEmbedderOptions: EmbedderOptionsProto, + expectedEmbedderOptions: Record = {}): void { + expect(actualEmbedderOptions.toObject()) + .toEqual(jasmine.objectContaining(expectedEmbedderOptions)); + } + + const testCases: TestCase[] = [ + { + optionName: 'l2Normalize', + protoName: 'l2Normalize', + customValue: true, + defaultValue: undefined + }, + { + optionName: 'quantize', + protoName: 'quantize', + customValue: true, + defaultValue: undefined + }, + ]; + + for (const testCase of testCases) { + it(`can set ${testCase.optionName}`, () => { + const embedderOptionsProto = convertEmbedderOptionsToProto( + {[testCase.optionName]: testCase.customValue}); + verifyOption( + embedderOptionsProto, {[testCase.protoName]: testCase.customValue}); + }); + + it(`can clear ${testCase.optionName}`, () => { + let embedderOptionsProto = convertEmbedderOptionsToProto( + {[testCase.optionName]: testCase.customValue}); + verifyOption( + embedderOptionsProto, {[testCase.protoName]: testCase.customValue}); + + embedderOptionsProto = + convertEmbedderOptionsToProto({[testCase.optionName]: undefined}); + verifyOption( + embedderOptionsProto, {[testCase.protoName]: testCase.defaultValue}); + }); + } + + it('overwrites options', () => { + let embedderOptionsProto = + convertEmbedderOptionsToProto({l2Normalize: true}); + verifyOption(embedderOptionsProto, {'l2Normalize': true}); + + embedderOptionsProto = convertEmbedderOptionsToProto( + {l2Normalize: false}, embedderOptionsProto); + verifyOption(embedderOptionsProto, {'l2Normalize': false}); + }); + + it('replaces options', () => { + let embedderOptionsProto = convertEmbedderOptionsToProto({quantize: true}); + verifyOption(embedderOptionsProto, {'quantize': true}); + + embedderOptionsProto = convertEmbedderOptionsToProto( + {l2Normalize: true}, embedderOptionsProto); + verifyOption(embedderOptionsProto, {'l2Normalize': true, 'quantize': true}); + }); +}); diff --git a/mediapipe/tasks/web/components/processors/embedder_result.test.ts b/mediapipe/tasks/web/components/processors/embedder_result.test.ts new file mode 100644 index 000000000..97ba935c8 --- /dev/null +++ b/mediapipe/tasks/web/components/processors/embedder_result.test.ts @@ -0,0 +1,75 @@ +/** + * Copyright 2022 The MediaPipe Authors. All Rights Reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +import 'jasmine'; + +import {Embedding, EmbeddingResult, FloatEmbedding, QuantizedEmbedding} from '../../../../tasks/cc/components/containers/proto/embeddings_pb'; + +import {convertFromEmbeddingResultProto} from './embedder_result'; + +// The OSS JS API does not support the builder pattern. +// tslint:disable:jspb-use-builder-pattern + +describe('convertFromEmbeddingResultProto()', () => { + it('transforms custom values', () => { + const embedding = new Embedding(); + embedding.setHeadIndex(1); + embedding.setHeadName('headName'); + + const floatEmbedding = new FloatEmbedding(); + floatEmbedding.setValuesList([0.1, 0.9]); + + embedding.setFloatEmbedding(floatEmbedding); + const resultProto = new EmbeddingResult(); + resultProto.addEmbeddings(embedding); + resultProto.setTimestampMs(1); + + const embedderResult = convertFromEmbeddingResultProto(resultProto); + const embeddings = embedderResult.embeddings; + const timestampMs = embedderResult.timestampMs; + expect(embeddings.length).toEqual(1); + expect(embeddings[0]) + .toEqual( + {floatEmbedding: [0.1, 0.9], headIndex: 1, headName: 'headName'}); + expect(timestampMs).toEqual(1); + }); + + it('transforms custom quantized values', () => { + const embedding = new Embedding(); + embedding.setHeadIndex(1); + embedding.setHeadName('headName'); + + const quantizedEmbedding = new QuantizedEmbedding(); + const quantizedValues = new Uint8Array([1, 2, 3]); + quantizedEmbedding.setValues(quantizedValues); + + embedding.setQuantizedEmbedding(quantizedEmbedding); + const resultProto = new EmbeddingResult(); + resultProto.addEmbeddings(embedding); + resultProto.setTimestampMs(1); + + const embedderResult = convertFromEmbeddingResultProto(resultProto); + const embeddings = embedderResult.embeddings; + const timestampMs = embedderResult.timestampMs; + expect(embeddings.length).toEqual(1); + expect(embeddings[0]).toEqual({ + quantizedEmbedding: new Uint8Array([1, 2, 3]), + headIndex: 1, + headName: 'headName' + }); + expect(timestampMs).toEqual(1); + }); +}); diff --git a/mediapipe/tasks/web/components/utils/BUILD b/mediapipe/tasks/web/components/utils/BUILD index 1c1ba69ca..f4a215e48 100644 --- a/mediapipe/tasks/web/components/utils/BUILD +++ b/mediapipe/tasks/web/components/utils/BUILD @@ -1,4 +1,5 @@ load("//mediapipe/framework/port:build_config.bzl", "mediapipe_ts_library") +load("@npm//@bazel/jasmine:index.bzl", "jasmine_node_test") package(default_visibility = ["//mediapipe/tasks:internal"]) @@ -9,3 +10,18 @@ mediapipe_ts_library( "//mediapipe/tasks/web/components/containers:embedding_result", ], ) + +mediapipe_ts_library( + name = "cosine_similarity_test_lib", + testonly = True, + srcs = ["cosine_similarity.test.ts"], + deps = [ + ":cosine_similarity", + "//mediapipe/tasks/web/components/containers:embedding_result", + ], +) + +jasmine_node_test( + name = "cosine_similarity_test", + deps = [":cosine_similarity_test_lib"], +) diff --git a/mediapipe/tasks/web/components/utils/cosine_similarity.test.ts b/mediapipe/tasks/web/components/utils/cosine_similarity.test.ts new file mode 100644 index 000000000..f442caa20 --- /dev/null +++ b/mediapipe/tasks/web/components/utils/cosine_similarity.test.ts @@ -0,0 +1,85 @@ +/** + * Copyright 2022 The MediaPipe Authors. All Rights Reserved. + * + *

Licensed under the Apache License, Version 2.0 (the "License"); you may + * not use this file except in compliance with the License. You may obtain a + * copy of the License at + * + *

http://www.apache.org/licenses/LICENSE-2.0 + * + *

Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT + * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the + * License for the specific language governing permissions and limitations under + * the License. + */ + +import {Embedding} from '../../../../tasks/web/components/containers/embedding_result'; + +import {computeCosineSimilarity} from './cosine_similarity'; + +describe('computeCosineSimilarity', () => { + it('fails with quantized and float embeddings', () => { + const u: Embedding = {floatEmbedding: [1.0], headIndex: 0, headName: ''}; + const v: Embedding = { + quantizedEmbedding: new Uint8Array([1.0]), + headIndex: 0, + headName: '' + }; + + expect(() => computeCosineSimilarity(u, v)) + .toThrowError( + /Cannot compute cosine similarity between quantized and float embeddings/); + }); + + it('fails with zero norm', () => { + const u = {floatEmbedding: [0.0], headIndex: 0, headName: ''}; + expect(() => computeCosineSimilarity(u, u)) + .toThrowError( + /Cannot compute cosine similarity on embedding with 0 norm/); + }); + + it('fails with different sizes', () => { + const u: + Embedding = {floatEmbedding: [1.0, 2.0], headIndex: 0, headName: ''}; + const v: Embedding = { + floatEmbedding: [1.0, 2.0, 3.0], + headIndex: 0, + headName: '' + }; + + expect(() => computeCosineSimilarity(u, v)) + .toThrowError( + /Cannot compute cosine similarity between embeddings of different sizes/); + }); + + it('succeeds with float embeddings', () => { + const u: Embedding = { + floatEmbedding: [1.0, 0.0, 0.0, 0.0], + headIndex: 0, + headName: '' + }; + const v: Embedding = { + floatEmbedding: [0.5, 0.5, 0.5, 0.5], + headIndex: 0, + headName: '' + }; + + expect(computeCosineSimilarity(u, v)).toEqual(0.5); + }); + + it('succeeds with quantized embeddings', () => { + const u: Embedding = { + quantizedEmbedding: new Uint8Array([255, 128, 128, 128]), + headIndex: 0, + headName: '' + }; + const v: Embedding = { + quantizedEmbedding: new Uint8Array([0, 128, 128, 128]), + headIndex: 0, + headName: '' + }; + + expect(computeCosineSimilarity(u, v)).toEqual(-1.0); + }); +}); diff --git a/mediapipe/tasks/web/core/BUILD b/mediapipe/tasks/web/core/BUILD index de429690d..1721661f5 100644 --- a/mediapipe/tasks/web/core/BUILD +++ b/mediapipe/tasks/web/core/BUILD @@ -1,6 +1,7 @@ # This package contains options shared by all MediaPipe Tasks for Web. load("//mediapipe/framework/port:build_config.bzl", "mediapipe_ts_declaration", "mediapipe_ts_library") +load("@npm//@bazel/jasmine:index.bzl", "jasmine_node_test") package(default_visibility = ["//mediapipe/tasks:internal"]) @@ -28,9 +29,42 @@ mediapipe_ts_library( mediapipe_ts_library( name = "fileset_resolver", srcs = ["fileset_resolver.ts"], + visibility = ["//visibility:public"], deps = [":core"], ) +mediapipe_ts_library( + name = "task_runner_test_utils", + testonly = True, + srcs = [ + "task_runner_test_utils.ts", + ], + deps = [ + "//mediapipe/framework:calculator_jspb_proto", + "//mediapipe/web/graph_runner:graph_runner_ts", + "//mediapipe/web/graph_runner:register_model_resources_graph_service_ts", + ], +) + +mediapipe_ts_library( + name = "task_runner_test_lib", + testonly = True, + srcs = [ + "task_runner_test.ts", + ], + deps = [ + ":task_runner", + ":task_runner_test_utils", + "//mediapipe/tasks/cc/core/proto:base_options_jspb_proto", + "//mediapipe/web/graph_runner:graph_runner_ts", + ], +) + +jasmine_node_test( + name = "task_runner_test", + deps = [":task_runner_test_lib"], +) + mediapipe_ts_declaration( name = "classifier_options", srcs = ["classifier_options.d.ts"], diff --git a/mediapipe/tasks/web/core/task_runner.ts b/mediapipe/tasks/web/core/task_runner.ts index d769139bc..2011fadef 100644 --- a/mediapipe/tasks/web/core/task_runner.ts +++ b/mediapipe/tasks/web/core/task_runner.ts @@ -32,8 +32,35 @@ const GraphRunnerImageLibType = /** An implementation of the GraphRunner that supports image operations */ export class GraphRunnerImageLib extends GraphRunnerImageLibType {} +/** + * Creates a new instance of a Mediapipe Task. Determines if SIMD is + * supported and loads the relevant WASM binary. + * @return A fully instantiated instance of `T`. + */ +export async function createTaskRunner( + type: WasmMediaPipeConstructor, initializeCanvas: boolean, + fileset: WasmFileset, options: TaskRunnerOptions): Promise { + const fileLocator: FileLocator = { + locateFile() { + // The only file loaded with this mechanism is the Wasm binary + return fileset.wasmBinaryPath.toString(); + } + }; + + // Initialize a canvas if requested. If OffscreenCanvas is availble, we + // let the graph runner initialize it by passing `undefined`. + const canvas = initializeCanvas ? (typeof OffscreenCanvas === 'undefined' ? + document.createElement('canvas') : + undefined) : + null; + const instance = await createMediaPipeLib( + type, fileset.wasmLoaderPath, NO_ASSETS, canvas, fileLocator); + await instance.setOptions(options); + return instance; +} + /** Base class for all MediaPipe Tasks. */ -export abstract class TaskRunner { +export abstract class TaskRunner { protected abstract baseOptions: BaseOptionsProto; protected graphRunner: GraphRunnerImageLib; private processingErrors: Error[] = []; @@ -43,33 +70,18 @@ export abstract class TaskRunner { * supported and loads the relevant WASM binary. * @return A fully instantiated instance of `T`. */ - protected static async createInstance, - O extends TaskRunnerOptions>( + protected static async createInstance( type: WasmMediaPipeConstructor, initializeCanvas: boolean, - fileset: WasmFileset, options: O): Promise { - const fileLocator: FileLocator = { - locateFile() { - // The only file loaded with this mechanism is the Wasm binary - return fileset.wasmBinaryPath.toString(); - } - }; - - // Initialize a canvas if requested. If OffscreenCanvas is availble, we - // let the graph runner initialize it by passing `undefined`. - const canvas = initializeCanvas ? (typeof OffscreenCanvas === 'undefined' ? - document.createElement('canvas') : - undefined) : - null; - const instance = await createMediaPipeLib( - type, fileset.wasmLoaderPath, NO_ASSETS, canvas, fileLocator); - await instance.setOptions(options); - return instance; + fileset: WasmFileset, options: TaskRunnerOptions): Promise { + return createTaskRunner(type, initializeCanvas, fileset, options); } + /** @hideconstructor protected */ constructor( - wasmModule: WasmModule, - glCanvas?: HTMLCanvasElement|OffscreenCanvas|null) { - this.graphRunner = new GraphRunnerImageLib(wasmModule, glCanvas); + wasmModule: WasmModule, glCanvas?: HTMLCanvasElement|OffscreenCanvas|null, + graphRunner?: GraphRunnerImageLib) { + this.graphRunner = + graphRunner ?? new GraphRunnerImageLib(wasmModule, glCanvas); // Disables the automatic render-to-screen code, which allows for pure // CPU processing. @@ -80,7 +92,7 @@ export abstract class TaskRunner { } /** Configures the shared options of a MediaPipe Task. */ - async setOptions(options: O): Promise { + async setOptions(options: TaskRunnerOptions): Promise { if (options.baseOptions) { this.baseOptions = await convertBaseOptionsToProto( options.baseOptions, this.baseOptions); diff --git a/mediapipe/tasks/web/core/task_runner_test.ts b/mediapipe/tasks/web/core/task_runner_test.ts new file mode 100644 index 000000000..c9aad9d25 --- /dev/null +++ b/mediapipe/tasks/web/core/task_runner_test.ts @@ -0,0 +1,107 @@ +/** + * Copyright 2022 The MediaPipe Authors. All Rights Reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +import 'jasmine'; + +import {BaseOptions as BaseOptionsProto} from '../../../tasks/cc/core/proto/base_options_pb'; +import {TaskRunner} from '../../../tasks/web/core/task_runner'; +import {createSpyWasmModule, SpyWasmModule} from '../../../tasks/web/core/task_runner_test_utils'; +import {ErrorListener} from '../../../web/graph_runner/graph_runner'; + +import {GraphRunnerImageLib} from './task_runner'; + +class TaskRunnerFake extends TaskRunner { + protected baseOptions = new BaseOptionsProto(); + private errorListener: ErrorListener|undefined; + private errors: string[] = []; + + static createFake(): TaskRunnerFake { + const wasmModule = createSpyWasmModule(); + return new TaskRunnerFake(wasmModule); + } + + constructor(wasmModuleFake: SpyWasmModule) { + super( + wasmModuleFake, /* glCanvas= */ null, + jasmine.createSpyObj([ + 'setAutoRenderToScreen', 'setGraph', 'finishProcessing', + 'registerModelResourcesGraphService', 'attachErrorListener' + ])); + const graphRunner = this.graphRunner as jasmine.SpyObj; + expect(graphRunner.registerModelResourcesGraphService).toHaveBeenCalled(); + expect(graphRunner.setAutoRenderToScreen).toHaveBeenCalled(); + graphRunner.attachErrorListener.and.callFake(listener => { + this.errorListener = listener; + }); + graphRunner.setGraph.and.callFake(() => { + this.throwErrors(); + }); + graphRunner.finishProcessing.and.callFake(() => { + this.throwErrors(); + }); + } + + enqueueError(message: string): void { + this.errors.push(message); + } + + override finishProcessing(): void { + super.finishProcessing(); + } + + override setGraph(graphData: Uint8Array, isBinary: boolean): void { + super.setGraph(graphData, isBinary); + } + + private throwErrors(): void { + expect(this.errorListener).toBeDefined(); + for (const error of this.errors) { + this.errorListener!(/* errorCode= */ -1, error); + } + this.errors = []; + } +} + +describe('TaskRunner', () => { + it('handles errors during graph update', () => { + const taskRunner = TaskRunnerFake.createFake(); + taskRunner.enqueueError('Test error'); + + expect(() => { + taskRunner.setGraph(new Uint8Array(0), /* isBinary= */ true); + }).toThrowError('Test error'); + }); + + it('handles errors during graph execution', () => { + const taskRunner = TaskRunnerFake.createFake(); + taskRunner.setGraph(new Uint8Array(0), /* isBinary= */ true); + + taskRunner.enqueueError('Test error'); + + expect(() => { + taskRunner.finishProcessing(); + }).toThrowError('Test error'); + }); + + it('can handle multiple errors', () => { + const taskRunner = TaskRunnerFake.createFake(); + taskRunner.enqueueError('Test error 1'); + taskRunner.enqueueError('Test error 2'); + + expect(() => { + taskRunner.setGraph(new Uint8Array(0), /* isBinary= */ true); + }).toThrowError(/Test error 1, Test error 2/); + }); +}); diff --git a/mediapipe/tasks/web/core/task_runner_test_utils.ts b/mediapipe/tasks/web/core/task_runner_test_utils.ts new file mode 100644 index 000000000..2a1161a55 --- /dev/null +++ b/mediapipe/tasks/web/core/task_runner_test_utils.ts @@ -0,0 +1,113 @@ +/** + * Copyright 2022 The MediaPipe Authors. All Rights Reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +import 'jasmine'; + +import {CalculatorGraphConfig} from '../../../framework/calculator_pb'; +import {WasmModule} from '../../../web/graph_runner/graph_runner'; +import {WasmModuleRegisterModelResources} from '../../../web/graph_runner/register_model_resources_graph_service'; + +type SpyWasmModuleInternal = WasmModule&WasmModuleRegisterModelResources; + +/** + * Convenience type for our fake WasmModule for Jasmine testing. + */ +export declare type SpyWasmModule = jasmine.SpyObj; + +/** + * Factory function for creating a fake WasmModule for our Jasmine tests, + * allowing our APIs to no longer rely on the Wasm layer so they can run tests + * in pure JS/TS (and optionally spy on the calls). + */ +export function createSpyWasmModule(): SpyWasmModule { + return jasmine.createSpyObj([ + '_setAutoRenderToScreen', 'stringToNewUTF8', '_attachProtoListener', + '_attachProtoVectorListener', '_free', '_waitUntilIdle', + '_addStringToInputStream', '_registerModelResourcesGraphService', + '_configureAudio' + ]); +} + +/** + * Sets up our equality testing to use a custom float equality checking function + * to avoid incorrect test results due to minor floating point inaccuracies. + */ +export function addJasmineCustomFloatEqualityTester() { + jasmine.addCustomEqualityTester((a, b) => { // Custom float equality + if (a === +a && b === +b && (a !== (a | 0) || b !== (b | 0))) { + return Math.abs(a - b) < 5e-8; + } + return; + }); +} + +/** The minimum interface provided by a test fake. */ +export interface MediapipeTasksFake { + graph: CalculatorGraphConfig|undefined; + calculatorName: string; + attachListenerSpies: jasmine.Spy[]; +} + +/** An map of field paths to values */ +export type FieldPathToValue = [string[] | string, unknown]; + +/** + * Verifies that the graph has been initialized and that it contains the + * provided options. + */ +export function verifyGraph( + tasksFake: MediapipeTasksFake, + expectedCalculatorOptions?: FieldPathToValue, + expectedBaseOptions?: FieldPathToValue, + ): void { + expect(tasksFake.graph).toBeDefined(); + expect(tasksFake.graph!.getNodeList().length).toBe(1); + const node = tasksFake.graph!.getNodeList()[0].toObject(); + expect(node).toEqual( + jasmine.objectContaining({calculator: tasksFake.calculatorName})); + + if (expectedBaseOptions) { + const [fieldPath, value] = expectedBaseOptions; + let proto = (node.options as {ext: {baseOptions: unknown}}).ext.baseOptions; + for (const fieldName of ( + Array.isArray(fieldPath) ? fieldPath : [fieldPath])) { + proto = ((proto ?? {}) as Record)[fieldName]; + } + expect(proto).toEqual(value); + } + + if (expectedCalculatorOptions) { + const [fieldPath, value] = expectedCalculatorOptions; + let proto = (node.options as {ext: unknown}).ext; + for (const fieldName of ( + Array.isArray(fieldPath) ? fieldPath : [fieldPath])) { + proto = ((proto ?? {}) as Record)[fieldName]; + } + expect(proto).toEqual(value); + } +} + +/** + * Verifies all listeners (as exposed by `.attachListenerSpies`) have been + * attached at least once since the last call to `verifyListenersRegistered()`. + * This helps us to ensure that listeners are re-registered with every graph + * update. + */ +export function verifyListenersRegistered(tasksFake: MediapipeTasksFake): void { + for (const spy of tasksFake.attachListenerSpies) { + expect(spy.calls.count()).toBeGreaterThanOrEqual(1); + spy.calls.reset(); + } +} diff --git a/mediapipe/tasks/web/rollup.config.mjs b/mediapipe/tasks/web/rollup.config.mjs index e633bf702..3b5119530 100644 --- a/mediapipe/tasks/web/rollup.config.mjs +++ b/mediapipe/tasks/web/rollup.config.mjs @@ -1,15 +1,9 @@ import resolve from '@rollup/plugin-node-resolve'; import commonjs from '@rollup/plugin-commonjs'; -import replace from '@rollup/plugin-replace'; import terser from '@rollup/plugin-terser'; export default { plugins: [ - // Workaround for https://github.com/protocolbuffers/protobuf-javascript/issues/151 - replace({ - 'var calculator_options_pb = {};': 'var calculator_options_pb = {}; var mediapipe_framework_calculator_options_pb = calculator_options_pb;', - delimiters: ['', ''] - }), resolve(), commonjs(), terser() diff --git a/mediapipe/tasks/web/text/BUILD b/mediapipe/tasks/web/text/BUILD index 159db1a0d..32f43d4b6 100644 --- a/mediapipe/tasks/web/text/BUILD +++ b/mediapipe/tasks/web/text/BUILD @@ -7,6 +7,7 @@ package(default_visibility = ["//mediapipe/tasks:internal"]) mediapipe_ts_library( name = "text_lib", srcs = ["index.ts"], + visibility = ["//visibility:public"], deps = [ "//mediapipe/tasks/web/core:fileset_resolver", "//mediapipe/tasks/web/text/text_classifier", diff --git a/mediapipe/tasks/web/text/text_classifier/BUILD b/mediapipe/tasks/web/text/text_classifier/BUILD index 2a7de21d6..fd97c3db4 100644 --- a/mediapipe/tasks/web/text/text_classifier/BUILD +++ b/mediapipe/tasks/web/text/text_classifier/BUILD @@ -4,6 +4,7 @@ # BERT-based text classification). load("//mediapipe/framework/port:build_config.bzl", "mediapipe_ts_declaration", "mediapipe_ts_library") +load("@npm//@bazel/jasmine:index.bzl", "jasmine_node_test") package(default_visibility = ["//mediapipe/tasks:internal"]) @@ -12,6 +13,7 @@ licenses(["notice"]) mediapipe_ts_library( name = "text_classifier", srcs = ["text_classifier.ts"], + visibility = ["//visibility:public"], deps = [ ":text_classifier_types", "//mediapipe/framework:calculator_jspb_proto", @@ -36,6 +38,7 @@ mediapipe_ts_declaration( "text_classifier_options.d.ts", "text_classifier_result.d.ts", ], + visibility = ["//visibility:public"], deps = [ "//mediapipe/tasks/web/components/containers:category", "//mediapipe/tasks/web/components/containers:classification_result", @@ -43,3 +46,24 @@ mediapipe_ts_declaration( "//mediapipe/tasks/web/core:classifier_options", ], ) + +mediapipe_ts_library( + name = "text_classifier_test_lib", + testonly = True, + srcs = [ + "text_classifier_test.ts", + ], + deps = [ + ":text_classifier", + "//mediapipe/framework:calculator_jspb_proto", + "//mediapipe/framework/formats:classification_jspb_proto", + "//mediapipe/tasks/cc/components/containers/proto:classifications_jspb_proto", + "//mediapipe/tasks/web/core", + "//mediapipe/tasks/web/core:task_runner_test_utils", + ], +) + +jasmine_node_test( + name = "text_classifier_test", + deps = [":text_classifier_test_lib"], +) diff --git a/mediapipe/tasks/web/text/text_classifier/text_classifier.ts b/mediapipe/tasks/web/text/text_classifier/text_classifier.ts index 8810d4b42..62708700a 100644 --- a/mediapipe/tasks/web/text/text_classifier/text_classifier.ts +++ b/mediapipe/tasks/web/text/text_classifier/text_classifier.ts @@ -41,7 +41,7 @@ const TEXT_CLASSIFIER_GRAPH = // tslint:disable:jspb-use-builder-pattern /** Performs Natural Language classification. */ -export class TextClassifier extends TaskRunner { +export class TextClassifier extends TaskRunner { private classificationResult: TextClassifierResult = {classifications: []}; private readonly options = new TextClassifierGraphOptions(); @@ -92,6 +92,7 @@ export class TextClassifier extends TaskRunner { {baseOptions: {modelAssetPath}}); } + /** @hideconstructor */ constructor( wasmModule: WasmModule, glCanvas?: HTMLCanvasElement|OffscreenCanvas|null) { diff --git a/mediapipe/tasks/web/text/text_classifier/text_classifier_test.ts b/mediapipe/tasks/web/text/text_classifier/text_classifier_test.ts new file mode 100644 index 000000000..841bf8c48 --- /dev/null +++ b/mediapipe/tasks/web/text/text_classifier/text_classifier_test.ts @@ -0,0 +1,152 @@ +/** + * Copyright 2022 The MediaPipe Authors. All Rights Reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +import 'jasmine'; + +// Placeholder for internal dependency on encodeByteArray +import {CalculatorGraphConfig} from '../../../../framework/calculator_pb'; +import {Classification, ClassificationList} from '../../../../framework/formats/classification_pb'; +import {ClassificationResult, Classifications} from '../../../../tasks/cc/components/containers/proto/classifications_pb'; +import {addJasmineCustomFloatEqualityTester, createSpyWasmModule, MediapipeTasksFake, SpyWasmModule, verifyGraph, verifyListenersRegistered} from '../../../../tasks/web/core/task_runner_test_utils'; + +import {TextClassifier} from './text_classifier'; + +// The OSS JS API does not support the builder pattern. +// tslint:disable:jspb-use-builder-pattern + +class TextClassifierFake extends TextClassifier implements MediapipeTasksFake { + calculatorName = 'mediapipe.tasks.text.text_classifier.TextClassifierGraph'; + attachListenerSpies: jasmine.Spy[] = []; + graph: CalculatorGraphConfig|undefined; + fakeWasmModule: SpyWasmModule; + protoListener: ((binaryProto: Uint8Array) => void)|undefined; + + constructor() { + super(createSpyWasmModule(), /* glCanvas= */ null); + this.fakeWasmModule = + this.graphRunner.wasmModule as unknown as SpyWasmModule; + this.attachListenerSpies[0] = + spyOn(this.graphRunner, 'attachProtoListener') + .and.callFake((stream, listener) => { + expect(stream).toEqual('classifications_out'); + this.protoListener = listener; + }); + spyOn(this.graphRunner, 'setGraph').and.callFake(binaryGraph => { + this.graph = CalculatorGraphConfig.deserializeBinary(binaryGraph); + }); + } +} + +describe('TextClassifier', () => { + let textClassifier: TextClassifierFake; + + beforeEach(async () => { + addJasmineCustomFloatEqualityTester(); + textClassifier = new TextClassifierFake(); + await textClassifier.setOptions({}); // Initialize graph + }); + + it('initializes graph', async () => { + verifyGraph(textClassifier); + verifyListenersRegistered(textClassifier); + }); + + it('reloads graph when settings are changed', async () => { + await textClassifier.setOptions({maxResults: 1}); + verifyGraph(textClassifier, [['classifierOptions', 'maxResults'], 1]); + verifyListenersRegistered(textClassifier); + + await textClassifier.setOptions({maxResults: 5}); + verifyGraph(textClassifier, [['classifierOptions', 'maxResults'], 5]); + verifyListenersRegistered(textClassifier); + }); + + it('can use custom models', async () => { + const newModel = new Uint8Array([0, 1, 2, 3, 4]); + const newModelBase64 = Buffer.from(newModel).toString('base64'); + await textClassifier.setOptions({ + baseOptions: { + modelAssetBuffer: newModel, + } + }); + + verifyGraph( + textClassifier, + /* expectedCalculatorOptions= */ undefined, + /* expectedBaseOptions= */ + [ + 'modelAsset', { + fileContent: newModelBase64, + fileName: undefined, + fileDescriptorMeta: undefined, + filePointerMeta: undefined + } + ]); + }); + + it('merges options', async () => { + await textClassifier.setOptions({maxResults: 1}); + await textClassifier.setOptions({displayNamesLocale: 'en'}); + verifyGraph(textClassifier, [ + 'classifierOptions', { + maxResults: 1, + displayNamesLocale: 'en', + scoreThreshold: undefined, + categoryAllowlistList: [], + categoryDenylistList: [] + } + ]); + }); + + it('transforms results', async () => { + const classificationResult = new ClassificationResult(); + const classifcations = new Classifications(); + classifcations.setHeadIndex(1); + classifcations.setHeadName('headName'); + const classificationList = new ClassificationList(); + const clasification = new Classification(); + clasification.setIndex(1); + clasification.setScore(0.2); + clasification.setDisplayName('displayName'); + clasification.setLabel('categoryName'); + classificationList.addClassification(clasification); + classifcations.setClassificationList(classificationList); + classificationResult.addClassifications(classifcations); + + // Pass the test data to our listener + textClassifier.fakeWasmModule._waitUntilIdle.and.callFake(() => { + verifyListenersRegistered(textClassifier); + textClassifier.protoListener!(classificationResult.serializeBinary()); + }); + + // Invoke the text classifier + const result = textClassifier.classify('foo'); + + expect(textClassifier.fakeWasmModule._waitUntilIdle).toHaveBeenCalled(); + expect(result).toEqual({ + classifications: [{ + categories: [{ + index: 1, + score: 0.2, + displayName: 'displayName', + categoryName: 'categoryName' + }], + headIndex: 1, + headName: 'headName' + }] + }); + }); +}); diff --git a/mediapipe/tasks/web/text/text_embedder/BUILD b/mediapipe/tasks/web/text/text_embedder/BUILD index 17d105258..1514944bf 100644 --- a/mediapipe/tasks/web/text/text_embedder/BUILD +++ b/mediapipe/tasks/web/text/text_embedder/BUILD @@ -4,6 +4,7 @@ # load("//mediapipe/framework/port:build_config.bzl", "mediapipe_ts_declaration", "mediapipe_ts_library") +load("@npm//@bazel/jasmine:index.bzl", "jasmine_node_test") package(default_visibility = ["//mediapipe/tasks:internal"]) @@ -12,6 +13,7 @@ licenses(["notice"]) mediapipe_ts_library( name = "text_embedder", srcs = ["text_embedder.ts"], + visibility = ["//visibility:public"], deps = [ ":text_embedder_types", "//mediapipe/framework:calculator_jspb_proto", @@ -36,9 +38,30 @@ mediapipe_ts_declaration( "text_embedder_options.d.ts", "text_embedder_result.d.ts", ], + visibility = ["//visibility:public"], deps = [ "//mediapipe/tasks/web/components/containers:embedding_result", "//mediapipe/tasks/web/core", "//mediapipe/tasks/web/core:embedder_options", ], ) + +mediapipe_ts_library( + name = "text_embedder_test_lib", + testonly = True, + srcs = [ + "text_embedder_test.ts", + ], + deps = [ + ":text_embedder", + "//mediapipe/framework:calculator_jspb_proto", + "//mediapipe/tasks/cc/components/containers/proto:embeddings_jspb_proto", + "//mediapipe/tasks/web/core", + "//mediapipe/tasks/web/core:task_runner_test_utils", + ], +) + +jasmine_node_test( + name = "text_embedder_test", + deps = [":text_embedder_test_lib"], +) diff --git a/mediapipe/tasks/web/text/text_embedder/text_embedder.ts b/mediapipe/tasks/web/text/text_embedder/text_embedder.ts index 62f9b06db..611233e02 100644 --- a/mediapipe/tasks/web/text/text_embedder/text_embedder.ts +++ b/mediapipe/tasks/web/text/text_embedder/text_embedder.ts @@ -45,7 +45,7 @@ const TEXT_EMBEDDER_CALCULATOR = /** * Performs embedding extraction on text. */ -export class TextEmbedder extends TaskRunner { +export class TextEmbedder extends TaskRunner { private embeddingResult: TextEmbedderResult = {embeddings: []}; private readonly options = new TextEmbedderGraphOptionsProto(); @@ -96,6 +96,7 @@ export class TextEmbedder extends TaskRunner { {baseOptions: {modelAssetPath}}); } + /** @hideconstructor */ constructor( wasmModule: WasmModule, glCanvas?: HTMLCanvasElement|OffscreenCanvas|null) { diff --git a/mediapipe/tasks/web/text/text_embedder/text_embedder_test.ts b/mediapipe/tasks/web/text/text_embedder/text_embedder_test.ts new file mode 100644 index 000000000..04a9b371a --- /dev/null +++ b/mediapipe/tasks/web/text/text_embedder/text_embedder_test.ts @@ -0,0 +1,165 @@ +/** + * Copyright 2022 The MediaPipe Authors. All Rights Reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +import 'jasmine'; + +// Placeholder for internal dependency on encodeByteArray +import {CalculatorGraphConfig} from '../../../../framework/calculator_pb'; +import {Embedding, EmbeddingResult, FloatEmbedding, QuantizedEmbedding} from '../../../../tasks/cc/components/containers/proto/embeddings_pb'; +import {addJasmineCustomFloatEqualityTester, createSpyWasmModule, MediapipeTasksFake, SpyWasmModule, verifyGraph, verifyListenersRegistered} from '../../../../tasks/web/core/task_runner_test_utils'; + +import {TextEmbedder} from './text_embedder'; + +// The OSS JS API does not support the builder pattern. +// tslint:disable:jspb-use-builder-pattern + +class TextEmbedderFake extends TextEmbedder implements MediapipeTasksFake { + calculatorName = 'mediapipe.tasks.text.text_embedder.TextEmbedderGraph'; + graph: CalculatorGraphConfig|undefined; + attachListenerSpies: jasmine.Spy[] = []; + fakeWasmModule: SpyWasmModule; + protoListener: ((binaryProtos: Uint8Array) => void)|undefined; + + constructor() { + super(createSpyWasmModule(), /* glCanvas= */ null); + this.fakeWasmModule = + this.graphRunner.wasmModule as unknown as SpyWasmModule; + + this.attachListenerSpies[0] = + spyOn(this.graphRunner, 'attachProtoListener') + .and.callFake((stream, listener) => { + expect(stream).toEqual('embeddings_out'); + this.protoListener = listener; + }); + spyOn(this.graphRunner, 'setGraph').and.callFake(binaryGraph => { + this.graph = CalculatorGraphConfig.deserializeBinary(binaryGraph); + }); + } +} + +describe('TextEmbedder', () => { + let textEmbedder: TextEmbedderFake; + + beforeEach(async () => { + addJasmineCustomFloatEqualityTester(); + textEmbedder = new TextEmbedderFake(); + await textEmbedder.setOptions({}); // Initialize graph + }); + + it('initializes graph', async () => { + verifyGraph(textEmbedder); + verifyListenersRegistered(textEmbedder); + }); + + it('reloads graph when settings are changed', async () => { + await textEmbedder.setOptions({quantize: true}); + verifyGraph(textEmbedder, [['embedderOptions', 'quantize'], true]); + verifyListenersRegistered(textEmbedder); + + await textEmbedder.setOptions({quantize: undefined}); + verifyGraph(textEmbedder, [['embedderOptions', 'quantize'], undefined]); + verifyListenersRegistered(textEmbedder); + }); + + it('can use custom models', async () => { + const newModel = new Uint8Array([0, 1, 2, 3, 4]); + const newModelBase64 = Buffer.from(newModel).toString('base64'); + await textEmbedder.setOptions({ + baseOptions: { + modelAssetBuffer: newModel, + } + }); + + verifyGraph( + textEmbedder, + /* expectedCalculatorOptions= */ undefined, + /* expectedBaseOptions= */[ + 'modelAsset', { + fileContent: newModelBase64, + fileName: undefined, + fileDescriptorMeta: undefined, + filePointerMeta: undefined + } + ]); + }); + + it('combines options', async () => { + await textEmbedder.setOptions({quantize: true}); + await textEmbedder.setOptions({l2Normalize: true}); + verifyGraph( + textEmbedder, + ['embedderOptions', {'quantize': true, 'l2Normalize': true}]); + }); + + it('transforms results', async () => { + const embedding = new Embedding(); + embedding.setHeadIndex(1); + embedding.setHeadName('headName'); + + const floatEmbedding = new FloatEmbedding(); + floatEmbedding.setValuesList([0.1, 0.9]); + + embedding.setFloatEmbedding(floatEmbedding); + const resultProto = new EmbeddingResult(); + resultProto.addEmbeddings(embedding); + + // Pass the test data to our listener + textEmbedder.fakeWasmModule._waitUntilIdle.and.callFake(() => { + verifyListenersRegistered(textEmbedder); + textEmbedder.protoListener!(resultProto.serializeBinary()); + }); + + // Invoke the text embedder + const embeddingResult = textEmbedder.embed('foo'); + + expect(textEmbedder.fakeWasmModule._waitUntilIdle).toHaveBeenCalled(); + expect(embeddingResult.embeddings.length).toEqual(1); + expect(embeddingResult.embeddings[0]) + .toEqual( + {floatEmbedding: [0.1, 0.9], headIndex: 1, headName: 'headName'}); + }); + + it('transforms custom quantized values', async () => { + const embedding = new Embedding(); + embedding.setHeadIndex(1); + embedding.setHeadName('headName'); + + const quantizedEmbedding = new QuantizedEmbedding(); + const quantizedValues = new Uint8Array([1, 2, 3]); + quantizedEmbedding.setValues(quantizedValues); + + embedding.setQuantizedEmbedding(quantizedEmbedding); + const resultProto = new EmbeddingResult(); + resultProto.addEmbeddings(embedding); + + // Pass the test data to our listener + textEmbedder.fakeWasmModule._waitUntilIdle.and.callFake(() => { + verifyListenersRegistered(textEmbedder); + textEmbedder.protoListener!(resultProto.serializeBinary()); + }); + + // Invoke the text embedder + const embeddingsResult = textEmbedder.embed('foo'); + + expect(textEmbedder.fakeWasmModule._waitUntilIdle).toHaveBeenCalled(); + expect(embeddingsResult.embeddings.length).toEqual(1); + expect(embeddingsResult.embeddings[0]).toEqual({ + quantizedEmbedding: new Uint8Array([1, 2, 3]), + headIndex: 1, + headName: 'headName' + }); + }); +}); diff --git a/mediapipe/tasks/web/vision/BUILD b/mediapipe/tasks/web/vision/BUILD index 42bc0a494..93493e873 100644 --- a/mediapipe/tasks/web/vision/BUILD +++ b/mediapipe/tasks/web/vision/BUILD @@ -7,6 +7,7 @@ package(default_visibility = ["//mediapipe/tasks:internal"]) mediapipe_ts_library( name = "vision_lib", srcs = ["index.ts"], + visibility = ["//visibility:public"], deps = [ "//mediapipe/tasks/web/core:fileset_resolver", "//mediapipe/tasks/web/vision/gesture_recognizer", diff --git a/mediapipe/tasks/web/vision/core/BUILD b/mediapipe/tasks/web/vision/core/BUILD index b389a9b01..e4ea3036f 100644 --- a/mediapipe/tasks/web/vision/core/BUILD +++ b/mediapipe/tasks/web/vision/core/BUILD @@ -1,5 +1,6 @@ # This package contains options shared by all MediaPipe Vision Tasks for Web. +load("@npm//@bazel/jasmine:index.bzl", "jasmine_node_test") load("//mediapipe/framework/port:build_config.bzl", "mediapipe_ts_declaration", "mediapipe_ts_library") package(default_visibility = ["//mediapipe/tasks:internal"]) @@ -22,3 +23,20 @@ mediapipe_ts_library( "//mediapipe/web/graph_runner:graph_runner_ts", ], ) + +mediapipe_ts_library( + name = "vision_task_runner_test_lib", + testonly = True, + srcs = ["vision_task_runner.test.ts"], + deps = [ + ":vision_task_runner", + "//mediapipe/tasks/cc/core/proto:base_options_jspb_proto", + "//mediapipe/tasks/web/core:task_runner_test_utils", + "//mediapipe/web/graph_runner:graph_runner_ts", + ], +) + +jasmine_node_test( + name = "vision_task_runner_test", + deps = [":vision_task_runner_test_lib"], +) diff --git a/mediapipe/tasks/web/vision/core/vision_task_runner.test.ts b/mediapipe/tasks/web/vision/core/vision_task_runner.test.ts new file mode 100644 index 000000000..6cc9ea328 --- /dev/null +++ b/mediapipe/tasks/web/vision/core/vision_task_runner.test.ts @@ -0,0 +1,99 @@ +/** + * Copyright 2022 The MediaPipe Authors. All Rights Reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +import 'jasmine'; + +import {BaseOptions as BaseOptionsProto} from '../../../../tasks/cc/core/proto/base_options_pb'; +import {createSpyWasmModule} from '../../../../tasks/web/core/task_runner_test_utils'; +import {ImageSource} from '../../../../web/graph_runner/graph_runner'; + +import {VisionTaskRunner} from './vision_task_runner'; + +class VisionTaskRunnerFake extends VisionTaskRunner { + baseOptions = new BaseOptionsProto(); + + constructor() { + super(createSpyWasmModule(), /* glCanvas= */ null); + } + + protected override process(): void {} + + override processImageData(image: ImageSource): void { + super.processImageData(image); + } + + override processVideoData(imageFrame: ImageSource, timestamp: number): void { + super.processVideoData(imageFrame, timestamp); + } +} + +describe('VisionTaskRunner', () => { + const streamMode = { + modelAsset: undefined, + useStreamMode: true, + acceleration: undefined, + }; + + const imageMode = { + modelAsset: undefined, + useStreamMode: false, + acceleration: undefined, + }; + + let visionTaskRunner: VisionTaskRunnerFake; + + beforeEach(() => { + visionTaskRunner = new VisionTaskRunnerFake(); + }); + + it('can enable image mode', async () => { + await visionTaskRunner.setOptions({runningMode: 'image'}); + expect(visionTaskRunner.baseOptions.toObject()).toEqual(imageMode); + }); + + it('can enable video mode', async () => { + await visionTaskRunner.setOptions({runningMode: 'video'}); + expect(visionTaskRunner.baseOptions.toObject()).toEqual(streamMode); + }); + + it('can clear running mode', async () => { + await visionTaskRunner.setOptions({runningMode: 'video'}); + + // Clear running mode + await visionTaskRunner.setOptions({runningMode: undefined}); + expect(visionTaskRunner.baseOptions.toObject()).toEqual(imageMode); + }); + + it('cannot process images with video mode', async () => { + await visionTaskRunner.setOptions({runningMode: 'video'}); + expect(() => { + visionTaskRunner.processImageData({} as HTMLImageElement); + }).toThrowError(/Task is not initialized with image mode./); + }); + + it('cannot process video with image mode', async () => { + // Use default for `useStreamMode` + expect(() => { + visionTaskRunner.processVideoData({} as HTMLImageElement, 42); + }).toThrowError(/Task is not initialized with video mode./); + + // Explicitly set to image mode + await visionTaskRunner.setOptions({runningMode: 'image'}); + expect(() => { + visionTaskRunner.processVideoData({} as HTMLImageElement, 42); + }).toThrowError(/Task is not initialized with video mode./); + }); +}); diff --git a/mediapipe/tasks/web/vision/core/vision_task_runner.ts b/mediapipe/tasks/web/vision/core/vision_task_runner.ts index 78b4859f2..3432b521b 100644 --- a/mediapipe/tasks/web/vision/core/vision_task_runner.ts +++ b/mediapipe/tasks/web/vision/core/vision_task_runner.ts @@ -20,8 +20,7 @@ import {ImageSource} from '../../../../web/graph_runner/graph_runner'; import {VisionTaskOptions} from './vision_task_options'; /** Base class for all MediaPipe Vision Tasks. */ -export abstract class VisionTaskRunner extends - TaskRunner { +export abstract class VisionTaskRunner extends TaskRunner { /** Configures the shared options of a vision task. */ override async setOptions(options: VisionTaskOptions): Promise { await super.setOptions(options); diff --git a/mediapipe/tasks/web/vision/gesture_recognizer/BUILD b/mediapipe/tasks/web/vision/gesture_recognizer/BUILD index ddfd1a327..aa2f9c366 100644 --- a/mediapipe/tasks/web/vision/gesture_recognizer/BUILD +++ b/mediapipe/tasks/web/vision/gesture_recognizer/BUILD @@ -4,6 +4,7 @@ # the detection results for one or more gesture categories, using Gesture Recognizer. load("//mediapipe/framework/port:build_config.bzl", "mediapipe_ts_declaration", "mediapipe_ts_library") +load("@npm//@bazel/jasmine:index.bzl", "jasmine_node_test") package(default_visibility = ["//mediapipe/tasks:internal"]) @@ -12,6 +13,7 @@ licenses(["notice"]) mediapipe_ts_library( name = "gesture_recognizer", srcs = ["gesture_recognizer.ts"], + visibility = ["//visibility:public"], deps = [ ":gesture_recognizer_types", "//mediapipe/framework:calculator_jspb_proto", @@ -42,6 +44,7 @@ mediapipe_ts_declaration( "gesture_recognizer_options.d.ts", "gesture_recognizer_result.d.ts", ], + visibility = ["//visibility:public"], deps = [ "//mediapipe/tasks/web/components/containers:category", "//mediapipe/tasks/web/components/containers:landmark", @@ -50,3 +53,27 @@ mediapipe_ts_declaration( "//mediapipe/tasks/web/vision/core:vision_task_options", ], ) + +mediapipe_ts_library( + name = "gesture_recognizer_test_lib", + testonly = True, + srcs = [ + "gesture_recognizer_test.ts", + ], + deps = [ + ":gesture_recognizer", + ":gesture_recognizer_types", + "//mediapipe/framework:calculator_jspb_proto", + "//mediapipe/framework/formats:classification_jspb_proto", + "//mediapipe/framework/formats:landmark_jspb_proto", + "//mediapipe/tasks/web/core", + "//mediapipe/tasks/web/core:task_runner", + "//mediapipe/tasks/web/core:task_runner_test_utils", + ], +) + +jasmine_node_test( + name = "gesture_recognizer_test", + tags = ["nomsan"], + deps = [":gesture_recognizer_test_lib"], +) diff --git a/mediapipe/tasks/web/vision/gesture_recognizer/gesture_recognizer.ts b/mediapipe/tasks/web/vision/gesture_recognizer/gesture_recognizer.ts index 69a8118a6..b6b795076 100644 --- a/mediapipe/tasks/web/vision/gesture_recognizer/gesture_recognizer.ts +++ b/mediapipe/tasks/web/vision/gesture_recognizer/gesture_recognizer.ts @@ -127,6 +127,7 @@ export class GestureRecognizer extends {baseOptions: {modelAssetPath}}); } + /** @hideconstructor */ constructor( wasmModule: WasmModule, glCanvas?: HTMLCanvasElement|OffscreenCanvas|null) { diff --git a/mediapipe/tasks/web/vision/gesture_recognizer/gesture_recognizer_test.ts b/mediapipe/tasks/web/vision/gesture_recognizer/gesture_recognizer_test.ts new file mode 100644 index 000000000..c0f0d1554 --- /dev/null +++ b/mediapipe/tasks/web/vision/gesture_recognizer/gesture_recognizer_test.ts @@ -0,0 +1,307 @@ +/** + * Copyright 2022 The MediaPipe Authors. All Rights Reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +import 'jasmine'; + +import {CalculatorGraphConfig} from '../../../../framework/calculator_pb'; +import {Classification, ClassificationList} from '../../../../framework/formats/classification_pb'; +import {Landmark, LandmarkList, NormalizedLandmark, NormalizedLandmarkList} from '../../../../framework/formats/landmark_pb'; +import {GraphRunnerImageLib} from '../../../../tasks/web/core/task_runner'; +import {addJasmineCustomFloatEqualityTester, createSpyWasmModule, MediapipeTasksFake, SpyWasmModule, verifyGraph, verifyListenersRegistered} from '../../../../tasks/web/core/task_runner_test_utils'; + +import {GestureRecognizer, GestureRecognizerOptions} from './gesture_recognizer'; + +// The OSS JS API does not support the builder pattern. +// tslint:disable:jspb-use-builder-pattern + +type ProtoListener = ((binaryProtos: Uint8Array[]) => void); + +function createHandednesses(): Uint8Array[] { + const handsProto = new ClassificationList(); + const classification = new Classification(); + classification.setScore(0.1); + classification.setIndex(1); + classification.setLabel('handedness_label'); + classification.setDisplayName('handedness_display_name'); + handsProto.addClassification(classification); + return [handsProto.serializeBinary()]; +} + +function createGestures(): Uint8Array[] { + const gesturesProto = new ClassificationList(); + const classification = new Classification(); + classification.setScore(0.2); + classification.setIndex(2); + classification.setLabel('gesture_label'); + classification.setDisplayName('gesture_display_name'); + gesturesProto.addClassification(classification); + return [gesturesProto.serializeBinary()]; +} + +function createLandmarks(): Uint8Array[] { + const handLandmarksProto = new NormalizedLandmarkList(); + const landmark = new NormalizedLandmark(); + landmark.setX(0.3); + landmark.setY(0.4); + landmark.setZ(0.5); + handLandmarksProto.addLandmark(landmark); + return [handLandmarksProto.serializeBinary()]; +} + +function createWorldLandmarks(): Uint8Array[] { + const handLandmarksProto = new LandmarkList(); + const landmark = new Landmark(); + landmark.setX(21); + landmark.setY(22); + landmark.setZ(23); + handLandmarksProto.addLandmark(landmark); + return [handLandmarksProto.serializeBinary()]; +} + +class GestureRecognizerFake extends GestureRecognizer implements + MediapipeTasksFake { + calculatorName = + 'mediapipe.tasks.vision.gesture_recognizer.GestureRecognizerGraph'; + attachListenerSpies: jasmine.Spy[] = []; + graph: CalculatorGraphConfig|undefined; + fakeWasmModule: SpyWasmModule; + listeners = new Map(); + + constructor() { + super(createSpyWasmModule(), /* glCanvas= */ null); + this.fakeWasmModule = + this.graphRunner.wasmModule as unknown as SpyWasmModule; + this.attachListenerSpies[0] = + spyOn(this.graphRunner, 'attachProtoVectorListener') + .and.callFake((stream, listener) => { + expect(stream).toMatch( + /(hand_landmarks|world_hand_landmarks|handedness|hand_gestures)/); + this.listeners.set(stream, listener); + }); + + spyOn(this.graphRunner, 'setGraph').and.callFake(binaryGraph => { + this.graph = CalculatorGraphConfig.deserializeBinary(binaryGraph); + }); + spyOn(this.graphRunner, 'addGpuBufferAsImageToStream'); + spyOn(this.graphRunner, 'addProtoToStream'); + } + + getGraphRunner(): GraphRunnerImageLib { + return this.graphRunner; + } +} + +describe('GestureRecognizer', () => { + let gestureRecognizer: GestureRecognizerFake; + + beforeEach(async () => { + addJasmineCustomFloatEqualityTester(); + gestureRecognizer = new GestureRecognizerFake(); + await gestureRecognizer.setOptions({}); // Initialize graph + }); + + it('initializes graph', async () => { + verifyGraph(gestureRecognizer); + verifyListenersRegistered(gestureRecognizer); + }); + + it('reloads graph when settings are changed', async () => { + await gestureRecognizer.setOptions({numHands: 1}); + verifyGraph(gestureRecognizer, [ + ['handLandmarkerGraphOptions', 'handDetectorGraphOptions', 'numHands'], 1 + ]); + verifyListenersRegistered(gestureRecognizer); + + await gestureRecognizer.setOptions({numHands: 5}); + verifyGraph(gestureRecognizer, [ + ['handLandmarkerGraphOptions', 'handDetectorGraphOptions', 'numHands'], 5 + ]); + verifyListenersRegistered(gestureRecognizer); + }); + + it('merges options', async () => { + await gestureRecognizer.setOptions({numHands: 1}); + await gestureRecognizer.setOptions({minHandDetectionConfidence: 0.5}); + verifyGraph(gestureRecognizer, [ + ['handLandmarkerGraphOptions', 'handDetectorGraphOptions', 'numHands'], 1 + ]); + verifyGraph(gestureRecognizer, [ + [ + 'handLandmarkerGraphOptions', 'handDetectorGraphOptions', + 'minDetectionConfidence' + ], + 0.5 + ]); + }); + + describe('setOptions() ', () => { + interface TestCase { + optionPath: [keyof GestureRecognizerOptions, ...string[]]; + fieldPath: string[]; + customValue: unknown; + defaultValue: unknown; + } + + const testCases: TestCase[] = [ + { + optionPath: ['numHands'], + fieldPath: [ + 'handLandmarkerGraphOptions', 'handDetectorGraphOptions', 'numHands' + ], + customValue: 5, + defaultValue: 1 + }, + { + optionPath: ['minHandDetectionConfidence'], + fieldPath: [ + 'handLandmarkerGraphOptions', 'handDetectorGraphOptions', + 'minDetectionConfidence' + ], + customValue: 0.1, + defaultValue: 0.5 + }, + { + optionPath: ['minHandPresenceConfidence'], + fieldPath: [ + 'handLandmarkerGraphOptions', 'handLandmarksDetectorGraphOptions', + 'minDetectionConfidence' + ], + customValue: 0.2, + defaultValue: 0.5 + }, + { + optionPath: ['minTrackingConfidence'], + fieldPath: ['handLandmarkerGraphOptions', 'minTrackingConfidence'], + customValue: 0.3, + defaultValue: 0.5 + }, + { + optionPath: ['cannedGesturesClassifierOptions', 'scoreThreshold'], + fieldPath: [ + 'handGestureRecognizerGraphOptions', + 'cannedGestureClassifierGraphOptions', 'classifierOptions', + 'scoreThreshold' + ], + customValue: 0.4, + defaultValue: undefined + }, + { + optionPath: ['customGesturesClassifierOptions', 'scoreThreshold'], + fieldPath: [ + 'handGestureRecognizerGraphOptions', + 'customGestureClassifierGraphOptions', 'classifierOptions', + 'scoreThreshold' + ], + customValue: 0.5, + defaultValue: undefined, + }, + ]; + + /** Creates an options object that can be passed to setOptions() */ + function createOptions( + path: string[], value: unknown): GestureRecognizerOptions { + const options: Record = {}; + let currentLevel = options; + for (const element of path.slice(0, -1)) { + currentLevel[element] = {}; + currentLevel = currentLevel[element] as Record; + } + currentLevel[path[path.length - 1]] = value; + return options; + } + + for (const testCase of testCases) { + it(`uses default value for ${testCase.optionPath[0]}`, async () => { + verifyGraph( + gestureRecognizer, [testCase.fieldPath, testCase.defaultValue]); + }); + + it(`can set ${testCase.optionPath[0]}`, async () => { + await gestureRecognizer.setOptions( + createOptions(testCase.optionPath, testCase.customValue)); + verifyGraph( + gestureRecognizer, [testCase.fieldPath, testCase.customValue]); + }); + + it(`can clear ${testCase.optionPath[0]}`, async () => { + await gestureRecognizer.setOptions( + createOptions(testCase.optionPath, testCase.customValue)); + verifyGraph( + gestureRecognizer, [testCase.fieldPath, testCase.customValue]); + + await gestureRecognizer.setOptions( + createOptions(testCase.optionPath, undefined)); + verifyGraph( + gestureRecognizer, [testCase.fieldPath, testCase.defaultValue]); + }); + } + }); + + it('transforms results', async () => { + // Pass the test data to our listener + gestureRecognizer.fakeWasmModule._waitUntilIdle.and.callFake(() => { + verifyListenersRegistered(gestureRecognizer); + gestureRecognizer.listeners.get('hand_landmarks')!(createLandmarks()); + gestureRecognizer.listeners.get('world_hand_landmarks')! + (createWorldLandmarks()); + gestureRecognizer.listeners.get('handedness')!(createHandednesses()); + gestureRecognizer.listeners.get('hand_gestures')!(createGestures()); + }); + + // Invoke the gesture recognizer + const gestures = gestureRecognizer.recognize({} as HTMLImageElement); + expect(gestureRecognizer.getGraphRunner().addProtoToStream) + .toHaveBeenCalledTimes(1); + expect(gestureRecognizer.getGraphRunner().addGpuBufferAsImageToStream) + .toHaveBeenCalledTimes(1); + expect(gestureRecognizer.fakeWasmModule._waitUntilIdle).toHaveBeenCalled(); + + expect(gestures).toEqual({ + 'gestures': [[{ + 'score': 0.2, + 'index': 2, + 'categoryName': 'gesture_label', + 'displayName': 'gesture_display_name' + }]], + 'landmarks': [[{'x': 0.3, 'y': 0.4, 'z': 0.5}]], + 'worldLandmarks': [[{'x': 21, 'y': 22, 'z': 23}]], + 'handednesses': [[{ + 'score': 0.1, + 'index': 1, + 'categoryName': 'handedness_label', + 'displayName': 'handedness_display_name' + }]] + }); + }); + + it('clears results between invoations', async () => { + // Pass the test data to our listener + gestureRecognizer.fakeWasmModule._waitUntilIdle.and.callFake(() => { + gestureRecognizer.listeners.get('hand_landmarks')!(createLandmarks()); + gestureRecognizer.listeners.get('world_hand_landmarks')! + (createWorldLandmarks()); + gestureRecognizer.listeners.get('handedness')!(createHandednesses()); + gestureRecognizer.listeners.get('hand_gestures')!(createGestures()); + }); + + // Invoke the gesture recognizer twice + const gestures1 = gestureRecognizer.recognize({} as HTMLImageElement); + const gestures2 = gestureRecognizer.recognize({} as HTMLImageElement); + + // Verify that gestures2 is not a concatenation of all previously returned + // gestures. + expect(gestures2).toEqual(gestures1); + }); +}); diff --git a/mediapipe/tasks/web/vision/hand_landmarker/BUILD b/mediapipe/tasks/web/vision/hand_landmarker/BUILD index fc3e6ef1f..d1f1e48f3 100644 --- a/mediapipe/tasks/web/vision/hand_landmarker/BUILD +++ b/mediapipe/tasks/web/vision/hand_landmarker/BUILD @@ -4,6 +4,7 @@ # the detection results for one or more hand categories, using Hand Landmarker. load("//mediapipe/framework/port:build_config.bzl", "mediapipe_ts_declaration", "mediapipe_ts_library") +load("@npm//@bazel/jasmine:index.bzl", "jasmine_node_test") package(default_visibility = ["//mediapipe/tasks:internal"]) @@ -12,6 +13,7 @@ licenses(["notice"]) mediapipe_ts_library( name = "hand_landmarker", srcs = ["hand_landmarker.ts"], + visibility = ["//visibility:public"], deps = [ ":hand_landmarker_types", "//mediapipe/framework:calculator_jspb_proto", @@ -38,6 +40,7 @@ mediapipe_ts_declaration( "hand_landmarker_options.d.ts", "hand_landmarker_result.d.ts", ], + visibility = ["//visibility:public"], deps = [ "//mediapipe/tasks/web/components/containers:category", "//mediapipe/tasks/web/components/containers:landmark", @@ -45,3 +48,27 @@ mediapipe_ts_declaration( "//mediapipe/tasks/web/vision/core:vision_task_options", ], ) + +mediapipe_ts_library( + name = "hand_landmarker_test_lib", + testonly = True, + srcs = [ + "hand_landmarker_test.ts", + ], + deps = [ + ":hand_landmarker", + ":hand_landmarker_types", + "//mediapipe/framework:calculator_jspb_proto", + "//mediapipe/framework/formats:classification_jspb_proto", + "//mediapipe/framework/formats:landmark_jspb_proto", + "//mediapipe/tasks/web/core", + "//mediapipe/tasks/web/core:task_runner", + "//mediapipe/tasks/web/core:task_runner_test_utils", + ], +) + +jasmine_node_test( + name = "hand_landmarker_test", + tags = ["nomsan"], + deps = [":hand_landmarker_test_lib"], +) diff --git a/mediapipe/tasks/web/vision/hand_landmarker/hand_landmarker.ts b/mediapipe/tasks/web/vision/hand_landmarker/hand_landmarker.ts index 9a0823f23..2a0e8286c 100644 --- a/mediapipe/tasks/web/vision/hand_landmarker/hand_landmarker.ts +++ b/mediapipe/tasks/web/vision/hand_landmarker/hand_landmarker.ts @@ -115,6 +115,7 @@ export class HandLandmarker extends VisionTaskRunner { {baseOptions: {modelAssetPath}}); } + /** @hideconstructor */ constructor( wasmModule: WasmModule, glCanvas?: HTMLCanvasElement|OffscreenCanvas|null) { diff --git a/mediapipe/tasks/web/vision/hand_landmarker/hand_landmarker_test.ts b/mediapipe/tasks/web/vision/hand_landmarker/hand_landmarker_test.ts new file mode 100644 index 000000000..fc26680e0 --- /dev/null +++ b/mediapipe/tasks/web/vision/hand_landmarker/hand_landmarker_test.ts @@ -0,0 +1,251 @@ +/** + * Copyright 2022 The MediaPipe Authors. All Rights Reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +import 'jasmine'; + +import {CalculatorGraphConfig} from '../../../../framework/calculator_pb'; +import {Classification, ClassificationList} from '../../../../framework/formats/classification_pb'; +import {Landmark, LandmarkList, NormalizedLandmark, NormalizedLandmarkList} from '../../../../framework/formats/landmark_pb'; +import {GraphRunnerImageLib} from '../../../../tasks/web/core/task_runner'; +import {addJasmineCustomFloatEqualityTester, createSpyWasmModule, MediapipeTasksFake, SpyWasmModule, verifyGraph, verifyListenersRegistered} from '../../../../tasks/web/core/task_runner_test_utils'; + +import {HandLandmarker} from './hand_landmarker'; +import {HandLandmarkerOptions} from './hand_landmarker_options'; + +// The OSS JS API does not support the builder pattern. +// tslint:disable:jspb-use-builder-pattern + +type ProtoListener = ((binaryProtos: Uint8Array[]) => void); + +function createHandednesses(): Uint8Array[] { + const handsProto = new ClassificationList(); + const classification = new Classification(); + classification.setScore(0.1); + classification.setIndex(1); + classification.setLabel('handedness_label'); + classification.setDisplayName('handedness_display_name'); + handsProto.addClassification(classification); + return [handsProto.serializeBinary()]; +} + +function createLandmarks(): Uint8Array[] { + const handLandmarksProto = new NormalizedLandmarkList(); + const landmark = new NormalizedLandmark(); + landmark.setX(0.3); + landmark.setY(0.4); + landmark.setZ(0.5); + handLandmarksProto.addLandmark(landmark); + return [handLandmarksProto.serializeBinary()]; +} + +function createWorldLandmarks(): Uint8Array[] { + const handLandmarksProto = new LandmarkList(); + const landmark = new Landmark(); + landmark.setX(21); + landmark.setY(22); + landmark.setZ(23); + handLandmarksProto.addLandmark(landmark); + return [handLandmarksProto.serializeBinary()]; +} + +class HandLandmarkerFake extends HandLandmarker implements MediapipeTasksFake { + calculatorName = 'mediapipe.tasks.vision.hand_landmarker.HandLandmarkerGraph'; + attachListenerSpies: jasmine.Spy[] = []; + graph: CalculatorGraphConfig|undefined; + fakeWasmModule: SpyWasmModule; + listeners = new Map(); + + constructor() { + super(createSpyWasmModule(), /* glCanvas= */ null); + this.fakeWasmModule = + this.graphRunner.wasmModule as unknown as SpyWasmModule; + + this.attachListenerSpies[0] = + spyOn(this.graphRunner, 'attachProtoVectorListener') + .and.callFake((stream, listener) => { + expect(stream).toMatch( + /(hand_landmarks|world_hand_landmarks|handedness|hand_hands)/); + this.listeners.set(stream, listener); + }); + + spyOn(this.graphRunner, 'setGraph').and.callFake(binaryGraph => { + this.graph = CalculatorGraphConfig.deserializeBinary(binaryGraph); + }); + spyOn(this.graphRunner, 'addGpuBufferAsImageToStream'); + spyOn(this.graphRunner, 'addProtoToStream'); + } + + getGraphRunner(): GraphRunnerImageLib { + return this.graphRunner; + } +} + +describe('HandLandmarker', () => { + let handLandmarker: HandLandmarkerFake; + + beforeEach(async () => { + addJasmineCustomFloatEqualityTester(); + handLandmarker = new HandLandmarkerFake(); + await handLandmarker.setOptions({}); // Initialize graph + }); + + it('initializes graph', async () => { + verifyGraph(handLandmarker); + verifyListenersRegistered(handLandmarker); + }); + + it('reloads graph when settings are changed', async () => { + verifyListenersRegistered(handLandmarker); + + await handLandmarker.setOptions({numHands: 1}); + verifyGraph(handLandmarker, [['handDetectorGraphOptions', 'numHands'], 1]); + verifyListenersRegistered(handLandmarker); + + await handLandmarker.setOptions({numHands: 5}); + verifyGraph(handLandmarker, [['handDetectorGraphOptions', 'numHands'], 5]); + verifyListenersRegistered(handLandmarker); + }); + + it('merges options', async () => { + await handLandmarker.setOptions({numHands: 1}); + await handLandmarker.setOptions({minHandDetectionConfidence: 0.5}); + verifyGraph(handLandmarker, [ + 'handDetectorGraphOptions', + {numHands: 1, baseOptions: undefined, minDetectionConfidence: 0.5} + ]); + }); + + describe('setOptions() ', () => { + interface TestCase { + optionPath: [keyof HandLandmarkerOptions, ...string[]]; + fieldPath: string[]; + customValue: unknown; + defaultValue: unknown; + } + + const testCases: TestCase[] = [ + { + optionPath: ['numHands'], + fieldPath: ['handDetectorGraphOptions', 'numHands'], + customValue: 5, + defaultValue: 1 + }, + { + optionPath: ['minHandDetectionConfidence'], + fieldPath: ['handDetectorGraphOptions', 'minDetectionConfidence'], + customValue: 0.1, + defaultValue: 0.5 + }, + { + optionPath: ['minHandPresenceConfidence'], + fieldPath: + ['handLandmarksDetectorGraphOptions', 'minDetectionConfidence'], + customValue: 0.2, + defaultValue: 0.5 + }, + { + optionPath: ['minTrackingConfidence'], + fieldPath: ['minTrackingConfidence'], + customValue: 0.3, + defaultValue: 0.5 + }, + ]; + + /** Creates an options object that can be passed to setOptions() */ + function createOptions( + path: string[], value: unknown): HandLandmarkerOptions { + const options: Record = {}; + let currentLevel = options; + for (const element of path.slice(0, -1)) { + currentLevel[element] = {}; + currentLevel = currentLevel[element] as Record; + } + currentLevel[path[path.length - 1]] = value; + return options; + } + + for (const testCase of testCases) { + it(`uses default value for ${testCase.optionPath[0]}`, async () => { + verifyGraph( + handLandmarker, [testCase.fieldPath, testCase.defaultValue]); + }); + + it(`can set ${testCase.optionPath[0]}`, async () => { + await handLandmarker.setOptions( + createOptions(testCase.optionPath, testCase.customValue)); + verifyGraph(handLandmarker, [testCase.fieldPath, testCase.customValue]); + }); + + it(`can clear ${testCase.optionPath[0]}`, async () => { + await handLandmarker.setOptions( + createOptions(testCase.optionPath, testCase.customValue)); + verifyGraph(handLandmarker, [testCase.fieldPath, testCase.customValue]); + + await handLandmarker.setOptions( + createOptions(testCase.optionPath, undefined)); + verifyGraph( + handLandmarker, [testCase.fieldPath, testCase.defaultValue]); + }); + } + }); + + it('transforms results', async () => { + // Pass the test data to our listener + handLandmarker.fakeWasmModule._waitUntilIdle.and.callFake(() => { + verifyListenersRegistered(handLandmarker); + handLandmarker.listeners.get('hand_landmarks')!(createLandmarks()); + handLandmarker.listeners.get('world_hand_landmarks')! + (createWorldLandmarks()); + handLandmarker.listeners.get('handedness')!(createHandednesses()); + }); + + // Invoke the hand landmarker + const landmarks = handLandmarker.detect({} as HTMLImageElement); + expect(handLandmarker.getGraphRunner().addProtoToStream) + .toHaveBeenCalledTimes(1); + expect(handLandmarker.getGraphRunner().addGpuBufferAsImageToStream) + .toHaveBeenCalledTimes(1); + expect(handLandmarker.fakeWasmModule._waitUntilIdle).toHaveBeenCalled(); + + expect(landmarks).toEqual({ + 'landmarks': [[{'x': 0.3, 'y': 0.4, 'z': 0.5}]], + 'worldLandmarks': [[{'x': 21, 'y': 22, 'z': 23}]], + 'handednesses': [[{ + 'score': 0.1, + 'index': 1, + 'categoryName': 'handedness_label', + 'displayName': 'handedness_display_name' + }]] + }); + }); + + it('clears results between invoations', async () => { + // Pass the test data to our listener + handLandmarker.fakeWasmModule._waitUntilIdle.and.callFake(() => { + handLandmarker.listeners.get('hand_landmarks')!(createLandmarks()); + handLandmarker.listeners.get('world_hand_landmarks')! + (createWorldLandmarks()); + handLandmarker.listeners.get('handedness')!(createHandednesses()); + }); + + // Invoke the hand landmarker twice + const landmarks1 = handLandmarker.detect({} as HTMLImageElement); + const landmarks2 = handLandmarker.detect({} as HTMLImageElement); + + // Verify that hands2 is not a concatenation of all previously returned + // hands. + expect(landmarks1).toEqual(landmarks2); + }); +}); diff --git a/mediapipe/tasks/web/vision/image_classifier/BUILD b/mediapipe/tasks/web/vision/image_classifier/BUILD index ebe64ecf4..310575964 100644 --- a/mediapipe/tasks/web/vision/image_classifier/BUILD +++ b/mediapipe/tasks/web/vision/image_classifier/BUILD @@ -3,6 +3,7 @@ # This task takes video or image frames and outputs the classification result. load("//mediapipe/framework/port:build_config.bzl", "mediapipe_ts_declaration", "mediapipe_ts_library") +load("@npm//@bazel/jasmine:index.bzl", "jasmine_node_test") package(default_visibility = ["//mediapipe/tasks:internal"]) @@ -11,6 +12,7 @@ licenses(["notice"]) mediapipe_ts_library( name = "image_classifier", srcs = ["image_classifier.ts"], + visibility = ["//visibility:public"], deps = [ ":image_classifier_types", "//mediapipe/framework:calculator_jspb_proto", @@ -35,6 +37,7 @@ mediapipe_ts_declaration( "image_classifier_options.d.ts", "image_classifier_result.d.ts", ], + visibility = ["//visibility:public"], deps = [ "//mediapipe/tasks/web/components/containers:category", "//mediapipe/tasks/web/components/containers:classification_result", @@ -42,3 +45,26 @@ mediapipe_ts_declaration( "//mediapipe/tasks/web/vision/core:vision_task_options", ], ) + +mediapipe_ts_library( + name = "image_classifier_test_lib", + testonly = True, + srcs = [ + "image_classifier_test.ts", + ], + deps = [ + ":image_classifier", + ":image_classifier_types", + "//mediapipe/framework:calculator_jspb_proto", + "//mediapipe/framework/formats:classification_jspb_proto", + "//mediapipe/tasks/cc/components/containers/proto:classifications_jspb_proto", + "//mediapipe/tasks/web/core", + "//mediapipe/tasks/web/core:task_runner_test_utils", + ], +) + +jasmine_node_test( + name = "image_classifier_test", + tags = ["nomsan"], + deps = [":image_classifier_test_lib"], +) diff --git a/mediapipe/tasks/web/vision/image_classifier/image_classifier.ts b/mediapipe/tasks/web/vision/image_classifier/image_classifier.ts index 40e8b5099..36e7311fb 100644 --- a/mediapipe/tasks/web/vision/image_classifier/image_classifier.ts +++ b/mediapipe/tasks/web/vision/image_classifier/image_classifier.ts @@ -93,6 +93,7 @@ export class ImageClassifier extends VisionTaskRunner { {baseOptions: {modelAssetPath}}); } + /** @hideconstructor */ constructor( wasmModule: WasmModule, glCanvas?: HTMLCanvasElement|OffscreenCanvas|null) { diff --git a/mediapipe/tasks/web/vision/image_classifier/image_classifier_test.ts b/mediapipe/tasks/web/vision/image_classifier/image_classifier_test.ts new file mode 100644 index 000000000..2041a0cef --- /dev/null +++ b/mediapipe/tasks/web/vision/image_classifier/image_classifier_test.ts @@ -0,0 +1,150 @@ +/** + * Copyright 2022 The MediaPipe Authors. All Rights Reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +import 'jasmine'; + +// Placeholder for internal dependency on encodeByteArray +import {CalculatorGraphConfig} from '../../../../framework/calculator_pb'; +import {Classification, ClassificationList} from '../../../../framework/formats/classification_pb'; +import {ClassificationResult, Classifications} from '../../../../tasks/cc/components/containers/proto/classifications_pb'; +import {addJasmineCustomFloatEqualityTester, createSpyWasmModule, MediapipeTasksFake, SpyWasmModule, verifyGraph, verifyListenersRegistered} from '../../../../tasks/web/core/task_runner_test_utils'; + +import {ImageClassifier} from './image_classifier'; + +// The OSS JS API does not support the builder pattern. +// tslint:disable:jspb-use-builder-pattern + +class ImageClassifierFake extends ImageClassifier implements + MediapipeTasksFake { + calculatorName = + 'mediapipe.tasks.vision.image_classifier.ImageClassifierGraph'; + attachListenerSpies: jasmine.Spy[] = []; + graph: CalculatorGraphConfig|undefined; + + fakeWasmModule: SpyWasmModule; + protoListener: ((binaryProto: Uint8Array) => void)|undefined; + + constructor() { + super(createSpyWasmModule(), /* glCanvas= */ null); + this.fakeWasmModule = + this.graphRunner.wasmModule as unknown as SpyWasmModule; + + this.attachListenerSpies[0] = + spyOn(this.graphRunner, 'attachProtoListener') + .and.callFake((stream, listener) => { + expect(stream).toEqual('classifications'); + this.protoListener = listener; + }); + spyOn(this.graphRunner, 'setGraph').and.callFake(binaryGraph => { + this.graph = CalculatorGraphConfig.deserializeBinary(binaryGraph); + }); + spyOn(this.graphRunner, 'addGpuBufferAsImageToStream'); + } +} + +describe('ImageClassifier', () => { + let imageClassifier: ImageClassifierFake; + + beforeEach(async () => { + addJasmineCustomFloatEqualityTester(); + imageClassifier = new ImageClassifierFake(); + await imageClassifier.setOptions({}); // Initialize graph + }); + + it('initializes graph', async () => { + verifyGraph(imageClassifier); + verifyListenersRegistered(imageClassifier); + }); + + it('reloads graph when settings are changed', async () => { + await imageClassifier.setOptions({maxResults: 1}); + verifyGraph(imageClassifier, [['classifierOptions', 'maxResults'], 1]); + verifyListenersRegistered(imageClassifier); + + await imageClassifier.setOptions({maxResults: 5}); + verifyGraph(imageClassifier, [['classifierOptions', 'maxResults'], 5]); + verifyListenersRegistered(imageClassifier); + }); + + it('can use custom models', async () => { + const newModel = new Uint8Array([0, 1, 2, 3, 4]); + const newModelBase64 = Buffer.from(newModel).toString('base64'); + await imageClassifier.setOptions({ + baseOptions: { + modelAssetBuffer: newModel, + } + }); + + verifyGraph( + imageClassifier, + /* expectedCalculatorOptions= */ undefined, + /* expectedBaseOptions= */[ + 'modelAsset', { + fileContent: newModelBase64, + fileName: undefined, + fileDescriptorMeta: undefined, + filePointerMeta: undefined + } + ]); + }); + + it('merges options', async () => { + await imageClassifier.setOptions({maxResults: 1}); + await imageClassifier.setOptions({displayNamesLocale: 'en'}); + verifyGraph(imageClassifier, [['classifierOptions', 'maxResults'], 1]); + verifyGraph( + imageClassifier, [['classifierOptions', 'displayNamesLocale'], 'en']); + }); + + it('transforms results', async () => { + const classificationResult = new ClassificationResult(); + const classifcations = new Classifications(); + classifcations.setHeadIndex(1); + classifcations.setHeadName('headName'); + const classificationList = new ClassificationList(); + const clasification = new Classification(); + clasification.setIndex(1); + clasification.setScore(0.2); + clasification.setDisplayName('displayName'); + clasification.setLabel('categoryName'); + classificationList.addClassification(clasification); + classifcations.setClassificationList(classificationList); + classificationResult.addClassifications(classifcations); + + // Pass the test data to our listener + imageClassifier.fakeWasmModule._waitUntilIdle.and.callFake(() => { + verifyListenersRegistered(imageClassifier); + imageClassifier.protoListener!(classificationResult.serializeBinary()); + }); + + // Invoke the image classifier + const result = imageClassifier.classify({} as HTMLImageElement); + + expect(imageClassifier.fakeWasmModule._waitUntilIdle).toHaveBeenCalled(); + expect(result).toEqual({ + classifications: [{ + categories: [{ + index: 1, + score: 0.2, + displayName: 'displayName', + categoryName: 'categoryName' + }], + headIndex: 1, + headName: 'headName' + }] + }); + }); +}); diff --git a/mediapipe/tasks/web/vision/image_embedder/BUILD b/mediapipe/tasks/web/vision/image_embedder/BUILD index 2f012dc5e..de4785e6c 100644 --- a/mediapipe/tasks/web/vision/image_embedder/BUILD +++ b/mediapipe/tasks/web/vision/image_embedder/BUILD @@ -3,6 +3,7 @@ # This task performs embedding extraction on images. load("//mediapipe/framework/port:build_config.bzl", "mediapipe_ts_declaration", "mediapipe_ts_library") +load("@npm//@bazel/jasmine:index.bzl", "jasmine_node_test") package(default_visibility = ["//mediapipe/tasks:internal"]) @@ -11,6 +12,7 @@ licenses(["notice"]) mediapipe_ts_library( name = "image_embedder", srcs = ["image_embedder.ts"], + visibility = ["//visibility:public"], deps = [ ":image_embedder_types", "//mediapipe/framework:calculator_jspb_proto", @@ -36,6 +38,7 @@ mediapipe_ts_declaration( "image_embedder_options.d.ts", "image_embedder_result.d.ts", ], + visibility = ["//visibility:public"], deps = [ "//mediapipe/tasks/web/components/containers:embedding_result", "//mediapipe/tasks/web/core", @@ -43,3 +46,23 @@ mediapipe_ts_declaration( "//mediapipe/tasks/web/vision/core:vision_task_options", ], ) + +mediapipe_ts_library( + name = "image_embedder_test_lib", + testonly = True, + srcs = [ + "image_embedder_test.ts", + ], + deps = [ + ":image_embedder", + "//mediapipe/framework:calculator_jspb_proto", + "//mediapipe/tasks/cc/components/containers/proto:embeddings_jspb_proto", + "//mediapipe/tasks/web/core", + "//mediapipe/tasks/web/core:task_runner_test_utils", + ], +) + +jasmine_node_test( + name = "image_embedder_test", + deps = [":image_embedder_test_lib"], +) diff --git a/mediapipe/tasks/web/vision/image_embedder/image_embedder.ts b/mediapipe/tasks/web/vision/image_embedder/image_embedder.ts index f8b0204ee..0c45ba5e7 100644 --- a/mediapipe/tasks/web/vision/image_embedder/image_embedder.ts +++ b/mediapipe/tasks/web/vision/image_embedder/image_embedder.ts @@ -95,6 +95,7 @@ export class ImageEmbedder extends VisionTaskRunner { {baseOptions: {modelAssetPath}}); } + /** @hideconstructor */ constructor( wasmModule: WasmModule, glCanvas?: HTMLCanvasElement|OffscreenCanvas|null) { diff --git a/mediapipe/tasks/web/vision/image_embedder/image_embedder_test.ts b/mediapipe/tasks/web/vision/image_embedder/image_embedder_test.ts new file mode 100644 index 000000000..cafe0f3d8 --- /dev/null +++ b/mediapipe/tasks/web/vision/image_embedder/image_embedder_test.ts @@ -0,0 +1,158 @@ +/** + * Copyright 2022 The MediaPipe Authors. All Rights Reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +import 'jasmine'; + +// Placeholder for internal dependency on encodeByteArray +import {CalculatorGraphConfig} from '../../../../framework/calculator_pb'; +import {Embedding, EmbeddingResult, FloatEmbedding} from '../../../../tasks/cc/components/containers/proto/embeddings_pb'; +import {addJasmineCustomFloatEqualityTester, createSpyWasmModule, MediapipeTasksFake, SpyWasmModule, verifyGraph, verifyListenersRegistered} from '../../../../tasks/web/core/task_runner_test_utils'; + +import {ImageEmbedder} from './image_embedder'; + +// The OSS JS API does not support the builder pattern. +// tslint:disable:jspb-use-builder-pattern + +class ImageEmbedderFake extends ImageEmbedder implements MediapipeTasksFake { + calculatorName = 'mediapipe.tasks.vision.image_embedder.ImageEmbedderGraph'; + graph: CalculatorGraphConfig|undefined; + attachListenerSpies: jasmine.Spy[] = []; + fakeWasmModule: SpyWasmModule; + protoListener: ((binaryProtos: Uint8Array) => void)|undefined; + + constructor() { + super(createSpyWasmModule(), /* glCanvas= */ null); + this.fakeWasmModule = + this.graphRunner.wasmModule as unknown as SpyWasmModule; + + this.attachListenerSpies[0] = + spyOn(this.graphRunner, 'attachProtoListener') + .and.callFake((stream, listener) => { + expect(stream).toEqual('embeddings_out'); + this.protoListener = listener; + }); + spyOn(this.graphRunner, 'setGraph').and.callFake(binaryGraph => { + this.graph = CalculatorGraphConfig.deserializeBinary(binaryGraph); + }); + spyOn(this.graphRunner, 'addGpuBufferAsImageToStream'); + } +} + +describe('ImageEmbedder', () => { + let imageEmbedder: ImageEmbedderFake; + + beforeEach(async () => { + addJasmineCustomFloatEqualityTester(); + imageEmbedder = new ImageEmbedderFake(); + await imageEmbedder.setOptions({}); // Initialize graph + }); + + it('initializes graph', async () => { + verifyGraph(imageEmbedder); + verifyListenersRegistered(imageEmbedder); + }); + + it('reloads graph when settings are changed', async () => { + verifyListenersRegistered(imageEmbedder); + + await imageEmbedder.setOptions({quantize: true}); + verifyGraph(imageEmbedder, [['embedderOptions', 'quantize'], true]); + verifyListenersRegistered(imageEmbedder); + + await imageEmbedder.setOptions({quantize: undefined}); + verifyGraph(imageEmbedder, [['embedderOptions', 'quantize'], undefined]); + verifyListenersRegistered(imageEmbedder); + }); + + it('can use custom models', async () => { + const newModel = new Uint8Array([0, 1, 2, 3, 4]); + const newModelBase64 = Buffer.from(newModel).toString('base64'); + await imageEmbedder.setOptions({ + baseOptions: { + modelAssetBuffer: newModel, + } + }); + + verifyGraph( + imageEmbedder, + /* expectedCalculatorOptions= */ undefined, + /* expectedBaseOptions= */[ + 'modelAsset', { + fileContent: newModelBase64, + fileName: undefined, + fileDescriptorMeta: undefined, + filePointerMeta: undefined + } + ]); + }); + + it('overrides options', async () => { + await imageEmbedder.setOptions({quantize: true}); + await imageEmbedder.setOptions({l2Normalize: true}); + verifyGraph( + imageEmbedder, + ['embedderOptions', {'quantize': true, 'l2Normalize': true}]); + }); + + describe('transforms result', () => { + beforeEach(() => { + const floatEmbedding = new FloatEmbedding(); + floatEmbedding.setValuesList([0.1, 0.9]); + + const embedding = new Embedding(); + embedding.setHeadIndex(1); + embedding.setHeadName('headName'); + embedding.setFloatEmbedding(floatEmbedding); + + const resultProto = new EmbeddingResult(); + resultProto.addEmbeddings(embedding); + resultProto.setTimestampMs(42); + + // Pass the test data to our listener + imageEmbedder.fakeWasmModule._waitUntilIdle.and.callFake(() => { + verifyListenersRegistered(imageEmbedder); + imageEmbedder.protoListener!(resultProto.serializeBinary()); + }); + }); + + it('for image mode', async () => { + // Invoke the image embedder + const embeddingResult = imageEmbedder.embed({} as HTMLImageElement); + + expect(imageEmbedder.fakeWasmModule._waitUntilIdle).toHaveBeenCalled(); + expect(embeddingResult).toEqual({ + embeddings: + [{headIndex: 1, headName: 'headName', floatEmbedding: [0.1, 0.9]}], + timestampMs: 42 + }); + }); + + it('for video mode', async () => { + await imageEmbedder.setOptions({runningMode: 'video'}); + + // Invoke the video embedder + const embeddingResult = + imageEmbedder.embedForVideo({} as HTMLImageElement, 42); + + expect(imageEmbedder.fakeWasmModule._waitUntilIdle).toHaveBeenCalled(); + expect(embeddingResult).toEqual({ + embeddings: + [{headIndex: 1, headName: 'headName', floatEmbedding: [0.1, 0.9]}], + timestampMs: 42 + }); + }); + }); +}); diff --git a/mediapipe/tasks/web/vision/object_detector/BUILD b/mediapipe/tasks/web/vision/object_detector/BUILD index 198585258..fc206a2d7 100644 --- a/mediapipe/tasks/web/vision/object_detector/BUILD +++ b/mediapipe/tasks/web/vision/object_detector/BUILD @@ -4,6 +4,7 @@ # the detection results for one or more object categories, using Object Detector. load("//mediapipe/framework/port:build_config.bzl", "mediapipe_ts_declaration", "mediapipe_ts_library") +load("@npm//@bazel/jasmine:index.bzl", "jasmine_node_test") package(default_visibility = ["//mediapipe/tasks:internal"]) @@ -12,6 +13,7 @@ licenses(["notice"]) mediapipe_ts_library( name = "object_detector", srcs = ["object_detector.ts"], + visibility = ["//visibility:public"], deps = [ ":object_detector_types", "//mediapipe/framework:calculator_jspb_proto", @@ -32,6 +34,7 @@ mediapipe_ts_declaration( "object_detector_options.d.ts", "object_detector_result.d.ts", ], + visibility = ["//visibility:public"], deps = [ "//mediapipe/tasks/web/components/containers:category", "//mediapipe/tasks/web/core", @@ -39,3 +42,26 @@ mediapipe_ts_declaration( "//mediapipe/tasks/web/vision/core:vision_task_options", ], ) + +mediapipe_ts_library( + name = "object_detector_test_lib", + testonly = True, + srcs = [ + "object_detector_test.ts", + ], + deps = [ + ":object_detector", + ":object_detector_types", + "//mediapipe/framework:calculator_jspb_proto", + "//mediapipe/framework/formats:detection_jspb_proto", + "//mediapipe/framework/formats:location_data_jspb_proto", + "//mediapipe/tasks/web/core", + "//mediapipe/tasks/web/core:task_runner_test_utils", + ], +) + +jasmine_node_test( + name = "object_detector_test", + tags = ["nomsan"], + deps = [":object_detector_test_lib"], +) diff --git a/mediapipe/tasks/web/vision/object_detector/object_detector.ts b/mediapipe/tasks/web/vision/object_detector/object_detector.ts index e2cfe0575..fbfaced12 100644 --- a/mediapipe/tasks/web/vision/object_detector/object_detector.ts +++ b/mediapipe/tasks/web/vision/object_detector/object_detector.ts @@ -92,6 +92,7 @@ export class ObjectDetector extends VisionTaskRunner { {baseOptions: {modelAssetPath}}); } + /** @hideconstructor */ constructor( wasmModule: WasmModule, glCanvas?: HTMLCanvasElement|OffscreenCanvas|null) { diff --git a/mediapipe/tasks/web/vision/object_detector/object_detector_test.ts b/mediapipe/tasks/web/vision/object_detector/object_detector_test.ts new file mode 100644 index 000000000..fff1a1c48 --- /dev/null +++ b/mediapipe/tasks/web/vision/object_detector/object_detector_test.ts @@ -0,0 +1,229 @@ +/** + * Copyright 2022 The MediaPipe Authors. All Rights Reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +import 'jasmine'; + +// Placeholder for internal dependency on encodeByteArray +import {CalculatorGraphConfig} from '../../../../framework/calculator_pb'; +import {Detection as DetectionProto} from '../../../../framework/formats/detection_pb'; +import {LocationData} from '../../../../framework/formats/location_data_pb'; +import {addJasmineCustomFloatEqualityTester, createSpyWasmModule, MediapipeTasksFake, SpyWasmModule, verifyGraph, verifyListenersRegistered} from '../../../../tasks/web/core/task_runner_test_utils'; + +import {ObjectDetector} from './object_detector'; +import {ObjectDetectorOptions} from './object_detector_options'; + +// The OSS JS API does not support the builder pattern. +// tslint:disable:jspb-use-builder-pattern + +class ObjectDetectorFake extends ObjectDetector implements MediapipeTasksFake { + lastSampleRate: number|undefined; + calculatorName = 'mediapipe.tasks.vision.ObjectDetectorGraph'; + attachListenerSpies: jasmine.Spy[] = []; + graph: CalculatorGraphConfig|undefined; + + fakeWasmModule: SpyWasmModule; + protoListener: ((binaryProtos: Uint8Array[]) => void)|undefined; + + constructor() { + super(createSpyWasmModule(), /* glCanvas= */ null); + this.fakeWasmModule = + this.graphRunner.wasmModule as unknown as SpyWasmModule; + + this.attachListenerSpies[0] = + spyOn(this.graphRunner, 'attachProtoVectorListener') + .and.callFake((stream, listener) => { + expect(stream).toEqual('detections'); + this.protoListener = listener; + }); + spyOn(this.graphRunner, 'setGraph').and.callFake(binaryGraph => { + this.graph = CalculatorGraphConfig.deserializeBinary(binaryGraph); + }); + spyOn(this.graphRunner, 'addGpuBufferAsImageToStream'); + } +} + +describe('ObjectDetector', () => { + let objectDetector: ObjectDetectorFake; + + beforeEach(async () => { + addJasmineCustomFloatEqualityTester(); + objectDetector = new ObjectDetectorFake(); + await objectDetector.setOptions({}); // Initialize graph + }); + + it('initializes graph', async () => { + verifyGraph(objectDetector); + verifyListenersRegistered(objectDetector); + }); + + it('reloads graph when settings are changed', async () => { + await objectDetector.setOptions({maxResults: 1}); + verifyGraph(objectDetector, ['maxResults', 1]); + verifyListenersRegistered(objectDetector); + + await objectDetector.setOptions({maxResults: 5}); + verifyGraph(objectDetector, ['maxResults', 5]); + verifyListenersRegistered(objectDetector); + }); + + it('can use custom models', async () => { + const newModel = new Uint8Array([0, 1, 2, 3, 4]); + const newModelBase64 = Buffer.from(newModel).toString('base64'); + await objectDetector.setOptions({ + baseOptions: { + modelAssetBuffer: newModel, + } + }); + + verifyGraph( + objectDetector, + /* expectedCalculatorOptions= */ undefined, + /* expectedBaseOptions= */ + [ + 'modelAsset', { + fileContent: newModelBase64, + fileName: undefined, + fileDescriptorMeta: undefined, + filePointerMeta: undefined + } + ]); + }); + + it('merges options', async () => { + await objectDetector.setOptions({maxResults: 1}); + await objectDetector.setOptions({displayNamesLocale: 'en'}); + verifyGraph(objectDetector, ['maxResults', 1]); + verifyGraph(objectDetector, ['displayNamesLocale', 'en']); + }); + + describe('setOptions() ', () => { + interface TestCase { + optionName: keyof ObjectDetectorOptions; + protoName: string; + customValue: unknown; + defaultValue: unknown; + } + + const testCases: TestCase[] = [ + { + optionName: 'maxResults', + protoName: 'maxResults', + customValue: 5, + defaultValue: -1 + }, + { + optionName: 'displayNamesLocale', + protoName: 'displayNamesLocale', + customValue: 'en', + defaultValue: 'en' + }, + { + optionName: 'scoreThreshold', + protoName: 'scoreThreshold', + customValue: 0.1, + defaultValue: undefined + }, + { + optionName: 'categoryAllowlist', + protoName: 'categoryAllowlistList', + customValue: ['foo'], + defaultValue: [] + }, + { + optionName: 'categoryDenylist', + protoName: 'categoryDenylistList', + customValue: ['bar'], + defaultValue: [] + }, + ]; + + for (const testCase of testCases) { + it(`can set ${testCase.optionName}`, async () => { + await objectDetector.setOptions( + {[testCase.optionName]: testCase.customValue}); + verifyGraph(objectDetector, [testCase.protoName, testCase.customValue]); + }); + + it(`can clear ${testCase.optionName}`, async () => { + await objectDetector.setOptions( + {[testCase.optionName]: testCase.customValue}); + verifyGraph(objectDetector, [testCase.protoName, testCase.customValue]); + await objectDetector.setOptions({[testCase.optionName]: undefined}); + verifyGraph( + objectDetector, [testCase.protoName, testCase.defaultValue]); + }); + } + }); + + it('transforms results', async () => { + const detectionProtos: Uint8Array[] = []; + + // Add a detection with all optional properties + let detection = new DetectionProto(); + detection.addScore(0.1); + detection.addLabelId(1); + detection.addLabel('foo'); + detection.addDisplayName('bar'); + let locationData = new LocationData(); + let boundingBox = new LocationData.BoundingBox(); + boundingBox.setXmin(1); + boundingBox.setYmin(2); + boundingBox.setWidth(3); + boundingBox.setHeight(4); + locationData.setBoundingBox(boundingBox); + detection.setLocationData(locationData); + detectionProtos.push(detection.serializeBinary()); + + // Add a detection without optional properties + detection = new DetectionProto(); + detection.addScore(0.2); + locationData = new LocationData(); + boundingBox = new LocationData.BoundingBox(); + locationData.setBoundingBox(boundingBox); + detection.setLocationData(locationData); + detectionProtos.push(detection.serializeBinary()); + + // Pass the test data to our listener + objectDetector.fakeWasmModule._waitUntilIdle.and.callFake(() => { + verifyListenersRegistered(objectDetector); + objectDetector.protoListener!(detectionProtos); + }); + + // Invoke the object detector + const detections = objectDetector.detect({} as HTMLImageElement); + + expect(objectDetector.fakeWasmModule._waitUntilIdle).toHaveBeenCalled(); + expect(detections.length).toEqual(2); + expect(detections[0]).toEqual({ + categories: [{ + score: 0.1, + index: 1, + categoryName: 'foo', + displayName: 'bar', + }], + boundingBox: {originX: 1, originY: 2, width: 3, height: 4} + }); + expect(detections[1]).toEqual({ + categories: [{ + score: 0.2, + index: -1, + categoryName: '', + displayName: '', + }], + boundingBox: {originX: 0, originY: 0, width: 0, height: 0} + }); + }); +}); diff --git a/mediapipe/util/rectangle_util_test.cc b/mediapipe/util/rectangle_util_test.cc index cd1946d45..3bc323f9f 100644 --- a/mediapipe/util/rectangle_util_test.cc +++ b/mediapipe/util/rectangle_util_test.cc @@ -20,6 +20,7 @@ namespace mediapipe { namespace { +using ::mediapipe::NormalizedRect; using ::testing::FloatNear; class RectangleUtilTest : public testing::Test { diff --git a/mediapipe/util/resource_util.cc b/mediapipe/util/resource_util.cc index 8f40154a0..38636f32e 100644 --- a/mediapipe/util/resource_util.cc +++ b/mediapipe/util/resource_util.cc @@ -37,6 +37,8 @@ absl::Status GetResourceContents(const std::string& path, std::string* output, return internal::DefaultGetResourceContents(path, output, read_as_binary); } +bool HasCustomGlobalResourceProvider() { return resource_provider_ != nullptr; } + void SetCustomGlobalResourceProvider(ResourceProviderFn fn) { resource_provider_ = std::move(fn); } diff --git a/mediapipe/util/resource_util_custom.h b/mediapipe/util/resource_util_custom.h index 6bc1513c6..e74af8b2e 100644 --- a/mediapipe/util/resource_util_custom.h +++ b/mediapipe/util/resource_util_custom.h @@ -10,6 +10,9 @@ namespace mediapipe { typedef std::function ResourceProviderFn; +// Returns true if files are provided via a custom resource provider. +bool HasCustomGlobalResourceProvider(); + // Overrides the behavior of GetResourceContents. void SetCustomGlobalResourceProvider(ResourceProviderFn fn); diff --git a/mediapipe/util/tracking/BUILD b/mediapipe/util/tracking/BUILD index 6bca24446..816af2533 100644 --- a/mediapipe/util/tracking/BUILD +++ b/mediapipe/util/tracking/BUILD @@ -282,7 +282,6 @@ cc_library( srcs = ["motion_models_cv.cc"], hdrs = ["motion_models_cv.h"], deps = [ - ":camera_motion_cc_proto", ":motion_models", ":motion_models_cc_proto", "//mediapipe/framework/port:opencv_core", diff --git a/mediapipe/util/tracking/tracked_detection.cc b/mediapipe/util/tracking/tracked_detection.cc index 130a87640..80a6981a8 100644 --- a/mediapipe/util/tracking/tracked_detection.cc +++ b/mediapipe/util/tracking/tracked_detection.cc @@ -20,6 +20,8 @@ namespace mediapipe { namespace { +using ::mediapipe::NormalizedRect; + // Struct for carrying boundary information. struct NormalizedRectBounds { float left, right, top, bottom; diff --git a/mediapipe/util/tracking/tracked_detection_manager.cc b/mediapipe/util/tracking/tracked_detection_manager.cc index 597827f3c..a9e348ceb 100644 --- a/mediapipe/util/tracking/tracked_detection_manager.cc +++ b/mediapipe/util/tracking/tracked_detection_manager.cc @@ -21,6 +21,7 @@ namespace { +using ::mediapipe::NormalizedRect; using mediapipe::TrackedDetection; // Checks if a point is out of view. diff --git a/mediapipe/util/tracking/tracked_detection_test.cc b/mediapipe/util/tracking/tracked_detection_test.cc index 60b9df1b1..13efaab92 100644 --- a/mediapipe/util/tracking/tracked_detection_test.cc +++ b/mediapipe/util/tracking/tracked_detection_test.cc @@ -18,6 +18,8 @@ namespace mediapipe { +using ::mediapipe::NormalizedRect; + const float kErrorMargin = 1e-4f; TEST(TrackedDetectionTest, ConstructorWithoutBox) { diff --git a/package.json b/package.json index 22a035b74..89b62bc83 100644 --- a/package.json +++ b/package.json @@ -3,15 +3,19 @@ "version": "0.0.0-alphga", "description": "MediaPipe GitHub repo", "devDependencies": { + "@bazel/jasmine": "^5.7.2", "@bazel/rollup": "^5.7.1", "@bazel/typescript": "^5.7.1", "@rollup/plugin-commonjs": "^23.0.2", "@rollup/plugin-node-resolve": "^15.0.1", - "@rollup/plugin-replace": "^5.0.1", "@rollup/plugin-terser": "^0.1.0", "@types/google-protobuf": "^3.15.6", + "@types/jasmine": "^4.3.1", + "@types/node": "^18.11.11", "@types/offscreencanvas": "^2019.7.0", "google-protobuf": "^3.21.2", + "jasmine": "^4.5.0", + "jasmine-core": "^4.5.0", "protobufjs": "^7.1.2", "protobufjs-cli": "^1.0.2", "rollup": "^2.3.0", diff --git a/third_party/apple_frameworks/BUILD b/third_party/apple_frameworks/BUILD new file mode 100644 index 000000000..05f830e81 --- /dev/null +++ b/third_party/apple_frameworks/BUILD @@ -0,0 +1,73 @@ +# Build rules to inject Apple Frameworks + +package(default_visibility = ["//visibility:public"]) + +cc_library( + name = "CoreGraphics", + linkopts = ["-framework CoreGraphics"], +) + +cc_library( + name = "CoreMedia", + linkopts = ["-framework CoreMedia"], +) + +cc_library( + name = "UIKit", + linkopts = ["-framework UIKit"], +) + +cc_library( + name = "Accelerate", + linkopts = ["-framework Accelerate"], +) + +cc_library( + name = "CoreVideo", + linkopts = ["-framework CoreVideo"], +) + +cc_library( + name = "Metal", + linkopts = ["-framework Metal"], +) + +cc_library( + name = "MetalPerformanceShaders", + linkopts = ["-framework MetalPerformanceShaders"], +) + +cc_library( + name = "AVFoundation", + linkopts = ["-framework AVFoundation"], +) + +cc_library( + name = "Foundation", + linkopts = ["-framework Foundation"], +) + +cc_library( + name = "CoreImage", + linkopts = ["-framework CoreImage"], +) + +cc_library( + name = "XCTest", + linkopts = ["-framework XCTest"], +) + +cc_library( + name = "GLKit", + linkopts = ["-framework GLKit"], +) + +cc_library( + name = "OpenGLES", + linkopts = ["-framework OpenGLES"], +) + +cc_library( + name = "QuartzCore", + linkopts = ["-framework QuartzCore"], +) diff --git a/tsconfig.json b/tsconfig.json index c17b1902e..970246dbb 100644 --- a/tsconfig.json +++ b/tsconfig.json @@ -10,7 +10,7 @@ "inlineSourceMap": true, "inlineSources": true, "strict": true, - "types": ["@types/offscreencanvas"], + "types": ["@types/offscreencanvas", "@types/jasmine", "node"], "rootDirs": [ ".", "./bazel-out/host/bin", diff --git a/yarn.lock b/yarn.lock index 19c32e322..9c4d91d30 100644 --- a/yarn.lock +++ b/yarn.lock @@ -3,34 +3,52 @@ "@babel/parser@^7.9.4": - version "7.20.3" - resolved "https://registry.yarnpkg.com/@babel/parser/-/parser-7.20.3.tgz#5358cf62e380cf69efcb87a7bb922ff88bfac6e2" - integrity sha512-OP/s5a94frIPXwjzEcv5S/tpQfc6XhxYUnmWpgdqMWGgYCuErA3SzozaRAMQgSZWKeTJxht9aWAkUY+0UzvOFg== + version "7.20.5" + resolved "https://registry.yarnpkg.com/@babel/parser/-/parser-7.20.5.tgz#7f3c7335fe417665d929f34ae5dceae4c04015e8" + integrity sha512-r27t/cy/m9uKLXQNWWebeCUHgnAZq0CpG1OwKRxzJMP1vpSU4bSIK2hq+/cp0bQxetkXx38n09rNu8jVkcK/zA== + +"@bazel/jasmine@^5.7.2": + version "5.7.2" + resolved "https://registry.yarnpkg.com/@bazel/jasmine/-/jasmine-5.7.2.tgz#438f272e66e939106cbdd58db709cd6aa008131b" + integrity sha512-RJruOB6S9e0efTNIE2JVdaslguUXh5KcmLUCq/xLCt0zENP44ssp9OooDIrZ8H+Sp4mLDNBX7CMMA9WTsbsxTQ== + dependencies: + c8 "~7.5.0" + jasmine-reporters "~2.5.0" "@bazel/rollup@^5.7.1": - version "5.7.1" - resolved "https://registry.yarnpkg.com/@bazel/rollup/-/rollup-5.7.1.tgz#6f644c2d493a5bd9cd3724a6f239e609585c6e37" - integrity sha512-LLNogoK2Qx9GIJVywQ+V/czjud8236mnaRX//g7qbOyXoWZDQvAEgsxRHq+lS/XX9USbh+zJJlfb+Dfp/PXx4A== + version "5.7.2" + resolved "https://registry.yarnpkg.com/@bazel/rollup/-/rollup-5.7.2.tgz#9953b06e3de52794791cee4f89540c263b035fcf" + integrity sha512-yGWLheSKdMnJ/Y3/qg+zCDx/qkD04FBFp+BjRS8xP4yvlz9G4rW3zc45VzHHz3oOywgQaY1vhfKuZMCcjTGEyA== dependencies: - "@bazel/worker" "5.7.1" + "@bazel/worker" "5.7.2" "@bazel/typescript@^5.7.1": - version "5.7.1" - resolved "https://registry.yarnpkg.com/@bazel/typescript/-/typescript-5.7.1.tgz#e585bcdc54a4ccb23d99c3e1206abf4853cf0682" - integrity sha512-MAnAtFxA2znadm81+rbYXcyWX1DEF/urzZ1F4LBq+w27EQ4PGyqIqCM5om7JcoSZJwjjMoBJc3SflRsMrZZ6+g== + version "5.7.2" + resolved "https://registry.yarnpkg.com/@bazel/typescript/-/typescript-5.7.2.tgz#a341215dc93ce28794e8430b311756816140bd78" + integrity sha512-tarBJBEIirnq/YaeYu18vXcDxjzlq4xhCXvXUxA0lhHX5oArjEcAEn4tmO0jF+t/7cbkAdMT7daG6vIHSz0QAA== dependencies: - "@bazel/worker" "5.7.1" + "@bazel/worker" "5.7.2" semver "5.6.0" source-map-support "0.5.9" tsutils "3.21.0" -"@bazel/worker@5.7.1": - version "5.7.1" - resolved "https://registry.yarnpkg.com/@bazel/worker/-/worker-5.7.1.tgz#2c4a9bd0e0ef75e496aec9599ff64a87307e7dad" - integrity sha512-UndmQVRqK0t0NMNl8I1P5XmxzdPvMA0X6jufszpfwy5gyzjOxeiOIzmC0ALCOx78CuJqOB/8WOI1pwTRmhd0tg== +"@bazel/worker@5.7.2": + version "5.7.2" + resolved "https://registry.yarnpkg.com/@bazel/worker/-/worker-5.7.2.tgz#43d800dc1b5a3707340a4eb0102da81c53fc6f63" + integrity sha512-H+auDA0QKF4mtZxKkZ2OKJvD7hGXVsVKtvcf4lbb93ur0ldpb5k810PcDxngmIGBcIX5kmyxniNTIiGFNobWTg== dependencies: google-protobuf "^3.6.1" +"@bcoe/v8-coverage@^0.2.3": + version "0.2.3" + resolved "https://registry.yarnpkg.com/@bcoe/v8-coverage/-/v8-coverage-0.2.3.tgz#75a2e8b51cb758a7553d6804a5932d7aace75c39" + integrity sha512-0hYQ8SB4Db5zvZB4axdMHGwEaQjkZzFjQiN9LVYvIFB2nSUHW9tYpxWriPrWDASIxiaXax83REcLxuSdnGPZtw== + +"@istanbuljs/schema@^0.1.2": + version "0.1.3" + resolved "https://registry.yarnpkg.com/@istanbuljs/schema/-/schema-0.1.3.tgz#e45e384e4b8ec16bce2fd903af78450f6bf7ec98" + integrity sha512-ZXRY4jNvVgSVQ8DL3LTcakaAtXwTVUxE81hslsyD2AtoXW/wVob10HkOJ1X/pAlcI7D+2YoZKg5do8G/w6RYgA== + "@jridgewell/gen-mapping@^0.3.0": version "0.3.2" resolved "https://registry.yarnpkg.com/@jridgewell/gen-mapping/-/gen-mapping-0.3.2.tgz#c1aedc61e853f2bb9f5dfe6d4442d3b565b253b9" @@ -125,9 +143,9 @@ integrity sha512-Vvn3zZrhQZkkBE8LSuW3em98c0FwgO4nxzv6OdSxPKJIEKY2bGbHn+mhGIPerzI4twdxaP8/0+06HBpwf345Lw== "@rollup/plugin-commonjs@^23.0.2": - version "23.0.2" - resolved "https://registry.yarnpkg.com/@rollup/plugin-commonjs/-/plugin-commonjs-23.0.2.tgz#3a3a5b7b1b1cb29037eb4992edcaae997d7ebd92" - integrity sha512-e9ThuiRf93YlVxc4qNIurvv+Hp9dnD+4PjOqQs5vAYfcZ3+AXSrcdzXnVjWxcGQOa6KGJFcRZyUI3ktWLavFjg== + version "23.0.3" + resolved "https://registry.yarnpkg.com/@rollup/plugin-commonjs/-/plugin-commonjs-23.0.3.tgz#442cd8ccca1b7563a503da86fc84a1a7112b54bb" + integrity sha512-31HxrT5emGfTyIfAs1lDQHj6EfYxTXcwtX5pIIhq+B/xZBNIqQ179d/CkYxlpYmFCxT78AeU4M8aL8Iv/IBxFA== dependencies: "@rollup/pluginutils" "^5.0.1" commondir "^1.0.1" @@ -148,14 +166,6 @@ is-module "^1.0.0" resolve "^1.22.1" -"@rollup/plugin-replace@^5.0.1": - version "5.0.1" - resolved "https://registry.yarnpkg.com/@rollup/plugin-replace/-/plugin-replace-5.0.1.tgz#49a57af3e6df111a9e75dea3f3572741f4c5c83e" - integrity sha512-Z3MfsJ4CK17BfGrZgvrcp/l6WXoKb0kokULO+zt/7bmcyayokDaQ2K3eDJcRLCTAlp5FPI4/gz9MHAsosz4Rag== - dependencies: - "@rollup/pluginutils" "^5.0.1" - magic-string "^0.26.4" - "@rollup/plugin-terser@^0.1.0": version "0.1.0" resolved "https://registry.yarnpkg.com/@rollup/plugin-terser/-/plugin-terser-0.1.0.tgz#7530c0f11667637419d71820461646c418526041" @@ -182,6 +192,21 @@ resolved "https://registry.yarnpkg.com/@types/google-protobuf/-/google-protobuf-3.15.6.tgz#674a69493ef2c849b95eafe69167ea59079eb504" integrity sha512-pYVNNJ+winC4aek+lZp93sIKxnXt5qMkuKmaqS3WGuTq0Bw1ZDYNBgzG5kkdtwcv+GmYJGo3yEg6z2cKKAiEdw== +"@types/is-windows@^1.0.0": + version "1.0.0" + resolved "https://registry.yarnpkg.com/@types/is-windows/-/is-windows-1.0.0.tgz#1011fa129d87091e2f6faf9042d6704cdf2e7be0" + integrity sha512-tJ1rq04tGKuIJoWIH0Gyuwv4RQ3+tIu7wQrC0MV47raQ44kIzXSSFKfrxFUOWVRvesoF7mrTqigXmqoZJsXwTg== + +"@types/istanbul-lib-coverage@^2.0.1": + version "2.0.4" + resolved "https://registry.yarnpkg.com/@types/istanbul-lib-coverage/-/istanbul-lib-coverage-2.0.4.tgz#8467d4b3c087805d63580480890791277ce35c44" + integrity sha512-z/QT1XN4K4KYuslS23k62yDIDLwLFkzxOuMplDtObz0+y7VqJCaO2o+SPwHCvLFZh7xazvvoor2tA/hPz9ee7g== + +"@types/jasmine@^4.3.1": + version "4.3.1" + resolved "https://registry.yarnpkg.com/@types/jasmine/-/jasmine-4.3.1.tgz#2d8ab5601c2fe7d9673dcb157e03f128ab5c5fff" + integrity sha512-Vu8l+UGcshYmV1VWwULgnV/2RDbBaO6i2Ptx7nd//oJPIZGhoI1YLST4VKagD2Pq/Bc2/7zvtvhM7F3p4SN7kQ== + "@types/linkify-it@*": version "3.0.2" resolved "https://registry.yarnpkg.com/@types/linkify-it/-/linkify-it-3.0.2.tgz#fd2cd2edbaa7eaac7e7f3c1748b52a19143846c9" @@ -200,10 +225,10 @@ resolved "https://registry.yarnpkg.com/@types/mdurl/-/mdurl-1.0.2.tgz#e2ce9d83a613bacf284c7be7d491945e39e1f8e9" integrity sha512-eC4U9MlIcu2q0KQmXszyn5Akca/0jrQmwDRgpAMJai7qBWq4amIQhZyNau4VYGtCeALvW1/NtjzJJ567aZxfKA== -"@types/node@>=13.7.0": - version "18.11.9" - resolved "https://registry.yarnpkg.com/@types/node/-/node-18.11.9.tgz#02d013de7058cea16d36168ef2fc653464cfbad4" - integrity sha512-CRpX21/kGdzjOpFsZSkcrXMGIBWMGNIHXXBVFSH+ggkftxg+XYP20TESbh+zFvFj3EQOl5byk0HTRn1IL6hbqg== +"@types/node@>=13.7.0", "@types/node@^18.11.11": + version "18.11.11" + resolved "https://registry.yarnpkg.com/@types/node/-/node-18.11.11.tgz#1d455ac0211549a8409d3cdb371cd55cc971e8dc" + integrity sha512-KJ021B1nlQUBLopzZmPBVuGU9un7WJd/W4ya7Ih02B4Uwky5Nja0yGYav2EfYIk0RR2Q9oVhf60S2XR1BCWJ2g== "@types/offscreencanvas@^2019.7.0": version "2019.7.0" @@ -215,6 +240,11 @@ resolved "https://registry.yarnpkg.com/@types/resolve/-/resolve-1.20.2.tgz#97d26e00cd4a0423b4af620abecf3e6f442b7975" integrity sha512-60BCwRFOZCQhDncwQdxxeOEEkbc5dIMccYLwbxsS4TUNeVECQ/pBJ0j09mrHOl/JJvpRPGwO9SvE4nR2Nb/a4Q== +"@xmldom/xmldom@^0.8.5": + version "0.8.6" + resolved "https://registry.yarnpkg.com/@xmldom/xmldom/-/xmldom-0.8.6.tgz#8a1524eb5bd5e965c1e3735476f0262469f71440" + integrity sha512-uRjjusqpoqfmRkTaNuLJ2VohVr67Q5YwDATW3VU7PfzTj6IRaihGrYI7zckGZjxQPBIp63nfvJbM+Yu5ICh0Bg== + acorn-jsx@^5.3.2: version "5.3.2" resolved "https://registry.yarnpkg.com/acorn-jsx/-/acorn-jsx-5.3.2.tgz#7ed5bb55908b3b2f1bc55c6af1653bada7f07937" @@ -225,7 +255,12 @@ acorn@^8.5.0, acorn@^8.8.0: resolved "https://registry.yarnpkg.com/acorn/-/acorn-8.8.1.tgz#0a3f9cbecc4ec3bea6f0a80b66ae8dd2da250b73" integrity sha512-7zFpHzhnqYKrkYdUjF1HI1bzd0VygEGX8lFk4k5zVMqHEoES+P+7TKI+EvLO9WVMJ8eekdO0aDEK044xTXwPPA== -ansi-styles@^4.1.0: +ansi-regex@^5.0.1: + version "5.0.1" + resolved "https://registry.yarnpkg.com/ansi-regex/-/ansi-regex-5.0.1.tgz#082cb2c89c9fe8659a311a53bd6a4dc5301db304" + integrity sha512-quJQXlTSUGL2LH9SUXo8VwsY4soanhgo6LNSm84E1LBcE8s3O0wpdiRzyR9z/ZZJMlMWv37qOOb9pdJlMUEKFQ== + +ansi-styles@^4.0.0, ansi-styles@^4.1.0: version "4.3.0" resolved "https://registry.yarnpkg.com/ansi-styles/-/ansi-styles-4.3.0.tgz#edd803628ae71c04c85ae7a0906edad34b648937" integrity sha512-zbB9rCJAT1rbjiVDb2hqKFHNYLxgtk8NURxZ3IZwD3F6NtxbXZQCnnSi1Lkx+IDohdPlFp222wVALIheZJQSEg== @@ -272,6 +307,25 @@ builtin-modules@^3.3.0: resolved "https://registry.yarnpkg.com/builtin-modules/-/builtin-modules-3.3.0.tgz#cae62812b89801e9656336e46223e030386be7b6" integrity sha512-zhaCDicdLuWN5UbN5IMnFqNMhNfo919sH85y2/ea+5Yg9TsTkeZxpL+JLbp6cgYFS4sRLp3YV4S6yDuqVWHYOw== +c8@~7.5.0: + version "7.5.0" + resolved "https://registry.yarnpkg.com/c8/-/c8-7.5.0.tgz#a69439ab82848f344a74bb25dc5dd4e867764481" + integrity sha512-GSkLsbvDr+FIwjNSJ8OwzWAyuznEYGTAd1pzb/Kr0FMLuV4vqYJTyjboDTwmlUNAG6jAU3PFWzqIdKrOt1D8tw== + dependencies: + "@bcoe/v8-coverage" "^0.2.3" + "@istanbuljs/schema" "^0.1.2" + find-up "^5.0.0" + foreground-child "^2.0.0" + furi "^2.0.0" + istanbul-lib-coverage "^3.0.0" + istanbul-lib-report "^3.0.0" + istanbul-reports "^3.0.2" + rimraf "^3.0.0" + test-exclude "^6.0.0" + v8-to-istanbul "^7.1.0" + yargs "^16.0.0" + yargs-parser "^20.0.0" + catharsis@^0.9.0: version "0.9.0" resolved "https://registry.yarnpkg.com/catharsis/-/catharsis-0.9.0.tgz#40382a168be0e6da308c277d3a2b3eb40c7d2121" @@ -287,6 +341,15 @@ chalk@^4.0.0: ansi-styles "^4.1.0" supports-color "^7.1.0" +cliui@^7.0.2: + version "7.0.4" + resolved "https://registry.yarnpkg.com/cliui/-/cliui-7.0.4.tgz#a0265ee655476fc807aea9df3df8df7783808b4f" + integrity sha512-OcRE68cOsVMXp1Yvonl/fzkQOyjLSu/8bhPDfQt0e0/Eb283TKP20Fs2MqoPsr9SwA595rRCA+QMzYc9nBP+JQ== + dependencies: + string-width "^4.2.0" + strip-ansi "^6.0.0" + wrap-ansi "^7.0.0" + color-convert@^2.0.1: version "2.0.1" resolved "https://registry.yarnpkg.com/color-convert/-/color-convert-2.0.1.tgz#72d3a68d598c9bdb3af2ad1e84f21d896abd4de3" @@ -314,6 +377,20 @@ concat-map@0.0.1: resolved "https://registry.yarnpkg.com/concat-map/-/concat-map-0.0.1.tgz#d8a96bd77fd68df7793a73036a3ba0d5405d477b" integrity sha512-/Srv4dswyQNBfohGpz9o6Yb3Gz3SrUDqBH5rTuhGR7ahtlbYKnVxw2bCFMRljaA7EXHaXZ8wsHdodFvbkhKmqg== +convert-source-map@^1.6.0: + version "1.9.0" + resolved "https://registry.yarnpkg.com/convert-source-map/-/convert-source-map-1.9.0.tgz#7faae62353fb4213366d0ca98358d22e8368b05f" + integrity sha512-ASFBup0Mz1uyiIjANan1jzLQami9z1PoYSZCiiYW2FczPbenXc45FZdBZLzOT+r6+iciuEModtmCti+hjaAk0A== + +cross-spawn@^7.0.0: + version "7.0.3" + resolved "https://registry.yarnpkg.com/cross-spawn/-/cross-spawn-7.0.3.tgz#f73a85b9d5d41d045551c177e2882d4ac85728a6" + integrity sha512-iRDPJKUPVEND7dHPO8rkbOnPpyDygcDFtWjpeWNCgy8WP2rXcxXL8TskReQl6OrB2G7+UJrags1q15Fudc7G6w== + dependencies: + path-key "^3.1.0" + shebang-command "^2.0.0" + which "^2.0.1" + deep-is@~0.1.3: version "0.1.4" resolved "https://registry.yarnpkg.com/deep-is/-/deep-is-0.1.4.tgz#a6f2dce612fadd2ef1f519b73551f17e85199831" @@ -324,11 +401,21 @@ deepmerge@^4.2.2: resolved "https://registry.yarnpkg.com/deepmerge/-/deepmerge-4.2.2.tgz#44d2ea3679b8f4d4ffba33f03d865fc1e7bf4955" integrity sha512-FJ3UgI4gIl+PHZm53knsuSFpE+nESMr7M4v9QcgB7S63Kj/6WqMiFQJpBBYz1Pt+66bZpP3Q7Lye0Oo9MPKEdg== +emoji-regex@^8.0.0: + version "8.0.0" + resolved "https://registry.yarnpkg.com/emoji-regex/-/emoji-regex-8.0.0.tgz#e818fd69ce5ccfcb404594f842963bf53164cc37" + integrity sha512-MSjYzcWNOA0ewAHpz0MxpYFvwg6yjy1NG3xteoqz644VCo/RPgnr1/GGt+ic3iJTzQ8Eu3TdM14SawnVUmGE6A== + entities@~2.1.0: version "2.1.0" resolved "https://registry.yarnpkg.com/entities/-/entities-2.1.0.tgz#992d3129cf7df6870b96c57858c249a120f8b8b5" integrity sha512-hCx1oky9PFrJ611mf0ifBLBRW8lUUVRlFolb5gWRfIELabBlbp9xZvrqZLZAs+NxFnbfQoeGd8wDkygjg7U85w== +escalade@^3.1.1: + version "3.1.1" + resolved "https://registry.yarnpkg.com/escalade/-/escalade-3.1.1.tgz#d8cfdc7000965c5a0174b4a82eaa5c0552742e40" + integrity sha512-k0er2gUkLf8O0zKJiAhmkTnJlTvINGv7ygDNPbeIsX/TJjGJZHuh9B2UxbsaEkmlEo9MfhrSzmhIlhRlI2GXnw== + escape-string-regexp@^2.0.0: version "2.0.0" resolved "https://registry.yarnpkg.com/escape-string-regexp/-/escape-string-regexp-2.0.0.tgz#a30304e99daa32e23b2fd20f51babd07cffca344" @@ -390,6 +477,22 @@ fast-levenshtein@~2.0.6: resolved "https://registry.yarnpkg.com/fast-levenshtein/-/fast-levenshtein-2.0.6.tgz#3d8a5c66883a16a30ca8643e851f19baa7797917" integrity sha512-DCXu6Ifhqcks7TZKY3Hxp3y6qphY5SJZmrWMDrKcERSOXWQdMhU9Ig/PYrzyw/ul9jOIyh0N4M0tbC5hodg8dw== +find-up@^5.0.0: + version "5.0.0" + resolved "https://registry.yarnpkg.com/find-up/-/find-up-5.0.0.tgz#4c92819ecb7083561e4f4a240a86be5198f536fc" + integrity sha512-78/PXT1wlLLDgTzDs7sjq9hzz0vXD+zn+7wypEe4fXQxCmdmqfGsEPQxmiCSQI3ajFV91bVSsvNtrJRiW6nGng== + dependencies: + locate-path "^6.0.0" + path-exists "^4.0.0" + +foreground-child@^2.0.0: + version "2.0.0" + resolved "https://registry.yarnpkg.com/foreground-child/-/foreground-child-2.0.0.tgz#71b32800c9f15aa8f2f83f4a6bd9bff35d861a53" + integrity sha512-dCIq9FpEcyQyXKCkyzmlPTFNgrCzPudOe+mhvJU5zAtlBnGVy2yKxtfsxK2tQBThwq225jcvBjpw1Gr40uzZCA== + dependencies: + cross-spawn "^7.0.0" + signal-exit "^3.0.2" + fs.realpath@^1.0.0: version "1.0.0" resolved "https://registry.yarnpkg.com/fs.realpath/-/fs.realpath-1.0.0.tgz#1504ad2523158caa40db4a2787cb01411994ea4f" @@ -405,7 +508,20 @@ function-bind@^1.1.1: resolved "https://registry.yarnpkg.com/function-bind/-/function-bind-1.1.1.tgz#a56899d3ea3c9bab874bb9773b7c5ede92f4895d" integrity sha512-yIovAzMX49sF8Yl58fSCWJ5svSLuaibPxXQJFLmBObTuCr0Mf1KiPopGM9NiFjiYBCbfaa2Fh6breQ6ANVTI0A== -glob@^7.1.3: +furi@^2.0.0: + version "2.0.0" + resolved "https://registry.yarnpkg.com/furi/-/furi-2.0.0.tgz#13d85826a1af21acc691da6254b3888fc39f0b4a" + integrity sha512-uKuNsaU0WVaK/vmvj23wW1bicOFfyqSsAIH71bRZx8kA4Xj+YCHin7CJKJJjkIsmxYaPFLk9ljmjEyB7xF7WvQ== + dependencies: + "@types/is-windows" "^1.0.0" + is-windows "^1.0.2" + +get-caller-file@^2.0.5: + version "2.0.5" + resolved "https://registry.yarnpkg.com/get-caller-file/-/get-caller-file-2.0.5.tgz#4f94412a82db32f36e3b0b9741f8a97feb031f7e" + integrity sha512-DyFP3BM/3YHTQOCUL/w0OZHR0lpKeGrxotcHWcqNEdnltqFwXVfhEBQ94eIo34AfQpo0rGki4cyIiftY06h2Fg== + +glob@^7.1.3, glob@^7.1.4, glob@^7.1.6: version "7.2.3" resolved "https://registry.yarnpkg.com/glob/-/glob-7.2.3.tgz#b8df0fb802bbfa8e89bd1d938b4e16578ed44f2b" integrity sha512-nFR0zLpU2YCaRxwoCJvL6UvCH2JFyFVIvwTLsIf21AuHlMskA1hhTdk+LlYJtOlYt9v6dvszD2BGRqBL+iQK9Q== @@ -450,6 +566,11 @@ has@^1.0.3: dependencies: function-bind "^1.1.1" +html-escaper@^2.0.0: + version "2.0.2" + resolved "https://registry.yarnpkg.com/html-escaper/-/html-escaper-2.0.2.tgz#dfd60027da36a36dfcbe236262c00a5822681453" + integrity sha512-H2iMtd0I4Mt5eYiapRdIDjp+XzelXQ0tFE4JS7YFwFevXXMmOp9myNrUvCg0D6ws8iqkRPBfKHgbwig1SmlLfg== + inflight@^1.0.4: version "1.0.6" resolved "https://registry.yarnpkg.com/inflight/-/inflight-1.0.6.tgz#49bd6331d7d02d0c09bc910a1075ba8165b56df9" @@ -477,6 +598,11 @@ is-core-module@^2.9.0: dependencies: has "^1.0.3" +is-fullwidth-code-point@^3.0.0: + version "3.0.0" + resolved "https://registry.yarnpkg.com/is-fullwidth-code-point/-/is-fullwidth-code-point-3.0.0.tgz#f116f8064fe90b3f7844a38997c0b75051269f1d" + integrity sha512-zymm5+u+sCsSWyD9qNaejV3DFvhCKclKdizYaJUuHA83RLjb7nSuGnddCHGv0hk+KY7BMAlsWeK4Ueg6EV6XQg== + is-module@^1.0.0: version "1.0.0" resolved "https://registry.yarnpkg.com/is-module/-/is-module-1.0.0.tgz#3258fb69f78c14d5b815d664336b4cffb6441591" @@ -489,6 +615,59 @@ is-reference@1.2.1: dependencies: "@types/estree" "*" +is-windows@^1.0.2: + version "1.0.2" + resolved "https://registry.yarnpkg.com/is-windows/-/is-windows-1.0.2.tgz#d1850eb9791ecd18e6182ce12a30f396634bb19d" + integrity sha512-eXK1UInq2bPmjyX6e3VHIzMLobc4J94i4AWn+Hpq3OU5KkrRC96OAcR3PRJ/pGu6m8TRnBHP9dkXQVsT/COVIA== + +isexe@^2.0.0: + version "2.0.0" + resolved "https://registry.yarnpkg.com/isexe/-/isexe-2.0.0.tgz#e8fbf374dc556ff8947a10dcb0572d633f2cfa10" + integrity sha512-RHxMLp9lnKHGHRng9QFhRCMbYAcVpn69smSGcq3f36xjgVVWThj4qqLbTLlq7Ssj8B+fIQ1EuCEGI2lKsyQeIw== + +istanbul-lib-coverage@^3.0.0: + version "3.2.0" + resolved "https://registry.yarnpkg.com/istanbul-lib-coverage/-/istanbul-lib-coverage-3.2.0.tgz#189e7909d0a39fa5a3dfad5b03f71947770191d3" + integrity sha512-eOeJ5BHCmHYvQK7xt9GkdHuzuCGS1Y6g9Gvnx3Ym33fz/HpLRYxiS0wHNr+m/MBC8B647Xt608vCDEvhl9c6Mw== + +istanbul-lib-report@^3.0.0: + version "3.0.0" + resolved "https://registry.yarnpkg.com/istanbul-lib-report/-/istanbul-lib-report-3.0.0.tgz#7518fe52ea44de372f460a76b5ecda9ffb73d8a6" + integrity sha512-wcdi+uAKzfiGT2abPpKZ0hSU1rGQjUQnLvtY5MpQ7QCTahD3VODhcu4wcfY1YtkGaDD5yuydOLINXsfbus9ROw== + dependencies: + istanbul-lib-coverage "^3.0.0" + make-dir "^3.0.0" + supports-color "^7.1.0" + +istanbul-reports@^3.0.2: + version "3.1.5" + resolved "https://registry.yarnpkg.com/istanbul-reports/-/istanbul-reports-3.1.5.tgz#cc9a6ab25cb25659810e4785ed9d9fb742578bae" + integrity sha512-nUsEMa9pBt/NOHqbcbeJEgqIlY/K7rVWUX6Lql2orY5e9roQOthbR3vtY4zzf2orPELg80fnxxk9zUyPlgwD1w== + dependencies: + html-escaper "^2.0.0" + istanbul-lib-report "^3.0.0" + +jasmine-core@^4.5.0: + version "4.5.0" + resolved "https://registry.yarnpkg.com/jasmine-core/-/jasmine-core-4.5.0.tgz#1a6bd0bde3f60996164311c88a0995d67ceda7c3" + integrity sha512-9PMzyvhtocxb3aXJVOPqBDswdgyAeSB81QnLop4npOpbqnheaTEwPc9ZloQeVswugPManznQBjD8kWDTjlnHuw== + +jasmine-reporters@~2.5.0: + version "2.5.2" + resolved "https://registry.yarnpkg.com/jasmine-reporters/-/jasmine-reporters-2.5.2.tgz#b5dfa1d9c40b8020c5225e0e1e2b9953d66a4d69" + integrity sha512-qdewRUuFOSiWhiyWZX8Yx3YNQ9JG51ntBEO4ekLQRpktxFTwUHy24a86zD/Oi2BRTKksEdfWQZcQFqzjqIkPig== + dependencies: + "@xmldom/xmldom" "^0.8.5" + mkdirp "^1.0.4" + +jasmine@^4.5.0: + version "4.5.0" + resolved "https://registry.yarnpkg.com/jasmine/-/jasmine-4.5.0.tgz#8d3c0d0a33a61e4d05c9f9747ee5dfaf6f7b5d3d" + integrity sha512-9olGRvNZyADIwYL9XBNBst5BTU/YaePzuddK+YRslc7rI9MdTIE4r3xaBKbv2GEmzYYUfMOdTR8/i6JfLZaxSQ== + dependencies: + glob "^7.1.6" + jasmine-core "^4.5.0" + js2xmlparser@^4.0.2: version "4.0.2" resolved "https://registry.yarnpkg.com/js2xmlparser/-/js2xmlparser-4.0.2.tgz#2a1fdf01e90585ef2ae872a01bc169c6a8d5e60a" @@ -539,7 +718,14 @@ linkify-it@^3.0.1: dependencies: uc.micro "^1.0.1" -lodash@^4.17.14, lodash@^4.17.15: +locate-path@^6.0.0: + version "6.0.0" + resolved "https://registry.yarnpkg.com/locate-path/-/locate-path-6.0.0.tgz#55321eb309febbc59c4801d931a72452a681d286" + integrity sha512-iPZK6eYjbxRu3uB4/WZ3EsEIMJFMqAoopl3R+zuq0UjcAm/MO6KCweDgPfP3elTztoKP3KtnVHxTn2NHBSDVUw== + dependencies: + p-locate "^5.0.0" + +lodash@^4.17.15, lodash@^4.17.21: version "4.17.21" resolved "https://registry.yarnpkg.com/lodash/-/lodash-4.17.21.tgz#679591c564c3bffaae8454cf0b3df370c3d6911c" integrity sha512-v2kDEe57lecTulaDIuNTPy3Ry4gLGJ6Z1O3vE1krgXZNrsQ+LFTGHVxVjcXPs17LhbZVGedAJv8XZ1tvj5FvSg== @@ -563,6 +749,13 @@ magic-string@^0.26.4: dependencies: sourcemap-codec "^1.4.8" +make-dir@^3.0.0: + version "3.1.0" + resolved "https://registry.yarnpkg.com/make-dir/-/make-dir-3.1.0.tgz#415e967046b3a7f1d185277d84aa58203726a13f" + integrity sha512-g3FeP20LNwhALb/6Cz6Dd4F2ngze0jz7tbzrD2wAV+o9FeNHe4rL+yK2md0J/fiSf1sa1ADhXqi5+oVwOM/eGw== + dependencies: + semver "^6.0.0" + markdown-it-anchor@^8.4.1: version "8.6.5" resolved "https://registry.yarnpkg.com/markdown-it-anchor/-/markdown-it-anchor-8.6.5.tgz#30c4bc5bbff327f15ce3c429010ec7ba75e7b5f8" @@ -580,16 +773,16 @@ markdown-it@^12.3.2: uc.micro "^1.0.5" marked@^4.0.10: - version "4.2.2" - resolved "https://registry.yarnpkg.com/marked/-/marked-4.2.2.tgz#1d2075ad6cdfe42e651ac221c32d949a26c0672a" - integrity sha512-JjBTFTAvuTgANXx82a5vzK9JLSMoV6V3LBVn4Uhdso6t7vXrGx7g1Cd2r6NYSsxrYbQGFCMqBDhFHyK5q2UvcQ== + version "4.2.3" + resolved "https://registry.yarnpkg.com/marked/-/marked-4.2.3.tgz#bd76a5eb510ff1d8421bc6c3b2f0b93488c15bea" + integrity sha512-slWRdJkbTZ+PjkyJnE30Uid64eHwbwa1Q25INCAYfZlK4o6ylagBy/Le9eWntqJFoFT93ikUKMv47GZ4gTwHkw== mdurl@^1.0.1: version "1.0.1" resolved "https://registry.yarnpkg.com/mdurl/-/mdurl-1.0.1.tgz#fe85b2ec75a59037f2adfec100fd6c601761152e" integrity sha512-/sKlQJCBYVY9Ers9hqzKou4H6V5UWc/M59TH2dvkt+84itfnq7uFOMLpOiOS4ujvHP4etln18fmIxA5R5fll0g== -minimatch@^3.1.1: +minimatch@^3.0.4, minimatch@^3.1.1: version "3.1.2" resolved "https://registry.yarnpkg.com/minimatch/-/minimatch-3.1.2.tgz#19cd194bfd3e428f049a70817c038d89ab4be35b" integrity sha512-J7p63hRiAjw1NDEww1W7i37+ByIrOWO5XQQAzZ3VOcL0PNybwpfmV/N05zFAzwQ9USyEcX6t3UO+K5aqBQOIHw== @@ -597,9 +790,9 @@ minimatch@^3.1.1: brace-expansion "^1.1.7" minimatch@^5.0.1: - version "5.1.0" - resolved "https://registry.yarnpkg.com/minimatch/-/minimatch-5.1.0.tgz#1717b464f4971b144f6aabe8f2d0b8e4511e09c7" - integrity sha512-9TPBGGak4nHfGZsPBohm9AWg6NoT7QTCehS3BIJABslyZbzxfV78QM2Y6+i741OPZIafFAaiiEMh5OyIrJPgtg== + version "5.1.1" + resolved "https://registry.yarnpkg.com/minimatch/-/minimatch-5.1.1.tgz#6c9dffcf9927ff2a31e74b5af11adf8b9604b022" + integrity sha512-362NP+zlprccbEt/SkxKfRMHnNY85V74mVnpUpNyr3F35covl09Kec7/sEFLt3RA4oXmewtoaanoIf67SE5Y5g== dependencies: brace-expansion "^2.0.1" @@ -632,11 +825,35 @@ optionator@^0.8.1: type-check "~0.3.2" word-wrap "~1.2.3" +p-limit@^3.0.2: + version "3.1.0" + resolved "https://registry.yarnpkg.com/p-limit/-/p-limit-3.1.0.tgz#e1daccbe78d0d1388ca18c64fea38e3e57e3706b" + integrity sha512-TYOanM3wGwNGsZN2cVTYPArw454xnXj5qmWF1bEoAc4+cU/ol7GVh7odevjp1FNHduHc3KZMcFduxU5Xc6uJRQ== + dependencies: + yocto-queue "^0.1.0" + +p-locate@^5.0.0: + version "5.0.0" + resolved "https://registry.yarnpkg.com/p-locate/-/p-locate-5.0.0.tgz#83c8315c6785005e3bd021839411c9e110e6d834" + integrity sha512-LaNjtRWUBY++zB5nE/NwcaoMylSPk+S+ZHNB1TzdbMJMny6dynpAGt7X/tl/QYq3TIeE6nxHppbo2LGymrG5Pw== + dependencies: + p-limit "^3.0.2" + +path-exists@^4.0.0: + version "4.0.0" + resolved "https://registry.yarnpkg.com/path-exists/-/path-exists-4.0.0.tgz#513bdbe2d3b95d7762e8c1137efa195c6c61b5b3" + integrity sha512-ak9Qy5Q7jYb2Wwcey5Fpvg2KoAc/ZIhLSLOSBmRmygPsGwkVVt0fZa0qrtMz+m6tJTAHfZQ8FnmB4MG4LWy7/w== + path-is-absolute@^1.0.0: version "1.0.1" resolved "https://registry.yarnpkg.com/path-is-absolute/-/path-is-absolute-1.0.1.tgz#174b9268735534ffbc7ace6bf53a5a9e1b5c5f5f" integrity sha512-AVbw3UJ2e9bq64vSaS9Am0fje1Pa8pbGqTTsmXfaIiMpnr5DlDhfJOuLj9Sf95ZPVDAUerDfEk88MPmPe7UCQg== +path-key@^3.1.0: + version "3.1.1" + resolved "https://registry.yarnpkg.com/path-key/-/path-key-3.1.1.tgz#581f6ade658cbba65a0d3380de7753295054f375" + integrity sha512-ojmeN0qd+y0jszEtoY48r0Peq5dwMEkIlCOu6Q5f41lfkswXuKtYrhgoTpLnyIcHm24Uhqx+5Tqm2InSwLhE6Q== + path-parse@^1.0.7: version "1.0.7" resolved "https://registry.yarnpkg.com/path-parse/-/path-parse-1.0.7.tgz#fbc114b60ca42b30d9daf5858e4bd68bbedb6735" @@ -686,12 +903,17 @@ protobufjs@^7.1.2: "@types/node" ">=13.7.0" long "^5.0.0" +require-directory@^2.1.1: + version "2.1.1" + resolved "https://registry.yarnpkg.com/require-directory/-/require-directory-2.1.1.tgz#8c64ad5fd30dab1c976e2344ffe7f792a6a6df42" + integrity sha512-fGxEI7+wsG9xrvdjsrlmL22OMTTiHRwAMroiEeMgq8gzoLC/PQr7RsRDSTLUg/bZAZtF+TVIkHc6/4RIKrui+Q== + requizzle@^0.2.3: - version "0.2.3" - resolved "https://registry.yarnpkg.com/requizzle/-/requizzle-0.2.3.tgz#4675c90aacafb2c036bd39ba2daa4a1cb777fded" - integrity sha512-YanoyJjykPxGHii0fZP0uUPEXpvqfBDxWV7s6GKAiiOsiqhX6vHNyW3Qzdmqp/iq/ExbhaGbVrjB4ruEVSM4GQ== + version "0.2.4" + resolved "https://registry.yarnpkg.com/requizzle/-/requizzle-0.2.4.tgz#319eb658b28c370f0c20f968fa8ceab98c13d27c" + integrity sha512-JRrFk1D4OQ4SqovXOgdav+K8EAhSB/LJZqCz8tbX0KObcdeM15Ss59ozWMBWmmINMagCwmqn4ZNryUGpBsl6Jw== dependencies: - lodash "^4.17.14" + lodash "^4.17.21" resolve@^1.22.1: version "1.22.1" @@ -721,6 +943,11 @@ semver@5.6.0: resolved "https://registry.yarnpkg.com/semver/-/semver-5.6.0.tgz#7e74256fbaa49c75aa7c7a205cc22799cac80004" integrity sha512-RS9R6R35NYgQn++fkDWaOmqGoj4Ek9gGs+DPxNUZKuwE183xjJroKvyo1IzVFeXvUrvmALy6FWD5xrdJT25gMg== +semver@^6.0.0: + version "6.3.0" + resolved "https://registry.yarnpkg.com/semver/-/semver-6.3.0.tgz#ee0a64c8af5e8ceea67687b133761e1becbd1d3d" + integrity sha512-b39TBaTSfV6yBrapU89p5fKekE2m/NwnDocOVruQFS1/veMgdzuPcnOM34M6CwxW8jH/lxEa5rBoDeUwu5HHTw== + semver@^7.1.2: version "7.3.8" resolved "https://registry.yarnpkg.com/semver/-/semver-7.3.8.tgz#07a78feafb3f7b32347d725e33de7e2a2df67798" @@ -728,6 +955,23 @@ semver@^7.1.2: dependencies: lru-cache "^6.0.0" +shebang-command@^2.0.0: + version "2.0.0" + resolved "https://registry.yarnpkg.com/shebang-command/-/shebang-command-2.0.0.tgz#ccd0af4f8835fbdc265b82461aaf0c36663f34ea" + integrity sha512-kHxr2zZpYtdmrN1qDjrrX/Z1rR1kG8Dx+gkpK1G4eXmvXswmcE1hTWBWYUzlraYw1/yZp6YuDY77YtvbN0dmDA== + dependencies: + shebang-regex "^3.0.0" + +shebang-regex@^3.0.0: + version "3.0.0" + resolved "https://registry.yarnpkg.com/shebang-regex/-/shebang-regex-3.0.0.tgz#ae16f1644d873ecad843b0307b143362d4c42172" + integrity sha512-7++dFhtcx3353uBaq8DDR4NuxBetBzC7ZQOhmTQInHEd6bSrXdiEyzCvG07Z44UYdLShWUyXt5M/yhz8ekcb1A== + +signal-exit@^3.0.2: + version "3.0.7" + resolved "https://registry.yarnpkg.com/signal-exit/-/signal-exit-3.0.7.tgz#a9a1767f8af84155114eaabd73f99273c8f59ad9" + integrity sha512-wnD2ZE+l+SPC/uoS0vXeE9L1+0wuaMqKlfz9AMUo38JsyLSBWSFcHR1Rri62LZc12vLr1gb3jl7iwQhgwpAbGQ== + source-map-support@0.5.9: version "0.5.9" resolved "https://registry.yarnpkg.com/source-map-support/-/source-map-support-0.5.9.tgz#41bc953b2534267ea2d605bccfa7bfa3111ced5f" @@ -749,11 +993,32 @@ source-map@^0.6.0, source-map@~0.6.1: resolved "https://registry.yarnpkg.com/source-map/-/source-map-0.6.1.tgz#74722af32e9614e9c287a8d0bbde48b5e2f1a263" integrity sha512-UjgapumWlbMhkBgzT7Ykc5YXUT46F0iKu8SGXq0bcwP5dz/h0Plj6enJqjz1Zbq2l5WaqYnrVbwWOWMyF3F47g== +source-map@^0.7.3: + version "0.7.4" + resolved "https://registry.yarnpkg.com/source-map/-/source-map-0.7.4.tgz#a9bbe705c9d8846f4e08ff6765acf0f1b0898656" + integrity sha512-l3BikUxvPOcn5E74dZiq5BGsTb5yEwhaTSzccU6t4sDOH8NWJCstKO5QT2CvtFoK6F0saL7p9xHAqHOlCPJygA== + sourcemap-codec@^1.4.8: version "1.4.8" resolved "https://registry.yarnpkg.com/sourcemap-codec/-/sourcemap-codec-1.4.8.tgz#ea804bd94857402e6992d05a38ef1ae35a9ab4c4" integrity sha512-9NykojV5Uih4lgo5So5dtw+f0JgJX30KCNI8gwhz2J9A15wD0Ml6tjHKwf6fTSa6fAdVBdZeNOs9eJ71qCk8vA== +string-width@^4.1.0, string-width@^4.2.0: + version "4.2.3" + resolved "https://registry.yarnpkg.com/string-width/-/string-width-4.2.3.tgz#269c7117d27b05ad2e536830a8ec895ef9c6d010" + integrity sha512-wKyQRQpjJ0sIp62ErSZdGsjMJWsap5oRNihHhu6G7JVO/9jIB6UyevL+tXuOqrng8j/cxKTWyWUwvSTriiZz/g== + dependencies: + emoji-regex "^8.0.0" + is-fullwidth-code-point "^3.0.0" + strip-ansi "^6.0.1" + +strip-ansi@^6.0.0, strip-ansi@^6.0.1: + version "6.0.1" + resolved "https://registry.yarnpkg.com/strip-ansi/-/strip-ansi-6.0.1.tgz#9e26c63d30f53443e9489495b2105d37b67a85d9" + integrity sha512-Y38VPSHcqkFrCpFnQ9vuSXmquuv5oXOKpGeT6aGrr3o3Gc9AlVa6JBfUSOCnbxGGZF+/0ooI7KrPuUSztUdU5A== + dependencies: + ansi-regex "^5.0.1" + strip-json-comments@^3.1.0: version "3.1.1" resolved "https://registry.yarnpkg.com/strip-json-comments/-/strip-json-comments-3.1.1.tgz#31f1281b3832630434831c310c01cccda8cbe006" @@ -777,15 +1042,24 @@ taffydb@2.6.2: integrity sha512-y3JaeRSplks6NYQuCOj3ZFMO3j60rTwbuKCvZxsAraGYH2epusatvZ0baZYA01WsGqJBq/Dl6vOrMUJqyMj8kA== terser@^5.15.1: - version "5.15.1" - resolved "https://registry.yarnpkg.com/terser/-/terser-5.15.1.tgz#8561af6e0fd6d839669c73b92bdd5777d870ed6c" - integrity sha512-K1faMUvpm/FBxjBXud0LWVAGxmvoPbZbfTCYbSgaaYQaIXI3/TdI7a7ZGA73Zrou6Q8Zmz3oeUTsp/dj+ag2Xw== + version "5.16.1" + resolved "https://registry.yarnpkg.com/terser/-/terser-5.16.1.tgz#5af3bc3d0f24241c7fb2024199d5c461a1075880" + integrity sha512-xvQfyfA1ayT0qdK47zskQgRZeWLoOQ8JQ6mIgRGVNwZKdQMU+5FkCBjmv4QjcrTzyZquRw2FVtlJSRUmMKQslw== dependencies: "@jridgewell/source-map" "^0.3.2" acorn "^8.5.0" commander "^2.20.0" source-map-support "~0.5.20" +test-exclude@^6.0.0: + version "6.0.0" + resolved "https://registry.yarnpkg.com/test-exclude/-/test-exclude-6.0.0.tgz#04a8698661d805ea6fa293b6cb9e63ac044ef15e" + integrity sha512-cAGWPIyOHU6zlmg88jwm7VRyXnMN7iV68OGAbYDk/Mh/xC/pzVPlQtY6ngoIH/5/tciuhGfvESU8GrHrcxD56w== + dependencies: + "@istanbuljs/schema" "^0.1.2" + glob "^7.1.4" + minimatch "^3.0.4" + tmp@^0.2.1: version "0.2.1" resolved "https://registry.yarnpkg.com/tmp/-/tmp-0.2.1.tgz#8457fc3037dcf4719c251367a1af6500ee1ccf14" @@ -820,9 +1094,9 @@ type-check@~0.3.2: prelude-ls "~1.1.2" typescript@^4.8.4: - version "4.8.4" - resolved "https://registry.yarnpkg.com/typescript/-/typescript-4.8.4.tgz#c464abca159669597be5f96b8943500b238e60e6" - integrity sha512-QCh+85mCy+h0IGff8r5XWzOVSbBO+KfeYrMQh7NJ58QujwcE22u+NUSmUxqF+un70P9GXKxa2HCNiTTMJknyjQ== + version "4.9.3" + resolved "https://registry.yarnpkg.com/typescript/-/typescript-4.9.3.tgz#3aea307c1746b8c384435d8ac36b8a2e580d85db" + integrity sha512-CIfGzTelbKNEnLpLdGFgdyKhG23CKdKgQPOBc+OUNrkJ2vr+KSzsSV5kq5iWhEQbok+quxgGzrAtGWCyU7tHnA== uc.micro@^1.0.1, uc.micro@^1.0.5: version "1.0.6" @@ -839,11 +1113,36 @@ underscore@~1.13.2: resolved "https://registry.yarnpkg.com/underscore/-/underscore-1.13.6.tgz#04786a1f589dc6c09f761fc5f45b89e935136441" integrity sha512-+A5Sja4HP1M08MaXya7p5LvjuM7K6q/2EaC0+iovj/wOcMsTzMvDFbasi/oSapiwOlt252IqsKqPjCl7huKS0A== +v8-to-istanbul@^7.1.0: + version "7.1.2" + resolved "https://registry.yarnpkg.com/v8-to-istanbul/-/v8-to-istanbul-7.1.2.tgz#30898d1a7fa0c84d225a2c1434fb958f290883c1" + integrity sha512-TxNb7YEUwkLXCQYeudi6lgQ/SZrzNO4kMdlqVxaZPUIUjCv6iSSypUQX70kNBSERpQ8fk48+d61FXk+tgqcWow== + dependencies: + "@types/istanbul-lib-coverage" "^2.0.1" + convert-source-map "^1.6.0" + source-map "^0.7.3" + +which@^2.0.1: + version "2.0.2" + resolved "https://registry.yarnpkg.com/which/-/which-2.0.2.tgz#7c6a8dd0a636a0327e10b59c9286eee93f3f51b1" + integrity sha512-BLI3Tl1TW3Pvl70l3yq3Y64i+awpwXqsGBYWkkqMtnbXgrMD+yj7rhW0kuEDxzJaYXGjEW5ogapKNMEKNMjibA== + dependencies: + isexe "^2.0.0" + word-wrap@~1.2.3: version "1.2.3" resolved "https://registry.yarnpkg.com/word-wrap/-/word-wrap-1.2.3.tgz#610636f6b1f703891bd34771ccb17fb93b47079c" integrity sha512-Hz/mrNwitNRh/HUAtM/VT/5VH+ygD6DV7mYKZAtHOrbs8U7lvPS6xf7EJKMF0uW1KJCl0H701g3ZGus+muE5vQ== +wrap-ansi@^7.0.0: + version "7.0.0" + resolved "https://registry.yarnpkg.com/wrap-ansi/-/wrap-ansi-7.0.0.tgz#67e145cff510a6a6984bdf1152911d69d2eb9e43" + integrity sha512-YVGIj2kamLSTxw6NsZjoBxfSwsn0ycdesmc4p+Q21c5zPuZ1pl+NfxVdxPtdHvmNVOQ6XSYG4AUtyt/Fi7D16Q== + dependencies: + ansi-styles "^4.0.0" + string-width "^4.1.0" + strip-ansi "^6.0.0" + wrappy@1: version "1.0.2" resolved "https://registry.yarnpkg.com/wrappy/-/wrappy-1.0.2.tgz#b5243d8f3ec1aa35f1364605bc0d1036e30ab69f" @@ -854,7 +1153,35 @@ xmlcreate@^2.0.4: resolved "https://registry.yarnpkg.com/xmlcreate/-/xmlcreate-2.0.4.tgz#0c5ab0f99cdd02a81065fa9cd8f8ae87624889be" integrity sha512-nquOebG4sngPmGPICTS5EnxqhKbCmz5Ox5hsszI2T6U5qdrJizBc+0ilYSEjTSzU0yZcmvppztXe/5Al5fUwdg== +y18n@^5.0.5: + version "5.0.8" + resolved "https://registry.yarnpkg.com/y18n/-/y18n-5.0.8.tgz#7f4934d0f7ca8c56f95314939ddcd2dd91ce1d55" + integrity sha512-0pfFzegeDWJHJIAmTLRP2DwHjdF5s7jo9tuztdQxAhINCdvS+3nGINqPd00AphqJR/0LhANUS6/+7SCb98YOfA== + yallist@^4.0.0: version "4.0.0" resolved "https://registry.yarnpkg.com/yallist/-/yallist-4.0.0.tgz#9bb92790d9c0effec63be73519e11a35019a3a72" integrity sha512-3wdGidZyq5PB084XLES5TpOSRA3wjXAlIWMhum2kRcv/41Sn2emQ0dycQW4uZXLejwKvg6EsvbdlVL+FYEct7A== + +yargs-parser@^20.0.0, yargs-parser@^20.2.2: + version "20.2.9" + resolved "https://registry.yarnpkg.com/yargs-parser/-/yargs-parser-20.2.9.tgz#2eb7dc3b0289718fc295f362753845c41a0c94ee" + integrity sha512-y11nGElTIV+CT3Zv9t7VKl+Q3hTQoT9a1Qzezhhl6Rp21gJ/IVTW7Z3y9EWXhuUBC2Shnf+DX0antecpAwSP8w== + +yargs@^16.0.0: + version "16.2.0" + resolved "https://registry.yarnpkg.com/yargs/-/yargs-16.2.0.tgz#1c82bf0f6b6a66eafce7ef30e376f49a12477f66" + integrity sha512-D1mvvtDG0L5ft/jGWkLpG1+m0eQxOfaBvTNELraWj22wSVUMWxZUvYgJYcKh6jGGIkJFhH4IZPQhR4TKpc8mBw== + dependencies: + cliui "^7.0.2" + escalade "^3.1.1" + get-caller-file "^2.0.5" + require-directory "^2.1.1" + string-width "^4.2.0" + y18n "^5.0.5" + yargs-parser "^20.2.2" + +yocto-queue@^0.1.0: + version "0.1.0" + resolved "https://registry.yarnpkg.com/yocto-queue/-/yocto-queue-0.1.0.tgz#0294eb3dee05028d31ee1a5fa2c556a6aaf10a1b" + integrity sha512-rVksvsnNCdJ/ohGc6xgPwyN8eheCxsiLM8mxuE/t/mOVqJewPuO1miLpTHQiRgTKCLexL4MeAFVagts7HmNZ2Q==