Project import generated by Copybara.
GitOrigin-RevId: e3a43e4e5e519cd14df7095749059e2613bdcf76
This commit is contained in:
parent
67bd8a2bf0
commit
e9fbe868e5
2
BUILD
2
BUILD
|
@ -1,4 +1,4 @@
|
||||||
# Copyright 2019-2020 The MediaPipe Authors.
|
# Copyright 2019 The MediaPipe Authors.
|
||||||
#
|
#
|
||||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
# you may not use this file except in compliance with the License.
|
# you may not use this file except in compliance with the License.
|
||||||
|
|
|
@ -101,7 +101,7 @@ run code search using
|
||||||
|
|
||||||
## Videos
|
## Videos
|
||||||
|
|
||||||
* [YouTube Channel](https://www.youtube.com/channel/UCObqmpuSMx-usADtL_qdMAw)
|
* [YouTube Channel](https://www.youtube.com/c/MediaPipe)
|
||||||
|
|
||||||
## Events
|
## Events
|
||||||
|
|
||||||
|
@ -123,7 +123,7 @@ run code search using
|
||||||
|
|
||||||
* [Awesome MediaPipe](https://mediapipe.org) - A curated list of awesome
|
* [Awesome MediaPipe](https://mediapipe.org) - A curated list of awesome
|
||||||
MediaPipe related frameworks, libraries and software
|
MediaPipe related frameworks, libraries and software
|
||||||
* [Slack community](https://mediapipe.slack.com) for MediaPipe users
|
* [Slack community](https://https://mediapipe.page.link/joinslack) for MediaPipe users
|
||||||
* [Discuss](https://groups.google.com/forum/#!forum/mediapipe) - General
|
* [Discuss](https://groups.google.com/forum/#!forum/mediapipe) - General
|
||||||
community discussion around MediaPipe
|
community discussion around MediaPipe
|
||||||
|
|
||||||
|
|
13
WORKSPACE
13
WORKSPACE
|
@ -37,10 +37,19 @@ http_archive(
|
||||||
)
|
)
|
||||||
|
|
||||||
# GoogleTest/GoogleMock framework. Used by most unit-tests.
|
# GoogleTest/GoogleMock framework. Used by most unit-tests.
|
||||||
|
# Last updated 2020-06-30.
|
||||||
http_archive(
|
http_archive(
|
||||||
name = "com_google_googletest",
|
name = "com_google_googletest",
|
||||||
urls = ["https://github.com/google/googletest/archive/master.zip"],
|
urls = ["https://github.com/google/googletest/archive/aee0f9d9b5b87796ee8a0ab26b7587ec30e8858e.zip"],
|
||||||
strip_prefix = "googletest-master",
|
patches = [
|
||||||
|
# fix for https://github.com/google/googletest/issues/2817
|
||||||
|
"@//third_party:com_google_googletest_9d580ea80592189e6d44fa35bcf9cdea8bf620d6.diff"
|
||||||
|
],
|
||||||
|
patch_args = [
|
||||||
|
"-p1",
|
||||||
|
],
|
||||||
|
strip_prefix = "googletest-aee0f9d9b5b87796ee8a0ab26b7587ec30e8858e",
|
||||||
|
sha256 = "04a1751f94244307cebe695a69cc945f9387a80b0ef1af21394a490697c5c895",
|
||||||
)
|
)
|
||||||
|
|
||||||
# Google Benchmark library.
|
# Google Benchmark library.
|
||||||
|
|
74
build_ios_examples.sh
Normal file
74
build_ios_examples.sh
Normal file
|
@ -0,0 +1,74 @@
|
||||||
|
#!/bin/bash
|
||||||
|
# Copyright 2020 The MediaPipe Authors.
|
||||||
|
#
|
||||||
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
# you may not use this file except in compliance with the License.
|
||||||
|
# You may obtain a copy of the License at
|
||||||
|
#
|
||||||
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
#
|
||||||
|
# Unless required by applicable law or agreed to in writing, software
|
||||||
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
# See the License for the specific language governing permissions and
|
||||||
|
# limitations under the License.
|
||||||
|
# =========================================================================
|
||||||
|
#
|
||||||
|
# Script to build all MediaPipe iOS example apps.
|
||||||
|
#
|
||||||
|
# To build all apps and store them in out_dir:
|
||||||
|
# $ ./build_ios_examples.sh -d out_dir
|
||||||
|
# Omitting -d and the associated directory saves all generated IPAs in the
|
||||||
|
# current directory.
|
||||||
|
# $ ./build_ios_examples.sh -d out_dir --nostrip
|
||||||
|
# Same as above except that the symnbols are not stripped.
|
||||||
|
|
||||||
|
set -e
|
||||||
|
|
||||||
|
out_dir="."
|
||||||
|
strip=true
|
||||||
|
app_dir="mediapipe/examples/ios"
|
||||||
|
bin_dir="bazel-bin"
|
||||||
|
declare -a default_bazel_flags=(build -c opt --config=ios_arm64)
|
||||||
|
|
||||||
|
while [[ -n $1 ]]; do
|
||||||
|
case $1 in
|
||||||
|
-d)
|
||||||
|
shift
|
||||||
|
out_dir=$1
|
||||||
|
;;
|
||||||
|
--nostrip)
|
||||||
|
strip=false
|
||||||
|
;;
|
||||||
|
*)
|
||||||
|
echo "Unsupported input argument $1."
|
||||||
|
exit 1
|
||||||
|
;;
|
||||||
|
esac
|
||||||
|
shift
|
||||||
|
done
|
||||||
|
|
||||||
|
echo "app_dir: $app_dir"
|
||||||
|
echo "out_dir: $out_dir"
|
||||||
|
echo "strip: $strip"
|
||||||
|
|
||||||
|
declare -a bazel_flags
|
||||||
|
|
||||||
|
apps="${app_dir}/*"
|
||||||
|
for app in ${apps}; do
|
||||||
|
if [[ -d "${app}" ]]; then
|
||||||
|
target_name=${app##*/}
|
||||||
|
target="${app}:${target_name}"
|
||||||
|
|
||||||
|
echo "=== Target: ${target}"
|
||||||
|
|
||||||
|
bazel_flags=("${default_bazel_flags[@]}")
|
||||||
|
bazel_flags+=(${target})
|
||||||
|
if [[ $strip == true ]]; then
|
||||||
|
bazel_flags+=(--linkopt=-s)
|
||||||
|
fi
|
||||||
|
|
||||||
|
bazel "${bazel_flags[@]}"
|
||||||
|
cp -f "${bin_dir}/${app}/"*".ipa" "${out_dir}"
|
||||||
|
fi
|
||||||
|
done
|
|
@ -149,15 +149,15 @@ When possible, these calculators use platform-specific functionality to share da
|
||||||
|
|
||||||
The below diagram shows the data flow in a mobile application that captures video from the camera, runs it through a MediaPipe graph, and renders the output on the screen in real time. The dashed line indicates which parts are inside the MediaPipe graph proper. This application runs a Canny edge-detection filter on the CPU using OpenCV, and overlays it on top of the original video using the GPU.
|
The below diagram shows the data flow in a mobile application that captures video from the camera, runs it through a MediaPipe graph, and renders the output on the screen in real time. The dashed line indicates which parts are inside the MediaPipe graph proper. This application runs a Canny edge-detection filter on the CPU using OpenCV, and overlays it on top of the original video using the GPU.
|
||||||
|
|
||||||
| ![How GPU calculators interact](../images/gpu_example_graph.png) |
|
![How GPU calculators interact](../images/gpu_example_graph.png)
|
||||||
| :--------------------------------------------------------------------------: |
|
|
||||||
| *Video frames from the camera are fed into the graph as `GpuBuffer` packets. |
|
Video frames from the camera are fed into the graph as `GpuBuffer` packets. The
|
||||||
: The input stream is accessed by two calculators in parallel. :
|
input stream is accessed by two calculators in parallel.
|
||||||
: `GpuBufferToImageFrameCalculator` converts the buffer into an `ImageFrame`, :
|
`GpuBufferToImageFrameCalculator` converts the buffer into an `ImageFrame`,
|
||||||
: which is then sent through a grayscale converter and a canny filter (both :
|
which is then sent through a grayscale converter and a canny filter (both based
|
||||||
: based on OpenCV and running on the CPU), whose output is then converted into :
|
on OpenCV and running on the CPU), whose output is then converted into a
|
||||||
: a `GpuBuffer` again. A multi-input GPU calculator, GlOverlayCalculator, :
|
`GpuBuffer` again. A multi-input GPU calculator, GlOverlayCalculator, takes as
|
||||||
: takes as input both the original `GpuBuffer` and the one coming out of the :
|
input both the original `GpuBuffer` and the one coming out of the edge detector,
|
||||||
: edge detector, and overlays them using a shader. The output is then sent :
|
and overlays them using a shader. The output is then sent back to the
|
||||||
: back to the application using a callback calculator, and the application :
|
application using a callback calculator, and the application renders the image
|
||||||
: renders the image to the screen using OpenGL.* :
|
to the screen using OpenGL.
|
||||||
|
|
|
@ -184,12 +184,8 @@ app:
|
||||||
|
|
||||||
### Prerequisite
|
### Prerequisite
|
||||||
|
|
||||||
1. Install [Xcode](https://developer.apple.com/xcode/) and the Command Line
|
1. Install [Xcode](https://developer.apple.com/xcode/), and additionally
|
||||||
Tools.
|
install the Command Line Tools by:
|
||||||
|
|
||||||
Follow Apple's instructions to obtain the required development certificates
|
|
||||||
and provisioning profiles for your iOS device. Install the Command Line
|
|
||||||
Tools by
|
|
||||||
|
|
||||||
```bash
|
```bash
|
||||||
xcode-select --install
|
xcode-select --install
|
||||||
|
@ -209,26 +205,31 @@ app:
|
||||||
pip3 install --user six
|
pip3 install --user six
|
||||||
```
|
```
|
||||||
|
|
||||||
4. Clone the MediaPipe repository.
|
4. Follow
|
||||||
|
[Apple's instructions](https://developer.apple.com/support/certificates/) to
|
||||||
|
obtain the required development certificates and provisioning profiles for
|
||||||
|
your iOS device.
|
||||||
|
|
||||||
|
Tip: You can the following command to see the provisioning profiles you have
|
||||||
|
previously downloaded using Xcode: `open
|
||||||
|
~/Library/MobileDevice/"Provisioning Profiles"`. If there are none, generate
|
||||||
|
and download a profile on
|
||||||
|
[Apple's developer site](https://developer.apple.com/account/resources/).
|
||||||
|
|
||||||
|
5. Clone the MediaPipe repository.
|
||||||
|
|
||||||
```bash
|
```bash
|
||||||
git clone https://github.com/google/mediapipe.git
|
git clone https://github.com/google/mediapipe.git
|
||||||
```
|
```
|
||||||
|
|
||||||
5. Symlink or copy your provisioning profile to
|
6. In the cloned MediaPipe repository, symlink or copy your provisioning profile
|
||||||
`mediapipe/mediapipe/provisioning_profile.mobileprovision`.
|
to `mediapipe/provisioning_profile.mobileprovision`, e.g.,
|
||||||
|
|
||||||
```bash
|
```bash
|
||||||
cd mediapipe
|
cd mediapipe
|
||||||
ln -s ~/Downloads/MyProvisioningProfile.mobileprovision mediapipe/provisioning_profile.mobileprovision
|
ln -s ~/Downloads/MyProvisioningProfile.mobileprovision mediapipe/provisioning_profile.mobileprovision
|
||||||
```
|
```
|
||||||
|
|
||||||
Tip: You can use this command to see the provisioning profiles you have
|
|
||||||
previously downloaded using Xcode: `open
|
|
||||||
~/Library/MobileDevice/"Provisioning Profiles"`. If there are none, generate
|
|
||||||
and download a profile on
|
|
||||||
[Apple's developer site](https://developer.apple.com/account/resources/).
|
|
||||||
|
|
||||||
### Option 1: Build with Bazel in Command Line
|
### Option 1: Build with Bazel in Command Line
|
||||||
|
|
||||||
1. Modify the `bundle_id` field of the app's `ios_application` build target to
|
1. Modify the `bundle_id` field of the app's `ios_application` build target to
|
||||||
|
@ -246,6 +247,10 @@ app:
|
||||||
|
|
||||||
You may see a permission request from `codesign` in order to sign the app.
|
You may see a permission request from `codesign` in order to sign the app.
|
||||||
|
|
||||||
|
Tip: You can run this
|
||||||
|
[script](https://github.com/google/mediapipe/blob/master/build_ios_examples.sh)
|
||||||
|
to build all MediaPipe iOS example apps.
|
||||||
|
|
||||||
3. In Xcode, open the `Devices and Simulators` window (command-shift-2).
|
3. In Xcode, open the `Devices and Simulators` window (command-shift-2).
|
||||||
|
|
||||||
4. Make sure your device is connected. You will see a list of installed apps.
|
4. Make sure your device is connected. You will see a list of installed apps.
|
||||||
|
|
|
@ -44,6 +44,18 @@ apps, see these [instructions](./building_examples.md#ios).
|
||||||
[Bazel documentation](https://docs.bazel.build/versions/master/install-ubuntu.html)
|
[Bazel documentation](https://docs.bazel.build/versions/master/install-ubuntu.html)
|
||||||
to install Bazel 2.0 or higher.
|
to install Bazel 2.0 or higher.
|
||||||
|
|
||||||
|
For Nvidia Jetson and Raspberry Pi devices with ARM Ubuntu, Bazel needs to
|
||||||
|
be built from source.
|
||||||
|
|
||||||
|
```bash
|
||||||
|
# For Bazel 3.0.0
|
||||||
|
wget https://github.com/bazelbuild/bazel/releases/download/3.0.0/bazel-3.0.0-dist.zip
|
||||||
|
sudo apt-get install build-essential openjdk-8-jdk python zip unzip
|
||||||
|
unzip bazel-3.0.0-dist.zip
|
||||||
|
env EXTRA_BAZEL_ARGS="--host_javabase=@local_jdk//:jdk" bash ./compile.sh
|
||||||
|
sudo cp output/bazel /usr/local/bin/
|
||||||
|
```
|
||||||
|
|
||||||
3. Install OpenCV and FFmpeg.
|
3. Install OpenCV and FFmpeg.
|
||||||
|
|
||||||
Option 1. Use package manager tool to install the pre-compiled OpenCV
|
Option 1. Use package manager tool to install the pre-compiled OpenCV
|
||||||
|
@ -58,6 +70,14 @@ apps, see these [instructions](./building_examples.md#ios).
|
||||||
libopencv-imgproc-dev libopencv-video-dev
|
libopencv-imgproc-dev libopencv-video-dev
|
||||||
```
|
```
|
||||||
|
|
||||||
|
[`opencv_linux.BUILD`] is configured for x86_64 by default. For Nvidia
|
||||||
|
Jetson and Raspberry Pi devices with ARM Ubuntu, the lib paths need to be
|
||||||
|
modified.
|
||||||
|
|
||||||
|
```bash
|
||||||
|
sed -i "s/x86_64-linux-gnu/aarch64-linux-gnu/g" third_party/opencv_linux.BUILD
|
||||||
|
```
|
||||||
|
|
||||||
Option 2. Run [`setup_opencv.sh`] to automatically build OpenCV from source
|
Option 2. Run [`setup_opencv.sh`] to automatically build OpenCV from source
|
||||||
and modify MediaPipe's OpenCV config.
|
and modify MediaPipe's OpenCV config.
|
||||||
|
|
||||||
|
@ -493,14 +513,14 @@ cameras. Alternatively, you use a video file as input.
|
||||||
|
|
||||||
```bash
|
```bash
|
||||||
username@DESKTOP-TMVLBJ1:~$ curl -sLO --retry 5 --retry-max-time 10 \
|
username@DESKTOP-TMVLBJ1:~$ curl -sLO --retry 5 --retry-max-time 10 \
|
||||||
https://storage.googleapis.com/bazel/2.0.0/release/bazel-2.0.0-installer-linux-x86_64.sh && \
|
https://storage.googleapis.com/bazel/3.0.0/release/bazel-3.0.0-installer-linux-x86_64.sh && \
|
||||||
sudo mkdir -p /usr/local/bazel/2.0.0 && \
|
sudo mkdir -p /usr/local/bazel/3.0.0 && \
|
||||||
chmod 755 bazel-2.0.0-installer-linux-x86_64.sh && \
|
chmod 755 bazel-3.0.0-installer-linux-x86_64.sh && \
|
||||||
sudo ./bazel-2.0.0-installer-linux-x86_64.sh --prefix=/usr/local/bazel/2.0.0 && \
|
sudo ./bazel-3.0.0-installer-linux-x86_64.sh --prefix=/usr/local/bazel/3.0.0 && \
|
||||||
source /usr/local/bazel/2.0.0/lib/bazel/bin/bazel-complete.bash
|
source /usr/local/bazel/3.0.0/lib/bazel/bin/bazel-complete.bash
|
||||||
|
|
||||||
username@DESKTOP-TMVLBJ1:~$ /usr/local/bazel/2.0.0/lib/bazel/bin/bazel version && \
|
username@DESKTOP-TMVLBJ1:~$ /usr/local/bazel/3.0.0/lib/bazel/bin/bazel version && \
|
||||||
alias bazel='/usr/local/bazel/2.0.0/lib/bazel/bin/bazel'
|
alias bazel='/usr/local/bazel/3.0.0/lib/bazel/bin/bazel'
|
||||||
```
|
```
|
||||||
|
|
||||||
6. Checkout MediaPipe repository.
|
6. Checkout MediaPipe repository.
|
||||||
|
|
|
@ -101,7 +101,7 @@ run code search using
|
||||||
|
|
||||||
## Videos
|
## Videos
|
||||||
|
|
||||||
* [YouTube Channel](https://www.youtube.com/channel/UCObqmpuSMx-usADtL_qdMAw)
|
* [YouTube Channel](https://www.youtube.com/c/MediaPipe)
|
||||||
|
|
||||||
## Events
|
## Events
|
||||||
|
|
||||||
|
@ -123,7 +123,7 @@ run code search using
|
||||||
|
|
||||||
* [Awesome MediaPipe](https://mediapipe.org) - A curated list of awesome
|
* [Awesome MediaPipe](https://mediapipe.org) - A curated list of awesome
|
||||||
MediaPipe related frameworks, libraries and software
|
MediaPipe related frameworks, libraries and software
|
||||||
* [Slack community](https://mediapipe.slack.com) for MediaPipe users
|
* [Slack community](https://https://mediapipe.page.link/joinslack) for MediaPipe users
|
||||||
* [Discuss](https://groups.google.com/forum/#!forum/mediapipe) - General
|
* [Discuss](https://groups.google.com/forum/#!forum/mediapipe) - General
|
||||||
community discussion around MediaPipe
|
community discussion around MediaPipe
|
||||||
|
|
||||||
|
|
|
@ -1,6 +1,6 @@
|
||||||
---
|
---
|
||||||
layout: default
|
layout: default
|
||||||
title: Hand
|
title: Hands
|
||||||
parent: Solutions
|
parent: Solutions
|
||||||
nav_order: 3
|
nav_order: 3
|
||||||
---
|
---
|
||||||
|
@ -219,9 +219,13 @@ Please refer to [these instructions](../index.md#mediapipe-on-the-web).
|
||||||
|
|
||||||
## Resources
|
## Resources
|
||||||
|
|
||||||
* Google AI Blog: [On-Device, Real-Time Hand Tracking with MediaPipe](https://ai.googleblog.com/2019/08/on-device-real-time-hand-tracking-with.html)
|
* Google AI Blog:
|
||||||
* TensorFlow Blog: [Face and hand tracking in the browser with MediaPipe and
|
[On-Device, Real-Time Hand Tracking with MediaPipe](https://ai.googleblog.com/2019/08/on-device-real-time-hand-tracking-with.html)
|
||||||
TensorFlow.js](https://blog.tensorflow.org/2020/03/face-and-hand-tracking-in-browser-with-mediapipe-and-tensorflowjs.html)
|
* TensorFlow Blog:
|
||||||
|
[Face and hand tracking in the browser with MediaPipe and TensorFlow.js](https://blog.tensorflow.org/2020/03/face-and-hand-tracking-in-browser-with-mediapipe-and-tensorflowjs.html)
|
||||||
|
* Paper:
|
||||||
|
[MediaPipe Hands: On-device Real-time Hand Tracking](https://arxiv.org/abs/2006.10214)
|
||||||
|
([presentation](https://www.youtube.com/watch?v=I-UOrvxxXEk))
|
||||||
* Palm detection model:
|
* Palm detection model:
|
||||||
[TFLite model](https://github.com/google/mediapipe/tree/master/mediapipe/models/palm_detection.tflite),
|
[TFLite model](https://github.com/google/mediapipe/tree/master/mediapipe/models/palm_detection.tflite),
|
||||||
[TF.js model](https://tfhub.dev/mediapipe/handdetector/1)
|
[TF.js model](https://tfhub.dev/mediapipe/handdetector/1)
|
||||||
|
|
|
@ -188,5 +188,8 @@ to visualize its associated subgraphs, please see
|
||||||
[Real-Time 3D Object Detection on Mobile Devices with MediaPipe](https://ai.googleblog.com/2020/03/real-time-3d-object-detection-on-mobile.html)
|
[Real-Time 3D Object Detection on Mobile Devices with MediaPipe](https://ai.googleblog.com/2020/03/real-time-3d-object-detection-on-mobile.html)
|
||||||
* Paper: [MobilePose: Real-Time Pose Estimation for Unseen Objects with Weak
|
* Paper: [MobilePose: Real-Time Pose Estimation for Unseen Objects with Weak
|
||||||
Shape Supervision](https://arxiv.org/abs/2003.03522)
|
Shape Supervision](https://arxiv.org/abs/2003.03522)
|
||||||
|
* Paper:
|
||||||
|
[Instant 3D Object Tracking with Applications in Augmented Reality](https://drive.google.com/open?id=1O_zHmlgXIzAdKljp20U_JUkEHOGG52R8)
|
||||||
|
([presentation](https://www.youtube.com/watch?v=9ndF1AIo7h0))
|
||||||
* [TFLite model for shoes](https://github.com/google/mediapipe/tree/master/mediapipe/models/object_detection_3d_sneakers.tflite)
|
* [TFLite model for shoes](https://github.com/google/mediapipe/tree/master/mediapipe/models/object_detection_3d_sneakers.tflite)
|
||||||
* [TFLite model for chairs](https://github.com/google/mediapipe/tree/master/mediapipe/models/object_detection_3d_chair.tflite)
|
* [TFLite model for chairs](https://github.com/google/mediapipe/tree/master/mediapipe/models/object_detection_3d_chair.tflite)
|
||||||
|
|
|
@ -21,7 +21,16 @@ available on Linux, Android, or iOS.
|
||||||
|
|
||||||
## Enabling tracing and profiling
|
## Enabling tracing and profiling
|
||||||
|
|
||||||
To enable tracing/profiling of a mediapipe graph, the `CalculatorGraphConfig` (in
|
To enable tracing and profiling of a mediapipe graph:
|
||||||
|
|
||||||
|
1. The profiling library must be linked to the framework.
|
||||||
|
2. Tracing and profiling must be enabled in the graph configuration.
|
||||||
|
|
||||||
|
The profiling library is linked to the framework by default. If needed,
|
||||||
|
the profiling library can be omitted from the framework using the bazel
|
||||||
|
command line option: `--define MEDIAPIPE_PROFILING=0`.
|
||||||
|
|
||||||
|
To enable tracing and profiling, the `CalculatorGraphConfig` (in
|
||||||
[calculator.proto](https://github.com/google/mediapipe/tree/master/mediapipe/framework/calculator.proto))
|
[calculator.proto](https://github.com/google/mediapipe/tree/master/mediapipe/framework/calculator.proto))
|
||||||
representing the graph must have a `profiler_config` message at its root. Here
|
representing the graph must have a `profiler_config` message at its root. Here
|
||||||
is a simple setup that turns on a few extra options:
|
is a simple setup that turns on a few extra options:
|
||||||
|
|
|
@ -386,14 +386,13 @@ REGISTER_CALCULATOR(ImageTransformationCalculator);
|
||||||
|
|
||||||
const int input_width = input_mat.cols;
|
const int input_width = input_mat.cols;
|
||||||
const int input_height = input_mat.rows;
|
const int input_height = input_mat.rows;
|
||||||
if (!output_height_ || !output_width_) {
|
int output_width;
|
||||||
output_height_ = input_height;
|
int output_height;
|
||||||
output_width_ = input_width;
|
ComputeOutputDimensions(input_width, input_height, &output_width,
|
||||||
}
|
&output_height);
|
||||||
|
|
||||||
|
if (output_width_ > 0 && output_height_ > 0) {
|
||||||
cv::Mat scaled_mat;
|
cv::Mat scaled_mat;
|
||||||
int output_width = output_width_;
|
|
||||||
int output_height = output_height_;
|
|
||||||
if (scale_mode_ == mediapipe::ScaleMode_Mode_STRETCH) {
|
if (scale_mode_ == mediapipe::ScaleMode_Mode_STRETCH) {
|
||||||
int scale_flag =
|
int scale_flag =
|
||||||
input_mat.cols > output_width_ && input_mat.rows > output_height_
|
input_mat.cols > output_width_ && input_mat.rows > output_height_
|
||||||
|
@ -416,7 +415,8 @@ REGISTER_CALCULATOR(ImageTransformationCalculator);
|
||||||
const int bottom = output_height_ - target_height - top;
|
const int bottom = output_height_ - target_height - top;
|
||||||
const int left = (output_width_ - target_width) / 2;
|
const int left = (output_width_ - target_width) / 2;
|
||||||
const int right = output_width_ - target_width - left;
|
const int right = output_width_ - target_width - left;
|
||||||
cv::copyMakeBorder(intermediate_mat, scaled_mat, top, bottom, left, right,
|
cv::copyMakeBorder(intermediate_mat, scaled_mat, top, bottom, left,
|
||||||
|
right,
|
||||||
options_.constant_padding() ? cv::BORDER_CONSTANT
|
options_.constant_padding() ? cv::BORDER_CONSTANT
|
||||||
: cv::BORDER_REPLICATE);
|
: cv::BORDER_REPLICATE);
|
||||||
} else {
|
} else {
|
||||||
|
@ -426,6 +426,8 @@ REGISTER_CALCULATOR(ImageTransformationCalculator);
|
||||||
output_height = target_height;
|
output_height = target_height;
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
input_mat = scaled_mat;
|
||||||
|
}
|
||||||
|
|
||||||
if (cc->Outputs().HasTag("LETTERBOX_PADDING")) {
|
if (cc->Outputs().HasTag("LETTERBOX_PADDING")) {
|
||||||
auto padding = absl::make_unique<std::array<float, 4>>();
|
auto padding = absl::make_unique<std::array<float, 4>>();
|
||||||
|
@ -437,10 +439,33 @@ REGISTER_CALCULATOR(ImageTransformationCalculator);
|
||||||
}
|
}
|
||||||
|
|
||||||
cv::Mat rotated_mat;
|
cv::Mat rotated_mat;
|
||||||
|
cv::Size rotated_size(output_width, output_height);
|
||||||
|
if (input_mat.size() == rotated_size) {
|
||||||
const int angle = RotationModeToDegrees(rotation_);
|
const int angle = RotationModeToDegrees(rotation_);
|
||||||
cv::Point2f src_center(scaled_mat.cols / 2.0, scaled_mat.rows / 2.0);
|
cv::Point2f src_center(input_mat.cols / 2.0, input_mat.rows / 2.0);
|
||||||
cv::Mat rotation_mat = cv::getRotationMatrix2D(src_center, angle, 1.0);
|
cv::Mat rotation_mat = cv::getRotationMatrix2D(src_center, angle, 1.0);
|
||||||
cv::warpAffine(scaled_mat, rotated_mat, rotation_mat, scaled_mat.size());
|
cv::warpAffine(input_mat, rotated_mat, rotation_mat, rotated_size);
|
||||||
|
} else {
|
||||||
|
switch (rotation_) {
|
||||||
|
case mediapipe::RotationMode_Mode_UNKNOWN:
|
||||||
|
case mediapipe::RotationMode_Mode_ROTATION_0:
|
||||||
|
LOG(ERROR) << "Not rotating image.";
|
||||||
|
rotated_mat = input_mat;
|
||||||
|
break;
|
||||||
|
case mediapipe::RotationMode_Mode_ROTATION_90:
|
||||||
|
LOG(ERROR) << "Rotating image by 90 degrees ccw.";
|
||||||
|
cv::rotate(input_mat, rotated_mat, cv::ROTATE_90_COUNTERCLOCKWISE);
|
||||||
|
break;
|
||||||
|
case mediapipe::RotationMode_Mode_ROTATION_180:
|
||||||
|
LOG(ERROR) << "Rotating image by 180 degrees.";
|
||||||
|
cv::rotate(input_mat, rotated_mat, cv::ROTATE_180);
|
||||||
|
break;
|
||||||
|
case mediapipe::RotationMode_Mode_ROTATION_270:
|
||||||
|
LOG(ERROR) << "Rotating image by 90 degrees cw.";
|
||||||
|
cv::rotate(input_mat, rotated_mat, cv::ROTATE_90_CLOCKWISE);
|
||||||
|
break;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
cv::Mat flipped_mat;
|
cv::Mat flipped_mat;
|
||||||
if (flip_horizontally_ || flip_vertically_) {
|
if (flip_horizontally_ || flip_vertically_) {
|
||||||
|
|
|
@ -139,7 +139,6 @@ class TensorFlowSessionFromSavedModelGenerator : public PacketGenerator {
|
||||||
static_cast<::mediapipe::StatusCode>(status.code()),
|
static_cast<::mediapipe::StatusCode>(status.code()),
|
||||||
status.ToString());
|
status.ToString());
|
||||||
}
|
}
|
||||||
|
|
||||||
auto session = absl::make_unique<TensorFlowSession>();
|
auto session = absl::make_unique<TensorFlowSession>();
|
||||||
session->session = std::move(saved_model->session);
|
session->session = std::move(saved_model->session);
|
||||||
|
|
||||||
|
|
|
@ -14,6 +14,7 @@
|
||||||
#
|
#
|
||||||
|
|
||||||
load("//mediapipe/framework/port:build_config.bzl", "mediapipe_cc_proto_library")
|
load("//mediapipe/framework/port:build_config.bzl", "mediapipe_cc_proto_library")
|
||||||
|
load("@bazel_skylib//lib:selects.bzl", "selects")
|
||||||
|
|
||||||
licenses(["notice"]) # Apache 2.0
|
licenses(["notice"]) # Apache 2.0
|
||||||
|
|
||||||
|
@ -202,6 +203,13 @@ cc_library(
|
||||||
alwayslink = 1,
|
alwayslink = 1,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
selects.config_setting_group(
|
||||||
|
name = "gpu_inference_disabled",
|
||||||
|
match_any = [
|
||||||
|
"//mediapipe/gpu:disable_gpu",
|
||||||
|
],
|
||||||
|
)
|
||||||
|
|
||||||
cc_library(
|
cc_library(
|
||||||
name = "tflite_inference_calculator",
|
name = "tflite_inference_calculator",
|
||||||
srcs = ["tflite_inference_calculator.cc"],
|
srcs = ["tflite_inference_calculator.cc"],
|
||||||
|
@ -226,13 +234,14 @@ cc_library(
|
||||||
"@com_google_absl//absl/memory",
|
"@com_google_absl//absl/memory",
|
||||||
"//mediapipe/framework:calculator_framework",
|
"//mediapipe/framework:calculator_framework",
|
||||||
"//mediapipe/util:resource_util",
|
"//mediapipe/util:resource_util",
|
||||||
|
"//mediapipe/util/tflite:config",
|
||||||
"@org_tensorflow//tensorflow/lite:framework",
|
"@org_tensorflow//tensorflow/lite:framework",
|
||||||
"@org_tensorflow//tensorflow/lite/delegates/xnnpack:xnnpack_delegate",
|
"@org_tensorflow//tensorflow/lite/delegates/xnnpack:xnnpack_delegate",
|
||||||
"@org_tensorflow//tensorflow/lite/kernels:builtin_ops",
|
"@org_tensorflow//tensorflow/lite/kernels:builtin_ops",
|
||||||
"//mediapipe/framework/stream_handler:fixed_size_input_stream_handler",
|
"//mediapipe/framework/stream_handler:fixed_size_input_stream_handler",
|
||||||
"//mediapipe/framework/port:ret_check",
|
"//mediapipe/framework/port:ret_check",
|
||||||
] + select({
|
] + selects.with_or({
|
||||||
"//mediapipe/gpu:disable_gpu": [],
|
":gpu_inference_disabled": [],
|
||||||
"//mediapipe:ios": [
|
"//mediapipe:ios": [
|
||||||
"//mediapipe/gpu:MPPMetalHelper",
|
"//mediapipe/gpu:MPPMetalHelper",
|
||||||
"//mediapipe/gpu:MPPMetalUtil",
|
"//mediapipe/gpu:MPPMetalUtil",
|
||||||
|
@ -285,6 +294,7 @@ cc_library(
|
||||||
}),
|
}),
|
||||||
visibility = ["//visibility:public"],
|
visibility = ["//visibility:public"],
|
||||||
deps = [
|
deps = [
|
||||||
|
"//mediapipe/util/tflite:config",
|
||||||
":util",
|
":util",
|
||||||
":tflite_converter_calculator_cc_proto",
|
":tflite_converter_calculator_cc_proto",
|
||||||
"//mediapipe/util:resource_util",
|
"//mediapipe/util:resource_util",
|
||||||
|
@ -295,23 +305,26 @@ cc_library(
|
||||||
"//mediapipe/framework/port:ret_check",
|
"//mediapipe/framework/port:ret_check",
|
||||||
"@org_tensorflow//tensorflow/lite:framework",
|
"@org_tensorflow//tensorflow/lite:framework",
|
||||||
"@org_tensorflow//tensorflow/lite/kernels:builtin_ops",
|
"@org_tensorflow//tensorflow/lite/kernels:builtin_ops",
|
||||||
] + select({
|
] + selects.with_or({
|
||||||
"//mediapipe/gpu:disable_gpu": [],
|
":gpu_inference_disabled": [],
|
||||||
"//mediapipe:ios": [
|
"//mediapipe:ios": [
|
||||||
"//mediapipe/gpu:MPPMetalUtil",
|
"//mediapipe/gpu:MPPMetalUtil",
|
||||||
"//mediapipe/gpu:gpu_buffer",
|
|
||||||
"//mediapipe/gpu:MPPMetalHelper",
|
"//mediapipe/gpu:MPPMetalHelper",
|
||||||
"//mediapipe/objc:mediapipe_framework_ios",
|
"//mediapipe/objc:mediapipe_framework_ios",
|
||||||
"@org_tensorflow//tensorflow/lite/delegates/gpu:metal_delegate",
|
"@org_tensorflow//tensorflow/lite/delegates/gpu:metal_delegate",
|
||||||
],
|
],
|
||||||
"//conditions:default": [
|
"//conditions:default": [
|
||||||
"//mediapipe/gpu:gpu_buffer",
|
|
||||||
"//mediapipe/gpu:gl_calculator_helper",
|
"//mediapipe/gpu:gl_calculator_helper",
|
||||||
"@org_tensorflow//tensorflow/lite/delegates/gpu:gl_delegate",
|
"@org_tensorflow//tensorflow/lite/delegates/gpu:gl_delegate",
|
||||||
"@org_tensorflow//tensorflow/lite/delegates/gpu/gl:gl_buffer",
|
"@org_tensorflow//tensorflow/lite/delegates/gpu/gl:gl_buffer",
|
||||||
"@org_tensorflow//tensorflow/lite/delegates/gpu/gl:gl_program",
|
"@org_tensorflow//tensorflow/lite/delegates/gpu/gl:gl_program",
|
||||||
"@org_tensorflow//tensorflow/lite/delegates/gpu/gl:gl_shader",
|
"@org_tensorflow//tensorflow/lite/delegates/gpu/gl:gl_shader",
|
||||||
],
|
],
|
||||||
|
}) + select({
|
||||||
|
"//mediapipe/gpu:disable_gpu": [],
|
||||||
|
"//conditions:default": [
|
||||||
|
"//mediapipe/gpu:gpu_buffer",
|
||||||
|
],
|
||||||
}),
|
}),
|
||||||
alwayslink = 1,
|
alwayslink = 1,
|
||||||
)
|
)
|
||||||
|
@ -348,8 +361,8 @@ cc_library(
|
||||||
"//mediapipe/framework:calculator_framework",
|
"//mediapipe/framework:calculator_framework",
|
||||||
"//mediapipe/util:resource_util",
|
"//mediapipe/util:resource_util",
|
||||||
"@org_tensorflow//tensorflow/lite:framework",
|
"@org_tensorflow//tensorflow/lite:framework",
|
||||||
] + select({
|
] + selects.with_or({
|
||||||
"//mediapipe/gpu:disable_gpu": [],
|
":gpu_inference_disabled": [],
|
||||||
"//mediapipe:ios": [],
|
"//mediapipe:ios": [],
|
||||||
"//conditions:default": [
|
"//conditions:default": [
|
||||||
"//mediapipe/gpu:gl_calculator_helper",
|
"//mediapipe/gpu:gl_calculator_helper",
|
||||||
|
@ -404,6 +417,7 @@ cc_library(
|
||||||
}),
|
}),
|
||||||
visibility = ["//visibility:public"],
|
visibility = ["//visibility:public"],
|
||||||
deps = [
|
deps = [
|
||||||
|
"//mediapipe/util/tflite:config",
|
||||||
":util",
|
":util",
|
||||||
":tflite_tensors_to_detections_calculator_cc_proto",
|
":tflite_tensors_to_detections_calculator_cc_proto",
|
||||||
"//mediapipe/framework/formats:detection_cc_proto",
|
"//mediapipe/framework/formats:detection_cc_proto",
|
||||||
|
@ -415,8 +429,8 @@ cc_library(
|
||||||
"//mediapipe/framework/formats/object_detection:anchor_cc_proto",
|
"//mediapipe/framework/formats/object_detection:anchor_cc_proto",
|
||||||
"//mediapipe/framework/port:ret_check",
|
"//mediapipe/framework/port:ret_check",
|
||||||
"@org_tensorflow//tensorflow/lite:framework",
|
"@org_tensorflow//tensorflow/lite:framework",
|
||||||
] + select({
|
] + selects.with_or({
|
||||||
"//mediapipe/gpu:disable_gpu": [],
|
":gpu_inference_disabled": [],
|
||||||
"//mediapipe:ios": [
|
"//mediapipe:ios": [
|
||||||
"//mediapipe/gpu:MPPMetalUtil",
|
"//mediapipe/gpu:MPPMetalUtil",
|
||||||
"//mediapipe/gpu:gpu_buffer",
|
"//mediapipe/gpu:gpu_buffer",
|
||||||
|
@ -492,6 +506,8 @@ cc_library(
|
||||||
alwayslink = 1,
|
alwayslink = 1,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
# To run this with native GPU on Linux, use:
|
||||||
|
# bazel test //mediapipe/calculators/tflite:tflite_inference_calculator_test --copt=-DTFLITE_GPU_EXTRA_GLES_DEPS --copt=-DMESA_EGL_NO_X11_HEADERS --copt=-DEGL_NO_X11 --config=grte_v5 --test_strategy=local
|
||||||
cc_test(
|
cc_test(
|
||||||
name = "tflite_inference_calculator_test",
|
name = "tflite_inference_calculator_test",
|
||||||
srcs = ["tflite_inference_calculator_test.cc"],
|
srcs = ["tflite_inference_calculator_test.cc"],
|
||||||
|
|
|
@ -22,19 +22,23 @@
|
||||||
#include "mediapipe/framework/formats/matrix.h"
|
#include "mediapipe/framework/formats/matrix.h"
|
||||||
#include "mediapipe/framework/port/ret_check.h"
|
#include "mediapipe/framework/port/ret_check.h"
|
||||||
#include "mediapipe/util/resource_util.h"
|
#include "mediapipe/util/resource_util.h"
|
||||||
|
#include "mediapipe/util/tflite/config.h"
|
||||||
#include "tensorflow/lite/error_reporter.h"
|
#include "tensorflow/lite/error_reporter.h"
|
||||||
#include "tensorflow/lite/interpreter.h"
|
#include "tensorflow/lite/interpreter.h"
|
||||||
|
|
||||||
#if !defined(MEDIAPIPE_DISABLE_GL_COMPUTE)
|
#ifndef MEDIAPIPE_DISABLE_GPU
|
||||||
#include "mediapipe/gpu/gl_calculator_helper.h"
|
|
||||||
#include "mediapipe/gpu/gpu_buffer.h"
|
#include "mediapipe/gpu/gpu_buffer.h"
|
||||||
|
#endif // MEDIAPIPE_DISABLE_GPU
|
||||||
|
|
||||||
|
#if MEDIAPIPE_TFLITE_GL_INFERENCE
|
||||||
|
#include "mediapipe/gpu/gl_calculator_helper.h"
|
||||||
#include "tensorflow/lite/delegates/gpu/gl/gl_buffer.h"
|
#include "tensorflow/lite/delegates/gpu/gl/gl_buffer.h"
|
||||||
#include "tensorflow/lite/delegates/gpu/gl/gl_program.h"
|
#include "tensorflow/lite/delegates/gpu/gl/gl_program.h"
|
||||||
#include "tensorflow/lite/delegates/gpu/gl/gl_shader.h"
|
#include "tensorflow/lite/delegates/gpu/gl/gl_shader.h"
|
||||||
#include "tensorflow/lite/delegates/gpu/gl_delegate.h"
|
#include "tensorflow/lite/delegates/gpu/gl_delegate.h"
|
||||||
#endif // !MEDIAPIPE_DISABLE_GPU
|
#endif // MEDIAPIPE_TFLITE_GL_INFERENCE
|
||||||
|
|
||||||
#if defined(MEDIAPIPE_IOS)
|
#if MEDIAPIPE_TFLITE_METAL_INFERENCE
|
||||||
#import <CoreVideo/CoreVideo.h>
|
#import <CoreVideo/CoreVideo.h>
|
||||||
#import <Metal/Metal.h>
|
#import <Metal/Metal.h>
|
||||||
#import <MetalKit/MetalKit.h>
|
#import <MetalKit/MetalKit.h>
|
||||||
|
@ -43,13 +47,7 @@
|
||||||
#include "mediapipe/gpu/MPPMetalUtil.h"
|
#include "mediapipe/gpu/MPPMetalUtil.h"
|
||||||
#include "mediapipe/gpu/gpu_buffer.h"
|
#include "mediapipe/gpu/gpu_buffer.h"
|
||||||
#include "tensorflow/lite/delegates/gpu/metal_delegate.h"
|
#include "tensorflow/lite/delegates/gpu/metal_delegate.h"
|
||||||
#endif // iOS
|
#endif // MEDIAPIPE_TFLITE_METAL_INFERENCE
|
||||||
|
|
||||||
#if !defined(MEDIAPIPE_DISABLE_GL_COMPUTE)
|
|
||||||
typedef ::tflite::gpu::gl::GlBuffer GpuTensor;
|
|
||||||
#elif defined(MEDIAPIPE_IOS)
|
|
||||||
typedef id<MTLBuffer> GpuTensor;
|
|
||||||
#endif
|
|
||||||
|
|
||||||
namespace {
|
namespace {
|
||||||
constexpr int kWorkgroupSize = 8; // Block size for GPU shader.
|
constexpr int kWorkgroupSize = 8; // Block size for GPU shader.
|
||||||
|
@ -73,7 +71,7 @@ constexpr char kMatrixTag[] = "MATRIX";
|
||||||
namespace mediapipe {
|
namespace mediapipe {
|
||||||
|
|
||||||
namespace {
|
namespace {
|
||||||
#if !defined(MEDIAPIPE_DISABLE_GL_COMPUTE)
|
#if MEDIAPIPE_TFLITE_GL_INFERENCE
|
||||||
using ::tflite::gpu::gl::CreateReadWriteShaderStorageBuffer;
|
using ::tflite::gpu::gl::CreateReadWriteShaderStorageBuffer;
|
||||||
using ::tflite::gpu::gl::GlProgram;
|
using ::tflite::gpu::gl::GlProgram;
|
||||||
using ::tflite::gpu::gl::GlShader;
|
using ::tflite::gpu::gl::GlShader;
|
||||||
|
@ -83,13 +81,13 @@ struct GPUData {
|
||||||
GlShader shader;
|
GlShader shader;
|
||||||
GlProgram program;
|
GlProgram program;
|
||||||
};
|
};
|
||||||
#elif defined(MEDIAPIPE_IOS)
|
#elif MEDIAPIPE_TFLITE_METAL_INFERENCE
|
||||||
struct GPUData {
|
struct GPUData {
|
||||||
int elements = 1;
|
int elements = 1;
|
||||||
GpuTensor buffer;
|
GpuTensor buffer;
|
||||||
id<MTLComputePipelineState> pipeline_state;
|
id<MTLComputePipelineState> pipeline_state;
|
||||||
};
|
};
|
||||||
#endif
|
#endif // MEDIAPIPE_TFLITE_GL_INFERENCE
|
||||||
|
|
||||||
} // namespace
|
} // namespace
|
||||||
|
|
||||||
|
@ -157,13 +155,13 @@ class TfLiteConverterCalculator : public CalculatorBase {
|
||||||
|
|
||||||
std::unique_ptr<tflite::Interpreter> interpreter_ = nullptr;
|
std::unique_ptr<tflite::Interpreter> interpreter_ = nullptr;
|
||||||
|
|
||||||
#if !defined(MEDIAPIPE_DISABLE_GL_COMPUTE)
|
#if MEDIAPIPE_TFLITE_GL_INFERENCE
|
||||||
mediapipe::GlCalculatorHelper gpu_helper_;
|
mediapipe::GlCalculatorHelper gpu_helper_;
|
||||||
std::unique_ptr<GPUData> gpu_data_out_;
|
std::unique_ptr<GPUData> gpu_data_out_;
|
||||||
#elif defined(MEDIAPIPE_IOS)
|
#elif MEDIAPIPE_TFLITE_METAL_INFERENCE
|
||||||
MPPMetalHelper* gpu_helper_ = nullptr;
|
MPPMetalHelper* gpu_helper_ = nullptr;
|
||||||
std::unique_ptr<GPUData> gpu_data_out_;
|
std::unique_ptr<GPUData> gpu_data_out_;
|
||||||
#endif
|
#endif // MEDIAPIPE_TFLITE_GL_INFERENCE
|
||||||
|
|
||||||
bool initialized_ = false;
|
bool initialized_ = false;
|
||||||
bool use_gpu_ = false;
|
bool use_gpu_ = false;
|
||||||
|
@ -178,6 +176,18 @@ class TfLiteConverterCalculator : public CalculatorBase {
|
||||||
};
|
};
|
||||||
REGISTER_CALCULATOR(TfLiteConverterCalculator);
|
REGISTER_CALCULATOR(TfLiteConverterCalculator);
|
||||||
|
|
||||||
|
namespace {
|
||||||
|
template <class CC>
|
||||||
|
bool ShouldUseGpu(CC* cc) {
|
||||||
|
#if MEDIAPIPE_TFLITE_GPU_SUPPORTED
|
||||||
|
return cc->Inputs().HasTag(kGpuBufferTag) ||
|
||||||
|
cc->Outputs().HasTag(kTensorsGpuTag);
|
||||||
|
#else
|
||||||
|
return false;
|
||||||
|
#endif // MEDIAPIPE_TFLITE_GPU_SUPPORTED
|
||||||
|
}
|
||||||
|
} // namespace
|
||||||
|
|
||||||
::mediapipe::Status TfLiteConverterCalculator::GetContract(
|
::mediapipe::Status TfLiteConverterCalculator::GetContract(
|
||||||
CalculatorContract* cc) {
|
CalculatorContract* cc) {
|
||||||
// Confirm only one of the input streams is present.
|
// Confirm only one of the input streams is present.
|
||||||
|
@ -189,37 +199,31 @@ REGISTER_CALCULATOR(TfLiteConverterCalculator);
|
||||||
RET_CHECK(cc->Outputs().HasTag(kTensorsTag) ^
|
RET_CHECK(cc->Outputs().HasTag(kTensorsTag) ^
|
||||||
cc->Outputs().HasTag(kTensorsGpuTag));
|
cc->Outputs().HasTag(kTensorsGpuTag));
|
||||||
|
|
||||||
bool use_gpu = false;
|
|
||||||
|
|
||||||
if (cc->Inputs().HasTag(kImageFrameTag)) {
|
if (cc->Inputs().HasTag(kImageFrameTag)) {
|
||||||
cc->Inputs().Tag(kImageFrameTag).Set<ImageFrame>();
|
cc->Inputs().Tag(kImageFrameTag).Set<ImageFrame>();
|
||||||
}
|
}
|
||||||
if (cc->Inputs().HasTag(kMatrixTag)) {
|
if (cc->Inputs().HasTag(kMatrixTag)) {
|
||||||
cc->Inputs().Tag(kMatrixTag).Set<Matrix>();
|
cc->Inputs().Tag(kMatrixTag).Set<Matrix>();
|
||||||
}
|
}
|
||||||
#if !defined(MEDIAPIPE_DISABLE_GPU) && !defined(__EMSCRIPTEN__)
|
#ifndef MEDIAPIPE_DISABLE_GPU
|
||||||
if (cc->Inputs().HasTag(kGpuBufferTag)) {
|
if (cc->Inputs().HasTag(kGpuBufferTag)) {
|
||||||
cc->Inputs().Tag(kGpuBufferTag).Set<mediapipe::GpuBuffer>();
|
cc->Inputs().Tag(kGpuBufferTag).Set<mediapipe::GpuBuffer>();
|
||||||
use_gpu |= true;
|
|
||||||
}
|
}
|
||||||
#endif // !MEDIAPIPE_DISABLE_GPU
|
#endif // MEDIAPIPE_DISABLE_GPU
|
||||||
|
|
||||||
if (cc->Outputs().HasTag(kTensorsTag)) {
|
if (cc->Outputs().HasTag(kTensorsTag)) {
|
||||||
cc->Outputs().Tag(kTensorsTag).Set<std::vector<TfLiteTensor>>();
|
cc->Outputs().Tag(kTensorsTag).Set<std::vector<TfLiteTensor>>();
|
||||||
}
|
}
|
||||||
#if !defined(MEDIAPIPE_DISABLE_GPU) && !defined(__EMSCRIPTEN__)
|
|
||||||
if (cc->Outputs().HasTag(kTensorsGpuTag)) {
|
if (cc->Outputs().HasTag(kTensorsGpuTag)) {
|
||||||
cc->Outputs().Tag(kTensorsGpuTag).Set<std::vector<GpuTensor>>();
|
cc->Outputs().Tag(kTensorsGpuTag).Set<std::vector<GpuTensor>>();
|
||||||
use_gpu |= true;
|
|
||||||
}
|
}
|
||||||
#endif // !MEDIAPIPE_DISABLE_GPU
|
|
||||||
|
|
||||||
if (use_gpu) {
|
if (ShouldUseGpu(cc)) {
|
||||||
#if !defined(MEDIAPIPE_DISABLE_GL_COMPUTE)
|
#if MEDIAPIPE_TFLITE_GL_INFERENCE
|
||||||
MP_RETURN_IF_ERROR(mediapipe::GlCalculatorHelper::UpdateContract(cc));
|
MP_RETURN_IF_ERROR(mediapipe::GlCalculatorHelper::UpdateContract(cc));
|
||||||
#elif defined(MEDIAPIPE_IOS)
|
#elif MEDIAPIPE_TFLITE_METAL_INFERENCE
|
||||||
MP_RETURN_IF_ERROR([MPPMetalHelper updateContract:cc]);
|
MP_RETURN_IF_ERROR([MPPMetalHelper updateContract:cc]);
|
||||||
#endif
|
#endif // MEDIAPIPE_TFLITE_GL_INFERENCE
|
||||||
}
|
}
|
||||||
|
|
||||||
// Assign this calculator's default InputStreamHandler.
|
// Assign this calculator's default InputStreamHandler.
|
||||||
|
@ -233,14 +237,7 @@ REGISTER_CALCULATOR(TfLiteConverterCalculator);
|
||||||
|
|
||||||
MP_RETURN_IF_ERROR(LoadOptions(cc));
|
MP_RETURN_IF_ERROR(LoadOptions(cc));
|
||||||
|
|
||||||
if (cc->Inputs().HasTag(kGpuBufferTag) ||
|
use_gpu_ = ShouldUseGpu(cc);
|
||||||
cc->Outputs().HasTag(kGpuBufferTag)) {
|
|
||||||
#if !defined(MEDIAPIPE_DISABLE_GPU) && !defined(__EMSCRIPTEN__)
|
|
||||||
use_gpu_ = true;
|
|
||||||
#else
|
|
||||||
RET_CHECK_FAIL() << "GPU processing not enabled.";
|
|
||||||
#endif
|
|
||||||
}
|
|
||||||
|
|
||||||
if (use_gpu_) {
|
if (use_gpu_) {
|
||||||
// Cannot mix CPU/GPU streams.
|
// Cannot mix CPU/GPU streams.
|
||||||
|
@ -248,12 +245,12 @@ REGISTER_CALCULATOR(TfLiteConverterCalculator);
|
||||||
cc->Outputs().HasTag(kTensorsGpuTag));
|
cc->Outputs().HasTag(kTensorsGpuTag));
|
||||||
// Cannot use quantization.
|
// Cannot use quantization.
|
||||||
use_quantized_tensors_ = false;
|
use_quantized_tensors_ = false;
|
||||||
#if !defined(MEDIAPIPE_DISABLE_GL_COMPUTE)
|
#if MEDIAPIPE_TFLITE_GL_INFERENCE
|
||||||
MP_RETURN_IF_ERROR(gpu_helper_.Open(cc));
|
MP_RETURN_IF_ERROR(gpu_helper_.Open(cc));
|
||||||
#elif defined(MEDIAPIPE_IOS)
|
#elif MEDIAPIPE_TFLITE_METAL_INFERENCE
|
||||||
gpu_helper_ = [[MPPMetalHelper alloc] initWithCalculatorContext:cc];
|
gpu_helper_ = [[MPPMetalHelper alloc] initWithCalculatorContext:cc];
|
||||||
RET_CHECK(gpu_helper_);
|
RET_CHECK(gpu_helper_);
|
||||||
#endif
|
#endif // MEDIAPIPE_TFLITE_GL_INFERENCE
|
||||||
} else {
|
} else {
|
||||||
interpreter_ = absl::make_unique<tflite::Interpreter>();
|
interpreter_ = absl::make_unique<tflite::Interpreter>();
|
||||||
interpreter_->AddTensors(1);
|
interpreter_->AddTensors(1);
|
||||||
|
@ -282,12 +279,12 @@ REGISTER_CALCULATOR(TfLiteConverterCalculator);
|
||||||
}
|
}
|
||||||
|
|
||||||
::mediapipe::Status TfLiteConverterCalculator::Close(CalculatorContext* cc) {
|
::mediapipe::Status TfLiteConverterCalculator::Close(CalculatorContext* cc) {
|
||||||
#if !defined(MEDIAPIPE_DISABLE_GL_COMPUTE)
|
interpreter_.reset();
|
||||||
|
#if MEDIAPIPE_TFLITE_GL_INFERENCE
|
||||||
gpu_helper_.RunInGlContext([this] { gpu_data_out_.reset(); });
|
gpu_helper_.RunInGlContext([this] { gpu_data_out_.reset(); });
|
||||||
#endif
|
#elif MEDIAPIPE_TFLITE_METAL_INFERENCE
|
||||||
#if defined(MEDIAPIPE_IOS)
|
|
||||||
gpu_data_out_.reset();
|
gpu_data_out_.reset();
|
||||||
#endif
|
#endif // MEDIAPIPE_TFLITE_GL_INFERENCE
|
||||||
return ::mediapipe::OkStatus();
|
return ::mediapipe::OkStatus();
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -318,8 +315,14 @@ REGISTER_CALCULATOR(TfLiteConverterCalculator);
|
||||||
RET_CHECK(format != mediapipe::ImageFormat::VEC32F1)
|
RET_CHECK(format != mediapipe::ImageFormat::VEC32F1)
|
||||||
<< "Only 8-bit input images are supported for quantization.";
|
<< "Only 8-bit input images are supported for quantization.";
|
||||||
quant.type = kTfLiteAffineQuantization;
|
quant.type = kTfLiteAffineQuantization;
|
||||||
quant.params = nullptr;
|
auto quant_params = static_cast<TfLiteAffineQuantization*>(
|
||||||
// Optional: Set 'quant' quantization params here if needed.
|
malloc(sizeof(TfLiteAffineQuantization)));
|
||||||
|
quant_params->scale = TfLiteFloatArrayCreate(1);
|
||||||
|
quant_params->scale->data[0] = 1.0;
|
||||||
|
quant_params->zero_point = TfLiteIntArrayCreate(1);
|
||||||
|
quant_params->zero_point->data[0] = 0;
|
||||||
|
quant_params->quantized_dimension = 0;
|
||||||
|
quant.params = quant_params;
|
||||||
interpreter_->SetTensorParametersReadWrite(0, kTfLiteUInt8, "",
|
interpreter_->SetTensorParametersReadWrite(0, kTfLiteUInt8, "",
|
||||||
{channels_preserved}, quant);
|
{channels_preserved}, quant);
|
||||||
} else {
|
} else {
|
||||||
|
@ -414,7 +417,7 @@ REGISTER_CALCULATOR(TfLiteConverterCalculator);
|
||||||
|
|
||||||
::mediapipe::Status TfLiteConverterCalculator::ProcessGPU(
|
::mediapipe::Status TfLiteConverterCalculator::ProcessGPU(
|
||||||
CalculatorContext* cc) {
|
CalculatorContext* cc) {
|
||||||
#if !defined(MEDIAPIPE_DISABLE_GL_COMPUTE)
|
#if MEDIAPIPE_TFLITE_GL_INFERENCE
|
||||||
// GpuBuffer to tflite::gpu::GlBuffer conversion.
|
// GpuBuffer to tflite::gpu::GlBuffer conversion.
|
||||||
const auto& input =
|
const auto& input =
|
||||||
cc->Inputs().Tag(kGpuBufferTag).Get<mediapipe::GpuBuffer>();
|
cc->Inputs().Tag(kGpuBufferTag).Get<mediapipe::GpuBuffer>();
|
||||||
|
@ -451,7 +454,7 @@ REGISTER_CALCULATOR(TfLiteConverterCalculator);
|
||||||
cc->Outputs()
|
cc->Outputs()
|
||||||
.Tag(kTensorsGpuTag)
|
.Tag(kTensorsGpuTag)
|
||||||
.Add(output_tensors.release(), cc->InputTimestamp());
|
.Add(output_tensors.release(), cc->InputTimestamp());
|
||||||
#elif defined(MEDIAPIPE_IOS)
|
#elif MEDIAPIPE_TFLITE_METAL_INFERENCE
|
||||||
// GpuBuffer to id<MTLBuffer> conversion.
|
// GpuBuffer to id<MTLBuffer> conversion.
|
||||||
const auto& input =
|
const auto& input =
|
||||||
cc->Inputs().Tag(kGpuBufferTag).Get<mediapipe::GpuBuffer>();
|
cc->Inputs().Tag(kGpuBufferTag).Get<mediapipe::GpuBuffer>();
|
||||||
|
@ -490,13 +493,13 @@ REGISTER_CALCULATOR(TfLiteConverterCalculator);
|
||||||
.Add(output_tensors.release(), cc->InputTimestamp());
|
.Add(output_tensors.release(), cc->InputTimestamp());
|
||||||
#else
|
#else
|
||||||
RET_CHECK_FAIL() << "GPU processing is not enabled.";
|
RET_CHECK_FAIL() << "GPU processing is not enabled.";
|
||||||
#endif
|
#endif // MEDIAPIPE_TFLITE_GL_INFERENCE
|
||||||
|
|
||||||
return ::mediapipe::OkStatus();
|
return ::mediapipe::OkStatus();
|
||||||
}
|
}
|
||||||
|
|
||||||
::mediapipe::Status TfLiteConverterCalculator::InitGpu(CalculatorContext* cc) {
|
::mediapipe::Status TfLiteConverterCalculator::InitGpu(CalculatorContext* cc) {
|
||||||
#if !defined(MEDIAPIPE_DISABLE_GPU) && !defined(__EMSCRIPTEN__)
|
#if MEDIAPIPE_TFLITE_GPU_SUPPORTED
|
||||||
// Get input image sizes.
|
// Get input image sizes.
|
||||||
const auto& input =
|
const auto& input =
|
||||||
cc->Inputs().Tag(kGpuBufferTag).Get<mediapipe::GpuBuffer>();
|
cc->Inputs().Tag(kGpuBufferTag).Get<mediapipe::GpuBuffer>();
|
||||||
|
@ -512,9 +515,9 @@ REGISTER_CALCULATOR(TfLiteConverterCalculator);
|
||||||
RET_CHECK_FAIL() << "Unsupported GPU input format.";
|
RET_CHECK_FAIL() << "Unsupported GPU input format.";
|
||||||
if (include_alpha && (format != mediapipe::ImageFormat::SRGBA))
|
if (include_alpha && (format != mediapipe::ImageFormat::SRGBA))
|
||||||
RET_CHECK_FAIL() << "Num input channels is less than desired output.";
|
RET_CHECK_FAIL() << "Num input channels is less than desired output.";
|
||||||
#endif // !MEDIAPIPE_DISABLE_GPU
|
#endif // MEDIAPIPE_TFLITE_GPU_SUPPORTED
|
||||||
|
|
||||||
#if !defined(MEDIAPIPE_DISABLE_GL_COMPUTE)
|
#if MEDIAPIPE_TFLITE_GL_INFERENCE
|
||||||
MP_RETURN_IF_ERROR(gpu_helper_.RunInGlContext(
|
MP_RETURN_IF_ERROR(gpu_helper_.RunInGlContext(
|
||||||
[this, &include_alpha, &input, &single_channel]() -> ::mediapipe::Status {
|
[this, &include_alpha, &input, &single_channel]() -> ::mediapipe::Status {
|
||||||
// Device memory.
|
// Device memory.
|
||||||
|
@ -559,7 +562,7 @@ REGISTER_CALCULATOR(TfLiteConverterCalculator);
|
||||||
return ::mediapipe::OkStatus();
|
return ::mediapipe::OkStatus();
|
||||||
}));
|
}));
|
||||||
|
|
||||||
#elif defined(MEDIAPIPE_IOS)
|
#elif MEDIAPIPE_TFLITE_METAL_INFERENCE
|
||||||
|
|
||||||
RET_CHECK(include_alpha)
|
RET_CHECK(include_alpha)
|
||||||
<< "iOS GPU inference currently accepts only RGBA input.";
|
<< "iOS GPU inference currently accepts only RGBA input.";
|
||||||
|
@ -616,7 +619,7 @@ REGISTER_CALCULATOR(TfLiteConverterCalculator);
|
||||||
RET_CHECK(gpu_data_out_->pipeline_state != nil)
|
RET_CHECK(gpu_data_out_->pipeline_state != nil)
|
||||||
<< "Couldn't create pipeline state "
|
<< "Couldn't create pipeline state "
|
||||||
<< [[error localizedDescription] UTF8String];
|
<< [[error localizedDescription] UTF8String];
|
||||||
#endif
|
#endif // MEDIAPIPE_TFLITE_GL_INFERENCE
|
||||||
|
|
||||||
return ::mediapipe::OkStatus();
|
return ::mediapipe::OkStatus();
|
||||||
}
|
}
|
||||||
|
|
|
@ -22,6 +22,7 @@
|
||||||
#include "mediapipe/calculators/tflite/util.h"
|
#include "mediapipe/calculators/tflite/util.h"
|
||||||
#include "mediapipe/framework/calculator_framework.h"
|
#include "mediapipe/framework/calculator_framework.h"
|
||||||
#include "mediapipe/framework/port/ret_check.h"
|
#include "mediapipe/framework/port/ret_check.h"
|
||||||
|
#include "mediapipe/util/tflite/config.h"
|
||||||
|
|
||||||
#if !defined(__EMSCRIPTEN__) || defined(__EMSCRIPTEN_PTHREADS__)
|
#if !defined(__EMSCRIPTEN__) || defined(__EMSCRIPTEN_PTHREADS__)
|
||||||
#include "mediapipe/util/cpu_util.h"
|
#include "mediapipe/util/cpu_util.h"
|
||||||
|
@ -33,7 +34,7 @@
|
||||||
#include "tensorflow/lite/kernels/register.h"
|
#include "tensorflow/lite/kernels/register.h"
|
||||||
#include "tensorflow/lite/model.h"
|
#include "tensorflow/lite/model.h"
|
||||||
|
|
||||||
#if !defined(MEDIAPIPE_DISABLE_GL_COMPUTE)
|
#if MEDIAPIPE_TFLITE_GL_INFERENCE
|
||||||
#include "mediapipe/gpu/gl_calculator_helper.h"
|
#include "mediapipe/gpu/gl_calculator_helper.h"
|
||||||
#include "mediapipe/gpu/gpu_buffer.h"
|
#include "mediapipe/gpu/gpu_buffer.h"
|
||||||
#include "mediapipe/util/tflite/tflite_gpu_runner.h"
|
#include "mediapipe/util/tflite/tflite_gpu_runner.h"
|
||||||
|
@ -42,9 +43,9 @@
|
||||||
#include "tensorflow/lite/delegates/gpu/gl/gl_program.h"
|
#include "tensorflow/lite/delegates/gpu/gl/gl_program.h"
|
||||||
#include "tensorflow/lite/delegates/gpu/gl/gl_shader.h"
|
#include "tensorflow/lite/delegates/gpu/gl/gl_shader.h"
|
||||||
#include "tensorflow/lite/delegates/gpu/gl_delegate.h"
|
#include "tensorflow/lite/delegates/gpu/gl_delegate.h"
|
||||||
#endif // !MEDIAPIPE_DISABLE_GL_COMPUTE
|
#endif // MEDIAPIPE_TFLITE_GL_INFERENCE
|
||||||
|
|
||||||
#if defined(MEDIAPIPE_IOS)
|
#if MEDIAPIPE_TFLITE_METAL_INFERENCE
|
||||||
#import <CoreVideo/CoreVideo.h>
|
#import <CoreVideo/CoreVideo.h>
|
||||||
#import <Metal/Metal.h>
|
#import <Metal/Metal.h>
|
||||||
#import <MetalKit/MetalKit.h>
|
#import <MetalKit/MetalKit.h>
|
||||||
|
@ -56,7 +57,7 @@
|
||||||
#include "tensorflow/lite/delegates/gpu/metal/buffer_convert.h"
|
#include "tensorflow/lite/delegates/gpu/metal/buffer_convert.h"
|
||||||
#include "tensorflow/lite/delegates/gpu/metal_delegate.h"
|
#include "tensorflow/lite/delegates/gpu/metal_delegate.h"
|
||||||
#include "tensorflow/lite/delegates/gpu/metal_delegate_internal.h"
|
#include "tensorflow/lite/delegates/gpu/metal_delegate_internal.h"
|
||||||
#endif // iOS
|
#endif // MEDIAPIPE_TFLITE_METAL_INFERENCE
|
||||||
|
|
||||||
#if !defined(MEDIAPIPE_EDGE_TPU)
|
#if !defined(MEDIAPIPE_EDGE_TPU)
|
||||||
#include "tensorflow/lite/delegates/xnnpack/xnnpack_delegate.h"
|
#include "tensorflow/lite/delegates/xnnpack/xnnpack_delegate.h"
|
||||||
|
@ -71,12 +72,6 @@ int NumGroups(const int size, const int group_size) { // NOLINT
|
||||||
return (size + group_size - 1) / group_size;
|
return (size + group_size - 1) / group_size;
|
||||||
}
|
}
|
||||||
|
|
||||||
#if !defined(MEDIAPIPE_DISABLE_GL_COMPUTE)
|
|
||||||
typedef ::tflite::gpu::gl::GlBuffer GpuTensor;
|
|
||||||
#elif defined(MEDIAPIPE_IOS)
|
|
||||||
typedef id<MTLBuffer> GpuTensor;
|
|
||||||
#endif
|
|
||||||
|
|
||||||
// Round up n to next multiple of m.
|
// Round up n to next multiple of m.
|
||||||
size_t RoundUp(size_t n, size_t m) { return ((n + m - 1) / m) * m; } // NOLINT
|
size_t RoundUp(size_t n, size_t m) { return ((n + m - 1) / m) * m; } // NOLINT
|
||||||
|
|
||||||
|
@ -112,13 +107,13 @@ std::unique_ptr<tflite::Interpreter> BuildEdgeTpuInterpreter(
|
||||||
// * Aux
|
// * Aux
|
||||||
namespace mediapipe {
|
namespace mediapipe {
|
||||||
|
|
||||||
#if !defined(MEDIAPIPE_DISABLE_GL_COMPUTE)
|
#if MEDIAPIPE_TFLITE_GL_INFERENCE
|
||||||
using ::tflite::gpu::gl::CopyBuffer;
|
using ::tflite::gpu::gl::CopyBuffer;
|
||||||
using ::tflite::gpu::gl::CreateReadWriteShaderStorageBuffer;
|
using ::tflite::gpu::gl::CreateReadWriteShaderStorageBuffer;
|
||||||
using ::tflite::gpu::gl::GlBuffer;
|
using ::tflite::gpu::gl::GlBuffer;
|
||||||
#endif
|
#endif
|
||||||
|
|
||||||
#if !defined(MEDIAPIPE_DISABLE_GPU) && !defined(__EMSCRIPTEN__)
|
#if MEDIAPIPE_TFLITE_GPU_SUPPORTED
|
||||||
namespace {
|
namespace {
|
||||||
struct GPUData {
|
struct GPUData {
|
||||||
int elements = 1;
|
int elements = 1;
|
||||||
|
@ -126,7 +121,7 @@ struct GPUData {
|
||||||
::tflite::gpu::BHWC shape;
|
::tflite::gpu::BHWC shape;
|
||||||
};
|
};
|
||||||
} // namespace
|
} // namespace
|
||||||
#endif
|
#endif // MEDIAPIPE_TFLITE_GPU_SUPPORTED
|
||||||
|
|
||||||
// Returns number of threads to configure XNNPACK delegate with.
|
// Returns number of threads to configure XNNPACK delegate with.
|
||||||
// (Equal to user provided value if specified. Otherwise, it returns number of
|
// (Equal to user provided value if specified. Otherwise, it returns number of
|
||||||
|
@ -152,7 +147,7 @@ int GetXnnpackNumThreads(
|
||||||
// Creates an interpreter with given model and calls invoke().
|
// Creates an interpreter with given model and calls invoke().
|
||||||
// Optionally run inference on CPU/GPU.
|
// Optionally run inference on CPU/GPU.
|
||||||
//
|
//
|
||||||
// This calculator is designed to be used with the TfLiteConverterCalcualtor,
|
// This calculator is designed to be used with the TfLiteConverterCalculator,
|
||||||
// to get the appropriate inputs.
|
// to get the appropriate inputs.
|
||||||
//
|
//
|
||||||
// When the input tensors are on CPU, gpu inference is optional and can be
|
// When the input tensors are on CPU, gpu inference is optional and can be
|
||||||
|
@ -183,7 +178,6 @@ int GetXnnpackNumThreads(
|
||||||
// options: {
|
// options: {
|
||||||
// [mediapipe.TfLiteInferenceCalculatorOptions.ext] {
|
// [mediapipe.TfLiteInferenceCalculatorOptions.ext] {
|
||||||
// model_path: "modelname.tflite"
|
// model_path: "modelname.tflite"
|
||||||
// delegate { gpu {} }
|
|
||||||
// }
|
// }
|
||||||
// }
|
// }
|
||||||
// }
|
// }
|
||||||
|
@ -192,11 +186,12 @@ int GetXnnpackNumThreads(
|
||||||
//
|
//
|
||||||
// node {
|
// node {
|
||||||
// calculator: "TfLiteInferenceCalculator"
|
// calculator: "TfLiteInferenceCalculator"
|
||||||
// input_stream: "TENSORS:tensor_image"
|
// input_stream: "TENSORS_GPU:tensor_image"
|
||||||
// input_side_packet: "MODEL:model"
|
// input_side_packet: "MODEL:model"
|
||||||
// output_stream: "TENSORS:tensors"
|
// output_stream: "TENSORS_GPU:tensors"
|
||||||
// options: {
|
// options: {
|
||||||
// [mediapipe.TfLiteInferenceCalculatorOptions.ext] {
|
// [mediapipe.TfLiteInferenceCalculatorOptions.ext] {
|
||||||
|
// model_path: "modelname.tflite"
|
||||||
// delegate { gpu {} }
|
// delegate { gpu {} }
|
||||||
// }
|
// }
|
||||||
// }
|
// }
|
||||||
|
@ -228,24 +223,45 @@ class TfLiteInferenceCalculator : public CalculatorBase {
|
||||||
::mediapipe::Status LoadModel(CalculatorContext* cc);
|
::mediapipe::Status LoadModel(CalculatorContext* cc);
|
||||||
::mediapipe::StatusOr<Packet> GetModelAsPacket(const CalculatorContext& cc);
|
::mediapipe::StatusOr<Packet> GetModelAsPacket(const CalculatorContext& cc);
|
||||||
::mediapipe::Status LoadDelegate(CalculatorContext* cc);
|
::mediapipe::Status LoadDelegate(CalculatorContext* cc);
|
||||||
::mediapipe::Status InitTFLiteGPURunner();
|
::mediapipe::Status InitTFLiteGPURunner(CalculatorContext* cc);
|
||||||
|
::mediapipe::Status ProcessInputsCpu(
|
||||||
|
CalculatorContext* cc, std::vector<TfLiteTensor>* output_tensors_cpu);
|
||||||
|
::mediapipe::Status ProcessOutputsCpu(
|
||||||
|
CalculatorContext* cc,
|
||||||
|
std::unique_ptr<std::vector<TfLiteTensor>> output_tensors_cpu);
|
||||||
|
::mediapipe::Status ProcessInputsGpu(
|
||||||
|
CalculatorContext* cc, std::vector<GpuTensor>* output_tensors_gpu);
|
||||||
|
::mediapipe::Status ProcessOutputsGpu(
|
||||||
|
CalculatorContext* cc,
|
||||||
|
std::unique_ptr<std::vector<TfLiteTensor>> output_tensors_cpu,
|
||||||
|
std::unique_ptr<std::vector<GpuTensor>> output_tensors_gpu);
|
||||||
|
|
||||||
|
::mediapipe::Status RunInContextIfNeeded(
|
||||||
|
std::function<::mediapipe::Status(void)> f) {
|
||||||
|
if (gpu_inference_) {
|
||||||
|
#if MEDIAPIPE_TFLITE_GL_INFERENCE
|
||||||
|
return gpu_helper_.RunInGlContext(std::move(f));
|
||||||
|
#endif // MEDIAPIPE_TFLITE_GL_INFERENCE
|
||||||
|
}
|
||||||
|
return f();
|
||||||
|
}
|
||||||
|
|
||||||
Packet model_packet_;
|
Packet model_packet_;
|
||||||
std::unique_ptr<tflite::Interpreter> interpreter_;
|
std::unique_ptr<tflite::Interpreter> interpreter_;
|
||||||
TfLiteDelegatePtr delegate_;
|
TfLiteDelegatePtr delegate_;
|
||||||
|
|
||||||
#if !defined(MEDIAPIPE_DISABLE_GL_COMPUTE)
|
#if MEDIAPIPE_TFLITE_GL_INFERENCE
|
||||||
mediapipe::GlCalculatorHelper gpu_helper_;
|
mediapipe::GlCalculatorHelper gpu_helper_;
|
||||||
std::vector<std::unique_ptr<GPUData>> gpu_data_in_;
|
std::vector<std::unique_ptr<GPUData>> gpu_data_in_;
|
||||||
std::vector<std::unique_ptr<GPUData>> gpu_data_out_;
|
std::vector<std::unique_ptr<GPUData>> gpu_data_out_;
|
||||||
std::unique_ptr<tflite::gpu::TFLiteGPURunner> tflite_gpu_runner_;
|
std::unique_ptr<tflite::gpu::TFLiteGPURunner> tflite_gpu_runner_;
|
||||||
#elif defined(MEDIAPIPE_IOS)
|
#elif MEDIAPIPE_TFLITE_METAL_INFERENCE
|
||||||
MPPMetalHelper* gpu_helper_ = nullptr;
|
MPPMetalHelper* gpu_helper_ = nullptr;
|
||||||
std::vector<std::unique_ptr<GPUData>> gpu_data_in_;
|
std::vector<std::unique_ptr<GPUData>> gpu_data_in_;
|
||||||
std::vector<std::unique_ptr<GPUData>> gpu_data_out_;
|
std::vector<std::unique_ptr<GPUData>> gpu_data_out_;
|
||||||
id<MTLComputePipelineState> fp32_to_fp16_program_;
|
id<MTLComputePipelineState> fp32_to_fp16_program_;
|
||||||
TFLBufferConvert* converter_from_BPHWC4_ = nil;
|
TFLBufferConvert* converter_from_BPHWC4_ = nil;
|
||||||
#endif
|
#endif // MEDIAPIPE_TFLITE_GL_INFERENCE
|
||||||
|
|
||||||
#if defined(MEDIAPIPE_EDGE_TPU)
|
#if defined(MEDIAPIPE_EDGE_TPU)
|
||||||
std::shared_ptr<edgetpu::EdgeTpuContext> edgetpu_context_ =
|
std::shared_ptr<edgetpu::EdgeTpuContext> edgetpu_context_ =
|
||||||
|
@ -263,6 +279,22 @@ REGISTER_CALCULATOR(TfLiteInferenceCalculator);
|
||||||
|
|
||||||
// Calculator Core Section
|
// Calculator Core Section
|
||||||
|
|
||||||
|
namespace {
|
||||||
|
template <class CC>
|
||||||
|
bool ShouldUseGpu(CC* cc) {
|
||||||
|
#if MEDIAPIPE_TFLITE_GPU_SUPPORTED
|
||||||
|
const auto& options =
|
||||||
|
cc->template Options<::mediapipe::TfLiteInferenceCalculatorOptions>();
|
||||||
|
return options.use_gpu() ||
|
||||||
|
(options.has_delegate() && options.delegate().has_gpu()) ||
|
||||||
|
cc->Inputs().HasTag(kTensorsGpuTag) ||
|
||||||
|
cc->Outputs().HasTag(kTensorsGpuTag);
|
||||||
|
#else
|
||||||
|
return false;
|
||||||
|
#endif // MEDIAPIPE_TFLITE_GPU_SUPPORTED
|
||||||
|
}
|
||||||
|
} // namespace
|
||||||
|
|
||||||
::mediapipe::Status TfLiteInferenceCalculator::GetContract(
|
::mediapipe::Status TfLiteInferenceCalculator::GetContract(
|
||||||
CalculatorContract* cc) {
|
CalculatorContract* cc) {
|
||||||
RET_CHECK(cc->Inputs().HasTag(kTensorsTag) ^
|
RET_CHECK(cc->Inputs().HasTag(kTensorsTag) ^
|
||||||
|
@ -276,32 +308,15 @@ REGISTER_CALCULATOR(TfLiteInferenceCalculator);
|
||||||
cc->InputSidePackets().HasTag("MODEL"))
|
cc->InputSidePackets().HasTag("MODEL"))
|
||||||
<< "Either model as side packet or model path in options is required.";
|
<< "Either model as side packet or model path in options is required.";
|
||||||
|
|
||||||
bool use_gpu =
|
|
||||||
options.has_delegate() ? options.delegate().has_gpu() : options.use_gpu();
|
|
||||||
|
|
||||||
if (cc->Inputs().HasTag(kTensorsTag))
|
if (cc->Inputs().HasTag(kTensorsTag))
|
||||||
cc->Inputs().Tag(kTensorsTag).Set<std::vector<TfLiteTensor>>();
|
cc->Inputs().Tag(kTensorsTag).Set<std::vector<TfLiteTensor>>();
|
||||||
#if !defined(MEDIAPIPE_DISABLE_GPU) && !defined(__EMSCRIPTEN__)
|
|
||||||
if (cc->Inputs().HasTag(kTensorsGpuTag)) {
|
|
||||||
RET_CHECK(!options.has_delegate() || options.delegate().has_gpu())
|
|
||||||
<< "GPU input is compatible with GPU delegate only.";
|
|
||||||
|
|
||||||
cc->Inputs().Tag(kTensorsGpuTag).Set<std::vector<GpuTensor>>();
|
|
||||||
use_gpu |= true;
|
|
||||||
}
|
|
||||||
#endif // !MEDIAPIPE_DISABLE_GPU
|
|
||||||
|
|
||||||
if (cc->Outputs().HasTag(kTensorsTag))
|
if (cc->Outputs().HasTag(kTensorsTag))
|
||||||
cc->Outputs().Tag(kTensorsTag).Set<std::vector<TfLiteTensor>>();
|
cc->Outputs().Tag(kTensorsTag).Set<std::vector<TfLiteTensor>>();
|
||||||
#if !defined(MEDIAPIPE_DISABLE_GPU) && !defined(__EMSCRIPTEN__)
|
|
||||||
if (cc->Outputs().HasTag(kTensorsGpuTag)) {
|
|
||||||
RET_CHECK(!options.has_delegate() || options.delegate().has_gpu())
|
|
||||||
<< "GPU output is compatible with GPU delegate only.";
|
|
||||||
|
|
||||||
|
if (cc->Inputs().HasTag(kTensorsGpuTag))
|
||||||
|
cc->Inputs().Tag(kTensorsGpuTag).Set<std::vector<GpuTensor>>();
|
||||||
|
if (cc->Outputs().HasTag(kTensorsGpuTag))
|
||||||
cc->Outputs().Tag(kTensorsGpuTag).Set<std::vector<GpuTensor>>();
|
cc->Outputs().Tag(kTensorsGpuTag).Set<std::vector<GpuTensor>>();
|
||||||
use_gpu |= true;
|
|
||||||
}
|
|
||||||
#endif // !MEDIAPIPE_DISABLE_GPU
|
|
||||||
|
|
||||||
if (cc->InputSidePackets().HasTag("CUSTOM_OP_RESOLVER")) {
|
if (cc->InputSidePackets().HasTag("CUSTOM_OP_RESOLVER")) {
|
||||||
cc->InputSidePackets()
|
cc->InputSidePackets()
|
||||||
|
@ -312,10 +327,10 @@ REGISTER_CALCULATOR(TfLiteInferenceCalculator);
|
||||||
cc->InputSidePackets().Tag("MODEL").Set<TfLiteModelPtr>();
|
cc->InputSidePackets().Tag("MODEL").Set<TfLiteModelPtr>();
|
||||||
}
|
}
|
||||||
|
|
||||||
if (use_gpu) {
|
if (ShouldUseGpu(cc)) {
|
||||||
#if !defined(MEDIAPIPE_DISABLE_GL_COMPUTE)
|
#if MEDIAPIPE_TFLITE_GL_INFERENCE
|
||||||
MP_RETURN_IF_ERROR(mediapipe::GlCalculatorHelper::UpdateContract(cc));
|
MP_RETURN_IF_ERROR(mediapipe::GlCalculatorHelper::UpdateContract(cc));
|
||||||
#elif defined(MEDIAPIPE_IOS)
|
#elif MEDIAPIPE_TFLITE_METAL_INFERENCE
|
||||||
MP_RETURN_IF_ERROR([MPPMetalHelper updateContract:cc]);
|
MP_RETURN_IF_ERROR([MPPMetalHelper updateContract:cc]);
|
||||||
#endif
|
#endif
|
||||||
}
|
}
|
||||||
|
@ -331,149 +346,111 @@ REGISTER_CALCULATOR(TfLiteInferenceCalculator);
|
||||||
|
|
||||||
const auto& options =
|
const auto& options =
|
||||||
cc->Options<::mediapipe::TfLiteInferenceCalculatorOptions>();
|
cc->Options<::mediapipe::TfLiteInferenceCalculatorOptions>();
|
||||||
gpu_inference_ = options.use_gpu();
|
|
||||||
|
|
||||||
if (cc->Inputs().HasTag(kTensorsGpuTag)) {
|
gpu_inference_ = ShouldUseGpu(cc);
|
||||||
#if !defined(MEDIAPIPE_DISABLE_GPU) && !defined(__EMSCRIPTEN__)
|
gpu_input_ = cc->Inputs().HasTag(kTensorsGpuTag);
|
||||||
gpu_input_ = true;
|
gpu_output_ = cc->Outputs().HasTag(kTensorsGpuTag);
|
||||||
gpu_inference_ = true; // Inference must be on GPU also.
|
|
||||||
#else
|
|
||||||
RET_CHECK(!cc->Inputs().HasTag(kTensorsGpuTag))
|
|
||||||
<< "GPU processing not enabled.";
|
|
||||||
#endif // !MEDIAPIPE_DISABLE_GPU
|
|
||||||
}
|
|
||||||
|
|
||||||
if (cc->Outputs().HasTag(kTensorsGpuTag)) {
|
use_advanced_gpu_api_ = MEDIAPIPE_TFLITE_GL_INFERENCE &&
|
||||||
#if !defined(MEDIAPIPE_DISABLE_GPU) && !defined(__EMSCRIPTEN__)
|
options.has_delegate() &&
|
||||||
gpu_output_ = true;
|
options.delegate().has_gpu() &&
|
||||||
RET_CHECK(cc->Inputs().HasTag(kTensorsGpuTag))
|
options.delegate().gpu().use_advanced_gpu_api();
|
||||||
<< "GPU output must also have GPU Input.";
|
if (use_advanced_gpu_api_ && !gpu_input_) {
|
||||||
#else
|
LOG(WARNING) << "Cannot use advanced GPU APIs, input must be GPU buffers."
|
||||||
RET_CHECK(!cc->Inputs().HasTag(kTensorsGpuTag))
|
"Falling back to the default TFLite API.";
|
||||||
<< "GPU processing not enabled.";
|
|
||||||
#endif // !MEDIAPIPE_DISABLE_GPU
|
|
||||||
}
|
|
||||||
|
|
||||||
use_advanced_gpu_api_ = false;
|
|
||||||
if (use_advanced_gpu_api_ && !(gpu_input_ && gpu_output_)) {
|
|
||||||
LOG(WARNING)
|
|
||||||
<< "Cannot use advanced GPU APIs, both inputs and outputs must "
|
|
||||||
"be GPU buffers. Falling back to the default TFLite API.";
|
|
||||||
use_advanced_gpu_api_ = false;
|
use_advanced_gpu_api_ = false;
|
||||||
}
|
}
|
||||||
|
CHECK(!use_advanced_gpu_api_ || gpu_inference_);
|
||||||
|
|
||||||
MP_RETURN_IF_ERROR(LoadModel(cc));
|
MP_RETURN_IF_ERROR(LoadModel(cc));
|
||||||
|
|
||||||
if (gpu_inference_) {
|
if (gpu_inference_) {
|
||||||
#if !defined(MEDIAPIPE_DISABLE_GL_COMPUTE)
|
#if MEDIAPIPE_TFLITE_GL_INFERENCE
|
||||||
MP_RETURN_IF_ERROR(gpu_helper_.Open(cc));
|
MP_RETURN_IF_ERROR(gpu_helper_.Open(cc));
|
||||||
#elif defined(MEDIAPIPE_IOS)
|
|
||||||
gpu_helper_ = [[MPPMetalHelper alloc] initWithCalculatorContext:cc];
|
|
||||||
RET_CHECK(gpu_helper_);
|
|
||||||
#endif
|
|
||||||
|
|
||||||
#if !defined(MEDIAPIPE_DISABLE_GL_COMPUTE)
|
|
||||||
MP_RETURN_IF_ERROR(
|
MP_RETURN_IF_ERROR(
|
||||||
gpu_helper_.RunInGlContext([this, &cc]() -> ::mediapipe::Status {
|
gpu_helper_.RunInGlContext([this, &cc]() -> ::mediapipe::Status {
|
||||||
return use_advanced_gpu_api_ ? InitTFLiteGPURunner()
|
return use_advanced_gpu_api_ ? InitTFLiteGPURunner(cc)
|
||||||
: LoadDelegate(cc);
|
: LoadDelegate(cc);
|
||||||
}));
|
}));
|
||||||
if (use_advanced_gpu_api_) return ::mediapipe::OkStatus();
|
#elif MEDIAPIPE_TFLITE_METAL_INFERENCE
|
||||||
#else
|
gpu_helper_ = [[MPPMetalHelper alloc] initWithCalculatorContext:cc];
|
||||||
|
RET_CHECK(gpu_helper_);
|
||||||
MP_RETURN_IF_ERROR(LoadDelegate(cc));
|
MP_RETURN_IF_ERROR(LoadDelegate(cc));
|
||||||
#endif
|
#endif
|
||||||
} else {
|
} else {
|
||||||
#if defined(__EMSCRIPTEN__) || defined(MEDIAPIPE_ANDROID)
|
// TODO: why only on these platforms?
|
||||||
|
// It seems that the XNNPACK delegate fails to load on Linux.
|
||||||
|
#if defined(__EMSCRIPTEN__) || defined(MEDIAPIPE_ANDROID) || \
|
||||||
|
defined(MEDIAPIPE_IOS)
|
||||||
MP_RETURN_IF_ERROR(LoadDelegate(cc));
|
MP_RETURN_IF_ERROR(LoadDelegate(cc));
|
||||||
#endif // __EMSCRIPTEN__ || ANDROID
|
#endif // __EMSCRIPTEN__ || MEDIAPIPE_ANDROID || MEDIAPIPE_IOS
|
||||||
}
|
}
|
||||||
return ::mediapipe::OkStatus();
|
return ::mediapipe::OkStatus();
|
||||||
}
|
}
|
||||||
|
|
||||||
::mediapipe::Status TfLiteInferenceCalculator::Process(CalculatorContext* cc) {
|
::mediapipe::Status TfLiteInferenceCalculator::Process(CalculatorContext* cc) {
|
||||||
|
return RunInContextIfNeeded([this, cc]() -> ::mediapipe::Status {
|
||||||
// 0. Declare outputs
|
// 0. Declare outputs
|
||||||
#if !defined(MEDIAPIPE_DISABLE_GL_COMPUTE) || defined(MEDIAPIPE_IOS)
|
|
||||||
auto output_tensors_gpu = absl::make_unique<std::vector<GpuTensor>>();
|
auto output_tensors_gpu = absl::make_unique<std::vector<GpuTensor>>();
|
||||||
#endif
|
|
||||||
auto output_tensors_cpu = absl::make_unique<std::vector<TfLiteTensor>>();
|
auto output_tensors_cpu = absl::make_unique<std::vector<TfLiteTensor>>();
|
||||||
|
|
||||||
// 1. Receive pre-processed tensor inputs.
|
// 1. Receive pre-processed tensor inputs.
|
||||||
if (use_advanced_gpu_api_ && gpu_output_) {
|
if (gpu_input_) {
|
||||||
#if !defined(MEDIAPIPE_DISABLE_GL_COMPUTE)
|
MP_RETURN_IF_ERROR(ProcessInputsGpu(cc, output_tensors_gpu.get()));
|
||||||
if (cc->Inputs().Tag(kTensorsGpuTag).IsEmpty()) {
|
} else {
|
||||||
return ::mediapipe::OkStatus();
|
MP_RETURN_IF_ERROR(ProcessInputsCpu(cc, output_tensors_cpu.get()));
|
||||||
}
|
}
|
||||||
const auto& input_tensors =
|
|
||||||
cc->Inputs().Tag(kTensorsGpuTag).Get<std::vector<GpuTensor>>();
|
// 2. Run inference.
|
||||||
RET_CHECK(!input_tensors.empty());
|
#if MEDIAPIPE_TFLITE_GL_INFERENCE
|
||||||
MP_RETURN_IF_ERROR(gpu_helper_.RunInGlContext(
|
if (gpu_inference_ && use_advanced_gpu_api_) {
|
||||||
[this, &input_tensors, &output_tensors_gpu]() -> ::mediapipe::Status {
|
RET_CHECK(tflite_gpu_runner_->Invoke().ok());
|
||||||
for (int i = 0; i < input_tensors.size(); ++i) {
|
} else {
|
||||||
MP_RETURN_IF_ERROR(tflite_gpu_runner_->BindSSBOToInputTensor(
|
RET_CHECK_EQ(interpreter_->Invoke(), kTfLiteOk);
|
||||||
input_tensors[i].id(), i));
|
|
||||||
}
|
}
|
||||||
// Allocate output tensor.
|
#else
|
||||||
output_tensors_gpu->resize(gpu_data_out_.size());
|
RET_CHECK_EQ(interpreter_->Invoke(), kTfLiteOk);
|
||||||
for (int i = 0; i < gpu_data_out_.size(); ++i) {
|
#endif // MEDIAPIPE_TFLITE_GL_INFERENCE
|
||||||
GpuTensor& tensor = output_tensors_gpu->at(i);
|
|
||||||
RET_CHECK_CALL(CreateReadWriteShaderStorageBuffer<float>(
|
// 3. Output processed tensors.
|
||||||
gpu_data_out_[i]->elements, &tensor));
|
if (gpu_output_ || use_advanced_gpu_api_) {
|
||||||
MP_RETURN_IF_ERROR(
|
MP_RETURN_IF_ERROR(ProcessOutputsGpu(cc, std::move(output_tensors_cpu),
|
||||||
tflite_gpu_runner_->BindSSBOToOutputTensor(tensor.id(), i));
|
std::move(output_tensors_gpu)));
|
||||||
}
|
} else {
|
||||||
return ::mediapipe::OkStatus();
|
MP_RETURN_IF_ERROR(ProcessOutputsCpu(cc, std::move(output_tensors_cpu)));
|
||||||
}));
|
|
||||||
#endif
|
|
||||||
} else if (gpu_input_) {
|
|
||||||
// Read GPU input into SSBO.
|
|
||||||
#if !defined(MEDIAPIPE_DISABLE_GL_COMPUTE)
|
|
||||||
if (cc->Inputs().Tag(kTensorsGpuTag).IsEmpty()) {
|
|
||||||
return ::mediapipe::OkStatus();
|
|
||||||
}
|
|
||||||
const auto& input_tensors =
|
|
||||||
cc->Inputs().Tag(kTensorsGpuTag).Get<std::vector<GpuTensor>>();
|
|
||||||
RET_CHECK_GT(input_tensors.size(), 0);
|
|
||||||
MP_RETURN_IF_ERROR(gpu_helper_.RunInGlContext(
|
|
||||||
[this, &input_tensors]() -> ::mediapipe::Status {
|
|
||||||
// Explicit copy input.
|
|
||||||
gpu_data_in_.resize(input_tensors.size());
|
|
||||||
for (int i = 0; i < input_tensors.size(); ++i) {
|
|
||||||
RET_CHECK_CALL(
|
|
||||||
CopyBuffer(input_tensors[i], gpu_data_in_[i]->buffer));
|
|
||||||
}
|
}
|
||||||
|
|
||||||
return ::mediapipe::OkStatus();
|
return ::mediapipe::OkStatus();
|
||||||
}));
|
});
|
||||||
#elif defined(MEDIAPIPE_IOS)
|
|
||||||
if (cc->Inputs().Tag(kTensorsGpuTag).IsEmpty()) {
|
|
||||||
return ::mediapipe::OkStatus();
|
|
||||||
}
|
}
|
||||||
const auto& input_tensors =
|
|
||||||
cc->Inputs().Tag(kTensorsGpuTag).Get<std::vector<GpuTensor>>();
|
::mediapipe::Status TfLiteInferenceCalculator::Close(CalculatorContext* cc) {
|
||||||
RET_CHECK_GT(input_tensors.size(), 0);
|
return RunInContextIfNeeded([this]() -> ::mediapipe::Status {
|
||||||
// Explicit copy input with conversion float 32 bits to 16 bits.
|
if (delegate_) {
|
||||||
gpu_data_in_.resize(input_tensors.size());
|
interpreter_ = nullptr;
|
||||||
id<MTLCommandBuffer> command_buffer = [gpu_helper_ commandBuffer];
|
delegate_ = nullptr;
|
||||||
command_buffer.label = @"TfLiteInferenceCalculatorConvert";
|
#if MEDIAPIPE_TFLITE_GPU_SUPPORTED
|
||||||
id<MTLComputeCommandEncoder> compute_encoder =
|
if (gpu_inference_) {
|
||||||
[command_buffer computeCommandEncoder];
|
for (int i = 0; i < gpu_data_in_.size(); ++i) {
|
||||||
[compute_encoder setComputePipelineState:fp32_to_fp16_program_];
|
gpu_data_in_[i].reset();
|
||||||
for (int i = 0; i < input_tensors.size(); ++i) {
|
|
||||||
[compute_encoder setBuffer:input_tensors[i] offset:0 atIndex:0];
|
|
||||||
[compute_encoder setBuffer:gpu_data_in_[i]->buffer offset:0 atIndex:1];
|
|
||||||
constexpr int kWorkgroupSize = 64; // Block size for GPU shader.
|
|
||||||
MTLSize threads_per_group = MTLSizeMake(kWorkgroupSize, 1, 1);
|
|
||||||
const int threadgroups =
|
|
||||||
NumGroups(gpu_data_in_[i]->elements, kWorkgroupSize);
|
|
||||||
[compute_encoder dispatchThreadgroups:MTLSizeMake(threadgroups, 1, 1)
|
|
||||||
threadsPerThreadgroup:threads_per_group];
|
|
||||||
}
|
}
|
||||||
[compute_encoder endEncoding];
|
for (int i = 0; i < gpu_data_out_.size(); ++i) {
|
||||||
[command_buffer commit];
|
gpu_data_out_[i].reset();
|
||||||
#else
|
}
|
||||||
RET_CHECK_FAIL() << "GPU processing not enabled.";
|
}
|
||||||
|
#endif // MEDIAPIPE_TFLITE_GPU_SUPPORTED
|
||||||
|
}
|
||||||
|
#if defined(MEDIAPIPE_EDGE_TPU)
|
||||||
|
edgetpu_context_.reset();
|
||||||
#endif
|
#endif
|
||||||
} else {
|
return ::mediapipe::OkStatus();
|
||||||
|
});
|
||||||
|
}
|
||||||
|
|
||||||
|
// Calculator Auxiliary Section
|
||||||
|
|
||||||
|
::mediapipe::Status TfLiteInferenceCalculator::ProcessInputsCpu(
|
||||||
|
CalculatorContext* cc, std::vector<TfLiteTensor>* output_tensors_cpu) {
|
||||||
if (cc->Inputs().Tag(kTensorsTag).IsEmpty()) {
|
if (cc->Inputs().Tag(kTensorsTag).IsEmpty()) {
|
||||||
return ::mediapipe::OkStatus();
|
return ::mediapipe::OkStatus();
|
||||||
}
|
}
|
||||||
|
@ -496,39 +473,128 @@ REGISTER_CALCULATOR(TfLiteInferenceCalculator);
|
||||||
input_tensor->bytes);
|
input_tensor->bytes);
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
|
||||||
|
|
||||||
// 2. Run inference.
|
|
||||||
if (gpu_inference_) {
|
|
||||||
#if !defined(MEDIAPIPE_DISABLE_GL_COMPUTE)
|
|
||||||
MP_RETURN_IF_ERROR(
|
|
||||||
gpu_helper_.RunInGlContext([this]() -> ::mediapipe::Status {
|
|
||||||
if (use_advanced_gpu_api_) {
|
|
||||||
RET_CHECK(tflite_gpu_runner_->Invoke().ok());
|
|
||||||
} else {
|
|
||||||
RET_CHECK_EQ(interpreter_->Invoke(), kTfLiteOk);
|
|
||||||
}
|
|
||||||
return ::mediapipe::OkStatus();
|
return ::mediapipe::OkStatus();
|
||||||
}));
|
|
||||||
#elif defined(MEDIAPIPE_IOS)
|
|
||||||
RET_CHECK_EQ(interpreter_->Invoke(), kTfLiteOk);
|
|
||||||
#endif
|
|
||||||
} else {
|
|
||||||
RET_CHECK_EQ(interpreter_->Invoke(), kTfLiteOk);
|
|
||||||
}
|
}
|
||||||
|
|
||||||
// 3. Output processed tensors.
|
::mediapipe::Status TfLiteInferenceCalculator::ProcessInputsGpu(
|
||||||
|
CalculatorContext* cc, std::vector<GpuTensor>* output_tensors_gpu) {
|
||||||
|
if (cc->Inputs().Tag(kTensorsGpuTag).IsEmpty()) {
|
||||||
|
return ::mediapipe::OkStatus();
|
||||||
|
}
|
||||||
if (use_advanced_gpu_api_) {
|
if (use_advanced_gpu_api_) {
|
||||||
#if !defined(MEDIAPIPE_DISABLE_GL_COMPUTE)
|
#if MEDIAPIPE_TFLITE_GL_INFERENCE
|
||||||
|
const auto& input_tensors =
|
||||||
|
cc->Inputs().Tag(kTensorsGpuTag).Get<std::vector<GpuTensor>>();
|
||||||
|
RET_CHECK(!input_tensors.empty());
|
||||||
|
for (int i = 0; i < input_tensors.size(); ++i) {
|
||||||
|
MP_RETURN_IF_ERROR(
|
||||||
|
tflite_gpu_runner_->BindSSBOToInputTensor(input_tensors[i].id(), i));
|
||||||
|
}
|
||||||
|
if (gpu_output_) {
|
||||||
|
// Allocate new output tensor.
|
||||||
|
output_tensors_gpu->resize(gpu_data_out_.size());
|
||||||
|
for (int i = 0; i < gpu_data_out_.size(); ++i) {
|
||||||
|
GpuTensor& tensor = output_tensors_gpu->at(i);
|
||||||
|
RET_CHECK_CALL(CreateReadWriteShaderStorageBuffer<float>(
|
||||||
|
gpu_data_out_[i]->elements, &tensor));
|
||||||
|
MP_RETURN_IF_ERROR(
|
||||||
|
tflite_gpu_runner_->BindSSBOToOutputTensor(tensor.id(), i));
|
||||||
|
}
|
||||||
|
} else {
|
||||||
|
// Re-use internal output tensor.
|
||||||
|
for (int i = 0; i < gpu_data_out_.size(); ++i) {
|
||||||
|
MP_RETURN_IF_ERROR(tflite_gpu_runner_->BindSSBOToOutputTensor(
|
||||||
|
gpu_data_out_[i]->buffer.id(), i));
|
||||||
|
}
|
||||||
|
}
|
||||||
|
#endif // MEDIAPIPE_TFLITE_GL_INFERENCE
|
||||||
|
} else if (gpu_input_) {
|
||||||
|
// Read GPU input into SSBO.
|
||||||
|
#if MEDIAPIPE_TFLITE_GL_INFERENCE
|
||||||
|
const auto& input_tensors =
|
||||||
|
cc->Inputs().Tag(kTensorsGpuTag).Get<std::vector<GpuTensor>>();
|
||||||
|
RET_CHECK_GT(input_tensors.size(), 0);
|
||||||
|
// Explicit copy input.
|
||||||
|
gpu_data_in_.resize(input_tensors.size());
|
||||||
|
for (int i = 0; i < input_tensors.size(); ++i) {
|
||||||
|
RET_CHECK_CALL(CopyBuffer(input_tensors[i], gpu_data_in_[i]->buffer));
|
||||||
|
}
|
||||||
|
#elif MEDIAPIPE_TFLITE_METAL_INFERENCE
|
||||||
|
const auto& input_tensors =
|
||||||
|
cc->Inputs().Tag(kTensorsGpuTag).Get<std::vector<GpuTensor>>();
|
||||||
|
RET_CHECK_GT(input_tensors.size(), 0);
|
||||||
|
// Explicit copy input with conversion float 32 bits to 16 bits.
|
||||||
|
gpu_data_in_.resize(input_tensors.size());
|
||||||
|
id<MTLCommandBuffer> command_buffer = [gpu_helper_ commandBuffer];
|
||||||
|
command_buffer.label = @"TfLiteInferenceCalculatorConvert";
|
||||||
|
id<MTLComputeCommandEncoder> compute_encoder =
|
||||||
|
[command_buffer computeCommandEncoder];
|
||||||
|
[compute_encoder setComputePipelineState:fp32_to_fp16_program_];
|
||||||
|
for (int i = 0; i < input_tensors.size(); ++i) {
|
||||||
|
[compute_encoder setBuffer:input_tensors[i] offset:0 atIndex:0];
|
||||||
|
[compute_encoder setBuffer:gpu_data_in_[i]->buffer offset:0 atIndex:1];
|
||||||
|
constexpr int kWorkgroupSize = 64; // Block size for GPU shader.
|
||||||
|
MTLSize threads_per_group = MTLSizeMake(kWorkgroupSize, 1, 1);
|
||||||
|
const int threadgroups =
|
||||||
|
NumGroups(gpu_data_in_[i]->elements, kWorkgroupSize);
|
||||||
|
[compute_encoder dispatchThreadgroups:MTLSizeMake(threadgroups, 1, 1)
|
||||||
|
threadsPerThreadgroup:threads_per_group];
|
||||||
|
}
|
||||||
|
[compute_encoder endEncoding];
|
||||||
|
[command_buffer commit];
|
||||||
|
#endif // MEDIAPIPE_TFLITE_GL_INFERENCE
|
||||||
|
}
|
||||||
|
|
||||||
|
return ::mediapipe::OkStatus();
|
||||||
|
}
|
||||||
|
|
||||||
|
::mediapipe::Status TfLiteInferenceCalculator::ProcessOutputsCpu(
|
||||||
|
CalculatorContext* cc,
|
||||||
|
std::unique_ptr<std::vector<TfLiteTensor>> output_tensors_cpu) {
|
||||||
|
// Output result tensors (CPU).
|
||||||
|
const auto& tensor_indexes = interpreter_->outputs();
|
||||||
|
for (int i = 0; i < tensor_indexes.size(); ++i) {
|
||||||
|
TfLiteTensor* tensor = interpreter_->tensor(tensor_indexes[i]);
|
||||||
|
output_tensors_cpu->emplace_back(*tensor);
|
||||||
|
}
|
||||||
|
cc->Outputs()
|
||||||
|
.Tag(kTensorsTag)
|
||||||
|
.Add(output_tensors_cpu.release(), cc->InputTimestamp());
|
||||||
|
|
||||||
|
return ::mediapipe::OkStatus();
|
||||||
|
}
|
||||||
|
|
||||||
|
::mediapipe::Status TfLiteInferenceCalculator::ProcessOutputsGpu(
|
||||||
|
CalculatorContext* cc,
|
||||||
|
std::unique_ptr<std::vector<TfLiteTensor>> output_tensors_cpu,
|
||||||
|
std::unique_ptr<std::vector<GpuTensor>> output_tensors_gpu) {
|
||||||
|
if (use_advanced_gpu_api_) {
|
||||||
|
#if MEDIAPIPE_TFLITE_GL_INFERENCE
|
||||||
|
if (gpu_output_) {
|
||||||
|
// Send out pre-allocated tensors.
|
||||||
cc->Outputs()
|
cc->Outputs()
|
||||||
.Tag(kTensorsGpuTag)
|
.Tag(kTensorsGpuTag)
|
||||||
.Add(output_tensors_gpu.release(), cc->InputTimestamp());
|
.Add(output_tensors_gpu.release(), cc->InputTimestamp());
|
||||||
#endif
|
} else {
|
||||||
|
// Download to CPU for output.
|
||||||
|
const auto& tensor_indexes = interpreter_->inputs();
|
||||||
|
for (int i = 0; i < tensor_indexes.size(); ++i) {
|
||||||
|
TfLiteTensor* tensor = interpreter_->tensor(tensor_indexes[i]);
|
||||||
|
std::vector<float> gpu_data(tensor->bytes / sizeof(float));
|
||||||
|
RET_CHECK_CALL(gpu_data_out_[i]->buffer.Read(
|
||||||
|
absl::MakeSpan(tensor->data.f, tensor->bytes)));
|
||||||
|
output_tensors_cpu->emplace_back(*tensor);
|
||||||
|
}
|
||||||
|
// Output result tensors (CPU).
|
||||||
|
cc->Outputs()
|
||||||
|
.Tag(kTensorsTag)
|
||||||
|
.Add(output_tensors_cpu.release(), cc->InputTimestamp());
|
||||||
|
}
|
||||||
|
#endif // MEDIAPIPE_TFLITE_GL_INFERENCE
|
||||||
} else if (gpu_output_) {
|
} else if (gpu_output_) {
|
||||||
#if !defined(MEDIAPIPE_DISABLE_GL_COMPUTE)
|
#if MEDIAPIPE_TFLITE_GL_INFERENCE
|
||||||
// Output result tensors (GPU).
|
// Output result tensors (GPU).
|
||||||
MP_RETURN_IF_ERROR(gpu_helper_.RunInGlContext(
|
|
||||||
[this, &output_tensors_gpu]() -> ::mediapipe::Status {
|
|
||||||
output_tensors_gpu->resize(gpu_data_out_.size());
|
output_tensors_gpu->resize(gpu_data_out_.size());
|
||||||
for (int i = 0; i < gpu_data_out_.size(); ++i) {
|
for (int i = 0; i < gpu_data_out_.size(); ++i) {
|
||||||
GpuTensor& tensor = output_tensors_gpu->at(i);
|
GpuTensor& tensor = output_tensors_gpu->at(i);
|
||||||
|
@ -537,12 +603,10 @@ REGISTER_CALCULATOR(TfLiteInferenceCalculator);
|
||||||
gpu_data_out_[i]->elements, &tensor));
|
gpu_data_out_[i]->elements, &tensor));
|
||||||
RET_CHECK_CALL(CopyBuffer(gpu_data_out_[i]->buffer, tensor));
|
RET_CHECK_CALL(CopyBuffer(gpu_data_out_[i]->buffer, tensor));
|
||||||
}
|
}
|
||||||
return ::mediapipe::OkStatus();
|
|
||||||
}));
|
|
||||||
cc->Outputs()
|
cc->Outputs()
|
||||||
.Tag(kTensorsGpuTag)
|
.Tag(kTensorsGpuTag)
|
||||||
.Add(output_tensors_gpu.release(), cc->InputTimestamp());
|
.Add(output_tensors_gpu.release(), cc->InputTimestamp());
|
||||||
#elif defined(MEDIAPIPE_IOS)
|
#elif MEDIAPIPE_TFLITE_METAL_INFERENCE
|
||||||
// Output result tensors (GPU).
|
// Output result tensors (GPU).
|
||||||
output_tensors_gpu->resize(gpu_data_out_.size());
|
output_tensors_gpu->resize(gpu_data_out_.size());
|
||||||
id<MTLDevice> device = gpu_helper_.mtlDevice;
|
id<MTLDevice> device = gpu_helper_.mtlDevice;
|
||||||
|
@ -566,68 +630,58 @@ REGISTER_CALCULATOR(TfLiteInferenceCalculator);
|
||||||
cc->Outputs()
|
cc->Outputs()
|
||||||
.Tag(kTensorsGpuTag)
|
.Tag(kTensorsGpuTag)
|
||||||
.Add(output_tensors_gpu.release(), cc->InputTimestamp());
|
.Add(output_tensors_gpu.release(), cc->InputTimestamp());
|
||||||
#else
|
#endif // MEDIAPIPE_TFLITE_GL_INFERENCE
|
||||||
RET_CHECK_FAIL() << "GPU processing not enabled.";
|
|
||||||
#endif // !MEDIAPIPE_DISABLE_GPU
|
|
||||||
} else {
|
|
||||||
// Output result tensors (CPU).
|
|
||||||
const auto& tensor_indexes = interpreter_->outputs();
|
|
||||||
for (int i = 0; i < tensor_indexes.size(); ++i) {
|
|
||||||
TfLiteTensor* tensor = interpreter_->tensor(tensor_indexes[i]);
|
|
||||||
output_tensors_cpu->emplace_back(*tensor);
|
|
||||||
}
|
|
||||||
cc->Outputs()
|
|
||||||
.Tag(kTensorsTag)
|
|
||||||
.Add(output_tensors_cpu.release(), cc->InputTimestamp());
|
|
||||||
}
|
}
|
||||||
|
|
||||||
return ::mediapipe::OkStatus();
|
return ::mediapipe::OkStatus();
|
||||||
}
|
}
|
||||||
|
|
||||||
::mediapipe::Status TfLiteInferenceCalculator::Close(CalculatorContext* cc) {
|
::mediapipe::Status TfLiteInferenceCalculator::InitTFLiteGPURunner(
|
||||||
if (delegate_) {
|
CalculatorContext* cc) {
|
||||||
if (gpu_inference_) {
|
#if MEDIAPIPE_TFLITE_GL_INFERENCE
|
||||||
#if !defined(MEDIAPIPE_DISABLE_GL_COMPUTE)
|
ASSIGN_OR_RETURN(model_packet_, GetModelAsPacket(*cc));
|
||||||
MP_RETURN_IF_ERROR(gpu_helper_.RunInGlContext([this]() -> Status {
|
const auto& model = *model_packet_.Get<TfLiteModelPtr>();
|
||||||
interpreter_ = nullptr;
|
tflite::ops::builtin::BuiltinOpResolver op_resolver;
|
||||||
delegate_ = nullptr;
|
if (cc->InputSidePackets().HasTag("CUSTOM_OP_RESOLVER")) {
|
||||||
for (int i = 0; i < gpu_data_in_.size(); ++i) {
|
op_resolver = cc->InputSidePackets()
|
||||||
gpu_data_in_[i].reset();
|
.Tag("CUSTOM_OP_RESOLVER")
|
||||||
}
|
.Get<tflite::ops::builtin::BuiltinOpResolver>();
|
||||||
for (int i = 0; i < gpu_data_out_.size(); ++i) {
|
|
||||||
gpu_data_out_[i].reset();
|
|
||||||
}
|
|
||||||
return ::mediapipe::OkStatus();
|
|
||||||
}));
|
|
||||||
#elif defined(MEDIAPIPE_IOS)
|
|
||||||
interpreter_ = nullptr;
|
|
||||||
delegate_ = nullptr;
|
|
||||||
for (int i = 0; i < gpu_data_in_.size(); ++i) {
|
|
||||||
gpu_data_in_[i].reset();
|
|
||||||
}
|
|
||||||
for (int i = 0; i < gpu_data_out_.size(); ++i) {
|
|
||||||
gpu_data_out_[i].reset();
|
|
||||||
}
|
|
||||||
#endif
|
|
||||||
} else {
|
|
||||||
interpreter_ = nullptr;
|
|
||||||
delegate_ = nullptr;
|
|
||||||
}
|
|
||||||
}
|
|
||||||
#if defined(MEDIAPIPE_EDGE_TPU)
|
|
||||||
edgetpu_context_.reset();
|
|
||||||
#endif
|
|
||||||
return ::mediapipe::OkStatus();
|
|
||||||
}
|
}
|
||||||
|
|
||||||
// Calculator Auxiliary Section
|
// Create runner
|
||||||
|
tflite::gpu::InferenceOptions options;
|
||||||
|
options.priority1 = tflite::gpu::InferencePriority::MIN_LATENCY;
|
||||||
|
options.priority2 = tflite::gpu::InferencePriority::AUTO;
|
||||||
|
options.priority3 = tflite::gpu::InferencePriority::AUTO;
|
||||||
|
options.usage = tflite::gpu::InferenceUsage::SUSTAINED_SPEED;
|
||||||
|
tflite_gpu_runner_ = std::make_unique<tflite::gpu::TFLiteGPURunner>(options);
|
||||||
|
RET_CHECK_CALL(tflite_gpu_runner_->InitializeWithModel(model, op_resolver));
|
||||||
|
|
||||||
|
// Allocate interpreter memory for cpu output.
|
||||||
|
if (!gpu_output_) {
|
||||||
|
interpreter_ = absl::make_unique<tflite::Interpreter>();
|
||||||
|
const int num_outputs = tflite_gpu_runner_->GetOutputShapes().size();
|
||||||
|
interpreter_->AddTensors(num_outputs);
|
||||||
|
std::vector<int> indices(num_outputs);
|
||||||
|
for (int i = 0; i < num_outputs; ++i) indices[i] = i;
|
||||||
|
// There is no ResizeOutputTensor(), so we use 'inputs' space instead.
|
||||||
|
interpreter_->SetInputs(indices);
|
||||||
|
TfLiteQuantization quant;
|
||||||
|
quant.type = kTfLiteNoQuantization;
|
||||||
|
quant.params = nullptr;
|
||||||
|
for (int i = 0; i < num_outputs; ++i) {
|
||||||
|
auto shape = tflite_gpu_runner_->GetOutputShapes()[i];
|
||||||
|
const int tensor_idx = interpreter_->inputs()[i];
|
||||||
|
interpreter_->SetTensorParametersReadWrite(tensor_idx, kTfLiteFloat32, "",
|
||||||
|
{shape.c}, quant);
|
||||||
|
CHECK(interpreter_->ResizeInputTensor(
|
||||||
|
tensor_idx, {shape.h, shape.w, shape.c}) == kTfLiteOk);
|
||||||
|
}
|
||||||
|
CHECK(interpreter_->AllocateTensors() == kTfLiteOk);
|
||||||
|
}
|
||||||
|
|
||||||
::mediapipe::Status TfLiteInferenceCalculator::InitTFLiteGPURunner() {
|
|
||||||
#if !defined(MEDIAPIPE_DISABLE_GL_COMPUTE)
|
|
||||||
// Create and bind OpenGL buffers for outputs.
|
// Create and bind OpenGL buffers for outputs.
|
||||||
// These buffers are created onve and later their ids are jut passed to the
|
// The buffers are created once and their ids are passed to calculator outputs
|
||||||
// calculator outputs.
|
|
||||||
|
|
||||||
gpu_data_out_.resize(tflite_gpu_runner_->outputs_size());
|
gpu_data_out_.resize(tflite_gpu_runner_->outputs_size());
|
||||||
for (int i = 0; i < tflite_gpu_runner_->outputs_size(); ++i) {
|
for (int i = 0; i < tflite_gpu_runner_->outputs_size(); ++i) {
|
||||||
gpu_data_out_[i] = absl::make_unique<GPUData>();
|
gpu_data_out_[i] = absl::make_unique<GPUData>();
|
||||||
|
@ -638,15 +692,20 @@ REGISTER_CALCULATOR(TfLiteInferenceCalculator);
|
||||||
gpu_data_out_[i]->elements, &gpu_data_out_[i]->buffer));
|
gpu_data_out_[i]->elements, &gpu_data_out_[i]->buffer));
|
||||||
}
|
}
|
||||||
RET_CHECK_CALL(tflite_gpu_runner_->Build());
|
RET_CHECK_CALL(tflite_gpu_runner_->Build());
|
||||||
#endif
|
#endif // MEDIAPIPE_TFLITE_GL_INFERENCE
|
||||||
|
|
||||||
return ::mediapipe::OkStatus();
|
return ::mediapipe::OkStatus();
|
||||||
}
|
}
|
||||||
|
|
||||||
::mediapipe::Status TfLiteInferenceCalculator::LoadModel(
|
::mediapipe::Status TfLiteInferenceCalculator::LoadModel(
|
||||||
CalculatorContext* cc) {
|
CalculatorContext* cc) {
|
||||||
|
if (use_advanced_gpu_api_) {
|
||||||
|
// Use InitTFLiteGPURunner for everything.
|
||||||
|
return ::mediapipe::OkStatus();
|
||||||
|
}
|
||||||
|
|
||||||
ASSIGN_OR_RETURN(model_packet_, GetModelAsPacket(*cc));
|
ASSIGN_OR_RETURN(model_packet_, GetModelAsPacket(*cc));
|
||||||
const auto& model = *model_packet_.Get<TfLiteModelPtr>();
|
const auto& model = *model_packet_.Get<TfLiteModelPtr>();
|
||||||
|
|
||||||
tflite::ops::builtin::BuiltinOpResolver op_resolver;
|
tflite::ops::builtin::BuiltinOpResolver op_resolver;
|
||||||
if (cc->InputSidePackets().HasTag("CUSTOM_OP_RESOLVER")) {
|
if (cc->InputSidePackets().HasTag("CUSTOM_OP_RESOLVER")) {
|
||||||
op_resolver = cc->InputSidePackets()
|
op_resolver = cc->InputSidePackets()
|
||||||
|
@ -654,19 +713,6 @@ REGISTER_CALCULATOR(TfLiteInferenceCalculator);
|
||||||
.Get<tflite::ops::builtin::BuiltinOpResolver>();
|
.Get<tflite::ops::builtin::BuiltinOpResolver>();
|
||||||
}
|
}
|
||||||
|
|
||||||
#if !defined(MEDIAPIPE_DISABLE_GL_COMPUTE)
|
|
||||||
if (use_advanced_gpu_api_) {
|
|
||||||
tflite::gpu::InferenceOptions options;
|
|
||||||
options.priority1 = tflite::gpu::InferencePriority::MIN_LATENCY;
|
|
||||||
options.priority2 = tflite::gpu::InferencePriority::AUTO;
|
|
||||||
options.priority3 = tflite::gpu::InferencePriority::AUTO;
|
|
||||||
options.usage = tflite::gpu::InferenceUsage::SUSTAINED_SPEED;
|
|
||||||
tflite_gpu_runner_ =
|
|
||||||
std::make_unique<tflite::gpu::TFLiteGPURunner>(options);
|
|
||||||
return tflite_gpu_runner_->InitializeWithModel(model, op_resolver);
|
|
||||||
}
|
|
||||||
#endif
|
|
||||||
|
|
||||||
#if defined(MEDIAPIPE_EDGE_TPU)
|
#if defined(MEDIAPIPE_EDGE_TPU)
|
||||||
interpreter_ =
|
interpreter_ =
|
||||||
BuildEdgeTpuInterpreter(model, &op_resolver, edgetpu_context_.get());
|
BuildEdgeTpuInterpreter(model, &op_resolver, edgetpu_context_.get());
|
||||||
|
@ -771,7 +817,7 @@ REGISTER_CALCULATOR(TfLiteInferenceCalculator);
|
||||||
return ::mediapipe::OkStatus();
|
return ::mediapipe::OkStatus();
|
||||||
}
|
}
|
||||||
|
|
||||||
#if !defined(MEDIAPIPE_DISABLE_GL_COMPUTE)
|
#if MEDIAPIPE_TFLITE_GL_INFERENCE
|
||||||
// Configure and create the delegate.
|
// Configure and create the delegate.
|
||||||
TfLiteGpuDelegateOptions options = TfLiteGpuDelegateOptionsDefault();
|
TfLiteGpuDelegateOptions options = TfLiteGpuDelegateOptionsDefault();
|
||||||
options.compile_options.precision_loss_allowed = 1;
|
options.compile_options.precision_loss_allowed = 1;
|
||||||
|
@ -832,9 +878,9 @@ REGISTER_CALCULATOR(TfLiteInferenceCalculator);
|
||||||
// Must call this last.
|
// Must call this last.
|
||||||
RET_CHECK_EQ(interpreter_->ModifyGraphWithDelegate(delegate_.get()),
|
RET_CHECK_EQ(interpreter_->ModifyGraphWithDelegate(delegate_.get()),
|
||||||
kTfLiteOk);
|
kTfLiteOk);
|
||||||
#endif // OpenGL
|
#endif // MEDIAPIPE_TFLITE_GL_INFERENCE
|
||||||
|
|
||||||
#if defined(MEDIAPIPE_IOS)
|
#if MEDIAPIPE_TFLITE_METAL_INFERENCE
|
||||||
const int kHalfSize = 2; // sizeof(half)
|
const int kHalfSize = 2; // sizeof(half)
|
||||||
// Configure and create the delegate.
|
// Configure and create the delegate.
|
||||||
TFLGpuDelegateOptions options;
|
TFLGpuDelegateOptions options;
|
||||||
|
@ -958,7 +1004,7 @@ REGISTER_CALCULATOR(TfLiteInferenceCalculator);
|
||||||
"Error initializating output buffer converter");
|
"Error initializating output buffer converter");
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
#endif // iOS
|
#endif // MEDIAPIPE_TFLITE_METAL_INFERENCE
|
||||||
|
|
||||||
return ::mediapipe::OkStatus();
|
return ::mediapipe::OkStatus();
|
||||||
}
|
}
|
||||||
|
|
|
@ -45,6 +45,8 @@ message TfLiteInferenceCalculatorOptions {
|
||||||
message Gpu {
|
message Gpu {
|
||||||
// Experimental, Android/Linux only. Use TFLite GPU delegate API2 for
|
// Experimental, Android/Linux only. Use TFLite GPU delegate API2 for
|
||||||
// the NN inference.
|
// the NN inference.
|
||||||
|
// example:
|
||||||
|
// delegate: { gpu { use_advanced_gpu_api: true } }
|
||||||
optional bool use_advanced_gpu_api = 1 [default = false];
|
optional bool use_advanced_gpu_api = 1 [default = false];
|
||||||
}
|
}
|
||||||
// Android only.
|
// Android only.
|
||||||
|
|
|
@ -25,17 +25,18 @@
|
||||||
#include "mediapipe/framework/formats/location.h"
|
#include "mediapipe/framework/formats/location.h"
|
||||||
#include "mediapipe/framework/formats/object_detection/anchor.pb.h"
|
#include "mediapipe/framework/formats/object_detection/anchor.pb.h"
|
||||||
#include "mediapipe/framework/port/ret_check.h"
|
#include "mediapipe/framework/port/ret_check.h"
|
||||||
|
#include "mediapipe/util/tflite/config.h"
|
||||||
#include "tensorflow/lite/interpreter.h"
|
#include "tensorflow/lite/interpreter.h"
|
||||||
|
|
||||||
#if !defined(MEDIAPIPE_DISABLE_GL_COMPUTE)
|
#if MEDIAPIPE_TFLITE_GL_INFERENCE
|
||||||
#include "mediapipe/gpu/gl_calculator_helper.h"
|
#include "mediapipe/gpu/gl_calculator_helper.h"
|
||||||
#include "tensorflow/lite/delegates/gpu/gl/gl_buffer.h"
|
#include "tensorflow/lite/delegates/gpu/gl/gl_buffer.h"
|
||||||
#include "tensorflow/lite/delegates/gpu/gl/gl_program.h"
|
#include "tensorflow/lite/delegates/gpu/gl/gl_program.h"
|
||||||
#include "tensorflow/lite/delegates/gpu/gl/gl_shader.h"
|
#include "tensorflow/lite/delegates/gpu/gl/gl_shader.h"
|
||||||
#include "tensorflow/lite/delegates/gpu/gl_delegate.h"
|
#include "tensorflow/lite/delegates/gpu/gl_delegate.h"
|
||||||
#endif // !MEDIAPIPE_DISABLE_GPU
|
#endif // MEDIAPIPE_TFLITE_GL_INFERENCE
|
||||||
|
|
||||||
#if defined(MEDIAPIPE_IOS)
|
#if MEDIAPIPE_TFLITE_METAL_INFERENCE
|
||||||
#import <CoreVideo/CoreVideo.h>
|
#import <CoreVideo/CoreVideo.h>
|
||||||
#import <Metal/Metal.h>
|
#import <Metal/Metal.h>
|
||||||
#import <MetalKit/MetalKit.h>
|
#import <MetalKit/MetalKit.h>
|
||||||
|
@ -44,7 +45,7 @@
|
||||||
#include "mediapipe/gpu/MPPMetalUtil.h"
|
#include "mediapipe/gpu/MPPMetalUtil.h"
|
||||||
#include "mediapipe/gpu/gpu_buffer.h"
|
#include "mediapipe/gpu/gpu_buffer.h"
|
||||||
#include "tensorflow/lite/delegates/gpu/metal_delegate.h"
|
#include "tensorflow/lite/delegates/gpu/metal_delegate.h"
|
||||||
#endif // iOS
|
#endif // MEDIAPIPE_TFLITE_METAL_INFERENCE
|
||||||
|
|
||||||
namespace {
|
namespace {
|
||||||
constexpr int kNumInputTensorsWithAnchors = 3;
|
constexpr int kNumInputTensorsWithAnchors = 3;
|
||||||
|
@ -56,22 +57,17 @@ constexpr char kTensorsGpuTag[] = "TENSORS_GPU";
|
||||||
|
|
||||||
namespace mediapipe {
|
namespace mediapipe {
|
||||||
|
|
||||||
#if !defined(MEDIAPIPE_DISABLE_GL_COMPUTE)
|
#if MEDIAPIPE_TFLITE_GL_INFERENCE
|
||||||
using ::tflite::gpu::gl::CreateReadWriteShaderStorageBuffer;
|
using ::tflite::gpu::gl::CreateReadWriteShaderStorageBuffer;
|
||||||
using ::tflite::gpu::gl::GlShader;
|
using ::tflite::gpu::gl::GlShader;
|
||||||
#endif
|
|
||||||
|
|
||||||
#if !defined(MEDIAPIPE_DISABLE_GL_COMPUTE)
|
|
||||||
typedef ::tflite::gpu::gl::GlBuffer GpuTensor;
|
|
||||||
typedef ::tflite::gpu::gl::GlProgram GpuProgram;
|
typedef ::tflite::gpu::gl::GlProgram GpuProgram;
|
||||||
#elif defined(MEDIAPIPE_IOS)
|
#elif MEDIAPIPE_TFLITE_METAL_INFERENCE
|
||||||
typedef id<MTLBuffer> GpuTensor;
|
|
||||||
typedef id<MTLComputePipelineState> GpuProgram;
|
typedef id<MTLComputePipelineState> GpuProgram;
|
||||||
#endif
|
#endif // MEDIAPIPE_TFLITE_GL_INFERENCE
|
||||||
|
|
||||||
namespace {
|
namespace {
|
||||||
|
|
||||||
#if !defined(MEDIAPIPE_DISABLE_GPU) && !defined(__EMSCRIPTEN__)
|
#if MEDIAPIPE_TFLITE_GPU_SUPPORTED
|
||||||
struct GPUData {
|
struct GPUData {
|
||||||
GpuProgram decode_program;
|
GpuProgram decode_program;
|
||||||
GpuProgram score_program;
|
GpuProgram score_program;
|
||||||
|
@ -81,7 +77,7 @@ struct GPUData {
|
||||||
GpuTensor scored_boxes_buffer;
|
GpuTensor scored_boxes_buffer;
|
||||||
GpuTensor raw_scores_buffer;
|
GpuTensor raw_scores_buffer;
|
||||||
};
|
};
|
||||||
#endif
|
#endif // MEDIAPIPE_TFLITE_GPU_SUPPORTED
|
||||||
|
|
||||||
void ConvertRawValuesToAnchors(const float* raw_anchors, int num_boxes,
|
void ConvertRawValuesToAnchors(const float* raw_anchors, int num_boxes,
|
||||||
std::vector<Anchor>* anchors) {
|
std::vector<Anchor>* anchors) {
|
||||||
|
@ -181,13 +177,13 @@ class TfLiteTensorsToDetectionsCalculator : public CalculatorBase {
|
||||||
std::vector<Anchor> anchors_;
|
std::vector<Anchor> anchors_;
|
||||||
bool side_packet_anchors_{};
|
bool side_packet_anchors_{};
|
||||||
|
|
||||||
#if !defined(MEDIAPIPE_DISABLE_GL_COMPUTE)
|
#if MEDIAPIPE_TFLITE_GL_INFERENCE
|
||||||
mediapipe::GlCalculatorHelper gpu_helper_;
|
mediapipe::GlCalculatorHelper gpu_helper_;
|
||||||
std::unique_ptr<GPUData> gpu_data_;
|
std::unique_ptr<GPUData> gpu_data_;
|
||||||
#elif defined(MEDIAPIPE_IOS)
|
#elif MEDIAPIPE_TFLITE_METAL_INFERENCE
|
||||||
MPPMetalHelper* gpu_helper_ = nullptr;
|
MPPMetalHelper* gpu_helper_ = nullptr;
|
||||||
std::unique_ptr<GPUData> gpu_data_;
|
std::unique_ptr<GPUData> gpu_data_;
|
||||||
#endif
|
#endif // MEDIAPIPE_TFLITE_GL_INFERENCE
|
||||||
|
|
||||||
bool gpu_input_ = false;
|
bool gpu_input_ = false;
|
||||||
bool anchors_init_ = false;
|
bool anchors_init_ = false;
|
||||||
|
@ -205,12 +201,10 @@ REGISTER_CALCULATOR(TfLiteTensorsToDetectionsCalculator);
|
||||||
cc->Inputs().Tag(kTensorsTag).Set<std::vector<TfLiteTensor>>();
|
cc->Inputs().Tag(kTensorsTag).Set<std::vector<TfLiteTensor>>();
|
||||||
}
|
}
|
||||||
|
|
||||||
#if !defined(MEDIAPIPE_DISABLE_GPU) && !defined(__EMSCRIPTEN__)
|
|
||||||
if (cc->Inputs().HasTag(kTensorsGpuTag)) {
|
if (cc->Inputs().HasTag(kTensorsGpuTag)) {
|
||||||
cc->Inputs().Tag(kTensorsGpuTag).Set<std::vector<GpuTensor>>();
|
cc->Inputs().Tag(kTensorsGpuTag).Set<std::vector<GpuTensor>>();
|
||||||
use_gpu |= true;
|
use_gpu |= true;
|
||||||
}
|
}
|
||||||
#endif // !MEDIAPIPE_DISABLE_GPU
|
|
||||||
|
|
||||||
if (cc->Outputs().HasTag("DETECTIONS")) {
|
if (cc->Outputs().HasTag("DETECTIONS")) {
|
||||||
cc->Outputs().Tag("DETECTIONS").Set<std::vector<Detection>>();
|
cc->Outputs().Tag("DETECTIONS").Set<std::vector<Detection>>();
|
||||||
|
@ -223,11 +217,11 @@ REGISTER_CALCULATOR(TfLiteTensorsToDetectionsCalculator);
|
||||||
}
|
}
|
||||||
|
|
||||||
if (use_gpu) {
|
if (use_gpu) {
|
||||||
#if !defined(MEDIAPIPE_DISABLE_GL_COMPUTE)
|
#if MEDIAPIPE_TFLITE_GL_INFERENCE
|
||||||
MP_RETURN_IF_ERROR(mediapipe::GlCalculatorHelper::UpdateContract(cc));
|
MP_RETURN_IF_ERROR(mediapipe::GlCalculatorHelper::UpdateContract(cc));
|
||||||
#elif defined(MEDIAPIPE_IOS)
|
#elif MEDIAPIPE_TFLITE_METAL_INFERENCE
|
||||||
MP_RETURN_IF_ERROR([MPPMetalHelper updateContract:cc]);
|
MP_RETURN_IF_ERROR([MPPMetalHelper updateContract:cc]);
|
||||||
#endif
|
#endif // MEDIAPIPE_TFLITE_GL_INFERENCE
|
||||||
}
|
}
|
||||||
|
|
||||||
return ::mediapipe::OkStatus();
|
return ::mediapipe::OkStatus();
|
||||||
|
@ -239,12 +233,12 @@ REGISTER_CALCULATOR(TfLiteTensorsToDetectionsCalculator);
|
||||||
|
|
||||||
if (cc->Inputs().HasTag(kTensorsGpuTag)) {
|
if (cc->Inputs().HasTag(kTensorsGpuTag)) {
|
||||||
gpu_input_ = true;
|
gpu_input_ = true;
|
||||||
#if !defined(MEDIAPIPE_DISABLE_GL_COMPUTE)
|
#if MEDIAPIPE_TFLITE_GL_INFERENCE
|
||||||
MP_RETURN_IF_ERROR(gpu_helper_.Open(cc));
|
MP_RETURN_IF_ERROR(gpu_helper_.Open(cc));
|
||||||
#elif defined(MEDIAPIPE_IOS)
|
#elif MEDIAPIPE_TFLITE_METAL_INFERENCE
|
||||||
gpu_helper_ = [[MPPMetalHelper alloc] initWithCalculatorContext:cc];
|
gpu_helper_ = [[MPPMetalHelper alloc] initWithCalculatorContext:cc];
|
||||||
RET_CHECK(gpu_helper_);
|
RET_CHECK(gpu_helper_);
|
||||||
#endif
|
#endif // MEDIAPIPE_TFLITE_GL_INFERENCE
|
||||||
}
|
}
|
||||||
|
|
||||||
MP_RETURN_IF_ERROR(LoadOptions(cc));
|
MP_RETURN_IF_ERROR(LoadOptions(cc));
|
||||||
|
@ -401,7 +395,7 @@ REGISTER_CALCULATOR(TfLiteTensorsToDetectionsCalculator);
|
||||||
}
|
}
|
||||||
::mediapipe::Status TfLiteTensorsToDetectionsCalculator::ProcessGPU(
|
::mediapipe::Status TfLiteTensorsToDetectionsCalculator::ProcessGPU(
|
||||||
CalculatorContext* cc, std::vector<Detection>* output_detections) {
|
CalculatorContext* cc, std::vector<Detection>* output_detections) {
|
||||||
#if !defined(MEDIAPIPE_DISABLE_GL_COMPUTE)
|
#if MEDIAPIPE_TFLITE_GL_INFERENCE
|
||||||
const auto& input_tensors =
|
const auto& input_tensors =
|
||||||
cc->Inputs().Tag(kTensorsGpuTag).Get<std::vector<GpuTensor>>();
|
cc->Inputs().Tag(kTensorsGpuTag).Get<std::vector<GpuTensor>>();
|
||||||
RET_CHECK_GE(input_tensors.size(), 2);
|
RET_CHECK_GE(input_tensors.size(), 2);
|
||||||
|
@ -464,7 +458,7 @@ REGISTER_CALCULATOR(TfLiteTensorsToDetectionsCalculator);
|
||||||
|
|
||||||
return ::mediapipe::OkStatus();
|
return ::mediapipe::OkStatus();
|
||||||
}));
|
}));
|
||||||
#elif defined(MEDIAPIPE_IOS)
|
#elif MEDIAPIPE_TFLITE_METAL_INFERENCE
|
||||||
|
|
||||||
const auto& input_tensors =
|
const auto& input_tensors =
|
||||||
cc->Inputs().Tag(kTensorsGpuTag).Get<std::vector<GpuTensor>>();
|
cc->Inputs().Tag(kTensorsGpuTag).Get<std::vector<GpuTensor>>();
|
||||||
|
@ -546,17 +540,17 @@ REGISTER_CALCULATOR(TfLiteTensorsToDetectionsCalculator);
|
||||||
|
|
||||||
#else
|
#else
|
||||||
LOG(ERROR) << "GPU input on non-Android not supported yet.";
|
LOG(ERROR) << "GPU input on non-Android not supported yet.";
|
||||||
#endif
|
#endif // MEDIAPIPE_TFLITE_GL_INFERENCE
|
||||||
return ::mediapipe::OkStatus();
|
return ::mediapipe::OkStatus();
|
||||||
}
|
}
|
||||||
|
|
||||||
::mediapipe::Status TfLiteTensorsToDetectionsCalculator::Close(
|
::mediapipe::Status TfLiteTensorsToDetectionsCalculator::Close(
|
||||||
CalculatorContext* cc) {
|
CalculatorContext* cc) {
|
||||||
#if !defined(MEDIAPIPE_DISABLE_GL_COMPUTE)
|
#if MEDIAPIPE_TFLITE_GL_INFERENCE
|
||||||
gpu_helper_.RunInGlContext([this] { gpu_data_.reset(); });
|
gpu_helper_.RunInGlContext([this] { gpu_data_.reset(); });
|
||||||
#elif defined(MEDIAPIPE_IOS)
|
#elif MEDIAPIPE_TFLITE_METAL_INFERENCE
|
||||||
gpu_data_.reset();
|
gpu_data_.reset();
|
||||||
#endif
|
#endif // MEDIAPIPE_TFLITE_GL_INFERENCE
|
||||||
|
|
||||||
return ::mediapipe::OkStatus();
|
return ::mediapipe::OkStatus();
|
||||||
}
|
}
|
||||||
|
@ -705,7 +699,7 @@ Detection TfLiteTensorsToDetectionsCalculator::ConvertToDetection(
|
||||||
|
|
||||||
::mediapipe::Status TfLiteTensorsToDetectionsCalculator::GpuInit(
|
::mediapipe::Status TfLiteTensorsToDetectionsCalculator::GpuInit(
|
||||||
CalculatorContext* cc) {
|
CalculatorContext* cc) {
|
||||||
#if !defined(MEDIAPIPE_DISABLE_GL_COMPUTE)
|
#if MEDIAPIPE_TFLITE_GL_INFERENCE
|
||||||
MP_RETURN_IF_ERROR(gpu_helper_.RunInGlContext([this]()
|
MP_RETURN_IF_ERROR(gpu_helper_.RunInGlContext([this]()
|
||||||
-> ::mediapipe::Status {
|
-> ::mediapipe::Status {
|
||||||
gpu_data_ = absl::make_unique<GPUData>();
|
gpu_data_ = absl::make_unique<GPUData>();
|
||||||
|
@ -918,7 +912,7 @@ void main() {
|
||||||
return ::mediapipe::OkStatus();
|
return ::mediapipe::OkStatus();
|
||||||
}));
|
}));
|
||||||
|
|
||||||
#elif defined(MEDIAPIPE_IOS)
|
#elif MEDIAPIPE_TFLITE_METAL_INFERENCE
|
||||||
|
|
||||||
gpu_data_ = absl::make_unique<GPUData>();
|
gpu_data_ = absl::make_unique<GPUData>();
|
||||||
id<MTLDevice> device = gpu_helper_.mtlDevice;
|
id<MTLDevice> device = gpu_helper_.mtlDevice;
|
||||||
|
@ -1148,7 +1142,7 @@ kernel void scoreKernel(
|
||||||
CHECK_LT(num_classes_, max_wg_size) << "# classes must be <" << max_wg_size;
|
CHECK_LT(num_classes_, max_wg_size) << "# classes must be <" << max_wg_size;
|
||||||
}
|
}
|
||||||
|
|
||||||
#endif // !defined(MEDIAPIPE_DISABLE_GL_COMPUTE)
|
#endif // MEDIAPIPE_TFLITE_GL_INFERENCE
|
||||||
|
|
||||||
return ::mediapipe::OkStatus();
|
return ::mediapipe::OkStatus();
|
||||||
}
|
}
|
||||||
|
|
|
@ -217,11 +217,11 @@ REGISTER_CALCULATOR(TfLiteTensorsToLandmarksCalculator);
|
||||||
for (int i = 0; i < output_landmarks.landmark_size(); ++i) {
|
for (int i = 0; i < output_landmarks.landmark_size(); ++i) {
|
||||||
const Landmark& landmark = output_landmarks.landmark(i);
|
const Landmark& landmark = output_landmarks.landmark(i);
|
||||||
NormalizedLandmark* norm_landmark = output_norm_landmarks.add_landmark();
|
NormalizedLandmark* norm_landmark = output_norm_landmarks.add_landmark();
|
||||||
norm_landmark->set_x(static_cast<float>(landmark.x()) /
|
norm_landmark->set_x(landmark.x() / options_.input_image_width());
|
||||||
options_.input_image_width());
|
norm_landmark->set_y(landmark.y() / options_.input_image_height());
|
||||||
norm_landmark->set_y(static_cast<float>(landmark.y()) /
|
// Scale Z coordinate as X + allow additional uniform normalization.
|
||||||
options_.input_image_height());
|
norm_landmark->set_z(landmark.z() / options_.input_image_width() /
|
||||||
norm_landmark->set_z(landmark.z() / options_.normalize_z());
|
options_.normalize_z());
|
||||||
norm_landmark->set_visibility(landmark.visibility());
|
norm_landmark->set_visibility(landmark.visibility());
|
||||||
}
|
}
|
||||||
cc->Outputs()
|
cc->Outputs()
|
||||||
|
|
|
@ -29,7 +29,8 @@ message TfLiteTensorsToLandmarksCalculatorOptions {
|
||||||
required int32 num_landmarks = 1;
|
required int32 num_landmarks = 1;
|
||||||
|
|
||||||
// Size of the input image for the model. These options are used only when
|
// Size of the input image for the model. These options are used only when
|
||||||
// normalized landmarks is needed.
|
// normalized landmarks are needed. Z coordinate is scaled as X assuming
|
||||||
|
// a weak perspective projection camera model.
|
||||||
optional int32 input_image_width = 2;
|
optional int32 input_image_width = 2;
|
||||||
optional int32 input_image_height = 3;
|
optional int32 input_image_height = 3;
|
||||||
|
|
||||||
|
@ -46,6 +47,8 @@ message TfLiteTensorsToLandmarksCalculatorOptions {
|
||||||
// beforehand.
|
// beforehand.
|
||||||
optional bool flip_horizontally = 6 [default = false];
|
optional bool flip_horizontally = 6 [default = false];
|
||||||
|
|
||||||
// A value that z values should be divided by.
|
// A value that Z coordinates should be divided by. This option is used only
|
||||||
|
// when normalized landmarks are needed. It is applied in addition to Z
|
||||||
|
// coordinate being re-scaled as X.
|
||||||
optional float normalize_z = 5 [default = 1.0];
|
optional float normalize_z = 5 [default = 1.0];
|
||||||
}
|
}
|
||||||
|
|
|
@ -376,6 +376,7 @@ cc_library(
|
||||||
visibility = ["//visibility:public"],
|
visibility = ["//visibility:public"],
|
||||||
deps = [
|
deps = [
|
||||||
":timed_box_list_id_to_label_calculator_cc_proto",
|
":timed_box_list_id_to_label_calculator_cc_proto",
|
||||||
|
"@com_google_absl//absl/container:node_hash_map",
|
||||||
"//mediapipe/framework/port:status",
|
"//mediapipe/framework/port:status",
|
||||||
"//mediapipe/framework:calculator_framework",
|
"//mediapipe/framework:calculator_framework",
|
||||||
"//mediapipe/framework:packet",
|
"//mediapipe/framework:packet",
|
||||||
|
|
|
@ -122,11 +122,13 @@ class LandmarkLetterboxRemovalCalculator : public CalculatorBase {
|
||||||
NormalizedLandmark* new_landmark = output_landmarks.add_landmark();
|
NormalizedLandmark* new_landmark = output_landmarks.add_landmark();
|
||||||
const float new_x = (landmark.x() - left) / (1.0f - left_and_right);
|
const float new_x = (landmark.x() - left) / (1.0f - left_and_right);
|
||||||
const float new_y = (landmark.y() - top) / (1.0f - top_and_bottom);
|
const float new_y = (landmark.y() - top) / (1.0f - top_and_bottom);
|
||||||
|
const float new_z =
|
||||||
|
landmark.z() / (1.0f - left_and_right); // Scale Z coordinate as X.
|
||||||
|
|
||||||
new_landmark->set_x(new_x);
|
new_landmark->set_x(new_x);
|
||||||
new_landmark->set_y(new_y);
|
new_landmark->set_y(new_y);
|
||||||
// Keep z-coord as is.
|
// Keep z-coord as is.
|
||||||
new_landmark->set_z(landmark.z());
|
new_landmark->set_z(new_z);
|
||||||
// Keep visibility as is.
|
// Keep visibility as is.
|
||||||
new_landmark->set_visibility(landmark.visibility());
|
new_landmark->set_visibility(landmark.visibility());
|
||||||
}
|
}
|
||||||
|
|
|
@ -123,11 +123,12 @@ class LandmarkProjectionCalculator : public CalculatorBase {
|
||||||
|
|
||||||
new_x = new_x * input_rect.width() + input_rect.x_center();
|
new_x = new_x * input_rect.width() + input_rect.x_center();
|
||||||
new_y = new_y * input_rect.height() + input_rect.y_center();
|
new_y = new_y * input_rect.height() + input_rect.y_center();
|
||||||
|
const float new_z =
|
||||||
|
landmark.z() * input_rect.width(); // Scale Z coordinate as X.
|
||||||
|
|
||||||
new_landmark->set_x(new_x);
|
new_landmark->set_x(new_x);
|
||||||
new_landmark->set_y(new_y);
|
new_landmark->set_y(new_y);
|
||||||
// Keep z-coord as is.
|
new_landmark->set_z(new_z);
|
||||||
new_landmark->set_z(landmark.z());
|
|
||||||
// Keep visibility as is.
|
// Keep visibility as is.
|
||||||
new_landmark->set_visibility(landmark.visibility());
|
new_landmark->set_visibility(landmark.visibility());
|
||||||
}
|
}
|
||||||
|
|
|
@ -12,6 +12,7 @@
|
||||||
// See the License for the specific language governing permissions and
|
// See the License for the specific language governing permissions and
|
||||||
// limitations under the License.
|
// limitations under the License.
|
||||||
|
|
||||||
|
#include "absl/container/node_hash_map.h"
|
||||||
#include "mediapipe/calculators/util/timed_box_list_id_to_label_calculator.pb.h"
|
#include "mediapipe/calculators/util/timed_box_list_id_to_label_calculator.pb.h"
|
||||||
#include "mediapipe/framework/calculator_framework.h"
|
#include "mediapipe/framework/calculator_framework.h"
|
||||||
#include "mediapipe/framework/packet.h"
|
#include "mediapipe/framework/packet.h"
|
||||||
|
@ -53,7 +54,7 @@ class TimedBoxListIdToLabelCalculator : public CalculatorBase {
|
||||||
::mediapipe::Status Process(CalculatorContext* cc) override;
|
::mediapipe::Status Process(CalculatorContext* cc) override;
|
||||||
|
|
||||||
private:
|
private:
|
||||||
std::unordered_map<int, std::string> label_map_;
|
absl::node_hash_map<int, std::string> label_map_;
|
||||||
};
|
};
|
||||||
REGISTER_CALCULATOR(TimedBoxListIdToLabelCalculator);
|
REGISTER_CALCULATOR(TimedBoxListIdToLabelCalculator);
|
||||||
|
|
||||||
|
|
|
@ -1,4 +1 @@
|
||||||
MediaPipe Examples
|
This directory contains MediaPipe example applications for Android. Please see [Solutions](https://solutions.mediapipe.dev)for details.
|
||||||
==================
|
|
||||||
|
|
||||||
This directory contains MediaPipe Android example applications. Please see [src/java/com/google/mediapipe/apps/README.md](src/java/com/google/mediapipe/apps/README.md) for details.
|
|
||||||
|
|
|
@ -1,7 +0,0 @@
|
||||||
tricorder: {
|
|
||||||
options: {
|
|
||||||
builder: {
|
|
||||||
config: "android_arm64"
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
|
@ -83,7 +83,7 @@ android_binary(
|
||||||
manifest = "//mediapipe/examples/android/src/java/com/google/mediapipe/apps/basic:AndroidManifest.xml",
|
manifest = "//mediapipe/examples/android/src/java/com/google/mediapipe/apps/basic:AndroidManifest.xml",
|
||||||
manifest_values = {
|
manifest_values = {
|
||||||
"applicationId": "com.google.mediapipe.apps.objectdetection3d",
|
"applicationId": "com.google.mediapipe.apps.objectdetection3d",
|
||||||
"appName": "Object Detection 3D",
|
"appName": "Objectron",
|
||||||
"mainActivity": ".MainActivity",
|
"mainActivity": ".MainActivity",
|
||||||
"cameraFacingFront": "False",
|
"cameraFacingFront": "False",
|
||||||
"binaryGraphName": "object_detection_3d.binarypb",
|
"binaryGraphName": "object_detection_3d.binarypb",
|
||||||
|
|
|
@ -1,113 +1 @@
|
||||||
**Hello World**
|
This directory contains MediaPipe example applications for desktop. Please see [Solutions](https://solutions.mediapipe.dev)for details.
|
||||||
|
|
||||||
To build the "Hello World" example, use:
|
|
||||||
|
|
||||||
```
|
|
||||||
bazel build -c opt mediapipe/examples/desktop/hello_world:hello_world
|
|
||||||
```
|
|
||||||
|
|
||||||
and then run it using:
|
|
||||||
|
|
||||||
```
|
|
||||||
export GLOG_logtostderr=1
|
|
||||||
|
|
||||||
bazel-bin/mediapipe/examples/desktop/hello_world/hello_world
|
|
||||||
```
|
|
||||||
|
|
||||||
**TFlite Object Detection**
|
|
||||||
|
|
||||||
To build the object detection demo using a TFLite model on desktop, use:
|
|
||||||
|
|
||||||
```
|
|
||||||
bazel build -c opt mediapipe/examples/desktop/object_detection:object_detection_tflite --define MEDIAPIPE_DISABLE_GPU=1
|
|
||||||
```
|
|
||||||
|
|
||||||
and run it using:
|
|
||||||
|
|
||||||
```
|
|
||||||
export GLOG_logtostderr=1
|
|
||||||
|
|
||||||
bazel-bin/mediapipe/examples/desktop/object_detection/object_detection_tflite \
|
|
||||||
--calculator_graph_config_file=mediapipe/graphs/object_detection/object_detection_desktop_tflite_graph.pbtxt \
|
|
||||||
--input_side_packets=input_video_path=/path/to/input/file,output_video_path=/path/to/output/file
|
|
||||||
```
|
|
||||||
|
|
||||||
**TensorFlow Object Detection**
|
|
||||||
|
|
||||||
To build the object detection demo using a TensorFlow model on desktop, use:
|
|
||||||
|
|
||||||
```
|
|
||||||
export GLOG_logtostderr=1
|
|
||||||
|
|
||||||
bazel build -c opt mediapipe/examples/desktop/object_detection:object_detection_tensorflow \
|
|
||||||
--define MEDIAPIPE_DISABLE_GPU=1
|
|
||||||
```
|
|
||||||
|
|
||||||
and run it using:
|
|
||||||
|
|
||||||
```
|
|
||||||
export GLOG_logtostderr=1
|
|
||||||
|
|
||||||
bazel-bin/mediapipe/examples/desktop/object_detection/object_detection_tensorflow \
|
|
||||||
--calculator_graph_config_file=mediapipe/graphs/object_detection/object_detection_desktop_tensorflow_graph.pbtxt \
|
|
||||||
--input_side_packets=input_video_path=/path/to/input/file,output_video_path=/path/to/output/file
|
|
||||||
```
|
|
||||||
|
|
||||||
**TFlite Hand Detection**
|
|
||||||
|
|
||||||
To build the hand detection demo using a TFLite model on desktop, use:
|
|
||||||
|
|
||||||
```
|
|
||||||
bazel build -c opt mediapipe/examples/desktop/hand_tracking:hand_tracking_tflite --define MEDIAPIPE_DISABLE_GPU=1
|
|
||||||
```
|
|
||||||
|
|
||||||
and run it using:
|
|
||||||
|
|
||||||
```
|
|
||||||
export GLOG_logtostderr=1
|
|
||||||
|
|
||||||
bazel-bin/mediapipe/examples/desktop/hand_tracking/hand_tracking_tflite \
|
|
||||||
--calculator_graph_config_file=mediapipe/graphs/hand_tracking/hand_detection_desktop.pbtxt \
|
|
||||||
--input_side_packets=input_video_path=/path/to/input/file,output_video_path=/path/to/output/file
|
|
||||||
```
|
|
||||||
|
|
||||||
**TFlite Hand Tracking**
|
|
||||||
|
|
||||||
To build the hand tracking demo using a TFLite model on desktop, use:
|
|
||||||
|
|
||||||
```
|
|
||||||
bazel build -c opt mediapipe/examples/desktop/hand_tracking:hand_tracking_tflite --define MEDIAPIPE_DISABLE_GPU=1
|
|
||||||
```
|
|
||||||
|
|
||||||
and run it using:
|
|
||||||
|
|
||||||
```
|
|
||||||
export GLOG_logtostderr=1
|
|
||||||
|
|
||||||
bazel-bin/mediapipe/examples/desktop/hand_tracking/hand_tracking_tflite \
|
|
||||||
--calculator_graph_config_file=mediapipe/graphs/hand_tracking/hand_tracking_desktop.pbtxt \
|
|
||||||
--input_side_packets=input_video_path=/path/to/input/file,output_video_path=/path/to/output/file
|
|
||||||
```
|
|
||||||
|
|
||||||
**TFlite Multi-Hand Tracking**
|
|
||||||
|
|
||||||
To build the multi-hand tracking demo using a TFLite model on desktop, use:
|
|
||||||
|
|
||||||
```
|
|
||||||
bazel build -c opt mediapipe/examples/desktop/multi_hand_tracking:multi_hand_tracking_tflite --define MEDIAPIPE_DISABLE_GPU=1
|
|
||||||
```
|
|
||||||
|
|
||||||
and run it using:
|
|
||||||
|
|
||||||
```
|
|
||||||
export GLOG_logtostderr=1
|
|
||||||
|
|
||||||
bazel-bin/mediapipe/examples/desktop/multi_hand_tracking/multi_hand_tracking_tflite \
|
|
||||||
--calculator_graph_config_file=mediapipe/graphs/hand_tracking/multi_hand_tracking_desktop.pbtxt \
|
|
||||||
--input_side_packets=input_video_path=/path/to/input/file,output_video_path=/path/to/output/file
|
|
||||||
```
|
|
||||||
|
|
||||||
To change the number of hands to `x` in this application, change:
|
|
||||||
|
|
||||||
1. `min_size:x` in `CollectionHasMinSizeCalculatorOptions` in `mediapipe/graphs/hand_tracking/multi_hand_tracking_desktop.pbtxt`.
|
|
||||||
2. `max_vec_size:x` in `ClipVectorSizeCalculatorOptions` in `mediapipe/examples/dekstop/hand_tracking/subgraphs/multi_hand_detection_cpu.pbtxt`.
|
|
||||||
|
|
|
@ -62,8 +62,10 @@ cc_library(
|
||||||
"//mediapipe/examples/desktop/autoflip:autoflip_messages_cc_proto",
|
"//mediapipe/examples/desktop/autoflip:autoflip_messages_cc_proto",
|
||||||
"//mediapipe/examples/desktop/autoflip/quality:kinematic_path_solver",
|
"//mediapipe/examples/desktop/autoflip/quality:kinematic_path_solver",
|
||||||
"//mediapipe/framework:calculator_framework",
|
"//mediapipe/framework:calculator_framework",
|
||||||
|
"//mediapipe/framework/formats:detection_cc_proto",
|
||||||
"//mediapipe/framework/formats:image_frame",
|
"//mediapipe/framework/formats:image_frame",
|
||||||
"//mediapipe/framework/formats:image_frame_opencv",
|
"//mediapipe/framework/formats:location_data_cc_proto",
|
||||||
|
"//mediapipe/framework/formats:rect_cc_proto",
|
||||||
"//mediapipe/framework/port:ret_check",
|
"//mediapipe/framework/port:ret_check",
|
||||||
"//mediapipe/framework/port:status",
|
"//mediapipe/framework/port:status",
|
||||||
],
|
],
|
||||||
|
@ -126,17 +128,20 @@ cc_test(
|
||||||
":content_zooming_calculator",
|
":content_zooming_calculator",
|
||||||
":content_zooming_calculator_cc_proto",
|
":content_zooming_calculator_cc_proto",
|
||||||
"//mediapipe/examples/desktop/autoflip:autoflip_messages_cc_proto",
|
"//mediapipe/examples/desktop/autoflip:autoflip_messages_cc_proto",
|
||||||
|
"//mediapipe/examples/desktop/autoflip/quality:kinematic_path_solver",
|
||||||
"//mediapipe/framework:calculator_framework",
|
"//mediapipe/framework:calculator_framework",
|
||||||
"//mediapipe/framework:calculator_runner",
|
"//mediapipe/framework:calculator_runner",
|
||||||
|
"//mediapipe/framework/formats:detection_cc_proto",
|
||||||
"//mediapipe/framework/formats:image_frame",
|
"//mediapipe/framework/formats:image_frame",
|
||||||
"//mediapipe/framework/formats:image_frame_opencv",
|
"//mediapipe/framework/formats:image_frame_opencv",
|
||||||
|
"//mediapipe/framework/formats:location_data_cc_proto",
|
||||||
|
"//mediapipe/framework/formats:rect_cc_proto",
|
||||||
"//mediapipe/framework/port:benchmark",
|
"//mediapipe/framework/port:benchmark",
|
||||||
"//mediapipe/framework/port:gtest_main",
|
"//mediapipe/framework/port:gtest_main",
|
||||||
"//mediapipe/framework/port:opencv_core",
|
"//mediapipe/framework/port:opencv_core",
|
||||||
"//mediapipe/framework/port:parse_text_proto",
|
"//mediapipe/framework/port:parse_text_proto",
|
||||||
"//mediapipe/framework/port:ret_check",
|
"//mediapipe/framework/port:ret_check",
|
||||||
"//mediapipe/framework/port:status",
|
"//mediapipe/framework/port:status",
|
||||||
"@com_google_absl//absl/strings",
|
|
||||||
],
|
],
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
|
@ -19,16 +19,20 @@
|
||||||
#include "mediapipe/examples/desktop/autoflip/calculators/content_zooming_calculator.pb.h"
|
#include "mediapipe/examples/desktop/autoflip/calculators/content_zooming_calculator.pb.h"
|
||||||
#include "mediapipe/examples/desktop/autoflip/quality/kinematic_path_solver.h"
|
#include "mediapipe/examples/desktop/autoflip/quality/kinematic_path_solver.h"
|
||||||
#include "mediapipe/framework/calculator_framework.h"
|
#include "mediapipe/framework/calculator_framework.h"
|
||||||
|
#include "mediapipe/framework/formats/detection.pb.h"
|
||||||
#include "mediapipe/framework/formats/image_frame.h"
|
#include "mediapipe/framework/formats/image_frame.h"
|
||||||
#include "mediapipe/framework/formats/image_frame_opencv.h"
|
#include "mediapipe/framework/formats/location_data.pb.h"
|
||||||
|
#include "mediapipe/framework/formats/rect.pb.h"
|
||||||
#include "mediapipe/framework/port/ret_check.h"
|
#include "mediapipe/framework/port/ret_check.h"
|
||||||
#include "mediapipe/framework/port/status.h"
|
#include "mediapipe/framework/port/status.h"
|
||||||
#include "mediapipe/framework/port/status_builder.h"
|
#include "mediapipe/framework/port/status_builder.h"
|
||||||
|
|
||||||
constexpr char kVideoFrame[] = "VIDEO";
|
constexpr char kVideoFrame[] = "VIDEO";
|
||||||
constexpr char kVideoSize[] = "VIDEO_SIZE";
|
constexpr char kVideoSize[] = "VIDEO_SIZE";
|
||||||
constexpr char kDetectionSet[] = "DETECTIONS";
|
constexpr char kSalientRegions[] = "SALIENT_REGIONS";
|
||||||
|
constexpr char kDetections[] = "DETECTIONS";
|
||||||
constexpr char kDetectedBorders[] = "BORDERS";
|
constexpr char kDetectedBorders[] = "BORDERS";
|
||||||
|
constexpr char kCropRect[] = "CROP_RECT";
|
||||||
// Field-of-view (degrees) of the camera's x-axis (width).
|
// Field-of-view (degrees) of the camera's x-axis (width).
|
||||||
// TODO: Parameterize FOV based on camera specs.
|
// TODO: Parameterize FOV based on camera specs.
|
||||||
constexpr float kWidthFieldOfView = 60;
|
constexpr float kWidthFieldOfView = 60;
|
||||||
|
@ -37,12 +41,12 @@ namespace mediapipe {
|
||||||
namespace autoflip {
|
namespace autoflip {
|
||||||
|
|
||||||
// Content zooming calculator zooms in on content when a detection has
|
// Content zooming calculator zooms in on content when a detection has
|
||||||
// "only_required" set true. It does this by computing the value of top/bottom
|
// "only_required" set true or any raw detection input. It does this by
|
||||||
// borders to remove from the output and sends these to the
|
// computing the value of top/bottom borders to remove from the output and sends
|
||||||
// SceneCroppingCalculator. When more than one detections are received the zoom
|
// these to the SceneCroppingCalculator using BORDERS output or a full rect crop
|
||||||
// box is calculated as the union of the detections. Typical applications
|
// using CROP_RECT output. When more than one detections are received the
|
||||||
// include mobile makeover and autofliplive face reframing. Currently only
|
// zoom box is calculated as the union of the detections. Typical applications
|
||||||
// supports y-dimension zooming.
|
// include mobile makeover and autofliplive face reframing.
|
||||||
class ContentZoomingCalculator : public CalculatorBase {
|
class ContentZoomingCalculator : public CalculatorBase {
|
||||||
public:
|
public:
|
||||||
ContentZoomingCalculator()
|
ContentZoomingCalculator()
|
||||||
|
@ -56,26 +60,32 @@ class ContentZoomingCalculator : public CalculatorBase {
|
||||||
::mediapipe::Status Process(mediapipe::CalculatorContext* cc) override;
|
::mediapipe::Status Process(mediapipe::CalculatorContext* cc) override;
|
||||||
|
|
||||||
private:
|
private:
|
||||||
// Converts bounds to tilt offset and height.
|
// Converts bounds to tilt offset, pan offset and height.
|
||||||
::mediapipe::Status ConvertToTiltZoom(float xmin, float xmax, float ymin,
|
::mediapipe::Status ConvertToPanTiltZoom(float xmin, float xmax, float ymin,
|
||||||
float ymax, int* tilt_offset,
|
float ymax, int* tilt_offset,
|
||||||
int* height);
|
int* pan_offset, int* height);
|
||||||
ContentZoomingCalculatorOptions options_;
|
ContentZoomingCalculatorOptions options_;
|
||||||
// Detection frame width/height.
|
// Detection frame width/height.
|
||||||
int frame_height_;
|
int frame_height_;
|
||||||
int frame_width_;
|
int frame_width_;
|
||||||
// Path solver used to smooth top/bottom border crop values.
|
// Path solver used to smooth top/bottom border crop values.
|
||||||
std::unique_ptr<KinematicPathSolver> path_solver_height_;
|
std::unique_ptr<KinematicPathSolver> path_solver_height_;
|
||||||
|
std::unique_ptr<KinematicPathSolver> path_solver_width_;
|
||||||
std::unique_ptr<KinematicPathSolver> path_solver_offset_;
|
std::unique_ptr<KinematicPathSolver> path_solver_offset_;
|
||||||
// Are parameters initialized.
|
// Are parameters initialized.
|
||||||
bool initialized_;
|
bool initialized_;
|
||||||
// Stores the time of the last "only_required" input.
|
// Stores the time of the last "only_required" input.
|
||||||
int64 last_only_required_detection_;
|
int64 last_only_required_detection_;
|
||||||
// Border values of last message with detection.
|
// Rect values of last message with detection(s).
|
||||||
int last_measured_height_;
|
int last_measured_height_;
|
||||||
|
int last_measured_x_offset_;
|
||||||
int last_measured_y_offset_;
|
int last_measured_y_offset_;
|
||||||
// Min border values.
|
// Target aspect ratio.
|
||||||
float min_height_value_;
|
float target_aspect_;
|
||||||
|
// Max size of bounding box. If input/output aspect ratios are the same,
|
||||||
|
// will be 1.0. Else, will be less than 1.0 to prevent exceeding the size of
|
||||||
|
// the image in either dimension.
|
||||||
|
float max_frame_value_;
|
||||||
};
|
};
|
||||||
REGISTER_CALCULATOR(ContentZoomingCalculator);
|
REGISTER_CALCULATOR(ContentZoomingCalculator);
|
||||||
|
|
||||||
|
@ -92,8 +102,18 @@ REGISTER_CALCULATOR(ContentZoomingCalculator);
|
||||||
return ::mediapipe::UnknownErrorBuilder(MEDIAPIPE_LOC)
|
return ::mediapipe::UnknownErrorBuilder(MEDIAPIPE_LOC)
|
||||||
<< "Input VIDEO or VIDEO_SIZE must be provided.";
|
<< "Input VIDEO or VIDEO_SIZE must be provided.";
|
||||||
}
|
}
|
||||||
cc->Inputs().Tag(kDetectionSet).Set<DetectionSet>();
|
if (cc->Inputs().HasTag(kSalientRegions)) {
|
||||||
|
cc->Inputs().Tag(kSalientRegions).Set<DetectionSet>();
|
||||||
|
}
|
||||||
|
if (cc->Inputs().HasTag(kDetections)) {
|
||||||
|
cc->Inputs().Tag(kDetections).Set<std::vector<mediapipe::Detection>>();
|
||||||
|
}
|
||||||
|
if (cc->Outputs().HasTag(kDetectedBorders)) {
|
||||||
cc->Outputs().Tag(kDetectedBorders).Set<StaticFeatures>();
|
cc->Outputs().Tag(kDetectedBorders).Set<StaticFeatures>();
|
||||||
|
}
|
||||||
|
if (cc->Outputs().HasTag(kCropRect)) {
|
||||||
|
cc->Outputs().Tag(kCropRect).Set<mediapipe::Rect>();
|
||||||
|
}
|
||||||
return ::mediapipe::OkStatus();
|
return ::mediapipe::OkStatus();
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -108,29 +128,38 @@ REGISTER_CALCULATOR(ContentZoomingCalculator);
|
||||||
if (options_.has_min_motion_to_reframe()) {
|
if (options_.has_min_motion_to_reframe()) {
|
||||||
return ::mediapipe::UnknownErrorBuilder(MEDIAPIPE_LOC)
|
return ::mediapipe::UnknownErrorBuilder(MEDIAPIPE_LOC)
|
||||||
<< "Deprecated min_motion_to_reframe was set, please set "
|
<< "Deprecated min_motion_to_reframe was set, please set "
|
||||||
"in kinematic_options_zoom and kinematic_options_tilt directly.";
|
"in kinematic_options_zoom and kinematic_options_tilt "
|
||||||
|
"directly.";
|
||||||
}
|
}
|
||||||
return ::mediapipe::OkStatus();
|
return ::mediapipe::OkStatus();
|
||||||
}
|
}
|
||||||
|
|
||||||
::mediapipe::Status ContentZoomingCalculator::ConvertToTiltZoom(
|
::mediapipe::Status ContentZoomingCalculator::ConvertToPanTiltZoom(
|
||||||
float xmin, float xmax, float ymin, float ymax, int* tilt_offset,
|
float xmin, float xmax, float ymin, float ymax, int* tilt_offset,
|
||||||
int* height) {
|
int* pan_offset, int* height) {
|
||||||
// Find center of the y-axis offset (for tilt control).
|
// Find center of the y-axis offset (for tilt control).
|
||||||
float y_center = ymin + (ymax - ymin) / 2;
|
float y_center = ymin + (ymax - ymin) / 2;
|
||||||
|
// Find center of the x-axis offset (for pan control).
|
||||||
|
float x_center = xmin + (xmax - xmin) / 2;
|
||||||
// Find size and apply scale factor to y-axis.
|
// Find size and apply scale factor to y-axis.
|
||||||
float fit_size = fmax((ymax - ymin) / options_.scale_factor(), xmax - xmin);
|
float fit_size = fmax((ymax - ymin) / options_.scale_factor(), xmax - xmin);
|
||||||
// Apply min zoom for cases where the target size is wider than input frame
|
// Apply max frame for cases where the target size is different than input
|
||||||
// size.
|
// frame size.
|
||||||
fit_size = fmin(min_height_value_, fit_size);
|
fit_size = fmin(max_frame_value_, fit_size);
|
||||||
// Prevent box from extending beyond the image.
|
// Prevent box from extending beyond the image.
|
||||||
if (y_center - fit_size / 2 < 0) {
|
if (y_center - fit_size / 2 < 0) {
|
||||||
y_center = fit_size / 2;
|
y_center = fit_size / 2;
|
||||||
} else if (y_center + fit_size / 2 > 1) {
|
} else if (y_center + fit_size / 2 > 1) {
|
||||||
y_center = 1 - fit_size / 2;
|
y_center = 1 - fit_size / 2;
|
||||||
}
|
}
|
||||||
|
if (x_center - fit_size / 2 < 0) {
|
||||||
|
x_center = fit_size / 2;
|
||||||
|
} else if (x_center + fit_size / 2 > 1) {
|
||||||
|
x_center = 1 - fit_size / 2;
|
||||||
|
}
|
||||||
// Scale to pixel coordinates.
|
// Scale to pixel coordinates.
|
||||||
*tilt_offset = frame_height_ * y_center;
|
*tilt_offset = frame_height_ * y_center;
|
||||||
|
*pan_offset = frame_width_ * x_center;
|
||||||
*height = frame_height_ * fit_size;
|
*height = frame_height_ * fit_size;
|
||||||
return ::mediapipe::OkStatus();
|
return ::mediapipe::OkStatus();
|
||||||
}
|
}
|
||||||
|
@ -151,6 +180,20 @@ namespace {
|
||||||
|
|
||||||
return ::mediapipe::OkStatus();
|
return ::mediapipe::OkStatus();
|
||||||
}
|
}
|
||||||
|
::mediapipe::Status UpdateRanges(const mediapipe::Detection& detection,
|
||||||
|
float* xmin, float* xmax, float* ymin,
|
||||||
|
float* ymax) {
|
||||||
|
RET_CHECK(detection.location_data().format() ==
|
||||||
|
mediapipe::LocationData::RELATIVE_BOUNDING_BOX)
|
||||||
|
<< "Face detection input is lacking required relative_bounding_box()";
|
||||||
|
const auto& location = detection.location_data().relative_bounding_box();
|
||||||
|
*xmin = fmin(*xmin, location.xmin());
|
||||||
|
*xmax = fmax(*xmax, location.xmin() + location.width());
|
||||||
|
*ymin = fmin(*ymin, location.ymin());
|
||||||
|
*ymax = fmax(*ymax, location.ymin() + location.height());
|
||||||
|
|
||||||
|
return ::mediapipe::OkStatus();
|
||||||
|
}
|
||||||
void MakeStaticFeatures(const int top_border, const int bottom_border,
|
void MakeStaticFeatures(const int top_border, const int bottom_border,
|
||||||
const int frame_width, const int frame_height,
|
const int frame_width, const int frame_height,
|
||||||
StaticFeatures* static_feature) {
|
StaticFeatures* static_feature) {
|
||||||
|
@ -173,10 +216,8 @@ void MakeStaticFeatures(const int top_border, const int bottom_border,
|
||||||
::mediapipe::Status ContentZoomingCalculator::Process(
|
::mediapipe::Status ContentZoomingCalculator::Process(
|
||||||
mediapipe::CalculatorContext* cc) {
|
mediapipe::CalculatorContext* cc) {
|
||||||
if (cc->Inputs().HasTag(kVideoFrame)) {
|
if (cc->Inputs().HasTag(kVideoFrame)) {
|
||||||
cv::Mat frame = mediapipe::formats::MatView(
|
frame_width_ = cc->Inputs().Tag(kVideoFrame).Get<ImageFrame>().Width();
|
||||||
&cc->Inputs().Tag(kVideoFrame).Get<ImageFrame>());
|
frame_height_ = cc->Inputs().Tag(kVideoFrame).Get<ImageFrame>().Height();
|
||||||
frame_width_ = frame.cols;
|
|
||||||
frame_height_ = frame.rows;
|
|
||||||
} else if (cc->Inputs().HasTag(kVideoSize)) {
|
} else if (cc->Inputs().HasTag(kVideoSize)) {
|
||||||
frame_width_ =
|
frame_width_ =
|
||||||
cc->Inputs().Tag(kVideoSize).Get<std::pair<int, int>>().first;
|
cc->Inputs().Tag(kVideoSize).Get<std::pair<int, int>>().first;
|
||||||
|
@ -191,10 +232,14 @@ void MakeStaticFeatures(const int top_border, const int bottom_border,
|
||||||
path_solver_height_ = std::make_unique<KinematicPathSolver>(
|
path_solver_height_ = std::make_unique<KinematicPathSolver>(
|
||||||
options_.kinematic_options_zoom(), 0, frame_height_,
|
options_.kinematic_options_zoom(), 0, frame_height_,
|
||||||
static_cast<float>(frame_width_) / kWidthFieldOfView);
|
static_cast<float>(frame_width_) / kWidthFieldOfView);
|
||||||
|
path_solver_width_ = std::make_unique<KinematicPathSolver>(
|
||||||
|
options_.kinematic_options_pan(), 0, frame_width_,
|
||||||
|
static_cast<float>(frame_width_) / kWidthFieldOfView);
|
||||||
path_solver_offset_ = std::make_unique<KinematicPathSolver>(
|
path_solver_offset_ = std::make_unique<KinematicPathSolver>(
|
||||||
options_.kinematic_options_tilt(), 0, frame_height_,
|
options_.kinematic_options_tilt(), 0, frame_height_,
|
||||||
static_cast<float>(frame_width_) / kWidthFieldOfView);
|
static_cast<float>(frame_width_) / kWidthFieldOfView);
|
||||||
min_height_value_ = 1.0;
|
max_frame_value_ = 1.0;
|
||||||
|
target_aspect_ = frame_width_ / static_cast<float>(frame_height_);
|
||||||
// If target size is set and wider than input aspect, make sure to always
|
// If target size is set and wider than input aspect, make sure to always
|
||||||
// crop the min required amount.
|
// crop the min required amount.
|
||||||
if (options_.has_target_size()) {
|
if (options_.has_target_size()) {
|
||||||
|
@ -203,21 +248,23 @@ void MakeStaticFeatures(const int top_border, const int bottom_border,
|
||||||
RET_CHECK_GT(options_.target_size().height(), 0)
|
RET_CHECK_GT(options_.target_size().height(), 0)
|
||||||
<< "Provided target height not valid.";
|
<< "Provided target height not valid.";
|
||||||
float input_aspect = frame_width_ / static_cast<float>(frame_height_);
|
float input_aspect = frame_width_ / static_cast<float>(frame_height_);
|
||||||
float target_aspect = options_.target_size().width() /
|
target_aspect_ = options_.target_size().width() /
|
||||||
static_cast<float>(options_.target_size().height());
|
static_cast<float>(options_.target_size().height());
|
||||||
min_height_value_ =
|
max_frame_value_ = std::min(input_aspect / target_aspect_,
|
||||||
(input_aspect < target_aspect) ? input_aspect / target_aspect : 1.0;
|
target_aspect_ / input_aspect);
|
||||||
}
|
}
|
||||||
last_measured_height_ = min_height_value_ * frame_height_;
|
last_measured_height_ = max_frame_value_ * frame_height_;
|
||||||
|
last_measured_x_offset_ = target_aspect_ * frame_width_;
|
||||||
last_measured_y_offset_ = frame_width_ / 2;
|
last_measured_y_offset_ = frame_width_ / 2;
|
||||||
initialized_ = true;
|
initialized_ = true;
|
||||||
}
|
}
|
||||||
|
|
||||||
auto detection_set = cc->Inputs().Tag(kDetectionSet).Get<DetectionSet>();
|
|
||||||
bool only_required_found = false;
|
bool only_required_found = false;
|
||||||
|
|
||||||
// Compute the box that contains all "is_required" detections.
|
// Compute the box that contains all "is_required" detections.
|
||||||
float xmin = 1, ymin = 1, xmax = 0, ymax = 0;
|
float xmin = 1, ymin = 1, xmax = 0, ymax = 0;
|
||||||
|
if (cc->Inputs().HasTag(kSalientRegions)) {
|
||||||
|
auto detection_set = cc->Inputs().Tag(kSalientRegions).Get<DetectionSet>();
|
||||||
for (const auto& region : detection_set.detections()) {
|
for (const auto& region : detection_set.detections()) {
|
||||||
if (!region.only_required()) {
|
if (!region.only_required()) {
|
||||||
continue;
|
continue;
|
||||||
|
@ -225,46 +272,64 @@ void MakeStaticFeatures(const int top_border, const int bottom_border,
|
||||||
only_required_found = true;
|
only_required_found = true;
|
||||||
MP_RETURN_IF_ERROR(UpdateRanges(region, &xmin, &xmax, &ymin, &ymax));
|
MP_RETURN_IF_ERROR(UpdateRanges(region, &xmin, &xmax, &ymin, &ymax));
|
||||||
}
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
if (cc->Inputs().HasTag(kDetections)) {
|
||||||
|
auto raw_detections =
|
||||||
|
cc->Inputs().Tag(kDetections).Get<std::vector<mediapipe::Detection>>();
|
||||||
|
for (const auto& detection : raw_detections) {
|
||||||
|
only_required_found = true;
|
||||||
|
MP_RETURN_IF_ERROR(UpdateRanges(detection, &xmin, &xmax, &ymin, &ymax));
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
// Convert bounds to tilt/zoom and in pixel coordinates.
|
// Convert bounds to tilt/zoom and in pixel coordinates.
|
||||||
int offset, height;
|
int offset_y, height, offset_x;
|
||||||
MP_RETURN_IF_ERROR(
|
MP_RETURN_IF_ERROR(ConvertToPanTiltZoom(xmin, xmax, ymin, ymax, &offset_y,
|
||||||
ConvertToTiltZoom(xmin, xmax, ymin, ymax, &offset, &height));
|
&offset_x, &height));
|
||||||
|
|
||||||
if (only_required_found) {
|
if (only_required_found) {
|
||||||
// A only required detection was found.
|
// A only required detection was found.
|
||||||
last_only_required_detection_ = cc->InputTimestamp().Microseconds();
|
last_only_required_detection_ = cc->InputTimestamp().Microseconds();
|
||||||
last_measured_height_ = height;
|
last_measured_height_ = height;
|
||||||
last_measured_y_offset_ = offset;
|
last_measured_x_offset_ = offset_x;
|
||||||
|
last_measured_y_offset_ = offset_y;
|
||||||
} else if (cc->InputTimestamp().Microseconds() -
|
} else if (cc->InputTimestamp().Microseconds() -
|
||||||
last_only_required_detection_ >=
|
last_only_required_detection_ >=
|
||||||
options_.us_before_zoomout()) {
|
options_.us_before_zoomout()) {
|
||||||
// No only_require detections found within salient regions packets arriving
|
// No only_require detections found within salient regions packets
|
||||||
// since us_before_zoomout duration.
|
// arriving since us_before_zoomout duration.
|
||||||
height = min_height_value_ * frame_height_;
|
height = max_frame_value_ * frame_height_;
|
||||||
offset = frame_height_ / 2;
|
offset_x = (target_aspect_ * height) / 2;
|
||||||
|
offset_y = frame_height_ / 2;
|
||||||
} else {
|
} else {
|
||||||
// No only detection found but using last detection due to
|
// No only detection found but using last detection due to
|
||||||
// duration_before_zoomout_us setting.
|
// duration_before_zoomout_us setting.
|
||||||
height = last_measured_height_;
|
height = last_measured_height_;
|
||||||
offset = last_measured_y_offset_;
|
offset_x = last_measured_x_offset_;
|
||||||
|
offset_y = last_measured_y_offset_;
|
||||||
}
|
}
|
||||||
|
|
||||||
// Compute smoothed camera paths.
|
// Compute smoothed camera paths.
|
||||||
MP_RETURN_IF_ERROR(path_solver_height_->AddObservation(
|
MP_RETURN_IF_ERROR(path_solver_height_->AddObservation(
|
||||||
height, cc->InputTimestamp().Microseconds()));
|
height, cc->InputTimestamp().Microseconds()));
|
||||||
|
MP_RETURN_IF_ERROR(path_solver_width_->AddObservation(
|
||||||
|
offset_x, cc->InputTimestamp().Microseconds()));
|
||||||
MP_RETURN_IF_ERROR(path_solver_offset_->AddObservation(
|
MP_RETURN_IF_ERROR(path_solver_offset_->AddObservation(
|
||||||
offset, cc->InputTimestamp().Microseconds()));
|
offset_y, cc->InputTimestamp().Microseconds()));
|
||||||
int path_size;
|
int path_height;
|
||||||
MP_RETURN_IF_ERROR(path_solver_height_->GetState(&path_size));
|
MP_RETURN_IF_ERROR(path_solver_height_->GetState(&path_height));
|
||||||
int path_offset;
|
int path_offset_x;
|
||||||
MP_RETURN_IF_ERROR(path_solver_offset_->GetState(&path_offset));
|
MP_RETURN_IF_ERROR(path_solver_width_->GetState(&path_offset_x));
|
||||||
|
int path_offset_y;
|
||||||
|
MP_RETURN_IF_ERROR(path_solver_offset_->GetState(&path_offset_y));
|
||||||
|
|
||||||
// Convert to top/bottom borders to remove.
|
// Convert to top/bottom borders to remove.
|
||||||
int path_top = path_offset - path_size / 2;
|
int path_top = path_offset_y - path_height / 2;
|
||||||
int path_bottom = frame_height_ - (path_offset + path_size / 2);
|
int path_bottom = frame_height_ - (path_offset_y + path_height / 2);
|
||||||
|
|
||||||
// Transmit result downstream.
|
// Transmit result downstream to scenecroppingcalculator.
|
||||||
|
if (cc->Outputs().HasTag(kDetectedBorders)) {
|
||||||
std::unique_ptr<StaticFeatures> features =
|
std::unique_ptr<StaticFeatures> features =
|
||||||
absl::make_unique<StaticFeatures>();
|
absl::make_unique<StaticFeatures>();
|
||||||
MakeStaticFeatures(path_top, path_bottom, frame_width_, frame_height_,
|
MakeStaticFeatures(path_top, path_bottom, frame_width_, frame_height_,
|
||||||
|
@ -272,6 +337,18 @@ void MakeStaticFeatures(const int top_border, const int bottom_border,
|
||||||
cc->Outputs()
|
cc->Outputs()
|
||||||
.Tag(kDetectedBorders)
|
.Tag(kDetectedBorders)
|
||||||
.AddPacket(Adopt(features.release()).At(cc->InputTimestamp()));
|
.AddPacket(Adopt(features.release()).At(cc->InputTimestamp()));
|
||||||
|
}
|
||||||
|
|
||||||
|
// Transmit downstream to glcroppingcalculator.
|
||||||
|
if (cc->Outputs().HasTag(kCropRect)) {
|
||||||
|
auto gpu_rect = absl::make_unique<mediapipe::Rect>();
|
||||||
|
gpu_rect->set_x_center(path_offset_x);
|
||||||
|
gpu_rect->set_width(path_height * target_aspect_);
|
||||||
|
gpu_rect->set_y_center(path_offset_y);
|
||||||
|
gpu_rect->set_height(path_height);
|
||||||
|
cc->Outputs().Tag(kCropRect).Add(gpu_rect.release(),
|
||||||
|
Timestamp(cc->InputTimestamp()));
|
||||||
|
}
|
||||||
|
|
||||||
return ::mediapipe::OkStatus();
|
return ::mediapipe::OkStatus();
|
||||||
}
|
}
|
||||||
|
|
|
@ -32,6 +32,8 @@ message ContentZoomingCalculatorOptions {
|
||||||
optional KinematicOptions kinematic_options_zoom = 6;
|
optional KinematicOptions kinematic_options_zoom = 6;
|
||||||
// Kinematic options for tilt (y-axis reframing.)
|
// Kinematic options for tilt (y-axis reframing.)
|
||||||
optional KinematicOptions kinematic_options_tilt = 7;
|
optional KinematicOptions kinematic_options_tilt = 7;
|
||||||
|
// Kinematic options for pan (x-axis reframing.)
|
||||||
|
optional KinematicOptions kinematic_options_pan = 10;
|
||||||
// Duration (in MicroSeconds) before returning to fully zoomed out position
|
// Duration (in MicroSeconds) before returning to fully zoomed out position
|
||||||
// when no "only_required" frames are received.
|
// when no "only_required" frames are received.
|
||||||
optional int64 us_before_zoomout = 9 [default = 1000000];
|
optional int64 us_before_zoomout = 9 [default = 1000000];
|
||||||
|
|
|
@ -16,10 +16,14 @@
|
||||||
|
|
||||||
#include "mediapipe/examples/desktop/autoflip/autoflip_messages.pb.h"
|
#include "mediapipe/examples/desktop/autoflip/autoflip_messages.pb.h"
|
||||||
#include "mediapipe/examples/desktop/autoflip/calculators/content_zooming_calculator.pb.h"
|
#include "mediapipe/examples/desktop/autoflip/calculators/content_zooming_calculator.pb.h"
|
||||||
|
#include "mediapipe/examples/desktop/autoflip/quality/kinematic_path_solver.h"
|
||||||
#include "mediapipe/framework/calculator_framework.h"
|
#include "mediapipe/framework/calculator_framework.h"
|
||||||
#include "mediapipe/framework/calculator_runner.h"
|
#include "mediapipe/framework/calculator_runner.h"
|
||||||
|
#include "mediapipe/framework/formats/detection.pb.h"
|
||||||
#include "mediapipe/framework/formats/image_frame.h"
|
#include "mediapipe/framework/formats/image_frame.h"
|
||||||
#include "mediapipe/framework/formats/image_frame_opencv.h"
|
#include "mediapipe/framework/formats/image_frame_opencv.h"
|
||||||
|
#include "mediapipe/framework/formats/location_data.pb.h"
|
||||||
|
#include "mediapipe/framework/formats/rect.pb.h"
|
||||||
#include "mediapipe/framework/port/benchmark.h"
|
#include "mediapipe/framework/port/benchmark.h"
|
||||||
#include "mediapipe/framework/port/gmock.h"
|
#include "mediapipe/framework/port/gmock.h"
|
||||||
#include "mediapipe/framework/port/gtest.h"
|
#include "mediapipe/framework/port/gtest.h"
|
||||||
|
@ -36,14 +40,14 @@ namespace {
|
||||||
const char kConfigA[] = R"(
|
const char kConfigA[] = R"(
|
||||||
calculator: "ContentZoomingCalculator"
|
calculator: "ContentZoomingCalculator"
|
||||||
input_stream: "VIDEO:camera_frames"
|
input_stream: "VIDEO:camera_frames"
|
||||||
input_stream: "DETECTIONS:detection_set"
|
input_stream: "SALIENT_REGIONS:detection_set"
|
||||||
output_stream: "BORDERS:borders"
|
output_stream: "BORDERS:borders"
|
||||||
)";
|
)";
|
||||||
|
|
||||||
const char kConfigB[] = R"(
|
const char kConfigB[] = R"(
|
||||||
calculator: "ContentZoomingCalculator"
|
calculator: "ContentZoomingCalculator"
|
||||||
input_stream: "VIDEO:camera_frames"
|
input_stream: "VIDEO:camera_frames"
|
||||||
input_stream: "DETECTIONS:detection_set"
|
input_stream: "SALIENT_REGIONS:detection_set"
|
||||||
output_stream: "BORDERS:borders"
|
output_stream: "BORDERS:borders"
|
||||||
options: {
|
options: {
|
||||||
[mediapipe.autoflip.ContentZoomingCalculatorOptions.ext]: {
|
[mediapipe.autoflip.ContentZoomingCalculatorOptions.ext]: {
|
||||||
|
@ -58,10 +62,17 @@ const char kConfigB[] = R"(
|
||||||
const char kConfigC[] = R"(
|
const char kConfigC[] = R"(
|
||||||
calculator: "ContentZoomingCalculator"
|
calculator: "ContentZoomingCalculator"
|
||||||
input_stream: "VIDEO_SIZE:size"
|
input_stream: "VIDEO_SIZE:size"
|
||||||
input_stream: "DETECTIONS:detection_set"
|
input_stream: "SALIENT_REGIONS:detection_set"
|
||||||
output_stream: "BORDERS:borders"
|
output_stream: "BORDERS:borders"
|
||||||
)";
|
)";
|
||||||
|
|
||||||
|
const char kConfigD[] = R"(
|
||||||
|
calculator: "ContentZoomingCalculator"
|
||||||
|
input_stream: "VIDEO_SIZE:size"
|
||||||
|
input_stream: "DETECTIONS:detections"
|
||||||
|
output_stream: "CROP_RECT:rect"
|
||||||
|
)";
|
||||||
|
|
||||||
void CheckBorder(const StaticFeatures& static_features, int width, int height,
|
void CheckBorder(const StaticFeatures& static_features, int width, int height,
|
||||||
int top_border, int bottom_border) {
|
int top_border, int bottom_border) {
|
||||||
ASSERT_EQ(2, static_features.border().size());
|
ASSERT_EQ(2, static_features.border().size());
|
||||||
|
@ -80,6 +91,43 @@ void CheckBorder(const StaticFeatures& static_features, int width, int height,
|
||||||
EXPECT_EQ(Border::BOTTOM, part.relative_position());
|
EXPECT_EQ(Border::BOTTOM, part.relative_position());
|
||||||
}
|
}
|
||||||
|
|
||||||
|
void AddDetection(const cv::Rect_<float>& position, const int64 time,
|
||||||
|
CalculatorRunner* runner) {
|
||||||
|
auto detections = std::make_unique<std::vector<mediapipe::Detection>>();
|
||||||
|
mediapipe::Detection detection;
|
||||||
|
detection.mutable_location_data()->set_format(
|
||||||
|
mediapipe::LocationData::RELATIVE_BOUNDING_BOX);
|
||||||
|
detection.mutable_location_data()
|
||||||
|
->mutable_relative_bounding_box()
|
||||||
|
->set_height(position.height);
|
||||||
|
detection.mutable_location_data()->mutable_relative_bounding_box()->set_width(
|
||||||
|
position.width);
|
||||||
|
detection.mutable_location_data()->mutable_relative_bounding_box()->set_xmin(
|
||||||
|
position.x);
|
||||||
|
detection.mutable_location_data()->mutable_relative_bounding_box()->set_ymin(
|
||||||
|
position.y);
|
||||||
|
detections->push_back(detection);
|
||||||
|
runner->MutableInputs()
|
||||||
|
->Tag("DETECTIONS")
|
||||||
|
.packets.push_back(Adopt(detections.release()).At(Timestamp(time)));
|
||||||
|
|
||||||
|
auto input_size = ::absl::make_unique<std::pair<int, int>>(1000, 1000);
|
||||||
|
runner->MutableInputs()
|
||||||
|
->Tag("VIDEO_SIZE")
|
||||||
|
.packets.push_back(Adopt(input_size.release()).At(Timestamp(time)));
|
||||||
|
}
|
||||||
|
|
||||||
|
void CheckCropRect(const int x_center, const int y_center, const int width,
|
||||||
|
const int height, const int frame_number,
|
||||||
|
const std::vector<Packet>& output_packets) {
|
||||||
|
ASSERT_GT(output_packets.size(), frame_number);
|
||||||
|
const auto& rect = output_packets[frame_number].Get<mediapipe::Rect>();
|
||||||
|
EXPECT_EQ(rect.x_center(), x_center);
|
||||||
|
EXPECT_EQ(rect.y_center(), y_center);
|
||||||
|
EXPECT_EQ(rect.width(), width);
|
||||||
|
EXPECT_EQ(rect.height(), height);
|
||||||
|
}
|
||||||
|
|
||||||
TEST(ContentZoomingCalculatorTest, ZoomTest) {
|
TEST(ContentZoomingCalculatorTest, ZoomTest) {
|
||||||
auto runner = ::absl::make_unique<CalculatorRunner>(
|
auto runner = ::absl::make_unique<CalculatorRunner>(
|
||||||
ParseTextProtoOrDie<CalculatorGraphConfig::Node>(kConfigA));
|
ParseTextProtoOrDie<CalculatorGraphConfig::Node>(kConfigA));
|
||||||
|
@ -98,7 +146,7 @@ TEST(ContentZoomingCalculatorTest, ZoomTest) {
|
||||||
Adopt(input_frame.release()).At(Timestamp(0)));
|
Adopt(input_frame.release()).At(Timestamp(0)));
|
||||||
|
|
||||||
runner->MutableInputs()
|
runner->MutableInputs()
|
||||||
->Tag("DETECTIONS")
|
->Tag("SALIENT_REGIONS")
|
||||||
.packets.push_back(Adopt(detection_set.release()).At(Timestamp(0)));
|
.packets.push_back(Adopt(detection_set.release()).At(Timestamp(0)));
|
||||||
|
|
||||||
// Run the calculator.
|
// Run the calculator.
|
||||||
|
@ -111,6 +159,66 @@ TEST(ContentZoomingCalculatorTest, ZoomTest) {
|
||||||
CheckBorder(static_features, 1000, 1000, 495, 395);
|
CheckBorder(static_features, 1000, 1000, 495, 395);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
TEST(ContentZoomingCalculatorTest, ZoomTestFullPTZ) {
|
||||||
|
auto runner = ::absl::make_unique<CalculatorRunner>(
|
||||||
|
ParseTextProtoOrDie<CalculatorGraphConfig::Node>(kConfigD));
|
||||||
|
AddDetection(cv::Rect_<float>(.4, .5, .1, .1), 0, runner.get());
|
||||||
|
MP_ASSERT_OK(runner->Run());
|
||||||
|
CheckCropRect(450, 550, 111, 111, 0,
|
||||||
|
runner->Outputs().Tag("CROP_RECT").packets);
|
||||||
|
}
|
||||||
|
|
||||||
|
TEST(ContentZoomingCalculatorTest, PanConfig) {
|
||||||
|
auto config = ParseTextProtoOrDie<CalculatorGraphConfig::Node>(kConfigD);
|
||||||
|
auto* options = config.mutable_options()->MutableExtension(
|
||||||
|
ContentZoomingCalculatorOptions::ext);
|
||||||
|
options->mutable_kinematic_options_pan()->set_min_motion_to_reframe(0.0);
|
||||||
|
options->mutable_kinematic_options_tilt()->set_min_motion_to_reframe(5.0);
|
||||||
|
options->mutable_kinematic_options_zoom()->set_min_motion_to_reframe(5.0);
|
||||||
|
auto runner = ::absl::make_unique<CalculatorRunner>(config);
|
||||||
|
AddDetection(cv::Rect_<float>(.4, .5, .1, .1), 0, runner.get());
|
||||||
|
AddDetection(cv::Rect_<float>(.45, .55, .15, .15), 1000000, runner.get());
|
||||||
|
MP_ASSERT_OK(runner->Run());
|
||||||
|
CheckCropRect(450, 550, 111, 111, 0,
|
||||||
|
runner->Outputs().Tag("CROP_RECT").packets);
|
||||||
|
CheckCropRect(488, 550, 111, 111, 1,
|
||||||
|
runner->Outputs().Tag("CROP_RECT").packets);
|
||||||
|
}
|
||||||
|
|
||||||
|
TEST(ContentZoomingCalculatorTest, TiltConfig) {
|
||||||
|
auto config = ParseTextProtoOrDie<CalculatorGraphConfig::Node>(kConfigD);
|
||||||
|
auto* options = config.mutable_options()->MutableExtension(
|
||||||
|
ContentZoomingCalculatorOptions::ext);
|
||||||
|
options->mutable_kinematic_options_pan()->set_min_motion_to_reframe(5.0);
|
||||||
|
options->mutable_kinematic_options_tilt()->set_min_motion_to_reframe(0.0);
|
||||||
|
options->mutable_kinematic_options_zoom()->set_min_motion_to_reframe(5.0);
|
||||||
|
auto runner = ::absl::make_unique<CalculatorRunner>(config);
|
||||||
|
AddDetection(cv::Rect_<float>(.4, .5, .1, .1), 0, runner.get());
|
||||||
|
AddDetection(cv::Rect_<float>(.45, .55, .15, .15), 1000000, runner.get());
|
||||||
|
MP_ASSERT_OK(runner->Run());
|
||||||
|
CheckCropRect(450, 550, 111, 111, 0,
|
||||||
|
runner->Outputs().Tag("CROP_RECT").packets);
|
||||||
|
CheckCropRect(450, 588, 111, 111, 1,
|
||||||
|
runner->Outputs().Tag("CROP_RECT").packets);
|
||||||
|
}
|
||||||
|
|
||||||
|
TEST(ContentZoomingCalculatorTest, ZoomConfig) {
|
||||||
|
auto config = ParseTextProtoOrDie<CalculatorGraphConfig::Node>(kConfigD);
|
||||||
|
auto* options = config.mutable_options()->MutableExtension(
|
||||||
|
ContentZoomingCalculatorOptions::ext);
|
||||||
|
options->mutable_kinematic_options_pan()->set_min_motion_to_reframe(5.0);
|
||||||
|
options->mutable_kinematic_options_tilt()->set_min_motion_to_reframe(5.0);
|
||||||
|
options->mutable_kinematic_options_zoom()->set_min_motion_to_reframe(0.0);
|
||||||
|
auto runner = ::absl::make_unique<CalculatorRunner>(config);
|
||||||
|
AddDetection(cv::Rect_<float>(.4, .5, .1, .1), 0, runner.get());
|
||||||
|
AddDetection(cv::Rect_<float>(.45, .55, .15, .15), 1000000, runner.get());
|
||||||
|
MP_ASSERT_OK(runner->Run());
|
||||||
|
CheckCropRect(450, 550, 111, 111, 0,
|
||||||
|
runner->Outputs().Tag("CROP_RECT").packets);
|
||||||
|
CheckCropRect(450, 550, 139, 139, 1,
|
||||||
|
runner->Outputs().Tag("CROP_RECT").packets);
|
||||||
|
}
|
||||||
|
|
||||||
TEST(ContentZoomingCalculatorTest, MinAspectBorderValues) {
|
TEST(ContentZoomingCalculatorTest, MinAspectBorderValues) {
|
||||||
auto runner = ::absl::make_unique<CalculatorRunner>(
|
auto runner = ::absl::make_unique<CalculatorRunner>(
|
||||||
ParseTextProtoOrDie<CalculatorGraphConfig::Node>(kConfigB));
|
ParseTextProtoOrDie<CalculatorGraphConfig::Node>(kConfigB));
|
||||||
|
@ -129,7 +237,7 @@ TEST(ContentZoomingCalculatorTest, MinAspectBorderValues) {
|
||||||
Adopt(input_frame.release()).At(Timestamp(0)));
|
Adopt(input_frame.release()).At(Timestamp(0)));
|
||||||
|
|
||||||
runner->MutableInputs()
|
runner->MutableInputs()
|
||||||
->Tag("DETECTIONS")
|
->Tag("SALIENT_REGIONS")
|
||||||
.packets.push_back(Adopt(detection_set.release()).At(Timestamp(0)));
|
.packets.push_back(Adopt(detection_set.release()).At(Timestamp(0)));
|
||||||
|
|
||||||
// Run the calculator.
|
// Run the calculator.
|
||||||
|
@ -166,7 +274,7 @@ TEST(ContentZoomingCalculatorTest, TwoFacesWide) {
|
||||||
Adopt(input_frame.release()).At(Timestamp(0)));
|
Adopt(input_frame.release()).At(Timestamp(0)));
|
||||||
|
|
||||||
runner->MutableInputs()
|
runner->MutableInputs()
|
||||||
->Tag("DETECTIONS")
|
->Tag("SALIENT_REGIONS")
|
||||||
.packets.push_back(Adopt(detection_set.release()).At(Timestamp(0)));
|
.packets.push_back(Adopt(detection_set.release()).At(Timestamp(0)));
|
||||||
|
|
||||||
// Run the calculator.
|
// Run the calculator.
|
||||||
|
@ -191,7 +299,7 @@ TEST(ContentZoomingCalculatorTest, NoDetectionOnInit) {
|
||||||
Adopt(input_frame.release()).At(Timestamp(0)));
|
Adopt(input_frame.release()).At(Timestamp(0)));
|
||||||
|
|
||||||
runner->MutableInputs()
|
runner->MutableInputs()
|
||||||
->Tag("DETECTIONS")
|
->Tag("SALIENT_REGIONS")
|
||||||
.packets.push_back(Adopt(detection_set.release()).At(Timestamp(0)));
|
.packets.push_back(Adopt(detection_set.release()).At(Timestamp(0)));
|
||||||
|
|
||||||
// Run the calculator.
|
// Run the calculator.
|
||||||
|
@ -223,7 +331,7 @@ TEST(ContentZoomingCalculatorTest, ZoomTestPairSize) {
|
||||||
.packets.push_back(Adopt(input_size.release()).At(Timestamp(0)));
|
.packets.push_back(Adopt(input_size.release()).At(Timestamp(0)));
|
||||||
|
|
||||||
runner->MutableInputs()
|
runner->MutableInputs()
|
||||||
->Tag("DETECTIONS")
|
->Tag("SALIENT_REGIONS")
|
||||||
.packets.push_back(Adopt(detection_set.release()).At(Timestamp(0)));
|
.packets.push_back(Adopt(detection_set.release()).At(Timestamp(0)));
|
||||||
|
|
||||||
// Run the calculator.
|
// Run the calculator.
|
||||||
|
|
|
@ -37,7 +37,7 @@ node {
|
||||||
output_stream: "TENSORS:detection_tensors"
|
output_stream: "TENSORS:detection_tensors"
|
||||||
options: {
|
options: {
|
||||||
[mediapipe.TfLiteInferenceCalculatorOptions.ext] {
|
[mediapipe.TfLiteInferenceCalculatorOptions.ext] {
|
||||||
model_path: "face_detection_front.tflite"
|
model_path: "mediapipe/models/face_detection_front.tflite"
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
@ -118,7 +118,7 @@ node {
|
||||||
output_stream: "labeled_detections"
|
output_stream: "labeled_detections"
|
||||||
options: {
|
options: {
|
||||||
[mediapipe.DetectionLabelIdToTextCalculatorOptions.ext] {
|
[mediapipe.DetectionLabelIdToTextCalculatorOptions.ext] {
|
||||||
label_map_path: "face_detection_front_labelmap.txt"
|
label_map_path: "mediapipe/models/face_detection_front_labelmap.txt"
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
|
@ -1,18 +1 @@
|
||||||
This directory contains example MediaPipe applications on iOS.
|
This directory contains MediaPipe example applications for iOS. Please see [Solutions](https://solutions.mediapipe.dev)for details.
|
||||||
|
|
||||||
| Use Case | Directory |
|
|
||||||
|---------------------------------------|:-----------------------------------:|
|
|
||||||
| Edge Detection on GPU | edgedetection |
|
|
||||||
| Face Detection on CPU | facedetectioncpu |
|
|
||||||
| Face Detection on GPU | facedetectiongpu |
|
|
||||||
| Object Detection on CPU | objectdetectioncpu |
|
|
||||||
| Object Detection on GPU | objectdetectiongpu |
|
|
||||||
| Hand Detection on GPU | handdetectiongpu |
|
|
||||||
| Hand Tracking on GPU | handtrackinggpu |
|
|
||||||
|
|
||||||
For instance, to build an example app for face detection on CPU, run:
|
|
||||||
|
|
||||||
```bash
|
|
||||||
bazel build -c opt --config=ios_arm64 --xcode_version=$XCODE_VERSION --cxxopt='-std=c++14' mediapipe/examples/ios/facedetectioncpu:FaceDetectionCpuApp
|
|
||||||
```
|
|
||||||
(Note: with your own $XCODE_VERSION)
|
|
||||||
|
|
|
@ -21,6 +21,11 @@ load(
|
||||||
"ios_application",
|
"ios_application",
|
||||||
)
|
)
|
||||||
|
|
||||||
|
alias(
|
||||||
|
name = "edgedetectiongpu",
|
||||||
|
actual = "EdgeDetectionGpuApp",
|
||||||
|
)
|
||||||
|
|
||||||
ios_application(
|
ios_application(
|
||||||
name = "EdgeDetectionGpuApp",
|
name = "EdgeDetectionGpuApp",
|
||||||
bundle_id = "com.google.mediapipe.EdgeDetectionGpu",
|
bundle_id = "com.google.mediapipe.EdgeDetectionGpu",
|
||||||
|
|
|
@ -21,6 +21,11 @@ load(
|
||||||
"ios_application",
|
"ios_application",
|
||||||
)
|
)
|
||||||
|
|
||||||
|
alias(
|
||||||
|
name = "facedetectioncpu",
|
||||||
|
actual = "FaceDetectionCpuApp",
|
||||||
|
)
|
||||||
|
|
||||||
ios_application(
|
ios_application(
|
||||||
name = "FaceDetectionCpuApp",
|
name = "FaceDetectionCpuApp",
|
||||||
bundle_id = "com.google.mediapipe.FaceDetectionCpu",
|
bundle_id = "com.google.mediapipe.FaceDetectionCpu",
|
||||||
|
|
|
@ -21,6 +21,11 @@ load(
|
||||||
"ios_application",
|
"ios_application",
|
||||||
)
|
)
|
||||||
|
|
||||||
|
alias(
|
||||||
|
name = "facedetectiongpu",
|
||||||
|
actual = "FaceDetectionGpuApp",
|
||||||
|
)
|
||||||
|
|
||||||
ios_application(
|
ios_application(
|
||||||
name = "FaceDetectionGpuApp",
|
name = "FaceDetectionGpuApp",
|
||||||
bundle_id = "com.google.mediapipe.FaceDetectionGpu",
|
bundle_id = "com.google.mediapipe.FaceDetectionGpu",
|
||||||
|
|
|
@ -21,6 +21,11 @@ licenses(["notice"]) # Apache 2.0
|
||||||
|
|
||||||
MIN_IOS_VERSION = "10.0"
|
MIN_IOS_VERSION = "10.0"
|
||||||
|
|
||||||
|
alias(
|
||||||
|
name = "facemeshgpu",
|
||||||
|
actual = "FaceMeshGpuApp",
|
||||||
|
)
|
||||||
|
|
||||||
ios_application(
|
ios_application(
|
||||||
name = "FaceMeshGpuApp",
|
name = "FaceMeshGpuApp",
|
||||||
bundle_id = "com.google.mediapipe.FaceMeshGpu",
|
bundle_id = "com.google.mediapipe.FaceMeshGpu",
|
||||||
|
|
|
@ -21,6 +21,11 @@ load(
|
||||||
"ios_application",
|
"ios_application",
|
||||||
)
|
)
|
||||||
|
|
||||||
|
alias(
|
||||||
|
name = "handdetectiongpu",
|
||||||
|
actual = "HandDetectionGpuApp",
|
||||||
|
)
|
||||||
|
|
||||||
ios_application(
|
ios_application(
|
||||||
name = "HandDetectionGpuApp",
|
name = "HandDetectionGpuApp",
|
||||||
bundle_id = "com.google.mediapipe.HandDetectionGpu",
|
bundle_id = "com.google.mediapipe.HandDetectionGpu",
|
||||||
|
|
|
@ -21,6 +21,11 @@ licenses(["notice"]) # Apache 2.0
|
||||||
|
|
||||||
MIN_IOS_VERSION = "10.0"
|
MIN_IOS_VERSION = "10.0"
|
||||||
|
|
||||||
|
alias(
|
||||||
|
name = "handtrackinggpu",
|
||||||
|
actual = "HandTrackingGpuApp",
|
||||||
|
)
|
||||||
|
|
||||||
ios_application(
|
ios_application(
|
||||||
name = "HandTrackingGpuApp",
|
name = "HandTrackingGpuApp",
|
||||||
bundle_id = "com.google.mediapipe.HandTrackingGpu",
|
bundle_id = "com.google.mediapipe.HandTrackingGpu",
|
||||||
|
|
|
@ -21,6 +21,11 @@ licenses(["notice"]) # Apache 2.0
|
||||||
|
|
||||||
MIN_IOS_VERSION = "10.0"
|
MIN_IOS_VERSION = "10.0"
|
||||||
|
|
||||||
|
alias(
|
||||||
|
name = "multihandtrackinggpu",
|
||||||
|
actual = "MultiHandTrackingGpuApp",
|
||||||
|
)
|
||||||
|
|
||||||
ios_application(
|
ios_application(
|
||||||
name = "MultiHandTrackingGpuApp",
|
name = "MultiHandTrackingGpuApp",
|
||||||
bundle_id = "com.google.mediapipe.MultiHandTrackingGpu",
|
bundle_id = "com.google.mediapipe.MultiHandTrackingGpu",
|
||||||
|
|
|
@ -21,6 +21,11 @@ load(
|
||||||
"ios_application",
|
"ios_application",
|
||||||
)
|
)
|
||||||
|
|
||||||
|
alias(
|
||||||
|
name = "objectdetectioncpu",
|
||||||
|
actual = "ObjectDetectionCpuApp",
|
||||||
|
)
|
||||||
|
|
||||||
ios_application(
|
ios_application(
|
||||||
name = "ObjectDetectionCpuApp",
|
name = "ObjectDetectionCpuApp",
|
||||||
bundle_id = "com.google.mediapipe.ObjectDetectionCpu",
|
bundle_id = "com.google.mediapipe.ObjectDetectionCpu",
|
||||||
|
|
|
@ -13,6 +13,7 @@
|
||||||
// limitations under the License.
|
// limitations under the License.
|
||||||
|
|
||||||
#import "AppDelegate.h"
|
#import "AppDelegate.h"
|
||||||
|
#import "ViewController.h"
|
||||||
|
|
||||||
@interface AppDelegate ()
|
@interface AppDelegate ()
|
||||||
|
|
||||||
|
@ -22,7 +23,14 @@
|
||||||
|
|
||||||
- (BOOL)application:(UIApplication *)application
|
- (BOOL)application:(UIApplication *)application
|
||||||
didFinishLaunchingWithOptions:(NSDictionary *)launchOptions {
|
didFinishLaunchingWithOptions:(NSDictionary *)launchOptions {
|
||||||
// Override point for customization after application launch.
|
ViewController *viewController = (ViewController *)self.window.rootViewController;
|
||||||
|
NSURL *url = [launchOptions objectForKey:UIApplicationLaunchOptionsURLKey];
|
||||||
|
// Unattended testing on Firebase is enabled by custom URL schema.
|
||||||
|
if ([url.scheme isEqualToString:@"firebase-game-loop"]) {
|
||||||
|
[viewController setSourceMode:MediaPipeDemoSourceVideo];
|
||||||
|
} else {
|
||||||
|
[viewController setSourceMode:MediaPipeDemoSourceBackCamera];
|
||||||
|
}
|
||||||
return YES;
|
return YES;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
|
@ -21,6 +21,11 @@ load(
|
||||||
"ios_application",
|
"ios_application",
|
||||||
)
|
)
|
||||||
|
|
||||||
|
alias(
|
||||||
|
name = "objectdetectiongpu",
|
||||||
|
actual = "ObjectDetectionGpuApp",
|
||||||
|
)
|
||||||
|
|
||||||
ios_application(
|
ios_application(
|
||||||
name = "ObjectDetectionGpuApp",
|
name = "ObjectDetectionGpuApp",
|
||||||
bundle_id = "com.google.mediapipe.ObjectDetectionGpu",
|
bundle_id = "com.google.mediapipe.ObjectDetectionGpu",
|
||||||
|
|
|
@ -38,5 +38,18 @@
|
||||||
<array>
|
<array>
|
||||||
<string>UIInterfaceOrientationPortrait</string>
|
<string>UIInterfaceOrientationPortrait</string>
|
||||||
</array>
|
</array>
|
||||||
|
<key>CFBundleURLTypes</key>
|
||||||
|
<array>
|
||||||
|
<dict>
|
||||||
|
<key>CFBundleURLName</key>
|
||||||
|
<string>com.google.firebase</string>
|
||||||
|
<key>CFBundleTypeRole</key>
|
||||||
|
<string>Editor</string>
|
||||||
|
<key>CFBundleURLSchemes</key>
|
||||||
|
<array>
|
||||||
|
<string>firebase-game-loop</string>
|
||||||
|
</array>
|
||||||
|
</dict>
|
||||||
|
</array>
|
||||||
</dict>
|
</dict>
|
||||||
</plist>
|
</plist>
|
||||||
|
|
|
@ -14,6 +14,11 @@
|
||||||
|
|
||||||
#import <UIKit/UIKit.h>
|
#import <UIKit/UIKit.h>
|
||||||
|
|
||||||
@interface ViewController : UIViewController
|
typedef NS_ENUM(NSInteger, MediaPipeDemoSourceMode) {
|
||||||
|
MediaPipeDemoSourceBackCamera,
|
||||||
|
MediaPipeDemoSourceVideo
|
||||||
|
};
|
||||||
|
|
||||||
|
@interface ViewController : UIViewController
|
||||||
|
- (void)setSourceMode:(MediaPipeDemoSourceMode)mode;
|
||||||
@end
|
@end
|
||||||
|
|
|
@ -17,6 +17,7 @@
|
||||||
#import "mediapipe/objc/MPPGraph.h"
|
#import "mediapipe/objc/MPPGraph.h"
|
||||||
#import "mediapipe/objc/MPPCameraInputSource.h"
|
#import "mediapipe/objc/MPPCameraInputSource.h"
|
||||||
#import "mediapipe/objc/MPPLayerRenderer.h"
|
#import "mediapipe/objc/MPPLayerRenderer.h"
|
||||||
|
#import "mediapipe/objc/MPPPlayerInputSource.h"
|
||||||
|
|
||||||
static NSString* const kGraphName = @"mobile_gpu";
|
static NSString* const kGraphName = @"mobile_gpu";
|
||||||
|
|
||||||
|
@ -35,6 +36,8 @@ static const char* kVideoQueueLabel = "com.google.mediapipe.example.videoQueue";
|
||||||
@implementation ViewController {
|
@implementation ViewController {
|
||||||
/// Handles camera access via AVCaptureSession library.
|
/// Handles camera access via AVCaptureSession library.
|
||||||
MPPCameraInputSource* _cameraSource;
|
MPPCameraInputSource* _cameraSource;
|
||||||
|
MPPPlayerInputSource* _videoSource;
|
||||||
|
MediaPipeDemoSourceMode _sourceMode;
|
||||||
|
|
||||||
/// Inform the user when camera is unavailable.
|
/// Inform the user when camera is unavailable.
|
||||||
IBOutlet UILabel* _noCameraLabel;
|
IBOutlet UILabel* _noCameraLabel;
|
||||||
|
@ -47,6 +50,10 @@ static const char* kVideoQueueLabel = "com.google.mediapipe.example.videoQueue";
|
||||||
dispatch_queue_t _videoQueue;
|
dispatch_queue_t _videoQueue;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
- (void)setSourceMode:(MediaPipeDemoSourceMode)mode {
|
||||||
|
_sourceMode = mode;
|
||||||
|
}
|
||||||
|
|
||||||
#pragma mark - Cleanup methods
|
#pragma mark - Cleanup methods
|
||||||
|
|
||||||
- (void)dealloc {
|
- (void)dealloc {
|
||||||
|
@ -97,13 +104,6 @@ static const char* kVideoQueueLabel = "com.google.mediapipe.example.videoQueue";
|
||||||
DISPATCH_QUEUE_SERIAL, QOS_CLASS_USER_INTERACTIVE, /*relative_priority=*/0);
|
DISPATCH_QUEUE_SERIAL, QOS_CLASS_USER_INTERACTIVE, /*relative_priority=*/0);
|
||||||
_videoQueue = dispatch_queue_create(kVideoQueueLabel, qosAttribute);
|
_videoQueue = dispatch_queue_create(kVideoQueueLabel, qosAttribute);
|
||||||
|
|
||||||
_cameraSource = [[MPPCameraInputSource alloc] init];
|
|
||||||
[_cameraSource setDelegate:self queue:_videoQueue];
|
|
||||||
_cameraSource.sessionPreset = AVCaptureSessionPresetHigh;
|
|
||||||
_cameraSource.cameraPosition = AVCaptureDevicePositionBack;
|
|
||||||
// The frame's native format is rotated with respect to the portrait orientation.
|
|
||||||
_cameraSource.orientation = AVCaptureVideoOrientationPortrait;
|
|
||||||
|
|
||||||
self.mediapipeGraph = [[self class] loadGraphFromResource:kGraphName];
|
self.mediapipeGraph = [[self class] loadGraphFromResource:kGraphName];
|
||||||
self.mediapipeGraph.delegate = self;
|
self.mediapipeGraph.delegate = self;
|
||||||
// Set maxFramesInFlight to a small value to avoid memory contention for real-time processing.
|
// Set maxFramesInFlight to a small value to avoid memory contention for real-time processing.
|
||||||
|
@ -119,27 +119,43 @@ static const char* kVideoQueueLabel = "com.google.mediapipe.example.videoQueue";
|
||||||
- (void)viewWillAppear:(BOOL)animated {
|
- (void)viewWillAppear:(BOOL)animated {
|
||||||
[super viewWillAppear:animated];
|
[super viewWillAppear:animated];
|
||||||
|
|
||||||
[_cameraSource requestCameraAccessWithCompletionHandler:^void(BOOL granted) {
|
|
||||||
if (granted) {
|
|
||||||
[self startGraphAndCamera];
|
|
||||||
dispatch_async(dispatch_get_main_queue(), ^{
|
|
||||||
_noCameraLabel.hidden = YES;
|
|
||||||
});
|
|
||||||
}
|
|
||||||
}];
|
|
||||||
}
|
|
||||||
|
|
||||||
- (void)startGraphAndCamera {
|
|
||||||
// Start running self.mediapipeGraph.
|
// Start running self.mediapipeGraph.
|
||||||
NSError* error;
|
NSError* error;
|
||||||
if (![self.mediapipeGraph startWithError:&error]) {
|
if (![self.mediapipeGraph startWithError:&error]) {
|
||||||
NSLog(@"Failed to start graph: %@", error);
|
NSLog(@"Failed to start graph: %@", error);
|
||||||
}
|
}
|
||||||
|
|
||||||
// Start fetching frames from the camera.
|
switch (_sourceMode) {
|
||||||
|
case MediaPipeDemoSourceVideo: {
|
||||||
|
AVAsset* video =
|
||||||
|
[AVAsset assetWithURL:[[NSBundle mainBundle] URLForResource:@"object_detection"
|
||||||
|
withExtension:@"mov"]];
|
||||||
|
_videoSource = [[MPPPlayerInputSource alloc] initWithAVAsset:video];
|
||||||
|
[_videoSource setDelegate:self queue:_videoQueue];
|
||||||
|
dispatch_async(_videoQueue, ^{
|
||||||
|
[_videoSource start];
|
||||||
|
});
|
||||||
|
break;
|
||||||
|
}
|
||||||
|
case MediaPipeDemoSourceBackCamera:
|
||||||
|
_cameraSource = [[MPPCameraInputSource alloc] init];
|
||||||
|
[_cameraSource setDelegate:self queue:_videoQueue];
|
||||||
|
_cameraSource.sessionPreset = AVCaptureSessionPresetHigh;
|
||||||
|
_cameraSource.cameraPosition = AVCaptureDevicePositionBack;
|
||||||
|
// The frame's native format is rotated with respect to the portrait orientation.
|
||||||
|
_cameraSource.orientation = AVCaptureVideoOrientationPortrait;
|
||||||
|
[_cameraSource requestCameraAccessWithCompletionHandler:^void(BOOL granted) {
|
||||||
|
if (granted) {
|
||||||
dispatch_async(_videoQueue, ^{
|
dispatch_async(_videoQueue, ^{
|
||||||
[_cameraSource start];
|
[_cameraSource start];
|
||||||
});
|
});
|
||||||
|
dispatch_async(dispatch_get_main_queue(), ^{
|
||||||
|
_noCameraLabel.hidden = YES;
|
||||||
|
});
|
||||||
|
}
|
||||||
|
}];
|
||||||
|
break;
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
#pragma mark - MPPGraphDelegate methods
|
#pragma mark - MPPGraphDelegate methods
|
||||||
|
@ -164,7 +180,7 @@ static const char* kVideoQueueLabel = "com.google.mediapipe.example.videoQueue";
|
||||||
- (void)processVideoFrame:(CVPixelBufferRef)imageBuffer
|
- (void)processVideoFrame:(CVPixelBufferRef)imageBuffer
|
||||||
timestamp:(CMTime)timestamp
|
timestamp:(CMTime)timestamp
|
||||||
fromSource:(MPPInputSource*)source {
|
fromSource:(MPPInputSource*)source {
|
||||||
if (source != _cameraSource) {
|
if (source != _cameraSource && source != _videoSource) {
|
||||||
NSLog(@"Unknown source: %@", source);
|
NSLog(@"Unknown source: %@", source);
|
||||||
return;
|
return;
|
||||||
}
|
}
|
||||||
|
|
|
@ -36,7 +36,7 @@ exports_files([
|
||||||
mediapipe_proto_library(
|
mediapipe_proto_library(
|
||||||
name = "calculator_proto",
|
name = "calculator_proto",
|
||||||
srcs = ["calculator.proto"],
|
srcs = ["calculator.proto"],
|
||||||
visibility = [":mediapipe_internal"],
|
visibility = ["//visibility:public"],
|
||||||
deps = [
|
deps = [
|
||||||
"//mediapipe/framework:calculator_options_proto",
|
"//mediapipe/framework:calculator_options_proto",
|
||||||
"//mediapipe/framework:mediapipe_options_proto",
|
"//mediapipe/framework:mediapipe_options_proto",
|
||||||
|
@ -68,7 +68,7 @@ mediapipe_proto_library(
|
||||||
mediapipe_proto_library(
|
mediapipe_proto_library(
|
||||||
name = "calculator_profile_proto",
|
name = "calculator_profile_proto",
|
||||||
srcs = ["calculator_profile.proto"],
|
srcs = ["calculator_profile.proto"],
|
||||||
visibility = [":mediapipe_internal"],
|
visibility = ["//visibility:public"],
|
||||||
deps = [
|
deps = [
|
||||||
"//mediapipe/framework:calculator_options_proto",
|
"//mediapipe/framework:calculator_options_proto",
|
||||||
"//mediapipe/framework:calculator_proto",
|
"//mediapipe/framework:calculator_proto",
|
||||||
|
@ -830,6 +830,8 @@ cc_library(
|
||||||
":port",
|
":port",
|
||||||
":timestamp",
|
":timestamp",
|
||||||
":type_map",
|
":type_map",
|
||||||
|
"//mediapipe/framework/deps:no_destructor",
|
||||||
|
"//mediapipe/framework/deps:registration",
|
||||||
"//mediapipe/framework/port:core_proto",
|
"//mediapipe/framework/port:core_proto",
|
||||||
"//mediapipe/framework/port:integral_types",
|
"//mediapipe/framework/port:integral_types",
|
||||||
"//mediapipe/framework/port:logging",
|
"//mediapipe/framework/port:logging",
|
||||||
|
@ -1524,6 +1526,21 @@ cc_test(
|
||||||
],
|
],
|
||||||
)
|
)
|
||||||
|
|
||||||
|
cc_test(
|
||||||
|
name = "packet_registration_test",
|
||||||
|
size = "small",
|
||||||
|
srcs = ["packet_registration_test.cc"],
|
||||||
|
deps = [
|
||||||
|
":calculator_framework",
|
||||||
|
":packet",
|
||||||
|
":packet_test_cc_proto",
|
||||||
|
":type_map",
|
||||||
|
"//mediapipe/framework/port:core_proto",
|
||||||
|
"//mediapipe/framework/port:gtest_main",
|
||||||
|
"@com_google_absl//absl/strings",
|
||||||
|
],
|
||||||
|
)
|
||||||
|
|
||||||
cc_test(
|
cc_test(
|
||||||
name = "packet_generator_test",
|
name = "packet_generator_test",
|
||||||
size = "small",
|
size = "small",
|
||||||
|
|
|
@ -115,6 +115,9 @@ class CalculatorContract {
|
||||||
// When true, Process is called for every new timestamp bound, with or without
|
// When true, Process is called for every new timestamp bound, with or without
|
||||||
// new packets. A call to Process with only an input timestamp bound is
|
// new packets. A call to Process with only an input timestamp bound is
|
||||||
// normally used to compute a new output timestamp bound.
|
// normally used to compute a new output timestamp bound.
|
||||||
|
// NOTE: Also, when true, Process is called when input streams become done,
|
||||||
|
// which means, Process needs to handle input streams in "done" state.
|
||||||
|
// (Usually, by closing calculators' outputs where and when appropriate.)
|
||||||
void SetProcessTimestampBounds(bool process_timestamps) {
|
void SetProcessTimestampBounds(bool process_timestamps) {
|
||||||
process_timestamps_ = process_timestamps;
|
process_timestamps_ = process_timestamps;
|
||||||
}
|
}
|
||||||
|
|
|
@ -91,6 +91,9 @@ typedef ::mediapipe::StatusOr<OutputStreamPoller> StatusOrPoller;
|
||||||
// {{"video_id", mediapipe::MakePacket<std::string>("Ex-uGhDzue4")}}));
|
// {{"video_id", mediapipe::MakePacket<std::string>("Ex-uGhDzue4")}}));
|
||||||
// // See mediapipe/framework/graph_runner.h for an interface
|
// // See mediapipe/framework/graph_runner.h for an interface
|
||||||
// // to insert and extract packets from a graph as it runs.
|
// // to insert and extract packets from a graph as it runs.
|
||||||
|
// // Once it is done using the graph, close its streams and wait till done.
|
||||||
|
// MP_RETURN_IF_ERROR(graph->CloseAllInputStreams());
|
||||||
|
// MP_RETURN_IF_ERROR(graph->WaitUntilDone());
|
||||||
class CalculatorGraph {
|
class CalculatorGraph {
|
||||||
public:
|
public:
|
||||||
// Defines possible modes for adding a packet to a graph input stream.
|
// Defines possible modes for adding a packet to a graph input stream.
|
||||||
|
@ -157,8 +160,9 @@ class CalculatorGraph {
|
||||||
std::function<::mediapipe::Status(const Packet&)> packet_callback);
|
std::function<::mediapipe::Status(const Packet&)> packet_callback);
|
||||||
|
|
||||||
// Adds an OutputStreamPoller for a stream. This provides a synchronous,
|
// Adds an OutputStreamPoller for a stream. This provides a synchronous,
|
||||||
// polling API for accessing a stream's output. For asynchronous output, use
|
// polling API for accessing a stream's output. Should only be called before
|
||||||
// ObserveOutputStream. See also the helpers in tool/sink.h.
|
// Run() or StartRun(). For asynchronous output, use ObserveOutputStream. See
|
||||||
|
// also the helpers in tool/sink.h.
|
||||||
StatusOrPoller AddOutputStreamPoller(const std::string& stream_name);
|
StatusOrPoller AddOutputStreamPoller(const std::string& stream_name);
|
||||||
|
|
||||||
// Gets output side packet by name after the graph is done. However, base
|
// Gets output side packet by name after the graph is done. However, base
|
||||||
|
@ -300,6 +304,13 @@ class CalculatorGraph {
|
||||||
void RecordError(const ::mediapipe::Status& error)
|
void RecordError(const ::mediapipe::Status& error)
|
||||||
ABSL_LOCKS_EXCLUDED(error_mutex_);
|
ABSL_LOCKS_EXCLUDED(error_mutex_);
|
||||||
|
|
||||||
|
// Combines errors into a status. Returns true if the vector of errors is
|
||||||
|
// non-empty.
|
||||||
|
bool GetCombinedErrors(const std::string& error_prefix,
|
||||||
|
::mediapipe::Status* error_status);
|
||||||
|
// Convenience overload which specifies a default error prefix.
|
||||||
|
bool GetCombinedErrors(::mediapipe::Status* error_status);
|
||||||
|
|
||||||
// Returns the maximum input stream queue size.
|
// Returns the maximum input stream queue size.
|
||||||
int GetMaxInputStreamQueueSize();
|
int GetMaxInputStreamQueueSize();
|
||||||
|
|
||||||
|
@ -501,13 +512,6 @@ class CalculatorGraph {
|
||||||
void CleanupAfterRun(::mediapipe::Status* status)
|
void CleanupAfterRun(::mediapipe::Status* status)
|
||||||
ABSL_LOCKS_EXCLUDED(error_mutex_);
|
ABSL_LOCKS_EXCLUDED(error_mutex_);
|
||||||
|
|
||||||
// Combines errors into a status. Returns true if the vector of errors is
|
|
||||||
// non-empty.
|
|
||||||
bool GetCombinedErrors(const std::string& error_prefix,
|
|
||||||
::mediapipe::Status* error_status);
|
|
||||||
// Convenience overload which specifies a default error prefix.
|
|
||||||
bool GetCombinedErrors(::mediapipe::Status* error_status);
|
|
||||||
|
|
||||||
// Calls HandlePreRunStatus or HandleStatus on the StatusHandlers. Which one
|
// Calls HandlePreRunStatus or HandleStatus on the StatusHandlers. Which one
|
||||||
// is called depends on the GraphRunState parameter (PRE_RUN or POST_RUN).
|
// is called depends on the GraphRunState parameter (PRE_RUN or POST_RUN).
|
||||||
// current_run_side_packets_ must be set before this function is called.
|
// current_run_side_packets_ must be set before this function is called.
|
||||||
|
|
|
@ -459,7 +459,8 @@ class Vector3
|
||||||
int LargestAbsComponent() const {
|
int LargestAbsComponent() const {
|
||||||
Vector3 temp = Abs();
|
Vector3 temp = Abs();
|
||||||
return temp[0] > temp[1] ? temp[0] > temp[2] ? 0 : 2
|
return temp[0] > temp[1] ? temp[0] > temp[2] ? 0 : 2
|
||||||
: temp[1] > temp[2] ? 1 : 2;
|
: temp[1] > temp[2] ? 1
|
||||||
|
: 2;
|
||||||
}
|
}
|
||||||
|
|
||||||
// return the index of the smallest, median ,largest component of the vector
|
// return the index of the smallest, median ,largest component of the vector
|
||||||
|
|
|
@ -155,7 +155,7 @@ class InputStreamHandler {
|
||||||
// max number of invocations that are allowed to be scheduled is reached.
|
// max number of invocations that are allowed to be scheduled is reached.
|
||||||
// Returns true if at least one invocation has been scheduled.
|
// Returns true if at least one invocation has been scheduled.
|
||||||
// The latest minimum timestamp bound of the input streams is returned in
|
// The latest minimum timestamp bound of the input streams is returned in
|
||||||
// *input_bound iff the latest readiness of the node is kNotReady when the
|
// *input_bound if the latest readiness of the node is kNotReady when the
|
||||||
// function returns. During batching, this value will be equal to the
|
// function returns. During batching, this value will be equal to the
|
||||||
// timestamp of the first set of inputs in the batch. In other cases,
|
// timestamp of the first set of inputs in the batch. In other cases,
|
||||||
// Timestamp::Unset() is returned.
|
// Timestamp::Unset() is returned.
|
||||||
|
|
|
@ -66,6 +66,20 @@ class LegacyCalculatorSupport {
|
||||||
};
|
};
|
||||||
};
|
};
|
||||||
|
|
||||||
|
// We only declare this variable for two specializations of the template because
|
||||||
|
// it is only meant to be used for these two types.
|
||||||
|
// Note that, since these variables are members of specific template
|
||||||
|
// _specializations_, they are not themselves templates, and therefore their
|
||||||
|
// definitions must be in the .cc file. However, a declaration still needs to be
|
||||||
|
// included in the header, or some compilers will assume they have no
|
||||||
|
// definition.
|
||||||
|
template <>
|
||||||
|
thread_local CalculatorContext*
|
||||||
|
LegacyCalculatorSupport::Scoped<CalculatorContext>::current_;
|
||||||
|
template <>
|
||||||
|
thread_local CalculatorContract*
|
||||||
|
LegacyCalculatorSupport::Scoped<CalculatorContract>::current_;
|
||||||
|
|
||||||
} // namespace mediapipe
|
} // namespace mediapipe
|
||||||
|
|
||||||
#endif // MEDIAPIPE_FRAMEWORK_LEGACY_CALCULATOR_SUPPORT_H_
|
#endif // MEDIAPIPE_FRAMEWORK_LEGACY_CALCULATOR_SUPPORT_H_
|
||||||
|
|
|
@ -51,6 +51,18 @@ const HolderBase* GetHolder(const Packet& packet) {
|
||||||
return packet.holder_.get();
|
return packet.holder_.get();
|
||||||
}
|
}
|
||||||
|
|
||||||
|
::mediapipe::StatusOr<Packet> PacketFromDynamicProto(
|
||||||
|
const std::string& type_name, const std::string& serialized) {
|
||||||
|
ASSIGN_OR_RETURN(
|
||||||
|
auto message_holder,
|
||||||
|
packet_internal::MessageHolderRegistry::CreateByName(type_name));
|
||||||
|
auto* message =
|
||||||
|
const_cast<proto_ns::MessageLite*>(message_holder->GetProtoMessageLite());
|
||||||
|
RET_CHECK_NE(message, nullptr);
|
||||||
|
RET_CHECK(message->ParseFromString(serialized));
|
||||||
|
return packet_internal::Create(message_holder.release());
|
||||||
|
}
|
||||||
|
|
||||||
} // namespace packet_internal
|
} // namespace packet_internal
|
||||||
|
|
||||||
Packet Packet::At(class Timestamp timestamp) const& {
|
Packet Packet::At(class Timestamp timestamp) const& {
|
||||||
|
|
|
@ -27,6 +27,8 @@
|
||||||
#include "absl/memory/memory.h"
|
#include "absl/memory/memory.h"
|
||||||
#include "absl/strings/str_cat.h"
|
#include "absl/strings/str_cat.h"
|
||||||
#include "absl/synchronization/mutex.h"
|
#include "absl/synchronization/mutex.h"
|
||||||
|
#include "mediapipe/framework/deps/no_destructor.h"
|
||||||
|
#include "mediapipe/framework/deps/registration.h"
|
||||||
#include "mediapipe/framework/port.h"
|
#include "mediapipe/framework/port.h"
|
||||||
#include "mediapipe/framework/port/canonical_errors.h"
|
#include "mediapipe/framework/port/canonical_errors.h"
|
||||||
#include "mediapipe/framework/port/logging.h"
|
#include "mediapipe/framework/port/logging.h"
|
||||||
|
@ -51,6 +53,8 @@ Packet Create(HolderBase* holder, Timestamp timestamp);
|
||||||
Packet Create(std::shared_ptr<HolderBase> holder, Timestamp timestamp);
|
Packet Create(std::shared_ptr<HolderBase> holder, Timestamp timestamp);
|
||||||
const HolderBase* GetHolder(const Packet& packet);
|
const HolderBase* GetHolder(const Packet& packet);
|
||||||
const std::shared_ptr<HolderBase>& GetHolderShared(const Packet& packet);
|
const std::shared_ptr<HolderBase>& GetHolderShared(const Packet& packet);
|
||||||
|
::mediapipe::StatusOr<Packet> PacketFromDynamicProto(
|
||||||
|
const std::string& type_name, const std::string& serialized);
|
||||||
} // namespace packet_internal
|
} // namespace packet_internal
|
||||||
|
|
||||||
// A generic container class which can hold data of any type. The type of
|
// A generic container class which can hold data of any type. The type of
|
||||||
|
@ -355,112 +359,11 @@ class HolderBase {
|
||||||
// Downcasts this to Holder<T>. Returns nullptr if deserialization
|
// Downcasts this to Holder<T>. Returns nullptr if deserialization
|
||||||
// failed or if the requested type is not what is stored.
|
// failed or if the requested type is not what is stored.
|
||||||
template <typename T>
|
template <typename T>
|
||||||
inline Holder<T>* As(
|
Holder<T>* As();
|
||||||
typename std::enable_if<
|
|
||||||
(!std::is_base_of<proto_ns::MessageLite, T>::value &&
|
|
||||||
!std::is_base_of<proto_ns::Message, T>::value) ||
|
|
||||||
(std::is_same<proto_ns::MessageLite, T>::value ||
|
|
||||||
std::is_same<proto_ns::Message, T>::value)>::type* = 0) {
|
|
||||||
if (HolderIsOfType<Holder<T>>() || HolderIsOfType<ForeignHolder<T>>()) {
|
|
||||||
return static_cast<Holder<T>*>(this);
|
|
||||||
}
|
|
||||||
// Does not hold a T.
|
|
||||||
return nullptr;
|
|
||||||
}
|
|
||||||
|
|
||||||
// For proto Message/MessageLite subclasses.
|
|
||||||
// When holder data is a concrete proto, the method downcasts this to
|
|
||||||
// Holder<T> if the requested type is what is stored.
|
|
||||||
// When holder data is a generic proto Message/MessageLite and a concrete
|
|
||||||
// proto type T is requested, the method will downcast the HolderBase to
|
|
||||||
// Holder<T> if the proto data is an instance of T.
|
|
||||||
template <typename T>
|
|
||||||
inline Holder<T>* As(
|
|
||||||
typename std::enable_if<
|
|
||||||
(std::is_base_of<proto_ns::MessageLite, T>::value ||
|
|
||||||
std::is_base_of<proto_ns::Message, T>::value) &&
|
|
||||||
(!std::is_same<proto_ns::MessageLite, T>::value &&
|
|
||||||
!std::is_same<proto_ns::Message, T>::value)>::type* = 0) {
|
|
||||||
// Holder data is an instance of subclass type T.
|
|
||||||
if (HolderIsOfType<Holder<T>>() || HolderIsOfType<ForeignHolder<T>>()) {
|
|
||||||
return static_cast<Holder<T>*>(this);
|
|
||||||
}
|
|
||||||
|
|
||||||
// Holder data is a generic proto Message/MessageLite and a subclass type T
|
|
||||||
// is requested.
|
|
||||||
if (HolderIsOfType<Holder<proto_ns::Message>>() ||
|
|
||||||
HolderIsOfType<ForeignHolder<proto_ns::Message>>() ||
|
|
||||||
HolderIsOfType<Holder<proto_ns::MessageLite>>() ||
|
|
||||||
HolderIsOfType<ForeignHolder<proto_ns::MessageLite>>()) {
|
|
||||||
// TODO: Holder<proto_ns::Message/MessageLite> cannot be
|
|
||||||
// legally downcast to Holder<T>, even though that downcast works in
|
|
||||||
// practice. Need to propose a better way to do the downcast.
|
|
||||||
Holder<T>* holder = static_cast<Holder<T>*>(this);
|
|
||||||
T tmp;
|
|
||||||
VLOG(2) << "Holder proto data type: " << holder->data().GetTypeName()
|
|
||||||
<< " vs requested proto type: " << tmp.GetTypeName();
|
|
||||||
if (tmp.GetTypeName() == holder->data().GetTypeName()) {
|
|
||||||
return holder;
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
// Does not hold a T.
|
|
||||||
return nullptr;
|
|
||||||
}
|
|
||||||
|
|
||||||
// Same as non-const As() function.
|
// Same as non-const As() function.
|
||||||
template <typename T>
|
template <typename T>
|
||||||
inline const Holder<T>* As(
|
const Holder<T>* As() const;
|
||||||
typename std::enable_if<
|
|
||||||
(!std::is_base_of<proto_ns::MessageLite, T>::value &&
|
|
||||||
!std::is_base_of<proto_ns::Message, T>::value) ||
|
|
||||||
(std::is_same<proto_ns::MessageLite, T>::value ||
|
|
||||||
std::is_same<proto_ns::Message, T>::value)>::type* = 0) const {
|
|
||||||
if (HolderIsOfType<Holder<T>>() || HolderIsOfType<ForeignHolder<T>>()) {
|
|
||||||
return static_cast<const Holder<T>*>(this);
|
|
||||||
}
|
|
||||||
// Does not hold a T.
|
|
||||||
return nullptr;
|
|
||||||
}
|
|
||||||
|
|
||||||
// For proto Message/MessageLite subclasses.
|
|
||||||
// When holder data is a concrete proto, the method downcasts this to
|
|
||||||
// Holder<T> if the requested type is what is stored.
|
|
||||||
// When holder data is a generic proto Message/MessageLite and a concrete
|
|
||||||
// proto type T is requested, the method will downcast the HolderBase to
|
|
||||||
// Holder<T> if the proto data is an instance of T.
|
|
||||||
template <typename T>
|
|
||||||
inline const Holder<T>* As(
|
|
||||||
typename std::enable_if<
|
|
||||||
(std::is_base_of<proto_ns::MessageLite, T>::value ||
|
|
||||||
std::is_base_of<proto_ns::Message, T>::value) &&
|
|
||||||
(!std::is_same<proto_ns::MessageLite, T>::value &&
|
|
||||||
!std::is_same<proto_ns::Message, T>::value)>::type* = 0) const {
|
|
||||||
if (HolderIsOfType<Holder<T>>() || HolderIsOfType<ForeignHolder<T>>()) {
|
|
||||||
return static_cast<const Holder<T>*>(this);
|
|
||||||
}
|
|
||||||
|
|
||||||
// Holder data is a generic proto Message/MessageLite and a subclass type T
|
|
||||||
// is requested.
|
|
||||||
if (HolderIsOfType<Holder<proto_ns::Message>>() ||
|
|
||||||
HolderIsOfType<ForeignHolder<proto_ns::Message>>() ||
|
|
||||||
HolderIsOfType<Holder<proto_ns::MessageLite>>() ||
|
|
||||||
HolderIsOfType<ForeignHolder<proto_ns::MessageLite>>()) {
|
|
||||||
// TODO: Holder<proto_ns::Message/MessageLite> cannot be
|
|
||||||
// legally downcast to Holder<T>, even though that downcast works in
|
|
||||||
// practice. Need to propose a better way to do the downcast.
|
|
||||||
Holder<T>* holder = static_cast<const Holder<T>*>(this);
|
|
||||||
T tmp;
|
|
||||||
VLOG(2) << "Holder proto data type: " << holder->data().GetTypeName()
|
|
||||||
<< " vs requested proto type: " << tmp.GetTypeName();
|
|
||||||
if (tmp.GetTypeName() == holder->data().GetTypeName()) {
|
|
||||||
return holder;
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
// Does not hold a T.
|
|
||||||
return nullptr;
|
|
||||||
}
|
|
||||||
|
|
||||||
// Returns the pointer to MessageLite type for the data in holder, if
|
// Returns the pointer to MessageLite type for the data in holder, if
|
||||||
// underlying object is protocol buffer type, otherwise, nullptr is returned.
|
// underlying object is protocol buffer type, otherwise, nullptr is returned.
|
||||||
|
@ -520,12 +423,68 @@ ConvertToVectorOfProtoMessageLitePtrs(const T* data,
|
||||||
return result;
|
return result;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// This registry is used to create Holders of the right concrete C++ type given
|
||||||
|
// a proto type std::string (which is used as the registration key).
|
||||||
|
class MessageHolderRegistry
|
||||||
|
: public GlobalFactoryRegistry<std::unique_ptr<HolderBase>> {};
|
||||||
|
|
||||||
|
template <typename T>
|
||||||
|
struct is_concrete_proto_t
|
||||||
|
: public std::integral_constant<
|
||||||
|
bool, std::is_base_of<proto_ns::MessageLite, T>{} &&
|
||||||
|
!std::is_same<proto_ns::MessageLite, T>{} &&
|
||||||
|
!std::is_same<proto_ns::Message, T>{}> {};
|
||||||
|
|
||||||
|
// Registers a message type. T must be a non-cv-qualified concrete proto type.
|
||||||
|
template <typename T>
|
||||||
|
struct MessageRegistrationImpl {
|
||||||
|
static NoDestructor<mediapipe::RegistrationToken> registration;
|
||||||
|
};
|
||||||
|
|
||||||
|
// Static members of template classes can be defined in the header.
|
||||||
|
template <typename T>
|
||||||
|
NoDestructor<mediapipe::RegistrationToken>
|
||||||
|
MessageRegistrationImpl<T>::registration(MessageHolderRegistry::Register(
|
||||||
|
T{}.GetTypeName(), [] { return absl::make_unique<Holder<T>>(new T); }));
|
||||||
|
|
||||||
|
// For non-Message payloads, this does nothing.
|
||||||
|
template <typename T, typename Enable = void>
|
||||||
|
struct HolderSupport {
|
||||||
|
static void EnsureStaticInit() {}
|
||||||
|
};
|
||||||
|
|
||||||
|
// This template ensures that, for each concrete MessageLite subclass that is
|
||||||
|
// stored in a Packet, we register a function that allows us to create a
|
||||||
|
// Holder with the correct payload type from the proto's type name.
|
||||||
|
template <typename T>
|
||||||
|
struct HolderSupport<T,
|
||||||
|
typename std::enable_if<is_concrete_proto_t<T>{}>::type> {
|
||||||
|
// We must use std::remove_cv to ensure we don't try to register Foo twice if
|
||||||
|
// there are Holder<Foo> and Holder<const Foo>. TODO: lift this
|
||||||
|
// up to Holder?
|
||||||
|
using R = MessageRegistrationImpl<typename std::remove_cv<T>::type>;
|
||||||
|
// For the registration static member to be instantiated, it needs to be
|
||||||
|
// referenced in a context that requires the definition to exist (see ISO/IEC
|
||||||
|
// C++ 2003 standard, 14.7.1). Calling this ensures that's the case.
|
||||||
|
// We need two different call-sites to cover proto types for which packets
|
||||||
|
// are only ever created (i.e. the protos are only produced by calculators)
|
||||||
|
// and proto types for which packets are only ever consumed (i.e. the protos
|
||||||
|
// are only consumed by calculators).
|
||||||
|
static void EnsureStaticInit() { CHECK(R::registration.get() != nullptr); }
|
||||||
|
};
|
||||||
|
|
||||||
template <typename T>
|
template <typename T>
|
||||||
class Holder : public HolderBase {
|
class Holder : public HolderBase {
|
||||||
public:
|
public:
|
||||||
explicit Holder(const T* ptr) : ptr_(ptr) { SetHolderTypeId<Holder>(); }
|
explicit Holder(const T* ptr) : ptr_(ptr) {
|
||||||
|
HolderSupport<T>::EnsureStaticInit();
|
||||||
|
SetHolderTypeId<Holder>();
|
||||||
|
}
|
||||||
~Holder() override { delete_helper(); }
|
~Holder() override { delete_helper(); }
|
||||||
const T& data() const { return *ptr_; }
|
const T& data() const {
|
||||||
|
HolderSupport<T>::EnsureStaticInit();
|
||||||
|
return *ptr_;
|
||||||
|
}
|
||||||
size_t GetTypeId() const final { return tool::GetTypeHash<T>(); }
|
size_t GetTypeId() const final { return tool::GetTypeHash<T>(); }
|
||||||
// Releases the underlying data pointer and transfers the ownership to a
|
// Releases the underlying data pointer and transfers the ownership to a
|
||||||
// unique pointer.
|
// unique pointer.
|
||||||
|
@ -622,6 +581,24 @@ class ForeignHolder : public Holder<T> {
|
||||||
}
|
}
|
||||||
};
|
};
|
||||||
|
|
||||||
|
template <typename T>
|
||||||
|
Holder<T>* HolderBase::As() {
|
||||||
|
if (HolderIsOfType<Holder<T>>() || HolderIsOfType<ForeignHolder<T>>()) {
|
||||||
|
return static_cast<Holder<T>*>(this);
|
||||||
|
}
|
||||||
|
// Does not hold a T.
|
||||||
|
return nullptr;
|
||||||
|
}
|
||||||
|
|
||||||
|
template <typename T>
|
||||||
|
const Holder<T>* HolderBase::As() const {
|
||||||
|
if (HolderIsOfType<Holder<T>>() || HolderIsOfType<ForeignHolder<T>>()) {
|
||||||
|
return static_cast<const Holder<T>*>(this);
|
||||||
|
}
|
||||||
|
// Does not hold a T.
|
||||||
|
return nullptr;
|
||||||
|
}
|
||||||
|
|
||||||
} // namespace packet_internal
|
} // namespace packet_internal
|
||||||
|
|
||||||
inline Packet::Packet(const Packet& packet)
|
inline Packet::Packet(const Packet& packet)
|
||||||
|
|
57
mediapipe/framework/packet_registration_test.cc
Normal file
57
mediapipe/framework/packet_registration_test.cc
Normal file
|
@ -0,0 +1,57 @@
|
||||||
|
// Copyright 2020 The MediaPipe Authors.
|
||||||
|
//
|
||||||
|
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
// you may not use this file except in compliance with the License.
|
||||||
|
// You may obtain a copy of the License at
|
||||||
|
//
|
||||||
|
// http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
//
|
||||||
|
// Unless required by applicable law or agreed to in writing, software
|
||||||
|
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
// See the License for the specific language governing permissions and
|
||||||
|
// limitations under the License.
|
||||||
|
|
||||||
|
#include "absl/strings/str_cat.h"
|
||||||
|
#include "mediapipe/framework/calculator_framework.h"
|
||||||
|
#include "mediapipe/framework/packet.h"
|
||||||
|
#include "mediapipe/framework/packet_test.pb.h"
|
||||||
|
#include "mediapipe/framework/port/core_proto_inc.h"
|
||||||
|
#include "mediapipe/framework/port/gmock.h"
|
||||||
|
#include "mediapipe/framework/port/gtest.h"
|
||||||
|
#include "mediapipe/framework/port/status_matchers.h"
|
||||||
|
|
||||||
|
namespace mediapipe {
|
||||||
|
namespace {
|
||||||
|
|
||||||
|
namespace test_ns {
|
||||||
|
|
||||||
|
class TestSinkCalculator : public CalculatorBase {
|
||||||
|
public:
|
||||||
|
static ::mediapipe::Status GetContract(CalculatorContract* cc) {
|
||||||
|
cc->Inputs().Tag("IN").Set<mediapipe::InputOnlyProto>();
|
||||||
|
cc->Outputs().Tag("OUT").Set<int>();
|
||||||
|
return ::mediapipe::OkStatus();
|
||||||
|
}
|
||||||
|
|
||||||
|
::mediapipe::Status Process(CalculatorContext* cc) override {
|
||||||
|
int x = cc->Inputs().Tag("IN").Get<mediapipe::InputOnlyProto>().x();
|
||||||
|
cc->Outputs().Tag("OUT").AddPacket(
|
||||||
|
MakePacket<int>(x).At(cc->InputTimestamp()));
|
||||||
|
return ::mediapipe::OkStatus();
|
||||||
|
}
|
||||||
|
};
|
||||||
|
REGISTER_CALCULATOR(::mediapipe::test_ns::TestSinkCalculator);
|
||||||
|
|
||||||
|
} // namespace test_ns
|
||||||
|
|
||||||
|
TEST(PacketTest, InputTypeRegistration) {
|
||||||
|
using testing::Contains;
|
||||||
|
ASSERT_EQ(mediapipe::InputOnlyProto{}.GetTypeName(),
|
||||||
|
"mediapipe.InputOnlyProto");
|
||||||
|
EXPECT_THAT(packet_internal::MessageHolderRegistry::GetRegisteredNames(),
|
||||||
|
Contains("mediapipe.InputOnlyProto"));
|
||||||
|
}
|
||||||
|
|
||||||
|
} // namespace
|
||||||
|
} // namespace mediapipe
|
|
@ -174,54 +174,13 @@ TEST(PacketTest, ReturnGenericProtobufMessage) {
|
||||||
.x(0));
|
.x(0));
|
||||||
}
|
}
|
||||||
|
|
||||||
TEST(PacketTest, ReturnProtobufMessageSubType) {
|
|
||||||
std::unique_ptr<::mediapipe::PacketTestProto> proto_ptr(
|
|
||||||
new ::mediapipe::PacketTestProto);
|
|
||||||
proto_ptr->add_x(123);
|
|
||||||
Packet packet = Adopt(static_cast<proto_ns::Message*>(proto_ptr.release()));
|
|
||||||
EXPECT_EQ(123, packet.Get<::mediapipe::PacketTestProto>().x(0));
|
|
||||||
EXPECT_EQ(123, packet.Get<const ::mediapipe::PacketTestProto>().x(0));
|
|
||||||
}
|
|
||||||
|
|
||||||
TEST(PacketTest, TryWrongProtobufMessageSubType) {
|
TEST(PacketTest, TryWrongProtobufMessageSubType) {
|
||||||
// Packet of PacketTestProto.
|
|
||||||
std::unique_ptr<::mediapipe::PacketTestProto> proto_ptr(
|
std::unique_ptr<::mediapipe::PacketTestProto> proto_ptr(
|
||||||
new ::mediapipe::PacketTestProto);
|
new ::mediapipe::PacketTestProto);
|
||||||
proto_ptr->add_x(123);
|
proto_ptr->add_x(123);
|
||||||
Packet packet = Adopt(proto_ptr.release());
|
Packet packet = Adopt(proto_ptr.release());
|
||||||
EXPECT_FALSE(packet.ValidateAsType<::mediapipe::SimpleProto>().ok());
|
EXPECT_FALSE(packet.ValidateAsType<::mediapipe::SimpleProto>().ok());
|
||||||
EXPECT_TRUE(packet.ValidateAsType<::mediapipe::PacketTestProto>().ok());
|
EXPECT_TRUE(packet.ValidateAsType<::mediapipe::PacketTestProto>().ok());
|
||||||
|
|
||||||
// Packet of proto_ns::Message.
|
|
||||||
proto_ptr.reset(new ::mediapipe::PacketTestProto);
|
|
||||||
proto_ptr->add_x(456);
|
|
||||||
Packet packet2 = Adopt(static_cast<proto_ns::Message*>(proto_ptr.release()));
|
|
||||||
EXPECT_FALSE(packet2.ValidateAsType<::mediapipe::SimpleProto>().ok());
|
|
||||||
EXPECT_TRUE(packet2.ValidateAsType<::mediapipe::PacketTestProto>().ok());
|
|
||||||
EXPECT_EQ(123, packet.Get<::mediapipe::PacketTestProto>().x(0));
|
|
||||||
}
|
|
||||||
|
|
||||||
TEST(PacketTest, ReturnProtobufMessageLiteSubType) {
|
|
||||||
std::unique_ptr<::mediapipe::PacketTestProto> proto_ptr(
|
|
||||||
new ::mediapipe::PacketTestProto);
|
|
||||||
proto_ptr->add_x(123);
|
|
||||||
Packet packet =
|
|
||||||
Adopt(static_cast<proto_ns::MessageLite*>(proto_ptr.release()));
|
|
||||||
EXPECT_EQ(123, packet.Get<::mediapipe::PacketTestProto>().x(0));
|
|
||||||
EXPECT_EQ(123, packet.Get<const ::mediapipe::PacketTestProto>().x(0));
|
|
||||||
}
|
|
||||||
|
|
||||||
TEST(PacketTest, TryWrongProtobufMessageLiteSubType) {
|
|
||||||
// Packet of PacketTestProto.
|
|
||||||
std::unique_ptr<::mediapipe::PacketTestProto> proto_ptr(
|
|
||||||
new ::mediapipe::PacketTestProto);
|
|
||||||
// Packet of proto_ns::MessageLite.
|
|
||||||
proto_ptr->add_x(456);
|
|
||||||
Packet packet =
|
|
||||||
Adopt(static_cast<proto_ns::MessageLite*>(proto_ptr.release()));
|
|
||||||
EXPECT_FALSE(packet.ValidateAsType<::mediapipe::SimpleProto>().ok());
|
|
||||||
EXPECT_TRUE(packet.ValidateAsType<::mediapipe::PacketTestProto>().ok());
|
|
||||||
EXPECT_EQ(456, packet.Get<::mediapipe::PacketTestProto>().x(0));
|
|
||||||
}
|
}
|
||||||
|
|
||||||
TEST(PacketTest, GetProtoBase) {
|
TEST(PacketTest, GetProtoBase) {
|
||||||
|
@ -505,5 +464,26 @@ TEST(PacketTest, TestConsumeOrCopyBoundedArray) {
|
||||||
EXPECT_TRUE(packet2.IsEmpty());
|
EXPECT_TRUE(packet2.IsEmpty());
|
||||||
}
|
}
|
||||||
|
|
||||||
|
TEST(PacketTest, MessageHolderRegistration) {
|
||||||
|
using testing::Contains;
|
||||||
|
Packet packet = MakePacket<mediapipe::SimpleProto>();
|
||||||
|
ASSERT_EQ(mediapipe::SimpleProto{}.GetTypeName(), "mediapipe.SimpleProto");
|
||||||
|
EXPECT_THAT(packet_internal::MessageHolderRegistry::GetRegisteredNames(),
|
||||||
|
Contains("mediapipe.SimpleProto"));
|
||||||
|
}
|
||||||
|
|
||||||
|
TEST(PacketTest, PacketFromSerializedProto) {
|
||||||
|
mediapipe::SimpleProto original;
|
||||||
|
original.add_value("foo");
|
||||||
|
std::string serialized = original.SerializeAsString();
|
||||||
|
|
||||||
|
StatusOr<Packet> maybe_packet = packet_internal::PacketFromDynamicProto(
|
||||||
|
"mediapipe.SimpleProto", serialized);
|
||||||
|
MP_ASSERT_OK(maybe_packet);
|
||||||
|
Packet packet = maybe_packet.ValueOrDie();
|
||||||
|
MP_EXPECT_OK(packet.ValidateAsType<::mediapipe::SimpleProto>());
|
||||||
|
EXPECT_FALSE(packet.ValidateAsType<::mediapipe::PacketTestProto>().ok());
|
||||||
|
}
|
||||||
|
|
||||||
} // namespace
|
} // namespace
|
||||||
} // namespace mediapipe
|
} // namespace mediapipe
|
||||||
|
|
|
@ -39,3 +39,9 @@ message SerializationProxyProto {
|
||||||
repeated float float_value = 2;
|
repeated float float_value = 2;
|
||||||
repeated string string_value = 3;
|
repeated string string_value = 3;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// This proto should be used only as an input to a calculator, to verify that
|
||||||
|
// that case is covered.
|
||||||
|
message InputOnlyProto {
|
||||||
|
optional int32 x = 1;
|
||||||
|
}
|
||||||
|
|
|
@ -46,7 +46,7 @@
|
||||||
// but may or may not still be able to run other OpenGL code.
|
// but may or may not still be able to run other OpenGL code.
|
||||||
#if !defined(MEDIAPIPE_DISABLE_GL_COMPUTE) && \
|
#if !defined(MEDIAPIPE_DISABLE_GL_COMPUTE) && \
|
||||||
(defined(__APPLE__) || defined(__EMSCRIPTEN__) || \
|
(defined(__APPLE__) || defined(__EMSCRIPTEN__) || \
|
||||||
defined(MEDIAPIPE_DISABLE_GPU))
|
defined(MEDIAPIPE_DISABLE_GPU) || MEDIAPIPE_USING_SWIFTSHADER)
|
||||||
#define MEDIAPIPE_DISABLE_GL_COMPUTE
|
#define MEDIAPIPE_DISABLE_GL_COMPUTE
|
||||||
#endif
|
#endif
|
||||||
|
|
||||||
|
|
|
@ -143,8 +143,8 @@ TEST_F(GraphTracerTest, CalculatorTrace) {
|
||||||
{{MakePacket<std::string>("goodbye").At(start_timestamp_)}});
|
{{MakePacket<std::string>("goodbye").At(start_timestamp_)}});
|
||||||
|
|
||||||
// Validate the GraphTrace data.
|
// Validate the GraphTrace data.
|
||||||
EXPECT_THAT(GetTrace(),
|
EXPECT_THAT(
|
||||||
EqualsProto(::mediapipe::ParseTextProtoOrDie<GraphTrace>(R"(
|
GetTrace(), EqualsProto(::mediapipe::ParseTextProtoOrDie<GraphTrace>(R"(
|
||||||
base_time: 1608911100000000
|
base_time: 1608911100000000
|
||||||
base_timestamp: 1608911100000000
|
base_timestamp: 1608911100000000
|
||||||
stream_name: ""
|
stream_name: ""
|
||||||
|
@ -163,7 +163,7 @@ TEST_F(GraphTracerTest, CalculatorTrace) {
|
||||||
stream_id: 1
|
stream_id: 1
|
||||||
event_data: 1
|
event_data: 1
|
||||||
}
|
}
|
||||||
output_trace { packet_timestamp: 0 stream_id: 2 }
|
output_trace { packet_timestamp: 0 stream_id: 2 event_data: 2 }
|
||||||
}
|
}
|
||||||
)")));
|
)")));
|
||||||
}
|
}
|
||||||
|
@ -205,18 +205,27 @@ TEST_F(GraphTracerTest, GraphTrace) {
|
||||||
LogOutputPackets("PCalculator_3", GraphTrace::PROCESS, curr_time,
|
LogOutputPackets("PCalculator_3", GraphTrace::PROCESS, curr_time,
|
||||||
{{MakePacket<std::string>("out").At(start_timestamp_)}});
|
{{MakePacket<std::string>("out").At(start_timestamp_)}});
|
||||||
curr_time += absl::Microseconds(2000);
|
curr_time += absl::Microseconds(2000);
|
||||||
ClearCalculatorContext("PCalculator_3");
|
|
||||||
LogInputPackets("PCalculator_3", GraphTrace::PROCESS, curr_time,
|
// Note: the packet data ID is based on the packet's payload address, which
|
||||||
|
// means the same ID can be reused if data is allocated in the same location
|
||||||
|
// as a previously expired packet (b/160212191). This means the generated
|
||||||
|
// trace can change depending on the allocator. To keep results stable, we
|
||||||
|
// must keep the packets used in this test alive until the end. Each
|
||||||
|
// TestContextBuilder happens to keep a reference to all packets for the last
|
||||||
|
// context, so for now we just create a separate TestContextBuilder instead of
|
||||||
|
// clearing it. TODO: revise this test.
|
||||||
|
SetUpCalculatorContext("PCalculator_3a", /*node_id=*/2, {"up_2"}, {"down_2"});
|
||||||
|
LogInputPackets("PCalculator_3a", GraphTrace::PROCESS, curr_time,
|
||||||
{MakePacket<std::string>("pup").At(start_timestamp_ + 5)});
|
{MakePacket<std::string>("pup").At(start_timestamp_ + 5)});
|
||||||
curr_time += absl::Microseconds(20000);
|
curr_time += absl::Microseconds(20000);
|
||||||
LogOutputPackets(
|
LogOutputPackets(
|
||||||
"PCalculator_3", GraphTrace::PROCESS, curr_time,
|
"PCalculator_3a", GraphTrace::PROCESS, curr_time,
|
||||||
{{MakePacket<std::string>("pout").At(start_timestamp_ + 5)}});
|
{{MakePacket<std::string>("pout").At(start_timestamp_ + 5)}});
|
||||||
curr_time += absl::Microseconds(1000);
|
curr_time += absl::Microseconds(1000);
|
||||||
|
|
||||||
// Validate the GraphTrace data.
|
// Validate the GraphTrace data.
|
||||||
EXPECT_THAT(GetTrace(),
|
EXPECT_THAT(
|
||||||
EqualsProto(::mediapipe::ParseTextProtoOrDie<GraphTrace>(R"(
|
GetTrace(), EqualsProto(::mediapipe::ParseTextProtoOrDie<GraphTrace>(R"(
|
||||||
base_time: 1608911100000000
|
base_time: 1608911100000000
|
||||||
base_timestamp: 1608911100000000
|
base_timestamp: 1608911100000000
|
||||||
stream_name: ""
|
stream_name: ""
|
||||||
|
@ -238,9 +247,9 @@ TEST_F(GraphTracerTest, GraphTrace) {
|
||||||
stream_id: 1
|
stream_id: 1
|
||||||
event_data: 1
|
event_data: 1
|
||||||
}
|
}
|
||||||
output_trace { packet_timestamp: 0 stream_id: 2 }
|
output_trace { packet_timestamp: 0 stream_id: 2 event_data: 2 }
|
||||||
output_trace { packet_timestamp: 0 stream_id: 3 }
|
output_trace { packet_timestamp: 0 stream_id: 3 event_data: 3 }
|
||||||
output_trace { packet_timestamp: 5 stream_id: 3 }
|
output_trace { packet_timestamp: 5 stream_id: 3 event_data: 4 }
|
||||||
}
|
}
|
||||||
calculator_trace {
|
calculator_trace {
|
||||||
node_id: 1
|
node_id: 1
|
||||||
|
@ -254,9 +263,9 @@ TEST_F(GraphTracerTest, GraphTrace) {
|
||||||
finish_time: 11000
|
finish_time: 11000
|
||||||
packet_timestamp: 0
|
packet_timestamp: 0
|
||||||
stream_id: 2
|
stream_id: 2
|
||||||
event_data: 2
|
event_data: 5
|
||||||
}
|
}
|
||||||
output_trace { packet_timestamp: 0 stream_id: 4 }
|
output_trace { packet_timestamp: 0 stream_id: 4 event_data: 6 }
|
||||||
}
|
}
|
||||||
calculator_trace {
|
calculator_trace {
|
||||||
node_id: 2
|
node_id: 2
|
||||||
|
@ -270,9 +279,9 @@ TEST_F(GraphTracerTest, GraphTrace) {
|
||||||
finish_time: 16000
|
finish_time: 16000
|
||||||
packet_timestamp: 0
|
packet_timestamp: 0
|
||||||
stream_id: 3
|
stream_id: 3
|
||||||
event_data: 3
|
event_data: 7
|
||||||
}
|
}
|
||||||
output_trace { packet_timestamp: 0 stream_id: 5 }
|
output_trace { packet_timestamp: 0 stream_id: 5 event_data: 8 }
|
||||||
}
|
}
|
||||||
calculator_trace {
|
calculator_trace {
|
||||||
node_id: 2
|
node_id: 2
|
||||||
|
@ -286,9 +295,9 @@ TEST_F(GraphTracerTest, GraphTrace) {
|
||||||
finish_time: 38000
|
finish_time: 38000
|
||||||
packet_timestamp: 5
|
packet_timestamp: 5
|
||||||
stream_id: 3
|
stream_id: 3
|
||||||
event_data: 4
|
event_data: 9
|
||||||
}
|
}
|
||||||
output_trace { packet_timestamp: 5 stream_id: 5 }
|
output_trace { packet_timestamp: 5 stream_id: 5 event_data: 10 }
|
||||||
}
|
}
|
||||||
)")));
|
)")));
|
||||||
|
|
||||||
|
@ -1275,7 +1284,9 @@ TEST_F(GraphTracerE2ETest, GpuTaskTrace) {
|
||||||
GraphTrace trace_1;
|
GraphTrace trace_1;
|
||||||
builder.CreateTrace(buffer, absl::InfinitePast(), absl::InfiniteFuture(),
|
builder.CreateTrace(buffer, absl::InfinitePast(), absl::InfiniteFuture(),
|
||||||
&trace_1);
|
&trace_1);
|
||||||
EXPECT_THAT(trace_1, EqualsProto(::mediapipe::ParseTextProtoOrDie<GraphTrace>(
|
EXPECT_THAT(
|
||||||
|
trace_1,
|
||||||
|
EqualsProto(::mediapipe::ParseTextProtoOrDie<GraphTrace>(
|
||||||
R"(
|
R"(
|
||||||
base_time: 1100
|
base_time: 1100
|
||||||
base_timestamp: 1000
|
base_timestamp: 1000
|
||||||
|
@ -1294,7 +1305,7 @@ TEST_F(GraphTracerE2ETest, GpuTaskTrace) {
|
||||||
stream_id: 1
|
stream_id: 1
|
||||||
event_data: 0
|
event_data: 0
|
||||||
}
|
}
|
||||||
output_trace { packet_timestamp: 0 stream_id: 2 }
|
output_trace { packet_timestamp: 0 stream_id: 2 event_data: 0 }
|
||||||
thread_id: 0
|
thread_id: 0
|
||||||
}
|
}
|
||||||
calculator_trace {
|
calculator_trace {
|
||||||
|
|
|
@ -330,13 +330,12 @@ class TraceBuilder::Impl {
|
||||||
if (trace_event_registry_[event->event_type].is_stream_event()) {
|
if (trace_event_registry_[event->event_type].is_stream_event()) {
|
||||||
auto stream_trace = event->is_finish ? result->add_output_trace()
|
auto stream_trace = event->is_finish ? result->add_output_trace()
|
||||||
: result->add_input_trace();
|
: result->add_input_trace();
|
||||||
if (event->is_finish) {
|
|
||||||
// Log only the packet id for each output event.
|
|
||||||
stream_trace->set_stream_id(stream_id_map_[event->stream_id]);
|
|
||||||
stream_trace->set_packet_timestamp(LogTimestamp(event->packet_ts));
|
|
||||||
} else {
|
|
||||||
// Log the full stream trace for each input event.
|
|
||||||
BuildStreamTrace(*event, stream_trace);
|
BuildStreamTrace(*event, stream_trace);
|
||||||
|
if (!event->is_finish) {
|
||||||
|
// Note: is_finish is true for output events, false for input events.
|
||||||
|
// For input events, we log some additional timing information. The
|
||||||
|
// finish_time is the start_time of this Process call, the start_time
|
||||||
|
// is the finish_time of the Process call that output the packet.
|
||||||
stream_trace->set_finish_time(LogTime(event->event_time));
|
stream_trace->set_finish_time(LogTime(event->event_time));
|
||||||
const TraceEvent* output_event = FindOutputEvent(*event);
|
const TraceEvent* output_event = FindOutputEvent(*event);
|
||||||
if (output_event) {
|
if (output_event) {
|
||||||
|
|
|
@ -116,10 +116,19 @@ NodeReadiness ImmediateInputStreamHandler::GetNodeReadiness(
|
||||||
CHECK_EQ(stream_ts, Timestamp::Done());
|
CHECK_EQ(stream_ts, Timestamp::Done());
|
||||||
if (ProcessTimestampBounds()) {
|
if (ProcessTimestampBounds()) {
|
||||||
// With kReadyForClose, the timestamp-bound Done is returned.
|
// With kReadyForClose, the timestamp-bound Done is returned.
|
||||||
// This bound is processed using the preceding input-timestamp.
|
|
||||||
// TODO: Make all InputStreamHandlers process Done() like this.
|
// TODO: Make all InputStreamHandlers process Done() like this.
|
||||||
ready_timestamps_[i] = stream_ts.PreviousAllowedInStream();
|
static const Timestamp kDonePrecedingTimestamp =
|
||||||
input_timestamp = std::min(input_timestamp, ready_timestamps_[i]);
|
Timestamp::Done().PreviousAllowedInStream();
|
||||||
|
if (prev_ts < kDonePrecedingTimestamp) {
|
||||||
|
// When kReadyForClose is received for the first time for a sync set,
|
||||||
|
// it is processed using the timestamp preceding Done() to indicate
|
||||||
|
// input stream is done, but still needs to be processed.
|
||||||
|
min_bound = std::min(min_bound, kDonePrecedingTimestamp);
|
||||||
|
input_timestamp = std::min(input_timestamp, kDonePrecedingTimestamp);
|
||||||
|
ready_timestamps_[i] = kDonePrecedingTimestamp;
|
||||||
|
} else {
|
||||||
|
ready_timestamps_[i] = Timestamp::Done();
|
||||||
|
}
|
||||||
} else if (prev_ts < Timestamp::Done()) {
|
} else if (prev_ts < Timestamp::Done()) {
|
||||||
stream_became_done = true;
|
stream_became_done = true;
|
||||||
ready_timestamps_[i] = Timestamp::Done();
|
ready_timestamps_[i] = Timestamp::Done();
|
||||||
|
|
|
@ -133,6 +133,11 @@ class ImmediateInputStreamHandlerTest : public ::testing::Test {
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
const InputStream& Input(const CollectionItemId& id) {
|
||||||
|
CHECK(cc_);
|
||||||
|
return cc_->Inputs().Get(id);
|
||||||
|
}
|
||||||
|
|
||||||
PacketType packet_type_;
|
PacketType packet_type_;
|
||||||
std::function<void()> headers_ready_callback_;
|
std::function<void()> headers_ready_callback_;
|
||||||
std::function<void()> notification_callback_;
|
std::function<void()> notification_callback_;
|
||||||
|
@ -262,6 +267,344 @@ TEST_F(ImmediateInputStreamHandlerTest, ReadyForClose) {
|
||||||
EXPECT_TRUE(errors_.empty());
|
EXPECT_TRUE(errors_.empty());
|
||||||
}
|
}
|
||||||
|
|
||||||
|
TEST_F(ImmediateInputStreamHandlerTest, ProcessTimestampBounds) {
|
||||||
|
input_stream_handler_->SetProcessTimestampBounds(true);
|
||||||
|
|
||||||
|
Timestamp min_stream_timestamp;
|
||||||
|
ASSERT_FALSE(input_stream_handler_->ScheduleInvocations(
|
||||||
|
/*max_allowance=*/1, &min_stream_timestamp));
|
||||||
|
EXPECT_EQ(min_stream_timestamp, Timestamp::PreStream());
|
||||||
|
|
||||||
|
const auto& input_a_id = name_to_id_["input_a"];
|
||||||
|
const auto& input_b_id = name_to_id_["input_b"];
|
||||||
|
const auto& input_c_id = name_to_id_["input_c"];
|
||||||
|
|
||||||
|
std::list<Packet> packets;
|
||||||
|
packets.push_back(Adopt(new std::string("packet 1")).At(Timestamp(1)));
|
||||||
|
input_stream_handler_->AddPackets(input_b_id, packets);
|
||||||
|
input_stream_handler_->SetNextTimestampBound(input_b_id, Timestamp::Done());
|
||||||
|
|
||||||
|
ASSERT_TRUE(input_stream_handler_->ScheduleInvocations(
|
||||||
|
/*max_allowance=*/1, &min_stream_timestamp));
|
||||||
|
EXPECT_EQ(min_stream_timestamp, Timestamp::Unset());
|
||||||
|
EXPECT_EQ(cc_->InputTimestamp(), Timestamp(1));
|
||||||
|
ExpectPackets(cc_->Inputs(), {{"input_b", "packet 1"}});
|
||||||
|
EXPECT_EQ(Input(input_b_id).Value().Timestamp(), Timestamp(1));
|
||||||
|
EXPECT_TRUE(Input(input_a_id).Value().IsEmpty());
|
||||||
|
EXPECT_EQ(Input(input_a_id).Value().Timestamp(), Timestamp::Unstarted());
|
||||||
|
EXPECT_TRUE(Input(input_c_id).Value().IsEmpty());
|
||||||
|
EXPECT_EQ(Input(input_c_id).Value().Timestamp(), Timestamp::Unstarted());
|
||||||
|
|
||||||
|
// FinalizeInputSet() is a no-op.
|
||||||
|
input_stream_handler_->FinalizeInputSet(cc_->InputTimestamp(),
|
||||||
|
&cc_->Inputs());
|
||||||
|
input_stream_handler_->ClearCurrentInputs(cc_);
|
||||||
|
|
||||||
|
ASSERT_TRUE(input_stream_handler_->ScheduleInvocations(
|
||||||
|
/*max_allowance=*/1, &min_stream_timestamp));
|
||||||
|
EXPECT_EQ(min_stream_timestamp, Timestamp::Unset());
|
||||||
|
EXPECT_EQ(cc_->InputTimestamp(), Timestamp::Max());
|
||||||
|
EXPECT_TRUE(Input(input_b_id).Value().IsEmpty());
|
||||||
|
EXPECT_EQ(Input(input_b_id).Value().Timestamp(), Timestamp::Max());
|
||||||
|
EXPECT_TRUE(Input(input_a_id).Value().IsEmpty());
|
||||||
|
EXPECT_EQ(Input(input_a_id).Value().Timestamp(), Timestamp::Unstarted());
|
||||||
|
EXPECT_TRUE(Input(input_c_id).Value().IsEmpty());
|
||||||
|
EXPECT_EQ(Input(input_c_id).Value().Timestamp(), Timestamp::Unstarted());
|
||||||
|
|
||||||
|
// FinalizeInputSet() is a no-op.
|
||||||
|
input_stream_handler_->FinalizeInputSet(cc_->InputTimestamp(),
|
||||||
|
&cc_->Inputs());
|
||||||
|
input_stream_handler_->ClearCurrentInputs(cc_);
|
||||||
|
|
||||||
|
EXPECT_TRUE(
|
||||||
|
input_stream_handler_->GetInputStreamManager(input_b_id)->IsEmpty());
|
||||||
|
|
||||||
|
input_stream_handler_->SetNextTimestampBound(input_a_id, Timestamp::Done());
|
||||||
|
input_stream_handler_->SetNextTimestampBound(input_c_id, Timestamp::Done());
|
||||||
|
|
||||||
|
ASSERT_TRUE(input_stream_handler_->ScheduleInvocations(
|
||||||
|
/*max_allowance=*/1, &min_stream_timestamp));
|
||||||
|
EXPECT_EQ(min_stream_timestamp, Timestamp::Unset());
|
||||||
|
EXPECT_EQ(cc_->InputTimestamp(), Timestamp::Max());
|
||||||
|
EXPECT_TRUE(Input(input_b_id).Value().IsEmpty());
|
||||||
|
EXPECT_EQ(Input(input_b_id).Value().Timestamp(), Timestamp::Max());
|
||||||
|
EXPECT_TRUE(Input(input_a_id).Value().IsEmpty());
|
||||||
|
EXPECT_EQ(Input(input_a_id).Value().Timestamp(), Timestamp::Max());
|
||||||
|
EXPECT_TRUE(Input(input_c_id).Value().IsEmpty());
|
||||||
|
EXPECT_EQ(Input(input_c_id).Value().Timestamp(), Timestamp::Max());
|
||||||
|
|
||||||
|
input_stream_handler_->FinalizeInputSet(cc_->InputTimestamp(),
|
||||||
|
&cc_->Inputs());
|
||||||
|
input_stream_handler_->ClearCurrentInputs(cc_);
|
||||||
|
EXPECT_TRUE(errors_.empty());
|
||||||
|
|
||||||
|
// Schedule invocation for Close.
|
||||||
|
ASSERT_TRUE(input_stream_handler_->ScheduleInvocations(
|
||||||
|
/*max_allowance=*/1, &min_stream_timestamp));
|
||||||
|
EXPECT_EQ(min_stream_timestamp, Timestamp::Unset());
|
||||||
|
EXPECT_EQ(cc_->InputTimestamp(), Timestamp::Done());
|
||||||
|
EXPECT_TRUE(Input(input_b_id).Value().IsEmpty());
|
||||||
|
EXPECT_EQ(Input(input_b_id).Value().Timestamp(), Timestamp::Unset());
|
||||||
|
EXPECT_TRUE(Input(input_a_id).Value().IsEmpty());
|
||||||
|
EXPECT_EQ(Input(input_a_id).Value().Timestamp(), Timestamp::Unset());
|
||||||
|
EXPECT_TRUE(Input(input_c_id).Value().IsEmpty());
|
||||||
|
EXPECT_EQ(Input(input_c_id).Value().Timestamp(), Timestamp::Unset());
|
||||||
|
|
||||||
|
input_stream_handler_->FinalizeInputSet(cc_->InputTimestamp(),
|
||||||
|
&cc_->Inputs());
|
||||||
|
input_stream_handler_->ClearCurrentInputs(cc_);
|
||||||
|
EXPECT_TRUE(errors_.empty());
|
||||||
|
}
|
||||||
|
|
||||||
|
TEST_F(ImmediateInputStreamHandlerTest,
|
||||||
|
ProcessTimestampBoundsNoOpScheduleInvocations) {
|
||||||
|
input_stream_handler_->SetProcessTimestampBounds(true);
|
||||||
|
|
||||||
|
const auto& input_a_id = name_to_id_["input_a"];
|
||||||
|
const auto& input_b_id = name_to_id_["input_b"];
|
||||||
|
const auto& input_c_id = name_to_id_["input_c"];
|
||||||
|
|
||||||
|
Timestamp min_stream_timestamp;
|
||||||
|
std::list<Packet> packets;
|
||||||
|
packets.push_back(Adopt(new std::string("packet 1")).At(Timestamp(1)));
|
||||||
|
input_stream_handler_->AddPackets(input_b_id, packets);
|
||||||
|
ASSERT_TRUE(input_stream_handler_->ScheduleInvocations(
|
||||||
|
/*max_allowance=*/1, &min_stream_timestamp));
|
||||||
|
EXPECT_EQ(min_stream_timestamp, Timestamp::Unset());
|
||||||
|
EXPECT_EQ(cc_->InputTimestamp(), Timestamp(1));
|
||||||
|
ExpectPackets(cc_->Inputs(), {{"input_b", "packet 1"}});
|
||||||
|
EXPECT_EQ(Input(input_b_id).Value().Timestamp(), Timestamp(1));
|
||||||
|
EXPECT_TRUE(Input(input_a_id).Value().IsEmpty());
|
||||||
|
EXPECT_EQ(Input(input_a_id).Value().Timestamp(), Timestamp::Unstarted());
|
||||||
|
EXPECT_TRUE(Input(input_c_id).Value().IsEmpty());
|
||||||
|
EXPECT_EQ(Input(input_c_id).Value().Timestamp(), Timestamp::Unstarted());
|
||||||
|
|
||||||
|
// FinalizeInputSet() is a no-op.
|
||||||
|
input_stream_handler_->FinalizeInputSet(cc_->InputTimestamp(),
|
||||||
|
&cc_->Inputs());
|
||||||
|
input_stream_handler_->ClearCurrentInputs(cc_);
|
||||||
|
|
||||||
|
input_stream_handler_->SetNextTimestampBound(input_a_id, Timestamp::Done());
|
||||||
|
input_stream_handler_->SetNextTimestampBound(input_c_id, Timestamp::Done());
|
||||||
|
|
||||||
|
ASSERT_TRUE(input_stream_handler_->ScheduleInvocations(
|
||||||
|
/*max_allowance=*/1, &min_stream_timestamp));
|
||||||
|
EXPECT_EQ(min_stream_timestamp, Timestamp::Unset());
|
||||||
|
EXPECT_EQ(cc_->InputTimestamp(), Timestamp::Max());
|
||||||
|
EXPECT_TRUE(Input(input_b_id).Value().IsEmpty());
|
||||||
|
EXPECT_EQ(Input(input_b_id).Value().Timestamp(), Timestamp(1));
|
||||||
|
EXPECT_TRUE(Input(input_a_id).Value().IsEmpty());
|
||||||
|
EXPECT_EQ(Input(input_a_id).Value().Timestamp(), Timestamp::Max());
|
||||||
|
EXPECT_TRUE(Input(input_c_id).Value().IsEmpty());
|
||||||
|
EXPECT_EQ(Input(input_c_id).Value().Timestamp(), Timestamp::Max());
|
||||||
|
|
||||||
|
input_stream_handler_->FinalizeInputSet(cc_->InputTimestamp(),
|
||||||
|
&cc_->Inputs());
|
||||||
|
input_stream_handler_->ClearCurrentInputs(cc_);
|
||||||
|
EXPECT_TRUE(errors_.empty());
|
||||||
|
|
||||||
|
// Try to schedule invocations several times again. Considering nothing
|
||||||
|
// changed since last invocation nothing should be scheduled.
|
||||||
|
for (int i = 0; i < 3; ++i) {
|
||||||
|
ASSERT_FALSE(input_stream_handler_->ScheduleInvocations(
|
||||||
|
/*max_allowance=*/1, &min_stream_timestamp));
|
||||||
|
EXPECT_EQ(min_stream_timestamp, Timestamp(2));
|
||||||
|
EXPECT_TRUE(Input(input_b_id).Value().IsEmpty());
|
||||||
|
EXPECT_EQ(Input(input_b_id).Value().Timestamp(), Timestamp::Unset());
|
||||||
|
EXPECT_TRUE(Input(input_a_id).Value().IsEmpty());
|
||||||
|
EXPECT_EQ(Input(input_a_id).Value().Timestamp(), Timestamp::Unset());
|
||||||
|
EXPECT_TRUE(Input(input_c_id).Value().IsEmpty());
|
||||||
|
EXPECT_EQ(Input(input_c_id).Value().Timestamp(), Timestamp::Unset());
|
||||||
|
}
|
||||||
|
|
||||||
|
input_stream_handler_->SetNextTimestampBound(input_b_id, Timestamp::Done());
|
||||||
|
|
||||||
|
ASSERT_TRUE(input_stream_handler_->ScheduleInvocations(
|
||||||
|
/*max_allowance=*/1, &min_stream_timestamp));
|
||||||
|
EXPECT_EQ(min_stream_timestamp, Timestamp::Unset());
|
||||||
|
EXPECT_EQ(cc_->InputTimestamp(), Timestamp::Max());
|
||||||
|
EXPECT_TRUE(Input(input_b_id).Value().IsEmpty());
|
||||||
|
EXPECT_EQ(Input(input_b_id).Value().Timestamp(), Timestamp::Max());
|
||||||
|
EXPECT_TRUE(Input(input_a_id).Value().IsEmpty());
|
||||||
|
EXPECT_EQ(Input(input_a_id).Value().Timestamp(), Timestamp::Max());
|
||||||
|
EXPECT_TRUE(Input(input_c_id).Value().IsEmpty());
|
||||||
|
EXPECT_EQ(Input(input_c_id).Value().Timestamp(), Timestamp::Max());
|
||||||
|
|
||||||
|
input_stream_handler_->FinalizeInputSet(cc_->InputTimestamp(),
|
||||||
|
&cc_->Inputs());
|
||||||
|
input_stream_handler_->ClearCurrentInputs(cc_);
|
||||||
|
EXPECT_TRUE(errors_.empty());
|
||||||
|
|
||||||
|
// Schedule invocation for Close.
|
||||||
|
ASSERT_TRUE(input_stream_handler_->ScheduleInvocations(
|
||||||
|
/*max_allowance=*/1, &min_stream_timestamp));
|
||||||
|
EXPECT_EQ(min_stream_timestamp, Timestamp::Unset());
|
||||||
|
EXPECT_EQ(cc_->InputTimestamp(), Timestamp::Done());
|
||||||
|
EXPECT_TRUE(Input(input_b_id).Value().IsEmpty());
|
||||||
|
EXPECT_EQ(Input(input_b_id).Value().Timestamp(), Timestamp::Unset());
|
||||||
|
EXPECT_TRUE(Input(input_a_id).Value().IsEmpty());
|
||||||
|
EXPECT_EQ(Input(input_a_id).Value().Timestamp(), Timestamp::Unset());
|
||||||
|
EXPECT_TRUE(Input(input_c_id).Value().IsEmpty());
|
||||||
|
EXPECT_EQ(Input(input_c_id).Value().Timestamp(), Timestamp::Unset());
|
||||||
|
|
||||||
|
input_stream_handler_->FinalizeInputSet(cc_->InputTimestamp(),
|
||||||
|
&cc_->Inputs());
|
||||||
|
input_stream_handler_->ClearCurrentInputs(cc_);
|
||||||
|
EXPECT_TRUE(errors_.empty());
|
||||||
|
|
||||||
|
// Try to schedule invocations several times again. Considering nothing
|
||||||
|
// changed since last invocation nothing should be scheduled.
|
||||||
|
for (int i = 0; i < 3; ++i) {
|
||||||
|
ASSERT_FALSE(input_stream_handler_->ScheduleInvocations(
|
||||||
|
/*max_allowance=*/1, &min_stream_timestamp));
|
||||||
|
EXPECT_EQ(min_stream_timestamp, Timestamp::Unset());
|
||||||
|
EXPECT_TRUE(Input(input_b_id).Value().IsEmpty());
|
||||||
|
EXPECT_EQ(Input(input_b_id).Value().Timestamp(), Timestamp::Unset());
|
||||||
|
EXPECT_TRUE(Input(input_a_id).Value().IsEmpty());
|
||||||
|
EXPECT_EQ(Input(input_a_id).Value().Timestamp(), Timestamp::Unset());
|
||||||
|
EXPECT_TRUE(Input(input_c_id).Value().IsEmpty());
|
||||||
|
EXPECT_EQ(Input(input_c_id).Value().Timestamp(), Timestamp::Unset());
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// Due to some temporary changes in ImmediateInputStreamHandler some packets
|
||||||
|
// - were queued but never released
|
||||||
|
// - were released in incorrect order
|
||||||
|
// As other test cases were passing, this test case is designed to ensure that.
|
||||||
|
TEST_F(ImmediateInputStreamHandlerTest, VerifyPacketsReleaseOrder) {
|
||||||
|
input_stream_handler_->SetProcessTimestampBounds(true);
|
||||||
|
|
||||||
|
const auto& input_a_id = name_to_id_["input_a"];
|
||||||
|
const auto& input_b_id = name_to_id_["input_b"];
|
||||||
|
const auto& input_c_id = name_to_id_["input_c"];
|
||||||
|
|
||||||
|
Packet packet_a = Adopt(new std::string("packet a"));
|
||||||
|
Packet packet_b = Adopt(new std::string("packet b"));
|
||||||
|
Packet packet_c = Adopt(new std::string("packet c"));
|
||||||
|
input_stream_handler_->AddPackets(input_a_id, {packet_a.At(Timestamp(1))});
|
||||||
|
input_stream_handler_->AddPackets(input_b_id, {packet_b.At(Timestamp(2))});
|
||||||
|
input_stream_handler_->AddPackets(input_c_id, {packet_c.At(Timestamp(3))});
|
||||||
|
|
||||||
|
Timestamp min_stream_timestamp;
|
||||||
|
ASSERT_TRUE(input_stream_handler_->ScheduleInvocations(
|
||||||
|
/*max_allowance=*/1, &min_stream_timestamp));
|
||||||
|
EXPECT_EQ(min_stream_timestamp, Timestamp::Unset());
|
||||||
|
EXPECT_EQ(cc_->InputTimestamp(), Timestamp(1));
|
||||||
|
ASSERT_FALSE(Input(input_a_id).IsEmpty());
|
||||||
|
EXPECT_EQ(Input(input_a_id).Get<std::string>(), "packet a");
|
||||||
|
EXPECT_EQ(Input(input_a_id).Value().Timestamp(), Timestamp(1));
|
||||||
|
EXPECT_TRUE(Input(input_b_id).Value().IsEmpty());
|
||||||
|
EXPECT_EQ(Input(input_b_id).Value().Timestamp(), Timestamp(1));
|
||||||
|
EXPECT_TRUE(Input(input_c_id).Value().IsEmpty());
|
||||||
|
EXPECT_EQ(Input(input_c_id).Value().Timestamp(), Timestamp(2));
|
||||||
|
|
||||||
|
// FinalizeInputSet() is a no-op.
|
||||||
|
input_stream_handler_->FinalizeInputSet(cc_->InputTimestamp(),
|
||||||
|
&cc_->Inputs());
|
||||||
|
input_stream_handler_->ClearCurrentInputs(cc_);
|
||||||
|
|
||||||
|
input_stream_handler_->AddPackets(input_a_id, {packet_a.At(Timestamp(5))});
|
||||||
|
input_stream_handler_->AddPackets(input_b_id, {packet_b.At(Timestamp(5))});
|
||||||
|
input_stream_handler_->AddPackets(input_c_id, {packet_c.At(Timestamp(5))});
|
||||||
|
|
||||||
|
ASSERT_TRUE(input_stream_handler_->ScheduleInvocations(
|
||||||
|
/*max_allowance=*/1, &min_stream_timestamp));
|
||||||
|
EXPECT_EQ(min_stream_timestamp, Timestamp::Unset());
|
||||||
|
EXPECT_EQ(cc_->InputTimestamp(), Timestamp(2));
|
||||||
|
EXPECT_TRUE(Input(input_a_id).Value().IsEmpty());
|
||||||
|
EXPECT_EQ(Input(input_a_id).Value().Timestamp(), Timestamp(4));
|
||||||
|
ASSERT_FALSE(Input(input_b_id).IsEmpty());
|
||||||
|
EXPECT_EQ(Input(input_b_id).Get<std::string>(), "packet b");
|
||||||
|
EXPECT_EQ(Input(input_b_id).Value().Timestamp(), Timestamp(2));
|
||||||
|
EXPECT_TRUE(Input(input_c_id).Value().IsEmpty());
|
||||||
|
EXPECT_EQ(Input(input_c_id).Value().Timestamp(), Timestamp(2));
|
||||||
|
|
||||||
|
// FinalizeInputSet() is a no-op.
|
||||||
|
input_stream_handler_->FinalizeInputSet(cc_->InputTimestamp(),
|
||||||
|
&cc_->Inputs());
|
||||||
|
input_stream_handler_->ClearCurrentInputs(cc_);
|
||||||
|
|
||||||
|
ASSERT_TRUE(input_stream_handler_->ScheduleInvocations(
|
||||||
|
/*max_allowance=*/1, &min_stream_timestamp));
|
||||||
|
EXPECT_EQ(min_stream_timestamp, Timestamp::Unset());
|
||||||
|
EXPECT_EQ(cc_->InputTimestamp(), Timestamp(3));
|
||||||
|
EXPECT_TRUE(Input(input_a_id).Value().IsEmpty());
|
||||||
|
EXPECT_EQ(Input(input_a_id).Value().Timestamp(), Timestamp(4));
|
||||||
|
EXPECT_TRUE(Input(input_b_id).Value().IsEmpty());
|
||||||
|
EXPECT_EQ(Input(input_b_id).Value().Timestamp(), Timestamp(4));
|
||||||
|
ASSERT_FALSE(Input(input_c_id).Value().IsEmpty());
|
||||||
|
EXPECT_EQ(Input(input_c_id).Get<std::string>(), "packet c");
|
||||||
|
EXPECT_EQ(Input(input_c_id).Value().Timestamp(), Timestamp(3));
|
||||||
|
|
||||||
|
// FinalizeInputSet() is a no-op.
|
||||||
|
input_stream_handler_->FinalizeInputSet(cc_->InputTimestamp(),
|
||||||
|
&cc_->Inputs());
|
||||||
|
input_stream_handler_->ClearCurrentInputs(cc_);
|
||||||
|
|
||||||
|
ASSERT_TRUE(input_stream_handler_->ScheduleInvocations(
|
||||||
|
/*max_allowance=*/1, &min_stream_timestamp));
|
||||||
|
EXPECT_EQ(min_stream_timestamp, Timestamp::Unset());
|
||||||
|
EXPECT_EQ(cc_->InputTimestamp(), Timestamp(5));
|
||||||
|
ASSERT_FALSE(Input(input_a_id).Value().IsEmpty());
|
||||||
|
EXPECT_EQ(Input(input_a_id).Get<std::string>(), "packet a");
|
||||||
|
EXPECT_EQ(Input(input_a_id).Value().Timestamp(), Timestamp(5));
|
||||||
|
ASSERT_FALSE(Input(input_b_id).Value().IsEmpty());
|
||||||
|
EXPECT_EQ(Input(input_b_id).Get<std::string>(), "packet b");
|
||||||
|
EXPECT_EQ(Input(input_b_id).Value().Timestamp(), Timestamp(5));
|
||||||
|
ASSERT_FALSE(Input(input_c_id).Value().IsEmpty());
|
||||||
|
EXPECT_EQ(Input(input_c_id).Get<std::string>(), "packet c");
|
||||||
|
EXPECT_EQ(Input(input_c_id).Value().Timestamp(), Timestamp(5));
|
||||||
|
|
||||||
|
input_stream_handler_->FinalizeInputSet(cc_->InputTimestamp(),
|
||||||
|
&cc_->Inputs());
|
||||||
|
input_stream_handler_->ClearCurrentInputs(cc_);
|
||||||
|
|
||||||
|
input_stream_handler_->SetNextTimestampBound(input_a_id, Timestamp::Done());
|
||||||
|
input_stream_handler_->SetNextTimestampBound(input_b_id, Timestamp::Done());
|
||||||
|
input_stream_handler_->SetNextTimestampBound(input_c_id, Timestamp::Done());
|
||||||
|
|
||||||
|
ASSERT_TRUE(input_stream_handler_->ScheduleInvocations(
|
||||||
|
/*max_allowance=*/1, &min_stream_timestamp));
|
||||||
|
EXPECT_EQ(min_stream_timestamp, Timestamp::Unset());
|
||||||
|
EXPECT_EQ(cc_->InputTimestamp(), Timestamp::Max());
|
||||||
|
EXPECT_TRUE(Input(input_b_id).Value().IsEmpty());
|
||||||
|
EXPECT_EQ(Input(input_b_id).Value().Timestamp(), Timestamp::Max());
|
||||||
|
EXPECT_TRUE(Input(input_a_id).Value().IsEmpty());
|
||||||
|
EXPECT_EQ(Input(input_a_id).Value().Timestamp(), Timestamp::Max());
|
||||||
|
EXPECT_TRUE(Input(input_c_id).Value().IsEmpty());
|
||||||
|
EXPECT_EQ(Input(input_c_id).Value().Timestamp(), Timestamp::Max());
|
||||||
|
|
||||||
|
input_stream_handler_->FinalizeInputSet(cc_->InputTimestamp(),
|
||||||
|
&cc_->Inputs());
|
||||||
|
input_stream_handler_->ClearCurrentInputs(cc_);
|
||||||
|
|
||||||
|
// Schedule invocation for Close.
|
||||||
|
ASSERT_TRUE(input_stream_handler_->ScheduleInvocations(
|
||||||
|
/*max_allowance=*/1, &min_stream_timestamp));
|
||||||
|
EXPECT_EQ(min_stream_timestamp, Timestamp::Unset());
|
||||||
|
EXPECT_EQ(cc_->InputTimestamp(), Timestamp::Done());
|
||||||
|
EXPECT_TRUE(Input(input_b_id).Value().IsEmpty());
|
||||||
|
EXPECT_EQ(Input(input_b_id).Value().Timestamp(), Timestamp::Unset());
|
||||||
|
EXPECT_TRUE(Input(input_a_id).Value().IsEmpty());
|
||||||
|
EXPECT_EQ(Input(input_a_id).Value().Timestamp(), Timestamp::Unset());
|
||||||
|
EXPECT_TRUE(Input(input_c_id).Value().IsEmpty());
|
||||||
|
EXPECT_EQ(Input(input_c_id).Value().Timestamp(), Timestamp::Unset());
|
||||||
|
|
||||||
|
input_stream_handler_->FinalizeInputSet(cc_->InputTimestamp(),
|
||||||
|
&cc_->Inputs());
|
||||||
|
input_stream_handler_->ClearCurrentInputs(cc_);
|
||||||
|
|
||||||
|
ASSERT_FALSE(input_stream_handler_->ScheduleInvocations(
|
||||||
|
/*max_allowance=*/1, &min_stream_timestamp));
|
||||||
|
EXPECT_EQ(min_stream_timestamp, Timestamp::Unset());
|
||||||
|
EXPECT_TRUE(Input(input_b_id).Value().IsEmpty());
|
||||||
|
EXPECT_EQ(Input(input_b_id).Value().Timestamp(), Timestamp::Unset());
|
||||||
|
EXPECT_TRUE(Input(input_a_id).Value().IsEmpty());
|
||||||
|
EXPECT_EQ(Input(input_a_id).Value().Timestamp(), Timestamp::Unset());
|
||||||
|
EXPECT_TRUE(Input(input_c_id).Value().IsEmpty());
|
||||||
|
EXPECT_EQ(Input(input_c_id).Value().Timestamp(), Timestamp::Unset());
|
||||||
|
}
|
||||||
|
|
||||||
// This test simulates how CalculatorNode::ProcessNode() uses an input
|
// This test simulates how CalculatorNode::ProcessNode() uses an input
|
||||||
// stream handler and the associated input streams.
|
// stream handler and the associated input streams.
|
||||||
TEST_F(ImmediateInputStreamHandlerTest, SimulateProcessNode) {
|
TEST_F(ImmediateInputStreamHandlerTest, SimulateProcessNode) {
|
||||||
|
|
|
@ -641,4 +641,61 @@ class DummyTestCalculator : public CalculatorBase {
|
||||||
};
|
};
|
||||||
REGISTER_CALCULATOR(DummyTestCalculator);
|
REGISTER_CALCULATOR(DummyTestCalculator);
|
||||||
|
|
||||||
|
// A Calculator that passes the input value to the output after sleeping for
|
||||||
|
// a set number of microseconds.
|
||||||
|
class PassThroughWithSleepCalculator : public CalculatorBase {
|
||||||
|
public:
|
||||||
|
static ::mediapipe::Status GetContract(CalculatorContract* cc) {
|
||||||
|
cc->Inputs().Index(0).Set<int>();
|
||||||
|
cc->Outputs().Index(0).SetSameAs(&cc->Inputs().Index(0));
|
||||||
|
cc->InputSidePackets().Tag("SLEEP_MICROS").Set<int>();
|
||||||
|
cc->InputSidePackets().Tag("CLOCK").Set<std::shared_ptr<Clock>>();
|
||||||
|
return ::mediapipe::OkStatus();
|
||||||
|
}
|
||||||
|
::mediapipe::Status Open(CalculatorContext* cc) final {
|
||||||
|
cc->SetOffset(TimestampDiff(0));
|
||||||
|
sleep_micros_ = cc->InputSidePackets().Tag("SLEEP_MICROS").Get<int>();
|
||||||
|
if (sleep_micros_ < 0) {
|
||||||
|
return ::mediapipe::InternalError("SLEEP_MICROS should be >= 0");
|
||||||
|
}
|
||||||
|
clock_ = cc->InputSidePackets().Tag("CLOCK").Get<std::shared_ptr<Clock>>();
|
||||||
|
return ::mediapipe::OkStatus();
|
||||||
|
}
|
||||||
|
::mediapipe::Status Process(CalculatorContext* cc) final {
|
||||||
|
clock_->Sleep(absl::Microseconds(sleep_micros_));
|
||||||
|
int value = cc->Inputs().Index(0).Value().Get<int>();
|
||||||
|
cc->Outputs().Index(0).Add(new int(value), cc->InputTimestamp());
|
||||||
|
return ::mediapipe::OkStatus();
|
||||||
|
}
|
||||||
|
|
||||||
|
private:
|
||||||
|
int sleep_micros_ = 0;
|
||||||
|
std::shared_ptr<Clock> clock_;
|
||||||
|
};
|
||||||
|
REGISTER_CALCULATOR(PassThroughWithSleepCalculator);
|
||||||
|
|
||||||
|
// A Calculator that multiples two input values.
|
||||||
|
class MultiplyIntCalculator : public CalculatorBase {
|
||||||
|
public:
|
||||||
|
static ::mediapipe::Status GetContract(CalculatorContract* cc) {
|
||||||
|
cc->Inputs().Index(0).Set<int>();
|
||||||
|
cc->Inputs().Index(1).SetSameAs(&cc->Inputs().Index(0));
|
||||||
|
// cc->Outputs().Index(0).SetSameAs(&cc->Inputs().Index(0));
|
||||||
|
RET_CHECK(cc->Outputs().HasTag("OUT"));
|
||||||
|
cc->Outputs().Tag("OUT").SetSameAs(&cc->Inputs().Index(0));
|
||||||
|
return ::mediapipe::OkStatus();
|
||||||
|
}
|
||||||
|
::mediapipe::Status Open(CalculatorContext* cc) final {
|
||||||
|
cc->SetOffset(TimestampDiff(0));
|
||||||
|
return ::mediapipe::OkStatus();
|
||||||
|
}
|
||||||
|
::mediapipe::Status Process(CalculatorContext* cc) final {
|
||||||
|
int x = cc->Inputs().Index(0).Value().Get<int>();
|
||||||
|
int y = cc->Inputs().Index(1).Value().Get<int>();
|
||||||
|
cc->Outputs().Tag("OUT").Add(new int(x * y), cc->InputTimestamp());
|
||||||
|
return ::mediapipe::OkStatus();
|
||||||
|
}
|
||||||
|
};
|
||||||
|
REGISTER_CALCULATOR(MultiplyIntCalculator);
|
||||||
|
|
||||||
} // namespace mediapipe
|
} // namespace mediapipe
|
||||||
|
|
|
@ -101,6 +101,13 @@ std::string ParseNameFromStream(const std::string& stream) {
|
||||||
return name;
|
return name;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
std::pair<std::string, int> ParseTagIndex(const std::string& tag_index) {
|
||||||
|
std::string tag;
|
||||||
|
int index;
|
||||||
|
MEDIAPIPE_CHECK_OK(tool::ParseTagIndex(tag_index, &tag, &index));
|
||||||
|
return {tag, index};
|
||||||
|
}
|
||||||
|
|
||||||
std::pair<std::string, int> ParseTagIndexFromStream(const std::string& stream) {
|
std::pair<std::string, int> ParseTagIndexFromStream(const std::string& stream) {
|
||||||
std::string tag, name;
|
std::string tag, name;
|
||||||
int index;
|
int index;
|
||||||
|
|
|
@ -76,6 +76,9 @@ std::string CanonicalNodeName(const CalculatorGraphConfig& graph_config,
|
||||||
// Parses the name from a "tag:index:name".
|
// Parses the name from a "tag:index:name".
|
||||||
std::string ParseNameFromStream(const std::string& stream);
|
std::string ParseNameFromStream(const std::string& stream);
|
||||||
|
|
||||||
|
// Parses the TagIndex from a "tag:index".
|
||||||
|
std::pair<std::string, int> ParseTagIndex(const std::string& tag_index);
|
||||||
|
|
||||||
// Parses the TagIndex from a "tag:index:name".
|
// Parses the TagIndex from a "tag:index:name".
|
||||||
std::pair<std::string, int> ParseTagIndexFromStream(const std::string& stream);
|
std::pair<std::string, int> ParseTagIndexFromStream(const std::string& stream);
|
||||||
|
|
||||||
|
|
10
mediapipe/framework/tool/testdata/BUILD
vendored
10
mediapipe/framework/tool/testdata/BUILD
vendored
|
@ -13,15 +13,15 @@
|
||||||
# limitations under the License.
|
# limitations under the License.
|
||||||
#
|
#
|
||||||
|
|
||||||
licenses(["notice"]) # Apache 2.0
|
|
||||||
|
|
||||||
package(default_visibility = ["//mediapipe:__subpackages__"])
|
|
||||||
|
|
||||||
load(
|
load(
|
||||||
"//mediapipe/framework/tool:mediapipe_graph.bzl",
|
"//mediapipe/framework/tool:mediapipe_graph.bzl",
|
||||||
"mediapipe_simple_subgraph",
|
"mediapipe_simple_subgraph",
|
||||||
)
|
)
|
||||||
|
|
||||||
|
licenses(["notice"]) # Apache 2.0
|
||||||
|
|
||||||
|
package(default_visibility = ["//mediapipe:__subpackages__"])
|
||||||
|
|
||||||
filegroup(
|
filegroup(
|
||||||
name = "test_graph",
|
name = "test_graph",
|
||||||
srcs = ["test.pbtxt"],
|
srcs = ["test.pbtxt"],
|
||||||
|
@ -31,6 +31,8 @@ exports_files([
|
||||||
"test.pbtxt",
|
"test.pbtxt",
|
||||||
"dub_quad_test_subgraph.pbtxt",
|
"dub_quad_test_subgraph.pbtxt",
|
||||||
"nested_test_subgraph.pbtxt",
|
"nested_test_subgraph.pbtxt",
|
||||||
|
"single_flow_container_test.pbtxt",
|
||||||
|
"dual_flow_container_test.pbtxt",
|
||||||
])
|
])
|
||||||
|
|
||||||
mediapipe_simple_subgraph(
|
mediapipe_simple_subgraph(
|
||||||
|
|
|
@ -12,14 +12,14 @@
|
||||||
# See the License for the specific language governing permissions and
|
# See the License for the specific language governing permissions and
|
||||||
# limitations under the License.
|
# limitations under the License.
|
||||||
|
|
||||||
licenses(["notice"]) # Apache 2.0
|
|
||||||
|
|
||||||
package(default_visibility = ["//visibility:public"])
|
|
||||||
|
|
||||||
load("//mediapipe/gpu:metal.bzl", "metal_library")
|
load("//mediapipe/gpu:metal.bzl", "metal_library")
|
||||||
load("@build_bazel_rules_apple//apple:ios.bzl", "ios_unit_test")
|
load("@build_bazel_rules_apple//apple:ios.bzl", "ios_unit_test")
|
||||||
load("//mediapipe/framework/port:build_config.bzl", "mediapipe_cc_proto_library")
|
load("//mediapipe/framework/port:build_config.bzl", "mediapipe_cc_proto_library")
|
||||||
|
|
||||||
|
licenses(["notice"]) # Apache 2.0
|
||||||
|
|
||||||
|
package(default_visibility = ["//visibility:public"])
|
||||||
|
|
||||||
# Disabling GPU support is sometimes useful on desktop Linux because SwiftShader can
|
# Disabling GPU support is sometimes useful on desktop Linux because SwiftShader can
|
||||||
# interfere with desktop GL. b/73494271
|
# interfere with desktop GL. b/73494271
|
||||||
config_setting(
|
config_setting(
|
||||||
|
|
|
@ -39,6 +39,7 @@ namespace mediapipe {
|
||||||
// ROTATION: the counterclockwise rotation angle in degrees. This allows
|
// ROTATION: the counterclockwise rotation angle in degrees. This allows
|
||||||
// user to specify different rotation angles for different frames. If this
|
// user to specify different rotation angles for different frames. If this
|
||||||
// stream is provided, it will override the ROTATION input side packet.
|
// stream is provided, it will override the ROTATION input side packet.
|
||||||
|
// OUTPUT_DIMENSIONS: the output width and height in pixels.
|
||||||
// Additional output streams:
|
// Additional output streams:
|
||||||
// TOP_BOTTOM_PADDING: If use FIT scale mode, this stream outputs the padding
|
// TOP_BOTTOM_PADDING: If use FIT scale mode, this stream outputs the padding
|
||||||
// size of the input image in normalized value [0, 1] for top and bottom
|
// size of the input image in normalized value [0, 1] for top and bottom
|
||||||
|
@ -103,6 +104,9 @@ REGISTER_CALCULATOR(GlScalerCalculator);
|
||||||
if (cc->Inputs().HasTag("ROTATION")) {
|
if (cc->Inputs().HasTag("ROTATION")) {
|
||||||
cc->Inputs().Tag("ROTATION").Set<int>();
|
cc->Inputs().Tag("ROTATION").Set<int>();
|
||||||
}
|
}
|
||||||
|
if (cc->Inputs().HasTag("OUTPUT_DIMENSIONS")) {
|
||||||
|
cc->Inputs().Tag("OUTPUT_DIMENSIONS").Set<DimensionsPacketType>();
|
||||||
|
}
|
||||||
MP_RETURN_IF_ERROR(GlCalculatorHelper::UpdateContract(cc));
|
MP_RETURN_IF_ERROR(GlCalculatorHelper::UpdateContract(cc));
|
||||||
|
|
||||||
if (cc->InputSidePackets().HasTag("OPTIONS")) {
|
if (cc->InputSidePackets().HasTag("OPTIONS")) {
|
||||||
|
@ -181,6 +185,18 @@ REGISTER_CALCULATOR(GlScalerCalculator);
|
||||||
}
|
}
|
||||||
|
|
||||||
::mediapipe::Status GlScalerCalculator::Process(CalculatorContext* cc) {
|
::mediapipe::Status GlScalerCalculator::Process(CalculatorContext* cc) {
|
||||||
|
if (cc->Inputs().HasTag("OUTPUT_DIMENSIONS")) {
|
||||||
|
if (cc->Inputs().Tag("OUTPUT_DIMENSIONS").IsEmpty()) {
|
||||||
|
// OUTPUT_DIMENSIONS input stream is specified, but value is missing.
|
||||||
|
return ::mediapipe::OkStatus();
|
||||||
|
}
|
||||||
|
|
||||||
|
const auto& dimensions =
|
||||||
|
cc->Inputs().Tag("OUTPUT_DIMENSIONS").Get<DimensionsPacketType>();
|
||||||
|
dst_width_ = dimensions[0];
|
||||||
|
dst_height_ = dimensions[1];
|
||||||
|
}
|
||||||
|
|
||||||
return helper_.RunInGlContext([this, cc]() -> ::mediapipe::Status {
|
return helper_.RunInGlContext([this, cc]() -> ::mediapipe::Status {
|
||||||
const auto& input = TagOrIndex(cc->Inputs(), "VIDEO", 0).Get<GpuBuffer>();
|
const auto& input = TagOrIndex(cc->Inputs(), "VIDEO", 0).Get<GpuBuffer>();
|
||||||
QuadRenderer* renderer = nullptr;
|
QuadRenderer* renderer = nullptr;
|
||||||
|
|
|
@ -140,6 +140,9 @@ node {
|
||||||
num_landmarks: 21
|
num_landmarks: 21
|
||||||
input_image_width: 256
|
input_image_width: 256
|
||||||
input_image_height: 256
|
input_image_height: 256
|
||||||
|
# The additional scaling factor is used to account for the Z coordinate
|
||||||
|
# distribution in the training data.
|
||||||
|
normalize_z: 0.4
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
|
@ -144,6 +144,9 @@ node {
|
||||||
num_landmarks: 21
|
num_landmarks: 21
|
||||||
input_image_width: 256
|
input_image_width: 256
|
||||||
input_image_height: 256
|
input_image_height: 256
|
||||||
|
# The additional scaling factor is used to account for the Z coordinate
|
||||||
|
# distribution in the training data.
|
||||||
|
normalize_z: 0.4
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
|
@ -25,6 +25,7 @@ android_library(
|
||||||
),
|
),
|
||||||
visibility = ["//visibility:public"],
|
visibility = ["//visibility:public"],
|
||||||
deps = [
|
deps = [
|
||||||
|
"//mediapipe/framework:calculator_java_proto_lite",
|
||||||
"//mediapipe/java/com/google/mediapipe/framework:android_framework",
|
"//mediapipe/java/com/google/mediapipe/framework:android_framework",
|
||||||
"//mediapipe/java/com/google/mediapipe/glutil",
|
"//mediapipe/java/com/google/mediapipe/glutil",
|
||||||
"//third_party:androidx_appcompat",
|
"//third_party:androidx_appcompat",
|
||||||
|
|
|
@ -14,17 +14,21 @@
|
||||||
|
|
||||||
package com.google.mediapipe.components;
|
package com.google.mediapipe.components;
|
||||||
|
|
||||||
|
import static java.lang.Math.max;
|
||||||
|
|
||||||
import android.graphics.SurfaceTexture;
|
import android.graphics.SurfaceTexture;
|
||||||
import android.opengl.GLES11Ext;
|
import android.opengl.GLES11Ext;
|
||||||
import android.opengl.GLES20;
|
import android.opengl.GLES20;
|
||||||
import android.util.Log;
|
import android.util.Log;
|
||||||
import com.google.mediapipe.framework.AppTextureFrame;
|
import com.google.mediapipe.framework.AppTextureFrame;
|
||||||
|
import com.google.mediapipe.framework.GlSyncToken;
|
||||||
import com.google.mediapipe.glutil.ExternalTextureRenderer;
|
import com.google.mediapipe.glutil.ExternalTextureRenderer;
|
||||||
import com.google.mediapipe.glutil.GlThread;
|
import com.google.mediapipe.glutil.GlThread;
|
||||||
import com.google.mediapipe.glutil.ShaderUtil;
|
import com.google.mediapipe.glutil.ShaderUtil;
|
||||||
|
import java.util.ArrayDeque;
|
||||||
import java.util.ArrayList;
|
import java.util.ArrayList;
|
||||||
import java.util.Collections;
|
|
||||||
import java.util.List;
|
import java.util.List;
|
||||||
|
import java.util.Queue;
|
||||||
import javax.microedition.khronos.egl.EGLContext;
|
import javax.microedition.khronos.egl.EGLContext;
|
||||||
|
|
||||||
/**
|
/**
|
||||||
|
@ -204,8 +208,11 @@ public class ExternalTextureConverter implements TextureFrameProducer {
|
||||||
private static final long NANOS_PER_MICRO = 1000; // Nanoseconds in one microsecond.
|
private static final long NANOS_PER_MICRO = 1000; // Nanoseconds in one microsecond.
|
||||||
private volatile SurfaceTexture surfaceTexture = null;
|
private volatile SurfaceTexture surfaceTexture = null;
|
||||||
private final List<TextureFrameConsumer> consumers;
|
private final List<TextureFrameConsumer> consumers;
|
||||||
private List<AppTextureFrame> outputFrames = null;
|
|
||||||
private int outputFrameIndex = -1;
|
private final Queue<PoolTextureFrame> framesAvailable = new ArrayDeque<>();
|
||||||
|
private int framesInUse = 0;
|
||||||
|
private final int framesToKeep;
|
||||||
|
|
||||||
private ExternalTextureRenderer renderer = null;
|
private ExternalTextureRenderer renderer = null;
|
||||||
private long nextFrameTimestampOffset = 0;
|
private long nextFrameTimestampOffset = 0;
|
||||||
private long timestampOffsetNanos = 0;
|
private long timestampOffsetNanos = 0;
|
||||||
|
@ -215,10 +222,27 @@ public class ExternalTextureConverter implements TextureFrameProducer {
|
||||||
protected int destinationWidth = 0;
|
protected int destinationWidth = 0;
|
||||||
protected int destinationHeight = 0;
|
protected int destinationHeight = 0;
|
||||||
|
|
||||||
|
private class PoolTextureFrame extends AppTextureFrame {
|
||||||
|
public PoolTextureFrame(int textureName, int width, int height) {
|
||||||
|
super(textureName, width, height);
|
||||||
|
}
|
||||||
|
|
||||||
|
@Override
|
||||||
|
public void release(GlSyncToken syncToken) {
|
||||||
|
super.release(syncToken);
|
||||||
|
poolFrameReleased(this);
|
||||||
|
}
|
||||||
|
|
||||||
|
@Override
|
||||||
|
public void release() {
|
||||||
|
super.release();
|
||||||
|
poolFrameReleased(this);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
public RenderThread(EGLContext parentContext, int numBuffers) {
|
public RenderThread(EGLContext parentContext, int numBuffers) {
|
||||||
super(parentContext);
|
super(parentContext);
|
||||||
outputFrames = new ArrayList<>();
|
framesToKeep = numBuffers;
|
||||||
outputFrames.addAll(Collections.nCopies(numBuffers, null));
|
|
||||||
renderer = new ExternalTextureRenderer();
|
renderer = new ExternalTextureRenderer();
|
||||||
consumers = new ArrayList<>();
|
consumers = new ArrayList<>();
|
||||||
}
|
}
|
||||||
|
@ -283,8 +307,8 @@ public class ExternalTextureConverter implements TextureFrameProducer {
|
||||||
@Override
|
@Override
|
||||||
public void releaseGl() {
|
public void releaseGl() {
|
||||||
setSurfaceTexture(null, 0, 0);
|
setSurfaceTexture(null, 0, 0);
|
||||||
for (int i = 0; i < outputFrames.size(); ++i) {
|
while (!framesAvailable.isEmpty()) {
|
||||||
teardownDestination(i);
|
teardownFrame(framesAvailable.remove());
|
||||||
}
|
}
|
||||||
renderer.release();
|
renderer.release();
|
||||||
super.releaseGl(); // This releases the EGL context, so must do it after any GL calls.
|
super.releaseGl(); // This releases the EGL context, so must do it after any GL calls.
|
||||||
|
@ -337,16 +361,11 @@ public class ExternalTextureConverter implements TextureFrameProducer {
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
private void teardownDestination(int index) {
|
private static void teardownFrame(AppTextureFrame frame) {
|
||||||
if (outputFrames.get(index) != null) {
|
GLES20.glDeleteTextures(1, new int[] {frame.getTextureName()}, 0);
|
||||||
waitUntilReleased(outputFrames.get(index));
|
|
||||||
GLES20.glDeleteTextures(1, new int[] {outputFrames.get(index).getTextureName()}, 0);
|
|
||||||
outputFrames.set(index, null);
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
|
|
||||||
private void setupDestination(int index) {
|
private PoolTextureFrame createFrame() {
|
||||||
teardownDestination(index);
|
|
||||||
int destinationTextureId = ShaderUtil.createRgbaTexture(destinationWidth, destinationHeight);
|
int destinationTextureId = ShaderUtil.createRgbaTexture(destinationWidth, destinationHeight);
|
||||||
Log.d(
|
Log.d(
|
||||||
TAG,
|
TAG,
|
||||||
|
@ -354,11 +373,9 @@ public class ExternalTextureConverter implements TextureFrameProducer {
|
||||||
"Created output texture: %d width: %d height: %d",
|
"Created output texture: %d width: %d height: %d",
|
||||||
destinationTextureId, destinationWidth, destinationHeight));
|
destinationTextureId, destinationWidth, destinationHeight));
|
||||||
bindFramebuffer(destinationTextureId, destinationWidth, destinationHeight);
|
bindFramebuffer(destinationTextureId, destinationWidth, destinationHeight);
|
||||||
outputFrames.set(
|
return new PoolTextureFrame(destinationTextureId, destinationWidth, destinationHeight);
|
||||||
index, new AppTextureFrame(destinationTextureId, destinationWidth, destinationHeight));
|
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* Gets next available frame or creates new one if next frame is not initialized
|
* Gets next available frame or creates new one if next frame is not initialized
|
||||||
* or cannot be used with current surface texture.
|
* or cannot be used with current surface texture.
|
||||||
|
@ -371,20 +388,38 @@ public class ExternalTextureConverter implements TextureFrameProducer {
|
||||||
* NOTE: must be invoked on GL thread
|
* NOTE: must be invoked on GL thread
|
||||||
*/
|
*/
|
||||||
private AppTextureFrame nextOutputFrame() {
|
private AppTextureFrame nextOutputFrame() {
|
||||||
outputFrameIndex = (outputFrameIndex + 1) % outputFrames.size();
|
PoolTextureFrame outputFrame;
|
||||||
AppTextureFrame outputFrame = outputFrames.get(outputFrameIndex);
|
synchronized (this) {
|
||||||
// Check if the size has changed.
|
outputFrame = framesAvailable.poll();
|
||||||
if (outputFrame == null
|
framesInUse++;
|
||||||
|| outputFrame.getWidth() != destinationWidth
|
|
||||||
|| outputFrame.getHeight() != destinationHeight) {
|
|
||||||
// setupDestination will wait for the frame to be released before reallocating it.
|
|
||||||
setupDestination(outputFrameIndex);
|
|
||||||
outputFrame = outputFrames.get(outputFrameIndex);
|
|
||||||
}
|
}
|
||||||
|
if (outputFrame == null) {
|
||||||
|
outputFrame = createFrame();
|
||||||
|
} else if (outputFrame.getWidth() != destinationWidth
|
||||||
|
|| outputFrame.getHeight() != destinationHeight) {
|
||||||
|
// Create anew if size has changed.
|
||||||
|
// TODO: waiting for the consumer sync here may not be necessary.
|
||||||
waitUntilReleased(outputFrame);
|
waitUntilReleased(outputFrame);
|
||||||
|
teardownFrame(outputFrame);
|
||||||
|
outputFrame = createFrame();
|
||||||
|
} else {
|
||||||
|
// Note: waitUntilReleased does two things: waits for the frame to be released by the CPU,
|
||||||
|
// and syncs with the GPU sync token provided by the consumer. The first part is redundant
|
||||||
|
// here (and completes immediately), but the second part is still needed.
|
||||||
|
waitUntilReleased(outputFrame);
|
||||||
|
}
|
||||||
return outputFrame;
|
return outputFrame;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
protected synchronized void poolFrameReleased(PoolTextureFrame frame) {
|
||||||
|
framesAvailable.offer(frame);
|
||||||
|
framesInUse--;
|
||||||
|
int keep = max(framesToKeep - framesInUse, 0);
|
||||||
|
while (framesAvailable.size() > keep) {
|
||||||
|
teardownFrame(framesAvailable.remove());
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* Updates output frame with current pixels of surface texture and corresponding timestamp.
|
* Updates output frame with current pixels of surface texture and corresponding timestamp.
|
||||||
*
|
*
|
||||||
|
@ -417,16 +452,22 @@ public class ExternalTextureConverter implements TextureFrameProducer {
|
||||||
Log.v(
|
Log.v(
|
||||||
TAG,
|
TAG,
|
||||||
String.format(
|
String.format(
|
||||||
"Waiting for tex: %d width: %d height: %d",
|
"Waiting for tex: %d width: %d height: %d timestamp: %d",
|
||||||
frame.getTextureName(), frame.getWidth(), frame.getHeight()));
|
frame.getTextureName(),
|
||||||
|
frame.getWidth(),
|
||||||
|
frame.getHeight(),
|
||||||
|
frame.getTimestamp()));
|
||||||
}
|
}
|
||||||
frame.waitUntilReleased();
|
frame.waitUntilReleased();
|
||||||
if (Log.isLoggable(TAG, Log.VERBOSE)) {
|
if (Log.isLoggable(TAG, Log.VERBOSE)) {
|
||||||
Log.v(
|
Log.v(
|
||||||
TAG,
|
TAG,
|
||||||
String.format(
|
String.format(
|
||||||
"Finished waiting for tex: %d width: %d height: %d",
|
"Finished waiting for tex: %d width: %d height: %d timestamp: %d",
|
||||||
frame.getTextureName(), frame.getWidth(), frame.getHeight()));
|
frame.getTextureName(),
|
||||||
|
frame.getWidth(),
|
||||||
|
frame.getHeight(),
|
||||||
|
frame.getTimestamp()));
|
||||||
}
|
}
|
||||||
} catch (InterruptedException ie) {
|
} catch (InterruptedException ie) {
|
||||||
// Someone interrupted our thread. This is not supposed to happen: we own
|
// Someone interrupted our thread. This is not supposed to happen: we own
|
||||||
|
|
|
@ -20,6 +20,7 @@ import android.media.AudioFormat;
|
||||||
import android.os.Handler;
|
import android.os.Handler;
|
||||||
import android.util.Log;
|
import android.util.Log;
|
||||||
import com.google.common.base.Preconditions;
|
import com.google.common.base.Preconditions;
|
||||||
|
import com.google.mediapipe.proto.CalculatorProto.CalculatorGraphConfig;
|
||||||
import com.google.mediapipe.framework.AndroidAssetUtil;
|
import com.google.mediapipe.framework.AndroidAssetUtil;
|
||||||
import com.google.mediapipe.framework.AndroidPacketCreator;
|
import com.google.mediapipe.framework.AndroidPacketCreator;
|
||||||
import com.google.mediapipe.framework.Graph;
|
import com.google.mediapipe.framework.Graph;
|
||||||
|
@ -32,10 +33,12 @@ import com.google.mediapipe.framework.SurfaceOutput;
|
||||||
import com.google.mediapipe.framework.TextureFrame;
|
import com.google.mediapipe.framework.TextureFrame;
|
||||||
import java.io.File;
|
import java.io.File;
|
||||||
import java.nio.ByteBuffer;
|
import java.nio.ByteBuffer;
|
||||||
|
import java.util.ArrayDeque;
|
||||||
import java.util.ArrayList;
|
import java.util.ArrayList;
|
||||||
import java.util.Arrays;
|
import java.util.Arrays;
|
||||||
import java.util.List;
|
import java.util.List;
|
||||||
import java.util.Map;
|
import java.util.Map;
|
||||||
|
import java.util.Queue;
|
||||||
import java.util.concurrent.atomic.AtomicBoolean;
|
import java.util.concurrent.atomic.AtomicBoolean;
|
||||||
import javax.annotation.Nullable;
|
import javax.annotation.Nullable;
|
||||||
|
|
||||||
|
@ -106,6 +109,15 @@ public class FrameProcessor implements TextureFrameProcessor, AudioDataProcessor
|
||||||
initializeGraphAndPacketCreator(context, graphName);
|
initializeGraphAndPacketCreator(context, graphName);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Constructor.
|
||||||
|
*
|
||||||
|
* @param graphConfig the proto object representation of the graph.
|
||||||
|
*/
|
||||||
|
public FrameProcessor(CalculatorGraphConfig graphConfig) {
|
||||||
|
initializeGraphAndPacketCreator(graphConfig);
|
||||||
|
}
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* Initializes a graph for processing data in real time.
|
* Initializes a graph for processing data in real time.
|
||||||
*
|
*
|
||||||
|
@ -123,6 +135,17 @@ public class FrameProcessor implements TextureFrameProcessor, AudioDataProcessor
|
||||||
packetCreator = new AndroidPacketCreator(mediapipeGraph);
|
packetCreator = new AndroidPacketCreator(mediapipeGraph);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Initializes a graph for processing data in real time.
|
||||||
|
*
|
||||||
|
* @param graphConfig the proto object representation of the graph.
|
||||||
|
*/
|
||||||
|
private void initializeGraphAndPacketCreator(CalculatorGraphConfig graphConfig) {
|
||||||
|
mediapipeGraph = new Graph();
|
||||||
|
mediapipeGraph.loadBinaryGraph(graphConfig);
|
||||||
|
packetCreator = new AndroidPacketCreator(mediapipeGraph);
|
||||||
|
}
|
||||||
|
|
||||||
/** Callback for errors occurring during processing in the graph. */
|
/** Callback for errors occurring during processing in the graph. */
|
||||||
public interface ErrorListener {
|
public interface ErrorListener {
|
||||||
void onError(RuntimeException error);
|
void onError(RuntimeException error);
|
||||||
|
@ -186,6 +209,8 @@ public class FrameProcessor implements TextureFrameProcessor, AudioDataProcessor
|
||||||
currentConsumers = videoConsumers;
|
currentConsumers = videoConsumers;
|
||||||
}
|
}
|
||||||
for (TextureFrameConsumer consumer : currentConsumers) {
|
for (TextureFrameConsumer consumer : currentConsumers) {
|
||||||
|
// Note: each consumer will release its TextureFrame, so each gets a separate object
|
||||||
|
// (though they all reference the same data).
|
||||||
TextureFrame frame = PacketGetter.getTextureFrame(packet);
|
TextureFrame frame = PacketGetter.getTextureFrame(packet);
|
||||||
if (Log.isLoggable(TAG, Log.VERBOSE)) {
|
if (Log.isLoggable(TAG, Log.VERBOSE)) {
|
||||||
Log.v(
|
Log.v(
|
||||||
|
@ -373,9 +398,10 @@ public class FrameProcessor implements TextureFrameProcessor, AudioDataProcessor
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* Returns true if the MediaPipe graph can accept one more input frame.
|
* Returns true if the MediaPipe graph can accept one more input frame.
|
||||||
|
*
|
||||||
* @throws MediaPipeException for any error status.
|
* @throws MediaPipeException for any error status.
|
||||||
*/
|
*/
|
||||||
private boolean maybeAcceptNewFrame() {
|
private boolean maybeAcceptNewFrame(long timestamp) {
|
||||||
if (!started.getAndSet(true)) {
|
if (!started.getAndSet(true)) {
|
||||||
startGraph();
|
startGraph();
|
||||||
}
|
}
|
||||||
|
@ -395,7 +421,7 @@ public class FrameProcessor implements TextureFrameProcessor, AudioDataProcessor
|
||||||
frame.getTextureName(), frame.getWidth(), frame.getHeight()));
|
frame.getTextureName(), frame.getWidth(), frame.getHeight()));
|
||||||
}
|
}
|
||||||
|
|
||||||
if (!maybeAcceptNewFrame()) {
|
if (!maybeAcceptNewFrame(frame.getTimestamp())) {
|
||||||
return;
|
return;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -451,7 +477,7 @@ public class FrameProcessor implements TextureFrameProcessor, AudioDataProcessor
|
||||||
public void onNewFrame(final Bitmap bitmap, long timestamp) {
|
public void onNewFrame(final Bitmap bitmap, long timestamp) {
|
||||||
Packet packet = null;
|
Packet packet = null;
|
||||||
try {
|
try {
|
||||||
if (!maybeAcceptNewFrame()) {
|
if (!maybeAcceptNewFrame(timestamp)) {
|
||||||
return;
|
return;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
|
@ -17,8 +17,8 @@ package com.google.mediapipe.components;
|
||||||
import android.Manifest;
|
import android.Manifest;
|
||||||
import android.app.Activity;
|
import android.app.Activity;
|
||||||
import android.content.pm.PackageManager;
|
import android.content.pm.PackageManager;
|
||||||
import androidx.core.app.ActivityCompat;
|
|
||||||
import android.util.Log;
|
import android.util.Log;
|
||||||
|
import androidx.core.app.ActivityCompat;
|
||||||
import androidx.core.content.ContextCompat;
|
import androidx.core.content.ContextCompat;
|
||||||
|
|
||||||
/** Manages camera permission request and handling. */
|
/** Manages camera permission request and handling. */
|
||||||
|
|
|
@ -18,6 +18,10 @@ import com.google.mediapipe.framework.TextureFrame;
|
||||||
|
|
||||||
/** Lightweight abstraction for an object that can receive video frames. */
|
/** Lightweight abstraction for an object that can receive video frames. */
|
||||||
public interface TextureFrameConsumer {
|
public interface TextureFrameConsumer {
|
||||||
/** Called when a new {@link TextureFrame} is available. */
|
/**
|
||||||
|
* Called when a new {@link TextureFrame} is available.
|
||||||
|
*
|
||||||
|
* Important: implementations of this method should call frame.release().
|
||||||
|
**/
|
||||||
public abstract void onNewFrame(TextureFrame frame);
|
public abstract void onNewFrame(TextureFrame frame);
|
||||||
}
|
}
|
||||||
|
|
|
@ -272,6 +272,10 @@ public final class PacketGetter {
|
||||||
* <p>Note: in order for the application to be able to use the texture, its GL context must be
|
* <p>Note: in order for the application to be able to use the texture, its GL context must be
|
||||||
* linked with MediaPipe's. This is ensured by calling {@link Graph#createGlRunner(String,long)}
|
* linked with MediaPipe's. This is ensured by calling {@link Graph#createGlRunner(String,long)}
|
||||||
* with the native handle to the application's GL context as the second argument.
|
* with the native handle to the application's GL context as the second argument.
|
||||||
|
*
|
||||||
|
* <p>The returned GraphTextureFrame must be released by the caller. If this method is called
|
||||||
|
* multiple times, each returned GraphTextureFrame is an independent reference to the underlying
|
||||||
|
* texture data, and must be released individually.
|
||||||
*/
|
*/
|
||||||
public static GraphTextureFrame getTextureFrame(final Packet packet) {
|
public static GraphTextureFrame getTextureFrame(final Packet packet) {
|
||||||
return new GraphTextureFrame(
|
return new GraphTextureFrame(
|
||||||
|
|
|
@ -1,7 +0,0 @@
|
||||||
tricorder: {
|
|
||||||
options: {
|
|
||||||
builder: {
|
|
||||||
config: "android_arm"
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
|
@ -184,6 +184,20 @@ typedef NS_ENUM(int, MPPPacketType) {
|
||||||
packetType:(MPPPacketType)packetType
|
packetType:(MPPPacketType)packetType
|
||||||
timestamp:(const mediapipe::Timestamp &)timestamp;
|
timestamp:(const mediapipe::Timestamp &)timestamp;
|
||||||
|
|
||||||
|
/// Sends a pixel buffer into a graph input stream, using the specified packet
|
||||||
|
/// type. The graph must have been started before calling this. Drops frames and
|
||||||
|
/// returns NO if maxFramesInFlight is exceeded. If allowOverwrite is set to YES,
|
||||||
|
/// allows MediaPipe to overwrite the packet contents on successful sending for
|
||||||
|
/// possibly increased efficiency. Returns YES if the packet was successfully sent.
|
||||||
|
/// Sets error to a non-nil value if an error occurs in the graph when sending the
|
||||||
|
/// packet.
|
||||||
|
- (BOOL)sendPixelBuffer:(CVPixelBufferRef)imageBuffer
|
||||||
|
intoStream:(const std::string &)inputName
|
||||||
|
packetType:(MPPPacketType)packetType
|
||||||
|
timestamp:(const mediapipe::Timestamp &)timestamp
|
||||||
|
allowOverwrite:(BOOL)allowOverwrite
|
||||||
|
error:(NSError **)error;
|
||||||
|
|
||||||
/// Sends a pixel buffer into a graph input stream, using the specified packet
|
/// Sends a pixel buffer into a graph input stream, using the specified packet
|
||||||
/// type. The graph must have been started before calling this. The timestamp is
|
/// type. The graph must have been started before calling this. The timestamp is
|
||||||
/// automatically incremented from the last timestamp used by this method. Drops
|
/// automatically incremented from the last timestamp used by this method. Drops
|
||||||
|
|
|
@ -327,22 +327,35 @@ void CallFrameDelegate(void* wrapperVoid, const std::string& streamName,
|
||||||
packetType:(MPPPacketType)packetType
|
packetType:(MPPPacketType)packetType
|
||||||
timestamp:(const mediapipe::Timestamp&)timestamp
|
timestamp:(const mediapipe::Timestamp&)timestamp
|
||||||
allowOverwrite:(BOOL)allowOverwrite {
|
allowOverwrite:(BOOL)allowOverwrite {
|
||||||
|
NSError* error;
|
||||||
|
bool success = [self sendPixelBuffer:imageBuffer
|
||||||
|
intoStream:inputName
|
||||||
|
packetType:packetType
|
||||||
|
timestamp:timestamp
|
||||||
|
allowOverwrite:allowOverwrite
|
||||||
|
error:&error];
|
||||||
|
if (error) {
|
||||||
|
_GTMDevLog(@"failed to send packet: %@", error);
|
||||||
|
}
|
||||||
|
return success;
|
||||||
|
}
|
||||||
|
|
||||||
|
- (BOOL)sendPixelBuffer:(CVPixelBufferRef)imageBuffer
|
||||||
|
intoStream:(const std::string&)inputName
|
||||||
|
packetType:(MPPPacketType)packetType
|
||||||
|
timestamp:(const mediapipe::Timestamp&)timestamp
|
||||||
|
allowOverwrite:(BOOL)allowOverwrite
|
||||||
|
error:(NSError**)error {
|
||||||
if (_maxFramesInFlight && _framesInFlight >= _maxFramesInFlight) return NO;
|
if (_maxFramesInFlight && _framesInFlight >= _maxFramesInFlight) return NO;
|
||||||
mediapipe::Packet packet = [self packetWithPixelBuffer:imageBuffer packetType:packetType];
|
mediapipe::Packet packet = [self packetWithPixelBuffer:imageBuffer packetType:packetType];
|
||||||
NSError* error;
|
|
||||||
BOOL success;
|
BOOL success;
|
||||||
if (allowOverwrite) {
|
if (allowOverwrite) {
|
||||||
packet = std::move(packet).At(timestamp);
|
packet = std::move(packet).At(timestamp);
|
||||||
success = [self movePacket:std::move(packet)
|
success = [self movePacket:std::move(packet) intoStream:inputName error:error];
|
||||||
intoStream:inputName
|
|
||||||
error:&error];
|
|
||||||
} else {
|
} else {
|
||||||
success = [self sendPacket:packet.At(timestamp)
|
success = [self sendPacket:packet.At(timestamp) intoStream:inputName error:error];
|
||||||
intoStream:inputName
|
|
||||||
error:&error];
|
|
||||||
}
|
}
|
||||||
if (success) _framesInFlight++;
|
if (success) _framesInFlight++;
|
||||||
else _GTMDevLog(@"failed to send packet: %@", error);
|
|
||||||
return success;
|
return success;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
|
@ -423,6 +423,10 @@ tasks and tracking (or class) fields for tracking information.
|
||||||
|`region/point/y`|feature list float list|`add_bbox_point_y` / `AddBBoxPointY`|A list of normalized y values for points in a frame.|
|
|`region/point/y`|feature list float list|`add_bbox_point_y` / `AddBBoxPointY`|A list of normalized y values for points in a frame.|
|
||||||
|`region/point/\*`| *special* |`add_bbox_point` / `AddBBoxPoint`|Operates on point/x,point/y with a single call.|
|
|`region/point/\*`| *special* |`add_bbox_point` / `AddBBoxPoint`|Operates on point/x,point/y with a single call.|
|
||||||
|`region/point/radius`|feature list float list|`add_bbox_point_radius` / `AddBBoxPointRadius`|A list of radii for points in a frame.|
|
|`region/point/radius`|feature list float list|`add_bbox_point_radius` / `AddBBoxPointRadius`|A list of radii for points in a frame.|
|
||||||
|
|`region/3d_point/x`|feature list float list|`add_bbox_3d_point_x` / `AddBBox3dPointX`|A list of normalized x values for points in a frame.|
|
||||||
|
|`region/3d_point/y`|feature list float list|`add_bbox_3d_point_y` / `AddBBox3dPointY`|A list of normalized y values for points in a frame.|
|
||||||
|
|`region/3d_point/z`|feature list float list|`add_bbox_3d_point_z` / `AddBBox3dPointZ`|A list of normalized z values for points in a frame.|
|
||||||
|
|`region/3d_point/\*`| *special* |`add_bbox_3d_point` / `AddBBox3dPoint`|Operates on 3d_point/{x,y,z} with a single call.|
|
||||||
|`region/timestamp`|feature list int|`add_bbox_timestamp` / `AddBBoxTimestamp`|The timestamp in microseconds for the region annotations.|
|
|`region/timestamp`|feature list int|`add_bbox_timestamp` / `AddBBoxTimestamp`|The timestamp in microseconds for the region annotations.|
|
||||||
|`region/num_regions`|feature list int|`add_bbox_num_regions` / `AddBBoxNumRegions`|The number of boxes or other regions in a frame. Should be 0 for unannotated frames.|
|
|`region/num_regions`|feature list int|`add_bbox_num_regions` / `AddBBoxNumRegions`|The number of boxes or other regions in a frame. Should be 0 for unannotated frames.|
|
||||||
|`region/is_annotated`|feature list int|`add_bbox_is_annotated` / `AddBBoxIsAnnotated`|1 if this timestep is annotated. 0 otherwise. Distinguishes empty from unannotated frames.|
|
|`region/is_annotated`|feature list int|`add_bbox_is_annotated` / `AddBBoxIsAnnotated`|1 if this timestep is annotated. 0 otherwise. Distinguishes empty from unannotated frames.|
|
||||||
|
|
|
@ -229,6 +229,18 @@ float TimestampsToRate(int64 first_timestamp, int64 second_timestamp) {
|
||||||
sequence);
|
sequence);
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
if (Get3dPointSize(prefix, *sequence) > 0) {
|
||||||
|
std::string x_key = merge_prefix(prefix, kRegion3dPointXKey);
|
||||||
|
auto* region_feature_list = MutableFeatureList(x_key, sequence);
|
||||||
|
RET_CHECK_EQ(num_bboxes, region_feature_list->feature_size())
|
||||||
|
<< "Expected number of BBox timestamps and boxes to match.";
|
||||||
|
ClearBBoxNumRegions(prefix, sequence);
|
||||||
|
for (int i = 0; i < num_bboxes; ++i) {
|
||||||
|
AddBBoxNumRegions(
|
||||||
|
prefix, region_feature_list->feature(i).float_list().value_size(),
|
||||||
|
sequence);
|
||||||
|
}
|
||||||
|
}
|
||||||
// Collect which timestamps currently match to which indices in timestamps.
|
// Collect which timestamps currently match to which indices in timestamps.
|
||||||
// skip empty timestamps.
|
// skip empty timestamps.
|
||||||
// Requires sorted indices.
|
// Requires sorted indices.
|
||||||
|
@ -453,6 +465,47 @@ void ClearPoint(const std::string& prefix,
|
||||||
ClearBBoxPointX(prefix, sequence);
|
ClearBBoxPointX(prefix, sequence);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
int Get3dPointSize(const std::string& prefix,
|
||||||
|
const tensorflow::SequenceExample& sequence) {
|
||||||
|
return GetBBox3dPointXSize(prefix, sequence);
|
||||||
|
}
|
||||||
|
|
||||||
|
std::vector<::std::tuple<float, float, float>> Get3dPointAt(
|
||||||
|
const std::string& prefix, const tensorflow::SequenceExample& sequence,
|
||||||
|
int index) {
|
||||||
|
const auto& xs = GetBBox3dPointXAt(prefix, sequence, index);
|
||||||
|
const auto& ys = GetBBox3dPointYAt(prefix, sequence, index);
|
||||||
|
const auto& zs = GetBBox3dPointZAt(prefix, sequence, index);
|
||||||
|
std::vector<::std::tuple<float, float, float>> points(ys.size());
|
||||||
|
for (int i = 0; i < xs.size(); ++i) {
|
||||||
|
points[i] = std::make_tuple(xs[i], ys[i], zs[i]);
|
||||||
|
}
|
||||||
|
return points;
|
||||||
|
}
|
||||||
|
|
||||||
|
void Add3dPoint(const std::string& prefix,
|
||||||
|
const std::vector<::std::tuple<float, float, float>>& points,
|
||||||
|
tensorflow::SequenceExample* sequence) {
|
||||||
|
::std::vector<float> xs;
|
||||||
|
::std::vector<float> ys;
|
||||||
|
::std::vector<float> zs;
|
||||||
|
for (auto& point : points) {
|
||||||
|
xs.push_back(std::get<0>(point));
|
||||||
|
ys.push_back(std::get<1>(point));
|
||||||
|
zs.push_back(std::get<2>(point));
|
||||||
|
}
|
||||||
|
AddBBox3dPointX(prefix, xs, sequence);
|
||||||
|
AddBBox3dPointY(prefix, ys, sequence);
|
||||||
|
AddBBox3dPointZ(prefix, zs, sequence);
|
||||||
|
}
|
||||||
|
|
||||||
|
void Clear3dPoint(const std::string& prefix,
|
||||||
|
tensorflow::SequenceExample* sequence) {
|
||||||
|
ClearBBox3dPointX(prefix, sequence);
|
||||||
|
ClearBBox3dPointY(prefix, sequence);
|
||||||
|
ClearBBox3dPointZ(prefix, sequence);
|
||||||
|
}
|
||||||
|
|
||||||
std::unique_ptr<mediapipe::Matrix> GetAudioFromFeatureAt(
|
std::unique_ptr<mediapipe::Matrix> GetAudioFromFeatureAt(
|
||||||
const std::string& prefix, const tensorflow::SequenceExample& sequence,
|
const std::string& prefix, const tensorflow::SequenceExample& sequence,
|
||||||
int index) {
|
int index) {
|
||||||
|
|
|
@ -268,6 +268,10 @@ const char kRegionBBoxXMaxKey[] = "region/bbox/xmax";
|
||||||
const char kRegionPointXKey[] = "region/point/x";
|
const char kRegionPointXKey[] = "region/point/x";
|
||||||
const char kRegionPointYKey[] = "region/point/y";
|
const char kRegionPointYKey[] = "region/point/y";
|
||||||
const char kRegionRadiusKey[] = "region/radius";
|
const char kRegionRadiusKey[] = "region/radius";
|
||||||
|
// The 3d point can denote keypoints.
|
||||||
|
const char kRegion3dPointXKey[] = "region/3d_point/x";
|
||||||
|
const char kRegion3dPointYKey[] = "region/3d_point/y";
|
||||||
|
const char kRegion3dPointZKey[] = "region/3d_point/z";
|
||||||
// The number of regions at that timestep.
|
// The number of regions at that timestep.
|
||||||
const char kRegionNumRegionsKey[] = "region/num_regions";
|
const char kRegionNumRegionsKey[] = "region/num_regions";
|
||||||
// Whether that timestep is annotated for bounding regions.
|
// Whether that timestep is annotated for bounding regions.
|
||||||
|
@ -333,6 +337,18 @@ void AddPoint(const std::string& prefix,
|
||||||
void ClearPoint(const std::string& prefix,
|
void ClearPoint(const std::string& prefix,
|
||||||
tensorflow::SequenceExample* sequence);
|
tensorflow::SequenceExample* sequence);
|
||||||
|
|
||||||
|
// The input and output format is a pair of <y, x> coordinates to match the
|
||||||
|
// order of bounding box coordinates.
|
||||||
|
int Get3dPointSize(const std::string& prefix,
|
||||||
|
const tensorflow::SequenceExample& sequence);
|
||||||
|
std::vector<std::tuple<float, float, float>> Get3dPointAt(
|
||||||
|
const std::string& prefix, const tensorflow::SequenceExample& sequence,
|
||||||
|
int index);
|
||||||
|
void Add3dPoint(const std::string& prefix,
|
||||||
|
const std::vector<std::tuple<float, float, float>>& points,
|
||||||
|
tensorflow::SequenceExample* sequence);
|
||||||
|
void Clear3dPoint(const std::string& prefix,
|
||||||
|
tensorflow::SequenceExample* sequence);
|
||||||
#define FIXED_PREFIX_BBOX_ACCESSORS(identifier, prefix) \
|
#define FIXED_PREFIX_BBOX_ACCESSORS(identifier, prefix) \
|
||||||
inline int CONCAT_STR3(Get, identifier, \
|
inline int CONCAT_STR3(Get, identifier, \
|
||||||
Size)(const tensorflow::SequenceExample& sequence) { \
|
Size)(const tensorflow::SequenceExample& sequence) { \
|
||||||
|
@ -388,6 +404,44 @@ void ClearPoint(const std::string& prefix,
|
||||||
inline void CONCAT_STR3(Clear, identifier, Point)( \
|
inline void CONCAT_STR3(Clear, identifier, Point)( \
|
||||||
std::string name, tensorflow::SequenceExample * sequence) { \
|
std::string name, tensorflow::SequenceExample * sequence) { \
|
||||||
return ClearPoint(name, sequence); \
|
return ClearPoint(name, sequence); \
|
||||||
|
} \
|
||||||
|
inline int CONCAT_STR3(Get, identifier, 3dPointSize)( \
|
||||||
|
const tensorflow::SequenceExample& sequence) { \
|
||||||
|
return Get3dPointSize(prefix, sequence); \
|
||||||
|
} \
|
||||||
|
inline int CONCAT_STR3(Get, identifier, 3dPointSize)( \
|
||||||
|
const std::string& name, const tensorflow::SequenceExample& sequence) { \
|
||||||
|
return Get3dPointSize(name, sequence); \
|
||||||
|
} \
|
||||||
|
inline std::vector<std::tuple<float, float, float>> CONCAT_STR3( \
|
||||||
|
Get, identifier, 3dPointAt)(const tensorflow::SequenceExample& sequence, \
|
||||||
|
int index) { \
|
||||||
|
return Get3dPointAt(prefix, sequence, index); \
|
||||||
|
} \
|
||||||
|
inline std::vector<std::tuple<float, float, float>> CONCAT_STR3( \
|
||||||
|
Get, identifier, 3dPointAt)(const std::string& name, \
|
||||||
|
const tensorflow::SequenceExample& sequence, \
|
||||||
|
int index) { \
|
||||||
|
return Get3dPointAt(name, sequence, index); \
|
||||||
|
} \
|
||||||
|
inline void CONCAT_STR3(Add, identifier, 3dPoint)( \
|
||||||
|
const std::vector<std::tuple<float, float, float>>& points, \
|
||||||
|
tensorflow::SequenceExample* sequence) { \
|
||||||
|
return Add3dPoint(prefix, points, sequence); \
|
||||||
|
} \
|
||||||
|
inline void CONCAT_STR3(Add, identifier, 3dPoint)( \
|
||||||
|
const std::string& name, \
|
||||||
|
const std::vector<std::tuple<float, float, float>>& points, \
|
||||||
|
tensorflow::SequenceExample* sequence) { \
|
||||||
|
return Add3dPoint(name, points, sequence); \
|
||||||
|
} \
|
||||||
|
inline void CONCAT_STR3(Clear, identifier, \
|
||||||
|
3dPoint)(tensorflow::SequenceExample * sequence) { \
|
||||||
|
return Clear3dPoint(prefix, sequence); \
|
||||||
|
} \
|
||||||
|
inline void CONCAT_STR3(Clear, identifier, 3dPoint)( \
|
||||||
|
std::string name, tensorflow::SequenceExample * sequence) { \
|
||||||
|
return Clear3dPoint(name, sequence); \
|
||||||
}
|
}
|
||||||
|
|
||||||
#define PREFIXED_BBOX(identifier, prefix) \
|
#define PREFIXED_BBOX(identifier, prefix) \
|
||||||
|
@ -435,6 +489,12 @@ void ClearPoint(const std::string& prefix,
|
||||||
kRegionPointYKey, prefix) \
|
kRegionPointYKey, prefix) \
|
||||||
FIXED_PREFIX_VECTOR_FLOAT_FEATURE_LIST(CONCAT_STR2(identifier, Radius), \
|
FIXED_PREFIX_VECTOR_FLOAT_FEATURE_LIST(CONCAT_STR2(identifier, Radius), \
|
||||||
kRegionRadiusKey, prefix) \
|
kRegionRadiusKey, prefix) \
|
||||||
|
FIXED_PREFIX_VECTOR_FLOAT_FEATURE_LIST(CONCAT_STR2(identifier, 3dPointX), \
|
||||||
|
kRegion3dPointXKey, prefix) \
|
||||||
|
FIXED_PREFIX_VECTOR_FLOAT_FEATURE_LIST(CONCAT_STR2(identifier, 3dPointY), \
|
||||||
|
kRegion3dPointYKey, prefix) \
|
||||||
|
FIXED_PREFIX_VECTOR_FLOAT_FEATURE_LIST(CONCAT_STR2(identifier, 3dPointZ), \
|
||||||
|
kRegion3dPointZKey, prefix) \
|
||||||
FIXED_PREFIX_VECTOR_FLOAT_FEATURE_LIST( \
|
FIXED_PREFIX_VECTOR_FLOAT_FEATURE_LIST( \
|
||||||
CONCAT_STR2(identifier, EmbeddingFloats), kRegionEmbeddingFloatKey, \
|
CONCAT_STR2(identifier, EmbeddingFloats), kRegionEmbeddingFloatKey, \
|
||||||
prefix) \
|
prefix) \
|
||||||
|
|
|
@ -262,6 +262,10 @@ REGION_BBOX_XMAX_KEY = "region/bbox/xmax"
|
||||||
REGION_POINT_X_KEY = "region/point/x"
|
REGION_POINT_X_KEY = "region/point/x"
|
||||||
REGION_POINT_Y_KEY = "region/point/y"
|
REGION_POINT_Y_KEY = "region/point/y"
|
||||||
REGION_RADIUS_KEY = "region/radius"
|
REGION_RADIUS_KEY = "region/radius"
|
||||||
|
# The 3D point can denote keypoints.
|
||||||
|
REGION_3D_POINT_X_KEY = "region/3d_point/x"
|
||||||
|
REGION_3D_POINT_Y_KEY = "region/3d_point/y"
|
||||||
|
REGION_3D_POINT_Z_KEY = "region/3d_point/z"
|
||||||
# The number of regions at that timestep.
|
# The number of regions at that timestep.
|
||||||
REGION_NUM_REGIONS_KEY = "region/num_regions"
|
REGION_NUM_REGIONS_KEY = "region/num_regions"
|
||||||
# Whether that timestep is annotated for regions.
|
# Whether that timestep is annotated for regions.
|
||||||
|
@ -365,6 +369,15 @@ def _create_region_with_prefix(name, prefix):
|
||||||
prefix=prefix, module_dict=globals())
|
prefix=prefix, module_dict=globals())
|
||||||
msu.create_float_list_feature_list(name + "_point_y", REGION_POINT_Y_KEY,
|
msu.create_float_list_feature_list(name + "_point_y", REGION_POINT_Y_KEY,
|
||||||
prefix=prefix, module_dict=globals())
|
prefix=prefix, module_dict=globals())
|
||||||
|
msu.create_float_list_feature_list(
|
||||||
|
name + "_3d_point_x", REGION_3D_POINT_X_KEY,
|
||||||
|
prefix=prefix, module_dict=globals())
|
||||||
|
msu.create_float_list_feature_list(
|
||||||
|
name + "_3d_point_y", REGION_3D_POINT_Y_KEY,
|
||||||
|
prefix=prefix, module_dict=globals())
|
||||||
|
msu.create_float_list_feature_list(
|
||||||
|
name + "_3d_point_z", REGION_3D_POINT_Z_KEY,
|
||||||
|
prefix=prefix, module_dict=globals())
|
||||||
msu.create_bytes_list_context_feature(name + "_parts",
|
msu.create_bytes_list_context_feature(name + "_parts",
|
||||||
REGION_PARTS_KEY,
|
REGION_PARTS_KEY,
|
||||||
prefix=prefix, module_dict=globals())
|
prefix=prefix, module_dict=globals())
|
||||||
|
@ -406,6 +419,39 @@ def _create_region_with_prefix(name, prefix):
|
||||||
clear_bbox_xmin(sequence_example, prefix=prefix)
|
clear_bbox_xmin(sequence_example, prefix=prefix)
|
||||||
clear_bbox_ymax(sequence_example, prefix=prefix)
|
clear_bbox_ymax(sequence_example, prefix=prefix)
|
||||||
clear_bbox_xmax(sequence_example, prefix=prefix)
|
clear_bbox_xmax(sequence_example, prefix=prefix)
|
||||||
|
def get_prefixed_point_at(index, sequence_example, prefix):
|
||||||
|
return np.stack((
|
||||||
|
get_bbox_point_y_at(index, sequence_example, prefix=prefix),
|
||||||
|
get_bbox_point_x_at(index, sequence_example, prefix=prefix)),
|
||||||
|
1)
|
||||||
|
def add_prefixed_point(values, sequence_example, prefix):
|
||||||
|
add_bbox_point_y(values[:, 0], sequence_example, prefix=prefix)
|
||||||
|
add_bbox_point_x(values[:, 1], sequence_example, prefix=prefix)
|
||||||
|
def get_prefixed_point_size(sequence_example, prefix):
|
||||||
|
return get_bbox_point_y_size(sequence_example, prefix=prefix)
|
||||||
|
def has_prefixed_point(sequence_example, prefix):
|
||||||
|
return has_bbox_point_y(sequence_example, prefix=prefix)
|
||||||
|
def clear_prefixed_point(sequence_example, prefix):
|
||||||
|
clear_bbox_point_y(sequence_example, prefix=prefix)
|
||||||
|
clear_bbox_point_x(sequence_example, prefix=prefix)
|
||||||
|
def get_prefixed_3d_point_at(index, sequence_example, prefix):
|
||||||
|
return np.stack((
|
||||||
|
get_bbox_3d_point_x_at(index, sequence_example, prefix=prefix),
|
||||||
|
get_bbox_3d_point_y_at(index, sequence_example, prefix=prefix),
|
||||||
|
get_bbox_3d_point_z_at(index, sequence_example, prefix=prefix)),
|
||||||
|
1)
|
||||||
|
def add_prefixed_3d_point(values, sequence_example, prefix):
|
||||||
|
add_bbox_3d_point_x(values[:, 0], sequence_example, prefix=prefix)
|
||||||
|
add_bbox_3d_point_y(values[:, 1], sequence_example, prefix=prefix)
|
||||||
|
add_bbox_3d_point_z(values[:, 2], sequence_example, prefix=prefix)
|
||||||
|
def get_prefixed_3d_point_size(sequence_example, prefix):
|
||||||
|
return get_bbox_3d_point_x_size(sequence_example, prefix=prefix)
|
||||||
|
def has_prefixed_3d_point(sequence_example, prefix):
|
||||||
|
return has_bbox_3d_point_x(sequence_example, prefix=prefix)
|
||||||
|
def clear_prefixed_3d_point(sequence_example, prefix):
|
||||||
|
clear_bbox_3d_point_x(sequence_example, prefix=prefix)
|
||||||
|
clear_bbox_3d_point_y(sequence_example, prefix=prefix)
|
||||||
|
clear_bbox_3d_point_z(sequence_example, prefix=prefix)
|
||||||
# pylint: enable=undefined-variable
|
# pylint: enable=undefined-variable
|
||||||
msu.add_functions_to_module({
|
msu.add_functions_to_module({
|
||||||
"get_" + name + "_at":
|
"get_" + name + "_at":
|
||||||
|
@ -419,6 +465,30 @@ def _create_region_with_prefix(name, prefix):
|
||||||
"clear_" + name:
|
"clear_" + name:
|
||||||
functools.partial(clear_prefixed_bbox, prefix=prefix),
|
functools.partial(clear_prefixed_bbox, prefix=prefix),
|
||||||
}, module_dict=globals())
|
}, module_dict=globals())
|
||||||
|
msu.add_functions_to_module({
|
||||||
|
"get_" + name + "_point_at":
|
||||||
|
functools.partial(get_prefixed_point_at, prefix=prefix),
|
||||||
|
"add_" + name + "_point":
|
||||||
|
functools.partial(add_prefixed_point, prefix=prefix),
|
||||||
|
"get_" + name + "_point_size":
|
||||||
|
functools.partial(get_prefixed_point_size, prefix=prefix),
|
||||||
|
"has_" + name + "_point":
|
||||||
|
functools.partial(has_prefixed_point, prefix=prefix),
|
||||||
|
"clear_" + name + "_point":
|
||||||
|
functools.partial(clear_prefixed_point, prefix=prefix),
|
||||||
|
}, module_dict=globals())
|
||||||
|
msu.add_functions_to_module({
|
||||||
|
"get_" + name + "_3d_point_at":
|
||||||
|
functools.partial(get_prefixed_3d_point_at, prefix=prefix),
|
||||||
|
"add_" + name + "_3d_point":
|
||||||
|
functools.partial(add_prefixed_3d_point, prefix=prefix),
|
||||||
|
"get_" + name + "_3d_point_size":
|
||||||
|
functools.partial(get_prefixed_3d_point_size, prefix=prefix),
|
||||||
|
"has_" + name + "_3d_point":
|
||||||
|
functools.partial(has_prefixed_3d_point, prefix=prefix),
|
||||||
|
"clear_" + name + "_3d_point":
|
||||||
|
functools.partial(clear_prefixed_3d_point, prefix=prefix),
|
||||||
|
}, module_dict=globals())
|
||||||
|
|
||||||
|
|
||||||
PREDICTED_PREFIX = "PREDICTED"
|
PREDICTED_PREFIX = "PREDICTED"
|
||||||
|
|
|
@ -436,6 +436,21 @@ TEST(MediaSequenceTest, RoundTripBBoxPointPrefixed) {
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
TEST(MediaSequenceTest, RoundTripBBox3dPoint) {
|
||||||
|
tensorflow::SequenceExample sequence;
|
||||||
|
std::vector<std::vector<std::tuple<float, float, float>>> points = {
|
||||||
|
{std::make_tuple(0.3, 0.5, 0.1), std::make_tuple(0.4, 0.7, 0.2)},
|
||||||
|
{std::make_tuple(0.7, 0.5, 0.3), std::make_tuple(0.3, 0.4, 0.4)}};
|
||||||
|
for (int i = 0; i < points.size(); ++i) {
|
||||||
|
AddBBox3dPoint(points[i], &sequence);
|
||||||
|
ASSERT_EQ(GetBBox3dPointSize(sequence), i + 1);
|
||||||
|
const auto& sequence_points = GetBBox3dPointAt(sequence, i);
|
||||||
|
for (int j = 0; j < sequence_points.size(); ++j) {
|
||||||
|
EXPECT_EQ(sequence_points[j], points[i][j]);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
TEST(MediaSequenceTest, RoundTripRegionParts) {
|
TEST(MediaSequenceTest, RoundTripRegionParts) {
|
||||||
tensorflow::SequenceExample sequence;
|
tensorflow::SequenceExample sequence;
|
||||||
std::vector<std::string> parts = {"HEAD", "FEET"};
|
std::vector<std::string> parts = {"HEAD", "FEET"};
|
||||||
|
|
|
@ -89,6 +89,9 @@ class MediaSequenceTest(tf.test.TestCase):
|
||||||
ms.add_bbox_xmax((0.47, 0.49), example)
|
ms.add_bbox_xmax((0.47, 0.49), example)
|
||||||
ms.add_bbox_point_x((0.47, 0.49), example)
|
ms.add_bbox_point_x((0.47, 0.49), example)
|
||||||
ms.add_bbox_point_y((0.47, 0.49), example)
|
ms.add_bbox_point_y((0.47, 0.49), example)
|
||||||
|
ms.add_bbox_3d_point_x((0.47, 0.49), example)
|
||||||
|
ms.add_bbox_3d_point_y((0.47, 0.49), example)
|
||||||
|
ms.add_bbox_3d_point_z((0.47, 0.49), example)
|
||||||
ms.add_predicted_bbox_ymin((0.47, 0.49), example)
|
ms.add_predicted_bbox_ymin((0.47, 0.49), example)
|
||||||
ms.add_predicted_bbox_xmin((0.47, 0.49), example)
|
ms.add_predicted_bbox_xmin((0.47, 0.49), example)
|
||||||
ms.add_predicted_bbox_ymax((0.47, 0.49), example)
|
ms.add_predicted_bbox_ymax((0.47, 0.49), example)
|
||||||
|
@ -133,6 +136,30 @@ class MediaSequenceTest(tf.test.TestCase):
|
||||||
ms.clear_bbox(example)
|
ms.clear_bbox(example)
|
||||||
self.assertEqual(0, ms.get_bbox_size(example))
|
self.assertEqual(0, ms.get_bbox_size(example))
|
||||||
|
|
||||||
|
def test_point_round_trip(self):
|
||||||
|
example = tf.train.SequenceExample()
|
||||||
|
points = np.array([[0.1, 0.2],
|
||||||
|
[0.5, 0.6]])
|
||||||
|
ms.add_bbox_point(points, example)
|
||||||
|
ms.add_bbox_point(points, example)
|
||||||
|
self.assertEqual(2, ms.get_bbox_point_size(example))
|
||||||
|
self.assertAllClose(points, ms.get_bbox_point_at(0, example))
|
||||||
|
self.assertTrue(ms.has_bbox_point(example))
|
||||||
|
ms.clear_bbox_point(example)
|
||||||
|
self.assertEqual(0, ms.get_bbox_point_size(example))
|
||||||
|
|
||||||
|
def test_3d_point_round_trip(self):
|
||||||
|
example = tf.train.SequenceExample()
|
||||||
|
points = np.array([[0.1, 0.2, 0.3],
|
||||||
|
[0.5, 0.6, 0.7]])
|
||||||
|
ms.add_bbox_3d_point(points, example)
|
||||||
|
ms.add_bbox_3d_point(points, example)
|
||||||
|
self.assertEqual(2, ms.get_bbox_3d_point_size(example))
|
||||||
|
self.assertAllClose(points, ms.get_bbox_3d_point_at(0, example))
|
||||||
|
self.assertTrue(ms.has_bbox_3d_point(example))
|
||||||
|
ms.clear_bbox_3d_point(example)
|
||||||
|
self.assertEqual(0, ms.get_bbox_3d_point_size(example))
|
||||||
|
|
||||||
def test_predicted_bbox_round_trip(self):
|
def test_predicted_bbox_round_trip(self):
|
||||||
example = tf.train.SequenceExample()
|
example = tf.train.SequenceExample()
|
||||||
boxes = np.array([[0.1, 0.2, 0.3, 0.4],
|
boxes = np.array([[0.1, 0.2, 0.3, 0.4],
|
||||||
|
|
|
@ -19,6 +19,14 @@ package(default_visibility = [
|
||||||
"//mediapipe:__subpackages__",
|
"//mediapipe:__subpackages__",
|
||||||
])
|
])
|
||||||
|
|
||||||
|
cc_library(
|
||||||
|
name = "config",
|
||||||
|
hdrs = ["config.h"],
|
||||||
|
deps = [
|
||||||
|
"//mediapipe/framework:calculator_framework",
|
||||||
|
],
|
||||||
|
)
|
||||||
|
|
||||||
cc_library(
|
cc_library(
|
||||||
name = "cpu_op_resolver",
|
name = "cpu_op_resolver",
|
||||||
srcs = ["cpu_op_resolver.cc"],
|
srcs = ["cpu_op_resolver.cc"],
|
||||||
|
@ -69,6 +77,7 @@ cc_test(
|
||||||
srcs = ["tensor_buffer_test.cc"],
|
srcs = ["tensor_buffer_test.cc"],
|
||||||
deps = [
|
deps = [
|
||||||
":tensor_buffer",
|
":tensor_buffer",
|
||||||
|
":config",
|
||||||
"//mediapipe/framework/port:gtest_main",
|
"//mediapipe/framework/port:gtest_main",
|
||||||
] + select({
|
] + select({
|
||||||
"//mediapipe/gpu:disable_gpu": [],
|
"//mediapipe/gpu:disable_gpu": [],
|
||||||
|
@ -99,6 +108,7 @@ cc_library(
|
||||||
"@org_tensorflow//tensorflow/lite:framework",
|
"@org_tensorflow//tensorflow/lite:framework",
|
||||||
"@org_tensorflow//tensorflow/lite/delegates/gpu:api",
|
"@org_tensorflow//tensorflow/lite/delegates/gpu:api",
|
||||||
"@org_tensorflow//tensorflow/lite/delegates/gpu/common:model",
|
"@org_tensorflow//tensorflow/lite/delegates/gpu/common:model",
|
||||||
|
"@org_tensorflow//tensorflow/lite/delegates/gpu/common/testing:tflite_model_reader",
|
||||||
"@org_tensorflow//tensorflow/lite/delegates/gpu/gl:api2",
|
"@org_tensorflow//tensorflow/lite/delegates/gpu/gl:api2",
|
||||||
],
|
],
|
||||||
"//mediapipe:android": [
|
"//mediapipe:android": [
|
||||||
|
@ -108,7 +118,9 @@ cc_library(
|
||||||
"//mediapipe/framework/port:statusor",
|
"//mediapipe/framework/port:statusor",
|
||||||
"@org_tensorflow//tensorflow/lite:framework",
|
"@org_tensorflow//tensorflow/lite:framework",
|
||||||
"@org_tensorflow//tensorflow/lite/delegates/gpu:api",
|
"@org_tensorflow//tensorflow/lite/delegates/gpu:api",
|
||||||
|
"@org_tensorflow//tensorflow/lite/delegates/gpu/cl:api",
|
||||||
"@org_tensorflow//tensorflow/lite/delegates/gpu/common:model",
|
"@org_tensorflow//tensorflow/lite/delegates/gpu/common:model",
|
||||||
|
"@org_tensorflow//tensorflow/lite/delegates/gpu/common/testing:tflite_model_reader",
|
||||||
"@org_tensorflow//tensorflow/lite/delegates/gpu/gl:api2",
|
"@org_tensorflow//tensorflow/lite/delegates/gpu/gl:api2",
|
||||||
],
|
],
|
||||||
}) + ["@org_tensorflow//tensorflow/lite/core/api"],
|
}) + ["@org_tensorflow//tensorflow/lite/core/api"],
|
||||||
|
|
59
mediapipe/util/tflite/config.h
Normal file
59
mediapipe/util/tflite/config.h
Normal file
|
@ -0,0 +1,59 @@
|
||||||
|
// Copyright 2020 The MediaPipe Authors.
|
||||||
|
//
|
||||||
|
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
// you may not use this file except in compliance with the License.
|
||||||
|
// You may obtain a copy of the License at
|
||||||
|
//
|
||||||
|
// http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
//
|
||||||
|
// Unless required by applicable law or agreed to in writing, software
|
||||||
|
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
// See the License for the specific language governing permissions and
|
||||||
|
// limitations under the License.
|
||||||
|
|
||||||
|
#ifndef MEDIAPIPE_UTIL_TFLITE_CONFIG_H_
|
||||||
|
#define MEDIAPIPE_UTIL_TFLITE_CONFIG_H_
|
||||||
|
|
||||||
|
#include "mediapipe/framework/calculator_framework.h"
|
||||||
|
|
||||||
|
// MediaPipe code should use the following defines to determine whether TFLite
|
||||||
|
// GPU support is available, and whether GL or Metal inference is available.
|
||||||
|
|
||||||
|
#ifdef MEDIAPIPE_DISABLE_GL_COMPUTE
|
||||||
|
#define MEDIAPIPE_TFLITE_GL_INFERENCE 0
|
||||||
|
#else
|
||||||
|
#define MEDIAPIPE_TFLITE_GL_INFERENCE 1
|
||||||
|
#endif // MEDIAPIPE_DISABLE_GL_COMPUTE
|
||||||
|
|
||||||
|
#ifdef MEDIAPIPE_IOS
|
||||||
|
#define MEDIAPIPE_TFLITE_METAL_INFERENCE 1
|
||||||
|
#else
|
||||||
|
#define MEDIAPIPE_TFLITE_METAL_INFERENCE 0
|
||||||
|
#endif // MEDIAPIPE_IOS
|
||||||
|
|
||||||
|
#define MEDIAPIPE_TFLITE_GPU_SUPPORTED \
|
||||||
|
((MEDIAPIPE_TFLITE_GL_INFERENCE) || (MEDIAPIPE_TFLITE_METAL_INFERENCE))
|
||||||
|
|
||||||
|
#if MEDIAPIPE_TFLITE_GL_INFERENCE
|
||||||
|
#include "tensorflow/lite/delegates/gpu/gl/gl_buffer.h"
|
||||||
|
#endif // MEDIAPIPE_TFLITE_GL_INFERENCE
|
||||||
|
|
||||||
|
#if MEDIAPIPE_TFLITE_METAL_INFERENCE
|
||||||
|
#import <Metal/Metal.h>
|
||||||
|
#endif // MEDIAPIPE_TFLITE_METAL_INFERENCE
|
||||||
|
|
||||||
|
namespace mediapipe {
|
||||||
|
|
||||||
|
#if MEDIAPIPE_TFLITE_GL_INFERENCE
|
||||||
|
typedef ::tflite::gpu::gl::GlBuffer GpuTensor;
|
||||||
|
#elif MEDIAPIPE_TFLITE_METAL_INFERENCE
|
||||||
|
typedef id<MTLBuffer> GpuTensor;
|
||||||
|
#else
|
||||||
|
struct DummyGpuTensor {};
|
||||||
|
typedef DummyGpuTensor GpuTensor; // Dummy define for less #ifdefs
|
||||||
|
#endif // MEDIAPIPE_TFLITE_GL_INFERENCE
|
||||||
|
|
||||||
|
} // namespace mediapipe
|
||||||
|
|
||||||
|
#endif // MEDIAPIPE_UTIL_TFLITE_CONFIG_H_
|
|
@ -130,8 +130,7 @@ TfLiteStatus Prepare(TfLiteContext* context, TfLiteNode* node) {
|
||||||
auto padding = params->padding;
|
auto padding = params->padding;
|
||||||
auto compute_out_size = [padding](int image_size, int filter_size,
|
auto compute_out_size = [padding](int image_size, int filter_size,
|
||||||
int stride) -> int {
|
int stride) -> int {
|
||||||
return padding == kTfLitePaddingSame
|
return padding == kTfLitePaddingSame ? (image_size + stride - 1) / stride
|
||||||
? (image_size + stride - 1) / stride
|
|
||||||
: padding == kTfLitePaddingValid
|
: padding == kTfLitePaddingValid
|
||||||
? (image_size - filter_size + stride) / stride
|
? (image_size - filter_size + stride) / stride
|
||||||
: 0;
|
: 0;
|
||||||
|
|
|
@ -2,6 +2,7 @@
|
||||||
|
|
||||||
#include "mediapipe/framework/port/gmock.h"
|
#include "mediapipe/framework/port/gmock.h"
|
||||||
#include "mediapipe/framework/port/gtest.h"
|
#include "mediapipe/framework/port/gtest.h"
|
||||||
|
#include "mediapipe/util/tflite/config.h"
|
||||||
|
|
||||||
namespace mediapipe {
|
namespace mediapipe {
|
||||||
|
|
||||||
|
@ -12,7 +13,7 @@ TEST(Cpu, BasicTest) {
|
||||||
EXPECT_FALSE(tb.UsesGpu());
|
EXPECT_FALSE(tb.UsesGpu());
|
||||||
}
|
}
|
||||||
|
|
||||||
#if !defined(MEDIAPIPE_DISABLE_GPU)
|
#if MEDIAPIPE_TFLITE_GL_INFERENCE
|
||||||
TEST(Gpu, BasicTest) {
|
TEST(Gpu, BasicTest) {
|
||||||
TensorBuffer tb;
|
TensorBuffer tb;
|
||||||
std::shared_ptr<tflite::gpu::gl::GlBuffer> tfg_tb =
|
std::shared_ptr<tflite::gpu::gl::GlBuffer> tfg_tb =
|
||||||
|
@ -20,7 +21,7 @@ TEST(Gpu, BasicTest) {
|
||||||
tb = TensorBuffer(tfg_tb);
|
tb = TensorBuffer(tfg_tb);
|
||||||
EXPECT_TRUE(tb.UsesGpu());
|
EXPECT_TRUE(tb.UsesGpu());
|
||||||
}
|
}
|
||||||
#endif // !MEDIAPIPE_DISABLE_GPU
|
#endif // !MEDIAPIPE_TFLITE_GL_INFERENCE
|
||||||
|
|
||||||
} // namespace mediapipe
|
} // namespace mediapipe
|
||||||
|
|
||||||
|
|
|
@ -30,6 +30,13 @@
|
||||||
#include "tensorflow/lite/delegates/gpu/gl/api2.h"
|
#include "tensorflow/lite/delegates/gpu/gl/api2.h"
|
||||||
#include "tensorflow/lite/model.h"
|
#include "tensorflow/lite/model.h"
|
||||||
|
|
||||||
|
// This code should be enabled as soon as TensorFlow version, which mediapipe
|
||||||
|
// uses, will include this module.
|
||||||
|
#ifdef __ANDROID__
|
||||||
|
#include "tensorflow/lite/delegates/gpu/cl/api.h"
|
||||||
|
#endif
|
||||||
|
#include "tensorflow/lite/delegates/gpu/common/testing/tflite_model_reader.h"
|
||||||
|
|
||||||
namespace tflite {
|
namespace tflite {
|
||||||
namespace gpu {
|
namespace gpu {
|
||||||
namespace {
|
namespace {
|
||||||
|
@ -51,6 +58,19 @@ ObjectDef GetSSBOObjectDef(int channels) {
|
||||||
mediapipe::Status TFLiteGPURunner::InitializeWithModel(
|
mediapipe::Status TFLiteGPURunner::InitializeWithModel(
|
||||||
const tflite::FlatBufferModel& flatbuffer,
|
const tflite::FlatBufferModel& flatbuffer,
|
||||||
const tflite::OpResolver& op_resolver) {
|
const tflite::OpResolver& op_resolver) {
|
||||||
|
// GraphFloat32 is created twice because, when OpenCL and OpenGL backends are
|
||||||
|
// initialized, different backend-specific graph transformations happen
|
||||||
|
// in-place. As GraphFloat32 is not copyable by design, we keep two copies of
|
||||||
|
// the graph until inference is built. This decision doesn't affect the amount
|
||||||
|
// of run time memory used, because both graph_gl_ and graph_cl_ are deleted
|
||||||
|
// in the end of the initialization stage.
|
||||||
|
graph_gl_ = std::make_unique<GraphFloat32>();
|
||||||
|
graph_cl_ = std::make_unique<GraphFloat32>();
|
||||||
|
MP_RETURN_IF_ERROR(
|
||||||
|
BuildFromFlatBuffer(flatbuffer, op_resolver, graph_gl_.get()));
|
||||||
|
MP_RETURN_IF_ERROR(
|
||||||
|
BuildFromFlatBuffer(flatbuffer, op_resolver, graph_cl_.get()));
|
||||||
|
|
||||||
for (const auto& input : graph_gl_->inputs()) {
|
for (const auto& input : graph_gl_->inputs()) {
|
||||||
input_shapes_.push_back(input->tensor.shape);
|
input_shapes_.push_back(input->tensor.shape);
|
||||||
}
|
}
|
||||||
|
@ -140,6 +160,19 @@ mediapipe::Status TFLiteGPURunner::InitializeOpenGL(
|
||||||
|
|
||||||
absl::Status TFLiteGPURunner::InitializeOpenCL(
|
absl::Status TFLiteGPURunner::InitializeOpenCL(
|
||||||
std::unique_ptr<InferenceBuilder>* builder) {
|
std::unique_ptr<InferenceBuilder>* builder) {
|
||||||
|
#ifdef __ANDROID__
|
||||||
|
cl::InferenceEnvironmentOptions env_options;
|
||||||
|
cl::InferenceEnvironmentProperties properties;
|
||||||
|
cl::InferenceOptions cl_options;
|
||||||
|
cl_options.priority1 = options_.priority1;
|
||||||
|
cl_options.priority2 = options_.priority2;
|
||||||
|
cl_options.priority3 = options_.priority3;
|
||||||
|
cl_options.usage = options_.usage;
|
||||||
|
MP_RETURN_IF_ERROR(
|
||||||
|
cl::NewInferenceEnvironment(env_options, &cl_environment_, &properties));
|
||||||
|
MP_RETURN_IF_ERROR(cl_environment_->NewInferenceBuilder(
|
||||||
|
cl_options, std::move(*graph_cl_), builder));
|
||||||
|
#endif
|
||||||
return absl::OkStatus();
|
return absl::OkStatus();
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
|
@ -27,6 +27,10 @@
|
||||||
#include "tensorflow/lite/delegates/gpu/gl/api2.h"
|
#include "tensorflow/lite/delegates/gpu/gl/api2.h"
|
||||||
#include "tensorflow/lite/model.h"
|
#include "tensorflow/lite/model.h"
|
||||||
|
|
||||||
|
#ifdef __ANDROID__
|
||||||
|
#include "tensorflow/lite/delegates/gpu/cl/api.h"
|
||||||
|
#endif
|
||||||
|
|
||||||
namespace tflite {
|
namespace tflite {
|
||||||
namespace gpu {
|
namespace gpu {
|
||||||
|
|
||||||
|
@ -64,6 +68,9 @@ class TFLiteGPURunner {
|
||||||
mediapipe::Status Build();
|
mediapipe::Status Build();
|
||||||
mediapipe::Status Invoke();
|
mediapipe::Status Invoke();
|
||||||
|
|
||||||
|
std::vector<BHWC> GetInputShapes() { return input_shapes_; }
|
||||||
|
std::vector<BHWC> GetOutputShapes() { return output_shapes_; }
|
||||||
|
|
||||||
private:
|
private:
|
||||||
mediapipe::Status InitializeOpenGL(
|
mediapipe::Status InitializeOpenGL(
|
||||||
std::unique_ptr<InferenceBuilder>* builder);
|
std::unique_ptr<InferenceBuilder>* builder);
|
||||||
|
@ -73,6 +80,10 @@ class TFLiteGPURunner {
|
||||||
InferenceOptions options_;
|
InferenceOptions options_;
|
||||||
std::unique_ptr<gl::InferenceEnvironment> gl_environment_;
|
std::unique_ptr<gl::InferenceEnvironment> gl_environment_;
|
||||||
|
|
||||||
|
#ifdef __ANDROID__
|
||||||
|
std::unique_ptr<cl::InferenceEnvironment> cl_environment_;
|
||||||
|
#endif
|
||||||
|
|
||||||
// graph_ is maintained temporarily and becomes invalid after runner_ is ready
|
// graph_ is maintained temporarily and becomes invalid after runner_ is ready
|
||||||
std::unique_ptr<GraphFloat32> graph_gl_;
|
std::unique_ptr<GraphFloat32> graph_gl_;
|
||||||
std::unique_ptr<GraphFloat32> graph_cl_;
|
std::unique_ptr<GraphFloat32> graph_cl_;
|
||||||
|
|
50
third_party/com_google_googletest_9d580ea80592189e6d44fa35bcf9cdea8bf620d6.diff
vendored
Normal file
50
third_party/com_google_googletest_9d580ea80592189e6d44fa35bcf9cdea8bf620d6.diff
vendored
Normal file
|
@ -0,0 +1,50 @@
|
||||||
|
diff --git a/googletest/include/gtest/internal/gtest-internal.h b/googletest/include/gtest/internal/gtest-internal.h
|
||||||
|
index 7f1a5b00e..c36029ee1 100644
|
||||||
|
--- a/googletest/include/gtest/internal/gtest-internal.h
|
||||||
|
+++ b/googletest/include/gtest/internal/gtest-internal.h
|
||||||
|
@@ -94,6 +94,12 @@ namespace proto2 {
|
||||||
|
class MessageLite;
|
||||||
|
}
|
||||||
|
|
||||||
|
+namespace google {
|
||||||
|
+namespace protobuf {
|
||||||
|
+class MessageLite;
|
||||||
|
+}
|
||||||
|
+}
|
||||||
|
+
|
||||||
|
namespace testing {
|
||||||
|
|
||||||
|
// Forward declarations.
|
||||||
|
@@ -881,10 +887,15 @@ class GTEST_API_ Random {
|
||||||
|
typename std::remove_const<typename std::remove_reference<T>::type>::type
|
||||||
|
|
||||||
|
// IsAProtocolMessage<T>::value is a compile-time bool constant that's
|
||||||
|
-// true if and only if T is type proto2::MessageLite or a subclass of it.
|
||||||
|
+// true if and only if T is type proto2::MessageLite or
|
||||||
|
+// google::protobuf::MessageLite or a subclass of one of them.
|
||||||
|
template <typename T>
|
||||||
|
struct IsAProtocolMessage
|
||||||
|
- : public std::is_convertible<const T*, const ::proto2::MessageLite*> {};
|
||||||
|
+ : public std::integral_constant<
|
||||||
|
+ bool,
|
||||||
|
+ std::is_convertible<const T*, const ::proto2::MessageLite*>::value ||
|
||||||
|
+ std::is_convertible<
|
||||||
|
+ const T*, const ::google::protobuf::MessageLite*>::value> {};
|
||||||
|
|
||||||
|
// When the compiler sees expression IsContainerTest<C>(0), if C is an
|
||||||
|
// STL-style container class, the first overload of IsContainerTest
|
||||||
|
diff --git a/googletest/test/gtest_unittest.cc b/googletest/test/gtest_unittest.cc
|
||||||
|
index 005a2d40d..631180e3d 100644
|
||||||
|
--- a/googletest/test/gtest_unittest.cc
|
||||||
|
+++ b/googletest/test/gtest_unittest.cc
|
||||||
|
@@ -7115,6 +7115,10 @@ TEST(IsAProtocolMessageTest, ValueIsTrueWhenTypeIsAProtocolMessage) {
|
||||||
|
EXPECT_TRUE(IsAProtocolMessage<::proto2::MessageLite>::value);
|
||||||
|
}
|
||||||
|
|
||||||
|
+TEST(IsAProtocolMessageTest, ValueIsTrueWhenTypeIsAnOpenSourceProtocolMessage) {
|
||||||
|
+ EXPECT_TRUE(IsAProtocolMessage<::google::protobuf::MessageLite>::value);
|
||||||
|
+}
|
||||||
|
+
|
||||||
|
// Tests that IsAProtocolMessage<T>::value is false when T is neither
|
||||||
|
// ::proto2::Message nor a sub-class of it.
|
||||||
|
TEST(IsAProtocolMessageTest, ValueIsFalseWhenTypeIsNotAProtocolMessage) {
|
Loading…
Reference in New Issue
Block a user