Project import generated by Copybara.
GitOrigin-RevId: e3a43e4e5e519cd14df7095749059e2613bdcf76
This commit is contained in:
parent
67bd8a2bf0
commit
e9fbe868e5
2
BUILD
2
BUILD
|
@ -1,4 +1,4 @@
|
|||
# Copyright 2019-2020 The MediaPipe Authors.
|
||||
# Copyright 2019 The MediaPipe Authors.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
|
|
|
@ -101,7 +101,7 @@ run code search using
|
|||
|
||||
## Videos
|
||||
|
||||
* [YouTube Channel](https://www.youtube.com/channel/UCObqmpuSMx-usADtL_qdMAw)
|
||||
* [YouTube Channel](https://www.youtube.com/c/MediaPipe)
|
||||
|
||||
## Events
|
||||
|
||||
|
@ -123,7 +123,7 @@ run code search using
|
|||
|
||||
* [Awesome MediaPipe](https://mediapipe.org) - A curated list of awesome
|
||||
MediaPipe related frameworks, libraries and software
|
||||
* [Slack community](https://mediapipe.slack.com) for MediaPipe users
|
||||
* [Slack community](https://https://mediapipe.page.link/joinslack) for MediaPipe users
|
||||
* [Discuss](https://groups.google.com/forum/#!forum/mediapipe) - General
|
||||
community discussion around MediaPipe
|
||||
|
||||
|
|
13
WORKSPACE
13
WORKSPACE
|
@ -37,10 +37,19 @@ http_archive(
|
|||
)
|
||||
|
||||
# GoogleTest/GoogleMock framework. Used by most unit-tests.
|
||||
# Last updated 2020-06-30.
|
||||
http_archive(
|
||||
name = "com_google_googletest",
|
||||
urls = ["https://github.com/google/googletest/archive/master.zip"],
|
||||
strip_prefix = "googletest-master",
|
||||
urls = ["https://github.com/google/googletest/archive/aee0f9d9b5b87796ee8a0ab26b7587ec30e8858e.zip"],
|
||||
patches = [
|
||||
# fix for https://github.com/google/googletest/issues/2817
|
||||
"@//third_party:com_google_googletest_9d580ea80592189e6d44fa35bcf9cdea8bf620d6.diff"
|
||||
],
|
||||
patch_args = [
|
||||
"-p1",
|
||||
],
|
||||
strip_prefix = "googletest-aee0f9d9b5b87796ee8a0ab26b7587ec30e8858e",
|
||||
sha256 = "04a1751f94244307cebe695a69cc945f9387a80b0ef1af21394a490697c5c895",
|
||||
)
|
||||
|
||||
# Google Benchmark library.
|
||||
|
|
74
build_ios_examples.sh
Normal file
74
build_ios_examples.sh
Normal file
|
@ -0,0 +1,74 @@
|
|||
#!/bin/bash
|
||||
# Copyright 2020 The MediaPipe Authors.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# =========================================================================
|
||||
#
|
||||
# Script to build all MediaPipe iOS example apps.
|
||||
#
|
||||
# To build all apps and store them in out_dir:
|
||||
# $ ./build_ios_examples.sh -d out_dir
|
||||
# Omitting -d and the associated directory saves all generated IPAs in the
|
||||
# current directory.
|
||||
# $ ./build_ios_examples.sh -d out_dir --nostrip
|
||||
# Same as above except that the symnbols are not stripped.
|
||||
|
||||
set -e
|
||||
|
||||
out_dir="."
|
||||
strip=true
|
||||
app_dir="mediapipe/examples/ios"
|
||||
bin_dir="bazel-bin"
|
||||
declare -a default_bazel_flags=(build -c opt --config=ios_arm64)
|
||||
|
||||
while [[ -n $1 ]]; do
|
||||
case $1 in
|
||||
-d)
|
||||
shift
|
||||
out_dir=$1
|
||||
;;
|
||||
--nostrip)
|
||||
strip=false
|
||||
;;
|
||||
*)
|
||||
echo "Unsupported input argument $1."
|
||||
exit 1
|
||||
;;
|
||||
esac
|
||||
shift
|
||||
done
|
||||
|
||||
echo "app_dir: $app_dir"
|
||||
echo "out_dir: $out_dir"
|
||||
echo "strip: $strip"
|
||||
|
||||
declare -a bazel_flags
|
||||
|
||||
apps="${app_dir}/*"
|
||||
for app in ${apps}; do
|
||||
if [[ -d "${app}" ]]; then
|
||||
target_name=${app##*/}
|
||||
target="${app}:${target_name}"
|
||||
|
||||
echo "=== Target: ${target}"
|
||||
|
||||
bazel_flags=("${default_bazel_flags[@]}")
|
||||
bazel_flags+=(${target})
|
||||
if [[ $strip == true ]]; then
|
||||
bazel_flags+=(--linkopt=-s)
|
||||
fi
|
||||
|
||||
bazel "${bazel_flags[@]}"
|
||||
cp -f "${bin_dir}/${app}/"*".ipa" "${out_dir}"
|
||||
fi
|
||||
done
|
|
@ -149,15 +149,15 @@ When possible, these calculators use platform-specific functionality to share da
|
|||
|
||||
The below diagram shows the data flow in a mobile application that captures video from the camera, runs it through a MediaPipe graph, and renders the output on the screen in real time. The dashed line indicates which parts are inside the MediaPipe graph proper. This application runs a Canny edge-detection filter on the CPU using OpenCV, and overlays it on top of the original video using the GPU.
|
||||
|
||||
| ![How GPU calculators interact](../images/gpu_example_graph.png) |
|
||||
| :--------------------------------------------------------------------------: |
|
||||
| *Video frames from the camera are fed into the graph as `GpuBuffer` packets. |
|
||||
: The input stream is accessed by two calculators in parallel. :
|
||||
: `GpuBufferToImageFrameCalculator` converts the buffer into an `ImageFrame`, :
|
||||
: which is then sent through a grayscale converter and a canny filter (both :
|
||||
: based on OpenCV and running on the CPU), whose output is then converted into :
|
||||
: a `GpuBuffer` again. A multi-input GPU calculator, GlOverlayCalculator, :
|
||||
: takes as input both the original `GpuBuffer` and the one coming out of the :
|
||||
: edge detector, and overlays them using a shader. The output is then sent :
|
||||
: back to the application using a callback calculator, and the application :
|
||||
: renders the image to the screen using OpenGL.* :
|
||||
![How GPU calculators interact](../images/gpu_example_graph.png)
|
||||
|
||||
Video frames from the camera are fed into the graph as `GpuBuffer` packets. The
|
||||
input stream is accessed by two calculators in parallel.
|
||||
`GpuBufferToImageFrameCalculator` converts the buffer into an `ImageFrame`,
|
||||
which is then sent through a grayscale converter and a canny filter (both based
|
||||
on OpenCV and running on the CPU), whose output is then converted into a
|
||||
`GpuBuffer` again. A multi-input GPU calculator, GlOverlayCalculator, takes as
|
||||
input both the original `GpuBuffer` and the one coming out of the edge detector,
|
||||
and overlays them using a shader. The output is then sent back to the
|
||||
application using a callback calculator, and the application renders the image
|
||||
to the screen using OpenGL.
|
||||
|
|
|
@ -184,12 +184,8 @@ app:
|
|||
|
||||
### Prerequisite
|
||||
|
||||
1. Install [Xcode](https://developer.apple.com/xcode/) and the Command Line
|
||||
Tools.
|
||||
|
||||
Follow Apple's instructions to obtain the required development certificates
|
||||
and provisioning profiles for your iOS device. Install the Command Line
|
||||
Tools by
|
||||
1. Install [Xcode](https://developer.apple.com/xcode/), and additionally
|
||||
install the Command Line Tools by:
|
||||
|
||||
```bash
|
||||
xcode-select --install
|
||||
|
@ -209,26 +205,31 @@ app:
|
|||
pip3 install --user six
|
||||
```
|
||||
|
||||
4. Clone the MediaPipe repository.
|
||||
4. Follow
|
||||
[Apple's instructions](https://developer.apple.com/support/certificates/) to
|
||||
obtain the required development certificates and provisioning profiles for
|
||||
your iOS device.
|
||||
|
||||
Tip: You can the following command to see the provisioning profiles you have
|
||||
previously downloaded using Xcode: `open
|
||||
~/Library/MobileDevice/"Provisioning Profiles"`. If there are none, generate
|
||||
and download a profile on
|
||||
[Apple's developer site](https://developer.apple.com/account/resources/).
|
||||
|
||||
5. Clone the MediaPipe repository.
|
||||
|
||||
```bash
|
||||
git clone https://github.com/google/mediapipe.git
|
||||
```
|
||||
|
||||
5. Symlink or copy your provisioning profile to
|
||||
`mediapipe/mediapipe/provisioning_profile.mobileprovision`.
|
||||
6. In the cloned MediaPipe repository, symlink or copy your provisioning profile
|
||||
to `mediapipe/provisioning_profile.mobileprovision`, e.g.,
|
||||
|
||||
```bash
|
||||
cd mediapipe
|
||||
ln -s ~/Downloads/MyProvisioningProfile.mobileprovision mediapipe/provisioning_profile.mobileprovision
|
||||
```
|
||||
|
||||
Tip: You can use this command to see the provisioning profiles you have
|
||||
previously downloaded using Xcode: `open
|
||||
~/Library/MobileDevice/"Provisioning Profiles"`. If there are none, generate
|
||||
and download a profile on
|
||||
[Apple's developer site](https://developer.apple.com/account/resources/).
|
||||
|
||||
### Option 1: Build with Bazel in Command Line
|
||||
|
||||
1. Modify the `bundle_id` field of the app's `ios_application` build target to
|
||||
|
@ -246,6 +247,10 @@ app:
|
|||
|
||||
You may see a permission request from `codesign` in order to sign the app.
|
||||
|
||||
Tip: You can run this
|
||||
[script](https://github.com/google/mediapipe/blob/master/build_ios_examples.sh)
|
||||
to build all MediaPipe iOS example apps.
|
||||
|
||||
3. In Xcode, open the `Devices and Simulators` window (command-shift-2).
|
||||
|
||||
4. Make sure your device is connected. You will see a list of installed apps.
|
||||
|
|
|
@ -44,6 +44,18 @@ apps, see these [instructions](./building_examples.md#ios).
|
|||
[Bazel documentation](https://docs.bazel.build/versions/master/install-ubuntu.html)
|
||||
to install Bazel 2.0 or higher.
|
||||
|
||||
For Nvidia Jetson and Raspberry Pi devices with ARM Ubuntu, Bazel needs to
|
||||
be built from source.
|
||||
|
||||
```bash
|
||||
# For Bazel 3.0.0
|
||||
wget https://github.com/bazelbuild/bazel/releases/download/3.0.0/bazel-3.0.0-dist.zip
|
||||
sudo apt-get install build-essential openjdk-8-jdk python zip unzip
|
||||
unzip bazel-3.0.0-dist.zip
|
||||
env EXTRA_BAZEL_ARGS="--host_javabase=@local_jdk//:jdk" bash ./compile.sh
|
||||
sudo cp output/bazel /usr/local/bin/
|
||||
```
|
||||
|
||||
3. Install OpenCV and FFmpeg.
|
||||
|
||||
Option 1. Use package manager tool to install the pre-compiled OpenCV
|
||||
|
@ -58,6 +70,14 @@ apps, see these [instructions](./building_examples.md#ios).
|
|||
libopencv-imgproc-dev libopencv-video-dev
|
||||
```
|
||||
|
||||
[`opencv_linux.BUILD`] is configured for x86_64 by default. For Nvidia
|
||||
Jetson and Raspberry Pi devices with ARM Ubuntu, the lib paths need to be
|
||||
modified.
|
||||
|
||||
```bash
|
||||
sed -i "s/x86_64-linux-gnu/aarch64-linux-gnu/g" third_party/opencv_linux.BUILD
|
||||
```
|
||||
|
||||
Option 2. Run [`setup_opencv.sh`] to automatically build OpenCV from source
|
||||
and modify MediaPipe's OpenCV config.
|
||||
|
||||
|
@ -493,14 +513,14 @@ cameras. Alternatively, you use a video file as input.
|
|||
|
||||
```bash
|
||||
username@DESKTOP-TMVLBJ1:~$ curl -sLO --retry 5 --retry-max-time 10 \
|
||||
https://storage.googleapis.com/bazel/2.0.0/release/bazel-2.0.0-installer-linux-x86_64.sh && \
|
||||
sudo mkdir -p /usr/local/bazel/2.0.0 && \
|
||||
chmod 755 bazel-2.0.0-installer-linux-x86_64.sh && \
|
||||
sudo ./bazel-2.0.0-installer-linux-x86_64.sh --prefix=/usr/local/bazel/2.0.0 && \
|
||||
source /usr/local/bazel/2.0.0/lib/bazel/bin/bazel-complete.bash
|
||||
https://storage.googleapis.com/bazel/3.0.0/release/bazel-3.0.0-installer-linux-x86_64.sh && \
|
||||
sudo mkdir -p /usr/local/bazel/3.0.0 && \
|
||||
chmod 755 bazel-3.0.0-installer-linux-x86_64.sh && \
|
||||
sudo ./bazel-3.0.0-installer-linux-x86_64.sh --prefix=/usr/local/bazel/3.0.0 && \
|
||||
source /usr/local/bazel/3.0.0/lib/bazel/bin/bazel-complete.bash
|
||||
|
||||
username@DESKTOP-TMVLBJ1:~$ /usr/local/bazel/2.0.0/lib/bazel/bin/bazel version && \
|
||||
alias bazel='/usr/local/bazel/2.0.0/lib/bazel/bin/bazel'
|
||||
username@DESKTOP-TMVLBJ1:~$ /usr/local/bazel/3.0.0/lib/bazel/bin/bazel version && \
|
||||
alias bazel='/usr/local/bazel/3.0.0/lib/bazel/bin/bazel'
|
||||
```
|
||||
|
||||
6. Checkout MediaPipe repository.
|
||||
|
|
|
@ -101,7 +101,7 @@ run code search using
|
|||
|
||||
## Videos
|
||||
|
||||
* [YouTube Channel](https://www.youtube.com/channel/UCObqmpuSMx-usADtL_qdMAw)
|
||||
* [YouTube Channel](https://www.youtube.com/c/MediaPipe)
|
||||
|
||||
## Events
|
||||
|
||||
|
@ -123,7 +123,7 @@ run code search using
|
|||
|
||||
* [Awesome MediaPipe](https://mediapipe.org) - A curated list of awesome
|
||||
MediaPipe related frameworks, libraries and software
|
||||
* [Slack community](https://mediapipe.slack.com) for MediaPipe users
|
||||
* [Slack community](https://https://mediapipe.page.link/joinslack) for MediaPipe users
|
||||
* [Discuss](https://groups.google.com/forum/#!forum/mediapipe) - General
|
||||
community discussion around MediaPipe
|
||||
|
||||
|
|
|
@ -1,6 +1,6 @@
|
|||
---
|
||||
layout: default
|
||||
title: Hand
|
||||
title: Hands
|
||||
parent: Solutions
|
||||
nav_order: 3
|
||||
---
|
||||
|
@ -219,9 +219,13 @@ Please refer to [these instructions](../index.md#mediapipe-on-the-web).
|
|||
|
||||
## Resources
|
||||
|
||||
* Google AI Blog: [On-Device, Real-Time Hand Tracking with MediaPipe](https://ai.googleblog.com/2019/08/on-device-real-time-hand-tracking-with.html)
|
||||
* TensorFlow Blog: [Face and hand tracking in the browser with MediaPipe and
|
||||
TensorFlow.js](https://blog.tensorflow.org/2020/03/face-and-hand-tracking-in-browser-with-mediapipe-and-tensorflowjs.html)
|
||||
* Google AI Blog:
|
||||
[On-Device, Real-Time Hand Tracking with MediaPipe](https://ai.googleblog.com/2019/08/on-device-real-time-hand-tracking-with.html)
|
||||
* TensorFlow Blog:
|
||||
[Face and hand tracking in the browser with MediaPipe and TensorFlow.js](https://blog.tensorflow.org/2020/03/face-and-hand-tracking-in-browser-with-mediapipe-and-tensorflowjs.html)
|
||||
* Paper:
|
||||
[MediaPipe Hands: On-device Real-time Hand Tracking](https://arxiv.org/abs/2006.10214)
|
||||
([presentation](https://www.youtube.com/watch?v=I-UOrvxxXEk))
|
||||
* Palm detection model:
|
||||
[TFLite model](https://github.com/google/mediapipe/tree/master/mediapipe/models/palm_detection.tflite),
|
||||
[TF.js model](https://tfhub.dev/mediapipe/handdetector/1)
|
||||
|
|
|
@ -188,5 +188,8 @@ to visualize its associated subgraphs, please see
|
|||
[Real-Time 3D Object Detection on Mobile Devices with MediaPipe](https://ai.googleblog.com/2020/03/real-time-3d-object-detection-on-mobile.html)
|
||||
* Paper: [MobilePose: Real-Time Pose Estimation for Unseen Objects with Weak
|
||||
Shape Supervision](https://arxiv.org/abs/2003.03522)
|
||||
* Paper:
|
||||
[Instant 3D Object Tracking with Applications in Augmented Reality](https://drive.google.com/open?id=1O_zHmlgXIzAdKljp20U_JUkEHOGG52R8)
|
||||
([presentation](https://www.youtube.com/watch?v=9ndF1AIo7h0))
|
||||
* [TFLite model for shoes](https://github.com/google/mediapipe/tree/master/mediapipe/models/object_detection_3d_sneakers.tflite)
|
||||
* [TFLite model for chairs](https://github.com/google/mediapipe/tree/master/mediapipe/models/object_detection_3d_chair.tflite)
|
||||
|
|
|
@ -21,7 +21,16 @@ available on Linux, Android, or iOS.
|
|||
|
||||
## Enabling tracing and profiling
|
||||
|
||||
To enable tracing/profiling of a mediapipe graph, the `CalculatorGraphConfig` (in
|
||||
To enable tracing and profiling of a mediapipe graph:
|
||||
|
||||
1. The profiling library must be linked to the framework.
|
||||
2. Tracing and profiling must be enabled in the graph configuration.
|
||||
|
||||
The profiling library is linked to the framework by default. If needed,
|
||||
the profiling library can be omitted from the framework using the bazel
|
||||
command line option: `--define MEDIAPIPE_PROFILING=0`.
|
||||
|
||||
To enable tracing and profiling, the `CalculatorGraphConfig` (in
|
||||
[calculator.proto](https://github.com/google/mediapipe/tree/master/mediapipe/framework/calculator.proto))
|
||||
representing the graph must have a `profiler_config` message at its root. Here
|
||||
is a simple setup that turns on a few extra options:
|
||||
|
|
|
@ -386,14 +386,13 @@ REGISTER_CALCULATOR(ImageTransformationCalculator);
|
|||
|
||||
const int input_width = input_mat.cols;
|
||||
const int input_height = input_mat.rows;
|
||||
if (!output_height_ || !output_width_) {
|
||||
output_height_ = input_height;
|
||||
output_width_ = input_width;
|
||||
}
|
||||
int output_width;
|
||||
int output_height;
|
||||
ComputeOutputDimensions(input_width, input_height, &output_width,
|
||||
&output_height);
|
||||
|
||||
if (output_width_ > 0 && output_height_ > 0) {
|
||||
cv::Mat scaled_mat;
|
||||
int output_width = output_width_;
|
||||
int output_height = output_height_;
|
||||
if (scale_mode_ == mediapipe::ScaleMode_Mode_STRETCH) {
|
||||
int scale_flag =
|
||||
input_mat.cols > output_width_ && input_mat.rows > output_height_
|
||||
|
@ -416,7 +415,8 @@ REGISTER_CALCULATOR(ImageTransformationCalculator);
|
|||
const int bottom = output_height_ - target_height - top;
|
||||
const int left = (output_width_ - target_width) / 2;
|
||||
const int right = output_width_ - target_width - left;
|
||||
cv::copyMakeBorder(intermediate_mat, scaled_mat, top, bottom, left, right,
|
||||
cv::copyMakeBorder(intermediate_mat, scaled_mat, top, bottom, left,
|
||||
right,
|
||||
options_.constant_padding() ? cv::BORDER_CONSTANT
|
||||
: cv::BORDER_REPLICATE);
|
||||
} else {
|
||||
|
@ -426,6 +426,8 @@ REGISTER_CALCULATOR(ImageTransformationCalculator);
|
|||
output_height = target_height;
|
||||
}
|
||||
}
|
||||
input_mat = scaled_mat;
|
||||
}
|
||||
|
||||
if (cc->Outputs().HasTag("LETTERBOX_PADDING")) {
|
||||
auto padding = absl::make_unique<std::array<float, 4>>();
|
||||
|
@ -437,10 +439,33 @@ REGISTER_CALCULATOR(ImageTransformationCalculator);
|
|||
}
|
||||
|
||||
cv::Mat rotated_mat;
|
||||
cv::Size rotated_size(output_width, output_height);
|
||||
if (input_mat.size() == rotated_size) {
|
||||
const int angle = RotationModeToDegrees(rotation_);
|
||||
cv::Point2f src_center(scaled_mat.cols / 2.0, scaled_mat.rows / 2.0);
|
||||
cv::Point2f src_center(input_mat.cols / 2.0, input_mat.rows / 2.0);
|
||||
cv::Mat rotation_mat = cv::getRotationMatrix2D(src_center, angle, 1.0);
|
||||
cv::warpAffine(scaled_mat, rotated_mat, rotation_mat, scaled_mat.size());
|
||||
cv::warpAffine(input_mat, rotated_mat, rotation_mat, rotated_size);
|
||||
} else {
|
||||
switch (rotation_) {
|
||||
case mediapipe::RotationMode_Mode_UNKNOWN:
|
||||
case mediapipe::RotationMode_Mode_ROTATION_0:
|
||||
LOG(ERROR) << "Not rotating image.";
|
||||
rotated_mat = input_mat;
|
||||
break;
|
||||
case mediapipe::RotationMode_Mode_ROTATION_90:
|
||||
LOG(ERROR) << "Rotating image by 90 degrees ccw.";
|
||||
cv::rotate(input_mat, rotated_mat, cv::ROTATE_90_COUNTERCLOCKWISE);
|
||||
break;
|
||||
case mediapipe::RotationMode_Mode_ROTATION_180:
|
||||
LOG(ERROR) << "Rotating image by 180 degrees.";
|
||||
cv::rotate(input_mat, rotated_mat, cv::ROTATE_180);
|
||||
break;
|
||||
case mediapipe::RotationMode_Mode_ROTATION_270:
|
||||
LOG(ERROR) << "Rotating image by 90 degrees cw.";
|
||||
cv::rotate(input_mat, rotated_mat, cv::ROTATE_90_CLOCKWISE);
|
||||
break;
|
||||
}
|
||||
}
|
||||
|
||||
cv::Mat flipped_mat;
|
||||
if (flip_horizontally_ || flip_vertically_) {
|
||||
|
|
|
@ -139,7 +139,6 @@ class TensorFlowSessionFromSavedModelGenerator : public PacketGenerator {
|
|||
static_cast<::mediapipe::StatusCode>(status.code()),
|
||||
status.ToString());
|
||||
}
|
||||
|
||||
auto session = absl::make_unique<TensorFlowSession>();
|
||||
session->session = std::move(saved_model->session);
|
||||
|
||||
|
|
|
@ -14,6 +14,7 @@
|
|||
#
|
||||
|
||||
load("//mediapipe/framework/port:build_config.bzl", "mediapipe_cc_proto_library")
|
||||
load("@bazel_skylib//lib:selects.bzl", "selects")
|
||||
|
||||
licenses(["notice"]) # Apache 2.0
|
||||
|
||||
|
@ -202,6 +203,13 @@ cc_library(
|
|||
alwayslink = 1,
|
||||
)
|
||||
|
||||
selects.config_setting_group(
|
||||
name = "gpu_inference_disabled",
|
||||
match_any = [
|
||||
"//mediapipe/gpu:disable_gpu",
|
||||
],
|
||||
)
|
||||
|
||||
cc_library(
|
||||
name = "tflite_inference_calculator",
|
||||
srcs = ["tflite_inference_calculator.cc"],
|
||||
|
@ -226,13 +234,14 @@ cc_library(
|
|||
"@com_google_absl//absl/memory",
|
||||
"//mediapipe/framework:calculator_framework",
|
||||
"//mediapipe/util:resource_util",
|
||||
"//mediapipe/util/tflite:config",
|
||||
"@org_tensorflow//tensorflow/lite:framework",
|
||||
"@org_tensorflow//tensorflow/lite/delegates/xnnpack:xnnpack_delegate",
|
||||
"@org_tensorflow//tensorflow/lite/kernels:builtin_ops",
|
||||
"//mediapipe/framework/stream_handler:fixed_size_input_stream_handler",
|
||||
"//mediapipe/framework/port:ret_check",
|
||||
] + select({
|
||||
"//mediapipe/gpu:disable_gpu": [],
|
||||
] + selects.with_or({
|
||||
":gpu_inference_disabled": [],
|
||||
"//mediapipe:ios": [
|
||||
"//mediapipe/gpu:MPPMetalHelper",
|
||||
"//mediapipe/gpu:MPPMetalUtil",
|
||||
|
@ -285,6 +294,7 @@ cc_library(
|
|||
}),
|
||||
visibility = ["//visibility:public"],
|
||||
deps = [
|
||||
"//mediapipe/util/tflite:config",
|
||||
":util",
|
||||
":tflite_converter_calculator_cc_proto",
|
||||
"//mediapipe/util:resource_util",
|
||||
|
@ -295,23 +305,26 @@ cc_library(
|
|||
"//mediapipe/framework/port:ret_check",
|
||||
"@org_tensorflow//tensorflow/lite:framework",
|
||||
"@org_tensorflow//tensorflow/lite/kernels:builtin_ops",
|
||||
] + select({
|
||||
"//mediapipe/gpu:disable_gpu": [],
|
||||
] + selects.with_or({
|
||||
":gpu_inference_disabled": [],
|
||||
"//mediapipe:ios": [
|
||||
"//mediapipe/gpu:MPPMetalUtil",
|
||||
"//mediapipe/gpu:gpu_buffer",
|
||||
"//mediapipe/gpu:MPPMetalHelper",
|
||||
"//mediapipe/objc:mediapipe_framework_ios",
|
||||
"@org_tensorflow//tensorflow/lite/delegates/gpu:metal_delegate",
|
||||
],
|
||||
"//conditions:default": [
|
||||
"//mediapipe/gpu:gpu_buffer",
|
||||
"//mediapipe/gpu:gl_calculator_helper",
|
||||
"@org_tensorflow//tensorflow/lite/delegates/gpu:gl_delegate",
|
||||
"@org_tensorflow//tensorflow/lite/delegates/gpu/gl:gl_buffer",
|
||||
"@org_tensorflow//tensorflow/lite/delegates/gpu/gl:gl_program",
|
||||
"@org_tensorflow//tensorflow/lite/delegates/gpu/gl:gl_shader",
|
||||
],
|
||||
}) + select({
|
||||
"//mediapipe/gpu:disable_gpu": [],
|
||||
"//conditions:default": [
|
||||
"//mediapipe/gpu:gpu_buffer",
|
||||
],
|
||||
}),
|
||||
alwayslink = 1,
|
||||
)
|
||||
|
@ -348,8 +361,8 @@ cc_library(
|
|||
"//mediapipe/framework:calculator_framework",
|
||||
"//mediapipe/util:resource_util",
|
||||
"@org_tensorflow//tensorflow/lite:framework",
|
||||
] + select({
|
||||
"//mediapipe/gpu:disable_gpu": [],
|
||||
] + selects.with_or({
|
||||
":gpu_inference_disabled": [],
|
||||
"//mediapipe:ios": [],
|
||||
"//conditions:default": [
|
||||
"//mediapipe/gpu:gl_calculator_helper",
|
||||
|
@ -404,6 +417,7 @@ cc_library(
|
|||
}),
|
||||
visibility = ["//visibility:public"],
|
||||
deps = [
|
||||
"//mediapipe/util/tflite:config",
|
||||
":util",
|
||||
":tflite_tensors_to_detections_calculator_cc_proto",
|
||||
"//mediapipe/framework/formats:detection_cc_proto",
|
||||
|
@ -415,8 +429,8 @@ cc_library(
|
|||
"//mediapipe/framework/formats/object_detection:anchor_cc_proto",
|
||||
"//mediapipe/framework/port:ret_check",
|
||||
"@org_tensorflow//tensorflow/lite:framework",
|
||||
] + select({
|
||||
"//mediapipe/gpu:disable_gpu": [],
|
||||
] + selects.with_or({
|
||||
":gpu_inference_disabled": [],
|
||||
"//mediapipe:ios": [
|
||||
"//mediapipe/gpu:MPPMetalUtil",
|
||||
"//mediapipe/gpu:gpu_buffer",
|
||||
|
@ -492,6 +506,8 @@ cc_library(
|
|||
alwayslink = 1,
|
||||
)
|
||||
|
||||
# To run this with native GPU on Linux, use:
|
||||
# bazel test //mediapipe/calculators/tflite:tflite_inference_calculator_test --copt=-DTFLITE_GPU_EXTRA_GLES_DEPS --copt=-DMESA_EGL_NO_X11_HEADERS --copt=-DEGL_NO_X11 --config=grte_v5 --test_strategy=local
|
||||
cc_test(
|
||||
name = "tflite_inference_calculator_test",
|
||||
srcs = ["tflite_inference_calculator_test.cc"],
|
||||
|
|
|
@ -22,19 +22,23 @@
|
|||
#include "mediapipe/framework/formats/matrix.h"
|
||||
#include "mediapipe/framework/port/ret_check.h"
|
||||
#include "mediapipe/util/resource_util.h"
|
||||
#include "mediapipe/util/tflite/config.h"
|
||||
#include "tensorflow/lite/error_reporter.h"
|
||||
#include "tensorflow/lite/interpreter.h"
|
||||
|
||||
#if !defined(MEDIAPIPE_DISABLE_GL_COMPUTE)
|
||||
#include "mediapipe/gpu/gl_calculator_helper.h"
|
||||
#ifndef MEDIAPIPE_DISABLE_GPU
|
||||
#include "mediapipe/gpu/gpu_buffer.h"
|
||||
#endif // MEDIAPIPE_DISABLE_GPU
|
||||
|
||||
#if MEDIAPIPE_TFLITE_GL_INFERENCE
|
||||
#include "mediapipe/gpu/gl_calculator_helper.h"
|
||||
#include "tensorflow/lite/delegates/gpu/gl/gl_buffer.h"
|
||||
#include "tensorflow/lite/delegates/gpu/gl/gl_program.h"
|
||||
#include "tensorflow/lite/delegates/gpu/gl/gl_shader.h"
|
||||
#include "tensorflow/lite/delegates/gpu/gl_delegate.h"
|
||||
#endif // !MEDIAPIPE_DISABLE_GPU
|
||||
#endif // MEDIAPIPE_TFLITE_GL_INFERENCE
|
||||
|
||||
#if defined(MEDIAPIPE_IOS)
|
||||
#if MEDIAPIPE_TFLITE_METAL_INFERENCE
|
||||
#import <CoreVideo/CoreVideo.h>
|
||||
#import <Metal/Metal.h>
|
||||
#import <MetalKit/MetalKit.h>
|
||||
|
@ -43,13 +47,7 @@
|
|||
#include "mediapipe/gpu/MPPMetalUtil.h"
|
||||
#include "mediapipe/gpu/gpu_buffer.h"
|
||||
#include "tensorflow/lite/delegates/gpu/metal_delegate.h"
|
||||
#endif // iOS
|
||||
|
||||
#if !defined(MEDIAPIPE_DISABLE_GL_COMPUTE)
|
||||
typedef ::tflite::gpu::gl::GlBuffer GpuTensor;
|
||||
#elif defined(MEDIAPIPE_IOS)
|
||||
typedef id<MTLBuffer> GpuTensor;
|
||||
#endif
|
||||
#endif // MEDIAPIPE_TFLITE_METAL_INFERENCE
|
||||
|
||||
namespace {
|
||||
constexpr int kWorkgroupSize = 8; // Block size for GPU shader.
|
||||
|
@ -73,7 +71,7 @@ constexpr char kMatrixTag[] = "MATRIX";
|
|||
namespace mediapipe {
|
||||
|
||||
namespace {
|
||||
#if !defined(MEDIAPIPE_DISABLE_GL_COMPUTE)
|
||||
#if MEDIAPIPE_TFLITE_GL_INFERENCE
|
||||
using ::tflite::gpu::gl::CreateReadWriteShaderStorageBuffer;
|
||||
using ::tflite::gpu::gl::GlProgram;
|
||||
using ::tflite::gpu::gl::GlShader;
|
||||
|
@ -83,13 +81,13 @@ struct GPUData {
|
|||
GlShader shader;
|
||||
GlProgram program;
|
||||
};
|
||||
#elif defined(MEDIAPIPE_IOS)
|
||||
#elif MEDIAPIPE_TFLITE_METAL_INFERENCE
|
||||
struct GPUData {
|
||||
int elements = 1;
|
||||
GpuTensor buffer;
|
||||
id<MTLComputePipelineState> pipeline_state;
|
||||
};
|
||||
#endif
|
||||
#endif // MEDIAPIPE_TFLITE_GL_INFERENCE
|
||||
|
||||
} // namespace
|
||||
|
||||
|
@ -157,13 +155,13 @@ class TfLiteConverterCalculator : public CalculatorBase {
|
|||
|
||||
std::unique_ptr<tflite::Interpreter> interpreter_ = nullptr;
|
||||
|
||||
#if !defined(MEDIAPIPE_DISABLE_GL_COMPUTE)
|
||||
#if MEDIAPIPE_TFLITE_GL_INFERENCE
|
||||
mediapipe::GlCalculatorHelper gpu_helper_;
|
||||
std::unique_ptr<GPUData> gpu_data_out_;
|
||||
#elif defined(MEDIAPIPE_IOS)
|
||||
#elif MEDIAPIPE_TFLITE_METAL_INFERENCE
|
||||
MPPMetalHelper* gpu_helper_ = nullptr;
|
||||
std::unique_ptr<GPUData> gpu_data_out_;
|
||||
#endif
|
||||
#endif // MEDIAPIPE_TFLITE_GL_INFERENCE
|
||||
|
||||
bool initialized_ = false;
|
||||
bool use_gpu_ = false;
|
||||
|
@ -178,6 +176,18 @@ class TfLiteConverterCalculator : public CalculatorBase {
|
|||
};
|
||||
REGISTER_CALCULATOR(TfLiteConverterCalculator);
|
||||
|
||||
namespace {
|
||||
template <class CC>
|
||||
bool ShouldUseGpu(CC* cc) {
|
||||
#if MEDIAPIPE_TFLITE_GPU_SUPPORTED
|
||||
return cc->Inputs().HasTag(kGpuBufferTag) ||
|
||||
cc->Outputs().HasTag(kTensorsGpuTag);
|
||||
#else
|
||||
return false;
|
||||
#endif // MEDIAPIPE_TFLITE_GPU_SUPPORTED
|
||||
}
|
||||
} // namespace
|
||||
|
||||
::mediapipe::Status TfLiteConverterCalculator::GetContract(
|
||||
CalculatorContract* cc) {
|
||||
// Confirm only one of the input streams is present.
|
||||
|
@ -189,37 +199,31 @@ REGISTER_CALCULATOR(TfLiteConverterCalculator);
|
|||
RET_CHECK(cc->Outputs().HasTag(kTensorsTag) ^
|
||||
cc->Outputs().HasTag(kTensorsGpuTag));
|
||||
|
||||
bool use_gpu = false;
|
||||
|
||||
if (cc->Inputs().HasTag(kImageFrameTag)) {
|
||||
cc->Inputs().Tag(kImageFrameTag).Set<ImageFrame>();
|
||||
}
|
||||
if (cc->Inputs().HasTag(kMatrixTag)) {
|
||||
cc->Inputs().Tag(kMatrixTag).Set<Matrix>();
|
||||
}
|
||||
#if !defined(MEDIAPIPE_DISABLE_GPU) && !defined(__EMSCRIPTEN__)
|
||||
#ifndef MEDIAPIPE_DISABLE_GPU
|
||||
if (cc->Inputs().HasTag(kGpuBufferTag)) {
|
||||
cc->Inputs().Tag(kGpuBufferTag).Set<mediapipe::GpuBuffer>();
|
||||
use_gpu |= true;
|
||||
}
|
||||
#endif // !MEDIAPIPE_DISABLE_GPU
|
||||
#endif // MEDIAPIPE_DISABLE_GPU
|
||||
|
||||
if (cc->Outputs().HasTag(kTensorsTag)) {
|
||||
cc->Outputs().Tag(kTensorsTag).Set<std::vector<TfLiteTensor>>();
|
||||
}
|
||||
#if !defined(MEDIAPIPE_DISABLE_GPU) && !defined(__EMSCRIPTEN__)
|
||||
if (cc->Outputs().HasTag(kTensorsGpuTag)) {
|
||||
cc->Outputs().Tag(kTensorsGpuTag).Set<std::vector<GpuTensor>>();
|
||||
use_gpu |= true;
|
||||
}
|
||||
#endif // !MEDIAPIPE_DISABLE_GPU
|
||||
|
||||
if (use_gpu) {
|
||||
#if !defined(MEDIAPIPE_DISABLE_GL_COMPUTE)
|
||||
if (ShouldUseGpu(cc)) {
|
||||
#if MEDIAPIPE_TFLITE_GL_INFERENCE
|
||||
MP_RETURN_IF_ERROR(mediapipe::GlCalculatorHelper::UpdateContract(cc));
|
||||
#elif defined(MEDIAPIPE_IOS)
|
||||
#elif MEDIAPIPE_TFLITE_METAL_INFERENCE
|
||||
MP_RETURN_IF_ERROR([MPPMetalHelper updateContract:cc]);
|
||||
#endif
|
||||
#endif // MEDIAPIPE_TFLITE_GL_INFERENCE
|
||||
}
|
||||
|
||||
// Assign this calculator's default InputStreamHandler.
|
||||
|
@ -233,14 +237,7 @@ REGISTER_CALCULATOR(TfLiteConverterCalculator);
|
|||
|
||||
MP_RETURN_IF_ERROR(LoadOptions(cc));
|
||||
|
||||
if (cc->Inputs().HasTag(kGpuBufferTag) ||
|
||||
cc->Outputs().HasTag(kGpuBufferTag)) {
|
||||
#if !defined(MEDIAPIPE_DISABLE_GPU) && !defined(__EMSCRIPTEN__)
|
||||
use_gpu_ = true;
|
||||
#else
|
||||
RET_CHECK_FAIL() << "GPU processing not enabled.";
|
||||
#endif
|
||||
}
|
||||
use_gpu_ = ShouldUseGpu(cc);
|
||||
|
||||
if (use_gpu_) {
|
||||
// Cannot mix CPU/GPU streams.
|
||||
|
@ -248,12 +245,12 @@ REGISTER_CALCULATOR(TfLiteConverterCalculator);
|
|||
cc->Outputs().HasTag(kTensorsGpuTag));
|
||||
// Cannot use quantization.
|
||||
use_quantized_tensors_ = false;
|
||||
#if !defined(MEDIAPIPE_DISABLE_GL_COMPUTE)
|
||||
#if MEDIAPIPE_TFLITE_GL_INFERENCE
|
||||
MP_RETURN_IF_ERROR(gpu_helper_.Open(cc));
|
||||
#elif defined(MEDIAPIPE_IOS)
|
||||
#elif MEDIAPIPE_TFLITE_METAL_INFERENCE
|
||||
gpu_helper_ = [[MPPMetalHelper alloc] initWithCalculatorContext:cc];
|
||||
RET_CHECK(gpu_helper_);
|
||||
#endif
|
||||
#endif // MEDIAPIPE_TFLITE_GL_INFERENCE
|
||||
} else {
|
||||
interpreter_ = absl::make_unique<tflite::Interpreter>();
|
||||
interpreter_->AddTensors(1);
|
||||
|
@ -282,12 +279,12 @@ REGISTER_CALCULATOR(TfLiteConverterCalculator);
|
|||
}
|
||||
|
||||
::mediapipe::Status TfLiteConverterCalculator::Close(CalculatorContext* cc) {
|
||||
#if !defined(MEDIAPIPE_DISABLE_GL_COMPUTE)
|
||||
interpreter_.reset();
|
||||
#if MEDIAPIPE_TFLITE_GL_INFERENCE
|
||||
gpu_helper_.RunInGlContext([this] { gpu_data_out_.reset(); });
|
||||
#endif
|
||||
#if defined(MEDIAPIPE_IOS)
|
||||
#elif MEDIAPIPE_TFLITE_METAL_INFERENCE
|
||||
gpu_data_out_.reset();
|
||||
#endif
|
||||
#endif // MEDIAPIPE_TFLITE_GL_INFERENCE
|
||||
return ::mediapipe::OkStatus();
|
||||
}
|
||||
|
||||
|
@ -318,8 +315,14 @@ REGISTER_CALCULATOR(TfLiteConverterCalculator);
|
|||
RET_CHECK(format != mediapipe::ImageFormat::VEC32F1)
|
||||
<< "Only 8-bit input images are supported for quantization.";
|
||||
quant.type = kTfLiteAffineQuantization;
|
||||
quant.params = nullptr;
|
||||
// Optional: Set 'quant' quantization params here if needed.
|
||||
auto quant_params = static_cast<TfLiteAffineQuantization*>(
|
||||
malloc(sizeof(TfLiteAffineQuantization)));
|
||||
quant_params->scale = TfLiteFloatArrayCreate(1);
|
||||
quant_params->scale->data[0] = 1.0;
|
||||
quant_params->zero_point = TfLiteIntArrayCreate(1);
|
||||
quant_params->zero_point->data[0] = 0;
|
||||
quant_params->quantized_dimension = 0;
|
||||
quant.params = quant_params;
|
||||
interpreter_->SetTensorParametersReadWrite(0, kTfLiteUInt8, "",
|
||||
{channels_preserved}, quant);
|
||||
} else {
|
||||
|
@ -414,7 +417,7 @@ REGISTER_CALCULATOR(TfLiteConverterCalculator);
|
|||
|
||||
::mediapipe::Status TfLiteConverterCalculator::ProcessGPU(
|
||||
CalculatorContext* cc) {
|
||||
#if !defined(MEDIAPIPE_DISABLE_GL_COMPUTE)
|
||||
#if MEDIAPIPE_TFLITE_GL_INFERENCE
|
||||
// GpuBuffer to tflite::gpu::GlBuffer conversion.
|
||||
const auto& input =
|
||||
cc->Inputs().Tag(kGpuBufferTag).Get<mediapipe::GpuBuffer>();
|
||||
|
@ -451,7 +454,7 @@ REGISTER_CALCULATOR(TfLiteConverterCalculator);
|
|||
cc->Outputs()
|
||||
.Tag(kTensorsGpuTag)
|
||||
.Add(output_tensors.release(), cc->InputTimestamp());
|
||||
#elif defined(MEDIAPIPE_IOS)
|
||||
#elif MEDIAPIPE_TFLITE_METAL_INFERENCE
|
||||
// GpuBuffer to id<MTLBuffer> conversion.
|
||||
const auto& input =
|
||||
cc->Inputs().Tag(kGpuBufferTag).Get<mediapipe::GpuBuffer>();
|
||||
|
@ -490,13 +493,13 @@ REGISTER_CALCULATOR(TfLiteConverterCalculator);
|
|||
.Add(output_tensors.release(), cc->InputTimestamp());
|
||||
#else
|
||||
RET_CHECK_FAIL() << "GPU processing is not enabled.";
|
||||
#endif
|
||||
#endif // MEDIAPIPE_TFLITE_GL_INFERENCE
|
||||
|
||||
return ::mediapipe::OkStatus();
|
||||
}
|
||||
|
||||
::mediapipe::Status TfLiteConverterCalculator::InitGpu(CalculatorContext* cc) {
|
||||
#if !defined(MEDIAPIPE_DISABLE_GPU) && !defined(__EMSCRIPTEN__)
|
||||
#if MEDIAPIPE_TFLITE_GPU_SUPPORTED
|
||||
// Get input image sizes.
|
||||
const auto& input =
|
||||
cc->Inputs().Tag(kGpuBufferTag).Get<mediapipe::GpuBuffer>();
|
||||
|
@ -512,9 +515,9 @@ REGISTER_CALCULATOR(TfLiteConverterCalculator);
|
|||
RET_CHECK_FAIL() << "Unsupported GPU input format.";
|
||||
if (include_alpha && (format != mediapipe::ImageFormat::SRGBA))
|
||||
RET_CHECK_FAIL() << "Num input channels is less than desired output.";
|
||||
#endif // !MEDIAPIPE_DISABLE_GPU
|
||||
#endif // MEDIAPIPE_TFLITE_GPU_SUPPORTED
|
||||
|
||||
#if !defined(MEDIAPIPE_DISABLE_GL_COMPUTE)
|
||||
#if MEDIAPIPE_TFLITE_GL_INFERENCE
|
||||
MP_RETURN_IF_ERROR(gpu_helper_.RunInGlContext(
|
||||
[this, &include_alpha, &input, &single_channel]() -> ::mediapipe::Status {
|
||||
// Device memory.
|
||||
|
@ -559,7 +562,7 @@ REGISTER_CALCULATOR(TfLiteConverterCalculator);
|
|||
return ::mediapipe::OkStatus();
|
||||
}));
|
||||
|
||||
#elif defined(MEDIAPIPE_IOS)
|
||||
#elif MEDIAPIPE_TFLITE_METAL_INFERENCE
|
||||
|
||||
RET_CHECK(include_alpha)
|
||||
<< "iOS GPU inference currently accepts only RGBA input.";
|
||||
|
@ -616,7 +619,7 @@ REGISTER_CALCULATOR(TfLiteConverterCalculator);
|
|||
RET_CHECK(gpu_data_out_->pipeline_state != nil)
|
||||
<< "Couldn't create pipeline state "
|
||||
<< [[error localizedDescription] UTF8String];
|
||||
#endif
|
||||
#endif // MEDIAPIPE_TFLITE_GL_INFERENCE
|
||||
|
||||
return ::mediapipe::OkStatus();
|
||||
}
|
||||
|
|
|
@ -22,6 +22,7 @@
|
|||
#include "mediapipe/calculators/tflite/util.h"
|
||||
#include "mediapipe/framework/calculator_framework.h"
|
||||
#include "mediapipe/framework/port/ret_check.h"
|
||||
#include "mediapipe/util/tflite/config.h"
|
||||
|
||||
#if !defined(__EMSCRIPTEN__) || defined(__EMSCRIPTEN_PTHREADS__)
|
||||
#include "mediapipe/util/cpu_util.h"
|
||||
|
@ -33,7 +34,7 @@
|
|||
#include "tensorflow/lite/kernels/register.h"
|
||||
#include "tensorflow/lite/model.h"
|
||||
|
||||
#if !defined(MEDIAPIPE_DISABLE_GL_COMPUTE)
|
||||
#if MEDIAPIPE_TFLITE_GL_INFERENCE
|
||||
#include "mediapipe/gpu/gl_calculator_helper.h"
|
||||
#include "mediapipe/gpu/gpu_buffer.h"
|
||||
#include "mediapipe/util/tflite/tflite_gpu_runner.h"
|
||||
|
@ -42,9 +43,9 @@
|
|||
#include "tensorflow/lite/delegates/gpu/gl/gl_program.h"
|
||||
#include "tensorflow/lite/delegates/gpu/gl/gl_shader.h"
|
||||
#include "tensorflow/lite/delegates/gpu/gl_delegate.h"
|
||||
#endif // !MEDIAPIPE_DISABLE_GL_COMPUTE
|
||||
#endif // MEDIAPIPE_TFLITE_GL_INFERENCE
|
||||
|
||||
#if defined(MEDIAPIPE_IOS)
|
||||
#if MEDIAPIPE_TFLITE_METAL_INFERENCE
|
||||
#import <CoreVideo/CoreVideo.h>
|
||||
#import <Metal/Metal.h>
|
||||
#import <MetalKit/MetalKit.h>
|
||||
|
@ -56,7 +57,7 @@
|
|||
#include "tensorflow/lite/delegates/gpu/metal/buffer_convert.h"
|
||||
#include "tensorflow/lite/delegates/gpu/metal_delegate.h"
|
||||
#include "tensorflow/lite/delegates/gpu/metal_delegate_internal.h"
|
||||
#endif // iOS
|
||||
#endif // MEDIAPIPE_TFLITE_METAL_INFERENCE
|
||||
|
||||
#if !defined(MEDIAPIPE_EDGE_TPU)
|
||||
#include "tensorflow/lite/delegates/xnnpack/xnnpack_delegate.h"
|
||||
|
@ -71,12 +72,6 @@ int NumGroups(const int size, const int group_size) { // NOLINT
|
|||
return (size + group_size - 1) / group_size;
|
||||
}
|
||||
|
||||
#if !defined(MEDIAPIPE_DISABLE_GL_COMPUTE)
|
||||
typedef ::tflite::gpu::gl::GlBuffer GpuTensor;
|
||||
#elif defined(MEDIAPIPE_IOS)
|
||||
typedef id<MTLBuffer> GpuTensor;
|
||||
#endif
|
||||
|
||||
// Round up n to next multiple of m.
|
||||
size_t RoundUp(size_t n, size_t m) { return ((n + m - 1) / m) * m; } // NOLINT
|
||||
|
||||
|
@ -112,13 +107,13 @@ std::unique_ptr<tflite::Interpreter> BuildEdgeTpuInterpreter(
|
|||
// * Aux
|
||||
namespace mediapipe {
|
||||
|
||||
#if !defined(MEDIAPIPE_DISABLE_GL_COMPUTE)
|
||||
#if MEDIAPIPE_TFLITE_GL_INFERENCE
|
||||
using ::tflite::gpu::gl::CopyBuffer;
|
||||
using ::tflite::gpu::gl::CreateReadWriteShaderStorageBuffer;
|
||||
using ::tflite::gpu::gl::GlBuffer;
|
||||
#endif
|
||||
|
||||
#if !defined(MEDIAPIPE_DISABLE_GPU) && !defined(__EMSCRIPTEN__)
|
||||
#if MEDIAPIPE_TFLITE_GPU_SUPPORTED
|
||||
namespace {
|
||||
struct GPUData {
|
||||
int elements = 1;
|
||||
|
@ -126,7 +121,7 @@ struct GPUData {
|
|||
::tflite::gpu::BHWC shape;
|
||||
};
|
||||
} // namespace
|
||||
#endif
|
||||
#endif // MEDIAPIPE_TFLITE_GPU_SUPPORTED
|
||||
|
||||
// Returns number of threads to configure XNNPACK delegate with.
|
||||
// (Equal to user provided value if specified. Otherwise, it returns number of
|
||||
|
@ -152,7 +147,7 @@ int GetXnnpackNumThreads(
|
|||
// Creates an interpreter with given model and calls invoke().
|
||||
// Optionally run inference on CPU/GPU.
|
||||
//
|
||||
// This calculator is designed to be used with the TfLiteConverterCalcualtor,
|
||||
// This calculator is designed to be used with the TfLiteConverterCalculator,
|
||||
// to get the appropriate inputs.
|
||||
//
|
||||
// When the input tensors are on CPU, gpu inference is optional and can be
|
||||
|
@ -183,7 +178,6 @@ int GetXnnpackNumThreads(
|
|||
// options: {
|
||||
// [mediapipe.TfLiteInferenceCalculatorOptions.ext] {
|
||||
// model_path: "modelname.tflite"
|
||||
// delegate { gpu {} }
|
||||
// }
|
||||
// }
|
||||
// }
|
||||
|
@ -192,11 +186,12 @@ int GetXnnpackNumThreads(
|
|||
//
|
||||
// node {
|
||||
// calculator: "TfLiteInferenceCalculator"
|
||||
// input_stream: "TENSORS:tensor_image"
|
||||
// input_stream: "TENSORS_GPU:tensor_image"
|
||||
// input_side_packet: "MODEL:model"
|
||||
// output_stream: "TENSORS:tensors"
|
||||
// output_stream: "TENSORS_GPU:tensors"
|
||||
// options: {
|
||||
// [mediapipe.TfLiteInferenceCalculatorOptions.ext] {
|
||||
// model_path: "modelname.tflite"
|
||||
// delegate { gpu {} }
|
||||
// }
|
||||
// }
|
||||
|
@ -228,24 +223,45 @@ class TfLiteInferenceCalculator : public CalculatorBase {
|
|||
::mediapipe::Status LoadModel(CalculatorContext* cc);
|
||||
::mediapipe::StatusOr<Packet> GetModelAsPacket(const CalculatorContext& cc);
|
||||
::mediapipe::Status LoadDelegate(CalculatorContext* cc);
|
||||
::mediapipe::Status InitTFLiteGPURunner();
|
||||
::mediapipe::Status InitTFLiteGPURunner(CalculatorContext* cc);
|
||||
::mediapipe::Status ProcessInputsCpu(
|
||||
CalculatorContext* cc, std::vector<TfLiteTensor>* output_tensors_cpu);
|
||||
::mediapipe::Status ProcessOutputsCpu(
|
||||
CalculatorContext* cc,
|
||||
std::unique_ptr<std::vector<TfLiteTensor>> output_tensors_cpu);
|
||||
::mediapipe::Status ProcessInputsGpu(
|
||||
CalculatorContext* cc, std::vector<GpuTensor>* output_tensors_gpu);
|
||||
::mediapipe::Status ProcessOutputsGpu(
|
||||
CalculatorContext* cc,
|
||||
std::unique_ptr<std::vector<TfLiteTensor>> output_tensors_cpu,
|
||||
std::unique_ptr<std::vector<GpuTensor>> output_tensors_gpu);
|
||||
|
||||
::mediapipe::Status RunInContextIfNeeded(
|
||||
std::function<::mediapipe::Status(void)> f) {
|
||||
if (gpu_inference_) {
|
||||
#if MEDIAPIPE_TFLITE_GL_INFERENCE
|
||||
return gpu_helper_.RunInGlContext(std::move(f));
|
||||
#endif // MEDIAPIPE_TFLITE_GL_INFERENCE
|
||||
}
|
||||
return f();
|
||||
}
|
||||
|
||||
Packet model_packet_;
|
||||
std::unique_ptr<tflite::Interpreter> interpreter_;
|
||||
TfLiteDelegatePtr delegate_;
|
||||
|
||||
#if !defined(MEDIAPIPE_DISABLE_GL_COMPUTE)
|
||||
#if MEDIAPIPE_TFLITE_GL_INFERENCE
|
||||
mediapipe::GlCalculatorHelper gpu_helper_;
|
||||
std::vector<std::unique_ptr<GPUData>> gpu_data_in_;
|
||||
std::vector<std::unique_ptr<GPUData>> gpu_data_out_;
|
||||
std::unique_ptr<tflite::gpu::TFLiteGPURunner> tflite_gpu_runner_;
|
||||
#elif defined(MEDIAPIPE_IOS)
|
||||
#elif MEDIAPIPE_TFLITE_METAL_INFERENCE
|
||||
MPPMetalHelper* gpu_helper_ = nullptr;
|
||||
std::vector<std::unique_ptr<GPUData>> gpu_data_in_;
|
||||
std::vector<std::unique_ptr<GPUData>> gpu_data_out_;
|
||||
id<MTLComputePipelineState> fp32_to_fp16_program_;
|
||||
TFLBufferConvert* converter_from_BPHWC4_ = nil;
|
||||
#endif
|
||||
#endif // MEDIAPIPE_TFLITE_GL_INFERENCE
|
||||
|
||||
#if defined(MEDIAPIPE_EDGE_TPU)
|
||||
std::shared_ptr<edgetpu::EdgeTpuContext> edgetpu_context_ =
|
||||
|
@ -263,6 +279,22 @@ REGISTER_CALCULATOR(TfLiteInferenceCalculator);
|
|||
|
||||
// Calculator Core Section
|
||||
|
||||
namespace {
|
||||
template <class CC>
|
||||
bool ShouldUseGpu(CC* cc) {
|
||||
#if MEDIAPIPE_TFLITE_GPU_SUPPORTED
|
||||
const auto& options =
|
||||
cc->template Options<::mediapipe::TfLiteInferenceCalculatorOptions>();
|
||||
return options.use_gpu() ||
|
||||
(options.has_delegate() && options.delegate().has_gpu()) ||
|
||||
cc->Inputs().HasTag(kTensorsGpuTag) ||
|
||||
cc->Outputs().HasTag(kTensorsGpuTag);
|
||||
#else
|
||||
return false;
|
||||
#endif // MEDIAPIPE_TFLITE_GPU_SUPPORTED
|
||||
}
|
||||
} // namespace
|
||||
|
||||
::mediapipe::Status TfLiteInferenceCalculator::GetContract(
|
||||
CalculatorContract* cc) {
|
||||
RET_CHECK(cc->Inputs().HasTag(kTensorsTag) ^
|
||||
|
@ -276,32 +308,15 @@ REGISTER_CALCULATOR(TfLiteInferenceCalculator);
|
|||
cc->InputSidePackets().HasTag("MODEL"))
|
||||
<< "Either model as side packet or model path in options is required.";
|
||||
|
||||
bool use_gpu =
|
||||
options.has_delegate() ? options.delegate().has_gpu() : options.use_gpu();
|
||||
|
||||
if (cc->Inputs().HasTag(kTensorsTag))
|
||||
cc->Inputs().Tag(kTensorsTag).Set<std::vector<TfLiteTensor>>();
|
||||
#if !defined(MEDIAPIPE_DISABLE_GPU) && !defined(__EMSCRIPTEN__)
|
||||
if (cc->Inputs().HasTag(kTensorsGpuTag)) {
|
||||
RET_CHECK(!options.has_delegate() || options.delegate().has_gpu())
|
||||
<< "GPU input is compatible with GPU delegate only.";
|
||||
|
||||
cc->Inputs().Tag(kTensorsGpuTag).Set<std::vector<GpuTensor>>();
|
||||
use_gpu |= true;
|
||||
}
|
||||
#endif // !MEDIAPIPE_DISABLE_GPU
|
||||
|
||||
if (cc->Outputs().HasTag(kTensorsTag))
|
||||
cc->Outputs().Tag(kTensorsTag).Set<std::vector<TfLiteTensor>>();
|
||||
#if !defined(MEDIAPIPE_DISABLE_GPU) && !defined(__EMSCRIPTEN__)
|
||||
if (cc->Outputs().HasTag(kTensorsGpuTag)) {
|
||||
RET_CHECK(!options.has_delegate() || options.delegate().has_gpu())
|
||||
<< "GPU output is compatible with GPU delegate only.";
|
||||
|
||||
if (cc->Inputs().HasTag(kTensorsGpuTag))
|
||||
cc->Inputs().Tag(kTensorsGpuTag).Set<std::vector<GpuTensor>>();
|
||||
if (cc->Outputs().HasTag(kTensorsGpuTag))
|
||||
cc->Outputs().Tag(kTensorsGpuTag).Set<std::vector<GpuTensor>>();
|
||||
use_gpu |= true;
|
||||
}
|
||||
#endif // !MEDIAPIPE_DISABLE_GPU
|
||||
|
||||
if (cc->InputSidePackets().HasTag("CUSTOM_OP_RESOLVER")) {
|
||||
cc->InputSidePackets()
|
||||
|
@ -312,10 +327,10 @@ REGISTER_CALCULATOR(TfLiteInferenceCalculator);
|
|||
cc->InputSidePackets().Tag("MODEL").Set<TfLiteModelPtr>();
|
||||
}
|
||||
|
||||
if (use_gpu) {
|
||||
#if !defined(MEDIAPIPE_DISABLE_GL_COMPUTE)
|
||||
if (ShouldUseGpu(cc)) {
|
||||
#if MEDIAPIPE_TFLITE_GL_INFERENCE
|
||||
MP_RETURN_IF_ERROR(mediapipe::GlCalculatorHelper::UpdateContract(cc));
|
||||
#elif defined(MEDIAPIPE_IOS)
|
||||
#elif MEDIAPIPE_TFLITE_METAL_INFERENCE
|
||||
MP_RETURN_IF_ERROR([MPPMetalHelper updateContract:cc]);
|
||||
#endif
|
||||
}
|
||||
|
@ -331,149 +346,111 @@ REGISTER_CALCULATOR(TfLiteInferenceCalculator);
|
|||
|
||||
const auto& options =
|
||||
cc->Options<::mediapipe::TfLiteInferenceCalculatorOptions>();
|
||||
gpu_inference_ = options.use_gpu();
|
||||
|
||||
if (cc->Inputs().HasTag(kTensorsGpuTag)) {
|
||||
#if !defined(MEDIAPIPE_DISABLE_GPU) && !defined(__EMSCRIPTEN__)
|
||||
gpu_input_ = true;
|
||||
gpu_inference_ = true; // Inference must be on GPU also.
|
||||
#else
|
||||
RET_CHECK(!cc->Inputs().HasTag(kTensorsGpuTag))
|
||||
<< "GPU processing not enabled.";
|
||||
#endif // !MEDIAPIPE_DISABLE_GPU
|
||||
}
|
||||
gpu_inference_ = ShouldUseGpu(cc);
|
||||
gpu_input_ = cc->Inputs().HasTag(kTensorsGpuTag);
|
||||
gpu_output_ = cc->Outputs().HasTag(kTensorsGpuTag);
|
||||
|
||||
if (cc->Outputs().HasTag(kTensorsGpuTag)) {
|
||||
#if !defined(MEDIAPIPE_DISABLE_GPU) && !defined(__EMSCRIPTEN__)
|
||||
gpu_output_ = true;
|
||||
RET_CHECK(cc->Inputs().HasTag(kTensorsGpuTag))
|
||||
<< "GPU output must also have GPU Input.";
|
||||
#else
|
||||
RET_CHECK(!cc->Inputs().HasTag(kTensorsGpuTag))
|
||||
<< "GPU processing not enabled.";
|
||||
#endif // !MEDIAPIPE_DISABLE_GPU
|
||||
}
|
||||
|
||||
use_advanced_gpu_api_ = false;
|
||||
if (use_advanced_gpu_api_ && !(gpu_input_ && gpu_output_)) {
|
||||
LOG(WARNING)
|
||||
<< "Cannot use advanced GPU APIs, both inputs and outputs must "
|
||||
"be GPU buffers. Falling back to the default TFLite API.";
|
||||
use_advanced_gpu_api_ = MEDIAPIPE_TFLITE_GL_INFERENCE &&
|
||||
options.has_delegate() &&
|
||||
options.delegate().has_gpu() &&
|
||||
options.delegate().gpu().use_advanced_gpu_api();
|
||||
if (use_advanced_gpu_api_ && !gpu_input_) {
|
||||
LOG(WARNING) << "Cannot use advanced GPU APIs, input must be GPU buffers."
|
||||
"Falling back to the default TFLite API.";
|
||||
use_advanced_gpu_api_ = false;
|
||||
}
|
||||
CHECK(!use_advanced_gpu_api_ || gpu_inference_);
|
||||
|
||||
MP_RETURN_IF_ERROR(LoadModel(cc));
|
||||
|
||||
if (gpu_inference_) {
|
||||
#if !defined(MEDIAPIPE_DISABLE_GL_COMPUTE)
|
||||
#if MEDIAPIPE_TFLITE_GL_INFERENCE
|
||||
MP_RETURN_IF_ERROR(gpu_helper_.Open(cc));
|
||||
#elif defined(MEDIAPIPE_IOS)
|
||||
gpu_helper_ = [[MPPMetalHelper alloc] initWithCalculatorContext:cc];
|
||||
RET_CHECK(gpu_helper_);
|
||||
#endif
|
||||
|
||||
#if !defined(MEDIAPIPE_DISABLE_GL_COMPUTE)
|
||||
MP_RETURN_IF_ERROR(
|
||||
gpu_helper_.RunInGlContext([this, &cc]() -> ::mediapipe::Status {
|
||||
return use_advanced_gpu_api_ ? InitTFLiteGPURunner()
|
||||
return use_advanced_gpu_api_ ? InitTFLiteGPURunner(cc)
|
||||
: LoadDelegate(cc);
|
||||
}));
|
||||
if (use_advanced_gpu_api_) return ::mediapipe::OkStatus();
|
||||
#else
|
||||
#elif MEDIAPIPE_TFLITE_METAL_INFERENCE
|
||||
gpu_helper_ = [[MPPMetalHelper alloc] initWithCalculatorContext:cc];
|
||||
RET_CHECK(gpu_helper_);
|
||||
MP_RETURN_IF_ERROR(LoadDelegate(cc));
|
||||
#endif
|
||||
} else {
|
||||
#if defined(__EMSCRIPTEN__) || defined(MEDIAPIPE_ANDROID)
|
||||
// TODO: why only on these platforms?
|
||||
// It seems that the XNNPACK delegate fails to load on Linux.
|
||||
#if defined(__EMSCRIPTEN__) || defined(MEDIAPIPE_ANDROID) || \
|
||||
defined(MEDIAPIPE_IOS)
|
||||
MP_RETURN_IF_ERROR(LoadDelegate(cc));
|
||||
#endif // __EMSCRIPTEN__ || ANDROID
|
||||
#endif // __EMSCRIPTEN__ || MEDIAPIPE_ANDROID || MEDIAPIPE_IOS
|
||||
}
|
||||
return ::mediapipe::OkStatus();
|
||||
}
|
||||
|
||||
::mediapipe::Status TfLiteInferenceCalculator::Process(CalculatorContext* cc) {
|
||||
return RunInContextIfNeeded([this, cc]() -> ::mediapipe::Status {
|
||||
// 0. Declare outputs
|
||||
#if !defined(MEDIAPIPE_DISABLE_GL_COMPUTE) || defined(MEDIAPIPE_IOS)
|
||||
auto output_tensors_gpu = absl::make_unique<std::vector<GpuTensor>>();
|
||||
#endif
|
||||
auto output_tensors_cpu = absl::make_unique<std::vector<TfLiteTensor>>();
|
||||
|
||||
// 1. Receive pre-processed tensor inputs.
|
||||
if (use_advanced_gpu_api_ && gpu_output_) {
|
||||
#if !defined(MEDIAPIPE_DISABLE_GL_COMPUTE)
|
||||
if (cc->Inputs().Tag(kTensorsGpuTag).IsEmpty()) {
|
||||
return ::mediapipe::OkStatus();
|
||||
if (gpu_input_) {
|
||||
MP_RETURN_IF_ERROR(ProcessInputsGpu(cc, output_tensors_gpu.get()));
|
||||
} else {
|
||||
MP_RETURN_IF_ERROR(ProcessInputsCpu(cc, output_tensors_cpu.get()));
|
||||
}
|
||||
const auto& input_tensors =
|
||||
cc->Inputs().Tag(kTensorsGpuTag).Get<std::vector<GpuTensor>>();
|
||||
RET_CHECK(!input_tensors.empty());
|
||||
MP_RETURN_IF_ERROR(gpu_helper_.RunInGlContext(
|
||||
[this, &input_tensors, &output_tensors_gpu]() -> ::mediapipe::Status {
|
||||
for (int i = 0; i < input_tensors.size(); ++i) {
|
||||
MP_RETURN_IF_ERROR(tflite_gpu_runner_->BindSSBOToInputTensor(
|
||||
input_tensors[i].id(), i));
|
||||
|
||||
// 2. Run inference.
|
||||
#if MEDIAPIPE_TFLITE_GL_INFERENCE
|
||||
if (gpu_inference_ && use_advanced_gpu_api_) {
|
||||
RET_CHECK(tflite_gpu_runner_->Invoke().ok());
|
||||
} else {
|
||||
RET_CHECK_EQ(interpreter_->Invoke(), kTfLiteOk);
|
||||
}
|
||||
// Allocate output tensor.
|
||||
output_tensors_gpu->resize(gpu_data_out_.size());
|
||||
for (int i = 0; i < gpu_data_out_.size(); ++i) {
|
||||
GpuTensor& tensor = output_tensors_gpu->at(i);
|
||||
RET_CHECK_CALL(CreateReadWriteShaderStorageBuffer<float>(
|
||||
gpu_data_out_[i]->elements, &tensor));
|
||||
MP_RETURN_IF_ERROR(
|
||||
tflite_gpu_runner_->BindSSBOToOutputTensor(tensor.id(), i));
|
||||
}
|
||||
return ::mediapipe::OkStatus();
|
||||
}));
|
||||
#endif
|
||||
} else if (gpu_input_) {
|
||||
// Read GPU input into SSBO.
|
||||
#if !defined(MEDIAPIPE_DISABLE_GL_COMPUTE)
|
||||
if (cc->Inputs().Tag(kTensorsGpuTag).IsEmpty()) {
|
||||
return ::mediapipe::OkStatus();
|
||||
}
|
||||
const auto& input_tensors =
|
||||
cc->Inputs().Tag(kTensorsGpuTag).Get<std::vector<GpuTensor>>();
|
||||
RET_CHECK_GT(input_tensors.size(), 0);
|
||||
MP_RETURN_IF_ERROR(gpu_helper_.RunInGlContext(
|
||||
[this, &input_tensors]() -> ::mediapipe::Status {
|
||||
// Explicit copy input.
|
||||
gpu_data_in_.resize(input_tensors.size());
|
||||
for (int i = 0; i < input_tensors.size(); ++i) {
|
||||
RET_CHECK_CALL(
|
||||
CopyBuffer(input_tensors[i], gpu_data_in_[i]->buffer));
|
||||
#else
|
||||
RET_CHECK_EQ(interpreter_->Invoke(), kTfLiteOk);
|
||||
#endif // MEDIAPIPE_TFLITE_GL_INFERENCE
|
||||
|
||||
// 3. Output processed tensors.
|
||||
if (gpu_output_ || use_advanced_gpu_api_) {
|
||||
MP_RETURN_IF_ERROR(ProcessOutputsGpu(cc, std::move(output_tensors_cpu),
|
||||
std::move(output_tensors_gpu)));
|
||||
} else {
|
||||
MP_RETURN_IF_ERROR(ProcessOutputsCpu(cc, std::move(output_tensors_cpu)));
|
||||
}
|
||||
|
||||
return ::mediapipe::OkStatus();
|
||||
}));
|
||||
#elif defined(MEDIAPIPE_IOS)
|
||||
if (cc->Inputs().Tag(kTensorsGpuTag).IsEmpty()) {
|
||||
return ::mediapipe::OkStatus();
|
||||
});
|
||||
}
|
||||
|
||||
::mediapipe::Status TfLiteInferenceCalculator::Close(CalculatorContext* cc) {
|
||||
return RunInContextIfNeeded([this]() -> ::mediapipe::Status {
|
||||
if (delegate_) {
|
||||
interpreter_ = nullptr;
|
||||
delegate_ = nullptr;
|
||||
#if MEDIAPIPE_TFLITE_GPU_SUPPORTED
|
||||
if (gpu_inference_) {
|
||||
for (int i = 0; i < gpu_data_in_.size(); ++i) {
|
||||
gpu_data_in_[i].reset();
|
||||
}
|
||||
const auto& input_tensors =
|
||||
cc->Inputs().Tag(kTensorsGpuTag).Get<std::vector<GpuTensor>>();
|
||||
RET_CHECK_GT(input_tensors.size(), 0);
|
||||
// Explicit copy input with conversion float 32 bits to 16 bits.
|
||||
gpu_data_in_.resize(input_tensors.size());
|
||||
id<MTLCommandBuffer> command_buffer = [gpu_helper_ commandBuffer];
|
||||
command_buffer.label = @"TfLiteInferenceCalculatorConvert";
|
||||
id<MTLComputeCommandEncoder> compute_encoder =
|
||||
[command_buffer computeCommandEncoder];
|
||||
[compute_encoder setComputePipelineState:fp32_to_fp16_program_];
|
||||
for (int i = 0; i < input_tensors.size(); ++i) {
|
||||
[compute_encoder setBuffer:input_tensors[i] offset:0 atIndex:0];
|
||||
[compute_encoder setBuffer:gpu_data_in_[i]->buffer offset:0 atIndex:1];
|
||||
constexpr int kWorkgroupSize = 64; // Block size for GPU shader.
|
||||
MTLSize threads_per_group = MTLSizeMake(kWorkgroupSize, 1, 1);
|
||||
const int threadgroups =
|
||||
NumGroups(gpu_data_in_[i]->elements, kWorkgroupSize);
|
||||
[compute_encoder dispatchThreadgroups:MTLSizeMake(threadgroups, 1, 1)
|
||||
threadsPerThreadgroup:threads_per_group];
|
||||
for (int i = 0; i < gpu_data_out_.size(); ++i) {
|
||||
gpu_data_out_[i].reset();
|
||||
}
|
||||
[compute_encoder endEncoding];
|
||||
[command_buffer commit];
|
||||
#else
|
||||
RET_CHECK_FAIL() << "GPU processing not enabled.";
|
||||
}
|
||||
#endif // MEDIAPIPE_TFLITE_GPU_SUPPORTED
|
||||
}
|
||||
#if defined(MEDIAPIPE_EDGE_TPU)
|
||||
edgetpu_context_.reset();
|
||||
#endif
|
||||
} else {
|
||||
return ::mediapipe::OkStatus();
|
||||
});
|
||||
}
|
||||
|
||||
// Calculator Auxiliary Section
|
||||
|
||||
::mediapipe::Status TfLiteInferenceCalculator::ProcessInputsCpu(
|
||||
CalculatorContext* cc, std::vector<TfLiteTensor>* output_tensors_cpu) {
|
||||
if (cc->Inputs().Tag(kTensorsTag).IsEmpty()) {
|
||||
return ::mediapipe::OkStatus();
|
||||
}
|
||||
|
@ -496,39 +473,128 @@ REGISTER_CALCULATOR(TfLiteInferenceCalculator);
|
|||
input_tensor->bytes);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// 2. Run inference.
|
||||
if (gpu_inference_) {
|
||||
#if !defined(MEDIAPIPE_DISABLE_GL_COMPUTE)
|
||||
MP_RETURN_IF_ERROR(
|
||||
gpu_helper_.RunInGlContext([this]() -> ::mediapipe::Status {
|
||||
if (use_advanced_gpu_api_) {
|
||||
RET_CHECK(tflite_gpu_runner_->Invoke().ok());
|
||||
} else {
|
||||
RET_CHECK_EQ(interpreter_->Invoke(), kTfLiteOk);
|
||||
}
|
||||
return ::mediapipe::OkStatus();
|
||||
}));
|
||||
#elif defined(MEDIAPIPE_IOS)
|
||||
RET_CHECK_EQ(interpreter_->Invoke(), kTfLiteOk);
|
||||
#endif
|
||||
}
|
||||
|
||||
::mediapipe::Status TfLiteInferenceCalculator::ProcessInputsGpu(
|
||||
CalculatorContext* cc, std::vector<GpuTensor>* output_tensors_gpu) {
|
||||
if (cc->Inputs().Tag(kTensorsGpuTag).IsEmpty()) {
|
||||
return ::mediapipe::OkStatus();
|
||||
}
|
||||
if (use_advanced_gpu_api_) {
|
||||
#if MEDIAPIPE_TFLITE_GL_INFERENCE
|
||||
const auto& input_tensors =
|
||||
cc->Inputs().Tag(kTensorsGpuTag).Get<std::vector<GpuTensor>>();
|
||||
RET_CHECK(!input_tensors.empty());
|
||||
for (int i = 0; i < input_tensors.size(); ++i) {
|
||||
MP_RETURN_IF_ERROR(
|
||||
tflite_gpu_runner_->BindSSBOToInputTensor(input_tensors[i].id(), i));
|
||||
}
|
||||
if (gpu_output_) {
|
||||
// Allocate new output tensor.
|
||||
output_tensors_gpu->resize(gpu_data_out_.size());
|
||||
for (int i = 0; i < gpu_data_out_.size(); ++i) {
|
||||
GpuTensor& tensor = output_tensors_gpu->at(i);
|
||||
RET_CHECK_CALL(CreateReadWriteShaderStorageBuffer<float>(
|
||||
gpu_data_out_[i]->elements, &tensor));
|
||||
MP_RETURN_IF_ERROR(
|
||||
tflite_gpu_runner_->BindSSBOToOutputTensor(tensor.id(), i));
|
||||
}
|
||||
} else {
|
||||
RET_CHECK_EQ(interpreter_->Invoke(), kTfLiteOk);
|
||||
// Re-use internal output tensor.
|
||||
for (int i = 0; i < gpu_data_out_.size(); ++i) {
|
||||
MP_RETURN_IF_ERROR(tflite_gpu_runner_->BindSSBOToOutputTensor(
|
||||
gpu_data_out_[i]->buffer.id(), i));
|
||||
}
|
||||
}
|
||||
#endif // MEDIAPIPE_TFLITE_GL_INFERENCE
|
||||
} else if (gpu_input_) {
|
||||
// Read GPU input into SSBO.
|
||||
#if MEDIAPIPE_TFLITE_GL_INFERENCE
|
||||
const auto& input_tensors =
|
||||
cc->Inputs().Tag(kTensorsGpuTag).Get<std::vector<GpuTensor>>();
|
||||
RET_CHECK_GT(input_tensors.size(), 0);
|
||||
// Explicit copy input.
|
||||
gpu_data_in_.resize(input_tensors.size());
|
||||
for (int i = 0; i < input_tensors.size(); ++i) {
|
||||
RET_CHECK_CALL(CopyBuffer(input_tensors[i], gpu_data_in_[i]->buffer));
|
||||
}
|
||||
#elif MEDIAPIPE_TFLITE_METAL_INFERENCE
|
||||
const auto& input_tensors =
|
||||
cc->Inputs().Tag(kTensorsGpuTag).Get<std::vector<GpuTensor>>();
|
||||
RET_CHECK_GT(input_tensors.size(), 0);
|
||||
// Explicit copy input with conversion float 32 bits to 16 bits.
|
||||
gpu_data_in_.resize(input_tensors.size());
|
||||
id<MTLCommandBuffer> command_buffer = [gpu_helper_ commandBuffer];
|
||||
command_buffer.label = @"TfLiteInferenceCalculatorConvert";
|
||||
id<MTLComputeCommandEncoder> compute_encoder =
|
||||
[command_buffer computeCommandEncoder];
|
||||
[compute_encoder setComputePipelineState:fp32_to_fp16_program_];
|
||||
for (int i = 0; i < input_tensors.size(); ++i) {
|
||||
[compute_encoder setBuffer:input_tensors[i] offset:0 atIndex:0];
|
||||
[compute_encoder setBuffer:gpu_data_in_[i]->buffer offset:0 atIndex:1];
|
||||
constexpr int kWorkgroupSize = 64; // Block size for GPU shader.
|
||||
MTLSize threads_per_group = MTLSizeMake(kWorkgroupSize, 1, 1);
|
||||
const int threadgroups =
|
||||
NumGroups(gpu_data_in_[i]->elements, kWorkgroupSize);
|
||||
[compute_encoder dispatchThreadgroups:MTLSizeMake(threadgroups, 1, 1)
|
||||
threadsPerThreadgroup:threads_per_group];
|
||||
}
|
||||
[compute_encoder endEncoding];
|
||||
[command_buffer commit];
|
||||
#endif // MEDIAPIPE_TFLITE_GL_INFERENCE
|
||||
}
|
||||
|
||||
// 3. Output processed tensors.
|
||||
return ::mediapipe::OkStatus();
|
||||
}
|
||||
|
||||
::mediapipe::Status TfLiteInferenceCalculator::ProcessOutputsCpu(
|
||||
CalculatorContext* cc,
|
||||
std::unique_ptr<std::vector<TfLiteTensor>> output_tensors_cpu) {
|
||||
// Output result tensors (CPU).
|
||||
const auto& tensor_indexes = interpreter_->outputs();
|
||||
for (int i = 0; i < tensor_indexes.size(); ++i) {
|
||||
TfLiteTensor* tensor = interpreter_->tensor(tensor_indexes[i]);
|
||||
output_tensors_cpu->emplace_back(*tensor);
|
||||
}
|
||||
cc->Outputs()
|
||||
.Tag(kTensorsTag)
|
||||
.Add(output_tensors_cpu.release(), cc->InputTimestamp());
|
||||
|
||||
return ::mediapipe::OkStatus();
|
||||
}
|
||||
|
||||
::mediapipe::Status TfLiteInferenceCalculator::ProcessOutputsGpu(
|
||||
CalculatorContext* cc,
|
||||
std::unique_ptr<std::vector<TfLiteTensor>> output_tensors_cpu,
|
||||
std::unique_ptr<std::vector<GpuTensor>> output_tensors_gpu) {
|
||||
if (use_advanced_gpu_api_) {
|
||||
#if !defined(MEDIAPIPE_DISABLE_GL_COMPUTE)
|
||||
#if MEDIAPIPE_TFLITE_GL_INFERENCE
|
||||
if (gpu_output_) {
|
||||
// Send out pre-allocated tensors.
|
||||
cc->Outputs()
|
||||
.Tag(kTensorsGpuTag)
|
||||
.Add(output_tensors_gpu.release(), cc->InputTimestamp());
|
||||
#endif
|
||||
} else {
|
||||
// Download to CPU for output.
|
||||
const auto& tensor_indexes = interpreter_->inputs();
|
||||
for (int i = 0; i < tensor_indexes.size(); ++i) {
|
||||
TfLiteTensor* tensor = interpreter_->tensor(tensor_indexes[i]);
|
||||
std::vector<float> gpu_data(tensor->bytes / sizeof(float));
|
||||
RET_CHECK_CALL(gpu_data_out_[i]->buffer.Read(
|
||||
absl::MakeSpan(tensor->data.f, tensor->bytes)));
|
||||
output_tensors_cpu->emplace_back(*tensor);
|
||||
}
|
||||
// Output result tensors (CPU).
|
||||
cc->Outputs()
|
||||
.Tag(kTensorsTag)
|
||||
.Add(output_tensors_cpu.release(), cc->InputTimestamp());
|
||||
}
|
||||
#endif // MEDIAPIPE_TFLITE_GL_INFERENCE
|
||||
} else if (gpu_output_) {
|
||||
#if !defined(MEDIAPIPE_DISABLE_GL_COMPUTE)
|
||||
#if MEDIAPIPE_TFLITE_GL_INFERENCE
|
||||
// Output result tensors (GPU).
|
||||
MP_RETURN_IF_ERROR(gpu_helper_.RunInGlContext(
|
||||
[this, &output_tensors_gpu]() -> ::mediapipe::Status {
|
||||
output_tensors_gpu->resize(gpu_data_out_.size());
|
||||
for (int i = 0; i < gpu_data_out_.size(); ++i) {
|
||||
GpuTensor& tensor = output_tensors_gpu->at(i);
|
||||
|
@ -537,12 +603,10 @@ REGISTER_CALCULATOR(TfLiteInferenceCalculator);
|
|||
gpu_data_out_[i]->elements, &tensor));
|
||||
RET_CHECK_CALL(CopyBuffer(gpu_data_out_[i]->buffer, tensor));
|
||||
}
|
||||
return ::mediapipe::OkStatus();
|
||||
}));
|
||||
cc->Outputs()
|
||||
.Tag(kTensorsGpuTag)
|
||||
.Add(output_tensors_gpu.release(), cc->InputTimestamp());
|
||||
#elif defined(MEDIAPIPE_IOS)
|
||||
#elif MEDIAPIPE_TFLITE_METAL_INFERENCE
|
||||
// Output result tensors (GPU).
|
||||
output_tensors_gpu->resize(gpu_data_out_.size());
|
||||
id<MTLDevice> device = gpu_helper_.mtlDevice;
|
||||
|
@ -566,68 +630,58 @@ REGISTER_CALCULATOR(TfLiteInferenceCalculator);
|
|||
cc->Outputs()
|
||||
.Tag(kTensorsGpuTag)
|
||||
.Add(output_tensors_gpu.release(), cc->InputTimestamp());
|
||||
#else
|
||||
RET_CHECK_FAIL() << "GPU processing not enabled.";
|
||||
#endif // !MEDIAPIPE_DISABLE_GPU
|
||||
} else {
|
||||
// Output result tensors (CPU).
|
||||
const auto& tensor_indexes = interpreter_->outputs();
|
||||
for (int i = 0; i < tensor_indexes.size(); ++i) {
|
||||
TfLiteTensor* tensor = interpreter_->tensor(tensor_indexes[i]);
|
||||
output_tensors_cpu->emplace_back(*tensor);
|
||||
}
|
||||
cc->Outputs()
|
||||
.Tag(kTensorsTag)
|
||||
.Add(output_tensors_cpu.release(), cc->InputTimestamp());
|
||||
#endif // MEDIAPIPE_TFLITE_GL_INFERENCE
|
||||
}
|
||||
|
||||
return ::mediapipe::OkStatus();
|
||||
}
|
||||
|
||||
::mediapipe::Status TfLiteInferenceCalculator::Close(CalculatorContext* cc) {
|
||||
if (delegate_) {
|
||||
if (gpu_inference_) {
|
||||
#if !defined(MEDIAPIPE_DISABLE_GL_COMPUTE)
|
||||
MP_RETURN_IF_ERROR(gpu_helper_.RunInGlContext([this]() -> Status {
|
||||
interpreter_ = nullptr;
|
||||
delegate_ = nullptr;
|
||||
for (int i = 0; i < gpu_data_in_.size(); ++i) {
|
||||
gpu_data_in_[i].reset();
|
||||
::mediapipe::Status TfLiteInferenceCalculator::InitTFLiteGPURunner(
|
||||
CalculatorContext* cc) {
|
||||
#if MEDIAPIPE_TFLITE_GL_INFERENCE
|
||||
ASSIGN_OR_RETURN(model_packet_, GetModelAsPacket(*cc));
|
||||
const auto& model = *model_packet_.Get<TfLiteModelPtr>();
|
||||
tflite::ops::builtin::BuiltinOpResolver op_resolver;
|
||||
if (cc->InputSidePackets().HasTag("CUSTOM_OP_RESOLVER")) {
|
||||
op_resolver = cc->InputSidePackets()
|
||||
.Tag("CUSTOM_OP_RESOLVER")
|
||||
.Get<tflite::ops::builtin::BuiltinOpResolver>();
|
||||
}
|
||||
for (int i = 0; i < gpu_data_out_.size(); ++i) {
|
||||
gpu_data_out_[i].reset();
|
||||
}
|
||||
return ::mediapipe::OkStatus();
|
||||
}));
|
||||
#elif defined(MEDIAPIPE_IOS)
|
||||
interpreter_ = nullptr;
|
||||
delegate_ = nullptr;
|
||||
for (int i = 0; i < gpu_data_in_.size(); ++i) {
|
||||
gpu_data_in_[i].reset();
|
||||
}
|
||||
for (int i = 0; i < gpu_data_out_.size(); ++i) {
|
||||
gpu_data_out_[i].reset();
|
||||
}
|
||||
#endif
|
||||
} else {
|
||||
interpreter_ = nullptr;
|
||||
delegate_ = nullptr;
|
||||
}
|
||||
}
|
||||
#if defined(MEDIAPIPE_EDGE_TPU)
|
||||
edgetpu_context_.reset();
|
||||
#endif
|
||||
return ::mediapipe::OkStatus();
|
||||
}
|
||||
|
||||
// Calculator Auxiliary Section
|
||||
// Create runner
|
||||
tflite::gpu::InferenceOptions options;
|
||||
options.priority1 = tflite::gpu::InferencePriority::MIN_LATENCY;
|
||||
options.priority2 = tflite::gpu::InferencePriority::AUTO;
|
||||
options.priority3 = tflite::gpu::InferencePriority::AUTO;
|
||||
options.usage = tflite::gpu::InferenceUsage::SUSTAINED_SPEED;
|
||||
tflite_gpu_runner_ = std::make_unique<tflite::gpu::TFLiteGPURunner>(options);
|
||||
RET_CHECK_CALL(tflite_gpu_runner_->InitializeWithModel(model, op_resolver));
|
||||
|
||||
// Allocate interpreter memory for cpu output.
|
||||
if (!gpu_output_) {
|
||||
interpreter_ = absl::make_unique<tflite::Interpreter>();
|
||||
const int num_outputs = tflite_gpu_runner_->GetOutputShapes().size();
|
||||
interpreter_->AddTensors(num_outputs);
|
||||
std::vector<int> indices(num_outputs);
|
||||
for (int i = 0; i < num_outputs; ++i) indices[i] = i;
|
||||
// There is no ResizeOutputTensor(), so we use 'inputs' space instead.
|
||||
interpreter_->SetInputs(indices);
|
||||
TfLiteQuantization quant;
|
||||
quant.type = kTfLiteNoQuantization;
|
||||
quant.params = nullptr;
|
||||
for (int i = 0; i < num_outputs; ++i) {
|
||||
auto shape = tflite_gpu_runner_->GetOutputShapes()[i];
|
||||
const int tensor_idx = interpreter_->inputs()[i];
|
||||
interpreter_->SetTensorParametersReadWrite(tensor_idx, kTfLiteFloat32, "",
|
||||
{shape.c}, quant);
|
||||
CHECK(interpreter_->ResizeInputTensor(
|
||||
tensor_idx, {shape.h, shape.w, shape.c}) == kTfLiteOk);
|
||||
}
|
||||
CHECK(interpreter_->AllocateTensors() == kTfLiteOk);
|
||||
}
|
||||
|
||||
::mediapipe::Status TfLiteInferenceCalculator::InitTFLiteGPURunner() {
|
||||
#if !defined(MEDIAPIPE_DISABLE_GL_COMPUTE)
|
||||
// Create and bind OpenGL buffers for outputs.
|
||||
// These buffers are created onve and later their ids are jut passed to the
|
||||
// calculator outputs.
|
||||
|
||||
// The buffers are created once and their ids are passed to calculator outputs
|
||||
gpu_data_out_.resize(tflite_gpu_runner_->outputs_size());
|
||||
for (int i = 0; i < tflite_gpu_runner_->outputs_size(); ++i) {
|
||||
gpu_data_out_[i] = absl::make_unique<GPUData>();
|
||||
|
@ -638,15 +692,20 @@ REGISTER_CALCULATOR(TfLiteInferenceCalculator);
|
|||
gpu_data_out_[i]->elements, &gpu_data_out_[i]->buffer));
|
||||
}
|
||||
RET_CHECK_CALL(tflite_gpu_runner_->Build());
|
||||
#endif
|
||||
#endif // MEDIAPIPE_TFLITE_GL_INFERENCE
|
||||
|
||||
return ::mediapipe::OkStatus();
|
||||
}
|
||||
|
||||
::mediapipe::Status TfLiteInferenceCalculator::LoadModel(
|
||||
CalculatorContext* cc) {
|
||||
if (use_advanced_gpu_api_) {
|
||||
// Use InitTFLiteGPURunner for everything.
|
||||
return ::mediapipe::OkStatus();
|
||||
}
|
||||
|
||||
ASSIGN_OR_RETURN(model_packet_, GetModelAsPacket(*cc));
|
||||
const auto& model = *model_packet_.Get<TfLiteModelPtr>();
|
||||
|
||||
tflite::ops::builtin::BuiltinOpResolver op_resolver;
|
||||
if (cc->InputSidePackets().HasTag("CUSTOM_OP_RESOLVER")) {
|
||||
op_resolver = cc->InputSidePackets()
|
||||
|
@ -654,19 +713,6 @@ REGISTER_CALCULATOR(TfLiteInferenceCalculator);
|
|||
.Get<tflite::ops::builtin::BuiltinOpResolver>();
|
||||
}
|
||||
|
||||
#if !defined(MEDIAPIPE_DISABLE_GL_COMPUTE)
|
||||
if (use_advanced_gpu_api_) {
|
||||
tflite::gpu::InferenceOptions options;
|
||||
options.priority1 = tflite::gpu::InferencePriority::MIN_LATENCY;
|
||||
options.priority2 = tflite::gpu::InferencePriority::AUTO;
|
||||
options.priority3 = tflite::gpu::InferencePriority::AUTO;
|
||||
options.usage = tflite::gpu::InferenceUsage::SUSTAINED_SPEED;
|
||||
tflite_gpu_runner_ =
|
||||
std::make_unique<tflite::gpu::TFLiteGPURunner>(options);
|
||||
return tflite_gpu_runner_->InitializeWithModel(model, op_resolver);
|
||||
}
|
||||
#endif
|
||||
|
||||
#if defined(MEDIAPIPE_EDGE_TPU)
|
||||
interpreter_ =
|
||||
BuildEdgeTpuInterpreter(model, &op_resolver, edgetpu_context_.get());
|
||||
|
@ -771,7 +817,7 @@ REGISTER_CALCULATOR(TfLiteInferenceCalculator);
|
|||
return ::mediapipe::OkStatus();
|
||||
}
|
||||
|
||||
#if !defined(MEDIAPIPE_DISABLE_GL_COMPUTE)
|
||||
#if MEDIAPIPE_TFLITE_GL_INFERENCE
|
||||
// Configure and create the delegate.
|
||||
TfLiteGpuDelegateOptions options = TfLiteGpuDelegateOptionsDefault();
|
||||
options.compile_options.precision_loss_allowed = 1;
|
||||
|
@ -832,9 +878,9 @@ REGISTER_CALCULATOR(TfLiteInferenceCalculator);
|
|||
// Must call this last.
|
||||
RET_CHECK_EQ(interpreter_->ModifyGraphWithDelegate(delegate_.get()),
|
||||
kTfLiteOk);
|
||||
#endif // OpenGL
|
||||
#endif // MEDIAPIPE_TFLITE_GL_INFERENCE
|
||||
|
||||
#if defined(MEDIAPIPE_IOS)
|
||||
#if MEDIAPIPE_TFLITE_METAL_INFERENCE
|
||||
const int kHalfSize = 2; // sizeof(half)
|
||||
// Configure and create the delegate.
|
||||
TFLGpuDelegateOptions options;
|
||||
|
@ -958,7 +1004,7 @@ REGISTER_CALCULATOR(TfLiteInferenceCalculator);
|
|||
"Error initializating output buffer converter");
|
||||
}
|
||||
}
|
||||
#endif // iOS
|
||||
#endif // MEDIAPIPE_TFLITE_METAL_INFERENCE
|
||||
|
||||
return ::mediapipe::OkStatus();
|
||||
}
|
||||
|
|
|
@ -45,6 +45,8 @@ message TfLiteInferenceCalculatorOptions {
|
|||
message Gpu {
|
||||
// Experimental, Android/Linux only. Use TFLite GPU delegate API2 for
|
||||
// the NN inference.
|
||||
// example:
|
||||
// delegate: { gpu { use_advanced_gpu_api: true } }
|
||||
optional bool use_advanced_gpu_api = 1 [default = false];
|
||||
}
|
||||
// Android only.
|
||||
|
|
|
@ -25,17 +25,18 @@
|
|||
#include "mediapipe/framework/formats/location.h"
|
||||
#include "mediapipe/framework/formats/object_detection/anchor.pb.h"
|
||||
#include "mediapipe/framework/port/ret_check.h"
|
||||
#include "mediapipe/util/tflite/config.h"
|
||||
#include "tensorflow/lite/interpreter.h"
|
||||
|
||||
#if !defined(MEDIAPIPE_DISABLE_GL_COMPUTE)
|
||||
#if MEDIAPIPE_TFLITE_GL_INFERENCE
|
||||
#include "mediapipe/gpu/gl_calculator_helper.h"
|
||||
#include "tensorflow/lite/delegates/gpu/gl/gl_buffer.h"
|
||||
#include "tensorflow/lite/delegates/gpu/gl/gl_program.h"
|
||||
#include "tensorflow/lite/delegates/gpu/gl/gl_shader.h"
|
||||
#include "tensorflow/lite/delegates/gpu/gl_delegate.h"
|
||||
#endif // !MEDIAPIPE_DISABLE_GPU
|
||||
#endif // MEDIAPIPE_TFLITE_GL_INFERENCE
|
||||
|
||||
#if defined(MEDIAPIPE_IOS)
|
||||
#if MEDIAPIPE_TFLITE_METAL_INFERENCE
|
||||
#import <CoreVideo/CoreVideo.h>
|
||||
#import <Metal/Metal.h>
|
||||
#import <MetalKit/MetalKit.h>
|
||||
|
@ -44,7 +45,7 @@
|
|||
#include "mediapipe/gpu/MPPMetalUtil.h"
|
||||
#include "mediapipe/gpu/gpu_buffer.h"
|
||||
#include "tensorflow/lite/delegates/gpu/metal_delegate.h"
|
||||
#endif // iOS
|
||||
#endif // MEDIAPIPE_TFLITE_METAL_INFERENCE
|
||||
|
||||
namespace {
|
||||
constexpr int kNumInputTensorsWithAnchors = 3;
|
||||
|
@ -56,22 +57,17 @@ constexpr char kTensorsGpuTag[] = "TENSORS_GPU";
|
|||
|
||||
namespace mediapipe {
|
||||
|
||||
#if !defined(MEDIAPIPE_DISABLE_GL_COMPUTE)
|
||||
#if MEDIAPIPE_TFLITE_GL_INFERENCE
|
||||
using ::tflite::gpu::gl::CreateReadWriteShaderStorageBuffer;
|
||||
using ::tflite::gpu::gl::GlShader;
|
||||
#endif
|
||||
|
||||
#if !defined(MEDIAPIPE_DISABLE_GL_COMPUTE)
|
||||
typedef ::tflite::gpu::gl::GlBuffer GpuTensor;
|
||||
typedef ::tflite::gpu::gl::GlProgram GpuProgram;
|
||||
#elif defined(MEDIAPIPE_IOS)
|
||||
typedef id<MTLBuffer> GpuTensor;
|
||||
#elif MEDIAPIPE_TFLITE_METAL_INFERENCE
|
||||
typedef id<MTLComputePipelineState> GpuProgram;
|
||||
#endif
|
||||
#endif // MEDIAPIPE_TFLITE_GL_INFERENCE
|
||||
|
||||
namespace {
|
||||
|
||||
#if !defined(MEDIAPIPE_DISABLE_GPU) && !defined(__EMSCRIPTEN__)
|
||||
#if MEDIAPIPE_TFLITE_GPU_SUPPORTED
|
||||
struct GPUData {
|
||||
GpuProgram decode_program;
|
||||
GpuProgram score_program;
|
||||
|
@ -81,7 +77,7 @@ struct GPUData {
|
|||
GpuTensor scored_boxes_buffer;
|
||||
GpuTensor raw_scores_buffer;
|
||||
};
|
||||
#endif
|
||||
#endif // MEDIAPIPE_TFLITE_GPU_SUPPORTED
|
||||
|
||||
void ConvertRawValuesToAnchors(const float* raw_anchors, int num_boxes,
|
||||
std::vector<Anchor>* anchors) {
|
||||
|
@ -181,13 +177,13 @@ class TfLiteTensorsToDetectionsCalculator : public CalculatorBase {
|
|||
std::vector<Anchor> anchors_;
|
||||
bool side_packet_anchors_{};
|
||||
|
||||
#if !defined(MEDIAPIPE_DISABLE_GL_COMPUTE)
|
||||
#if MEDIAPIPE_TFLITE_GL_INFERENCE
|
||||
mediapipe::GlCalculatorHelper gpu_helper_;
|
||||
std::unique_ptr<GPUData> gpu_data_;
|
||||
#elif defined(MEDIAPIPE_IOS)
|
||||
#elif MEDIAPIPE_TFLITE_METAL_INFERENCE
|
||||
MPPMetalHelper* gpu_helper_ = nullptr;
|
||||
std::unique_ptr<GPUData> gpu_data_;
|
||||
#endif
|
||||
#endif // MEDIAPIPE_TFLITE_GL_INFERENCE
|
||||
|
||||
bool gpu_input_ = false;
|
||||
bool anchors_init_ = false;
|
||||
|
@ -205,12 +201,10 @@ REGISTER_CALCULATOR(TfLiteTensorsToDetectionsCalculator);
|
|||
cc->Inputs().Tag(kTensorsTag).Set<std::vector<TfLiteTensor>>();
|
||||
}
|
||||
|
||||
#if !defined(MEDIAPIPE_DISABLE_GPU) && !defined(__EMSCRIPTEN__)
|
||||
if (cc->Inputs().HasTag(kTensorsGpuTag)) {
|
||||
cc->Inputs().Tag(kTensorsGpuTag).Set<std::vector<GpuTensor>>();
|
||||
use_gpu |= true;
|
||||
}
|
||||
#endif // !MEDIAPIPE_DISABLE_GPU
|
||||
|
||||
if (cc->Outputs().HasTag("DETECTIONS")) {
|
||||
cc->Outputs().Tag("DETECTIONS").Set<std::vector<Detection>>();
|
||||
|
@ -223,11 +217,11 @@ REGISTER_CALCULATOR(TfLiteTensorsToDetectionsCalculator);
|
|||
}
|
||||
|
||||
if (use_gpu) {
|
||||
#if !defined(MEDIAPIPE_DISABLE_GL_COMPUTE)
|
||||
#if MEDIAPIPE_TFLITE_GL_INFERENCE
|
||||
MP_RETURN_IF_ERROR(mediapipe::GlCalculatorHelper::UpdateContract(cc));
|
||||
#elif defined(MEDIAPIPE_IOS)
|
||||
#elif MEDIAPIPE_TFLITE_METAL_INFERENCE
|
||||
MP_RETURN_IF_ERROR([MPPMetalHelper updateContract:cc]);
|
||||
#endif
|
||||
#endif // MEDIAPIPE_TFLITE_GL_INFERENCE
|
||||
}
|
||||
|
||||
return ::mediapipe::OkStatus();
|
||||
|
@ -239,12 +233,12 @@ REGISTER_CALCULATOR(TfLiteTensorsToDetectionsCalculator);
|
|||
|
||||
if (cc->Inputs().HasTag(kTensorsGpuTag)) {
|
||||
gpu_input_ = true;
|
||||
#if !defined(MEDIAPIPE_DISABLE_GL_COMPUTE)
|
||||
#if MEDIAPIPE_TFLITE_GL_INFERENCE
|
||||
MP_RETURN_IF_ERROR(gpu_helper_.Open(cc));
|
||||
#elif defined(MEDIAPIPE_IOS)
|
||||
#elif MEDIAPIPE_TFLITE_METAL_INFERENCE
|
||||
gpu_helper_ = [[MPPMetalHelper alloc] initWithCalculatorContext:cc];
|
||||
RET_CHECK(gpu_helper_);
|
||||
#endif
|
||||
#endif // MEDIAPIPE_TFLITE_GL_INFERENCE
|
||||
}
|
||||
|
||||
MP_RETURN_IF_ERROR(LoadOptions(cc));
|
||||
|
@ -401,7 +395,7 @@ REGISTER_CALCULATOR(TfLiteTensorsToDetectionsCalculator);
|
|||
}
|
||||
::mediapipe::Status TfLiteTensorsToDetectionsCalculator::ProcessGPU(
|
||||
CalculatorContext* cc, std::vector<Detection>* output_detections) {
|
||||
#if !defined(MEDIAPIPE_DISABLE_GL_COMPUTE)
|
||||
#if MEDIAPIPE_TFLITE_GL_INFERENCE
|
||||
const auto& input_tensors =
|
||||
cc->Inputs().Tag(kTensorsGpuTag).Get<std::vector<GpuTensor>>();
|
||||
RET_CHECK_GE(input_tensors.size(), 2);
|
||||
|
@ -464,7 +458,7 @@ REGISTER_CALCULATOR(TfLiteTensorsToDetectionsCalculator);
|
|||
|
||||
return ::mediapipe::OkStatus();
|
||||
}));
|
||||
#elif defined(MEDIAPIPE_IOS)
|
||||
#elif MEDIAPIPE_TFLITE_METAL_INFERENCE
|
||||
|
||||
const auto& input_tensors =
|
||||
cc->Inputs().Tag(kTensorsGpuTag).Get<std::vector<GpuTensor>>();
|
||||
|
@ -546,17 +540,17 @@ REGISTER_CALCULATOR(TfLiteTensorsToDetectionsCalculator);
|
|||
|
||||
#else
|
||||
LOG(ERROR) << "GPU input on non-Android not supported yet.";
|
||||
#endif
|
||||
#endif // MEDIAPIPE_TFLITE_GL_INFERENCE
|
||||
return ::mediapipe::OkStatus();
|
||||
}
|
||||
|
||||
::mediapipe::Status TfLiteTensorsToDetectionsCalculator::Close(
|
||||
CalculatorContext* cc) {
|
||||
#if !defined(MEDIAPIPE_DISABLE_GL_COMPUTE)
|
||||
#if MEDIAPIPE_TFLITE_GL_INFERENCE
|
||||
gpu_helper_.RunInGlContext([this] { gpu_data_.reset(); });
|
||||
#elif defined(MEDIAPIPE_IOS)
|
||||
#elif MEDIAPIPE_TFLITE_METAL_INFERENCE
|
||||
gpu_data_.reset();
|
||||
#endif
|
||||
#endif // MEDIAPIPE_TFLITE_GL_INFERENCE
|
||||
|
||||
return ::mediapipe::OkStatus();
|
||||
}
|
||||
|
@ -705,7 +699,7 @@ Detection TfLiteTensorsToDetectionsCalculator::ConvertToDetection(
|
|||
|
||||
::mediapipe::Status TfLiteTensorsToDetectionsCalculator::GpuInit(
|
||||
CalculatorContext* cc) {
|
||||
#if !defined(MEDIAPIPE_DISABLE_GL_COMPUTE)
|
||||
#if MEDIAPIPE_TFLITE_GL_INFERENCE
|
||||
MP_RETURN_IF_ERROR(gpu_helper_.RunInGlContext([this]()
|
||||
-> ::mediapipe::Status {
|
||||
gpu_data_ = absl::make_unique<GPUData>();
|
||||
|
@ -918,7 +912,7 @@ void main() {
|
|||
return ::mediapipe::OkStatus();
|
||||
}));
|
||||
|
||||
#elif defined(MEDIAPIPE_IOS)
|
||||
#elif MEDIAPIPE_TFLITE_METAL_INFERENCE
|
||||
|
||||
gpu_data_ = absl::make_unique<GPUData>();
|
||||
id<MTLDevice> device = gpu_helper_.mtlDevice;
|
||||
|
@ -1148,7 +1142,7 @@ kernel void scoreKernel(
|
|||
CHECK_LT(num_classes_, max_wg_size) << "# classes must be <" << max_wg_size;
|
||||
}
|
||||
|
||||
#endif // !defined(MEDIAPIPE_DISABLE_GL_COMPUTE)
|
||||
#endif // MEDIAPIPE_TFLITE_GL_INFERENCE
|
||||
|
||||
return ::mediapipe::OkStatus();
|
||||
}
|
||||
|
|
|
@ -217,11 +217,11 @@ REGISTER_CALCULATOR(TfLiteTensorsToLandmarksCalculator);
|
|||
for (int i = 0; i < output_landmarks.landmark_size(); ++i) {
|
||||
const Landmark& landmark = output_landmarks.landmark(i);
|
||||
NormalizedLandmark* norm_landmark = output_norm_landmarks.add_landmark();
|
||||
norm_landmark->set_x(static_cast<float>(landmark.x()) /
|
||||
options_.input_image_width());
|
||||
norm_landmark->set_y(static_cast<float>(landmark.y()) /
|
||||
options_.input_image_height());
|
||||
norm_landmark->set_z(landmark.z() / options_.normalize_z());
|
||||
norm_landmark->set_x(landmark.x() / options_.input_image_width());
|
||||
norm_landmark->set_y(landmark.y() / options_.input_image_height());
|
||||
// Scale Z coordinate as X + allow additional uniform normalization.
|
||||
norm_landmark->set_z(landmark.z() / options_.input_image_width() /
|
||||
options_.normalize_z());
|
||||
norm_landmark->set_visibility(landmark.visibility());
|
||||
}
|
||||
cc->Outputs()
|
||||
|
|
|
@ -29,7 +29,8 @@ message TfLiteTensorsToLandmarksCalculatorOptions {
|
|||
required int32 num_landmarks = 1;
|
||||
|
||||
// Size of the input image for the model. These options are used only when
|
||||
// normalized landmarks is needed.
|
||||
// normalized landmarks are needed. Z coordinate is scaled as X assuming
|
||||
// a weak perspective projection camera model.
|
||||
optional int32 input_image_width = 2;
|
||||
optional int32 input_image_height = 3;
|
||||
|
||||
|
@ -46,6 +47,8 @@ message TfLiteTensorsToLandmarksCalculatorOptions {
|
|||
// beforehand.
|
||||
optional bool flip_horizontally = 6 [default = false];
|
||||
|
||||
// A value that z values should be divided by.
|
||||
// A value that Z coordinates should be divided by. This option is used only
|
||||
// when normalized landmarks are needed. It is applied in addition to Z
|
||||
// coordinate being re-scaled as X.
|
||||
optional float normalize_z = 5 [default = 1.0];
|
||||
}
|
||||
|
|
|
@ -376,6 +376,7 @@ cc_library(
|
|||
visibility = ["//visibility:public"],
|
||||
deps = [
|
||||
":timed_box_list_id_to_label_calculator_cc_proto",
|
||||
"@com_google_absl//absl/container:node_hash_map",
|
||||
"//mediapipe/framework/port:status",
|
||||
"//mediapipe/framework:calculator_framework",
|
||||
"//mediapipe/framework:packet",
|
||||
|
|
|
@ -122,11 +122,13 @@ class LandmarkLetterboxRemovalCalculator : public CalculatorBase {
|
|||
NormalizedLandmark* new_landmark = output_landmarks.add_landmark();
|
||||
const float new_x = (landmark.x() - left) / (1.0f - left_and_right);
|
||||
const float new_y = (landmark.y() - top) / (1.0f - top_and_bottom);
|
||||
const float new_z =
|
||||
landmark.z() / (1.0f - left_and_right); // Scale Z coordinate as X.
|
||||
|
||||
new_landmark->set_x(new_x);
|
||||
new_landmark->set_y(new_y);
|
||||
// Keep z-coord as is.
|
||||
new_landmark->set_z(landmark.z());
|
||||
new_landmark->set_z(new_z);
|
||||
// Keep visibility as is.
|
||||
new_landmark->set_visibility(landmark.visibility());
|
||||
}
|
||||
|
|
|
@ -123,11 +123,12 @@ class LandmarkProjectionCalculator : public CalculatorBase {
|
|||
|
||||
new_x = new_x * input_rect.width() + input_rect.x_center();
|
||||
new_y = new_y * input_rect.height() + input_rect.y_center();
|
||||
const float new_z =
|
||||
landmark.z() * input_rect.width(); // Scale Z coordinate as X.
|
||||
|
||||
new_landmark->set_x(new_x);
|
||||
new_landmark->set_y(new_y);
|
||||
// Keep z-coord as is.
|
||||
new_landmark->set_z(landmark.z());
|
||||
new_landmark->set_z(new_z);
|
||||
// Keep visibility as is.
|
||||
new_landmark->set_visibility(landmark.visibility());
|
||||
}
|
||||
|
|
|
@ -12,6 +12,7 @@
|
|||
// See the License for the specific language governing permissions and
|
||||
// limitations under the License.
|
||||
|
||||
#include "absl/container/node_hash_map.h"
|
||||
#include "mediapipe/calculators/util/timed_box_list_id_to_label_calculator.pb.h"
|
||||
#include "mediapipe/framework/calculator_framework.h"
|
||||
#include "mediapipe/framework/packet.h"
|
||||
|
@ -53,7 +54,7 @@ class TimedBoxListIdToLabelCalculator : public CalculatorBase {
|
|||
::mediapipe::Status Process(CalculatorContext* cc) override;
|
||||
|
||||
private:
|
||||
std::unordered_map<int, std::string> label_map_;
|
||||
absl::node_hash_map<int, std::string> label_map_;
|
||||
};
|
||||
REGISTER_CALCULATOR(TimedBoxListIdToLabelCalculator);
|
||||
|
||||
|
|
|
@ -1,4 +1 @@
|
|||
MediaPipe Examples
|
||||
==================
|
||||
|
||||
This directory contains MediaPipe Android example applications. Please see [src/java/com/google/mediapipe/apps/README.md](src/java/com/google/mediapipe/apps/README.md) for details.
|
||||
This directory contains MediaPipe example applications for Android. Please see [Solutions](https://solutions.mediapipe.dev)for details.
|
||||
|
|
|
@ -1,7 +0,0 @@
|
|||
tricorder: {
|
||||
options: {
|
||||
builder: {
|
||||
config: "android_arm64"
|
||||
}
|
||||
}
|
||||
}
|
|
@ -83,7 +83,7 @@ android_binary(
|
|||
manifest = "//mediapipe/examples/android/src/java/com/google/mediapipe/apps/basic:AndroidManifest.xml",
|
||||
manifest_values = {
|
||||
"applicationId": "com.google.mediapipe.apps.objectdetection3d",
|
||||
"appName": "Object Detection 3D",
|
||||
"appName": "Objectron",
|
||||
"mainActivity": ".MainActivity",
|
||||
"cameraFacingFront": "False",
|
||||
"binaryGraphName": "object_detection_3d.binarypb",
|
||||
|
|
|
@ -1,113 +1 @@
|
|||
**Hello World**
|
||||
|
||||
To build the "Hello World" example, use:
|
||||
|
||||
```
|
||||
bazel build -c opt mediapipe/examples/desktop/hello_world:hello_world
|
||||
```
|
||||
|
||||
and then run it using:
|
||||
|
||||
```
|
||||
export GLOG_logtostderr=1
|
||||
|
||||
bazel-bin/mediapipe/examples/desktop/hello_world/hello_world
|
||||
```
|
||||
|
||||
**TFlite Object Detection**
|
||||
|
||||
To build the object detection demo using a TFLite model on desktop, use:
|
||||
|
||||
```
|
||||
bazel build -c opt mediapipe/examples/desktop/object_detection:object_detection_tflite --define MEDIAPIPE_DISABLE_GPU=1
|
||||
```
|
||||
|
||||
and run it using:
|
||||
|
||||
```
|
||||
export GLOG_logtostderr=1
|
||||
|
||||
bazel-bin/mediapipe/examples/desktop/object_detection/object_detection_tflite \
|
||||
--calculator_graph_config_file=mediapipe/graphs/object_detection/object_detection_desktop_tflite_graph.pbtxt \
|
||||
--input_side_packets=input_video_path=/path/to/input/file,output_video_path=/path/to/output/file
|
||||
```
|
||||
|
||||
**TensorFlow Object Detection**
|
||||
|
||||
To build the object detection demo using a TensorFlow model on desktop, use:
|
||||
|
||||
```
|
||||
export GLOG_logtostderr=1
|
||||
|
||||
bazel build -c opt mediapipe/examples/desktop/object_detection:object_detection_tensorflow \
|
||||
--define MEDIAPIPE_DISABLE_GPU=1
|
||||
```
|
||||
|
||||
and run it using:
|
||||
|
||||
```
|
||||
export GLOG_logtostderr=1
|
||||
|
||||
bazel-bin/mediapipe/examples/desktop/object_detection/object_detection_tensorflow \
|
||||
--calculator_graph_config_file=mediapipe/graphs/object_detection/object_detection_desktop_tensorflow_graph.pbtxt \
|
||||
--input_side_packets=input_video_path=/path/to/input/file,output_video_path=/path/to/output/file
|
||||
```
|
||||
|
||||
**TFlite Hand Detection**
|
||||
|
||||
To build the hand detection demo using a TFLite model on desktop, use:
|
||||
|
||||
```
|
||||
bazel build -c opt mediapipe/examples/desktop/hand_tracking:hand_tracking_tflite --define MEDIAPIPE_DISABLE_GPU=1
|
||||
```
|
||||
|
||||
and run it using:
|
||||
|
||||
```
|
||||
export GLOG_logtostderr=1
|
||||
|
||||
bazel-bin/mediapipe/examples/desktop/hand_tracking/hand_tracking_tflite \
|
||||
--calculator_graph_config_file=mediapipe/graphs/hand_tracking/hand_detection_desktop.pbtxt \
|
||||
--input_side_packets=input_video_path=/path/to/input/file,output_video_path=/path/to/output/file
|
||||
```
|
||||
|
||||
**TFlite Hand Tracking**
|
||||
|
||||
To build the hand tracking demo using a TFLite model on desktop, use:
|
||||
|
||||
```
|
||||
bazel build -c opt mediapipe/examples/desktop/hand_tracking:hand_tracking_tflite --define MEDIAPIPE_DISABLE_GPU=1
|
||||
```
|
||||
|
||||
and run it using:
|
||||
|
||||
```
|
||||
export GLOG_logtostderr=1
|
||||
|
||||
bazel-bin/mediapipe/examples/desktop/hand_tracking/hand_tracking_tflite \
|
||||
--calculator_graph_config_file=mediapipe/graphs/hand_tracking/hand_tracking_desktop.pbtxt \
|
||||
--input_side_packets=input_video_path=/path/to/input/file,output_video_path=/path/to/output/file
|
||||
```
|
||||
|
||||
**TFlite Multi-Hand Tracking**
|
||||
|
||||
To build the multi-hand tracking demo using a TFLite model on desktop, use:
|
||||
|
||||
```
|
||||
bazel build -c opt mediapipe/examples/desktop/multi_hand_tracking:multi_hand_tracking_tflite --define MEDIAPIPE_DISABLE_GPU=1
|
||||
```
|
||||
|
||||
and run it using:
|
||||
|
||||
```
|
||||
export GLOG_logtostderr=1
|
||||
|
||||
bazel-bin/mediapipe/examples/desktop/multi_hand_tracking/multi_hand_tracking_tflite \
|
||||
--calculator_graph_config_file=mediapipe/graphs/hand_tracking/multi_hand_tracking_desktop.pbtxt \
|
||||
--input_side_packets=input_video_path=/path/to/input/file,output_video_path=/path/to/output/file
|
||||
```
|
||||
|
||||
To change the number of hands to `x` in this application, change:
|
||||
|
||||
1. `min_size:x` in `CollectionHasMinSizeCalculatorOptions` in `mediapipe/graphs/hand_tracking/multi_hand_tracking_desktop.pbtxt`.
|
||||
2. `max_vec_size:x` in `ClipVectorSizeCalculatorOptions` in `mediapipe/examples/dekstop/hand_tracking/subgraphs/multi_hand_detection_cpu.pbtxt`.
|
||||
This directory contains MediaPipe example applications for desktop. Please see [Solutions](https://solutions.mediapipe.dev)for details.
|
||||
|
|
|
@ -62,8 +62,10 @@ cc_library(
|
|||
"//mediapipe/examples/desktop/autoflip:autoflip_messages_cc_proto",
|
||||
"//mediapipe/examples/desktop/autoflip/quality:kinematic_path_solver",
|
||||
"//mediapipe/framework:calculator_framework",
|
||||
"//mediapipe/framework/formats:detection_cc_proto",
|
||||
"//mediapipe/framework/formats:image_frame",
|
||||
"//mediapipe/framework/formats:image_frame_opencv",
|
||||
"//mediapipe/framework/formats:location_data_cc_proto",
|
||||
"//mediapipe/framework/formats:rect_cc_proto",
|
||||
"//mediapipe/framework/port:ret_check",
|
||||
"//mediapipe/framework/port:status",
|
||||
],
|
||||
|
@ -126,17 +128,20 @@ cc_test(
|
|||
":content_zooming_calculator",
|
||||
":content_zooming_calculator_cc_proto",
|
||||
"//mediapipe/examples/desktop/autoflip:autoflip_messages_cc_proto",
|
||||
"//mediapipe/examples/desktop/autoflip/quality:kinematic_path_solver",
|
||||
"//mediapipe/framework:calculator_framework",
|
||||
"//mediapipe/framework:calculator_runner",
|
||||
"//mediapipe/framework/formats:detection_cc_proto",
|
||||
"//mediapipe/framework/formats:image_frame",
|
||||
"//mediapipe/framework/formats:image_frame_opencv",
|
||||
"//mediapipe/framework/formats:location_data_cc_proto",
|
||||
"//mediapipe/framework/formats:rect_cc_proto",
|
||||
"//mediapipe/framework/port:benchmark",
|
||||
"//mediapipe/framework/port:gtest_main",
|
||||
"//mediapipe/framework/port:opencv_core",
|
||||
"//mediapipe/framework/port:parse_text_proto",
|
||||
"//mediapipe/framework/port:ret_check",
|
||||
"//mediapipe/framework/port:status",
|
||||
"@com_google_absl//absl/strings",
|
||||
],
|
||||
)
|
||||
|
||||
|
|
|
@ -19,16 +19,20 @@
|
|||
#include "mediapipe/examples/desktop/autoflip/calculators/content_zooming_calculator.pb.h"
|
||||
#include "mediapipe/examples/desktop/autoflip/quality/kinematic_path_solver.h"
|
||||
#include "mediapipe/framework/calculator_framework.h"
|
||||
#include "mediapipe/framework/formats/detection.pb.h"
|
||||
#include "mediapipe/framework/formats/image_frame.h"
|
||||
#include "mediapipe/framework/formats/image_frame_opencv.h"
|
||||
#include "mediapipe/framework/formats/location_data.pb.h"
|
||||
#include "mediapipe/framework/formats/rect.pb.h"
|
||||
#include "mediapipe/framework/port/ret_check.h"
|
||||
#include "mediapipe/framework/port/status.h"
|
||||
#include "mediapipe/framework/port/status_builder.h"
|
||||
|
||||
constexpr char kVideoFrame[] = "VIDEO";
|
||||
constexpr char kVideoSize[] = "VIDEO_SIZE";
|
||||
constexpr char kDetectionSet[] = "DETECTIONS";
|
||||
constexpr char kSalientRegions[] = "SALIENT_REGIONS";
|
||||
constexpr char kDetections[] = "DETECTIONS";
|
||||
constexpr char kDetectedBorders[] = "BORDERS";
|
||||
constexpr char kCropRect[] = "CROP_RECT";
|
||||
// Field-of-view (degrees) of the camera's x-axis (width).
|
||||
// TODO: Parameterize FOV based on camera specs.
|
||||
constexpr float kWidthFieldOfView = 60;
|
||||
|
@ -37,12 +41,12 @@ namespace mediapipe {
|
|||
namespace autoflip {
|
||||
|
||||
// Content zooming calculator zooms in on content when a detection has
|
||||
// "only_required" set true. It does this by computing the value of top/bottom
|
||||
// borders to remove from the output and sends these to the
|
||||
// SceneCroppingCalculator. When more than one detections are received the zoom
|
||||
// box is calculated as the union of the detections. Typical applications
|
||||
// include mobile makeover and autofliplive face reframing. Currently only
|
||||
// supports y-dimension zooming.
|
||||
// "only_required" set true or any raw detection input. It does this by
|
||||
// computing the value of top/bottom borders to remove from the output and sends
|
||||
// these to the SceneCroppingCalculator using BORDERS output or a full rect crop
|
||||
// using CROP_RECT output. When more than one detections are received the
|
||||
// zoom box is calculated as the union of the detections. Typical applications
|
||||
// include mobile makeover and autofliplive face reframing.
|
||||
class ContentZoomingCalculator : public CalculatorBase {
|
||||
public:
|
||||
ContentZoomingCalculator()
|
||||
|
@ -56,26 +60,32 @@ class ContentZoomingCalculator : public CalculatorBase {
|
|||
::mediapipe::Status Process(mediapipe::CalculatorContext* cc) override;
|
||||
|
||||
private:
|
||||
// Converts bounds to tilt offset and height.
|
||||
::mediapipe::Status ConvertToTiltZoom(float xmin, float xmax, float ymin,
|
||||
// Converts bounds to tilt offset, pan offset and height.
|
||||
::mediapipe::Status ConvertToPanTiltZoom(float xmin, float xmax, float ymin,
|
||||
float ymax, int* tilt_offset,
|
||||
int* height);
|
||||
int* pan_offset, int* height);
|
||||
ContentZoomingCalculatorOptions options_;
|
||||
// Detection frame width/height.
|
||||
int frame_height_;
|
||||
int frame_width_;
|
||||
// Path solver used to smooth top/bottom border crop values.
|
||||
std::unique_ptr<KinematicPathSolver> path_solver_height_;
|
||||
std::unique_ptr<KinematicPathSolver> path_solver_width_;
|
||||
std::unique_ptr<KinematicPathSolver> path_solver_offset_;
|
||||
// Are parameters initialized.
|
||||
bool initialized_;
|
||||
// Stores the time of the last "only_required" input.
|
||||
int64 last_only_required_detection_;
|
||||
// Border values of last message with detection.
|
||||
// Rect values of last message with detection(s).
|
||||
int last_measured_height_;
|
||||
int last_measured_x_offset_;
|
||||
int last_measured_y_offset_;
|
||||
// Min border values.
|
||||
float min_height_value_;
|
||||
// Target aspect ratio.
|
||||
float target_aspect_;
|
||||
// Max size of bounding box. If input/output aspect ratios are the same,
|
||||
// will be 1.0. Else, will be less than 1.0 to prevent exceeding the size of
|
||||
// the image in either dimension.
|
||||
float max_frame_value_;
|
||||
};
|
||||
REGISTER_CALCULATOR(ContentZoomingCalculator);
|
||||
|
||||
|
@ -92,8 +102,18 @@ REGISTER_CALCULATOR(ContentZoomingCalculator);
|
|||
return ::mediapipe::UnknownErrorBuilder(MEDIAPIPE_LOC)
|
||||
<< "Input VIDEO or VIDEO_SIZE must be provided.";
|
||||
}
|
||||
cc->Inputs().Tag(kDetectionSet).Set<DetectionSet>();
|
||||
if (cc->Inputs().HasTag(kSalientRegions)) {
|
||||
cc->Inputs().Tag(kSalientRegions).Set<DetectionSet>();
|
||||
}
|
||||
if (cc->Inputs().HasTag(kDetections)) {
|
||||
cc->Inputs().Tag(kDetections).Set<std::vector<mediapipe::Detection>>();
|
||||
}
|
||||
if (cc->Outputs().HasTag(kDetectedBorders)) {
|
||||
cc->Outputs().Tag(kDetectedBorders).Set<StaticFeatures>();
|
||||
}
|
||||
if (cc->Outputs().HasTag(kCropRect)) {
|
||||
cc->Outputs().Tag(kCropRect).Set<mediapipe::Rect>();
|
||||
}
|
||||
return ::mediapipe::OkStatus();
|
||||
}
|
||||
|
||||
|
@ -108,29 +128,38 @@ REGISTER_CALCULATOR(ContentZoomingCalculator);
|
|||
if (options_.has_min_motion_to_reframe()) {
|
||||
return ::mediapipe::UnknownErrorBuilder(MEDIAPIPE_LOC)
|
||||
<< "Deprecated min_motion_to_reframe was set, please set "
|
||||
"in kinematic_options_zoom and kinematic_options_tilt directly.";
|
||||
"in kinematic_options_zoom and kinematic_options_tilt "
|
||||
"directly.";
|
||||
}
|
||||
return ::mediapipe::OkStatus();
|
||||
}
|
||||
|
||||
::mediapipe::Status ContentZoomingCalculator::ConvertToTiltZoom(
|
||||
::mediapipe::Status ContentZoomingCalculator::ConvertToPanTiltZoom(
|
||||
float xmin, float xmax, float ymin, float ymax, int* tilt_offset,
|
||||
int* height) {
|
||||
int* pan_offset, int* height) {
|
||||
// Find center of the y-axis offset (for tilt control).
|
||||
float y_center = ymin + (ymax - ymin) / 2;
|
||||
// Find center of the x-axis offset (for pan control).
|
||||
float x_center = xmin + (xmax - xmin) / 2;
|
||||
// Find size and apply scale factor to y-axis.
|
||||
float fit_size = fmax((ymax - ymin) / options_.scale_factor(), xmax - xmin);
|
||||
// Apply min zoom for cases where the target size is wider than input frame
|
||||
// size.
|
||||
fit_size = fmin(min_height_value_, fit_size);
|
||||
// Apply max frame for cases where the target size is different than input
|
||||
// frame size.
|
||||
fit_size = fmin(max_frame_value_, fit_size);
|
||||
// Prevent box from extending beyond the image.
|
||||
if (y_center - fit_size / 2 < 0) {
|
||||
y_center = fit_size / 2;
|
||||
} else if (y_center + fit_size / 2 > 1) {
|
||||
y_center = 1 - fit_size / 2;
|
||||
}
|
||||
if (x_center - fit_size / 2 < 0) {
|
||||
x_center = fit_size / 2;
|
||||
} else if (x_center + fit_size / 2 > 1) {
|
||||
x_center = 1 - fit_size / 2;
|
||||
}
|
||||
// Scale to pixel coordinates.
|
||||
*tilt_offset = frame_height_ * y_center;
|
||||
*pan_offset = frame_width_ * x_center;
|
||||
*height = frame_height_ * fit_size;
|
||||
return ::mediapipe::OkStatus();
|
||||
}
|
||||
|
@ -151,6 +180,20 @@ namespace {
|
|||
|
||||
return ::mediapipe::OkStatus();
|
||||
}
|
||||
::mediapipe::Status UpdateRanges(const mediapipe::Detection& detection,
|
||||
float* xmin, float* xmax, float* ymin,
|
||||
float* ymax) {
|
||||
RET_CHECK(detection.location_data().format() ==
|
||||
mediapipe::LocationData::RELATIVE_BOUNDING_BOX)
|
||||
<< "Face detection input is lacking required relative_bounding_box()";
|
||||
const auto& location = detection.location_data().relative_bounding_box();
|
||||
*xmin = fmin(*xmin, location.xmin());
|
||||
*xmax = fmax(*xmax, location.xmin() + location.width());
|
||||
*ymin = fmin(*ymin, location.ymin());
|
||||
*ymax = fmax(*ymax, location.ymin() + location.height());
|
||||
|
||||
return ::mediapipe::OkStatus();
|
||||
}
|
||||
void MakeStaticFeatures(const int top_border, const int bottom_border,
|
||||
const int frame_width, const int frame_height,
|
||||
StaticFeatures* static_feature) {
|
||||
|
@ -173,10 +216,8 @@ void MakeStaticFeatures(const int top_border, const int bottom_border,
|
|||
::mediapipe::Status ContentZoomingCalculator::Process(
|
||||
mediapipe::CalculatorContext* cc) {
|
||||
if (cc->Inputs().HasTag(kVideoFrame)) {
|
||||
cv::Mat frame = mediapipe::formats::MatView(
|
||||
&cc->Inputs().Tag(kVideoFrame).Get<ImageFrame>());
|
||||
frame_width_ = frame.cols;
|
||||
frame_height_ = frame.rows;
|
||||
frame_width_ = cc->Inputs().Tag(kVideoFrame).Get<ImageFrame>().Width();
|
||||
frame_height_ = cc->Inputs().Tag(kVideoFrame).Get<ImageFrame>().Height();
|
||||
} else if (cc->Inputs().HasTag(kVideoSize)) {
|
||||
frame_width_ =
|
||||
cc->Inputs().Tag(kVideoSize).Get<std::pair<int, int>>().first;
|
||||
|
@ -191,10 +232,14 @@ void MakeStaticFeatures(const int top_border, const int bottom_border,
|
|||
path_solver_height_ = std::make_unique<KinematicPathSolver>(
|
||||
options_.kinematic_options_zoom(), 0, frame_height_,
|
||||
static_cast<float>(frame_width_) / kWidthFieldOfView);
|
||||
path_solver_width_ = std::make_unique<KinematicPathSolver>(
|
||||
options_.kinematic_options_pan(), 0, frame_width_,
|
||||
static_cast<float>(frame_width_) / kWidthFieldOfView);
|
||||
path_solver_offset_ = std::make_unique<KinematicPathSolver>(
|
||||
options_.kinematic_options_tilt(), 0, frame_height_,
|
||||
static_cast<float>(frame_width_) / kWidthFieldOfView);
|
||||
min_height_value_ = 1.0;
|
||||
max_frame_value_ = 1.0;
|
||||
target_aspect_ = frame_width_ / static_cast<float>(frame_height_);
|
||||
// If target size is set and wider than input aspect, make sure to always
|
||||
// crop the min required amount.
|
||||
if (options_.has_target_size()) {
|
||||
|
@ -203,21 +248,23 @@ void MakeStaticFeatures(const int top_border, const int bottom_border,
|
|||
RET_CHECK_GT(options_.target_size().height(), 0)
|
||||
<< "Provided target height not valid.";
|
||||
float input_aspect = frame_width_ / static_cast<float>(frame_height_);
|
||||
float target_aspect = options_.target_size().width() /
|
||||
target_aspect_ = options_.target_size().width() /
|
||||
static_cast<float>(options_.target_size().height());
|
||||
min_height_value_ =
|
||||
(input_aspect < target_aspect) ? input_aspect / target_aspect : 1.0;
|
||||
max_frame_value_ = std::min(input_aspect / target_aspect_,
|
||||
target_aspect_ / input_aspect);
|
||||
}
|
||||
last_measured_height_ = min_height_value_ * frame_height_;
|
||||
last_measured_height_ = max_frame_value_ * frame_height_;
|
||||
last_measured_x_offset_ = target_aspect_ * frame_width_;
|
||||
last_measured_y_offset_ = frame_width_ / 2;
|
||||
initialized_ = true;
|
||||
}
|
||||
|
||||
auto detection_set = cc->Inputs().Tag(kDetectionSet).Get<DetectionSet>();
|
||||
bool only_required_found = false;
|
||||
|
||||
// Compute the box that contains all "is_required" detections.
|
||||
float xmin = 1, ymin = 1, xmax = 0, ymax = 0;
|
||||
if (cc->Inputs().HasTag(kSalientRegions)) {
|
||||
auto detection_set = cc->Inputs().Tag(kSalientRegions).Get<DetectionSet>();
|
||||
for (const auto& region : detection_set.detections()) {
|
||||
if (!region.only_required()) {
|
||||
continue;
|
||||
|
@ -225,46 +272,64 @@ void MakeStaticFeatures(const int top_border, const int bottom_border,
|
|||
only_required_found = true;
|
||||
MP_RETURN_IF_ERROR(UpdateRanges(region, &xmin, &xmax, &ymin, &ymax));
|
||||
}
|
||||
}
|
||||
|
||||
if (cc->Inputs().HasTag(kDetections)) {
|
||||
auto raw_detections =
|
||||
cc->Inputs().Tag(kDetections).Get<std::vector<mediapipe::Detection>>();
|
||||
for (const auto& detection : raw_detections) {
|
||||
only_required_found = true;
|
||||
MP_RETURN_IF_ERROR(UpdateRanges(detection, &xmin, &xmax, &ymin, &ymax));
|
||||
}
|
||||
}
|
||||
|
||||
// Convert bounds to tilt/zoom and in pixel coordinates.
|
||||
int offset, height;
|
||||
MP_RETURN_IF_ERROR(
|
||||
ConvertToTiltZoom(xmin, xmax, ymin, ymax, &offset, &height));
|
||||
int offset_y, height, offset_x;
|
||||
MP_RETURN_IF_ERROR(ConvertToPanTiltZoom(xmin, xmax, ymin, ymax, &offset_y,
|
||||
&offset_x, &height));
|
||||
|
||||
if (only_required_found) {
|
||||
// A only required detection was found.
|
||||
last_only_required_detection_ = cc->InputTimestamp().Microseconds();
|
||||
last_measured_height_ = height;
|
||||
last_measured_y_offset_ = offset;
|
||||
last_measured_x_offset_ = offset_x;
|
||||
last_measured_y_offset_ = offset_y;
|
||||
} else if (cc->InputTimestamp().Microseconds() -
|
||||
last_only_required_detection_ >=
|
||||
options_.us_before_zoomout()) {
|
||||
// No only_require detections found within salient regions packets arriving
|
||||
// since us_before_zoomout duration.
|
||||
height = min_height_value_ * frame_height_;
|
||||
offset = frame_height_ / 2;
|
||||
// No only_require detections found within salient regions packets
|
||||
// arriving since us_before_zoomout duration.
|
||||
height = max_frame_value_ * frame_height_;
|
||||
offset_x = (target_aspect_ * height) / 2;
|
||||
offset_y = frame_height_ / 2;
|
||||
} else {
|
||||
// No only detection found but using last detection due to
|
||||
// duration_before_zoomout_us setting.
|
||||
height = last_measured_height_;
|
||||
offset = last_measured_y_offset_;
|
||||
offset_x = last_measured_x_offset_;
|
||||
offset_y = last_measured_y_offset_;
|
||||
}
|
||||
|
||||
// Compute smoothed camera paths.
|
||||
MP_RETURN_IF_ERROR(path_solver_height_->AddObservation(
|
||||
height, cc->InputTimestamp().Microseconds()));
|
||||
MP_RETURN_IF_ERROR(path_solver_width_->AddObservation(
|
||||
offset_x, cc->InputTimestamp().Microseconds()));
|
||||
MP_RETURN_IF_ERROR(path_solver_offset_->AddObservation(
|
||||
offset, cc->InputTimestamp().Microseconds()));
|
||||
int path_size;
|
||||
MP_RETURN_IF_ERROR(path_solver_height_->GetState(&path_size));
|
||||
int path_offset;
|
||||
MP_RETURN_IF_ERROR(path_solver_offset_->GetState(&path_offset));
|
||||
offset_y, cc->InputTimestamp().Microseconds()));
|
||||
int path_height;
|
||||
MP_RETURN_IF_ERROR(path_solver_height_->GetState(&path_height));
|
||||
int path_offset_x;
|
||||
MP_RETURN_IF_ERROR(path_solver_width_->GetState(&path_offset_x));
|
||||
int path_offset_y;
|
||||
MP_RETURN_IF_ERROR(path_solver_offset_->GetState(&path_offset_y));
|
||||
|
||||
// Convert to top/bottom borders to remove.
|
||||
int path_top = path_offset - path_size / 2;
|
||||
int path_bottom = frame_height_ - (path_offset + path_size / 2);
|
||||
int path_top = path_offset_y - path_height / 2;
|
||||
int path_bottom = frame_height_ - (path_offset_y + path_height / 2);
|
||||
|
||||
// Transmit result downstream.
|
||||
// Transmit result downstream to scenecroppingcalculator.
|
||||
if (cc->Outputs().HasTag(kDetectedBorders)) {
|
||||
std::unique_ptr<StaticFeatures> features =
|
||||
absl::make_unique<StaticFeatures>();
|
||||
MakeStaticFeatures(path_top, path_bottom, frame_width_, frame_height_,
|
||||
|
@ -272,6 +337,18 @@ void MakeStaticFeatures(const int top_border, const int bottom_border,
|
|||
cc->Outputs()
|
||||
.Tag(kDetectedBorders)
|
||||
.AddPacket(Adopt(features.release()).At(cc->InputTimestamp()));
|
||||
}
|
||||
|
||||
// Transmit downstream to glcroppingcalculator.
|
||||
if (cc->Outputs().HasTag(kCropRect)) {
|
||||
auto gpu_rect = absl::make_unique<mediapipe::Rect>();
|
||||
gpu_rect->set_x_center(path_offset_x);
|
||||
gpu_rect->set_width(path_height * target_aspect_);
|
||||
gpu_rect->set_y_center(path_offset_y);
|
||||
gpu_rect->set_height(path_height);
|
||||
cc->Outputs().Tag(kCropRect).Add(gpu_rect.release(),
|
||||
Timestamp(cc->InputTimestamp()));
|
||||
}
|
||||
|
||||
return ::mediapipe::OkStatus();
|
||||
}
|
||||
|
|
|
@ -32,6 +32,8 @@ message ContentZoomingCalculatorOptions {
|
|||
optional KinematicOptions kinematic_options_zoom = 6;
|
||||
// Kinematic options for tilt (y-axis reframing.)
|
||||
optional KinematicOptions kinematic_options_tilt = 7;
|
||||
// Kinematic options for pan (x-axis reframing.)
|
||||
optional KinematicOptions kinematic_options_pan = 10;
|
||||
// Duration (in MicroSeconds) before returning to fully zoomed out position
|
||||
// when no "only_required" frames are received.
|
||||
optional int64 us_before_zoomout = 9 [default = 1000000];
|
||||
|
|
|
@ -16,10 +16,14 @@
|
|||
|
||||
#include "mediapipe/examples/desktop/autoflip/autoflip_messages.pb.h"
|
||||
#include "mediapipe/examples/desktop/autoflip/calculators/content_zooming_calculator.pb.h"
|
||||
#include "mediapipe/examples/desktop/autoflip/quality/kinematic_path_solver.h"
|
||||
#include "mediapipe/framework/calculator_framework.h"
|
||||
#include "mediapipe/framework/calculator_runner.h"
|
||||
#include "mediapipe/framework/formats/detection.pb.h"
|
||||
#include "mediapipe/framework/formats/image_frame.h"
|
||||
#include "mediapipe/framework/formats/image_frame_opencv.h"
|
||||
#include "mediapipe/framework/formats/location_data.pb.h"
|
||||
#include "mediapipe/framework/formats/rect.pb.h"
|
||||
#include "mediapipe/framework/port/benchmark.h"
|
||||
#include "mediapipe/framework/port/gmock.h"
|
||||
#include "mediapipe/framework/port/gtest.h"
|
||||
|
@ -36,14 +40,14 @@ namespace {
|
|||
const char kConfigA[] = R"(
|
||||
calculator: "ContentZoomingCalculator"
|
||||
input_stream: "VIDEO:camera_frames"
|
||||
input_stream: "DETECTIONS:detection_set"
|
||||
input_stream: "SALIENT_REGIONS:detection_set"
|
||||
output_stream: "BORDERS:borders"
|
||||
)";
|
||||
|
||||
const char kConfigB[] = R"(
|
||||
calculator: "ContentZoomingCalculator"
|
||||
input_stream: "VIDEO:camera_frames"
|
||||
input_stream: "DETECTIONS:detection_set"
|
||||
input_stream: "SALIENT_REGIONS:detection_set"
|
||||
output_stream: "BORDERS:borders"
|
||||
options: {
|
||||
[mediapipe.autoflip.ContentZoomingCalculatorOptions.ext]: {
|
||||
|
@ -58,10 +62,17 @@ const char kConfigB[] = R"(
|
|||
const char kConfigC[] = R"(
|
||||
calculator: "ContentZoomingCalculator"
|
||||
input_stream: "VIDEO_SIZE:size"
|
||||
input_stream: "DETECTIONS:detection_set"
|
||||
input_stream: "SALIENT_REGIONS:detection_set"
|
||||
output_stream: "BORDERS:borders"
|
||||
)";
|
||||
|
||||
const char kConfigD[] = R"(
|
||||
calculator: "ContentZoomingCalculator"
|
||||
input_stream: "VIDEO_SIZE:size"
|
||||
input_stream: "DETECTIONS:detections"
|
||||
output_stream: "CROP_RECT:rect"
|
||||
)";
|
||||
|
||||
void CheckBorder(const StaticFeatures& static_features, int width, int height,
|
||||
int top_border, int bottom_border) {
|
||||
ASSERT_EQ(2, static_features.border().size());
|
||||
|
@ -80,6 +91,43 @@ void CheckBorder(const StaticFeatures& static_features, int width, int height,
|
|||
EXPECT_EQ(Border::BOTTOM, part.relative_position());
|
||||
}
|
||||
|
||||
void AddDetection(const cv::Rect_<float>& position, const int64 time,
|
||||
CalculatorRunner* runner) {
|
||||
auto detections = std::make_unique<std::vector<mediapipe::Detection>>();
|
||||
mediapipe::Detection detection;
|
||||
detection.mutable_location_data()->set_format(
|
||||
mediapipe::LocationData::RELATIVE_BOUNDING_BOX);
|
||||
detection.mutable_location_data()
|
||||
->mutable_relative_bounding_box()
|
||||
->set_height(position.height);
|
||||
detection.mutable_location_data()->mutable_relative_bounding_box()->set_width(
|
||||
position.width);
|
||||
detection.mutable_location_data()->mutable_relative_bounding_box()->set_xmin(
|
||||
position.x);
|
||||
detection.mutable_location_data()->mutable_relative_bounding_box()->set_ymin(
|
||||
position.y);
|
||||
detections->push_back(detection);
|
||||
runner->MutableInputs()
|
||||
->Tag("DETECTIONS")
|
||||
.packets.push_back(Adopt(detections.release()).At(Timestamp(time)));
|
||||
|
||||
auto input_size = ::absl::make_unique<std::pair<int, int>>(1000, 1000);
|
||||
runner->MutableInputs()
|
||||
->Tag("VIDEO_SIZE")
|
||||
.packets.push_back(Adopt(input_size.release()).At(Timestamp(time)));
|
||||
}
|
||||
|
||||
void CheckCropRect(const int x_center, const int y_center, const int width,
|
||||
const int height, const int frame_number,
|
||||
const std::vector<Packet>& output_packets) {
|
||||
ASSERT_GT(output_packets.size(), frame_number);
|
||||
const auto& rect = output_packets[frame_number].Get<mediapipe::Rect>();
|
||||
EXPECT_EQ(rect.x_center(), x_center);
|
||||
EXPECT_EQ(rect.y_center(), y_center);
|
||||
EXPECT_EQ(rect.width(), width);
|
||||
EXPECT_EQ(rect.height(), height);
|
||||
}
|
||||
|
||||
TEST(ContentZoomingCalculatorTest, ZoomTest) {
|
||||
auto runner = ::absl::make_unique<CalculatorRunner>(
|
||||
ParseTextProtoOrDie<CalculatorGraphConfig::Node>(kConfigA));
|
||||
|
@ -98,7 +146,7 @@ TEST(ContentZoomingCalculatorTest, ZoomTest) {
|
|||
Adopt(input_frame.release()).At(Timestamp(0)));
|
||||
|
||||
runner->MutableInputs()
|
||||
->Tag("DETECTIONS")
|
||||
->Tag("SALIENT_REGIONS")
|
||||
.packets.push_back(Adopt(detection_set.release()).At(Timestamp(0)));
|
||||
|
||||
// Run the calculator.
|
||||
|
@ -111,6 +159,66 @@ TEST(ContentZoomingCalculatorTest, ZoomTest) {
|
|||
CheckBorder(static_features, 1000, 1000, 495, 395);
|
||||
}
|
||||
|
||||
TEST(ContentZoomingCalculatorTest, ZoomTestFullPTZ) {
|
||||
auto runner = ::absl::make_unique<CalculatorRunner>(
|
||||
ParseTextProtoOrDie<CalculatorGraphConfig::Node>(kConfigD));
|
||||
AddDetection(cv::Rect_<float>(.4, .5, .1, .1), 0, runner.get());
|
||||
MP_ASSERT_OK(runner->Run());
|
||||
CheckCropRect(450, 550, 111, 111, 0,
|
||||
runner->Outputs().Tag("CROP_RECT").packets);
|
||||
}
|
||||
|
||||
TEST(ContentZoomingCalculatorTest, PanConfig) {
|
||||
auto config = ParseTextProtoOrDie<CalculatorGraphConfig::Node>(kConfigD);
|
||||
auto* options = config.mutable_options()->MutableExtension(
|
||||
ContentZoomingCalculatorOptions::ext);
|
||||
options->mutable_kinematic_options_pan()->set_min_motion_to_reframe(0.0);
|
||||
options->mutable_kinematic_options_tilt()->set_min_motion_to_reframe(5.0);
|
||||
options->mutable_kinematic_options_zoom()->set_min_motion_to_reframe(5.0);
|
||||
auto runner = ::absl::make_unique<CalculatorRunner>(config);
|
||||
AddDetection(cv::Rect_<float>(.4, .5, .1, .1), 0, runner.get());
|
||||
AddDetection(cv::Rect_<float>(.45, .55, .15, .15), 1000000, runner.get());
|
||||
MP_ASSERT_OK(runner->Run());
|
||||
CheckCropRect(450, 550, 111, 111, 0,
|
||||
runner->Outputs().Tag("CROP_RECT").packets);
|
||||
CheckCropRect(488, 550, 111, 111, 1,
|
||||
runner->Outputs().Tag("CROP_RECT").packets);
|
||||
}
|
||||
|
||||
TEST(ContentZoomingCalculatorTest, TiltConfig) {
|
||||
auto config = ParseTextProtoOrDie<CalculatorGraphConfig::Node>(kConfigD);
|
||||
auto* options = config.mutable_options()->MutableExtension(
|
||||
ContentZoomingCalculatorOptions::ext);
|
||||
options->mutable_kinematic_options_pan()->set_min_motion_to_reframe(5.0);
|
||||
options->mutable_kinematic_options_tilt()->set_min_motion_to_reframe(0.0);
|
||||
options->mutable_kinematic_options_zoom()->set_min_motion_to_reframe(5.0);
|
||||
auto runner = ::absl::make_unique<CalculatorRunner>(config);
|
||||
AddDetection(cv::Rect_<float>(.4, .5, .1, .1), 0, runner.get());
|
||||
AddDetection(cv::Rect_<float>(.45, .55, .15, .15), 1000000, runner.get());
|
||||
MP_ASSERT_OK(runner->Run());
|
||||
CheckCropRect(450, 550, 111, 111, 0,
|
||||
runner->Outputs().Tag("CROP_RECT").packets);
|
||||
CheckCropRect(450, 588, 111, 111, 1,
|
||||
runner->Outputs().Tag("CROP_RECT").packets);
|
||||
}
|
||||
|
||||
TEST(ContentZoomingCalculatorTest, ZoomConfig) {
|
||||
auto config = ParseTextProtoOrDie<CalculatorGraphConfig::Node>(kConfigD);
|
||||
auto* options = config.mutable_options()->MutableExtension(
|
||||
ContentZoomingCalculatorOptions::ext);
|
||||
options->mutable_kinematic_options_pan()->set_min_motion_to_reframe(5.0);
|
||||
options->mutable_kinematic_options_tilt()->set_min_motion_to_reframe(5.0);
|
||||
options->mutable_kinematic_options_zoom()->set_min_motion_to_reframe(0.0);
|
||||
auto runner = ::absl::make_unique<CalculatorRunner>(config);
|
||||
AddDetection(cv::Rect_<float>(.4, .5, .1, .1), 0, runner.get());
|
||||
AddDetection(cv::Rect_<float>(.45, .55, .15, .15), 1000000, runner.get());
|
||||
MP_ASSERT_OK(runner->Run());
|
||||
CheckCropRect(450, 550, 111, 111, 0,
|
||||
runner->Outputs().Tag("CROP_RECT").packets);
|
||||
CheckCropRect(450, 550, 139, 139, 1,
|
||||
runner->Outputs().Tag("CROP_RECT").packets);
|
||||
}
|
||||
|
||||
TEST(ContentZoomingCalculatorTest, MinAspectBorderValues) {
|
||||
auto runner = ::absl::make_unique<CalculatorRunner>(
|
||||
ParseTextProtoOrDie<CalculatorGraphConfig::Node>(kConfigB));
|
||||
|
@ -129,7 +237,7 @@ TEST(ContentZoomingCalculatorTest, MinAspectBorderValues) {
|
|||
Adopt(input_frame.release()).At(Timestamp(0)));
|
||||
|
||||
runner->MutableInputs()
|
||||
->Tag("DETECTIONS")
|
||||
->Tag("SALIENT_REGIONS")
|
||||
.packets.push_back(Adopt(detection_set.release()).At(Timestamp(0)));
|
||||
|
||||
// Run the calculator.
|
||||
|
@ -166,7 +274,7 @@ TEST(ContentZoomingCalculatorTest, TwoFacesWide) {
|
|||
Adopt(input_frame.release()).At(Timestamp(0)));
|
||||
|
||||
runner->MutableInputs()
|
||||
->Tag("DETECTIONS")
|
||||
->Tag("SALIENT_REGIONS")
|
||||
.packets.push_back(Adopt(detection_set.release()).At(Timestamp(0)));
|
||||
|
||||
// Run the calculator.
|
||||
|
@ -191,7 +299,7 @@ TEST(ContentZoomingCalculatorTest, NoDetectionOnInit) {
|
|||
Adopt(input_frame.release()).At(Timestamp(0)));
|
||||
|
||||
runner->MutableInputs()
|
||||
->Tag("DETECTIONS")
|
||||
->Tag("SALIENT_REGIONS")
|
||||
.packets.push_back(Adopt(detection_set.release()).At(Timestamp(0)));
|
||||
|
||||
// Run the calculator.
|
||||
|
@ -223,7 +331,7 @@ TEST(ContentZoomingCalculatorTest, ZoomTestPairSize) {
|
|||
.packets.push_back(Adopt(input_size.release()).At(Timestamp(0)));
|
||||
|
||||
runner->MutableInputs()
|
||||
->Tag("DETECTIONS")
|
||||
->Tag("SALIENT_REGIONS")
|
||||
.packets.push_back(Adopt(detection_set.release()).At(Timestamp(0)));
|
||||
|
||||
// Run the calculator.
|
||||
|
|
|
@ -37,7 +37,7 @@ node {
|
|||
output_stream: "TENSORS:detection_tensors"
|
||||
options: {
|
||||
[mediapipe.TfLiteInferenceCalculatorOptions.ext] {
|
||||
model_path: "face_detection_front.tflite"
|
||||
model_path: "mediapipe/models/face_detection_front.tflite"
|
||||
}
|
||||
}
|
||||
}
|
||||
|
@ -118,7 +118,7 @@ node {
|
|||
output_stream: "labeled_detections"
|
||||
options: {
|
||||
[mediapipe.DetectionLabelIdToTextCalculatorOptions.ext] {
|
||||
label_map_path: "face_detection_front_labelmap.txt"
|
||||
label_map_path: "mediapipe/models/face_detection_front_labelmap.txt"
|
||||
}
|
||||
}
|
||||
}
|
||||
|
|
|
@ -1,18 +1 @@
|
|||
This directory contains example MediaPipe applications on iOS.
|
||||
|
||||
| Use Case | Directory |
|
||||
|---------------------------------------|:-----------------------------------:|
|
||||
| Edge Detection on GPU | edgedetection |
|
||||
| Face Detection on CPU | facedetectioncpu |
|
||||
| Face Detection on GPU | facedetectiongpu |
|
||||
| Object Detection on CPU | objectdetectioncpu |
|
||||
| Object Detection on GPU | objectdetectiongpu |
|
||||
| Hand Detection on GPU | handdetectiongpu |
|
||||
| Hand Tracking on GPU | handtrackinggpu |
|
||||
|
||||
For instance, to build an example app for face detection on CPU, run:
|
||||
|
||||
```bash
|
||||
bazel build -c opt --config=ios_arm64 --xcode_version=$XCODE_VERSION --cxxopt='-std=c++14' mediapipe/examples/ios/facedetectioncpu:FaceDetectionCpuApp
|
||||
```
|
||||
(Note: with your own $XCODE_VERSION)
|
||||
This directory contains MediaPipe example applications for iOS. Please see [Solutions](https://solutions.mediapipe.dev)for details.
|
||||
|
|
|
@ -21,6 +21,11 @@ load(
|
|||
"ios_application",
|
||||
)
|
||||
|
||||
alias(
|
||||
name = "edgedetectiongpu",
|
||||
actual = "EdgeDetectionGpuApp",
|
||||
)
|
||||
|
||||
ios_application(
|
||||
name = "EdgeDetectionGpuApp",
|
||||
bundle_id = "com.google.mediapipe.EdgeDetectionGpu",
|
||||
|
|
|
@ -21,6 +21,11 @@ load(
|
|||
"ios_application",
|
||||
)
|
||||
|
||||
alias(
|
||||
name = "facedetectioncpu",
|
||||
actual = "FaceDetectionCpuApp",
|
||||
)
|
||||
|
||||
ios_application(
|
||||
name = "FaceDetectionCpuApp",
|
||||
bundle_id = "com.google.mediapipe.FaceDetectionCpu",
|
||||
|
|
|
@ -21,6 +21,11 @@ load(
|
|||
"ios_application",
|
||||
)
|
||||
|
||||
alias(
|
||||
name = "facedetectiongpu",
|
||||
actual = "FaceDetectionGpuApp",
|
||||
)
|
||||
|
||||
ios_application(
|
||||
name = "FaceDetectionGpuApp",
|
||||
bundle_id = "com.google.mediapipe.FaceDetectionGpu",
|
||||
|
|
|
@ -21,6 +21,11 @@ licenses(["notice"]) # Apache 2.0
|
|||
|
||||
MIN_IOS_VERSION = "10.0"
|
||||
|
||||
alias(
|
||||
name = "facemeshgpu",
|
||||
actual = "FaceMeshGpuApp",
|
||||
)
|
||||
|
||||
ios_application(
|
||||
name = "FaceMeshGpuApp",
|
||||
bundle_id = "com.google.mediapipe.FaceMeshGpu",
|
||||
|
|
|
@ -21,6 +21,11 @@ load(
|
|||
"ios_application",
|
||||
)
|
||||
|
||||
alias(
|
||||
name = "handdetectiongpu",
|
||||
actual = "HandDetectionGpuApp",
|
||||
)
|
||||
|
||||
ios_application(
|
||||
name = "HandDetectionGpuApp",
|
||||
bundle_id = "com.google.mediapipe.HandDetectionGpu",
|
||||
|
|
|
@ -21,6 +21,11 @@ licenses(["notice"]) # Apache 2.0
|
|||
|
||||
MIN_IOS_VERSION = "10.0"
|
||||
|
||||
alias(
|
||||
name = "handtrackinggpu",
|
||||
actual = "HandTrackingGpuApp",
|
||||
)
|
||||
|
||||
ios_application(
|
||||
name = "HandTrackingGpuApp",
|
||||
bundle_id = "com.google.mediapipe.HandTrackingGpu",
|
||||
|
|
|
@ -21,6 +21,11 @@ licenses(["notice"]) # Apache 2.0
|
|||
|
||||
MIN_IOS_VERSION = "10.0"
|
||||
|
||||
alias(
|
||||
name = "multihandtrackinggpu",
|
||||
actual = "MultiHandTrackingGpuApp",
|
||||
)
|
||||
|
||||
ios_application(
|
||||
name = "MultiHandTrackingGpuApp",
|
||||
bundle_id = "com.google.mediapipe.MultiHandTrackingGpu",
|
||||
|
|
|
@ -21,6 +21,11 @@ load(
|
|||
"ios_application",
|
||||
)
|
||||
|
||||
alias(
|
||||
name = "objectdetectioncpu",
|
||||
actual = "ObjectDetectionCpuApp",
|
||||
)
|
||||
|
||||
ios_application(
|
||||
name = "ObjectDetectionCpuApp",
|
||||
bundle_id = "com.google.mediapipe.ObjectDetectionCpu",
|
||||
|
|
|
@ -13,6 +13,7 @@
|
|||
// limitations under the License.
|
||||
|
||||
#import "AppDelegate.h"
|
||||
#import "ViewController.h"
|
||||
|
||||
@interface AppDelegate ()
|
||||
|
||||
|
@ -22,7 +23,14 @@
|
|||
|
||||
- (BOOL)application:(UIApplication *)application
|
||||
didFinishLaunchingWithOptions:(NSDictionary *)launchOptions {
|
||||
// Override point for customization after application launch.
|
||||
ViewController *viewController = (ViewController *)self.window.rootViewController;
|
||||
NSURL *url = [launchOptions objectForKey:UIApplicationLaunchOptionsURLKey];
|
||||
// Unattended testing on Firebase is enabled by custom URL schema.
|
||||
if ([url.scheme isEqualToString:@"firebase-game-loop"]) {
|
||||
[viewController setSourceMode:MediaPipeDemoSourceVideo];
|
||||
} else {
|
||||
[viewController setSourceMode:MediaPipeDemoSourceBackCamera];
|
||||
}
|
||||
return YES;
|
||||
}
|
||||
|
||||
|
|
|
@ -21,6 +21,11 @@ load(
|
|||
"ios_application",
|
||||
)
|
||||
|
||||
alias(
|
||||
name = "objectdetectiongpu",
|
||||
actual = "ObjectDetectionGpuApp",
|
||||
)
|
||||
|
||||
ios_application(
|
||||
name = "ObjectDetectionGpuApp",
|
||||
bundle_id = "com.google.mediapipe.ObjectDetectionGpu",
|
||||
|
|
|
@ -38,5 +38,18 @@
|
|||
<array>
|
||||
<string>UIInterfaceOrientationPortrait</string>
|
||||
</array>
|
||||
<key>CFBundleURLTypes</key>
|
||||
<array>
|
||||
<dict>
|
||||
<key>CFBundleURLName</key>
|
||||
<string>com.google.firebase</string>
|
||||
<key>CFBundleTypeRole</key>
|
||||
<string>Editor</string>
|
||||
<key>CFBundleURLSchemes</key>
|
||||
<array>
|
||||
<string>firebase-game-loop</string>
|
||||
</array>
|
||||
</dict>
|
||||
</array>
|
||||
</dict>
|
||||
</plist>
|
||||
|
|
|
@ -14,6 +14,11 @@
|
|||
|
||||
#import <UIKit/UIKit.h>
|
||||
|
||||
@interface ViewController : UIViewController
|
||||
typedef NS_ENUM(NSInteger, MediaPipeDemoSourceMode) {
|
||||
MediaPipeDemoSourceBackCamera,
|
||||
MediaPipeDemoSourceVideo
|
||||
};
|
||||
|
||||
@interface ViewController : UIViewController
|
||||
- (void)setSourceMode:(MediaPipeDemoSourceMode)mode;
|
||||
@end
|
||||
|
|
|
@ -17,6 +17,7 @@
|
|||
#import "mediapipe/objc/MPPGraph.h"
|
||||
#import "mediapipe/objc/MPPCameraInputSource.h"
|
||||
#import "mediapipe/objc/MPPLayerRenderer.h"
|
||||
#import "mediapipe/objc/MPPPlayerInputSource.h"
|
||||
|
||||
static NSString* const kGraphName = @"mobile_gpu";
|
||||
|
||||
|
@ -35,6 +36,8 @@ static const char* kVideoQueueLabel = "com.google.mediapipe.example.videoQueue";
|
|||
@implementation ViewController {
|
||||
/// Handles camera access via AVCaptureSession library.
|
||||
MPPCameraInputSource* _cameraSource;
|
||||
MPPPlayerInputSource* _videoSource;
|
||||
MediaPipeDemoSourceMode _sourceMode;
|
||||
|
||||
/// Inform the user when camera is unavailable.
|
||||
IBOutlet UILabel* _noCameraLabel;
|
||||
|
@ -47,6 +50,10 @@ static const char* kVideoQueueLabel = "com.google.mediapipe.example.videoQueue";
|
|||
dispatch_queue_t _videoQueue;
|
||||
}
|
||||
|
||||
- (void)setSourceMode:(MediaPipeDemoSourceMode)mode {
|
||||
_sourceMode = mode;
|
||||
}
|
||||
|
||||
#pragma mark - Cleanup methods
|
||||
|
||||
- (void)dealloc {
|
||||
|
@ -97,13 +104,6 @@ static const char* kVideoQueueLabel = "com.google.mediapipe.example.videoQueue";
|
|||
DISPATCH_QUEUE_SERIAL, QOS_CLASS_USER_INTERACTIVE, /*relative_priority=*/0);
|
||||
_videoQueue = dispatch_queue_create(kVideoQueueLabel, qosAttribute);
|
||||
|
||||
_cameraSource = [[MPPCameraInputSource alloc] init];
|
||||
[_cameraSource setDelegate:self queue:_videoQueue];
|
||||
_cameraSource.sessionPreset = AVCaptureSessionPresetHigh;
|
||||
_cameraSource.cameraPosition = AVCaptureDevicePositionBack;
|
||||
// The frame's native format is rotated with respect to the portrait orientation.
|
||||
_cameraSource.orientation = AVCaptureVideoOrientationPortrait;
|
||||
|
||||
self.mediapipeGraph = [[self class] loadGraphFromResource:kGraphName];
|
||||
self.mediapipeGraph.delegate = self;
|
||||
// Set maxFramesInFlight to a small value to avoid memory contention for real-time processing.
|
||||
|
@ -119,27 +119,43 @@ static const char* kVideoQueueLabel = "com.google.mediapipe.example.videoQueue";
|
|||
- (void)viewWillAppear:(BOOL)animated {
|
||||
[super viewWillAppear:animated];
|
||||
|
||||
[_cameraSource requestCameraAccessWithCompletionHandler:^void(BOOL granted) {
|
||||
if (granted) {
|
||||
[self startGraphAndCamera];
|
||||
dispatch_async(dispatch_get_main_queue(), ^{
|
||||
_noCameraLabel.hidden = YES;
|
||||
});
|
||||
}
|
||||
}];
|
||||
}
|
||||
|
||||
- (void)startGraphAndCamera {
|
||||
// Start running self.mediapipeGraph.
|
||||
NSError* error;
|
||||
if (![self.mediapipeGraph startWithError:&error]) {
|
||||
NSLog(@"Failed to start graph: %@", error);
|
||||
}
|
||||
|
||||
// Start fetching frames from the camera.
|
||||
switch (_sourceMode) {
|
||||
case MediaPipeDemoSourceVideo: {
|
||||
AVAsset* video =
|
||||
[AVAsset assetWithURL:[[NSBundle mainBundle] URLForResource:@"object_detection"
|
||||
withExtension:@"mov"]];
|
||||
_videoSource = [[MPPPlayerInputSource alloc] initWithAVAsset:video];
|
||||
[_videoSource setDelegate:self queue:_videoQueue];
|
||||
dispatch_async(_videoQueue, ^{
|
||||
[_videoSource start];
|
||||
});
|
||||
break;
|
||||
}
|
||||
case MediaPipeDemoSourceBackCamera:
|
||||
_cameraSource = [[MPPCameraInputSource alloc] init];
|
||||
[_cameraSource setDelegate:self queue:_videoQueue];
|
||||
_cameraSource.sessionPreset = AVCaptureSessionPresetHigh;
|
||||
_cameraSource.cameraPosition = AVCaptureDevicePositionBack;
|
||||
// The frame's native format is rotated with respect to the portrait orientation.
|
||||
_cameraSource.orientation = AVCaptureVideoOrientationPortrait;
|
||||
[_cameraSource requestCameraAccessWithCompletionHandler:^void(BOOL granted) {
|
||||
if (granted) {
|
||||
dispatch_async(_videoQueue, ^{
|
||||
[_cameraSource start];
|
||||
});
|
||||
dispatch_async(dispatch_get_main_queue(), ^{
|
||||
_noCameraLabel.hidden = YES;
|
||||
});
|
||||
}
|
||||
}];
|
||||
break;
|
||||
}
|
||||
}
|
||||
|
||||
#pragma mark - MPPGraphDelegate methods
|
||||
|
@ -164,7 +180,7 @@ static const char* kVideoQueueLabel = "com.google.mediapipe.example.videoQueue";
|
|||
- (void)processVideoFrame:(CVPixelBufferRef)imageBuffer
|
||||
timestamp:(CMTime)timestamp
|
||||
fromSource:(MPPInputSource*)source {
|
||||
if (source != _cameraSource) {
|
||||
if (source != _cameraSource && source != _videoSource) {
|
||||
NSLog(@"Unknown source: %@", source);
|
||||
return;
|
||||
}
|
||||
|
|
|
@ -36,7 +36,7 @@ exports_files([
|
|||
mediapipe_proto_library(
|
||||
name = "calculator_proto",
|
||||
srcs = ["calculator.proto"],
|
||||
visibility = [":mediapipe_internal"],
|
||||
visibility = ["//visibility:public"],
|
||||
deps = [
|
||||
"//mediapipe/framework:calculator_options_proto",
|
||||
"//mediapipe/framework:mediapipe_options_proto",
|
||||
|
@ -68,7 +68,7 @@ mediapipe_proto_library(
|
|||
mediapipe_proto_library(
|
||||
name = "calculator_profile_proto",
|
||||
srcs = ["calculator_profile.proto"],
|
||||
visibility = [":mediapipe_internal"],
|
||||
visibility = ["//visibility:public"],
|
||||
deps = [
|
||||
"//mediapipe/framework:calculator_options_proto",
|
||||
"//mediapipe/framework:calculator_proto",
|
||||
|
@ -830,6 +830,8 @@ cc_library(
|
|||
":port",
|
||||
":timestamp",
|
||||
":type_map",
|
||||
"//mediapipe/framework/deps:no_destructor",
|
||||
"//mediapipe/framework/deps:registration",
|
||||
"//mediapipe/framework/port:core_proto",
|
||||
"//mediapipe/framework/port:integral_types",
|
||||
"//mediapipe/framework/port:logging",
|
||||
|
@ -1524,6 +1526,21 @@ cc_test(
|
|||
],
|
||||
)
|
||||
|
||||
cc_test(
|
||||
name = "packet_registration_test",
|
||||
size = "small",
|
||||
srcs = ["packet_registration_test.cc"],
|
||||
deps = [
|
||||
":calculator_framework",
|
||||
":packet",
|
||||
":packet_test_cc_proto",
|
||||
":type_map",
|
||||
"//mediapipe/framework/port:core_proto",
|
||||
"//mediapipe/framework/port:gtest_main",
|
||||
"@com_google_absl//absl/strings",
|
||||
],
|
||||
)
|
||||
|
||||
cc_test(
|
||||
name = "packet_generator_test",
|
||||
size = "small",
|
||||
|
|
|
@ -115,6 +115,9 @@ class CalculatorContract {
|
|||
// When true, Process is called for every new timestamp bound, with or without
|
||||
// new packets. A call to Process with only an input timestamp bound is
|
||||
// normally used to compute a new output timestamp bound.
|
||||
// NOTE: Also, when true, Process is called when input streams become done,
|
||||
// which means, Process needs to handle input streams in "done" state.
|
||||
// (Usually, by closing calculators' outputs where and when appropriate.)
|
||||
void SetProcessTimestampBounds(bool process_timestamps) {
|
||||
process_timestamps_ = process_timestamps;
|
||||
}
|
||||
|
|
|
@ -91,6 +91,9 @@ typedef ::mediapipe::StatusOr<OutputStreamPoller> StatusOrPoller;
|
|||
// {{"video_id", mediapipe::MakePacket<std::string>("Ex-uGhDzue4")}}));
|
||||
// // See mediapipe/framework/graph_runner.h for an interface
|
||||
// // to insert and extract packets from a graph as it runs.
|
||||
// // Once it is done using the graph, close its streams and wait till done.
|
||||
// MP_RETURN_IF_ERROR(graph->CloseAllInputStreams());
|
||||
// MP_RETURN_IF_ERROR(graph->WaitUntilDone());
|
||||
class CalculatorGraph {
|
||||
public:
|
||||
// Defines possible modes for adding a packet to a graph input stream.
|
||||
|
@ -157,8 +160,9 @@ class CalculatorGraph {
|
|||
std::function<::mediapipe::Status(const Packet&)> packet_callback);
|
||||
|
||||
// Adds an OutputStreamPoller for a stream. This provides a synchronous,
|
||||
// polling API for accessing a stream's output. For asynchronous output, use
|
||||
// ObserveOutputStream. See also the helpers in tool/sink.h.
|
||||
// polling API for accessing a stream's output. Should only be called before
|
||||
// Run() or StartRun(). For asynchronous output, use ObserveOutputStream. See
|
||||
// also the helpers in tool/sink.h.
|
||||
StatusOrPoller AddOutputStreamPoller(const std::string& stream_name);
|
||||
|
||||
// Gets output side packet by name after the graph is done. However, base
|
||||
|
@ -300,6 +304,13 @@ class CalculatorGraph {
|
|||
void RecordError(const ::mediapipe::Status& error)
|
||||
ABSL_LOCKS_EXCLUDED(error_mutex_);
|
||||
|
||||
// Combines errors into a status. Returns true if the vector of errors is
|
||||
// non-empty.
|
||||
bool GetCombinedErrors(const std::string& error_prefix,
|
||||
::mediapipe::Status* error_status);
|
||||
// Convenience overload which specifies a default error prefix.
|
||||
bool GetCombinedErrors(::mediapipe::Status* error_status);
|
||||
|
||||
// Returns the maximum input stream queue size.
|
||||
int GetMaxInputStreamQueueSize();
|
||||
|
||||
|
@ -501,13 +512,6 @@ class CalculatorGraph {
|
|||
void CleanupAfterRun(::mediapipe::Status* status)
|
||||
ABSL_LOCKS_EXCLUDED(error_mutex_);
|
||||
|
||||
// Combines errors into a status. Returns true if the vector of errors is
|
||||
// non-empty.
|
||||
bool GetCombinedErrors(const std::string& error_prefix,
|
||||
::mediapipe::Status* error_status);
|
||||
// Convenience overload which specifies a default error prefix.
|
||||
bool GetCombinedErrors(::mediapipe::Status* error_status);
|
||||
|
||||
// Calls HandlePreRunStatus or HandleStatus on the StatusHandlers. Which one
|
||||
// is called depends on the GraphRunState parameter (PRE_RUN or POST_RUN).
|
||||
// current_run_side_packets_ must be set before this function is called.
|
||||
|
|
|
@ -459,7 +459,8 @@ class Vector3
|
|||
int LargestAbsComponent() const {
|
||||
Vector3 temp = Abs();
|
||||
return temp[0] > temp[1] ? temp[0] > temp[2] ? 0 : 2
|
||||
: temp[1] > temp[2] ? 1 : 2;
|
||||
: temp[1] > temp[2] ? 1
|
||||
: 2;
|
||||
}
|
||||
|
||||
// return the index of the smallest, median ,largest component of the vector
|
||||
|
|
|
@ -155,7 +155,7 @@ class InputStreamHandler {
|
|||
// max number of invocations that are allowed to be scheduled is reached.
|
||||
// Returns true if at least one invocation has been scheduled.
|
||||
// The latest minimum timestamp bound of the input streams is returned in
|
||||
// *input_bound iff the latest readiness of the node is kNotReady when the
|
||||
// *input_bound if the latest readiness of the node is kNotReady when the
|
||||
// function returns. During batching, this value will be equal to the
|
||||
// timestamp of the first set of inputs in the batch. In other cases,
|
||||
// Timestamp::Unset() is returned.
|
||||
|
|
|
@ -66,6 +66,20 @@ class LegacyCalculatorSupport {
|
|||
};
|
||||
};
|
||||
|
||||
// We only declare this variable for two specializations of the template because
|
||||
// it is only meant to be used for these two types.
|
||||
// Note that, since these variables are members of specific template
|
||||
// _specializations_, they are not themselves templates, and therefore their
|
||||
// definitions must be in the .cc file. However, a declaration still needs to be
|
||||
// included in the header, or some compilers will assume they have no
|
||||
// definition.
|
||||
template <>
|
||||
thread_local CalculatorContext*
|
||||
LegacyCalculatorSupport::Scoped<CalculatorContext>::current_;
|
||||
template <>
|
||||
thread_local CalculatorContract*
|
||||
LegacyCalculatorSupport::Scoped<CalculatorContract>::current_;
|
||||
|
||||
} // namespace mediapipe
|
||||
|
||||
#endif // MEDIAPIPE_FRAMEWORK_LEGACY_CALCULATOR_SUPPORT_H_
|
||||
|
|
|
@ -51,6 +51,18 @@ const HolderBase* GetHolder(const Packet& packet) {
|
|||
return packet.holder_.get();
|
||||
}
|
||||
|
||||
::mediapipe::StatusOr<Packet> PacketFromDynamicProto(
|
||||
const std::string& type_name, const std::string& serialized) {
|
||||
ASSIGN_OR_RETURN(
|
||||
auto message_holder,
|
||||
packet_internal::MessageHolderRegistry::CreateByName(type_name));
|
||||
auto* message =
|
||||
const_cast<proto_ns::MessageLite*>(message_holder->GetProtoMessageLite());
|
||||
RET_CHECK_NE(message, nullptr);
|
||||
RET_CHECK(message->ParseFromString(serialized));
|
||||
return packet_internal::Create(message_holder.release());
|
||||
}
|
||||
|
||||
} // namespace packet_internal
|
||||
|
||||
Packet Packet::At(class Timestamp timestamp) const& {
|
||||
|
|
|
@ -27,6 +27,8 @@
|
|||
#include "absl/memory/memory.h"
|
||||
#include "absl/strings/str_cat.h"
|
||||
#include "absl/synchronization/mutex.h"
|
||||
#include "mediapipe/framework/deps/no_destructor.h"
|
||||
#include "mediapipe/framework/deps/registration.h"
|
||||
#include "mediapipe/framework/port.h"
|
||||
#include "mediapipe/framework/port/canonical_errors.h"
|
||||
#include "mediapipe/framework/port/logging.h"
|
||||
|
@ -51,6 +53,8 @@ Packet Create(HolderBase* holder, Timestamp timestamp);
|
|||
Packet Create(std::shared_ptr<HolderBase> holder, Timestamp timestamp);
|
||||
const HolderBase* GetHolder(const Packet& packet);
|
||||
const std::shared_ptr<HolderBase>& GetHolderShared(const Packet& packet);
|
||||
::mediapipe::StatusOr<Packet> PacketFromDynamicProto(
|
||||
const std::string& type_name, const std::string& serialized);
|
||||
} // namespace packet_internal
|
||||
|
||||
// A generic container class which can hold data of any type. The type of
|
||||
|
@ -355,112 +359,11 @@ class HolderBase {
|
|||
// Downcasts this to Holder<T>. Returns nullptr if deserialization
|
||||
// failed or if the requested type is not what is stored.
|
||||
template <typename T>
|
||||
inline Holder<T>* As(
|
||||
typename std::enable_if<
|
||||
(!std::is_base_of<proto_ns::MessageLite, T>::value &&
|
||||
!std::is_base_of<proto_ns::Message, T>::value) ||
|
||||
(std::is_same<proto_ns::MessageLite, T>::value ||
|
||||
std::is_same<proto_ns::Message, T>::value)>::type* = 0) {
|
||||
if (HolderIsOfType<Holder<T>>() || HolderIsOfType<ForeignHolder<T>>()) {
|
||||
return static_cast<Holder<T>*>(this);
|
||||
}
|
||||
// Does not hold a T.
|
||||
return nullptr;
|
||||
}
|
||||
|
||||
// For proto Message/MessageLite subclasses.
|
||||
// When holder data is a concrete proto, the method downcasts this to
|
||||
// Holder<T> if the requested type is what is stored.
|
||||
// When holder data is a generic proto Message/MessageLite and a concrete
|
||||
// proto type T is requested, the method will downcast the HolderBase to
|
||||
// Holder<T> if the proto data is an instance of T.
|
||||
template <typename T>
|
||||
inline Holder<T>* As(
|
||||
typename std::enable_if<
|
||||
(std::is_base_of<proto_ns::MessageLite, T>::value ||
|
||||
std::is_base_of<proto_ns::Message, T>::value) &&
|
||||
(!std::is_same<proto_ns::MessageLite, T>::value &&
|
||||
!std::is_same<proto_ns::Message, T>::value)>::type* = 0) {
|
||||
// Holder data is an instance of subclass type T.
|
||||
if (HolderIsOfType<Holder<T>>() || HolderIsOfType<ForeignHolder<T>>()) {
|
||||
return static_cast<Holder<T>*>(this);
|
||||
}
|
||||
|
||||
// Holder data is a generic proto Message/MessageLite and a subclass type T
|
||||
// is requested.
|
||||
if (HolderIsOfType<Holder<proto_ns::Message>>() ||
|
||||
HolderIsOfType<ForeignHolder<proto_ns::Message>>() ||
|
||||
HolderIsOfType<Holder<proto_ns::MessageLite>>() ||
|
||||
HolderIsOfType<ForeignHolder<proto_ns::MessageLite>>()) {
|
||||
// TODO: Holder<proto_ns::Message/MessageLite> cannot be
|
||||
// legally downcast to Holder<T>, even though that downcast works in
|
||||
// practice. Need to propose a better way to do the downcast.
|
||||
Holder<T>* holder = static_cast<Holder<T>*>(this);
|
||||
T tmp;
|
||||
VLOG(2) << "Holder proto data type: " << holder->data().GetTypeName()
|
||||
<< " vs requested proto type: " << tmp.GetTypeName();
|
||||
if (tmp.GetTypeName() == holder->data().GetTypeName()) {
|
||||
return holder;
|
||||
}
|
||||
}
|
||||
|
||||
// Does not hold a T.
|
||||
return nullptr;
|
||||
}
|
||||
Holder<T>* As();
|
||||
|
||||
// Same as non-const As() function.
|
||||
template <typename T>
|
||||
inline const Holder<T>* As(
|
||||
typename std::enable_if<
|
||||
(!std::is_base_of<proto_ns::MessageLite, T>::value &&
|
||||
!std::is_base_of<proto_ns::Message, T>::value) ||
|
||||
(std::is_same<proto_ns::MessageLite, T>::value ||
|
||||
std::is_same<proto_ns::Message, T>::value)>::type* = 0) const {
|
||||
if (HolderIsOfType<Holder<T>>() || HolderIsOfType<ForeignHolder<T>>()) {
|
||||
return static_cast<const Holder<T>*>(this);
|
||||
}
|
||||
// Does not hold a T.
|
||||
return nullptr;
|
||||
}
|
||||
|
||||
// For proto Message/MessageLite subclasses.
|
||||
// When holder data is a concrete proto, the method downcasts this to
|
||||
// Holder<T> if the requested type is what is stored.
|
||||
// When holder data is a generic proto Message/MessageLite and a concrete
|
||||
// proto type T is requested, the method will downcast the HolderBase to
|
||||
// Holder<T> if the proto data is an instance of T.
|
||||
template <typename T>
|
||||
inline const Holder<T>* As(
|
||||
typename std::enable_if<
|
||||
(std::is_base_of<proto_ns::MessageLite, T>::value ||
|
||||
std::is_base_of<proto_ns::Message, T>::value) &&
|
||||
(!std::is_same<proto_ns::MessageLite, T>::value &&
|
||||
!std::is_same<proto_ns::Message, T>::value)>::type* = 0) const {
|
||||
if (HolderIsOfType<Holder<T>>() || HolderIsOfType<ForeignHolder<T>>()) {
|
||||
return static_cast<const Holder<T>*>(this);
|
||||
}
|
||||
|
||||
// Holder data is a generic proto Message/MessageLite and a subclass type T
|
||||
// is requested.
|
||||
if (HolderIsOfType<Holder<proto_ns::Message>>() ||
|
||||
HolderIsOfType<ForeignHolder<proto_ns::Message>>() ||
|
||||
HolderIsOfType<Holder<proto_ns::MessageLite>>() ||
|
||||
HolderIsOfType<ForeignHolder<proto_ns::MessageLite>>()) {
|
||||
// TODO: Holder<proto_ns::Message/MessageLite> cannot be
|
||||
// legally downcast to Holder<T>, even though that downcast works in
|
||||
// practice. Need to propose a better way to do the downcast.
|
||||
Holder<T>* holder = static_cast<const Holder<T>*>(this);
|
||||
T tmp;
|
||||
VLOG(2) << "Holder proto data type: " << holder->data().GetTypeName()
|
||||
<< " vs requested proto type: " << tmp.GetTypeName();
|
||||
if (tmp.GetTypeName() == holder->data().GetTypeName()) {
|
||||
return holder;
|
||||
}
|
||||
}
|
||||
|
||||
// Does not hold a T.
|
||||
return nullptr;
|
||||
}
|
||||
const Holder<T>* As() const;
|
||||
|
||||
// Returns the pointer to MessageLite type for the data in holder, if
|
||||
// underlying object is protocol buffer type, otherwise, nullptr is returned.
|
||||
|
@ -520,12 +423,68 @@ ConvertToVectorOfProtoMessageLitePtrs(const T* data,
|
|||
return result;
|
||||
}
|
||||
|
||||
// This registry is used to create Holders of the right concrete C++ type given
|
||||
// a proto type std::string (which is used as the registration key).
|
||||
class MessageHolderRegistry
|
||||
: public GlobalFactoryRegistry<std::unique_ptr<HolderBase>> {};
|
||||
|
||||
template <typename T>
|
||||
struct is_concrete_proto_t
|
||||
: public std::integral_constant<
|
||||
bool, std::is_base_of<proto_ns::MessageLite, T>{} &&
|
||||
!std::is_same<proto_ns::MessageLite, T>{} &&
|
||||
!std::is_same<proto_ns::Message, T>{}> {};
|
||||
|
||||
// Registers a message type. T must be a non-cv-qualified concrete proto type.
|
||||
template <typename T>
|
||||
struct MessageRegistrationImpl {
|
||||
static NoDestructor<mediapipe::RegistrationToken> registration;
|
||||
};
|
||||
|
||||
// Static members of template classes can be defined in the header.
|
||||
template <typename T>
|
||||
NoDestructor<mediapipe::RegistrationToken>
|
||||
MessageRegistrationImpl<T>::registration(MessageHolderRegistry::Register(
|
||||
T{}.GetTypeName(), [] { return absl::make_unique<Holder<T>>(new T); }));
|
||||
|
||||
// For non-Message payloads, this does nothing.
|
||||
template <typename T, typename Enable = void>
|
||||
struct HolderSupport {
|
||||
static void EnsureStaticInit() {}
|
||||
};
|
||||
|
||||
// This template ensures that, for each concrete MessageLite subclass that is
|
||||
// stored in a Packet, we register a function that allows us to create a
|
||||
// Holder with the correct payload type from the proto's type name.
|
||||
template <typename T>
|
||||
struct HolderSupport<T,
|
||||
typename std::enable_if<is_concrete_proto_t<T>{}>::type> {
|
||||
// We must use std::remove_cv to ensure we don't try to register Foo twice if
|
||||
// there are Holder<Foo> and Holder<const Foo>. TODO: lift this
|
||||
// up to Holder?
|
||||
using R = MessageRegistrationImpl<typename std::remove_cv<T>::type>;
|
||||
// For the registration static member to be instantiated, it needs to be
|
||||
// referenced in a context that requires the definition to exist (see ISO/IEC
|
||||
// C++ 2003 standard, 14.7.1). Calling this ensures that's the case.
|
||||
// We need two different call-sites to cover proto types for which packets
|
||||
// are only ever created (i.e. the protos are only produced by calculators)
|
||||
// and proto types for which packets are only ever consumed (i.e. the protos
|
||||
// are only consumed by calculators).
|
||||
static void EnsureStaticInit() { CHECK(R::registration.get() != nullptr); }
|
||||
};
|
||||
|
||||
template <typename T>
|
||||
class Holder : public HolderBase {
|
||||
public:
|
||||
explicit Holder(const T* ptr) : ptr_(ptr) { SetHolderTypeId<Holder>(); }
|
||||
explicit Holder(const T* ptr) : ptr_(ptr) {
|
||||
HolderSupport<T>::EnsureStaticInit();
|
||||
SetHolderTypeId<Holder>();
|
||||
}
|
||||
~Holder() override { delete_helper(); }
|
||||
const T& data() const { return *ptr_; }
|
||||
const T& data() const {
|
||||
HolderSupport<T>::EnsureStaticInit();
|
||||
return *ptr_;
|
||||
}
|
||||
size_t GetTypeId() const final { return tool::GetTypeHash<T>(); }
|
||||
// Releases the underlying data pointer and transfers the ownership to a
|
||||
// unique pointer.
|
||||
|
@ -622,6 +581,24 @@ class ForeignHolder : public Holder<T> {
|
|||
}
|
||||
};
|
||||
|
||||
template <typename T>
|
||||
Holder<T>* HolderBase::As() {
|
||||
if (HolderIsOfType<Holder<T>>() || HolderIsOfType<ForeignHolder<T>>()) {
|
||||
return static_cast<Holder<T>*>(this);
|
||||
}
|
||||
// Does not hold a T.
|
||||
return nullptr;
|
||||
}
|
||||
|
||||
template <typename T>
|
||||
const Holder<T>* HolderBase::As() const {
|
||||
if (HolderIsOfType<Holder<T>>() || HolderIsOfType<ForeignHolder<T>>()) {
|
||||
return static_cast<const Holder<T>*>(this);
|
||||
}
|
||||
// Does not hold a T.
|
||||
return nullptr;
|
||||
}
|
||||
|
||||
} // namespace packet_internal
|
||||
|
||||
inline Packet::Packet(const Packet& packet)
|
||||
|
|
57
mediapipe/framework/packet_registration_test.cc
Normal file
57
mediapipe/framework/packet_registration_test.cc
Normal file
|
@ -0,0 +1,57 @@
|
|||
// Copyright 2020 The MediaPipe Authors.
|
||||
//
|
||||
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||
// you may not use this file except in compliance with the License.
|
||||
// You may obtain a copy of the License at
|
||||
//
|
||||
// http://www.apache.org/licenses/LICENSE-2.0
|
||||
//
|
||||
// Unless required by applicable law or agreed to in writing, software
|
||||
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
// See the License for the specific language governing permissions and
|
||||
// limitations under the License.
|
||||
|
||||
#include "absl/strings/str_cat.h"
|
||||
#include "mediapipe/framework/calculator_framework.h"
|
||||
#include "mediapipe/framework/packet.h"
|
||||
#include "mediapipe/framework/packet_test.pb.h"
|
||||
#include "mediapipe/framework/port/core_proto_inc.h"
|
||||
#include "mediapipe/framework/port/gmock.h"
|
||||
#include "mediapipe/framework/port/gtest.h"
|
||||
#include "mediapipe/framework/port/status_matchers.h"
|
||||
|
||||
namespace mediapipe {
|
||||
namespace {
|
||||
|
||||
namespace test_ns {
|
||||
|
||||
class TestSinkCalculator : public CalculatorBase {
|
||||
public:
|
||||
static ::mediapipe::Status GetContract(CalculatorContract* cc) {
|
||||
cc->Inputs().Tag("IN").Set<mediapipe::InputOnlyProto>();
|
||||
cc->Outputs().Tag("OUT").Set<int>();
|
||||
return ::mediapipe::OkStatus();
|
||||
}
|
||||
|
||||
::mediapipe::Status Process(CalculatorContext* cc) override {
|
||||
int x = cc->Inputs().Tag("IN").Get<mediapipe::InputOnlyProto>().x();
|
||||
cc->Outputs().Tag("OUT").AddPacket(
|
||||
MakePacket<int>(x).At(cc->InputTimestamp()));
|
||||
return ::mediapipe::OkStatus();
|
||||
}
|
||||
};
|
||||
REGISTER_CALCULATOR(::mediapipe::test_ns::TestSinkCalculator);
|
||||
|
||||
} // namespace test_ns
|
||||
|
||||
TEST(PacketTest, InputTypeRegistration) {
|
||||
using testing::Contains;
|
||||
ASSERT_EQ(mediapipe::InputOnlyProto{}.GetTypeName(),
|
||||
"mediapipe.InputOnlyProto");
|
||||
EXPECT_THAT(packet_internal::MessageHolderRegistry::GetRegisteredNames(),
|
||||
Contains("mediapipe.InputOnlyProto"));
|
||||
}
|
||||
|
||||
} // namespace
|
||||
} // namespace mediapipe
|
|
@ -174,54 +174,13 @@ TEST(PacketTest, ReturnGenericProtobufMessage) {
|
|||
.x(0));
|
||||
}
|
||||
|
||||
TEST(PacketTest, ReturnProtobufMessageSubType) {
|
||||
std::unique_ptr<::mediapipe::PacketTestProto> proto_ptr(
|
||||
new ::mediapipe::PacketTestProto);
|
||||
proto_ptr->add_x(123);
|
||||
Packet packet = Adopt(static_cast<proto_ns::Message*>(proto_ptr.release()));
|
||||
EXPECT_EQ(123, packet.Get<::mediapipe::PacketTestProto>().x(0));
|
||||
EXPECT_EQ(123, packet.Get<const ::mediapipe::PacketTestProto>().x(0));
|
||||
}
|
||||
|
||||
TEST(PacketTest, TryWrongProtobufMessageSubType) {
|
||||
// Packet of PacketTestProto.
|
||||
std::unique_ptr<::mediapipe::PacketTestProto> proto_ptr(
|
||||
new ::mediapipe::PacketTestProto);
|
||||
proto_ptr->add_x(123);
|
||||
Packet packet = Adopt(proto_ptr.release());
|
||||
EXPECT_FALSE(packet.ValidateAsType<::mediapipe::SimpleProto>().ok());
|
||||
EXPECT_TRUE(packet.ValidateAsType<::mediapipe::PacketTestProto>().ok());
|
||||
|
||||
// Packet of proto_ns::Message.
|
||||
proto_ptr.reset(new ::mediapipe::PacketTestProto);
|
||||
proto_ptr->add_x(456);
|
||||
Packet packet2 = Adopt(static_cast<proto_ns::Message*>(proto_ptr.release()));
|
||||
EXPECT_FALSE(packet2.ValidateAsType<::mediapipe::SimpleProto>().ok());
|
||||
EXPECT_TRUE(packet2.ValidateAsType<::mediapipe::PacketTestProto>().ok());
|
||||
EXPECT_EQ(123, packet.Get<::mediapipe::PacketTestProto>().x(0));
|
||||
}
|
||||
|
||||
TEST(PacketTest, ReturnProtobufMessageLiteSubType) {
|
||||
std::unique_ptr<::mediapipe::PacketTestProto> proto_ptr(
|
||||
new ::mediapipe::PacketTestProto);
|
||||
proto_ptr->add_x(123);
|
||||
Packet packet =
|
||||
Adopt(static_cast<proto_ns::MessageLite*>(proto_ptr.release()));
|
||||
EXPECT_EQ(123, packet.Get<::mediapipe::PacketTestProto>().x(0));
|
||||
EXPECT_EQ(123, packet.Get<const ::mediapipe::PacketTestProto>().x(0));
|
||||
}
|
||||
|
||||
TEST(PacketTest, TryWrongProtobufMessageLiteSubType) {
|
||||
// Packet of PacketTestProto.
|
||||
std::unique_ptr<::mediapipe::PacketTestProto> proto_ptr(
|
||||
new ::mediapipe::PacketTestProto);
|
||||
// Packet of proto_ns::MessageLite.
|
||||
proto_ptr->add_x(456);
|
||||
Packet packet =
|
||||
Adopt(static_cast<proto_ns::MessageLite*>(proto_ptr.release()));
|
||||
EXPECT_FALSE(packet.ValidateAsType<::mediapipe::SimpleProto>().ok());
|
||||
EXPECT_TRUE(packet.ValidateAsType<::mediapipe::PacketTestProto>().ok());
|
||||
EXPECT_EQ(456, packet.Get<::mediapipe::PacketTestProto>().x(0));
|
||||
}
|
||||
|
||||
TEST(PacketTest, GetProtoBase) {
|
||||
|
@ -505,5 +464,26 @@ TEST(PacketTest, TestConsumeOrCopyBoundedArray) {
|
|||
EXPECT_TRUE(packet2.IsEmpty());
|
||||
}
|
||||
|
||||
TEST(PacketTest, MessageHolderRegistration) {
|
||||
using testing::Contains;
|
||||
Packet packet = MakePacket<mediapipe::SimpleProto>();
|
||||
ASSERT_EQ(mediapipe::SimpleProto{}.GetTypeName(), "mediapipe.SimpleProto");
|
||||
EXPECT_THAT(packet_internal::MessageHolderRegistry::GetRegisteredNames(),
|
||||
Contains("mediapipe.SimpleProto"));
|
||||
}
|
||||
|
||||
TEST(PacketTest, PacketFromSerializedProto) {
|
||||
mediapipe::SimpleProto original;
|
||||
original.add_value("foo");
|
||||
std::string serialized = original.SerializeAsString();
|
||||
|
||||
StatusOr<Packet> maybe_packet = packet_internal::PacketFromDynamicProto(
|
||||
"mediapipe.SimpleProto", serialized);
|
||||
MP_ASSERT_OK(maybe_packet);
|
||||
Packet packet = maybe_packet.ValueOrDie();
|
||||
MP_EXPECT_OK(packet.ValidateAsType<::mediapipe::SimpleProto>());
|
||||
EXPECT_FALSE(packet.ValidateAsType<::mediapipe::PacketTestProto>().ok());
|
||||
}
|
||||
|
||||
} // namespace
|
||||
} // namespace mediapipe
|
||||
|
|
|
@ -39,3 +39,9 @@ message SerializationProxyProto {
|
|||
repeated float float_value = 2;
|
||||
repeated string string_value = 3;
|
||||
}
|
||||
|
||||
// This proto should be used only as an input to a calculator, to verify that
|
||||
// that case is covered.
|
||||
message InputOnlyProto {
|
||||
optional int32 x = 1;
|
||||
}
|
||||
|
|
|
@ -46,7 +46,7 @@
|
|||
// but may or may not still be able to run other OpenGL code.
|
||||
#if !defined(MEDIAPIPE_DISABLE_GL_COMPUTE) && \
|
||||
(defined(__APPLE__) || defined(__EMSCRIPTEN__) || \
|
||||
defined(MEDIAPIPE_DISABLE_GPU))
|
||||
defined(MEDIAPIPE_DISABLE_GPU) || MEDIAPIPE_USING_SWIFTSHADER)
|
||||
#define MEDIAPIPE_DISABLE_GL_COMPUTE
|
||||
#endif
|
||||
|
||||
|
|
|
@ -143,8 +143,8 @@ TEST_F(GraphTracerTest, CalculatorTrace) {
|
|||
{{MakePacket<std::string>("goodbye").At(start_timestamp_)}});
|
||||
|
||||
// Validate the GraphTrace data.
|
||||
EXPECT_THAT(GetTrace(),
|
||||
EqualsProto(::mediapipe::ParseTextProtoOrDie<GraphTrace>(R"(
|
||||
EXPECT_THAT(
|
||||
GetTrace(), EqualsProto(::mediapipe::ParseTextProtoOrDie<GraphTrace>(R"(
|
||||
base_time: 1608911100000000
|
||||
base_timestamp: 1608911100000000
|
||||
stream_name: ""
|
||||
|
@ -163,7 +163,7 @@ TEST_F(GraphTracerTest, CalculatorTrace) {
|
|||
stream_id: 1
|
||||
event_data: 1
|
||||
}
|
||||
output_trace { packet_timestamp: 0 stream_id: 2 }
|
||||
output_trace { packet_timestamp: 0 stream_id: 2 event_data: 2 }
|
||||
}
|
||||
)")));
|
||||
}
|
||||
|
@ -205,18 +205,27 @@ TEST_F(GraphTracerTest, GraphTrace) {
|
|||
LogOutputPackets("PCalculator_3", GraphTrace::PROCESS, curr_time,
|
||||
{{MakePacket<std::string>("out").At(start_timestamp_)}});
|
||||
curr_time += absl::Microseconds(2000);
|
||||
ClearCalculatorContext("PCalculator_3");
|
||||
LogInputPackets("PCalculator_3", GraphTrace::PROCESS, curr_time,
|
||||
|
||||
// Note: the packet data ID is based on the packet's payload address, which
|
||||
// means the same ID can be reused if data is allocated in the same location
|
||||
// as a previously expired packet (b/160212191). This means the generated
|
||||
// trace can change depending on the allocator. To keep results stable, we
|
||||
// must keep the packets used in this test alive until the end. Each
|
||||
// TestContextBuilder happens to keep a reference to all packets for the last
|
||||
// context, so for now we just create a separate TestContextBuilder instead of
|
||||
// clearing it. TODO: revise this test.
|
||||
SetUpCalculatorContext("PCalculator_3a", /*node_id=*/2, {"up_2"}, {"down_2"});
|
||||
LogInputPackets("PCalculator_3a", GraphTrace::PROCESS, curr_time,
|
||||
{MakePacket<std::string>("pup").At(start_timestamp_ + 5)});
|
||||
curr_time += absl::Microseconds(20000);
|
||||
LogOutputPackets(
|
||||
"PCalculator_3", GraphTrace::PROCESS, curr_time,
|
||||
"PCalculator_3a", GraphTrace::PROCESS, curr_time,
|
||||
{{MakePacket<std::string>("pout").At(start_timestamp_ + 5)}});
|
||||
curr_time += absl::Microseconds(1000);
|
||||
|
||||
// Validate the GraphTrace data.
|
||||
EXPECT_THAT(GetTrace(),
|
||||
EqualsProto(::mediapipe::ParseTextProtoOrDie<GraphTrace>(R"(
|
||||
EXPECT_THAT(
|
||||
GetTrace(), EqualsProto(::mediapipe::ParseTextProtoOrDie<GraphTrace>(R"(
|
||||
base_time: 1608911100000000
|
||||
base_timestamp: 1608911100000000
|
||||
stream_name: ""
|
||||
|
@ -238,9 +247,9 @@ TEST_F(GraphTracerTest, GraphTrace) {
|
|||
stream_id: 1
|
||||
event_data: 1
|
||||
}
|
||||
output_trace { packet_timestamp: 0 stream_id: 2 }
|
||||
output_trace { packet_timestamp: 0 stream_id: 3 }
|
||||
output_trace { packet_timestamp: 5 stream_id: 3 }
|
||||
output_trace { packet_timestamp: 0 stream_id: 2 event_data: 2 }
|
||||
output_trace { packet_timestamp: 0 stream_id: 3 event_data: 3 }
|
||||
output_trace { packet_timestamp: 5 stream_id: 3 event_data: 4 }
|
||||
}
|
||||
calculator_trace {
|
||||
node_id: 1
|
||||
|
@ -254,9 +263,9 @@ TEST_F(GraphTracerTest, GraphTrace) {
|
|||
finish_time: 11000
|
||||
packet_timestamp: 0
|
||||
stream_id: 2
|
||||
event_data: 2
|
||||
event_data: 5
|
||||
}
|
||||
output_trace { packet_timestamp: 0 stream_id: 4 }
|
||||
output_trace { packet_timestamp: 0 stream_id: 4 event_data: 6 }
|
||||
}
|
||||
calculator_trace {
|
||||
node_id: 2
|
||||
|
@ -270,9 +279,9 @@ TEST_F(GraphTracerTest, GraphTrace) {
|
|||
finish_time: 16000
|
||||
packet_timestamp: 0
|
||||
stream_id: 3
|
||||
event_data: 3
|
||||
event_data: 7
|
||||
}
|
||||
output_trace { packet_timestamp: 0 stream_id: 5 }
|
||||
output_trace { packet_timestamp: 0 stream_id: 5 event_data: 8 }
|
||||
}
|
||||
calculator_trace {
|
||||
node_id: 2
|
||||
|
@ -286,9 +295,9 @@ TEST_F(GraphTracerTest, GraphTrace) {
|
|||
finish_time: 38000
|
||||
packet_timestamp: 5
|
||||
stream_id: 3
|
||||
event_data: 4
|
||||
event_data: 9
|
||||
}
|
||||
output_trace { packet_timestamp: 5 stream_id: 5 }
|
||||
output_trace { packet_timestamp: 5 stream_id: 5 event_data: 10 }
|
||||
}
|
||||
)")));
|
||||
|
||||
|
@ -1275,7 +1284,9 @@ TEST_F(GraphTracerE2ETest, GpuTaskTrace) {
|
|||
GraphTrace trace_1;
|
||||
builder.CreateTrace(buffer, absl::InfinitePast(), absl::InfiniteFuture(),
|
||||
&trace_1);
|
||||
EXPECT_THAT(trace_1, EqualsProto(::mediapipe::ParseTextProtoOrDie<GraphTrace>(
|
||||
EXPECT_THAT(
|
||||
trace_1,
|
||||
EqualsProto(::mediapipe::ParseTextProtoOrDie<GraphTrace>(
|
||||
R"(
|
||||
base_time: 1100
|
||||
base_timestamp: 1000
|
||||
|
@ -1294,7 +1305,7 @@ TEST_F(GraphTracerE2ETest, GpuTaskTrace) {
|
|||
stream_id: 1
|
||||
event_data: 0
|
||||
}
|
||||
output_trace { packet_timestamp: 0 stream_id: 2 }
|
||||
output_trace { packet_timestamp: 0 stream_id: 2 event_data: 0 }
|
||||
thread_id: 0
|
||||
}
|
||||
calculator_trace {
|
||||
|
|
|
@ -330,13 +330,12 @@ class TraceBuilder::Impl {
|
|||
if (trace_event_registry_[event->event_type].is_stream_event()) {
|
||||
auto stream_trace = event->is_finish ? result->add_output_trace()
|
||||
: result->add_input_trace();
|
||||
if (event->is_finish) {
|
||||
// Log only the packet id for each output event.
|
||||
stream_trace->set_stream_id(stream_id_map_[event->stream_id]);
|
||||
stream_trace->set_packet_timestamp(LogTimestamp(event->packet_ts));
|
||||
} else {
|
||||
// Log the full stream trace for each input event.
|
||||
BuildStreamTrace(*event, stream_trace);
|
||||
if (!event->is_finish) {
|
||||
// Note: is_finish is true for output events, false for input events.
|
||||
// For input events, we log some additional timing information. The
|
||||
// finish_time is the start_time of this Process call, the start_time
|
||||
// is the finish_time of the Process call that output the packet.
|
||||
stream_trace->set_finish_time(LogTime(event->event_time));
|
||||
const TraceEvent* output_event = FindOutputEvent(*event);
|
||||
if (output_event) {
|
||||
|
|
|
@ -116,10 +116,19 @@ NodeReadiness ImmediateInputStreamHandler::GetNodeReadiness(
|
|||
CHECK_EQ(stream_ts, Timestamp::Done());
|
||||
if (ProcessTimestampBounds()) {
|
||||
// With kReadyForClose, the timestamp-bound Done is returned.
|
||||
// This bound is processed using the preceding input-timestamp.
|
||||
// TODO: Make all InputStreamHandlers process Done() like this.
|
||||
ready_timestamps_[i] = stream_ts.PreviousAllowedInStream();
|
||||
input_timestamp = std::min(input_timestamp, ready_timestamps_[i]);
|
||||
static const Timestamp kDonePrecedingTimestamp =
|
||||
Timestamp::Done().PreviousAllowedInStream();
|
||||
if (prev_ts < kDonePrecedingTimestamp) {
|
||||
// When kReadyForClose is received for the first time for a sync set,
|
||||
// it is processed using the timestamp preceding Done() to indicate
|
||||
// input stream is done, but still needs to be processed.
|
||||
min_bound = std::min(min_bound, kDonePrecedingTimestamp);
|
||||
input_timestamp = std::min(input_timestamp, kDonePrecedingTimestamp);
|
||||
ready_timestamps_[i] = kDonePrecedingTimestamp;
|
||||
} else {
|
||||
ready_timestamps_[i] = Timestamp::Done();
|
||||
}
|
||||
} else if (prev_ts < Timestamp::Done()) {
|
||||
stream_became_done = true;
|
||||
ready_timestamps_[i] = Timestamp::Done();
|
||||
|
|
|
@ -133,6 +133,11 @@ class ImmediateInputStreamHandlerTest : public ::testing::Test {
|
|||
}
|
||||
}
|
||||
|
||||
const InputStream& Input(const CollectionItemId& id) {
|
||||
CHECK(cc_);
|
||||
return cc_->Inputs().Get(id);
|
||||
}
|
||||
|
||||
PacketType packet_type_;
|
||||
std::function<void()> headers_ready_callback_;
|
||||
std::function<void()> notification_callback_;
|
||||
|
@ -262,6 +267,344 @@ TEST_F(ImmediateInputStreamHandlerTest, ReadyForClose) {
|
|||
EXPECT_TRUE(errors_.empty());
|
||||
}
|
||||
|
||||
TEST_F(ImmediateInputStreamHandlerTest, ProcessTimestampBounds) {
|
||||
input_stream_handler_->SetProcessTimestampBounds(true);
|
||||
|
||||
Timestamp min_stream_timestamp;
|
||||
ASSERT_FALSE(input_stream_handler_->ScheduleInvocations(
|
||||
/*max_allowance=*/1, &min_stream_timestamp));
|
||||
EXPECT_EQ(min_stream_timestamp, Timestamp::PreStream());
|
||||
|
||||
const auto& input_a_id = name_to_id_["input_a"];
|
||||
const auto& input_b_id = name_to_id_["input_b"];
|
||||
const auto& input_c_id = name_to_id_["input_c"];
|
||||
|
||||
std::list<Packet> packets;
|
||||
packets.push_back(Adopt(new std::string("packet 1")).At(Timestamp(1)));
|
||||
input_stream_handler_->AddPackets(input_b_id, packets);
|
||||
input_stream_handler_->SetNextTimestampBound(input_b_id, Timestamp::Done());
|
||||
|
||||
ASSERT_TRUE(input_stream_handler_->ScheduleInvocations(
|
||||
/*max_allowance=*/1, &min_stream_timestamp));
|
||||
EXPECT_EQ(min_stream_timestamp, Timestamp::Unset());
|
||||
EXPECT_EQ(cc_->InputTimestamp(), Timestamp(1));
|
||||
ExpectPackets(cc_->Inputs(), {{"input_b", "packet 1"}});
|
||||
EXPECT_EQ(Input(input_b_id).Value().Timestamp(), Timestamp(1));
|
||||
EXPECT_TRUE(Input(input_a_id).Value().IsEmpty());
|
||||
EXPECT_EQ(Input(input_a_id).Value().Timestamp(), Timestamp::Unstarted());
|
||||
EXPECT_TRUE(Input(input_c_id).Value().IsEmpty());
|
||||
EXPECT_EQ(Input(input_c_id).Value().Timestamp(), Timestamp::Unstarted());
|
||||
|
||||
// FinalizeInputSet() is a no-op.
|
||||
input_stream_handler_->FinalizeInputSet(cc_->InputTimestamp(),
|
||||
&cc_->Inputs());
|
||||
input_stream_handler_->ClearCurrentInputs(cc_);
|
||||
|
||||
ASSERT_TRUE(input_stream_handler_->ScheduleInvocations(
|
||||
/*max_allowance=*/1, &min_stream_timestamp));
|
||||
EXPECT_EQ(min_stream_timestamp, Timestamp::Unset());
|
||||
EXPECT_EQ(cc_->InputTimestamp(), Timestamp::Max());
|
||||
EXPECT_TRUE(Input(input_b_id).Value().IsEmpty());
|
||||
EXPECT_EQ(Input(input_b_id).Value().Timestamp(), Timestamp::Max());
|
||||
EXPECT_TRUE(Input(input_a_id).Value().IsEmpty());
|
||||
EXPECT_EQ(Input(input_a_id).Value().Timestamp(), Timestamp::Unstarted());
|
||||
EXPECT_TRUE(Input(input_c_id).Value().IsEmpty());
|
||||
EXPECT_EQ(Input(input_c_id).Value().Timestamp(), Timestamp::Unstarted());
|
||||
|
||||
// FinalizeInputSet() is a no-op.
|
||||
input_stream_handler_->FinalizeInputSet(cc_->InputTimestamp(),
|
||||
&cc_->Inputs());
|
||||
input_stream_handler_->ClearCurrentInputs(cc_);
|
||||
|
||||
EXPECT_TRUE(
|
||||
input_stream_handler_->GetInputStreamManager(input_b_id)->IsEmpty());
|
||||
|
||||
input_stream_handler_->SetNextTimestampBound(input_a_id, Timestamp::Done());
|
||||
input_stream_handler_->SetNextTimestampBound(input_c_id, Timestamp::Done());
|
||||
|
||||
ASSERT_TRUE(input_stream_handler_->ScheduleInvocations(
|
||||
/*max_allowance=*/1, &min_stream_timestamp));
|
||||
EXPECT_EQ(min_stream_timestamp, Timestamp::Unset());
|
||||
EXPECT_EQ(cc_->InputTimestamp(), Timestamp::Max());
|
||||
EXPECT_TRUE(Input(input_b_id).Value().IsEmpty());
|
||||
EXPECT_EQ(Input(input_b_id).Value().Timestamp(), Timestamp::Max());
|
||||
EXPECT_TRUE(Input(input_a_id).Value().IsEmpty());
|
||||
EXPECT_EQ(Input(input_a_id).Value().Timestamp(), Timestamp::Max());
|
||||
EXPECT_TRUE(Input(input_c_id).Value().IsEmpty());
|
||||
EXPECT_EQ(Input(input_c_id).Value().Timestamp(), Timestamp::Max());
|
||||
|
||||
input_stream_handler_->FinalizeInputSet(cc_->InputTimestamp(),
|
||||
&cc_->Inputs());
|
||||
input_stream_handler_->ClearCurrentInputs(cc_);
|
||||
EXPECT_TRUE(errors_.empty());
|
||||
|
||||
// Schedule invocation for Close.
|
||||
ASSERT_TRUE(input_stream_handler_->ScheduleInvocations(
|
||||
/*max_allowance=*/1, &min_stream_timestamp));
|
||||
EXPECT_EQ(min_stream_timestamp, Timestamp::Unset());
|
||||
EXPECT_EQ(cc_->InputTimestamp(), Timestamp::Done());
|
||||
EXPECT_TRUE(Input(input_b_id).Value().IsEmpty());
|
||||
EXPECT_EQ(Input(input_b_id).Value().Timestamp(), Timestamp::Unset());
|
||||
EXPECT_TRUE(Input(input_a_id).Value().IsEmpty());
|
||||
EXPECT_EQ(Input(input_a_id).Value().Timestamp(), Timestamp::Unset());
|
||||
EXPECT_TRUE(Input(input_c_id).Value().IsEmpty());
|
||||
EXPECT_EQ(Input(input_c_id).Value().Timestamp(), Timestamp::Unset());
|
||||
|
||||
input_stream_handler_->FinalizeInputSet(cc_->InputTimestamp(),
|
||||
&cc_->Inputs());
|
||||
input_stream_handler_->ClearCurrentInputs(cc_);
|
||||
EXPECT_TRUE(errors_.empty());
|
||||
}
|
||||
|
||||
TEST_F(ImmediateInputStreamHandlerTest,
|
||||
ProcessTimestampBoundsNoOpScheduleInvocations) {
|
||||
input_stream_handler_->SetProcessTimestampBounds(true);
|
||||
|
||||
const auto& input_a_id = name_to_id_["input_a"];
|
||||
const auto& input_b_id = name_to_id_["input_b"];
|
||||
const auto& input_c_id = name_to_id_["input_c"];
|
||||
|
||||
Timestamp min_stream_timestamp;
|
||||
std::list<Packet> packets;
|
||||
packets.push_back(Adopt(new std::string("packet 1")).At(Timestamp(1)));
|
||||
input_stream_handler_->AddPackets(input_b_id, packets);
|
||||
ASSERT_TRUE(input_stream_handler_->ScheduleInvocations(
|
||||
/*max_allowance=*/1, &min_stream_timestamp));
|
||||
EXPECT_EQ(min_stream_timestamp, Timestamp::Unset());
|
||||
EXPECT_EQ(cc_->InputTimestamp(), Timestamp(1));
|
||||
ExpectPackets(cc_->Inputs(), {{"input_b", "packet 1"}});
|
||||
EXPECT_EQ(Input(input_b_id).Value().Timestamp(), Timestamp(1));
|
||||
EXPECT_TRUE(Input(input_a_id).Value().IsEmpty());
|
||||
EXPECT_EQ(Input(input_a_id).Value().Timestamp(), Timestamp::Unstarted());
|
||||
EXPECT_TRUE(Input(input_c_id).Value().IsEmpty());
|
||||
EXPECT_EQ(Input(input_c_id).Value().Timestamp(), Timestamp::Unstarted());
|
||||
|
||||
// FinalizeInputSet() is a no-op.
|
||||
input_stream_handler_->FinalizeInputSet(cc_->InputTimestamp(),
|
||||
&cc_->Inputs());
|
||||
input_stream_handler_->ClearCurrentInputs(cc_);
|
||||
|
||||
input_stream_handler_->SetNextTimestampBound(input_a_id, Timestamp::Done());
|
||||
input_stream_handler_->SetNextTimestampBound(input_c_id, Timestamp::Done());
|
||||
|
||||
ASSERT_TRUE(input_stream_handler_->ScheduleInvocations(
|
||||
/*max_allowance=*/1, &min_stream_timestamp));
|
||||
EXPECT_EQ(min_stream_timestamp, Timestamp::Unset());
|
||||
EXPECT_EQ(cc_->InputTimestamp(), Timestamp::Max());
|
||||
EXPECT_TRUE(Input(input_b_id).Value().IsEmpty());
|
||||
EXPECT_EQ(Input(input_b_id).Value().Timestamp(), Timestamp(1));
|
||||
EXPECT_TRUE(Input(input_a_id).Value().IsEmpty());
|
||||
EXPECT_EQ(Input(input_a_id).Value().Timestamp(), Timestamp::Max());
|
||||
EXPECT_TRUE(Input(input_c_id).Value().IsEmpty());
|
||||
EXPECT_EQ(Input(input_c_id).Value().Timestamp(), Timestamp::Max());
|
||||
|
||||
input_stream_handler_->FinalizeInputSet(cc_->InputTimestamp(),
|
||||
&cc_->Inputs());
|
||||
input_stream_handler_->ClearCurrentInputs(cc_);
|
||||
EXPECT_TRUE(errors_.empty());
|
||||
|
||||
// Try to schedule invocations several times again. Considering nothing
|
||||
// changed since last invocation nothing should be scheduled.
|
||||
for (int i = 0; i < 3; ++i) {
|
||||
ASSERT_FALSE(input_stream_handler_->ScheduleInvocations(
|
||||
/*max_allowance=*/1, &min_stream_timestamp));
|
||||
EXPECT_EQ(min_stream_timestamp, Timestamp(2));
|
||||
EXPECT_TRUE(Input(input_b_id).Value().IsEmpty());
|
||||
EXPECT_EQ(Input(input_b_id).Value().Timestamp(), Timestamp::Unset());
|
||||
EXPECT_TRUE(Input(input_a_id).Value().IsEmpty());
|
||||
EXPECT_EQ(Input(input_a_id).Value().Timestamp(), Timestamp::Unset());
|
||||
EXPECT_TRUE(Input(input_c_id).Value().IsEmpty());
|
||||
EXPECT_EQ(Input(input_c_id).Value().Timestamp(), Timestamp::Unset());
|
||||
}
|
||||
|
||||
input_stream_handler_->SetNextTimestampBound(input_b_id, Timestamp::Done());
|
||||
|
||||
ASSERT_TRUE(input_stream_handler_->ScheduleInvocations(
|
||||
/*max_allowance=*/1, &min_stream_timestamp));
|
||||
EXPECT_EQ(min_stream_timestamp, Timestamp::Unset());
|
||||
EXPECT_EQ(cc_->InputTimestamp(), Timestamp::Max());
|
||||
EXPECT_TRUE(Input(input_b_id).Value().IsEmpty());
|
||||
EXPECT_EQ(Input(input_b_id).Value().Timestamp(), Timestamp::Max());
|
||||
EXPECT_TRUE(Input(input_a_id).Value().IsEmpty());
|
||||
EXPECT_EQ(Input(input_a_id).Value().Timestamp(), Timestamp::Max());
|
||||
EXPECT_TRUE(Input(input_c_id).Value().IsEmpty());
|
||||
EXPECT_EQ(Input(input_c_id).Value().Timestamp(), Timestamp::Max());
|
||||
|
||||
input_stream_handler_->FinalizeInputSet(cc_->InputTimestamp(),
|
||||
&cc_->Inputs());
|
||||
input_stream_handler_->ClearCurrentInputs(cc_);
|
||||
EXPECT_TRUE(errors_.empty());
|
||||
|
||||
// Schedule invocation for Close.
|
||||
ASSERT_TRUE(input_stream_handler_->ScheduleInvocations(
|
||||
/*max_allowance=*/1, &min_stream_timestamp));
|
||||
EXPECT_EQ(min_stream_timestamp, Timestamp::Unset());
|
||||
EXPECT_EQ(cc_->InputTimestamp(), Timestamp::Done());
|
||||
EXPECT_TRUE(Input(input_b_id).Value().IsEmpty());
|
||||
EXPECT_EQ(Input(input_b_id).Value().Timestamp(), Timestamp::Unset());
|
||||
EXPECT_TRUE(Input(input_a_id).Value().IsEmpty());
|
||||
EXPECT_EQ(Input(input_a_id).Value().Timestamp(), Timestamp::Unset());
|
||||
EXPECT_TRUE(Input(input_c_id).Value().IsEmpty());
|
||||
EXPECT_EQ(Input(input_c_id).Value().Timestamp(), Timestamp::Unset());
|
||||
|
||||
input_stream_handler_->FinalizeInputSet(cc_->InputTimestamp(),
|
||||
&cc_->Inputs());
|
||||
input_stream_handler_->ClearCurrentInputs(cc_);
|
||||
EXPECT_TRUE(errors_.empty());
|
||||
|
||||
// Try to schedule invocations several times again. Considering nothing
|
||||
// changed since last invocation nothing should be scheduled.
|
||||
for (int i = 0; i < 3; ++i) {
|
||||
ASSERT_FALSE(input_stream_handler_->ScheduleInvocations(
|
||||
/*max_allowance=*/1, &min_stream_timestamp));
|
||||
EXPECT_EQ(min_stream_timestamp, Timestamp::Unset());
|
||||
EXPECT_TRUE(Input(input_b_id).Value().IsEmpty());
|
||||
EXPECT_EQ(Input(input_b_id).Value().Timestamp(), Timestamp::Unset());
|
||||
EXPECT_TRUE(Input(input_a_id).Value().IsEmpty());
|
||||
EXPECT_EQ(Input(input_a_id).Value().Timestamp(), Timestamp::Unset());
|
||||
EXPECT_TRUE(Input(input_c_id).Value().IsEmpty());
|
||||
EXPECT_EQ(Input(input_c_id).Value().Timestamp(), Timestamp::Unset());
|
||||
}
|
||||
}
|
||||
|
||||
// Due to some temporary changes in ImmediateInputStreamHandler some packets
|
||||
// - were queued but never released
|
||||
// - were released in incorrect order
|
||||
// As other test cases were passing, this test case is designed to ensure that.
|
||||
TEST_F(ImmediateInputStreamHandlerTest, VerifyPacketsReleaseOrder) {
|
||||
input_stream_handler_->SetProcessTimestampBounds(true);
|
||||
|
||||
const auto& input_a_id = name_to_id_["input_a"];
|
||||
const auto& input_b_id = name_to_id_["input_b"];
|
||||
const auto& input_c_id = name_to_id_["input_c"];
|
||||
|
||||
Packet packet_a = Adopt(new std::string("packet a"));
|
||||
Packet packet_b = Adopt(new std::string("packet b"));
|
||||
Packet packet_c = Adopt(new std::string("packet c"));
|
||||
input_stream_handler_->AddPackets(input_a_id, {packet_a.At(Timestamp(1))});
|
||||
input_stream_handler_->AddPackets(input_b_id, {packet_b.At(Timestamp(2))});
|
||||
input_stream_handler_->AddPackets(input_c_id, {packet_c.At(Timestamp(3))});
|
||||
|
||||
Timestamp min_stream_timestamp;
|
||||
ASSERT_TRUE(input_stream_handler_->ScheduleInvocations(
|
||||
/*max_allowance=*/1, &min_stream_timestamp));
|
||||
EXPECT_EQ(min_stream_timestamp, Timestamp::Unset());
|
||||
EXPECT_EQ(cc_->InputTimestamp(), Timestamp(1));
|
||||
ASSERT_FALSE(Input(input_a_id).IsEmpty());
|
||||
EXPECT_EQ(Input(input_a_id).Get<std::string>(), "packet a");
|
||||
EXPECT_EQ(Input(input_a_id).Value().Timestamp(), Timestamp(1));
|
||||
EXPECT_TRUE(Input(input_b_id).Value().IsEmpty());
|
||||
EXPECT_EQ(Input(input_b_id).Value().Timestamp(), Timestamp(1));
|
||||
EXPECT_TRUE(Input(input_c_id).Value().IsEmpty());
|
||||
EXPECT_EQ(Input(input_c_id).Value().Timestamp(), Timestamp(2));
|
||||
|
||||
// FinalizeInputSet() is a no-op.
|
||||
input_stream_handler_->FinalizeInputSet(cc_->InputTimestamp(),
|
||||
&cc_->Inputs());
|
||||
input_stream_handler_->ClearCurrentInputs(cc_);
|
||||
|
||||
input_stream_handler_->AddPackets(input_a_id, {packet_a.At(Timestamp(5))});
|
||||
input_stream_handler_->AddPackets(input_b_id, {packet_b.At(Timestamp(5))});
|
||||
input_stream_handler_->AddPackets(input_c_id, {packet_c.At(Timestamp(5))});
|
||||
|
||||
ASSERT_TRUE(input_stream_handler_->ScheduleInvocations(
|
||||
/*max_allowance=*/1, &min_stream_timestamp));
|
||||
EXPECT_EQ(min_stream_timestamp, Timestamp::Unset());
|
||||
EXPECT_EQ(cc_->InputTimestamp(), Timestamp(2));
|
||||
EXPECT_TRUE(Input(input_a_id).Value().IsEmpty());
|
||||
EXPECT_EQ(Input(input_a_id).Value().Timestamp(), Timestamp(4));
|
||||
ASSERT_FALSE(Input(input_b_id).IsEmpty());
|
||||
EXPECT_EQ(Input(input_b_id).Get<std::string>(), "packet b");
|
||||
EXPECT_EQ(Input(input_b_id).Value().Timestamp(), Timestamp(2));
|
||||
EXPECT_TRUE(Input(input_c_id).Value().IsEmpty());
|
||||
EXPECT_EQ(Input(input_c_id).Value().Timestamp(), Timestamp(2));
|
||||
|
||||
// FinalizeInputSet() is a no-op.
|
||||
input_stream_handler_->FinalizeInputSet(cc_->InputTimestamp(),
|
||||
&cc_->Inputs());
|
||||
input_stream_handler_->ClearCurrentInputs(cc_);
|
||||
|
||||
ASSERT_TRUE(input_stream_handler_->ScheduleInvocations(
|
||||
/*max_allowance=*/1, &min_stream_timestamp));
|
||||
EXPECT_EQ(min_stream_timestamp, Timestamp::Unset());
|
||||
EXPECT_EQ(cc_->InputTimestamp(), Timestamp(3));
|
||||
EXPECT_TRUE(Input(input_a_id).Value().IsEmpty());
|
||||
EXPECT_EQ(Input(input_a_id).Value().Timestamp(), Timestamp(4));
|
||||
EXPECT_TRUE(Input(input_b_id).Value().IsEmpty());
|
||||
EXPECT_EQ(Input(input_b_id).Value().Timestamp(), Timestamp(4));
|
||||
ASSERT_FALSE(Input(input_c_id).Value().IsEmpty());
|
||||
EXPECT_EQ(Input(input_c_id).Get<std::string>(), "packet c");
|
||||
EXPECT_EQ(Input(input_c_id).Value().Timestamp(), Timestamp(3));
|
||||
|
||||
// FinalizeInputSet() is a no-op.
|
||||
input_stream_handler_->FinalizeInputSet(cc_->InputTimestamp(),
|
||||
&cc_->Inputs());
|
||||
input_stream_handler_->ClearCurrentInputs(cc_);
|
||||
|
||||
ASSERT_TRUE(input_stream_handler_->ScheduleInvocations(
|
||||
/*max_allowance=*/1, &min_stream_timestamp));
|
||||
EXPECT_EQ(min_stream_timestamp, Timestamp::Unset());
|
||||
EXPECT_EQ(cc_->InputTimestamp(), Timestamp(5));
|
||||
ASSERT_FALSE(Input(input_a_id).Value().IsEmpty());
|
||||
EXPECT_EQ(Input(input_a_id).Get<std::string>(), "packet a");
|
||||
EXPECT_EQ(Input(input_a_id).Value().Timestamp(), Timestamp(5));
|
||||
ASSERT_FALSE(Input(input_b_id).Value().IsEmpty());
|
||||
EXPECT_EQ(Input(input_b_id).Get<std::string>(), "packet b");
|
||||
EXPECT_EQ(Input(input_b_id).Value().Timestamp(), Timestamp(5));
|
||||
ASSERT_FALSE(Input(input_c_id).Value().IsEmpty());
|
||||
EXPECT_EQ(Input(input_c_id).Get<std::string>(), "packet c");
|
||||
EXPECT_EQ(Input(input_c_id).Value().Timestamp(), Timestamp(5));
|
||||
|
||||
input_stream_handler_->FinalizeInputSet(cc_->InputTimestamp(),
|
||||
&cc_->Inputs());
|
||||
input_stream_handler_->ClearCurrentInputs(cc_);
|
||||
|
||||
input_stream_handler_->SetNextTimestampBound(input_a_id, Timestamp::Done());
|
||||
input_stream_handler_->SetNextTimestampBound(input_b_id, Timestamp::Done());
|
||||
input_stream_handler_->SetNextTimestampBound(input_c_id, Timestamp::Done());
|
||||
|
||||
ASSERT_TRUE(input_stream_handler_->ScheduleInvocations(
|
||||
/*max_allowance=*/1, &min_stream_timestamp));
|
||||
EXPECT_EQ(min_stream_timestamp, Timestamp::Unset());
|
||||
EXPECT_EQ(cc_->InputTimestamp(), Timestamp::Max());
|
||||
EXPECT_TRUE(Input(input_b_id).Value().IsEmpty());
|
||||
EXPECT_EQ(Input(input_b_id).Value().Timestamp(), Timestamp::Max());
|
||||
EXPECT_TRUE(Input(input_a_id).Value().IsEmpty());
|
||||
EXPECT_EQ(Input(input_a_id).Value().Timestamp(), Timestamp::Max());
|
||||
EXPECT_TRUE(Input(input_c_id).Value().IsEmpty());
|
||||
EXPECT_EQ(Input(input_c_id).Value().Timestamp(), Timestamp::Max());
|
||||
|
||||
input_stream_handler_->FinalizeInputSet(cc_->InputTimestamp(),
|
||||
&cc_->Inputs());
|
||||
input_stream_handler_->ClearCurrentInputs(cc_);
|
||||
|
||||
// Schedule invocation for Close.
|
||||
ASSERT_TRUE(input_stream_handler_->ScheduleInvocations(
|
||||
/*max_allowance=*/1, &min_stream_timestamp));
|
||||
EXPECT_EQ(min_stream_timestamp, Timestamp::Unset());
|
||||
EXPECT_EQ(cc_->InputTimestamp(), Timestamp::Done());
|
||||
EXPECT_TRUE(Input(input_b_id).Value().IsEmpty());
|
||||
EXPECT_EQ(Input(input_b_id).Value().Timestamp(), Timestamp::Unset());
|
||||
EXPECT_TRUE(Input(input_a_id).Value().IsEmpty());
|
||||
EXPECT_EQ(Input(input_a_id).Value().Timestamp(), Timestamp::Unset());
|
||||
EXPECT_TRUE(Input(input_c_id).Value().IsEmpty());
|
||||
EXPECT_EQ(Input(input_c_id).Value().Timestamp(), Timestamp::Unset());
|
||||
|
||||
input_stream_handler_->FinalizeInputSet(cc_->InputTimestamp(),
|
||||
&cc_->Inputs());
|
||||
input_stream_handler_->ClearCurrentInputs(cc_);
|
||||
|
||||
ASSERT_FALSE(input_stream_handler_->ScheduleInvocations(
|
||||
/*max_allowance=*/1, &min_stream_timestamp));
|
||||
EXPECT_EQ(min_stream_timestamp, Timestamp::Unset());
|
||||
EXPECT_TRUE(Input(input_b_id).Value().IsEmpty());
|
||||
EXPECT_EQ(Input(input_b_id).Value().Timestamp(), Timestamp::Unset());
|
||||
EXPECT_TRUE(Input(input_a_id).Value().IsEmpty());
|
||||
EXPECT_EQ(Input(input_a_id).Value().Timestamp(), Timestamp::Unset());
|
||||
EXPECT_TRUE(Input(input_c_id).Value().IsEmpty());
|
||||
EXPECT_EQ(Input(input_c_id).Value().Timestamp(), Timestamp::Unset());
|
||||
}
|
||||
|
||||
// This test simulates how CalculatorNode::ProcessNode() uses an input
|
||||
// stream handler and the associated input streams.
|
||||
TEST_F(ImmediateInputStreamHandlerTest, SimulateProcessNode) {
|
||||
|
|
|
@ -641,4 +641,61 @@ class DummyTestCalculator : public CalculatorBase {
|
|||
};
|
||||
REGISTER_CALCULATOR(DummyTestCalculator);
|
||||
|
||||
// A Calculator that passes the input value to the output after sleeping for
|
||||
// a set number of microseconds.
|
||||
class PassThroughWithSleepCalculator : public CalculatorBase {
|
||||
public:
|
||||
static ::mediapipe::Status GetContract(CalculatorContract* cc) {
|
||||
cc->Inputs().Index(0).Set<int>();
|
||||
cc->Outputs().Index(0).SetSameAs(&cc->Inputs().Index(0));
|
||||
cc->InputSidePackets().Tag("SLEEP_MICROS").Set<int>();
|
||||
cc->InputSidePackets().Tag("CLOCK").Set<std::shared_ptr<Clock>>();
|
||||
return ::mediapipe::OkStatus();
|
||||
}
|
||||
::mediapipe::Status Open(CalculatorContext* cc) final {
|
||||
cc->SetOffset(TimestampDiff(0));
|
||||
sleep_micros_ = cc->InputSidePackets().Tag("SLEEP_MICROS").Get<int>();
|
||||
if (sleep_micros_ < 0) {
|
||||
return ::mediapipe::InternalError("SLEEP_MICROS should be >= 0");
|
||||
}
|
||||
clock_ = cc->InputSidePackets().Tag("CLOCK").Get<std::shared_ptr<Clock>>();
|
||||
return ::mediapipe::OkStatus();
|
||||
}
|
||||
::mediapipe::Status Process(CalculatorContext* cc) final {
|
||||
clock_->Sleep(absl::Microseconds(sleep_micros_));
|
||||
int value = cc->Inputs().Index(0).Value().Get<int>();
|
||||
cc->Outputs().Index(0).Add(new int(value), cc->InputTimestamp());
|
||||
return ::mediapipe::OkStatus();
|
||||
}
|
||||
|
||||
private:
|
||||
int sleep_micros_ = 0;
|
||||
std::shared_ptr<Clock> clock_;
|
||||
};
|
||||
REGISTER_CALCULATOR(PassThroughWithSleepCalculator);
|
||||
|
||||
// A Calculator that multiples two input values.
|
||||
class MultiplyIntCalculator : public CalculatorBase {
|
||||
public:
|
||||
static ::mediapipe::Status GetContract(CalculatorContract* cc) {
|
||||
cc->Inputs().Index(0).Set<int>();
|
||||
cc->Inputs().Index(1).SetSameAs(&cc->Inputs().Index(0));
|
||||
// cc->Outputs().Index(0).SetSameAs(&cc->Inputs().Index(0));
|
||||
RET_CHECK(cc->Outputs().HasTag("OUT"));
|
||||
cc->Outputs().Tag("OUT").SetSameAs(&cc->Inputs().Index(0));
|
||||
return ::mediapipe::OkStatus();
|
||||
}
|
||||
::mediapipe::Status Open(CalculatorContext* cc) final {
|
||||
cc->SetOffset(TimestampDiff(0));
|
||||
return ::mediapipe::OkStatus();
|
||||
}
|
||||
::mediapipe::Status Process(CalculatorContext* cc) final {
|
||||
int x = cc->Inputs().Index(0).Value().Get<int>();
|
||||
int y = cc->Inputs().Index(1).Value().Get<int>();
|
||||
cc->Outputs().Tag("OUT").Add(new int(x * y), cc->InputTimestamp());
|
||||
return ::mediapipe::OkStatus();
|
||||
}
|
||||
};
|
||||
REGISTER_CALCULATOR(MultiplyIntCalculator);
|
||||
|
||||
} // namespace mediapipe
|
||||
|
|
|
@ -101,6 +101,13 @@ std::string ParseNameFromStream(const std::string& stream) {
|
|||
return name;
|
||||
}
|
||||
|
||||
std::pair<std::string, int> ParseTagIndex(const std::string& tag_index) {
|
||||
std::string tag;
|
||||
int index;
|
||||
MEDIAPIPE_CHECK_OK(tool::ParseTagIndex(tag_index, &tag, &index));
|
||||
return {tag, index};
|
||||
}
|
||||
|
||||
std::pair<std::string, int> ParseTagIndexFromStream(const std::string& stream) {
|
||||
std::string tag, name;
|
||||
int index;
|
||||
|
|
|
@ -76,6 +76,9 @@ std::string CanonicalNodeName(const CalculatorGraphConfig& graph_config,
|
|||
// Parses the name from a "tag:index:name".
|
||||
std::string ParseNameFromStream(const std::string& stream);
|
||||
|
||||
// Parses the TagIndex from a "tag:index".
|
||||
std::pair<std::string, int> ParseTagIndex(const std::string& tag_index);
|
||||
|
||||
// Parses the TagIndex from a "tag:index:name".
|
||||
std::pair<std::string, int> ParseTagIndexFromStream(const std::string& stream);
|
||||
|
||||
|
|
10
mediapipe/framework/tool/testdata/BUILD
vendored
10
mediapipe/framework/tool/testdata/BUILD
vendored
|
@ -13,15 +13,15 @@
|
|||
# limitations under the License.
|
||||
#
|
||||
|
||||
licenses(["notice"]) # Apache 2.0
|
||||
|
||||
package(default_visibility = ["//mediapipe:__subpackages__"])
|
||||
|
||||
load(
|
||||
"//mediapipe/framework/tool:mediapipe_graph.bzl",
|
||||
"mediapipe_simple_subgraph",
|
||||
)
|
||||
|
||||
licenses(["notice"]) # Apache 2.0
|
||||
|
||||
package(default_visibility = ["//mediapipe:__subpackages__"])
|
||||
|
||||
filegroup(
|
||||
name = "test_graph",
|
||||
srcs = ["test.pbtxt"],
|
||||
|
@ -31,6 +31,8 @@ exports_files([
|
|||
"test.pbtxt",
|
||||
"dub_quad_test_subgraph.pbtxt",
|
||||
"nested_test_subgraph.pbtxt",
|
||||
"single_flow_container_test.pbtxt",
|
||||
"dual_flow_container_test.pbtxt",
|
||||
])
|
||||
|
||||
mediapipe_simple_subgraph(
|
||||
|
|
|
@ -12,14 +12,14 @@
|
|||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
licenses(["notice"]) # Apache 2.0
|
||||
|
||||
package(default_visibility = ["//visibility:public"])
|
||||
|
||||
load("//mediapipe/gpu:metal.bzl", "metal_library")
|
||||
load("@build_bazel_rules_apple//apple:ios.bzl", "ios_unit_test")
|
||||
load("//mediapipe/framework/port:build_config.bzl", "mediapipe_cc_proto_library")
|
||||
|
||||
licenses(["notice"]) # Apache 2.0
|
||||
|
||||
package(default_visibility = ["//visibility:public"])
|
||||
|
||||
# Disabling GPU support is sometimes useful on desktop Linux because SwiftShader can
|
||||
# interfere with desktop GL. b/73494271
|
||||
config_setting(
|
||||
|
|
|
@ -39,6 +39,7 @@ namespace mediapipe {
|
|||
// ROTATION: the counterclockwise rotation angle in degrees. This allows
|
||||
// user to specify different rotation angles for different frames. If this
|
||||
// stream is provided, it will override the ROTATION input side packet.
|
||||
// OUTPUT_DIMENSIONS: the output width and height in pixels.
|
||||
// Additional output streams:
|
||||
// TOP_BOTTOM_PADDING: If use FIT scale mode, this stream outputs the padding
|
||||
// size of the input image in normalized value [0, 1] for top and bottom
|
||||
|
@ -103,6 +104,9 @@ REGISTER_CALCULATOR(GlScalerCalculator);
|
|||
if (cc->Inputs().HasTag("ROTATION")) {
|
||||
cc->Inputs().Tag("ROTATION").Set<int>();
|
||||
}
|
||||
if (cc->Inputs().HasTag("OUTPUT_DIMENSIONS")) {
|
||||
cc->Inputs().Tag("OUTPUT_DIMENSIONS").Set<DimensionsPacketType>();
|
||||
}
|
||||
MP_RETURN_IF_ERROR(GlCalculatorHelper::UpdateContract(cc));
|
||||
|
||||
if (cc->InputSidePackets().HasTag("OPTIONS")) {
|
||||
|
@ -181,6 +185,18 @@ REGISTER_CALCULATOR(GlScalerCalculator);
|
|||
}
|
||||
|
||||
::mediapipe::Status GlScalerCalculator::Process(CalculatorContext* cc) {
|
||||
if (cc->Inputs().HasTag("OUTPUT_DIMENSIONS")) {
|
||||
if (cc->Inputs().Tag("OUTPUT_DIMENSIONS").IsEmpty()) {
|
||||
// OUTPUT_DIMENSIONS input stream is specified, but value is missing.
|
||||
return ::mediapipe::OkStatus();
|
||||
}
|
||||
|
||||
const auto& dimensions =
|
||||
cc->Inputs().Tag("OUTPUT_DIMENSIONS").Get<DimensionsPacketType>();
|
||||
dst_width_ = dimensions[0];
|
||||
dst_height_ = dimensions[1];
|
||||
}
|
||||
|
||||
return helper_.RunInGlContext([this, cc]() -> ::mediapipe::Status {
|
||||
const auto& input = TagOrIndex(cc->Inputs(), "VIDEO", 0).Get<GpuBuffer>();
|
||||
QuadRenderer* renderer = nullptr;
|
||||
|
|
|
@ -140,6 +140,9 @@ node {
|
|||
num_landmarks: 21
|
||||
input_image_width: 256
|
||||
input_image_height: 256
|
||||
# The additional scaling factor is used to account for the Z coordinate
|
||||
# distribution in the training data.
|
||||
normalize_z: 0.4
|
||||
}
|
||||
}
|
||||
}
|
||||
|
|
|
@ -144,6 +144,9 @@ node {
|
|||
num_landmarks: 21
|
||||
input_image_width: 256
|
||||
input_image_height: 256
|
||||
# The additional scaling factor is used to account for the Z coordinate
|
||||
# distribution in the training data.
|
||||
normalize_z: 0.4
|
||||
}
|
||||
}
|
||||
}
|
||||
|
|
|
@ -25,6 +25,7 @@ android_library(
|
|||
),
|
||||
visibility = ["//visibility:public"],
|
||||
deps = [
|
||||
"//mediapipe/framework:calculator_java_proto_lite",
|
||||
"//mediapipe/java/com/google/mediapipe/framework:android_framework",
|
||||
"//mediapipe/java/com/google/mediapipe/glutil",
|
||||
"//third_party:androidx_appcompat",
|
||||
|
|
|
@ -14,17 +14,21 @@
|
|||
|
||||
package com.google.mediapipe.components;
|
||||
|
||||
import static java.lang.Math.max;
|
||||
|
||||
import android.graphics.SurfaceTexture;
|
||||
import android.opengl.GLES11Ext;
|
||||
import android.opengl.GLES20;
|
||||
import android.util.Log;
|
||||
import com.google.mediapipe.framework.AppTextureFrame;
|
||||
import com.google.mediapipe.framework.GlSyncToken;
|
||||
import com.google.mediapipe.glutil.ExternalTextureRenderer;
|
||||
import com.google.mediapipe.glutil.GlThread;
|
||||
import com.google.mediapipe.glutil.ShaderUtil;
|
||||
import java.util.ArrayDeque;
|
||||
import java.util.ArrayList;
|
||||
import java.util.Collections;
|
||||
import java.util.List;
|
||||
import java.util.Queue;
|
||||
import javax.microedition.khronos.egl.EGLContext;
|
||||
|
||||
/**
|
||||
|
@ -204,8 +208,11 @@ public class ExternalTextureConverter implements TextureFrameProducer {
|
|||
private static final long NANOS_PER_MICRO = 1000; // Nanoseconds in one microsecond.
|
||||
private volatile SurfaceTexture surfaceTexture = null;
|
||||
private final List<TextureFrameConsumer> consumers;
|
||||
private List<AppTextureFrame> outputFrames = null;
|
||||
private int outputFrameIndex = -1;
|
||||
|
||||
private final Queue<PoolTextureFrame> framesAvailable = new ArrayDeque<>();
|
||||
private int framesInUse = 0;
|
||||
private final int framesToKeep;
|
||||
|
||||
private ExternalTextureRenderer renderer = null;
|
||||
private long nextFrameTimestampOffset = 0;
|
||||
private long timestampOffsetNanos = 0;
|
||||
|
@ -215,10 +222,27 @@ public class ExternalTextureConverter implements TextureFrameProducer {
|
|||
protected int destinationWidth = 0;
|
||||
protected int destinationHeight = 0;
|
||||
|
||||
private class PoolTextureFrame extends AppTextureFrame {
|
||||
public PoolTextureFrame(int textureName, int width, int height) {
|
||||
super(textureName, width, height);
|
||||
}
|
||||
|
||||
@Override
|
||||
public void release(GlSyncToken syncToken) {
|
||||
super.release(syncToken);
|
||||
poolFrameReleased(this);
|
||||
}
|
||||
|
||||
@Override
|
||||
public void release() {
|
||||
super.release();
|
||||
poolFrameReleased(this);
|
||||
}
|
||||
}
|
||||
|
||||
public RenderThread(EGLContext parentContext, int numBuffers) {
|
||||
super(parentContext);
|
||||
outputFrames = new ArrayList<>();
|
||||
outputFrames.addAll(Collections.nCopies(numBuffers, null));
|
||||
framesToKeep = numBuffers;
|
||||
renderer = new ExternalTextureRenderer();
|
||||
consumers = new ArrayList<>();
|
||||
}
|
||||
|
@ -283,8 +307,8 @@ public class ExternalTextureConverter implements TextureFrameProducer {
|
|||
@Override
|
||||
public void releaseGl() {
|
||||
setSurfaceTexture(null, 0, 0);
|
||||
for (int i = 0; i < outputFrames.size(); ++i) {
|
||||
teardownDestination(i);
|
||||
while (!framesAvailable.isEmpty()) {
|
||||
teardownFrame(framesAvailable.remove());
|
||||
}
|
||||
renderer.release();
|
||||
super.releaseGl(); // This releases the EGL context, so must do it after any GL calls.
|
||||
|
@ -337,16 +361,11 @@ public class ExternalTextureConverter implements TextureFrameProducer {
|
|||
}
|
||||
}
|
||||
|
||||
private void teardownDestination(int index) {
|
||||
if (outputFrames.get(index) != null) {
|
||||
waitUntilReleased(outputFrames.get(index));
|
||||
GLES20.glDeleteTextures(1, new int[] {outputFrames.get(index).getTextureName()}, 0);
|
||||
outputFrames.set(index, null);
|
||||
}
|
||||
private static void teardownFrame(AppTextureFrame frame) {
|
||||
GLES20.glDeleteTextures(1, new int[] {frame.getTextureName()}, 0);
|
||||
}
|
||||
|
||||
private void setupDestination(int index) {
|
||||
teardownDestination(index);
|
||||
private PoolTextureFrame createFrame() {
|
||||
int destinationTextureId = ShaderUtil.createRgbaTexture(destinationWidth, destinationHeight);
|
||||
Log.d(
|
||||
TAG,
|
||||
|
@ -354,11 +373,9 @@ public class ExternalTextureConverter implements TextureFrameProducer {
|
|||
"Created output texture: %d width: %d height: %d",
|
||||
destinationTextureId, destinationWidth, destinationHeight));
|
||||
bindFramebuffer(destinationTextureId, destinationWidth, destinationHeight);
|
||||
outputFrames.set(
|
||||
index, new AppTextureFrame(destinationTextureId, destinationWidth, destinationHeight));
|
||||
return new PoolTextureFrame(destinationTextureId, destinationWidth, destinationHeight);
|
||||
}
|
||||
|
||||
|
||||
/**
|
||||
* Gets next available frame or creates new one if next frame is not initialized
|
||||
* or cannot be used with current surface texture.
|
||||
|
@ -371,20 +388,38 @@ public class ExternalTextureConverter implements TextureFrameProducer {
|
|||
* NOTE: must be invoked on GL thread
|
||||
*/
|
||||
private AppTextureFrame nextOutputFrame() {
|
||||
outputFrameIndex = (outputFrameIndex + 1) % outputFrames.size();
|
||||
AppTextureFrame outputFrame = outputFrames.get(outputFrameIndex);
|
||||
// Check if the size has changed.
|
||||
if (outputFrame == null
|
||||
|| outputFrame.getWidth() != destinationWidth
|
||||
|| outputFrame.getHeight() != destinationHeight) {
|
||||
// setupDestination will wait for the frame to be released before reallocating it.
|
||||
setupDestination(outputFrameIndex);
|
||||
outputFrame = outputFrames.get(outputFrameIndex);
|
||||
PoolTextureFrame outputFrame;
|
||||
synchronized (this) {
|
||||
outputFrame = framesAvailable.poll();
|
||||
framesInUse++;
|
||||
}
|
||||
if (outputFrame == null) {
|
||||
outputFrame = createFrame();
|
||||
} else if (outputFrame.getWidth() != destinationWidth
|
||||
|| outputFrame.getHeight() != destinationHeight) {
|
||||
// Create anew if size has changed.
|
||||
// TODO: waiting for the consumer sync here may not be necessary.
|
||||
waitUntilReleased(outputFrame);
|
||||
teardownFrame(outputFrame);
|
||||
outputFrame = createFrame();
|
||||
} else {
|
||||
// Note: waitUntilReleased does two things: waits for the frame to be released by the CPU,
|
||||
// and syncs with the GPU sync token provided by the consumer. The first part is redundant
|
||||
// here (and completes immediately), but the second part is still needed.
|
||||
waitUntilReleased(outputFrame);
|
||||
}
|
||||
return outputFrame;
|
||||
}
|
||||
|
||||
protected synchronized void poolFrameReleased(PoolTextureFrame frame) {
|
||||
framesAvailable.offer(frame);
|
||||
framesInUse--;
|
||||
int keep = max(framesToKeep - framesInUse, 0);
|
||||
while (framesAvailable.size() > keep) {
|
||||
teardownFrame(framesAvailable.remove());
|
||||
}
|
||||
}
|
||||
|
||||
/**
|
||||
* Updates output frame with current pixels of surface texture and corresponding timestamp.
|
||||
*
|
||||
|
@ -417,16 +452,22 @@ public class ExternalTextureConverter implements TextureFrameProducer {
|
|||
Log.v(
|
||||
TAG,
|
||||
String.format(
|
||||
"Waiting for tex: %d width: %d height: %d",
|
||||
frame.getTextureName(), frame.getWidth(), frame.getHeight()));
|
||||
"Waiting for tex: %d width: %d height: %d timestamp: %d",
|
||||
frame.getTextureName(),
|
||||
frame.getWidth(),
|
||||
frame.getHeight(),
|
||||
frame.getTimestamp()));
|
||||
}
|
||||
frame.waitUntilReleased();
|
||||
if (Log.isLoggable(TAG, Log.VERBOSE)) {
|
||||
Log.v(
|
||||
TAG,
|
||||
String.format(
|
||||
"Finished waiting for tex: %d width: %d height: %d",
|
||||
frame.getTextureName(), frame.getWidth(), frame.getHeight()));
|
||||
"Finished waiting for tex: %d width: %d height: %d timestamp: %d",
|
||||
frame.getTextureName(),
|
||||
frame.getWidth(),
|
||||
frame.getHeight(),
|
||||
frame.getTimestamp()));
|
||||
}
|
||||
} catch (InterruptedException ie) {
|
||||
// Someone interrupted our thread. This is not supposed to happen: we own
|
||||
|
|
|
@ -20,6 +20,7 @@ import android.media.AudioFormat;
|
|||
import android.os.Handler;
|
||||
import android.util.Log;
|
||||
import com.google.common.base.Preconditions;
|
||||
import com.google.mediapipe.proto.CalculatorProto.CalculatorGraphConfig;
|
||||
import com.google.mediapipe.framework.AndroidAssetUtil;
|
||||
import com.google.mediapipe.framework.AndroidPacketCreator;
|
||||
import com.google.mediapipe.framework.Graph;
|
||||
|
@ -32,10 +33,12 @@ import com.google.mediapipe.framework.SurfaceOutput;
|
|||
import com.google.mediapipe.framework.TextureFrame;
|
||||
import java.io.File;
|
||||
import java.nio.ByteBuffer;
|
||||
import java.util.ArrayDeque;
|
||||
import java.util.ArrayList;
|
||||
import java.util.Arrays;
|
||||
import java.util.List;
|
||||
import java.util.Map;
|
||||
import java.util.Queue;
|
||||
import java.util.concurrent.atomic.AtomicBoolean;
|
||||
import javax.annotation.Nullable;
|
||||
|
||||
|
@ -106,6 +109,15 @@ public class FrameProcessor implements TextureFrameProcessor, AudioDataProcessor
|
|||
initializeGraphAndPacketCreator(context, graphName);
|
||||
}
|
||||
|
||||
/**
|
||||
* Constructor.
|
||||
*
|
||||
* @param graphConfig the proto object representation of the graph.
|
||||
*/
|
||||
public FrameProcessor(CalculatorGraphConfig graphConfig) {
|
||||
initializeGraphAndPacketCreator(graphConfig);
|
||||
}
|
||||
|
||||
/**
|
||||
* Initializes a graph for processing data in real time.
|
||||
*
|
||||
|
@ -123,6 +135,17 @@ public class FrameProcessor implements TextureFrameProcessor, AudioDataProcessor
|
|||
packetCreator = new AndroidPacketCreator(mediapipeGraph);
|
||||
}
|
||||
|
||||
/**
|
||||
* Initializes a graph for processing data in real time.
|
||||
*
|
||||
* @param graphConfig the proto object representation of the graph.
|
||||
*/
|
||||
private void initializeGraphAndPacketCreator(CalculatorGraphConfig graphConfig) {
|
||||
mediapipeGraph = new Graph();
|
||||
mediapipeGraph.loadBinaryGraph(graphConfig);
|
||||
packetCreator = new AndroidPacketCreator(mediapipeGraph);
|
||||
}
|
||||
|
||||
/** Callback for errors occurring during processing in the graph. */
|
||||
public interface ErrorListener {
|
||||
void onError(RuntimeException error);
|
||||
|
@ -186,6 +209,8 @@ public class FrameProcessor implements TextureFrameProcessor, AudioDataProcessor
|
|||
currentConsumers = videoConsumers;
|
||||
}
|
||||
for (TextureFrameConsumer consumer : currentConsumers) {
|
||||
// Note: each consumer will release its TextureFrame, so each gets a separate object
|
||||
// (though they all reference the same data).
|
||||
TextureFrame frame = PacketGetter.getTextureFrame(packet);
|
||||
if (Log.isLoggable(TAG, Log.VERBOSE)) {
|
||||
Log.v(
|
||||
|
@ -373,9 +398,10 @@ public class FrameProcessor implements TextureFrameProcessor, AudioDataProcessor
|
|||
|
||||
/**
|
||||
* Returns true if the MediaPipe graph can accept one more input frame.
|
||||
*
|
||||
* @throws MediaPipeException for any error status.
|
||||
*/
|
||||
private boolean maybeAcceptNewFrame() {
|
||||
private boolean maybeAcceptNewFrame(long timestamp) {
|
||||
if (!started.getAndSet(true)) {
|
||||
startGraph();
|
||||
}
|
||||
|
@ -395,7 +421,7 @@ public class FrameProcessor implements TextureFrameProcessor, AudioDataProcessor
|
|||
frame.getTextureName(), frame.getWidth(), frame.getHeight()));
|
||||
}
|
||||
|
||||
if (!maybeAcceptNewFrame()) {
|
||||
if (!maybeAcceptNewFrame(frame.getTimestamp())) {
|
||||
return;
|
||||
}
|
||||
|
||||
|
@ -451,7 +477,7 @@ public class FrameProcessor implements TextureFrameProcessor, AudioDataProcessor
|
|||
public void onNewFrame(final Bitmap bitmap, long timestamp) {
|
||||
Packet packet = null;
|
||||
try {
|
||||
if (!maybeAcceptNewFrame()) {
|
||||
if (!maybeAcceptNewFrame(timestamp)) {
|
||||
return;
|
||||
}
|
||||
|
||||
|
|
|
@ -17,8 +17,8 @@ package com.google.mediapipe.components;
|
|||
import android.Manifest;
|
||||
import android.app.Activity;
|
||||
import android.content.pm.PackageManager;
|
||||
import androidx.core.app.ActivityCompat;
|
||||
import android.util.Log;
|
||||
import androidx.core.app.ActivityCompat;
|
||||
import androidx.core.content.ContextCompat;
|
||||
|
||||
/** Manages camera permission request and handling. */
|
||||
|
|
|
@ -18,6 +18,10 @@ import com.google.mediapipe.framework.TextureFrame;
|
|||
|
||||
/** Lightweight abstraction for an object that can receive video frames. */
|
||||
public interface TextureFrameConsumer {
|
||||
/** Called when a new {@link TextureFrame} is available. */
|
||||
/**
|
||||
* Called when a new {@link TextureFrame} is available.
|
||||
*
|
||||
* Important: implementations of this method should call frame.release().
|
||||
**/
|
||||
public abstract void onNewFrame(TextureFrame frame);
|
||||
}
|
||||
|
|
|
@ -272,6 +272,10 @@ public final class PacketGetter {
|
|||
* <p>Note: in order for the application to be able to use the texture, its GL context must be
|
||||
* linked with MediaPipe's. This is ensured by calling {@link Graph#createGlRunner(String,long)}
|
||||
* with the native handle to the application's GL context as the second argument.
|
||||
*
|
||||
* <p>The returned GraphTextureFrame must be released by the caller. If this method is called
|
||||
* multiple times, each returned GraphTextureFrame is an independent reference to the underlying
|
||||
* texture data, and must be released individually.
|
||||
*/
|
||||
public static GraphTextureFrame getTextureFrame(final Packet packet) {
|
||||
return new GraphTextureFrame(
|
||||
|
|
|
@ -1,7 +0,0 @@
|
|||
tricorder: {
|
||||
options: {
|
||||
builder: {
|
||||
config: "android_arm"
|
||||
}
|
||||
}
|
||||
}
|
|
@ -33,22 +33,22 @@ struct GpuSharedData;
|
|||
|
||||
/// Provides the delegate with a new video frame.
|
||||
@optional
|
||||
- (void)mediapipeGraph:(MPPGraph*)graph
|
||||
- (void)mediapipeGraph:(MPPGraph *)graph
|
||||
didOutputPixelBuffer:(CVPixelBufferRef)pixelBuffer
|
||||
fromStream:(const std::string&)streamName;
|
||||
fromStream:(const std::string &)streamName;
|
||||
|
||||
/// Provides the delegate with a new video frame and time stamp.
|
||||
@optional
|
||||
- (void)mediapipeGraph:(MPPGraph*)graph
|
||||
- (void)mediapipeGraph:(MPPGraph *)graph
|
||||
didOutputPixelBuffer:(CVPixelBufferRef)pixelBuffer
|
||||
fromStream:(const std::string&)streamName
|
||||
timestamp:(const mediapipe::Timestamp&)timestamp;
|
||||
fromStream:(const std::string &)streamName
|
||||
timestamp:(const mediapipe::Timestamp &)timestamp;
|
||||
|
||||
/// Provides the delegate with a raw packet.
|
||||
@optional
|
||||
- (void)mediapipeGraph:(MPPGraph*)graph
|
||||
didOutputPacket:(const mediapipe::Packet&)packet
|
||||
fromStream:(const std::string&)streamName;
|
||||
- (void)mediapipeGraph:(MPPGraph *)graph
|
||||
didOutputPacket:(const mediapipe::Packet &)packet
|
||||
fromStream:(const std::string &)streamName;
|
||||
|
||||
@end
|
||||
|
||||
|
@ -100,34 +100,34 @@ typedef NS_ENUM(int, MPPPacketType) {
|
|||
|
||||
/// Copies the config and initializes the graph.
|
||||
/// @param config The configuration describing the graph.
|
||||
- (instancetype)initWithGraphConfig:(const mediapipe::CalculatorGraphConfig&)config
|
||||
- (instancetype)initWithGraphConfig:(const mediapipe::CalculatorGraphConfig &)config
|
||||
NS_DESIGNATED_INITIALIZER;
|
||||
|
||||
- (mediapipe::ProfilingContext*)getProfiler;
|
||||
- (mediapipe::ProfilingContext *)getProfiler;
|
||||
|
||||
/// Sets a stream header. If the header was already set, it is overwritten.
|
||||
/// @param packet The header.
|
||||
/// @param streamName The name of the stream.
|
||||
- (void)setHeaderPacket:(const mediapipe::Packet&)packet forStream:(const std::string&)streamName;
|
||||
- (void)setHeaderPacket:(const mediapipe::Packet &)packet forStream:(const std::string &)streamName;
|
||||
|
||||
/// Sets a side packet. If it was already set, it is overwritten.
|
||||
/// Must be called before the graph is started.
|
||||
/// @param packet The packet to be associated with the input side packet.
|
||||
/// @param name The name of the input side packet.
|
||||
- (void)setSidePacket:(const mediapipe::Packet&)packet named:(const std::string&)name;
|
||||
- (void)setSidePacket:(const mediapipe::Packet &)packet named:(const std::string &)name;
|
||||
|
||||
/// Sets a service packet. If it was already set, it is overwritten.
|
||||
/// Must be called before the graph is started.
|
||||
/// @param packet The packet to be associated with the service.
|
||||
/// @param service.
|
||||
- (void)setServicePacket:(mediapipe::Packet&)packet
|
||||
forService:(const mediapipe::GraphServiceBase&)service;
|
||||
- (void)setServicePacket:(mediapipe::Packet &)packet
|
||||
forService:(const mediapipe::GraphServiceBase &)service;
|
||||
|
||||
/// Adds input side packets from a map. Any inputs that were already set are
|
||||
/// left unchanged.
|
||||
/// Must be called before the graph is started.
|
||||
/// @param extraInputSidePackets The input side packets to be added.
|
||||
- (void)addSidePackets:(const std::map<std::string, mediapipe::Packet>&)extraSidePackets;
|
||||
- (void)addSidePackets:(const std::map<std::string, mediapipe::Packet> &)extraSidePackets;
|
||||
|
||||
// TODO: rename to addDelegateOutputStream:packetType:
|
||||
/// Add an output stream in the graph from which the delegate wants to receive
|
||||
|
@ -135,30 +135,30 @@ typedef NS_ENUM(int, MPPPacketType) {
|
|||
/// @param outputStreamName The name of the output stream from which
|
||||
/// the delegate will receive frames.
|
||||
/// @param packetType The type of packet provided by the output streams.
|
||||
- (void)addFrameOutputStream:(const std::string&)outputStreamName
|
||||
- (void)addFrameOutputStream:(const std::string &)outputStreamName
|
||||
outputPacketType:(MPPPacketType)packetType;
|
||||
|
||||
/// Starts running the graph.
|
||||
/// @return YES if successful.
|
||||
- (BOOL)startWithError:(NSError**)error;
|
||||
- (BOOL)startWithError:(NSError **)error;
|
||||
|
||||
/// Sends a generic packet into a graph input stream.
|
||||
/// The graph must have been started before calling this.
|
||||
/// Returns YES if the packet was successfully sent.
|
||||
- (BOOL)sendPacket:(const mediapipe::Packet&)packet
|
||||
intoStream:(const std::string&)streamName
|
||||
error:(NSError**)error;
|
||||
- (BOOL)sendPacket:(const mediapipe::Packet &)packet
|
||||
intoStream:(const std::string &)streamName
|
||||
error:(NSError **)error;
|
||||
|
||||
- (BOOL)movePacket:(mediapipe::Packet&&)packet
|
||||
intoStream:(const std::string&)streamName
|
||||
error:(NSError**)error;
|
||||
- (BOOL)movePacket:(mediapipe::Packet &&)packet
|
||||
intoStream:(const std::string &)streamName
|
||||
error:(NSError **)error;
|
||||
|
||||
/// Sets the maximum queue size for a stream. Experimental feature, currently
|
||||
/// only supported for graph input streams. Should be called before starting the
|
||||
/// graph.
|
||||
- (BOOL)setMaxQueueSize:(int)maxQueueSize
|
||||
forStream:(const std::string&)streamName
|
||||
error:(NSError**)error;
|
||||
forStream:(const std::string &)streamName
|
||||
error:(NSError **)error;
|
||||
|
||||
/// Creates a MediaPipe packet wrapping the given pixelBuffer;
|
||||
- (mediapipe::Packet)packetWithPixelBuffer:(CVPixelBufferRef)pixelBuffer
|
||||
|
@ -170,9 +170,9 @@ typedef NS_ENUM(int, MPPPacketType) {
|
|||
/// allows MediaPipe to overwrite the packet contents on successful sending for
|
||||
/// possibly increased efficiency. Returns YES if the packet was successfully sent.
|
||||
- (BOOL)sendPixelBuffer:(CVPixelBufferRef)imageBuffer
|
||||
intoStream:(const std::string&)inputName
|
||||
intoStream:(const std::string &)inputName
|
||||
packetType:(MPPPacketType)packetType
|
||||
timestamp:(const mediapipe::Timestamp&)timestamp
|
||||
timestamp:(const mediapipe::Timestamp &)timestamp
|
||||
allowOverwrite:(BOOL)allowOverwrite;
|
||||
|
||||
/// Sends a pixel buffer into a graph input stream, using the specified packet
|
||||
|
@ -180,9 +180,23 @@ typedef NS_ENUM(int, MPPPacketType) {
|
|||
/// returns NO if maxFramesInFlight is exceeded. Returns YES if the packet was
|
||||
/// successfully sent.
|
||||
- (BOOL)sendPixelBuffer:(CVPixelBufferRef)pixelBuffer
|
||||
intoStream:(const std::string&)inputName
|
||||
intoStream:(const std::string &)inputName
|
||||
packetType:(MPPPacketType)packetType
|
||||
timestamp:(const mediapipe::Timestamp&)timestamp;
|
||||
timestamp:(const mediapipe::Timestamp &)timestamp;
|
||||
|
||||
/// Sends a pixel buffer into a graph input stream, using the specified packet
|
||||
/// type. The graph must have been started before calling this. Drops frames and
|
||||
/// returns NO if maxFramesInFlight is exceeded. If allowOverwrite is set to YES,
|
||||
/// allows MediaPipe to overwrite the packet contents on successful sending for
|
||||
/// possibly increased efficiency. Returns YES if the packet was successfully sent.
|
||||
/// Sets error to a non-nil value if an error occurs in the graph when sending the
|
||||
/// packet.
|
||||
- (BOOL)sendPixelBuffer:(CVPixelBufferRef)imageBuffer
|
||||
intoStream:(const std::string &)inputName
|
||||
packetType:(MPPPacketType)packetType
|
||||
timestamp:(const mediapipe::Timestamp &)timestamp
|
||||
allowOverwrite:(BOOL)allowOverwrite
|
||||
error:(NSError **)error;
|
||||
|
||||
/// Sends a pixel buffer into a graph input stream, using the specified packet
|
||||
/// type. The graph must have been started before calling this. The timestamp is
|
||||
|
@ -190,32 +204,32 @@ typedef NS_ENUM(int, MPPPacketType) {
|
|||
/// frames and returns NO if maxFramesInFlight is exceeded. Returns YES if the
|
||||
/// packet was successfully sent.
|
||||
- (BOOL)sendPixelBuffer:(CVPixelBufferRef)pixelBuffer
|
||||
intoStream:(const std::string&)inputName
|
||||
intoStream:(const std::string &)inputName
|
||||
packetType:(MPPPacketType)packetType;
|
||||
|
||||
/// Cancels a graph run. You must still call waitUntilDoneWithError: after this.
|
||||
- (void)cancel;
|
||||
|
||||
/// Check if the graph contains this input stream
|
||||
- (BOOL)hasInputStream:(const std::string&)inputName;
|
||||
- (BOOL)hasInputStream:(const std::string &)inputName;
|
||||
|
||||
/// Closes an input stream.
|
||||
/// You must close all graph input streams before stopping the graph.
|
||||
/// @return YES if successful.
|
||||
- (BOOL)closeInputStream:(const std::string&)inputName error:(NSError**)error;
|
||||
- (BOOL)closeInputStream:(const std::string &)inputName error:(NSError **)error;
|
||||
|
||||
/// Closes all graph input streams.
|
||||
/// @return YES if successful.
|
||||
- (BOOL)closeAllInputStreamsWithError:(NSError**)error;
|
||||
- (BOOL)closeAllInputStreamsWithError:(NSError **)error;
|
||||
|
||||
/// Stops running the graph.
|
||||
/// Call this before releasing this object. All input streams must have been
|
||||
/// closed. This call does not time out, so you should not call it from the main
|
||||
/// thread.
|
||||
/// @return YES if successful.
|
||||
- (BOOL)waitUntilDoneWithError:(NSError**)error;
|
||||
- (BOOL)waitUntilDoneWithError:(NSError **)error;
|
||||
|
||||
/// Waits for the graph to become idle.
|
||||
- (BOOL)waitUntilIdleWithError:(NSError**)error;
|
||||
- (BOOL)waitUntilIdleWithError:(NSError **)error;
|
||||
|
||||
@end
|
||||
|
|
|
@ -327,22 +327,35 @@ void CallFrameDelegate(void* wrapperVoid, const std::string& streamName,
|
|||
packetType:(MPPPacketType)packetType
|
||||
timestamp:(const mediapipe::Timestamp&)timestamp
|
||||
allowOverwrite:(BOOL)allowOverwrite {
|
||||
NSError* error;
|
||||
bool success = [self sendPixelBuffer:imageBuffer
|
||||
intoStream:inputName
|
||||
packetType:packetType
|
||||
timestamp:timestamp
|
||||
allowOverwrite:allowOverwrite
|
||||
error:&error];
|
||||
if (error) {
|
||||
_GTMDevLog(@"failed to send packet: %@", error);
|
||||
}
|
||||
return success;
|
||||
}
|
||||
|
||||
- (BOOL)sendPixelBuffer:(CVPixelBufferRef)imageBuffer
|
||||
intoStream:(const std::string&)inputName
|
||||
packetType:(MPPPacketType)packetType
|
||||
timestamp:(const mediapipe::Timestamp&)timestamp
|
||||
allowOverwrite:(BOOL)allowOverwrite
|
||||
error:(NSError**)error {
|
||||
if (_maxFramesInFlight && _framesInFlight >= _maxFramesInFlight) return NO;
|
||||
mediapipe::Packet packet = [self packetWithPixelBuffer:imageBuffer packetType:packetType];
|
||||
NSError* error;
|
||||
BOOL success;
|
||||
if (allowOverwrite) {
|
||||
packet = std::move(packet).At(timestamp);
|
||||
success = [self movePacket:std::move(packet)
|
||||
intoStream:inputName
|
||||
error:&error];
|
||||
success = [self movePacket:std::move(packet) intoStream:inputName error:error];
|
||||
} else {
|
||||
success = [self sendPacket:packet.At(timestamp)
|
||||
intoStream:inputName
|
||||
error:&error];
|
||||
success = [self sendPacket:packet.At(timestamp) intoStream:inputName error:error];
|
||||
}
|
||||
if (success) _framesInFlight++;
|
||||
else _GTMDevLog(@"failed to send packet: %@", error);
|
||||
return success;
|
||||
}
|
||||
|
||||
|
|
|
@ -423,6 +423,10 @@ tasks and tracking (or class) fields for tracking information.
|
|||
|`region/point/y`|feature list float list|`add_bbox_point_y` / `AddBBoxPointY`|A list of normalized y values for points in a frame.|
|
||||
|`region/point/\*`| *special* |`add_bbox_point` / `AddBBoxPoint`|Operates on point/x,point/y with a single call.|
|
||||
|`region/point/radius`|feature list float list|`add_bbox_point_radius` / `AddBBoxPointRadius`|A list of radii for points in a frame.|
|
||||
|`region/3d_point/x`|feature list float list|`add_bbox_3d_point_x` / `AddBBox3dPointX`|A list of normalized x values for points in a frame.|
|
||||
|`region/3d_point/y`|feature list float list|`add_bbox_3d_point_y` / `AddBBox3dPointY`|A list of normalized y values for points in a frame.|
|
||||
|`region/3d_point/z`|feature list float list|`add_bbox_3d_point_z` / `AddBBox3dPointZ`|A list of normalized z values for points in a frame.|
|
||||
|`region/3d_point/\*`| *special* |`add_bbox_3d_point` / `AddBBox3dPoint`|Operates on 3d_point/{x,y,z} with a single call.|
|
||||
|`region/timestamp`|feature list int|`add_bbox_timestamp` / `AddBBoxTimestamp`|The timestamp in microseconds for the region annotations.|
|
||||
|`region/num_regions`|feature list int|`add_bbox_num_regions` / `AddBBoxNumRegions`|The number of boxes or other regions in a frame. Should be 0 for unannotated frames.|
|
||||
|`region/is_annotated`|feature list int|`add_bbox_is_annotated` / `AddBBoxIsAnnotated`|1 if this timestep is annotated. 0 otherwise. Distinguishes empty from unannotated frames.|
|
||||
|
|
|
@ -229,6 +229,18 @@ float TimestampsToRate(int64 first_timestamp, int64 second_timestamp) {
|
|||
sequence);
|
||||
}
|
||||
}
|
||||
if (Get3dPointSize(prefix, *sequence) > 0) {
|
||||
std::string x_key = merge_prefix(prefix, kRegion3dPointXKey);
|
||||
auto* region_feature_list = MutableFeatureList(x_key, sequence);
|
||||
RET_CHECK_EQ(num_bboxes, region_feature_list->feature_size())
|
||||
<< "Expected number of BBox timestamps and boxes to match.";
|
||||
ClearBBoxNumRegions(prefix, sequence);
|
||||
for (int i = 0; i < num_bboxes; ++i) {
|
||||
AddBBoxNumRegions(
|
||||
prefix, region_feature_list->feature(i).float_list().value_size(),
|
||||
sequence);
|
||||
}
|
||||
}
|
||||
// Collect which timestamps currently match to which indices in timestamps.
|
||||
// skip empty timestamps.
|
||||
// Requires sorted indices.
|
||||
|
@ -453,6 +465,47 @@ void ClearPoint(const std::string& prefix,
|
|||
ClearBBoxPointX(prefix, sequence);
|
||||
}
|
||||
|
||||
int Get3dPointSize(const std::string& prefix,
|
||||
const tensorflow::SequenceExample& sequence) {
|
||||
return GetBBox3dPointXSize(prefix, sequence);
|
||||
}
|
||||
|
||||
std::vector<::std::tuple<float, float, float>> Get3dPointAt(
|
||||
const std::string& prefix, const tensorflow::SequenceExample& sequence,
|
||||
int index) {
|
||||
const auto& xs = GetBBox3dPointXAt(prefix, sequence, index);
|
||||
const auto& ys = GetBBox3dPointYAt(prefix, sequence, index);
|
||||
const auto& zs = GetBBox3dPointZAt(prefix, sequence, index);
|
||||
std::vector<::std::tuple<float, float, float>> points(ys.size());
|
||||
for (int i = 0; i < xs.size(); ++i) {
|
||||
points[i] = std::make_tuple(xs[i], ys[i], zs[i]);
|
||||
}
|
||||
return points;
|
||||
}
|
||||
|
||||
void Add3dPoint(const std::string& prefix,
|
||||
const std::vector<::std::tuple<float, float, float>>& points,
|
||||
tensorflow::SequenceExample* sequence) {
|
||||
::std::vector<float> xs;
|
||||
::std::vector<float> ys;
|
||||
::std::vector<float> zs;
|
||||
for (auto& point : points) {
|
||||
xs.push_back(std::get<0>(point));
|
||||
ys.push_back(std::get<1>(point));
|
||||
zs.push_back(std::get<2>(point));
|
||||
}
|
||||
AddBBox3dPointX(prefix, xs, sequence);
|
||||
AddBBox3dPointY(prefix, ys, sequence);
|
||||
AddBBox3dPointZ(prefix, zs, sequence);
|
||||
}
|
||||
|
||||
void Clear3dPoint(const std::string& prefix,
|
||||
tensorflow::SequenceExample* sequence) {
|
||||
ClearBBox3dPointX(prefix, sequence);
|
||||
ClearBBox3dPointY(prefix, sequence);
|
||||
ClearBBox3dPointZ(prefix, sequence);
|
||||
}
|
||||
|
||||
std::unique_ptr<mediapipe::Matrix> GetAudioFromFeatureAt(
|
||||
const std::string& prefix, const tensorflow::SequenceExample& sequence,
|
||||
int index) {
|
||||
|
|
|
@ -268,6 +268,10 @@ const char kRegionBBoxXMaxKey[] = "region/bbox/xmax";
|
|||
const char kRegionPointXKey[] = "region/point/x";
|
||||
const char kRegionPointYKey[] = "region/point/y";
|
||||
const char kRegionRadiusKey[] = "region/radius";
|
||||
// The 3d point can denote keypoints.
|
||||
const char kRegion3dPointXKey[] = "region/3d_point/x";
|
||||
const char kRegion3dPointYKey[] = "region/3d_point/y";
|
||||
const char kRegion3dPointZKey[] = "region/3d_point/z";
|
||||
// The number of regions at that timestep.
|
||||
const char kRegionNumRegionsKey[] = "region/num_regions";
|
||||
// Whether that timestep is annotated for bounding regions.
|
||||
|
@ -333,6 +337,18 @@ void AddPoint(const std::string& prefix,
|
|||
void ClearPoint(const std::string& prefix,
|
||||
tensorflow::SequenceExample* sequence);
|
||||
|
||||
// The input and output format is a pair of <y, x> coordinates to match the
|
||||
// order of bounding box coordinates.
|
||||
int Get3dPointSize(const std::string& prefix,
|
||||
const tensorflow::SequenceExample& sequence);
|
||||
std::vector<std::tuple<float, float, float>> Get3dPointAt(
|
||||
const std::string& prefix, const tensorflow::SequenceExample& sequence,
|
||||
int index);
|
||||
void Add3dPoint(const std::string& prefix,
|
||||
const std::vector<std::tuple<float, float, float>>& points,
|
||||
tensorflow::SequenceExample* sequence);
|
||||
void Clear3dPoint(const std::string& prefix,
|
||||
tensorflow::SequenceExample* sequence);
|
||||
#define FIXED_PREFIX_BBOX_ACCESSORS(identifier, prefix) \
|
||||
inline int CONCAT_STR3(Get, identifier, \
|
||||
Size)(const tensorflow::SequenceExample& sequence) { \
|
||||
|
@ -388,6 +404,44 @@ void ClearPoint(const std::string& prefix,
|
|||
inline void CONCAT_STR3(Clear, identifier, Point)( \
|
||||
std::string name, tensorflow::SequenceExample * sequence) { \
|
||||
return ClearPoint(name, sequence); \
|
||||
} \
|
||||
inline int CONCAT_STR3(Get, identifier, 3dPointSize)( \
|
||||
const tensorflow::SequenceExample& sequence) { \
|
||||
return Get3dPointSize(prefix, sequence); \
|
||||
} \
|
||||
inline int CONCAT_STR3(Get, identifier, 3dPointSize)( \
|
||||
const std::string& name, const tensorflow::SequenceExample& sequence) { \
|
||||
return Get3dPointSize(name, sequence); \
|
||||
} \
|
||||
inline std::vector<std::tuple<float, float, float>> CONCAT_STR3( \
|
||||
Get, identifier, 3dPointAt)(const tensorflow::SequenceExample& sequence, \
|
||||
int index) { \
|
||||
return Get3dPointAt(prefix, sequence, index); \
|
||||
} \
|
||||
inline std::vector<std::tuple<float, float, float>> CONCAT_STR3( \
|
||||
Get, identifier, 3dPointAt)(const std::string& name, \
|
||||
const tensorflow::SequenceExample& sequence, \
|
||||
int index) { \
|
||||
return Get3dPointAt(name, sequence, index); \
|
||||
} \
|
||||
inline void CONCAT_STR3(Add, identifier, 3dPoint)( \
|
||||
const std::vector<std::tuple<float, float, float>>& points, \
|
||||
tensorflow::SequenceExample* sequence) { \
|
||||
return Add3dPoint(prefix, points, sequence); \
|
||||
} \
|
||||
inline void CONCAT_STR3(Add, identifier, 3dPoint)( \
|
||||
const std::string& name, \
|
||||
const std::vector<std::tuple<float, float, float>>& points, \
|
||||
tensorflow::SequenceExample* sequence) { \
|
||||
return Add3dPoint(name, points, sequence); \
|
||||
} \
|
||||
inline void CONCAT_STR3(Clear, identifier, \
|
||||
3dPoint)(tensorflow::SequenceExample * sequence) { \
|
||||
return Clear3dPoint(prefix, sequence); \
|
||||
} \
|
||||
inline void CONCAT_STR3(Clear, identifier, 3dPoint)( \
|
||||
std::string name, tensorflow::SequenceExample * sequence) { \
|
||||
return Clear3dPoint(name, sequence); \
|
||||
}
|
||||
|
||||
#define PREFIXED_BBOX(identifier, prefix) \
|
||||
|
@ -435,6 +489,12 @@ void ClearPoint(const std::string& prefix,
|
|||
kRegionPointYKey, prefix) \
|
||||
FIXED_PREFIX_VECTOR_FLOAT_FEATURE_LIST(CONCAT_STR2(identifier, Radius), \
|
||||
kRegionRadiusKey, prefix) \
|
||||
FIXED_PREFIX_VECTOR_FLOAT_FEATURE_LIST(CONCAT_STR2(identifier, 3dPointX), \
|
||||
kRegion3dPointXKey, prefix) \
|
||||
FIXED_PREFIX_VECTOR_FLOAT_FEATURE_LIST(CONCAT_STR2(identifier, 3dPointY), \
|
||||
kRegion3dPointYKey, prefix) \
|
||||
FIXED_PREFIX_VECTOR_FLOAT_FEATURE_LIST(CONCAT_STR2(identifier, 3dPointZ), \
|
||||
kRegion3dPointZKey, prefix) \
|
||||
FIXED_PREFIX_VECTOR_FLOAT_FEATURE_LIST( \
|
||||
CONCAT_STR2(identifier, EmbeddingFloats), kRegionEmbeddingFloatKey, \
|
||||
prefix) \
|
||||
|
|
|
@ -262,6 +262,10 @@ REGION_BBOX_XMAX_KEY = "region/bbox/xmax"
|
|||
REGION_POINT_X_KEY = "region/point/x"
|
||||
REGION_POINT_Y_KEY = "region/point/y"
|
||||
REGION_RADIUS_KEY = "region/radius"
|
||||
# The 3D point can denote keypoints.
|
||||
REGION_3D_POINT_X_KEY = "region/3d_point/x"
|
||||
REGION_3D_POINT_Y_KEY = "region/3d_point/y"
|
||||
REGION_3D_POINT_Z_KEY = "region/3d_point/z"
|
||||
# The number of regions at that timestep.
|
||||
REGION_NUM_REGIONS_KEY = "region/num_regions"
|
||||
# Whether that timestep is annotated for regions.
|
||||
|
@ -365,6 +369,15 @@ def _create_region_with_prefix(name, prefix):
|
|||
prefix=prefix, module_dict=globals())
|
||||
msu.create_float_list_feature_list(name + "_point_y", REGION_POINT_Y_KEY,
|
||||
prefix=prefix, module_dict=globals())
|
||||
msu.create_float_list_feature_list(
|
||||
name + "_3d_point_x", REGION_3D_POINT_X_KEY,
|
||||
prefix=prefix, module_dict=globals())
|
||||
msu.create_float_list_feature_list(
|
||||
name + "_3d_point_y", REGION_3D_POINT_Y_KEY,
|
||||
prefix=prefix, module_dict=globals())
|
||||
msu.create_float_list_feature_list(
|
||||
name + "_3d_point_z", REGION_3D_POINT_Z_KEY,
|
||||
prefix=prefix, module_dict=globals())
|
||||
msu.create_bytes_list_context_feature(name + "_parts",
|
||||
REGION_PARTS_KEY,
|
||||
prefix=prefix, module_dict=globals())
|
||||
|
@ -406,6 +419,39 @@ def _create_region_with_prefix(name, prefix):
|
|||
clear_bbox_xmin(sequence_example, prefix=prefix)
|
||||
clear_bbox_ymax(sequence_example, prefix=prefix)
|
||||
clear_bbox_xmax(sequence_example, prefix=prefix)
|
||||
def get_prefixed_point_at(index, sequence_example, prefix):
|
||||
return np.stack((
|
||||
get_bbox_point_y_at(index, sequence_example, prefix=prefix),
|
||||
get_bbox_point_x_at(index, sequence_example, prefix=prefix)),
|
||||
1)
|
||||
def add_prefixed_point(values, sequence_example, prefix):
|
||||
add_bbox_point_y(values[:, 0], sequence_example, prefix=prefix)
|
||||
add_bbox_point_x(values[:, 1], sequence_example, prefix=prefix)
|
||||
def get_prefixed_point_size(sequence_example, prefix):
|
||||
return get_bbox_point_y_size(sequence_example, prefix=prefix)
|
||||
def has_prefixed_point(sequence_example, prefix):
|
||||
return has_bbox_point_y(sequence_example, prefix=prefix)
|
||||
def clear_prefixed_point(sequence_example, prefix):
|
||||
clear_bbox_point_y(sequence_example, prefix=prefix)
|
||||
clear_bbox_point_x(sequence_example, prefix=prefix)
|
||||
def get_prefixed_3d_point_at(index, sequence_example, prefix):
|
||||
return np.stack((
|
||||
get_bbox_3d_point_x_at(index, sequence_example, prefix=prefix),
|
||||
get_bbox_3d_point_y_at(index, sequence_example, prefix=prefix),
|
||||
get_bbox_3d_point_z_at(index, sequence_example, prefix=prefix)),
|
||||
1)
|
||||
def add_prefixed_3d_point(values, sequence_example, prefix):
|
||||
add_bbox_3d_point_x(values[:, 0], sequence_example, prefix=prefix)
|
||||
add_bbox_3d_point_y(values[:, 1], sequence_example, prefix=prefix)
|
||||
add_bbox_3d_point_z(values[:, 2], sequence_example, prefix=prefix)
|
||||
def get_prefixed_3d_point_size(sequence_example, prefix):
|
||||
return get_bbox_3d_point_x_size(sequence_example, prefix=prefix)
|
||||
def has_prefixed_3d_point(sequence_example, prefix):
|
||||
return has_bbox_3d_point_x(sequence_example, prefix=prefix)
|
||||
def clear_prefixed_3d_point(sequence_example, prefix):
|
||||
clear_bbox_3d_point_x(sequence_example, prefix=prefix)
|
||||
clear_bbox_3d_point_y(sequence_example, prefix=prefix)
|
||||
clear_bbox_3d_point_z(sequence_example, prefix=prefix)
|
||||
# pylint: enable=undefined-variable
|
||||
msu.add_functions_to_module({
|
||||
"get_" + name + "_at":
|
||||
|
@ -419,6 +465,30 @@ def _create_region_with_prefix(name, prefix):
|
|||
"clear_" + name:
|
||||
functools.partial(clear_prefixed_bbox, prefix=prefix),
|
||||
}, module_dict=globals())
|
||||
msu.add_functions_to_module({
|
||||
"get_" + name + "_point_at":
|
||||
functools.partial(get_prefixed_point_at, prefix=prefix),
|
||||
"add_" + name + "_point":
|
||||
functools.partial(add_prefixed_point, prefix=prefix),
|
||||
"get_" + name + "_point_size":
|
||||
functools.partial(get_prefixed_point_size, prefix=prefix),
|
||||
"has_" + name + "_point":
|
||||
functools.partial(has_prefixed_point, prefix=prefix),
|
||||
"clear_" + name + "_point":
|
||||
functools.partial(clear_prefixed_point, prefix=prefix),
|
||||
}, module_dict=globals())
|
||||
msu.add_functions_to_module({
|
||||
"get_" + name + "_3d_point_at":
|
||||
functools.partial(get_prefixed_3d_point_at, prefix=prefix),
|
||||
"add_" + name + "_3d_point":
|
||||
functools.partial(add_prefixed_3d_point, prefix=prefix),
|
||||
"get_" + name + "_3d_point_size":
|
||||
functools.partial(get_prefixed_3d_point_size, prefix=prefix),
|
||||
"has_" + name + "_3d_point":
|
||||
functools.partial(has_prefixed_3d_point, prefix=prefix),
|
||||
"clear_" + name + "_3d_point":
|
||||
functools.partial(clear_prefixed_3d_point, prefix=prefix),
|
||||
}, module_dict=globals())
|
||||
|
||||
|
||||
PREDICTED_PREFIX = "PREDICTED"
|
||||
|
|
|
@ -436,6 +436,21 @@ TEST(MediaSequenceTest, RoundTripBBoxPointPrefixed) {
|
|||
}
|
||||
}
|
||||
|
||||
TEST(MediaSequenceTest, RoundTripBBox3dPoint) {
|
||||
tensorflow::SequenceExample sequence;
|
||||
std::vector<std::vector<std::tuple<float, float, float>>> points = {
|
||||
{std::make_tuple(0.3, 0.5, 0.1), std::make_tuple(0.4, 0.7, 0.2)},
|
||||
{std::make_tuple(0.7, 0.5, 0.3), std::make_tuple(0.3, 0.4, 0.4)}};
|
||||
for (int i = 0; i < points.size(); ++i) {
|
||||
AddBBox3dPoint(points[i], &sequence);
|
||||
ASSERT_EQ(GetBBox3dPointSize(sequence), i + 1);
|
||||
const auto& sequence_points = GetBBox3dPointAt(sequence, i);
|
||||
for (int j = 0; j < sequence_points.size(); ++j) {
|
||||
EXPECT_EQ(sequence_points[j], points[i][j]);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
TEST(MediaSequenceTest, RoundTripRegionParts) {
|
||||
tensorflow::SequenceExample sequence;
|
||||
std::vector<std::string> parts = {"HEAD", "FEET"};
|
||||
|
|
|
@ -89,6 +89,9 @@ class MediaSequenceTest(tf.test.TestCase):
|
|||
ms.add_bbox_xmax((0.47, 0.49), example)
|
||||
ms.add_bbox_point_x((0.47, 0.49), example)
|
||||
ms.add_bbox_point_y((0.47, 0.49), example)
|
||||
ms.add_bbox_3d_point_x((0.47, 0.49), example)
|
||||
ms.add_bbox_3d_point_y((0.47, 0.49), example)
|
||||
ms.add_bbox_3d_point_z((0.47, 0.49), example)
|
||||
ms.add_predicted_bbox_ymin((0.47, 0.49), example)
|
||||
ms.add_predicted_bbox_xmin((0.47, 0.49), example)
|
||||
ms.add_predicted_bbox_ymax((0.47, 0.49), example)
|
||||
|
@ -133,6 +136,30 @@ class MediaSequenceTest(tf.test.TestCase):
|
|||
ms.clear_bbox(example)
|
||||
self.assertEqual(0, ms.get_bbox_size(example))
|
||||
|
||||
def test_point_round_trip(self):
|
||||
example = tf.train.SequenceExample()
|
||||
points = np.array([[0.1, 0.2],
|
||||
[0.5, 0.6]])
|
||||
ms.add_bbox_point(points, example)
|
||||
ms.add_bbox_point(points, example)
|
||||
self.assertEqual(2, ms.get_bbox_point_size(example))
|
||||
self.assertAllClose(points, ms.get_bbox_point_at(0, example))
|
||||
self.assertTrue(ms.has_bbox_point(example))
|
||||
ms.clear_bbox_point(example)
|
||||
self.assertEqual(0, ms.get_bbox_point_size(example))
|
||||
|
||||
def test_3d_point_round_trip(self):
|
||||
example = tf.train.SequenceExample()
|
||||
points = np.array([[0.1, 0.2, 0.3],
|
||||
[0.5, 0.6, 0.7]])
|
||||
ms.add_bbox_3d_point(points, example)
|
||||
ms.add_bbox_3d_point(points, example)
|
||||
self.assertEqual(2, ms.get_bbox_3d_point_size(example))
|
||||
self.assertAllClose(points, ms.get_bbox_3d_point_at(0, example))
|
||||
self.assertTrue(ms.has_bbox_3d_point(example))
|
||||
ms.clear_bbox_3d_point(example)
|
||||
self.assertEqual(0, ms.get_bbox_3d_point_size(example))
|
||||
|
||||
def test_predicted_bbox_round_trip(self):
|
||||
example = tf.train.SequenceExample()
|
||||
boxes = np.array([[0.1, 0.2, 0.3, 0.4],
|
||||
|
|
|
@ -19,6 +19,14 @@ package(default_visibility = [
|
|||
"//mediapipe:__subpackages__",
|
||||
])
|
||||
|
||||
cc_library(
|
||||
name = "config",
|
||||
hdrs = ["config.h"],
|
||||
deps = [
|
||||
"//mediapipe/framework:calculator_framework",
|
||||
],
|
||||
)
|
||||
|
||||
cc_library(
|
||||
name = "cpu_op_resolver",
|
||||
srcs = ["cpu_op_resolver.cc"],
|
||||
|
@ -69,6 +77,7 @@ cc_test(
|
|||
srcs = ["tensor_buffer_test.cc"],
|
||||
deps = [
|
||||
":tensor_buffer",
|
||||
":config",
|
||||
"//mediapipe/framework/port:gtest_main",
|
||||
] + select({
|
||||
"//mediapipe/gpu:disable_gpu": [],
|
||||
|
@ -99,6 +108,7 @@ cc_library(
|
|||
"@org_tensorflow//tensorflow/lite:framework",
|
||||
"@org_tensorflow//tensorflow/lite/delegates/gpu:api",
|
||||
"@org_tensorflow//tensorflow/lite/delegates/gpu/common:model",
|
||||
"@org_tensorflow//tensorflow/lite/delegates/gpu/common/testing:tflite_model_reader",
|
||||
"@org_tensorflow//tensorflow/lite/delegates/gpu/gl:api2",
|
||||
],
|
||||
"//mediapipe:android": [
|
||||
|
@ -108,7 +118,9 @@ cc_library(
|
|||
"//mediapipe/framework/port:statusor",
|
||||
"@org_tensorflow//tensorflow/lite:framework",
|
||||
"@org_tensorflow//tensorflow/lite/delegates/gpu:api",
|
||||
"@org_tensorflow//tensorflow/lite/delegates/gpu/cl:api",
|
||||
"@org_tensorflow//tensorflow/lite/delegates/gpu/common:model",
|
||||
"@org_tensorflow//tensorflow/lite/delegates/gpu/common/testing:tflite_model_reader",
|
||||
"@org_tensorflow//tensorflow/lite/delegates/gpu/gl:api2",
|
||||
],
|
||||
}) + ["@org_tensorflow//tensorflow/lite/core/api"],
|
||||
|
|
59
mediapipe/util/tflite/config.h
Normal file
59
mediapipe/util/tflite/config.h
Normal file
|
@ -0,0 +1,59 @@
|
|||
// Copyright 2020 The MediaPipe Authors.
|
||||
//
|
||||
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||
// you may not use this file except in compliance with the License.
|
||||
// You may obtain a copy of the License at
|
||||
//
|
||||
// http://www.apache.org/licenses/LICENSE-2.0
|
||||
//
|
||||
// Unless required by applicable law or agreed to in writing, software
|
||||
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
// See the License for the specific language governing permissions and
|
||||
// limitations under the License.
|
||||
|
||||
#ifndef MEDIAPIPE_UTIL_TFLITE_CONFIG_H_
|
||||
#define MEDIAPIPE_UTIL_TFLITE_CONFIG_H_
|
||||
|
||||
#include "mediapipe/framework/calculator_framework.h"
|
||||
|
||||
// MediaPipe code should use the following defines to determine whether TFLite
|
||||
// GPU support is available, and whether GL or Metal inference is available.
|
||||
|
||||
#ifdef MEDIAPIPE_DISABLE_GL_COMPUTE
|
||||
#define MEDIAPIPE_TFLITE_GL_INFERENCE 0
|
||||
#else
|
||||
#define MEDIAPIPE_TFLITE_GL_INFERENCE 1
|
||||
#endif // MEDIAPIPE_DISABLE_GL_COMPUTE
|
||||
|
||||
#ifdef MEDIAPIPE_IOS
|
||||
#define MEDIAPIPE_TFLITE_METAL_INFERENCE 1
|
||||
#else
|
||||
#define MEDIAPIPE_TFLITE_METAL_INFERENCE 0
|
||||
#endif // MEDIAPIPE_IOS
|
||||
|
||||
#define MEDIAPIPE_TFLITE_GPU_SUPPORTED \
|
||||
((MEDIAPIPE_TFLITE_GL_INFERENCE) || (MEDIAPIPE_TFLITE_METAL_INFERENCE))
|
||||
|
||||
#if MEDIAPIPE_TFLITE_GL_INFERENCE
|
||||
#include "tensorflow/lite/delegates/gpu/gl/gl_buffer.h"
|
||||
#endif // MEDIAPIPE_TFLITE_GL_INFERENCE
|
||||
|
||||
#if MEDIAPIPE_TFLITE_METAL_INFERENCE
|
||||
#import <Metal/Metal.h>
|
||||
#endif // MEDIAPIPE_TFLITE_METAL_INFERENCE
|
||||
|
||||
namespace mediapipe {
|
||||
|
||||
#if MEDIAPIPE_TFLITE_GL_INFERENCE
|
||||
typedef ::tflite::gpu::gl::GlBuffer GpuTensor;
|
||||
#elif MEDIAPIPE_TFLITE_METAL_INFERENCE
|
||||
typedef id<MTLBuffer> GpuTensor;
|
||||
#else
|
||||
struct DummyGpuTensor {};
|
||||
typedef DummyGpuTensor GpuTensor; // Dummy define for less #ifdefs
|
||||
#endif // MEDIAPIPE_TFLITE_GL_INFERENCE
|
||||
|
||||
} // namespace mediapipe
|
||||
|
||||
#endif // MEDIAPIPE_UTIL_TFLITE_CONFIG_H_
|
|
@ -130,8 +130,7 @@ TfLiteStatus Prepare(TfLiteContext* context, TfLiteNode* node) {
|
|||
auto padding = params->padding;
|
||||
auto compute_out_size = [padding](int image_size, int filter_size,
|
||||
int stride) -> int {
|
||||
return padding == kTfLitePaddingSame
|
||||
? (image_size + stride - 1) / stride
|
||||
return padding == kTfLitePaddingSame ? (image_size + stride - 1) / stride
|
||||
: padding == kTfLitePaddingValid
|
||||
? (image_size - filter_size + stride) / stride
|
||||
: 0;
|
||||
|
|
|
@ -2,6 +2,7 @@
|
|||
|
||||
#include "mediapipe/framework/port/gmock.h"
|
||||
#include "mediapipe/framework/port/gtest.h"
|
||||
#include "mediapipe/util/tflite/config.h"
|
||||
|
||||
namespace mediapipe {
|
||||
|
||||
|
@ -12,7 +13,7 @@ TEST(Cpu, BasicTest) {
|
|||
EXPECT_FALSE(tb.UsesGpu());
|
||||
}
|
||||
|
||||
#if !defined(MEDIAPIPE_DISABLE_GPU)
|
||||
#if MEDIAPIPE_TFLITE_GL_INFERENCE
|
||||
TEST(Gpu, BasicTest) {
|
||||
TensorBuffer tb;
|
||||
std::shared_ptr<tflite::gpu::gl::GlBuffer> tfg_tb =
|
||||
|
@ -20,7 +21,7 @@ TEST(Gpu, BasicTest) {
|
|||
tb = TensorBuffer(tfg_tb);
|
||||
EXPECT_TRUE(tb.UsesGpu());
|
||||
}
|
||||
#endif // !MEDIAPIPE_DISABLE_GPU
|
||||
#endif // !MEDIAPIPE_TFLITE_GL_INFERENCE
|
||||
|
||||
} // namespace mediapipe
|
||||
|
||||
|
|
|
@ -30,6 +30,13 @@
|
|||
#include "tensorflow/lite/delegates/gpu/gl/api2.h"
|
||||
#include "tensorflow/lite/model.h"
|
||||
|
||||
// This code should be enabled as soon as TensorFlow version, which mediapipe
|
||||
// uses, will include this module.
|
||||
#ifdef __ANDROID__
|
||||
#include "tensorflow/lite/delegates/gpu/cl/api.h"
|
||||
#endif
|
||||
#include "tensorflow/lite/delegates/gpu/common/testing/tflite_model_reader.h"
|
||||
|
||||
namespace tflite {
|
||||
namespace gpu {
|
||||
namespace {
|
||||
|
@ -51,6 +58,19 @@ ObjectDef GetSSBOObjectDef(int channels) {
|
|||
mediapipe::Status TFLiteGPURunner::InitializeWithModel(
|
||||
const tflite::FlatBufferModel& flatbuffer,
|
||||
const tflite::OpResolver& op_resolver) {
|
||||
// GraphFloat32 is created twice because, when OpenCL and OpenGL backends are
|
||||
// initialized, different backend-specific graph transformations happen
|
||||
// in-place. As GraphFloat32 is not copyable by design, we keep two copies of
|
||||
// the graph until inference is built. This decision doesn't affect the amount
|
||||
// of run time memory used, because both graph_gl_ and graph_cl_ are deleted
|
||||
// in the end of the initialization stage.
|
||||
graph_gl_ = std::make_unique<GraphFloat32>();
|
||||
graph_cl_ = std::make_unique<GraphFloat32>();
|
||||
MP_RETURN_IF_ERROR(
|
||||
BuildFromFlatBuffer(flatbuffer, op_resolver, graph_gl_.get()));
|
||||
MP_RETURN_IF_ERROR(
|
||||
BuildFromFlatBuffer(flatbuffer, op_resolver, graph_cl_.get()));
|
||||
|
||||
for (const auto& input : graph_gl_->inputs()) {
|
||||
input_shapes_.push_back(input->tensor.shape);
|
||||
}
|
||||
|
@ -140,6 +160,19 @@ mediapipe::Status TFLiteGPURunner::InitializeOpenGL(
|
|||
|
||||
absl::Status TFLiteGPURunner::InitializeOpenCL(
|
||||
std::unique_ptr<InferenceBuilder>* builder) {
|
||||
#ifdef __ANDROID__
|
||||
cl::InferenceEnvironmentOptions env_options;
|
||||
cl::InferenceEnvironmentProperties properties;
|
||||
cl::InferenceOptions cl_options;
|
||||
cl_options.priority1 = options_.priority1;
|
||||
cl_options.priority2 = options_.priority2;
|
||||
cl_options.priority3 = options_.priority3;
|
||||
cl_options.usage = options_.usage;
|
||||
MP_RETURN_IF_ERROR(
|
||||
cl::NewInferenceEnvironment(env_options, &cl_environment_, &properties));
|
||||
MP_RETURN_IF_ERROR(cl_environment_->NewInferenceBuilder(
|
||||
cl_options, std::move(*graph_cl_), builder));
|
||||
#endif
|
||||
return absl::OkStatus();
|
||||
}
|
||||
|
||||
|
|
|
@ -27,6 +27,10 @@
|
|||
#include "tensorflow/lite/delegates/gpu/gl/api2.h"
|
||||
#include "tensorflow/lite/model.h"
|
||||
|
||||
#ifdef __ANDROID__
|
||||
#include "tensorflow/lite/delegates/gpu/cl/api.h"
|
||||
#endif
|
||||
|
||||
namespace tflite {
|
||||
namespace gpu {
|
||||
|
||||
|
@ -64,6 +68,9 @@ class TFLiteGPURunner {
|
|||
mediapipe::Status Build();
|
||||
mediapipe::Status Invoke();
|
||||
|
||||
std::vector<BHWC> GetInputShapes() { return input_shapes_; }
|
||||
std::vector<BHWC> GetOutputShapes() { return output_shapes_; }
|
||||
|
||||
private:
|
||||
mediapipe::Status InitializeOpenGL(
|
||||
std::unique_ptr<InferenceBuilder>* builder);
|
||||
|
@ -73,6 +80,10 @@ class TFLiteGPURunner {
|
|||
InferenceOptions options_;
|
||||
std::unique_ptr<gl::InferenceEnvironment> gl_environment_;
|
||||
|
||||
#ifdef __ANDROID__
|
||||
std::unique_ptr<cl::InferenceEnvironment> cl_environment_;
|
||||
#endif
|
||||
|
||||
// graph_ is maintained temporarily and becomes invalid after runner_ is ready
|
||||
std::unique_ptr<GraphFloat32> graph_gl_;
|
||||
std::unique_ptr<GraphFloat32> graph_cl_;
|
||||
|
|
50
third_party/com_google_googletest_9d580ea80592189e6d44fa35bcf9cdea8bf620d6.diff
vendored
Normal file
50
third_party/com_google_googletest_9d580ea80592189e6d44fa35bcf9cdea8bf620d6.diff
vendored
Normal file
|
@ -0,0 +1,50 @@
|
|||
diff --git a/googletest/include/gtest/internal/gtest-internal.h b/googletest/include/gtest/internal/gtest-internal.h
|
||||
index 7f1a5b00e..c36029ee1 100644
|
||||
--- a/googletest/include/gtest/internal/gtest-internal.h
|
||||
+++ b/googletest/include/gtest/internal/gtest-internal.h
|
||||
@@ -94,6 +94,12 @@ namespace proto2 {
|
||||
class MessageLite;
|
||||
}
|
||||
|
||||
+namespace google {
|
||||
+namespace protobuf {
|
||||
+class MessageLite;
|
||||
+}
|
||||
+}
|
||||
+
|
||||
namespace testing {
|
||||
|
||||
// Forward declarations.
|
||||
@@ -881,10 +887,15 @@ class GTEST_API_ Random {
|
||||
typename std::remove_const<typename std::remove_reference<T>::type>::type
|
||||
|
||||
// IsAProtocolMessage<T>::value is a compile-time bool constant that's
|
||||
-// true if and only if T is type proto2::MessageLite or a subclass of it.
|
||||
+// true if and only if T is type proto2::MessageLite or
|
||||
+// google::protobuf::MessageLite or a subclass of one of them.
|
||||
template <typename T>
|
||||
struct IsAProtocolMessage
|
||||
- : public std::is_convertible<const T*, const ::proto2::MessageLite*> {};
|
||||
+ : public std::integral_constant<
|
||||
+ bool,
|
||||
+ std::is_convertible<const T*, const ::proto2::MessageLite*>::value ||
|
||||
+ std::is_convertible<
|
||||
+ const T*, const ::google::protobuf::MessageLite*>::value> {};
|
||||
|
||||
// When the compiler sees expression IsContainerTest<C>(0), if C is an
|
||||
// STL-style container class, the first overload of IsContainerTest
|
||||
diff --git a/googletest/test/gtest_unittest.cc b/googletest/test/gtest_unittest.cc
|
||||
index 005a2d40d..631180e3d 100644
|
||||
--- a/googletest/test/gtest_unittest.cc
|
||||
+++ b/googletest/test/gtest_unittest.cc
|
||||
@@ -7115,6 +7115,10 @@ TEST(IsAProtocolMessageTest, ValueIsTrueWhenTypeIsAProtocolMessage) {
|
||||
EXPECT_TRUE(IsAProtocolMessage<::proto2::MessageLite>::value);
|
||||
}
|
||||
|
||||
+TEST(IsAProtocolMessageTest, ValueIsTrueWhenTypeIsAnOpenSourceProtocolMessage) {
|
||||
+ EXPECT_TRUE(IsAProtocolMessage<::google::protobuf::MessageLite>::value);
|
||||
+}
|
||||
+
|
||||
// Tests that IsAProtocolMessage<T>::value is false when T is neither
|
||||
// ::proto2::Message nor a sub-class of it.
|
||||
TEST(IsAProtocolMessageTest, ValueIsFalseWhenTypeIsNotAProtocolMessage) {
|
Loading…
Reference in New Issue
Block a user