Merge branch 'master' into ios-task

This commit is contained in:
Prianka Liz Kariat 2022-12-14 11:19:31 +05:30
commit ee230520da
273 changed files with 6982 additions and 1413 deletions

View File

@ -0,0 +1,25 @@
---
name: "Tasks Issue"
about: Use this template for assistance with using MediaPipe Tasks (developers.google.com/mediapipe/solutions) to deploy on-device ML solutions (e.g. gesture recognition etc.) on supported platforms.
labels: type:support
---
<em>Please make sure that this is a [Tasks](https://developers.google.com/mediapipe/solutions) issue.<em>
**System information** (Please provide as much relevant information as possible)
- Have I written custom code (as opposed to using a stock example script provided in MediaPipe):
- OS Platform and Distribution (e.g., Linux Ubuntu 16.04, Android 11, iOS 14.4):
- MediaPipe Tasks SDK version:
- Task name (e.g. Object detection, Gesture recognition etc.):
- Programming Language and version ( e.g. C++, Python, Java):
**Describe the expected behavior:**
**Standalone code you may have used to try to get what you need :**
If there is a problem, provide a reproducible test case that is the bare minimum necessary to generate the problem. If possible, please share a link to Colab, GitHub repo link or anything that we can use to reproduce the problem:
**Other info / Complete Logs :**
Include any logs or source code that would be helpful to
diagnose the problem. If including tracebacks, please include the full
traceback. Large logs and files should be attached:

View File

@ -0,0 +1,25 @@
---
name: "Model Maker Issue"
about: Use this template for assistance with using MediaPipe Model Maker (developers.google.com/mediapipe/solutions) to create custom on-device ML solutions.
labels: type:support
---
<em>Please make sure that this is a [Model Maker](https://developers.google.com/mediapipe/solutions) issue.<em>
**System information** (Please provide as much relevant information as possible)
- Have I written custom code (as opposed to using a stock example script provided in MediaPipe):
- OS Platform and Distribution (e.g., Linux Ubuntu 16.04):
- Python version (e.g. 3.8):
- [MediaPipe Model Maker version](https://pypi.org/project/mediapipe-model-maker/):
- Task name (e.g. Image classification, Gesture recognition etc.):
**Describe the expected behavior:**
**Standalone code you may have used to try to get what you need :**
If there is a problem, provide a reproducible test case that is the bare minimum necessary to generate the problem. If possible, please share a link to Colab, GitHub repo link or anything that we can use to reproduce the problem:
**Other info / Complete Logs :**
Include any logs or source code that would be helpful to
diagnose the problem. If including tracebacks, please include the full
traceback. Large logs and files should be attached:

View File

@ -1,6 +1,6 @@
---
name: "Solution Issue"
about: Use this template for assistance with a specific mediapipe solution, such as "Pose" or "Iris", including inference model usage/training, solution-specific calculators, etc.
name: "Solution (legacy) Issue"
about: Use this template for assistance with a specific Mediapipe solution (google.github.io/mediapipe/solutions), such as "Pose" or "Iris", including inference model usage/training, solution-specific calculators, etc.
labels: type:support
---

View File

@ -0,0 +1,19 @@
---
name: "Studio Issue"
about: Use this template for assistance with the MediaPipe Studio application.
labels: type:support
---
<em>Please make sure that this is a MediaPipe Studio issue.<em>
**System information** (Please provide as much relevant information as possible)
- OS Platform and Distribution (e.g., Linux Ubuntu 16.04, Android 11, iOS 14.4):
- Browser and Version
- Any microphone or camera hardware
- URL that shows the problem
**Describe the expected behavior:**
**Other info / Complete Logs :**
Include any js console logs that would be helpful to diagnose the problem.
Large logs and files should be attached:

View File

@ -320,12 +320,30 @@ http_archive(
],
)
# iOS basic build deps.
# Load Zlib before initializing TensorFlow and the iOS build rules to guarantee
# that the target @zlib//:mini_zlib is available
http_archive(
name = "zlib",
build_file = "//third_party:zlib.BUILD",
sha256 = "c3e5e9fdd5004dcb542feda5ee4f0ff0744628baf8ed2dd5d66f8ca1197cb1a1",
strip_prefix = "zlib-1.2.11",
urls = [
"http://mirror.bazel.build/zlib.net/fossils/zlib-1.2.11.tar.gz",
"http://zlib.net/fossils/zlib-1.2.11.tar.gz", # 2017-01-15
],
patches = [
"@//third_party:zlib.diff",
],
patch_args = [
"-p1",
],
)
# iOS basic build deps.
http_archive(
name = "build_bazel_rules_apple",
sha256 = "77e8bf6fda706f420a55874ae6ee4df0c9d95da6c7838228b26910fc82eea5a2",
url = "https://github.com/bazelbuild/rules_apple/releases/download/0.32.0/rules_apple.0.32.0.tar.gz",
sha256 = "f94e6dddf74739ef5cb30f000e13a2a613f6ebfa5e63588305a71fce8a8a9911",
url = "https://github.com/bazelbuild/rules_apple/releases/download/1.1.3/rules_apple.1.1.3.tar.gz",
patches = [
# Bypass checking ios unit test runner when building MP ios applications.
"@//third_party:build_bazel_rules_apple_bypass_test_runner_check.diff"
@ -339,29 +357,24 @@ load(
"@build_bazel_rules_apple//apple:repositories.bzl",
"apple_rules_dependencies",
)
apple_rules_dependencies()
load(
"@build_bazel_rules_swift//swift:repositories.bzl",
"swift_rules_dependencies",
)
swift_rules_dependencies()
http_archive(
name = "build_bazel_apple_support",
sha256 = "741366f79d900c11e11d8efd6cc6c66a31bfb2451178b58e0b5edc6f1db17b35",
urls = [
"https://github.com/bazelbuild/apple_support/releases/download/0.10.0/apple_support.0.10.0.tar.gz"
],
load(
"@build_bazel_rules_swift//swift:extras.bzl",
"swift_rules_extra_dependencies",
)
swift_rules_extra_dependencies()
load(
"@build_bazel_apple_support//lib:repositories.bzl",
"apple_support_dependencies",
)
apple_support_dependencies()
# More iOS deps.
@ -442,25 +455,6 @@ http_archive(
],
)
# Load Zlib before initializing TensorFlow to guarantee that the target
# @zlib//:mini_zlib is available
http_archive(
name = "zlib",
build_file = "//third_party:zlib.BUILD",
sha256 = "c3e5e9fdd5004dcb542feda5ee4f0ff0744628baf8ed2dd5d66f8ca1197cb1a1",
strip_prefix = "zlib-1.2.11",
urls = [
"http://mirror.bazel.build/zlib.net/fossils/zlib-1.2.11.tar.gz",
"http://zlib.net/fossils/zlib-1.2.11.tar.gz", # 2017-01-15
],
patches = [
"@//third_party:zlib.diff",
],
patch_args = [
"-p1",
],
)
# TensorFlow repo should always go after the other external dependencies.
# TF on 2022-08-10.
_TENSORFLOW_GIT_COMMIT = "af1d5bc4fbb66d9e6cc1cf89503014a99233583b"

View File

@ -0,0 +1,81 @@
# Copyright 2022 The MediaPipe Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
r"""MediaPipe Model Maker reference docs generation script.
This script generates API reference docs for the `mediapipe` PIP package.
$> pip install -U git+https://github.com/tensorflow/docs mediapipe-model-maker
$> python build_model_maker_api_docs.py
"""
import os
from absl import app
from absl import flags
from tensorflow_docs.api_generator import generate_lib
try:
# mediapipe has not been set up to work with bazel yet, so catch & report.
import mediapipe_model_maker # pytype: disable=import-error
except ImportError as e:
raise ImportError('Please `pip install mediapipe-model-maker`.') from e
PROJECT_SHORT_NAME = 'mediapipe_model_maker'
PROJECT_FULL_NAME = 'MediaPipe Model Maker'
_OUTPUT_DIR = flags.DEFINE_string(
'output_dir',
default='/tmp/generated_docs',
help='Where to write the resulting docs.')
_URL_PREFIX = flags.DEFINE_string(
'code_url_prefix',
'https://github.com/google/mediapipe/tree/master/mediapipe/model_maker',
'The url prefix for links to code.')
_SEARCH_HINTS = flags.DEFINE_bool(
'search_hints', True,
'Include metadata search hints in the generated files')
_SITE_PATH = flags.DEFINE_string('site_path', '/mediapipe/api_docs/python',
'Path prefix in the _toc.yaml')
def gen_api_docs():
"""Generates API docs for the mediapipe-model-maker package."""
doc_generator = generate_lib.DocGenerator(
root_title=PROJECT_FULL_NAME,
py_modules=[(PROJECT_SHORT_NAME, mediapipe_model_maker)],
base_dir=os.path.dirname(mediapipe_model_maker.__file__),
code_url_prefix=_URL_PREFIX.value,
search_hints=_SEARCH_HINTS.value,
site_path=_SITE_PATH.value,
callbacks=[],
)
doc_generator.build(_OUTPUT_DIR.value)
print('Docs output to:', _OUTPUT_DIR.value)
def main(_):
gen_api_docs()
if __name__ == '__main__':
app.run(main)

View File

@ -26,7 +26,6 @@ from absl import app
from absl import flags
from tensorflow_docs.api_generator import generate_lib
from tensorflow_docs.api_generator import public_api
try:
# mediapipe has not been set up to work with bazel yet, so catch & report.
@ -45,14 +44,14 @@ _OUTPUT_DIR = flags.DEFINE_string(
_URL_PREFIX = flags.DEFINE_string(
'code_url_prefix',
'https://github.com/google/mediapipe/tree/master/mediapipe',
'https://github.com/google/mediapipe/blob/master/mediapipe',
'The url prefix for links to code.')
_SEARCH_HINTS = flags.DEFINE_bool(
'search_hints', True,
'Include metadata search hints in the generated files')
_SITE_PATH = flags.DEFINE_string('site_path', '/mediapipe/api_docs/python',
_SITE_PATH = flags.DEFINE_string('site_path', '/mediapipe/api/solutions/python',
'Path prefix in the _toc.yaml')
@ -68,10 +67,7 @@ def gen_api_docs():
code_url_prefix=_URL_PREFIX.value,
search_hints=_SEARCH_HINTS.value,
site_path=_SITE_PATH.value,
# This callback ensures that docs are only generated for objects that
# are explicitly imported in your __init__.py files. There are other
# options but this is a good starting point.
callbacks=[public_api.explicit_package_contents_filter],
callbacks=[],
)
doc_generator.build(_OUTPUT_DIR.value)

View File

@ -17,12 +17,11 @@ load("//mediapipe/framework/port:build_config.bzl", "mediapipe_proto_library")
licenses(["notice"])
package(default_visibility = ["//visibility:private"])
package(default_visibility = ["//visibility:public"])
mediapipe_proto_library(
name = "concatenate_vector_calculator_proto",
srcs = ["concatenate_vector_calculator.proto"],
visibility = ["//visibility:public"],
deps = [
"//mediapipe/framework:calculator_options_proto",
"//mediapipe/framework:calculator_proto",
@ -32,7 +31,6 @@ mediapipe_proto_library(
mediapipe_proto_library(
name = "dequantize_byte_array_calculator_proto",
srcs = ["dequantize_byte_array_calculator.proto"],
visibility = ["//visibility:public"],
deps = [
"//mediapipe/framework:calculator_options_proto",
"//mediapipe/framework:calculator_proto",
@ -42,7 +40,6 @@ mediapipe_proto_library(
mediapipe_proto_library(
name = "packet_cloner_calculator_proto",
srcs = ["packet_cloner_calculator.proto"],
visibility = ["//visibility:public"],
deps = [
"//mediapipe/framework:calculator_options_proto",
"//mediapipe/framework:calculator_proto",
@ -52,7 +49,6 @@ mediapipe_proto_library(
mediapipe_proto_library(
name = "packet_resampler_calculator_proto",
srcs = ["packet_resampler_calculator.proto"],
visibility = ["//visibility:public"],
deps = [
"//mediapipe/framework:calculator_options_proto",
"//mediapipe/framework:calculator_proto",
@ -62,7 +58,6 @@ mediapipe_proto_library(
mediapipe_proto_library(
name = "packet_thinner_calculator_proto",
srcs = ["packet_thinner_calculator.proto"],
visibility = ["//visibility:public"],
deps = [
"//mediapipe/framework:calculator_options_proto",
"//mediapipe/framework:calculator_proto",
@ -72,7 +67,6 @@ mediapipe_proto_library(
mediapipe_proto_library(
name = "split_vector_calculator_proto",
srcs = ["split_vector_calculator.proto"],
visibility = ["//visibility:public"],
deps = [
"//mediapipe/framework:calculator_options_proto",
"//mediapipe/framework:calculator_proto",
@ -82,7 +76,6 @@ mediapipe_proto_library(
mediapipe_proto_library(
name = "quantize_float_vector_calculator_proto",
srcs = ["quantize_float_vector_calculator.proto"],
visibility = ["//visibility:public"],
deps = [
"//mediapipe/framework:calculator_options_proto",
"//mediapipe/framework:calculator_proto",
@ -92,7 +85,6 @@ mediapipe_proto_library(
mediapipe_proto_library(
name = "sequence_shift_calculator_proto",
srcs = ["sequence_shift_calculator.proto"],
visibility = ["//visibility:public"],
deps = [
"//mediapipe/framework:calculator_options_proto",
"//mediapipe/framework:calculator_proto",
@ -102,7 +94,6 @@ mediapipe_proto_library(
mediapipe_proto_library(
name = "gate_calculator_proto",
srcs = ["gate_calculator.proto"],
visibility = ["//visibility:public"],
deps = [
"//mediapipe/framework:calculator_options_proto",
"//mediapipe/framework:calculator_proto",
@ -112,7 +103,6 @@ mediapipe_proto_library(
mediapipe_proto_library(
name = "constant_side_packet_calculator_proto",
srcs = ["constant_side_packet_calculator.proto"],
visibility = ["//visibility:public"],
deps = [
"//mediapipe/framework:calculator_options_proto",
"//mediapipe/framework:calculator_proto",
@ -124,7 +114,6 @@ mediapipe_proto_library(
mediapipe_proto_library(
name = "clip_vector_size_calculator_proto",
srcs = ["clip_vector_size_calculator.proto"],
visibility = ["//visibility:public"],
deps = [
"//mediapipe/framework:calculator_options_proto",
"//mediapipe/framework:calculator_proto",
@ -134,7 +123,6 @@ mediapipe_proto_library(
mediapipe_proto_library(
name = "flow_limiter_calculator_proto",
srcs = ["flow_limiter_calculator.proto"],
visibility = ["//visibility:public"],
deps = [
"//mediapipe/framework:calculator_options_proto",
"//mediapipe/framework:calculator_proto",
@ -144,7 +132,6 @@ mediapipe_proto_library(
mediapipe_proto_library(
name = "graph_profile_calculator_proto",
srcs = ["graph_profile_calculator.proto"],
visibility = ["//visibility:public"],
deps = [
"//mediapipe/framework:calculator_options_proto",
"//mediapipe/framework:calculator_proto",
@ -154,7 +141,6 @@ mediapipe_proto_library(
mediapipe_proto_library(
name = "get_vector_item_calculator_proto",
srcs = ["get_vector_item_calculator.proto"],
visibility = ["//visibility:public"],
deps = [
"//mediapipe/framework:calculator_options_proto",
"//mediapipe/framework:calculator_proto",
@ -164,7 +150,6 @@ mediapipe_proto_library(
cc_library(
name = "add_header_calculator",
srcs = ["add_header_calculator.cc"],
visibility = ["//visibility:public"],
deps = [
"//mediapipe/framework:calculator_framework",
"//mediapipe/framework/api2:node",
@ -193,7 +178,6 @@ cc_library(
name = "begin_loop_calculator",
srcs = ["begin_loop_calculator.cc"],
hdrs = ["begin_loop_calculator.h"],
visibility = ["//visibility:public"],
deps = [
"//mediapipe/framework:calculator_context",
"//mediapipe/framework:calculator_contract",
@ -216,7 +200,6 @@ cc_library(
name = "end_loop_calculator",
srcs = ["end_loop_calculator.cc"],
hdrs = ["end_loop_calculator.h"],
visibility = ["//visibility:public"],
deps = [
"//mediapipe/framework:calculator_context",
"//mediapipe/framework:calculator_contract",
@ -258,7 +241,6 @@ cc_test(
cc_library(
name = "concatenate_vector_calculator_hdr",
hdrs = ["concatenate_vector_calculator.h"],
visibility = ["//visibility:public"],
deps = [
":concatenate_vector_calculator_cc_proto",
"//mediapipe/framework:calculator_framework",
@ -284,7 +266,6 @@ cc_library(
],
"//conditions:default": [],
}),
visibility = ["//visibility:public"],
deps = [
":concatenate_vector_calculator_cc_proto",
"//mediapipe/framework/api2:node",
@ -311,7 +292,6 @@ cc_library(
cc_library(
name = "concatenate_detection_vector_calculator",
srcs = ["concatenate_detection_vector_calculator.cc"],
visibility = ["//visibility:public"],
deps = [
":concatenate_vector_calculator",
"//mediapipe/framework:calculator_framework",
@ -323,7 +303,6 @@ cc_library(
cc_library(
name = "concatenate_proto_list_calculator",
srcs = ["concatenate_proto_list_calculator.cc"],
visibility = ["//visibility:public"],
deps = [
":concatenate_vector_calculator_cc_proto",
"//mediapipe/framework:calculator_framework",
@ -372,7 +351,6 @@ cc_library(
name = "clip_vector_size_calculator",
srcs = ["clip_vector_size_calculator.cc"],
hdrs = ["clip_vector_size_calculator.h"],
visibility = ["//visibility:public"],
deps = [
":clip_vector_size_calculator_cc_proto",
"//mediapipe/framework:calculator_framework",
@ -388,7 +366,6 @@ cc_library(
cc_library(
name = "clip_detection_vector_size_calculator",
srcs = ["clip_detection_vector_size_calculator.cc"],
visibility = ["//visibility:public"],
deps = [
":clip_vector_size_calculator",
"//mediapipe/framework:calculator_framework",
@ -415,9 +392,6 @@ cc_test(
cc_library(
name = "counting_source_calculator",
srcs = ["counting_source_calculator.cc"],
visibility = [
"//visibility:public",
],
deps = [
"//mediapipe/framework:calculator_framework",
"//mediapipe/framework/port:ret_check",
@ -430,9 +404,6 @@ cc_library(
cc_library(
name = "make_pair_calculator",
srcs = ["make_pair_calculator.cc"],
visibility = [
"//visibility:public",
],
deps = [
"//mediapipe/framework:calculator_framework",
"//mediapipe/framework/api2:node",
@ -461,9 +432,6 @@ cc_test(
cc_library(
name = "matrix_multiply_calculator",
srcs = ["matrix_multiply_calculator.cc"],
visibility = [
"//visibility:public",
],
deps = [
"//mediapipe/framework:calculator_framework",
"//mediapipe/framework/api2:node",
@ -477,9 +445,6 @@ cc_library(
cc_library(
name = "matrix_subtract_calculator",
srcs = ["matrix_subtract_calculator.cc"],
visibility = [
"//visibility:public",
],
deps = [
"//mediapipe/framework:calculator_framework",
"//mediapipe/framework/api2:node",
@ -493,9 +458,6 @@ cc_library(
cc_library(
name = "mux_calculator",
srcs = ["mux_calculator.cc"],
visibility = [
"//visibility:public",
],
deps = [
"//mediapipe/framework:calculator_framework",
"//mediapipe/framework/api2:node",
@ -508,9 +470,6 @@ cc_library(
cc_library(
name = "non_zero_calculator",
srcs = ["non_zero_calculator.cc"],
visibility = [
"//visibility:public",
],
deps = [
"//mediapipe/framework:calculator_framework",
"//mediapipe/framework/api2:node",
@ -556,9 +515,6 @@ cc_test(
cc_library(
name = "packet_cloner_calculator",
srcs = ["packet_cloner_calculator.cc"],
visibility = [
"//visibility:public",
],
deps = [
":packet_cloner_calculator_cc_proto",
"//mediapipe/framework:calculator_framework",
@ -587,7 +543,6 @@ cc_test(
cc_library(
name = "packet_inner_join_calculator",
srcs = ["packet_inner_join_calculator.cc"],
visibility = ["//visibility:public"],
deps = [
"//mediapipe/framework:calculator_framework",
"//mediapipe/framework/port:ret_check",
@ -611,7 +566,6 @@ cc_test(
cc_library(
name = "packet_thinner_calculator",
srcs = ["packet_thinner_calculator.cc"],
visibility = ["//visibility:public"],
deps = [
"//mediapipe/calculators/core:packet_thinner_calculator_cc_proto",
"//mediapipe/framework:calculator_context",
@ -643,9 +597,6 @@ cc_test(
cc_library(
name = "pass_through_calculator",
srcs = ["pass_through_calculator.cc"],
visibility = [
"//visibility:public",
],
deps = [
"//mediapipe/framework:calculator_framework",
"//mediapipe/framework/port:status",
@ -656,9 +607,6 @@ cc_library(
cc_library(
name = "round_robin_demux_calculator",
srcs = ["round_robin_demux_calculator.cc"],
visibility = [
"//visibility:public",
],
deps = [
"//mediapipe/framework:calculator_framework",
"//mediapipe/framework/api2:node",
@ -670,9 +618,6 @@ cc_library(
cc_library(
name = "immediate_mux_calculator",
srcs = ["immediate_mux_calculator.cc"],
visibility = [
"//visibility:public",
],
deps = [
"//mediapipe/framework:calculator_framework",
"//mediapipe/framework/port:ret_check",
@ -684,7 +629,6 @@ cc_library(
cc_library(
name = "packet_presence_calculator",
srcs = ["packet_presence_calculator.cc"],
visibility = ["//visibility:public"],
deps = [
"//mediapipe/framework:calculator_framework",
"//mediapipe/framework:packet",
@ -713,7 +657,6 @@ cc_test(
cc_library(
name = "previous_loopback_calculator",
srcs = ["previous_loopback_calculator.cc"],
visibility = ["//visibility:public"],
deps = [
"//mediapipe/framework:calculator_framework",
"//mediapipe/framework:packet",
@ -729,7 +672,6 @@ cc_library(
cc_library(
name = "flow_limiter_calculator",
srcs = ["flow_limiter_calculator.cc"],
visibility = ["//visibility:public"],
deps = [
":flow_limiter_calculator_cc_proto",
"//mediapipe/framework:calculator_framework",
@ -746,7 +688,6 @@ cc_library(
cc_library(
name = "string_to_int_calculator",
srcs = ["string_to_int_calculator.cc"],
visibility = ["//visibility:public"],
deps = [
"//mediapipe/framework:calculator_framework",
"//mediapipe/framework/port:integral_types",
@ -759,7 +700,6 @@ cc_library(
cc_library(
name = "default_side_packet_calculator",
srcs = ["default_side_packet_calculator.cc"],
visibility = ["//visibility:public"],
deps = [
"//mediapipe/framework:calculator_framework",
"//mediapipe/framework/port:ret_check",
@ -771,7 +711,6 @@ cc_library(
cc_library(
name = "side_packet_to_stream_calculator",
srcs = ["side_packet_to_stream_calculator.cc"],
visibility = ["//visibility:public"],
deps = [
"//mediapipe/framework:calculator_framework",
"//mediapipe/framework/port:logging",
@ -822,9 +761,6 @@ cc_library(
name = "packet_resampler_calculator",
srcs = ["packet_resampler_calculator.cc"],
hdrs = ["packet_resampler_calculator.h"],
visibility = [
"//visibility:public",
],
deps = [
"//mediapipe/calculators/core:packet_resampler_calculator_cc_proto",
"//mediapipe/framework:calculator_framework",
@ -884,7 +820,6 @@ cc_test(
cc_test(
name = "matrix_multiply_calculator_test",
srcs = ["matrix_multiply_calculator_test.cc"],
visibility = ["//visibility:private"],
deps = [
":matrix_multiply_calculator",
"//mediapipe/framework:calculator_framework",
@ -900,7 +835,6 @@ cc_test(
cc_test(
name = "matrix_subtract_calculator_test",
srcs = ["matrix_subtract_calculator_test.cc"],
visibility = ["//visibility:private"],
deps = [
":matrix_subtract_calculator",
"//mediapipe/framework:calculator_framework",
@ -950,7 +884,6 @@ cc_library(
],
"//conditions:default": [],
}),
visibility = ["//visibility:public"],
deps = [
":split_vector_calculator_cc_proto",
"//mediapipe/framework/formats:detection_cc_proto",
@ -996,7 +929,6 @@ cc_test(
cc_library(
name = "split_proto_list_calculator",
srcs = ["split_proto_list_calculator.cc"],
visibility = ["//visibility:public"],
deps = [
":split_vector_calculator_cc_proto",
"//mediapipe/framework:calculator_framework",
@ -1028,7 +960,6 @@ cc_test(
cc_library(
name = "dequantize_byte_array_calculator",
srcs = ["dequantize_byte_array_calculator.cc"],
visibility = ["//visibility:public"],
deps = [
":dequantize_byte_array_calculator_cc_proto",
"//mediapipe/framework:calculator_context",
@ -1054,7 +985,6 @@ cc_test(
cc_library(
name = "quantize_float_vector_calculator",
srcs = ["quantize_float_vector_calculator.cc"],
visibility = ["//visibility:public"],
deps = [
":quantize_float_vector_calculator_cc_proto",
"//mediapipe/framework:calculator_context",
@ -1080,7 +1010,6 @@ cc_test(
cc_library(
name = "sequence_shift_calculator",
srcs = ["sequence_shift_calculator.cc"],
visibility = ["//visibility:public"],
deps = [
":sequence_shift_calculator_cc_proto",
"//mediapipe/framework:calculator_framework",
@ -1105,7 +1034,6 @@ cc_test(
cc_library(
name = "gate_calculator",
srcs = ["gate_calculator.cc"],
visibility = ["//visibility:public"],
deps = [
":gate_calculator_cc_proto",
"//mediapipe/framework:calculator_framework",
@ -1131,7 +1059,6 @@ cc_test(
cc_library(
name = "matrix_to_vector_calculator",
srcs = ["matrix_to_vector_calculator.cc"],
visibility = ["//visibility:public"],
deps = [
"//mediapipe/framework:calculator_framework",
"//mediapipe/framework/api2:node",
@ -1167,7 +1094,6 @@ cc_test(
cc_library(
name = "merge_calculator",
srcs = ["merge_calculator.cc"],
visibility = ["//visibility:public"],
deps = [
"//mediapipe/framework:calculator_framework",
"//mediapipe/framework/api2:node",
@ -1193,7 +1119,6 @@ cc_test(
cc_library(
name = "stream_to_side_packet_calculator",
srcs = ["stream_to_side_packet_calculator.cc"],
visibility = ["//visibility:public"],
deps = [
"//mediapipe/framework:calculator_framework",
"//mediapipe/framework:timestamp",
@ -1219,7 +1144,6 @@ cc_test(
cc_library(
name = "constant_side_packet_calculator",
srcs = ["constant_side_packet_calculator.cc"],
visibility = ["//visibility:public"],
deps = [
":constant_side_packet_calculator_cc_proto",
"//mediapipe/framework:calculator_framework",
@ -1249,7 +1173,6 @@ cc_test(
cc_library(
name = "graph_profile_calculator",
srcs = ["graph_profile_calculator.cc"],
visibility = ["//visibility:public"],
deps = [
":graph_profile_calculator_cc_proto",
"//mediapipe/framework:calculator_framework",
@ -1291,7 +1214,6 @@ cc_library(
name = "get_vector_item_calculator",
srcs = ["get_vector_item_calculator.cc"],
hdrs = ["get_vector_item_calculator.h"],
visibility = ["//visibility:public"],
deps = [
":get_vector_item_calculator_cc_proto",
"//mediapipe/framework:calculator_framework",
@ -1325,7 +1247,6 @@ cc_library(
name = "vector_indices_calculator",
srcs = ["vector_indices_calculator.cc"],
hdrs = ["vector_indices_calculator.h"],
visibility = ["//visibility:public"],
deps = [
"//mediapipe/framework:calculator_framework",
"//mediapipe/framework/api2:node",
@ -1351,7 +1272,6 @@ cc_library(
name = "vector_size_calculator",
srcs = ["vector_size_calculator.cc"],
hdrs = ["vector_size_calculator.h"],
visibility = ["//visibility:public"],
deps = [
"//mediapipe/framework:calculator_framework",
"//mediapipe/framework/api2:node",
@ -1365,9 +1285,6 @@ cc_library(
cc_library(
name = "packet_sequencer_calculator",
srcs = ["packet_sequencer_calculator.cc"],
visibility = [
"//visibility:public",
],
deps = [
"//mediapipe/framework:calculator_framework",
"//mediapipe/framework/api2:contract",
@ -1402,11 +1319,11 @@ cc_library(
name = "merge_to_vector_calculator",
srcs = ["merge_to_vector_calculator.cc"],
hdrs = ["merge_to_vector_calculator.h"],
visibility = ["//visibility:public"],
deps = [
"//mediapipe/framework:calculator_framework",
"//mediapipe/framework/api2:node",
"//mediapipe/framework/api2:port",
"//mediapipe/framework/formats:detection_cc_proto",
"//mediapipe/framework/formats:image",
"@com_google_absl//absl/status",
],
@ -1416,7 +1333,6 @@ cc_library(
mediapipe_proto_library(
name = "bypass_calculator_proto",
srcs = ["bypass_calculator.proto"],
visibility = ["//visibility:public"],
deps = [
"//mediapipe/framework:calculator_options_proto",
"//mediapipe/framework:calculator_proto",
@ -1426,7 +1342,6 @@ mediapipe_proto_library(
cc_library(
name = "bypass_calculator",
srcs = ["bypass_calculator.cc"],
visibility = ["//visibility:public"],
deps = [
":bypass_calculator_cc_proto",
"//mediapipe/framework:calculator_framework",

View File

@ -65,6 +65,7 @@ class GetVectorItemCalculator : public Node {
MEDIAPIPE_NODE_CONTRACT(kIn, kIdx, kOut);
absl::Status Open(CalculatorContext* cc) final {
cc->SetOffset(mediapipe::TimestampDiff(0));
auto& options = cc->Options<mediapipe::GetVectorItemCalculatorOptions>();
RET_CHECK(kIdx(cc).IsConnected() || options.has_item_index());
return absl::OkStatus();
@ -90,8 +91,12 @@ class GetVectorItemCalculator : public Node {
return absl::OkStatus();
}
RET_CHECK(idx >= 0 && idx < items.size());
kOut(cc).Send(items[idx]);
RET_CHECK(idx >= 0);
RET_CHECK(options.output_empty_on_oob() || idx < items.size());
if (idx < items.size()) {
kOut(cc).Send(items[idx]);
}
return absl::OkStatus();
}

View File

@ -26,4 +26,7 @@ message GetVectorItemCalculatorOptions {
// Index of vector item to get. INDEX input stream can be used instead, or to
// override.
optional int32 item_index = 1;
// Set to true to output an empty packet when the index is out of bounds.
optional bool output_empty_on_oob = 2;
}

View File

@ -32,18 +32,21 @@ CalculatorRunner MakeRunnerWithStream() {
)");
}
CalculatorRunner MakeRunnerWithOptions(int set_index) {
return CalculatorRunner(absl::StrFormat(R"(
CalculatorRunner MakeRunnerWithOptions(int set_index,
bool output_empty_on_oob = false) {
return CalculatorRunner(
absl::StrFormat(R"(
calculator: "TestGetIntVectorItemCalculator"
input_stream: "VECTOR:vector_stream"
output_stream: "ITEM:item_stream"
options {
[mediapipe.GetVectorItemCalculatorOptions.ext] {
item_index: %d
output_empty_on_oob: %s
}
}
)",
set_index));
set_index, output_empty_on_oob ? "true" : "false"));
}
void AddInputVector(CalculatorRunner& runner, const std::vector<int>& inputs,
@ -140,8 +143,7 @@ TEST(TestGetIntVectorItemCalculatorTest, StreamIndexBoundsCheckFail1) {
absl::Status status = runner.Run();
ASSERT_FALSE(status.ok());
EXPECT_THAT(status.message(),
testing::HasSubstr("idx >= 0 && idx < items.size()"));
EXPECT_THAT(status.message(), testing::HasSubstr("idx >= 0"));
}
TEST(TestGetIntVectorItemCalculatorTest, StreamIndexBoundsCheckFail2) {
@ -155,7 +157,8 @@ TEST(TestGetIntVectorItemCalculatorTest, StreamIndexBoundsCheckFail2) {
absl::Status status = runner.Run();
ASSERT_FALSE(status.ok());
EXPECT_THAT(status.message(),
testing::HasSubstr("idx >= 0 && idx < items.size()"));
testing::HasSubstr(
"options.output_empty_on_oob() || idx < items.size()"));
}
TEST(TestGetIntVectorItemCalculatorTest, OptionsIndexBoundsCheckFail1) {
@ -167,8 +170,7 @@ TEST(TestGetIntVectorItemCalculatorTest, OptionsIndexBoundsCheckFail1) {
absl::Status status = runner.Run();
ASSERT_FALSE(status.ok());
EXPECT_THAT(status.message(),
testing::HasSubstr("idx >= 0 && idx < items.size()"));
EXPECT_THAT(status.message(), testing::HasSubstr("idx >= 0"));
}
TEST(TestGetIntVectorItemCalculatorTest, OptionsIndexBoundsCheckFail2) {
@ -181,7 +183,21 @@ TEST(TestGetIntVectorItemCalculatorTest, OptionsIndexBoundsCheckFail2) {
absl::Status status = runner.Run();
ASSERT_FALSE(status.ok());
EXPECT_THAT(status.message(),
testing::HasSubstr("idx >= 0 && idx < items.size()"));
testing::HasSubstr(
"options.output_empty_on_oob() || idx < items.size()"));
}
TEST(TestGetIntVectorItemCalculatorTest, OptionsIndexBoundsCheckFail3) {
const int try_index = 3;
CalculatorRunner runner = MakeRunnerWithOptions(try_index, true);
const std::vector<int> inputs = {1, 2, 3};
AddInputVector(runner, inputs, 1);
MP_ASSERT_OK(runner.Run());
const std::vector<Packet>& outputs = runner.Outputs().Tag("ITEM").packets;
EXPECT_THAT(outputs, testing::ElementsAre());
}
TEST(TestGetIntVectorItemCalculatorTest, IndexStreamTwoTimestamps) {

View File

@ -15,6 +15,7 @@ limitations under the License.
#include "mediapipe/calculators/core/merge_to_vector_calculator.h"
#include "mediapipe/framework/formats/detection.pb.h"
#include "mediapipe/framework/formats/image.h"
namespace mediapipe {
@ -23,5 +24,13 @@ namespace api2 {
typedef MergeToVectorCalculator<mediapipe::Image> MergeImagesToVectorCalculator;
MEDIAPIPE_REGISTER_NODE(MergeImagesToVectorCalculator);
typedef MergeToVectorCalculator<mediapipe::GpuBuffer>
MergeGpuBuffersToVectorCalculator;
MEDIAPIPE_REGISTER_NODE(MergeGpuBuffersToVectorCalculator);
typedef MergeToVectorCalculator<mediapipe::Detection>
MergeDetectionsToVectorCalculator;
MEDIAPIPE_REGISTER_NODE(MergeDetectionsToVectorCalculator);
} // namespace api2
} // namespace mediapipe

View File

@ -42,11 +42,20 @@ class MergeToVectorCalculator : public Node {
return absl::OkStatus();
}
absl::Status Open(::mediapipe::CalculatorContext* cc) {
cc->SetOffset(::mediapipe::TimestampDiff(0));
return absl::OkStatus();
}
absl::Status Process(CalculatorContext* cc) {
const int input_num = kIn(cc).Count();
std::vector<T> output_vector(input_num);
std::transform(kIn(cc).begin(), kIn(cc).end(), output_vector.begin(),
[](const auto& elem) -> T { return elem.Get(); });
std::vector<T> output_vector;
for (auto it = kIn(cc).begin(); it != kIn(cc).end(); it++) {
const auto& elem = *it;
if (!elem.IsEmpty()) {
output_vector.push_back(elem.Get());
}
}
kOut(cc).Send(output_vector);
return absl::OkStatus();
}

View File

@ -37,7 +37,8 @@ enum { ATTRIB_VERTEX, ATTRIB_TEXTURE_POSITION, NUM_ATTRIBUTES };
namespace mediapipe {
namespace {
using ::mediapipe::NormalizedRect;
using ::mediapipe::Rect;
#if !MEDIAPIPE_DISABLE_GPU
#endif // !MEDIAPIPE_DISABLE_GPU

View File

@ -195,11 +195,11 @@ TEST(ImageCroppingCalculatorTest, RedundantSpecWithInputStream) {
auto cc = absl::make_unique<CalculatorContext>(
calculator_state.get(), inputTags, tool::CreateTagMap({}).value());
auto& inputs = cc->Inputs();
mediapipe::Rect rect = ParseTextProtoOrDie<mediapipe::Rect>(
Rect rect = ParseTextProtoOrDie<Rect>(
R"pb(
width: 1 height: 1 x_center: 40 y_center: 40 rotation: 0.5
)pb");
inputs.Tag(kRectTag).Value() = MakePacket<mediapipe::Rect>(rect);
inputs.Tag(kRectTag).Value() = MakePacket<Rect>(rect);
RectSpec expectRect = {
.width = 1,
.height = 1,

View File

@ -21,7 +21,7 @@ package(default_visibility = ["//visibility:private"])
proto_library(
name = "callback_packet_calculator_proto",
srcs = ["callback_packet_calculator.proto"],
visibility = ["//visibility:public"],
visibility = ["//mediapipe/framework:__subpackages__"],
deps = ["//mediapipe/framework:calculator_proto"],
)
@ -29,14 +29,14 @@ mediapipe_cc_proto_library(
name = "callback_packet_calculator_cc_proto",
srcs = ["callback_packet_calculator.proto"],
cc_deps = ["//mediapipe/framework:calculator_cc_proto"],
visibility = ["//visibility:public"],
visibility = ["//mediapipe/framework:__subpackages__"],
deps = [":callback_packet_calculator_proto"],
)
cc_library(
name = "callback_packet_calculator",
srcs = ["callback_packet_calculator.cc"],
visibility = ["//visibility:public"],
visibility = ["//mediapipe/framework:__subpackages__"],
deps = [
":callback_packet_calculator_cc_proto",
"//mediapipe/framework:calculator_base",

View File

@ -24,7 +24,7 @@ load("//mediapipe/framework:encode_binary_proto.bzl", "encode_binary_proto")
licenses(["notice"])
package(default_visibility = ["//visibility:private"])
package(default_visibility = ["//visibility:public"])
exports_files(
glob(["testdata/image_to_tensor/*"]),
@ -44,9 +44,6 @@ selects.config_setting_group(
mediapipe_proto_library(
name = "audio_to_tensor_calculator_proto",
srcs = ["audio_to_tensor_calculator.proto"],
visibility = [
"//mediapipe/framework:mediapipe_internal",
],
deps = [
"//mediapipe/framework:calculator_options_proto",
"//mediapipe/framework:calculator_proto",
@ -64,9 +61,6 @@ cc_library(
],
"//conditions:default": [],
}),
visibility = [
"//mediapipe/framework:mediapipe_internal",
],
deps = [
":audio_to_tensor_calculator_cc_proto",
"//mediapipe/framework:calculator_framework",
@ -113,9 +107,6 @@ cc_test(
mediapipe_proto_library(
name = "tensors_to_audio_calculator_proto",
srcs = ["tensors_to_audio_calculator.proto"],
visibility = [
"//mediapipe/framework:mediapipe_internal",
],
deps = [
"//mediapipe/framework:calculator_options_proto",
"//mediapipe/framework:calculator_proto",
@ -125,9 +116,6 @@ mediapipe_proto_library(
cc_library(
name = "tensors_to_audio_calculator",
srcs = ["tensors_to_audio_calculator.cc"],
visibility = [
"//mediapipe/framework:mediapipe_internal",
],
deps = [
":tensors_to_audio_calculator_cc_proto",
"//mediapipe/framework:calculator_framework",
@ -164,9 +152,6 @@ cc_test(
mediapipe_proto_library(
name = "feedback_tensors_calculator_proto",
srcs = ["feedback_tensors_calculator.proto"],
visibility = [
"//mediapipe/framework:mediapipe_internal",
],
deps = [
"//mediapipe/framework:calculator_options_proto",
"//mediapipe/framework:calculator_proto",
@ -184,9 +169,6 @@ cc_library(
],
"//conditions:default": [],
}),
visibility = [
"//mediapipe/framework:mediapipe_internal",
],
deps = [
":feedback_tensors_calculator_cc_proto",
"//mediapipe/framework:calculator_framework",
@ -216,9 +198,6 @@ cc_test(
mediapipe_proto_library(
name = "bert_preprocessor_calculator_proto",
srcs = ["bert_preprocessor_calculator.proto"],
visibility = [
"//mediapipe/framework:mediapipe_internal",
],
deps = [
"//mediapipe/framework:calculator_options_proto",
"//mediapipe/framework:calculator_proto",
@ -228,9 +207,6 @@ mediapipe_proto_library(
cc_library(
name = "bert_preprocessor_calculator",
srcs = ["bert_preprocessor_calculator.cc"],
visibility = [
"//mediapipe/framework:mediapipe_internal",
],
deps = [
":bert_preprocessor_calculator_cc_proto",
"//mediapipe/framework:calculator_framework",
@ -274,9 +250,6 @@ cc_test(
mediapipe_proto_library(
name = "regex_preprocessor_calculator_proto",
srcs = ["regex_preprocessor_calculator.proto"],
visibility = [
"//mediapipe/framework:mediapipe_internal",
],
deps = [
"//mediapipe/framework:calculator_options_proto",
"//mediapipe/framework:calculator_proto",
@ -286,9 +259,6 @@ mediapipe_proto_library(
cc_library(
name = "regex_preprocessor_calculator",
srcs = ["regex_preprocessor_calculator.cc"],
visibility = [
"//mediapipe/framework:mediapipe_internal",
],
deps = [
":regex_preprocessor_calculator_cc_proto",
"//mediapipe/framework:calculator_framework",
@ -330,9 +300,6 @@ cc_test(
cc_library(
name = "text_to_tensor_calculator",
srcs = ["text_to_tensor_calculator.cc"],
visibility = [
"//mediapipe/framework:mediapipe_internal",
],
deps = [
"//mediapipe/framework:calculator_context",
"//mediapipe/framework:calculator_framework",
@ -405,7 +372,6 @@ cc_test(
mediapipe_proto_library(
name = "inference_calculator_proto",
srcs = ["inference_calculator.proto"],
visibility = ["//visibility:public"],
deps = [
"//mediapipe/framework:calculator_options_proto",
"//mediapipe/framework:calculator_proto",
@ -432,7 +398,6 @@ cc_library(
],
"//conditions:default": [],
}),
visibility = ["//visibility:public"],
deps = [
":inference_calculator_cc_proto",
":inference_calculator_options_lib",
@ -457,7 +422,6 @@ cc_library(
name = "inference_calculator_gl",
srcs = ["inference_calculator_gl.cc"],
tags = ["nomac"], # config problem with cpuinfo via TF
visibility = ["//visibility:public"],
deps = [
":inference_calculator_cc_proto",
":inference_calculator_interface",
@ -475,7 +439,6 @@ cc_library(
name = "inference_calculator_gl_advanced",
srcs = ["inference_calculator_gl_advanced.cc"],
tags = ["nomac"],
visibility = ["//visibility:public"],
deps = [
":inference_calculator_interface",
"@com_google_absl//absl/memory",
@ -506,7 +469,6 @@ cc_library(
"-framework MetalKit",
],
tags = ["ios"],
visibility = ["//visibility:public"],
deps = [
"inference_calculator_interface",
"//mediapipe/gpu:MPPMetalHelper",
@ -535,7 +497,6 @@ cc_library(
],
"//conditions:default": [],
}),
visibility = ["//visibility:public"],
deps = [
"//mediapipe/framework:calculator_context",
"//mediapipe/framework/formats:tensor",
@ -555,7 +516,6 @@ cc_library(
],
"//conditions:default": [],
}),
visibility = ["//visibility:public"],
deps = [
":inference_runner",
"//mediapipe/framework:mediapipe_profiling",
@ -585,7 +545,6 @@ cc_library(
],
"//conditions:default": [],
}),
visibility = ["//visibility:public"],
deps = [
":inference_calculator_interface",
":inference_calculator_utils",
@ -632,7 +591,6 @@ cc_library(
],
"//conditions:default": [],
}),
visibility = ["//visibility:public"],
deps = [
":inference_calculator_interface",
":inference_calculator_utils",
@ -648,7 +606,6 @@ cc_library(
cc_library(
name = "inference_calculator_gl_if_compute_shader_available",
visibility = ["//visibility:public"],
deps = selects.with_or({
":compute_shader_unavailable": [],
"//conditions:default": [
@ -664,7 +621,6 @@ cc_library(
# inference_calculator_interface.
cc_library(
name = "inference_calculator",
visibility = ["//visibility:public"],
deps = [
":inference_calculator_interface",
":inference_calculator_cpu",
@ -678,7 +634,6 @@ cc_library(
mediapipe_proto_library(
name = "tensor_converter_calculator_proto",
srcs = ["tensor_converter_calculator.proto"],
visibility = ["//visibility:public"],
deps = [
"//mediapipe/framework:calculator_options_proto",
"//mediapipe/framework:calculator_proto",
@ -703,7 +658,6 @@ cc_library(
],
"//conditions:default": [],
}),
visibility = ["//visibility:public"],
deps = [
":tensor_converter_calculator_cc_proto",
"//mediapipe/framework:calculator_framework",
@ -722,6 +676,7 @@ cc_library(
cc_library(
name = "tensor_converter_calculator_gpu_deps",
visibility = ["//visibility:private"],
deps = select({
"//mediapipe:android": [
"//mediapipe/gpu:gl_calculator_helper",
@ -766,7 +721,6 @@ cc_test(
mediapipe_proto_library(
name = "tensors_to_detections_calculator_proto",
srcs = ["tensors_to_detections_calculator.proto"],
visibility = ["//visibility:public"],
deps = [
"//mediapipe/framework:calculator_options_proto",
"//mediapipe/framework:calculator_proto",
@ -791,7 +745,6 @@ cc_library(
],
"//conditions:default": [],
}),
visibility = ["//visibility:public"],
deps = [
":tensors_to_detections_calculator_cc_proto",
"//mediapipe/framework/formats:detection_cc_proto",
@ -814,6 +767,7 @@ cc_library(
cc_library(
name = "tensors_to_detections_calculator_gpu_deps",
visibility = ["//visibility:private"],
deps = select({
"//mediapipe:ios": [
"//mediapipe/gpu:MPPMetalUtil",
@ -829,7 +783,6 @@ cc_library(
mediapipe_proto_library(
name = "tensors_to_landmarks_calculator_proto",
srcs = ["tensors_to_landmarks_calculator.proto"],
visibility = ["//visibility:public"],
deps = [
"//mediapipe/framework:calculator_options_proto",
"//mediapipe/framework:calculator_proto",
@ -846,7 +799,6 @@ cc_library(
],
"//conditions:default": [],
}),
visibility = ["//visibility:public"],
deps = [
":tensors_to_landmarks_calculator_cc_proto",
"//mediapipe/framework:calculator_framework",
@ -861,7 +813,6 @@ cc_library(
mediapipe_proto_library(
name = "landmarks_to_tensor_calculator_proto",
srcs = ["landmarks_to_tensor_calculator.proto"],
visibility = ["//visibility:public"],
deps = [
"//mediapipe/framework:calculator_options_proto",
"//mediapipe/framework:calculator_proto",
@ -879,7 +830,6 @@ cc_library(
],
"//conditions:default": [],
}),
visibility = ["//visibility:public"],
deps = [
":landmarks_to_tensor_calculator_cc_proto",
"//mediapipe/framework:calculator_framework",
@ -912,7 +862,6 @@ cc_test(
mediapipe_proto_library(
name = "tensors_to_floats_calculator_proto",
srcs = ["tensors_to_floats_calculator.proto"],
visibility = ["//visibility:public"],
deps = [
"//mediapipe/framework:calculator_options_proto",
"//mediapipe/framework:calculator_proto",
@ -929,7 +878,6 @@ cc_library(
],
"//conditions:default": [],
}),
visibility = ["//visibility:public"],
deps = [
":tensors_to_floats_calculator_cc_proto",
"//mediapipe/framework:calculator_framework",
@ -967,7 +915,6 @@ cc_library(
],
"//conditions:default": [],
}),
visibility = ["//visibility:public"],
deps = [
":tensors_to_classification_calculator_cc_proto",
"@com_google_absl//absl/container:node_hash_map",
@ -998,7 +945,6 @@ cc_library(
mediapipe_proto_library(
name = "tensors_to_classification_calculator_proto",
srcs = ["tensors_to_classification_calculator.proto"],
visibility = ["//visibility:public"],
deps = [
"//mediapipe/framework:calculator_options_proto",
"//mediapipe/framework:calculator_proto",
@ -1036,7 +982,6 @@ cc_library(
"//conditions:default": [],
}),
features = ["-layering_check"], # allow depending on image_to_tensor_calculator_gpu_deps
visibility = ["//visibility:public"],
deps = [
":image_to_tensor_calculator_cc_proto",
":image_to_tensor_converter",
@ -1065,6 +1010,7 @@ cc_library(
cc_library(
name = "image_to_tensor_calculator_gpu_deps",
visibility = ["//visibility:private"],
deps = selects.with_or({
"//mediapipe:android": [
":image_to_tensor_converter_gl_buffer",
@ -1088,7 +1034,6 @@ cc_library(
mediapipe_proto_library(
name = "image_to_tensor_calculator_proto",
srcs = ["image_to_tensor_calculator.proto"],
visibility = ["//visibility:public"],
deps = [
"//mediapipe/framework:calculator_options_proto",
"//mediapipe/framework:calculator_proto",
@ -1151,7 +1096,6 @@ cc_library(
],
"//conditions:default": [],
}),
visibility = ["//visibility:public"],
deps = [
":image_to_tensor_utils",
"//mediapipe/framework/formats:image",
@ -1171,7 +1115,6 @@ cc_library(
],
"//conditions:default": [],
}),
visibility = ["//visibility:public"],
deps = [
":image_to_tensor_converter",
":image_to_tensor_utils",
@ -1191,6 +1134,7 @@ cc_library(
name = "image_to_tensor_converter_gl_buffer",
srcs = ["image_to_tensor_converter_gl_buffer.cc"],
hdrs = ["image_to_tensor_converter_gl_buffer.h"],
visibility = ["//visibility:private"],
deps = ["//mediapipe/framework:port"] + selects.with_or({
"//mediapipe:apple": [],
"//conditions:default": [
@ -1224,6 +1168,7 @@ cc_library(
name = "image_to_tensor_converter_gl_texture",
srcs = ["image_to_tensor_converter_gl_texture.cc"],
hdrs = ["image_to_tensor_converter_gl_texture.h"],
visibility = ["//visibility:private"],
deps = ["//mediapipe/framework:port"] + select({
"//mediapipe/gpu:disable_gpu": [],
"//conditions:default": [
@ -1248,6 +1193,7 @@ cc_library(
name = "image_to_tensor_converter_gl_utils",
srcs = ["image_to_tensor_converter_gl_utils.cc"],
hdrs = ["image_to_tensor_converter_gl_utils.h"],
visibility = ["//visibility:private"],
deps = ["//mediapipe/framework:port"] + select({
"//mediapipe/gpu:disable_gpu": [],
"//conditions:default": [
@ -1277,6 +1223,7 @@ cc_library(
],
"//conditions:default": [],
}),
visibility = ["//visibility:private"],
deps = ["//mediapipe/framework:port"] + select({
"//mediapipe:apple": [
":image_to_tensor_converter",
@ -1308,7 +1255,6 @@ cc_library(
],
"//conditions:default": [],
}),
visibility = ["//visibility:public"],
deps = [
":image_to_tensor_calculator_cc_proto",
"@com_google_absl//absl/status",
@ -1351,7 +1297,6 @@ selects.config_setting_group(
mediapipe_proto_library(
name = "tensors_to_segmentation_calculator_proto",
srcs = ["tensors_to_segmentation_calculator.proto"],
visibility = ["//visibility:public"],
deps = [
"//mediapipe/framework:calculator_options_proto",
"//mediapipe/framework:calculator_proto",
@ -1369,7 +1314,6 @@ cc_library(
],
"//conditions:default": [],
}),
visibility = ["//visibility:public"],
deps = [
":tensors_to_segmentation_calculator_cc_proto",
"@com_google_absl//absl/strings:str_format",
@ -1427,7 +1371,6 @@ cc_library(
],
"//conditions:default": [],
}),
visibility = ["//visibility:public"],
deps = [
"//mediapipe/framework:calculator_context",
"//mediapipe/framework:calculator_framework",

View File

@ -17,6 +17,7 @@ syntax = "proto2";
package mediapipe;
import "mediapipe/framework/calculator.proto";
import "mediapipe/framework/calculator_options.proto";
option java_package = "com.google.mediapipe.calculator.proto";
option java_outer_classname = "InferenceCalculatorProto";

View File

@ -456,6 +456,23 @@ cc_library(
alwayslink = 1,
)
cc_library(
name = "detections_deduplicate_calculator",
srcs = [
"detections_deduplicate_calculator.cc",
],
deps = [
"//mediapipe/framework:calculator_framework",
"//mediapipe/framework/api2:node",
"//mediapipe/framework/api2:port",
"//mediapipe/framework/formats:detection_cc_proto",
"//mediapipe/framework/formats:location_data_cc_proto",
"@com_google_absl//absl/container:flat_hash_map",
"@com_google_absl//absl/status",
],
alwayslink = 1,
)
cc_library(
name = "rect_transformation_calculator",
srcs = ["rect_transformation_calculator.cc"],

View File

@ -0,0 +1,114 @@
/* Copyright 2022 The MediaPipe Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#include <algorithm>
#include <memory>
#include <utility>
#include <vector>
#include "absl/container/flat_hash_map.h"
#include "absl/status/status.h"
#include "mediapipe/framework/api2/node.h"
#include "mediapipe/framework/api2/port.h"
#include "mediapipe/framework/calculator_framework.h"
#include "mediapipe/framework/formats/detection.pb.h"
#include "mediapipe/framework/formats/location_data.pb.h"
namespace mediapipe {
namespace api2 {
namespace {
struct BoundingBoxHash {
size_t operator()(const LocationData::BoundingBox& bbox) const {
return std::hash<int>{}(bbox.xmin()) ^ std::hash<int>{}(bbox.ymin()) ^
std::hash<int>{}(bbox.width()) ^ std::hash<int>{}(bbox.height());
}
};
struct BoundingBoxEq {
bool operator()(const LocationData::BoundingBox& lhs,
const LocationData::BoundingBox& rhs) const {
return lhs.xmin() == rhs.xmin() && lhs.ymin() == rhs.ymin() &&
lhs.width() == rhs.width() && lhs.height() == rhs.height();
}
};
} // namespace
// This Calculator deduplicates the bunding boxes with exactly the same
// coordinates, and folds the labels into a single Detection proto. Note
// non-maximum-suppression remove the overlapping bounding boxes within a class,
// while the deduplication operation merges bounding boxes from different
// classes.
// Example config:
// node {
// calculator: "DetectionsDeduplicateCalculator"
// input_stream: "detections"
// output_stream: "deduplicated_detections"
// }
class DetectionsDeduplicateCalculator : public Node {
public:
static constexpr Input<std::vector<Detection>> kIn{""};
static constexpr Output<std::vector<Detection>> kOut{""};
MEDIAPIPE_NODE_CONTRACT(kIn, kOut);
absl::Status Open(mediapipe::CalculatorContext* cc) {
cc->SetOffset(::mediapipe::TimestampDiff(0));
return absl::OkStatus();
}
absl::Status Process(mediapipe::CalculatorContext* cc) {
const std::vector<Detection>& raw_detections = kIn(cc).Get();
absl::flat_hash_map<LocationData::BoundingBox, Detection*, BoundingBoxHash,
BoundingBoxEq>
bbox_to_detections;
std::vector<Detection> deduplicated_detections;
for (const auto& detection : raw_detections) {
if (!detection.has_location_data() ||
!detection.location_data().has_bounding_box()) {
return absl::InvalidArgumentError(
"The location data of Detections must be BoundingBox.");
}
if (bbox_to_detections.contains(
detection.location_data().bounding_box())) {
// The bbox location already exists. Merge the detection labels into
// the existing detection proto.
Detection& deduplicated_detection =
*bbox_to_detections[detection.location_data().bounding_box()];
deduplicated_detection.mutable_score()->MergeFrom(detection.score());
deduplicated_detection.mutable_label()->MergeFrom(detection.label());
deduplicated_detection.mutable_label_id()->MergeFrom(
detection.label_id());
deduplicated_detection.mutable_display_name()->MergeFrom(
detection.display_name());
} else {
// The bbox location appears first time. Add the detection to output
// detection vector.
deduplicated_detections.push_back(detection);
bbox_to_detections[detection.location_data().bounding_box()] =
&deduplicated_detections.back();
}
}
kOut(cc).Send(std::move(deduplicated_detections));
return absl::OkStatus();
}
};
MEDIAPIPE_REGISTER_NODE(DetectionsDeduplicateCalculator);
} // namespace api2
} // namespace mediapipe

View File

@ -37,6 +37,9 @@ constexpr char kNormRectTag[] = "NORM_RECT";
constexpr char kRectsTag[] = "RECTS";
constexpr char kNormRectsTag[] = "NORM_RECTS";
using ::mediapipe::NormalizedRect;
using ::mediapipe::Rect;
constexpr float kMinFloat = std::numeric_limits<float>::lowest();
constexpr float kMaxFloat = std::numeric_limits<float>::max();

View File

@ -39,6 +39,9 @@ constexpr char kImageSizeTag[] = "IMAGE_SIZE";
constexpr char kRectTag[] = "RECT";
constexpr char kDetectionTag[] = "DETECTION";
using ::mediapipe::NormalizedRect;
using ::mediapipe::Rect;
MATCHER_P4(RectEq, x_center, y_center, width, height, "") {
return testing::Value(arg.x_center(), testing::Eq(x_center)) &&
testing::Value(arg.y_center(), testing::Eq(y_center)) &&

View File

@ -24,6 +24,8 @@
namespace mediapipe {
using ::mediapipe::NormalizedRect;
namespace {
constexpr char kLandmarksTag[] = "NORM_LANDMARKS";

View File

@ -35,7 +35,9 @@ constexpr char kObjectScaleRoiTag[] = "OBJECT_SCALE_ROI";
constexpr char kNormalizedFilteredLandmarksTag[] = "NORM_FILTERED_LANDMARKS";
constexpr char kFilteredLandmarksTag[] = "FILTERED_LANDMARKS";
using ::mediapipe::NormalizedRect;
using mediapipe::OneEuroFilter;
using ::mediapipe::Rect;
using mediapipe::RelativeVelocityFilter;
void NormalizedLandmarksToLandmarks(

View File

@ -23,6 +23,8 @@ namespace {
constexpr char kNormRectTag[] = "NORM_RECT";
constexpr char kNormReferenceRectTag[] = "NORM_REFERENCE_RECT";
using ::mediapipe::NormalizedRect;
} // namespace
// Projects rectangle from reference coordinate system (defined by reference

View File

@ -29,6 +29,9 @@ constexpr char kNormRectsTag[] = "NORM_RECTS";
constexpr char kRectsTag[] = "RECTS";
constexpr char kRenderDataTag[] = "RENDER_DATA";
using ::mediapipe::NormalizedRect;
using ::mediapipe::Rect;
RenderAnnotation::Rectangle* NewRect(
const RectToRenderDataCalculatorOptions& options, RenderData* render_data) {
auto* annotation = render_data->add_render_annotations();

View File

@ -24,6 +24,8 @@ constexpr char kNormRectTag[] = "NORM_RECT";
constexpr char kImageSizeTag[] = "IMAGE_SIZE";
constexpr char kRenderScaleTag[] = "RENDER_SCALE";
using ::mediapipe::NormalizedRect;
} // namespace
// A calculator to get scale for RenderData primitives.

View File

@ -28,6 +28,9 @@ constexpr char kRectTag[] = "RECT";
constexpr char kRectsTag[] = "RECTS";
constexpr char kImageSizeTag[] = "IMAGE_SIZE";
using ::mediapipe::NormalizedRect;
using ::mediapipe::Rect;
// Wraps around an angle in radians to within -M_PI and M_PI.
inline float NormalizeRadians(float angle) {
return angle - 2 * M_PI * std::floor((angle - (-M_PI)) / (2 * M_PI));

View File

@ -22,6 +22,8 @@
namespace mediapipe {
using ::mediapipe::NormalizedRect;
namespace {
constexpr char kLandmarksTag[] = "LANDMARKS";

View File

@ -32,6 +32,8 @@
namespace mediapipe {
namespace {
using ::mediapipe::NormalizedRect;
constexpr int kDetectionUpdateTimeOutMS = 5000;
constexpr char kDetectionsTag[] = "DETECTIONS";
constexpr char kDetectionBoxesTag[] = "DETECTION_BOXES";

View File

@ -18,7 +18,7 @@ import android.content.ClipDescription;
import android.content.Context;
import android.net.Uri;
import android.os.Bundle;
import android.support.v7.widget.AppCompatEditText;
import androidx.appcompat.widget.AppCompatEditText;
import android.util.AttributeSet;
import android.util.Log;
import android.view.inputmethod.EditorInfo;

View File

@ -29,12 +29,6 @@ objc_library(
"Base.lproj/LaunchScreen.storyboard",
"Base.lproj/Main.storyboard",
],
sdk_frameworks = [
"AVFoundation",
"CoreGraphics",
"CoreMedia",
"UIKit",
],
visibility = [
"//mediapipe:__subpackages__",
],
@ -42,6 +36,10 @@ objc_library(
"//mediapipe/objc:mediapipe_framework_ios",
"//mediapipe/objc:mediapipe_input_sources_ios",
"//mediapipe/objc:mediapipe_layer_renderer",
"//third_party/apple_frameworks:AVFoundation",
"//third_party/apple_frameworks:CoreGraphics",
"//third_party/apple_frameworks:CoreMedia",
"//third_party/apple_frameworks:UIKit",
],
)

View File

@ -73,13 +73,13 @@ objc_library(
"//mediapipe/modules/face_landmark:face_landmark.tflite",
],
features = ["-layering_check"],
sdk_frameworks = [
"AVFoundation",
"CoreGraphics",
"CoreMedia",
"UIKit",
],
deps = [
"//mediapipe/framework/formats:matrix_data_cc_proto",
"//third_party/apple_frameworks:AVFoundation",
"//third_party/apple_frameworks:CoreGraphics",
"//third_party/apple_frameworks:CoreMedia",
"//third_party/apple_frameworks:UIKit",
"//mediapipe/modules/face_geometry/protos:face_geometry_cc_proto",
"//mediapipe/objc:mediapipe_framework_ios",
"//mediapipe/objc:mediapipe_input_sources_ios",
"//mediapipe/objc:mediapipe_layer_renderer",
@ -87,9 +87,7 @@ objc_library(
"//mediapipe:ios_i386": [],
"//mediapipe:ios_x86_64": [],
"//conditions:default": [
"//mediapipe/framework/formats:matrix_data_cc_proto",
"//mediapipe/graphs/face_effect:face_effect_gpu_deps",
"//mediapipe/modules/face_geometry/protos:face_geometry_cc_proto",
],
}),
)

View File

@ -67,12 +67,12 @@ objc_library(
],
deps = [
"//mediapipe/examples/ios/common:CommonMediaPipeAppLibrary",
"//mediapipe/framework/formats:landmark_cc_proto",
] + select({
"//mediapipe:ios_i386": [],
"//mediapipe:ios_x86_64": [],
"//conditions:default": [
"//mediapipe/graphs/face_mesh:mobile_calculators",
"//mediapipe/framework/formats:landmark_cc_proto",
],
}),
)

View File

@ -68,12 +68,12 @@ objc_library(
],
deps = [
"//mediapipe/examples/ios/common:CommonMediaPipeAppLibrary",
"//mediapipe/framework/formats:landmark_cc_proto",
] + select({
"//mediapipe:ios_i386": [],
"//mediapipe:ios_x86_64": [],
"//conditions:default": [
"//mediapipe/graphs/hand_tracking:mobile_calculators",
"//mediapipe/framework/formats:landmark_cc_proto",
],
}),
)

View File

@ -68,12 +68,12 @@ objc_library(
],
deps = [
"//mediapipe/examples/ios/common:CommonMediaPipeAppLibrary",
"//mediapipe/framework/formats:landmark_cc_proto",
] + select({
"//mediapipe:ios_i386": [],
"//mediapipe:ios_x86_64": [],
"//conditions:default": [
"//mediapipe/graphs/iris_tracking:iris_tracking_gpu_deps",
"//mediapipe/framework/formats:landmark_cc_proto",
],
}),
)

View File

@ -67,12 +67,12 @@ objc_library(
],
deps = [
"//mediapipe/examples/ios/common:CommonMediaPipeAppLibrary",
"//mediapipe/framework/formats:landmark_cc_proto",
] + select({
"//mediapipe:ios_i386": [],
"//mediapipe:ios_x86_64": [],
"//conditions:default": [
"//mediapipe/graphs/pose_tracking:pose_tracking_gpu_deps",
"//mediapipe/framework/formats:landmark_cc_proto",
],
}),
)

View File

@ -21,6 +21,7 @@ licenses(["notice"])
package(default_visibility = ["//visibility:private"])
# The MediaPipe internal package group. No mediapipe users should be added to this group.
package_group(
name = "mediapipe_internal",
packages = [
@ -56,12 +57,12 @@ mediapipe_proto_library(
srcs = ["calculator.proto"],
visibility = ["//visibility:public"],
deps = [
"//mediapipe/framework:calculator_options_proto",
"//mediapipe/framework:mediapipe_options_proto",
"//mediapipe/framework:packet_factory_proto",
"//mediapipe/framework:packet_generator_proto",
"//mediapipe/framework:status_handler_proto",
"//mediapipe/framework:stream_handler_proto",
":calculator_options_proto",
":mediapipe_options_proto",
":packet_factory_proto",
":packet_generator_proto",
":status_handler_proto",
":stream_handler_proto",
"@com_google_protobuf//:any_proto",
],
)
@ -78,8 +79,8 @@ mediapipe_proto_library(
srcs = ["calculator_contract_test.proto"],
visibility = ["//mediapipe/framework:__subpackages__"],
deps = [
"//mediapipe/framework:calculator_options_proto",
"//mediapipe/framework:calculator_proto",
":calculator_options_proto",
":calculator_proto",
],
)
@ -88,8 +89,8 @@ mediapipe_proto_library(
srcs = ["calculator_profile.proto"],
visibility = ["//visibility:public"],
deps = [
"//mediapipe/framework:calculator_options_proto",
"//mediapipe/framework:calculator_proto",
":calculator_options_proto",
":calculator_proto",
],
)
@ -125,14 +126,14 @@ mediapipe_proto_library(
name = "status_handler_proto",
srcs = ["status_handler.proto"],
visibility = [":mediapipe_internal"],
deps = ["//mediapipe/framework:mediapipe_options_proto"],
deps = [":mediapipe_options_proto"],
)
mediapipe_proto_library(
name = "stream_handler_proto",
srcs = ["stream_handler.proto"],
visibility = [":mediapipe_internal"],
deps = ["//mediapipe/framework:mediapipe_options_proto"],
deps = [":mediapipe_options_proto"],
)
mediapipe_proto_library(
@ -141,8 +142,8 @@ mediapipe_proto_library(
srcs = ["test_calculators.proto"],
visibility = [":mediapipe_internal"],
deps = [
"//mediapipe/framework:calculator_options_proto",
"//mediapipe/framework:calculator_proto",
":calculator_options_proto",
":calculator_proto",
],
)
@ -150,7 +151,7 @@ mediapipe_proto_library(
name = "thread_pool_executor_proto",
srcs = ["thread_pool_executor.proto"],
visibility = [":mediapipe_internal"],
deps = ["//mediapipe/framework:mediapipe_options_proto"],
deps = [":mediapipe_options_proto"],
)
# It is for pure-native Android builds where the library can't have any dependency on libandroid.so

View File

@ -20,7 +20,9 @@ load("@bazel_skylib//:bzl_library.bzl", "bzl_library")
licenses(["notice"])
package(default_visibility = ["//visibility:private"])
package(default_visibility = [
"//mediapipe:__subpackages__",
])
bzl_library(
name = "expand_template_bzl",
@ -50,13 +52,11 @@ mediapipe_proto_library(
cc_library(
name = "aligned_malloc_and_free",
hdrs = ["aligned_malloc_and_free.h"],
visibility = ["//visibility:public"],
)
cc_library(
name = "cleanup",
hdrs = ["cleanup.h"],
visibility = ["//visibility:public"],
deps = ["@com_google_absl//absl/base:core_headers"],
)
@ -86,7 +86,6 @@ cc_library(
# Use this library through "mediapipe/framework/port:gtest_main".
visibility = [
"//mediapipe/framework/port:__pkg__",
"//third_party/visionai/algorithms/tracking:__pkg__",
],
deps = [
"//mediapipe/framework/port:core_proto",
@ -108,7 +107,6 @@ cc_library(
name = "file_helpers",
srcs = ["file_helpers.cc"],
hdrs = ["file_helpers.h"],
visibility = ["//visibility:public"],
deps = [
":file_path",
"//mediapipe/framework/port:status",
@ -134,7 +132,6 @@ cc_library(
cc_library(
name = "image_resizer",
hdrs = ["image_resizer.h"],
visibility = ["//visibility:public"],
deps = [
"//mediapipe/framework/port:opencv_imgproc",
],
@ -151,7 +148,9 @@ cc_library(
cc_library(
name = "mathutil",
hdrs = ["mathutil.h"],
visibility = ["//visibility:public"],
visibility = [
"//mediapipe:__subpackages__",
],
deps = [
"//mediapipe/framework/port:integral_types",
"//mediapipe/framework/port:logging",
@ -171,7 +170,6 @@ cc_library(
cc_library(
name = "no_destructor",
hdrs = ["no_destructor.h"],
visibility = ["//visibility:public"],
)
cc_library(
@ -190,7 +188,6 @@ cc_library(
cc_library(
name = "random",
hdrs = ["random_base.h"],
visibility = ["//visibility:public"],
deps = ["//mediapipe/framework/port:integral_types"],
)
@ -211,14 +208,12 @@ cc_library(
name = "registration_token",
srcs = ["registration_token.cc"],
hdrs = ["registration_token.h"],
visibility = ["//visibility:public"],
)
cc_library(
name = "registration",
srcs = ["registration.cc"],
hdrs = ["registration.h"],
visibility = ["//visibility:public"],
deps = [
":registration_token",
"//mediapipe/framework/port:logging",
@ -279,7 +274,6 @@ cc_library(
hdrs = [
"re2.h",
],
visibility = ["//visibility:public"],
)
cc_library(
@ -310,7 +304,6 @@ cc_library(
cc_library(
name = "thread_options",
hdrs = ["thread_options.h"],
visibility = ["//visibility:public"],
)
cc_library(
@ -356,7 +349,6 @@ cc_library(
cc_test(
name = "mathutil_unittest",
srcs = ["mathutil_unittest.cc"],
visibility = ["//visibility:public"],
deps = [
":mathutil",
"//mediapipe/framework/port:benchmark",
@ -368,7 +360,6 @@ cc_test(
name = "registration_token_test",
srcs = ["registration_token_test.cc"],
linkstatic = 1,
visibility = ["//visibility:public"],
deps = [
":registration_token",
"//mediapipe/framework/port:gtest_main",
@ -381,7 +372,6 @@ cc_test(
timeout = "long",
srcs = ["safe_int_test.cc"],
linkstatic = 1,
visibility = ["//visibility:public"],
deps = [
":intops",
"//mediapipe/framework/port:gtest_main",
@ -393,7 +383,6 @@ cc_test(
name = "monotonic_clock_test",
srcs = ["monotonic_clock_test.cc"],
linkstatic = 1,
visibility = ["//visibility:public"],
deps = [
":clock",
"//mediapipe/framework/port:gtest_main",

View File

@ -361,7 +361,7 @@ void Tensor::AllocateOpenGlBuffer() const {
LOG_IF(FATAL, !gl_context_) << "GlContext is not bound to the thread.";
glGenBuffers(1, &opengl_buffer_);
glBindBuffer(GL_SHADER_STORAGE_BUFFER, opengl_buffer_);
if (!AllocateAhwbMapToSsbo()) {
if (!use_ahwb_ || !AllocateAhwbMapToSsbo()) {
glBufferData(GL_SHADER_STORAGE_BUFFER, bytes(), NULL, GL_STREAM_COPY);
}
glBindBuffer(GL_SHADER_STORAGE_BUFFER, 0);
@ -551,7 +551,7 @@ Tensor::CpuReadView Tensor::GetCpuReadView() const {
});
} else
#endif // MEDIAPIPE_OPENGL_ES_VERSION >= MEDIAPIPE_OPENGL_ES_31
{
// Transfer data from texture if not transferred from SSBO/MTLBuffer
// yet.
if (valid_ & kValidOpenGlTexture2d) {
@ -582,6 +582,7 @@ Tensor::CpuReadView Tensor::GetCpuReadView() const {
}
});
}
}
#endif // MEDIAPIPE_OPENGL_ES_VERSION >= MEDIAPIPE_OPENGL_ES_30
valid_ |= kValidCpu;
}
@ -609,7 +610,7 @@ Tensor::CpuWriteView Tensor::GetCpuWriteView() const {
void Tensor::AllocateCpuBuffer() const {
if (!cpu_buffer_) {
#ifdef MEDIAPIPE_TENSOR_USE_AHWB
if (AllocateAHardwareBuffer()) return;
if (use_ahwb_ && AllocateAHardwareBuffer()) return;
#endif // MEDIAPIPE_TENSOR_USE_AHWB
#if MEDIAPIPE_METAL_ENABLED
cpu_buffer_ = AllocateVirtualMemory(bytes());

View File

@ -39,10 +39,9 @@
#endif // MEDIAPIPE_NO_JNI
#ifdef MEDIAPIPE_TENSOR_USE_AHWB
#include <EGL/egl.h>
#include <EGL/eglext.h>
#include <android/hardware_buffer.h>
#include "third_party/GL/gl/include/EGL/egl.h"
#include "third_party/GL/gl/include/EGL/eglext.h"
#endif // MEDIAPIPE_TENSOR_USE_AHWB
#if MEDIAPIPE_OPENGL_ES_VERSION >= MEDIAPIPE_OPENGL_ES_30
#include "mediapipe/gpu/gl_base.h"
@ -410,8 +409,8 @@ class Tensor {
bool AllocateAHardwareBuffer(int size_alignment = 0) const;
void CreateEglSyncAndFd() const;
// Use Ahwb for other views: OpenGL / CPU buffer.
static inline bool use_ahwb_ = false;
#endif // MEDIAPIPE_TENSOR_USE_AHWB
static inline bool use_ahwb_ = false;
// Expects the target SSBO to be already bound.
bool AllocateAhwbMapToSsbo() const;
bool InsertAhwbToSsboFence() const;
@ -419,6 +418,7 @@ class Tensor {
void ReleaseAhwbStuff();
void* MapAhwbToCpuRead() const;
void* MapAhwbToCpuWrite() const;
void MoveCpuOrSsboToAhwb() const;
#if MEDIAPIPE_OPENGL_ES_VERSION >= MEDIAPIPE_OPENGL_ES_30
mutable std::shared_ptr<mediapipe::GlContext> gl_context_;

View File

@ -4,12 +4,13 @@
#include "mediapipe/framework/formats/tensor.h"
#ifdef MEDIAPIPE_TENSOR_USE_AHWB
#include <EGL/egl.h>
#include <EGL/eglext.h>
#include "absl/synchronization/mutex.h"
#include "mediapipe/framework/port.h"
#include "mediapipe/framework/port/logging.h"
#include "mediapipe/gpu/gl_base.h"
#include "third_party/GL/gl/include/EGL/egl.h"
#include "third_party/GL/gl/include/EGL/eglext.h"
#endif // MEDIAPIPE_TENSOR_USE_AHWB
namespace mediapipe {
@ -213,11 +214,16 @@ Tensor::AHardwareBufferView Tensor::GetAHardwareBufferReadView() const {
"supported.";
CHECK(ahwb_ || !(valid_ & kValidOpenGlBuffer))
<< "Interoperability bettween OpenGL buffer and AHardwareBuffer is not "
"supported on targe system.";
"supported on target system.";
bool transfer = !ahwb_;
CHECK(AllocateAHardwareBuffer())
<< "AHardwareBuffer is not supported on the target system.";
valid_ |= kValidAHardwareBuffer;
if (valid_ & kValidOpenGlBuffer) CreateEglSyncAndFd();
if (transfer) {
MoveCpuOrSsboToAhwb();
} else {
if (valid_ & kValidOpenGlBuffer) CreateEglSyncAndFd();
}
return {ahwb_,
ssbo_written_,
&fence_fd_, // The FD is created for SSBO -> AHWB synchronization.
@ -262,7 +268,6 @@ Tensor::AHardwareBufferView Tensor::GetAHardwareBufferWriteView(
}
bool Tensor::AllocateAHardwareBuffer(int size_alignment) const {
if (!use_ahwb_) return false;
if (__builtin_available(android 26, *)) {
if (ahwb_ == nullptr) {
AHardwareBuffer_Desc desc = {};
@ -302,6 +307,39 @@ bool Tensor::AllocateAhwbMapToSsbo() const {
return false;
}
// Moves Cpu/Ssbo resource under the Ahwb backed memory.
void Tensor::MoveCpuOrSsboToAhwb() const {
void* dest = nullptr;
if (__builtin_available(android 26, *)) {
auto error = AHardwareBuffer_lock(
ahwb_, AHARDWAREBUFFER_USAGE_CPU_WRITE_RARELY, -1, nullptr, &dest);
CHECK(error == 0) << "AHardwareBuffer_lock " << error;
}
if (valid_ & kValidOpenGlBuffer) {
gl_context_->Run([this, dest]() {
glBindBuffer(GL_SHADER_STORAGE_BUFFER, opengl_buffer_);
const void* src = glMapBufferRange(GL_SHADER_STORAGE_BUFFER, 0, bytes(),
GL_MAP_READ_BIT);
std::memcpy(dest, src, bytes());
glUnmapBuffer(GL_SHADER_STORAGE_BUFFER);
glDeleteBuffers(1, &opengl_buffer_);
});
opengl_buffer_ = GL_INVALID_INDEX;
gl_context_ = nullptr;
} else if (valid_ & kValidCpu) {
std::memcpy(dest, cpu_buffer_, bytes());
// Free CPU memory because next time AHWB is mapped instead.
free(cpu_buffer_);
cpu_buffer_ = nullptr;
} else {
LOG(FATAL) << "Can't convert tensor with mask " << valid_ << " into AHWB.";
}
if (__builtin_available(android 26, *)) {
auto error = AHardwareBuffer_unlock(ahwb_, nullptr);
CHECK(error == 0) << "AHardwareBuffer_unlock " << error;
}
}
// SSBO is created on top of AHWB. A fence is inserted into the GPU queue before
// the GPU task that is going to read from the SSBO. When the writing into AHWB
// is finished then the GPU reads from the SSBO.

View File

@ -0,0 +1,171 @@
#if !defined(MEDIAPIPE_NO_JNI) && \
(__ANDROID_API__ >= 26 || \
defined(__ANDROID_UNAVAILABLE_SYMBOLS_ARE_WEAK__))
#include <android/hardware_buffer.h>
#include <cstdint>
#include "mediapipe/framework/formats/tensor.h"
#include "mediapipe/framework/formats/tensor_data_types.h"
#include "mediapipe/gpu/gpu_test_base.h"
#include "mediapipe/gpu/shader_util.h"
#include "tensorflow/lite/delegates/gpu/gl/gl_call.h"
#include "testing/base/public/gunit.h"
// The test creates OpenGL ES buffer, fills the buffer with incrementing values
// 0.0, 0.1, 0.2 etc. with the compute shader on GPU.
// Then the test requests the CPU view and compares the values.
// Float32 and Float16 tests are there.
namespace {
using mediapipe::Float16;
using mediapipe::Tensor;
MATCHER_P(NearWithPrecision, precision, "") {
return std::abs(std::get<0>(arg) - std::get<1>(arg)) < precision;
}
#if MEDIAPIPE_OPENGL_ES_VERSION >= MEDIAPIPE_OPENGL_ES_31
// Utility function to fill the GPU buffer.
void FillGpuBuffer(GLuint name, std::size_t size,
const Tensor::ElementType fmt) {
std::string shader_source;
if (fmt == Tensor::ElementType::kFloat32) {
shader_source = R"( #version 310 es
precision highp float;
layout(local_size_x = 1, local_size_y = 1) in;
layout(std430, binding = 0) buffer Output {float elements[];} output_data;
void main() {
uint v = gl_GlobalInvocationID.x * 2u;
output_data.elements[v] = float(v) / 10.0;
output_data.elements[v + 1u] = float(v + 1u) / 10.0;
})";
} else {
shader_source = R"( #version 310 es
precision highp float;
layout(local_size_x = 1, local_size_y = 1) in;
layout(std430, binding = 0) buffer Output {float elements[];} output_data;
void main() {
uint v = gl_GlobalInvocationID.x;
uint tmp = packHalf2x16(vec2((float(v)* 2.0 + 0.0) / 10.0,
(float(v) * 2.0 + 1.0) / 10.0));
output_data.elements[v] = uintBitsToFloat(tmp);
})";
}
GLuint shader;
MP_ASSERT_OK(TFLITE_GPU_CALL_GL(glCreateShader, &shader, GL_COMPUTE_SHADER));
const GLchar* sources[] = {shader_source.c_str()};
MP_ASSERT_OK(TFLITE_GPU_CALL_GL(glShaderSource, shader, 1, sources, nullptr));
MP_ASSERT_OK(TFLITE_GPU_CALL_GL(glCompileShader, shader));
GLint is_compiled = 0;
MP_ASSERT_OK(TFLITE_GPU_CALL_GL(glGetShaderiv, shader, GL_COMPILE_STATUS,
&is_compiled));
if (is_compiled == GL_FALSE) {
GLint max_length = 0;
MP_ASSERT_OK(TFLITE_GPU_CALL_GL(glGetShaderiv, shader, GL_INFO_LOG_LENGTH,
&max_length));
std::vector<GLchar> error_log(max_length);
glGetShaderInfoLog(shader, max_length, &max_length, error_log.data());
glDeleteShader(shader);
FAIL() << error_log.data();
return;
}
GLuint to_buffer_program;
MP_ASSERT_OK(TFLITE_GPU_CALL_GL(glCreateProgram, &to_buffer_program));
MP_ASSERT_OK(TFLITE_GPU_CALL_GL(glAttachShader, to_buffer_program, shader));
MP_ASSERT_OK(TFLITE_GPU_CALL_GL(glDeleteShader, shader));
MP_ASSERT_OK(TFLITE_GPU_CALL_GL(glLinkProgram, to_buffer_program));
MP_ASSERT_OK(
TFLITE_GPU_CALL_GL(glBindBufferBase, GL_SHADER_STORAGE_BUFFER, 0, name));
MP_ASSERT_OK(TFLITE_GPU_CALL_GL(glUseProgram, to_buffer_program));
MP_ASSERT_OK(TFLITE_GPU_CALL_GL(glDispatchCompute, size / 2, 1, 1));
MP_ASSERT_OK(TFLITE_GPU_CALL_GL(glBindBuffer, GL_SHADER_STORAGE_BUFFER, 0));
MP_ASSERT_OK(TFLITE_GPU_CALL_GL(glDeleteProgram, to_buffer_program));
}
class TensorAhwbGpuTest : public mediapipe::GpuTestBase {
public:
};
TEST_F(TensorAhwbGpuTest, TestGpuToCpuFloat32) {
Tensor::SetPreferredStorageType(Tensor::StorageType::kAhwb);
constexpr size_t num_elements = 20;
Tensor tensor{Tensor::ElementType::kFloat32, Tensor::Shape({num_elements})};
RunInGlContext([&tensor] {
auto ssbo_view = tensor.GetOpenGlBufferWriteView();
auto ssbo_name = ssbo_view.name();
EXPECT_GT(ssbo_name, 0);
FillGpuBuffer(ssbo_name, tensor.shape().num_elements(),
tensor.element_type());
});
auto ptr = tensor.GetCpuReadView().buffer<float>();
EXPECT_NE(ptr, nullptr);
std::vector<float> reference;
reference.resize(num_elements);
for (int i = 0; i < num_elements; i++) {
reference[i] = static_cast<float>(i) / 10.0f;
}
EXPECT_THAT(absl::Span<const float>(ptr, num_elements),
testing::Pointwise(testing::FloatEq(), reference));
}
TEST_F(TensorAhwbGpuTest, TestGpuToCpuFloat16) {
Tensor::SetPreferredStorageType(Tensor::StorageType::kAhwb);
constexpr size_t num_elements = 20;
Tensor tensor{Tensor::ElementType::kFloat16, Tensor::Shape({num_elements})};
RunInGlContext([&tensor] {
auto ssbo_view = tensor.GetOpenGlBufferWriteView();
auto ssbo_name = ssbo_view.name();
EXPECT_GT(ssbo_name, 0);
FillGpuBuffer(ssbo_name, tensor.shape().num_elements(),
tensor.element_type());
});
auto ptr = tensor.GetCpuReadView().buffer<Float16>();
EXPECT_NE(ptr, nullptr);
std::vector<Float16> reference;
reference.resize(num_elements);
for (int i = 0; i < num_elements; i++) {
reference[i] = static_cast<float>(i) / 10.0f;
}
// Precision is set to a reasonable value for Float16.
EXPECT_THAT(absl::Span<const Float16>(ptr, num_elements),
testing::Pointwise(NearWithPrecision(0.001), reference));
}
TEST_F(TensorAhwbGpuTest, TestReplacingCpuByAhwb) {
// Request the CPU view to get the memory to be allocated.
// Request Ahwb view then to transform the storage into Ahwb.
Tensor::SetPreferredStorageType(Tensor::StorageType::kDefault);
constexpr size_t num_elements = 20;
Tensor tensor{Tensor::ElementType::kFloat32, Tensor::Shape({num_elements})};
{
auto ptr = tensor.GetCpuWriteView().buffer<float>();
EXPECT_NE(ptr, nullptr);
for (int i = 0; i < num_elements; i++) {
ptr[i] = static_cast<float>(i) / 10.0f;
}
}
{
auto view = tensor.GetAHardwareBufferReadView();
EXPECT_NE(view.handle(), nullptr);
}
auto ptr = tensor.GetCpuReadView().buffer<float>();
EXPECT_NE(ptr, nullptr);
std::vector<float> reference;
reference.resize(num_elements);
for (int i = 0; i < num_elements; i++) {
reference[i] = static_cast<float>(i) / 10.0f;
}
EXPECT_THAT(absl::Span<const float>(ptr, num_elements),
testing::Pointwise(testing::FloatEq(), reference));
}
#endif // MEDIAPIPE_OPENGL_ES_VERSION >= MEDIAPIPE_OPENGL_ES_31
} // namespace
#endif // !defined(MEDIAPIPE_NO_JNI) && (__ANDROID_API__ >= 26 ||
// defined(__ANDROID_UNAVAILABLE_SYMBOLS_ARE_WEAK__))

View File

@ -0,0 +1,71 @@
#ifndef MEDIAPIPE_FRAMEWORK_FORMATS_TENSOR_HARDWARE_BUFFER_H_
#define MEDIAPIPE_FRAMEWORK_FORMATS_TENSOR_HARDWARE_BUFFER_H_
#if !defined(MEDIAPIPE_NO_JNI) && \
(__ANDROID_API__ >= 26 || \
defined(__ANDROID_UNAVAILABLE_SYMBOLS_ARE_WEAK__))
#include <android/hardware_buffer.h>
#include <cstdint>
#include "mediapipe/framework/formats/tensor_buffer.h"
#include "mediapipe/framework/formats/tensor_internal.h"
#include "mediapipe/framework/formats/tensor_v2.h"
namespace mediapipe {
// Supports:
// - float 16 and 32 bits
// - signed / unsigned integers 8,16,32 bits
class TensorHardwareBufferView;
struct TensorHardwareBufferViewDescriptor : public Tensor::ViewDescriptor {
using ViewT = TensorHardwareBufferView;
TensorBufferDescriptor buffer;
};
class TensorHardwareBufferView : public Tensor::View {
public:
TENSOR_UNIQUE_VIEW_TYPE_ID();
~TensorHardwareBufferView() = default;
const TensorHardwareBufferViewDescriptor& descriptor() const override {
return descriptor_;
}
AHardwareBuffer* handle() const { return ahwb_handle_; }
protected:
TensorHardwareBufferView(int access_capability, Tensor::View::Access access,
Tensor::View::State state,
const TensorHardwareBufferViewDescriptor& desc,
AHardwareBuffer* ahwb_handle)
: Tensor::View(kId, access_capability, access, state),
descriptor_(desc),
ahwb_handle_(ahwb_handle) {}
private:
bool MatchDescriptor(
uint64_t view_type_id,
const Tensor::ViewDescriptor& base_descriptor) const override {
if (!Tensor::View::MatchDescriptor(view_type_id, base_descriptor))
return false;
auto descriptor =
static_cast<const TensorHardwareBufferViewDescriptor&>(base_descriptor);
return descriptor.buffer.format == descriptor_.buffer.format &&
descriptor.buffer.size_alignment <=
descriptor_.buffer.size_alignment &&
descriptor_.buffer.size_alignment %
descriptor.buffer.size_alignment ==
0;
}
const TensorHardwareBufferViewDescriptor& descriptor_;
AHardwareBuffer* ahwb_handle_ = nullptr;
};
} // namespace mediapipe
#endif // !defined(MEDIAPIPE_NO_JNI) && \
(__ANDROID_API__ >= 26 || \
defined(__ANDROID_UNAVAILABLE_SYMBOLS_ARE_WEAK__))
#endif // MEDIAPIPE_FRAMEWORK_FORMATS_TENSOR_HARDWARE_BUFFER_H_

View File

@ -0,0 +1,216 @@
#if !defined(MEDIAPIPE_NO_JNI) && \
(__ANDROID_API__ >= 26 || \
defined(__ANDROID_UNAVAILABLE_SYMBOLS_ARE_WEAK__))
#include <cstdint>
#include "absl/memory/memory.h"
#include "absl/status/status.h"
#include "mediapipe/framework/formats/tensor_backend.h"
#include "mediapipe/framework/formats/tensor_cpu_buffer.h"
#include "mediapipe/framework/formats/tensor_hardware_buffer.h"
#include "mediapipe/framework/formats/tensor_v2.h"
#include "util/task/status_macros.h"
namespace mediapipe {
namespace {
class TensorCpuViewImpl : public TensorCpuView {
public:
TensorCpuViewImpl(int access_capabilities, Tensor::View::Access access,
Tensor::View::State state,
const TensorCpuViewDescriptor& descriptor, void* pointer,
AHardwareBuffer* ahwb_handle)
: TensorCpuView(access_capabilities, access, state, descriptor, pointer),
ahwb_handle_(ahwb_handle) {}
~TensorCpuViewImpl() {
// If handle_ is null then this view is constructed in GetViews with no
// access.
if (ahwb_handle_) {
if (__builtin_available(android 26, *)) {
AHardwareBuffer_unlock(ahwb_handle_, nullptr);
}
}
}
private:
AHardwareBuffer* ahwb_handle_;
};
class TensorHardwareBufferViewImpl : public TensorHardwareBufferView {
public:
TensorHardwareBufferViewImpl(
int access_capability, Tensor::View::Access access,
Tensor::View::State state,
const TensorHardwareBufferViewDescriptor& descriptor,
AHardwareBuffer* handle)
: TensorHardwareBufferView(access_capability, access, state, descriptor,
handle) {}
~TensorHardwareBufferViewImpl() = default;
};
class HardwareBufferCpuStorage : public TensorStorage {
public:
~HardwareBufferCpuStorage() {
if (!ahwb_handle_) return;
if (__builtin_available(android 26, *)) {
AHardwareBuffer_release(ahwb_handle_);
}
}
static absl::Status CanProvide(
int access_capability, const Tensor::Shape& shape, uint64_t view_type_id,
const Tensor::ViewDescriptor& base_descriptor) {
// TODO: use AHardwareBuffer_isSupported for API >= 29.
static const bool is_ahwb_supported = [] {
if (__builtin_available(android 26, *)) {
AHardwareBuffer_Desc desc = {};
// Aligned to the largest possible virtual memory page size.
constexpr uint32_t kPageSize = 16384;
desc.width = kPageSize;
desc.height = 1;
desc.layers = 1;
desc.format = AHARDWAREBUFFER_FORMAT_BLOB;
desc.usage = AHARDWAREBUFFER_USAGE_CPU_READ_OFTEN |
AHARDWAREBUFFER_USAGE_CPU_WRITE_OFTEN;
AHardwareBuffer* handle;
if (AHardwareBuffer_allocate(&desc, &handle) != 0) return false;
AHardwareBuffer_release(handle);
return true;
}
return false;
}();
if (!is_ahwb_supported) {
return absl::UnavailableError(
"AHardwareBuffer is not supported on the platform.");
}
if (view_type_id != TensorCpuView::kId &&
view_type_id != TensorHardwareBufferView::kId) {
return absl::InvalidArgumentError(
"A view type is not supported by this storage.");
}
return absl::OkStatus();
}
std::vector<std::unique_ptr<Tensor::View>> GetViews(uint64_t latest_version) {
std::vector<std::unique_ptr<Tensor::View>> result;
auto update_state = latest_version == version_
? Tensor::View::State::kUpToDate
: Tensor::View::State::kOutdated;
if (ahwb_handle_) {
result.push_back(
std::unique_ptr<Tensor::View>(new TensorHardwareBufferViewImpl(
kAccessCapability, Tensor::View::Access::kNoAccess, update_state,
hw_descriptor_, ahwb_handle_)));
result.push_back(std::unique_ptr<Tensor::View>(new TensorCpuViewImpl(
kAccessCapability, Tensor::View::Access::kNoAccess, update_state,
cpu_descriptor_, nullptr, nullptr)));
}
return result;
}
absl::StatusOr<std::unique_ptr<Tensor::View>> GetView(
Tensor::View::Access access, const Tensor::Shape& shape,
uint64_t latest_version, uint64_t view_type_id,
const Tensor::ViewDescriptor& base_descriptor, int access_capability) {
MP_RETURN_IF_ERROR(
CanProvide(access_capability, shape, view_type_id, base_descriptor));
const auto& buffer_descriptor =
view_type_id == TensorHardwareBufferView::kId
? static_cast<const TensorHardwareBufferViewDescriptor&>(
base_descriptor)
.buffer
: static_cast<const TensorCpuViewDescriptor&>(base_descriptor)
.buffer;
if (!ahwb_handle_) {
if (__builtin_available(android 26, *)) {
AHardwareBuffer_Desc desc = {};
desc.width = TensorBufferSize(buffer_descriptor, shape);
desc.height = 1;
desc.layers = 1;
desc.format = AHARDWAREBUFFER_FORMAT_BLOB;
// TODO: Use access capabilities to set hints.
desc.usage = AHARDWAREBUFFER_USAGE_CPU_READ_OFTEN |
AHARDWAREBUFFER_USAGE_CPU_WRITE_OFTEN;
auto error = AHardwareBuffer_allocate(&desc, &ahwb_handle_);
if (error != 0) {
return absl::UnknownError(
absl::StrCat("Error allocating hardware buffer: ", error));
}
// Fill all possible views to provide it as proto views.
hw_descriptor_.buffer = buffer_descriptor;
cpu_descriptor_.buffer = buffer_descriptor;
}
}
if (buffer_descriptor.format != hw_descriptor_.buffer.format ||
buffer_descriptor.size_alignment >
hw_descriptor_.buffer.size_alignment ||
hw_descriptor_.buffer.size_alignment %
buffer_descriptor.size_alignment >
0) {
return absl::AlreadyExistsError(
"A view with different params is already allocated with this "
"storage");
}
absl::StatusOr<std::unique_ptr<Tensor::View>> result;
if (view_type_id == TensorHardwareBufferView::kId) {
result = GetAhwbView(access, shape, base_descriptor);
} else {
result = GetCpuView(access, shape, base_descriptor);
}
if (result.ok()) version_ = latest_version;
return result;
}
private:
absl::StatusOr<std::unique_ptr<Tensor::View>> GetAhwbView(
Tensor::View::Access access, const Tensor::Shape& shape,
const Tensor::ViewDescriptor& base_descriptor) {
return std::unique_ptr<Tensor::View>(new TensorHardwareBufferViewImpl(
kAccessCapability, access, Tensor::View::State::kUpToDate,
hw_descriptor_, ahwb_handle_));
}
absl::StatusOr<std::unique_ptr<Tensor::View>> GetCpuView(
Tensor::View::Access access, const Tensor::Shape& shape,
const Tensor::ViewDescriptor& base_descriptor) {
void* pointer = nullptr;
if (__builtin_available(android 26, *)) {
int error =
AHardwareBuffer_lock(ahwb_handle_,
access == Tensor::View::Access::kWriteOnly
? AHARDWAREBUFFER_USAGE_CPU_WRITE_OFTEN
: AHARDWAREBUFFER_USAGE_CPU_READ_OFTEN,
-1, nullptr, &pointer);
if (error != 0) {
return absl::UnknownError(
absl::StrCat("Error locking hardware buffer: ", error));
}
}
return std::unique_ptr<Tensor::View>(
new TensorCpuViewImpl(access == Tensor::View::Access::kWriteOnly
? Tensor::View::AccessCapability::kWrite
: Tensor::View::AccessCapability::kRead,
access, Tensor::View::State::kUpToDate,
cpu_descriptor_, pointer, ahwb_handle_));
}
static constexpr int kAccessCapability =
Tensor::View::AccessCapability::kRead |
Tensor::View::AccessCapability::kWrite;
TensorHardwareBufferViewDescriptor hw_descriptor_;
AHardwareBuffer* ahwb_handle_ = nullptr;
TensorCpuViewDescriptor cpu_descriptor_;
uint64_t version_ = 0;
};
TENSOR_REGISTER_STORAGE(HardwareBufferCpuStorage);
} // namespace
} // namespace mediapipe
#endif // !defined(MEDIAPIPE_NO_JNI) && (__ANDROID_API__ >= 26 ||
// defined(__ANDROID_UNAVAILABLE_SYMBOLS_ARE_WEAK__))

View File

@ -0,0 +1,76 @@
#if !defined(MEDIAPIPE_NO_JNI) && \
(__ANDROID_API__ >= 26 || \
defined(__ANDROID_UNAVAILABLE_SYMBOLS_ARE_WEAK__))
#include <android/hardware_buffer.h>
#include <cstdint>
#include "mediapipe/framework/formats/tensor_cpu_buffer.h"
#include "mediapipe/framework/formats/tensor_hardware_buffer.h"
#include "mediapipe/framework/formats/tensor_v2.h"
#include "testing/base/public/gmock.h"
#include "testing/base/public/gunit.h"
namespace mediapipe {
namespace {
class TensorHardwareBufferTest : public ::testing::Test {
public:
TensorHardwareBufferTest() {}
~TensorHardwareBufferTest() override {}
};
TEST_F(TensorHardwareBufferTest, TestFloat32) {
Tensor tensor{Tensor::Shape({1})};
{
MP_ASSERT_OK_AND_ASSIGN(
auto view,
tensor.GetView<Tensor::View::Access::kWriteOnly>(
TensorHardwareBufferViewDescriptor{
.buffer = {.format =
TensorBufferDescriptor::Format::kFloat32}}));
EXPECT_NE(view->handle(), nullptr);
}
{
const auto& const_tensor = tensor;
MP_ASSERT_OK_AND_ASSIGN(
auto view,
const_tensor.GetView<Tensor::View::Access::kReadOnly>(
TensorCpuViewDescriptor{
.buffer = {.format =
TensorBufferDescriptor::Format::kFloat32}}));
EXPECT_NE(view->data<void>(), nullptr);
}
}
TEST_F(TensorHardwareBufferTest, TestInt8Padding) {
Tensor tensor{Tensor::Shape({1})};
{
MP_ASSERT_OK_AND_ASSIGN(
auto view,
tensor.GetView<Tensor::View::Access::kWriteOnly>(
TensorHardwareBufferViewDescriptor{
.buffer = {.format = TensorBufferDescriptor::Format::kInt8,
.size_alignment = 4}}));
EXPECT_NE(view->handle(), nullptr);
}
{
const auto& const_tensor = tensor;
MP_ASSERT_OK_AND_ASSIGN(
auto view,
const_tensor.GetView<Tensor::View::Access::kReadOnly>(
TensorCpuViewDescriptor{
.buffer = {.format = TensorBufferDescriptor::Format::kInt8}}));
EXPECT_NE(view->data<void>(), nullptr);
}
}
} // namespace
} // namespace mediapipe
#endif // !defined(MEDIAPIPE_NO_JNI) && (__ANDROID_API__ >= 26 ||
// defined(__ANDROID_UNAVAILABLE_SYMBOLS_ARE_WEAK__))

View File

@ -127,6 +127,8 @@ class OutputStreamShard : public OutputStream {
friend class GraphProfiler;
// Accesses OutputStreamShard for profiling.
friend class GraphTracer;
// Accesses OutputStreamShard for profiling.
friend class PerfettoTraceScope;
// Accesses OutputStreamShard for post processing.
friend class OutputStreamManager;
};

View File

@ -18,7 +18,7 @@
licenses(["notice"])
package(
default_visibility = ["//visibility:private"],
default_visibility = ["//visibility:public"],
features = ["-parse_headers"],
)
@ -28,7 +28,6 @@ config_setting(
define_values = {
"USE_MEDIAPIPE_THREADPOOL": "1",
},
visibility = ["//visibility:public"],
)
#TODO : remove from OSS.
@ -37,13 +36,11 @@ config_setting(
define_values = {
"USE_MEDIAPIPE_THREADPOOL": "0",
},
visibility = ["//visibility:public"],
)
cc_library(
name = "aligned_malloc_and_free",
hdrs = ["aligned_malloc_and_free.h"],
visibility = ["//visibility:public"],
deps = [
"//mediapipe/framework/deps:aligned_malloc_and_free",
"@com_google_absl//absl/base:core_headers",
@ -57,7 +54,6 @@ cc_library(
"advanced_proto_inc.h",
"proto_ns.h",
],
visibility = ["//visibility:public"],
deps = [
":advanced_proto_lite",
":core_proto",
@ -72,7 +68,6 @@ cc_library(
"advanced_proto_lite_inc.h",
"proto_ns.h",
],
visibility = ["//visibility:public"],
deps = [
":core_proto",
"//mediapipe/framework:port",
@ -83,7 +78,6 @@ cc_library(
cc_library(
name = "any_proto",
hdrs = ["any_proto.h"],
visibility = ["//visibility:public"],
deps = [
":core_proto",
],
@ -94,7 +88,6 @@ cc_library(
hdrs = [
"commandlineflags.h",
],
visibility = ["//visibility:public"],
deps = [
"//third_party:glog",
"@com_google_absl//absl/flags:flag",
@ -107,7 +100,6 @@ cc_library(
"core_proto_inc.h",
"proto_ns.h",
],
visibility = ["//visibility:public"],
deps = [
"//mediapipe/framework:port",
"@com_google_protobuf//:protobuf",
@ -117,7 +109,6 @@ cc_library(
cc_library(
name = "file_helpers",
hdrs = ["file_helpers.h"],
visibility = ["//visibility:public"],
deps = [
":status",
"//mediapipe/framework/deps:file_helpers",
@ -128,7 +119,6 @@ cc_library(
cc_library(
name = "image_resizer",
hdrs = ["image_resizer.h"],
visibility = ["//visibility:public"],
deps = select({
"//conditions:default": [
"//mediapipe/framework/deps:image_resizer",
@ -140,14 +130,12 @@ cc_library(
cc_library(
name = "integral_types",
hdrs = ["integral_types.h"],
visibility = ["//visibility:public"],
)
cc_library(
name = "benchmark",
testonly = 1,
hdrs = ["benchmark.h"],
visibility = ["//visibility:public"],
deps = [
"@com_google_benchmark//:benchmark",
],
@ -158,7 +146,6 @@ cc_library(
hdrs = [
"re2.h",
],
visibility = ["//visibility:public"],
deps = [
"//mediapipe/framework/deps:re2",
],
@ -173,7 +160,6 @@ cc_library(
"gtest-spi.h",
"status_matchers.h",
],
visibility = ["//visibility:public"],
deps = [
":status_matchers",
"//mediapipe/framework/deps:message_matchers",
@ -190,7 +176,6 @@ cc_library(
"gtest-spi.h",
"status_matchers.h",
],
visibility = ["//visibility:public"],
deps = [
":status_matchers",
"//mediapipe/framework/deps:message_matchers",
@ -204,7 +189,6 @@ cc_library(
hdrs = [
"logging.h",
],
visibility = ["//visibility:public"],
deps = [
"//mediapipe/framework:port",
"//third_party:glog",
@ -217,7 +201,6 @@ cc_library(
hdrs = [
"map_util.h",
],
visibility = ["//visibility:public"],
deps = [
"//mediapipe/framework:port",
"//mediapipe/framework/deps:map_util",
@ -227,7 +210,6 @@ cc_library(
cc_library(
name = "numbers",
hdrs = ["numbers.h"],
visibility = ["//visibility:public"],
deps = ["//mediapipe/framework/deps:numbers"],
)
@ -238,13 +220,11 @@ config_setting(
define_values = {
"MEDIAPIPE_DISABLE_OPENCV": "1",
},
visibility = ["//visibility:public"],
)
cc_library(
name = "opencv_core",
hdrs = ["opencv_core_inc.h"],
visibility = ["//visibility:public"],
deps = [
"//third_party:opencv",
],
@ -253,7 +233,6 @@ cc_library(
cc_library(
name = "opencv_imgproc",
hdrs = ["opencv_imgproc_inc.h"],
visibility = ["//visibility:public"],
deps = [
":opencv_core",
"//third_party:opencv",
@ -263,7 +242,6 @@ cc_library(
cc_library(
name = "opencv_imgcodecs",
hdrs = ["opencv_imgcodecs_inc.h"],
visibility = ["//visibility:public"],
deps = [
":opencv_core",
"//third_party:opencv",
@ -273,7 +251,6 @@ cc_library(
cc_library(
name = "opencv_highgui",
hdrs = ["opencv_highgui_inc.h"],
visibility = ["//visibility:public"],
deps = [
":opencv_core",
"//third_party:opencv",
@ -283,7 +260,6 @@ cc_library(
cc_library(
name = "opencv_video",
hdrs = ["opencv_video_inc.h"],
visibility = ["//visibility:public"],
deps = [
":opencv_core",
"//mediapipe/framework:port",
@ -294,7 +270,6 @@ cc_library(
cc_library(
name = "opencv_features2d",
hdrs = ["opencv_features2d_inc.h"],
visibility = ["//visibility:public"],
deps = [
":opencv_core",
"//third_party:opencv",
@ -304,7 +279,6 @@ cc_library(
cc_library(
name = "opencv_calib3d",
hdrs = ["opencv_calib3d_inc.h"],
visibility = ["//visibility:public"],
deps = [
":opencv_core",
"//third_party:opencv",
@ -314,7 +288,6 @@ cc_library(
cc_library(
name = "opencv_videoio",
hdrs = ["opencv_videoio_inc.h"],
visibility = ["//visibility:public"],
deps = [
":opencv_core",
"//mediapipe/framework:port",
@ -328,7 +301,6 @@ cc_library(
"parse_text_proto.h",
"proto_ns.h",
],
visibility = ["//visibility:public"],
deps = [
":core_proto",
":logging",
@ -339,14 +311,12 @@ cc_library(
cc_library(
name = "point",
hdrs = ["point2.h"],
visibility = ["//visibility:public"],
deps = ["//mediapipe/framework/deps:point"],
)
cc_library(
name = "port",
hdrs = ["port.h"],
visibility = ["//visibility:public"],
deps = [
"//mediapipe/framework:port",
"@com_google_absl//absl/base:core_headers",
@ -356,14 +326,12 @@ cc_library(
cc_library(
name = "rectangle",
hdrs = ["rectangle.h"],
visibility = ["//visibility:public"],
deps = ["//mediapipe/framework/deps:rectangle"],
)
cc_library(
name = "ret_check",
hdrs = ["ret_check.h"],
visibility = ["//visibility:public"],
deps = [
"//mediapipe/framework:port",
"//mediapipe/framework/deps:ret_check",
@ -373,7 +341,6 @@ cc_library(
cc_library(
name = "singleton",
hdrs = ["singleton.h"],
visibility = ["//visibility:public"],
deps = ["//mediapipe/framework/deps:singleton"],
)
@ -382,7 +349,6 @@ cc_library(
hdrs = [
"source_location.h",
],
visibility = ["//visibility:public"],
deps = [
"//mediapipe/framework:port",
"//mediapipe/framework/deps:source_location",
@ -397,7 +363,6 @@ cc_library(
"status_builder.h",
"status_macros.h",
],
visibility = ["//visibility:public"],
deps = [
":source_location",
"//mediapipe/framework:port",
@ -412,7 +377,6 @@ cc_library(
hdrs = [
"statusor.h",
],
visibility = ["//visibility:public"],
deps = [
"//mediapipe/framework:port",
"@com_google_absl//absl/status:statusor",
@ -423,7 +387,6 @@ cc_library(
name = "status_matchers",
testonly = 1,
hdrs = ["status_matchers.h"],
visibility = ["//visibility:private"],
deps = [
":status",
"@com_google_googletest//:gtest",
@ -433,7 +396,6 @@ cc_library(
cc_library(
name = "threadpool",
hdrs = ["threadpool.h"],
visibility = ["//visibility:public"],
deps = select({
"//conditions:default": [":threadpool_impl_default_to_google"],
"//mediapipe:android": [":threadpool_impl_default_to_mediapipe"],
@ -460,7 +422,6 @@ alias(
cc_library(
name = "topologicalsorter",
hdrs = ["topologicalsorter.h"],
visibility = ["//visibility:public"],
deps = [
"//mediapipe/framework:port",
"//mediapipe/framework/deps:topologicalsorter",
@ -470,6 +431,5 @@ cc_library(
cc_library(
name = "vector",
hdrs = ["vector.h"],
visibility = ["//visibility:public"],
deps = ["//mediapipe/framework/deps:vector"],
)

View File

@ -228,6 +228,8 @@ def mediapipe_ts_library(
srcs = srcs,
visibility = visibility,
deps = deps + [
"@npm//@types/jasmine",
"@npm//@types/node",
"@npm//@types/offscreencanvas",
"@npm//@types/google-protobuf",
],

View File

@ -140,7 +140,7 @@ cc_library(
name = "circular_buffer",
hdrs = ["circular_buffer.h"],
visibility = [
"//visibility:public",
"//mediapipe:__subpackages__",
],
deps = [
"//mediapipe/framework/port:integral_types",
@ -151,7 +151,6 @@ cc_test(
name = "circular_buffer_test",
size = "small",
srcs = ["circular_buffer_test.cc"],
visibility = ["//visibility:public"],
deps = [
":circular_buffer",
"//mediapipe/framework/port:gtest_main",
@ -164,7 +163,7 @@ cc_library(
name = "trace_buffer",
srcs = ["trace_buffer.h"],
hdrs = ["trace_buffer.h"],
visibility = ["//visibility:public"],
visibility = ["//mediapipe/framework/profiler:__subpackages__"],
deps = [
":circular_buffer",
"//mediapipe/framework:calculator_profile_cc_proto",
@ -292,9 +291,7 @@ cc_library(
"-ObjC++",
],
}),
visibility = [
"//mediapipe/framework:mediapipe_internal",
],
visibility = ["//visibility:private"],
deps = [
"@com_google_absl//absl/flags:flag",
"//mediapipe/framework/port:logging",

View File

@ -232,6 +232,11 @@ class GraphProfiler : public std::enable_shared_from_this<ProfilingContext> {
const ProfilerConfig& profiler_config() { return profiler_config_; }
// Helper method to expose the config to other profilers.
const ValidatedGraphConfig* GetValidatedGraphConfig() {
return validated_graph_;
}
private:
// This can be used to add packet info for the input streams to the graph.
// It treats the stream defined by |stream_name| as a stream produced by a

View File

@ -117,7 +117,7 @@ void Scheduler::SubmitWaitingTasksOnQueues() {
// Note: state_mutex_ is held when this function is entered or
// exited.
void Scheduler::HandleIdle() {
if (handling_idle_) {
if (++handling_idle_ > 1) {
// Someone is already inside this method.
// Note: This can happen in the sections below where we unlock the mutex
// and make more nodes runnable: the nodes can run and become idle again
@ -127,7 +127,6 @@ void Scheduler::HandleIdle() {
VLOG(2) << "HandleIdle: already in progress";
return;
}
handling_idle_ = true;
while (IsIdle() && (state_ == STATE_RUNNING || state_ == STATE_CANCELLING)) {
// Remove active sources that are closed.
@ -165,11 +164,17 @@ void Scheduler::HandleIdle() {
}
}
// If HandleIdle has been called again, then continue scheduling.
if (handling_idle_ > 1) {
handling_idle_ = 1;
continue;
}
// Nothing left to do.
break;
}
handling_idle_ = false;
handling_idle_ = 0;
}
// Note: state_mutex_ is held when this function is entered or exited.

View File

@ -302,7 +302,7 @@ class Scheduler {
// - We need it to be reentrant, which Mutex does not support.
// - We want simultaneous calls to return immediately instead of waiting,
// and Mutex's TryLock is not guaranteed to work.
bool handling_idle_ ABSL_GUARDED_BY(state_mutex_) = false;
int handling_idle_ ABSL_GUARDED_BY(state_mutex_) = 0;
// Mutex for the scheduler state and related things.
// Note: state_ is declared as atomic so that its getter methods don't need

View File

@ -90,7 +90,7 @@ mediapipe_proto_library(
name = "packet_generator_wrapper_calculator_proto",
srcs = ["packet_generator_wrapper_calculator.proto"],
def_py_proto = False,
visibility = ["//mediapipe/framework:mediapipe_internal"],
visibility = ["//mediapipe/framework:__subpackages__"],
deps = [
"//mediapipe/framework:calculator_options_proto",
"//mediapipe/framework:packet_generator_proto",
@ -120,13 +120,13 @@ cc_library(
name = "fill_packet_set",
srcs = ["fill_packet_set.cc"],
hdrs = ["fill_packet_set.h"],
visibility = ["//mediapipe/framework:mediapipe_internal"],
visibility = ["//mediapipe/framework:__subpackages__"],
deps = [
":status_util",
"//mediapipe/framework:packet_set",
"//mediapipe/framework:packet_type",
"//mediapipe/framework/port:status",
"//mediapipe/framework/port:statusor",
"//mediapipe/framework/tool:status_util",
"@com_google_absl//absl/memory",
],
)
@ -162,7 +162,6 @@ cc_library(
cc_test(
name = "executor_util_test",
srcs = ["executor_util_test.cc"],
visibility = ["//mediapipe/framework:mediapipe_internal"],
deps = [
":executor_util",
"//mediapipe/framework/port:gtest_main",
@ -173,7 +172,7 @@ cc_test(
cc_library(
name = "options_map",
hdrs = ["options_map.h"],
visibility = ["//mediapipe/framework:mediapipe_internal"],
visibility = ["//mediapipe:__subpackages__"],
deps = [
":type_util",
"//mediapipe/framework:calculator_cc_proto",
@ -193,7 +192,7 @@ cc_library(
name = "options_field_util",
srcs = ["options_field_util.cc"],
hdrs = ["options_field_util.h"],
visibility = ["//mediapipe/framework:mediapipe_internal"],
visibility = ["//visibility:private"],
deps = [
":field_data_cc_proto",
":name_util",
@ -216,7 +215,7 @@ cc_library(
name = "options_syntax_util",
srcs = ["options_syntax_util.cc"],
hdrs = ["options_syntax_util.h"],
visibility = ["//mediapipe/framework:mediapipe_internal"],
visibility = ["//visibility:private"],
deps = [
":name_util",
":options_field_util",
@ -235,8 +234,9 @@ cc_library(
name = "options_util",
srcs = ["options_util.cc"],
hdrs = ["options_util.h"],
visibility = ["//mediapipe/framework:mediapipe_internal"],
visibility = ["//visibility:public"],
deps = [
":name_util",
":options_field_util",
":options_map",
":options_registry",
@ -254,7 +254,6 @@ cc_library(
"//mediapipe/framework/port:advanced_proto",
"//mediapipe/framework/port:any_proto",
"//mediapipe/framework/port:status",
"//mediapipe/framework/tool:name_util",
"@com_google_absl//absl/strings",
],
)
@ -323,7 +322,7 @@ mediapipe_cc_test(
cc_library(
name = "packet_generator_wrapper_calculator",
srcs = ["packet_generator_wrapper_calculator.cc"],
visibility = ["//mediapipe/framework:mediapipe_internal"],
visibility = ["//mediapipe/framework:__subpackages__"],
deps = [
":packet_generator_wrapper_calculator_cc_proto",
"//mediapipe/framework:calculator_base",
@ -347,6 +346,7 @@ cc_library(
"//mediapipe/framework/port:logging",
"//mediapipe/framework/port:ret_check",
"//mediapipe/framework/port:status",
"//mediapipe/framework/port:statusor",
"@com_google_absl//absl/strings",
],
)
@ -507,6 +507,7 @@ cc_library(
visibility = ["//visibility:public"],
deps = [
":proto_util_lite",
"//mediapipe/framework:calculator_cc_proto",
"//mediapipe/framework/deps:proto_descriptor_cc_proto",
"//mediapipe/framework/port:advanced_proto",
"//mediapipe/framework/port:integral_types",

View File

@ -27,6 +27,9 @@ message TemplateExpression {
// The FieldDescriptor::Type of the modified field.
optional mediapipe.FieldDescriptorProto.Type field_type = 5;
// The FieldDescriptor::Type of each map key in the path.
repeated mediapipe.FieldDescriptorProto.Type key_type = 6;
// Alternative value for the modified field, in protobuf binary format.
optional string field_value = 7;
}

View File

@ -22,6 +22,7 @@
#include "mediapipe/framework/port/canonical_errors.h"
#include "mediapipe/framework/port/logging.h"
#include "mediapipe/framework/port/ret_check.h"
#include "mediapipe/framework/port/statusor.h"
#include "mediapipe/framework/tool/field_data.pb.h"
#include "mediapipe/framework/type_map.h"
@ -87,12 +88,13 @@ absl::Status ReadPackedValues(WireFormatLite::WireType wire_type,
// Extracts the data value(s) for one field from a serialized message.
// The message with these field values removed is written to |out|.
absl::Status GetFieldValues(uint32 field_id, WireFormatLite::WireType wire_type,
CodedInputStream* in, CodedOutputStream* out,
absl::Status GetFieldValues(uint32 field_id, CodedInputStream* in,
CodedOutputStream* out,
std::vector<std::string>* field_values) {
uint32 tag;
while ((tag = in->ReadTag()) != 0) {
int field_number = WireFormatLite::GetTagFieldNumber(tag);
WireFormatLite::WireType wire_type = WireFormatLite::GetTagWireType(tag);
if (field_number == field_id) {
if (!IsLengthDelimited(wire_type) &&
IsLengthDelimited(WireFormatLite::GetTagWireType(tag))) {
@ -131,9 +133,7 @@ absl::Status FieldAccess::SetMessage(const std::string& message) {
CodedInputStream in(&ais);
StringOutputStream sos(&message_);
CodedOutputStream out(&sos);
WireFormatLite::WireType wire_type =
WireFormatLite::WireTypeForFieldType(field_type_);
return GetFieldValues(field_id_, wire_type, &in, &out, &field_values_);
return GetFieldValues(field_id_, &in, &out, &field_values_);
}
void FieldAccess::GetMessage(std::string* result) {
@ -149,18 +149,56 @@ std::vector<FieldValue>* FieldAccess::mutable_field_values() {
return &field_values_;
}
namespace {
using ProtoPathEntry = ProtoUtilLite::ProtoPathEntry;
// Returns the FieldAccess and index for a field-id or a map-id.
// Returns access to the field-id if the field index is found,
// to the map-id if the map entry is found, and to the field-id otherwise.
absl::StatusOr<std::pair<FieldAccess, int>> AccessField(
const ProtoPathEntry& entry, FieldType field_type,
const FieldValue& message) {
FieldAccess result(entry.field_id, field_type);
if (entry.field_id >= 0) {
MP_RETURN_IF_ERROR(result.SetMessage(message));
if (entry.index < result.mutable_field_values()->size()) {
return std::pair(result, entry.index);
}
}
if (entry.map_id >= 0) {
FieldAccess access(entry.map_id, field_type);
MP_RETURN_IF_ERROR(access.SetMessage(message));
auto& field_values = *access.mutable_field_values();
for (int index = 0; index < field_values.size(); ++index) {
FieldAccess key(entry.key_id, entry.key_type);
MP_RETURN_IF_ERROR(key.SetMessage(field_values[index]));
if (key.mutable_field_values()->at(0) == entry.key_value) {
return std::pair(std::move(access), index);
}
}
}
if (entry.field_id >= 0) {
return std::pair(result, entry.index);
}
return absl::InvalidArgumentError(absl::StrCat(
"ProtoPath field missing, field-id: ", entry.field_id, ", map-id: ",
entry.map_id, ", key: ", entry.key_value, " key_type: ", entry.key_type));
}
} // namespace
// Replaces a range of field values for one field nested within a protobuf.
absl::Status ProtoUtilLite::ReplaceFieldRange(
FieldValue* message, ProtoPath proto_path, int length, FieldType field_type,
const std::vector<FieldValue>& field_values) {
int field_id, index;
std::tie(field_id, index) = proto_path.front();
ProtoPathEntry entry = proto_path.front();
proto_path.erase(proto_path.begin());
FieldAccess access(field_id, !proto_path.empty()
? WireFormatLite::TYPE_MESSAGE
: field_type);
MP_RETURN_IF_ERROR(access.SetMessage(*message));
std::vector<std::string>& v = *access.mutable_field_values();
FieldType type =
!proto_path.empty() ? WireFormatLite::TYPE_MESSAGE : field_type;
ASSIGN_OR_RETURN(auto r, AccessField(entry, type, *message));
FieldAccess& access = r.first;
int index = r.second;
std::vector<FieldValue>& v = *access.mutable_field_values();
if (!proto_path.empty()) {
RET_CHECK_NO_LOG(index >= 0 && index < v.size());
MP_RETURN_IF_ERROR(ReplaceFieldRange(&v[index], proto_path, length,
@ -180,19 +218,22 @@ absl::Status ProtoUtilLite::ReplaceFieldRange(
absl::Status ProtoUtilLite::GetFieldRange(
const FieldValue& message, ProtoPath proto_path, int length,
FieldType field_type, std::vector<FieldValue>* field_values) {
int field_id, index;
std::tie(field_id, index) = proto_path.front();
ProtoPathEntry entry = proto_path.front();
proto_path.erase(proto_path.begin());
FieldAccess access(field_id, !proto_path.empty()
? WireFormatLite::TYPE_MESSAGE
: field_type);
MP_RETURN_IF_ERROR(access.SetMessage(message));
std::vector<std::string>& v = *access.mutable_field_values();
FieldType type =
!proto_path.empty() ? WireFormatLite::TYPE_MESSAGE : field_type;
ASSIGN_OR_RETURN(auto r, AccessField(entry, type, message));
FieldAccess& access = r.first;
int index = r.second;
std::vector<FieldValue>& v = *access.mutable_field_values();
if (!proto_path.empty()) {
RET_CHECK_NO_LOG(index >= 0 && index < v.size());
MP_RETURN_IF_ERROR(
GetFieldRange(v[index], proto_path, length, field_type, field_values));
} else {
if (length == -1) {
length = v.size() - index;
}
RET_CHECK_NO_LOG(index >= 0 && index <= v.size());
RET_CHECK_NO_LOG(index + length >= 0 && index + length <= v.size());
field_values->insert(field_values->begin(), v.begin() + index,
@ -206,19 +247,21 @@ absl::Status ProtoUtilLite::GetFieldCount(const FieldValue& message,
ProtoPath proto_path,
FieldType field_type,
int* field_count) {
int field_id, index;
std::tie(field_id, index) = proto_path.back();
proto_path.pop_back();
std::vector<std::string> parent;
if (proto_path.empty()) {
parent.push_back(std::string(message));
ProtoPathEntry entry = proto_path.front();
proto_path.erase(proto_path.begin());
FieldType type =
!proto_path.empty() ? WireFormatLite::TYPE_MESSAGE : field_type;
ASSIGN_OR_RETURN(auto r, AccessField(entry, type, message));
FieldAccess& access = r.first;
int index = r.second;
std::vector<FieldValue>& v = *access.mutable_field_values();
if (!proto_path.empty()) {
RET_CHECK_NO_LOG(index >= 0 && index < v.size());
MP_RETURN_IF_ERROR(
GetFieldCount(v[index], proto_path, field_type, field_count));
} else {
MP_RETURN_IF_ERROR(ProtoUtilLite::GetFieldRange(
message, proto_path, 1, WireFormatLite::TYPE_MESSAGE, &parent));
*field_count = v.size();
}
FieldAccess access(field_id, field_type);
MP_RETURN_IF_ERROR(access.SetMessage(parent[0]));
*field_count = access.mutable_field_values()->size();
return absl::OkStatus();
}

View File

@ -34,15 +34,36 @@ class ProtoUtilLite {
// Defines field types and tag formats.
using WireFormatLite = proto_ns::internal::WireFormatLite;
// Defines a sequence of nested field-number field-index pairs.
using ProtoPath = std::vector<std::pair<int, int>>;
// The serialized value for a protobuf field.
using FieldValue = std::string;
// The serialized data type for a protobuf field.
using FieldType = WireFormatLite::FieldType;
// A field-id and index, or a map-id and key, or both.
struct ProtoPathEntry {
ProtoPathEntry(int id, int index) : field_id(id), index(index) {}
ProtoPathEntry(int id, int key_id, FieldType key_type, FieldValue key_value)
: map_id(id),
key_id(key_id),
key_type(key_type),
key_value(std::move(key_value)) {}
bool operator==(const ProtoPathEntry& o) const {
return field_id == o.field_id && index == o.index && map_id == o.map_id &&
key_id == o.key_id && key_type == o.key_type &&
key_value == o.key_value;
}
int field_id = -1;
int index = -1;
int map_id = -1;
int key_id = -1;
FieldType key_type = FieldType::MAX_FIELD_TYPE;
FieldValue key_value;
};
// Defines a sequence of nested field-number field-index pairs.
using ProtoPath = std::vector<ProtoPathEntry>;
class FieldAccess {
public:
// Provides access to a certain protobuf field.
@ -57,9 +78,11 @@ class ProtoUtilLite {
// Returns the serialized values of the protobuf field.
std::vector<FieldValue>* mutable_field_values();
uint32 field_id() const { return field_id_; }
private:
const uint32 field_id_;
const FieldType field_type_;
uint32 field_id_;
FieldType field_type_;
std::string message_;
std::vector<FieldValue> field_values_;
};

View File

@ -22,6 +22,7 @@
#include <vector>
#include "absl/strings/ascii.h"
#include "absl/strings/match.h"
#include "absl/strings/numbers.h"
#include "absl/strings/str_join.h"
#include "absl/strings/str_split.h"
@ -44,6 +45,7 @@ using WireFormatLite = ProtoUtilLite::WireFormatLite;
using FieldValue = ProtoUtilLite::FieldValue;
using FieldType = ProtoUtilLite::FieldType;
using ProtoPath = ProtoUtilLite::ProtoPath;
using ProtoPathEntry = ProtoUtilLite::ProtoPathEntry;
namespace {
@ -84,26 +86,87 @@ std::unique_ptr<MessageLite> CloneMessage(const MessageLite& message) {
return result;
}
// Returns the (tag, index) pairs in a field path.
// For example, returns {{1, 1}, {2, 1}, {3, 1}} for path "/1[1]/2[1]/3[1]".
absl::Status ProtoPathSplit(const std::string& path, ProtoPath* result) {
absl::Status status;
std::vector<std::string> ids = absl::StrSplit(path, '/');
for (const std::string& id : ids) {
if (id.length() > 0) {
std::pair<std::string, std::string> id_pair =
absl::StrSplit(id, absl::ByAnyChar("[]"));
int tag = 0;
int index = 0;
bool ok = absl::SimpleAtoi(id_pair.first, &tag) &&
absl::SimpleAtoi(id_pair.second, &index);
if (!ok) {
status.Update(absl::InvalidArgumentError(path));
}
result->push_back(std::make_pair(tag, index));
// Parses one ProtoPathEntry.
// The parsed entry is appended to `result` and removed from `path`.
// ProtoPathEntry::key_value stores map key text. Use SetMapKeyTypes
// to serialize the key text to protobuf wire format.
absl::Status ParseEntry(absl::string_view& path, ProtoPath* result) {
bool ok = true;
int sb = path.find('[');
int eb = path.find(']');
int field_id = -1;
ok &= absl::SimpleAtoi(path.substr(0, sb), &field_id);
auto selector = path.substr(sb + 1, eb - 1 - sb);
if (absl::StartsWith(selector, "@")) {
int eq = selector.find('=');
int key_id = -1;
ok &= absl::SimpleAtoi(selector.substr(1, eq - 1), &key_id);
auto key_text = selector.substr(eq + 1);
FieldType key_type = FieldType::TYPE_STRING;
result->push_back({field_id, key_id, key_type, std::string(key_text)});
} else {
int index = 0;
ok &= absl::SimpleAtoi(selector, &index);
result->push_back({field_id, index});
}
int end = path.find('/', eb);
if (end == std::string::npos) {
path = "";
} else {
path = path.substr(end + 1);
}
return ok ? absl::OkStatus()
: absl::InvalidArgumentError(
absl::StrCat("Failed to parse ProtoPath entry: ", path));
}
// Specifies the FieldTypes for protobuf map keys in a ProtoPath.
// Each ProtoPathEntry::key_value is converted from text to the protobuf
// wire format for its key type.
absl::Status SetMapKeyTypes(const std::vector<FieldType>& key_types,
ProtoPath* result) {
int i = 0;
for (ProtoPathEntry& entry : *result) {
if (entry.map_id >= 0) {
FieldType key_type = key_types[i++];
std::vector<FieldValue> key_value;
MP_RETURN_IF_ERROR(
ProtoUtilLite::Serialize({entry.key_value}, key_type, &key_value));
entry.key_type = key_type;
entry.key_value = key_value.front();
}
}
return status;
return absl::OkStatus();
}
// Returns the (tag, index) pairs in a field path.
// For example, returns {{1, 1}, {2, 1}, {3, 1}} for "/1[1]/2[1]/3[1]",
// returns {{1, 1}, {2, 1, "INPUT_FRAMES"}} for "/1[1]/2[@1=INPUT_FRAMES]".
absl::Status ProtoPathSplit(const std::string& path, ProtoPath* result) {
result->clear();
absl::string_view rest = path;
if (absl::StartsWith(rest, "/")) {
rest = rest.substr(1);
}
while (!rest.empty()) {
MP_RETURN_IF_ERROR(ParseEntry(rest, result));
}
return absl::OkStatus();
}
// Parse the TemplateExpression.path field into a ProtoPath struct.
absl::Status ParseProtoPath(const TemplateExpression& rule,
std::string base_path, ProtoPath* result) {
ProtoPath base_entries;
MP_RETURN_IF_ERROR(ProtoPathSplit(base_path, &base_entries));
MP_RETURN_IF_ERROR(ProtoPathSplit(rule.path(), result));
std::vector<FieldType> key_types;
for (int type : rule.key_type()) {
key_types.push_back(static_cast<FieldType>(type));
}
MP_RETURN_IF_ERROR(SetMapKeyTypes(key_types, result));
result->erase(result->begin(), result->begin() + base_entries.size());
return absl::OkStatus();
}
// Returns true if one proto path is prefix by another.
@ -111,13 +174,6 @@ bool ProtoPathStartsWith(const std::string& path, const std::string& prefix) {
return absl::StartsWith(path, prefix);
}
// Returns the part of one proto path after a prefix proto path.
std::string ProtoPathRelative(const std::string& field_path,
const std::string& base_path) {
CHECK(ProtoPathStartsWith(field_path, base_path));
return field_path.substr(base_path.length());
}
// Returns the target ProtoUtilLite::FieldType of a rule.
FieldType GetFieldType(const TemplateExpression& rule) {
return static_cast<FieldType>(rule.field_type());
@ -126,19 +182,10 @@ FieldType GetFieldType(const TemplateExpression& rule) {
// Returns the count of field values at a ProtoPath.
int FieldCount(const FieldValue& base, ProtoPath field_path,
FieldType field_type) {
int field_id, index;
std::tie(field_id, index) = field_path.back();
field_path.pop_back();
std::vector<FieldValue> parent;
if (field_path.empty()) {
parent.push_back(base);
} else {
MEDIAPIPE_CHECK_OK(ProtoUtilLite::GetFieldRange(
base, field_path, 1, WireFormatLite::TYPE_MESSAGE, &parent));
}
ProtoUtilLite::FieldAccess access(field_id, field_type);
MEDIAPIPE_CHECK_OK(access.SetMessage(parent[0]));
return access.mutable_field_values()->size();
int result = 0;
CHECK(
ProtoUtilLite::GetFieldCount(base, field_path, field_type, &result).ok());
return result;
}
} // namespace
@ -229,9 +276,7 @@ class TemplateExpanderImpl {
return absl::OkStatus();
}
ProtoPath field_path;
absl::Status status =
ProtoPathSplit(ProtoPathRelative(rule.path(), base_path), &field_path);
if (!status.ok()) return status;
MP_RETURN_IF_ERROR(ParseProtoPath(rule, base_path, &field_path));
return ProtoUtilLite::GetFieldRange(output, field_path, 1,
GetFieldType(rule), base);
}
@ -242,12 +287,13 @@ class TemplateExpanderImpl {
const std::vector<FieldValue>& field_values,
FieldValue* output) {
if (!rule.has_path()) {
*output = field_values[0];
if (!field_values.empty()) {
*output = field_values[0];
}
return absl::OkStatus();
}
ProtoPath field_path;
RET_CHECK_OK(
ProtoPathSplit(ProtoPathRelative(rule.path(), base_path), &field_path));
MP_RETURN_IF_ERROR(ParseProtoPath(rule, base_path, &field_path));
int field_count = 1;
if (rule.has_field_value()) {
// For a non-repeated field, only one value can be specified.
@ -257,7 +303,7 @@ class TemplateExpanderImpl {
"Multiple values specified for non-repeated field: ", rule.path()));
}
// For a non-repeated field, the field value is stored only in the rule.
field_path[field_path.size() - 1].second = 0;
field_path[field_path.size() - 1].index = 0;
field_count = 0;
}
return ProtoUtilLite::ReplaceFieldRange(output, field_path, field_count,

View File

@ -26,6 +26,7 @@
#include "absl/strings/str_cat.h"
#include "absl/strings/str_join.h"
#include "absl/strings/str_split.h"
#include "mediapipe/framework/calculator.pb.h"
#include "mediapipe/framework/deps/proto_descriptor.pb.h"
#include "mediapipe/framework/port/canonical_errors.h"
#include "mediapipe/framework/port/integral_types.h"
@ -45,6 +46,9 @@ using mediapipe::proto_ns::Message;
using mediapipe::proto_ns::OneofDescriptor;
using mediapipe::proto_ns::Reflection;
using mediapipe::proto_ns::TextFormat;
using ProtoPath = mediapipe::tool::ProtoUtilLite::ProtoPath;
using FieldType = mediapipe::tool::ProtoUtilLite::FieldType;
using FieldValue = mediapipe::tool::ProtoUtilLite::FieldValue;
namespace mediapipe {
@ -1357,32 +1361,138 @@ absl::Status ProtoPathSplit(const std::string& path,
if (!ok) {
status.Update(absl::InvalidArgumentError(path));
}
result->push_back(std::make_pair(tag, index));
result->push_back({tag, index});
}
}
return status;
}
// Returns a message serialized deterministically.
bool DeterministicallySerialize(const Message& proto, std::string* result) {
proto_ns::io::StringOutputStream stream(result);
proto_ns::io::CodedOutputStream output(&stream);
output.SetSerializationDeterministic(true);
return proto.SerializeToCodedStream(&output);
}
// Serialize one field of a message.
void SerializeField(const Message* message, const FieldDescriptor* field,
std::vector<ProtoUtilLite::FieldValue>* result) {
ProtoUtilLite::FieldValue message_bytes;
CHECK(message->SerializePartialToString(&message_bytes));
CHECK(DeterministicallySerialize(*message, &message_bytes));
ProtoUtilLite::FieldAccess access(
field->number(), static_cast<ProtoUtilLite::FieldType>(field->type()));
MEDIAPIPE_CHECK_OK(access.SetMessage(message_bytes));
*result = *access.mutable_field_values();
}
// Serialize a ProtoPath as a readable string.
// For example, {{1, 1}, {2, 1}, {3, 1}} returns "/1[1]/2[1]/3[1]",
// and {{1, 1}, {2, 1, "INPUT_FRAMES"}} returns "/1[1]/2[@1=INPUT_FRAMES]".
std::string ProtoPathJoin(ProtoPath path) {
std::string result;
for (ProtoUtilLite::ProtoPathEntry& e : path) {
if (e.field_id >= 0) {
absl::StrAppend(&result, "/", e.field_id, "[", e.index, "]");
} else if (e.map_id >= 0) {
absl::StrAppend(&result, "/", e.map_id, "[@", e.key_id, "=", e.key_value,
"]");
}
}
return result;
}
// Returns the message value from a field at an index.
const Message* GetFieldMessage(const Message& message,
const FieldDescriptor* field, int index) {
if (field->type() != FieldDescriptor::TYPE_MESSAGE) {
return nullptr;
}
if (!field->is_repeated()) {
return &message.GetReflection()->GetMessage(message, field);
}
if (index < message.GetReflection()->FieldSize(message, field)) {
return &message.GetReflection()->GetRepeatedMessage(message, field, index);
}
return nullptr;
}
// Returns all FieldDescriptors including extensions.
std::vector<const FieldDescriptor*> GetFields(const Message* src) {
std::vector<const FieldDescriptor*> result;
src->GetDescriptor()->file()->pool()->FindAllExtensions(src->GetDescriptor(),
&result);
for (int i = 0; i < src->GetDescriptor()->field_count(); ++i) {
result.push_back(src->GetDescriptor()->field(i));
}
return result;
}
// Orders map entries in dst to match src.
void OrderMapEntries(const Message* src, Message* dst,
std::set<const Message*>* seen = nullptr) {
std::unique_ptr<std::set<const Message*>> seen_owner;
if (!seen) {
seen_owner = std::make_unique<std::set<const Message*>>();
seen = seen_owner.get();
}
if (seen->count(src) > 0) {
return;
} else {
seen->insert(src);
}
for (auto field : GetFields(src)) {
if (field->is_map()) {
dst->GetReflection()->ClearField(dst, field);
for (int j = 0; j < src->GetReflection()->FieldSize(*src, field); ++j) {
const Message& entry =
src->GetReflection()->GetRepeatedMessage(*src, field, j);
dst->GetReflection()->AddMessage(dst, field)->CopyFrom(entry);
}
}
if (field->type() == FieldDescriptor::TYPE_MESSAGE) {
if (field->is_repeated()) {
for (int j = 0; j < src->GetReflection()->FieldSize(*src, field); ++j) {
OrderMapEntries(
&src->GetReflection()->GetRepeatedMessage(*src, field, j),
dst->GetReflection()->MutableRepeatedMessage(dst, field, j),
seen);
}
} else {
OrderMapEntries(&src->GetReflection()->GetMessage(*src, field),
dst->GetReflection()->MutableMessage(dst, field), seen);
}
}
}
}
// Copies a Message, keeping map entries in order.
std::unique_ptr<Message> CloneMessage(const Message* message) {
std::unique_ptr<Message> result(message->New());
result->CopyFrom(*message);
OrderMapEntries(message, result.get());
return result;
}
using MessageMap = std::map<std::string, std::unique_ptr<Message>>;
// For a non-repeated field, move the most recently parsed field value
// into the most recently parsed template expression.
void StowFieldValue(Message* message, TemplateExpression* expression) {
void StowFieldValue(Message* message, TemplateExpression* expression,
MessageMap* stowed_messages) {
const Reflection* reflection = message->GetReflection();
const Descriptor* descriptor = message->GetDescriptor();
ProtoUtilLite::ProtoPath path;
MEDIAPIPE_CHECK_OK(ProtoPathSplit(expression->path(), &path));
int field_number = path[path.size() - 1].first;
int field_number = path[path.size() - 1].field_id;
const FieldDescriptor* field = descriptor->FindFieldByNumber(field_number);
// Save each stowed message unserialized preserving map entry order.
if (!field->is_repeated() && field->type() == FieldDescriptor::TYPE_MESSAGE) {
(*stowed_messages)[ProtoPathJoin(path)] =
CloneMessage(GetFieldMessage(*message, field, 0));
}
if (!field->is_repeated()) {
std::vector<ProtoUtilLite::FieldValue> field_values;
SerializeField(message, field, &field_values);
@ -1402,6 +1512,112 @@ static void StripQuotes(std::string* str) {
}
}
// Returns the field or extension for field number.
const FieldDescriptor* FindFieldByNumber(const Message* message,
int field_num) {
const FieldDescriptor* result =
message->GetDescriptor()->FindFieldByNumber(field_num);
if (result == nullptr) {
result = message->GetReflection()->FindKnownExtensionByNumber(field_num);
}
return result;
}
// Returns the protobuf map key types from a ProtoPath.
std::vector<FieldType> ProtoPathKeyTypes(ProtoPath path) {
std::vector<FieldType> result;
for (auto& entry : path) {
if (entry.map_id >= 0) {
result.push_back(entry.key_type);
}
}
return result;
}
// Returns the text value for a string or numeric protobuf map key.
std::string GetMapKey(const Message& map_entry) {
auto key_field = map_entry.GetDescriptor()->FindFieldByName("key");
auto reflection = map_entry.GetReflection();
if (key_field->type() == FieldDescriptor::TYPE_STRING) {
return reflection->GetString(map_entry, key_field);
} else if (key_field->type() == FieldDescriptor::TYPE_INT32) {
return absl::StrCat(reflection->GetInt32(map_entry, key_field));
} else if (key_field->type() == FieldDescriptor::TYPE_INT64) {
return absl::StrCat(reflection->GetInt64(map_entry, key_field));
}
return "";
}
// Returns a Message store in CalculatorGraphTemplate::field_value.
Message* FindStowedMessage(MessageMap* stowed_messages, ProtoPath proto_path) {
auto it = stowed_messages->find(ProtoPathJoin(proto_path));
return (it != stowed_messages->end()) ? it->second.get() : nullptr;
}
const Message* GetNestedMessage(const Message& message,
const FieldDescriptor* field,
ProtoPath proto_path,
MessageMap* stowed_messages) {
if (field->type() != FieldDescriptor::TYPE_MESSAGE) {
return nullptr;
}
const Message* result = FindStowedMessage(stowed_messages, proto_path);
if (!result) {
result = GetFieldMessage(message, field, proto_path.back().index);
}
return result;
}
// Adjusts map-entries from indexes to keys.
// Protobuf map-entry order is intentionally not preserved.
absl::Status KeyProtoMapEntries(Message* source, MessageMap* stowed_messages) {
// Copy the rules from the source CalculatorGraphTemplate.
mediapipe::CalculatorGraphTemplate rules;
rules.ParsePartialFromString(source->SerializePartialAsString());
// Only the "source" Message knows all extension types.
Message* config_0 = source->GetReflection()->MutableMessage(
source, source->GetDescriptor()->FindFieldByName("config"), nullptr);
for (int i = 0; i < rules.rule().size(); ++i) {
TemplateExpression* rule = rules.mutable_rule()->Mutable(i);
const Message* message = config_0;
ProtoPath path;
MP_RETURN_IF_ERROR(ProtoPathSplit(rule->path(), &path));
for (int j = 0; j < path.size(); ++j) {
int field_id = path[j].field_id;
const FieldDescriptor* field = FindFieldByNumber(message, field_id);
ProtoPath prefix = {path.begin(), path.begin() + j + 1};
message = GetNestedMessage(*message, field, prefix, stowed_messages);
if (!message) {
break;
}
if (field->is_map()) {
const Message* map_entry = message;
int key_id =
map_entry->GetDescriptor()->FindFieldByName("key")->number();
FieldType key_type = static_cast<ProtoUtilLite::FieldType>(
map_entry->GetDescriptor()->FindFieldByName("key")->type());
std::string key_value = GetMapKey(*map_entry);
path[j] = {field_id, key_id, key_type, key_value};
}
}
if (!rule->path().empty()) {
*rule->mutable_path() = ProtoPathJoin(path);
for (FieldType key_type : ProtoPathKeyTypes(path)) {
*rule->mutable_key_type()->Add() = key_type;
}
}
}
// Copy the rules back into the source CalculatorGraphTemplate.
auto source_rules =
source->GetReflection()->GetMutableRepeatedFieldRef<Message>(
source, source->GetDescriptor()->FindFieldByName("rule"));
source_rules.Clear();
for (auto& rule : rules.rule()) {
source_rules.Add(rule);
}
return absl::OkStatus();
}
} // namespace
class TemplateParser::Parser::MediaPipeParserImpl
@ -1416,6 +1632,8 @@ class TemplateParser::Parser::MediaPipeParserImpl
// Copy the template rules into the output template "rule" field.
success &= MergeFields(template_rules_, output).ok();
// Replace map-entry indexes with map keys.
success &= KeyProtoMapEntries(output, &stowed_messages_).ok();
return success;
}
@ -1441,7 +1659,7 @@ class TemplateParser::Parser::MediaPipeParserImpl
DO(ConsumeFieldTemplate(message));
} else {
DO(ConsumeField(message));
StowFieldValue(message, expression);
StowFieldValue(message, expression, &stowed_messages_);
}
DO(ConsumeEndTemplate());
return true;
@ -1652,6 +1870,7 @@ class TemplateParser::Parser::MediaPipeParserImpl
}
mediapipe::CalculatorGraphTemplate template_rules_;
std::map<std::string, std::unique_ptr<Message>> stowed_messages_;
};
#undef DO

View File

@ -17,10 +17,13 @@ load(
"//mediapipe/framework/tool:mediapipe_graph.bzl",
"mediapipe_simple_subgraph",
)
load("//mediapipe/framework/port:build_config.bzl", "mediapipe_proto_library")
licenses(["notice"])
package(default_visibility = ["//mediapipe:__subpackages__"])
package(default_visibility = [
"//mediapipe:__subpackages__",
])
filegroup(
name = "test_graph",
@ -40,7 +43,6 @@ mediapipe_simple_subgraph(
testonly = 1,
graph = "dub_quad_test_subgraph.pbtxt",
register_as = "DubQuadTestSubgraph",
visibility = ["//visibility:public"],
deps = [
"//mediapipe/framework:test_calculators",
],
@ -51,9 +53,18 @@ mediapipe_simple_subgraph(
testonly = 1,
graph = "nested_test_subgraph.pbtxt",
register_as = "NestedTestSubgraph",
visibility = ["//visibility:public"],
visibility = ["//mediapipe/framework:__subpackages__"],
deps = [
":dub_quad_test_subgraph",
"//mediapipe/framework:test_calculators",
],
)
mediapipe_proto_library(
name = "frozen_generator_proto",
srcs = ["frozen_generator.proto"],
visibility = ["//mediapipe/framework:__subpackages__"],
deps = [
"//mediapipe/framework:packet_generator_proto",
],
)

View File

@ -0,0 +1,20 @@
syntax = "proto2";
package mediapipe;
import "mediapipe/framework/packet_generator.proto";
message FrozenGeneratorOptions {
extend mediapipe.PacketGeneratorOptions {
optional FrozenGeneratorOptions ext = 225748738;
}
// Path to file containing serialized proto of type tensorflow::GraphDef.
optional string graph_proto_path = 1;
// This map defines the which streams are fed to which tensors in the model.
map<string, string> tag_to_tensor_names = 2;
// Graph nodes to run to initialize the model.
repeated string initialization_op_names = 4;
}

View File

@ -369,6 +369,7 @@ absl::Status ValidatedGraphConfig::Initialize(
input_side_packets_.clear();
output_side_packets_.clear();
stream_to_producer_.clear();
output_streams_to_consumer_nodes_.clear();
input_streams_.clear();
output_streams_.clear();
owned_packet_types_.clear();
@ -719,6 +720,15 @@ absl::Status ValidatedGraphConfig::AddInputStreamsForNode(
<< " does not have a corresponding output stream.";
}
}
// Add this node as a consumer of this edge's output stream.
if (edge_info.upstream > -1) {
auto parent_node = output_streams_[edge_info.upstream].parent_node;
if (parent_node.type == NodeTypeInfo::NodeType::CALCULATOR) {
int this_idx = node_type_info->Node().index;
output_streams_to_consumer_nodes_[edge_info.upstream].push_back(
this_idx);
}
}
edge_info.parent_node = node_type_info->Node();
edge_info.name = name;

View File

@ -282,6 +282,14 @@ class ValidatedGraphConfig {
return output_streams_[iter->second].parent_node.index;
}
std::vector<int> OutputStreamToConsumers(int idx) const {
auto iter = output_streams_to_consumer_nodes_.find(idx);
if (iter == output_streams_to_consumer_nodes_.end()) {
return {};
}
return iter->second;
}
// Returns the registered type name of the specified side packet if
// it can be determined, otherwise an appropriate error is returned.
absl::StatusOr<std::string> RegisteredSidePacketTypeName(
@ -418,6 +426,10 @@ class ValidatedGraphConfig {
// Mapping from stream name to the output_streams_ index which produces it.
std::map<std::string, int> stream_to_producer_;
// Mapping from output streams to consumer node ids. Used for profiling.
std::map<int, std::vector<int>> output_streams_to_consumer_nodes_;
// Mapping from side packet name to the output_side_packets_ index
// which produces it.
std::map<std::string, int> side_packet_to_producer_;

View File

@ -289,7 +289,9 @@ cc_library(
deps = [
":gpu_buffer_format",
":gpu_buffer_storage",
"@com_google_absl//absl/functional:bind_front",
"@com_google_absl//absl/strings",
"@com_google_absl//absl/synchronization",
"//mediapipe/framework/formats:image_frame",
"//mediapipe/framework/port:logging",
":gpu_buffer_storage_image_frame",
@ -472,13 +474,13 @@ objc_library(
copts = [
"-Wno-shorten-64-to-32",
],
sdk_frameworks = [
"Accelerate",
"CoreGraphics",
"CoreVideo",
],
visibility = ["//visibility:public"],
deps = ["//mediapipe/objc:util"],
deps = [
"//mediapipe/objc:util",
"//third_party/apple_frameworks:Accelerate",
"//third_party/apple_frameworks:CoreGraphics",
"//third_party/apple_frameworks:CoreVideo",
],
)
objc_library(
@ -510,13 +512,11 @@ objc_library(
"-x objective-c++",
"-Wno-shorten-64-to-32",
],
sdk_frameworks = [
"CoreVideo",
"Metal",
],
visibility = ["//visibility:public"],
deps = [
"//mediapipe/objc:mediapipe_framework_ios",
"//third_party/apple_frameworks:CoreVideo",
"//third_party/apple_frameworks:Metal",
"@com_google_absl//absl/time",
"@google_toolbox_for_mac//:GTM_Defines",
],
@ -808,15 +808,13 @@ objc_library(
"-Wno-shorten-64-to-32",
],
features = ["-layering_check"],
sdk_frameworks = [
"CoreVideo",
"Metal",
],
visibility = ["//visibility:public"],
deps = [
":gpu_shared_data_internal",
":graph_support",
"//mediapipe/objc:mediapipe_framework_ios",
"//third_party/apple_frameworks:CoreVideo",
"//third_party/apple_frameworks:Metal",
"@google_toolbox_for_mac//:GTM_Defines",
],
)
@ -1020,16 +1018,14 @@ objc_library(
name = "metal_copy_calculator",
srcs = ["MetalCopyCalculator.mm"],
features = ["-layering_check"],
sdk_frameworks = [
"CoreVideo",
"Metal",
],
visibility = ["//visibility:public"],
deps = [
":MPPMetalHelper",
":simple_shaders_mtl",
"//mediapipe/gpu:copy_calculator_cc_proto",
"//mediapipe/objc:mediapipe_framework_ios",
"//third_party/apple_frameworks:CoreVideo",
"//third_party/apple_frameworks:Metal",
],
alwayslink = 1,
)
@ -1038,15 +1034,13 @@ objc_library(
name = "metal_rgb_weight_calculator",
srcs = ["MetalRgbWeightCalculator.mm"],
features = ["-layering_check"],
sdk_frameworks = [
"CoreVideo",
"Metal",
],
visibility = ["//visibility:public"],
deps = [
":MPPMetalHelper",
":simple_shaders_mtl",
"//mediapipe/objc:mediapipe_framework_ios",
"//third_party/apple_frameworks:CoreVideo",
"//third_party/apple_frameworks:Metal",
],
alwayslink = 1,
)
@ -1055,15 +1049,13 @@ objc_library(
name = "metal_sobel_calculator",
srcs = ["MetalSobelCalculator.mm"],
features = ["-layering_check"],
sdk_frameworks = [
"CoreVideo",
"Metal",
],
visibility = ["//visibility:public"],
deps = [
":MPPMetalHelper",
":simple_shaders_mtl",
"//mediapipe/objc:mediapipe_framework_ios",
"//third_party/apple_frameworks:CoreVideo",
"//third_party/apple_frameworks:Metal",
],
alwayslink = 1,
)
@ -1072,15 +1064,13 @@ objc_library(
name = "metal_sobel_compute_calculator",
srcs = ["MetalSobelComputeCalculator.mm"],
features = ["-layering_check"],
sdk_frameworks = [
"CoreVideo",
"Metal",
],
visibility = ["//visibility:public"],
deps = [
":MPPMetalHelper",
":simple_shaders_mtl",
"//mediapipe/objc:mediapipe_framework_ios",
"//third_party/apple_frameworks:CoreVideo",
"//third_party/apple_frameworks:Metal",
],
alwayslink = 1,
)
@ -1090,15 +1080,13 @@ objc_library(
srcs = ["MPSSobelCalculator.mm"],
copts = ["-std=c++17"],
features = ["-layering_check"],
sdk_frameworks = [
"CoreVideo",
"Metal",
"MetalPerformanceShaders",
],
visibility = ["//visibility:public"],
deps = [
":MPPMetalHelper",
"//mediapipe/objc:mediapipe_framework_ios",
"//third_party/apple_frameworks:CoreVideo",
"//third_party/apple_frameworks:Metal",
"//third_party/apple_frameworks:MetalPerformanceShaders",
],
alwayslink = 1,
)
@ -1106,15 +1094,13 @@ objc_library(
objc_library(
name = "mps_threshold_calculator",
srcs = ["MPSThresholdCalculator.mm"],
sdk_frameworks = [
"CoreVideo",
"Metal",
"MetalPerformanceShaders",
],
visibility = ["//visibility:public"],
deps = [
":MPPMetalHelper",
"//mediapipe/objc:mediapipe_framework_ios",
"//third_party/apple_frameworks:CoreVideo",
"//third_party/apple_frameworks:Metal",
"//third_party/apple_frameworks:MetalPerformanceShaders",
],
alwayslink = 1,
)

View File

@ -3,6 +3,7 @@
#include <memory>
#include <utility>
#include "absl/functional/bind_front.h"
#include "absl/strings/str_cat.h"
#include "absl/strings/str_join.h"
#include "mediapipe/framework/port/logging.h"
@ -25,57 +26,101 @@ struct StorageTypeFormatter {
} // namespace
std::string GpuBuffer::DebugString() const {
return absl::StrCat("GpuBuffer[",
absl::StrJoin(storages_, ", ", StorageTypeFormatter()),
"]");
return holder_ ? absl::StrCat("GpuBuffer[", width(), "x", height(), " ",
format(), " as ", holder_->DebugString(), "]")
: "GpuBuffer[invalid]";
}
internal::GpuBufferStorage* GpuBuffer::GetStorageForView(
std::string GpuBuffer::StorageHolder::DebugString() const {
absl::MutexLock lock(&mutex_);
return absl::StrJoin(storages_, ", ", StorageTypeFormatter());
}
internal::GpuBufferStorage* GpuBuffer::StorageHolder::GetStorageForView(
TypeId view_provider_type, bool for_writing) const {
const std::shared_ptr<internal::GpuBufferStorage>* chosen_storage = nullptr;
std::shared_ptr<internal::GpuBufferStorage> chosen_storage;
std::function<std::shared_ptr<internal::GpuBufferStorage>()> conversion;
// First see if any current storage supports the view.
for (const auto& s : storages_) {
if (s->can_down_cast_to(view_provider_type)) {
chosen_storage = &s;
break;
}
}
// Then try to convert existing storages to one that does.
// TODO: choose best conversion.
if (!chosen_storage) {
{
absl::MutexLock lock(&mutex_);
// First see if any current storage supports the view.
for (const auto& s : storages_) {
if (auto converter = internal::GpuBufferStorageRegistry::Get()
.StorageConverterForViewProvider(
view_provider_type, s->storage_type())) {
if (auto new_storage = converter(s)) {
storages_.push_back(new_storage);
chosen_storage = &storages_.back();
if (s->can_down_cast_to(view_provider_type)) {
chosen_storage = s;
break;
}
}
// Then try to convert existing storages to one that does.
// TODO: choose best conversion.
if (!chosen_storage) {
for (const auto& s : storages_) {
if (auto converter = internal::GpuBufferStorageRegistry::Get()
.StorageConverterForViewProvider(
view_provider_type, s->storage_type())) {
conversion = absl::bind_front(converter, s);
break;
}
}
}
}
// Avoid invoking a converter or factory while holding the mutex.
// Two reasons:
// 1. Readers that don't need a conversion will not be blocked.
// 2. We use mutexes to make sure GL contexts are not used simultaneously on
// different threads, and we also rely on Mutex's deadlock detection
// heuristic, which enforces a consistent mutex acquisition order.
// This function is likely to be called within a GL context, and the
// conversion function may in turn use a GL context, and this may cause a
// false positive in the deadlock detector.
// TODO: we could use Mutex::ForgetDeadlockInfo instead.
if (conversion) {
auto new_storage = conversion();
absl::MutexLock lock(&mutex_);
// Another reader might have already completed and inserted the same
// conversion. TODO: prevent this?
for (const auto& s : storages_) {
if (s->can_down_cast_to(view_provider_type)) {
chosen_storage = s;
break;
}
}
if (!chosen_storage) {
storages_.push_back(std::move(new_storage));
chosen_storage = storages_.back();
}
}
if (for_writing) {
// This will temporarily hold storages to be released, and do so while the
// lock is not held (see above).
decltype(storages_) old_storages;
using std::swap;
if (chosen_storage) {
// Discard all other storages.
storages_ = {*chosen_storage};
chosen_storage = &storages_.back();
absl::MutexLock lock(&mutex_);
swap(old_storages, storages_);
storages_ = {chosen_storage};
} else {
// Allocate a new storage supporting the requested view.
if (auto factory =
internal::GpuBufferStorageRegistry::Get()
.StorageFactoryForViewProvider(view_provider_type)) {
if (auto new_storage = factory(width(), height(), format())) {
if (auto new_storage = factory(width_, height_, format_)) {
absl::MutexLock lock(&mutex_);
swap(old_storages, storages_);
storages_ = {std::move(new_storage)};
chosen_storage = &storages_.back();
chosen_storage = storages_.back();
}
}
}
}
return chosen_storage ? chosen_storage->get() : nullptr;
// It is ok to return a non-owning storage pointer here because this object
// ensures the storage's lifetime. Overwriting a GpuBuffer while readers are
// active would violate this, but it's not allowed in MediaPipe.
return chosen_storage ? chosen_storage.get() : nullptr;
}
internal::GpuBufferStorage& GpuBuffer::GetStorageForViewOrDie(
@ -84,8 +129,7 @@ internal::GpuBufferStorage& GpuBuffer::GetStorageForViewOrDie(
GpuBuffer::GetStorageForView(view_provider_type, for_writing);
CHECK(chosen_storage) << "no view provider found for requested view "
<< view_provider_type.name() << "; storages available: "
<< absl::StrJoin(storages_, ", ",
StorageTypeFormatter());
<< (holder_ ? holder_->DebugString() : "invalid");
DCHECK(chosen_storage->can_down_cast_to(view_provider_type));
return *chosen_storage;
}

View File

@ -15,9 +15,12 @@
#ifndef MEDIAPIPE_GPU_GPU_BUFFER_H_
#define MEDIAPIPE_GPU_GPU_BUFFER_H_
#include <algorithm>
#include <functional>
#include <memory>
#include <utility>
#include "absl/synchronization/mutex.h"
#include "mediapipe/framework/formats/image_frame.h"
#include "mediapipe/gpu/gpu_buffer_format.h"
#include "mediapipe/gpu/gpu_buffer_storage.h"
@ -56,8 +59,7 @@ class GpuBuffer {
// Creates an empty buffer of a given size and format. It will be allocated
// when a view is requested.
GpuBuffer(int width, int height, Format format)
: GpuBuffer(std::make_shared<PlaceholderGpuBufferStorage>(width, height,
format)) {}
: holder_(std::make_shared<StorageHolder>(width, height, format)) {}
// Copy and move constructors and assignment operators are supported.
GpuBuffer(const GpuBuffer& other) = default;
@ -70,9 +72,8 @@ class GpuBuffer {
// are not portable. Applications and calculators should normally obtain
// GpuBuffers in a portable way from the framework, e.g. using
// GpuBufferMultiPool.
explicit GpuBuffer(std::shared_ptr<internal::GpuBufferStorage> storage) {
storages_.push_back(std::move(storage));
}
explicit GpuBuffer(std::shared_ptr<internal::GpuBufferStorage> storage)
: holder_(std::make_shared<StorageHolder>(std::move(storage))) {}
#if !MEDIAPIPE_DISABLE_GPU && MEDIAPIPE_GPU_BUFFER_USE_CV_PIXEL_BUFFER
// This is used to support backward-compatible construction of GpuBuffer from
@ -84,9 +85,11 @@ class GpuBuffer {
: GpuBuffer(internal::AsGpuBufferStorage(storage_convertible)) {}
#endif // !MEDIAPIPE_DISABLE_GPU && MEDIAPIPE_GPU_BUFFER_USE_CV_PIXEL_BUFFER
int width() const { return current_storage().width(); }
int height() const { return current_storage().height(); }
GpuBufferFormat format() const { return current_storage().format(); }
int width() const { return holder_ ? holder_->width() : 0; }
int height() const { return holder_ ? holder_->height() : 0; }
GpuBufferFormat format() const {
return holder_ ? holder_->format() : GpuBufferFormat::kUnknown;
}
// Converts to true iff valid.
explicit operator bool() const { return operator!=(nullptr); }
@ -122,31 +125,17 @@ class GpuBuffer {
// using views.
template <class T>
std::shared_ptr<T> internal_storage() const {
for (const auto& s : storages_)
if (s->down_cast<T>()) return std::static_pointer_cast<T>(s);
return nullptr;
return holder_ ? holder_->internal_storage<T>() : nullptr;
}
std::string DebugString() const;
private:
class PlaceholderGpuBufferStorage
: public internal::GpuBufferStorageImpl<PlaceholderGpuBufferStorage> {
public:
PlaceholderGpuBufferStorage(int width, int height, Format format)
: width_(width), height_(height), format_(format) {}
int width() const override { return width_; }
int height() const override { return height_; }
GpuBufferFormat format() const override { return format_; }
private:
int width_ = 0;
int height_ = 0;
GpuBufferFormat format_ = GpuBufferFormat::kUnknown;
};
internal::GpuBufferStorage* GetStorageForView(TypeId view_provider_type,
bool for_writing) const;
bool for_writing) const {
return holder_ ? holder_->GetStorageForView(view_provider_type, for_writing)
: nullptr;
}
internal::GpuBufferStorage& GetStorageForViewOrDie(TypeId view_provider_type,
bool for_writing) const;
@ -158,25 +147,49 @@ class GpuBuffer {
.template down_cast<VP>();
}
std::shared_ptr<internal::GpuBufferStorage>& no_storage() const {
static auto placeholder =
std::static_pointer_cast<internal::GpuBufferStorage>(
std::make_shared<PlaceholderGpuBufferStorage>(
0, 0, GpuBufferFormat::kUnknown));
return placeholder;
}
// This class manages a set of alternative storages for the contents of a
// GpuBuffer. GpuBuffer was originally designed as a reference-type object,
// where a copy represents another reference to the same contents, so multiple
// GpuBuffer instances can share the same StorageHolder.
class StorageHolder {
public:
explicit StorageHolder(std::shared_ptr<internal::GpuBufferStorage> storage)
: StorageHolder(storage->width(), storage->height(),
storage->format()) {
storages_.push_back(std::move(storage));
}
explicit StorageHolder(int width, int height, Format format)
: width_(width), height_(height), format_(format) {}
const internal::GpuBufferStorage& current_storage() const {
return storages_.empty() ? *no_storage() : *storages_[0];
}
int width() const { return width_; }
int height() const { return height_; }
GpuBufferFormat format() const { return format_; }
internal::GpuBufferStorage& current_storage() {
return storages_.empty() ? *no_storage() : *storages_[0];
}
internal::GpuBufferStorage* GetStorageForView(TypeId view_provider_type,
bool for_writing) const;
// This is mutable because view methods that do not change the contents may
// still need to allocate new storages.
mutable std::vector<std::shared_ptr<internal::GpuBufferStorage>> storages_;
template <class T>
std::shared_ptr<T> internal_storage() const {
absl::MutexLock lock(&mutex_);
for (const auto& s : storages_)
if (s->down_cast<T>()) return std::static_pointer_cast<T>(s);
return nullptr;
}
std::string DebugString() const;
private:
int width_ = 0;
int height_ = 0;
GpuBufferFormat format_ = GpuBufferFormat::kUnknown;
// This is mutable because view methods that do not change the contents may
// still need to allocate new storages.
mutable absl::Mutex mutex_;
mutable std::vector<std::shared_ptr<internal::GpuBufferStorage>> storages_
ABSL_GUARDED_BY(mutex_);
};
std::shared_ptr<StorageHolder> holder_;
#if MEDIAPIPE_GPU_BUFFER_USE_CV_PIXEL_BUFFER
friend CVPixelBufferRef GetCVPixelBufferRef(const GpuBuffer& buffer);
@ -184,15 +197,15 @@ class GpuBuffer {
};
inline bool GpuBuffer::operator==(std::nullptr_t other) const {
return storages_.empty();
return holder_ == other;
}
inline bool GpuBuffer::operator==(const GpuBuffer& other) const {
return storages_ == other.storages_;
return holder_ == other.holder_;
}
inline GpuBuffer& GpuBuffer::operator=(std::nullptr_t other) {
storages_.clear();
holder_ = other;
return *this;
}

View File

@ -20,6 +20,7 @@
#include "mediapipe/framework/port/gmock.h"
#include "mediapipe/framework/port/gtest.h"
#include "mediapipe/framework/tool/test_util.h"
#include "mediapipe/gpu/gl_texture_buffer.h"
#include "mediapipe/gpu/gl_texture_util.h"
#include "mediapipe/gpu/gpu_buffer_storage_ahwb.h"
#include "mediapipe/gpu/gpu_buffer_storage_image_frame.h"
@ -228,5 +229,26 @@ TEST_F(GpuBufferTest, GlTextureViewRetainsWhatItNeeds) {
EXPECT_TRUE(true);
}
TEST_F(GpuBufferTest, CopiesShareConversions) {
GpuBuffer buffer(300, 200, GpuBufferFormat::kBGRA32);
{
std::shared_ptr<ImageFrame> view = buffer.GetWriteView<ImageFrame>();
FillImageFrameRGBA(*view, 255, 0, 0, 255);
}
GpuBuffer other_handle = buffer;
RunInGlContext([&buffer] {
TempGlFramebuffer fb;
auto view = buffer.GetReadView<GlTextureView>(0);
});
// Check that other_handle also sees the same GlTextureBuffer as buffer.
// Note that this is deliberately written so that it still passes on platforms
// where we use another storage for GL textures (they will both be null).
// TODO: expose more accessors for testing?
EXPECT_EQ(other_handle.internal_storage<GlTextureBuffer>(),
buffer.internal_storage<GlTextureBuffer>());
}
} // anonymous namespace
} // namespace mediapipe

View File

@ -0,0 +1,40 @@
#ifndef MEDIAPIPE_GPU_METAL_SHARED_RESOURCES_H_
#define MEDIAPIPE_GPU_METAL_SHARED_RESOURCES_H_
#import <CoreVideo/CVMetalTextureCache.h>
#import <CoreVideo/CoreVideo.h>
#import <Foundation/NSObject.h>
#import <Metal/Metal.h>
#ifndef __OBJC__
#error This class must be built as Objective-C++.
#endif // !__OBJC__
@interface MPPMetalSharedResources : NSObject {
}
- (instancetype)init NS_DESIGNATED_INITIALIZER;
@property(readonly) id<MTLDevice> mtlDevice;
@property(readonly) id<MTLCommandQueue> mtlCommandQueue;
#if COREVIDEO_SUPPORTS_METAL
@property(readonly) CVMetalTextureCacheRef mtlTextureCache;
#endif
@end
namespace mediapipe {
class MetalSharedResources {
public:
MetalSharedResources();
~MetalSharedResources();
MPPMetalSharedResources* resources() { return resources_; }
private:
MPPMetalSharedResources* resources_;
};
} // namespace mediapipe
#endif // MEDIAPIPE_GPU_METAL_SHARED_RESOURCES_H_

View File

@ -0,0 +1,73 @@
#import "mediapipe/gpu/metal_shared_resources.h"
@interface MPPMetalSharedResources ()
@end
@implementation MPPMetalSharedResources {
}
@synthesize mtlDevice = _mtlDevice;
@synthesize mtlCommandQueue = _mtlCommandQueue;
#if COREVIDEO_SUPPORTS_METAL
@synthesize mtlTextureCache = _mtlTextureCache;
#endif
- (instancetype)init {
self = [super init];
if (self) {
}
return self;
}
- (void)dealloc {
#if COREVIDEO_SUPPORTS_METAL
if (_mtlTextureCache) {
CFRelease(_mtlTextureCache);
_mtlTextureCache = NULL;
}
#endif
}
- (id<MTLDevice>)mtlDevice {
@synchronized(self) {
if (!_mtlDevice) {
_mtlDevice = MTLCreateSystemDefaultDevice();
}
}
return _mtlDevice;
}
- (id<MTLCommandQueue>)mtlCommandQueue {
@synchronized(self) {
if (!_mtlCommandQueue) {
_mtlCommandQueue = [self.mtlDevice newCommandQueue];
}
}
return _mtlCommandQueue;
}
#if COREVIDEO_SUPPORTS_METAL
- (CVMetalTextureCacheRef)mtlTextureCache {
@synchronized(self) {
if (!_mtlTextureCache) {
CVReturn __unused err =
CVMetalTextureCacheCreate(NULL, NULL, self.mtlDevice, NULL, &_mtlTextureCache);
NSAssert(err == kCVReturnSuccess, @"Error at CVMetalTextureCacheCreate %d ; device %@", err,
self.mtlDevice);
// TODO: register and flush metal caches too.
}
}
return _mtlTextureCache;
}
#endif
@end
namespace mediapipe {
MetalSharedResources::MetalSharedResources() {
resources_ = [[MPPMetalSharedResources alloc] init];
}
MetalSharedResources::~MetalSharedResources() {}
} // namespace mediapipe

View File

@ -0,0 +1,49 @@
// Copyright 2019 The MediaPipe Authors.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#import <UIKit/UIKit.h>
#import <XCTest/XCTest.h>
#include <memory>
#include "absl/memory/memory.h"
#include "mediapipe/framework/port/threadpool.h"
#import "mediapipe/gpu/gpu_shared_data_internal.h"
#import "mediapipe/gpu/metal_shared_resources.h"
@interface MPPMetalSharedResourcesTests : XCTestCase {
}
@end
@implementation MPPMetalSharedResourcesTests
// This test verifies that the internal Objective-C object is correctly
// released when the C++ wrapper is released.
- (void)testCorrectlyReleased {
__weak id metalRes = nil;
std::weak_ptr<mediapipe::GpuResources> weakGpuRes;
@autoreleasepool {
auto maybeGpuRes = mediapipe::GpuResources::Create();
XCTAssertTrue(maybeGpuRes.ok());
weakGpuRes = *maybeGpuRes;
metalRes = (**maybeGpuRes).metal_shared().resources();
XCTAssertNotEqual(weakGpuRes.lock(), nullptr);
XCTAssertNotNil(metalRes);
}
XCTAssertEqual(weakGpuRes.lock(), nullptr);
XCTAssertNil(metalRes);
}
@end

View File

@ -43,6 +43,7 @@ cc_library(
deps = [
"//mediapipe/calculators/core:flow_limiter_calculator",
"//mediapipe/calculators/core:previous_loopback_calculator",
"//mediapipe/calculators/image:color_convert_calculator",
"//mediapipe/calculators/image:image_transformation_calculator",
"//mediapipe/calculators/image:recolor_calculator",
"//mediapipe/calculators/image:set_alpha_calculator",

View File

@ -60,7 +60,14 @@ node {
tag_index: "LOOP"
back_edge: true
}
output_stream: "PREV_LOOP:previous_hair_mask"
output_stream: "PREV_LOOP:previous_hair_mask_rgb"
}
# Converts the 4 channel hair mask to a single channel mask
node {
calculator: "ColorConvertCalculator"
input_stream: "RGB_IN:previous_hair_mask_rgb"
output_stream: "GRAY_OUT:previous_hair_mask"
}
# Embeds the hair mask generated from the previous round of hair segmentation

View File

@ -97,7 +97,6 @@ cc_library(
visibility = ["//visibility:public"],
deps = [
"//mediapipe/framework:calculator_framework",
"//mediapipe/framework/formats:image_file_properties_cc_proto",
"//mediapipe/framework/formats:landmark_cc_proto",
"//mediapipe/framework/port:ret_check",
"//mediapipe/framework/port:status",

View File

@ -22,6 +22,7 @@ package(default_visibility = ["//visibility:public"])
mediapipe_proto_library(
name = "gl_animation_overlay_calculator_proto",
srcs = ["gl_animation_overlay_calculator.proto"],
def_options_lib = False,
visibility = ["//visibility:public"],
exports = [
"//mediapipe/gpu:gl_animation_overlay_calculator_proto",

View File

@ -84,12 +84,11 @@ cc_library(
deps = [
":class_registry",
":jni_util",
"//mediapipe/framework:calculator_framework",
"//mediapipe/framework:calculator_profile_cc_proto",
"//mediapipe/framework/tool:calculator_graph_template_cc_proto",
"//mediapipe/framework/formats:image_format_cc_proto",
"//mediapipe/framework/formats:matrix_data_cc_proto",
"//mediapipe/framework/formats:time_series_header_cc_proto",
"//mediapipe/framework:calculator_cc_proto",
"//mediapipe/framework:calculator_profile_cc_proto",
"//mediapipe/framework:calculator_framework",
"@com_google_absl//absl/strings",
"@com_google_absl//absl/strings:str_format",
"@com_google_absl//absl/synchronization",

View File

@ -0,0 +1 @@
recursive-include pip_src/mediapipe_model_maker/models *

View File

@ -11,3 +11,9 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from mediapipe.model_maker.python.core.utils import quantization
from mediapipe.model_maker.python.vision import image_classifier
from mediapipe.model_maker.python.vision import gesture_recognizer
from mediapipe.model_maker.python.text import text_classifier

View File

@ -19,7 +19,7 @@ import os
def get_absolute_path(file_path: str) -> str:
"""Gets the absolute path of a file.
"""Gets the absolute path of a file in the model_maker directory.
Args:
file_path: The path to a file relative to the `mediapipe` dir
@ -27,10 +27,17 @@ def get_absolute_path(file_path: str) -> str:
Returns:
The full path of the file
"""
# Extract the file path before mediapipe/ as the `base_dir`. By joining it
# with the `path` which defines the relative path under mediapipe/, it
# yields to the absolute path of the model files directory.
# Extract the file path before and including 'model_maker' as the
# `mm_base_dir`. By joining it with the `path` after 'model_maker/', it
# yields to the absolute path of the model files directory. We must join
# on 'model_maker' because in the pypi package, the 'model_maker' directory
# is renamed to 'mediapipe_model_maker'. So we have to join on model_maker
# to ensure that the `mm_base_dir` path includes the renamed
# 'mediapipe_model_maker' directory.
cwd = os.path.dirname(__file__)
base_dir = cwd[:cwd.rfind('mediapipe')]
absolute_path = os.path.join(base_dir, file_path)
cwd_stop_idx = cwd.rfind('model_maker') + len('model_maker')
mm_base_dir = cwd[:cwd_stop_idx]
file_path_start_idx = file_path.find('model_maker') + len('model_maker') + 1
mm_relative_path = file_path[file_path_start_idx:]
absolute_path = os.path.join(mm_base_dir, mm_relative_path)
return absolute_path

View File

@ -4,21 +4,10 @@
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
load("//mediapipe/framework/port:build_config.bzl", "mediapipe_proto_library")
package(default_visibility = ["//mediapipe/tasks:internal"])
licenses(["notice"])
mediapipe_proto_library(
name = "segmenter_options_proto",
srcs = ["segmenter_options.proto"],
)

View File

@ -35,20 +35,21 @@ py_library(
srcs = ["constants.py"],
)
# TODO: Change to py_library after migrating the MediaPipe hand solution
# library to MediaPipe hand task library.
py_library(
name = "dataset",
srcs = ["dataset.py"],
deps = [
":constants",
":metadata_writer",
"//mediapipe/model_maker/python/core/data:classification_dataset",
"//mediapipe/model_maker/python/core/data:data_util",
"//mediapipe/model_maker/python/core/utils:model_util",
"//mediapipe/python/solutions:hands",
"//mediapipe/python:_framework_bindings",
"//mediapipe/tasks/python/core:base_options",
"//mediapipe/tasks/python/vision:hand_landmarker",
],
)
# TODO: Remove notsan tag once tasks no longer has race condition issue
py_test(
name = "dataset_test",
srcs = ["dataset_test.py"],
@ -56,10 +57,11 @@ py_test(
":testdata",
"//mediapipe/model_maker/models/gesture_recognizer:models",
],
tags = ["notsan"],
deps = [
":dataset",
"//mediapipe/python/solutions:hands",
"//mediapipe/tasks/python/test:test_utils",
"//mediapipe/tasks/python/vision:hand_landmarker",
],
)
@ -131,6 +133,7 @@ py_library(
],
)
# TODO: Remove notsan tag once tasks no longer has race condition issue
py_test(
name = "gesture_recognizer_test",
size = "large",
@ -140,6 +143,7 @@ py_test(
"//mediapipe/model_maker/models/gesture_recognizer:models",
],
shard_count = 2,
tags = ["notsan"],
deps = [
":gesture_recognizer_import",
"//mediapipe/model_maker/python/core/utils:test_util",

View File

@ -16,16 +16,22 @@
import dataclasses
import os
import random
from typing import List, NamedTuple, Optional
from typing import List, Optional
import cv2
import tensorflow as tf
from mediapipe.model_maker.python.core.data import classification_dataset
from mediapipe.model_maker.python.core.data import data_util
from mediapipe.model_maker.python.core.utils import model_util
from mediapipe.model_maker.python.vision.gesture_recognizer import constants
from mediapipe.python.solutions import hands as mp_hands
from mediapipe.model_maker.python.vision.gesture_recognizer import metadata_writer
from mediapipe.python._framework_bindings import image as image_module
from mediapipe.tasks.python.core import base_options as base_options_module
from mediapipe.tasks.python.vision import hand_landmarker as hand_landmarker_module
_Image = image_module.Image
_HandLandmarker = hand_landmarker_module.HandLandmarker
_HandLandmarkerOptions = hand_landmarker_module.HandLandmarkerOptions
_HandLandmarkerResult = hand_landmarker_module.HandLandmarkerResult
@dataclasses.dataclass
@ -59,7 +65,7 @@ class HandData:
handedness: List[float]
def _validate_data_sample(data: NamedTuple) -> bool:
def _validate_data_sample(data: _HandLandmarkerResult) -> bool:
"""Validates the input hand data sample.
Args:
@ -70,19 +76,17 @@ def _validate_data_sample(data: NamedTuple) -> bool:
'multi_hand_landmarks' or 'multi_hand_world_landmarks' or 'multi_handedness'
or any of these attributes' values are none. Otherwise, True.
"""
if (not hasattr(data, 'multi_hand_landmarks') or
data.multi_hand_landmarks is None):
if data.hand_landmarks is None or not data.hand_landmarks:
return False
if (not hasattr(data, 'multi_hand_world_landmarks') or
data.multi_hand_world_landmarks is None):
if data.hand_world_landmarks is None or not data.hand_world_landmarks:
return False
if not hasattr(data, 'multi_handedness') or data.multi_handedness is None:
if data.handedness is None or not data.handedness:
return False
return True
def _get_hand_data(all_image_paths: List[str],
min_detection_confidence: float) -> Optional[HandData]:
min_detection_confidence: float) -> List[Optional[HandData]]:
"""Computes hand data (landmarks and handedness) in the input image.
Args:
@ -93,28 +97,36 @@ def _get_hand_data(all_image_paths: List[str],
A HandData object. Returns None if no hand is detected.
"""
hand_data_result = []
with mp_hands.Hands(
static_image_mode=True,
max_num_hands=1,
min_detection_confidence=min_detection_confidence) as hands:
hand_detector_model_buffer = model_util.load_tflite_model_buffer(
constants.HAND_DETECTOR_TFLITE_MODEL_FILE)
hand_landmarks_detector_model_buffer = model_util.load_tflite_model_buffer(
constants.HAND_LANDMARKS_DETECTOR_TFLITE_MODEL_FILE)
hand_landmarker_writer = metadata_writer.HandLandmarkerMetadataWriter(
hand_detector_model_buffer, hand_landmarks_detector_model_buffer)
hand_landmarker_options = _HandLandmarkerOptions(
base_options=base_options_module.BaseOptions(
model_asset_buffer=hand_landmarker_writer.populate()),
num_hands=1,
min_hand_detection_confidence=min_detection_confidence,
min_hand_presence_confidence=0.5,
min_tracking_confidence=1,
)
with _HandLandmarker.create_from_options(
hand_landmarker_options) as hand_landmarker:
for path in all_image_paths:
tf.compat.v1.logging.info('Loading image %s', path)
image = data_util.load_image(path)
# Flip image around y-axis for correct handedness output
image = cv2.flip(image, 1)
data = hands.process(image)
image = _Image.create_from_file(path)
data = hand_landmarker.detect(image)
if not _validate_data_sample(data):
hand_data_result.append(None)
continue
hand_landmarks = [[
hand_landmark.x, hand_landmark.y, hand_landmark.z
] for hand_landmark in data.multi_hand_landmarks[0].landmark]
hand_landmarks = [[hand_landmark.x, hand_landmark.y, hand_landmark.z]
for hand_landmark in data.hand_landmarks[0]]
hand_world_landmarks = [[
hand_landmark.x, hand_landmark.y, hand_landmark.z
] for hand_landmark in data.multi_hand_world_landmarks[0].landmark]
] for hand_landmark in data.hand_world_landmarks[0]]
handedness_scores = [
handedness.score
for handedness in data.multi_handedness[0].classification
handedness.score for handedness in data.handedness[0]
]
hand_data_result.append(
HandData(

View File

@ -12,21 +12,17 @@
# See the License for the specific language governing permissions and
# limitations under the License.
import collections
import os
import shutil
from typing import NamedTuple
import unittest
from absl import flags
from absl.testing import parameterized
import tensorflow as tf
from mediapipe.model_maker.python.vision.gesture_recognizer import dataset
from mediapipe.python.solutions import hands as mp_hands
from mediapipe.tasks.python.test import test_utils
FLAGS = flags.FLAGS
from mediapipe.tasks.python.vision import hand_landmarker
_TEST_DATA_DIRNAME = 'raw_data'
@ -39,14 +35,14 @@ class DatasetTest(tf.test.TestCase, parameterized.TestCase):
dirname=input_data_dir, hparams=dataset.HandDataPreprocessingParams())
train_data, test_data = data.split(0.5)
self.assertLen(train_data, 17)
self.assertLen(train_data, 16)
for _, elem in enumerate(train_data.gen_tf_dataset(is_training=True)):
self.assertEqual(elem[0].shape, (1, 128))
self.assertEqual(elem[1].shape, ([1, 4]))
self.assertEqual(train_data.num_classes, 4)
self.assertEqual(train_data.label_names, ['none', 'call', 'four', 'rock'])
self.assertLen(test_data, 18)
self.assertLen(test_data, 16)
for _, elem in enumerate(test_data.gen_tf_dataset(is_training=True)):
self.assertEqual(elem[0].shape, (1, 128))
self.assertEqual(elem[1].shape, ([1, 4]))
@ -60,7 +56,7 @@ class DatasetTest(tf.test.TestCase, parameterized.TestCase):
for _, elem in enumerate(data.gen_tf_dataset(is_training=True)):
self.assertEqual(elem[0].shape, (1, 128))
self.assertEqual(elem[1].shape, ([1, 4]))
self.assertLen(data, 35)
self.assertLen(data, 32)
self.assertEqual(data.num_classes, 4)
self.assertEqual(data.label_names, ['none', 'call', 'four', 'rock'])
@ -105,51 +101,42 @@ class DatasetTest(tf.test.TestCase, parameterized.TestCase):
for _, elem in enumerate(data.gen_tf_dataset(is_training=True)):
self.assertEqual(elem[0].shape, (1, 128))
self.assertEqual(elem[1].shape, ([1, 4]))
self.assertLen(data, 35)
self.assertLen(data, 32)
self.assertEqual(data.num_classes, 4)
self.assertEqual(data.label_names, ['NONE', 'CALL', 'FOUR', 'ROCK'])
@parameterized.named_parameters(
dict(
testcase_name='invalid_field_name_multi_hand_landmark',
hand=collections.namedtuple('Hand', [
'multi_hand_landmark', 'multi_hand_world_landmarks',
'multi_handedness'
])(1, 2, 3)),
testcase_name='none_handedness',
hand=hand_landmarker.HandLandmarkerResult(
handedness=None, hand_landmarks=[[2]],
hand_world_landmarks=[[3]])),
dict(
testcase_name='invalid_field_name_multi_hand_world_landmarks',
hand=collections.namedtuple('Hand', [
'multi_hand_landmarks', 'multi_hand_world_landmark',
'multi_handedness'
])(1, 2, 3)),
testcase_name='none_hand_landmarks',
hand=hand_landmarker.HandLandmarkerResult(
handedness=[[1]], hand_landmarks=None,
hand_world_landmarks=[[3]])),
dict(
testcase_name='invalid_field_name_multi_handed',
hand=collections.namedtuple('Hand', [
'multi_hand_landmarks', 'multi_hand_world_landmarks',
'multi_handed'
])(1, 2, 3)),
testcase_name='none_hand_world_landmarks',
hand=hand_landmarker.HandLandmarkerResult(
handedness=[[1]], hand_landmarks=[[2]],
hand_world_landmarks=None)),
dict(
testcase_name='multi_hand_landmarks_is_none',
hand=collections.namedtuple('Hand', [
'multi_hand_landmarks', 'multi_hand_world_landmarks',
'multi_handedness'
])(None, 2, 3)),
testcase_name='empty_handedness',
hand=hand_landmarker.HandLandmarkerResult(
handedness=[], hand_landmarks=[[2]], hand_world_landmarks=[[3]])),
dict(
testcase_name='multi_hand_world_landmarks_is_none',
hand=collections.namedtuple('Hand', [
'multi_hand_landmarks', 'multi_hand_world_landmarks',
'multi_handedness'
])(1, None, 3)),
testcase_name='empty_hand_landmarks',
hand=hand_landmarker.HandLandmarkerResult(
handedness=[[1]], hand_landmarks=[], hand_world_landmarks=[[3]])),
dict(
testcase_name='multi_handedness_is_none',
hand=collections.namedtuple('Hand', [
'multi_hand_landmarks', 'multi_hand_world_landmarks',
'multi_handedness'
])(1, 2, None)),
testcase_name='empty_hand_world_landmarks',
hand=hand_landmarker.HandLandmarkerResult(
handedness=[[1]], hand_landmarks=[[2]], hand_world_landmarks=[])),
)
def test_create_dataset_from_invalid_hand_data(self, hand: NamedTuple):
with unittest.mock.patch.object(
mp_hands.Hands, 'process', return_value=hand):
hand_landmarker.HandLandmarker, 'detect', return_value=hand):
input_data_dir = test_utils.get_test_data_path(_TEST_DATA_DIRNAME)
with self.assertRaisesRegex(ValueError, 'No valid hand is detected'):
dataset.Dataset.from_folder(

View File

@ -62,6 +62,50 @@ def read_file(file_path: str, mode: str = "rb") -> Union[str, bytes]:
return f.read()
class HandLandmarkerMetadataWriter:
"""MetadataWriter to write the model asset bundle for HandLandmarker."""
def __init__(
self,
hand_detector_model_buffer: bytearray,
hand_landmarks_detector_model_buffer: bytearray,
) -> None:
"""Initializes HandLandmarkerMetadataWriter to write model asset bundle.
Args:
hand_detector_model_buffer: A valid flatbuffer *with* metadata loaded from
the TFLite hand detector model file.
hand_landmarks_detector_model_buffer: A valid flatbuffer *with* metadata
loaded from the TFLite hand landmarks detector model file.
"""
self._hand_detector_model_buffer = hand_detector_model_buffer
self._hand_landmarks_detector_model_buffer = hand_landmarks_detector_model_buffer
self._temp_folder = tempfile.TemporaryDirectory()
def __del__(self):
if os.path.exists(self._temp_folder.name):
self._temp_folder.cleanup()
def populate(self):
"""Creates the model asset bundle for hand landmarker task.
Returns:
Model asset bundle in bytes
"""
landmark_models = {
_HAND_DETECTOR_TFLITE_NAME:
self._hand_detector_model_buffer,
_HAND_LANDMARKS_DETECTOR_TFLITE_NAME:
self._hand_landmarks_detector_model_buffer
}
output_hand_landmarker_path = os.path.join(self._temp_folder.name,
_HAND_LANDMARKER_BUNDLE_NAME)
writer_utils.create_model_asset_bundle(landmark_models,
output_hand_landmarker_path)
hand_landmarker_model_buffer = read_file(output_hand_landmarker_path)
return hand_landmarker_model_buffer
class MetadataWriter:
"""MetadataWriter to write the metadata and the model asset bundle."""
@ -86,8 +130,8 @@ class MetadataWriter:
custom_gesture_classifier_metadata_writer: Metadata writer to write custom
gesture classifier metadata into the TFLite file.
"""
self._hand_detector_model_buffer = hand_detector_model_buffer
self._hand_landmarks_detector_model_buffer = hand_landmarks_detector_model_buffer
self._hand_landmarker_metadata_writer = HandLandmarkerMetadataWriter(
hand_detector_model_buffer, hand_landmarks_detector_model_buffer)
self._gesture_embedder_model_buffer = gesture_embedder_model_buffer
self._canned_gesture_classifier_model_buffer = canned_gesture_classifier_model_buffer
self._custom_gesture_classifier_metadata_writer = custom_gesture_classifier_metadata_writer
@ -147,16 +191,8 @@ class MetadataWriter:
A tuple of (model_asset_bundle_in_bytes, metadata_json_content)
"""
# Creates the model asset bundle for hand landmarker task.
landmark_models = {
_HAND_DETECTOR_TFLITE_NAME:
self._hand_detector_model_buffer,
_HAND_LANDMARKS_DETECTOR_TFLITE_NAME:
self._hand_landmarks_detector_model_buffer
}
output_hand_landmarker_path = os.path.join(self._temp_folder.name,
_HAND_LANDMARKER_BUNDLE_NAME)
writer_utils.create_model_asset_bundle(landmark_models,
output_hand_landmarker_path)
hand_landmarker_model_buffer = self._hand_landmarker_metadata_writer.populate(
)
# Write metadata into custom gesture classifier model.
self._custom_gesture_classifier_model_buffer, custom_gesture_classifier_metadata_json = self._custom_gesture_classifier_metadata_writer.populate(
@ -179,7 +215,7 @@ class MetadataWriter:
# graph.
gesture_recognizer_models = {
_HAND_LANDMARKER_BUNDLE_NAME:
read_file(output_hand_landmarker_path),
hand_landmarker_model_buffer,
_HAND_GESTURE_RECOGNIZER_BUNDLE_NAME:
read_file(output_hand_gesture_recognizer_path),
}

View File

@ -33,6 +33,23 @@ _CUSTOM_GESTURE_CLASSIFIER_PATH = test_utils.get_test_data_path(
class MetadataWriterTest(tf.test.TestCase):
def test_hand_landmarker_metadata_writer(self):
# Use dummy model buffer for unit test only.
hand_detector_model_buffer = b"\x11\x12"
hand_landmarks_detector_model_buffer = b"\x22"
writer = metadata_writer.HandLandmarkerMetadataWriter(
hand_detector_model_buffer, hand_landmarks_detector_model_buffer)
model_bundle_content = writer.populate()
model_bundle_filepath = os.path.join(self.get_temp_dir(),
"hand_landmarker.task")
with open(model_bundle_filepath, "wb") as f:
f.write(model_bundle_content)
with zipfile.ZipFile(model_bundle_filepath) as zf:
self.assertEqual(
set(zf.namelist()),
set(["hand_landmarks_detector.tflite", "hand_detector.tflite"]))
def test_write_metadata_and_create_model_asset_bundle_successful(self):
# Use dummy model buffer for unit test only.
hand_detector_model_buffer = b"\x11\x12"

View File

@ -1,6 +1,8 @@
absl-py
mediapipe==0.9.1
numpy
opencv-contrib-python
tensorflow
opencv-python
tensorflow>=2.10
tensorflow-datasets
tensorflow-hub
tf-models-official>=2.10.1

View File

@ -0,0 +1,147 @@
"""Copyright 2020-2022 The MediaPipe Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
Setup for Mediapipe-Model-Maker package with setuptools.
"""
import glob
import os
import shutil
import subprocess
import sys
import setuptools
__version__ = 'dev'
MM_ROOT_PATH = os.path.dirname(os.path.abspath(__file__))
# Build dir to copy all necessary files and build package
SRC_NAME = 'pip_src'
BUILD_DIR = os.path.join(MM_ROOT_PATH, SRC_NAME)
BUILD_MM_DIR = os.path.join(BUILD_DIR, 'mediapipe_model_maker')
def _parse_requirements(path):
with open(os.path.join(MM_ROOT_PATH, path)) as f:
return [
line.rstrip()
for line in f
if not (line.isspace() or line.startswith('#'))
]
def _copy_to_pip_src_dir(file):
"""Copy a file from bazel-bin to the pip_src dir."""
dst = file
dst_dir = os.path.dirname(dst)
if not os.path.exists(dst_dir):
os.makedirs(dst_dir)
src_file = os.path.join('../../bazel-bin/mediapipe/model_maker', file)
shutil.copyfile(src_file, file)
def _setup_build_dir():
"""Setup the BUILD_DIR directory to build the mediapipe_model_maker package.
We need to create a new BUILD_DIR directory because any references to the path
`mediapipe/model_maker` needs to be renamed to `mediapipe_model_maker` to
avoid conflicting with the mediapipe package name.
This setup function performs the following actions:
1. Copy python source code into BUILD_DIR and rename imports to
mediapipe_model_maker
2. Download models from GCS into BUILD_DIR
"""
# Copy python source code into BUILD_DIR
if os.path.exists(BUILD_DIR):
shutil.rmtree(BUILD_DIR)
python_files = glob.glob('python/**/*.py', recursive=True)
python_files.append('__init__.py')
for python_file in python_files:
# Exclude test files from pip package
if '_test.py' in python_file:
continue
build_target_file = os.path.join(BUILD_MM_DIR, python_file)
with open(python_file, 'r') as file:
filedata = file.read()
# Rename all mediapipe.model_maker imports to mediapipe_model_maker
filedata = filedata.replace('from mediapipe.model_maker',
'from mediapipe_model_maker')
os.makedirs(os.path.dirname(build_target_file), exist_ok=True)
with open(build_target_file, 'w') as file:
file.write(filedata)
# Use bazel to download GCS model files
model_build_files = ['models/gesture_recognizer/BUILD']
for model_build_file in model_build_files:
build_target_file = os.path.join(BUILD_MM_DIR, model_build_file)
os.makedirs(os.path.dirname(build_target_file), exist_ok=True)
shutil.copy(model_build_file, build_target_file)
external_files = [
'models/gesture_recognizer/canned_gesture_classifier.tflite',
'models/gesture_recognizer/gesture_embedder.tflite',
'models/gesture_recognizer/hand_landmark_full.tflite',
'models/gesture_recognizer/palm_detection_full.tflite',
'models/gesture_recognizer/gesture_embedder/keras_metadata.pb',
'models/gesture_recognizer/gesture_embedder/saved_model.pb',
'models/gesture_recognizer/gesture_embedder/variables/variables.data-00000-of-00001',
'models/gesture_recognizer/gesture_embedder/variables/variables.index',
]
for elem in external_files:
external_file = os.path.join(f'{SRC_NAME}/mediapipe_model_maker', elem)
sys.stderr.write('downloading file: %s\n' % external_file)
fetch_model_command = [
'bazel',
'build',
external_file,
]
if subprocess.call(fetch_model_command) != 0:
sys.exit(-1)
_copy_to_pip_src_dir(external_file)
_setup_build_dir()
setuptools.setup(
name='mediapipe-model-maker',
version=__version__,
url='https://github.com/google/mediapipe/tree/master/mediapipe/model_maker',
description='MediaPipe Model Maker is a simple, low-code solution for customizing on-device ML models',
author='The MediaPipe Authors',
author_email='mediapipe@google.com',
long_description='',
long_description_content_type='text/markdown',
packages=setuptools.find_packages(where=SRC_NAME),
package_dir={'': SRC_NAME},
install_requires=_parse_requirements('requirements.txt'),
include_package_data=True,
classifiers=[
'Development Status :: 3 - Alpha',
'Intended Audience :: Developers',
'Intended Audience :: Education',
'Intended Audience :: Science/Research',
'License :: OSI Approved :: Apache Software License',
'Operating System :: MacOS :: MacOS X',
'Operating System :: Microsoft :: Windows',
'Operating System :: POSIX :: Linux',
'Programming Language :: Python :: 3.7',
'Programming Language :: Python :: 3.8',
'Programming Language :: Python :: 3.9',
'Programming Language :: Python :: 3 :: Only',
'Topic :: Scientific/Engineering',
'Topic :: Scientific/Engineering :: Artificial Intelligence',
'Topic :: Software Development',
'Topic :: Software Development :: Libraries',
'Topic :: Software Development :: Libraries :: Python Modules',
],
license='Apache 2.0',
keywords=['mediapipe', 'model', 'maker'],
)

View File

@ -24,7 +24,6 @@ cc_library(
"//mediapipe/framework:calculator_framework",
"//mediapipe/framework:calculator_options_cc_proto",
"//mediapipe/framework/formats:landmark_cc_proto",
"//mediapipe/framework/formats:location_data_cc_proto",
"//mediapipe/framework/formats:rect_cc_proto",
"//mediapipe/framework/port:ret_check",
"//mediapipe/framework/port:status",

View File

@ -22,6 +22,8 @@
namespace mediapipe {
using ::mediapipe::NormalizedRect;
namespace {
// NORM_LANDMARKS is either the full set of landmarks for the hand, or

View File

@ -34,6 +34,8 @@ constexpr char kRecropRectTag[] = "RECROP_RECT";
constexpr char kImageSizeTag[] = "IMAGE_SIZE";
constexpr char kTrackingRectTag[] = "TRACKING_RECT";
using ::mediapipe::NormalizedRect;
// TODO: Use rect rotation.
// Verifies that Intersection over Union of previous frame rect and current
// frame re-crop rect is less than threshold.

View File

@ -275,7 +275,6 @@ cc_library(
":tflite_tensors_to_objects_calculator_cc_proto",
"//mediapipe/framework:calculator_framework",
"//mediapipe/framework/deps:file_path",
"//mediapipe/framework/formats:detection_cc_proto",
"//mediapipe/framework/port:opencv_core",
"//mediapipe/framework/port:ret_check",
"@com_google_absl//absl/memory",
@ -299,7 +298,6 @@ cc_library(
":tensors_to_objects_calculator_cc_proto",
"//mediapipe/framework:calculator_framework",
"//mediapipe/framework/deps:file_path",
"//mediapipe/framework/formats:detection_cc_proto",
"//mediapipe/framework/port:opencv_core",
"//mediapipe/framework/port:ret_check",
"@com_google_absl//absl/memory",
@ -316,13 +314,11 @@ cc_library(
visibility = ["//visibility:public"],
deps = [
":annotation_cc_proto",
":belief_decoder_config_cc_proto",
":decoder",
":lift_2d_frame_annotation_to_3d_calculator_cc_proto",
":tensor_util",
"//mediapipe/framework:calculator_framework",
"//mediapipe/framework/deps:file_path",
"//mediapipe/framework/formats:detection_cc_proto",
"//mediapipe/framework/port:opencv_core",
"//mediapipe/framework/port:ret_check",
"@com_google_absl//absl/memory",

View File

@ -34,6 +34,8 @@ namespace {
constexpr char kInputFrameAnnotationTag[] = "FRAME_ANNOTATION";
constexpr char kOutputNormRectsTag[] = "NORM_RECTS";
using ::mediapipe::NormalizedRect;
} // namespace
// A calculator that converts FrameAnnotation proto to NormalizedRect.

View File

@ -68,7 +68,6 @@ objc_library(
copts = [
"-Wno-shorten-64-to-32",
],
sdk_frameworks = ["Accelerate"],
# This build rule is public to allow external customers to build their own iOS apps.
visibility = ["//visibility:public"],
deps = [
@ -90,6 +89,7 @@ objc_library(
"//mediapipe/gpu:metal_shared_resources",
"//mediapipe/gpu:pixel_buffer_pool_util",
"//mediapipe/util:cpu_util",
"//third_party/apple_frameworks:Accelerate",
"@com_google_absl//absl/base:core_headers",
"@com_google_absl//absl/memory",
"@com_google_absl//absl/strings",
@ -120,13 +120,13 @@ objc_library(
],
"//conditions:default": [],
}),
sdk_frameworks = [
"AVFoundation",
"CoreVideo",
"Foundation",
],
# This build rule is public to allow external customers to build their own iOS apps.
visibility = ["//visibility:public"],
deps = [
"//third_party/apple_frameworks:AVFoundation",
"//third_party/apple_frameworks:CoreVideo",
"//third_party/apple_frameworks:Foundation",
],
)
objc_library(
@ -140,16 +140,14 @@ objc_library(
copts = [
"-Wno-shorten-64-to-32",
],
sdk_frameworks = [
"Foundation",
"GLKit",
],
visibility = ["//mediapipe/framework:mediapipe_internal"],
deps = [
":mediapipe_framework_ios",
"//mediapipe/gpu:gl_calculator_helper",
"//mediapipe/gpu:gl_quad_renderer",
"//mediapipe/gpu:gl_simple_shaders",
"//third_party/apple_frameworks:Foundation",
"//third_party/apple_frameworks:GLKit",
],
)
@ -164,16 +162,14 @@ objc_library(
copts = [
"-Wno-shorten-64-to-32",
],
sdk_frameworks = [
"Foundation",
"GLKit",
],
# This build rule is public to allow external customers to build their own iOS apps.
visibility = ["//visibility:public"],
deps = [
":mediapipe_framework_ios",
":mediapipe_gl_view_renderer",
"//mediapipe/gpu:gl_calculator_helper",
"//third_party/apple_frameworks:Foundation",
"//third_party/apple_frameworks:GLKit",
],
)
@ -188,13 +184,11 @@ objc_library(
copts = [
"-Wno-shorten-64-to-32",
],
sdk_frameworks = [
"CoreVideo",
"Foundation",
],
visibility = ["//mediapipe/framework:mediapipe_internal"],
deps = [
":mediapipe_framework_ios",
"//third_party/apple_frameworks:CoreVideo",
"//third_party/apple_frameworks:Foundation",
"@com_google_absl//absl/strings",
],
)
@ -211,23 +205,21 @@ objc_library(
copts = [
"-Wno-shorten-64-to-32",
],
sdk_frameworks = [
"AVFoundation",
"Accelerate",
"CoreGraphics",
"CoreMedia",
"CoreVideo",
"GLKit",
"OpenGLES",
"QuartzCore",
"UIKit",
],
visibility = ["//mediapipe/framework:mediapipe_internal"],
deps = [
":CGImageRefUtils",
":Weakify",
":mediapipe_framework_ios",
"//mediapipe/framework:calculator_framework",
"//third_party/apple_frameworks:AVFoundation",
"//third_party/apple_frameworks:Accelerate",
"//third_party/apple_frameworks:CoreGraphics",
"//third_party/apple_frameworks:CoreMedia",
"//third_party/apple_frameworks:CoreVideo",
"//third_party/apple_frameworks:GLKit",
"//third_party/apple_frameworks:OpenGLES",
"//third_party/apple_frameworks:QuartzCore",
"//third_party/apple_frameworks:UIKit",
],
)
@ -245,16 +237,6 @@ objc_library(
data = [
"testdata/googlelogo_color_272x92dp.png",
],
sdk_frameworks = [
"AVFoundation",
"Accelerate",
"CoreGraphics",
"CoreMedia",
"CoreVideo",
"GLKit",
"QuartzCore",
"UIKit",
],
visibility = ["//mediapipe/framework:mediapipe_internal"],
deps = [
":CGImageRefUtils",
@ -263,6 +245,14 @@ objc_library(
":mediapipe_framework_ios",
":mediapipe_input_sources_ios",
"//mediapipe/calculators/core:pass_through_calculator",
"//third_party/apple_frameworks:AVFoundation",
"//third_party/apple_frameworks:Accelerate",
"//third_party/apple_frameworks:CoreGraphics",
"//third_party/apple_frameworks:CoreMedia",
"//third_party/apple_frameworks:CoreVideo",
"//third_party/apple_frameworks:GLKit",
"//third_party/apple_frameworks:QuartzCore",
"//third_party/apple_frameworks:UIKit",
],
)

View File

@ -22,9 +22,7 @@ cc_library(
name = "audio_classifier",
srcs = ["audio_classifier.cc"],
hdrs = ["audio_classifier.h"],
visibility = [
"//mediapipe/tasks:users",
],
visibility = ["//visibility:public"],
deps = [
":audio_classifier_graph",
"//mediapipe/framework/api2:builder",

View File

@ -18,6 +18,7 @@ syntax = "proto2";
package mediapipe.tasks.audio.audio_classifier.proto;
import "mediapipe/framework/calculator.proto";
import "mediapipe/framework/calculator_options.proto";
import "mediapipe/tasks/cc/components/processors/proto/classifier_options.proto";
import "mediapipe/tasks/cc/core/proto/base_options.proto";

View File

@ -22,9 +22,7 @@ cc_library(
name = "audio_embedder",
srcs = ["audio_embedder.cc"],
hdrs = ["audio_embedder.h"],
visibility = [
"//mediapipe/tasks:users",
],
visibility = ["//visibility:public"],
deps = [
":audio_embedder_graph",
"//mediapipe/framework/api2:builder",

View File

@ -18,6 +18,7 @@ syntax = "proto2";
package mediapipe.tasks.audio.audio_embedder.proto;
import "mediapipe/framework/calculator.proto";
import "mediapipe/framework/calculator_options.proto";
import "mediapipe/tasks/cc/components/processors/proto/embedder_options.proto";
import "mediapipe/tasks/cc/core/proto/base_options.proto";

View File

@ -18,6 +18,7 @@ licenses(["notice"])
cc_library(
name = "rect",
srcs = ["rect.cc"],
hdrs = ["rect.h"],
)
@ -41,6 +42,18 @@ cc_library(
],
)
cc_library(
name = "detection_result",
srcs = ["detection_result.cc"],
hdrs = ["detection_result.h"],
deps = [
":category",
":rect",
"//mediapipe/framework/formats:detection_cc_proto",
"//mediapipe/framework/formats:location_data_cc_proto",
],
)
cc_library(
name = "embedding_result",
srcs = ["embedding_result.cc"],

View File

@ -0,0 +1,73 @@
/* Copyright 2022 The MediaPipe Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#include "mediapipe/tasks/cc/components/containers/detection_result.h"
#include <strings.h>
#include <optional>
#include <string>
#include <vector>
#include "mediapipe/framework/formats/detection.pb.h"
#include "mediapipe/framework/formats/location_data.pb.h"
#include "mediapipe/tasks/cc/components/containers/category.h"
#include "mediapipe/tasks/cc/components/containers/rect.h"
namespace mediapipe::tasks::components::containers {
constexpr int kDefaultCategoryIndex = -1;
Detection ConvertToDetectionResult(
const mediapipe::Detection& detection_proto) {
Detection detection;
for (int idx = 0; idx < detection_proto.score_size(); ++idx) {
detection.categories.push_back(
{/* index= */ detection_proto.label_id_size() > idx
? detection_proto.label_id(idx)
: kDefaultCategoryIndex,
/* score= */ detection_proto.score(idx),
/* category_name */ detection_proto.label_size() > idx
? detection_proto.label(idx)
: "",
/* display_name */ detection_proto.display_name_size() > idx
? detection_proto.display_name(idx)
: ""});
}
Rect bounding_box;
if (detection_proto.location_data().has_bounding_box()) {
mediapipe::LocationData::BoundingBox bounding_box_proto =
detection_proto.location_data().bounding_box();
bounding_box.left = bounding_box_proto.xmin();
bounding_box.top = bounding_box_proto.ymin();
bounding_box.right = bounding_box_proto.xmin() + bounding_box_proto.width();
bounding_box.bottom =
bounding_box_proto.ymin() + bounding_box_proto.height();
}
detection.bounding_box = bounding_box;
return detection;
}
DetectionResult ConvertToDetectionResult(
std::vector<mediapipe::Detection> detections_proto) {
DetectionResult detection_result;
detection_result.detections.reserve(detections_proto.size());
for (const auto& detection_proto : detections_proto) {
detection_result.detections.push_back(
ConvertToDetectionResult(detection_proto));
}
return detection_result;
}
} // namespace mediapipe::tasks::components::containers

View File

@ -0,0 +1,52 @@
/* Copyright 2022 The MediaPipe Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#ifndef MEDIAPIPE_TASKS_CC_COMPONENTS_CONTAINERS_DETECTION_RESULT_H_
#define MEDIAPIPE_TASKS_CC_COMPONENTS_CONTAINERS_DETECTION_RESULT_H_
#include <optional>
#include <string>
#include <vector>
#include "mediapipe/framework/formats/detection.pb.h"
#include "mediapipe/tasks/cc/components/containers/category.h"
#include "mediapipe/tasks/cc/components/containers/rect.h"
namespace mediapipe::tasks::components::containers {
// Detection for a single bounding box.
struct Detection {
// A vector of detected categories.
std::vector<Category> categories;
// The bounding box location.
Rect bounding_box;
};
// Detection results of a model.
struct DetectionResult {
// A vector of Detections.
std::vector<Detection> detections;
};
// Utility function to convert from Detection proto to Detection struct.
Detection ConvertToDetection(const mediapipe::Detection& detection_proto);
// Utility function to convert from list of Detection proto to DetectionResult
// struct.
DetectionResult ConvertToDetectionResult(
std::vector<mediapipe::Detection> detections_proto);
} // namespace mediapipe::tasks::components::containers
#endif // MEDIAPIPE_TASKS_CC_COMPONENTS_CONTAINERS_DETECTION_RESULT_H_

View File

@ -0,0 +1,34 @@
/* Copyright 2022 The MediaPipe Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#include "mediapipe/tasks/cc/components/containers/rect.h"
namespace mediapipe::tasks::components::containers {
RectF ToRectF(const Rect& rect, int image_height, int image_width) {
return RectF{static_cast<float>(rect.left) / image_width,
static_cast<float>(rect.top) / image_height,
static_cast<float>(rect.right) / image_width,
static_cast<float>(rect.bottom) / image_height};
}
Rect ToRect(const RectF& rect, int image_height, int image_width) {
return Rect{static_cast<int>(rect.left * image_width),
static_cast<int>(rect.top * image_height),
static_cast<int>(rect.right * image_width),
static_cast<int>(rect.bottom * image_height)};
}
} // namespace mediapipe::tasks::components::containers

Some files were not shown because too many files have changed in this diff Show More