Project import generated by Copybara.

GitOrigin-RevId: d1277f8cb42aa228165e96775687ff2e0effffcf
This commit is contained in:
MediaPipe Team 2020-04-24 14:06:05 -07:00 committed by jqtang
parent 7bad8fce62
commit b6e680647c
19 changed files with 289 additions and 126 deletions

View File

@ -69,3 +69,7 @@ build:ios_arm64e --watchos_cpus=armv7k
build:ios_fat --config=ios
build:ios_fat --ios_multi_cpus=armv7,arm64
build:ios_fat --watchos_cpus=armv7k
build:darwin_x86_64 --apple_platform_type=macos
build:darwin_x86_64 --macos_minimum_os=10.12
build:darwin_x86_64 --cpu=darwin_x86_64

View File

@ -407,7 +407,7 @@ REGISTER_CALCULATOR(TfLiteInferenceCalculator);
#if !defined(MEDIAPIPE_DISABLE_GL_COMPUTE)
const auto& input_tensors =
cc->Inputs().Tag("TENSORS_GPU").Get<std::vector<GpuTensor>>();
RET_CHECK(input_tensors.empty());
RET_CHECK(!input_tensors.empty());
MP_RETURN_IF_ERROR(gpu_helper_.RunInGlContext(
[this, &input_tensors]() -> ::mediapipe::Status {
for (int i = 0; i < input_tensors.size(); ++i) {

View File

@ -118,7 +118,7 @@ project.
implementation 'com.google.code.findbugs:jsr305:3.0.2'
implementation 'com.google.guava:guava:27.0.1-android'
implementation 'com.google.guava:guava:27.0.1-android'
implementation 'com.google.protobuf:protobuf-lite:3.0.0'
implementation 'com.google.protobuf:protobuf-java:3.11.4''
// CameraX core library
def camerax_version = "1.0.0-alpha06"
implementation "androidx.camera:camera-core:$camerax_version"

View File

@ -14,10 +14,6 @@ We show the face detection demos with TensorFlow Lite model using the Webcam:
- [TensorFlow Lite Face Detection Demo with Webcam (GPU)](#tensorflow-lite-face-detection-demo-with-webcam-gpu)
Note: Desktop GPU works only on Linux. Mesa drivers need to be installed. Please
see
[step 4 of "Installing on Debian and Ubuntu" in the installation guide](./install.md).
Note: If MediaPipe depends on OpenCV 2, please see the
[known issues with OpenCV 2](./object_detection_desktop.md#known-issues-with-opencv-2)
section.
@ -46,11 +42,14 @@ $ GLOG_logtostderr=1 bazel-bin/mediapipe/examples/desktop/face_detection/face_de
### TensorFlow Lite Face Detection Demo with Webcam (GPU)
Note: This currently works only on Linux, and please first follow
[OpenGL ES Setup on Linux Desktop](./gpu.md#opengl-es-setup-on-linux-desktop).
To build and run the TensorFlow Lite example on desktop (GPU) with Webcam, run:
```bash
# Video from webcam running on desktop GPU
# This works only for linux currently
# This works only for Linux currently
$ bazel build -c opt --copt -DMESA_EGL_NO_X11_HEADERS --copt -DEGL_NO_X11 \
mediapipe/examples/desktop/face_detection:face_detection_gpu
@ -68,9 +67,6 @@ $ GLOG_logtostderr=1 bazel-bin/mediapipe/examples/desktop/face_detection/face_de
--calculator_graph_config_file=mediapipe/graphs/face_detection/face_detection_mobile_gpu.pbtxt
```
Issues running? Please first
[check that your GPU is supported](./gpu.md#desktop-gpu-linux)
#### Graph
![graph visualization](images/face_detection_desktop.png)

View File

@ -12,10 +12,6 @@ please see [Face Mesh on Android/iOS](face_mesh_mobile_gpu.md).
- [Face Mesh on Desktop with Webcam (GPU)](#face-mesh-on-desktop-with-webcam-gpu)
Note: Desktop GPU works only on Linux. Mesa drivers need to be installed. Please
see
[step 4 of "Installing on Debian and Ubuntu" in the installation guide](./install.md).
Note: If MediaPipe depends on OpenCV 2, please see the [known issues with OpenCV 2](#known-issues-with-opencv-2) section.
### Face Mesh on Desktop with Webcam (CPU)
@ -38,12 +34,13 @@ $ GLOG_logtostderr=1 bazel-bin/mediapipe/examples/desktop/face_mesh/face_mesh_cp
### Face Mesh on Desktop with Webcam (GPU)
Note: please first [check that your GPU is supported](gpu.md#desktop-gpu-linux).
Note: This currently works only on Linux, and please first follow
[OpenGL ES Setup on Linux Desktop](./gpu.md#opengl-es-setup-on-linux-desktop).
To build and run Face Mesh on desktop with webcam (GPU), run:
```bash
# This works only for linux currently
# This works only for Linux currently
$ bazel build -c opt --copt -DMESA_EGL_NO_X11_HEADERS --copt -DEGL_NO_X11 \
mediapipe/examples/desktop/face_mesh:face_mesh_gpu

View File

@ -1,13 +1,15 @@
## Running on GPUs
- [Overview](#overview)
- [OpenGL Support](#opengl-support)
- [Desktop GPUs](#desktop-gpu-linux)
- [Life of a GPU calculator](#life-of-a-gpu-calculator)
- [GpuBuffer to ImageFrame converters](#gpubuffer-to-imageframe-converters)
- [Disable GPU support](#disable-gpu-support)
- [OpenGL ES Support](#opengl-es-support)
- [Disable OpenGL ES Support](#disable-opengl-es-support)
- [OpenGL ES Setup on Linux Desktop](#opengl-es-setup-on-linux-desktop)
- [TensorFlow CUDA Support and Setup on Linux Desktop](#tensorflow-cuda-support-and-setup-on-linux-desktop)
- [Life of a GPU Calculator](#life-of-a-gpu-calculator)
- [GpuBuffer to ImageFrame Converters](#gpubuffer-to-imageframe-converters)
### Overview
MediaPipe supports calculator nodes for GPU compute and rendering, and allows combining multiple GPU nodes, as well as mixing them with CPU based calculator nodes. There exist several GPU APIs on mobile platforms (eg, OpenGL ES, Metal and Vulkan). MediaPipe does not attempt to offer a single cross-API GPU abstraction. Individual nodes can be written using different APIs, allowing them to take advantage of platform specific features when needed.
GPU support is essential for good performance on mobile platforms, especially for real-time video. MediaPipe enables developers to write GPU compatible calculators that support the use of GPU for:
@ -23,7 +25,7 @@ Below are the design principles for GPU support in MediaPipe
* Because different platforms may require different techniques for best performance, the API should allow flexibility in the way things are implemented behind the scenes.
* A calculator should be allowed maximum flexibility in using the GPU for all or part of its operation, combining it with the CPU if necessary.
### OpenGL Support
### OpenGL ES Support
MediaPipe supports OpenGL ES up to version 3.2 on Android/Linux and up to ES 3.0
on iOS. In addition, MediaPipe also supports Metal on iOS.
@ -48,12 +50,28 @@ some Android devices. Therefore, our approach is to have one dedicated thread
per context. Each thread issues GL commands, building up a serial command queue
on its context, which is then executed by the GPU asynchronously.
#### Desktop GPU (Linux)
### Disable OpenGL ES Support
MediaPipe GPU can run on linux systems with video cards that support OpenGL ES
3.1 and up.
By default, building MediaPipe (with no special bazel flags) attempts to compile
and link against OpenGL ES (and for iOS also Metal) libraries.
To check if your linux desktop GPU can run mediapipe:
On platforms where OpenGL ES is not available (see also
[OpenGL ES Setup on Linux Desktop](#opengl-es-setup-on-linux-desktop)), you
should disable OpenGL ES support with:
```
$ bazel build --define MEDIAPIPE_DISABLE_GPU=1 <my-target>
```
Note: On Android and iOS, OpenGL ES is required by MediaPipe framework and the
support should never be disabled.
### OpenGL ES Setup on Linux Desktop
On Linux desktop with video cards that support OpenGL ES 3.1+, MediaPipe can run
GPU compute and rendering and perform TFLite inference on GPU.
To check if your Linux desktop GPU can run MediaPipe with OpenGL ES:
```bash
$ sudo apt-get install mesa-common-dev libegl1-mesa-dev libgles2-mesa-dev
@ -61,7 +79,7 @@ $ sudo apt-get install mesa-utils
$ glxinfo | grep -i opengl
```
My linux box prints:
For example, it may print:
```bash
$ glxinfo | grep -i opengl
@ -71,14 +89,133 @@ OpenGL ES profile shading language version string: OpenGL ES GLSL ES 3.20
OpenGL ES profile extensions:
```
*^notice the OpenGL ES 3.2 text^*
*Notice the ES 3.20 text above.*
To run MediaPipe GPU on desktop, you need to see ES 3.1 or greater printed.
You need to see ES 3.1 or greater printed in order to perform TFLite inference
on GPU in MediaPipe. With this setup, build with:
If OpenGL ES is not printed, or is below 3.1, then the GPU inference will not
run.
```
$ bazel build --copt -DMESA_EGL_NO_X11_HEADERS --copt -DEGL_NO_X11 <my-target>
```
### Life of a GPU calculator
If only ES 3.0 or below is supported, you can still build MediaPipe targets that
don't require TFLite inference on GPU with:
```
$ bazel build --copt -DMESA_EGL_NO_X11_HEADERS --copt -DEGL_NO_X11 --copt -DMEDIAPIPE_DISABLE_GL_COMPUTE <my-target>
```
Note: MEDIAPIPE_DISABLE_GL_COMPUTE is already defined automatically on all Apple
systems (Apple doesn't support OpenGL ES 3.1+).
### TensorFlow CUDA Support and Setup on Linux Desktop
MediaPipe framework doesn't require CUDA for GPU compute and rendering. However,
MediaPipe can work with TensorFlow to perform GPU inference on video cards that
support CUDA.
To enable TensorFlow GPU inference with MediaPipe, the first step is to follow
the
[TensorFlow GPU documentation](https://www.tensorflow.org/install/gpu#software_requirements)
to install the required NVIDIA software on your Linux desktop.
After installation, update `$PATH` and `$LD_LIBRARY_PATH` and run `ldconfig`
with:
```
$ export PATH=/usr/local/cuda-10.1/bin${PATH:+:${PATH}}
$ export LD_LIBRARY_PATH=/usr/local/cuda/extras/CUPTI/lib64,/usr/local/cuda-10.1/lib64${LD_LIBRARY_PATH:+:${LD_LIBRARY_PATH}}
$ sudo ldconfig
```
It's recommended to verify the installation of CUPTI, CUDA, CuDNN, and NVCC:
```
$ ls /usr/local/cuda/extras/CUPTI
/lib64
libcupti.so libcupti.so.10.1.208 libnvperf_host.so libnvperf_target.so
libcupti.so.10.1 libcupti_static.a libnvperf_host_static.a
$ ls /usr/local/cuda-10.1
LICENSE bin extras lib64 libnvvp nvml samples src tools
README doc include libnsight nsightee_plugins nvvm share targets version.txt
$ nvcc -V
nvcc: NVIDIA (R) Cuda compiler driver
Copyright (c) 2005-2019 NVIDIA Corporation
Built on Sun_Jul_28_19:07:16_PDT_2019
Cuda compilation tools, release 10.1, V10.1.243
$ ls /usr/lib/x86_64-linux-gnu/ | grep libcudnn.so
libcudnn.so
libcudnn.so.7
libcudnn.so.7.6.4
```
Setting `$TF_CUDA_PATHS` is the way to declare where the CUDA library is. Note
that the following code snippet also adds `/usr/lib/x86_64-linux-gnu` and
`/usr/include` into `$TF_CUDA_PATHS` for cudablas and libcudnn.
```
$ export TF_CUDA_PATHS=/usr/local/cuda-10.1,/usr/lib/x86_64-linux-gnu,/usr/include
```
To make MediaPipe get TensorFlow's CUDA settings, find TensorFlow's
[.bazelrc](https://github.com/tensorflow/tensorflow/blob/master/.bazelrc) and
copy the `build:using_cuda` and `build:cuda` section into MediaPipe's .bazelrc
file. For example, as of April 23, 2020, TensorFlow's CUDA setting is the
following:
```
# This config refers to building with CUDA available. It does not necessarily
# mean that we build CUDA op kernels.
build:using_cuda --define=using_cuda=true
build:using_cuda --action_env TF_NEED_CUDA=1
build:using_cuda --crosstool_top=@local_config_cuda//crosstool:toolchain
# This config refers to building CUDA op kernels with nvcc.
build:cuda --config=using_cuda
build:cuda --define=using_cuda_nvcc=true
```
Finally, build MediaPipe with TensorFlow GPU with two more flags `--config=cuda`
and `--spawn_strategy=local`. For example:
```
$ bazel build -c opt --config=cuda --spawn_strategy=local \
--define no_aws_support=true --copt -DMESA_EGL_NO_X11_HEADERS \
mediapipe/examples/desktop/object_detection:object_detection_tensorflow
```
While the binary is running, it prints out the GPU device info:
```
I external/org_tensorflow/tensorflow/stream_executor/platform/default/dso_loader.cc:44] Successfully opened dynamic library libcuda.so.1
I external/org_tensorflow/tensorflow/core/common_runtime/gpu/gpu_device.cc:1544] Found device 0 with properties: pciBusID: 0000:00:04.0 name: Tesla T4 computeCapability: 7.5 coreClock: 1.59GHz coreCount: 40 deviceMemorySize: 14.75GiB deviceMemoryBandwidth: 298.08GiB/s
I external/org_tensorflow/tensorflow/core/common_runtime/gpu/gpu_device.cc:1686] Adding visible gpu devices: 0
```
You can monitor the GPU usage to verify whether the GPU is used for model
inference.
```
$ nvidia-smi --query-gpu=utilization.gpu --format=csv --loop=1
0 %
0 %
4 %
5 %
83 %
21 %
22 %
27 %
29 %
100 %
0 %
0%
```
### Life of a GPU Calculator
This section presents the basic structure of the Process method of a GPU
calculator derived from base class GlSimpleCalculator. The GPU calculator
@ -165,7 +302,7 @@ choices for MediaPipe GPU support:
* Data that needs to be shared between all GPU-based calculators is provided as a external input that is implemented as a graph service and is managed by the `GlCalculatorHelper` class.
* The combination of calculator-specific helpers and a shared graph service allows us great flexibility in managing the GPU resource: we can have a separate context per calculator, share a single context, share a lock or other synchronization primitives, etc. -- and all of this is managed by the helper and hidden from the individual calculators.
### GpuBuffer to ImageFrame converters
### GpuBuffer to ImageFrame Converters
We provide two calculators called `GpuBufferToImageFrameCalculator` and `ImageFrameToGpuBufferCalculator`. These calculators convert between `ImageFrame` and `GpuBuffer`, allowing the construction of graphs that combine GPU and CPU calculators. They are supported on both iOS and Android
@ -176,27 +313,3 @@ The below diagram shows the data flow in a mobile application that captures vide
| ![How GPU calculators interact](images/gpu_example_graph.png) |
|:--:|
| *Video frames from the camera are fed into the graph as `GpuBuffer` packets. The input stream is accessed by two calculators in parallel. `GpuBufferToImageFrameCalculator` converts the buffer into an `ImageFrame`, which is then sent through a grayscale converter and a canny filter (both based on OpenCV and running on the CPU), whose output is then converted into a `GpuBuffer` again. A multi-input GPU calculator, GlOverlayCalculator, takes as input both the original `GpuBuffer` and the one coming out of the edge detector, and overlays them using a shader. The output is then sent back to the application using a callback calculator, and the application renders the image to the screen using OpenGL.* |
### Disable GPU Support
By default, building MediaPipe (with no special bazel flags) attempts to compile
and link against OpenGL/Metal libraries.
There are some command line build flags available to disable/enable GPU support
within the MediaPipe framework:
```
# To disable *all* gpu support
bazel build --define MEDIAPIPE_DISABLE_GPU=1 <my-target>
# to enable full GPU support (OpenGL ES 3.1+ & Metal)
bazel build --copt -DMESA_EGL_NO_X11_HEADERS --copt -DEGL_NO_X11 <my-target>
# to enable only OpenGL ES 3.0 and below (no GLES 3.1+ features)
bazel build --copt -DMESA_EGL_NO_X11_HEADERS --copt -DEGL_NO_X11 --copt -DMEDIAPIPE_DISABLE_GL_COMPUTE <my-target>
```
Note *MEDIAPIPE_DISABLE_GL_COMPUTE* is automatically defined on all Apple
systems (Apple doesn't support OpenGL ES 3.1+).
Note on iOS and Android, it is assumed that GPU support will be enabled.

View File

@ -11,21 +11,20 @@ We show the hair segmentation demos with TensorFlow Lite model using the Webcam:
- [TensorFlow Lite Hair Segmentation Demo with Webcam (GPU)](#tensorflow-lite-hair-segmentation-demo-with-webcam-gpu)
Note: Desktop GPU works only on Linux. Mesa drivers need to be installed. Please
see
[step 4 of "Installing on Debian and Ubuntu" in the installation guide](./install.md).
Note: If MediaPipe depends on OpenCV 2, please see the
[known issues with OpenCV 2](./object_detection_desktop.md#known-issues-with-opencv-2)
section.
### TensorFlow Lite Hair Segmentation Demo with Webcam (GPU)
Note: This currently works only on Linux, and please first follow
[OpenGL ES Setup on Linux Desktop](./gpu.md#opengl-es-setup-on-linux-desktop).
To build and run the TensorFlow Lite example on desktop (GPU) with Webcam, run:
```bash
# Video from webcam running on desktop GPU
# This works only for linux currently
# This works only for Linux currently
$ bazel build -c opt --copt -DMESA_EGL_NO_X11_HEADERS --copt -DEGL_NO_X11 \
mediapipe/examples/desktop/hair_segmentation:hair_segmentation_gpu
@ -42,9 +41,6 @@ $ GLOG_logtostderr=1 bazel-bin/mediapipe/examples/desktop/hair_segmentation/hair
--calculator_graph_config_file=mediapipe/graphs/hair_segmentation/hair_segmentation_mobile_gpu.pbtxt
```
Issues running? Please first
[check that your GPU is supported](./gpu.md#desktop-gpu-linux)
#### Graph
![hair_segmentation_mobile_gpu_graph](images/mobile/hair_segmentation_mobile_gpu.png)

View File

@ -13,10 +13,6 @@ We show the hand tracking demos with TensorFlow Lite model using the Webcam:
- [TensorFlow Lite Hand Tracking Demo with Webcam (GPU)](#tensorflow-lite-hand-tracking-demo-with-webcam-gpu)
Note: Desktop GPU works only on Linux. Mesa drivers need to be installed. Please
see
[step 4 of "Installing on Debian and Ubuntu" in the installation guide](./install.md).
Note: If MediaPipe depends on OpenCV 2, please see the
[known issues with OpenCV 2](./object_detection_desktop.md#known-issues-with-opencv-2)
section.
@ -43,11 +39,14 @@ $ GLOG_logtostderr=1 bazel-bin/mediapipe/examples/desktop/hand_tracking/hand_tra
### TensorFlow Lite Hand Tracking Demo with Webcam (GPU)
Note: This currently works only on Linux, and please first follow
[OpenGL ES Setup on Linux Desktop](./gpu.md#opengl-es-setup-on-linux-desktop).
To build and run the TensorFlow Lite example on desktop (GPU) with Webcam, run:
```bash
# Video from webcam running on desktop GPU
# This works only for linux currently
# This works only for Linux currently
$ bazel build -c opt --copt -DMESA_EGL_NO_X11_HEADERS --copt -DEGL_NO_X11 \
mediapipe/examples/desktop/hand_tracking:hand_tracking_gpu
@ -63,9 +62,6 @@ $ GLOG_logtostderr=1 bazel-bin/mediapipe/examples/desktop/hand_tracking/hand_tra
--calculator_graph_config_file=mediapipe/graphs/hand_tracking/hand_tracking_mobile.pbtxt
```
Issues running? Please first
[check that your GPU is supported](./gpu.md#desktop-gpu-linux)
#### Graph
![graph visualization](images/hand_tracking_desktop.png)

View File

@ -16,7 +16,7 @@ type [`ImageFrame`] and [`GpuBuffer`]. [`ImageFrame`] refers to image data in
CPU memory in any of a number of bitmap image formats. [`GpuBuffer`] refers to
image data in GPU memory. You can find more detail in the Framework Concepts
section
[GpuBuffer to ImageFrame converters](./gpu.md).
[GpuBuffer to ImageFrame Converters](./gpu.md#gpubuffer-to-imageframe-converters).
You can see an example in:
* [`object_detection_mobile_cpu.pbtxt`]

View File

@ -654,7 +654,8 @@ and install a MediaPipe example app.
7. Select `Import Bazel Project`.
* Select `Workspace`: `/path/to/mediapipe` and select `Next`.
* Select `Generate from BUILD file`: `/path/to/mediapipe/BUILD` and select `Next`.
* Select `Generate from BUILD file`: `/path/to/mediapipe/BUILD` and select
`Next`.
* Modify `Project View` to be the following and select `Finish`.
```
@ -669,15 +670,17 @@ and install a MediaPipe example app.
//mediapipe/java/...:all
android_sdk_platform: android-29
sync_flags:
--host_crosstool_top=@bazel_tools//tools/cpp:toolchain
```
8. Select `Bazel` | `Sync` | `Sync project with Build files`.
Note: Even after doing step 4, if you still see the error:
`"no such package '@androidsdk//': Either the path
attribute of android_sdk_repository or the ANDROID_HOME environment variable
must be set."`, please modify the **WORKSPACE** file to point
to your SDK and NDK library locations, as below:
Note: Even after doing step 4, if you still see the error: `"no such package
'@androidsdk//': Either the path attribute of android_sdk_repository or the
ANDROID_HOME environment variable must be set."`, please modify the
**WORKSPACE** file to point to your SDK and NDK library locations, as below:
```
android_sdk_repository(

View File

@ -14,10 +14,6 @@ We show the hand tracking demos with TensorFlow Lite model using the Webcam:
- [TensorFlow Lite Multi-Hand Tracking Demo with Webcam (GPU)](#tensorflow-lite-multi-hand-tracking-demo-with-webcam-gpu)
Note: Desktop GPU works only on Linux. Mesa drivers need to be installed. Please
see
[step 4 of "Installing on Debian and Ubuntu" in the installation guide](./install.md).
Note: If MediaPipe depends on OpenCV 2, please see the
[known issues with OpenCV 2](./object_detection_desktop.md#known-issues-with-opencv-2)
section.
@ -43,11 +39,14 @@ $ GLOG_logtostderr=1 bazel-bin/mediapipe/examples/desktop/multi_hand_tracking/mu
### TensorFlow Lite Multi-Hand Tracking Demo with Webcam (GPU)
Note: This currently works only on Linux, and please first follow
[OpenGL ES Setup on Linux Desktop](./gpu.md#opengl-es-setup-on-linux-desktop).
To build and run the TensorFlow Lite example on desktop (GPU) with Webcam, run:
```bash
# Video from webcam running on desktop GPU
# This works only for linux currently
# This works only for Linux currently
$ bazel build -c opt --copt -DMESA_EGL_NO_X11_HEADERS --copt -DEGL_NO_X11 \
mediapipe/examples/desktop/multi_hand_tracking:multi_hand_tracking_gpu
@ -62,9 +61,6 @@ $ GLOG_logtostderr=1 bazel-bin/mediapipe/examples/desktop/multi_hand_tracking/mu
--calculator_graph_config_file=mediapipe/graphs/hand_tracking/multi_hand_tracking_mobile.pbtxt
```
Issues running? Please first
[check that your GPU is supported](./gpu.md#desktop-gpu-linux)
#### Graph
![graph visualization](images/multi_hand_tracking_desktop.png)

View File

@ -18,7 +18,12 @@ Note: If MediaPipe depends on OpenCV 2, please see the [known issues with OpenCV
### TensorFlow Object Detection Demo
To build and run the TensorFlow example on desktop, run:
Note: If you would like to run TensorFlow inference on GPU on Linux, please
follow
[TensorFlow CUDA Support and Setup on Linux Desktop](gpu.md#tensorflow-cuda-support-and-setup-on-linux-desktop)
instead.
To build and run the TensorFlow inference example on CPU on desktop, run:
```bash
# Note that this command also builds TensorFlow targets from scratch, it may

View File

@ -600,3 +600,8 @@ cc_test(
"@com_google_absl//absl/strings",
],
)
exports_files(
["build_defs.bzl"],
visibility = ["//mediapipe/framework:__subpackages__"],
)

View File

@ -0,0 +1,11 @@
"""MediaPipe BUILD rules and related utilities."""
# Sanitize a dependency so that it works correctly from targets that
# include MediaPipe as an external dependency.
def clean_dep(dep):
return str(Label(dep))
# Sanitize a list of dependencies so that they work correctly from targets that
# include MediaPipe as an external dependency.
def clean_deps(dep_list):
return [clean_dep(dep) for dep in dep_list]

View File

@ -18,6 +18,7 @@ Example:
load("//mediapipe/framework:encode_binary_proto.bzl", "encode_binary_proto", "generate_proto_descriptor_set")
load("//mediapipe/framework:transitive_protos.bzl", "transitive_protos")
load("//mediapipe/framework/deps:expand_template.bzl", "expand_template")
load("//mediapipe/framework/tool:build_defs.bzl", "clean_dep")
def mediapipe_binary_graph(name, graph = None, output_name = None, deps = [], testonly = False, **kwargs):
"""Converts a graph from text format to binary format."""
@ -39,7 +40,7 @@ def mediapipe_binary_graph(name, graph = None, output_name = None, deps = [], te
name = name + "_text_to_binary_graph",
visibility = ["//visibility:private"],
deps = [
"//mediapipe/framework/tool:text_to_binary_graph",
clean_dep("//mediapipe/framework/tool:text_to_binary_graph"),
name + "_gather_cc_protos",
],
tags = ["manual"],
@ -81,12 +82,13 @@ def data_as_c_string(
fail("srcs must be a single-element list")
if outs == None:
outs = [name]
encode_as_c_string = clean_dep("//mediapipe/framework/tool:encode_as_c_string")
native.genrule(
name = name,
srcs = srcs,
outs = outs,
cmd = "$(location //mediapipe/framework/tool:encode_as_c_string) \"$<\" > \"$@\"",
tools = ["//mediapipe/framework/tool:encode_as_c_string"],
cmd = "$(location %s) \"$<\" > \"$@\"" % encode_as_c_string,
tools = [encode_as_c_string],
testonly = testonly,
)
@ -127,7 +129,7 @@ def mediapipe_simple_subgraph(
# cc_library for a linked mediapipe graph.
expand_template(
name = name + "_linked_cc",
template = "//mediapipe/framework/tool:simple_subgraph_template.cc",
template = clean_dep("//mediapipe/framework/tool:simple_subgraph_template.cc"),
out = name + "_linked.cc",
substitutions = {
"{{SUBGRAPH_CLASS_NAME}}": register_as,
@ -142,8 +144,8 @@ def mediapipe_simple_subgraph(
graph_base_name + ".inc",
],
deps = [
"//mediapipe/framework:calculator_framework",
"//mediapipe/framework:subgraph",
clean_dep("//mediapipe/framework:calculator_framework"),
clean_dep("//mediapipe/framework:subgraph"),
] + deps,
alwayslink = 1,
visibility = visibility,

View File

@ -14,6 +14,7 @@
#include "mediapipe/framework/tool/name_util.h"
#include <set>
#include <unordered_map>
#include "absl/strings/str_cat.h"

View File

@ -14,8 +14,11 @@
package com.google.mediapipe.framework;
import static com.google.common.base.Preconditions.checkArgument;
import static com.google.common.base.Preconditions.checkState;
import android.graphics.Bitmap;
import com.google.common.flogger.FluentLogger;
import android.graphics.Bitmap.Config;
import java.nio.ByteBuffer;
import java.nio.ByteOrder;
@ -27,42 +30,90 @@ import java.nio.ByteOrder;
* <p>This class contains methods that are Android-specific.
*/
public final class AndroidPacketGetter {
private static final FluentLogger logger = FluentLogger.forEnclosingClass();
/** Gets an {@code ARGB_8888} bitmap from an RGB mediapipe image frame packet. */
/**
* Gets an {@code ARGB_8888} bitmap from an RGB mediapipe image frame packet.
*
* @param packet mediapipe packet
* @return {@link Bitmap} with pixels copied from the packet
*/
public static Bitmap getBitmapFromRgb(Packet packet) {
int width = PacketGetter.getImageWidth(packet);
int height = PacketGetter.getImageHeight(packet);
Bitmap bitmap = Bitmap.createBitmap(width, height, Bitmap.Config.ARGB_8888);
copyRgbToBitmap(packet, bitmap, width, height);
return bitmap;
}
/**
* Copies data from an RGB mediapipe image frame packet to {@code ARGB_8888} bitmap.
*
* @param packet mediapipe packet
* @param inBitmap mutable {@link Bitmap} of same dimension and config as the expected output, the
* image would be copied to this {@link Bitmap}
*/
public static void copyRgbToBitmap(Packet packet, Bitmap inBitmap) {
checkArgument(inBitmap.isMutable(), "Input bitmap should be mutable.");
checkArgument(
inBitmap.getConfig() == Config.ARGB_8888, "Input bitmap should be of type ARGB_8888.");
int width = PacketGetter.getImageWidth(packet);
int height = PacketGetter.getImageHeight(packet);
checkArgument(inBitmap.getByteCount() == width * height * 4, "Input bitmap size mismatch.");
copyRgbToBitmap(packet, inBitmap, width, height);
}
private static void copyRgbToBitmap(Packet packet, Bitmap mutableBitmap, int width, int height) {
// TODO: use NDK Bitmap access instead of copyPixelsToBuffer.
ByteBuffer buffer = ByteBuffer.allocateDirect(width * height * 4);
PacketGetter.getRgbaFromRgb(packet, buffer);
Bitmap bitmap = Bitmap.createBitmap(width, height, Bitmap.Config.ARGB_8888);
bitmap.copyPixelsFromBuffer(buffer);
return bitmap;
mutableBitmap.copyPixelsFromBuffer(buffer);
}
/**
* Gets an {@code ARGB_8888} bitmap from an RGBA mediapipe image frame packet. Returns null in
* case of failure.
*
* @param packet mediapipe packet
* @return {@link Bitmap} with pixels copied from the packet
*/
public static Bitmap getBitmapFromRgba(Packet packet) {
// TODO: unify into a single getBitmap call.
// TODO: use NDK Bitmap access instead of copyPixelsToBuffer.
int width = PacketGetter.getImageWidth(packet);
int height = PacketGetter.getImageHeight(packet);
Bitmap bitmap = Bitmap.createBitmap(width, height, Bitmap.Config.ARGB_8888);
copyRgbaToBitmap(packet, bitmap, width, height);
return bitmap;
}
/**
* Copies data from an RGBA mediapipe image frame packet to {@code ARGB_8888} bitmap.
*
* @param packet mediapipe packet
* @param inBitmap mutable {@link Bitmap} of same dimension and config as the expected output, the
* image would be copied to this {@link Bitmap}.
*/
public static void copyRgbaToBitmap(Packet packet, Bitmap inBitmap) {
checkArgument(inBitmap.isMutable(), "Input bitmap should be mutable.");
checkArgument(
inBitmap.getConfig() == Config.ARGB_8888, "Input bitmap should be of type ARGB_8888.");
int width = PacketGetter.getImageWidth(packet);
int height = PacketGetter.getImageHeight(packet);
checkArgument(inBitmap.getByteCount() == width * height * 4, "Input bitmap size mismatch.");
copyRgbaToBitmap(packet, inBitmap, width, height);
}
private static void copyRgbaToBitmap(Packet packet, Bitmap mutableBitmap, int width, int height) {
// TODO: unify into a single getBitmap call.
// TODO: use NDK Bitmap access instead of copyPixelsToBuffer.
ByteBuffer buffer = ByteBuffer.allocateDirect(width * height * 4);
buffer.order(ByteOrder.nativeOrder());
// Note: even though the Android Bitmap config is named ARGB_8888, the data
// is stored as RGBA internally.
boolean status = PacketGetter.getImageData(packet, buffer);
if (!status) {
logger.atSevere().log(
"Got error from getImageData, returning null Bitmap. Image width %d, height %d",
width, height);
return null;
}
Bitmap bitmap = Bitmap.createBitmap(width, height, Bitmap.Config.ARGB_8888);
bitmap.copyPixelsFromBuffer(buffer);
return bitmap;
checkState(
status,
String.format(
"Got error from getImageData, returning null Bitmap. Image width %d, height %d",
width, height));
mutableBitmap.copyPixelsFromBuffer(buffer);
}
private AndroidPacketGetter() {}

View File

@ -81,18 +81,10 @@ mediapipe::Status TFLiteGPURunner::Build() {
// 2. Describe output/input objects for created builder.
for (int flow_index = 0; flow_index < input_shapes_.size(); ++flow_index) {
if (input_ssbo_ids_.find(flow_index) == input_ssbo_ids_.end()) {
return absl::AlreadyExistsError(absl::Substitute(
"Couldn't find a OpenGL ssbo for input $0.", flow_index));
}
MP_RETURN_IF_ERROR(builder->SetInputObjectDef(
flow_index, GetSSBOObjectDef(input_shapes_[flow_index].c)));
}
for (int flow_index = 0; flow_index < output_shapes_.size(); ++flow_index) {
if (output_ssbo_ids_.find(flow_index) == output_ssbo_ids_.end()) {
return absl::AlreadyExistsError(absl::Substitute(
"Couldn't find a OpenGL ssbo for output $0.", flow_index));
}
MP_RETURN_IF_ERROR(builder->SetOutputObjectDef(
flow_index, GetSSBOObjectDef(output_shapes_[flow_index].c)));
}

View File

@ -73,11 +73,6 @@ class TFLiteGPURunner {
std::unique_ptr<GraphFloat32> graph_;
std::unique_ptr<InferenceRunner> runner_;
// Store registered OpenGL ssbo ids for the corresponding input/output tensor.
// key: io tensor position, value: OpenGL ssbo id.
std::unordered_map<int, GLuint> input_ssbo_ids_;
std::unordered_map<int, GLuint> output_ssbo_ids_;
// We keep information about input/output shapes, because they are needed
// after graph_ becomes "converted" into runner_.
std::vector<BHWC> input_shapes_;