Merge branch 'master' into ios-task
This commit is contained in:
commit
ee230520da
25
.github/ISSUE_TEMPLATE/11-tasks-issue.md
vendored
Normal file
25
.github/ISSUE_TEMPLATE/11-tasks-issue.md
vendored
Normal file
|
@ -0,0 +1,25 @@
|
|||
---
|
||||
name: "Tasks Issue"
|
||||
about: Use this template for assistance with using MediaPipe Tasks (developers.google.com/mediapipe/solutions) to deploy on-device ML solutions (e.g. gesture recognition etc.) on supported platforms.
|
||||
labels: type:support
|
||||
|
||||
---
|
||||
<em>Please make sure that this is a [Tasks](https://developers.google.com/mediapipe/solutions) issue.<em>
|
||||
|
||||
**System information** (Please provide as much relevant information as possible)
|
||||
- Have I written custom code (as opposed to using a stock example script provided in MediaPipe):
|
||||
- OS Platform and Distribution (e.g., Linux Ubuntu 16.04, Android 11, iOS 14.4):
|
||||
- MediaPipe Tasks SDK version:
|
||||
- Task name (e.g. Object detection, Gesture recognition etc.):
|
||||
- Programming Language and version ( e.g. C++, Python, Java):
|
||||
|
||||
**Describe the expected behavior:**
|
||||
|
||||
**Standalone code you may have used to try to get what you need :**
|
||||
|
||||
If there is a problem, provide a reproducible test case that is the bare minimum necessary to generate the problem. If possible, please share a link to Colab, GitHub repo link or anything that we can use to reproduce the problem:
|
||||
|
||||
**Other info / Complete Logs :**
|
||||
Include any logs or source code that would be helpful to
|
||||
diagnose the problem. If including tracebacks, please include the full
|
||||
traceback. Large logs and files should be attached:
|
25
.github/ISSUE_TEMPLATE/12-model-maker-issue.md
vendored
Normal file
25
.github/ISSUE_TEMPLATE/12-model-maker-issue.md
vendored
Normal file
|
@ -0,0 +1,25 @@
|
|||
---
|
||||
name: "Model Maker Issue"
|
||||
about: Use this template for assistance with using MediaPipe Model Maker (developers.google.com/mediapipe/solutions) to create custom on-device ML solutions.
|
||||
labels: type:support
|
||||
|
||||
---
|
||||
<em>Please make sure that this is a [Model Maker](https://developers.google.com/mediapipe/solutions) issue.<em>
|
||||
|
||||
**System information** (Please provide as much relevant information as possible)
|
||||
- Have I written custom code (as opposed to using a stock example script provided in MediaPipe):
|
||||
- OS Platform and Distribution (e.g., Linux Ubuntu 16.04):
|
||||
- Python version (e.g. 3.8):
|
||||
- [MediaPipe Model Maker version](https://pypi.org/project/mediapipe-model-maker/):
|
||||
- Task name (e.g. Image classification, Gesture recognition etc.):
|
||||
|
||||
**Describe the expected behavior:**
|
||||
|
||||
**Standalone code you may have used to try to get what you need :**
|
||||
|
||||
If there is a problem, provide a reproducible test case that is the bare minimum necessary to generate the problem. If possible, please share a link to Colab, GitHub repo link or anything that we can use to reproduce the problem:
|
||||
|
||||
**Other info / Complete Logs :**
|
||||
Include any logs or source code that would be helpful to
|
||||
diagnose the problem. If including tracebacks, please include the full
|
||||
traceback. Large logs and files should be attached:
|
|
@ -1,6 +1,6 @@
|
|||
---
|
||||
name: "Solution Issue"
|
||||
about: Use this template for assistance with a specific mediapipe solution, such as "Pose" or "Iris", including inference model usage/training, solution-specific calculators, etc.
|
||||
name: "Solution (legacy) Issue"
|
||||
about: Use this template for assistance with a specific Mediapipe solution (google.github.io/mediapipe/solutions), such as "Pose" or "Iris", including inference model usage/training, solution-specific calculators, etc.
|
||||
labels: type:support
|
||||
|
||||
---
|
19
.github/ISSUE_TEMPLATE/14-studio-issue.md
vendored
Normal file
19
.github/ISSUE_TEMPLATE/14-studio-issue.md
vendored
Normal file
|
@ -0,0 +1,19 @@
|
|||
---
|
||||
name: "Studio Issue"
|
||||
about: Use this template for assistance with the MediaPipe Studio application.
|
||||
labels: type:support
|
||||
|
||||
---
|
||||
<em>Please make sure that this is a MediaPipe Studio issue.<em>
|
||||
|
||||
**System information** (Please provide as much relevant information as possible)
|
||||
- OS Platform and Distribution (e.g., Linux Ubuntu 16.04, Android 11, iOS 14.4):
|
||||
- Browser and Version
|
||||
- Any microphone or camera hardware
|
||||
- URL that shows the problem
|
||||
|
||||
**Describe the expected behavior:**
|
||||
|
||||
**Other info / Complete Logs :**
|
||||
Include any js console logs that would be helpful to diagnose the problem.
|
||||
Large logs and files should be attached:
|
56
WORKSPACE
56
WORKSPACE
|
@ -320,12 +320,30 @@ http_archive(
|
|||
],
|
||||
)
|
||||
|
||||
# iOS basic build deps.
|
||||
# Load Zlib before initializing TensorFlow and the iOS build rules to guarantee
|
||||
# that the target @zlib//:mini_zlib is available
|
||||
http_archive(
|
||||
name = "zlib",
|
||||
build_file = "//third_party:zlib.BUILD",
|
||||
sha256 = "c3e5e9fdd5004dcb542feda5ee4f0ff0744628baf8ed2dd5d66f8ca1197cb1a1",
|
||||
strip_prefix = "zlib-1.2.11",
|
||||
urls = [
|
||||
"http://mirror.bazel.build/zlib.net/fossils/zlib-1.2.11.tar.gz",
|
||||
"http://zlib.net/fossils/zlib-1.2.11.tar.gz", # 2017-01-15
|
||||
],
|
||||
patches = [
|
||||
"@//third_party:zlib.diff",
|
||||
],
|
||||
patch_args = [
|
||||
"-p1",
|
||||
],
|
||||
)
|
||||
|
||||
# iOS basic build deps.
|
||||
http_archive(
|
||||
name = "build_bazel_rules_apple",
|
||||
sha256 = "77e8bf6fda706f420a55874ae6ee4df0c9d95da6c7838228b26910fc82eea5a2",
|
||||
url = "https://github.com/bazelbuild/rules_apple/releases/download/0.32.0/rules_apple.0.32.0.tar.gz",
|
||||
sha256 = "f94e6dddf74739ef5cb30f000e13a2a613f6ebfa5e63588305a71fce8a8a9911",
|
||||
url = "https://github.com/bazelbuild/rules_apple/releases/download/1.1.3/rules_apple.1.1.3.tar.gz",
|
||||
patches = [
|
||||
# Bypass checking ios unit test runner when building MP ios applications.
|
||||
"@//third_party:build_bazel_rules_apple_bypass_test_runner_check.diff"
|
||||
|
@ -339,29 +357,24 @@ load(
|
|||
"@build_bazel_rules_apple//apple:repositories.bzl",
|
||||
"apple_rules_dependencies",
|
||||
)
|
||||
|
||||
apple_rules_dependencies()
|
||||
|
||||
load(
|
||||
"@build_bazel_rules_swift//swift:repositories.bzl",
|
||||
"swift_rules_dependencies",
|
||||
)
|
||||
|
||||
swift_rules_dependencies()
|
||||
|
||||
http_archive(
|
||||
name = "build_bazel_apple_support",
|
||||
sha256 = "741366f79d900c11e11d8efd6cc6c66a31bfb2451178b58e0b5edc6f1db17b35",
|
||||
urls = [
|
||||
"https://github.com/bazelbuild/apple_support/releases/download/0.10.0/apple_support.0.10.0.tar.gz"
|
||||
],
|
||||
load(
|
||||
"@build_bazel_rules_swift//swift:extras.bzl",
|
||||
"swift_rules_extra_dependencies",
|
||||
)
|
||||
swift_rules_extra_dependencies()
|
||||
|
||||
load(
|
||||
"@build_bazel_apple_support//lib:repositories.bzl",
|
||||
"apple_support_dependencies",
|
||||
)
|
||||
|
||||
apple_support_dependencies()
|
||||
|
||||
# More iOS deps.
|
||||
|
@ -442,25 +455,6 @@ http_archive(
|
|||
],
|
||||
)
|
||||
|
||||
# Load Zlib before initializing TensorFlow to guarantee that the target
|
||||
# @zlib//:mini_zlib is available
|
||||
http_archive(
|
||||
name = "zlib",
|
||||
build_file = "//third_party:zlib.BUILD",
|
||||
sha256 = "c3e5e9fdd5004dcb542feda5ee4f0ff0744628baf8ed2dd5d66f8ca1197cb1a1",
|
||||
strip_prefix = "zlib-1.2.11",
|
||||
urls = [
|
||||
"http://mirror.bazel.build/zlib.net/fossils/zlib-1.2.11.tar.gz",
|
||||
"http://zlib.net/fossils/zlib-1.2.11.tar.gz", # 2017-01-15
|
||||
],
|
||||
patches = [
|
||||
"@//third_party:zlib.diff",
|
||||
],
|
||||
patch_args = [
|
||||
"-p1",
|
||||
],
|
||||
)
|
||||
|
||||
# TensorFlow repo should always go after the other external dependencies.
|
||||
# TF on 2022-08-10.
|
||||
_TENSORFLOW_GIT_COMMIT = "af1d5bc4fbb66d9e6cc1cf89503014a99233583b"
|
||||
|
|
81
docs/build_model_maker_api_docs.py
Normal file
81
docs/build_model_maker_api_docs.py
Normal file
|
@ -0,0 +1,81 @@
|
|||
# Copyright 2022 The MediaPipe Authors. All Rights Reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ==============================================================================
|
||||
r"""MediaPipe Model Maker reference docs generation script.
|
||||
|
||||
This script generates API reference docs for the `mediapipe` PIP package.
|
||||
|
||||
$> pip install -U git+https://github.com/tensorflow/docs mediapipe-model-maker
|
||||
$> python build_model_maker_api_docs.py
|
||||
"""
|
||||
|
||||
import os
|
||||
|
||||
from absl import app
|
||||
from absl import flags
|
||||
|
||||
from tensorflow_docs.api_generator import generate_lib
|
||||
|
||||
try:
|
||||
# mediapipe has not been set up to work with bazel yet, so catch & report.
|
||||
import mediapipe_model_maker # pytype: disable=import-error
|
||||
except ImportError as e:
|
||||
raise ImportError('Please `pip install mediapipe-model-maker`.') from e
|
||||
|
||||
|
||||
PROJECT_SHORT_NAME = 'mediapipe_model_maker'
|
||||
PROJECT_FULL_NAME = 'MediaPipe Model Maker'
|
||||
|
||||
_OUTPUT_DIR = flags.DEFINE_string(
|
||||
'output_dir',
|
||||
default='/tmp/generated_docs',
|
||||
help='Where to write the resulting docs.')
|
||||
|
||||
_URL_PREFIX = flags.DEFINE_string(
|
||||
'code_url_prefix',
|
||||
'https://github.com/google/mediapipe/tree/master/mediapipe/model_maker',
|
||||
'The url prefix for links to code.')
|
||||
|
||||
_SEARCH_HINTS = flags.DEFINE_bool(
|
||||
'search_hints', True,
|
||||
'Include metadata search hints in the generated files')
|
||||
|
||||
_SITE_PATH = flags.DEFINE_string('site_path', '/mediapipe/api_docs/python',
|
||||
'Path prefix in the _toc.yaml')
|
||||
|
||||
|
||||
def gen_api_docs():
|
||||
"""Generates API docs for the mediapipe-model-maker package."""
|
||||
|
||||
doc_generator = generate_lib.DocGenerator(
|
||||
root_title=PROJECT_FULL_NAME,
|
||||
py_modules=[(PROJECT_SHORT_NAME, mediapipe_model_maker)],
|
||||
base_dir=os.path.dirname(mediapipe_model_maker.__file__),
|
||||
code_url_prefix=_URL_PREFIX.value,
|
||||
search_hints=_SEARCH_HINTS.value,
|
||||
site_path=_SITE_PATH.value,
|
||||
callbacks=[],
|
||||
)
|
||||
|
||||
doc_generator.build(_OUTPUT_DIR.value)
|
||||
|
||||
print('Docs output to:', _OUTPUT_DIR.value)
|
||||
|
||||
|
||||
def main(_):
|
||||
gen_api_docs()
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
app.run(main)
|
|
@ -26,7 +26,6 @@ from absl import app
|
|||
from absl import flags
|
||||
|
||||
from tensorflow_docs.api_generator import generate_lib
|
||||
from tensorflow_docs.api_generator import public_api
|
||||
|
||||
try:
|
||||
# mediapipe has not been set up to work with bazel yet, so catch & report.
|
||||
|
@ -45,14 +44,14 @@ _OUTPUT_DIR = flags.DEFINE_string(
|
|||
|
||||
_URL_PREFIX = flags.DEFINE_string(
|
||||
'code_url_prefix',
|
||||
'https://github.com/google/mediapipe/tree/master/mediapipe',
|
||||
'https://github.com/google/mediapipe/blob/master/mediapipe',
|
||||
'The url prefix for links to code.')
|
||||
|
||||
_SEARCH_HINTS = flags.DEFINE_bool(
|
||||
'search_hints', True,
|
||||
'Include metadata search hints in the generated files')
|
||||
|
||||
_SITE_PATH = flags.DEFINE_string('site_path', '/mediapipe/api_docs/python',
|
||||
_SITE_PATH = flags.DEFINE_string('site_path', '/mediapipe/api/solutions/python',
|
||||
'Path prefix in the _toc.yaml')
|
||||
|
||||
|
||||
|
@ -68,10 +67,7 @@ def gen_api_docs():
|
|||
code_url_prefix=_URL_PREFIX.value,
|
||||
search_hints=_SEARCH_HINTS.value,
|
||||
site_path=_SITE_PATH.value,
|
||||
# This callback ensures that docs are only generated for objects that
|
||||
# are explicitly imported in your __init__.py files. There are other
|
||||
# options but this is a good starting point.
|
||||
callbacks=[public_api.explicit_package_contents_filter],
|
||||
callbacks=[],
|
||||
)
|
||||
|
||||
doc_generator.build(_OUTPUT_DIR.value)
|
||||
|
|
|
@ -17,12 +17,11 @@ load("//mediapipe/framework/port:build_config.bzl", "mediapipe_proto_library")
|
|||
|
||||
licenses(["notice"])
|
||||
|
||||
package(default_visibility = ["//visibility:private"])
|
||||
package(default_visibility = ["//visibility:public"])
|
||||
|
||||
mediapipe_proto_library(
|
||||
name = "concatenate_vector_calculator_proto",
|
||||
srcs = ["concatenate_vector_calculator.proto"],
|
||||
visibility = ["//visibility:public"],
|
||||
deps = [
|
||||
"//mediapipe/framework:calculator_options_proto",
|
||||
"//mediapipe/framework:calculator_proto",
|
||||
|
@ -32,7 +31,6 @@ mediapipe_proto_library(
|
|||
mediapipe_proto_library(
|
||||
name = "dequantize_byte_array_calculator_proto",
|
||||
srcs = ["dequantize_byte_array_calculator.proto"],
|
||||
visibility = ["//visibility:public"],
|
||||
deps = [
|
||||
"//mediapipe/framework:calculator_options_proto",
|
||||
"//mediapipe/framework:calculator_proto",
|
||||
|
@ -42,7 +40,6 @@ mediapipe_proto_library(
|
|||
mediapipe_proto_library(
|
||||
name = "packet_cloner_calculator_proto",
|
||||
srcs = ["packet_cloner_calculator.proto"],
|
||||
visibility = ["//visibility:public"],
|
||||
deps = [
|
||||
"//mediapipe/framework:calculator_options_proto",
|
||||
"//mediapipe/framework:calculator_proto",
|
||||
|
@ -52,7 +49,6 @@ mediapipe_proto_library(
|
|||
mediapipe_proto_library(
|
||||
name = "packet_resampler_calculator_proto",
|
||||
srcs = ["packet_resampler_calculator.proto"],
|
||||
visibility = ["//visibility:public"],
|
||||
deps = [
|
||||
"//mediapipe/framework:calculator_options_proto",
|
||||
"//mediapipe/framework:calculator_proto",
|
||||
|
@ -62,7 +58,6 @@ mediapipe_proto_library(
|
|||
mediapipe_proto_library(
|
||||
name = "packet_thinner_calculator_proto",
|
||||
srcs = ["packet_thinner_calculator.proto"],
|
||||
visibility = ["//visibility:public"],
|
||||
deps = [
|
||||
"//mediapipe/framework:calculator_options_proto",
|
||||
"//mediapipe/framework:calculator_proto",
|
||||
|
@ -72,7 +67,6 @@ mediapipe_proto_library(
|
|||
mediapipe_proto_library(
|
||||
name = "split_vector_calculator_proto",
|
||||
srcs = ["split_vector_calculator.proto"],
|
||||
visibility = ["//visibility:public"],
|
||||
deps = [
|
||||
"//mediapipe/framework:calculator_options_proto",
|
||||
"//mediapipe/framework:calculator_proto",
|
||||
|
@ -82,7 +76,6 @@ mediapipe_proto_library(
|
|||
mediapipe_proto_library(
|
||||
name = "quantize_float_vector_calculator_proto",
|
||||
srcs = ["quantize_float_vector_calculator.proto"],
|
||||
visibility = ["//visibility:public"],
|
||||
deps = [
|
||||
"//mediapipe/framework:calculator_options_proto",
|
||||
"//mediapipe/framework:calculator_proto",
|
||||
|
@ -92,7 +85,6 @@ mediapipe_proto_library(
|
|||
mediapipe_proto_library(
|
||||
name = "sequence_shift_calculator_proto",
|
||||
srcs = ["sequence_shift_calculator.proto"],
|
||||
visibility = ["//visibility:public"],
|
||||
deps = [
|
||||
"//mediapipe/framework:calculator_options_proto",
|
||||
"//mediapipe/framework:calculator_proto",
|
||||
|
@ -102,7 +94,6 @@ mediapipe_proto_library(
|
|||
mediapipe_proto_library(
|
||||
name = "gate_calculator_proto",
|
||||
srcs = ["gate_calculator.proto"],
|
||||
visibility = ["//visibility:public"],
|
||||
deps = [
|
||||
"//mediapipe/framework:calculator_options_proto",
|
||||
"//mediapipe/framework:calculator_proto",
|
||||
|
@ -112,7 +103,6 @@ mediapipe_proto_library(
|
|||
mediapipe_proto_library(
|
||||
name = "constant_side_packet_calculator_proto",
|
||||
srcs = ["constant_side_packet_calculator.proto"],
|
||||
visibility = ["//visibility:public"],
|
||||
deps = [
|
||||
"//mediapipe/framework:calculator_options_proto",
|
||||
"//mediapipe/framework:calculator_proto",
|
||||
|
@ -124,7 +114,6 @@ mediapipe_proto_library(
|
|||
mediapipe_proto_library(
|
||||
name = "clip_vector_size_calculator_proto",
|
||||
srcs = ["clip_vector_size_calculator.proto"],
|
||||
visibility = ["//visibility:public"],
|
||||
deps = [
|
||||
"//mediapipe/framework:calculator_options_proto",
|
||||
"//mediapipe/framework:calculator_proto",
|
||||
|
@ -134,7 +123,6 @@ mediapipe_proto_library(
|
|||
mediapipe_proto_library(
|
||||
name = "flow_limiter_calculator_proto",
|
||||
srcs = ["flow_limiter_calculator.proto"],
|
||||
visibility = ["//visibility:public"],
|
||||
deps = [
|
||||
"//mediapipe/framework:calculator_options_proto",
|
||||
"//mediapipe/framework:calculator_proto",
|
||||
|
@ -144,7 +132,6 @@ mediapipe_proto_library(
|
|||
mediapipe_proto_library(
|
||||
name = "graph_profile_calculator_proto",
|
||||
srcs = ["graph_profile_calculator.proto"],
|
||||
visibility = ["//visibility:public"],
|
||||
deps = [
|
||||
"//mediapipe/framework:calculator_options_proto",
|
||||
"//mediapipe/framework:calculator_proto",
|
||||
|
@ -154,7 +141,6 @@ mediapipe_proto_library(
|
|||
mediapipe_proto_library(
|
||||
name = "get_vector_item_calculator_proto",
|
||||
srcs = ["get_vector_item_calculator.proto"],
|
||||
visibility = ["//visibility:public"],
|
||||
deps = [
|
||||
"//mediapipe/framework:calculator_options_proto",
|
||||
"//mediapipe/framework:calculator_proto",
|
||||
|
@ -164,7 +150,6 @@ mediapipe_proto_library(
|
|||
cc_library(
|
||||
name = "add_header_calculator",
|
||||
srcs = ["add_header_calculator.cc"],
|
||||
visibility = ["//visibility:public"],
|
||||
deps = [
|
||||
"//mediapipe/framework:calculator_framework",
|
||||
"//mediapipe/framework/api2:node",
|
||||
|
@ -193,7 +178,6 @@ cc_library(
|
|||
name = "begin_loop_calculator",
|
||||
srcs = ["begin_loop_calculator.cc"],
|
||||
hdrs = ["begin_loop_calculator.h"],
|
||||
visibility = ["//visibility:public"],
|
||||
deps = [
|
||||
"//mediapipe/framework:calculator_context",
|
||||
"//mediapipe/framework:calculator_contract",
|
||||
|
@ -216,7 +200,6 @@ cc_library(
|
|||
name = "end_loop_calculator",
|
||||
srcs = ["end_loop_calculator.cc"],
|
||||
hdrs = ["end_loop_calculator.h"],
|
||||
visibility = ["//visibility:public"],
|
||||
deps = [
|
||||
"//mediapipe/framework:calculator_context",
|
||||
"//mediapipe/framework:calculator_contract",
|
||||
|
@ -258,7 +241,6 @@ cc_test(
|
|||
cc_library(
|
||||
name = "concatenate_vector_calculator_hdr",
|
||||
hdrs = ["concatenate_vector_calculator.h"],
|
||||
visibility = ["//visibility:public"],
|
||||
deps = [
|
||||
":concatenate_vector_calculator_cc_proto",
|
||||
"//mediapipe/framework:calculator_framework",
|
||||
|
@ -284,7 +266,6 @@ cc_library(
|
|||
],
|
||||
"//conditions:default": [],
|
||||
}),
|
||||
visibility = ["//visibility:public"],
|
||||
deps = [
|
||||
":concatenate_vector_calculator_cc_proto",
|
||||
"//mediapipe/framework/api2:node",
|
||||
|
@ -311,7 +292,6 @@ cc_library(
|
|||
cc_library(
|
||||
name = "concatenate_detection_vector_calculator",
|
||||
srcs = ["concatenate_detection_vector_calculator.cc"],
|
||||
visibility = ["//visibility:public"],
|
||||
deps = [
|
||||
":concatenate_vector_calculator",
|
||||
"//mediapipe/framework:calculator_framework",
|
||||
|
@ -323,7 +303,6 @@ cc_library(
|
|||
cc_library(
|
||||
name = "concatenate_proto_list_calculator",
|
||||
srcs = ["concatenate_proto_list_calculator.cc"],
|
||||
visibility = ["//visibility:public"],
|
||||
deps = [
|
||||
":concatenate_vector_calculator_cc_proto",
|
||||
"//mediapipe/framework:calculator_framework",
|
||||
|
@ -372,7 +351,6 @@ cc_library(
|
|||
name = "clip_vector_size_calculator",
|
||||
srcs = ["clip_vector_size_calculator.cc"],
|
||||
hdrs = ["clip_vector_size_calculator.h"],
|
||||
visibility = ["//visibility:public"],
|
||||
deps = [
|
||||
":clip_vector_size_calculator_cc_proto",
|
||||
"//mediapipe/framework:calculator_framework",
|
||||
|
@ -388,7 +366,6 @@ cc_library(
|
|||
cc_library(
|
||||
name = "clip_detection_vector_size_calculator",
|
||||
srcs = ["clip_detection_vector_size_calculator.cc"],
|
||||
visibility = ["//visibility:public"],
|
||||
deps = [
|
||||
":clip_vector_size_calculator",
|
||||
"//mediapipe/framework:calculator_framework",
|
||||
|
@ -415,9 +392,6 @@ cc_test(
|
|||
cc_library(
|
||||
name = "counting_source_calculator",
|
||||
srcs = ["counting_source_calculator.cc"],
|
||||
visibility = [
|
||||
"//visibility:public",
|
||||
],
|
||||
deps = [
|
||||
"//mediapipe/framework:calculator_framework",
|
||||
"//mediapipe/framework/port:ret_check",
|
||||
|
@ -430,9 +404,6 @@ cc_library(
|
|||
cc_library(
|
||||
name = "make_pair_calculator",
|
||||
srcs = ["make_pair_calculator.cc"],
|
||||
visibility = [
|
||||
"//visibility:public",
|
||||
],
|
||||
deps = [
|
||||
"//mediapipe/framework:calculator_framework",
|
||||
"//mediapipe/framework/api2:node",
|
||||
|
@ -461,9 +432,6 @@ cc_test(
|
|||
cc_library(
|
||||
name = "matrix_multiply_calculator",
|
||||
srcs = ["matrix_multiply_calculator.cc"],
|
||||
visibility = [
|
||||
"//visibility:public",
|
||||
],
|
||||
deps = [
|
||||
"//mediapipe/framework:calculator_framework",
|
||||
"//mediapipe/framework/api2:node",
|
||||
|
@ -477,9 +445,6 @@ cc_library(
|
|||
cc_library(
|
||||
name = "matrix_subtract_calculator",
|
||||
srcs = ["matrix_subtract_calculator.cc"],
|
||||
visibility = [
|
||||
"//visibility:public",
|
||||
],
|
||||
deps = [
|
||||
"//mediapipe/framework:calculator_framework",
|
||||
"//mediapipe/framework/api2:node",
|
||||
|
@ -493,9 +458,6 @@ cc_library(
|
|||
cc_library(
|
||||
name = "mux_calculator",
|
||||
srcs = ["mux_calculator.cc"],
|
||||
visibility = [
|
||||
"//visibility:public",
|
||||
],
|
||||
deps = [
|
||||
"//mediapipe/framework:calculator_framework",
|
||||
"//mediapipe/framework/api2:node",
|
||||
|
@ -508,9 +470,6 @@ cc_library(
|
|||
cc_library(
|
||||
name = "non_zero_calculator",
|
||||
srcs = ["non_zero_calculator.cc"],
|
||||
visibility = [
|
||||
"//visibility:public",
|
||||
],
|
||||
deps = [
|
||||
"//mediapipe/framework:calculator_framework",
|
||||
"//mediapipe/framework/api2:node",
|
||||
|
@ -556,9 +515,6 @@ cc_test(
|
|||
cc_library(
|
||||
name = "packet_cloner_calculator",
|
||||
srcs = ["packet_cloner_calculator.cc"],
|
||||
visibility = [
|
||||
"//visibility:public",
|
||||
],
|
||||
deps = [
|
||||
":packet_cloner_calculator_cc_proto",
|
||||
"//mediapipe/framework:calculator_framework",
|
||||
|
@ -587,7 +543,6 @@ cc_test(
|
|||
cc_library(
|
||||
name = "packet_inner_join_calculator",
|
||||
srcs = ["packet_inner_join_calculator.cc"],
|
||||
visibility = ["//visibility:public"],
|
||||
deps = [
|
||||
"//mediapipe/framework:calculator_framework",
|
||||
"//mediapipe/framework/port:ret_check",
|
||||
|
@ -611,7 +566,6 @@ cc_test(
|
|||
cc_library(
|
||||
name = "packet_thinner_calculator",
|
||||
srcs = ["packet_thinner_calculator.cc"],
|
||||
visibility = ["//visibility:public"],
|
||||
deps = [
|
||||
"//mediapipe/calculators/core:packet_thinner_calculator_cc_proto",
|
||||
"//mediapipe/framework:calculator_context",
|
||||
|
@ -643,9 +597,6 @@ cc_test(
|
|||
cc_library(
|
||||
name = "pass_through_calculator",
|
||||
srcs = ["pass_through_calculator.cc"],
|
||||
visibility = [
|
||||
"//visibility:public",
|
||||
],
|
||||
deps = [
|
||||
"//mediapipe/framework:calculator_framework",
|
||||
"//mediapipe/framework/port:status",
|
||||
|
@ -656,9 +607,6 @@ cc_library(
|
|||
cc_library(
|
||||
name = "round_robin_demux_calculator",
|
||||
srcs = ["round_robin_demux_calculator.cc"],
|
||||
visibility = [
|
||||
"//visibility:public",
|
||||
],
|
||||
deps = [
|
||||
"//mediapipe/framework:calculator_framework",
|
||||
"//mediapipe/framework/api2:node",
|
||||
|
@ -670,9 +618,6 @@ cc_library(
|
|||
cc_library(
|
||||
name = "immediate_mux_calculator",
|
||||
srcs = ["immediate_mux_calculator.cc"],
|
||||
visibility = [
|
||||
"//visibility:public",
|
||||
],
|
||||
deps = [
|
||||
"//mediapipe/framework:calculator_framework",
|
||||
"//mediapipe/framework/port:ret_check",
|
||||
|
@ -684,7 +629,6 @@ cc_library(
|
|||
cc_library(
|
||||
name = "packet_presence_calculator",
|
||||
srcs = ["packet_presence_calculator.cc"],
|
||||
visibility = ["//visibility:public"],
|
||||
deps = [
|
||||
"//mediapipe/framework:calculator_framework",
|
||||
"//mediapipe/framework:packet",
|
||||
|
@ -713,7 +657,6 @@ cc_test(
|
|||
cc_library(
|
||||
name = "previous_loopback_calculator",
|
||||
srcs = ["previous_loopback_calculator.cc"],
|
||||
visibility = ["//visibility:public"],
|
||||
deps = [
|
||||
"//mediapipe/framework:calculator_framework",
|
||||
"//mediapipe/framework:packet",
|
||||
|
@ -729,7 +672,6 @@ cc_library(
|
|||
cc_library(
|
||||
name = "flow_limiter_calculator",
|
||||
srcs = ["flow_limiter_calculator.cc"],
|
||||
visibility = ["//visibility:public"],
|
||||
deps = [
|
||||
":flow_limiter_calculator_cc_proto",
|
||||
"//mediapipe/framework:calculator_framework",
|
||||
|
@ -746,7 +688,6 @@ cc_library(
|
|||
cc_library(
|
||||
name = "string_to_int_calculator",
|
||||
srcs = ["string_to_int_calculator.cc"],
|
||||
visibility = ["//visibility:public"],
|
||||
deps = [
|
||||
"//mediapipe/framework:calculator_framework",
|
||||
"//mediapipe/framework/port:integral_types",
|
||||
|
@ -759,7 +700,6 @@ cc_library(
|
|||
cc_library(
|
||||
name = "default_side_packet_calculator",
|
||||
srcs = ["default_side_packet_calculator.cc"],
|
||||
visibility = ["//visibility:public"],
|
||||
deps = [
|
||||
"//mediapipe/framework:calculator_framework",
|
||||
"//mediapipe/framework/port:ret_check",
|
||||
|
@ -771,7 +711,6 @@ cc_library(
|
|||
cc_library(
|
||||
name = "side_packet_to_stream_calculator",
|
||||
srcs = ["side_packet_to_stream_calculator.cc"],
|
||||
visibility = ["//visibility:public"],
|
||||
deps = [
|
||||
"//mediapipe/framework:calculator_framework",
|
||||
"//mediapipe/framework/port:logging",
|
||||
|
@ -822,9 +761,6 @@ cc_library(
|
|||
name = "packet_resampler_calculator",
|
||||
srcs = ["packet_resampler_calculator.cc"],
|
||||
hdrs = ["packet_resampler_calculator.h"],
|
||||
visibility = [
|
||||
"//visibility:public",
|
||||
],
|
||||
deps = [
|
||||
"//mediapipe/calculators/core:packet_resampler_calculator_cc_proto",
|
||||
"//mediapipe/framework:calculator_framework",
|
||||
|
@ -884,7 +820,6 @@ cc_test(
|
|||
cc_test(
|
||||
name = "matrix_multiply_calculator_test",
|
||||
srcs = ["matrix_multiply_calculator_test.cc"],
|
||||
visibility = ["//visibility:private"],
|
||||
deps = [
|
||||
":matrix_multiply_calculator",
|
||||
"//mediapipe/framework:calculator_framework",
|
||||
|
@ -900,7 +835,6 @@ cc_test(
|
|||
cc_test(
|
||||
name = "matrix_subtract_calculator_test",
|
||||
srcs = ["matrix_subtract_calculator_test.cc"],
|
||||
visibility = ["//visibility:private"],
|
||||
deps = [
|
||||
":matrix_subtract_calculator",
|
||||
"//mediapipe/framework:calculator_framework",
|
||||
|
@ -950,7 +884,6 @@ cc_library(
|
|||
],
|
||||
"//conditions:default": [],
|
||||
}),
|
||||
visibility = ["//visibility:public"],
|
||||
deps = [
|
||||
":split_vector_calculator_cc_proto",
|
||||
"//mediapipe/framework/formats:detection_cc_proto",
|
||||
|
@ -996,7 +929,6 @@ cc_test(
|
|||
cc_library(
|
||||
name = "split_proto_list_calculator",
|
||||
srcs = ["split_proto_list_calculator.cc"],
|
||||
visibility = ["//visibility:public"],
|
||||
deps = [
|
||||
":split_vector_calculator_cc_proto",
|
||||
"//mediapipe/framework:calculator_framework",
|
||||
|
@ -1028,7 +960,6 @@ cc_test(
|
|||
cc_library(
|
||||
name = "dequantize_byte_array_calculator",
|
||||
srcs = ["dequantize_byte_array_calculator.cc"],
|
||||
visibility = ["//visibility:public"],
|
||||
deps = [
|
||||
":dequantize_byte_array_calculator_cc_proto",
|
||||
"//mediapipe/framework:calculator_context",
|
||||
|
@ -1054,7 +985,6 @@ cc_test(
|
|||
cc_library(
|
||||
name = "quantize_float_vector_calculator",
|
||||
srcs = ["quantize_float_vector_calculator.cc"],
|
||||
visibility = ["//visibility:public"],
|
||||
deps = [
|
||||
":quantize_float_vector_calculator_cc_proto",
|
||||
"//mediapipe/framework:calculator_context",
|
||||
|
@ -1080,7 +1010,6 @@ cc_test(
|
|||
cc_library(
|
||||
name = "sequence_shift_calculator",
|
||||
srcs = ["sequence_shift_calculator.cc"],
|
||||
visibility = ["//visibility:public"],
|
||||
deps = [
|
||||
":sequence_shift_calculator_cc_proto",
|
||||
"//mediapipe/framework:calculator_framework",
|
||||
|
@ -1105,7 +1034,6 @@ cc_test(
|
|||
cc_library(
|
||||
name = "gate_calculator",
|
||||
srcs = ["gate_calculator.cc"],
|
||||
visibility = ["//visibility:public"],
|
||||
deps = [
|
||||
":gate_calculator_cc_proto",
|
||||
"//mediapipe/framework:calculator_framework",
|
||||
|
@ -1131,7 +1059,6 @@ cc_test(
|
|||
cc_library(
|
||||
name = "matrix_to_vector_calculator",
|
||||
srcs = ["matrix_to_vector_calculator.cc"],
|
||||
visibility = ["//visibility:public"],
|
||||
deps = [
|
||||
"//mediapipe/framework:calculator_framework",
|
||||
"//mediapipe/framework/api2:node",
|
||||
|
@ -1167,7 +1094,6 @@ cc_test(
|
|||
cc_library(
|
||||
name = "merge_calculator",
|
||||
srcs = ["merge_calculator.cc"],
|
||||
visibility = ["//visibility:public"],
|
||||
deps = [
|
||||
"//mediapipe/framework:calculator_framework",
|
||||
"//mediapipe/framework/api2:node",
|
||||
|
@ -1193,7 +1119,6 @@ cc_test(
|
|||
cc_library(
|
||||
name = "stream_to_side_packet_calculator",
|
||||
srcs = ["stream_to_side_packet_calculator.cc"],
|
||||
visibility = ["//visibility:public"],
|
||||
deps = [
|
||||
"//mediapipe/framework:calculator_framework",
|
||||
"//mediapipe/framework:timestamp",
|
||||
|
@ -1219,7 +1144,6 @@ cc_test(
|
|||
cc_library(
|
||||
name = "constant_side_packet_calculator",
|
||||
srcs = ["constant_side_packet_calculator.cc"],
|
||||
visibility = ["//visibility:public"],
|
||||
deps = [
|
||||
":constant_side_packet_calculator_cc_proto",
|
||||
"//mediapipe/framework:calculator_framework",
|
||||
|
@ -1249,7 +1173,6 @@ cc_test(
|
|||
cc_library(
|
||||
name = "graph_profile_calculator",
|
||||
srcs = ["graph_profile_calculator.cc"],
|
||||
visibility = ["//visibility:public"],
|
||||
deps = [
|
||||
":graph_profile_calculator_cc_proto",
|
||||
"//mediapipe/framework:calculator_framework",
|
||||
|
@ -1291,7 +1214,6 @@ cc_library(
|
|||
name = "get_vector_item_calculator",
|
||||
srcs = ["get_vector_item_calculator.cc"],
|
||||
hdrs = ["get_vector_item_calculator.h"],
|
||||
visibility = ["//visibility:public"],
|
||||
deps = [
|
||||
":get_vector_item_calculator_cc_proto",
|
||||
"//mediapipe/framework:calculator_framework",
|
||||
|
@ -1325,7 +1247,6 @@ cc_library(
|
|||
name = "vector_indices_calculator",
|
||||
srcs = ["vector_indices_calculator.cc"],
|
||||
hdrs = ["vector_indices_calculator.h"],
|
||||
visibility = ["//visibility:public"],
|
||||
deps = [
|
||||
"//mediapipe/framework:calculator_framework",
|
||||
"//mediapipe/framework/api2:node",
|
||||
|
@ -1351,7 +1272,6 @@ cc_library(
|
|||
name = "vector_size_calculator",
|
||||
srcs = ["vector_size_calculator.cc"],
|
||||
hdrs = ["vector_size_calculator.h"],
|
||||
visibility = ["//visibility:public"],
|
||||
deps = [
|
||||
"//mediapipe/framework:calculator_framework",
|
||||
"//mediapipe/framework/api2:node",
|
||||
|
@ -1365,9 +1285,6 @@ cc_library(
|
|||
cc_library(
|
||||
name = "packet_sequencer_calculator",
|
||||
srcs = ["packet_sequencer_calculator.cc"],
|
||||
visibility = [
|
||||
"//visibility:public",
|
||||
],
|
||||
deps = [
|
||||
"//mediapipe/framework:calculator_framework",
|
||||
"//mediapipe/framework/api2:contract",
|
||||
|
@ -1402,11 +1319,11 @@ cc_library(
|
|||
name = "merge_to_vector_calculator",
|
||||
srcs = ["merge_to_vector_calculator.cc"],
|
||||
hdrs = ["merge_to_vector_calculator.h"],
|
||||
visibility = ["//visibility:public"],
|
||||
deps = [
|
||||
"//mediapipe/framework:calculator_framework",
|
||||
"//mediapipe/framework/api2:node",
|
||||
"//mediapipe/framework/api2:port",
|
||||
"//mediapipe/framework/formats:detection_cc_proto",
|
||||
"//mediapipe/framework/formats:image",
|
||||
"@com_google_absl//absl/status",
|
||||
],
|
||||
|
@ -1416,7 +1333,6 @@ cc_library(
|
|||
mediapipe_proto_library(
|
||||
name = "bypass_calculator_proto",
|
||||
srcs = ["bypass_calculator.proto"],
|
||||
visibility = ["//visibility:public"],
|
||||
deps = [
|
||||
"//mediapipe/framework:calculator_options_proto",
|
||||
"//mediapipe/framework:calculator_proto",
|
||||
|
@ -1426,7 +1342,6 @@ mediapipe_proto_library(
|
|||
cc_library(
|
||||
name = "bypass_calculator",
|
||||
srcs = ["bypass_calculator.cc"],
|
||||
visibility = ["//visibility:public"],
|
||||
deps = [
|
||||
":bypass_calculator_cc_proto",
|
||||
"//mediapipe/framework:calculator_framework",
|
||||
|
|
|
@ -65,6 +65,7 @@ class GetVectorItemCalculator : public Node {
|
|||
MEDIAPIPE_NODE_CONTRACT(kIn, kIdx, kOut);
|
||||
|
||||
absl::Status Open(CalculatorContext* cc) final {
|
||||
cc->SetOffset(mediapipe::TimestampDiff(0));
|
||||
auto& options = cc->Options<mediapipe::GetVectorItemCalculatorOptions>();
|
||||
RET_CHECK(kIdx(cc).IsConnected() || options.has_item_index());
|
||||
return absl::OkStatus();
|
||||
|
@ -90,8 +91,12 @@ class GetVectorItemCalculator : public Node {
|
|||
return absl::OkStatus();
|
||||
}
|
||||
|
||||
RET_CHECK(idx >= 0 && idx < items.size());
|
||||
kOut(cc).Send(items[idx]);
|
||||
RET_CHECK(idx >= 0);
|
||||
RET_CHECK(options.output_empty_on_oob() || idx < items.size());
|
||||
|
||||
if (idx < items.size()) {
|
||||
kOut(cc).Send(items[idx]);
|
||||
}
|
||||
|
||||
return absl::OkStatus();
|
||||
}
|
||||
|
|
|
@ -26,4 +26,7 @@ message GetVectorItemCalculatorOptions {
|
|||
// Index of vector item to get. INDEX input stream can be used instead, or to
|
||||
// override.
|
||||
optional int32 item_index = 1;
|
||||
|
||||
// Set to true to output an empty packet when the index is out of bounds.
|
||||
optional bool output_empty_on_oob = 2;
|
||||
}
|
||||
|
|
|
@ -32,18 +32,21 @@ CalculatorRunner MakeRunnerWithStream() {
|
|||
)");
|
||||
}
|
||||
|
||||
CalculatorRunner MakeRunnerWithOptions(int set_index) {
|
||||
return CalculatorRunner(absl::StrFormat(R"(
|
||||
CalculatorRunner MakeRunnerWithOptions(int set_index,
|
||||
bool output_empty_on_oob = false) {
|
||||
return CalculatorRunner(
|
||||
absl::StrFormat(R"(
|
||||
calculator: "TestGetIntVectorItemCalculator"
|
||||
input_stream: "VECTOR:vector_stream"
|
||||
output_stream: "ITEM:item_stream"
|
||||
options {
|
||||
[mediapipe.GetVectorItemCalculatorOptions.ext] {
|
||||
item_index: %d
|
||||
output_empty_on_oob: %s
|
||||
}
|
||||
}
|
||||
)",
|
||||
set_index));
|
||||
set_index, output_empty_on_oob ? "true" : "false"));
|
||||
}
|
||||
|
||||
void AddInputVector(CalculatorRunner& runner, const std::vector<int>& inputs,
|
||||
|
@ -140,8 +143,7 @@ TEST(TestGetIntVectorItemCalculatorTest, StreamIndexBoundsCheckFail1) {
|
|||
|
||||
absl::Status status = runner.Run();
|
||||
ASSERT_FALSE(status.ok());
|
||||
EXPECT_THAT(status.message(),
|
||||
testing::HasSubstr("idx >= 0 && idx < items.size()"));
|
||||
EXPECT_THAT(status.message(), testing::HasSubstr("idx >= 0"));
|
||||
}
|
||||
|
||||
TEST(TestGetIntVectorItemCalculatorTest, StreamIndexBoundsCheckFail2) {
|
||||
|
@ -155,7 +157,8 @@ TEST(TestGetIntVectorItemCalculatorTest, StreamIndexBoundsCheckFail2) {
|
|||
absl::Status status = runner.Run();
|
||||
ASSERT_FALSE(status.ok());
|
||||
EXPECT_THAT(status.message(),
|
||||
testing::HasSubstr("idx >= 0 && idx < items.size()"));
|
||||
testing::HasSubstr(
|
||||
"options.output_empty_on_oob() || idx < items.size()"));
|
||||
}
|
||||
|
||||
TEST(TestGetIntVectorItemCalculatorTest, OptionsIndexBoundsCheckFail1) {
|
||||
|
@ -167,8 +170,7 @@ TEST(TestGetIntVectorItemCalculatorTest, OptionsIndexBoundsCheckFail1) {
|
|||
|
||||
absl::Status status = runner.Run();
|
||||
ASSERT_FALSE(status.ok());
|
||||
EXPECT_THAT(status.message(),
|
||||
testing::HasSubstr("idx >= 0 && idx < items.size()"));
|
||||
EXPECT_THAT(status.message(), testing::HasSubstr("idx >= 0"));
|
||||
}
|
||||
|
||||
TEST(TestGetIntVectorItemCalculatorTest, OptionsIndexBoundsCheckFail2) {
|
||||
|
@ -181,7 +183,21 @@ TEST(TestGetIntVectorItemCalculatorTest, OptionsIndexBoundsCheckFail2) {
|
|||
absl::Status status = runner.Run();
|
||||
ASSERT_FALSE(status.ok());
|
||||
EXPECT_THAT(status.message(),
|
||||
testing::HasSubstr("idx >= 0 && idx < items.size()"));
|
||||
testing::HasSubstr(
|
||||
"options.output_empty_on_oob() || idx < items.size()"));
|
||||
}
|
||||
|
||||
TEST(TestGetIntVectorItemCalculatorTest, OptionsIndexBoundsCheckFail3) {
|
||||
const int try_index = 3;
|
||||
CalculatorRunner runner = MakeRunnerWithOptions(try_index, true);
|
||||
const std::vector<int> inputs = {1, 2, 3};
|
||||
|
||||
AddInputVector(runner, inputs, 1);
|
||||
|
||||
MP_ASSERT_OK(runner.Run());
|
||||
|
||||
const std::vector<Packet>& outputs = runner.Outputs().Tag("ITEM").packets;
|
||||
EXPECT_THAT(outputs, testing::ElementsAre());
|
||||
}
|
||||
|
||||
TEST(TestGetIntVectorItemCalculatorTest, IndexStreamTwoTimestamps) {
|
||||
|
|
|
@ -15,6 +15,7 @@ limitations under the License.
|
|||
|
||||
#include "mediapipe/calculators/core/merge_to_vector_calculator.h"
|
||||
|
||||
#include "mediapipe/framework/formats/detection.pb.h"
|
||||
#include "mediapipe/framework/formats/image.h"
|
||||
|
||||
namespace mediapipe {
|
||||
|
@ -23,5 +24,13 @@ namespace api2 {
|
|||
typedef MergeToVectorCalculator<mediapipe::Image> MergeImagesToVectorCalculator;
|
||||
MEDIAPIPE_REGISTER_NODE(MergeImagesToVectorCalculator);
|
||||
|
||||
typedef MergeToVectorCalculator<mediapipe::GpuBuffer>
|
||||
MergeGpuBuffersToVectorCalculator;
|
||||
MEDIAPIPE_REGISTER_NODE(MergeGpuBuffersToVectorCalculator);
|
||||
|
||||
typedef MergeToVectorCalculator<mediapipe::Detection>
|
||||
MergeDetectionsToVectorCalculator;
|
||||
MEDIAPIPE_REGISTER_NODE(MergeDetectionsToVectorCalculator);
|
||||
|
||||
} // namespace api2
|
||||
} // namespace mediapipe
|
||||
|
|
|
@ -42,11 +42,20 @@ class MergeToVectorCalculator : public Node {
|
|||
return absl::OkStatus();
|
||||
}
|
||||
|
||||
absl::Status Open(::mediapipe::CalculatorContext* cc) {
|
||||
cc->SetOffset(::mediapipe::TimestampDiff(0));
|
||||
return absl::OkStatus();
|
||||
}
|
||||
|
||||
absl::Status Process(CalculatorContext* cc) {
|
||||
const int input_num = kIn(cc).Count();
|
||||
std::vector<T> output_vector(input_num);
|
||||
std::transform(kIn(cc).begin(), kIn(cc).end(), output_vector.begin(),
|
||||
[](const auto& elem) -> T { return elem.Get(); });
|
||||
std::vector<T> output_vector;
|
||||
for (auto it = kIn(cc).begin(); it != kIn(cc).end(); it++) {
|
||||
const auto& elem = *it;
|
||||
if (!elem.IsEmpty()) {
|
||||
output_vector.push_back(elem.Get());
|
||||
}
|
||||
}
|
||||
kOut(cc).Send(output_vector);
|
||||
return absl::OkStatus();
|
||||
}
|
||||
|
|
|
@ -37,7 +37,8 @@ enum { ATTRIB_VERTEX, ATTRIB_TEXTURE_POSITION, NUM_ATTRIBUTES };
|
|||
namespace mediapipe {
|
||||
|
||||
namespace {
|
||||
|
||||
using ::mediapipe::NormalizedRect;
|
||||
using ::mediapipe::Rect;
|
||||
#if !MEDIAPIPE_DISABLE_GPU
|
||||
|
||||
#endif // !MEDIAPIPE_DISABLE_GPU
|
||||
|
|
|
@ -195,11 +195,11 @@ TEST(ImageCroppingCalculatorTest, RedundantSpecWithInputStream) {
|
|||
auto cc = absl::make_unique<CalculatorContext>(
|
||||
calculator_state.get(), inputTags, tool::CreateTagMap({}).value());
|
||||
auto& inputs = cc->Inputs();
|
||||
mediapipe::Rect rect = ParseTextProtoOrDie<mediapipe::Rect>(
|
||||
Rect rect = ParseTextProtoOrDie<Rect>(
|
||||
R"pb(
|
||||
width: 1 height: 1 x_center: 40 y_center: 40 rotation: 0.5
|
||||
)pb");
|
||||
inputs.Tag(kRectTag).Value() = MakePacket<mediapipe::Rect>(rect);
|
||||
inputs.Tag(kRectTag).Value() = MakePacket<Rect>(rect);
|
||||
RectSpec expectRect = {
|
||||
.width = 1,
|
||||
.height = 1,
|
||||
|
|
|
@ -21,7 +21,7 @@ package(default_visibility = ["//visibility:private"])
|
|||
proto_library(
|
||||
name = "callback_packet_calculator_proto",
|
||||
srcs = ["callback_packet_calculator.proto"],
|
||||
visibility = ["//visibility:public"],
|
||||
visibility = ["//mediapipe/framework:__subpackages__"],
|
||||
deps = ["//mediapipe/framework:calculator_proto"],
|
||||
)
|
||||
|
||||
|
@ -29,14 +29,14 @@ mediapipe_cc_proto_library(
|
|||
name = "callback_packet_calculator_cc_proto",
|
||||
srcs = ["callback_packet_calculator.proto"],
|
||||
cc_deps = ["//mediapipe/framework:calculator_cc_proto"],
|
||||
visibility = ["//visibility:public"],
|
||||
visibility = ["//mediapipe/framework:__subpackages__"],
|
||||
deps = [":callback_packet_calculator_proto"],
|
||||
)
|
||||
|
||||
cc_library(
|
||||
name = "callback_packet_calculator",
|
||||
srcs = ["callback_packet_calculator.cc"],
|
||||
visibility = ["//visibility:public"],
|
||||
visibility = ["//mediapipe/framework:__subpackages__"],
|
||||
deps = [
|
||||
":callback_packet_calculator_cc_proto",
|
||||
"//mediapipe/framework:calculator_base",
|
||||
|
|
|
@ -24,7 +24,7 @@ load("//mediapipe/framework:encode_binary_proto.bzl", "encode_binary_proto")
|
|||
|
||||
licenses(["notice"])
|
||||
|
||||
package(default_visibility = ["//visibility:private"])
|
||||
package(default_visibility = ["//visibility:public"])
|
||||
|
||||
exports_files(
|
||||
glob(["testdata/image_to_tensor/*"]),
|
||||
|
@ -44,9 +44,6 @@ selects.config_setting_group(
|
|||
mediapipe_proto_library(
|
||||
name = "audio_to_tensor_calculator_proto",
|
||||
srcs = ["audio_to_tensor_calculator.proto"],
|
||||
visibility = [
|
||||
"//mediapipe/framework:mediapipe_internal",
|
||||
],
|
||||
deps = [
|
||||
"//mediapipe/framework:calculator_options_proto",
|
||||
"//mediapipe/framework:calculator_proto",
|
||||
|
@ -64,9 +61,6 @@ cc_library(
|
|||
],
|
||||
"//conditions:default": [],
|
||||
}),
|
||||
visibility = [
|
||||
"//mediapipe/framework:mediapipe_internal",
|
||||
],
|
||||
deps = [
|
||||
":audio_to_tensor_calculator_cc_proto",
|
||||
"//mediapipe/framework:calculator_framework",
|
||||
|
@ -113,9 +107,6 @@ cc_test(
|
|||
mediapipe_proto_library(
|
||||
name = "tensors_to_audio_calculator_proto",
|
||||
srcs = ["tensors_to_audio_calculator.proto"],
|
||||
visibility = [
|
||||
"//mediapipe/framework:mediapipe_internal",
|
||||
],
|
||||
deps = [
|
||||
"//mediapipe/framework:calculator_options_proto",
|
||||
"//mediapipe/framework:calculator_proto",
|
||||
|
@ -125,9 +116,6 @@ mediapipe_proto_library(
|
|||
cc_library(
|
||||
name = "tensors_to_audio_calculator",
|
||||
srcs = ["tensors_to_audio_calculator.cc"],
|
||||
visibility = [
|
||||
"//mediapipe/framework:mediapipe_internal",
|
||||
],
|
||||
deps = [
|
||||
":tensors_to_audio_calculator_cc_proto",
|
||||
"//mediapipe/framework:calculator_framework",
|
||||
|
@ -164,9 +152,6 @@ cc_test(
|
|||
mediapipe_proto_library(
|
||||
name = "feedback_tensors_calculator_proto",
|
||||
srcs = ["feedback_tensors_calculator.proto"],
|
||||
visibility = [
|
||||
"//mediapipe/framework:mediapipe_internal",
|
||||
],
|
||||
deps = [
|
||||
"//mediapipe/framework:calculator_options_proto",
|
||||
"//mediapipe/framework:calculator_proto",
|
||||
|
@ -184,9 +169,6 @@ cc_library(
|
|||
],
|
||||
"//conditions:default": [],
|
||||
}),
|
||||
visibility = [
|
||||
"//mediapipe/framework:mediapipe_internal",
|
||||
],
|
||||
deps = [
|
||||
":feedback_tensors_calculator_cc_proto",
|
||||
"//mediapipe/framework:calculator_framework",
|
||||
|
@ -216,9 +198,6 @@ cc_test(
|
|||
mediapipe_proto_library(
|
||||
name = "bert_preprocessor_calculator_proto",
|
||||
srcs = ["bert_preprocessor_calculator.proto"],
|
||||
visibility = [
|
||||
"//mediapipe/framework:mediapipe_internal",
|
||||
],
|
||||
deps = [
|
||||
"//mediapipe/framework:calculator_options_proto",
|
||||
"//mediapipe/framework:calculator_proto",
|
||||
|
@ -228,9 +207,6 @@ mediapipe_proto_library(
|
|||
cc_library(
|
||||
name = "bert_preprocessor_calculator",
|
||||
srcs = ["bert_preprocessor_calculator.cc"],
|
||||
visibility = [
|
||||
"//mediapipe/framework:mediapipe_internal",
|
||||
],
|
||||
deps = [
|
||||
":bert_preprocessor_calculator_cc_proto",
|
||||
"//mediapipe/framework:calculator_framework",
|
||||
|
@ -274,9 +250,6 @@ cc_test(
|
|||
mediapipe_proto_library(
|
||||
name = "regex_preprocessor_calculator_proto",
|
||||
srcs = ["regex_preprocessor_calculator.proto"],
|
||||
visibility = [
|
||||
"//mediapipe/framework:mediapipe_internal",
|
||||
],
|
||||
deps = [
|
||||
"//mediapipe/framework:calculator_options_proto",
|
||||
"//mediapipe/framework:calculator_proto",
|
||||
|
@ -286,9 +259,6 @@ mediapipe_proto_library(
|
|||
cc_library(
|
||||
name = "regex_preprocessor_calculator",
|
||||
srcs = ["regex_preprocessor_calculator.cc"],
|
||||
visibility = [
|
||||
"//mediapipe/framework:mediapipe_internal",
|
||||
],
|
||||
deps = [
|
||||
":regex_preprocessor_calculator_cc_proto",
|
||||
"//mediapipe/framework:calculator_framework",
|
||||
|
@ -330,9 +300,6 @@ cc_test(
|
|||
cc_library(
|
||||
name = "text_to_tensor_calculator",
|
||||
srcs = ["text_to_tensor_calculator.cc"],
|
||||
visibility = [
|
||||
"//mediapipe/framework:mediapipe_internal",
|
||||
],
|
||||
deps = [
|
||||
"//mediapipe/framework:calculator_context",
|
||||
"//mediapipe/framework:calculator_framework",
|
||||
|
@ -405,7 +372,6 @@ cc_test(
|
|||
mediapipe_proto_library(
|
||||
name = "inference_calculator_proto",
|
||||
srcs = ["inference_calculator.proto"],
|
||||
visibility = ["//visibility:public"],
|
||||
deps = [
|
||||
"//mediapipe/framework:calculator_options_proto",
|
||||
"//mediapipe/framework:calculator_proto",
|
||||
|
@ -432,7 +398,6 @@ cc_library(
|
|||
],
|
||||
"//conditions:default": [],
|
||||
}),
|
||||
visibility = ["//visibility:public"],
|
||||
deps = [
|
||||
":inference_calculator_cc_proto",
|
||||
":inference_calculator_options_lib",
|
||||
|
@ -457,7 +422,6 @@ cc_library(
|
|||
name = "inference_calculator_gl",
|
||||
srcs = ["inference_calculator_gl.cc"],
|
||||
tags = ["nomac"], # config problem with cpuinfo via TF
|
||||
visibility = ["//visibility:public"],
|
||||
deps = [
|
||||
":inference_calculator_cc_proto",
|
||||
":inference_calculator_interface",
|
||||
|
@ -475,7 +439,6 @@ cc_library(
|
|||
name = "inference_calculator_gl_advanced",
|
||||
srcs = ["inference_calculator_gl_advanced.cc"],
|
||||
tags = ["nomac"],
|
||||
visibility = ["//visibility:public"],
|
||||
deps = [
|
||||
":inference_calculator_interface",
|
||||
"@com_google_absl//absl/memory",
|
||||
|
@ -506,7 +469,6 @@ cc_library(
|
|||
"-framework MetalKit",
|
||||
],
|
||||
tags = ["ios"],
|
||||
visibility = ["//visibility:public"],
|
||||
deps = [
|
||||
"inference_calculator_interface",
|
||||
"//mediapipe/gpu:MPPMetalHelper",
|
||||
|
@ -535,7 +497,6 @@ cc_library(
|
|||
],
|
||||
"//conditions:default": [],
|
||||
}),
|
||||
visibility = ["//visibility:public"],
|
||||
deps = [
|
||||
"//mediapipe/framework:calculator_context",
|
||||
"//mediapipe/framework/formats:tensor",
|
||||
|
@ -555,7 +516,6 @@ cc_library(
|
|||
],
|
||||
"//conditions:default": [],
|
||||
}),
|
||||
visibility = ["//visibility:public"],
|
||||
deps = [
|
||||
":inference_runner",
|
||||
"//mediapipe/framework:mediapipe_profiling",
|
||||
|
@ -585,7 +545,6 @@ cc_library(
|
|||
],
|
||||
"//conditions:default": [],
|
||||
}),
|
||||
visibility = ["//visibility:public"],
|
||||
deps = [
|
||||
":inference_calculator_interface",
|
||||
":inference_calculator_utils",
|
||||
|
@ -632,7 +591,6 @@ cc_library(
|
|||
],
|
||||
"//conditions:default": [],
|
||||
}),
|
||||
visibility = ["//visibility:public"],
|
||||
deps = [
|
||||
":inference_calculator_interface",
|
||||
":inference_calculator_utils",
|
||||
|
@ -648,7 +606,6 @@ cc_library(
|
|||
|
||||
cc_library(
|
||||
name = "inference_calculator_gl_if_compute_shader_available",
|
||||
visibility = ["//visibility:public"],
|
||||
deps = selects.with_or({
|
||||
":compute_shader_unavailable": [],
|
||||
"//conditions:default": [
|
||||
|
@ -664,7 +621,6 @@ cc_library(
|
|||
# inference_calculator_interface.
|
||||
cc_library(
|
||||
name = "inference_calculator",
|
||||
visibility = ["//visibility:public"],
|
||||
deps = [
|
||||
":inference_calculator_interface",
|
||||
":inference_calculator_cpu",
|
||||
|
@ -678,7 +634,6 @@ cc_library(
|
|||
mediapipe_proto_library(
|
||||
name = "tensor_converter_calculator_proto",
|
||||
srcs = ["tensor_converter_calculator.proto"],
|
||||
visibility = ["//visibility:public"],
|
||||
deps = [
|
||||
"//mediapipe/framework:calculator_options_proto",
|
||||
"//mediapipe/framework:calculator_proto",
|
||||
|
@ -703,7 +658,6 @@ cc_library(
|
|||
],
|
||||
"//conditions:default": [],
|
||||
}),
|
||||
visibility = ["//visibility:public"],
|
||||
deps = [
|
||||
":tensor_converter_calculator_cc_proto",
|
||||
"//mediapipe/framework:calculator_framework",
|
||||
|
@ -722,6 +676,7 @@ cc_library(
|
|||
|
||||
cc_library(
|
||||
name = "tensor_converter_calculator_gpu_deps",
|
||||
visibility = ["//visibility:private"],
|
||||
deps = select({
|
||||
"//mediapipe:android": [
|
||||
"//mediapipe/gpu:gl_calculator_helper",
|
||||
|
@ -766,7 +721,6 @@ cc_test(
|
|||
mediapipe_proto_library(
|
||||
name = "tensors_to_detections_calculator_proto",
|
||||
srcs = ["tensors_to_detections_calculator.proto"],
|
||||
visibility = ["//visibility:public"],
|
||||
deps = [
|
||||
"//mediapipe/framework:calculator_options_proto",
|
||||
"//mediapipe/framework:calculator_proto",
|
||||
|
@ -791,7 +745,6 @@ cc_library(
|
|||
],
|
||||
"//conditions:default": [],
|
||||
}),
|
||||
visibility = ["//visibility:public"],
|
||||
deps = [
|
||||
":tensors_to_detections_calculator_cc_proto",
|
||||
"//mediapipe/framework/formats:detection_cc_proto",
|
||||
|
@ -814,6 +767,7 @@ cc_library(
|
|||
|
||||
cc_library(
|
||||
name = "tensors_to_detections_calculator_gpu_deps",
|
||||
visibility = ["//visibility:private"],
|
||||
deps = select({
|
||||
"//mediapipe:ios": [
|
||||
"//mediapipe/gpu:MPPMetalUtil",
|
||||
|
@ -829,7 +783,6 @@ cc_library(
|
|||
mediapipe_proto_library(
|
||||
name = "tensors_to_landmarks_calculator_proto",
|
||||
srcs = ["tensors_to_landmarks_calculator.proto"],
|
||||
visibility = ["//visibility:public"],
|
||||
deps = [
|
||||
"//mediapipe/framework:calculator_options_proto",
|
||||
"//mediapipe/framework:calculator_proto",
|
||||
|
@ -846,7 +799,6 @@ cc_library(
|
|||
],
|
||||
"//conditions:default": [],
|
||||
}),
|
||||
visibility = ["//visibility:public"],
|
||||
deps = [
|
||||
":tensors_to_landmarks_calculator_cc_proto",
|
||||
"//mediapipe/framework:calculator_framework",
|
||||
|
@ -861,7 +813,6 @@ cc_library(
|
|||
mediapipe_proto_library(
|
||||
name = "landmarks_to_tensor_calculator_proto",
|
||||
srcs = ["landmarks_to_tensor_calculator.proto"],
|
||||
visibility = ["//visibility:public"],
|
||||
deps = [
|
||||
"//mediapipe/framework:calculator_options_proto",
|
||||
"//mediapipe/framework:calculator_proto",
|
||||
|
@ -879,7 +830,6 @@ cc_library(
|
|||
],
|
||||
"//conditions:default": [],
|
||||
}),
|
||||
visibility = ["//visibility:public"],
|
||||
deps = [
|
||||
":landmarks_to_tensor_calculator_cc_proto",
|
||||
"//mediapipe/framework:calculator_framework",
|
||||
|
@ -912,7 +862,6 @@ cc_test(
|
|||
mediapipe_proto_library(
|
||||
name = "tensors_to_floats_calculator_proto",
|
||||
srcs = ["tensors_to_floats_calculator.proto"],
|
||||
visibility = ["//visibility:public"],
|
||||
deps = [
|
||||
"//mediapipe/framework:calculator_options_proto",
|
||||
"//mediapipe/framework:calculator_proto",
|
||||
|
@ -929,7 +878,6 @@ cc_library(
|
|||
],
|
||||
"//conditions:default": [],
|
||||
}),
|
||||
visibility = ["//visibility:public"],
|
||||
deps = [
|
||||
":tensors_to_floats_calculator_cc_proto",
|
||||
"//mediapipe/framework:calculator_framework",
|
||||
|
@ -967,7 +915,6 @@ cc_library(
|
|||
],
|
||||
"//conditions:default": [],
|
||||
}),
|
||||
visibility = ["//visibility:public"],
|
||||
deps = [
|
||||
":tensors_to_classification_calculator_cc_proto",
|
||||
"@com_google_absl//absl/container:node_hash_map",
|
||||
|
@ -998,7 +945,6 @@ cc_library(
|
|||
mediapipe_proto_library(
|
||||
name = "tensors_to_classification_calculator_proto",
|
||||
srcs = ["tensors_to_classification_calculator.proto"],
|
||||
visibility = ["//visibility:public"],
|
||||
deps = [
|
||||
"//mediapipe/framework:calculator_options_proto",
|
||||
"//mediapipe/framework:calculator_proto",
|
||||
|
@ -1036,7 +982,6 @@ cc_library(
|
|||
"//conditions:default": [],
|
||||
}),
|
||||
features = ["-layering_check"], # allow depending on image_to_tensor_calculator_gpu_deps
|
||||
visibility = ["//visibility:public"],
|
||||
deps = [
|
||||
":image_to_tensor_calculator_cc_proto",
|
||||
":image_to_tensor_converter",
|
||||
|
@ -1065,6 +1010,7 @@ cc_library(
|
|||
|
||||
cc_library(
|
||||
name = "image_to_tensor_calculator_gpu_deps",
|
||||
visibility = ["//visibility:private"],
|
||||
deps = selects.with_or({
|
||||
"//mediapipe:android": [
|
||||
":image_to_tensor_converter_gl_buffer",
|
||||
|
@ -1088,7 +1034,6 @@ cc_library(
|
|||
mediapipe_proto_library(
|
||||
name = "image_to_tensor_calculator_proto",
|
||||
srcs = ["image_to_tensor_calculator.proto"],
|
||||
visibility = ["//visibility:public"],
|
||||
deps = [
|
||||
"//mediapipe/framework:calculator_options_proto",
|
||||
"//mediapipe/framework:calculator_proto",
|
||||
|
@ -1151,7 +1096,6 @@ cc_library(
|
|||
],
|
||||
"//conditions:default": [],
|
||||
}),
|
||||
visibility = ["//visibility:public"],
|
||||
deps = [
|
||||
":image_to_tensor_utils",
|
||||
"//mediapipe/framework/formats:image",
|
||||
|
@ -1171,7 +1115,6 @@ cc_library(
|
|||
],
|
||||
"//conditions:default": [],
|
||||
}),
|
||||
visibility = ["//visibility:public"],
|
||||
deps = [
|
||||
":image_to_tensor_converter",
|
||||
":image_to_tensor_utils",
|
||||
|
@ -1191,6 +1134,7 @@ cc_library(
|
|||
name = "image_to_tensor_converter_gl_buffer",
|
||||
srcs = ["image_to_tensor_converter_gl_buffer.cc"],
|
||||
hdrs = ["image_to_tensor_converter_gl_buffer.h"],
|
||||
visibility = ["//visibility:private"],
|
||||
deps = ["//mediapipe/framework:port"] + selects.with_or({
|
||||
"//mediapipe:apple": [],
|
||||
"//conditions:default": [
|
||||
|
@ -1224,6 +1168,7 @@ cc_library(
|
|||
name = "image_to_tensor_converter_gl_texture",
|
||||
srcs = ["image_to_tensor_converter_gl_texture.cc"],
|
||||
hdrs = ["image_to_tensor_converter_gl_texture.h"],
|
||||
visibility = ["//visibility:private"],
|
||||
deps = ["//mediapipe/framework:port"] + select({
|
||||
"//mediapipe/gpu:disable_gpu": [],
|
||||
"//conditions:default": [
|
||||
|
@ -1248,6 +1193,7 @@ cc_library(
|
|||
name = "image_to_tensor_converter_gl_utils",
|
||||
srcs = ["image_to_tensor_converter_gl_utils.cc"],
|
||||
hdrs = ["image_to_tensor_converter_gl_utils.h"],
|
||||
visibility = ["//visibility:private"],
|
||||
deps = ["//mediapipe/framework:port"] + select({
|
||||
"//mediapipe/gpu:disable_gpu": [],
|
||||
"//conditions:default": [
|
||||
|
@ -1277,6 +1223,7 @@ cc_library(
|
|||
],
|
||||
"//conditions:default": [],
|
||||
}),
|
||||
visibility = ["//visibility:private"],
|
||||
deps = ["//mediapipe/framework:port"] + select({
|
||||
"//mediapipe:apple": [
|
||||
":image_to_tensor_converter",
|
||||
|
@ -1308,7 +1255,6 @@ cc_library(
|
|||
],
|
||||
"//conditions:default": [],
|
||||
}),
|
||||
visibility = ["//visibility:public"],
|
||||
deps = [
|
||||
":image_to_tensor_calculator_cc_proto",
|
||||
"@com_google_absl//absl/status",
|
||||
|
@ -1351,7 +1297,6 @@ selects.config_setting_group(
|
|||
mediapipe_proto_library(
|
||||
name = "tensors_to_segmentation_calculator_proto",
|
||||
srcs = ["tensors_to_segmentation_calculator.proto"],
|
||||
visibility = ["//visibility:public"],
|
||||
deps = [
|
||||
"//mediapipe/framework:calculator_options_proto",
|
||||
"//mediapipe/framework:calculator_proto",
|
||||
|
@ -1369,7 +1314,6 @@ cc_library(
|
|||
],
|
||||
"//conditions:default": [],
|
||||
}),
|
||||
visibility = ["//visibility:public"],
|
||||
deps = [
|
||||
":tensors_to_segmentation_calculator_cc_proto",
|
||||
"@com_google_absl//absl/strings:str_format",
|
||||
|
@ -1427,7 +1371,6 @@ cc_library(
|
|||
],
|
||||
"//conditions:default": [],
|
||||
}),
|
||||
visibility = ["//visibility:public"],
|
||||
deps = [
|
||||
"//mediapipe/framework:calculator_context",
|
||||
"//mediapipe/framework:calculator_framework",
|
||||
|
|
|
@ -17,6 +17,7 @@ syntax = "proto2";
|
|||
package mediapipe;
|
||||
|
||||
import "mediapipe/framework/calculator.proto";
|
||||
import "mediapipe/framework/calculator_options.proto";
|
||||
|
||||
option java_package = "com.google.mediapipe.calculator.proto";
|
||||
option java_outer_classname = "InferenceCalculatorProto";
|
||||
|
|
|
@ -456,6 +456,23 @@ cc_library(
|
|||
alwayslink = 1,
|
||||
)
|
||||
|
||||
cc_library(
|
||||
name = "detections_deduplicate_calculator",
|
||||
srcs = [
|
||||
"detections_deduplicate_calculator.cc",
|
||||
],
|
||||
deps = [
|
||||
"//mediapipe/framework:calculator_framework",
|
||||
"//mediapipe/framework/api2:node",
|
||||
"//mediapipe/framework/api2:port",
|
||||
"//mediapipe/framework/formats:detection_cc_proto",
|
||||
"//mediapipe/framework/formats:location_data_cc_proto",
|
||||
"@com_google_absl//absl/container:flat_hash_map",
|
||||
"@com_google_absl//absl/status",
|
||||
],
|
||||
alwayslink = 1,
|
||||
)
|
||||
|
||||
cc_library(
|
||||
name = "rect_transformation_calculator",
|
||||
srcs = ["rect_transformation_calculator.cc"],
|
||||
|
|
114
mediapipe/calculators/util/detections_deduplicate_calculator.cc
Normal file
114
mediapipe/calculators/util/detections_deduplicate_calculator.cc
Normal file
|
@ -0,0 +1,114 @@
|
|||
/* Copyright 2022 The MediaPipe Authors. All Rights Reserved.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
==============================================================================*/
|
||||
|
||||
#include <algorithm>
|
||||
#include <memory>
|
||||
#include <utility>
|
||||
#include <vector>
|
||||
|
||||
#include "absl/container/flat_hash_map.h"
|
||||
#include "absl/status/status.h"
|
||||
#include "mediapipe/framework/api2/node.h"
|
||||
#include "mediapipe/framework/api2/port.h"
|
||||
#include "mediapipe/framework/calculator_framework.h"
|
||||
#include "mediapipe/framework/formats/detection.pb.h"
|
||||
#include "mediapipe/framework/formats/location_data.pb.h"
|
||||
|
||||
namespace mediapipe {
|
||||
namespace api2 {
|
||||
namespace {
|
||||
|
||||
struct BoundingBoxHash {
|
||||
size_t operator()(const LocationData::BoundingBox& bbox) const {
|
||||
return std::hash<int>{}(bbox.xmin()) ^ std::hash<int>{}(bbox.ymin()) ^
|
||||
std::hash<int>{}(bbox.width()) ^ std::hash<int>{}(bbox.height());
|
||||
}
|
||||
};
|
||||
|
||||
struct BoundingBoxEq {
|
||||
bool operator()(const LocationData::BoundingBox& lhs,
|
||||
const LocationData::BoundingBox& rhs) const {
|
||||
return lhs.xmin() == rhs.xmin() && lhs.ymin() == rhs.ymin() &&
|
||||
lhs.width() == rhs.width() && lhs.height() == rhs.height();
|
||||
}
|
||||
};
|
||||
|
||||
} // namespace
|
||||
|
||||
// This Calculator deduplicates the bunding boxes with exactly the same
|
||||
// coordinates, and folds the labels into a single Detection proto. Note
|
||||
// non-maximum-suppression remove the overlapping bounding boxes within a class,
|
||||
// while the deduplication operation merges bounding boxes from different
|
||||
// classes.
|
||||
|
||||
// Example config:
|
||||
// node {
|
||||
// calculator: "DetectionsDeduplicateCalculator"
|
||||
// input_stream: "detections"
|
||||
// output_stream: "deduplicated_detections"
|
||||
// }
|
||||
class DetectionsDeduplicateCalculator : public Node {
|
||||
public:
|
||||
static constexpr Input<std::vector<Detection>> kIn{""};
|
||||
static constexpr Output<std::vector<Detection>> kOut{""};
|
||||
|
||||
MEDIAPIPE_NODE_CONTRACT(kIn, kOut);
|
||||
|
||||
absl::Status Open(mediapipe::CalculatorContext* cc) {
|
||||
cc->SetOffset(::mediapipe::TimestampDiff(0));
|
||||
return absl::OkStatus();
|
||||
}
|
||||
|
||||
absl::Status Process(mediapipe::CalculatorContext* cc) {
|
||||
const std::vector<Detection>& raw_detections = kIn(cc).Get();
|
||||
absl::flat_hash_map<LocationData::BoundingBox, Detection*, BoundingBoxHash,
|
||||
BoundingBoxEq>
|
||||
bbox_to_detections;
|
||||
std::vector<Detection> deduplicated_detections;
|
||||
for (const auto& detection : raw_detections) {
|
||||
if (!detection.has_location_data() ||
|
||||
!detection.location_data().has_bounding_box()) {
|
||||
return absl::InvalidArgumentError(
|
||||
"The location data of Detections must be BoundingBox.");
|
||||
}
|
||||
if (bbox_to_detections.contains(
|
||||
detection.location_data().bounding_box())) {
|
||||
// The bbox location already exists. Merge the detection labels into
|
||||
// the existing detection proto.
|
||||
Detection& deduplicated_detection =
|
||||
*bbox_to_detections[detection.location_data().bounding_box()];
|
||||
deduplicated_detection.mutable_score()->MergeFrom(detection.score());
|
||||
deduplicated_detection.mutable_label()->MergeFrom(detection.label());
|
||||
deduplicated_detection.mutable_label_id()->MergeFrom(
|
||||
detection.label_id());
|
||||
deduplicated_detection.mutable_display_name()->MergeFrom(
|
||||
detection.display_name());
|
||||
} else {
|
||||
// The bbox location appears first time. Add the detection to output
|
||||
// detection vector.
|
||||
deduplicated_detections.push_back(detection);
|
||||
bbox_to_detections[detection.location_data().bounding_box()] =
|
||||
&deduplicated_detections.back();
|
||||
}
|
||||
}
|
||||
kOut(cc).Send(std::move(deduplicated_detections));
|
||||
return absl::OkStatus();
|
||||
}
|
||||
};
|
||||
|
||||
MEDIAPIPE_REGISTER_NODE(DetectionsDeduplicateCalculator);
|
||||
|
||||
} // namespace api2
|
||||
} // namespace mediapipe
|
|
@ -37,6 +37,9 @@ constexpr char kNormRectTag[] = "NORM_RECT";
|
|||
constexpr char kRectsTag[] = "RECTS";
|
||||
constexpr char kNormRectsTag[] = "NORM_RECTS";
|
||||
|
||||
using ::mediapipe::NormalizedRect;
|
||||
using ::mediapipe::Rect;
|
||||
|
||||
constexpr float kMinFloat = std::numeric_limits<float>::lowest();
|
||||
constexpr float kMaxFloat = std::numeric_limits<float>::max();
|
||||
|
||||
|
|
|
@ -39,6 +39,9 @@ constexpr char kImageSizeTag[] = "IMAGE_SIZE";
|
|||
constexpr char kRectTag[] = "RECT";
|
||||
constexpr char kDetectionTag[] = "DETECTION";
|
||||
|
||||
using ::mediapipe::NormalizedRect;
|
||||
using ::mediapipe::Rect;
|
||||
|
||||
MATCHER_P4(RectEq, x_center, y_center, width, height, "") {
|
||||
return testing::Value(arg.x_center(), testing::Eq(x_center)) &&
|
||||
testing::Value(arg.y_center(), testing::Eq(y_center)) &&
|
||||
|
|
|
@ -24,6 +24,8 @@
|
|||
|
||||
namespace mediapipe {
|
||||
|
||||
using ::mediapipe::NormalizedRect;
|
||||
|
||||
namespace {
|
||||
|
||||
constexpr char kLandmarksTag[] = "NORM_LANDMARKS";
|
||||
|
|
|
@ -35,7 +35,9 @@ constexpr char kObjectScaleRoiTag[] = "OBJECT_SCALE_ROI";
|
|||
constexpr char kNormalizedFilteredLandmarksTag[] = "NORM_FILTERED_LANDMARKS";
|
||||
constexpr char kFilteredLandmarksTag[] = "FILTERED_LANDMARKS";
|
||||
|
||||
using ::mediapipe::NormalizedRect;
|
||||
using mediapipe::OneEuroFilter;
|
||||
using ::mediapipe::Rect;
|
||||
using mediapipe::RelativeVelocityFilter;
|
||||
|
||||
void NormalizedLandmarksToLandmarks(
|
||||
|
|
|
@ -23,6 +23,8 @@ namespace {
|
|||
constexpr char kNormRectTag[] = "NORM_RECT";
|
||||
constexpr char kNormReferenceRectTag[] = "NORM_REFERENCE_RECT";
|
||||
|
||||
using ::mediapipe::NormalizedRect;
|
||||
|
||||
} // namespace
|
||||
|
||||
// Projects rectangle from reference coordinate system (defined by reference
|
||||
|
|
|
@ -29,6 +29,9 @@ constexpr char kNormRectsTag[] = "NORM_RECTS";
|
|||
constexpr char kRectsTag[] = "RECTS";
|
||||
constexpr char kRenderDataTag[] = "RENDER_DATA";
|
||||
|
||||
using ::mediapipe::NormalizedRect;
|
||||
using ::mediapipe::Rect;
|
||||
|
||||
RenderAnnotation::Rectangle* NewRect(
|
||||
const RectToRenderDataCalculatorOptions& options, RenderData* render_data) {
|
||||
auto* annotation = render_data->add_render_annotations();
|
||||
|
|
|
@ -24,6 +24,8 @@ constexpr char kNormRectTag[] = "NORM_RECT";
|
|||
constexpr char kImageSizeTag[] = "IMAGE_SIZE";
|
||||
constexpr char kRenderScaleTag[] = "RENDER_SCALE";
|
||||
|
||||
using ::mediapipe::NormalizedRect;
|
||||
|
||||
} // namespace
|
||||
|
||||
// A calculator to get scale for RenderData primitives.
|
||||
|
|
|
@ -28,6 +28,9 @@ constexpr char kRectTag[] = "RECT";
|
|||
constexpr char kRectsTag[] = "RECTS";
|
||||
constexpr char kImageSizeTag[] = "IMAGE_SIZE";
|
||||
|
||||
using ::mediapipe::NormalizedRect;
|
||||
using ::mediapipe::Rect;
|
||||
|
||||
// Wraps around an angle in radians to within -M_PI and M_PI.
|
||||
inline float NormalizeRadians(float angle) {
|
||||
return angle - 2 * M_PI * std::floor((angle - (-M_PI)) / (2 * M_PI));
|
||||
|
|
|
@ -22,6 +22,8 @@
|
|||
|
||||
namespace mediapipe {
|
||||
|
||||
using ::mediapipe::NormalizedRect;
|
||||
|
||||
namespace {
|
||||
|
||||
constexpr char kLandmarksTag[] = "LANDMARKS";
|
||||
|
|
|
@ -32,6 +32,8 @@
|
|||
namespace mediapipe {
|
||||
namespace {
|
||||
|
||||
using ::mediapipe::NormalizedRect;
|
||||
|
||||
constexpr int kDetectionUpdateTimeOutMS = 5000;
|
||||
constexpr char kDetectionsTag[] = "DETECTIONS";
|
||||
constexpr char kDetectionBoxesTag[] = "DETECTION_BOXES";
|
||||
|
|
|
@ -18,7 +18,7 @@ import android.content.ClipDescription;
|
|||
import android.content.Context;
|
||||
import android.net.Uri;
|
||||
import android.os.Bundle;
|
||||
import android.support.v7.widget.AppCompatEditText;
|
||||
import androidx.appcompat.widget.AppCompatEditText;
|
||||
import android.util.AttributeSet;
|
||||
import android.util.Log;
|
||||
import android.view.inputmethod.EditorInfo;
|
||||
|
|
|
@ -29,12 +29,6 @@ objc_library(
|
|||
"Base.lproj/LaunchScreen.storyboard",
|
||||
"Base.lproj/Main.storyboard",
|
||||
],
|
||||
sdk_frameworks = [
|
||||
"AVFoundation",
|
||||
"CoreGraphics",
|
||||
"CoreMedia",
|
||||
"UIKit",
|
||||
],
|
||||
visibility = [
|
||||
"//mediapipe:__subpackages__",
|
||||
],
|
||||
|
@ -42,6 +36,10 @@ objc_library(
|
|||
"//mediapipe/objc:mediapipe_framework_ios",
|
||||
"//mediapipe/objc:mediapipe_input_sources_ios",
|
||||
"//mediapipe/objc:mediapipe_layer_renderer",
|
||||
"//third_party/apple_frameworks:AVFoundation",
|
||||
"//third_party/apple_frameworks:CoreGraphics",
|
||||
"//third_party/apple_frameworks:CoreMedia",
|
||||
"//third_party/apple_frameworks:UIKit",
|
||||
],
|
||||
)
|
||||
|
||||
|
|
|
@ -73,13 +73,13 @@ objc_library(
|
|||
"//mediapipe/modules/face_landmark:face_landmark.tflite",
|
||||
],
|
||||
features = ["-layering_check"],
|
||||
sdk_frameworks = [
|
||||
"AVFoundation",
|
||||
"CoreGraphics",
|
||||
"CoreMedia",
|
||||
"UIKit",
|
||||
],
|
||||
deps = [
|
||||
"//mediapipe/framework/formats:matrix_data_cc_proto",
|
||||
"//third_party/apple_frameworks:AVFoundation",
|
||||
"//third_party/apple_frameworks:CoreGraphics",
|
||||
"//third_party/apple_frameworks:CoreMedia",
|
||||
"//third_party/apple_frameworks:UIKit",
|
||||
"//mediapipe/modules/face_geometry/protos:face_geometry_cc_proto",
|
||||
"//mediapipe/objc:mediapipe_framework_ios",
|
||||
"//mediapipe/objc:mediapipe_input_sources_ios",
|
||||
"//mediapipe/objc:mediapipe_layer_renderer",
|
||||
|
@ -87,9 +87,7 @@ objc_library(
|
|||
"//mediapipe:ios_i386": [],
|
||||
"//mediapipe:ios_x86_64": [],
|
||||
"//conditions:default": [
|
||||
"//mediapipe/framework/formats:matrix_data_cc_proto",
|
||||
"//mediapipe/graphs/face_effect:face_effect_gpu_deps",
|
||||
"//mediapipe/modules/face_geometry/protos:face_geometry_cc_proto",
|
||||
],
|
||||
}),
|
||||
)
|
||||
|
|
|
@ -67,12 +67,12 @@ objc_library(
|
|||
],
|
||||
deps = [
|
||||
"//mediapipe/examples/ios/common:CommonMediaPipeAppLibrary",
|
||||
"//mediapipe/framework/formats:landmark_cc_proto",
|
||||
] + select({
|
||||
"//mediapipe:ios_i386": [],
|
||||
"//mediapipe:ios_x86_64": [],
|
||||
"//conditions:default": [
|
||||
"//mediapipe/graphs/face_mesh:mobile_calculators",
|
||||
"//mediapipe/framework/formats:landmark_cc_proto",
|
||||
],
|
||||
}),
|
||||
)
|
||||
|
|
|
@ -68,12 +68,12 @@ objc_library(
|
|||
],
|
||||
deps = [
|
||||
"//mediapipe/examples/ios/common:CommonMediaPipeAppLibrary",
|
||||
"//mediapipe/framework/formats:landmark_cc_proto",
|
||||
] + select({
|
||||
"//mediapipe:ios_i386": [],
|
||||
"//mediapipe:ios_x86_64": [],
|
||||
"//conditions:default": [
|
||||
"//mediapipe/graphs/hand_tracking:mobile_calculators",
|
||||
"//mediapipe/framework/formats:landmark_cc_proto",
|
||||
],
|
||||
}),
|
||||
)
|
||||
|
|
|
@ -68,12 +68,12 @@ objc_library(
|
|||
],
|
||||
deps = [
|
||||
"//mediapipe/examples/ios/common:CommonMediaPipeAppLibrary",
|
||||
"//mediapipe/framework/formats:landmark_cc_proto",
|
||||
] + select({
|
||||
"//mediapipe:ios_i386": [],
|
||||
"//mediapipe:ios_x86_64": [],
|
||||
"//conditions:default": [
|
||||
"//mediapipe/graphs/iris_tracking:iris_tracking_gpu_deps",
|
||||
"//mediapipe/framework/formats:landmark_cc_proto",
|
||||
],
|
||||
}),
|
||||
)
|
||||
|
|
|
@ -67,12 +67,12 @@ objc_library(
|
|||
],
|
||||
deps = [
|
||||
"//mediapipe/examples/ios/common:CommonMediaPipeAppLibrary",
|
||||
"//mediapipe/framework/formats:landmark_cc_proto",
|
||||
] + select({
|
||||
"//mediapipe:ios_i386": [],
|
||||
"//mediapipe:ios_x86_64": [],
|
||||
"//conditions:default": [
|
||||
"//mediapipe/graphs/pose_tracking:pose_tracking_gpu_deps",
|
||||
"//mediapipe/framework/formats:landmark_cc_proto",
|
||||
],
|
||||
}),
|
||||
)
|
||||
|
|
|
@ -21,6 +21,7 @@ licenses(["notice"])
|
|||
|
||||
package(default_visibility = ["//visibility:private"])
|
||||
|
||||
# The MediaPipe internal package group. No mediapipe users should be added to this group.
|
||||
package_group(
|
||||
name = "mediapipe_internal",
|
||||
packages = [
|
||||
|
@ -56,12 +57,12 @@ mediapipe_proto_library(
|
|||
srcs = ["calculator.proto"],
|
||||
visibility = ["//visibility:public"],
|
||||
deps = [
|
||||
"//mediapipe/framework:calculator_options_proto",
|
||||
"//mediapipe/framework:mediapipe_options_proto",
|
||||
"//mediapipe/framework:packet_factory_proto",
|
||||
"//mediapipe/framework:packet_generator_proto",
|
||||
"//mediapipe/framework:status_handler_proto",
|
||||
"//mediapipe/framework:stream_handler_proto",
|
||||
":calculator_options_proto",
|
||||
":mediapipe_options_proto",
|
||||
":packet_factory_proto",
|
||||
":packet_generator_proto",
|
||||
":status_handler_proto",
|
||||
":stream_handler_proto",
|
||||
"@com_google_protobuf//:any_proto",
|
||||
],
|
||||
)
|
||||
|
@ -78,8 +79,8 @@ mediapipe_proto_library(
|
|||
srcs = ["calculator_contract_test.proto"],
|
||||
visibility = ["//mediapipe/framework:__subpackages__"],
|
||||
deps = [
|
||||
"//mediapipe/framework:calculator_options_proto",
|
||||
"//mediapipe/framework:calculator_proto",
|
||||
":calculator_options_proto",
|
||||
":calculator_proto",
|
||||
],
|
||||
)
|
||||
|
||||
|
@ -88,8 +89,8 @@ mediapipe_proto_library(
|
|||
srcs = ["calculator_profile.proto"],
|
||||
visibility = ["//visibility:public"],
|
||||
deps = [
|
||||
"//mediapipe/framework:calculator_options_proto",
|
||||
"//mediapipe/framework:calculator_proto",
|
||||
":calculator_options_proto",
|
||||
":calculator_proto",
|
||||
],
|
||||
)
|
||||
|
||||
|
@ -125,14 +126,14 @@ mediapipe_proto_library(
|
|||
name = "status_handler_proto",
|
||||
srcs = ["status_handler.proto"],
|
||||
visibility = [":mediapipe_internal"],
|
||||
deps = ["//mediapipe/framework:mediapipe_options_proto"],
|
||||
deps = [":mediapipe_options_proto"],
|
||||
)
|
||||
|
||||
mediapipe_proto_library(
|
||||
name = "stream_handler_proto",
|
||||
srcs = ["stream_handler.proto"],
|
||||
visibility = [":mediapipe_internal"],
|
||||
deps = ["//mediapipe/framework:mediapipe_options_proto"],
|
||||
deps = [":mediapipe_options_proto"],
|
||||
)
|
||||
|
||||
mediapipe_proto_library(
|
||||
|
@ -141,8 +142,8 @@ mediapipe_proto_library(
|
|||
srcs = ["test_calculators.proto"],
|
||||
visibility = [":mediapipe_internal"],
|
||||
deps = [
|
||||
"//mediapipe/framework:calculator_options_proto",
|
||||
"//mediapipe/framework:calculator_proto",
|
||||
":calculator_options_proto",
|
||||
":calculator_proto",
|
||||
],
|
||||
)
|
||||
|
||||
|
@ -150,7 +151,7 @@ mediapipe_proto_library(
|
|||
name = "thread_pool_executor_proto",
|
||||
srcs = ["thread_pool_executor.proto"],
|
||||
visibility = [":mediapipe_internal"],
|
||||
deps = ["//mediapipe/framework:mediapipe_options_proto"],
|
||||
deps = [":mediapipe_options_proto"],
|
||||
)
|
||||
|
||||
# It is for pure-native Android builds where the library can't have any dependency on libandroid.so
|
||||
|
|
|
@ -20,7 +20,9 @@ load("@bazel_skylib//:bzl_library.bzl", "bzl_library")
|
|||
|
||||
licenses(["notice"])
|
||||
|
||||
package(default_visibility = ["//visibility:private"])
|
||||
package(default_visibility = [
|
||||
"//mediapipe:__subpackages__",
|
||||
])
|
||||
|
||||
bzl_library(
|
||||
name = "expand_template_bzl",
|
||||
|
@ -50,13 +52,11 @@ mediapipe_proto_library(
|
|||
cc_library(
|
||||
name = "aligned_malloc_and_free",
|
||||
hdrs = ["aligned_malloc_and_free.h"],
|
||||
visibility = ["//visibility:public"],
|
||||
)
|
||||
|
||||
cc_library(
|
||||
name = "cleanup",
|
||||
hdrs = ["cleanup.h"],
|
||||
visibility = ["//visibility:public"],
|
||||
deps = ["@com_google_absl//absl/base:core_headers"],
|
||||
)
|
||||
|
||||
|
@ -86,7 +86,6 @@ cc_library(
|
|||
# Use this library through "mediapipe/framework/port:gtest_main".
|
||||
visibility = [
|
||||
"//mediapipe/framework/port:__pkg__",
|
||||
"//third_party/visionai/algorithms/tracking:__pkg__",
|
||||
],
|
||||
deps = [
|
||||
"//mediapipe/framework/port:core_proto",
|
||||
|
@ -108,7 +107,6 @@ cc_library(
|
|||
name = "file_helpers",
|
||||
srcs = ["file_helpers.cc"],
|
||||
hdrs = ["file_helpers.h"],
|
||||
visibility = ["//visibility:public"],
|
||||
deps = [
|
||||
":file_path",
|
||||
"//mediapipe/framework/port:status",
|
||||
|
@ -134,7 +132,6 @@ cc_library(
|
|||
cc_library(
|
||||
name = "image_resizer",
|
||||
hdrs = ["image_resizer.h"],
|
||||
visibility = ["//visibility:public"],
|
||||
deps = [
|
||||
"//mediapipe/framework/port:opencv_imgproc",
|
||||
],
|
||||
|
@ -151,7 +148,9 @@ cc_library(
|
|||
cc_library(
|
||||
name = "mathutil",
|
||||
hdrs = ["mathutil.h"],
|
||||
visibility = ["//visibility:public"],
|
||||
visibility = [
|
||||
"//mediapipe:__subpackages__",
|
||||
],
|
||||
deps = [
|
||||
"//mediapipe/framework/port:integral_types",
|
||||
"//mediapipe/framework/port:logging",
|
||||
|
@ -171,7 +170,6 @@ cc_library(
|
|||
cc_library(
|
||||
name = "no_destructor",
|
||||
hdrs = ["no_destructor.h"],
|
||||
visibility = ["//visibility:public"],
|
||||
)
|
||||
|
||||
cc_library(
|
||||
|
@ -190,7 +188,6 @@ cc_library(
|
|||
cc_library(
|
||||
name = "random",
|
||||
hdrs = ["random_base.h"],
|
||||
visibility = ["//visibility:public"],
|
||||
deps = ["//mediapipe/framework/port:integral_types"],
|
||||
)
|
||||
|
||||
|
@ -211,14 +208,12 @@ cc_library(
|
|||
name = "registration_token",
|
||||
srcs = ["registration_token.cc"],
|
||||
hdrs = ["registration_token.h"],
|
||||
visibility = ["//visibility:public"],
|
||||
)
|
||||
|
||||
cc_library(
|
||||
name = "registration",
|
||||
srcs = ["registration.cc"],
|
||||
hdrs = ["registration.h"],
|
||||
visibility = ["//visibility:public"],
|
||||
deps = [
|
||||
":registration_token",
|
||||
"//mediapipe/framework/port:logging",
|
||||
|
@ -279,7 +274,6 @@ cc_library(
|
|||
hdrs = [
|
||||
"re2.h",
|
||||
],
|
||||
visibility = ["//visibility:public"],
|
||||
)
|
||||
|
||||
cc_library(
|
||||
|
@ -310,7 +304,6 @@ cc_library(
|
|||
cc_library(
|
||||
name = "thread_options",
|
||||
hdrs = ["thread_options.h"],
|
||||
visibility = ["//visibility:public"],
|
||||
)
|
||||
|
||||
cc_library(
|
||||
|
@ -356,7 +349,6 @@ cc_library(
|
|||
cc_test(
|
||||
name = "mathutil_unittest",
|
||||
srcs = ["mathutil_unittest.cc"],
|
||||
visibility = ["//visibility:public"],
|
||||
deps = [
|
||||
":mathutil",
|
||||
"//mediapipe/framework/port:benchmark",
|
||||
|
@ -368,7 +360,6 @@ cc_test(
|
|||
name = "registration_token_test",
|
||||
srcs = ["registration_token_test.cc"],
|
||||
linkstatic = 1,
|
||||
visibility = ["//visibility:public"],
|
||||
deps = [
|
||||
":registration_token",
|
||||
"//mediapipe/framework/port:gtest_main",
|
||||
|
@ -381,7 +372,6 @@ cc_test(
|
|||
timeout = "long",
|
||||
srcs = ["safe_int_test.cc"],
|
||||
linkstatic = 1,
|
||||
visibility = ["//visibility:public"],
|
||||
deps = [
|
||||
":intops",
|
||||
"//mediapipe/framework/port:gtest_main",
|
||||
|
@ -393,7 +383,6 @@ cc_test(
|
|||
name = "monotonic_clock_test",
|
||||
srcs = ["monotonic_clock_test.cc"],
|
||||
linkstatic = 1,
|
||||
visibility = ["//visibility:public"],
|
||||
deps = [
|
||||
":clock",
|
||||
"//mediapipe/framework/port:gtest_main",
|
||||
|
|
|
@ -361,7 +361,7 @@ void Tensor::AllocateOpenGlBuffer() const {
|
|||
LOG_IF(FATAL, !gl_context_) << "GlContext is not bound to the thread.";
|
||||
glGenBuffers(1, &opengl_buffer_);
|
||||
glBindBuffer(GL_SHADER_STORAGE_BUFFER, opengl_buffer_);
|
||||
if (!AllocateAhwbMapToSsbo()) {
|
||||
if (!use_ahwb_ || !AllocateAhwbMapToSsbo()) {
|
||||
glBufferData(GL_SHADER_STORAGE_BUFFER, bytes(), NULL, GL_STREAM_COPY);
|
||||
}
|
||||
glBindBuffer(GL_SHADER_STORAGE_BUFFER, 0);
|
||||
|
@ -551,7 +551,7 @@ Tensor::CpuReadView Tensor::GetCpuReadView() const {
|
|||
});
|
||||
} else
|
||||
#endif // MEDIAPIPE_OPENGL_ES_VERSION >= MEDIAPIPE_OPENGL_ES_31
|
||||
|
||||
{
|
||||
// Transfer data from texture if not transferred from SSBO/MTLBuffer
|
||||
// yet.
|
||||
if (valid_ & kValidOpenGlTexture2d) {
|
||||
|
@ -582,6 +582,7 @@ Tensor::CpuReadView Tensor::GetCpuReadView() const {
|
|||
}
|
||||
});
|
||||
}
|
||||
}
|
||||
#endif // MEDIAPIPE_OPENGL_ES_VERSION >= MEDIAPIPE_OPENGL_ES_30
|
||||
valid_ |= kValidCpu;
|
||||
}
|
||||
|
@ -609,7 +610,7 @@ Tensor::CpuWriteView Tensor::GetCpuWriteView() const {
|
|||
void Tensor::AllocateCpuBuffer() const {
|
||||
if (!cpu_buffer_) {
|
||||
#ifdef MEDIAPIPE_TENSOR_USE_AHWB
|
||||
if (AllocateAHardwareBuffer()) return;
|
||||
if (use_ahwb_ && AllocateAHardwareBuffer()) return;
|
||||
#endif // MEDIAPIPE_TENSOR_USE_AHWB
|
||||
#if MEDIAPIPE_METAL_ENABLED
|
||||
cpu_buffer_ = AllocateVirtualMemory(bytes());
|
||||
|
|
|
@ -39,10 +39,9 @@
|
|||
#endif // MEDIAPIPE_NO_JNI
|
||||
|
||||
#ifdef MEDIAPIPE_TENSOR_USE_AHWB
|
||||
#include <EGL/egl.h>
|
||||
#include <EGL/eglext.h>
|
||||
#include <android/hardware_buffer.h>
|
||||
|
||||
#include "third_party/GL/gl/include/EGL/egl.h"
|
||||
#include "third_party/GL/gl/include/EGL/eglext.h"
|
||||
#endif // MEDIAPIPE_TENSOR_USE_AHWB
|
||||
#if MEDIAPIPE_OPENGL_ES_VERSION >= MEDIAPIPE_OPENGL_ES_30
|
||||
#include "mediapipe/gpu/gl_base.h"
|
||||
|
@ -410,8 +409,8 @@ class Tensor {
|
|||
bool AllocateAHardwareBuffer(int size_alignment = 0) const;
|
||||
void CreateEglSyncAndFd() const;
|
||||
// Use Ahwb for other views: OpenGL / CPU buffer.
|
||||
static inline bool use_ahwb_ = false;
|
||||
#endif // MEDIAPIPE_TENSOR_USE_AHWB
|
||||
static inline bool use_ahwb_ = false;
|
||||
// Expects the target SSBO to be already bound.
|
||||
bool AllocateAhwbMapToSsbo() const;
|
||||
bool InsertAhwbToSsboFence() const;
|
||||
|
@ -419,6 +418,7 @@ class Tensor {
|
|||
void ReleaseAhwbStuff();
|
||||
void* MapAhwbToCpuRead() const;
|
||||
void* MapAhwbToCpuWrite() const;
|
||||
void MoveCpuOrSsboToAhwb() const;
|
||||
|
||||
#if MEDIAPIPE_OPENGL_ES_VERSION >= MEDIAPIPE_OPENGL_ES_30
|
||||
mutable std::shared_ptr<mediapipe::GlContext> gl_context_;
|
||||
|
|
|
@ -4,12 +4,13 @@
|
|||
#include "mediapipe/framework/formats/tensor.h"
|
||||
|
||||
#ifdef MEDIAPIPE_TENSOR_USE_AHWB
|
||||
#include <EGL/egl.h>
|
||||
#include <EGL/eglext.h>
|
||||
|
||||
#include "absl/synchronization/mutex.h"
|
||||
#include "mediapipe/framework/port.h"
|
||||
#include "mediapipe/framework/port/logging.h"
|
||||
#include "mediapipe/gpu/gl_base.h"
|
||||
#include "third_party/GL/gl/include/EGL/egl.h"
|
||||
#include "third_party/GL/gl/include/EGL/eglext.h"
|
||||
#endif // MEDIAPIPE_TENSOR_USE_AHWB
|
||||
|
||||
namespace mediapipe {
|
||||
|
@ -213,11 +214,16 @@ Tensor::AHardwareBufferView Tensor::GetAHardwareBufferReadView() const {
|
|||
"supported.";
|
||||
CHECK(ahwb_ || !(valid_ & kValidOpenGlBuffer))
|
||||
<< "Interoperability bettween OpenGL buffer and AHardwareBuffer is not "
|
||||
"supported on targe system.";
|
||||
"supported on target system.";
|
||||
bool transfer = !ahwb_;
|
||||
CHECK(AllocateAHardwareBuffer())
|
||||
<< "AHardwareBuffer is not supported on the target system.";
|
||||
valid_ |= kValidAHardwareBuffer;
|
||||
if (valid_ & kValidOpenGlBuffer) CreateEglSyncAndFd();
|
||||
if (transfer) {
|
||||
MoveCpuOrSsboToAhwb();
|
||||
} else {
|
||||
if (valid_ & kValidOpenGlBuffer) CreateEglSyncAndFd();
|
||||
}
|
||||
return {ahwb_,
|
||||
ssbo_written_,
|
||||
&fence_fd_, // The FD is created for SSBO -> AHWB synchronization.
|
||||
|
@ -262,7 +268,6 @@ Tensor::AHardwareBufferView Tensor::GetAHardwareBufferWriteView(
|
|||
}
|
||||
|
||||
bool Tensor::AllocateAHardwareBuffer(int size_alignment) const {
|
||||
if (!use_ahwb_) return false;
|
||||
if (__builtin_available(android 26, *)) {
|
||||
if (ahwb_ == nullptr) {
|
||||
AHardwareBuffer_Desc desc = {};
|
||||
|
@ -302,6 +307,39 @@ bool Tensor::AllocateAhwbMapToSsbo() const {
|
|||
return false;
|
||||
}
|
||||
|
||||
// Moves Cpu/Ssbo resource under the Ahwb backed memory.
|
||||
void Tensor::MoveCpuOrSsboToAhwb() const {
|
||||
void* dest = nullptr;
|
||||
if (__builtin_available(android 26, *)) {
|
||||
auto error = AHardwareBuffer_lock(
|
||||
ahwb_, AHARDWAREBUFFER_USAGE_CPU_WRITE_RARELY, -1, nullptr, &dest);
|
||||
CHECK(error == 0) << "AHardwareBuffer_lock " << error;
|
||||
}
|
||||
if (valid_ & kValidOpenGlBuffer) {
|
||||
gl_context_->Run([this, dest]() {
|
||||
glBindBuffer(GL_SHADER_STORAGE_BUFFER, opengl_buffer_);
|
||||
const void* src = glMapBufferRange(GL_SHADER_STORAGE_BUFFER, 0, bytes(),
|
||||
GL_MAP_READ_BIT);
|
||||
std::memcpy(dest, src, bytes());
|
||||
glUnmapBuffer(GL_SHADER_STORAGE_BUFFER);
|
||||
glDeleteBuffers(1, &opengl_buffer_);
|
||||
});
|
||||
opengl_buffer_ = GL_INVALID_INDEX;
|
||||
gl_context_ = nullptr;
|
||||
} else if (valid_ & kValidCpu) {
|
||||
std::memcpy(dest, cpu_buffer_, bytes());
|
||||
// Free CPU memory because next time AHWB is mapped instead.
|
||||
free(cpu_buffer_);
|
||||
cpu_buffer_ = nullptr;
|
||||
} else {
|
||||
LOG(FATAL) << "Can't convert tensor with mask " << valid_ << " into AHWB.";
|
||||
}
|
||||
if (__builtin_available(android 26, *)) {
|
||||
auto error = AHardwareBuffer_unlock(ahwb_, nullptr);
|
||||
CHECK(error == 0) << "AHardwareBuffer_unlock " << error;
|
||||
}
|
||||
}
|
||||
|
||||
// SSBO is created on top of AHWB. A fence is inserted into the GPU queue before
|
||||
// the GPU task that is going to read from the SSBO. When the writing into AHWB
|
||||
// is finished then the GPU reads from the SSBO.
|
||||
|
|
171
mediapipe/framework/formats/tensor_ahwb_gpu_test.cc
Normal file
171
mediapipe/framework/formats/tensor_ahwb_gpu_test.cc
Normal file
|
@ -0,0 +1,171 @@
|
|||
|
||||
#if !defined(MEDIAPIPE_NO_JNI) && \
|
||||
(__ANDROID_API__ >= 26 || \
|
||||
defined(__ANDROID_UNAVAILABLE_SYMBOLS_ARE_WEAK__))
|
||||
#include <android/hardware_buffer.h>
|
||||
|
||||
#include <cstdint>
|
||||
|
||||
#include "mediapipe/framework/formats/tensor.h"
|
||||
#include "mediapipe/framework/formats/tensor_data_types.h"
|
||||
#include "mediapipe/gpu/gpu_test_base.h"
|
||||
#include "mediapipe/gpu/shader_util.h"
|
||||
#include "tensorflow/lite/delegates/gpu/gl/gl_call.h"
|
||||
#include "testing/base/public/gunit.h"
|
||||
|
||||
// The test creates OpenGL ES buffer, fills the buffer with incrementing values
|
||||
// 0.0, 0.1, 0.2 etc. with the compute shader on GPU.
|
||||
// Then the test requests the CPU view and compares the values.
|
||||
// Float32 and Float16 tests are there.
|
||||
|
||||
namespace {
|
||||
|
||||
using mediapipe::Float16;
|
||||
using mediapipe::Tensor;
|
||||
|
||||
MATCHER_P(NearWithPrecision, precision, "") {
|
||||
return std::abs(std::get<0>(arg) - std::get<1>(arg)) < precision;
|
||||
}
|
||||
|
||||
#if MEDIAPIPE_OPENGL_ES_VERSION >= MEDIAPIPE_OPENGL_ES_31
|
||||
|
||||
// Utility function to fill the GPU buffer.
|
||||
void FillGpuBuffer(GLuint name, std::size_t size,
|
||||
const Tensor::ElementType fmt) {
|
||||
std::string shader_source;
|
||||
if (fmt == Tensor::ElementType::kFloat32) {
|
||||
shader_source = R"( #version 310 es
|
||||
precision highp float;
|
||||
layout(local_size_x = 1, local_size_y = 1) in;
|
||||
layout(std430, binding = 0) buffer Output {float elements[];} output_data;
|
||||
void main() {
|
||||
uint v = gl_GlobalInvocationID.x * 2u;
|
||||
output_data.elements[v] = float(v) / 10.0;
|
||||
output_data.elements[v + 1u] = float(v + 1u) / 10.0;
|
||||
})";
|
||||
} else {
|
||||
shader_source = R"( #version 310 es
|
||||
precision highp float;
|
||||
layout(local_size_x = 1, local_size_y = 1) in;
|
||||
layout(std430, binding = 0) buffer Output {float elements[];} output_data;
|
||||
void main() {
|
||||
uint v = gl_GlobalInvocationID.x;
|
||||
uint tmp = packHalf2x16(vec2((float(v)* 2.0 + 0.0) / 10.0,
|
||||
(float(v) * 2.0 + 1.0) / 10.0));
|
||||
output_data.elements[v] = uintBitsToFloat(tmp);
|
||||
})";
|
||||
}
|
||||
GLuint shader;
|
||||
MP_ASSERT_OK(TFLITE_GPU_CALL_GL(glCreateShader, &shader, GL_COMPUTE_SHADER));
|
||||
const GLchar* sources[] = {shader_source.c_str()};
|
||||
MP_ASSERT_OK(TFLITE_GPU_CALL_GL(glShaderSource, shader, 1, sources, nullptr));
|
||||
MP_ASSERT_OK(TFLITE_GPU_CALL_GL(glCompileShader, shader));
|
||||
GLint is_compiled = 0;
|
||||
MP_ASSERT_OK(TFLITE_GPU_CALL_GL(glGetShaderiv, shader, GL_COMPILE_STATUS,
|
||||
&is_compiled));
|
||||
if (is_compiled == GL_FALSE) {
|
||||
GLint max_length = 0;
|
||||
MP_ASSERT_OK(TFLITE_GPU_CALL_GL(glGetShaderiv, shader, GL_INFO_LOG_LENGTH,
|
||||
&max_length));
|
||||
std::vector<GLchar> error_log(max_length);
|
||||
glGetShaderInfoLog(shader, max_length, &max_length, error_log.data());
|
||||
glDeleteShader(shader);
|
||||
FAIL() << error_log.data();
|
||||
return;
|
||||
}
|
||||
GLuint to_buffer_program;
|
||||
MP_ASSERT_OK(TFLITE_GPU_CALL_GL(glCreateProgram, &to_buffer_program));
|
||||
MP_ASSERT_OK(TFLITE_GPU_CALL_GL(glAttachShader, to_buffer_program, shader));
|
||||
MP_ASSERT_OK(TFLITE_GPU_CALL_GL(glDeleteShader, shader));
|
||||
MP_ASSERT_OK(TFLITE_GPU_CALL_GL(glLinkProgram, to_buffer_program));
|
||||
|
||||
MP_ASSERT_OK(
|
||||
TFLITE_GPU_CALL_GL(glBindBufferBase, GL_SHADER_STORAGE_BUFFER, 0, name));
|
||||
MP_ASSERT_OK(TFLITE_GPU_CALL_GL(glUseProgram, to_buffer_program));
|
||||
MP_ASSERT_OK(TFLITE_GPU_CALL_GL(glDispatchCompute, size / 2, 1, 1));
|
||||
MP_ASSERT_OK(TFLITE_GPU_CALL_GL(glBindBuffer, GL_SHADER_STORAGE_BUFFER, 0));
|
||||
MP_ASSERT_OK(TFLITE_GPU_CALL_GL(glDeleteProgram, to_buffer_program));
|
||||
}
|
||||
|
||||
class TensorAhwbGpuTest : public mediapipe::GpuTestBase {
|
||||
public:
|
||||
};
|
||||
|
||||
TEST_F(TensorAhwbGpuTest, TestGpuToCpuFloat32) {
|
||||
Tensor::SetPreferredStorageType(Tensor::StorageType::kAhwb);
|
||||
constexpr size_t num_elements = 20;
|
||||
Tensor tensor{Tensor::ElementType::kFloat32, Tensor::Shape({num_elements})};
|
||||
RunInGlContext([&tensor] {
|
||||
auto ssbo_view = tensor.GetOpenGlBufferWriteView();
|
||||
auto ssbo_name = ssbo_view.name();
|
||||
EXPECT_GT(ssbo_name, 0);
|
||||
FillGpuBuffer(ssbo_name, tensor.shape().num_elements(),
|
||||
tensor.element_type());
|
||||
});
|
||||
auto ptr = tensor.GetCpuReadView().buffer<float>();
|
||||
EXPECT_NE(ptr, nullptr);
|
||||
std::vector<float> reference;
|
||||
reference.resize(num_elements);
|
||||
for (int i = 0; i < num_elements; i++) {
|
||||
reference[i] = static_cast<float>(i) / 10.0f;
|
||||
}
|
||||
EXPECT_THAT(absl::Span<const float>(ptr, num_elements),
|
||||
testing::Pointwise(testing::FloatEq(), reference));
|
||||
}
|
||||
|
||||
TEST_F(TensorAhwbGpuTest, TestGpuToCpuFloat16) {
|
||||
Tensor::SetPreferredStorageType(Tensor::StorageType::kAhwb);
|
||||
constexpr size_t num_elements = 20;
|
||||
Tensor tensor{Tensor::ElementType::kFloat16, Tensor::Shape({num_elements})};
|
||||
RunInGlContext([&tensor] {
|
||||
auto ssbo_view = tensor.GetOpenGlBufferWriteView();
|
||||
auto ssbo_name = ssbo_view.name();
|
||||
EXPECT_GT(ssbo_name, 0);
|
||||
FillGpuBuffer(ssbo_name, tensor.shape().num_elements(),
|
||||
tensor.element_type());
|
||||
});
|
||||
auto ptr = tensor.GetCpuReadView().buffer<Float16>();
|
||||
EXPECT_NE(ptr, nullptr);
|
||||
std::vector<Float16> reference;
|
||||
reference.resize(num_elements);
|
||||
for (int i = 0; i < num_elements; i++) {
|
||||
reference[i] = static_cast<float>(i) / 10.0f;
|
||||
}
|
||||
// Precision is set to a reasonable value for Float16.
|
||||
EXPECT_THAT(absl::Span<const Float16>(ptr, num_elements),
|
||||
testing::Pointwise(NearWithPrecision(0.001), reference));
|
||||
}
|
||||
|
||||
TEST_F(TensorAhwbGpuTest, TestReplacingCpuByAhwb) {
|
||||
// Request the CPU view to get the memory to be allocated.
|
||||
// Request Ahwb view then to transform the storage into Ahwb.
|
||||
Tensor::SetPreferredStorageType(Tensor::StorageType::kDefault);
|
||||
constexpr size_t num_elements = 20;
|
||||
Tensor tensor{Tensor::ElementType::kFloat32, Tensor::Shape({num_elements})};
|
||||
{
|
||||
auto ptr = tensor.GetCpuWriteView().buffer<float>();
|
||||
EXPECT_NE(ptr, nullptr);
|
||||
for (int i = 0; i < num_elements; i++) {
|
||||
ptr[i] = static_cast<float>(i) / 10.0f;
|
||||
}
|
||||
}
|
||||
{
|
||||
auto view = tensor.GetAHardwareBufferReadView();
|
||||
EXPECT_NE(view.handle(), nullptr);
|
||||
}
|
||||
auto ptr = tensor.GetCpuReadView().buffer<float>();
|
||||
EXPECT_NE(ptr, nullptr);
|
||||
std::vector<float> reference;
|
||||
reference.resize(num_elements);
|
||||
for (int i = 0; i < num_elements; i++) {
|
||||
reference[i] = static_cast<float>(i) / 10.0f;
|
||||
}
|
||||
EXPECT_THAT(absl::Span<const float>(ptr, num_elements),
|
||||
testing::Pointwise(testing::FloatEq(), reference));
|
||||
}
|
||||
|
||||
#endif // MEDIAPIPE_OPENGL_ES_VERSION >= MEDIAPIPE_OPENGL_ES_31
|
||||
} // namespace
|
||||
|
||||
#endif // !defined(MEDIAPIPE_NO_JNI) && (__ANDROID_API__ >= 26 ||
|
||||
// defined(__ANDROID_UNAVAILABLE_SYMBOLS_ARE_WEAK__))
|
71
mediapipe/framework/formats/tensor_hardware_buffer.h
Normal file
71
mediapipe/framework/formats/tensor_hardware_buffer.h
Normal file
|
@ -0,0 +1,71 @@
|
|||
#ifndef MEDIAPIPE_FRAMEWORK_FORMATS_TENSOR_HARDWARE_BUFFER_H_
|
||||
#define MEDIAPIPE_FRAMEWORK_FORMATS_TENSOR_HARDWARE_BUFFER_H_
|
||||
|
||||
#if !defined(MEDIAPIPE_NO_JNI) && \
|
||||
(__ANDROID_API__ >= 26 || \
|
||||
defined(__ANDROID_UNAVAILABLE_SYMBOLS_ARE_WEAK__))
|
||||
|
||||
#include <android/hardware_buffer.h>
|
||||
|
||||
#include <cstdint>
|
||||
|
||||
#include "mediapipe/framework/formats/tensor_buffer.h"
|
||||
#include "mediapipe/framework/formats/tensor_internal.h"
|
||||
#include "mediapipe/framework/formats/tensor_v2.h"
|
||||
|
||||
namespace mediapipe {
|
||||
|
||||
// Supports:
|
||||
// - float 16 and 32 bits
|
||||
// - signed / unsigned integers 8,16,32 bits
|
||||
class TensorHardwareBufferView;
|
||||
struct TensorHardwareBufferViewDescriptor : public Tensor::ViewDescriptor {
|
||||
using ViewT = TensorHardwareBufferView;
|
||||
TensorBufferDescriptor buffer;
|
||||
};
|
||||
|
||||
class TensorHardwareBufferView : public Tensor::View {
|
||||
public:
|
||||
TENSOR_UNIQUE_VIEW_TYPE_ID();
|
||||
~TensorHardwareBufferView() = default;
|
||||
|
||||
const TensorHardwareBufferViewDescriptor& descriptor() const override {
|
||||
return descriptor_;
|
||||
}
|
||||
AHardwareBuffer* handle() const { return ahwb_handle_; }
|
||||
|
||||
protected:
|
||||
TensorHardwareBufferView(int access_capability, Tensor::View::Access access,
|
||||
Tensor::View::State state,
|
||||
const TensorHardwareBufferViewDescriptor& desc,
|
||||
AHardwareBuffer* ahwb_handle)
|
||||
: Tensor::View(kId, access_capability, access, state),
|
||||
descriptor_(desc),
|
||||
ahwb_handle_(ahwb_handle) {}
|
||||
|
||||
private:
|
||||
bool MatchDescriptor(
|
||||
uint64_t view_type_id,
|
||||
const Tensor::ViewDescriptor& base_descriptor) const override {
|
||||
if (!Tensor::View::MatchDescriptor(view_type_id, base_descriptor))
|
||||
return false;
|
||||
auto descriptor =
|
||||
static_cast<const TensorHardwareBufferViewDescriptor&>(base_descriptor);
|
||||
return descriptor.buffer.format == descriptor_.buffer.format &&
|
||||
descriptor.buffer.size_alignment <=
|
||||
descriptor_.buffer.size_alignment &&
|
||||
descriptor_.buffer.size_alignment %
|
||||
descriptor.buffer.size_alignment ==
|
||||
0;
|
||||
}
|
||||
const TensorHardwareBufferViewDescriptor& descriptor_;
|
||||
AHardwareBuffer* ahwb_handle_ = nullptr;
|
||||
};
|
||||
|
||||
} // namespace mediapipe
|
||||
|
||||
#endif // !defined(MEDIAPIPE_NO_JNI) && \
|
||||
(__ANDROID_API__ >= 26 || \
|
||||
defined(__ANDROID_UNAVAILABLE_SYMBOLS_ARE_WEAK__))
|
||||
|
||||
#endif // MEDIAPIPE_FRAMEWORK_FORMATS_TENSOR_HARDWARE_BUFFER_H_
|
|
@ -0,0 +1,216 @@
|
|||
#if !defined(MEDIAPIPE_NO_JNI) && \
|
||||
(__ANDROID_API__ >= 26 || \
|
||||
defined(__ANDROID_UNAVAILABLE_SYMBOLS_ARE_WEAK__))
|
||||
|
||||
#include <cstdint>
|
||||
|
||||
#include "absl/memory/memory.h"
|
||||
#include "absl/status/status.h"
|
||||
#include "mediapipe/framework/formats/tensor_backend.h"
|
||||
#include "mediapipe/framework/formats/tensor_cpu_buffer.h"
|
||||
#include "mediapipe/framework/formats/tensor_hardware_buffer.h"
|
||||
#include "mediapipe/framework/formats/tensor_v2.h"
|
||||
#include "util/task/status_macros.h"
|
||||
|
||||
namespace mediapipe {
|
||||
namespace {
|
||||
|
||||
class TensorCpuViewImpl : public TensorCpuView {
|
||||
public:
|
||||
TensorCpuViewImpl(int access_capabilities, Tensor::View::Access access,
|
||||
Tensor::View::State state,
|
||||
const TensorCpuViewDescriptor& descriptor, void* pointer,
|
||||
AHardwareBuffer* ahwb_handle)
|
||||
: TensorCpuView(access_capabilities, access, state, descriptor, pointer),
|
||||
ahwb_handle_(ahwb_handle) {}
|
||||
~TensorCpuViewImpl() {
|
||||
// If handle_ is null then this view is constructed in GetViews with no
|
||||
// access.
|
||||
if (ahwb_handle_) {
|
||||
if (__builtin_available(android 26, *)) {
|
||||
AHardwareBuffer_unlock(ahwb_handle_, nullptr);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
private:
|
||||
AHardwareBuffer* ahwb_handle_;
|
||||
};
|
||||
|
||||
class TensorHardwareBufferViewImpl : public TensorHardwareBufferView {
|
||||
public:
|
||||
TensorHardwareBufferViewImpl(
|
||||
int access_capability, Tensor::View::Access access,
|
||||
Tensor::View::State state,
|
||||
const TensorHardwareBufferViewDescriptor& descriptor,
|
||||
AHardwareBuffer* handle)
|
||||
: TensorHardwareBufferView(access_capability, access, state, descriptor,
|
||||
handle) {}
|
||||
~TensorHardwareBufferViewImpl() = default;
|
||||
};
|
||||
|
||||
class HardwareBufferCpuStorage : public TensorStorage {
|
||||
public:
|
||||
~HardwareBufferCpuStorage() {
|
||||
if (!ahwb_handle_) return;
|
||||
if (__builtin_available(android 26, *)) {
|
||||
AHardwareBuffer_release(ahwb_handle_);
|
||||
}
|
||||
}
|
||||
|
||||
static absl::Status CanProvide(
|
||||
int access_capability, const Tensor::Shape& shape, uint64_t view_type_id,
|
||||
const Tensor::ViewDescriptor& base_descriptor) {
|
||||
// TODO: use AHardwareBuffer_isSupported for API >= 29.
|
||||
static const bool is_ahwb_supported = [] {
|
||||
if (__builtin_available(android 26, *)) {
|
||||
AHardwareBuffer_Desc desc = {};
|
||||
// Aligned to the largest possible virtual memory page size.
|
||||
constexpr uint32_t kPageSize = 16384;
|
||||
desc.width = kPageSize;
|
||||
desc.height = 1;
|
||||
desc.layers = 1;
|
||||
desc.format = AHARDWAREBUFFER_FORMAT_BLOB;
|
||||
desc.usage = AHARDWAREBUFFER_USAGE_CPU_READ_OFTEN |
|
||||
AHARDWAREBUFFER_USAGE_CPU_WRITE_OFTEN;
|
||||
AHardwareBuffer* handle;
|
||||
if (AHardwareBuffer_allocate(&desc, &handle) != 0) return false;
|
||||
AHardwareBuffer_release(handle);
|
||||
return true;
|
||||
}
|
||||
return false;
|
||||
}();
|
||||
if (!is_ahwb_supported) {
|
||||
return absl::UnavailableError(
|
||||
"AHardwareBuffer is not supported on the platform.");
|
||||
}
|
||||
|
||||
if (view_type_id != TensorCpuView::kId &&
|
||||
view_type_id != TensorHardwareBufferView::kId) {
|
||||
return absl::InvalidArgumentError(
|
||||
"A view type is not supported by this storage.");
|
||||
}
|
||||
return absl::OkStatus();
|
||||
}
|
||||
|
||||
std::vector<std::unique_ptr<Tensor::View>> GetViews(uint64_t latest_version) {
|
||||
std::vector<std::unique_ptr<Tensor::View>> result;
|
||||
auto update_state = latest_version == version_
|
||||
? Tensor::View::State::kUpToDate
|
||||
: Tensor::View::State::kOutdated;
|
||||
if (ahwb_handle_) {
|
||||
result.push_back(
|
||||
std::unique_ptr<Tensor::View>(new TensorHardwareBufferViewImpl(
|
||||
kAccessCapability, Tensor::View::Access::kNoAccess, update_state,
|
||||
hw_descriptor_, ahwb_handle_)));
|
||||
|
||||
result.push_back(std::unique_ptr<Tensor::View>(new TensorCpuViewImpl(
|
||||
kAccessCapability, Tensor::View::Access::kNoAccess, update_state,
|
||||
cpu_descriptor_, nullptr, nullptr)));
|
||||
}
|
||||
return result;
|
||||
}
|
||||
|
||||
absl::StatusOr<std::unique_ptr<Tensor::View>> GetView(
|
||||
Tensor::View::Access access, const Tensor::Shape& shape,
|
||||
uint64_t latest_version, uint64_t view_type_id,
|
||||
const Tensor::ViewDescriptor& base_descriptor, int access_capability) {
|
||||
MP_RETURN_IF_ERROR(
|
||||
CanProvide(access_capability, shape, view_type_id, base_descriptor));
|
||||
const auto& buffer_descriptor =
|
||||
view_type_id == TensorHardwareBufferView::kId
|
||||
? static_cast<const TensorHardwareBufferViewDescriptor&>(
|
||||
base_descriptor)
|
||||
.buffer
|
||||
: static_cast<const TensorCpuViewDescriptor&>(base_descriptor)
|
||||
.buffer;
|
||||
if (!ahwb_handle_) {
|
||||
if (__builtin_available(android 26, *)) {
|
||||
AHardwareBuffer_Desc desc = {};
|
||||
desc.width = TensorBufferSize(buffer_descriptor, shape);
|
||||
desc.height = 1;
|
||||
desc.layers = 1;
|
||||
desc.format = AHARDWAREBUFFER_FORMAT_BLOB;
|
||||
// TODO: Use access capabilities to set hints.
|
||||
desc.usage = AHARDWAREBUFFER_USAGE_CPU_READ_OFTEN |
|
||||
AHARDWAREBUFFER_USAGE_CPU_WRITE_OFTEN;
|
||||
auto error = AHardwareBuffer_allocate(&desc, &ahwb_handle_);
|
||||
if (error != 0) {
|
||||
return absl::UnknownError(
|
||||
absl::StrCat("Error allocating hardware buffer: ", error));
|
||||
}
|
||||
// Fill all possible views to provide it as proto views.
|
||||
hw_descriptor_.buffer = buffer_descriptor;
|
||||
cpu_descriptor_.buffer = buffer_descriptor;
|
||||
}
|
||||
}
|
||||
if (buffer_descriptor.format != hw_descriptor_.buffer.format ||
|
||||
buffer_descriptor.size_alignment >
|
||||
hw_descriptor_.buffer.size_alignment ||
|
||||
hw_descriptor_.buffer.size_alignment %
|
||||
buffer_descriptor.size_alignment >
|
||||
0) {
|
||||
return absl::AlreadyExistsError(
|
||||
"A view with different params is already allocated with this "
|
||||
"storage");
|
||||
}
|
||||
|
||||
absl::StatusOr<std::unique_ptr<Tensor::View>> result;
|
||||
if (view_type_id == TensorHardwareBufferView::kId) {
|
||||
result = GetAhwbView(access, shape, base_descriptor);
|
||||
} else {
|
||||
result = GetCpuView(access, shape, base_descriptor);
|
||||
}
|
||||
if (result.ok()) version_ = latest_version;
|
||||
return result;
|
||||
}
|
||||
|
||||
private:
|
||||
absl::StatusOr<std::unique_ptr<Tensor::View>> GetAhwbView(
|
||||
Tensor::View::Access access, const Tensor::Shape& shape,
|
||||
const Tensor::ViewDescriptor& base_descriptor) {
|
||||
return std::unique_ptr<Tensor::View>(new TensorHardwareBufferViewImpl(
|
||||
kAccessCapability, access, Tensor::View::State::kUpToDate,
|
||||
hw_descriptor_, ahwb_handle_));
|
||||
}
|
||||
|
||||
absl::StatusOr<std::unique_ptr<Tensor::View>> GetCpuView(
|
||||
Tensor::View::Access access, const Tensor::Shape& shape,
|
||||
const Tensor::ViewDescriptor& base_descriptor) {
|
||||
void* pointer = nullptr;
|
||||
if (__builtin_available(android 26, *)) {
|
||||
int error =
|
||||
AHardwareBuffer_lock(ahwb_handle_,
|
||||
access == Tensor::View::Access::kWriteOnly
|
||||
? AHARDWAREBUFFER_USAGE_CPU_WRITE_OFTEN
|
||||
: AHARDWAREBUFFER_USAGE_CPU_READ_OFTEN,
|
||||
-1, nullptr, &pointer);
|
||||
if (error != 0) {
|
||||
return absl::UnknownError(
|
||||
absl::StrCat("Error locking hardware buffer: ", error));
|
||||
}
|
||||
}
|
||||
return std::unique_ptr<Tensor::View>(
|
||||
new TensorCpuViewImpl(access == Tensor::View::Access::kWriteOnly
|
||||
? Tensor::View::AccessCapability::kWrite
|
||||
: Tensor::View::AccessCapability::kRead,
|
||||
access, Tensor::View::State::kUpToDate,
|
||||
cpu_descriptor_, pointer, ahwb_handle_));
|
||||
}
|
||||
|
||||
static constexpr int kAccessCapability =
|
||||
Tensor::View::AccessCapability::kRead |
|
||||
Tensor::View::AccessCapability::kWrite;
|
||||
TensorHardwareBufferViewDescriptor hw_descriptor_;
|
||||
AHardwareBuffer* ahwb_handle_ = nullptr;
|
||||
|
||||
TensorCpuViewDescriptor cpu_descriptor_;
|
||||
uint64_t version_ = 0;
|
||||
};
|
||||
TENSOR_REGISTER_STORAGE(HardwareBufferCpuStorage);
|
||||
|
||||
} // namespace
|
||||
} // namespace mediapipe
|
||||
|
||||
#endif // !defined(MEDIAPIPE_NO_JNI) && (__ANDROID_API__ >= 26 ||
|
||||
// defined(__ANDROID_UNAVAILABLE_SYMBOLS_ARE_WEAK__))
|
|
@ -0,0 +1,76 @@
|
|||
|
||||
#if !defined(MEDIAPIPE_NO_JNI) && \
|
||||
(__ANDROID_API__ >= 26 || \
|
||||
defined(__ANDROID_UNAVAILABLE_SYMBOLS_ARE_WEAK__))
|
||||
#include <android/hardware_buffer.h>
|
||||
|
||||
#include <cstdint>
|
||||
|
||||
#include "mediapipe/framework/formats/tensor_cpu_buffer.h"
|
||||
#include "mediapipe/framework/formats/tensor_hardware_buffer.h"
|
||||
#include "mediapipe/framework/formats/tensor_v2.h"
|
||||
#include "testing/base/public/gmock.h"
|
||||
#include "testing/base/public/gunit.h"
|
||||
|
||||
namespace mediapipe {
|
||||
|
||||
namespace {
|
||||
|
||||
class TensorHardwareBufferTest : public ::testing::Test {
|
||||
public:
|
||||
TensorHardwareBufferTest() {}
|
||||
~TensorHardwareBufferTest() override {}
|
||||
};
|
||||
|
||||
TEST_F(TensorHardwareBufferTest, TestFloat32) {
|
||||
Tensor tensor{Tensor::Shape({1})};
|
||||
{
|
||||
MP_ASSERT_OK_AND_ASSIGN(
|
||||
auto view,
|
||||
tensor.GetView<Tensor::View::Access::kWriteOnly>(
|
||||
TensorHardwareBufferViewDescriptor{
|
||||
.buffer = {.format =
|
||||
TensorBufferDescriptor::Format::kFloat32}}));
|
||||
EXPECT_NE(view->handle(), nullptr);
|
||||
}
|
||||
{
|
||||
const auto& const_tensor = tensor;
|
||||
MP_ASSERT_OK_AND_ASSIGN(
|
||||
auto view,
|
||||
const_tensor.GetView<Tensor::View::Access::kReadOnly>(
|
||||
TensorCpuViewDescriptor{
|
||||
.buffer = {.format =
|
||||
TensorBufferDescriptor::Format::kFloat32}}));
|
||||
EXPECT_NE(view->data<void>(), nullptr);
|
||||
}
|
||||
}
|
||||
|
||||
TEST_F(TensorHardwareBufferTest, TestInt8Padding) {
|
||||
Tensor tensor{Tensor::Shape({1})};
|
||||
|
||||
{
|
||||
MP_ASSERT_OK_AND_ASSIGN(
|
||||
auto view,
|
||||
tensor.GetView<Tensor::View::Access::kWriteOnly>(
|
||||
TensorHardwareBufferViewDescriptor{
|
||||
.buffer = {.format = TensorBufferDescriptor::Format::kInt8,
|
||||
.size_alignment = 4}}));
|
||||
EXPECT_NE(view->handle(), nullptr);
|
||||
}
|
||||
{
|
||||
const auto& const_tensor = tensor;
|
||||
MP_ASSERT_OK_AND_ASSIGN(
|
||||
auto view,
|
||||
const_tensor.GetView<Tensor::View::Access::kReadOnly>(
|
||||
TensorCpuViewDescriptor{
|
||||
.buffer = {.format = TensorBufferDescriptor::Format::kInt8}}));
|
||||
EXPECT_NE(view->data<void>(), nullptr);
|
||||
}
|
||||
}
|
||||
|
||||
} // namespace
|
||||
|
||||
} // namespace mediapipe
|
||||
|
||||
#endif // !defined(MEDIAPIPE_NO_JNI) && (__ANDROID_API__ >= 26 ||
|
||||
// defined(__ANDROID_UNAVAILABLE_SYMBOLS_ARE_WEAK__))
|
|
@ -127,6 +127,8 @@ class OutputStreamShard : public OutputStream {
|
|||
friend class GraphProfiler;
|
||||
// Accesses OutputStreamShard for profiling.
|
||||
friend class GraphTracer;
|
||||
// Accesses OutputStreamShard for profiling.
|
||||
friend class PerfettoTraceScope;
|
||||
// Accesses OutputStreamShard for post processing.
|
||||
friend class OutputStreamManager;
|
||||
};
|
||||
|
|
|
@ -18,7 +18,7 @@
|
|||
licenses(["notice"])
|
||||
|
||||
package(
|
||||
default_visibility = ["//visibility:private"],
|
||||
default_visibility = ["//visibility:public"],
|
||||
features = ["-parse_headers"],
|
||||
)
|
||||
|
||||
|
@ -28,7 +28,6 @@ config_setting(
|
|||
define_values = {
|
||||
"USE_MEDIAPIPE_THREADPOOL": "1",
|
||||
},
|
||||
visibility = ["//visibility:public"],
|
||||
)
|
||||
|
||||
#TODO : remove from OSS.
|
||||
|
@ -37,13 +36,11 @@ config_setting(
|
|||
define_values = {
|
||||
"USE_MEDIAPIPE_THREADPOOL": "0",
|
||||
},
|
||||
visibility = ["//visibility:public"],
|
||||
)
|
||||
|
||||
cc_library(
|
||||
name = "aligned_malloc_and_free",
|
||||
hdrs = ["aligned_malloc_and_free.h"],
|
||||
visibility = ["//visibility:public"],
|
||||
deps = [
|
||||
"//mediapipe/framework/deps:aligned_malloc_and_free",
|
||||
"@com_google_absl//absl/base:core_headers",
|
||||
|
@ -57,7 +54,6 @@ cc_library(
|
|||
"advanced_proto_inc.h",
|
||||
"proto_ns.h",
|
||||
],
|
||||
visibility = ["//visibility:public"],
|
||||
deps = [
|
||||
":advanced_proto_lite",
|
||||
":core_proto",
|
||||
|
@ -72,7 +68,6 @@ cc_library(
|
|||
"advanced_proto_lite_inc.h",
|
||||
"proto_ns.h",
|
||||
],
|
||||
visibility = ["//visibility:public"],
|
||||
deps = [
|
||||
":core_proto",
|
||||
"//mediapipe/framework:port",
|
||||
|
@ -83,7 +78,6 @@ cc_library(
|
|||
cc_library(
|
||||
name = "any_proto",
|
||||
hdrs = ["any_proto.h"],
|
||||
visibility = ["//visibility:public"],
|
||||
deps = [
|
||||
":core_proto",
|
||||
],
|
||||
|
@ -94,7 +88,6 @@ cc_library(
|
|||
hdrs = [
|
||||
"commandlineflags.h",
|
||||
],
|
||||
visibility = ["//visibility:public"],
|
||||
deps = [
|
||||
"//third_party:glog",
|
||||
"@com_google_absl//absl/flags:flag",
|
||||
|
@ -107,7 +100,6 @@ cc_library(
|
|||
"core_proto_inc.h",
|
||||
"proto_ns.h",
|
||||
],
|
||||
visibility = ["//visibility:public"],
|
||||
deps = [
|
||||
"//mediapipe/framework:port",
|
||||
"@com_google_protobuf//:protobuf",
|
||||
|
@ -117,7 +109,6 @@ cc_library(
|
|||
cc_library(
|
||||
name = "file_helpers",
|
||||
hdrs = ["file_helpers.h"],
|
||||
visibility = ["//visibility:public"],
|
||||
deps = [
|
||||
":status",
|
||||
"//mediapipe/framework/deps:file_helpers",
|
||||
|
@ -128,7 +119,6 @@ cc_library(
|
|||
cc_library(
|
||||
name = "image_resizer",
|
||||
hdrs = ["image_resizer.h"],
|
||||
visibility = ["//visibility:public"],
|
||||
deps = select({
|
||||
"//conditions:default": [
|
||||
"//mediapipe/framework/deps:image_resizer",
|
||||
|
@ -140,14 +130,12 @@ cc_library(
|
|||
cc_library(
|
||||
name = "integral_types",
|
||||
hdrs = ["integral_types.h"],
|
||||
visibility = ["//visibility:public"],
|
||||
)
|
||||
|
||||
cc_library(
|
||||
name = "benchmark",
|
||||
testonly = 1,
|
||||
hdrs = ["benchmark.h"],
|
||||
visibility = ["//visibility:public"],
|
||||
deps = [
|
||||
"@com_google_benchmark//:benchmark",
|
||||
],
|
||||
|
@ -158,7 +146,6 @@ cc_library(
|
|||
hdrs = [
|
||||
"re2.h",
|
||||
],
|
||||
visibility = ["//visibility:public"],
|
||||
deps = [
|
||||
"//mediapipe/framework/deps:re2",
|
||||
],
|
||||
|
@ -173,7 +160,6 @@ cc_library(
|
|||
"gtest-spi.h",
|
||||
"status_matchers.h",
|
||||
],
|
||||
visibility = ["//visibility:public"],
|
||||
deps = [
|
||||
":status_matchers",
|
||||
"//mediapipe/framework/deps:message_matchers",
|
||||
|
@ -190,7 +176,6 @@ cc_library(
|
|||
"gtest-spi.h",
|
||||
"status_matchers.h",
|
||||
],
|
||||
visibility = ["//visibility:public"],
|
||||
deps = [
|
||||
":status_matchers",
|
||||
"//mediapipe/framework/deps:message_matchers",
|
||||
|
@ -204,7 +189,6 @@ cc_library(
|
|||
hdrs = [
|
||||
"logging.h",
|
||||
],
|
||||
visibility = ["//visibility:public"],
|
||||
deps = [
|
||||
"//mediapipe/framework:port",
|
||||
"//third_party:glog",
|
||||
|
@ -217,7 +201,6 @@ cc_library(
|
|||
hdrs = [
|
||||
"map_util.h",
|
||||
],
|
||||
visibility = ["//visibility:public"],
|
||||
deps = [
|
||||
"//mediapipe/framework:port",
|
||||
"//mediapipe/framework/deps:map_util",
|
||||
|
@ -227,7 +210,6 @@ cc_library(
|
|||
cc_library(
|
||||
name = "numbers",
|
||||
hdrs = ["numbers.h"],
|
||||
visibility = ["//visibility:public"],
|
||||
deps = ["//mediapipe/framework/deps:numbers"],
|
||||
)
|
||||
|
||||
|
@ -238,13 +220,11 @@ config_setting(
|
|||
define_values = {
|
||||
"MEDIAPIPE_DISABLE_OPENCV": "1",
|
||||
},
|
||||
visibility = ["//visibility:public"],
|
||||
)
|
||||
|
||||
cc_library(
|
||||
name = "opencv_core",
|
||||
hdrs = ["opencv_core_inc.h"],
|
||||
visibility = ["//visibility:public"],
|
||||
deps = [
|
||||
"//third_party:opencv",
|
||||
],
|
||||
|
@ -253,7 +233,6 @@ cc_library(
|
|||
cc_library(
|
||||
name = "opencv_imgproc",
|
||||
hdrs = ["opencv_imgproc_inc.h"],
|
||||
visibility = ["//visibility:public"],
|
||||
deps = [
|
||||
":opencv_core",
|
||||
"//third_party:opencv",
|
||||
|
@ -263,7 +242,6 @@ cc_library(
|
|||
cc_library(
|
||||
name = "opencv_imgcodecs",
|
||||
hdrs = ["opencv_imgcodecs_inc.h"],
|
||||
visibility = ["//visibility:public"],
|
||||
deps = [
|
||||
":opencv_core",
|
||||
"//third_party:opencv",
|
||||
|
@ -273,7 +251,6 @@ cc_library(
|
|||
cc_library(
|
||||
name = "opencv_highgui",
|
||||
hdrs = ["opencv_highgui_inc.h"],
|
||||
visibility = ["//visibility:public"],
|
||||
deps = [
|
||||
":opencv_core",
|
||||
"//third_party:opencv",
|
||||
|
@ -283,7 +260,6 @@ cc_library(
|
|||
cc_library(
|
||||
name = "opencv_video",
|
||||
hdrs = ["opencv_video_inc.h"],
|
||||
visibility = ["//visibility:public"],
|
||||
deps = [
|
||||
":opencv_core",
|
||||
"//mediapipe/framework:port",
|
||||
|
@ -294,7 +270,6 @@ cc_library(
|
|||
cc_library(
|
||||
name = "opencv_features2d",
|
||||
hdrs = ["opencv_features2d_inc.h"],
|
||||
visibility = ["//visibility:public"],
|
||||
deps = [
|
||||
":opencv_core",
|
||||
"//third_party:opencv",
|
||||
|
@ -304,7 +279,6 @@ cc_library(
|
|||
cc_library(
|
||||
name = "opencv_calib3d",
|
||||
hdrs = ["opencv_calib3d_inc.h"],
|
||||
visibility = ["//visibility:public"],
|
||||
deps = [
|
||||
":opencv_core",
|
||||
"//third_party:opencv",
|
||||
|
@ -314,7 +288,6 @@ cc_library(
|
|||
cc_library(
|
||||
name = "opencv_videoio",
|
||||
hdrs = ["opencv_videoio_inc.h"],
|
||||
visibility = ["//visibility:public"],
|
||||
deps = [
|
||||
":opencv_core",
|
||||
"//mediapipe/framework:port",
|
||||
|
@ -328,7 +301,6 @@ cc_library(
|
|||
"parse_text_proto.h",
|
||||
"proto_ns.h",
|
||||
],
|
||||
visibility = ["//visibility:public"],
|
||||
deps = [
|
||||
":core_proto",
|
||||
":logging",
|
||||
|
@ -339,14 +311,12 @@ cc_library(
|
|||
cc_library(
|
||||
name = "point",
|
||||
hdrs = ["point2.h"],
|
||||
visibility = ["//visibility:public"],
|
||||
deps = ["//mediapipe/framework/deps:point"],
|
||||
)
|
||||
|
||||
cc_library(
|
||||
name = "port",
|
||||
hdrs = ["port.h"],
|
||||
visibility = ["//visibility:public"],
|
||||
deps = [
|
||||
"//mediapipe/framework:port",
|
||||
"@com_google_absl//absl/base:core_headers",
|
||||
|
@ -356,14 +326,12 @@ cc_library(
|
|||
cc_library(
|
||||
name = "rectangle",
|
||||
hdrs = ["rectangle.h"],
|
||||
visibility = ["//visibility:public"],
|
||||
deps = ["//mediapipe/framework/deps:rectangle"],
|
||||
)
|
||||
|
||||
cc_library(
|
||||
name = "ret_check",
|
||||
hdrs = ["ret_check.h"],
|
||||
visibility = ["//visibility:public"],
|
||||
deps = [
|
||||
"//mediapipe/framework:port",
|
||||
"//mediapipe/framework/deps:ret_check",
|
||||
|
@ -373,7 +341,6 @@ cc_library(
|
|||
cc_library(
|
||||
name = "singleton",
|
||||
hdrs = ["singleton.h"],
|
||||
visibility = ["//visibility:public"],
|
||||
deps = ["//mediapipe/framework/deps:singleton"],
|
||||
)
|
||||
|
||||
|
@ -382,7 +349,6 @@ cc_library(
|
|||
hdrs = [
|
||||
"source_location.h",
|
||||
],
|
||||
visibility = ["//visibility:public"],
|
||||
deps = [
|
||||
"//mediapipe/framework:port",
|
||||
"//mediapipe/framework/deps:source_location",
|
||||
|
@ -397,7 +363,6 @@ cc_library(
|
|||
"status_builder.h",
|
||||
"status_macros.h",
|
||||
],
|
||||
visibility = ["//visibility:public"],
|
||||
deps = [
|
||||
":source_location",
|
||||
"//mediapipe/framework:port",
|
||||
|
@ -412,7 +377,6 @@ cc_library(
|
|||
hdrs = [
|
||||
"statusor.h",
|
||||
],
|
||||
visibility = ["//visibility:public"],
|
||||
deps = [
|
||||
"//mediapipe/framework:port",
|
||||
"@com_google_absl//absl/status:statusor",
|
||||
|
@ -423,7 +387,6 @@ cc_library(
|
|||
name = "status_matchers",
|
||||
testonly = 1,
|
||||
hdrs = ["status_matchers.h"],
|
||||
visibility = ["//visibility:private"],
|
||||
deps = [
|
||||
":status",
|
||||
"@com_google_googletest//:gtest",
|
||||
|
@ -433,7 +396,6 @@ cc_library(
|
|||
cc_library(
|
||||
name = "threadpool",
|
||||
hdrs = ["threadpool.h"],
|
||||
visibility = ["//visibility:public"],
|
||||
deps = select({
|
||||
"//conditions:default": [":threadpool_impl_default_to_google"],
|
||||
"//mediapipe:android": [":threadpool_impl_default_to_mediapipe"],
|
||||
|
@ -460,7 +422,6 @@ alias(
|
|||
cc_library(
|
||||
name = "topologicalsorter",
|
||||
hdrs = ["topologicalsorter.h"],
|
||||
visibility = ["//visibility:public"],
|
||||
deps = [
|
||||
"//mediapipe/framework:port",
|
||||
"//mediapipe/framework/deps:topologicalsorter",
|
||||
|
@ -470,6 +431,5 @@ cc_library(
|
|||
cc_library(
|
||||
name = "vector",
|
||||
hdrs = ["vector.h"],
|
||||
visibility = ["//visibility:public"],
|
||||
deps = ["//mediapipe/framework/deps:vector"],
|
||||
)
|
||||
|
|
|
@ -228,6 +228,8 @@ def mediapipe_ts_library(
|
|||
srcs = srcs,
|
||||
visibility = visibility,
|
||||
deps = deps + [
|
||||
"@npm//@types/jasmine",
|
||||
"@npm//@types/node",
|
||||
"@npm//@types/offscreencanvas",
|
||||
"@npm//@types/google-protobuf",
|
||||
],
|
||||
|
|
|
@ -140,7 +140,7 @@ cc_library(
|
|||
name = "circular_buffer",
|
||||
hdrs = ["circular_buffer.h"],
|
||||
visibility = [
|
||||
"//visibility:public",
|
||||
"//mediapipe:__subpackages__",
|
||||
],
|
||||
deps = [
|
||||
"//mediapipe/framework/port:integral_types",
|
||||
|
@ -151,7 +151,6 @@ cc_test(
|
|||
name = "circular_buffer_test",
|
||||
size = "small",
|
||||
srcs = ["circular_buffer_test.cc"],
|
||||
visibility = ["//visibility:public"],
|
||||
deps = [
|
||||
":circular_buffer",
|
||||
"//mediapipe/framework/port:gtest_main",
|
||||
|
@ -164,7 +163,7 @@ cc_library(
|
|||
name = "trace_buffer",
|
||||
srcs = ["trace_buffer.h"],
|
||||
hdrs = ["trace_buffer.h"],
|
||||
visibility = ["//visibility:public"],
|
||||
visibility = ["//mediapipe/framework/profiler:__subpackages__"],
|
||||
deps = [
|
||||
":circular_buffer",
|
||||
"//mediapipe/framework:calculator_profile_cc_proto",
|
||||
|
@ -292,9 +291,7 @@ cc_library(
|
|||
"-ObjC++",
|
||||
],
|
||||
}),
|
||||
visibility = [
|
||||
"//mediapipe/framework:mediapipe_internal",
|
||||
],
|
||||
visibility = ["//visibility:private"],
|
||||
deps = [
|
||||
"@com_google_absl//absl/flags:flag",
|
||||
"//mediapipe/framework/port:logging",
|
||||
|
|
|
@ -232,6 +232,11 @@ class GraphProfiler : public std::enable_shared_from_this<ProfilingContext> {
|
|||
|
||||
const ProfilerConfig& profiler_config() { return profiler_config_; }
|
||||
|
||||
// Helper method to expose the config to other profilers.
|
||||
const ValidatedGraphConfig* GetValidatedGraphConfig() {
|
||||
return validated_graph_;
|
||||
}
|
||||
|
||||
private:
|
||||
// This can be used to add packet info for the input streams to the graph.
|
||||
// It treats the stream defined by |stream_name| as a stream produced by a
|
||||
|
|
|
@ -117,7 +117,7 @@ void Scheduler::SubmitWaitingTasksOnQueues() {
|
|||
// Note: state_mutex_ is held when this function is entered or
|
||||
// exited.
|
||||
void Scheduler::HandleIdle() {
|
||||
if (handling_idle_) {
|
||||
if (++handling_idle_ > 1) {
|
||||
// Someone is already inside this method.
|
||||
// Note: This can happen in the sections below where we unlock the mutex
|
||||
// and make more nodes runnable: the nodes can run and become idle again
|
||||
|
@ -127,7 +127,6 @@ void Scheduler::HandleIdle() {
|
|||
VLOG(2) << "HandleIdle: already in progress";
|
||||
return;
|
||||
}
|
||||
handling_idle_ = true;
|
||||
|
||||
while (IsIdle() && (state_ == STATE_RUNNING || state_ == STATE_CANCELLING)) {
|
||||
// Remove active sources that are closed.
|
||||
|
@ -165,11 +164,17 @@ void Scheduler::HandleIdle() {
|
|||
}
|
||||
}
|
||||
|
||||
// If HandleIdle has been called again, then continue scheduling.
|
||||
if (handling_idle_ > 1) {
|
||||
handling_idle_ = 1;
|
||||
continue;
|
||||
}
|
||||
|
||||
// Nothing left to do.
|
||||
break;
|
||||
}
|
||||
|
||||
handling_idle_ = false;
|
||||
handling_idle_ = 0;
|
||||
}
|
||||
|
||||
// Note: state_mutex_ is held when this function is entered or exited.
|
||||
|
|
|
@ -302,7 +302,7 @@ class Scheduler {
|
|||
// - We need it to be reentrant, which Mutex does not support.
|
||||
// - We want simultaneous calls to return immediately instead of waiting,
|
||||
// and Mutex's TryLock is not guaranteed to work.
|
||||
bool handling_idle_ ABSL_GUARDED_BY(state_mutex_) = false;
|
||||
int handling_idle_ ABSL_GUARDED_BY(state_mutex_) = 0;
|
||||
|
||||
// Mutex for the scheduler state and related things.
|
||||
// Note: state_ is declared as atomic so that its getter methods don't need
|
||||
|
|
|
@ -90,7 +90,7 @@ mediapipe_proto_library(
|
|||
name = "packet_generator_wrapper_calculator_proto",
|
||||
srcs = ["packet_generator_wrapper_calculator.proto"],
|
||||
def_py_proto = False,
|
||||
visibility = ["//mediapipe/framework:mediapipe_internal"],
|
||||
visibility = ["//mediapipe/framework:__subpackages__"],
|
||||
deps = [
|
||||
"//mediapipe/framework:calculator_options_proto",
|
||||
"//mediapipe/framework:packet_generator_proto",
|
||||
|
@ -120,13 +120,13 @@ cc_library(
|
|||
name = "fill_packet_set",
|
||||
srcs = ["fill_packet_set.cc"],
|
||||
hdrs = ["fill_packet_set.h"],
|
||||
visibility = ["//mediapipe/framework:mediapipe_internal"],
|
||||
visibility = ["//mediapipe/framework:__subpackages__"],
|
||||
deps = [
|
||||
":status_util",
|
||||
"//mediapipe/framework:packet_set",
|
||||
"//mediapipe/framework:packet_type",
|
||||
"//mediapipe/framework/port:status",
|
||||
"//mediapipe/framework/port:statusor",
|
||||
"//mediapipe/framework/tool:status_util",
|
||||
"@com_google_absl//absl/memory",
|
||||
],
|
||||
)
|
||||
|
@ -162,7 +162,6 @@ cc_library(
|
|||
cc_test(
|
||||
name = "executor_util_test",
|
||||
srcs = ["executor_util_test.cc"],
|
||||
visibility = ["//mediapipe/framework:mediapipe_internal"],
|
||||
deps = [
|
||||
":executor_util",
|
||||
"//mediapipe/framework/port:gtest_main",
|
||||
|
@ -173,7 +172,7 @@ cc_test(
|
|||
cc_library(
|
||||
name = "options_map",
|
||||
hdrs = ["options_map.h"],
|
||||
visibility = ["//mediapipe/framework:mediapipe_internal"],
|
||||
visibility = ["//mediapipe:__subpackages__"],
|
||||
deps = [
|
||||
":type_util",
|
||||
"//mediapipe/framework:calculator_cc_proto",
|
||||
|
@ -193,7 +192,7 @@ cc_library(
|
|||
name = "options_field_util",
|
||||
srcs = ["options_field_util.cc"],
|
||||
hdrs = ["options_field_util.h"],
|
||||
visibility = ["//mediapipe/framework:mediapipe_internal"],
|
||||
visibility = ["//visibility:private"],
|
||||
deps = [
|
||||
":field_data_cc_proto",
|
||||
":name_util",
|
||||
|
@ -216,7 +215,7 @@ cc_library(
|
|||
name = "options_syntax_util",
|
||||
srcs = ["options_syntax_util.cc"],
|
||||
hdrs = ["options_syntax_util.h"],
|
||||
visibility = ["//mediapipe/framework:mediapipe_internal"],
|
||||
visibility = ["//visibility:private"],
|
||||
deps = [
|
||||
":name_util",
|
||||
":options_field_util",
|
||||
|
@ -235,8 +234,9 @@ cc_library(
|
|||
name = "options_util",
|
||||
srcs = ["options_util.cc"],
|
||||
hdrs = ["options_util.h"],
|
||||
visibility = ["//mediapipe/framework:mediapipe_internal"],
|
||||
visibility = ["//visibility:public"],
|
||||
deps = [
|
||||
":name_util",
|
||||
":options_field_util",
|
||||
":options_map",
|
||||
":options_registry",
|
||||
|
@ -254,7 +254,6 @@ cc_library(
|
|||
"//mediapipe/framework/port:advanced_proto",
|
||||
"//mediapipe/framework/port:any_proto",
|
||||
"//mediapipe/framework/port:status",
|
||||
"//mediapipe/framework/tool:name_util",
|
||||
"@com_google_absl//absl/strings",
|
||||
],
|
||||
)
|
||||
|
@ -323,7 +322,7 @@ mediapipe_cc_test(
|
|||
cc_library(
|
||||
name = "packet_generator_wrapper_calculator",
|
||||
srcs = ["packet_generator_wrapper_calculator.cc"],
|
||||
visibility = ["//mediapipe/framework:mediapipe_internal"],
|
||||
visibility = ["//mediapipe/framework:__subpackages__"],
|
||||
deps = [
|
||||
":packet_generator_wrapper_calculator_cc_proto",
|
||||
"//mediapipe/framework:calculator_base",
|
||||
|
@ -347,6 +346,7 @@ cc_library(
|
|||
"//mediapipe/framework/port:logging",
|
||||
"//mediapipe/framework/port:ret_check",
|
||||
"//mediapipe/framework/port:status",
|
||||
"//mediapipe/framework/port:statusor",
|
||||
"@com_google_absl//absl/strings",
|
||||
],
|
||||
)
|
||||
|
@ -507,6 +507,7 @@ cc_library(
|
|||
visibility = ["//visibility:public"],
|
||||
deps = [
|
||||
":proto_util_lite",
|
||||
"//mediapipe/framework:calculator_cc_proto",
|
||||
"//mediapipe/framework/deps:proto_descriptor_cc_proto",
|
||||
"//mediapipe/framework/port:advanced_proto",
|
||||
"//mediapipe/framework/port:integral_types",
|
||||
|
|
|
@ -27,6 +27,9 @@ message TemplateExpression {
|
|||
// The FieldDescriptor::Type of the modified field.
|
||||
optional mediapipe.FieldDescriptorProto.Type field_type = 5;
|
||||
|
||||
// The FieldDescriptor::Type of each map key in the path.
|
||||
repeated mediapipe.FieldDescriptorProto.Type key_type = 6;
|
||||
|
||||
// Alternative value for the modified field, in protobuf binary format.
|
||||
optional string field_value = 7;
|
||||
}
|
||||
|
|
|
@ -22,6 +22,7 @@
|
|||
#include "mediapipe/framework/port/canonical_errors.h"
|
||||
#include "mediapipe/framework/port/logging.h"
|
||||
#include "mediapipe/framework/port/ret_check.h"
|
||||
#include "mediapipe/framework/port/statusor.h"
|
||||
#include "mediapipe/framework/tool/field_data.pb.h"
|
||||
#include "mediapipe/framework/type_map.h"
|
||||
|
||||
|
@ -87,12 +88,13 @@ absl::Status ReadPackedValues(WireFormatLite::WireType wire_type,
|
|||
|
||||
// Extracts the data value(s) for one field from a serialized message.
|
||||
// The message with these field values removed is written to |out|.
|
||||
absl::Status GetFieldValues(uint32 field_id, WireFormatLite::WireType wire_type,
|
||||
CodedInputStream* in, CodedOutputStream* out,
|
||||
absl::Status GetFieldValues(uint32 field_id, CodedInputStream* in,
|
||||
CodedOutputStream* out,
|
||||
std::vector<std::string>* field_values) {
|
||||
uint32 tag;
|
||||
while ((tag = in->ReadTag()) != 0) {
|
||||
int field_number = WireFormatLite::GetTagFieldNumber(tag);
|
||||
WireFormatLite::WireType wire_type = WireFormatLite::GetTagWireType(tag);
|
||||
if (field_number == field_id) {
|
||||
if (!IsLengthDelimited(wire_type) &&
|
||||
IsLengthDelimited(WireFormatLite::GetTagWireType(tag))) {
|
||||
|
@ -131,9 +133,7 @@ absl::Status FieldAccess::SetMessage(const std::string& message) {
|
|||
CodedInputStream in(&ais);
|
||||
StringOutputStream sos(&message_);
|
||||
CodedOutputStream out(&sos);
|
||||
WireFormatLite::WireType wire_type =
|
||||
WireFormatLite::WireTypeForFieldType(field_type_);
|
||||
return GetFieldValues(field_id_, wire_type, &in, &out, &field_values_);
|
||||
return GetFieldValues(field_id_, &in, &out, &field_values_);
|
||||
}
|
||||
|
||||
void FieldAccess::GetMessage(std::string* result) {
|
||||
|
@ -149,18 +149,56 @@ std::vector<FieldValue>* FieldAccess::mutable_field_values() {
|
|||
return &field_values_;
|
||||
}
|
||||
|
||||
namespace {
|
||||
using ProtoPathEntry = ProtoUtilLite::ProtoPathEntry;
|
||||
|
||||
// Returns the FieldAccess and index for a field-id or a map-id.
|
||||
// Returns access to the field-id if the field index is found,
|
||||
// to the map-id if the map entry is found, and to the field-id otherwise.
|
||||
absl::StatusOr<std::pair<FieldAccess, int>> AccessField(
|
||||
const ProtoPathEntry& entry, FieldType field_type,
|
||||
const FieldValue& message) {
|
||||
FieldAccess result(entry.field_id, field_type);
|
||||
if (entry.field_id >= 0) {
|
||||
MP_RETURN_IF_ERROR(result.SetMessage(message));
|
||||
if (entry.index < result.mutable_field_values()->size()) {
|
||||
return std::pair(result, entry.index);
|
||||
}
|
||||
}
|
||||
if (entry.map_id >= 0) {
|
||||
FieldAccess access(entry.map_id, field_type);
|
||||
MP_RETURN_IF_ERROR(access.SetMessage(message));
|
||||
auto& field_values = *access.mutable_field_values();
|
||||
for (int index = 0; index < field_values.size(); ++index) {
|
||||
FieldAccess key(entry.key_id, entry.key_type);
|
||||
MP_RETURN_IF_ERROR(key.SetMessage(field_values[index]));
|
||||
if (key.mutable_field_values()->at(0) == entry.key_value) {
|
||||
return std::pair(std::move(access), index);
|
||||
}
|
||||
}
|
||||
}
|
||||
if (entry.field_id >= 0) {
|
||||
return std::pair(result, entry.index);
|
||||
}
|
||||
return absl::InvalidArgumentError(absl::StrCat(
|
||||
"ProtoPath field missing, field-id: ", entry.field_id, ", map-id: ",
|
||||
entry.map_id, ", key: ", entry.key_value, " key_type: ", entry.key_type));
|
||||
}
|
||||
|
||||
} // namespace
|
||||
|
||||
// Replaces a range of field values for one field nested within a protobuf.
|
||||
absl::Status ProtoUtilLite::ReplaceFieldRange(
|
||||
FieldValue* message, ProtoPath proto_path, int length, FieldType field_type,
|
||||
const std::vector<FieldValue>& field_values) {
|
||||
int field_id, index;
|
||||
std::tie(field_id, index) = proto_path.front();
|
||||
ProtoPathEntry entry = proto_path.front();
|
||||
proto_path.erase(proto_path.begin());
|
||||
FieldAccess access(field_id, !proto_path.empty()
|
||||
? WireFormatLite::TYPE_MESSAGE
|
||||
: field_type);
|
||||
MP_RETURN_IF_ERROR(access.SetMessage(*message));
|
||||
std::vector<std::string>& v = *access.mutable_field_values();
|
||||
FieldType type =
|
||||
!proto_path.empty() ? WireFormatLite::TYPE_MESSAGE : field_type;
|
||||
ASSIGN_OR_RETURN(auto r, AccessField(entry, type, *message));
|
||||
FieldAccess& access = r.first;
|
||||
int index = r.second;
|
||||
std::vector<FieldValue>& v = *access.mutable_field_values();
|
||||
if (!proto_path.empty()) {
|
||||
RET_CHECK_NO_LOG(index >= 0 && index < v.size());
|
||||
MP_RETURN_IF_ERROR(ReplaceFieldRange(&v[index], proto_path, length,
|
||||
|
@ -180,19 +218,22 @@ absl::Status ProtoUtilLite::ReplaceFieldRange(
|
|||
absl::Status ProtoUtilLite::GetFieldRange(
|
||||
const FieldValue& message, ProtoPath proto_path, int length,
|
||||
FieldType field_type, std::vector<FieldValue>* field_values) {
|
||||
int field_id, index;
|
||||
std::tie(field_id, index) = proto_path.front();
|
||||
ProtoPathEntry entry = proto_path.front();
|
||||
proto_path.erase(proto_path.begin());
|
||||
FieldAccess access(field_id, !proto_path.empty()
|
||||
? WireFormatLite::TYPE_MESSAGE
|
||||
: field_type);
|
||||
MP_RETURN_IF_ERROR(access.SetMessage(message));
|
||||
std::vector<std::string>& v = *access.mutable_field_values();
|
||||
FieldType type =
|
||||
!proto_path.empty() ? WireFormatLite::TYPE_MESSAGE : field_type;
|
||||
ASSIGN_OR_RETURN(auto r, AccessField(entry, type, message));
|
||||
FieldAccess& access = r.first;
|
||||
int index = r.second;
|
||||
std::vector<FieldValue>& v = *access.mutable_field_values();
|
||||
if (!proto_path.empty()) {
|
||||
RET_CHECK_NO_LOG(index >= 0 && index < v.size());
|
||||
MP_RETURN_IF_ERROR(
|
||||
GetFieldRange(v[index], proto_path, length, field_type, field_values));
|
||||
} else {
|
||||
if (length == -1) {
|
||||
length = v.size() - index;
|
||||
}
|
||||
RET_CHECK_NO_LOG(index >= 0 && index <= v.size());
|
||||
RET_CHECK_NO_LOG(index + length >= 0 && index + length <= v.size());
|
||||
field_values->insert(field_values->begin(), v.begin() + index,
|
||||
|
@ -206,19 +247,21 @@ absl::Status ProtoUtilLite::GetFieldCount(const FieldValue& message,
|
|||
ProtoPath proto_path,
|
||||
FieldType field_type,
|
||||
int* field_count) {
|
||||
int field_id, index;
|
||||
std::tie(field_id, index) = proto_path.back();
|
||||
proto_path.pop_back();
|
||||
std::vector<std::string> parent;
|
||||
if (proto_path.empty()) {
|
||||
parent.push_back(std::string(message));
|
||||
ProtoPathEntry entry = proto_path.front();
|
||||
proto_path.erase(proto_path.begin());
|
||||
FieldType type =
|
||||
!proto_path.empty() ? WireFormatLite::TYPE_MESSAGE : field_type;
|
||||
ASSIGN_OR_RETURN(auto r, AccessField(entry, type, message));
|
||||
FieldAccess& access = r.first;
|
||||
int index = r.second;
|
||||
std::vector<FieldValue>& v = *access.mutable_field_values();
|
||||
if (!proto_path.empty()) {
|
||||
RET_CHECK_NO_LOG(index >= 0 && index < v.size());
|
||||
MP_RETURN_IF_ERROR(
|
||||
GetFieldCount(v[index], proto_path, field_type, field_count));
|
||||
} else {
|
||||
MP_RETURN_IF_ERROR(ProtoUtilLite::GetFieldRange(
|
||||
message, proto_path, 1, WireFormatLite::TYPE_MESSAGE, &parent));
|
||||
*field_count = v.size();
|
||||
}
|
||||
FieldAccess access(field_id, field_type);
|
||||
MP_RETURN_IF_ERROR(access.SetMessage(parent[0]));
|
||||
*field_count = access.mutable_field_values()->size();
|
||||
return absl::OkStatus();
|
||||
}
|
||||
|
||||
|
|
|
@ -34,15 +34,36 @@ class ProtoUtilLite {
|
|||
// Defines field types and tag formats.
|
||||
using WireFormatLite = proto_ns::internal::WireFormatLite;
|
||||
|
||||
// Defines a sequence of nested field-number field-index pairs.
|
||||
using ProtoPath = std::vector<std::pair<int, int>>;
|
||||
|
||||
// The serialized value for a protobuf field.
|
||||
using FieldValue = std::string;
|
||||
|
||||
// The serialized data type for a protobuf field.
|
||||
using FieldType = WireFormatLite::FieldType;
|
||||
|
||||
// A field-id and index, or a map-id and key, or both.
|
||||
struct ProtoPathEntry {
|
||||
ProtoPathEntry(int id, int index) : field_id(id), index(index) {}
|
||||
ProtoPathEntry(int id, int key_id, FieldType key_type, FieldValue key_value)
|
||||
: map_id(id),
|
||||
key_id(key_id),
|
||||
key_type(key_type),
|
||||
key_value(std::move(key_value)) {}
|
||||
bool operator==(const ProtoPathEntry& o) const {
|
||||
return field_id == o.field_id && index == o.index && map_id == o.map_id &&
|
||||
key_id == o.key_id && key_type == o.key_type &&
|
||||
key_value == o.key_value;
|
||||
}
|
||||
int field_id = -1;
|
||||
int index = -1;
|
||||
int map_id = -1;
|
||||
int key_id = -1;
|
||||
FieldType key_type = FieldType::MAX_FIELD_TYPE;
|
||||
FieldValue key_value;
|
||||
};
|
||||
|
||||
// Defines a sequence of nested field-number field-index pairs.
|
||||
using ProtoPath = std::vector<ProtoPathEntry>;
|
||||
|
||||
class FieldAccess {
|
||||
public:
|
||||
// Provides access to a certain protobuf field.
|
||||
|
@ -57,9 +78,11 @@ class ProtoUtilLite {
|
|||
// Returns the serialized values of the protobuf field.
|
||||
std::vector<FieldValue>* mutable_field_values();
|
||||
|
||||
uint32 field_id() const { return field_id_; }
|
||||
|
||||
private:
|
||||
const uint32 field_id_;
|
||||
const FieldType field_type_;
|
||||
uint32 field_id_;
|
||||
FieldType field_type_;
|
||||
std::string message_;
|
||||
std::vector<FieldValue> field_values_;
|
||||
};
|
||||
|
|
|
@ -22,6 +22,7 @@
|
|||
#include <vector>
|
||||
|
||||
#include "absl/strings/ascii.h"
|
||||
#include "absl/strings/match.h"
|
||||
#include "absl/strings/numbers.h"
|
||||
#include "absl/strings/str_join.h"
|
||||
#include "absl/strings/str_split.h"
|
||||
|
@ -44,6 +45,7 @@ using WireFormatLite = ProtoUtilLite::WireFormatLite;
|
|||
using FieldValue = ProtoUtilLite::FieldValue;
|
||||
using FieldType = ProtoUtilLite::FieldType;
|
||||
using ProtoPath = ProtoUtilLite::ProtoPath;
|
||||
using ProtoPathEntry = ProtoUtilLite::ProtoPathEntry;
|
||||
|
||||
namespace {
|
||||
|
||||
|
@ -84,26 +86,87 @@ std::unique_ptr<MessageLite> CloneMessage(const MessageLite& message) {
|
|||
return result;
|
||||
}
|
||||
|
||||
// Returns the (tag, index) pairs in a field path.
|
||||
// For example, returns {{1, 1}, {2, 1}, {3, 1}} for path "/1[1]/2[1]/3[1]".
|
||||
absl::Status ProtoPathSplit(const std::string& path, ProtoPath* result) {
|
||||
absl::Status status;
|
||||
std::vector<std::string> ids = absl::StrSplit(path, '/');
|
||||
for (const std::string& id : ids) {
|
||||
if (id.length() > 0) {
|
||||
std::pair<std::string, std::string> id_pair =
|
||||
absl::StrSplit(id, absl::ByAnyChar("[]"));
|
||||
int tag = 0;
|
||||
int index = 0;
|
||||
bool ok = absl::SimpleAtoi(id_pair.first, &tag) &&
|
||||
absl::SimpleAtoi(id_pair.second, &index);
|
||||
if (!ok) {
|
||||
status.Update(absl::InvalidArgumentError(path));
|
||||
}
|
||||
result->push_back(std::make_pair(tag, index));
|
||||
// Parses one ProtoPathEntry.
|
||||
// The parsed entry is appended to `result` and removed from `path`.
|
||||
// ProtoPathEntry::key_value stores map key text. Use SetMapKeyTypes
|
||||
// to serialize the key text to protobuf wire format.
|
||||
absl::Status ParseEntry(absl::string_view& path, ProtoPath* result) {
|
||||
bool ok = true;
|
||||
int sb = path.find('[');
|
||||
int eb = path.find(']');
|
||||
int field_id = -1;
|
||||
ok &= absl::SimpleAtoi(path.substr(0, sb), &field_id);
|
||||
auto selector = path.substr(sb + 1, eb - 1 - sb);
|
||||
if (absl::StartsWith(selector, "@")) {
|
||||
int eq = selector.find('=');
|
||||
int key_id = -1;
|
||||
ok &= absl::SimpleAtoi(selector.substr(1, eq - 1), &key_id);
|
||||
auto key_text = selector.substr(eq + 1);
|
||||
FieldType key_type = FieldType::TYPE_STRING;
|
||||
result->push_back({field_id, key_id, key_type, std::string(key_text)});
|
||||
} else {
|
||||
int index = 0;
|
||||
ok &= absl::SimpleAtoi(selector, &index);
|
||||
result->push_back({field_id, index});
|
||||
}
|
||||
int end = path.find('/', eb);
|
||||
if (end == std::string::npos) {
|
||||
path = "";
|
||||
} else {
|
||||
path = path.substr(end + 1);
|
||||
}
|
||||
return ok ? absl::OkStatus()
|
||||
: absl::InvalidArgumentError(
|
||||
absl::StrCat("Failed to parse ProtoPath entry: ", path));
|
||||
}
|
||||
|
||||
// Specifies the FieldTypes for protobuf map keys in a ProtoPath.
|
||||
// Each ProtoPathEntry::key_value is converted from text to the protobuf
|
||||
// wire format for its key type.
|
||||
absl::Status SetMapKeyTypes(const std::vector<FieldType>& key_types,
|
||||
ProtoPath* result) {
|
||||
int i = 0;
|
||||
for (ProtoPathEntry& entry : *result) {
|
||||
if (entry.map_id >= 0) {
|
||||
FieldType key_type = key_types[i++];
|
||||
std::vector<FieldValue> key_value;
|
||||
MP_RETURN_IF_ERROR(
|
||||
ProtoUtilLite::Serialize({entry.key_value}, key_type, &key_value));
|
||||
entry.key_type = key_type;
|
||||
entry.key_value = key_value.front();
|
||||
}
|
||||
}
|
||||
return status;
|
||||
return absl::OkStatus();
|
||||
}
|
||||
|
||||
// Returns the (tag, index) pairs in a field path.
|
||||
// For example, returns {{1, 1}, {2, 1}, {3, 1}} for "/1[1]/2[1]/3[1]",
|
||||
// returns {{1, 1}, {2, 1, "INPUT_FRAMES"}} for "/1[1]/2[@1=INPUT_FRAMES]".
|
||||
absl::Status ProtoPathSplit(const std::string& path, ProtoPath* result) {
|
||||
result->clear();
|
||||
absl::string_view rest = path;
|
||||
if (absl::StartsWith(rest, "/")) {
|
||||
rest = rest.substr(1);
|
||||
}
|
||||
while (!rest.empty()) {
|
||||
MP_RETURN_IF_ERROR(ParseEntry(rest, result));
|
||||
}
|
||||
return absl::OkStatus();
|
||||
}
|
||||
|
||||
// Parse the TemplateExpression.path field into a ProtoPath struct.
|
||||
absl::Status ParseProtoPath(const TemplateExpression& rule,
|
||||
std::string base_path, ProtoPath* result) {
|
||||
ProtoPath base_entries;
|
||||
MP_RETURN_IF_ERROR(ProtoPathSplit(base_path, &base_entries));
|
||||
MP_RETURN_IF_ERROR(ProtoPathSplit(rule.path(), result));
|
||||
std::vector<FieldType> key_types;
|
||||
for (int type : rule.key_type()) {
|
||||
key_types.push_back(static_cast<FieldType>(type));
|
||||
}
|
||||
MP_RETURN_IF_ERROR(SetMapKeyTypes(key_types, result));
|
||||
result->erase(result->begin(), result->begin() + base_entries.size());
|
||||
return absl::OkStatus();
|
||||
}
|
||||
|
||||
// Returns true if one proto path is prefix by another.
|
||||
|
@ -111,13 +174,6 @@ bool ProtoPathStartsWith(const std::string& path, const std::string& prefix) {
|
|||
return absl::StartsWith(path, prefix);
|
||||
}
|
||||
|
||||
// Returns the part of one proto path after a prefix proto path.
|
||||
std::string ProtoPathRelative(const std::string& field_path,
|
||||
const std::string& base_path) {
|
||||
CHECK(ProtoPathStartsWith(field_path, base_path));
|
||||
return field_path.substr(base_path.length());
|
||||
}
|
||||
|
||||
// Returns the target ProtoUtilLite::FieldType of a rule.
|
||||
FieldType GetFieldType(const TemplateExpression& rule) {
|
||||
return static_cast<FieldType>(rule.field_type());
|
||||
|
@ -126,19 +182,10 @@ FieldType GetFieldType(const TemplateExpression& rule) {
|
|||
// Returns the count of field values at a ProtoPath.
|
||||
int FieldCount(const FieldValue& base, ProtoPath field_path,
|
||||
FieldType field_type) {
|
||||
int field_id, index;
|
||||
std::tie(field_id, index) = field_path.back();
|
||||
field_path.pop_back();
|
||||
std::vector<FieldValue> parent;
|
||||
if (field_path.empty()) {
|
||||
parent.push_back(base);
|
||||
} else {
|
||||
MEDIAPIPE_CHECK_OK(ProtoUtilLite::GetFieldRange(
|
||||
base, field_path, 1, WireFormatLite::TYPE_MESSAGE, &parent));
|
||||
}
|
||||
ProtoUtilLite::FieldAccess access(field_id, field_type);
|
||||
MEDIAPIPE_CHECK_OK(access.SetMessage(parent[0]));
|
||||
return access.mutable_field_values()->size();
|
||||
int result = 0;
|
||||
CHECK(
|
||||
ProtoUtilLite::GetFieldCount(base, field_path, field_type, &result).ok());
|
||||
return result;
|
||||
}
|
||||
|
||||
} // namespace
|
||||
|
@ -229,9 +276,7 @@ class TemplateExpanderImpl {
|
|||
return absl::OkStatus();
|
||||
}
|
||||
ProtoPath field_path;
|
||||
absl::Status status =
|
||||
ProtoPathSplit(ProtoPathRelative(rule.path(), base_path), &field_path);
|
||||
if (!status.ok()) return status;
|
||||
MP_RETURN_IF_ERROR(ParseProtoPath(rule, base_path, &field_path));
|
||||
return ProtoUtilLite::GetFieldRange(output, field_path, 1,
|
||||
GetFieldType(rule), base);
|
||||
}
|
||||
|
@ -242,12 +287,13 @@ class TemplateExpanderImpl {
|
|||
const std::vector<FieldValue>& field_values,
|
||||
FieldValue* output) {
|
||||
if (!rule.has_path()) {
|
||||
*output = field_values[0];
|
||||
if (!field_values.empty()) {
|
||||
*output = field_values[0];
|
||||
}
|
||||
return absl::OkStatus();
|
||||
}
|
||||
ProtoPath field_path;
|
||||
RET_CHECK_OK(
|
||||
ProtoPathSplit(ProtoPathRelative(rule.path(), base_path), &field_path));
|
||||
MP_RETURN_IF_ERROR(ParseProtoPath(rule, base_path, &field_path));
|
||||
int field_count = 1;
|
||||
if (rule.has_field_value()) {
|
||||
// For a non-repeated field, only one value can be specified.
|
||||
|
@ -257,7 +303,7 @@ class TemplateExpanderImpl {
|
|||
"Multiple values specified for non-repeated field: ", rule.path()));
|
||||
}
|
||||
// For a non-repeated field, the field value is stored only in the rule.
|
||||
field_path[field_path.size() - 1].second = 0;
|
||||
field_path[field_path.size() - 1].index = 0;
|
||||
field_count = 0;
|
||||
}
|
||||
return ProtoUtilLite::ReplaceFieldRange(output, field_path, field_count,
|
||||
|
|
|
@ -26,6 +26,7 @@
|
|||
#include "absl/strings/str_cat.h"
|
||||
#include "absl/strings/str_join.h"
|
||||
#include "absl/strings/str_split.h"
|
||||
#include "mediapipe/framework/calculator.pb.h"
|
||||
#include "mediapipe/framework/deps/proto_descriptor.pb.h"
|
||||
#include "mediapipe/framework/port/canonical_errors.h"
|
||||
#include "mediapipe/framework/port/integral_types.h"
|
||||
|
@ -45,6 +46,9 @@ using mediapipe::proto_ns::Message;
|
|||
using mediapipe::proto_ns::OneofDescriptor;
|
||||
using mediapipe::proto_ns::Reflection;
|
||||
using mediapipe::proto_ns::TextFormat;
|
||||
using ProtoPath = mediapipe::tool::ProtoUtilLite::ProtoPath;
|
||||
using FieldType = mediapipe::tool::ProtoUtilLite::FieldType;
|
||||
using FieldValue = mediapipe::tool::ProtoUtilLite::FieldValue;
|
||||
|
||||
namespace mediapipe {
|
||||
|
||||
|
@ -1357,32 +1361,138 @@ absl::Status ProtoPathSplit(const std::string& path,
|
|||
if (!ok) {
|
||||
status.Update(absl::InvalidArgumentError(path));
|
||||
}
|
||||
result->push_back(std::make_pair(tag, index));
|
||||
result->push_back({tag, index});
|
||||
}
|
||||
}
|
||||
return status;
|
||||
}
|
||||
|
||||
// Returns a message serialized deterministically.
|
||||
bool DeterministicallySerialize(const Message& proto, std::string* result) {
|
||||
proto_ns::io::StringOutputStream stream(result);
|
||||
proto_ns::io::CodedOutputStream output(&stream);
|
||||
output.SetSerializationDeterministic(true);
|
||||
return proto.SerializeToCodedStream(&output);
|
||||
}
|
||||
|
||||
// Serialize one field of a message.
|
||||
void SerializeField(const Message* message, const FieldDescriptor* field,
|
||||
std::vector<ProtoUtilLite::FieldValue>* result) {
|
||||
ProtoUtilLite::FieldValue message_bytes;
|
||||
CHECK(message->SerializePartialToString(&message_bytes));
|
||||
CHECK(DeterministicallySerialize(*message, &message_bytes));
|
||||
ProtoUtilLite::FieldAccess access(
|
||||
field->number(), static_cast<ProtoUtilLite::FieldType>(field->type()));
|
||||
MEDIAPIPE_CHECK_OK(access.SetMessage(message_bytes));
|
||||
*result = *access.mutable_field_values();
|
||||
}
|
||||
|
||||
// Serialize a ProtoPath as a readable string.
|
||||
// For example, {{1, 1}, {2, 1}, {3, 1}} returns "/1[1]/2[1]/3[1]",
|
||||
// and {{1, 1}, {2, 1, "INPUT_FRAMES"}} returns "/1[1]/2[@1=INPUT_FRAMES]".
|
||||
std::string ProtoPathJoin(ProtoPath path) {
|
||||
std::string result;
|
||||
for (ProtoUtilLite::ProtoPathEntry& e : path) {
|
||||
if (e.field_id >= 0) {
|
||||
absl::StrAppend(&result, "/", e.field_id, "[", e.index, "]");
|
||||
} else if (e.map_id >= 0) {
|
||||
absl::StrAppend(&result, "/", e.map_id, "[@", e.key_id, "=", e.key_value,
|
||||
"]");
|
||||
}
|
||||
}
|
||||
return result;
|
||||
}
|
||||
|
||||
// Returns the message value from a field at an index.
|
||||
const Message* GetFieldMessage(const Message& message,
|
||||
const FieldDescriptor* field, int index) {
|
||||
if (field->type() != FieldDescriptor::TYPE_MESSAGE) {
|
||||
return nullptr;
|
||||
}
|
||||
if (!field->is_repeated()) {
|
||||
return &message.GetReflection()->GetMessage(message, field);
|
||||
}
|
||||
if (index < message.GetReflection()->FieldSize(message, field)) {
|
||||
return &message.GetReflection()->GetRepeatedMessage(message, field, index);
|
||||
}
|
||||
return nullptr;
|
||||
}
|
||||
|
||||
// Returns all FieldDescriptors including extensions.
|
||||
std::vector<const FieldDescriptor*> GetFields(const Message* src) {
|
||||
std::vector<const FieldDescriptor*> result;
|
||||
src->GetDescriptor()->file()->pool()->FindAllExtensions(src->GetDescriptor(),
|
||||
&result);
|
||||
for (int i = 0; i < src->GetDescriptor()->field_count(); ++i) {
|
||||
result.push_back(src->GetDescriptor()->field(i));
|
||||
}
|
||||
return result;
|
||||
}
|
||||
|
||||
// Orders map entries in dst to match src.
|
||||
void OrderMapEntries(const Message* src, Message* dst,
|
||||
std::set<const Message*>* seen = nullptr) {
|
||||
std::unique_ptr<std::set<const Message*>> seen_owner;
|
||||
if (!seen) {
|
||||
seen_owner = std::make_unique<std::set<const Message*>>();
|
||||
seen = seen_owner.get();
|
||||
}
|
||||
if (seen->count(src) > 0) {
|
||||
return;
|
||||
} else {
|
||||
seen->insert(src);
|
||||
}
|
||||
for (auto field : GetFields(src)) {
|
||||
if (field->is_map()) {
|
||||
dst->GetReflection()->ClearField(dst, field);
|
||||
for (int j = 0; j < src->GetReflection()->FieldSize(*src, field); ++j) {
|
||||
const Message& entry =
|
||||
src->GetReflection()->GetRepeatedMessage(*src, field, j);
|
||||
dst->GetReflection()->AddMessage(dst, field)->CopyFrom(entry);
|
||||
}
|
||||
}
|
||||
if (field->type() == FieldDescriptor::TYPE_MESSAGE) {
|
||||
if (field->is_repeated()) {
|
||||
for (int j = 0; j < src->GetReflection()->FieldSize(*src, field); ++j) {
|
||||
OrderMapEntries(
|
||||
&src->GetReflection()->GetRepeatedMessage(*src, field, j),
|
||||
dst->GetReflection()->MutableRepeatedMessage(dst, field, j),
|
||||
seen);
|
||||
}
|
||||
} else {
|
||||
OrderMapEntries(&src->GetReflection()->GetMessage(*src, field),
|
||||
dst->GetReflection()->MutableMessage(dst, field), seen);
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// Copies a Message, keeping map entries in order.
|
||||
std::unique_ptr<Message> CloneMessage(const Message* message) {
|
||||
std::unique_ptr<Message> result(message->New());
|
||||
result->CopyFrom(*message);
|
||||
OrderMapEntries(message, result.get());
|
||||
return result;
|
||||
}
|
||||
|
||||
using MessageMap = std::map<std::string, std::unique_ptr<Message>>;
|
||||
|
||||
// For a non-repeated field, move the most recently parsed field value
|
||||
// into the most recently parsed template expression.
|
||||
void StowFieldValue(Message* message, TemplateExpression* expression) {
|
||||
void StowFieldValue(Message* message, TemplateExpression* expression,
|
||||
MessageMap* stowed_messages) {
|
||||
const Reflection* reflection = message->GetReflection();
|
||||
const Descriptor* descriptor = message->GetDescriptor();
|
||||
ProtoUtilLite::ProtoPath path;
|
||||
MEDIAPIPE_CHECK_OK(ProtoPathSplit(expression->path(), &path));
|
||||
int field_number = path[path.size() - 1].first;
|
||||
int field_number = path[path.size() - 1].field_id;
|
||||
const FieldDescriptor* field = descriptor->FindFieldByNumber(field_number);
|
||||
|
||||
// Save each stowed message unserialized preserving map entry order.
|
||||
if (!field->is_repeated() && field->type() == FieldDescriptor::TYPE_MESSAGE) {
|
||||
(*stowed_messages)[ProtoPathJoin(path)] =
|
||||
CloneMessage(GetFieldMessage(*message, field, 0));
|
||||
}
|
||||
|
||||
if (!field->is_repeated()) {
|
||||
std::vector<ProtoUtilLite::FieldValue> field_values;
|
||||
SerializeField(message, field, &field_values);
|
||||
|
@ -1402,6 +1512,112 @@ static void StripQuotes(std::string* str) {
|
|||
}
|
||||
}
|
||||
|
||||
// Returns the field or extension for field number.
|
||||
const FieldDescriptor* FindFieldByNumber(const Message* message,
|
||||
int field_num) {
|
||||
const FieldDescriptor* result =
|
||||
message->GetDescriptor()->FindFieldByNumber(field_num);
|
||||
if (result == nullptr) {
|
||||
result = message->GetReflection()->FindKnownExtensionByNumber(field_num);
|
||||
}
|
||||
return result;
|
||||
}
|
||||
|
||||
// Returns the protobuf map key types from a ProtoPath.
|
||||
std::vector<FieldType> ProtoPathKeyTypes(ProtoPath path) {
|
||||
std::vector<FieldType> result;
|
||||
for (auto& entry : path) {
|
||||
if (entry.map_id >= 0) {
|
||||
result.push_back(entry.key_type);
|
||||
}
|
||||
}
|
||||
return result;
|
||||
}
|
||||
|
||||
// Returns the text value for a string or numeric protobuf map key.
|
||||
std::string GetMapKey(const Message& map_entry) {
|
||||
auto key_field = map_entry.GetDescriptor()->FindFieldByName("key");
|
||||
auto reflection = map_entry.GetReflection();
|
||||
if (key_field->type() == FieldDescriptor::TYPE_STRING) {
|
||||
return reflection->GetString(map_entry, key_field);
|
||||
} else if (key_field->type() == FieldDescriptor::TYPE_INT32) {
|
||||
return absl::StrCat(reflection->GetInt32(map_entry, key_field));
|
||||
} else if (key_field->type() == FieldDescriptor::TYPE_INT64) {
|
||||
return absl::StrCat(reflection->GetInt64(map_entry, key_field));
|
||||
}
|
||||
return "";
|
||||
}
|
||||
|
||||
// Returns a Message store in CalculatorGraphTemplate::field_value.
|
||||
Message* FindStowedMessage(MessageMap* stowed_messages, ProtoPath proto_path) {
|
||||
auto it = stowed_messages->find(ProtoPathJoin(proto_path));
|
||||
return (it != stowed_messages->end()) ? it->second.get() : nullptr;
|
||||
}
|
||||
|
||||
const Message* GetNestedMessage(const Message& message,
|
||||
const FieldDescriptor* field,
|
||||
ProtoPath proto_path,
|
||||
MessageMap* stowed_messages) {
|
||||
if (field->type() != FieldDescriptor::TYPE_MESSAGE) {
|
||||
return nullptr;
|
||||
}
|
||||
const Message* result = FindStowedMessage(stowed_messages, proto_path);
|
||||
if (!result) {
|
||||
result = GetFieldMessage(message, field, proto_path.back().index);
|
||||
}
|
||||
return result;
|
||||
}
|
||||
|
||||
// Adjusts map-entries from indexes to keys.
|
||||
// Protobuf map-entry order is intentionally not preserved.
|
||||
absl::Status KeyProtoMapEntries(Message* source, MessageMap* stowed_messages) {
|
||||
// Copy the rules from the source CalculatorGraphTemplate.
|
||||
mediapipe::CalculatorGraphTemplate rules;
|
||||
rules.ParsePartialFromString(source->SerializePartialAsString());
|
||||
// Only the "source" Message knows all extension types.
|
||||
Message* config_0 = source->GetReflection()->MutableMessage(
|
||||
source, source->GetDescriptor()->FindFieldByName("config"), nullptr);
|
||||
for (int i = 0; i < rules.rule().size(); ++i) {
|
||||
TemplateExpression* rule = rules.mutable_rule()->Mutable(i);
|
||||
const Message* message = config_0;
|
||||
ProtoPath path;
|
||||
MP_RETURN_IF_ERROR(ProtoPathSplit(rule->path(), &path));
|
||||
for (int j = 0; j < path.size(); ++j) {
|
||||
int field_id = path[j].field_id;
|
||||
const FieldDescriptor* field = FindFieldByNumber(message, field_id);
|
||||
ProtoPath prefix = {path.begin(), path.begin() + j + 1};
|
||||
message = GetNestedMessage(*message, field, prefix, stowed_messages);
|
||||
if (!message) {
|
||||
break;
|
||||
}
|
||||
if (field->is_map()) {
|
||||
const Message* map_entry = message;
|
||||
int key_id =
|
||||
map_entry->GetDescriptor()->FindFieldByName("key")->number();
|
||||
FieldType key_type = static_cast<ProtoUtilLite::FieldType>(
|
||||
map_entry->GetDescriptor()->FindFieldByName("key")->type());
|
||||
std::string key_value = GetMapKey(*map_entry);
|
||||
path[j] = {field_id, key_id, key_type, key_value};
|
||||
}
|
||||
}
|
||||
if (!rule->path().empty()) {
|
||||
*rule->mutable_path() = ProtoPathJoin(path);
|
||||
for (FieldType key_type : ProtoPathKeyTypes(path)) {
|
||||
*rule->mutable_key_type()->Add() = key_type;
|
||||
}
|
||||
}
|
||||
}
|
||||
// Copy the rules back into the source CalculatorGraphTemplate.
|
||||
auto source_rules =
|
||||
source->GetReflection()->GetMutableRepeatedFieldRef<Message>(
|
||||
source, source->GetDescriptor()->FindFieldByName("rule"));
|
||||
source_rules.Clear();
|
||||
for (auto& rule : rules.rule()) {
|
||||
source_rules.Add(rule);
|
||||
}
|
||||
return absl::OkStatus();
|
||||
}
|
||||
|
||||
} // namespace
|
||||
|
||||
class TemplateParser::Parser::MediaPipeParserImpl
|
||||
|
@ -1416,6 +1632,8 @@ class TemplateParser::Parser::MediaPipeParserImpl
|
|||
|
||||
// Copy the template rules into the output template "rule" field.
|
||||
success &= MergeFields(template_rules_, output).ok();
|
||||
// Replace map-entry indexes with map keys.
|
||||
success &= KeyProtoMapEntries(output, &stowed_messages_).ok();
|
||||
return success;
|
||||
}
|
||||
|
||||
|
@ -1441,7 +1659,7 @@ class TemplateParser::Parser::MediaPipeParserImpl
|
|||
DO(ConsumeFieldTemplate(message));
|
||||
} else {
|
||||
DO(ConsumeField(message));
|
||||
StowFieldValue(message, expression);
|
||||
StowFieldValue(message, expression, &stowed_messages_);
|
||||
}
|
||||
DO(ConsumeEndTemplate());
|
||||
return true;
|
||||
|
@ -1652,6 +1870,7 @@ class TemplateParser::Parser::MediaPipeParserImpl
|
|||
}
|
||||
|
||||
mediapipe::CalculatorGraphTemplate template_rules_;
|
||||
std::map<std::string, std::unique_ptr<Message>> stowed_messages_;
|
||||
};
|
||||
|
||||
#undef DO
|
||||
|
|
17
mediapipe/framework/tool/testdata/BUILD
vendored
17
mediapipe/framework/tool/testdata/BUILD
vendored
|
@ -17,10 +17,13 @@ load(
|
|||
"//mediapipe/framework/tool:mediapipe_graph.bzl",
|
||||
"mediapipe_simple_subgraph",
|
||||
)
|
||||
load("//mediapipe/framework/port:build_config.bzl", "mediapipe_proto_library")
|
||||
|
||||
licenses(["notice"])
|
||||
|
||||
package(default_visibility = ["//mediapipe:__subpackages__"])
|
||||
package(default_visibility = [
|
||||
"//mediapipe:__subpackages__",
|
||||
])
|
||||
|
||||
filegroup(
|
||||
name = "test_graph",
|
||||
|
@ -40,7 +43,6 @@ mediapipe_simple_subgraph(
|
|||
testonly = 1,
|
||||
graph = "dub_quad_test_subgraph.pbtxt",
|
||||
register_as = "DubQuadTestSubgraph",
|
||||
visibility = ["//visibility:public"],
|
||||
deps = [
|
||||
"//mediapipe/framework:test_calculators",
|
||||
],
|
||||
|
@ -51,9 +53,18 @@ mediapipe_simple_subgraph(
|
|||
testonly = 1,
|
||||
graph = "nested_test_subgraph.pbtxt",
|
||||
register_as = "NestedTestSubgraph",
|
||||
visibility = ["//visibility:public"],
|
||||
visibility = ["//mediapipe/framework:__subpackages__"],
|
||||
deps = [
|
||||
":dub_quad_test_subgraph",
|
||||
"//mediapipe/framework:test_calculators",
|
||||
],
|
||||
)
|
||||
|
||||
mediapipe_proto_library(
|
||||
name = "frozen_generator_proto",
|
||||
srcs = ["frozen_generator.proto"],
|
||||
visibility = ["//mediapipe/framework:__subpackages__"],
|
||||
deps = [
|
||||
"//mediapipe/framework:packet_generator_proto",
|
||||
],
|
||||
)
|
||||
|
|
20
mediapipe/framework/tool/testdata/frozen_generator.proto
vendored
Normal file
20
mediapipe/framework/tool/testdata/frozen_generator.proto
vendored
Normal file
|
@ -0,0 +1,20 @@
|
|||
syntax = "proto2";
|
||||
|
||||
package mediapipe;
|
||||
|
||||
import "mediapipe/framework/packet_generator.proto";
|
||||
|
||||
message FrozenGeneratorOptions {
|
||||
extend mediapipe.PacketGeneratorOptions {
|
||||
optional FrozenGeneratorOptions ext = 225748738;
|
||||
}
|
||||
|
||||
// Path to file containing serialized proto of type tensorflow::GraphDef.
|
||||
optional string graph_proto_path = 1;
|
||||
|
||||
// This map defines the which streams are fed to which tensors in the model.
|
||||
map<string, string> tag_to_tensor_names = 2;
|
||||
|
||||
// Graph nodes to run to initialize the model.
|
||||
repeated string initialization_op_names = 4;
|
||||
}
|
|
@ -369,6 +369,7 @@ absl::Status ValidatedGraphConfig::Initialize(
|
|||
input_side_packets_.clear();
|
||||
output_side_packets_.clear();
|
||||
stream_to_producer_.clear();
|
||||
output_streams_to_consumer_nodes_.clear();
|
||||
input_streams_.clear();
|
||||
output_streams_.clear();
|
||||
owned_packet_types_.clear();
|
||||
|
@ -719,6 +720,15 @@ absl::Status ValidatedGraphConfig::AddInputStreamsForNode(
|
|||
<< " does not have a corresponding output stream.";
|
||||
}
|
||||
}
|
||||
// Add this node as a consumer of this edge's output stream.
|
||||
if (edge_info.upstream > -1) {
|
||||
auto parent_node = output_streams_[edge_info.upstream].parent_node;
|
||||
if (parent_node.type == NodeTypeInfo::NodeType::CALCULATOR) {
|
||||
int this_idx = node_type_info->Node().index;
|
||||
output_streams_to_consumer_nodes_[edge_info.upstream].push_back(
|
||||
this_idx);
|
||||
}
|
||||
}
|
||||
|
||||
edge_info.parent_node = node_type_info->Node();
|
||||
edge_info.name = name;
|
||||
|
|
|
@ -282,6 +282,14 @@ class ValidatedGraphConfig {
|
|||
return output_streams_[iter->second].parent_node.index;
|
||||
}
|
||||
|
||||
std::vector<int> OutputStreamToConsumers(int idx) const {
|
||||
auto iter = output_streams_to_consumer_nodes_.find(idx);
|
||||
if (iter == output_streams_to_consumer_nodes_.end()) {
|
||||
return {};
|
||||
}
|
||||
return iter->second;
|
||||
}
|
||||
|
||||
// Returns the registered type name of the specified side packet if
|
||||
// it can be determined, otherwise an appropriate error is returned.
|
||||
absl::StatusOr<std::string> RegisteredSidePacketTypeName(
|
||||
|
@ -418,6 +426,10 @@ class ValidatedGraphConfig {
|
|||
|
||||
// Mapping from stream name to the output_streams_ index which produces it.
|
||||
std::map<std::string, int> stream_to_producer_;
|
||||
|
||||
// Mapping from output streams to consumer node ids. Used for profiling.
|
||||
std::map<int, std::vector<int>> output_streams_to_consumer_nodes_;
|
||||
|
||||
// Mapping from side packet name to the output_side_packets_ index
|
||||
// which produces it.
|
||||
std::map<std::string, int> side_packet_to_producer_;
|
||||
|
|
|
@ -289,7 +289,9 @@ cc_library(
|
|||
deps = [
|
||||
":gpu_buffer_format",
|
||||
":gpu_buffer_storage",
|
||||
"@com_google_absl//absl/functional:bind_front",
|
||||
"@com_google_absl//absl/strings",
|
||||
"@com_google_absl//absl/synchronization",
|
||||
"//mediapipe/framework/formats:image_frame",
|
||||
"//mediapipe/framework/port:logging",
|
||||
":gpu_buffer_storage_image_frame",
|
||||
|
@ -472,13 +474,13 @@ objc_library(
|
|||
copts = [
|
||||
"-Wno-shorten-64-to-32",
|
||||
],
|
||||
sdk_frameworks = [
|
||||
"Accelerate",
|
||||
"CoreGraphics",
|
||||
"CoreVideo",
|
||||
],
|
||||
visibility = ["//visibility:public"],
|
||||
deps = ["//mediapipe/objc:util"],
|
||||
deps = [
|
||||
"//mediapipe/objc:util",
|
||||
"//third_party/apple_frameworks:Accelerate",
|
||||
"//third_party/apple_frameworks:CoreGraphics",
|
||||
"//third_party/apple_frameworks:CoreVideo",
|
||||
],
|
||||
)
|
||||
|
||||
objc_library(
|
||||
|
@ -510,13 +512,11 @@ objc_library(
|
|||
"-x objective-c++",
|
||||
"-Wno-shorten-64-to-32",
|
||||
],
|
||||
sdk_frameworks = [
|
||||
"CoreVideo",
|
||||
"Metal",
|
||||
],
|
||||
visibility = ["//visibility:public"],
|
||||
deps = [
|
||||
"//mediapipe/objc:mediapipe_framework_ios",
|
||||
"//third_party/apple_frameworks:CoreVideo",
|
||||
"//third_party/apple_frameworks:Metal",
|
||||
"@com_google_absl//absl/time",
|
||||
"@google_toolbox_for_mac//:GTM_Defines",
|
||||
],
|
||||
|
@ -808,15 +808,13 @@ objc_library(
|
|||
"-Wno-shorten-64-to-32",
|
||||
],
|
||||
features = ["-layering_check"],
|
||||
sdk_frameworks = [
|
||||
"CoreVideo",
|
||||
"Metal",
|
||||
],
|
||||
visibility = ["//visibility:public"],
|
||||
deps = [
|
||||
":gpu_shared_data_internal",
|
||||
":graph_support",
|
||||
"//mediapipe/objc:mediapipe_framework_ios",
|
||||
"//third_party/apple_frameworks:CoreVideo",
|
||||
"//third_party/apple_frameworks:Metal",
|
||||
"@google_toolbox_for_mac//:GTM_Defines",
|
||||
],
|
||||
)
|
||||
|
@ -1020,16 +1018,14 @@ objc_library(
|
|||
name = "metal_copy_calculator",
|
||||
srcs = ["MetalCopyCalculator.mm"],
|
||||
features = ["-layering_check"],
|
||||
sdk_frameworks = [
|
||||
"CoreVideo",
|
||||
"Metal",
|
||||
],
|
||||
visibility = ["//visibility:public"],
|
||||
deps = [
|
||||
":MPPMetalHelper",
|
||||
":simple_shaders_mtl",
|
||||
"//mediapipe/gpu:copy_calculator_cc_proto",
|
||||
"//mediapipe/objc:mediapipe_framework_ios",
|
||||
"//third_party/apple_frameworks:CoreVideo",
|
||||
"//third_party/apple_frameworks:Metal",
|
||||
],
|
||||
alwayslink = 1,
|
||||
)
|
||||
|
@ -1038,15 +1034,13 @@ objc_library(
|
|||
name = "metal_rgb_weight_calculator",
|
||||
srcs = ["MetalRgbWeightCalculator.mm"],
|
||||
features = ["-layering_check"],
|
||||
sdk_frameworks = [
|
||||
"CoreVideo",
|
||||
"Metal",
|
||||
],
|
||||
visibility = ["//visibility:public"],
|
||||
deps = [
|
||||
":MPPMetalHelper",
|
||||
":simple_shaders_mtl",
|
||||
"//mediapipe/objc:mediapipe_framework_ios",
|
||||
"//third_party/apple_frameworks:CoreVideo",
|
||||
"//third_party/apple_frameworks:Metal",
|
||||
],
|
||||
alwayslink = 1,
|
||||
)
|
||||
|
@ -1055,15 +1049,13 @@ objc_library(
|
|||
name = "metal_sobel_calculator",
|
||||
srcs = ["MetalSobelCalculator.mm"],
|
||||
features = ["-layering_check"],
|
||||
sdk_frameworks = [
|
||||
"CoreVideo",
|
||||
"Metal",
|
||||
],
|
||||
visibility = ["//visibility:public"],
|
||||
deps = [
|
||||
":MPPMetalHelper",
|
||||
":simple_shaders_mtl",
|
||||
"//mediapipe/objc:mediapipe_framework_ios",
|
||||
"//third_party/apple_frameworks:CoreVideo",
|
||||
"//third_party/apple_frameworks:Metal",
|
||||
],
|
||||
alwayslink = 1,
|
||||
)
|
||||
|
@ -1072,15 +1064,13 @@ objc_library(
|
|||
name = "metal_sobel_compute_calculator",
|
||||
srcs = ["MetalSobelComputeCalculator.mm"],
|
||||
features = ["-layering_check"],
|
||||
sdk_frameworks = [
|
||||
"CoreVideo",
|
||||
"Metal",
|
||||
],
|
||||
visibility = ["//visibility:public"],
|
||||
deps = [
|
||||
":MPPMetalHelper",
|
||||
":simple_shaders_mtl",
|
||||
"//mediapipe/objc:mediapipe_framework_ios",
|
||||
"//third_party/apple_frameworks:CoreVideo",
|
||||
"//third_party/apple_frameworks:Metal",
|
||||
],
|
||||
alwayslink = 1,
|
||||
)
|
||||
|
@ -1090,15 +1080,13 @@ objc_library(
|
|||
srcs = ["MPSSobelCalculator.mm"],
|
||||
copts = ["-std=c++17"],
|
||||
features = ["-layering_check"],
|
||||
sdk_frameworks = [
|
||||
"CoreVideo",
|
||||
"Metal",
|
||||
"MetalPerformanceShaders",
|
||||
],
|
||||
visibility = ["//visibility:public"],
|
||||
deps = [
|
||||
":MPPMetalHelper",
|
||||
"//mediapipe/objc:mediapipe_framework_ios",
|
||||
"//third_party/apple_frameworks:CoreVideo",
|
||||
"//third_party/apple_frameworks:Metal",
|
||||
"//third_party/apple_frameworks:MetalPerformanceShaders",
|
||||
],
|
||||
alwayslink = 1,
|
||||
)
|
||||
|
@ -1106,15 +1094,13 @@ objc_library(
|
|||
objc_library(
|
||||
name = "mps_threshold_calculator",
|
||||
srcs = ["MPSThresholdCalculator.mm"],
|
||||
sdk_frameworks = [
|
||||
"CoreVideo",
|
||||
"Metal",
|
||||
"MetalPerformanceShaders",
|
||||
],
|
||||
visibility = ["//visibility:public"],
|
||||
deps = [
|
||||
":MPPMetalHelper",
|
||||
"//mediapipe/objc:mediapipe_framework_ios",
|
||||
"//third_party/apple_frameworks:CoreVideo",
|
||||
"//third_party/apple_frameworks:Metal",
|
||||
"//third_party/apple_frameworks:MetalPerformanceShaders",
|
||||
],
|
||||
alwayslink = 1,
|
||||
)
|
||||
|
|
|
@ -3,6 +3,7 @@
|
|||
#include <memory>
|
||||
#include <utility>
|
||||
|
||||
#include "absl/functional/bind_front.h"
|
||||
#include "absl/strings/str_cat.h"
|
||||
#include "absl/strings/str_join.h"
|
||||
#include "mediapipe/framework/port/logging.h"
|
||||
|
@ -25,57 +26,101 @@ struct StorageTypeFormatter {
|
|||
} // namespace
|
||||
|
||||
std::string GpuBuffer::DebugString() const {
|
||||
return absl::StrCat("GpuBuffer[",
|
||||
absl::StrJoin(storages_, ", ", StorageTypeFormatter()),
|
||||
"]");
|
||||
return holder_ ? absl::StrCat("GpuBuffer[", width(), "x", height(), " ",
|
||||
format(), " as ", holder_->DebugString(), "]")
|
||||
: "GpuBuffer[invalid]";
|
||||
}
|
||||
|
||||
internal::GpuBufferStorage* GpuBuffer::GetStorageForView(
|
||||
std::string GpuBuffer::StorageHolder::DebugString() const {
|
||||
absl::MutexLock lock(&mutex_);
|
||||
return absl::StrJoin(storages_, ", ", StorageTypeFormatter());
|
||||
}
|
||||
|
||||
internal::GpuBufferStorage* GpuBuffer::StorageHolder::GetStorageForView(
|
||||
TypeId view_provider_type, bool for_writing) const {
|
||||
const std::shared_ptr<internal::GpuBufferStorage>* chosen_storage = nullptr;
|
||||
std::shared_ptr<internal::GpuBufferStorage> chosen_storage;
|
||||
std::function<std::shared_ptr<internal::GpuBufferStorage>()> conversion;
|
||||
|
||||
// First see if any current storage supports the view.
|
||||
for (const auto& s : storages_) {
|
||||
if (s->can_down_cast_to(view_provider_type)) {
|
||||
chosen_storage = &s;
|
||||
break;
|
||||
}
|
||||
}
|
||||
|
||||
// Then try to convert existing storages to one that does.
|
||||
// TODO: choose best conversion.
|
||||
if (!chosen_storage) {
|
||||
{
|
||||
absl::MutexLock lock(&mutex_);
|
||||
// First see if any current storage supports the view.
|
||||
for (const auto& s : storages_) {
|
||||
if (auto converter = internal::GpuBufferStorageRegistry::Get()
|
||||
.StorageConverterForViewProvider(
|
||||
view_provider_type, s->storage_type())) {
|
||||
if (auto new_storage = converter(s)) {
|
||||
storages_.push_back(new_storage);
|
||||
chosen_storage = &storages_.back();
|
||||
if (s->can_down_cast_to(view_provider_type)) {
|
||||
chosen_storage = s;
|
||||
break;
|
||||
}
|
||||
}
|
||||
|
||||
// Then try to convert existing storages to one that does.
|
||||
// TODO: choose best conversion.
|
||||
if (!chosen_storage) {
|
||||
for (const auto& s : storages_) {
|
||||
if (auto converter = internal::GpuBufferStorageRegistry::Get()
|
||||
.StorageConverterForViewProvider(
|
||||
view_provider_type, s->storage_type())) {
|
||||
conversion = absl::bind_front(converter, s);
|
||||
break;
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// Avoid invoking a converter or factory while holding the mutex.
|
||||
// Two reasons:
|
||||
// 1. Readers that don't need a conversion will not be blocked.
|
||||
// 2. We use mutexes to make sure GL contexts are not used simultaneously on
|
||||
// different threads, and we also rely on Mutex's deadlock detection
|
||||
// heuristic, which enforces a consistent mutex acquisition order.
|
||||
// This function is likely to be called within a GL context, and the
|
||||
// conversion function may in turn use a GL context, and this may cause a
|
||||
// false positive in the deadlock detector.
|
||||
// TODO: we could use Mutex::ForgetDeadlockInfo instead.
|
||||
if (conversion) {
|
||||
auto new_storage = conversion();
|
||||
absl::MutexLock lock(&mutex_);
|
||||
// Another reader might have already completed and inserted the same
|
||||
// conversion. TODO: prevent this?
|
||||
for (const auto& s : storages_) {
|
||||
if (s->can_down_cast_to(view_provider_type)) {
|
||||
chosen_storage = s;
|
||||
break;
|
||||
}
|
||||
}
|
||||
if (!chosen_storage) {
|
||||
storages_.push_back(std::move(new_storage));
|
||||
chosen_storage = storages_.back();
|
||||
}
|
||||
}
|
||||
|
||||
if (for_writing) {
|
||||
// This will temporarily hold storages to be released, and do so while the
|
||||
// lock is not held (see above).
|
||||
decltype(storages_) old_storages;
|
||||
using std::swap;
|
||||
if (chosen_storage) {
|
||||
// Discard all other storages.
|
||||
storages_ = {*chosen_storage};
|
||||
chosen_storage = &storages_.back();
|
||||
absl::MutexLock lock(&mutex_);
|
||||
swap(old_storages, storages_);
|
||||
storages_ = {chosen_storage};
|
||||
} else {
|
||||
// Allocate a new storage supporting the requested view.
|
||||
if (auto factory =
|
||||
internal::GpuBufferStorageRegistry::Get()
|
||||
.StorageFactoryForViewProvider(view_provider_type)) {
|
||||
if (auto new_storage = factory(width(), height(), format())) {
|
||||
if (auto new_storage = factory(width_, height_, format_)) {
|
||||
absl::MutexLock lock(&mutex_);
|
||||
swap(old_storages, storages_);
|
||||
storages_ = {std::move(new_storage)};
|
||||
chosen_storage = &storages_.back();
|
||||
chosen_storage = storages_.back();
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
return chosen_storage ? chosen_storage->get() : nullptr;
|
||||
|
||||
// It is ok to return a non-owning storage pointer here because this object
|
||||
// ensures the storage's lifetime. Overwriting a GpuBuffer while readers are
|
||||
// active would violate this, but it's not allowed in MediaPipe.
|
||||
return chosen_storage ? chosen_storage.get() : nullptr;
|
||||
}
|
||||
|
||||
internal::GpuBufferStorage& GpuBuffer::GetStorageForViewOrDie(
|
||||
|
@ -84,8 +129,7 @@ internal::GpuBufferStorage& GpuBuffer::GetStorageForViewOrDie(
|
|||
GpuBuffer::GetStorageForView(view_provider_type, for_writing);
|
||||
CHECK(chosen_storage) << "no view provider found for requested view "
|
||||
<< view_provider_type.name() << "; storages available: "
|
||||
<< absl::StrJoin(storages_, ", ",
|
||||
StorageTypeFormatter());
|
||||
<< (holder_ ? holder_->DebugString() : "invalid");
|
||||
DCHECK(chosen_storage->can_down_cast_to(view_provider_type));
|
||||
return *chosen_storage;
|
||||
}
|
||||
|
|
|
@ -15,9 +15,12 @@
|
|||
#ifndef MEDIAPIPE_GPU_GPU_BUFFER_H_
|
||||
#define MEDIAPIPE_GPU_GPU_BUFFER_H_
|
||||
|
||||
#include <algorithm>
|
||||
#include <functional>
|
||||
#include <memory>
|
||||
#include <utility>
|
||||
|
||||
#include "absl/synchronization/mutex.h"
|
||||
#include "mediapipe/framework/formats/image_frame.h"
|
||||
#include "mediapipe/gpu/gpu_buffer_format.h"
|
||||
#include "mediapipe/gpu/gpu_buffer_storage.h"
|
||||
|
@ -56,8 +59,7 @@ class GpuBuffer {
|
|||
// Creates an empty buffer of a given size and format. It will be allocated
|
||||
// when a view is requested.
|
||||
GpuBuffer(int width, int height, Format format)
|
||||
: GpuBuffer(std::make_shared<PlaceholderGpuBufferStorage>(width, height,
|
||||
format)) {}
|
||||
: holder_(std::make_shared<StorageHolder>(width, height, format)) {}
|
||||
|
||||
// Copy and move constructors and assignment operators are supported.
|
||||
GpuBuffer(const GpuBuffer& other) = default;
|
||||
|
@ -70,9 +72,8 @@ class GpuBuffer {
|
|||
// are not portable. Applications and calculators should normally obtain
|
||||
// GpuBuffers in a portable way from the framework, e.g. using
|
||||
// GpuBufferMultiPool.
|
||||
explicit GpuBuffer(std::shared_ptr<internal::GpuBufferStorage> storage) {
|
||||
storages_.push_back(std::move(storage));
|
||||
}
|
||||
explicit GpuBuffer(std::shared_ptr<internal::GpuBufferStorage> storage)
|
||||
: holder_(std::make_shared<StorageHolder>(std::move(storage))) {}
|
||||
|
||||
#if !MEDIAPIPE_DISABLE_GPU && MEDIAPIPE_GPU_BUFFER_USE_CV_PIXEL_BUFFER
|
||||
// This is used to support backward-compatible construction of GpuBuffer from
|
||||
|
@ -84,9 +85,11 @@ class GpuBuffer {
|
|||
: GpuBuffer(internal::AsGpuBufferStorage(storage_convertible)) {}
|
||||
#endif // !MEDIAPIPE_DISABLE_GPU && MEDIAPIPE_GPU_BUFFER_USE_CV_PIXEL_BUFFER
|
||||
|
||||
int width() const { return current_storage().width(); }
|
||||
int height() const { return current_storage().height(); }
|
||||
GpuBufferFormat format() const { return current_storage().format(); }
|
||||
int width() const { return holder_ ? holder_->width() : 0; }
|
||||
int height() const { return holder_ ? holder_->height() : 0; }
|
||||
GpuBufferFormat format() const {
|
||||
return holder_ ? holder_->format() : GpuBufferFormat::kUnknown;
|
||||
}
|
||||
|
||||
// Converts to true iff valid.
|
||||
explicit operator bool() const { return operator!=(nullptr); }
|
||||
|
@ -122,31 +125,17 @@ class GpuBuffer {
|
|||
// using views.
|
||||
template <class T>
|
||||
std::shared_ptr<T> internal_storage() const {
|
||||
for (const auto& s : storages_)
|
||||
if (s->down_cast<T>()) return std::static_pointer_cast<T>(s);
|
||||
return nullptr;
|
||||
return holder_ ? holder_->internal_storage<T>() : nullptr;
|
||||
}
|
||||
|
||||
std::string DebugString() const;
|
||||
|
||||
private:
|
||||
class PlaceholderGpuBufferStorage
|
||||
: public internal::GpuBufferStorageImpl<PlaceholderGpuBufferStorage> {
|
||||
public:
|
||||
PlaceholderGpuBufferStorage(int width, int height, Format format)
|
||||
: width_(width), height_(height), format_(format) {}
|
||||
int width() const override { return width_; }
|
||||
int height() const override { return height_; }
|
||||
GpuBufferFormat format() const override { return format_; }
|
||||
|
||||
private:
|
||||
int width_ = 0;
|
||||
int height_ = 0;
|
||||
GpuBufferFormat format_ = GpuBufferFormat::kUnknown;
|
||||
};
|
||||
|
||||
internal::GpuBufferStorage* GetStorageForView(TypeId view_provider_type,
|
||||
bool for_writing) const;
|
||||
bool for_writing) const {
|
||||
return holder_ ? holder_->GetStorageForView(view_provider_type, for_writing)
|
||||
: nullptr;
|
||||
}
|
||||
|
||||
internal::GpuBufferStorage& GetStorageForViewOrDie(TypeId view_provider_type,
|
||||
bool for_writing) const;
|
||||
|
@ -158,25 +147,49 @@ class GpuBuffer {
|
|||
.template down_cast<VP>();
|
||||
}
|
||||
|
||||
std::shared_ptr<internal::GpuBufferStorage>& no_storage() const {
|
||||
static auto placeholder =
|
||||
std::static_pointer_cast<internal::GpuBufferStorage>(
|
||||
std::make_shared<PlaceholderGpuBufferStorage>(
|
||||
0, 0, GpuBufferFormat::kUnknown));
|
||||
return placeholder;
|
||||
}
|
||||
// This class manages a set of alternative storages for the contents of a
|
||||
// GpuBuffer. GpuBuffer was originally designed as a reference-type object,
|
||||
// where a copy represents another reference to the same contents, so multiple
|
||||
// GpuBuffer instances can share the same StorageHolder.
|
||||
class StorageHolder {
|
||||
public:
|
||||
explicit StorageHolder(std::shared_ptr<internal::GpuBufferStorage> storage)
|
||||
: StorageHolder(storage->width(), storage->height(),
|
||||
storage->format()) {
|
||||
storages_.push_back(std::move(storage));
|
||||
}
|
||||
explicit StorageHolder(int width, int height, Format format)
|
||||
: width_(width), height_(height), format_(format) {}
|
||||
|
||||
const internal::GpuBufferStorage& current_storage() const {
|
||||
return storages_.empty() ? *no_storage() : *storages_[0];
|
||||
}
|
||||
int width() const { return width_; }
|
||||
int height() const { return height_; }
|
||||
GpuBufferFormat format() const { return format_; }
|
||||
|
||||
internal::GpuBufferStorage& current_storage() {
|
||||
return storages_.empty() ? *no_storage() : *storages_[0];
|
||||
}
|
||||
internal::GpuBufferStorage* GetStorageForView(TypeId view_provider_type,
|
||||
bool for_writing) const;
|
||||
|
||||
// This is mutable because view methods that do not change the contents may
|
||||
// still need to allocate new storages.
|
||||
mutable std::vector<std::shared_ptr<internal::GpuBufferStorage>> storages_;
|
||||
template <class T>
|
||||
std::shared_ptr<T> internal_storage() const {
|
||||
absl::MutexLock lock(&mutex_);
|
||||
for (const auto& s : storages_)
|
||||
if (s->down_cast<T>()) return std::static_pointer_cast<T>(s);
|
||||
return nullptr;
|
||||
}
|
||||
|
||||
std::string DebugString() const;
|
||||
|
||||
private:
|
||||
int width_ = 0;
|
||||
int height_ = 0;
|
||||
GpuBufferFormat format_ = GpuBufferFormat::kUnknown;
|
||||
// This is mutable because view methods that do not change the contents may
|
||||
// still need to allocate new storages.
|
||||
mutable absl::Mutex mutex_;
|
||||
mutable std::vector<std::shared_ptr<internal::GpuBufferStorage>> storages_
|
||||
ABSL_GUARDED_BY(mutex_);
|
||||
};
|
||||
|
||||
std::shared_ptr<StorageHolder> holder_;
|
||||
|
||||
#if MEDIAPIPE_GPU_BUFFER_USE_CV_PIXEL_BUFFER
|
||||
friend CVPixelBufferRef GetCVPixelBufferRef(const GpuBuffer& buffer);
|
||||
|
@ -184,15 +197,15 @@ class GpuBuffer {
|
|||
};
|
||||
|
||||
inline bool GpuBuffer::operator==(std::nullptr_t other) const {
|
||||
return storages_.empty();
|
||||
return holder_ == other;
|
||||
}
|
||||
|
||||
inline bool GpuBuffer::operator==(const GpuBuffer& other) const {
|
||||
return storages_ == other.storages_;
|
||||
return holder_ == other.holder_;
|
||||
}
|
||||
|
||||
inline GpuBuffer& GpuBuffer::operator=(std::nullptr_t other) {
|
||||
storages_.clear();
|
||||
holder_ = other;
|
||||
return *this;
|
||||
}
|
||||
|
||||
|
|
|
@ -20,6 +20,7 @@
|
|||
#include "mediapipe/framework/port/gmock.h"
|
||||
#include "mediapipe/framework/port/gtest.h"
|
||||
#include "mediapipe/framework/tool/test_util.h"
|
||||
#include "mediapipe/gpu/gl_texture_buffer.h"
|
||||
#include "mediapipe/gpu/gl_texture_util.h"
|
||||
#include "mediapipe/gpu/gpu_buffer_storage_ahwb.h"
|
||||
#include "mediapipe/gpu/gpu_buffer_storage_image_frame.h"
|
||||
|
@ -228,5 +229,26 @@ TEST_F(GpuBufferTest, GlTextureViewRetainsWhatItNeeds) {
|
|||
EXPECT_TRUE(true);
|
||||
}
|
||||
|
||||
TEST_F(GpuBufferTest, CopiesShareConversions) {
|
||||
GpuBuffer buffer(300, 200, GpuBufferFormat::kBGRA32);
|
||||
{
|
||||
std::shared_ptr<ImageFrame> view = buffer.GetWriteView<ImageFrame>();
|
||||
FillImageFrameRGBA(*view, 255, 0, 0, 255);
|
||||
}
|
||||
|
||||
GpuBuffer other_handle = buffer;
|
||||
RunInGlContext([&buffer] {
|
||||
TempGlFramebuffer fb;
|
||||
auto view = buffer.GetReadView<GlTextureView>(0);
|
||||
});
|
||||
|
||||
// Check that other_handle also sees the same GlTextureBuffer as buffer.
|
||||
// Note that this is deliberately written so that it still passes on platforms
|
||||
// where we use another storage for GL textures (they will both be null).
|
||||
// TODO: expose more accessors for testing?
|
||||
EXPECT_EQ(other_handle.internal_storage<GlTextureBuffer>(),
|
||||
buffer.internal_storage<GlTextureBuffer>());
|
||||
}
|
||||
|
||||
} // anonymous namespace
|
||||
} // namespace mediapipe
|
||||
|
|
40
mediapipe/gpu/metal_shared_resources.h
Normal file
40
mediapipe/gpu/metal_shared_resources.h
Normal file
|
@ -0,0 +1,40 @@
|
|||
#ifndef MEDIAPIPE_GPU_METAL_SHARED_RESOURCES_H_
|
||||
#define MEDIAPIPE_GPU_METAL_SHARED_RESOURCES_H_
|
||||
|
||||
#import <CoreVideo/CVMetalTextureCache.h>
|
||||
#import <CoreVideo/CoreVideo.h>
|
||||
#import <Foundation/NSObject.h>
|
||||
#import <Metal/Metal.h>
|
||||
|
||||
#ifndef __OBJC__
|
||||
#error This class must be built as Objective-C++.
|
||||
#endif // !__OBJC__
|
||||
|
||||
@interface MPPMetalSharedResources : NSObject {
|
||||
}
|
||||
|
||||
- (instancetype)init NS_DESIGNATED_INITIALIZER;
|
||||
|
||||
@property(readonly) id<MTLDevice> mtlDevice;
|
||||
@property(readonly) id<MTLCommandQueue> mtlCommandQueue;
|
||||
#if COREVIDEO_SUPPORTS_METAL
|
||||
@property(readonly) CVMetalTextureCacheRef mtlTextureCache;
|
||||
#endif
|
||||
|
||||
@end
|
||||
|
||||
namespace mediapipe {
|
||||
|
||||
class MetalSharedResources {
|
||||
public:
|
||||
MetalSharedResources();
|
||||
~MetalSharedResources();
|
||||
MPPMetalSharedResources* resources() { return resources_; }
|
||||
|
||||
private:
|
||||
MPPMetalSharedResources* resources_;
|
||||
};
|
||||
|
||||
} // namespace mediapipe
|
||||
|
||||
#endif // MEDIAPIPE_GPU_METAL_SHARED_RESOURCES_H_
|
73
mediapipe/gpu/metal_shared_resources.mm
Normal file
73
mediapipe/gpu/metal_shared_resources.mm
Normal file
|
@ -0,0 +1,73 @@
|
|||
#import "mediapipe/gpu/metal_shared_resources.h"
|
||||
|
||||
@interface MPPMetalSharedResources ()
|
||||
@end
|
||||
|
||||
@implementation MPPMetalSharedResources {
|
||||
}
|
||||
|
||||
@synthesize mtlDevice = _mtlDevice;
|
||||
@synthesize mtlCommandQueue = _mtlCommandQueue;
|
||||
#if COREVIDEO_SUPPORTS_METAL
|
||||
@synthesize mtlTextureCache = _mtlTextureCache;
|
||||
#endif
|
||||
|
||||
- (instancetype)init {
|
||||
self = [super init];
|
||||
if (self) {
|
||||
}
|
||||
return self;
|
||||
}
|
||||
|
||||
- (void)dealloc {
|
||||
#if COREVIDEO_SUPPORTS_METAL
|
||||
if (_mtlTextureCache) {
|
||||
CFRelease(_mtlTextureCache);
|
||||
_mtlTextureCache = NULL;
|
||||
}
|
||||
#endif
|
||||
}
|
||||
|
||||
- (id<MTLDevice>)mtlDevice {
|
||||
@synchronized(self) {
|
||||
if (!_mtlDevice) {
|
||||
_mtlDevice = MTLCreateSystemDefaultDevice();
|
||||
}
|
||||
}
|
||||
return _mtlDevice;
|
||||
}
|
||||
|
||||
- (id<MTLCommandQueue>)mtlCommandQueue {
|
||||
@synchronized(self) {
|
||||
if (!_mtlCommandQueue) {
|
||||
_mtlCommandQueue = [self.mtlDevice newCommandQueue];
|
||||
}
|
||||
}
|
||||
return _mtlCommandQueue;
|
||||
}
|
||||
|
||||
#if COREVIDEO_SUPPORTS_METAL
|
||||
- (CVMetalTextureCacheRef)mtlTextureCache {
|
||||
@synchronized(self) {
|
||||
if (!_mtlTextureCache) {
|
||||
CVReturn __unused err =
|
||||
CVMetalTextureCacheCreate(NULL, NULL, self.mtlDevice, NULL, &_mtlTextureCache);
|
||||
NSAssert(err == kCVReturnSuccess, @"Error at CVMetalTextureCacheCreate %d ; device %@", err,
|
||||
self.mtlDevice);
|
||||
// TODO: register and flush metal caches too.
|
||||
}
|
||||
}
|
||||
return _mtlTextureCache;
|
||||
}
|
||||
#endif
|
||||
|
||||
@end
|
||||
|
||||
namespace mediapipe {
|
||||
|
||||
MetalSharedResources::MetalSharedResources() {
|
||||
resources_ = [[MPPMetalSharedResources alloc] init];
|
||||
}
|
||||
MetalSharedResources::~MetalSharedResources() {}
|
||||
|
||||
} // namespace mediapipe
|
49
mediapipe/gpu/metal_shared_resources_test.mm
Normal file
49
mediapipe/gpu/metal_shared_resources_test.mm
Normal file
|
@ -0,0 +1,49 @@
|
|||
// Copyright 2019 The MediaPipe Authors.
|
||||
//
|
||||
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||
// you may not use this file except in compliance with the License.
|
||||
// You may obtain a copy of the License at
|
||||
//
|
||||
// http://www.apache.org/licenses/LICENSE-2.0
|
||||
//
|
||||
// Unless required by applicable law or agreed to in writing, software
|
||||
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
// See the License for the specific language governing permissions and
|
||||
// limitations under the License.
|
||||
|
||||
#import <UIKit/UIKit.h>
|
||||
#import <XCTest/XCTest.h>
|
||||
|
||||
#include <memory>
|
||||
|
||||
#include "absl/memory/memory.h"
|
||||
#include "mediapipe/framework/port/threadpool.h"
|
||||
|
||||
#import "mediapipe/gpu/gpu_shared_data_internal.h"
|
||||
#import "mediapipe/gpu/metal_shared_resources.h"
|
||||
|
||||
@interface MPPMetalSharedResourcesTests : XCTestCase {
|
||||
}
|
||||
@end
|
||||
|
||||
@implementation MPPMetalSharedResourcesTests
|
||||
|
||||
// This test verifies that the internal Objective-C object is correctly
|
||||
// released when the C++ wrapper is released.
|
||||
- (void)testCorrectlyReleased {
|
||||
__weak id metalRes = nil;
|
||||
std::weak_ptr<mediapipe::GpuResources> weakGpuRes;
|
||||
@autoreleasepool {
|
||||
auto maybeGpuRes = mediapipe::GpuResources::Create();
|
||||
XCTAssertTrue(maybeGpuRes.ok());
|
||||
weakGpuRes = *maybeGpuRes;
|
||||
metalRes = (**maybeGpuRes).metal_shared().resources();
|
||||
XCTAssertNotEqual(weakGpuRes.lock(), nullptr);
|
||||
XCTAssertNotNil(metalRes);
|
||||
}
|
||||
XCTAssertEqual(weakGpuRes.lock(), nullptr);
|
||||
XCTAssertNil(metalRes);
|
||||
}
|
||||
|
||||
@end
|
|
@ -43,6 +43,7 @@ cc_library(
|
|||
deps = [
|
||||
"//mediapipe/calculators/core:flow_limiter_calculator",
|
||||
"//mediapipe/calculators/core:previous_loopback_calculator",
|
||||
"//mediapipe/calculators/image:color_convert_calculator",
|
||||
"//mediapipe/calculators/image:image_transformation_calculator",
|
||||
"//mediapipe/calculators/image:recolor_calculator",
|
||||
"//mediapipe/calculators/image:set_alpha_calculator",
|
||||
|
|
|
@ -60,7 +60,14 @@ node {
|
|||
tag_index: "LOOP"
|
||||
back_edge: true
|
||||
}
|
||||
output_stream: "PREV_LOOP:previous_hair_mask"
|
||||
output_stream: "PREV_LOOP:previous_hair_mask_rgb"
|
||||
}
|
||||
|
||||
# Converts the 4 channel hair mask to a single channel mask
|
||||
node {
|
||||
calculator: "ColorConvertCalculator"
|
||||
input_stream: "RGB_IN:previous_hair_mask_rgb"
|
||||
output_stream: "GRAY_OUT:previous_hair_mask"
|
||||
}
|
||||
|
||||
# Embeds the hair mask generated from the previous round of hair segmentation
|
||||
|
|
|
@ -97,7 +97,6 @@ cc_library(
|
|||
visibility = ["//visibility:public"],
|
||||
deps = [
|
||||
"//mediapipe/framework:calculator_framework",
|
||||
"//mediapipe/framework/formats:image_file_properties_cc_proto",
|
||||
"//mediapipe/framework/formats:landmark_cc_proto",
|
||||
"//mediapipe/framework/port:ret_check",
|
||||
"//mediapipe/framework/port:status",
|
||||
|
|
|
@ -22,6 +22,7 @@ package(default_visibility = ["//visibility:public"])
|
|||
mediapipe_proto_library(
|
||||
name = "gl_animation_overlay_calculator_proto",
|
||||
srcs = ["gl_animation_overlay_calculator.proto"],
|
||||
def_options_lib = False,
|
||||
visibility = ["//visibility:public"],
|
||||
exports = [
|
||||
"//mediapipe/gpu:gl_animation_overlay_calculator_proto",
|
||||
|
|
|
@ -84,12 +84,11 @@ cc_library(
|
|||
deps = [
|
||||
":class_registry",
|
||||
":jni_util",
|
||||
"//mediapipe/framework:calculator_framework",
|
||||
"//mediapipe/framework:calculator_profile_cc_proto",
|
||||
"//mediapipe/framework/tool:calculator_graph_template_cc_proto",
|
||||
"//mediapipe/framework/formats:image_format_cc_proto",
|
||||
"//mediapipe/framework/formats:matrix_data_cc_proto",
|
||||
"//mediapipe/framework/formats:time_series_header_cc_proto",
|
||||
"//mediapipe/framework:calculator_cc_proto",
|
||||
"//mediapipe/framework:calculator_profile_cc_proto",
|
||||
"//mediapipe/framework:calculator_framework",
|
||||
"@com_google_absl//absl/strings",
|
||||
"@com_google_absl//absl/strings:str_format",
|
||||
"@com_google_absl//absl/synchronization",
|
||||
|
|
1
mediapipe/model_maker/MANIFEST.in
Normal file
1
mediapipe/model_maker/MANIFEST.in
Normal file
|
@ -0,0 +1 @@
|
|||
recursive-include pip_src/mediapipe_model_maker/models *
|
|
@ -11,3 +11,9 @@
|
|||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
|
||||
from mediapipe.model_maker.python.core.utils import quantization
|
||||
from mediapipe.model_maker.python.vision import image_classifier
|
||||
from mediapipe.model_maker.python.vision import gesture_recognizer
|
||||
from mediapipe.model_maker.python.text import text_classifier
|
||||
|
|
|
@ -19,7 +19,7 @@ import os
|
|||
|
||||
|
||||
def get_absolute_path(file_path: str) -> str:
|
||||
"""Gets the absolute path of a file.
|
||||
"""Gets the absolute path of a file in the model_maker directory.
|
||||
|
||||
Args:
|
||||
file_path: The path to a file relative to the `mediapipe` dir
|
||||
|
@ -27,10 +27,17 @@ def get_absolute_path(file_path: str) -> str:
|
|||
Returns:
|
||||
The full path of the file
|
||||
"""
|
||||
# Extract the file path before mediapipe/ as the `base_dir`. By joining it
|
||||
# with the `path` which defines the relative path under mediapipe/, it
|
||||
# yields to the absolute path of the model files directory.
|
||||
# Extract the file path before and including 'model_maker' as the
|
||||
# `mm_base_dir`. By joining it with the `path` after 'model_maker/', it
|
||||
# yields to the absolute path of the model files directory. We must join
|
||||
# on 'model_maker' because in the pypi package, the 'model_maker' directory
|
||||
# is renamed to 'mediapipe_model_maker'. So we have to join on model_maker
|
||||
# to ensure that the `mm_base_dir` path includes the renamed
|
||||
# 'mediapipe_model_maker' directory.
|
||||
cwd = os.path.dirname(__file__)
|
||||
base_dir = cwd[:cwd.rfind('mediapipe')]
|
||||
absolute_path = os.path.join(base_dir, file_path)
|
||||
cwd_stop_idx = cwd.rfind('model_maker') + len('model_maker')
|
||||
mm_base_dir = cwd[:cwd_stop_idx]
|
||||
file_path_start_idx = file_path.find('model_maker') + len('model_maker') + 1
|
||||
mm_relative_path = file_path[file_path_start_idx:]
|
||||
absolute_path = os.path.join(mm_base_dir, mm_relative_path)
|
||||
return absolute_path
|
||||
|
|
|
@ -4,21 +4,10 @@
|
|||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
load("//mediapipe/framework/port:build_config.bzl", "mediapipe_proto_library")
|
||||
|
||||
package(default_visibility = ["//mediapipe/tasks:internal"])
|
||||
|
||||
licenses(["notice"])
|
||||
|
||||
mediapipe_proto_library(
|
||||
name = "segmenter_options_proto",
|
||||
srcs = ["segmenter_options.proto"],
|
||||
)
|
|
@ -35,20 +35,21 @@ py_library(
|
|||
srcs = ["constants.py"],
|
||||
)
|
||||
|
||||
# TODO: Change to py_library after migrating the MediaPipe hand solution
|
||||
# library to MediaPipe hand task library.
|
||||
py_library(
|
||||
name = "dataset",
|
||||
srcs = ["dataset.py"],
|
||||
deps = [
|
||||
":constants",
|
||||
":metadata_writer",
|
||||
"//mediapipe/model_maker/python/core/data:classification_dataset",
|
||||
"//mediapipe/model_maker/python/core/data:data_util",
|
||||
"//mediapipe/model_maker/python/core/utils:model_util",
|
||||
"//mediapipe/python/solutions:hands",
|
||||
"//mediapipe/python:_framework_bindings",
|
||||
"//mediapipe/tasks/python/core:base_options",
|
||||
"//mediapipe/tasks/python/vision:hand_landmarker",
|
||||
],
|
||||
)
|
||||
|
||||
# TODO: Remove notsan tag once tasks no longer has race condition issue
|
||||
py_test(
|
||||
name = "dataset_test",
|
||||
srcs = ["dataset_test.py"],
|
||||
|
@ -56,10 +57,11 @@ py_test(
|
|||
":testdata",
|
||||
"//mediapipe/model_maker/models/gesture_recognizer:models",
|
||||
],
|
||||
tags = ["notsan"],
|
||||
deps = [
|
||||
":dataset",
|
||||
"//mediapipe/python/solutions:hands",
|
||||
"//mediapipe/tasks/python/test:test_utils",
|
||||
"//mediapipe/tasks/python/vision:hand_landmarker",
|
||||
],
|
||||
)
|
||||
|
||||
|
@ -131,6 +133,7 @@ py_library(
|
|||
],
|
||||
)
|
||||
|
||||
# TODO: Remove notsan tag once tasks no longer has race condition issue
|
||||
py_test(
|
||||
name = "gesture_recognizer_test",
|
||||
size = "large",
|
||||
|
@ -140,6 +143,7 @@ py_test(
|
|||
"//mediapipe/model_maker/models/gesture_recognizer:models",
|
||||
],
|
||||
shard_count = 2,
|
||||
tags = ["notsan"],
|
||||
deps = [
|
||||
":gesture_recognizer_import",
|
||||
"//mediapipe/model_maker/python/core/utils:test_util",
|
||||
|
|
|
@ -16,16 +16,22 @@
|
|||
import dataclasses
|
||||
import os
|
||||
import random
|
||||
from typing import List, NamedTuple, Optional
|
||||
from typing import List, Optional
|
||||
|
||||
import cv2
|
||||
import tensorflow as tf
|
||||
|
||||
from mediapipe.model_maker.python.core.data import classification_dataset
|
||||
from mediapipe.model_maker.python.core.data import data_util
|
||||
from mediapipe.model_maker.python.core.utils import model_util
|
||||
from mediapipe.model_maker.python.vision.gesture_recognizer import constants
|
||||
from mediapipe.python.solutions import hands as mp_hands
|
||||
from mediapipe.model_maker.python.vision.gesture_recognizer import metadata_writer
|
||||
from mediapipe.python._framework_bindings import image as image_module
|
||||
from mediapipe.tasks.python.core import base_options as base_options_module
|
||||
from mediapipe.tasks.python.vision import hand_landmarker as hand_landmarker_module
|
||||
|
||||
_Image = image_module.Image
|
||||
_HandLandmarker = hand_landmarker_module.HandLandmarker
|
||||
_HandLandmarkerOptions = hand_landmarker_module.HandLandmarkerOptions
|
||||
_HandLandmarkerResult = hand_landmarker_module.HandLandmarkerResult
|
||||
|
||||
|
||||
@dataclasses.dataclass
|
||||
|
@ -59,7 +65,7 @@ class HandData:
|
|||
handedness: List[float]
|
||||
|
||||
|
||||
def _validate_data_sample(data: NamedTuple) -> bool:
|
||||
def _validate_data_sample(data: _HandLandmarkerResult) -> bool:
|
||||
"""Validates the input hand data sample.
|
||||
|
||||
Args:
|
||||
|
@ -70,19 +76,17 @@ def _validate_data_sample(data: NamedTuple) -> bool:
|
|||
'multi_hand_landmarks' or 'multi_hand_world_landmarks' or 'multi_handedness'
|
||||
or any of these attributes' values are none. Otherwise, True.
|
||||
"""
|
||||
if (not hasattr(data, 'multi_hand_landmarks') or
|
||||
data.multi_hand_landmarks is None):
|
||||
if data.hand_landmarks is None or not data.hand_landmarks:
|
||||
return False
|
||||
if (not hasattr(data, 'multi_hand_world_landmarks') or
|
||||
data.multi_hand_world_landmarks is None):
|
||||
if data.hand_world_landmarks is None or not data.hand_world_landmarks:
|
||||
return False
|
||||
if not hasattr(data, 'multi_handedness') or data.multi_handedness is None:
|
||||
if data.handedness is None or not data.handedness:
|
||||
return False
|
||||
return True
|
||||
|
||||
|
||||
def _get_hand_data(all_image_paths: List[str],
|
||||
min_detection_confidence: float) -> Optional[HandData]:
|
||||
min_detection_confidence: float) -> List[Optional[HandData]]:
|
||||
"""Computes hand data (landmarks and handedness) in the input image.
|
||||
|
||||
Args:
|
||||
|
@ -93,28 +97,36 @@ def _get_hand_data(all_image_paths: List[str],
|
|||
A HandData object. Returns None if no hand is detected.
|
||||
"""
|
||||
hand_data_result = []
|
||||
with mp_hands.Hands(
|
||||
static_image_mode=True,
|
||||
max_num_hands=1,
|
||||
min_detection_confidence=min_detection_confidence) as hands:
|
||||
hand_detector_model_buffer = model_util.load_tflite_model_buffer(
|
||||
constants.HAND_DETECTOR_TFLITE_MODEL_FILE)
|
||||
hand_landmarks_detector_model_buffer = model_util.load_tflite_model_buffer(
|
||||
constants.HAND_LANDMARKS_DETECTOR_TFLITE_MODEL_FILE)
|
||||
hand_landmarker_writer = metadata_writer.HandLandmarkerMetadataWriter(
|
||||
hand_detector_model_buffer, hand_landmarks_detector_model_buffer)
|
||||
hand_landmarker_options = _HandLandmarkerOptions(
|
||||
base_options=base_options_module.BaseOptions(
|
||||
model_asset_buffer=hand_landmarker_writer.populate()),
|
||||
num_hands=1,
|
||||
min_hand_detection_confidence=min_detection_confidence,
|
||||
min_hand_presence_confidence=0.5,
|
||||
min_tracking_confidence=1,
|
||||
)
|
||||
with _HandLandmarker.create_from_options(
|
||||
hand_landmarker_options) as hand_landmarker:
|
||||
for path in all_image_paths:
|
||||
tf.compat.v1.logging.info('Loading image %s', path)
|
||||
image = data_util.load_image(path)
|
||||
# Flip image around y-axis for correct handedness output
|
||||
image = cv2.flip(image, 1)
|
||||
data = hands.process(image)
|
||||
image = _Image.create_from_file(path)
|
||||
data = hand_landmarker.detect(image)
|
||||
if not _validate_data_sample(data):
|
||||
hand_data_result.append(None)
|
||||
continue
|
||||
hand_landmarks = [[
|
||||
hand_landmark.x, hand_landmark.y, hand_landmark.z
|
||||
] for hand_landmark in data.multi_hand_landmarks[0].landmark]
|
||||
hand_landmarks = [[hand_landmark.x, hand_landmark.y, hand_landmark.z]
|
||||
for hand_landmark in data.hand_landmarks[0]]
|
||||
hand_world_landmarks = [[
|
||||
hand_landmark.x, hand_landmark.y, hand_landmark.z
|
||||
] for hand_landmark in data.multi_hand_world_landmarks[0].landmark]
|
||||
] for hand_landmark in data.hand_world_landmarks[0]]
|
||||
handedness_scores = [
|
||||
handedness.score
|
||||
for handedness in data.multi_handedness[0].classification
|
||||
handedness.score for handedness in data.handedness[0]
|
||||
]
|
||||
hand_data_result.append(
|
||||
HandData(
|
||||
|
|
|
@ -12,21 +12,17 @@
|
|||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
import collections
|
||||
import os
|
||||
import shutil
|
||||
from typing import NamedTuple
|
||||
import unittest
|
||||
|
||||
from absl import flags
|
||||
from absl.testing import parameterized
|
||||
import tensorflow as tf
|
||||
|
||||
from mediapipe.model_maker.python.vision.gesture_recognizer import dataset
|
||||
from mediapipe.python.solutions import hands as mp_hands
|
||||
from mediapipe.tasks.python.test import test_utils
|
||||
|
||||
FLAGS = flags.FLAGS
|
||||
from mediapipe.tasks.python.vision import hand_landmarker
|
||||
|
||||
_TEST_DATA_DIRNAME = 'raw_data'
|
||||
|
||||
|
@ -39,14 +35,14 @@ class DatasetTest(tf.test.TestCase, parameterized.TestCase):
|
|||
dirname=input_data_dir, hparams=dataset.HandDataPreprocessingParams())
|
||||
train_data, test_data = data.split(0.5)
|
||||
|
||||
self.assertLen(train_data, 17)
|
||||
self.assertLen(train_data, 16)
|
||||
for _, elem in enumerate(train_data.gen_tf_dataset(is_training=True)):
|
||||
self.assertEqual(elem[0].shape, (1, 128))
|
||||
self.assertEqual(elem[1].shape, ([1, 4]))
|
||||
self.assertEqual(train_data.num_classes, 4)
|
||||
self.assertEqual(train_data.label_names, ['none', 'call', 'four', 'rock'])
|
||||
|
||||
self.assertLen(test_data, 18)
|
||||
self.assertLen(test_data, 16)
|
||||
for _, elem in enumerate(test_data.gen_tf_dataset(is_training=True)):
|
||||
self.assertEqual(elem[0].shape, (1, 128))
|
||||
self.assertEqual(elem[1].shape, ([1, 4]))
|
||||
|
@ -60,7 +56,7 @@ class DatasetTest(tf.test.TestCase, parameterized.TestCase):
|
|||
for _, elem in enumerate(data.gen_tf_dataset(is_training=True)):
|
||||
self.assertEqual(elem[0].shape, (1, 128))
|
||||
self.assertEqual(elem[1].shape, ([1, 4]))
|
||||
self.assertLen(data, 35)
|
||||
self.assertLen(data, 32)
|
||||
self.assertEqual(data.num_classes, 4)
|
||||
self.assertEqual(data.label_names, ['none', 'call', 'four', 'rock'])
|
||||
|
||||
|
@ -105,51 +101,42 @@ class DatasetTest(tf.test.TestCase, parameterized.TestCase):
|
|||
for _, elem in enumerate(data.gen_tf_dataset(is_training=True)):
|
||||
self.assertEqual(elem[0].shape, (1, 128))
|
||||
self.assertEqual(elem[1].shape, ([1, 4]))
|
||||
self.assertLen(data, 35)
|
||||
self.assertLen(data, 32)
|
||||
self.assertEqual(data.num_classes, 4)
|
||||
self.assertEqual(data.label_names, ['NONE', 'CALL', 'FOUR', 'ROCK'])
|
||||
|
||||
@parameterized.named_parameters(
|
||||
dict(
|
||||
testcase_name='invalid_field_name_multi_hand_landmark',
|
||||
hand=collections.namedtuple('Hand', [
|
||||
'multi_hand_landmark', 'multi_hand_world_landmarks',
|
||||
'multi_handedness'
|
||||
])(1, 2, 3)),
|
||||
testcase_name='none_handedness',
|
||||
hand=hand_landmarker.HandLandmarkerResult(
|
||||
handedness=None, hand_landmarks=[[2]],
|
||||
hand_world_landmarks=[[3]])),
|
||||
dict(
|
||||
testcase_name='invalid_field_name_multi_hand_world_landmarks',
|
||||
hand=collections.namedtuple('Hand', [
|
||||
'multi_hand_landmarks', 'multi_hand_world_landmark',
|
||||
'multi_handedness'
|
||||
])(1, 2, 3)),
|
||||
testcase_name='none_hand_landmarks',
|
||||
hand=hand_landmarker.HandLandmarkerResult(
|
||||
handedness=[[1]], hand_landmarks=None,
|
||||
hand_world_landmarks=[[3]])),
|
||||
dict(
|
||||
testcase_name='invalid_field_name_multi_handed',
|
||||
hand=collections.namedtuple('Hand', [
|
||||
'multi_hand_landmarks', 'multi_hand_world_landmarks',
|
||||
'multi_handed'
|
||||
])(1, 2, 3)),
|
||||
testcase_name='none_hand_world_landmarks',
|
||||
hand=hand_landmarker.HandLandmarkerResult(
|
||||
handedness=[[1]], hand_landmarks=[[2]],
|
||||
hand_world_landmarks=None)),
|
||||
dict(
|
||||
testcase_name='multi_hand_landmarks_is_none',
|
||||
hand=collections.namedtuple('Hand', [
|
||||
'multi_hand_landmarks', 'multi_hand_world_landmarks',
|
||||
'multi_handedness'
|
||||
])(None, 2, 3)),
|
||||
testcase_name='empty_handedness',
|
||||
hand=hand_landmarker.HandLandmarkerResult(
|
||||
handedness=[], hand_landmarks=[[2]], hand_world_landmarks=[[3]])),
|
||||
dict(
|
||||
testcase_name='multi_hand_world_landmarks_is_none',
|
||||
hand=collections.namedtuple('Hand', [
|
||||
'multi_hand_landmarks', 'multi_hand_world_landmarks',
|
||||
'multi_handedness'
|
||||
])(1, None, 3)),
|
||||
testcase_name='empty_hand_landmarks',
|
||||
hand=hand_landmarker.HandLandmarkerResult(
|
||||
handedness=[[1]], hand_landmarks=[], hand_world_landmarks=[[3]])),
|
||||
dict(
|
||||
testcase_name='multi_handedness_is_none',
|
||||
hand=collections.namedtuple('Hand', [
|
||||
'multi_hand_landmarks', 'multi_hand_world_landmarks',
|
||||
'multi_handedness'
|
||||
])(1, 2, None)),
|
||||
testcase_name='empty_hand_world_landmarks',
|
||||
hand=hand_landmarker.HandLandmarkerResult(
|
||||
handedness=[[1]], hand_landmarks=[[2]], hand_world_landmarks=[])),
|
||||
)
|
||||
def test_create_dataset_from_invalid_hand_data(self, hand: NamedTuple):
|
||||
with unittest.mock.patch.object(
|
||||
mp_hands.Hands, 'process', return_value=hand):
|
||||
hand_landmarker.HandLandmarker, 'detect', return_value=hand):
|
||||
input_data_dir = test_utils.get_test_data_path(_TEST_DATA_DIRNAME)
|
||||
with self.assertRaisesRegex(ValueError, 'No valid hand is detected'):
|
||||
dataset.Dataset.from_folder(
|
||||
|
|
|
@ -62,6 +62,50 @@ def read_file(file_path: str, mode: str = "rb") -> Union[str, bytes]:
|
|||
return f.read()
|
||||
|
||||
|
||||
class HandLandmarkerMetadataWriter:
|
||||
"""MetadataWriter to write the model asset bundle for HandLandmarker."""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
hand_detector_model_buffer: bytearray,
|
||||
hand_landmarks_detector_model_buffer: bytearray,
|
||||
) -> None:
|
||||
"""Initializes HandLandmarkerMetadataWriter to write model asset bundle.
|
||||
|
||||
Args:
|
||||
hand_detector_model_buffer: A valid flatbuffer *with* metadata loaded from
|
||||
the TFLite hand detector model file.
|
||||
hand_landmarks_detector_model_buffer: A valid flatbuffer *with* metadata
|
||||
loaded from the TFLite hand landmarks detector model file.
|
||||
"""
|
||||
self._hand_detector_model_buffer = hand_detector_model_buffer
|
||||
self._hand_landmarks_detector_model_buffer = hand_landmarks_detector_model_buffer
|
||||
self._temp_folder = tempfile.TemporaryDirectory()
|
||||
|
||||
def __del__(self):
|
||||
if os.path.exists(self._temp_folder.name):
|
||||
self._temp_folder.cleanup()
|
||||
|
||||
def populate(self):
|
||||
"""Creates the model asset bundle for hand landmarker task.
|
||||
|
||||
Returns:
|
||||
Model asset bundle in bytes
|
||||
"""
|
||||
landmark_models = {
|
||||
_HAND_DETECTOR_TFLITE_NAME:
|
||||
self._hand_detector_model_buffer,
|
||||
_HAND_LANDMARKS_DETECTOR_TFLITE_NAME:
|
||||
self._hand_landmarks_detector_model_buffer
|
||||
}
|
||||
output_hand_landmarker_path = os.path.join(self._temp_folder.name,
|
||||
_HAND_LANDMARKER_BUNDLE_NAME)
|
||||
writer_utils.create_model_asset_bundle(landmark_models,
|
||||
output_hand_landmarker_path)
|
||||
hand_landmarker_model_buffer = read_file(output_hand_landmarker_path)
|
||||
return hand_landmarker_model_buffer
|
||||
|
||||
|
||||
class MetadataWriter:
|
||||
"""MetadataWriter to write the metadata and the model asset bundle."""
|
||||
|
||||
|
@ -86,8 +130,8 @@ class MetadataWriter:
|
|||
custom_gesture_classifier_metadata_writer: Metadata writer to write custom
|
||||
gesture classifier metadata into the TFLite file.
|
||||
"""
|
||||
self._hand_detector_model_buffer = hand_detector_model_buffer
|
||||
self._hand_landmarks_detector_model_buffer = hand_landmarks_detector_model_buffer
|
||||
self._hand_landmarker_metadata_writer = HandLandmarkerMetadataWriter(
|
||||
hand_detector_model_buffer, hand_landmarks_detector_model_buffer)
|
||||
self._gesture_embedder_model_buffer = gesture_embedder_model_buffer
|
||||
self._canned_gesture_classifier_model_buffer = canned_gesture_classifier_model_buffer
|
||||
self._custom_gesture_classifier_metadata_writer = custom_gesture_classifier_metadata_writer
|
||||
|
@ -147,16 +191,8 @@ class MetadataWriter:
|
|||
A tuple of (model_asset_bundle_in_bytes, metadata_json_content)
|
||||
"""
|
||||
# Creates the model asset bundle for hand landmarker task.
|
||||
landmark_models = {
|
||||
_HAND_DETECTOR_TFLITE_NAME:
|
||||
self._hand_detector_model_buffer,
|
||||
_HAND_LANDMARKS_DETECTOR_TFLITE_NAME:
|
||||
self._hand_landmarks_detector_model_buffer
|
||||
}
|
||||
output_hand_landmarker_path = os.path.join(self._temp_folder.name,
|
||||
_HAND_LANDMARKER_BUNDLE_NAME)
|
||||
writer_utils.create_model_asset_bundle(landmark_models,
|
||||
output_hand_landmarker_path)
|
||||
hand_landmarker_model_buffer = self._hand_landmarker_metadata_writer.populate(
|
||||
)
|
||||
|
||||
# Write metadata into custom gesture classifier model.
|
||||
self._custom_gesture_classifier_model_buffer, custom_gesture_classifier_metadata_json = self._custom_gesture_classifier_metadata_writer.populate(
|
||||
|
@ -179,7 +215,7 @@ class MetadataWriter:
|
|||
# graph.
|
||||
gesture_recognizer_models = {
|
||||
_HAND_LANDMARKER_BUNDLE_NAME:
|
||||
read_file(output_hand_landmarker_path),
|
||||
hand_landmarker_model_buffer,
|
||||
_HAND_GESTURE_RECOGNIZER_BUNDLE_NAME:
|
||||
read_file(output_hand_gesture_recognizer_path),
|
||||
}
|
||||
|
|
|
@ -33,6 +33,23 @@ _CUSTOM_GESTURE_CLASSIFIER_PATH = test_utils.get_test_data_path(
|
|||
|
||||
class MetadataWriterTest(tf.test.TestCase):
|
||||
|
||||
def test_hand_landmarker_metadata_writer(self):
|
||||
# Use dummy model buffer for unit test only.
|
||||
hand_detector_model_buffer = b"\x11\x12"
|
||||
hand_landmarks_detector_model_buffer = b"\x22"
|
||||
writer = metadata_writer.HandLandmarkerMetadataWriter(
|
||||
hand_detector_model_buffer, hand_landmarks_detector_model_buffer)
|
||||
model_bundle_content = writer.populate()
|
||||
model_bundle_filepath = os.path.join(self.get_temp_dir(),
|
||||
"hand_landmarker.task")
|
||||
with open(model_bundle_filepath, "wb") as f:
|
||||
f.write(model_bundle_content)
|
||||
|
||||
with zipfile.ZipFile(model_bundle_filepath) as zf:
|
||||
self.assertEqual(
|
||||
set(zf.namelist()),
|
||||
set(["hand_landmarks_detector.tflite", "hand_detector.tflite"]))
|
||||
|
||||
def test_write_metadata_and_create_model_asset_bundle_successful(self):
|
||||
# Use dummy model buffer for unit test only.
|
||||
hand_detector_model_buffer = b"\x11\x12"
|
||||
|
|
|
@ -1,6 +1,8 @@
|
|||
absl-py
|
||||
mediapipe==0.9.1
|
||||
numpy
|
||||
opencv-contrib-python
|
||||
tensorflow
|
||||
opencv-python
|
||||
tensorflow>=2.10
|
||||
tensorflow-datasets
|
||||
tensorflow-hub
|
||||
tf-models-official>=2.10.1
|
||||
|
|
147
mediapipe/model_maker/setup.py
Normal file
147
mediapipe/model_maker/setup.py
Normal file
|
@ -0,0 +1,147 @@
|
|||
"""Copyright 2020-2022 The MediaPipe Authors. All Rights Reserved.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
|
||||
Setup for Mediapipe-Model-Maker package with setuptools.
|
||||
"""
|
||||
|
||||
import glob
|
||||
import os
|
||||
import shutil
|
||||
import subprocess
|
||||
import sys
|
||||
import setuptools
|
||||
|
||||
|
||||
__version__ = 'dev'
|
||||
MM_ROOT_PATH = os.path.dirname(os.path.abspath(__file__))
|
||||
# Build dir to copy all necessary files and build package
|
||||
SRC_NAME = 'pip_src'
|
||||
BUILD_DIR = os.path.join(MM_ROOT_PATH, SRC_NAME)
|
||||
BUILD_MM_DIR = os.path.join(BUILD_DIR, 'mediapipe_model_maker')
|
||||
|
||||
|
||||
def _parse_requirements(path):
|
||||
with open(os.path.join(MM_ROOT_PATH, path)) as f:
|
||||
return [
|
||||
line.rstrip()
|
||||
for line in f
|
||||
if not (line.isspace() or line.startswith('#'))
|
||||
]
|
||||
|
||||
|
||||
def _copy_to_pip_src_dir(file):
|
||||
"""Copy a file from bazel-bin to the pip_src dir."""
|
||||
dst = file
|
||||
dst_dir = os.path.dirname(dst)
|
||||
if not os.path.exists(dst_dir):
|
||||
os.makedirs(dst_dir)
|
||||
src_file = os.path.join('../../bazel-bin/mediapipe/model_maker', file)
|
||||
shutil.copyfile(src_file, file)
|
||||
|
||||
|
||||
def _setup_build_dir():
|
||||
"""Setup the BUILD_DIR directory to build the mediapipe_model_maker package.
|
||||
|
||||
We need to create a new BUILD_DIR directory because any references to the path
|
||||
`mediapipe/model_maker` needs to be renamed to `mediapipe_model_maker` to
|
||||
avoid conflicting with the mediapipe package name.
|
||||
This setup function performs the following actions:
|
||||
1. Copy python source code into BUILD_DIR and rename imports to
|
||||
mediapipe_model_maker
|
||||
2. Download models from GCS into BUILD_DIR
|
||||
"""
|
||||
# Copy python source code into BUILD_DIR
|
||||
if os.path.exists(BUILD_DIR):
|
||||
shutil.rmtree(BUILD_DIR)
|
||||
python_files = glob.glob('python/**/*.py', recursive=True)
|
||||
python_files.append('__init__.py')
|
||||
for python_file in python_files:
|
||||
# Exclude test files from pip package
|
||||
if '_test.py' in python_file:
|
||||
continue
|
||||
build_target_file = os.path.join(BUILD_MM_DIR, python_file)
|
||||
with open(python_file, 'r') as file:
|
||||
filedata = file.read()
|
||||
# Rename all mediapipe.model_maker imports to mediapipe_model_maker
|
||||
filedata = filedata.replace('from mediapipe.model_maker',
|
||||
'from mediapipe_model_maker')
|
||||
os.makedirs(os.path.dirname(build_target_file), exist_ok=True)
|
||||
with open(build_target_file, 'w') as file:
|
||||
file.write(filedata)
|
||||
|
||||
# Use bazel to download GCS model files
|
||||
model_build_files = ['models/gesture_recognizer/BUILD']
|
||||
for model_build_file in model_build_files:
|
||||
build_target_file = os.path.join(BUILD_MM_DIR, model_build_file)
|
||||
os.makedirs(os.path.dirname(build_target_file), exist_ok=True)
|
||||
shutil.copy(model_build_file, build_target_file)
|
||||
external_files = [
|
||||
'models/gesture_recognizer/canned_gesture_classifier.tflite',
|
||||
'models/gesture_recognizer/gesture_embedder.tflite',
|
||||
'models/gesture_recognizer/hand_landmark_full.tflite',
|
||||
'models/gesture_recognizer/palm_detection_full.tflite',
|
||||
'models/gesture_recognizer/gesture_embedder/keras_metadata.pb',
|
||||
'models/gesture_recognizer/gesture_embedder/saved_model.pb',
|
||||
'models/gesture_recognizer/gesture_embedder/variables/variables.data-00000-of-00001',
|
||||
'models/gesture_recognizer/gesture_embedder/variables/variables.index',
|
||||
]
|
||||
for elem in external_files:
|
||||
external_file = os.path.join(f'{SRC_NAME}/mediapipe_model_maker', elem)
|
||||
sys.stderr.write('downloading file: %s\n' % external_file)
|
||||
fetch_model_command = [
|
||||
'bazel',
|
||||
'build',
|
||||
external_file,
|
||||
]
|
||||
if subprocess.call(fetch_model_command) != 0:
|
||||
sys.exit(-1)
|
||||
_copy_to_pip_src_dir(external_file)
|
||||
|
||||
_setup_build_dir()
|
||||
|
||||
setuptools.setup(
|
||||
name='mediapipe-model-maker',
|
||||
version=__version__,
|
||||
url='https://github.com/google/mediapipe/tree/master/mediapipe/model_maker',
|
||||
description='MediaPipe Model Maker is a simple, low-code solution for customizing on-device ML models',
|
||||
author='The MediaPipe Authors',
|
||||
author_email='mediapipe@google.com',
|
||||
long_description='',
|
||||
long_description_content_type='text/markdown',
|
||||
packages=setuptools.find_packages(where=SRC_NAME),
|
||||
package_dir={'': SRC_NAME},
|
||||
install_requires=_parse_requirements('requirements.txt'),
|
||||
include_package_data=True,
|
||||
classifiers=[
|
||||
'Development Status :: 3 - Alpha',
|
||||
'Intended Audience :: Developers',
|
||||
'Intended Audience :: Education',
|
||||
'Intended Audience :: Science/Research',
|
||||
'License :: OSI Approved :: Apache Software License',
|
||||
'Operating System :: MacOS :: MacOS X',
|
||||
'Operating System :: Microsoft :: Windows',
|
||||
'Operating System :: POSIX :: Linux',
|
||||
'Programming Language :: Python :: 3.7',
|
||||
'Programming Language :: Python :: 3.8',
|
||||
'Programming Language :: Python :: 3.9',
|
||||
'Programming Language :: Python :: 3 :: Only',
|
||||
'Topic :: Scientific/Engineering',
|
||||
'Topic :: Scientific/Engineering :: Artificial Intelligence',
|
||||
'Topic :: Software Development',
|
||||
'Topic :: Software Development :: Libraries',
|
||||
'Topic :: Software Development :: Libraries :: Python Modules',
|
||||
],
|
||||
license='Apache 2.0',
|
||||
keywords=['mediapipe', 'model', 'maker'],
|
||||
)
|
|
@ -24,7 +24,6 @@ cc_library(
|
|||
"//mediapipe/framework:calculator_framework",
|
||||
"//mediapipe/framework:calculator_options_cc_proto",
|
||||
"//mediapipe/framework/formats:landmark_cc_proto",
|
||||
"//mediapipe/framework/formats:location_data_cc_proto",
|
||||
"//mediapipe/framework/formats:rect_cc_proto",
|
||||
"//mediapipe/framework/port:ret_check",
|
||||
"//mediapipe/framework/port:status",
|
||||
|
|
|
@ -22,6 +22,8 @@
|
|||
|
||||
namespace mediapipe {
|
||||
|
||||
using ::mediapipe::NormalizedRect;
|
||||
|
||||
namespace {
|
||||
|
||||
// NORM_LANDMARKS is either the full set of landmarks for the hand, or
|
||||
|
|
|
@ -34,6 +34,8 @@ constexpr char kRecropRectTag[] = "RECROP_RECT";
|
|||
constexpr char kImageSizeTag[] = "IMAGE_SIZE";
|
||||
constexpr char kTrackingRectTag[] = "TRACKING_RECT";
|
||||
|
||||
using ::mediapipe::NormalizedRect;
|
||||
|
||||
// TODO: Use rect rotation.
|
||||
// Verifies that Intersection over Union of previous frame rect and current
|
||||
// frame re-crop rect is less than threshold.
|
||||
|
|
|
@ -275,7 +275,6 @@ cc_library(
|
|||
":tflite_tensors_to_objects_calculator_cc_proto",
|
||||
"//mediapipe/framework:calculator_framework",
|
||||
"//mediapipe/framework/deps:file_path",
|
||||
"//mediapipe/framework/formats:detection_cc_proto",
|
||||
"//mediapipe/framework/port:opencv_core",
|
||||
"//mediapipe/framework/port:ret_check",
|
||||
"@com_google_absl//absl/memory",
|
||||
|
@ -299,7 +298,6 @@ cc_library(
|
|||
":tensors_to_objects_calculator_cc_proto",
|
||||
"//mediapipe/framework:calculator_framework",
|
||||
"//mediapipe/framework/deps:file_path",
|
||||
"//mediapipe/framework/formats:detection_cc_proto",
|
||||
"//mediapipe/framework/port:opencv_core",
|
||||
"//mediapipe/framework/port:ret_check",
|
||||
"@com_google_absl//absl/memory",
|
||||
|
@ -316,13 +314,11 @@ cc_library(
|
|||
visibility = ["//visibility:public"],
|
||||
deps = [
|
||||
":annotation_cc_proto",
|
||||
":belief_decoder_config_cc_proto",
|
||||
":decoder",
|
||||
":lift_2d_frame_annotation_to_3d_calculator_cc_proto",
|
||||
":tensor_util",
|
||||
"//mediapipe/framework:calculator_framework",
|
||||
"//mediapipe/framework/deps:file_path",
|
||||
"//mediapipe/framework/formats:detection_cc_proto",
|
||||
"//mediapipe/framework/port:opencv_core",
|
||||
"//mediapipe/framework/port:ret_check",
|
||||
"@com_google_absl//absl/memory",
|
||||
|
|
|
@ -34,6 +34,8 @@ namespace {
|
|||
constexpr char kInputFrameAnnotationTag[] = "FRAME_ANNOTATION";
|
||||
constexpr char kOutputNormRectsTag[] = "NORM_RECTS";
|
||||
|
||||
using ::mediapipe::NormalizedRect;
|
||||
|
||||
} // namespace
|
||||
|
||||
// A calculator that converts FrameAnnotation proto to NormalizedRect.
|
||||
|
|
|
@ -68,7 +68,6 @@ objc_library(
|
|||
copts = [
|
||||
"-Wno-shorten-64-to-32",
|
||||
],
|
||||
sdk_frameworks = ["Accelerate"],
|
||||
# This build rule is public to allow external customers to build their own iOS apps.
|
||||
visibility = ["//visibility:public"],
|
||||
deps = [
|
||||
|
@ -90,6 +89,7 @@ objc_library(
|
|||
"//mediapipe/gpu:metal_shared_resources",
|
||||
"//mediapipe/gpu:pixel_buffer_pool_util",
|
||||
"//mediapipe/util:cpu_util",
|
||||
"//third_party/apple_frameworks:Accelerate",
|
||||
"@com_google_absl//absl/base:core_headers",
|
||||
"@com_google_absl//absl/memory",
|
||||
"@com_google_absl//absl/strings",
|
||||
|
@ -120,13 +120,13 @@ objc_library(
|
|||
],
|
||||
"//conditions:default": [],
|
||||
}),
|
||||
sdk_frameworks = [
|
||||
"AVFoundation",
|
||||
"CoreVideo",
|
||||
"Foundation",
|
||||
],
|
||||
# This build rule is public to allow external customers to build their own iOS apps.
|
||||
visibility = ["//visibility:public"],
|
||||
deps = [
|
||||
"//third_party/apple_frameworks:AVFoundation",
|
||||
"//third_party/apple_frameworks:CoreVideo",
|
||||
"//third_party/apple_frameworks:Foundation",
|
||||
],
|
||||
)
|
||||
|
||||
objc_library(
|
||||
|
@ -140,16 +140,14 @@ objc_library(
|
|||
copts = [
|
||||
"-Wno-shorten-64-to-32",
|
||||
],
|
||||
sdk_frameworks = [
|
||||
"Foundation",
|
||||
"GLKit",
|
||||
],
|
||||
visibility = ["//mediapipe/framework:mediapipe_internal"],
|
||||
deps = [
|
||||
":mediapipe_framework_ios",
|
||||
"//mediapipe/gpu:gl_calculator_helper",
|
||||
"//mediapipe/gpu:gl_quad_renderer",
|
||||
"//mediapipe/gpu:gl_simple_shaders",
|
||||
"//third_party/apple_frameworks:Foundation",
|
||||
"//third_party/apple_frameworks:GLKit",
|
||||
],
|
||||
)
|
||||
|
||||
|
@ -164,16 +162,14 @@ objc_library(
|
|||
copts = [
|
||||
"-Wno-shorten-64-to-32",
|
||||
],
|
||||
sdk_frameworks = [
|
||||
"Foundation",
|
||||
"GLKit",
|
||||
],
|
||||
# This build rule is public to allow external customers to build their own iOS apps.
|
||||
visibility = ["//visibility:public"],
|
||||
deps = [
|
||||
":mediapipe_framework_ios",
|
||||
":mediapipe_gl_view_renderer",
|
||||
"//mediapipe/gpu:gl_calculator_helper",
|
||||
"//third_party/apple_frameworks:Foundation",
|
||||
"//third_party/apple_frameworks:GLKit",
|
||||
],
|
||||
)
|
||||
|
||||
|
@ -188,13 +184,11 @@ objc_library(
|
|||
copts = [
|
||||
"-Wno-shorten-64-to-32",
|
||||
],
|
||||
sdk_frameworks = [
|
||||
"CoreVideo",
|
||||
"Foundation",
|
||||
],
|
||||
visibility = ["//mediapipe/framework:mediapipe_internal"],
|
||||
deps = [
|
||||
":mediapipe_framework_ios",
|
||||
"//third_party/apple_frameworks:CoreVideo",
|
||||
"//third_party/apple_frameworks:Foundation",
|
||||
"@com_google_absl//absl/strings",
|
||||
],
|
||||
)
|
||||
|
@ -211,23 +205,21 @@ objc_library(
|
|||
copts = [
|
||||
"-Wno-shorten-64-to-32",
|
||||
],
|
||||
sdk_frameworks = [
|
||||
"AVFoundation",
|
||||
"Accelerate",
|
||||
"CoreGraphics",
|
||||
"CoreMedia",
|
||||
"CoreVideo",
|
||||
"GLKit",
|
||||
"OpenGLES",
|
||||
"QuartzCore",
|
||||
"UIKit",
|
||||
],
|
||||
visibility = ["//mediapipe/framework:mediapipe_internal"],
|
||||
deps = [
|
||||
":CGImageRefUtils",
|
||||
":Weakify",
|
||||
":mediapipe_framework_ios",
|
||||
"//mediapipe/framework:calculator_framework",
|
||||
"//third_party/apple_frameworks:AVFoundation",
|
||||
"//third_party/apple_frameworks:Accelerate",
|
||||
"//third_party/apple_frameworks:CoreGraphics",
|
||||
"//third_party/apple_frameworks:CoreMedia",
|
||||
"//third_party/apple_frameworks:CoreVideo",
|
||||
"//third_party/apple_frameworks:GLKit",
|
||||
"//third_party/apple_frameworks:OpenGLES",
|
||||
"//third_party/apple_frameworks:QuartzCore",
|
||||
"//third_party/apple_frameworks:UIKit",
|
||||
],
|
||||
)
|
||||
|
||||
|
@ -245,16 +237,6 @@ objc_library(
|
|||
data = [
|
||||
"testdata/googlelogo_color_272x92dp.png",
|
||||
],
|
||||
sdk_frameworks = [
|
||||
"AVFoundation",
|
||||
"Accelerate",
|
||||
"CoreGraphics",
|
||||
"CoreMedia",
|
||||
"CoreVideo",
|
||||
"GLKit",
|
||||
"QuartzCore",
|
||||
"UIKit",
|
||||
],
|
||||
visibility = ["//mediapipe/framework:mediapipe_internal"],
|
||||
deps = [
|
||||
":CGImageRefUtils",
|
||||
|
@ -263,6 +245,14 @@ objc_library(
|
|||
":mediapipe_framework_ios",
|
||||
":mediapipe_input_sources_ios",
|
||||
"//mediapipe/calculators/core:pass_through_calculator",
|
||||
"//third_party/apple_frameworks:AVFoundation",
|
||||
"//third_party/apple_frameworks:Accelerate",
|
||||
"//third_party/apple_frameworks:CoreGraphics",
|
||||
"//third_party/apple_frameworks:CoreMedia",
|
||||
"//third_party/apple_frameworks:CoreVideo",
|
||||
"//third_party/apple_frameworks:GLKit",
|
||||
"//third_party/apple_frameworks:QuartzCore",
|
||||
"//third_party/apple_frameworks:UIKit",
|
||||
],
|
||||
)
|
||||
|
||||
|
|
|
@ -22,9 +22,7 @@ cc_library(
|
|||
name = "audio_classifier",
|
||||
srcs = ["audio_classifier.cc"],
|
||||
hdrs = ["audio_classifier.h"],
|
||||
visibility = [
|
||||
"//mediapipe/tasks:users",
|
||||
],
|
||||
visibility = ["//visibility:public"],
|
||||
deps = [
|
||||
":audio_classifier_graph",
|
||||
"//mediapipe/framework/api2:builder",
|
||||
|
|
|
@ -18,6 +18,7 @@ syntax = "proto2";
|
|||
package mediapipe.tasks.audio.audio_classifier.proto;
|
||||
|
||||
import "mediapipe/framework/calculator.proto";
|
||||
import "mediapipe/framework/calculator_options.proto";
|
||||
import "mediapipe/tasks/cc/components/processors/proto/classifier_options.proto";
|
||||
import "mediapipe/tasks/cc/core/proto/base_options.proto";
|
||||
|
||||
|
|
|
@ -22,9 +22,7 @@ cc_library(
|
|||
name = "audio_embedder",
|
||||
srcs = ["audio_embedder.cc"],
|
||||
hdrs = ["audio_embedder.h"],
|
||||
visibility = [
|
||||
"//mediapipe/tasks:users",
|
||||
],
|
||||
visibility = ["//visibility:public"],
|
||||
deps = [
|
||||
":audio_embedder_graph",
|
||||
"//mediapipe/framework/api2:builder",
|
||||
|
|
|
@ -18,6 +18,7 @@ syntax = "proto2";
|
|||
package mediapipe.tasks.audio.audio_embedder.proto;
|
||||
|
||||
import "mediapipe/framework/calculator.proto";
|
||||
import "mediapipe/framework/calculator_options.proto";
|
||||
import "mediapipe/tasks/cc/components/processors/proto/embedder_options.proto";
|
||||
import "mediapipe/tasks/cc/core/proto/base_options.proto";
|
||||
|
||||
|
|
|
@ -18,6 +18,7 @@ licenses(["notice"])
|
|||
|
||||
cc_library(
|
||||
name = "rect",
|
||||
srcs = ["rect.cc"],
|
||||
hdrs = ["rect.h"],
|
||||
)
|
||||
|
||||
|
@ -41,6 +42,18 @@ cc_library(
|
|||
],
|
||||
)
|
||||
|
||||
cc_library(
|
||||
name = "detection_result",
|
||||
srcs = ["detection_result.cc"],
|
||||
hdrs = ["detection_result.h"],
|
||||
deps = [
|
||||
":category",
|
||||
":rect",
|
||||
"//mediapipe/framework/formats:detection_cc_proto",
|
||||
"//mediapipe/framework/formats:location_data_cc_proto",
|
||||
],
|
||||
)
|
||||
|
||||
cc_library(
|
||||
name = "embedding_result",
|
||||
srcs = ["embedding_result.cc"],
|
||||
|
|
73
mediapipe/tasks/cc/components/containers/detection_result.cc
Normal file
73
mediapipe/tasks/cc/components/containers/detection_result.cc
Normal file
|
@ -0,0 +1,73 @@
|
|||
/* Copyright 2022 The MediaPipe Authors. All Rights Reserved.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
==============================================================================*/
|
||||
|
||||
#include "mediapipe/tasks/cc/components/containers/detection_result.h"
|
||||
|
||||
#include <strings.h>
|
||||
|
||||
#include <optional>
|
||||
#include <string>
|
||||
#include <vector>
|
||||
|
||||
#include "mediapipe/framework/formats/detection.pb.h"
|
||||
#include "mediapipe/framework/formats/location_data.pb.h"
|
||||
#include "mediapipe/tasks/cc/components/containers/category.h"
|
||||
#include "mediapipe/tasks/cc/components/containers/rect.h"
|
||||
|
||||
namespace mediapipe::tasks::components::containers {
|
||||
|
||||
constexpr int kDefaultCategoryIndex = -1;
|
||||
|
||||
Detection ConvertToDetectionResult(
|
||||
const mediapipe::Detection& detection_proto) {
|
||||
Detection detection;
|
||||
for (int idx = 0; idx < detection_proto.score_size(); ++idx) {
|
||||
detection.categories.push_back(
|
||||
{/* index= */ detection_proto.label_id_size() > idx
|
||||
? detection_proto.label_id(idx)
|
||||
: kDefaultCategoryIndex,
|
||||
/* score= */ detection_proto.score(idx),
|
||||
/* category_name */ detection_proto.label_size() > idx
|
||||
? detection_proto.label(idx)
|
||||
: "",
|
||||
/* display_name */ detection_proto.display_name_size() > idx
|
||||
? detection_proto.display_name(idx)
|
||||
: ""});
|
||||
}
|
||||
Rect bounding_box;
|
||||
if (detection_proto.location_data().has_bounding_box()) {
|
||||
mediapipe::LocationData::BoundingBox bounding_box_proto =
|
||||
detection_proto.location_data().bounding_box();
|
||||
bounding_box.left = bounding_box_proto.xmin();
|
||||
bounding_box.top = bounding_box_proto.ymin();
|
||||
bounding_box.right = bounding_box_proto.xmin() + bounding_box_proto.width();
|
||||
bounding_box.bottom =
|
||||
bounding_box_proto.ymin() + bounding_box_proto.height();
|
||||
}
|
||||
detection.bounding_box = bounding_box;
|
||||
return detection;
|
||||
}
|
||||
|
||||
DetectionResult ConvertToDetectionResult(
|
||||
std::vector<mediapipe::Detection> detections_proto) {
|
||||
DetectionResult detection_result;
|
||||
detection_result.detections.reserve(detections_proto.size());
|
||||
for (const auto& detection_proto : detections_proto) {
|
||||
detection_result.detections.push_back(
|
||||
ConvertToDetectionResult(detection_proto));
|
||||
}
|
||||
return detection_result;
|
||||
}
|
||||
} // namespace mediapipe::tasks::components::containers
|
52
mediapipe/tasks/cc/components/containers/detection_result.h
Normal file
52
mediapipe/tasks/cc/components/containers/detection_result.h
Normal file
|
@ -0,0 +1,52 @@
|
|||
/* Copyright 2022 The MediaPipe Authors. All Rights Reserved.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
==============================================================================*/
|
||||
|
||||
#ifndef MEDIAPIPE_TASKS_CC_COMPONENTS_CONTAINERS_DETECTION_RESULT_H_
|
||||
#define MEDIAPIPE_TASKS_CC_COMPONENTS_CONTAINERS_DETECTION_RESULT_H_
|
||||
|
||||
#include <optional>
|
||||
#include <string>
|
||||
#include <vector>
|
||||
|
||||
#include "mediapipe/framework/formats/detection.pb.h"
|
||||
#include "mediapipe/tasks/cc/components/containers/category.h"
|
||||
#include "mediapipe/tasks/cc/components/containers/rect.h"
|
||||
|
||||
namespace mediapipe::tasks::components::containers {
|
||||
|
||||
// Detection for a single bounding box.
|
||||
struct Detection {
|
||||
// A vector of detected categories.
|
||||
std::vector<Category> categories;
|
||||
// The bounding box location.
|
||||
Rect bounding_box;
|
||||
};
|
||||
|
||||
// Detection results of a model.
|
||||
struct DetectionResult {
|
||||
// A vector of Detections.
|
||||
std::vector<Detection> detections;
|
||||
};
|
||||
|
||||
// Utility function to convert from Detection proto to Detection struct.
|
||||
Detection ConvertToDetection(const mediapipe::Detection& detection_proto);
|
||||
|
||||
// Utility function to convert from list of Detection proto to DetectionResult
|
||||
// struct.
|
||||
DetectionResult ConvertToDetectionResult(
|
||||
std::vector<mediapipe::Detection> detections_proto);
|
||||
|
||||
} // namespace mediapipe::tasks::components::containers
|
||||
#endif // MEDIAPIPE_TASKS_CC_COMPONENTS_CONTAINERS_DETECTION_RESULT_H_
|
34
mediapipe/tasks/cc/components/containers/rect.cc
Normal file
34
mediapipe/tasks/cc/components/containers/rect.cc
Normal file
|
@ -0,0 +1,34 @@
|
|||
/* Copyright 2022 The MediaPipe Authors. All Rights Reserved.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
==============================================================================*/
|
||||
|
||||
#include "mediapipe/tasks/cc/components/containers/rect.h"
|
||||
|
||||
namespace mediapipe::tasks::components::containers {
|
||||
|
||||
RectF ToRectF(const Rect& rect, int image_height, int image_width) {
|
||||
return RectF{static_cast<float>(rect.left) / image_width,
|
||||
static_cast<float>(rect.top) / image_height,
|
||||
static_cast<float>(rect.right) / image_width,
|
||||
static_cast<float>(rect.bottom) / image_height};
|
||||
}
|
||||
|
||||
Rect ToRect(const RectF& rect, int image_height, int image_width) {
|
||||
return Rect{static_cast<int>(rect.left * image_width),
|
||||
static_cast<int>(rect.top * image_height),
|
||||
static_cast<int>(rect.right * image_width),
|
||||
static_cast<int>(rect.bottom * image_height)};
|
||||
}
|
||||
|
||||
} // namespace mediapipe::tasks::components::containers
|
Some files were not shown because too many files have changed in this diff Show More
Loading…
Reference in New Issue
Block a user