Project import generated by Copybara.

GitOrigin-RevId: e3a43e4e5e519cd14df7095749059e2613bdcf76
This commit is contained in:
MediaPipe Team 2020-07-08 17:34:05 -07:00 committed by jqtang
parent 67bd8a2bf0
commit e9fbe868e5
96 changed files with 2547 additions and 1176 deletions

2
BUILD
View File

@ -1,4 +1,4 @@
# Copyright 2019-2020 The MediaPipe Authors.
# Copyright 2019 The MediaPipe Authors.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.

View File

@ -101,7 +101,7 @@ run code search using
## Videos
* [YouTube Channel](https://www.youtube.com/channel/UCObqmpuSMx-usADtL_qdMAw)
* [YouTube Channel](https://www.youtube.com/c/MediaPipe)
## Events
@ -123,7 +123,7 @@ run code search using
* [Awesome MediaPipe](https://mediapipe.org) - A curated list of awesome
MediaPipe related frameworks, libraries and software
* [Slack community](https://mediapipe.slack.com) for MediaPipe users
* [Slack community](https://https://mediapipe.page.link/joinslack) for MediaPipe users
* [Discuss](https://groups.google.com/forum/#!forum/mediapipe) - General
community discussion around MediaPipe

View File

@ -37,10 +37,19 @@ http_archive(
)
# GoogleTest/GoogleMock framework. Used by most unit-tests.
# Last updated 2020-06-30.
http_archive(
name = "com_google_googletest",
urls = ["https://github.com/google/googletest/archive/master.zip"],
strip_prefix = "googletest-master",
name = "com_google_googletest",
urls = ["https://github.com/google/googletest/archive/aee0f9d9b5b87796ee8a0ab26b7587ec30e8858e.zip"],
patches = [
# fix for https://github.com/google/googletest/issues/2817
"@//third_party:com_google_googletest_9d580ea80592189e6d44fa35bcf9cdea8bf620d6.diff"
],
patch_args = [
"-p1",
],
strip_prefix = "googletest-aee0f9d9b5b87796ee8a0ab26b7587ec30e8858e",
sha256 = "04a1751f94244307cebe695a69cc945f9387a80b0ef1af21394a490697c5c895",
)
# Google Benchmark library.

74
build_ios_examples.sh Normal file
View File

@ -0,0 +1,74 @@
#!/bin/bash
# Copyright 2020 The MediaPipe Authors.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# =========================================================================
#
# Script to build all MediaPipe iOS example apps.
#
# To build all apps and store them in out_dir:
# $ ./build_ios_examples.sh -d out_dir
# Omitting -d and the associated directory saves all generated IPAs in the
# current directory.
# $ ./build_ios_examples.sh -d out_dir --nostrip
# Same as above except that the symnbols are not stripped.
set -e
out_dir="."
strip=true
app_dir="mediapipe/examples/ios"
bin_dir="bazel-bin"
declare -a default_bazel_flags=(build -c opt --config=ios_arm64)
while [[ -n $1 ]]; do
case $1 in
-d)
shift
out_dir=$1
;;
--nostrip)
strip=false
;;
*)
echo "Unsupported input argument $1."
exit 1
;;
esac
shift
done
echo "app_dir: $app_dir"
echo "out_dir: $out_dir"
echo "strip: $strip"
declare -a bazel_flags
apps="${app_dir}/*"
for app in ${apps}; do
if [[ -d "${app}" ]]; then
target_name=${app##*/}
target="${app}:${target_name}"
echo "=== Target: ${target}"
bazel_flags=("${default_bazel_flags[@]}")
bazel_flags+=(${target})
if [[ $strip == true ]]; then
bazel_flags+=(--linkopt=-s)
fi
bazel "${bazel_flags[@]}"
cp -f "${bin_dir}/${app}/"*".ipa" "${out_dir}"
fi
done

View File

@ -149,15 +149,15 @@ When possible, these calculators use platform-specific functionality to share da
The below diagram shows the data flow in a mobile application that captures video from the camera, runs it through a MediaPipe graph, and renders the output on the screen in real time. The dashed line indicates which parts are inside the MediaPipe graph proper. This application runs a Canny edge-detection filter on the CPU using OpenCV, and overlays it on top of the original video using the GPU.
| ![How GPU calculators interact](../images/gpu_example_graph.png) |
| :--------------------------------------------------------------------------: |
| *Video frames from the camera are fed into the graph as `GpuBuffer` packets. |
: The input stream is accessed by two calculators in parallel. :
: `GpuBufferToImageFrameCalculator` converts the buffer into an `ImageFrame`, :
: which is then sent through a grayscale converter and a canny filter (both :
: based on OpenCV and running on the CPU), whose output is then converted into :
: a `GpuBuffer` again. A multi-input GPU calculator, GlOverlayCalculator, :
: takes as input both the original `GpuBuffer` and the one coming out of the :
: edge detector, and overlays them using a shader. The output is then sent :
: back to the application using a callback calculator, and the application :
: renders the image to the screen using OpenGL.* :
![How GPU calculators interact](../images/gpu_example_graph.png)
Video frames from the camera are fed into the graph as `GpuBuffer` packets. The
input stream is accessed by two calculators in parallel.
`GpuBufferToImageFrameCalculator` converts the buffer into an `ImageFrame`,
which is then sent through a grayscale converter and a canny filter (both based
on OpenCV and running on the CPU), whose output is then converted into a
`GpuBuffer` again. A multi-input GPU calculator, GlOverlayCalculator, takes as
input both the original `GpuBuffer` and the one coming out of the edge detector,
and overlays them using a shader. The output is then sent back to the
application using a callback calculator, and the application renders the image
to the screen using OpenGL.

View File

@ -184,12 +184,8 @@ app:
### Prerequisite
1. Install [Xcode](https://developer.apple.com/xcode/) and the Command Line
Tools.
Follow Apple's instructions to obtain the required development certificates
and provisioning profiles for your iOS device. Install the Command Line
Tools by
1. Install [Xcode](https://developer.apple.com/xcode/), and additionally
install the Command Line Tools by:
```bash
xcode-select --install
@ -209,26 +205,31 @@ app:
pip3 install --user six
```
4. Clone the MediaPipe repository.
4. Follow
[Apple's instructions](https://developer.apple.com/support/certificates/) to
obtain the required development certificates and provisioning profiles for
your iOS device.
Tip: You can the following command to see the provisioning profiles you have
previously downloaded using Xcode: `open
~/Library/MobileDevice/"Provisioning Profiles"`. If there are none, generate
and download a profile on
[Apple's developer site](https://developer.apple.com/account/resources/).
5. Clone the MediaPipe repository.
```bash
git clone https://github.com/google/mediapipe.git
```
5. Symlink or copy your provisioning profile to
`mediapipe/mediapipe/provisioning_profile.mobileprovision`.
6. In the cloned MediaPipe repository, symlink or copy your provisioning profile
to `mediapipe/provisioning_profile.mobileprovision`, e.g.,
```bash
cd mediapipe
ln -s ~/Downloads/MyProvisioningProfile.mobileprovision mediapipe/provisioning_profile.mobileprovision
```
Tip: You can use this command to see the provisioning profiles you have
previously downloaded using Xcode: `open
~/Library/MobileDevice/"Provisioning Profiles"`. If there are none, generate
and download a profile on
[Apple's developer site](https://developer.apple.com/account/resources/).
### Option 1: Build with Bazel in Command Line
1. Modify the `bundle_id` field of the app's `ios_application` build target to
@ -246,6 +247,10 @@ app:
You may see a permission request from `codesign` in order to sign the app.
Tip: You can run this
[script](https://github.com/google/mediapipe/blob/master/build_ios_examples.sh)
to build all MediaPipe iOS example apps.
3. In Xcode, open the `Devices and Simulators` window (command-shift-2).
4. Make sure your device is connected. You will see a list of installed apps.

View File

@ -44,6 +44,18 @@ apps, see these [instructions](./building_examples.md#ios).
[Bazel documentation](https://docs.bazel.build/versions/master/install-ubuntu.html)
to install Bazel 2.0 or higher.
For Nvidia Jetson and Raspberry Pi devices with ARM Ubuntu, Bazel needs to
be built from source.
```bash
# For Bazel 3.0.0
wget https://github.com/bazelbuild/bazel/releases/download/3.0.0/bazel-3.0.0-dist.zip
sudo apt-get install build-essential openjdk-8-jdk python zip unzip
unzip bazel-3.0.0-dist.zip
env EXTRA_BAZEL_ARGS="--host_javabase=@local_jdk//:jdk" bash ./compile.sh
sudo cp output/bazel /usr/local/bin/
```
3. Install OpenCV and FFmpeg.
Option 1. Use package manager tool to install the pre-compiled OpenCV
@ -58,6 +70,14 @@ apps, see these [instructions](./building_examples.md#ios).
libopencv-imgproc-dev libopencv-video-dev
```
[`opencv_linux.BUILD`] is configured for x86_64 by default. For Nvidia
Jetson and Raspberry Pi devices with ARM Ubuntu, the lib paths need to be
modified.
```bash
sed -i "s/x86_64-linux-gnu/aarch64-linux-gnu/g" third_party/opencv_linux.BUILD
```
Option 2. Run [`setup_opencv.sh`] to automatically build OpenCV from source
and modify MediaPipe's OpenCV config.
@ -493,14 +513,14 @@ cameras. Alternatively, you use a video file as input.
```bash
username@DESKTOP-TMVLBJ1:~$ curl -sLO --retry 5 --retry-max-time 10 \
https://storage.googleapis.com/bazel/2.0.0/release/bazel-2.0.0-installer-linux-x86_64.sh && \
sudo mkdir -p /usr/local/bazel/2.0.0 && \
chmod 755 bazel-2.0.0-installer-linux-x86_64.sh && \
sudo ./bazel-2.0.0-installer-linux-x86_64.sh --prefix=/usr/local/bazel/2.0.0 && \
source /usr/local/bazel/2.0.0/lib/bazel/bin/bazel-complete.bash
https://storage.googleapis.com/bazel/3.0.0/release/bazel-3.0.0-installer-linux-x86_64.sh && \
sudo mkdir -p /usr/local/bazel/3.0.0 && \
chmod 755 bazel-3.0.0-installer-linux-x86_64.sh && \
sudo ./bazel-3.0.0-installer-linux-x86_64.sh --prefix=/usr/local/bazel/3.0.0 && \
source /usr/local/bazel/3.0.0/lib/bazel/bin/bazel-complete.bash
username@DESKTOP-TMVLBJ1:~$ /usr/local/bazel/2.0.0/lib/bazel/bin/bazel version && \
alias bazel='/usr/local/bazel/2.0.0/lib/bazel/bin/bazel'
username@DESKTOP-TMVLBJ1:~$ /usr/local/bazel/3.0.0/lib/bazel/bin/bazel version && \
alias bazel='/usr/local/bazel/3.0.0/lib/bazel/bin/bazel'
```
6. Checkout MediaPipe repository.

View File

@ -101,7 +101,7 @@ run code search using
## Videos
* [YouTube Channel](https://www.youtube.com/channel/UCObqmpuSMx-usADtL_qdMAw)
* [YouTube Channel](https://www.youtube.com/c/MediaPipe)
## Events
@ -123,7 +123,7 @@ run code search using
* [Awesome MediaPipe](https://mediapipe.org) - A curated list of awesome
MediaPipe related frameworks, libraries and software
* [Slack community](https://mediapipe.slack.com) for MediaPipe users
* [Slack community](https://https://mediapipe.page.link/joinslack) for MediaPipe users
* [Discuss](https://groups.google.com/forum/#!forum/mediapipe) - General
community discussion around MediaPipe

View File

@ -1,6 +1,6 @@
---
layout: default
title: Hand
title: Hands
parent: Solutions
nav_order: 3
---
@ -219,9 +219,13 @@ Please refer to [these instructions](../index.md#mediapipe-on-the-web).
## Resources
* Google AI Blog: [On-Device, Real-Time Hand Tracking with MediaPipe](https://ai.googleblog.com/2019/08/on-device-real-time-hand-tracking-with.html)
* TensorFlow Blog: [Face and hand tracking in the browser with MediaPipe and
TensorFlow.js](https://blog.tensorflow.org/2020/03/face-and-hand-tracking-in-browser-with-mediapipe-and-tensorflowjs.html)
* Google AI Blog:
[On-Device, Real-Time Hand Tracking with MediaPipe](https://ai.googleblog.com/2019/08/on-device-real-time-hand-tracking-with.html)
* TensorFlow Blog:
[Face and hand tracking in the browser with MediaPipe and TensorFlow.js](https://blog.tensorflow.org/2020/03/face-and-hand-tracking-in-browser-with-mediapipe-and-tensorflowjs.html)
* Paper:
[MediaPipe Hands: On-device Real-time Hand Tracking](https://arxiv.org/abs/2006.10214)
([presentation](https://www.youtube.com/watch?v=I-UOrvxxXEk))
* Palm detection model:
[TFLite model](https://github.com/google/mediapipe/tree/master/mediapipe/models/palm_detection.tflite),
[TF.js model](https://tfhub.dev/mediapipe/handdetector/1)

View File

@ -188,5 +188,8 @@ to visualize its associated subgraphs, please see
[Real-Time 3D Object Detection on Mobile Devices with MediaPipe](https://ai.googleblog.com/2020/03/real-time-3d-object-detection-on-mobile.html)
* Paper: [MobilePose: Real-Time Pose Estimation for Unseen Objects with Weak
Shape Supervision](https://arxiv.org/abs/2003.03522)
* Paper:
[Instant 3D Object Tracking with Applications in Augmented Reality](https://drive.google.com/open?id=1O_zHmlgXIzAdKljp20U_JUkEHOGG52R8)
([presentation](https://www.youtube.com/watch?v=9ndF1AIo7h0))
* [TFLite model for shoes](https://github.com/google/mediapipe/tree/master/mediapipe/models/object_detection_3d_sneakers.tflite)
* [TFLite model for chairs](https://github.com/google/mediapipe/tree/master/mediapipe/models/object_detection_3d_chair.tflite)

View File

@ -21,7 +21,16 @@ available on Linux, Android, or iOS.
## Enabling tracing and profiling
To enable tracing/profiling of a mediapipe graph, the `CalculatorGraphConfig` (in
To enable tracing and profiling of a mediapipe graph:
1. The profiling library must be linked to the framework.
2. Tracing and profiling must be enabled in the graph configuration.
The profiling library is linked to the framework by default. If needed,
the profiling library can be omitted from the framework using the bazel
command line option: `--define MEDIAPIPE_PROFILING=0`.
To enable tracing and profiling, the `CalculatorGraphConfig` (in
[calculator.proto](https://github.com/google/mediapipe/tree/master/mediapipe/framework/calculator.proto))
representing the graph must have a `profiler_config` message at its root. Here
is a simple setup that turns on a few extra options:

View File

@ -107,7 +107,7 @@ class BilateralFilterCalculator : public CalculatorBase {
GLuint program_ = 0;
GLuint vao_;
GLuint vbo_[2]; // vertex storage
#endif // !MEDIAPIPE_DISABLE_GPU
#endif // !MEDIAPIPE_DISABLE_GPU
};
REGISTER_CALCULATOR(BilateralFilterCalculator);

View File

@ -386,45 +386,47 @@ REGISTER_CALCULATOR(ImageTransformationCalculator);
const int input_width = input_mat.cols;
const int input_height = input_mat.rows;
if (!output_height_ || !output_width_) {
output_height_ = input_height;
output_width_ = input_width;
}
int output_width;
int output_height;
ComputeOutputDimensions(input_width, input_height, &output_width,
&output_height);
cv::Mat scaled_mat;
int output_width = output_width_;
int output_height = output_height_;
if (scale_mode_ == mediapipe::ScaleMode_Mode_STRETCH) {
int scale_flag =
input_mat.cols > output_width_ && input_mat.rows > output_height_
? cv::INTER_AREA
: cv::INTER_LINEAR;
cv::resize(input_mat, scaled_mat, cv::Size(output_width_, output_height_),
0, 0, scale_flag);
} else {
const float scale =
std::min(static_cast<float>(output_width_) / input_width,
static_cast<float>(output_height_) / input_height);
const int target_width = std::round(input_width * scale);
const int target_height = std::round(input_height * scale);
int scale_flag = scale < 1.0f ? cv::INTER_AREA : cv::INTER_LINEAR;
if (scale_mode_ == mediapipe::ScaleMode_Mode_FIT) {
cv::Mat intermediate_mat;
cv::resize(input_mat, intermediate_mat,
cv::Size(target_width, target_height), 0, 0, scale_flag);
const int top = (output_height_ - target_height) / 2;
const int bottom = output_height_ - target_height - top;
const int left = (output_width_ - target_width) / 2;
const int right = output_width_ - target_width - left;
cv::copyMakeBorder(intermediate_mat, scaled_mat, top, bottom, left, right,
options_.constant_padding() ? cv::BORDER_CONSTANT
: cv::BORDER_REPLICATE);
} else {
cv::resize(input_mat, scaled_mat, cv::Size(target_width, target_height),
if (output_width_ > 0 && output_height_ > 0) {
cv::Mat scaled_mat;
if (scale_mode_ == mediapipe::ScaleMode_Mode_STRETCH) {
int scale_flag =
input_mat.cols > output_width_ && input_mat.rows > output_height_
? cv::INTER_AREA
: cv::INTER_LINEAR;
cv::resize(input_mat, scaled_mat, cv::Size(output_width_, output_height_),
0, 0, scale_flag);
output_width = target_width;
output_height = target_height;
} else {
const float scale =
std::min(static_cast<float>(output_width_) / input_width,
static_cast<float>(output_height_) / input_height);
const int target_width = std::round(input_width * scale);
const int target_height = std::round(input_height * scale);
int scale_flag = scale < 1.0f ? cv::INTER_AREA : cv::INTER_LINEAR;
if (scale_mode_ == mediapipe::ScaleMode_Mode_FIT) {
cv::Mat intermediate_mat;
cv::resize(input_mat, intermediate_mat,
cv::Size(target_width, target_height), 0, 0, scale_flag);
const int top = (output_height_ - target_height) / 2;
const int bottom = output_height_ - target_height - top;
const int left = (output_width_ - target_width) / 2;
const int right = output_width_ - target_width - left;
cv::copyMakeBorder(intermediate_mat, scaled_mat, top, bottom, left,
right,
options_.constant_padding() ? cv::BORDER_CONSTANT
: cv::BORDER_REPLICATE);
} else {
cv::resize(input_mat, scaled_mat, cv::Size(target_width, target_height),
0, 0, scale_flag);
output_width = target_width;
output_height = target_height;
}
}
input_mat = scaled_mat;
}
if (cc->Outputs().HasTag("LETTERBOX_PADDING")) {
@ -437,10 +439,33 @@ REGISTER_CALCULATOR(ImageTransformationCalculator);
}
cv::Mat rotated_mat;
const int angle = RotationModeToDegrees(rotation_);
cv::Point2f src_center(scaled_mat.cols / 2.0, scaled_mat.rows / 2.0);
cv::Mat rotation_mat = cv::getRotationMatrix2D(src_center, angle, 1.0);
cv::warpAffine(scaled_mat, rotated_mat, rotation_mat, scaled_mat.size());
cv::Size rotated_size(output_width, output_height);
if (input_mat.size() == rotated_size) {
const int angle = RotationModeToDegrees(rotation_);
cv::Point2f src_center(input_mat.cols / 2.0, input_mat.rows / 2.0);
cv::Mat rotation_mat = cv::getRotationMatrix2D(src_center, angle, 1.0);
cv::warpAffine(input_mat, rotated_mat, rotation_mat, rotated_size);
} else {
switch (rotation_) {
case mediapipe::RotationMode_Mode_UNKNOWN:
case mediapipe::RotationMode_Mode_ROTATION_0:
LOG(ERROR) << "Not rotating image.";
rotated_mat = input_mat;
break;
case mediapipe::RotationMode_Mode_ROTATION_90:
LOG(ERROR) << "Rotating image by 90 degrees ccw.";
cv::rotate(input_mat, rotated_mat, cv::ROTATE_90_COUNTERCLOCKWISE);
break;
case mediapipe::RotationMode_Mode_ROTATION_180:
LOG(ERROR) << "Rotating image by 180 degrees.";
cv::rotate(input_mat, rotated_mat, cv::ROTATE_180);
break;
case mediapipe::RotationMode_Mode_ROTATION_270:
LOG(ERROR) << "Rotating image by 90 degrees cw.";
cv::rotate(input_mat, rotated_mat, cv::ROTATE_90_CLOCKWISE);
break;
}
}
cv::Mat flipped_mat;
if (flip_horizontally_ || flip_vertically_) {
@ -498,7 +523,7 @@ REGISTER_CALCULATOR(ImageTransformationCalculator);
renderer = yuv_renderer_.get();
src1 = gpu_helper_.CreateSourceTexture(input, 0);
} else // NOLINT(readability/braces)
#endif // iOS
#endif // iOS
{
src1 = gpu_helper_.CreateSourceTexture(input);
#if defined(TEXTURE_EXTERNAL_OES)
@ -510,7 +535,7 @@ REGISTER_CALCULATOR(ImageTransformationCalculator);
}
renderer = ext_rgb_renderer_.get();
} else // NOLINT(readability/braces)
#endif // TEXTURE_EXTERNAL_OES
#endif // TEXTURE_EXTERNAL_OES
{
if (!rgb_renderer_) {
rgb_renderer_ = absl::make_unique<QuadRenderer>();

View File

@ -139,7 +139,6 @@ class TensorFlowSessionFromSavedModelGenerator : public PacketGenerator {
static_cast<::mediapipe::StatusCode>(status.code()),
status.ToString());
}
auto session = absl::make_unique<TensorFlowSession>();
session->session = std::move(saved_model->session);

View File

@ -14,6 +14,7 @@
#
load("//mediapipe/framework/port:build_config.bzl", "mediapipe_cc_proto_library")
load("@bazel_skylib//lib:selects.bzl", "selects")
licenses(["notice"]) # Apache 2.0
@ -202,6 +203,13 @@ cc_library(
alwayslink = 1,
)
selects.config_setting_group(
name = "gpu_inference_disabled",
match_any = [
"//mediapipe/gpu:disable_gpu",
],
)
cc_library(
name = "tflite_inference_calculator",
srcs = ["tflite_inference_calculator.cc"],
@ -226,13 +234,14 @@ cc_library(
"@com_google_absl//absl/memory",
"//mediapipe/framework:calculator_framework",
"//mediapipe/util:resource_util",
"//mediapipe/util/tflite:config",
"@org_tensorflow//tensorflow/lite:framework",
"@org_tensorflow//tensorflow/lite/delegates/xnnpack:xnnpack_delegate",
"@org_tensorflow//tensorflow/lite/kernels:builtin_ops",
"//mediapipe/framework/stream_handler:fixed_size_input_stream_handler",
"//mediapipe/framework/port:ret_check",
] + select({
"//mediapipe/gpu:disable_gpu": [],
] + selects.with_or({
":gpu_inference_disabled": [],
"//mediapipe:ios": [
"//mediapipe/gpu:MPPMetalHelper",
"//mediapipe/gpu:MPPMetalUtil",
@ -285,6 +294,7 @@ cc_library(
}),
visibility = ["//visibility:public"],
deps = [
"//mediapipe/util/tflite:config",
":util",
":tflite_converter_calculator_cc_proto",
"//mediapipe/util:resource_util",
@ -295,23 +305,26 @@ cc_library(
"//mediapipe/framework/port:ret_check",
"@org_tensorflow//tensorflow/lite:framework",
"@org_tensorflow//tensorflow/lite/kernels:builtin_ops",
] + select({
"//mediapipe/gpu:disable_gpu": [],
] + selects.with_or({
":gpu_inference_disabled": [],
"//mediapipe:ios": [
"//mediapipe/gpu:MPPMetalUtil",
"//mediapipe/gpu:gpu_buffer",
"//mediapipe/gpu:MPPMetalHelper",
"//mediapipe/objc:mediapipe_framework_ios",
"@org_tensorflow//tensorflow/lite/delegates/gpu:metal_delegate",
],
"//conditions:default": [
"//mediapipe/gpu:gpu_buffer",
"//mediapipe/gpu:gl_calculator_helper",
"@org_tensorflow//tensorflow/lite/delegates/gpu:gl_delegate",
"@org_tensorflow//tensorflow/lite/delegates/gpu/gl:gl_buffer",
"@org_tensorflow//tensorflow/lite/delegates/gpu/gl:gl_program",
"@org_tensorflow//tensorflow/lite/delegates/gpu/gl:gl_shader",
],
}) + select({
"//mediapipe/gpu:disable_gpu": [],
"//conditions:default": [
"//mediapipe/gpu:gpu_buffer",
],
}),
alwayslink = 1,
)
@ -348,8 +361,8 @@ cc_library(
"//mediapipe/framework:calculator_framework",
"//mediapipe/util:resource_util",
"@org_tensorflow//tensorflow/lite:framework",
] + select({
"//mediapipe/gpu:disable_gpu": [],
] + selects.with_or({
":gpu_inference_disabled": [],
"//mediapipe:ios": [],
"//conditions:default": [
"//mediapipe/gpu:gl_calculator_helper",
@ -404,6 +417,7 @@ cc_library(
}),
visibility = ["//visibility:public"],
deps = [
"//mediapipe/util/tflite:config",
":util",
":tflite_tensors_to_detections_calculator_cc_proto",
"//mediapipe/framework/formats:detection_cc_proto",
@ -415,8 +429,8 @@ cc_library(
"//mediapipe/framework/formats/object_detection:anchor_cc_proto",
"//mediapipe/framework/port:ret_check",
"@org_tensorflow//tensorflow/lite:framework",
] + select({
"//mediapipe/gpu:disable_gpu": [],
] + selects.with_or({
":gpu_inference_disabled": [],
"//mediapipe:ios": [
"//mediapipe/gpu:MPPMetalUtil",
"//mediapipe/gpu:gpu_buffer",
@ -492,6 +506,8 @@ cc_library(
alwayslink = 1,
)
# To run this with native GPU on Linux, use:
# bazel test //mediapipe/calculators/tflite:tflite_inference_calculator_test --copt=-DTFLITE_GPU_EXTRA_GLES_DEPS --copt=-DMESA_EGL_NO_X11_HEADERS --copt=-DEGL_NO_X11 --config=grte_v5 --test_strategy=local
cc_test(
name = "tflite_inference_calculator_test",
srcs = ["tflite_inference_calculator_test.cc"],

View File

@ -22,19 +22,23 @@
#include "mediapipe/framework/formats/matrix.h"
#include "mediapipe/framework/port/ret_check.h"
#include "mediapipe/util/resource_util.h"
#include "mediapipe/util/tflite/config.h"
#include "tensorflow/lite/error_reporter.h"
#include "tensorflow/lite/interpreter.h"
#if !defined(MEDIAPIPE_DISABLE_GL_COMPUTE)
#include "mediapipe/gpu/gl_calculator_helper.h"
#ifndef MEDIAPIPE_DISABLE_GPU
#include "mediapipe/gpu/gpu_buffer.h"
#endif // MEDIAPIPE_DISABLE_GPU
#if MEDIAPIPE_TFLITE_GL_INFERENCE
#include "mediapipe/gpu/gl_calculator_helper.h"
#include "tensorflow/lite/delegates/gpu/gl/gl_buffer.h"
#include "tensorflow/lite/delegates/gpu/gl/gl_program.h"
#include "tensorflow/lite/delegates/gpu/gl/gl_shader.h"
#include "tensorflow/lite/delegates/gpu/gl_delegate.h"
#endif // !MEDIAPIPE_DISABLE_GPU
#endif // MEDIAPIPE_TFLITE_GL_INFERENCE
#if defined(MEDIAPIPE_IOS)
#if MEDIAPIPE_TFLITE_METAL_INFERENCE
#import <CoreVideo/CoreVideo.h>
#import <Metal/Metal.h>
#import <MetalKit/MetalKit.h>
@ -43,13 +47,7 @@
#include "mediapipe/gpu/MPPMetalUtil.h"
#include "mediapipe/gpu/gpu_buffer.h"
#include "tensorflow/lite/delegates/gpu/metal_delegate.h"
#endif // iOS
#if !defined(MEDIAPIPE_DISABLE_GL_COMPUTE)
typedef ::tflite::gpu::gl::GlBuffer GpuTensor;
#elif defined(MEDIAPIPE_IOS)
typedef id<MTLBuffer> GpuTensor;
#endif
#endif // MEDIAPIPE_TFLITE_METAL_INFERENCE
namespace {
constexpr int kWorkgroupSize = 8; // Block size for GPU shader.
@ -73,7 +71,7 @@ constexpr char kMatrixTag[] = "MATRIX";
namespace mediapipe {
namespace {
#if !defined(MEDIAPIPE_DISABLE_GL_COMPUTE)
#if MEDIAPIPE_TFLITE_GL_INFERENCE
using ::tflite::gpu::gl::CreateReadWriteShaderStorageBuffer;
using ::tflite::gpu::gl::GlProgram;
using ::tflite::gpu::gl::GlShader;
@ -83,13 +81,13 @@ struct GPUData {
GlShader shader;
GlProgram program;
};
#elif defined(MEDIAPIPE_IOS)
#elif MEDIAPIPE_TFLITE_METAL_INFERENCE
struct GPUData {
int elements = 1;
GpuTensor buffer;
id<MTLComputePipelineState> pipeline_state;
};
#endif
#endif // MEDIAPIPE_TFLITE_GL_INFERENCE
} // namespace
@ -157,13 +155,13 @@ class TfLiteConverterCalculator : public CalculatorBase {
std::unique_ptr<tflite::Interpreter> interpreter_ = nullptr;
#if !defined(MEDIAPIPE_DISABLE_GL_COMPUTE)
#if MEDIAPIPE_TFLITE_GL_INFERENCE
mediapipe::GlCalculatorHelper gpu_helper_;
std::unique_ptr<GPUData> gpu_data_out_;
#elif defined(MEDIAPIPE_IOS)
#elif MEDIAPIPE_TFLITE_METAL_INFERENCE
MPPMetalHelper* gpu_helper_ = nullptr;
std::unique_ptr<GPUData> gpu_data_out_;
#endif
#endif // MEDIAPIPE_TFLITE_GL_INFERENCE
bool initialized_ = false;
bool use_gpu_ = false;
@ -178,6 +176,18 @@ class TfLiteConverterCalculator : public CalculatorBase {
};
REGISTER_CALCULATOR(TfLiteConverterCalculator);
namespace {
template <class CC>
bool ShouldUseGpu(CC* cc) {
#if MEDIAPIPE_TFLITE_GPU_SUPPORTED
return cc->Inputs().HasTag(kGpuBufferTag) ||
cc->Outputs().HasTag(kTensorsGpuTag);
#else
return false;
#endif // MEDIAPIPE_TFLITE_GPU_SUPPORTED
}
} // namespace
::mediapipe::Status TfLiteConverterCalculator::GetContract(
CalculatorContract* cc) {
// Confirm only one of the input streams is present.
@ -189,37 +199,31 @@ REGISTER_CALCULATOR(TfLiteConverterCalculator);
RET_CHECK(cc->Outputs().HasTag(kTensorsTag) ^
cc->Outputs().HasTag(kTensorsGpuTag));
bool use_gpu = false;
if (cc->Inputs().HasTag(kImageFrameTag)) {
cc->Inputs().Tag(kImageFrameTag).Set<ImageFrame>();
}
if (cc->Inputs().HasTag(kMatrixTag)) {
cc->Inputs().Tag(kMatrixTag).Set<Matrix>();
}
#if !defined(MEDIAPIPE_DISABLE_GPU) && !defined(__EMSCRIPTEN__)
#ifndef MEDIAPIPE_DISABLE_GPU
if (cc->Inputs().HasTag(kGpuBufferTag)) {
cc->Inputs().Tag(kGpuBufferTag).Set<mediapipe::GpuBuffer>();
use_gpu |= true;
}
#endif // !MEDIAPIPE_DISABLE_GPU
#endif // MEDIAPIPE_DISABLE_GPU
if (cc->Outputs().HasTag(kTensorsTag)) {
cc->Outputs().Tag(kTensorsTag).Set<std::vector<TfLiteTensor>>();
}
#if !defined(MEDIAPIPE_DISABLE_GPU) && !defined(__EMSCRIPTEN__)
if (cc->Outputs().HasTag(kTensorsGpuTag)) {
cc->Outputs().Tag(kTensorsGpuTag).Set<std::vector<GpuTensor>>();
use_gpu |= true;
}
#endif // !MEDIAPIPE_DISABLE_GPU
if (use_gpu) {
#if !defined(MEDIAPIPE_DISABLE_GL_COMPUTE)
if (ShouldUseGpu(cc)) {
#if MEDIAPIPE_TFLITE_GL_INFERENCE
MP_RETURN_IF_ERROR(mediapipe::GlCalculatorHelper::UpdateContract(cc));
#elif defined(MEDIAPIPE_IOS)
#elif MEDIAPIPE_TFLITE_METAL_INFERENCE
MP_RETURN_IF_ERROR([MPPMetalHelper updateContract:cc]);
#endif
#endif // MEDIAPIPE_TFLITE_GL_INFERENCE
}
// Assign this calculator's default InputStreamHandler.
@ -233,14 +237,7 @@ REGISTER_CALCULATOR(TfLiteConverterCalculator);
MP_RETURN_IF_ERROR(LoadOptions(cc));
if (cc->Inputs().HasTag(kGpuBufferTag) ||
cc->Outputs().HasTag(kGpuBufferTag)) {
#if !defined(MEDIAPIPE_DISABLE_GPU) && !defined(__EMSCRIPTEN__)
use_gpu_ = true;
#else
RET_CHECK_FAIL() << "GPU processing not enabled.";
#endif
}
use_gpu_ = ShouldUseGpu(cc);
if (use_gpu_) {
// Cannot mix CPU/GPU streams.
@ -248,12 +245,12 @@ REGISTER_CALCULATOR(TfLiteConverterCalculator);
cc->Outputs().HasTag(kTensorsGpuTag));
// Cannot use quantization.
use_quantized_tensors_ = false;
#if !defined(MEDIAPIPE_DISABLE_GL_COMPUTE)
#if MEDIAPIPE_TFLITE_GL_INFERENCE
MP_RETURN_IF_ERROR(gpu_helper_.Open(cc));
#elif defined(MEDIAPIPE_IOS)
#elif MEDIAPIPE_TFLITE_METAL_INFERENCE
gpu_helper_ = [[MPPMetalHelper alloc] initWithCalculatorContext:cc];
RET_CHECK(gpu_helper_);
#endif
#endif // MEDIAPIPE_TFLITE_GL_INFERENCE
} else {
interpreter_ = absl::make_unique<tflite::Interpreter>();
interpreter_->AddTensors(1);
@ -282,12 +279,12 @@ REGISTER_CALCULATOR(TfLiteConverterCalculator);
}
::mediapipe::Status TfLiteConverterCalculator::Close(CalculatorContext* cc) {
#if !defined(MEDIAPIPE_DISABLE_GL_COMPUTE)
interpreter_.reset();
#if MEDIAPIPE_TFLITE_GL_INFERENCE
gpu_helper_.RunInGlContext([this] { gpu_data_out_.reset(); });
#endif
#if defined(MEDIAPIPE_IOS)
#elif MEDIAPIPE_TFLITE_METAL_INFERENCE
gpu_data_out_.reset();
#endif
#endif // MEDIAPIPE_TFLITE_GL_INFERENCE
return ::mediapipe::OkStatus();
}
@ -318,8 +315,14 @@ REGISTER_CALCULATOR(TfLiteConverterCalculator);
RET_CHECK(format != mediapipe::ImageFormat::VEC32F1)
<< "Only 8-bit input images are supported for quantization.";
quant.type = kTfLiteAffineQuantization;
quant.params = nullptr;
// Optional: Set 'quant' quantization params here if needed.
auto quant_params = static_cast<TfLiteAffineQuantization*>(
malloc(sizeof(TfLiteAffineQuantization)));
quant_params->scale = TfLiteFloatArrayCreate(1);
quant_params->scale->data[0] = 1.0;
quant_params->zero_point = TfLiteIntArrayCreate(1);
quant_params->zero_point->data[0] = 0;
quant_params->quantized_dimension = 0;
quant.params = quant_params;
interpreter_->SetTensorParametersReadWrite(0, kTfLiteUInt8, "",
{channels_preserved}, quant);
} else {
@ -414,7 +417,7 @@ REGISTER_CALCULATOR(TfLiteConverterCalculator);
::mediapipe::Status TfLiteConverterCalculator::ProcessGPU(
CalculatorContext* cc) {
#if !defined(MEDIAPIPE_DISABLE_GL_COMPUTE)
#if MEDIAPIPE_TFLITE_GL_INFERENCE
// GpuBuffer to tflite::gpu::GlBuffer conversion.
const auto& input =
cc->Inputs().Tag(kGpuBufferTag).Get<mediapipe::GpuBuffer>();
@ -451,7 +454,7 @@ REGISTER_CALCULATOR(TfLiteConverterCalculator);
cc->Outputs()
.Tag(kTensorsGpuTag)
.Add(output_tensors.release(), cc->InputTimestamp());
#elif defined(MEDIAPIPE_IOS)
#elif MEDIAPIPE_TFLITE_METAL_INFERENCE
// GpuBuffer to id<MTLBuffer> conversion.
const auto& input =
cc->Inputs().Tag(kGpuBufferTag).Get<mediapipe::GpuBuffer>();
@ -490,13 +493,13 @@ REGISTER_CALCULATOR(TfLiteConverterCalculator);
.Add(output_tensors.release(), cc->InputTimestamp());
#else
RET_CHECK_FAIL() << "GPU processing is not enabled.";
#endif
#endif // MEDIAPIPE_TFLITE_GL_INFERENCE
return ::mediapipe::OkStatus();
}
::mediapipe::Status TfLiteConverterCalculator::InitGpu(CalculatorContext* cc) {
#if !defined(MEDIAPIPE_DISABLE_GPU) && !defined(__EMSCRIPTEN__)
#if MEDIAPIPE_TFLITE_GPU_SUPPORTED
// Get input image sizes.
const auto& input =
cc->Inputs().Tag(kGpuBufferTag).Get<mediapipe::GpuBuffer>();
@ -512,9 +515,9 @@ REGISTER_CALCULATOR(TfLiteConverterCalculator);
RET_CHECK_FAIL() << "Unsupported GPU input format.";
if (include_alpha && (format != mediapipe::ImageFormat::SRGBA))
RET_CHECK_FAIL() << "Num input channels is less than desired output.";
#endif // !MEDIAPIPE_DISABLE_GPU
#endif // MEDIAPIPE_TFLITE_GPU_SUPPORTED
#if !defined(MEDIAPIPE_DISABLE_GL_COMPUTE)
#if MEDIAPIPE_TFLITE_GL_INFERENCE
MP_RETURN_IF_ERROR(gpu_helper_.RunInGlContext(
[this, &include_alpha, &input, &single_channel]() -> ::mediapipe::Status {
// Device memory.
@ -559,7 +562,7 @@ REGISTER_CALCULATOR(TfLiteConverterCalculator);
return ::mediapipe::OkStatus();
}));
#elif defined(MEDIAPIPE_IOS)
#elif MEDIAPIPE_TFLITE_METAL_INFERENCE
RET_CHECK(include_alpha)
<< "iOS GPU inference currently accepts only RGBA input.";
@ -616,7 +619,7 @@ REGISTER_CALCULATOR(TfLiteConverterCalculator);
RET_CHECK(gpu_data_out_->pipeline_state != nil)
<< "Couldn't create pipeline state "
<< [[error localizedDescription] UTF8String];
#endif
#endif // MEDIAPIPE_TFLITE_GL_INFERENCE
return ::mediapipe::OkStatus();
}

View File

@ -22,6 +22,7 @@
#include "mediapipe/calculators/tflite/util.h"
#include "mediapipe/framework/calculator_framework.h"
#include "mediapipe/framework/port/ret_check.h"
#include "mediapipe/util/tflite/config.h"
#if !defined(__EMSCRIPTEN__) || defined(__EMSCRIPTEN_PTHREADS__)
#include "mediapipe/util/cpu_util.h"
@ -33,7 +34,7 @@
#include "tensorflow/lite/kernels/register.h"
#include "tensorflow/lite/model.h"
#if !defined(MEDIAPIPE_DISABLE_GL_COMPUTE)
#if MEDIAPIPE_TFLITE_GL_INFERENCE
#include "mediapipe/gpu/gl_calculator_helper.h"
#include "mediapipe/gpu/gpu_buffer.h"
#include "mediapipe/util/tflite/tflite_gpu_runner.h"
@ -42,9 +43,9 @@
#include "tensorflow/lite/delegates/gpu/gl/gl_program.h"
#include "tensorflow/lite/delegates/gpu/gl/gl_shader.h"
#include "tensorflow/lite/delegates/gpu/gl_delegate.h"
#endif // !MEDIAPIPE_DISABLE_GL_COMPUTE
#endif // MEDIAPIPE_TFLITE_GL_INFERENCE
#if defined(MEDIAPIPE_IOS)
#if MEDIAPIPE_TFLITE_METAL_INFERENCE
#import <CoreVideo/CoreVideo.h>
#import <Metal/Metal.h>
#import <MetalKit/MetalKit.h>
@ -56,7 +57,7 @@
#include "tensorflow/lite/delegates/gpu/metal/buffer_convert.h"
#include "tensorflow/lite/delegates/gpu/metal_delegate.h"
#include "tensorflow/lite/delegates/gpu/metal_delegate_internal.h"
#endif // iOS
#endif // MEDIAPIPE_TFLITE_METAL_INFERENCE
#if !defined(MEDIAPIPE_EDGE_TPU)
#include "tensorflow/lite/delegates/xnnpack/xnnpack_delegate.h"
@ -71,12 +72,6 @@ int NumGroups(const int size, const int group_size) { // NOLINT
return (size + group_size - 1) / group_size;
}
#if !defined(MEDIAPIPE_DISABLE_GL_COMPUTE)
typedef ::tflite::gpu::gl::GlBuffer GpuTensor;
#elif defined(MEDIAPIPE_IOS)
typedef id<MTLBuffer> GpuTensor;
#endif
// Round up n to next multiple of m.
size_t RoundUp(size_t n, size_t m) { return ((n + m - 1) / m) * m; } // NOLINT
@ -112,13 +107,13 @@ std::unique_ptr<tflite::Interpreter> BuildEdgeTpuInterpreter(
// * Aux
namespace mediapipe {
#if !defined(MEDIAPIPE_DISABLE_GL_COMPUTE)
#if MEDIAPIPE_TFLITE_GL_INFERENCE
using ::tflite::gpu::gl::CopyBuffer;
using ::tflite::gpu::gl::CreateReadWriteShaderStorageBuffer;
using ::tflite::gpu::gl::GlBuffer;
#endif
#if !defined(MEDIAPIPE_DISABLE_GPU) && !defined(__EMSCRIPTEN__)
#if MEDIAPIPE_TFLITE_GPU_SUPPORTED
namespace {
struct GPUData {
int elements = 1;
@ -126,7 +121,7 @@ struct GPUData {
::tflite::gpu::BHWC shape;
};
} // namespace
#endif
#endif // MEDIAPIPE_TFLITE_GPU_SUPPORTED
// Returns number of threads to configure XNNPACK delegate with.
// (Equal to user provided value if specified. Otherwise, it returns number of
@ -152,7 +147,7 @@ int GetXnnpackNumThreads(
// Creates an interpreter with given model and calls invoke().
// Optionally run inference on CPU/GPU.
//
// This calculator is designed to be used with the TfLiteConverterCalcualtor,
// This calculator is designed to be used with the TfLiteConverterCalculator,
// to get the appropriate inputs.
//
// When the input tensors are on CPU, gpu inference is optional and can be
@ -183,7 +178,6 @@ int GetXnnpackNumThreads(
// options: {
// [mediapipe.TfLiteInferenceCalculatorOptions.ext] {
// model_path: "modelname.tflite"
// delegate { gpu {} }
// }
// }
// }
@ -192,11 +186,12 @@ int GetXnnpackNumThreads(
//
// node {
// calculator: "TfLiteInferenceCalculator"
// input_stream: "TENSORS:tensor_image"
// input_stream: "TENSORS_GPU:tensor_image"
// input_side_packet: "MODEL:model"
// output_stream: "TENSORS:tensors"
// output_stream: "TENSORS_GPU:tensors"
// options: {
// [mediapipe.TfLiteInferenceCalculatorOptions.ext] {
// model_path: "modelname.tflite"
// delegate { gpu {} }
// }
// }
@ -228,24 +223,45 @@ class TfLiteInferenceCalculator : public CalculatorBase {
::mediapipe::Status LoadModel(CalculatorContext* cc);
::mediapipe::StatusOr<Packet> GetModelAsPacket(const CalculatorContext& cc);
::mediapipe::Status LoadDelegate(CalculatorContext* cc);
::mediapipe::Status InitTFLiteGPURunner();
::mediapipe::Status InitTFLiteGPURunner(CalculatorContext* cc);
::mediapipe::Status ProcessInputsCpu(
CalculatorContext* cc, std::vector<TfLiteTensor>* output_tensors_cpu);
::mediapipe::Status ProcessOutputsCpu(
CalculatorContext* cc,
std::unique_ptr<std::vector<TfLiteTensor>> output_tensors_cpu);
::mediapipe::Status ProcessInputsGpu(
CalculatorContext* cc, std::vector<GpuTensor>* output_tensors_gpu);
::mediapipe::Status ProcessOutputsGpu(
CalculatorContext* cc,
std::unique_ptr<std::vector<TfLiteTensor>> output_tensors_cpu,
std::unique_ptr<std::vector<GpuTensor>> output_tensors_gpu);
::mediapipe::Status RunInContextIfNeeded(
std::function<::mediapipe::Status(void)> f) {
if (gpu_inference_) {
#if MEDIAPIPE_TFLITE_GL_INFERENCE
return gpu_helper_.RunInGlContext(std::move(f));
#endif // MEDIAPIPE_TFLITE_GL_INFERENCE
}
return f();
}
Packet model_packet_;
std::unique_ptr<tflite::Interpreter> interpreter_;
TfLiteDelegatePtr delegate_;
#if !defined(MEDIAPIPE_DISABLE_GL_COMPUTE)
#if MEDIAPIPE_TFLITE_GL_INFERENCE
mediapipe::GlCalculatorHelper gpu_helper_;
std::vector<std::unique_ptr<GPUData>> gpu_data_in_;
std::vector<std::unique_ptr<GPUData>> gpu_data_out_;
std::unique_ptr<tflite::gpu::TFLiteGPURunner> tflite_gpu_runner_;
#elif defined(MEDIAPIPE_IOS)
#elif MEDIAPIPE_TFLITE_METAL_INFERENCE
MPPMetalHelper* gpu_helper_ = nullptr;
std::vector<std::unique_ptr<GPUData>> gpu_data_in_;
std::vector<std::unique_ptr<GPUData>> gpu_data_out_;
id<MTLComputePipelineState> fp32_to_fp16_program_;
TFLBufferConvert* converter_from_BPHWC4_ = nil;
#endif
#endif // MEDIAPIPE_TFLITE_GL_INFERENCE
#if defined(MEDIAPIPE_EDGE_TPU)
std::shared_ptr<edgetpu::EdgeTpuContext> edgetpu_context_ =
@ -263,6 +279,22 @@ REGISTER_CALCULATOR(TfLiteInferenceCalculator);
// Calculator Core Section
namespace {
template <class CC>
bool ShouldUseGpu(CC* cc) {
#if MEDIAPIPE_TFLITE_GPU_SUPPORTED
const auto& options =
cc->template Options<::mediapipe::TfLiteInferenceCalculatorOptions>();
return options.use_gpu() ||
(options.has_delegate() && options.delegate().has_gpu()) ||
cc->Inputs().HasTag(kTensorsGpuTag) ||
cc->Outputs().HasTag(kTensorsGpuTag);
#else
return false;
#endif // MEDIAPIPE_TFLITE_GPU_SUPPORTED
}
} // namespace
::mediapipe::Status TfLiteInferenceCalculator::GetContract(
CalculatorContract* cc) {
RET_CHECK(cc->Inputs().HasTag(kTensorsTag) ^
@ -276,32 +308,15 @@ REGISTER_CALCULATOR(TfLiteInferenceCalculator);
cc->InputSidePackets().HasTag("MODEL"))
<< "Either model as side packet or model path in options is required.";
bool use_gpu =
options.has_delegate() ? options.delegate().has_gpu() : options.use_gpu();
if (cc->Inputs().HasTag(kTensorsTag))
cc->Inputs().Tag(kTensorsTag).Set<std::vector<TfLiteTensor>>();
#if !defined(MEDIAPIPE_DISABLE_GPU) && !defined(__EMSCRIPTEN__)
if (cc->Inputs().HasTag(kTensorsGpuTag)) {
RET_CHECK(!options.has_delegate() || options.delegate().has_gpu())
<< "GPU input is compatible with GPU delegate only.";
cc->Inputs().Tag(kTensorsGpuTag).Set<std::vector<GpuTensor>>();
use_gpu |= true;
}
#endif // !MEDIAPIPE_DISABLE_GPU
if (cc->Outputs().HasTag(kTensorsTag))
cc->Outputs().Tag(kTensorsTag).Set<std::vector<TfLiteTensor>>();
#if !defined(MEDIAPIPE_DISABLE_GPU) && !defined(__EMSCRIPTEN__)
if (cc->Outputs().HasTag(kTensorsGpuTag)) {
RET_CHECK(!options.has_delegate() || options.delegate().has_gpu())
<< "GPU output is compatible with GPU delegate only.";
if (cc->Inputs().HasTag(kTensorsGpuTag))
cc->Inputs().Tag(kTensorsGpuTag).Set<std::vector<GpuTensor>>();
if (cc->Outputs().HasTag(kTensorsGpuTag))
cc->Outputs().Tag(kTensorsGpuTag).Set<std::vector<GpuTensor>>();
use_gpu |= true;
}
#endif // !MEDIAPIPE_DISABLE_GPU
if (cc->InputSidePackets().HasTag("CUSTOM_OP_RESOLVER")) {
cc->InputSidePackets()
@ -312,10 +327,10 @@ REGISTER_CALCULATOR(TfLiteInferenceCalculator);
cc->InputSidePackets().Tag("MODEL").Set<TfLiteModelPtr>();
}
if (use_gpu) {
#if !defined(MEDIAPIPE_DISABLE_GL_COMPUTE)
if (ShouldUseGpu(cc)) {
#if MEDIAPIPE_TFLITE_GL_INFERENCE
MP_RETURN_IF_ERROR(mediapipe::GlCalculatorHelper::UpdateContract(cc));
#elif defined(MEDIAPIPE_IOS)
#elif MEDIAPIPE_TFLITE_METAL_INFERENCE
MP_RETURN_IF_ERROR([MPPMetalHelper updateContract:cc]);
#endif
}
@ -331,123 +346,181 @@ REGISTER_CALCULATOR(TfLiteInferenceCalculator);
const auto& options =
cc->Options<::mediapipe::TfLiteInferenceCalculatorOptions>();
gpu_inference_ = options.use_gpu();
if (cc->Inputs().HasTag(kTensorsGpuTag)) {
#if !defined(MEDIAPIPE_DISABLE_GPU) && !defined(__EMSCRIPTEN__)
gpu_input_ = true;
gpu_inference_ = true; // Inference must be on GPU also.
#else
RET_CHECK(!cc->Inputs().HasTag(kTensorsGpuTag))
<< "GPU processing not enabled.";
#endif // !MEDIAPIPE_DISABLE_GPU
}
gpu_inference_ = ShouldUseGpu(cc);
gpu_input_ = cc->Inputs().HasTag(kTensorsGpuTag);
gpu_output_ = cc->Outputs().HasTag(kTensorsGpuTag);
if (cc->Outputs().HasTag(kTensorsGpuTag)) {
#if !defined(MEDIAPIPE_DISABLE_GPU) && !defined(__EMSCRIPTEN__)
gpu_output_ = true;
RET_CHECK(cc->Inputs().HasTag(kTensorsGpuTag))
<< "GPU output must also have GPU Input.";
#else
RET_CHECK(!cc->Inputs().HasTag(kTensorsGpuTag))
<< "GPU processing not enabled.";
#endif // !MEDIAPIPE_DISABLE_GPU
}
use_advanced_gpu_api_ = false;
if (use_advanced_gpu_api_ && !(gpu_input_ && gpu_output_)) {
LOG(WARNING)
<< "Cannot use advanced GPU APIs, both inputs and outputs must "
"be GPU buffers. Falling back to the default TFLite API.";
use_advanced_gpu_api_ = MEDIAPIPE_TFLITE_GL_INFERENCE &&
options.has_delegate() &&
options.delegate().has_gpu() &&
options.delegate().gpu().use_advanced_gpu_api();
if (use_advanced_gpu_api_ && !gpu_input_) {
LOG(WARNING) << "Cannot use advanced GPU APIs, input must be GPU buffers."
"Falling back to the default TFLite API.";
use_advanced_gpu_api_ = false;
}
CHECK(!use_advanced_gpu_api_ || gpu_inference_);
MP_RETURN_IF_ERROR(LoadModel(cc));
if (gpu_inference_) {
#if !defined(MEDIAPIPE_DISABLE_GL_COMPUTE)
#if MEDIAPIPE_TFLITE_GL_INFERENCE
MP_RETURN_IF_ERROR(gpu_helper_.Open(cc));
#elif defined(MEDIAPIPE_IOS)
gpu_helper_ = [[MPPMetalHelper alloc] initWithCalculatorContext:cc];
RET_CHECK(gpu_helper_);
#endif
#if !defined(MEDIAPIPE_DISABLE_GL_COMPUTE)
MP_RETURN_IF_ERROR(
gpu_helper_.RunInGlContext([this, &cc]() -> ::mediapipe::Status {
return use_advanced_gpu_api_ ? InitTFLiteGPURunner()
return use_advanced_gpu_api_ ? InitTFLiteGPURunner(cc)
: LoadDelegate(cc);
}));
if (use_advanced_gpu_api_) return ::mediapipe::OkStatus();
#else
#elif MEDIAPIPE_TFLITE_METAL_INFERENCE
gpu_helper_ = [[MPPMetalHelper alloc] initWithCalculatorContext:cc];
RET_CHECK(gpu_helper_);
MP_RETURN_IF_ERROR(LoadDelegate(cc));
#endif
} else {
#if defined(__EMSCRIPTEN__) || defined(MEDIAPIPE_ANDROID)
// TODO: why only on these platforms?
// It seems that the XNNPACK delegate fails to load on Linux.
#if defined(__EMSCRIPTEN__) || defined(MEDIAPIPE_ANDROID) || \
defined(MEDIAPIPE_IOS)
MP_RETURN_IF_ERROR(LoadDelegate(cc));
#endif // __EMSCRIPTEN__ || ANDROID
#endif // __EMSCRIPTEN__ || MEDIAPIPE_ANDROID || MEDIAPIPE_IOS
}
return ::mediapipe::OkStatus();
}
::mediapipe::Status TfLiteInferenceCalculator::Process(CalculatorContext* cc) {
// 0. Declare outputs
#if !defined(MEDIAPIPE_DISABLE_GL_COMPUTE) || defined(MEDIAPIPE_IOS)
auto output_tensors_gpu = absl::make_unique<std::vector<GpuTensor>>();
#endif
auto output_tensors_cpu = absl::make_unique<std::vector<TfLiteTensor>>();
return RunInContextIfNeeded([this, cc]() -> ::mediapipe::Status {
// 0. Declare outputs
auto output_tensors_gpu = absl::make_unique<std::vector<GpuTensor>>();
auto output_tensors_cpu = absl::make_unique<std::vector<TfLiteTensor>>();
// 1. Receive pre-processed tensor inputs.
if (use_advanced_gpu_api_ && gpu_output_) {
#if !defined(MEDIAPIPE_DISABLE_GL_COMPUTE)
if (cc->Inputs().Tag(kTensorsGpuTag).IsEmpty()) {
return ::mediapipe::OkStatus();
// 1. Receive pre-processed tensor inputs.
if (gpu_input_) {
MP_RETURN_IF_ERROR(ProcessInputsGpu(cc, output_tensors_gpu.get()));
} else {
MP_RETURN_IF_ERROR(ProcessInputsCpu(cc, output_tensors_cpu.get()));
}
// 2. Run inference.
#if MEDIAPIPE_TFLITE_GL_INFERENCE
if (gpu_inference_ && use_advanced_gpu_api_) {
RET_CHECK(tflite_gpu_runner_->Invoke().ok());
} else {
RET_CHECK_EQ(interpreter_->Invoke(), kTfLiteOk);
}
#else
RET_CHECK_EQ(interpreter_->Invoke(), kTfLiteOk);
#endif // MEDIAPIPE_TFLITE_GL_INFERENCE
// 3. Output processed tensors.
if (gpu_output_ || use_advanced_gpu_api_) {
MP_RETURN_IF_ERROR(ProcessOutputsGpu(cc, std::move(output_tensors_cpu),
std::move(output_tensors_gpu)));
} else {
MP_RETURN_IF_ERROR(ProcessOutputsCpu(cc, std::move(output_tensors_cpu)));
}
return ::mediapipe::OkStatus();
});
}
::mediapipe::Status TfLiteInferenceCalculator::Close(CalculatorContext* cc) {
return RunInContextIfNeeded([this]() -> ::mediapipe::Status {
if (delegate_) {
interpreter_ = nullptr;
delegate_ = nullptr;
#if MEDIAPIPE_TFLITE_GPU_SUPPORTED
if (gpu_inference_) {
for (int i = 0; i < gpu_data_in_.size(); ++i) {
gpu_data_in_[i].reset();
}
for (int i = 0; i < gpu_data_out_.size(); ++i) {
gpu_data_out_[i].reset();
}
}
#endif // MEDIAPIPE_TFLITE_GPU_SUPPORTED
}
#if defined(MEDIAPIPE_EDGE_TPU)
edgetpu_context_.reset();
#endif
return ::mediapipe::OkStatus();
});
}
// Calculator Auxiliary Section
::mediapipe::Status TfLiteInferenceCalculator::ProcessInputsCpu(
CalculatorContext* cc, std::vector<TfLiteTensor>* output_tensors_cpu) {
if (cc->Inputs().Tag(kTensorsTag).IsEmpty()) {
return ::mediapipe::OkStatus();
}
// Read CPU input into tensors.
const auto& input_tensors =
cc->Inputs().Tag(kTensorsTag).Get<std::vector<TfLiteTensor>>();
RET_CHECK_GT(input_tensors.size(), 0);
for (int i = 0; i < input_tensors.size(); ++i) {
const TfLiteTensor* input_tensor = &input_tensors[i];
RET_CHECK(input_tensor->data.raw);
if (use_quantized_tensors_) {
const uint8* input_tensor_buffer = input_tensor->data.uint8;
uint8* local_tensor_buffer = interpreter_->typed_input_tensor<uint8>(i);
std::memcpy(local_tensor_buffer, input_tensor_buffer,
input_tensor->bytes);
} else {
const float* input_tensor_buffer = input_tensor->data.f;
float* local_tensor_buffer = interpreter_->typed_input_tensor<float>(i);
std::memcpy(local_tensor_buffer, input_tensor_buffer,
input_tensor->bytes);
}
}
return ::mediapipe::OkStatus();
}
::mediapipe::Status TfLiteInferenceCalculator::ProcessInputsGpu(
CalculatorContext* cc, std::vector<GpuTensor>* output_tensors_gpu) {
if (cc->Inputs().Tag(kTensorsGpuTag).IsEmpty()) {
return ::mediapipe::OkStatus();
}
if (use_advanced_gpu_api_) {
#if MEDIAPIPE_TFLITE_GL_INFERENCE
const auto& input_tensors =
cc->Inputs().Tag(kTensorsGpuTag).Get<std::vector<GpuTensor>>();
RET_CHECK(!input_tensors.empty());
MP_RETURN_IF_ERROR(gpu_helper_.RunInGlContext(
[this, &input_tensors, &output_tensors_gpu]() -> ::mediapipe::Status {
for (int i = 0; i < input_tensors.size(); ++i) {
MP_RETURN_IF_ERROR(tflite_gpu_runner_->BindSSBOToInputTensor(
input_tensors[i].id(), i));
}
// Allocate output tensor.
output_tensors_gpu->resize(gpu_data_out_.size());
for (int i = 0; i < gpu_data_out_.size(); ++i) {
GpuTensor& tensor = output_tensors_gpu->at(i);
RET_CHECK_CALL(CreateReadWriteShaderStorageBuffer<float>(
gpu_data_out_[i]->elements, &tensor));
MP_RETURN_IF_ERROR(
tflite_gpu_runner_->BindSSBOToOutputTensor(tensor.id(), i));
}
return ::mediapipe::OkStatus();
}));
#endif
for (int i = 0; i < input_tensors.size(); ++i) {
MP_RETURN_IF_ERROR(
tflite_gpu_runner_->BindSSBOToInputTensor(input_tensors[i].id(), i));
}
if (gpu_output_) {
// Allocate new output tensor.
output_tensors_gpu->resize(gpu_data_out_.size());
for (int i = 0; i < gpu_data_out_.size(); ++i) {
GpuTensor& tensor = output_tensors_gpu->at(i);
RET_CHECK_CALL(CreateReadWriteShaderStorageBuffer<float>(
gpu_data_out_[i]->elements, &tensor));
MP_RETURN_IF_ERROR(
tflite_gpu_runner_->BindSSBOToOutputTensor(tensor.id(), i));
}
} else {
// Re-use internal output tensor.
for (int i = 0; i < gpu_data_out_.size(); ++i) {
MP_RETURN_IF_ERROR(tflite_gpu_runner_->BindSSBOToOutputTensor(
gpu_data_out_[i]->buffer.id(), i));
}
}
#endif // MEDIAPIPE_TFLITE_GL_INFERENCE
} else if (gpu_input_) {
// Read GPU input into SSBO.
#if !defined(MEDIAPIPE_DISABLE_GL_COMPUTE)
if (cc->Inputs().Tag(kTensorsGpuTag).IsEmpty()) {
return ::mediapipe::OkStatus();
}
#if MEDIAPIPE_TFLITE_GL_INFERENCE
const auto& input_tensors =
cc->Inputs().Tag(kTensorsGpuTag).Get<std::vector<GpuTensor>>();
RET_CHECK_GT(input_tensors.size(), 0);
MP_RETURN_IF_ERROR(gpu_helper_.RunInGlContext(
[this, &input_tensors]() -> ::mediapipe::Status {
// Explicit copy input.
gpu_data_in_.resize(input_tensors.size());
for (int i = 0; i < input_tensors.size(); ++i) {
RET_CHECK_CALL(
CopyBuffer(input_tensors[i], gpu_data_in_[i]->buffer));
}
return ::mediapipe::OkStatus();
}));
#elif defined(MEDIAPIPE_IOS)
if (cc->Inputs().Tag(kTensorsGpuTag).IsEmpty()) {
return ::mediapipe::OkStatus();
// Explicit copy input.
gpu_data_in_.resize(input_tensors.size());
for (int i = 0; i < input_tensors.size(); ++i) {
RET_CHECK_CALL(CopyBuffer(input_tensors[i], gpu_data_in_[i]->buffer));
}
#elif MEDIAPIPE_TFLITE_METAL_INFERENCE
const auto& input_tensors =
cc->Inputs().Tag(kTensorsGpuTag).Get<std::vector<GpuTensor>>();
RET_CHECK_GT(input_tensors.size(), 0);
@ -470,79 +543,70 @@ REGISTER_CALCULATOR(TfLiteInferenceCalculator);
}
[compute_encoder endEncoding];
[command_buffer commit];
#else
RET_CHECK_FAIL() << "GPU processing not enabled.";
#endif
} else {
if (cc->Inputs().Tag(kTensorsTag).IsEmpty()) {
return ::mediapipe::OkStatus();
}
// Read CPU input into tensors.
const auto& input_tensors =
cc->Inputs().Tag(kTensorsTag).Get<std::vector<TfLiteTensor>>();
RET_CHECK_GT(input_tensors.size(), 0);
for (int i = 0; i < input_tensors.size(); ++i) {
const TfLiteTensor* input_tensor = &input_tensors[i];
RET_CHECK(input_tensor->data.raw);
if (use_quantized_tensors_) {
const uint8* input_tensor_buffer = input_tensor->data.uint8;
uint8* local_tensor_buffer = interpreter_->typed_input_tensor<uint8>(i);
std::memcpy(local_tensor_buffer, input_tensor_buffer,
input_tensor->bytes);
} else {
const float* input_tensor_buffer = input_tensor->data.f;
float* local_tensor_buffer = interpreter_->typed_input_tensor<float>(i);
std::memcpy(local_tensor_buffer, input_tensor_buffer,
input_tensor->bytes);
}
}
#endif // MEDIAPIPE_TFLITE_GL_INFERENCE
}
// 2. Run inference.
if (gpu_inference_) {
#if !defined(MEDIAPIPE_DISABLE_GL_COMPUTE)
MP_RETURN_IF_ERROR(
gpu_helper_.RunInGlContext([this]() -> ::mediapipe::Status {
if (use_advanced_gpu_api_) {
RET_CHECK(tflite_gpu_runner_->Invoke().ok());
} else {
RET_CHECK_EQ(interpreter_->Invoke(), kTfLiteOk);
}
return ::mediapipe::OkStatus();
}));
#elif defined(MEDIAPIPE_IOS)
RET_CHECK_EQ(interpreter_->Invoke(), kTfLiteOk);
#endif
} else {
RET_CHECK_EQ(interpreter_->Invoke(), kTfLiteOk);
}
return ::mediapipe::OkStatus();
}
// 3. Output processed tensors.
::mediapipe::Status TfLiteInferenceCalculator::ProcessOutputsCpu(
CalculatorContext* cc,
std::unique_ptr<std::vector<TfLiteTensor>> output_tensors_cpu) {
// Output result tensors (CPU).
const auto& tensor_indexes = interpreter_->outputs();
for (int i = 0; i < tensor_indexes.size(); ++i) {
TfLiteTensor* tensor = interpreter_->tensor(tensor_indexes[i]);
output_tensors_cpu->emplace_back(*tensor);
}
cc->Outputs()
.Tag(kTensorsTag)
.Add(output_tensors_cpu.release(), cc->InputTimestamp());
return ::mediapipe::OkStatus();
}
::mediapipe::Status TfLiteInferenceCalculator::ProcessOutputsGpu(
CalculatorContext* cc,
std::unique_ptr<std::vector<TfLiteTensor>> output_tensors_cpu,
std::unique_ptr<std::vector<GpuTensor>> output_tensors_gpu) {
if (use_advanced_gpu_api_) {
#if !defined(MEDIAPIPE_DISABLE_GL_COMPUTE)
cc->Outputs()
.Tag(kTensorsGpuTag)
.Add(output_tensors_gpu.release(), cc->InputTimestamp());
#endif
#if MEDIAPIPE_TFLITE_GL_INFERENCE
if (gpu_output_) {
// Send out pre-allocated tensors.
cc->Outputs()
.Tag(kTensorsGpuTag)
.Add(output_tensors_gpu.release(), cc->InputTimestamp());
} else {
// Download to CPU for output.
const auto& tensor_indexes = interpreter_->inputs();
for (int i = 0; i < tensor_indexes.size(); ++i) {
TfLiteTensor* tensor = interpreter_->tensor(tensor_indexes[i]);
std::vector<float> gpu_data(tensor->bytes / sizeof(float));
RET_CHECK_CALL(gpu_data_out_[i]->buffer.Read(
absl::MakeSpan(tensor->data.f, tensor->bytes)));
output_tensors_cpu->emplace_back(*tensor);
}
// Output result tensors (CPU).
cc->Outputs()
.Tag(kTensorsTag)
.Add(output_tensors_cpu.release(), cc->InputTimestamp());
}
#endif // MEDIAPIPE_TFLITE_GL_INFERENCE
} else if (gpu_output_) {
#if !defined(MEDIAPIPE_DISABLE_GL_COMPUTE)
#if MEDIAPIPE_TFLITE_GL_INFERENCE
// Output result tensors (GPU).
MP_RETURN_IF_ERROR(gpu_helper_.RunInGlContext(
[this, &output_tensors_gpu]() -> ::mediapipe::Status {
output_tensors_gpu->resize(gpu_data_out_.size());
for (int i = 0; i < gpu_data_out_.size(); ++i) {
GpuTensor& tensor = output_tensors_gpu->at(i);
// Allocate output tensor.
RET_CHECK_CALL(CreateReadWriteShaderStorageBuffer<float>(
gpu_data_out_[i]->elements, &tensor));
RET_CHECK_CALL(CopyBuffer(gpu_data_out_[i]->buffer, tensor));
}
return ::mediapipe::OkStatus();
}));
output_tensors_gpu->resize(gpu_data_out_.size());
for (int i = 0; i < gpu_data_out_.size(); ++i) {
GpuTensor& tensor = output_tensors_gpu->at(i);
// Allocate output tensor.
RET_CHECK_CALL(CreateReadWriteShaderStorageBuffer<float>(
gpu_data_out_[i]->elements, &tensor));
RET_CHECK_CALL(CopyBuffer(gpu_data_out_[i]->buffer, tensor));
}
cc->Outputs()
.Tag(kTensorsGpuTag)
.Add(output_tensors_gpu.release(), cc->InputTimestamp());
#elif defined(MEDIAPIPE_IOS)
#elif MEDIAPIPE_TFLITE_METAL_INFERENCE
// Output result tensors (GPU).
output_tensors_gpu->resize(gpu_data_out_.size());
id<MTLDevice> device = gpu_helper_.mtlDevice;
@ -566,68 +630,58 @@ REGISTER_CALCULATOR(TfLiteInferenceCalculator);
cc->Outputs()
.Tag(kTensorsGpuTag)
.Add(output_tensors_gpu.release(), cc->InputTimestamp());
#else
RET_CHECK_FAIL() << "GPU processing not enabled.";
#endif // !MEDIAPIPE_DISABLE_GPU
} else {
// Output result tensors (CPU).
const auto& tensor_indexes = interpreter_->outputs();
for (int i = 0; i < tensor_indexes.size(); ++i) {
TfLiteTensor* tensor = interpreter_->tensor(tensor_indexes[i]);
output_tensors_cpu->emplace_back(*tensor);
}
cc->Outputs()
.Tag(kTensorsTag)
.Add(output_tensors_cpu.release(), cc->InputTimestamp());
#endif // MEDIAPIPE_TFLITE_GL_INFERENCE
}
return ::mediapipe::OkStatus();
}
::mediapipe::Status TfLiteInferenceCalculator::Close(CalculatorContext* cc) {
if (delegate_) {
if (gpu_inference_) {
#if !defined(MEDIAPIPE_DISABLE_GL_COMPUTE)
MP_RETURN_IF_ERROR(gpu_helper_.RunInGlContext([this]() -> Status {
interpreter_ = nullptr;
delegate_ = nullptr;
for (int i = 0; i < gpu_data_in_.size(); ++i) {
gpu_data_in_[i].reset();
}
for (int i = 0; i < gpu_data_out_.size(); ++i) {
gpu_data_out_[i].reset();
}
return ::mediapipe::OkStatus();
}));
#elif defined(MEDIAPIPE_IOS)
interpreter_ = nullptr;
delegate_ = nullptr;
for (int i = 0; i < gpu_data_in_.size(); ++i) {
gpu_data_in_[i].reset();
}
for (int i = 0; i < gpu_data_out_.size(); ++i) {
gpu_data_out_[i].reset();
}
#endif
} else {
interpreter_ = nullptr;
delegate_ = nullptr;
}
::mediapipe::Status TfLiteInferenceCalculator::InitTFLiteGPURunner(
CalculatorContext* cc) {
#if MEDIAPIPE_TFLITE_GL_INFERENCE
ASSIGN_OR_RETURN(model_packet_, GetModelAsPacket(*cc));
const auto& model = *model_packet_.Get<TfLiteModelPtr>();
tflite::ops::builtin::BuiltinOpResolver op_resolver;
if (cc->InputSidePackets().HasTag("CUSTOM_OP_RESOLVER")) {
op_resolver = cc->InputSidePackets()
.Tag("CUSTOM_OP_RESOLVER")
.Get<tflite::ops::builtin::BuiltinOpResolver>();
}
#if defined(MEDIAPIPE_EDGE_TPU)
edgetpu_context_.reset();
#endif
return ::mediapipe::OkStatus();
}
// Calculator Auxiliary Section
// Create runner
tflite::gpu::InferenceOptions options;
options.priority1 = tflite::gpu::InferencePriority::MIN_LATENCY;
options.priority2 = tflite::gpu::InferencePriority::AUTO;
options.priority3 = tflite::gpu::InferencePriority::AUTO;
options.usage = tflite::gpu::InferenceUsage::SUSTAINED_SPEED;
tflite_gpu_runner_ = std::make_unique<tflite::gpu::TFLiteGPURunner>(options);
RET_CHECK_CALL(tflite_gpu_runner_->InitializeWithModel(model, op_resolver));
// Allocate interpreter memory for cpu output.
if (!gpu_output_) {
interpreter_ = absl::make_unique<tflite::Interpreter>();
const int num_outputs = tflite_gpu_runner_->GetOutputShapes().size();
interpreter_->AddTensors(num_outputs);
std::vector<int> indices(num_outputs);
for (int i = 0; i < num_outputs; ++i) indices[i] = i;
// There is no ResizeOutputTensor(), so we use 'inputs' space instead.
interpreter_->SetInputs(indices);
TfLiteQuantization quant;
quant.type = kTfLiteNoQuantization;
quant.params = nullptr;
for (int i = 0; i < num_outputs; ++i) {
auto shape = tflite_gpu_runner_->GetOutputShapes()[i];
const int tensor_idx = interpreter_->inputs()[i];
interpreter_->SetTensorParametersReadWrite(tensor_idx, kTfLiteFloat32, "",
{shape.c}, quant);
CHECK(interpreter_->ResizeInputTensor(
tensor_idx, {shape.h, shape.w, shape.c}) == kTfLiteOk);
}
CHECK(interpreter_->AllocateTensors() == kTfLiteOk);
}
::mediapipe::Status TfLiteInferenceCalculator::InitTFLiteGPURunner() {
#if !defined(MEDIAPIPE_DISABLE_GL_COMPUTE)
// Create and bind OpenGL buffers for outputs.
// These buffers are created onve and later their ids are jut passed to the
// calculator outputs.
// The buffers are created once and their ids are passed to calculator outputs
gpu_data_out_.resize(tflite_gpu_runner_->outputs_size());
for (int i = 0; i < tflite_gpu_runner_->outputs_size(); ++i) {
gpu_data_out_[i] = absl::make_unique<GPUData>();
@ -638,15 +692,20 @@ REGISTER_CALCULATOR(TfLiteInferenceCalculator);
gpu_data_out_[i]->elements, &gpu_data_out_[i]->buffer));
}
RET_CHECK_CALL(tflite_gpu_runner_->Build());
#endif
#endif // MEDIAPIPE_TFLITE_GL_INFERENCE
return ::mediapipe::OkStatus();
}
::mediapipe::Status TfLiteInferenceCalculator::LoadModel(
CalculatorContext* cc) {
if (use_advanced_gpu_api_) {
// Use InitTFLiteGPURunner for everything.
return ::mediapipe::OkStatus();
}
ASSIGN_OR_RETURN(model_packet_, GetModelAsPacket(*cc));
const auto& model = *model_packet_.Get<TfLiteModelPtr>();
tflite::ops::builtin::BuiltinOpResolver op_resolver;
if (cc->InputSidePackets().HasTag("CUSTOM_OP_RESOLVER")) {
op_resolver = cc->InputSidePackets()
@ -654,19 +713,6 @@ REGISTER_CALCULATOR(TfLiteInferenceCalculator);
.Get<tflite::ops::builtin::BuiltinOpResolver>();
}
#if !defined(MEDIAPIPE_DISABLE_GL_COMPUTE)
if (use_advanced_gpu_api_) {
tflite::gpu::InferenceOptions options;
options.priority1 = tflite::gpu::InferencePriority::MIN_LATENCY;
options.priority2 = tflite::gpu::InferencePriority::AUTO;
options.priority3 = tflite::gpu::InferencePriority::AUTO;
options.usage = tflite::gpu::InferenceUsage::SUSTAINED_SPEED;
tflite_gpu_runner_ =
std::make_unique<tflite::gpu::TFLiteGPURunner>(options);
return tflite_gpu_runner_->InitializeWithModel(model, op_resolver);
}
#endif
#if defined(MEDIAPIPE_EDGE_TPU)
interpreter_ =
BuildEdgeTpuInterpreter(model, &op_resolver, edgetpu_context_.get());
@ -771,7 +817,7 @@ REGISTER_CALCULATOR(TfLiteInferenceCalculator);
return ::mediapipe::OkStatus();
}
#if !defined(MEDIAPIPE_DISABLE_GL_COMPUTE)
#if MEDIAPIPE_TFLITE_GL_INFERENCE
// Configure and create the delegate.
TfLiteGpuDelegateOptions options = TfLiteGpuDelegateOptionsDefault();
options.compile_options.precision_loss_allowed = 1;
@ -832,9 +878,9 @@ REGISTER_CALCULATOR(TfLiteInferenceCalculator);
// Must call this last.
RET_CHECK_EQ(interpreter_->ModifyGraphWithDelegate(delegate_.get()),
kTfLiteOk);
#endif // OpenGL
#endif // MEDIAPIPE_TFLITE_GL_INFERENCE
#if defined(MEDIAPIPE_IOS)
#if MEDIAPIPE_TFLITE_METAL_INFERENCE
const int kHalfSize = 2; // sizeof(half)
// Configure and create the delegate.
TFLGpuDelegateOptions options;
@ -958,7 +1004,7 @@ REGISTER_CALCULATOR(TfLiteInferenceCalculator);
"Error initializating output buffer converter");
}
}
#endif // iOS
#endif // MEDIAPIPE_TFLITE_METAL_INFERENCE
return ::mediapipe::OkStatus();
}

View File

@ -45,6 +45,8 @@ message TfLiteInferenceCalculatorOptions {
message Gpu {
// Experimental, Android/Linux only. Use TFLite GPU delegate API2 for
// the NN inference.
// example:
// delegate: { gpu { use_advanced_gpu_api: true } }
optional bool use_advanced_gpu_api = 1 [default = false];
}
// Android only.

View File

@ -25,17 +25,18 @@
#include "mediapipe/framework/formats/location.h"
#include "mediapipe/framework/formats/object_detection/anchor.pb.h"
#include "mediapipe/framework/port/ret_check.h"
#include "mediapipe/util/tflite/config.h"
#include "tensorflow/lite/interpreter.h"
#if !defined(MEDIAPIPE_DISABLE_GL_COMPUTE)
#if MEDIAPIPE_TFLITE_GL_INFERENCE
#include "mediapipe/gpu/gl_calculator_helper.h"
#include "tensorflow/lite/delegates/gpu/gl/gl_buffer.h"
#include "tensorflow/lite/delegates/gpu/gl/gl_program.h"
#include "tensorflow/lite/delegates/gpu/gl/gl_shader.h"
#include "tensorflow/lite/delegates/gpu/gl_delegate.h"
#endif // !MEDIAPIPE_DISABLE_GPU
#endif // MEDIAPIPE_TFLITE_GL_INFERENCE
#if defined(MEDIAPIPE_IOS)
#if MEDIAPIPE_TFLITE_METAL_INFERENCE
#import <CoreVideo/CoreVideo.h>
#import <Metal/Metal.h>
#import <MetalKit/MetalKit.h>
@ -44,7 +45,7 @@
#include "mediapipe/gpu/MPPMetalUtil.h"
#include "mediapipe/gpu/gpu_buffer.h"
#include "tensorflow/lite/delegates/gpu/metal_delegate.h"
#endif // iOS
#endif // MEDIAPIPE_TFLITE_METAL_INFERENCE
namespace {
constexpr int kNumInputTensorsWithAnchors = 3;
@ -56,22 +57,17 @@ constexpr char kTensorsGpuTag[] = "TENSORS_GPU";
namespace mediapipe {
#if !defined(MEDIAPIPE_DISABLE_GL_COMPUTE)
#if MEDIAPIPE_TFLITE_GL_INFERENCE
using ::tflite::gpu::gl::CreateReadWriteShaderStorageBuffer;
using ::tflite::gpu::gl::GlShader;
#endif
#if !defined(MEDIAPIPE_DISABLE_GL_COMPUTE)
typedef ::tflite::gpu::gl::GlBuffer GpuTensor;
typedef ::tflite::gpu::gl::GlProgram GpuProgram;
#elif defined(MEDIAPIPE_IOS)
typedef id<MTLBuffer> GpuTensor;
#elif MEDIAPIPE_TFLITE_METAL_INFERENCE
typedef id<MTLComputePipelineState> GpuProgram;
#endif
#endif // MEDIAPIPE_TFLITE_GL_INFERENCE
namespace {
#if !defined(MEDIAPIPE_DISABLE_GPU) && !defined(__EMSCRIPTEN__)
#if MEDIAPIPE_TFLITE_GPU_SUPPORTED
struct GPUData {
GpuProgram decode_program;
GpuProgram score_program;
@ -81,7 +77,7 @@ struct GPUData {
GpuTensor scored_boxes_buffer;
GpuTensor raw_scores_buffer;
};
#endif
#endif // MEDIAPIPE_TFLITE_GPU_SUPPORTED
void ConvertRawValuesToAnchors(const float* raw_anchors, int num_boxes,
std::vector<Anchor>* anchors) {
@ -181,13 +177,13 @@ class TfLiteTensorsToDetectionsCalculator : public CalculatorBase {
std::vector<Anchor> anchors_;
bool side_packet_anchors_{};
#if !defined(MEDIAPIPE_DISABLE_GL_COMPUTE)
#if MEDIAPIPE_TFLITE_GL_INFERENCE
mediapipe::GlCalculatorHelper gpu_helper_;
std::unique_ptr<GPUData> gpu_data_;
#elif defined(MEDIAPIPE_IOS)
#elif MEDIAPIPE_TFLITE_METAL_INFERENCE
MPPMetalHelper* gpu_helper_ = nullptr;
std::unique_ptr<GPUData> gpu_data_;
#endif
#endif // MEDIAPIPE_TFLITE_GL_INFERENCE
bool gpu_input_ = false;
bool anchors_init_ = false;
@ -205,12 +201,10 @@ REGISTER_CALCULATOR(TfLiteTensorsToDetectionsCalculator);
cc->Inputs().Tag(kTensorsTag).Set<std::vector<TfLiteTensor>>();
}
#if !defined(MEDIAPIPE_DISABLE_GPU) && !defined(__EMSCRIPTEN__)
if (cc->Inputs().HasTag(kTensorsGpuTag)) {
cc->Inputs().Tag(kTensorsGpuTag).Set<std::vector<GpuTensor>>();
use_gpu |= true;
}
#endif // !MEDIAPIPE_DISABLE_GPU
if (cc->Outputs().HasTag("DETECTIONS")) {
cc->Outputs().Tag("DETECTIONS").Set<std::vector<Detection>>();
@ -223,11 +217,11 @@ REGISTER_CALCULATOR(TfLiteTensorsToDetectionsCalculator);
}
if (use_gpu) {
#if !defined(MEDIAPIPE_DISABLE_GL_COMPUTE)
#if MEDIAPIPE_TFLITE_GL_INFERENCE
MP_RETURN_IF_ERROR(mediapipe::GlCalculatorHelper::UpdateContract(cc));
#elif defined(MEDIAPIPE_IOS)
#elif MEDIAPIPE_TFLITE_METAL_INFERENCE
MP_RETURN_IF_ERROR([MPPMetalHelper updateContract:cc]);
#endif
#endif // MEDIAPIPE_TFLITE_GL_INFERENCE
}
return ::mediapipe::OkStatus();
@ -239,12 +233,12 @@ REGISTER_CALCULATOR(TfLiteTensorsToDetectionsCalculator);
if (cc->Inputs().HasTag(kTensorsGpuTag)) {
gpu_input_ = true;
#if !defined(MEDIAPIPE_DISABLE_GL_COMPUTE)
#if MEDIAPIPE_TFLITE_GL_INFERENCE
MP_RETURN_IF_ERROR(gpu_helper_.Open(cc));
#elif defined(MEDIAPIPE_IOS)
#elif MEDIAPIPE_TFLITE_METAL_INFERENCE
gpu_helper_ = [[MPPMetalHelper alloc] initWithCalculatorContext:cc];
RET_CHECK(gpu_helper_);
#endif
#endif // MEDIAPIPE_TFLITE_GL_INFERENCE
}
MP_RETURN_IF_ERROR(LoadOptions(cc));
@ -401,7 +395,7 @@ REGISTER_CALCULATOR(TfLiteTensorsToDetectionsCalculator);
}
::mediapipe::Status TfLiteTensorsToDetectionsCalculator::ProcessGPU(
CalculatorContext* cc, std::vector<Detection>* output_detections) {
#if !defined(MEDIAPIPE_DISABLE_GL_COMPUTE)
#if MEDIAPIPE_TFLITE_GL_INFERENCE
const auto& input_tensors =
cc->Inputs().Tag(kTensorsGpuTag).Get<std::vector<GpuTensor>>();
RET_CHECK_GE(input_tensors.size(), 2);
@ -464,7 +458,7 @@ REGISTER_CALCULATOR(TfLiteTensorsToDetectionsCalculator);
return ::mediapipe::OkStatus();
}));
#elif defined(MEDIAPIPE_IOS)
#elif MEDIAPIPE_TFLITE_METAL_INFERENCE
const auto& input_tensors =
cc->Inputs().Tag(kTensorsGpuTag).Get<std::vector<GpuTensor>>();
@ -546,17 +540,17 @@ REGISTER_CALCULATOR(TfLiteTensorsToDetectionsCalculator);
#else
LOG(ERROR) << "GPU input on non-Android not supported yet.";
#endif
#endif // MEDIAPIPE_TFLITE_GL_INFERENCE
return ::mediapipe::OkStatus();
}
::mediapipe::Status TfLiteTensorsToDetectionsCalculator::Close(
CalculatorContext* cc) {
#if !defined(MEDIAPIPE_DISABLE_GL_COMPUTE)
#if MEDIAPIPE_TFLITE_GL_INFERENCE
gpu_helper_.RunInGlContext([this] { gpu_data_.reset(); });
#elif defined(MEDIAPIPE_IOS)
#elif MEDIAPIPE_TFLITE_METAL_INFERENCE
gpu_data_.reset();
#endif
#endif // MEDIAPIPE_TFLITE_GL_INFERENCE
return ::mediapipe::OkStatus();
}
@ -705,7 +699,7 @@ Detection TfLiteTensorsToDetectionsCalculator::ConvertToDetection(
::mediapipe::Status TfLiteTensorsToDetectionsCalculator::GpuInit(
CalculatorContext* cc) {
#if !defined(MEDIAPIPE_DISABLE_GL_COMPUTE)
#if MEDIAPIPE_TFLITE_GL_INFERENCE
MP_RETURN_IF_ERROR(gpu_helper_.RunInGlContext([this]()
-> ::mediapipe::Status {
gpu_data_ = absl::make_unique<GPUData>();
@ -918,7 +912,7 @@ void main() {
return ::mediapipe::OkStatus();
}));
#elif defined(MEDIAPIPE_IOS)
#elif MEDIAPIPE_TFLITE_METAL_INFERENCE
gpu_data_ = absl::make_unique<GPUData>();
id<MTLDevice> device = gpu_helper_.mtlDevice;
@ -1148,7 +1142,7 @@ kernel void scoreKernel(
CHECK_LT(num_classes_, max_wg_size) << "# classes must be <" << max_wg_size;
}
#endif // !defined(MEDIAPIPE_DISABLE_GL_COMPUTE)
#endif // MEDIAPIPE_TFLITE_GL_INFERENCE
return ::mediapipe::OkStatus();
}

View File

@ -217,11 +217,11 @@ REGISTER_CALCULATOR(TfLiteTensorsToLandmarksCalculator);
for (int i = 0; i < output_landmarks.landmark_size(); ++i) {
const Landmark& landmark = output_landmarks.landmark(i);
NormalizedLandmark* norm_landmark = output_norm_landmarks.add_landmark();
norm_landmark->set_x(static_cast<float>(landmark.x()) /
options_.input_image_width());
norm_landmark->set_y(static_cast<float>(landmark.y()) /
options_.input_image_height());
norm_landmark->set_z(landmark.z() / options_.normalize_z());
norm_landmark->set_x(landmark.x() / options_.input_image_width());
norm_landmark->set_y(landmark.y() / options_.input_image_height());
// Scale Z coordinate as X + allow additional uniform normalization.
norm_landmark->set_z(landmark.z() / options_.input_image_width() /
options_.normalize_z());
norm_landmark->set_visibility(landmark.visibility());
}
cc->Outputs()

View File

@ -29,7 +29,8 @@ message TfLiteTensorsToLandmarksCalculatorOptions {
required int32 num_landmarks = 1;
// Size of the input image for the model. These options are used only when
// normalized landmarks is needed.
// normalized landmarks are needed. Z coordinate is scaled as X assuming
// a weak perspective projection camera model.
optional int32 input_image_width = 2;
optional int32 input_image_height = 3;
@ -46,6 +47,8 @@ message TfLiteTensorsToLandmarksCalculatorOptions {
// beforehand.
optional bool flip_horizontally = 6 [default = false];
// A value that z values should be divided by.
// A value that Z coordinates should be divided by. This option is used only
// when normalized landmarks are needed. It is applied in addition to Z
// coordinate being re-scaled as X.
optional float normalize_z = 5 [default = 1.0];
}

View File

@ -376,6 +376,7 @@ cc_library(
visibility = ["//visibility:public"],
deps = [
":timed_box_list_id_to_label_calculator_cc_proto",
"@com_google_absl//absl/container:node_hash_map",
"//mediapipe/framework/port:status",
"//mediapipe/framework:calculator_framework",
"//mediapipe/framework:packet",

View File

@ -122,11 +122,13 @@ class LandmarkLetterboxRemovalCalculator : public CalculatorBase {
NormalizedLandmark* new_landmark = output_landmarks.add_landmark();
const float new_x = (landmark.x() - left) / (1.0f - left_and_right);
const float new_y = (landmark.y() - top) / (1.0f - top_and_bottom);
const float new_z =
landmark.z() / (1.0f - left_and_right); // Scale Z coordinate as X.
new_landmark->set_x(new_x);
new_landmark->set_y(new_y);
// Keep z-coord as is.
new_landmark->set_z(landmark.z());
new_landmark->set_z(new_z);
// Keep visibility as is.
new_landmark->set_visibility(landmark.visibility());
}

View File

@ -123,11 +123,12 @@ class LandmarkProjectionCalculator : public CalculatorBase {
new_x = new_x * input_rect.width() + input_rect.x_center();
new_y = new_y * input_rect.height() + input_rect.y_center();
const float new_z =
landmark.z() * input_rect.width(); // Scale Z coordinate as X.
new_landmark->set_x(new_x);
new_landmark->set_y(new_y);
// Keep z-coord as is.
new_landmark->set_z(landmark.z());
new_landmark->set_z(new_z);
// Keep visibility as is.
new_landmark->set_visibility(landmark.visibility());
}

View File

@ -12,6 +12,7 @@
// See the License for the specific language governing permissions and
// limitations under the License.
#include "absl/container/node_hash_map.h"
#include "mediapipe/calculators/util/timed_box_list_id_to_label_calculator.pb.h"
#include "mediapipe/framework/calculator_framework.h"
#include "mediapipe/framework/packet.h"
@ -53,7 +54,7 @@ class TimedBoxListIdToLabelCalculator : public CalculatorBase {
::mediapipe::Status Process(CalculatorContext* cc) override;
private:
std::unordered_map<int, std::string> label_map_;
absl::node_hash_map<int, std::string> label_map_;
};
REGISTER_CALCULATOR(TimedBoxListIdToLabelCalculator);

View File

@ -1,4 +1 @@
MediaPipe Examples
==================
This directory contains MediaPipe Android example applications. Please see [src/java/com/google/mediapipe/apps/README.md](src/java/com/google/mediapipe/apps/README.md) for details.
This directory contains MediaPipe example applications for Android. Please see [Solutions](https://solutions.mediapipe.dev)for details.

View File

@ -1,7 +0,0 @@
tricorder: {
options: {
builder: {
config: "android_arm64"
}
}
}

View File

@ -83,7 +83,7 @@ android_binary(
manifest = "//mediapipe/examples/android/src/java/com/google/mediapipe/apps/basic:AndroidManifest.xml",
manifest_values = {
"applicationId": "com.google.mediapipe.apps.objectdetection3d",
"appName": "Object Detection 3D",
"appName": "Objectron",
"mainActivity": ".MainActivity",
"cameraFacingFront": "False",
"binaryGraphName": "object_detection_3d.binarypb",

View File

@ -1,113 +1 @@
**Hello World**
To build the "Hello World" example, use:
```
bazel build -c opt mediapipe/examples/desktop/hello_world:hello_world
```
and then run it using:
```
export GLOG_logtostderr=1
bazel-bin/mediapipe/examples/desktop/hello_world/hello_world
```
**TFlite Object Detection**
To build the object detection demo using a TFLite model on desktop, use:
```
bazel build -c opt mediapipe/examples/desktop/object_detection:object_detection_tflite --define MEDIAPIPE_DISABLE_GPU=1
```
and run it using:
```
export GLOG_logtostderr=1
bazel-bin/mediapipe/examples/desktop/object_detection/object_detection_tflite \
--calculator_graph_config_file=mediapipe/graphs/object_detection/object_detection_desktop_tflite_graph.pbtxt \
--input_side_packets=input_video_path=/path/to/input/file,output_video_path=/path/to/output/file
```
**TensorFlow Object Detection**
To build the object detection demo using a TensorFlow model on desktop, use:
```
export GLOG_logtostderr=1
bazel build -c opt mediapipe/examples/desktop/object_detection:object_detection_tensorflow \
--define MEDIAPIPE_DISABLE_GPU=1
```
and run it using:
```
export GLOG_logtostderr=1
bazel-bin/mediapipe/examples/desktop/object_detection/object_detection_tensorflow \
--calculator_graph_config_file=mediapipe/graphs/object_detection/object_detection_desktop_tensorflow_graph.pbtxt \
--input_side_packets=input_video_path=/path/to/input/file,output_video_path=/path/to/output/file
```
**TFlite Hand Detection**
To build the hand detection demo using a TFLite model on desktop, use:
```
bazel build -c opt mediapipe/examples/desktop/hand_tracking:hand_tracking_tflite --define MEDIAPIPE_DISABLE_GPU=1
```
and run it using:
```
export GLOG_logtostderr=1
bazel-bin/mediapipe/examples/desktop/hand_tracking/hand_tracking_tflite \
--calculator_graph_config_file=mediapipe/graphs/hand_tracking/hand_detection_desktop.pbtxt \
--input_side_packets=input_video_path=/path/to/input/file,output_video_path=/path/to/output/file
```
**TFlite Hand Tracking**
To build the hand tracking demo using a TFLite model on desktop, use:
```
bazel build -c opt mediapipe/examples/desktop/hand_tracking:hand_tracking_tflite --define MEDIAPIPE_DISABLE_GPU=1
```
and run it using:
```
export GLOG_logtostderr=1
bazel-bin/mediapipe/examples/desktop/hand_tracking/hand_tracking_tflite \
--calculator_graph_config_file=mediapipe/graphs/hand_tracking/hand_tracking_desktop.pbtxt \
--input_side_packets=input_video_path=/path/to/input/file,output_video_path=/path/to/output/file
```
**TFlite Multi-Hand Tracking**
To build the multi-hand tracking demo using a TFLite model on desktop, use:
```
bazel build -c opt mediapipe/examples/desktop/multi_hand_tracking:multi_hand_tracking_tflite --define MEDIAPIPE_DISABLE_GPU=1
```
and run it using:
```
export GLOG_logtostderr=1
bazel-bin/mediapipe/examples/desktop/multi_hand_tracking/multi_hand_tracking_tflite \
--calculator_graph_config_file=mediapipe/graphs/hand_tracking/multi_hand_tracking_desktop.pbtxt \
--input_side_packets=input_video_path=/path/to/input/file,output_video_path=/path/to/output/file
```
To change the number of hands to `x` in this application, change:
1. `min_size:x` in `CollectionHasMinSizeCalculatorOptions` in `mediapipe/graphs/hand_tracking/multi_hand_tracking_desktop.pbtxt`.
2. `max_vec_size:x` in `ClipVectorSizeCalculatorOptions` in `mediapipe/examples/dekstop/hand_tracking/subgraphs/multi_hand_detection_cpu.pbtxt`.
This directory contains MediaPipe example applications for desktop. Please see [Solutions](https://solutions.mediapipe.dev)for details.

View File

@ -62,8 +62,10 @@ cc_library(
"//mediapipe/examples/desktop/autoflip:autoflip_messages_cc_proto",
"//mediapipe/examples/desktop/autoflip/quality:kinematic_path_solver",
"//mediapipe/framework:calculator_framework",
"//mediapipe/framework/formats:detection_cc_proto",
"//mediapipe/framework/formats:image_frame",
"//mediapipe/framework/formats:image_frame_opencv",
"//mediapipe/framework/formats:location_data_cc_proto",
"//mediapipe/framework/formats:rect_cc_proto",
"//mediapipe/framework/port:ret_check",
"//mediapipe/framework/port:status",
],
@ -126,17 +128,20 @@ cc_test(
":content_zooming_calculator",
":content_zooming_calculator_cc_proto",
"//mediapipe/examples/desktop/autoflip:autoflip_messages_cc_proto",
"//mediapipe/examples/desktop/autoflip/quality:kinematic_path_solver",
"//mediapipe/framework:calculator_framework",
"//mediapipe/framework:calculator_runner",
"//mediapipe/framework/formats:detection_cc_proto",
"//mediapipe/framework/formats:image_frame",
"//mediapipe/framework/formats:image_frame_opencv",
"//mediapipe/framework/formats:location_data_cc_proto",
"//mediapipe/framework/formats:rect_cc_proto",
"//mediapipe/framework/port:benchmark",
"//mediapipe/framework/port:gtest_main",
"//mediapipe/framework/port:opencv_core",
"//mediapipe/framework/port:parse_text_proto",
"//mediapipe/framework/port:ret_check",
"//mediapipe/framework/port:status",
"@com_google_absl//absl/strings",
],
)

View File

@ -19,16 +19,20 @@
#include "mediapipe/examples/desktop/autoflip/calculators/content_zooming_calculator.pb.h"
#include "mediapipe/examples/desktop/autoflip/quality/kinematic_path_solver.h"
#include "mediapipe/framework/calculator_framework.h"
#include "mediapipe/framework/formats/detection.pb.h"
#include "mediapipe/framework/formats/image_frame.h"
#include "mediapipe/framework/formats/image_frame_opencv.h"
#include "mediapipe/framework/formats/location_data.pb.h"
#include "mediapipe/framework/formats/rect.pb.h"
#include "mediapipe/framework/port/ret_check.h"
#include "mediapipe/framework/port/status.h"
#include "mediapipe/framework/port/status_builder.h"
constexpr char kVideoFrame[] = "VIDEO";
constexpr char kVideoSize[] = "VIDEO_SIZE";
constexpr char kDetectionSet[] = "DETECTIONS";
constexpr char kSalientRegions[] = "SALIENT_REGIONS";
constexpr char kDetections[] = "DETECTIONS";
constexpr char kDetectedBorders[] = "BORDERS";
constexpr char kCropRect[] = "CROP_RECT";
// Field-of-view (degrees) of the camera's x-axis (width).
// TODO: Parameterize FOV based on camera specs.
constexpr float kWidthFieldOfView = 60;
@ -37,12 +41,12 @@ namespace mediapipe {
namespace autoflip {
// Content zooming calculator zooms in on content when a detection has
// "only_required" set true. It does this by computing the value of top/bottom
// borders to remove from the output and sends these to the
// SceneCroppingCalculator. When more than one detections are received the zoom
// box is calculated as the union of the detections. Typical applications
// include mobile makeover and autofliplive face reframing. Currently only
// supports y-dimension zooming.
// "only_required" set true or any raw detection input. It does this by
// computing the value of top/bottom borders to remove from the output and sends
// these to the SceneCroppingCalculator using BORDERS output or a full rect crop
// using CROP_RECT output. When more than one detections are received the
// zoom box is calculated as the union of the detections. Typical applications
// include mobile makeover and autofliplive face reframing.
class ContentZoomingCalculator : public CalculatorBase {
public:
ContentZoomingCalculator()
@ -56,26 +60,32 @@ class ContentZoomingCalculator : public CalculatorBase {
::mediapipe::Status Process(mediapipe::CalculatorContext* cc) override;
private:
// Converts bounds to tilt offset and height.
::mediapipe::Status ConvertToTiltZoom(float xmin, float xmax, float ymin,
float ymax, int* tilt_offset,
int* height);
// Converts bounds to tilt offset, pan offset and height.
::mediapipe::Status ConvertToPanTiltZoom(float xmin, float xmax, float ymin,
float ymax, int* tilt_offset,
int* pan_offset, int* height);
ContentZoomingCalculatorOptions options_;
// Detection frame width/height.
int frame_height_;
int frame_width_;
// Path solver used to smooth top/bottom border crop values.
std::unique_ptr<KinematicPathSolver> path_solver_height_;
std::unique_ptr<KinematicPathSolver> path_solver_width_;
std::unique_ptr<KinematicPathSolver> path_solver_offset_;
// Are parameters initialized.
bool initialized_;
// Stores the time of the last "only_required" input.
int64 last_only_required_detection_;
// Border values of last message with detection.
// Rect values of last message with detection(s).
int last_measured_height_;
int last_measured_x_offset_;
int last_measured_y_offset_;
// Min border values.
float min_height_value_;
// Target aspect ratio.
float target_aspect_;
// Max size of bounding box. If input/output aspect ratios are the same,
// will be 1.0. Else, will be less than 1.0 to prevent exceeding the size of
// the image in either dimension.
float max_frame_value_;
};
REGISTER_CALCULATOR(ContentZoomingCalculator);
@ -92,8 +102,18 @@ REGISTER_CALCULATOR(ContentZoomingCalculator);
return ::mediapipe::UnknownErrorBuilder(MEDIAPIPE_LOC)
<< "Input VIDEO or VIDEO_SIZE must be provided.";
}
cc->Inputs().Tag(kDetectionSet).Set<DetectionSet>();
cc->Outputs().Tag(kDetectedBorders).Set<StaticFeatures>();
if (cc->Inputs().HasTag(kSalientRegions)) {
cc->Inputs().Tag(kSalientRegions).Set<DetectionSet>();
}
if (cc->Inputs().HasTag(kDetections)) {
cc->Inputs().Tag(kDetections).Set<std::vector<mediapipe::Detection>>();
}
if (cc->Outputs().HasTag(kDetectedBorders)) {
cc->Outputs().Tag(kDetectedBorders).Set<StaticFeatures>();
}
if (cc->Outputs().HasTag(kCropRect)) {
cc->Outputs().Tag(kCropRect).Set<mediapipe::Rect>();
}
return ::mediapipe::OkStatus();
}
@ -108,29 +128,38 @@ REGISTER_CALCULATOR(ContentZoomingCalculator);
if (options_.has_min_motion_to_reframe()) {
return ::mediapipe::UnknownErrorBuilder(MEDIAPIPE_LOC)
<< "Deprecated min_motion_to_reframe was set, please set "
"in kinematic_options_zoom and kinematic_options_tilt directly.";
"in kinematic_options_zoom and kinematic_options_tilt "
"directly.";
}
return ::mediapipe::OkStatus();
}
::mediapipe::Status ContentZoomingCalculator::ConvertToTiltZoom(
::mediapipe::Status ContentZoomingCalculator::ConvertToPanTiltZoom(
float xmin, float xmax, float ymin, float ymax, int* tilt_offset,
int* height) {
int* pan_offset, int* height) {
// Find center of the y-axis offset (for tilt control).
float y_center = ymin + (ymax - ymin) / 2;
// Find center of the x-axis offset (for pan control).
float x_center = xmin + (xmax - xmin) / 2;
// Find size and apply scale factor to y-axis.
float fit_size = fmax((ymax - ymin) / options_.scale_factor(), xmax - xmin);
// Apply min zoom for cases where the target size is wider than input frame
// size.
fit_size = fmin(min_height_value_, fit_size);
// Apply max frame for cases where the target size is different than input
// frame size.
fit_size = fmin(max_frame_value_, fit_size);
// Prevent box from extending beyond the image.
if (y_center - fit_size / 2 < 0) {
y_center = fit_size / 2;
} else if (y_center + fit_size / 2 > 1) {
y_center = 1 - fit_size / 2;
}
if (x_center - fit_size / 2 < 0) {
x_center = fit_size / 2;
} else if (x_center + fit_size / 2 > 1) {
x_center = 1 - fit_size / 2;
}
// Scale to pixel coordinates.
*tilt_offset = frame_height_ * y_center;
*pan_offset = frame_width_ * x_center;
*height = frame_height_ * fit_size;
return ::mediapipe::OkStatus();
}
@ -151,6 +180,20 @@ namespace {
return ::mediapipe::OkStatus();
}
::mediapipe::Status UpdateRanges(const mediapipe::Detection& detection,
float* xmin, float* xmax, float* ymin,
float* ymax) {
RET_CHECK(detection.location_data().format() ==
mediapipe::LocationData::RELATIVE_BOUNDING_BOX)
<< "Face detection input is lacking required relative_bounding_box()";
const auto& location = detection.location_data().relative_bounding_box();
*xmin = fmin(*xmin, location.xmin());
*xmax = fmax(*xmax, location.xmin() + location.width());
*ymin = fmin(*ymin, location.ymin());
*ymax = fmax(*ymax, location.ymin() + location.height());
return ::mediapipe::OkStatus();
}
void MakeStaticFeatures(const int top_border, const int bottom_border,
const int frame_width, const int frame_height,
StaticFeatures* static_feature) {
@ -173,10 +216,8 @@ void MakeStaticFeatures(const int top_border, const int bottom_border,
::mediapipe::Status ContentZoomingCalculator::Process(
mediapipe::CalculatorContext* cc) {
if (cc->Inputs().HasTag(kVideoFrame)) {
cv::Mat frame = mediapipe::formats::MatView(
&cc->Inputs().Tag(kVideoFrame).Get<ImageFrame>());
frame_width_ = frame.cols;
frame_height_ = frame.rows;
frame_width_ = cc->Inputs().Tag(kVideoFrame).Get<ImageFrame>().Width();
frame_height_ = cc->Inputs().Tag(kVideoFrame).Get<ImageFrame>().Height();
} else if (cc->Inputs().HasTag(kVideoSize)) {
frame_width_ =
cc->Inputs().Tag(kVideoSize).Get<std::pair<int, int>>().first;
@ -191,10 +232,14 @@ void MakeStaticFeatures(const int top_border, const int bottom_border,
path_solver_height_ = std::make_unique<KinematicPathSolver>(
options_.kinematic_options_zoom(), 0, frame_height_,
static_cast<float>(frame_width_) / kWidthFieldOfView);
path_solver_width_ = std::make_unique<KinematicPathSolver>(
options_.kinematic_options_pan(), 0, frame_width_,
static_cast<float>(frame_width_) / kWidthFieldOfView);
path_solver_offset_ = std::make_unique<KinematicPathSolver>(
options_.kinematic_options_tilt(), 0, frame_height_,
static_cast<float>(frame_width_) / kWidthFieldOfView);
min_height_value_ = 1.0;
max_frame_value_ = 1.0;
target_aspect_ = frame_width_ / static_cast<float>(frame_height_);
// If target size is set and wider than input aspect, make sure to always
// crop the min required amount.
if (options_.has_target_size()) {
@ -203,75 +248,107 @@ void MakeStaticFeatures(const int top_border, const int bottom_border,
RET_CHECK_GT(options_.target_size().height(), 0)
<< "Provided target height not valid.";
float input_aspect = frame_width_ / static_cast<float>(frame_height_);
float target_aspect = options_.target_size().width() /
static_cast<float>(options_.target_size().height());
min_height_value_ =
(input_aspect < target_aspect) ? input_aspect / target_aspect : 1.0;
target_aspect_ = options_.target_size().width() /
static_cast<float>(options_.target_size().height());
max_frame_value_ = std::min(input_aspect / target_aspect_,
target_aspect_ / input_aspect);
}
last_measured_height_ = min_height_value_ * frame_height_;
last_measured_height_ = max_frame_value_ * frame_height_;
last_measured_x_offset_ = target_aspect_ * frame_width_;
last_measured_y_offset_ = frame_width_ / 2;
initialized_ = true;
}
auto detection_set = cc->Inputs().Tag(kDetectionSet).Get<DetectionSet>();
bool only_required_found = false;
// Compute the box that contains all "is_required" detections.
float xmin = 1, ymin = 1, xmax = 0, ymax = 0;
for (const auto& region : detection_set.detections()) {
if (!region.only_required()) {
continue;
if (cc->Inputs().HasTag(kSalientRegions)) {
auto detection_set = cc->Inputs().Tag(kSalientRegions).Get<DetectionSet>();
for (const auto& region : detection_set.detections()) {
if (!region.only_required()) {
continue;
}
only_required_found = true;
MP_RETURN_IF_ERROR(UpdateRanges(region, &xmin, &xmax, &ymin, &ymax));
}
}
if (cc->Inputs().HasTag(kDetections)) {
auto raw_detections =
cc->Inputs().Tag(kDetections).Get<std::vector<mediapipe::Detection>>();
for (const auto& detection : raw_detections) {
only_required_found = true;
MP_RETURN_IF_ERROR(UpdateRanges(detection, &xmin, &xmax, &ymin, &ymax));
}
only_required_found = true;
MP_RETURN_IF_ERROR(UpdateRanges(region, &xmin, &xmax, &ymin, &ymax));
}
// Convert bounds to tilt/zoom and in pixel coordinates.
int offset, height;
MP_RETURN_IF_ERROR(
ConvertToTiltZoom(xmin, xmax, ymin, ymax, &offset, &height));
int offset_y, height, offset_x;
MP_RETURN_IF_ERROR(ConvertToPanTiltZoom(xmin, xmax, ymin, ymax, &offset_y,
&offset_x, &height));
if (only_required_found) {
// A only required detection was found.
last_only_required_detection_ = cc->InputTimestamp().Microseconds();
last_measured_height_ = height;
last_measured_y_offset_ = offset;
last_measured_x_offset_ = offset_x;
last_measured_y_offset_ = offset_y;
} else if (cc->InputTimestamp().Microseconds() -
last_only_required_detection_ >=
options_.us_before_zoomout()) {
// No only_require detections found within salient regions packets arriving
// since us_before_zoomout duration.
height = min_height_value_ * frame_height_;
offset = frame_height_ / 2;
// No only_require detections found within salient regions packets
// arriving since us_before_zoomout duration.
height = max_frame_value_ * frame_height_;
offset_x = (target_aspect_ * height) / 2;
offset_y = frame_height_ / 2;
} else {
// No only detection found but using last detection due to
// duration_before_zoomout_us setting.
height = last_measured_height_;
offset = last_measured_y_offset_;
offset_x = last_measured_x_offset_;
offset_y = last_measured_y_offset_;
}
// Compute smoothed camera paths.
MP_RETURN_IF_ERROR(path_solver_height_->AddObservation(
height, cc->InputTimestamp().Microseconds()));
MP_RETURN_IF_ERROR(path_solver_width_->AddObservation(
offset_x, cc->InputTimestamp().Microseconds()));
MP_RETURN_IF_ERROR(path_solver_offset_->AddObservation(
offset, cc->InputTimestamp().Microseconds()));
int path_size;
MP_RETURN_IF_ERROR(path_solver_height_->GetState(&path_size));
int path_offset;
MP_RETURN_IF_ERROR(path_solver_offset_->GetState(&path_offset));
offset_y, cc->InputTimestamp().Microseconds()));
int path_height;
MP_RETURN_IF_ERROR(path_solver_height_->GetState(&path_height));
int path_offset_x;
MP_RETURN_IF_ERROR(path_solver_width_->GetState(&path_offset_x));
int path_offset_y;
MP_RETURN_IF_ERROR(path_solver_offset_->GetState(&path_offset_y));
// Convert to top/bottom borders to remove.
int path_top = path_offset - path_size / 2;
int path_bottom = frame_height_ - (path_offset + path_size / 2);
int path_top = path_offset_y - path_height / 2;
int path_bottom = frame_height_ - (path_offset_y + path_height / 2);
// Transmit result downstream.
std::unique_ptr<StaticFeatures> features =
absl::make_unique<StaticFeatures>();
MakeStaticFeatures(path_top, path_bottom, frame_width_, frame_height_,
features.get());
cc->Outputs()
.Tag(kDetectedBorders)
.AddPacket(Adopt(features.release()).At(cc->InputTimestamp()));
// Transmit result downstream to scenecroppingcalculator.
if (cc->Outputs().HasTag(kDetectedBorders)) {
std::unique_ptr<StaticFeatures> features =
absl::make_unique<StaticFeatures>();
MakeStaticFeatures(path_top, path_bottom, frame_width_, frame_height_,
features.get());
cc->Outputs()
.Tag(kDetectedBorders)
.AddPacket(Adopt(features.release()).At(cc->InputTimestamp()));
}
// Transmit downstream to glcroppingcalculator.
if (cc->Outputs().HasTag(kCropRect)) {
auto gpu_rect = absl::make_unique<mediapipe::Rect>();
gpu_rect->set_x_center(path_offset_x);
gpu_rect->set_width(path_height * target_aspect_);
gpu_rect->set_y_center(path_offset_y);
gpu_rect->set_height(path_height);
cc->Outputs().Tag(kCropRect).Add(gpu_rect.release(),
Timestamp(cc->InputTimestamp()));
}
return ::mediapipe::OkStatus();
}

View File

@ -32,6 +32,8 @@ message ContentZoomingCalculatorOptions {
optional KinematicOptions kinematic_options_zoom = 6;
// Kinematic options for tilt (y-axis reframing.)
optional KinematicOptions kinematic_options_tilt = 7;
// Kinematic options for pan (x-axis reframing.)
optional KinematicOptions kinematic_options_pan = 10;
// Duration (in MicroSeconds) before returning to fully zoomed out position
// when no "only_required" frames are received.
optional int64 us_before_zoomout = 9 [default = 1000000];

View File

@ -16,10 +16,14 @@
#include "mediapipe/examples/desktop/autoflip/autoflip_messages.pb.h"
#include "mediapipe/examples/desktop/autoflip/calculators/content_zooming_calculator.pb.h"
#include "mediapipe/examples/desktop/autoflip/quality/kinematic_path_solver.h"
#include "mediapipe/framework/calculator_framework.h"
#include "mediapipe/framework/calculator_runner.h"
#include "mediapipe/framework/formats/detection.pb.h"
#include "mediapipe/framework/formats/image_frame.h"
#include "mediapipe/framework/formats/image_frame_opencv.h"
#include "mediapipe/framework/formats/location_data.pb.h"
#include "mediapipe/framework/formats/rect.pb.h"
#include "mediapipe/framework/port/benchmark.h"
#include "mediapipe/framework/port/gmock.h"
#include "mediapipe/framework/port/gtest.h"
@ -36,14 +40,14 @@ namespace {
const char kConfigA[] = R"(
calculator: "ContentZoomingCalculator"
input_stream: "VIDEO:camera_frames"
input_stream: "DETECTIONS:detection_set"
input_stream: "SALIENT_REGIONS:detection_set"
output_stream: "BORDERS:borders"
)";
const char kConfigB[] = R"(
calculator: "ContentZoomingCalculator"
input_stream: "VIDEO:camera_frames"
input_stream: "DETECTIONS:detection_set"
input_stream: "SALIENT_REGIONS:detection_set"
output_stream: "BORDERS:borders"
options: {
[mediapipe.autoflip.ContentZoomingCalculatorOptions.ext]: {
@ -58,10 +62,17 @@ const char kConfigB[] = R"(
const char kConfigC[] = R"(
calculator: "ContentZoomingCalculator"
input_stream: "VIDEO_SIZE:size"
input_stream: "DETECTIONS:detection_set"
input_stream: "SALIENT_REGIONS:detection_set"
output_stream: "BORDERS:borders"
)";
const char kConfigD[] = R"(
calculator: "ContentZoomingCalculator"
input_stream: "VIDEO_SIZE:size"
input_stream: "DETECTIONS:detections"
output_stream: "CROP_RECT:rect"
)";
void CheckBorder(const StaticFeatures& static_features, int width, int height,
int top_border, int bottom_border) {
ASSERT_EQ(2, static_features.border().size());
@ -80,6 +91,43 @@ void CheckBorder(const StaticFeatures& static_features, int width, int height,
EXPECT_EQ(Border::BOTTOM, part.relative_position());
}
void AddDetection(const cv::Rect_<float>& position, const int64 time,
CalculatorRunner* runner) {
auto detections = std::make_unique<std::vector<mediapipe::Detection>>();
mediapipe::Detection detection;
detection.mutable_location_data()->set_format(
mediapipe::LocationData::RELATIVE_BOUNDING_BOX);
detection.mutable_location_data()
->mutable_relative_bounding_box()
->set_height(position.height);
detection.mutable_location_data()->mutable_relative_bounding_box()->set_width(
position.width);
detection.mutable_location_data()->mutable_relative_bounding_box()->set_xmin(
position.x);
detection.mutable_location_data()->mutable_relative_bounding_box()->set_ymin(
position.y);
detections->push_back(detection);
runner->MutableInputs()
->Tag("DETECTIONS")
.packets.push_back(Adopt(detections.release()).At(Timestamp(time)));
auto input_size = ::absl::make_unique<std::pair<int, int>>(1000, 1000);
runner->MutableInputs()
->Tag("VIDEO_SIZE")
.packets.push_back(Adopt(input_size.release()).At(Timestamp(time)));
}
void CheckCropRect(const int x_center, const int y_center, const int width,
const int height, const int frame_number,
const std::vector<Packet>& output_packets) {
ASSERT_GT(output_packets.size(), frame_number);
const auto& rect = output_packets[frame_number].Get<mediapipe::Rect>();
EXPECT_EQ(rect.x_center(), x_center);
EXPECT_EQ(rect.y_center(), y_center);
EXPECT_EQ(rect.width(), width);
EXPECT_EQ(rect.height(), height);
}
TEST(ContentZoomingCalculatorTest, ZoomTest) {
auto runner = ::absl::make_unique<CalculatorRunner>(
ParseTextProtoOrDie<CalculatorGraphConfig::Node>(kConfigA));
@ -98,7 +146,7 @@ TEST(ContentZoomingCalculatorTest, ZoomTest) {
Adopt(input_frame.release()).At(Timestamp(0)));
runner->MutableInputs()
->Tag("DETECTIONS")
->Tag("SALIENT_REGIONS")
.packets.push_back(Adopt(detection_set.release()).At(Timestamp(0)));
// Run the calculator.
@ -111,6 +159,66 @@ TEST(ContentZoomingCalculatorTest, ZoomTest) {
CheckBorder(static_features, 1000, 1000, 495, 395);
}
TEST(ContentZoomingCalculatorTest, ZoomTestFullPTZ) {
auto runner = ::absl::make_unique<CalculatorRunner>(
ParseTextProtoOrDie<CalculatorGraphConfig::Node>(kConfigD));
AddDetection(cv::Rect_<float>(.4, .5, .1, .1), 0, runner.get());
MP_ASSERT_OK(runner->Run());
CheckCropRect(450, 550, 111, 111, 0,
runner->Outputs().Tag("CROP_RECT").packets);
}
TEST(ContentZoomingCalculatorTest, PanConfig) {
auto config = ParseTextProtoOrDie<CalculatorGraphConfig::Node>(kConfigD);
auto* options = config.mutable_options()->MutableExtension(
ContentZoomingCalculatorOptions::ext);
options->mutable_kinematic_options_pan()->set_min_motion_to_reframe(0.0);
options->mutable_kinematic_options_tilt()->set_min_motion_to_reframe(5.0);
options->mutable_kinematic_options_zoom()->set_min_motion_to_reframe(5.0);
auto runner = ::absl::make_unique<CalculatorRunner>(config);
AddDetection(cv::Rect_<float>(.4, .5, .1, .1), 0, runner.get());
AddDetection(cv::Rect_<float>(.45, .55, .15, .15), 1000000, runner.get());
MP_ASSERT_OK(runner->Run());
CheckCropRect(450, 550, 111, 111, 0,
runner->Outputs().Tag("CROP_RECT").packets);
CheckCropRect(488, 550, 111, 111, 1,
runner->Outputs().Tag("CROP_RECT").packets);
}
TEST(ContentZoomingCalculatorTest, TiltConfig) {
auto config = ParseTextProtoOrDie<CalculatorGraphConfig::Node>(kConfigD);
auto* options = config.mutable_options()->MutableExtension(
ContentZoomingCalculatorOptions::ext);
options->mutable_kinematic_options_pan()->set_min_motion_to_reframe(5.0);
options->mutable_kinematic_options_tilt()->set_min_motion_to_reframe(0.0);
options->mutable_kinematic_options_zoom()->set_min_motion_to_reframe(5.0);
auto runner = ::absl::make_unique<CalculatorRunner>(config);
AddDetection(cv::Rect_<float>(.4, .5, .1, .1), 0, runner.get());
AddDetection(cv::Rect_<float>(.45, .55, .15, .15), 1000000, runner.get());
MP_ASSERT_OK(runner->Run());
CheckCropRect(450, 550, 111, 111, 0,
runner->Outputs().Tag("CROP_RECT").packets);
CheckCropRect(450, 588, 111, 111, 1,
runner->Outputs().Tag("CROP_RECT").packets);
}
TEST(ContentZoomingCalculatorTest, ZoomConfig) {
auto config = ParseTextProtoOrDie<CalculatorGraphConfig::Node>(kConfigD);
auto* options = config.mutable_options()->MutableExtension(
ContentZoomingCalculatorOptions::ext);
options->mutable_kinematic_options_pan()->set_min_motion_to_reframe(5.0);
options->mutable_kinematic_options_tilt()->set_min_motion_to_reframe(5.0);
options->mutable_kinematic_options_zoom()->set_min_motion_to_reframe(0.0);
auto runner = ::absl::make_unique<CalculatorRunner>(config);
AddDetection(cv::Rect_<float>(.4, .5, .1, .1), 0, runner.get());
AddDetection(cv::Rect_<float>(.45, .55, .15, .15), 1000000, runner.get());
MP_ASSERT_OK(runner->Run());
CheckCropRect(450, 550, 111, 111, 0,
runner->Outputs().Tag("CROP_RECT").packets);
CheckCropRect(450, 550, 139, 139, 1,
runner->Outputs().Tag("CROP_RECT").packets);
}
TEST(ContentZoomingCalculatorTest, MinAspectBorderValues) {
auto runner = ::absl::make_unique<CalculatorRunner>(
ParseTextProtoOrDie<CalculatorGraphConfig::Node>(kConfigB));
@ -129,7 +237,7 @@ TEST(ContentZoomingCalculatorTest, MinAspectBorderValues) {
Adopt(input_frame.release()).At(Timestamp(0)));
runner->MutableInputs()
->Tag("DETECTIONS")
->Tag("SALIENT_REGIONS")
.packets.push_back(Adopt(detection_set.release()).At(Timestamp(0)));
// Run the calculator.
@ -166,7 +274,7 @@ TEST(ContentZoomingCalculatorTest, TwoFacesWide) {
Adopt(input_frame.release()).At(Timestamp(0)));
runner->MutableInputs()
->Tag("DETECTIONS")
->Tag("SALIENT_REGIONS")
.packets.push_back(Adopt(detection_set.release()).At(Timestamp(0)));
// Run the calculator.
@ -191,7 +299,7 @@ TEST(ContentZoomingCalculatorTest, NoDetectionOnInit) {
Adopt(input_frame.release()).At(Timestamp(0)));
runner->MutableInputs()
->Tag("DETECTIONS")
->Tag("SALIENT_REGIONS")
.packets.push_back(Adopt(detection_set.release()).At(Timestamp(0)));
// Run the calculator.
@ -223,7 +331,7 @@ TEST(ContentZoomingCalculatorTest, ZoomTestPairSize) {
.packets.push_back(Adopt(input_size.release()).At(Timestamp(0)));
runner->MutableInputs()
->Tag("DETECTIONS")
->Tag("SALIENT_REGIONS")
.packets.push_back(Adopt(detection_set.release()).At(Timestamp(0)));
// Run the calculator.

View File

@ -37,7 +37,7 @@ node {
output_stream: "TENSORS:detection_tensors"
options: {
[mediapipe.TfLiteInferenceCalculatorOptions.ext] {
model_path: "face_detection_front.tflite"
model_path: "mediapipe/models/face_detection_front.tflite"
}
}
}
@ -118,7 +118,7 @@ node {
output_stream: "labeled_detections"
options: {
[mediapipe.DetectionLabelIdToTextCalculatorOptions.ext] {
label_map_path: "face_detection_front_labelmap.txt"
label_map_path: "mediapipe/models/face_detection_front_labelmap.txt"
}
}
}

View File

@ -1,18 +1 @@
This directory contains example MediaPipe applications on iOS.
| Use Case | Directory |
|---------------------------------------|:-----------------------------------:|
| Edge Detection on GPU | edgedetection |
| Face Detection on CPU | facedetectioncpu |
| Face Detection on GPU | facedetectiongpu |
| Object Detection on CPU | objectdetectioncpu |
| Object Detection on GPU | objectdetectiongpu |
| Hand Detection on GPU | handdetectiongpu |
| Hand Tracking on GPU | handtrackinggpu |
For instance, to build an example app for face detection on CPU, run:
```bash
bazel build -c opt --config=ios_arm64 --xcode_version=$XCODE_VERSION --cxxopt='-std=c++14' mediapipe/examples/ios/facedetectioncpu:FaceDetectionCpuApp
```
(Note: with your own $XCODE_VERSION)
This directory contains MediaPipe example applications for iOS. Please see [Solutions](https://solutions.mediapipe.dev)for details.

View File

@ -21,6 +21,11 @@ load(
"ios_application",
)
alias(
name = "edgedetectiongpu",
actual = "EdgeDetectionGpuApp",
)
ios_application(
name = "EdgeDetectionGpuApp",
bundle_id = "com.google.mediapipe.EdgeDetectionGpu",

View File

@ -21,6 +21,11 @@ load(
"ios_application",
)
alias(
name = "facedetectioncpu",
actual = "FaceDetectionCpuApp",
)
ios_application(
name = "FaceDetectionCpuApp",
bundle_id = "com.google.mediapipe.FaceDetectionCpu",

View File

@ -21,6 +21,11 @@ load(
"ios_application",
)
alias(
name = "facedetectiongpu",
actual = "FaceDetectionGpuApp",
)
ios_application(
name = "FaceDetectionGpuApp",
bundle_id = "com.google.mediapipe.FaceDetectionGpu",

View File

@ -21,6 +21,11 @@ licenses(["notice"]) # Apache 2.0
MIN_IOS_VERSION = "10.0"
alias(
name = "facemeshgpu",
actual = "FaceMeshGpuApp",
)
ios_application(
name = "FaceMeshGpuApp",
bundle_id = "com.google.mediapipe.FaceMeshGpu",

View File

@ -21,6 +21,11 @@ load(
"ios_application",
)
alias(
name = "handdetectiongpu",
actual = "HandDetectionGpuApp",
)
ios_application(
name = "HandDetectionGpuApp",
bundle_id = "com.google.mediapipe.HandDetectionGpu",

View File

@ -21,6 +21,11 @@ licenses(["notice"]) # Apache 2.0
MIN_IOS_VERSION = "10.0"
alias(
name = "handtrackinggpu",
actual = "HandTrackingGpuApp",
)
ios_application(
name = "HandTrackingGpuApp",
bundle_id = "com.google.mediapipe.HandTrackingGpu",

View File

@ -21,6 +21,11 @@ licenses(["notice"]) # Apache 2.0
MIN_IOS_VERSION = "10.0"
alias(
name = "multihandtrackinggpu",
actual = "MultiHandTrackingGpuApp",
)
ios_application(
name = "MultiHandTrackingGpuApp",
bundle_id = "com.google.mediapipe.MultiHandTrackingGpu",

View File

@ -21,6 +21,11 @@ load(
"ios_application",
)
alias(
name = "objectdetectioncpu",
actual = "ObjectDetectionCpuApp",
)
ios_application(
name = "ObjectDetectionCpuApp",
bundle_id = "com.google.mediapipe.ObjectDetectionCpu",

View File

@ -13,6 +13,7 @@
// limitations under the License.
#import "AppDelegate.h"
#import "ViewController.h"
@interface AppDelegate ()
@ -22,7 +23,14 @@
- (BOOL)application:(UIApplication *)application
didFinishLaunchingWithOptions:(NSDictionary *)launchOptions {
// Override point for customization after application launch.
ViewController *viewController = (ViewController *)self.window.rootViewController;
NSURL *url = [launchOptions objectForKey:UIApplicationLaunchOptionsURLKey];
// Unattended testing on Firebase is enabled by custom URL schema.
if ([url.scheme isEqualToString:@"firebase-game-loop"]) {
[viewController setSourceMode:MediaPipeDemoSourceVideo];
} else {
[viewController setSourceMode:MediaPipeDemoSourceBackCamera];
}
return YES;
}

View File

@ -21,6 +21,11 @@ load(
"ios_application",
)
alias(
name = "objectdetectiongpu",
actual = "ObjectDetectionGpuApp",
)
ios_application(
name = "ObjectDetectionGpuApp",
bundle_id = "com.google.mediapipe.ObjectDetectionGpu",

View File

@ -38,5 +38,18 @@
<array>
<string>UIInterfaceOrientationPortrait</string>
</array>
<key>CFBundleURLTypes</key>
<array>
<dict>
<key>CFBundleURLName</key>
<string>com.google.firebase</string>
<key>CFBundleTypeRole</key>
<string>Editor</string>
<key>CFBundleURLSchemes</key>
<array>
<string>firebase-game-loop</string>
</array>
</dict>
</array>
</dict>
</plist>

View File

@ -14,6 +14,11 @@
#import <UIKit/UIKit.h>
@interface ViewController : UIViewController
typedef NS_ENUM(NSInteger, MediaPipeDemoSourceMode) {
MediaPipeDemoSourceBackCamera,
MediaPipeDemoSourceVideo
};
@interface ViewController : UIViewController
- (void)setSourceMode:(MediaPipeDemoSourceMode)mode;
@end

View File

@ -17,6 +17,7 @@
#import "mediapipe/objc/MPPGraph.h"
#import "mediapipe/objc/MPPCameraInputSource.h"
#import "mediapipe/objc/MPPLayerRenderer.h"
#import "mediapipe/objc/MPPPlayerInputSource.h"
static NSString* const kGraphName = @"mobile_gpu";
@ -35,6 +36,8 @@ static const char* kVideoQueueLabel = "com.google.mediapipe.example.videoQueue";
@implementation ViewController {
/// Handles camera access via AVCaptureSession library.
MPPCameraInputSource* _cameraSource;
MPPPlayerInputSource* _videoSource;
MediaPipeDemoSourceMode _sourceMode;
/// Inform the user when camera is unavailable.
IBOutlet UILabel* _noCameraLabel;
@ -47,6 +50,10 @@ static const char* kVideoQueueLabel = "com.google.mediapipe.example.videoQueue";
dispatch_queue_t _videoQueue;
}
- (void)setSourceMode:(MediaPipeDemoSourceMode)mode {
_sourceMode = mode;
}
#pragma mark - Cleanup methods
- (void)dealloc {
@ -97,13 +104,6 @@ static const char* kVideoQueueLabel = "com.google.mediapipe.example.videoQueue";
DISPATCH_QUEUE_SERIAL, QOS_CLASS_USER_INTERACTIVE, /*relative_priority=*/0);
_videoQueue = dispatch_queue_create(kVideoQueueLabel, qosAttribute);
_cameraSource = [[MPPCameraInputSource alloc] init];
[_cameraSource setDelegate:self queue:_videoQueue];
_cameraSource.sessionPreset = AVCaptureSessionPresetHigh;
_cameraSource.cameraPosition = AVCaptureDevicePositionBack;
// The frame's native format is rotated with respect to the portrait orientation.
_cameraSource.orientation = AVCaptureVideoOrientationPortrait;
self.mediapipeGraph = [[self class] loadGraphFromResource:kGraphName];
self.mediapipeGraph.delegate = self;
// Set maxFramesInFlight to a small value to avoid memory contention for real-time processing.
@ -119,27 +119,43 @@ static const char* kVideoQueueLabel = "com.google.mediapipe.example.videoQueue";
- (void)viewWillAppear:(BOOL)animated {
[super viewWillAppear:animated];
[_cameraSource requestCameraAccessWithCompletionHandler:^void(BOOL granted) {
if (granted) {
[self startGraphAndCamera];
dispatch_async(dispatch_get_main_queue(), ^{
_noCameraLabel.hidden = YES;
});
}
}];
}
- (void)startGraphAndCamera {
// Start running self.mediapipeGraph.
NSError* error;
if (![self.mediapipeGraph startWithError:&error]) {
NSLog(@"Failed to start graph: %@", error);
}
// Start fetching frames from the camera.
dispatch_async(_videoQueue, ^{
[_cameraSource start];
});
switch (_sourceMode) {
case MediaPipeDemoSourceVideo: {
AVAsset* video =
[AVAsset assetWithURL:[[NSBundle mainBundle] URLForResource:@"object_detection"
withExtension:@"mov"]];
_videoSource = [[MPPPlayerInputSource alloc] initWithAVAsset:video];
[_videoSource setDelegate:self queue:_videoQueue];
dispatch_async(_videoQueue, ^{
[_videoSource start];
});
break;
}
case MediaPipeDemoSourceBackCamera:
_cameraSource = [[MPPCameraInputSource alloc] init];
[_cameraSource setDelegate:self queue:_videoQueue];
_cameraSource.sessionPreset = AVCaptureSessionPresetHigh;
_cameraSource.cameraPosition = AVCaptureDevicePositionBack;
// The frame's native format is rotated with respect to the portrait orientation.
_cameraSource.orientation = AVCaptureVideoOrientationPortrait;
[_cameraSource requestCameraAccessWithCompletionHandler:^void(BOOL granted) {
if (granted) {
dispatch_async(_videoQueue, ^{
[_cameraSource start];
});
dispatch_async(dispatch_get_main_queue(), ^{
_noCameraLabel.hidden = YES;
});
}
}];
break;
}
}
#pragma mark - MPPGraphDelegate methods
@ -164,7 +180,7 @@ static const char* kVideoQueueLabel = "com.google.mediapipe.example.videoQueue";
- (void)processVideoFrame:(CVPixelBufferRef)imageBuffer
timestamp:(CMTime)timestamp
fromSource:(MPPInputSource*)source {
if (source != _cameraSource) {
if (source != _cameraSource && source != _videoSource) {
NSLog(@"Unknown source: %@", source);
return;
}

View File

@ -36,7 +36,7 @@ exports_files([
mediapipe_proto_library(
name = "calculator_proto",
srcs = ["calculator.proto"],
visibility = [":mediapipe_internal"],
visibility = ["//visibility:public"],
deps = [
"//mediapipe/framework:calculator_options_proto",
"//mediapipe/framework:mediapipe_options_proto",
@ -68,7 +68,7 @@ mediapipe_proto_library(
mediapipe_proto_library(
name = "calculator_profile_proto",
srcs = ["calculator_profile.proto"],
visibility = [":mediapipe_internal"],
visibility = ["//visibility:public"],
deps = [
"//mediapipe/framework:calculator_options_proto",
"//mediapipe/framework:calculator_proto",
@ -830,6 +830,8 @@ cc_library(
":port",
":timestamp",
":type_map",
"//mediapipe/framework/deps:no_destructor",
"//mediapipe/framework/deps:registration",
"//mediapipe/framework/port:core_proto",
"//mediapipe/framework/port:integral_types",
"//mediapipe/framework/port:logging",
@ -1524,6 +1526,21 @@ cc_test(
],
)
cc_test(
name = "packet_registration_test",
size = "small",
srcs = ["packet_registration_test.cc"],
deps = [
":calculator_framework",
":packet",
":packet_test_cc_proto",
":type_map",
"//mediapipe/framework/port:core_proto",
"//mediapipe/framework/port:gtest_main",
"@com_google_absl//absl/strings",
],
)
cc_test(
name = "packet_generator_test",
size = "small",

View File

@ -113,8 +113,11 @@ class CalculatorContract {
// calculations should use SetProcessTimestampBounds.
// When true, Process is called for every new timestamp bound, with or without
// new packets. A call to Process with only an input timestamp bound is
// new packets. A call to Process with only an input timestamp bound is
// normally used to compute a new output timestamp bound.
// NOTE: Also, when true, Process is called when input streams become done,
// which means, Process needs to handle input streams in "done" state.
// (Usually, by closing calculators' outputs where and when appropriate.)
void SetProcessTimestampBounds(bool process_timestamps) {
process_timestamps_ = process_timestamps;
}

View File

@ -91,6 +91,9 @@ typedef ::mediapipe::StatusOr<OutputStreamPoller> StatusOrPoller;
// {{"video_id", mediapipe::MakePacket<std::string>("Ex-uGhDzue4")}}));
// // See mediapipe/framework/graph_runner.h for an interface
// // to insert and extract packets from a graph as it runs.
// // Once it is done using the graph, close its streams and wait till done.
// MP_RETURN_IF_ERROR(graph->CloseAllInputStreams());
// MP_RETURN_IF_ERROR(graph->WaitUntilDone());
class CalculatorGraph {
public:
// Defines possible modes for adding a packet to a graph input stream.
@ -157,8 +160,9 @@ class CalculatorGraph {
std::function<::mediapipe::Status(const Packet&)> packet_callback);
// Adds an OutputStreamPoller for a stream. This provides a synchronous,
// polling API for accessing a stream's output. For asynchronous output, use
// ObserveOutputStream. See also the helpers in tool/sink.h.
// polling API for accessing a stream's output. Should only be called before
// Run() or StartRun(). For asynchronous output, use ObserveOutputStream. See
// also the helpers in tool/sink.h.
StatusOrPoller AddOutputStreamPoller(const std::string& stream_name);
// Gets output side packet by name after the graph is done. However, base
@ -300,6 +304,13 @@ class CalculatorGraph {
void RecordError(const ::mediapipe::Status& error)
ABSL_LOCKS_EXCLUDED(error_mutex_);
// Combines errors into a status. Returns true if the vector of errors is
// non-empty.
bool GetCombinedErrors(const std::string& error_prefix,
::mediapipe::Status* error_status);
// Convenience overload which specifies a default error prefix.
bool GetCombinedErrors(::mediapipe::Status* error_status);
// Returns the maximum input stream queue size.
int GetMaxInputStreamQueueSize();
@ -501,13 +512,6 @@ class CalculatorGraph {
void CleanupAfterRun(::mediapipe::Status* status)
ABSL_LOCKS_EXCLUDED(error_mutex_);
// Combines errors into a status. Returns true if the vector of errors is
// non-empty.
bool GetCombinedErrors(const std::string& error_prefix,
::mediapipe::Status* error_status);
// Convenience overload which specifies a default error prefix.
bool GetCombinedErrors(::mediapipe::Status* error_status);
// Calls HandlePreRunStatus or HandleStatus on the StatusHandlers. Which one
// is called depends on the GraphRunState parameter (PRE_RUN or POST_RUN).
// current_run_side_packets_ must be set before this function is called.

View File

@ -458,8 +458,9 @@ class Vector3
// return the index of the largest component (fabs)
int LargestAbsComponent() const {
Vector3 temp = Abs();
return temp[0] > temp[1] ? temp[0] > temp[2] ? 0 : 2
: temp[1] > temp[2] ? 1 : 2;
return temp[0] > temp[1] ? temp[0] > temp[2] ? 0 : 2
: temp[1] > temp[2] ? 1
: 2;
}
// return the index of the smallest, median ,largest component of the vector

View File

@ -155,7 +155,7 @@ class InputStreamHandler {
// max number of invocations that are allowed to be scheduled is reached.
// Returns true if at least one invocation has been scheduled.
// The latest minimum timestamp bound of the input streams is returned in
// *input_bound iff the latest readiness of the node is kNotReady when the
// *input_bound if the latest readiness of the node is kNotReady when the
// function returns. During batching, this value will be equal to the
// timestamp of the first set of inputs in the batch. In other cases,
// Timestamp::Unset() is returned.

View File

@ -61,11 +61,25 @@ class LegacyCalculatorSupport {
// platforms.
#ifndef __APPLE__
ABSL_CONST_INIT
#endif // !__APPLE__
#endif // !__APPLE__
static thread_local C* current_; // NOLINT
};
};
// We only declare this variable for two specializations of the template because
// it is only meant to be used for these two types.
// Note that, since these variables are members of specific template
// _specializations_, they are not themselves templates, and therefore their
// definitions must be in the .cc file. However, a declaration still needs to be
// included in the header, or some compilers will assume they have no
// definition.
template <>
thread_local CalculatorContext*
LegacyCalculatorSupport::Scoped<CalculatorContext>::current_;
template <>
thread_local CalculatorContract*
LegacyCalculatorSupport::Scoped<CalculatorContract>::current_;
} // namespace mediapipe
#endif // MEDIAPIPE_FRAMEWORK_LEGACY_CALCULATOR_SUPPORT_H_

View File

@ -51,6 +51,18 @@ const HolderBase* GetHolder(const Packet& packet) {
return packet.holder_.get();
}
::mediapipe::StatusOr<Packet> PacketFromDynamicProto(
const std::string& type_name, const std::string& serialized) {
ASSIGN_OR_RETURN(
auto message_holder,
packet_internal::MessageHolderRegistry::CreateByName(type_name));
auto* message =
const_cast<proto_ns::MessageLite*>(message_holder->GetProtoMessageLite());
RET_CHECK_NE(message, nullptr);
RET_CHECK(message->ParseFromString(serialized));
return packet_internal::Create(message_holder.release());
}
} // namespace packet_internal
Packet Packet::At(class Timestamp timestamp) const& {

View File

@ -27,6 +27,8 @@
#include "absl/memory/memory.h"
#include "absl/strings/str_cat.h"
#include "absl/synchronization/mutex.h"
#include "mediapipe/framework/deps/no_destructor.h"
#include "mediapipe/framework/deps/registration.h"
#include "mediapipe/framework/port.h"
#include "mediapipe/framework/port/canonical_errors.h"
#include "mediapipe/framework/port/logging.h"
@ -51,6 +53,8 @@ Packet Create(HolderBase* holder, Timestamp timestamp);
Packet Create(std::shared_ptr<HolderBase> holder, Timestamp timestamp);
const HolderBase* GetHolder(const Packet& packet);
const std::shared_ptr<HolderBase>& GetHolderShared(const Packet& packet);
::mediapipe::StatusOr<Packet> PacketFromDynamicProto(
const std::string& type_name, const std::string& serialized);
} // namespace packet_internal
// A generic container class which can hold data of any type. The type of
@ -355,112 +359,11 @@ class HolderBase {
// Downcasts this to Holder<T>. Returns nullptr if deserialization
// failed or if the requested type is not what is stored.
template <typename T>
inline Holder<T>* As(
typename std::enable_if<
(!std::is_base_of<proto_ns::MessageLite, T>::value &&
!std::is_base_of<proto_ns::Message, T>::value) ||
(std::is_same<proto_ns::MessageLite, T>::value ||
std::is_same<proto_ns::Message, T>::value)>::type* = 0) {
if (HolderIsOfType<Holder<T>>() || HolderIsOfType<ForeignHolder<T>>()) {
return static_cast<Holder<T>*>(this);
}
// Does not hold a T.
return nullptr;
}
// For proto Message/MessageLite subclasses.
// When holder data is a concrete proto, the method downcasts this to
// Holder<T> if the requested type is what is stored.
// When holder data is a generic proto Message/MessageLite and a concrete
// proto type T is requested, the method will downcast the HolderBase to
// Holder<T> if the proto data is an instance of T.
template <typename T>
inline Holder<T>* As(
typename std::enable_if<
(std::is_base_of<proto_ns::MessageLite, T>::value ||
std::is_base_of<proto_ns::Message, T>::value) &&
(!std::is_same<proto_ns::MessageLite, T>::value &&
!std::is_same<proto_ns::Message, T>::value)>::type* = 0) {
// Holder data is an instance of subclass type T.
if (HolderIsOfType<Holder<T>>() || HolderIsOfType<ForeignHolder<T>>()) {
return static_cast<Holder<T>*>(this);
}
// Holder data is a generic proto Message/MessageLite and a subclass type T
// is requested.
if (HolderIsOfType<Holder<proto_ns::Message>>() ||
HolderIsOfType<ForeignHolder<proto_ns::Message>>() ||
HolderIsOfType<Holder<proto_ns::MessageLite>>() ||
HolderIsOfType<ForeignHolder<proto_ns::MessageLite>>()) {
// TODO: Holder<proto_ns::Message/MessageLite> cannot be
// legally downcast to Holder<T>, even though that downcast works in
// practice. Need to propose a better way to do the downcast.
Holder<T>* holder = static_cast<Holder<T>*>(this);
T tmp;
VLOG(2) << "Holder proto data type: " << holder->data().GetTypeName()
<< " vs requested proto type: " << tmp.GetTypeName();
if (tmp.GetTypeName() == holder->data().GetTypeName()) {
return holder;
}
}
// Does not hold a T.
return nullptr;
}
Holder<T>* As();
// Same as non-const As() function.
template <typename T>
inline const Holder<T>* As(
typename std::enable_if<
(!std::is_base_of<proto_ns::MessageLite, T>::value &&
!std::is_base_of<proto_ns::Message, T>::value) ||
(std::is_same<proto_ns::MessageLite, T>::value ||
std::is_same<proto_ns::Message, T>::value)>::type* = 0) const {
if (HolderIsOfType<Holder<T>>() || HolderIsOfType<ForeignHolder<T>>()) {
return static_cast<const Holder<T>*>(this);
}
// Does not hold a T.
return nullptr;
}
// For proto Message/MessageLite subclasses.
// When holder data is a concrete proto, the method downcasts this to
// Holder<T> if the requested type is what is stored.
// When holder data is a generic proto Message/MessageLite and a concrete
// proto type T is requested, the method will downcast the HolderBase to
// Holder<T> if the proto data is an instance of T.
template <typename T>
inline const Holder<T>* As(
typename std::enable_if<
(std::is_base_of<proto_ns::MessageLite, T>::value ||
std::is_base_of<proto_ns::Message, T>::value) &&
(!std::is_same<proto_ns::MessageLite, T>::value &&
!std::is_same<proto_ns::Message, T>::value)>::type* = 0) const {
if (HolderIsOfType<Holder<T>>() || HolderIsOfType<ForeignHolder<T>>()) {
return static_cast<const Holder<T>*>(this);
}
// Holder data is a generic proto Message/MessageLite and a subclass type T
// is requested.
if (HolderIsOfType<Holder<proto_ns::Message>>() ||
HolderIsOfType<ForeignHolder<proto_ns::Message>>() ||
HolderIsOfType<Holder<proto_ns::MessageLite>>() ||
HolderIsOfType<ForeignHolder<proto_ns::MessageLite>>()) {
// TODO: Holder<proto_ns::Message/MessageLite> cannot be
// legally downcast to Holder<T>, even though that downcast works in
// practice. Need to propose a better way to do the downcast.
Holder<T>* holder = static_cast<const Holder<T>*>(this);
T tmp;
VLOG(2) << "Holder proto data type: " << holder->data().GetTypeName()
<< " vs requested proto type: " << tmp.GetTypeName();
if (tmp.GetTypeName() == holder->data().GetTypeName()) {
return holder;
}
}
// Does not hold a T.
return nullptr;
}
const Holder<T>* As() const;
// Returns the pointer to MessageLite type for the data in holder, if
// underlying object is protocol buffer type, otherwise, nullptr is returned.
@ -520,12 +423,68 @@ ConvertToVectorOfProtoMessageLitePtrs(const T* data,
return result;
}
// This registry is used to create Holders of the right concrete C++ type given
// a proto type std::string (which is used as the registration key).
class MessageHolderRegistry
: public GlobalFactoryRegistry<std::unique_ptr<HolderBase>> {};
template <typename T>
struct is_concrete_proto_t
: public std::integral_constant<
bool, std::is_base_of<proto_ns::MessageLite, T>{} &&
!std::is_same<proto_ns::MessageLite, T>{} &&
!std::is_same<proto_ns::Message, T>{}> {};
// Registers a message type. T must be a non-cv-qualified concrete proto type.
template <typename T>
struct MessageRegistrationImpl {
static NoDestructor<mediapipe::RegistrationToken> registration;
};
// Static members of template classes can be defined in the header.
template <typename T>
NoDestructor<mediapipe::RegistrationToken>
MessageRegistrationImpl<T>::registration(MessageHolderRegistry::Register(
T{}.GetTypeName(), [] { return absl::make_unique<Holder<T>>(new T); }));
// For non-Message payloads, this does nothing.
template <typename T, typename Enable = void>
struct HolderSupport {
static void EnsureStaticInit() {}
};
// This template ensures that, for each concrete MessageLite subclass that is
// stored in a Packet, we register a function that allows us to create a
// Holder with the correct payload type from the proto's type name.
template <typename T>
struct HolderSupport<T,
typename std::enable_if<is_concrete_proto_t<T>{}>::type> {
// We must use std::remove_cv to ensure we don't try to register Foo twice if
// there are Holder<Foo> and Holder<const Foo>. TODO: lift this
// up to Holder?
using R = MessageRegistrationImpl<typename std::remove_cv<T>::type>;
// For the registration static member to be instantiated, it needs to be
// referenced in a context that requires the definition to exist (see ISO/IEC
// C++ 2003 standard, 14.7.1). Calling this ensures that's the case.
// We need two different call-sites to cover proto types for which packets
// are only ever created (i.e. the protos are only produced by calculators)
// and proto types for which packets are only ever consumed (i.e. the protos
// are only consumed by calculators).
static void EnsureStaticInit() { CHECK(R::registration.get() != nullptr); }
};
template <typename T>
class Holder : public HolderBase {
public:
explicit Holder(const T* ptr) : ptr_(ptr) { SetHolderTypeId<Holder>(); }
explicit Holder(const T* ptr) : ptr_(ptr) {
HolderSupport<T>::EnsureStaticInit();
SetHolderTypeId<Holder>();
}
~Holder() override { delete_helper(); }
const T& data() const { return *ptr_; }
const T& data() const {
HolderSupport<T>::EnsureStaticInit();
return *ptr_;
}
size_t GetTypeId() const final { return tool::GetTypeHash<T>(); }
// Releases the underlying data pointer and transfers the ownership to a
// unique pointer.
@ -622,6 +581,24 @@ class ForeignHolder : public Holder<T> {
}
};
template <typename T>
Holder<T>* HolderBase::As() {
if (HolderIsOfType<Holder<T>>() || HolderIsOfType<ForeignHolder<T>>()) {
return static_cast<Holder<T>*>(this);
}
// Does not hold a T.
return nullptr;
}
template <typename T>
const Holder<T>* HolderBase::As() const {
if (HolderIsOfType<Holder<T>>() || HolderIsOfType<ForeignHolder<T>>()) {
return static_cast<const Holder<T>*>(this);
}
// Does not hold a T.
return nullptr;
}
} // namespace packet_internal
inline Packet::Packet(const Packet& packet)

View File

@ -0,0 +1,57 @@
// Copyright 2020 The MediaPipe Authors.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "absl/strings/str_cat.h"
#include "mediapipe/framework/calculator_framework.h"
#include "mediapipe/framework/packet.h"
#include "mediapipe/framework/packet_test.pb.h"
#include "mediapipe/framework/port/core_proto_inc.h"
#include "mediapipe/framework/port/gmock.h"
#include "mediapipe/framework/port/gtest.h"
#include "mediapipe/framework/port/status_matchers.h"
namespace mediapipe {
namespace {
namespace test_ns {
class TestSinkCalculator : public CalculatorBase {
public:
static ::mediapipe::Status GetContract(CalculatorContract* cc) {
cc->Inputs().Tag("IN").Set<mediapipe::InputOnlyProto>();
cc->Outputs().Tag("OUT").Set<int>();
return ::mediapipe::OkStatus();
}
::mediapipe::Status Process(CalculatorContext* cc) override {
int x = cc->Inputs().Tag("IN").Get<mediapipe::InputOnlyProto>().x();
cc->Outputs().Tag("OUT").AddPacket(
MakePacket<int>(x).At(cc->InputTimestamp()));
return ::mediapipe::OkStatus();
}
};
REGISTER_CALCULATOR(::mediapipe::test_ns::TestSinkCalculator);
} // namespace test_ns
TEST(PacketTest, InputTypeRegistration) {
using testing::Contains;
ASSERT_EQ(mediapipe::InputOnlyProto{}.GetTypeName(),
"mediapipe.InputOnlyProto");
EXPECT_THAT(packet_internal::MessageHolderRegistry::GetRegisteredNames(),
Contains("mediapipe.InputOnlyProto"));
}
} // namespace
} // namespace mediapipe

View File

@ -174,54 +174,13 @@ TEST(PacketTest, ReturnGenericProtobufMessage) {
.x(0));
}
TEST(PacketTest, ReturnProtobufMessageSubType) {
std::unique_ptr<::mediapipe::PacketTestProto> proto_ptr(
new ::mediapipe::PacketTestProto);
proto_ptr->add_x(123);
Packet packet = Adopt(static_cast<proto_ns::Message*>(proto_ptr.release()));
EXPECT_EQ(123, packet.Get<::mediapipe::PacketTestProto>().x(0));
EXPECT_EQ(123, packet.Get<const ::mediapipe::PacketTestProto>().x(0));
}
TEST(PacketTest, TryWrongProtobufMessageSubType) {
// Packet of PacketTestProto.
std::unique_ptr<::mediapipe::PacketTestProto> proto_ptr(
new ::mediapipe::PacketTestProto);
proto_ptr->add_x(123);
Packet packet = Adopt(proto_ptr.release());
EXPECT_FALSE(packet.ValidateAsType<::mediapipe::SimpleProto>().ok());
EXPECT_TRUE(packet.ValidateAsType<::mediapipe::PacketTestProto>().ok());
// Packet of proto_ns::Message.
proto_ptr.reset(new ::mediapipe::PacketTestProto);
proto_ptr->add_x(456);
Packet packet2 = Adopt(static_cast<proto_ns::Message*>(proto_ptr.release()));
EXPECT_FALSE(packet2.ValidateAsType<::mediapipe::SimpleProto>().ok());
EXPECT_TRUE(packet2.ValidateAsType<::mediapipe::PacketTestProto>().ok());
EXPECT_EQ(123, packet.Get<::mediapipe::PacketTestProto>().x(0));
}
TEST(PacketTest, ReturnProtobufMessageLiteSubType) {
std::unique_ptr<::mediapipe::PacketTestProto> proto_ptr(
new ::mediapipe::PacketTestProto);
proto_ptr->add_x(123);
Packet packet =
Adopt(static_cast<proto_ns::MessageLite*>(proto_ptr.release()));
EXPECT_EQ(123, packet.Get<::mediapipe::PacketTestProto>().x(0));
EXPECT_EQ(123, packet.Get<const ::mediapipe::PacketTestProto>().x(0));
}
TEST(PacketTest, TryWrongProtobufMessageLiteSubType) {
// Packet of PacketTestProto.
std::unique_ptr<::mediapipe::PacketTestProto> proto_ptr(
new ::mediapipe::PacketTestProto);
// Packet of proto_ns::MessageLite.
proto_ptr->add_x(456);
Packet packet =
Adopt(static_cast<proto_ns::MessageLite*>(proto_ptr.release()));
EXPECT_FALSE(packet.ValidateAsType<::mediapipe::SimpleProto>().ok());
EXPECT_TRUE(packet.ValidateAsType<::mediapipe::PacketTestProto>().ok());
EXPECT_EQ(456, packet.Get<::mediapipe::PacketTestProto>().x(0));
}
TEST(PacketTest, GetProtoBase) {
@ -505,5 +464,26 @@ TEST(PacketTest, TestConsumeOrCopyBoundedArray) {
EXPECT_TRUE(packet2.IsEmpty());
}
TEST(PacketTest, MessageHolderRegistration) {
using testing::Contains;
Packet packet = MakePacket<mediapipe::SimpleProto>();
ASSERT_EQ(mediapipe::SimpleProto{}.GetTypeName(), "mediapipe.SimpleProto");
EXPECT_THAT(packet_internal::MessageHolderRegistry::GetRegisteredNames(),
Contains("mediapipe.SimpleProto"));
}
TEST(PacketTest, PacketFromSerializedProto) {
mediapipe::SimpleProto original;
original.add_value("foo");
std::string serialized = original.SerializeAsString();
StatusOr<Packet> maybe_packet = packet_internal::PacketFromDynamicProto(
"mediapipe.SimpleProto", serialized);
MP_ASSERT_OK(maybe_packet);
Packet packet = maybe_packet.ValueOrDie();
MP_EXPECT_OK(packet.ValidateAsType<::mediapipe::SimpleProto>());
EXPECT_FALSE(packet.ValidateAsType<::mediapipe::PacketTestProto>().ok());
}
} // namespace
} // namespace mediapipe

View File

@ -39,3 +39,9 @@ message SerializationProxyProto {
repeated float float_value = 2;
repeated string string_value = 3;
}
// This proto should be used only as an input to a calculator, to verify that
// that case is covered.
message InputOnlyProto {
optional int32 x = 1;
}

View File

@ -46,7 +46,7 @@
// but may or may not still be able to run other OpenGL code.
#if !defined(MEDIAPIPE_DISABLE_GL_COMPUTE) && \
(defined(__APPLE__) || defined(__EMSCRIPTEN__) || \
defined(MEDIAPIPE_DISABLE_GPU))
defined(MEDIAPIPE_DISABLE_GPU) || MEDIAPIPE_USING_SWIFTSHADER)
#define MEDIAPIPE_DISABLE_GL_COMPUTE
#endif

View File

@ -143,29 +143,29 @@ TEST_F(GraphTracerTest, CalculatorTrace) {
{{MakePacket<std::string>("goodbye").At(start_timestamp_)}});
// Validate the GraphTrace data.
EXPECT_THAT(GetTrace(),
EqualsProto(::mediapipe::ParseTextProtoOrDie<GraphTrace>(R"(
base_time: 1608911100000000
base_timestamp: 1608911100000000
stream_name: ""
stream_name: "input_stream"
stream_name: "output_stream"
calculator_trace {
node_id: 0
input_timestamp: 0
event_type: PROCESS
start_time: 0
finish_time: 10000
thread_id: 0
input_trace {
finish_time: 0
packet_timestamp: 0
stream_id: 1
event_data: 1
}
output_trace { packet_timestamp: 0 stream_id: 2 }
}
)")));
EXPECT_THAT(
GetTrace(), EqualsProto(::mediapipe::ParseTextProtoOrDie<GraphTrace>(R"(
base_time: 1608911100000000
base_timestamp: 1608911100000000
stream_name: ""
stream_name: "input_stream"
stream_name: "output_stream"
calculator_trace {
node_id: 0
input_timestamp: 0
event_type: PROCESS
start_time: 0
finish_time: 10000
thread_id: 0
input_trace {
finish_time: 0
packet_timestamp: 0
stream_id: 1
event_data: 1
}
output_trace { packet_timestamp: 0 stream_id: 2 event_data: 2 }
}
)")));
}
TEST_F(GraphTracerTest, GraphTrace) {
@ -205,92 +205,101 @@ TEST_F(GraphTracerTest, GraphTrace) {
LogOutputPackets("PCalculator_3", GraphTrace::PROCESS, curr_time,
{{MakePacket<std::string>("out").At(start_timestamp_)}});
curr_time += absl::Microseconds(2000);
ClearCalculatorContext("PCalculator_3");
LogInputPackets("PCalculator_3", GraphTrace::PROCESS, curr_time,
// Note: the packet data ID is based on the packet's payload address, which
// means the same ID can be reused if data is allocated in the same location
// as a previously expired packet (b/160212191). This means the generated
// trace can change depending on the allocator. To keep results stable, we
// must keep the packets used in this test alive until the end. Each
// TestContextBuilder happens to keep a reference to all packets for the last
// context, so for now we just create a separate TestContextBuilder instead of
// clearing it. TODO: revise this test.
SetUpCalculatorContext("PCalculator_3a", /*node_id=*/2, {"up_2"}, {"down_2"});
LogInputPackets("PCalculator_3a", GraphTrace::PROCESS, curr_time,
{MakePacket<std::string>("pup").At(start_timestamp_ + 5)});
curr_time += absl::Microseconds(20000);
LogOutputPackets(
"PCalculator_3", GraphTrace::PROCESS, curr_time,
"PCalculator_3a", GraphTrace::PROCESS, curr_time,
{{MakePacket<std::string>("pout").At(start_timestamp_ + 5)}});
curr_time += absl::Microseconds(1000);
// Validate the GraphTrace data.
EXPECT_THAT(GetTrace(),
EqualsProto(::mediapipe::ParseTextProtoOrDie<GraphTrace>(R"(
base_time: 1608911100000000
base_timestamp: 1608911100000000
stream_name: ""
stream_name: "input_stream"
stream_name: "up_1"
stream_name: "up_2"
stream_name: "down_1"
stream_name: "down_2"
calculator_trace {
node_id: 0
input_timestamp: 0
event_type: PROCESS
start_time: 0
finish_time: 10000
thread_id: 0
input_trace {
finish_time: 0
packet_timestamp: 0
stream_id: 1
event_data: 1
}
output_trace { packet_timestamp: 0 stream_id: 2 }
output_trace { packet_timestamp: 0 stream_id: 3 }
output_trace { packet_timestamp: 5 stream_id: 3 }
}
calculator_trace {
node_id: 1
input_timestamp: 0
event_type: PROCESS
start_time: 11000
finish_time: 21000
thread_id: 0
input_trace {
start_time: 10000
finish_time: 11000
packet_timestamp: 0
stream_id: 2
event_data: 2
}
output_trace { packet_timestamp: 0 stream_id: 4 }
}
calculator_trace {
node_id: 2
input_timestamp: 0
event_type: PROCESS
start_time: 16000
finish_time: 36000
thread_id: 0
input_trace {
start_time: 10000
finish_time: 16000
packet_timestamp: 0
stream_id: 3
event_data: 3
}
output_trace { packet_timestamp: 0 stream_id: 5 }
}
calculator_trace {
node_id: 2
input_timestamp: 5
event_type: PROCESS
start_time: 38000
finish_time: 58000
thread_id: 0
input_trace {
start_time: 10000
finish_time: 38000
packet_timestamp: 5
stream_id: 3
event_data: 4
}
output_trace { packet_timestamp: 5 stream_id: 5 }
}
)")));
EXPECT_THAT(
GetTrace(), EqualsProto(::mediapipe::ParseTextProtoOrDie<GraphTrace>(R"(
base_time: 1608911100000000
base_timestamp: 1608911100000000
stream_name: ""
stream_name: "input_stream"
stream_name: "up_1"
stream_name: "up_2"
stream_name: "down_1"
stream_name: "down_2"
calculator_trace {
node_id: 0
input_timestamp: 0
event_type: PROCESS
start_time: 0
finish_time: 10000
thread_id: 0
input_trace {
finish_time: 0
packet_timestamp: 0
stream_id: 1
event_data: 1
}
output_trace { packet_timestamp: 0 stream_id: 2 event_data: 2 }
output_trace { packet_timestamp: 0 stream_id: 3 event_data: 3 }
output_trace { packet_timestamp: 5 stream_id: 3 event_data: 4 }
}
calculator_trace {
node_id: 1
input_timestamp: 0
event_type: PROCESS
start_time: 11000
finish_time: 21000
thread_id: 0
input_trace {
start_time: 10000
finish_time: 11000
packet_timestamp: 0
stream_id: 2
event_data: 5
}
output_trace { packet_timestamp: 0 stream_id: 4 event_data: 6 }
}
calculator_trace {
node_id: 2
input_timestamp: 0
event_type: PROCESS
start_time: 16000
finish_time: 36000
thread_id: 0
input_trace {
start_time: 10000
finish_time: 16000
packet_timestamp: 0
stream_id: 3
event_data: 7
}
output_trace { packet_timestamp: 0 stream_id: 5 event_data: 8 }
}
calculator_trace {
node_id: 2
input_timestamp: 5
event_type: PROCESS
start_time: 38000
finish_time: 58000
thread_id: 0
input_trace {
start_time: 10000
finish_time: 38000
packet_timestamp: 5
stream_id: 3
event_data: 9
}
output_trace { packet_timestamp: 5 stream_id: 5 event_data: 10 }
}
)")));
// No timestamps are completed before start_time_.
// One timestamp is completed before start_time_ + 10ms.
@ -1275,37 +1284,39 @@ TEST_F(GraphTracerE2ETest, GpuTaskTrace) {
GraphTrace trace_1;
builder.CreateTrace(buffer, absl::InfinitePast(), absl::InfiniteFuture(),
&trace_1);
EXPECT_THAT(trace_1, EqualsProto(::mediapipe::ParseTextProtoOrDie<GraphTrace>(
R"(
base_time: 1100
base_timestamp: 1000
stream_name: ""
stream_name: "stream_1"
stream_name: "stream_2"
calculator_trace {
node_id: 333
input_timestamp: 0
event_type: PROCESS
start_time: 0
finish_time: 1000
input_trace {
finish_time: 0
packet_timestamp: 0
stream_id: 1
event_data: 0
}
output_trace { packet_timestamp: 0 stream_id: 2 }
thread_id: 0
}
calculator_trace {
node_id: 333
input_timestamp: 0
event_type: GPU_TASK
start_time: 100
finish_time: 2100
thread_id: 0
}
)")));
EXPECT_THAT(
trace_1,
EqualsProto(::mediapipe::ParseTextProtoOrDie<GraphTrace>(
R"(
base_time: 1100
base_timestamp: 1000
stream_name: ""
stream_name: "stream_1"
stream_name: "stream_2"
calculator_trace {
node_id: 333
input_timestamp: 0
event_type: PROCESS
start_time: 0
finish_time: 1000
input_trace {
finish_time: 0
packet_timestamp: 0
stream_id: 1
event_data: 0
}
output_trace { packet_timestamp: 0 stream_id: 2 event_data: 0 }
thread_id: 0
}
calculator_trace {
node_id: 333
input_timestamp: 0
event_type: GPU_TASK
start_time: 100
finish_time: 2100
thread_id: 0
}
)")));
GraphTrace trace_2;
builder.CreateLog(buffer, absl::InfinitePast(), absl::InfiniteFuture(),

View File

@ -330,13 +330,12 @@ class TraceBuilder::Impl {
if (trace_event_registry_[event->event_type].is_stream_event()) {
auto stream_trace = event->is_finish ? result->add_output_trace()
: result->add_input_trace();
if (event->is_finish) {
// Log only the packet id for each output event.
stream_trace->set_stream_id(stream_id_map_[event->stream_id]);
stream_trace->set_packet_timestamp(LogTimestamp(event->packet_ts));
} else {
// Log the full stream trace for each input event.
BuildStreamTrace(*event, stream_trace);
BuildStreamTrace(*event, stream_trace);
if (!event->is_finish) {
// Note: is_finish is true for output events, false for input events.
// For input events, we log some additional timing information. The
// finish_time is the start_time of this Process call, the start_time
// is the finish_time of the Process call that output the packet.
stream_trace->set_finish_time(LogTime(event->event_time));
const TraceEvent* output_event = FindOutputEvent(*event);
if (output_event) {

View File

@ -116,10 +116,19 @@ NodeReadiness ImmediateInputStreamHandler::GetNodeReadiness(
CHECK_EQ(stream_ts, Timestamp::Done());
if (ProcessTimestampBounds()) {
// With kReadyForClose, the timestamp-bound Done is returned.
// This bound is processed using the preceding input-timestamp.
// TODO: Make all InputStreamHandlers process Done() like this.
ready_timestamps_[i] = stream_ts.PreviousAllowedInStream();
input_timestamp = std::min(input_timestamp, ready_timestamps_[i]);
static const Timestamp kDonePrecedingTimestamp =
Timestamp::Done().PreviousAllowedInStream();
if (prev_ts < kDonePrecedingTimestamp) {
// When kReadyForClose is received for the first time for a sync set,
// it is processed using the timestamp preceding Done() to indicate
// input stream is done, but still needs to be processed.
min_bound = std::min(min_bound, kDonePrecedingTimestamp);
input_timestamp = std::min(input_timestamp, kDonePrecedingTimestamp);
ready_timestamps_[i] = kDonePrecedingTimestamp;
} else {
ready_timestamps_[i] = Timestamp::Done();
}
} else if (prev_ts < Timestamp::Done()) {
stream_became_done = true;
ready_timestamps_[i] = Timestamp::Done();

View File

@ -133,6 +133,11 @@ class ImmediateInputStreamHandlerTest : public ::testing::Test {
}
}
const InputStream& Input(const CollectionItemId& id) {
CHECK(cc_);
return cc_->Inputs().Get(id);
}
PacketType packet_type_;
std::function<void()> headers_ready_callback_;
std::function<void()> notification_callback_;
@ -262,6 +267,344 @@ TEST_F(ImmediateInputStreamHandlerTest, ReadyForClose) {
EXPECT_TRUE(errors_.empty());
}
TEST_F(ImmediateInputStreamHandlerTest, ProcessTimestampBounds) {
input_stream_handler_->SetProcessTimestampBounds(true);
Timestamp min_stream_timestamp;
ASSERT_FALSE(input_stream_handler_->ScheduleInvocations(
/*max_allowance=*/1, &min_stream_timestamp));
EXPECT_EQ(min_stream_timestamp, Timestamp::PreStream());
const auto& input_a_id = name_to_id_["input_a"];
const auto& input_b_id = name_to_id_["input_b"];
const auto& input_c_id = name_to_id_["input_c"];
std::list<Packet> packets;
packets.push_back(Adopt(new std::string("packet 1")).At(Timestamp(1)));
input_stream_handler_->AddPackets(input_b_id, packets);
input_stream_handler_->SetNextTimestampBound(input_b_id, Timestamp::Done());
ASSERT_TRUE(input_stream_handler_->ScheduleInvocations(
/*max_allowance=*/1, &min_stream_timestamp));
EXPECT_EQ(min_stream_timestamp, Timestamp::Unset());
EXPECT_EQ(cc_->InputTimestamp(), Timestamp(1));
ExpectPackets(cc_->Inputs(), {{"input_b", "packet 1"}});
EXPECT_EQ(Input(input_b_id).Value().Timestamp(), Timestamp(1));
EXPECT_TRUE(Input(input_a_id).Value().IsEmpty());
EXPECT_EQ(Input(input_a_id).Value().Timestamp(), Timestamp::Unstarted());
EXPECT_TRUE(Input(input_c_id).Value().IsEmpty());
EXPECT_EQ(Input(input_c_id).Value().Timestamp(), Timestamp::Unstarted());
// FinalizeInputSet() is a no-op.
input_stream_handler_->FinalizeInputSet(cc_->InputTimestamp(),
&cc_->Inputs());
input_stream_handler_->ClearCurrentInputs(cc_);
ASSERT_TRUE(input_stream_handler_->ScheduleInvocations(
/*max_allowance=*/1, &min_stream_timestamp));
EXPECT_EQ(min_stream_timestamp, Timestamp::Unset());
EXPECT_EQ(cc_->InputTimestamp(), Timestamp::Max());
EXPECT_TRUE(Input(input_b_id).Value().IsEmpty());
EXPECT_EQ(Input(input_b_id).Value().Timestamp(), Timestamp::Max());
EXPECT_TRUE(Input(input_a_id).Value().IsEmpty());
EXPECT_EQ(Input(input_a_id).Value().Timestamp(), Timestamp::Unstarted());
EXPECT_TRUE(Input(input_c_id).Value().IsEmpty());
EXPECT_EQ(Input(input_c_id).Value().Timestamp(), Timestamp::Unstarted());
// FinalizeInputSet() is a no-op.
input_stream_handler_->FinalizeInputSet(cc_->InputTimestamp(),
&cc_->Inputs());
input_stream_handler_->ClearCurrentInputs(cc_);
EXPECT_TRUE(
input_stream_handler_->GetInputStreamManager(input_b_id)->IsEmpty());
input_stream_handler_->SetNextTimestampBound(input_a_id, Timestamp::Done());
input_stream_handler_->SetNextTimestampBound(input_c_id, Timestamp::Done());
ASSERT_TRUE(input_stream_handler_->ScheduleInvocations(
/*max_allowance=*/1, &min_stream_timestamp));
EXPECT_EQ(min_stream_timestamp, Timestamp::Unset());
EXPECT_EQ(cc_->InputTimestamp(), Timestamp::Max());
EXPECT_TRUE(Input(input_b_id).Value().IsEmpty());
EXPECT_EQ(Input(input_b_id).Value().Timestamp(), Timestamp::Max());
EXPECT_TRUE(Input(input_a_id).Value().IsEmpty());
EXPECT_EQ(Input(input_a_id).Value().Timestamp(), Timestamp::Max());
EXPECT_TRUE(Input(input_c_id).Value().IsEmpty());
EXPECT_EQ(Input(input_c_id).Value().Timestamp(), Timestamp::Max());
input_stream_handler_->FinalizeInputSet(cc_->InputTimestamp(),
&cc_->Inputs());
input_stream_handler_->ClearCurrentInputs(cc_);
EXPECT_TRUE(errors_.empty());
// Schedule invocation for Close.
ASSERT_TRUE(input_stream_handler_->ScheduleInvocations(
/*max_allowance=*/1, &min_stream_timestamp));
EXPECT_EQ(min_stream_timestamp, Timestamp::Unset());
EXPECT_EQ(cc_->InputTimestamp(), Timestamp::Done());
EXPECT_TRUE(Input(input_b_id).Value().IsEmpty());
EXPECT_EQ(Input(input_b_id).Value().Timestamp(), Timestamp::Unset());
EXPECT_TRUE(Input(input_a_id).Value().IsEmpty());
EXPECT_EQ(Input(input_a_id).Value().Timestamp(), Timestamp::Unset());
EXPECT_TRUE(Input(input_c_id).Value().IsEmpty());
EXPECT_EQ(Input(input_c_id).Value().Timestamp(), Timestamp::Unset());
input_stream_handler_->FinalizeInputSet(cc_->InputTimestamp(),
&cc_->Inputs());
input_stream_handler_->ClearCurrentInputs(cc_);
EXPECT_TRUE(errors_.empty());
}
TEST_F(ImmediateInputStreamHandlerTest,
ProcessTimestampBoundsNoOpScheduleInvocations) {
input_stream_handler_->SetProcessTimestampBounds(true);
const auto& input_a_id = name_to_id_["input_a"];
const auto& input_b_id = name_to_id_["input_b"];
const auto& input_c_id = name_to_id_["input_c"];
Timestamp min_stream_timestamp;
std::list<Packet> packets;
packets.push_back(Adopt(new std::string("packet 1")).At(Timestamp(1)));
input_stream_handler_->AddPackets(input_b_id, packets);
ASSERT_TRUE(input_stream_handler_->ScheduleInvocations(
/*max_allowance=*/1, &min_stream_timestamp));
EXPECT_EQ(min_stream_timestamp, Timestamp::Unset());
EXPECT_EQ(cc_->InputTimestamp(), Timestamp(1));
ExpectPackets(cc_->Inputs(), {{"input_b", "packet 1"}});
EXPECT_EQ(Input(input_b_id).Value().Timestamp(), Timestamp(1));
EXPECT_TRUE(Input(input_a_id).Value().IsEmpty());
EXPECT_EQ(Input(input_a_id).Value().Timestamp(), Timestamp::Unstarted());
EXPECT_TRUE(Input(input_c_id).Value().IsEmpty());
EXPECT_EQ(Input(input_c_id).Value().Timestamp(), Timestamp::Unstarted());
// FinalizeInputSet() is a no-op.
input_stream_handler_->FinalizeInputSet(cc_->InputTimestamp(),
&cc_->Inputs());
input_stream_handler_->ClearCurrentInputs(cc_);
input_stream_handler_->SetNextTimestampBound(input_a_id, Timestamp::Done());
input_stream_handler_->SetNextTimestampBound(input_c_id, Timestamp::Done());
ASSERT_TRUE(input_stream_handler_->ScheduleInvocations(
/*max_allowance=*/1, &min_stream_timestamp));
EXPECT_EQ(min_stream_timestamp, Timestamp::Unset());
EXPECT_EQ(cc_->InputTimestamp(), Timestamp::Max());
EXPECT_TRUE(Input(input_b_id).Value().IsEmpty());
EXPECT_EQ(Input(input_b_id).Value().Timestamp(), Timestamp(1));
EXPECT_TRUE(Input(input_a_id).Value().IsEmpty());
EXPECT_EQ(Input(input_a_id).Value().Timestamp(), Timestamp::Max());
EXPECT_TRUE(Input(input_c_id).Value().IsEmpty());
EXPECT_EQ(Input(input_c_id).Value().Timestamp(), Timestamp::Max());
input_stream_handler_->FinalizeInputSet(cc_->InputTimestamp(),
&cc_->Inputs());
input_stream_handler_->ClearCurrentInputs(cc_);
EXPECT_TRUE(errors_.empty());
// Try to schedule invocations several times again. Considering nothing
// changed since last invocation nothing should be scheduled.
for (int i = 0; i < 3; ++i) {
ASSERT_FALSE(input_stream_handler_->ScheduleInvocations(
/*max_allowance=*/1, &min_stream_timestamp));
EXPECT_EQ(min_stream_timestamp, Timestamp(2));
EXPECT_TRUE(Input(input_b_id).Value().IsEmpty());
EXPECT_EQ(Input(input_b_id).Value().Timestamp(), Timestamp::Unset());
EXPECT_TRUE(Input(input_a_id).Value().IsEmpty());
EXPECT_EQ(Input(input_a_id).Value().Timestamp(), Timestamp::Unset());
EXPECT_TRUE(Input(input_c_id).Value().IsEmpty());
EXPECT_EQ(Input(input_c_id).Value().Timestamp(), Timestamp::Unset());
}
input_stream_handler_->SetNextTimestampBound(input_b_id, Timestamp::Done());
ASSERT_TRUE(input_stream_handler_->ScheduleInvocations(
/*max_allowance=*/1, &min_stream_timestamp));
EXPECT_EQ(min_stream_timestamp, Timestamp::Unset());
EXPECT_EQ(cc_->InputTimestamp(), Timestamp::Max());
EXPECT_TRUE(Input(input_b_id).Value().IsEmpty());
EXPECT_EQ(Input(input_b_id).Value().Timestamp(), Timestamp::Max());
EXPECT_TRUE(Input(input_a_id).Value().IsEmpty());
EXPECT_EQ(Input(input_a_id).Value().Timestamp(), Timestamp::Max());
EXPECT_TRUE(Input(input_c_id).Value().IsEmpty());
EXPECT_EQ(Input(input_c_id).Value().Timestamp(), Timestamp::Max());
input_stream_handler_->FinalizeInputSet(cc_->InputTimestamp(),
&cc_->Inputs());
input_stream_handler_->ClearCurrentInputs(cc_);
EXPECT_TRUE(errors_.empty());
// Schedule invocation for Close.
ASSERT_TRUE(input_stream_handler_->ScheduleInvocations(
/*max_allowance=*/1, &min_stream_timestamp));
EXPECT_EQ(min_stream_timestamp, Timestamp::Unset());
EXPECT_EQ(cc_->InputTimestamp(), Timestamp::Done());
EXPECT_TRUE(Input(input_b_id).Value().IsEmpty());
EXPECT_EQ(Input(input_b_id).Value().Timestamp(), Timestamp::Unset());
EXPECT_TRUE(Input(input_a_id).Value().IsEmpty());
EXPECT_EQ(Input(input_a_id).Value().Timestamp(), Timestamp::Unset());
EXPECT_TRUE(Input(input_c_id).Value().IsEmpty());
EXPECT_EQ(Input(input_c_id).Value().Timestamp(), Timestamp::Unset());
input_stream_handler_->FinalizeInputSet(cc_->InputTimestamp(),
&cc_->Inputs());
input_stream_handler_->ClearCurrentInputs(cc_);
EXPECT_TRUE(errors_.empty());
// Try to schedule invocations several times again. Considering nothing
// changed since last invocation nothing should be scheduled.
for (int i = 0; i < 3; ++i) {
ASSERT_FALSE(input_stream_handler_->ScheduleInvocations(
/*max_allowance=*/1, &min_stream_timestamp));
EXPECT_EQ(min_stream_timestamp, Timestamp::Unset());
EXPECT_TRUE(Input(input_b_id).Value().IsEmpty());
EXPECT_EQ(Input(input_b_id).Value().Timestamp(), Timestamp::Unset());
EXPECT_TRUE(Input(input_a_id).Value().IsEmpty());
EXPECT_EQ(Input(input_a_id).Value().Timestamp(), Timestamp::Unset());
EXPECT_TRUE(Input(input_c_id).Value().IsEmpty());
EXPECT_EQ(Input(input_c_id).Value().Timestamp(), Timestamp::Unset());
}
}
// Due to some temporary changes in ImmediateInputStreamHandler some packets
// - were queued but never released
// - were released in incorrect order
// As other test cases were passing, this test case is designed to ensure that.
TEST_F(ImmediateInputStreamHandlerTest, VerifyPacketsReleaseOrder) {
input_stream_handler_->SetProcessTimestampBounds(true);
const auto& input_a_id = name_to_id_["input_a"];
const auto& input_b_id = name_to_id_["input_b"];
const auto& input_c_id = name_to_id_["input_c"];
Packet packet_a = Adopt(new std::string("packet a"));
Packet packet_b = Adopt(new std::string("packet b"));
Packet packet_c = Adopt(new std::string("packet c"));
input_stream_handler_->AddPackets(input_a_id, {packet_a.At(Timestamp(1))});
input_stream_handler_->AddPackets(input_b_id, {packet_b.At(Timestamp(2))});
input_stream_handler_->AddPackets(input_c_id, {packet_c.At(Timestamp(3))});
Timestamp min_stream_timestamp;
ASSERT_TRUE(input_stream_handler_->ScheduleInvocations(
/*max_allowance=*/1, &min_stream_timestamp));
EXPECT_EQ(min_stream_timestamp, Timestamp::Unset());
EXPECT_EQ(cc_->InputTimestamp(), Timestamp(1));
ASSERT_FALSE(Input(input_a_id).IsEmpty());
EXPECT_EQ(Input(input_a_id).Get<std::string>(), "packet a");
EXPECT_EQ(Input(input_a_id).Value().Timestamp(), Timestamp(1));
EXPECT_TRUE(Input(input_b_id).Value().IsEmpty());
EXPECT_EQ(Input(input_b_id).Value().Timestamp(), Timestamp(1));
EXPECT_TRUE(Input(input_c_id).Value().IsEmpty());
EXPECT_EQ(Input(input_c_id).Value().Timestamp(), Timestamp(2));
// FinalizeInputSet() is a no-op.
input_stream_handler_->FinalizeInputSet(cc_->InputTimestamp(),
&cc_->Inputs());
input_stream_handler_->ClearCurrentInputs(cc_);
input_stream_handler_->AddPackets(input_a_id, {packet_a.At(Timestamp(5))});
input_stream_handler_->AddPackets(input_b_id, {packet_b.At(Timestamp(5))});
input_stream_handler_->AddPackets(input_c_id, {packet_c.At(Timestamp(5))});
ASSERT_TRUE(input_stream_handler_->ScheduleInvocations(
/*max_allowance=*/1, &min_stream_timestamp));
EXPECT_EQ(min_stream_timestamp, Timestamp::Unset());
EXPECT_EQ(cc_->InputTimestamp(), Timestamp(2));
EXPECT_TRUE(Input(input_a_id).Value().IsEmpty());
EXPECT_EQ(Input(input_a_id).Value().Timestamp(), Timestamp(4));
ASSERT_FALSE(Input(input_b_id).IsEmpty());
EXPECT_EQ(Input(input_b_id).Get<std::string>(), "packet b");
EXPECT_EQ(Input(input_b_id).Value().Timestamp(), Timestamp(2));
EXPECT_TRUE(Input(input_c_id).Value().IsEmpty());
EXPECT_EQ(Input(input_c_id).Value().Timestamp(), Timestamp(2));
// FinalizeInputSet() is a no-op.
input_stream_handler_->FinalizeInputSet(cc_->InputTimestamp(),
&cc_->Inputs());
input_stream_handler_->ClearCurrentInputs(cc_);
ASSERT_TRUE(input_stream_handler_->ScheduleInvocations(
/*max_allowance=*/1, &min_stream_timestamp));
EXPECT_EQ(min_stream_timestamp, Timestamp::Unset());
EXPECT_EQ(cc_->InputTimestamp(), Timestamp(3));
EXPECT_TRUE(Input(input_a_id).Value().IsEmpty());
EXPECT_EQ(Input(input_a_id).Value().Timestamp(), Timestamp(4));
EXPECT_TRUE(Input(input_b_id).Value().IsEmpty());
EXPECT_EQ(Input(input_b_id).Value().Timestamp(), Timestamp(4));
ASSERT_FALSE(Input(input_c_id).Value().IsEmpty());
EXPECT_EQ(Input(input_c_id).Get<std::string>(), "packet c");
EXPECT_EQ(Input(input_c_id).Value().Timestamp(), Timestamp(3));
// FinalizeInputSet() is a no-op.
input_stream_handler_->FinalizeInputSet(cc_->InputTimestamp(),
&cc_->Inputs());
input_stream_handler_->ClearCurrentInputs(cc_);
ASSERT_TRUE(input_stream_handler_->ScheduleInvocations(
/*max_allowance=*/1, &min_stream_timestamp));
EXPECT_EQ(min_stream_timestamp, Timestamp::Unset());
EXPECT_EQ(cc_->InputTimestamp(), Timestamp(5));
ASSERT_FALSE(Input(input_a_id).Value().IsEmpty());
EXPECT_EQ(Input(input_a_id).Get<std::string>(), "packet a");
EXPECT_EQ(Input(input_a_id).Value().Timestamp(), Timestamp(5));
ASSERT_FALSE(Input(input_b_id).Value().IsEmpty());
EXPECT_EQ(Input(input_b_id).Get<std::string>(), "packet b");
EXPECT_EQ(Input(input_b_id).Value().Timestamp(), Timestamp(5));
ASSERT_FALSE(Input(input_c_id).Value().IsEmpty());
EXPECT_EQ(Input(input_c_id).Get<std::string>(), "packet c");
EXPECT_EQ(Input(input_c_id).Value().Timestamp(), Timestamp(5));
input_stream_handler_->FinalizeInputSet(cc_->InputTimestamp(),
&cc_->Inputs());
input_stream_handler_->ClearCurrentInputs(cc_);
input_stream_handler_->SetNextTimestampBound(input_a_id, Timestamp::Done());
input_stream_handler_->SetNextTimestampBound(input_b_id, Timestamp::Done());
input_stream_handler_->SetNextTimestampBound(input_c_id, Timestamp::Done());
ASSERT_TRUE(input_stream_handler_->ScheduleInvocations(
/*max_allowance=*/1, &min_stream_timestamp));
EXPECT_EQ(min_stream_timestamp, Timestamp::Unset());
EXPECT_EQ(cc_->InputTimestamp(), Timestamp::Max());
EXPECT_TRUE(Input(input_b_id).Value().IsEmpty());
EXPECT_EQ(Input(input_b_id).Value().Timestamp(), Timestamp::Max());
EXPECT_TRUE(Input(input_a_id).Value().IsEmpty());
EXPECT_EQ(Input(input_a_id).Value().Timestamp(), Timestamp::Max());
EXPECT_TRUE(Input(input_c_id).Value().IsEmpty());
EXPECT_EQ(Input(input_c_id).Value().Timestamp(), Timestamp::Max());
input_stream_handler_->FinalizeInputSet(cc_->InputTimestamp(),
&cc_->Inputs());
input_stream_handler_->ClearCurrentInputs(cc_);
// Schedule invocation for Close.
ASSERT_TRUE(input_stream_handler_->ScheduleInvocations(
/*max_allowance=*/1, &min_stream_timestamp));
EXPECT_EQ(min_stream_timestamp, Timestamp::Unset());
EXPECT_EQ(cc_->InputTimestamp(), Timestamp::Done());
EXPECT_TRUE(Input(input_b_id).Value().IsEmpty());
EXPECT_EQ(Input(input_b_id).Value().Timestamp(), Timestamp::Unset());
EXPECT_TRUE(Input(input_a_id).Value().IsEmpty());
EXPECT_EQ(Input(input_a_id).Value().Timestamp(), Timestamp::Unset());
EXPECT_TRUE(Input(input_c_id).Value().IsEmpty());
EXPECT_EQ(Input(input_c_id).Value().Timestamp(), Timestamp::Unset());
input_stream_handler_->FinalizeInputSet(cc_->InputTimestamp(),
&cc_->Inputs());
input_stream_handler_->ClearCurrentInputs(cc_);
ASSERT_FALSE(input_stream_handler_->ScheduleInvocations(
/*max_allowance=*/1, &min_stream_timestamp));
EXPECT_EQ(min_stream_timestamp, Timestamp::Unset());
EXPECT_TRUE(Input(input_b_id).Value().IsEmpty());
EXPECT_EQ(Input(input_b_id).Value().Timestamp(), Timestamp::Unset());
EXPECT_TRUE(Input(input_a_id).Value().IsEmpty());
EXPECT_EQ(Input(input_a_id).Value().Timestamp(), Timestamp::Unset());
EXPECT_TRUE(Input(input_c_id).Value().IsEmpty());
EXPECT_EQ(Input(input_c_id).Value().Timestamp(), Timestamp::Unset());
}
// This test simulates how CalculatorNode::ProcessNode() uses an input
// stream handler and the associated input streams.
TEST_F(ImmediateInputStreamHandlerTest, SimulateProcessNode) {

View File

@ -641,4 +641,61 @@ class DummyTestCalculator : public CalculatorBase {
};
REGISTER_CALCULATOR(DummyTestCalculator);
// A Calculator that passes the input value to the output after sleeping for
// a set number of microseconds.
class PassThroughWithSleepCalculator : public CalculatorBase {
public:
static ::mediapipe::Status GetContract(CalculatorContract* cc) {
cc->Inputs().Index(0).Set<int>();
cc->Outputs().Index(0).SetSameAs(&cc->Inputs().Index(0));
cc->InputSidePackets().Tag("SLEEP_MICROS").Set<int>();
cc->InputSidePackets().Tag("CLOCK").Set<std::shared_ptr<Clock>>();
return ::mediapipe::OkStatus();
}
::mediapipe::Status Open(CalculatorContext* cc) final {
cc->SetOffset(TimestampDiff(0));
sleep_micros_ = cc->InputSidePackets().Tag("SLEEP_MICROS").Get<int>();
if (sleep_micros_ < 0) {
return ::mediapipe::InternalError("SLEEP_MICROS should be >= 0");
}
clock_ = cc->InputSidePackets().Tag("CLOCK").Get<std::shared_ptr<Clock>>();
return ::mediapipe::OkStatus();
}
::mediapipe::Status Process(CalculatorContext* cc) final {
clock_->Sleep(absl::Microseconds(sleep_micros_));
int value = cc->Inputs().Index(0).Value().Get<int>();
cc->Outputs().Index(0).Add(new int(value), cc->InputTimestamp());
return ::mediapipe::OkStatus();
}
private:
int sleep_micros_ = 0;
std::shared_ptr<Clock> clock_;
};
REGISTER_CALCULATOR(PassThroughWithSleepCalculator);
// A Calculator that multiples two input values.
class MultiplyIntCalculator : public CalculatorBase {
public:
static ::mediapipe::Status GetContract(CalculatorContract* cc) {
cc->Inputs().Index(0).Set<int>();
cc->Inputs().Index(1).SetSameAs(&cc->Inputs().Index(0));
// cc->Outputs().Index(0).SetSameAs(&cc->Inputs().Index(0));
RET_CHECK(cc->Outputs().HasTag("OUT"));
cc->Outputs().Tag("OUT").SetSameAs(&cc->Inputs().Index(0));
return ::mediapipe::OkStatus();
}
::mediapipe::Status Open(CalculatorContext* cc) final {
cc->SetOffset(TimestampDiff(0));
return ::mediapipe::OkStatus();
}
::mediapipe::Status Process(CalculatorContext* cc) final {
int x = cc->Inputs().Index(0).Value().Get<int>();
int y = cc->Inputs().Index(1).Value().Get<int>();
cc->Outputs().Tag("OUT").Add(new int(x * y), cc->InputTimestamp());
return ::mediapipe::OkStatus();
}
};
REGISTER_CALCULATOR(MultiplyIntCalculator);
} // namespace mediapipe

View File

@ -101,6 +101,13 @@ std::string ParseNameFromStream(const std::string& stream) {
return name;
}
std::pair<std::string, int> ParseTagIndex(const std::string& tag_index) {
std::string tag;
int index;
MEDIAPIPE_CHECK_OK(tool::ParseTagIndex(tag_index, &tag, &index));
return {tag, index};
}
std::pair<std::string, int> ParseTagIndexFromStream(const std::string& stream) {
std::string tag, name;
int index;

View File

@ -76,6 +76,9 @@ std::string CanonicalNodeName(const CalculatorGraphConfig& graph_config,
// Parses the name from a "tag:index:name".
std::string ParseNameFromStream(const std::string& stream);
// Parses the TagIndex from a "tag:index".
std::pair<std::string, int> ParseTagIndex(const std::string& tag_index);
// Parses the TagIndex from a "tag:index:name".
std::pair<std::string, int> ParseTagIndexFromStream(const std::string& stream);

View File

@ -13,15 +13,15 @@
# limitations under the License.
#
licenses(["notice"]) # Apache 2.0
package(default_visibility = ["//mediapipe:__subpackages__"])
load(
"//mediapipe/framework/tool:mediapipe_graph.bzl",
"mediapipe_simple_subgraph",
)
licenses(["notice"]) # Apache 2.0
package(default_visibility = ["//mediapipe:__subpackages__"])
filegroup(
name = "test_graph",
srcs = ["test.pbtxt"],
@ -31,6 +31,8 @@ exports_files([
"test.pbtxt",
"dub_quad_test_subgraph.pbtxt",
"nested_test_subgraph.pbtxt",
"single_flow_container_test.pbtxt",
"dual_flow_container_test.pbtxt",
])
mediapipe_simple_subgraph(

View File

@ -12,14 +12,14 @@
# See the License for the specific language governing permissions and
# limitations under the License.
licenses(["notice"]) # Apache 2.0
package(default_visibility = ["//visibility:public"])
load("//mediapipe/gpu:metal.bzl", "metal_library")
load("@build_bazel_rules_apple//apple:ios.bzl", "ios_unit_test")
load("//mediapipe/framework/port:build_config.bzl", "mediapipe_cc_proto_library")
licenses(["notice"]) # Apache 2.0
package(default_visibility = ["//visibility:public"])
# Disabling GPU support is sometimes useful on desktop Linux because SwiftShader can
# interfere with desktop GL. b/73494271
config_setting(

View File

@ -39,6 +39,7 @@ namespace mediapipe {
// ROTATION: the counterclockwise rotation angle in degrees. This allows
// user to specify different rotation angles for different frames. If this
// stream is provided, it will override the ROTATION input side packet.
// OUTPUT_DIMENSIONS: the output width and height in pixels.
// Additional output streams:
// TOP_BOTTOM_PADDING: If use FIT scale mode, this stream outputs the padding
// size of the input image in normalized value [0, 1] for top and bottom
@ -103,6 +104,9 @@ REGISTER_CALCULATOR(GlScalerCalculator);
if (cc->Inputs().HasTag("ROTATION")) {
cc->Inputs().Tag("ROTATION").Set<int>();
}
if (cc->Inputs().HasTag("OUTPUT_DIMENSIONS")) {
cc->Inputs().Tag("OUTPUT_DIMENSIONS").Set<DimensionsPacketType>();
}
MP_RETURN_IF_ERROR(GlCalculatorHelper::UpdateContract(cc));
if (cc->InputSidePackets().HasTag("OPTIONS")) {
@ -181,6 +185,18 @@ REGISTER_CALCULATOR(GlScalerCalculator);
}
::mediapipe::Status GlScalerCalculator::Process(CalculatorContext* cc) {
if (cc->Inputs().HasTag("OUTPUT_DIMENSIONS")) {
if (cc->Inputs().Tag("OUTPUT_DIMENSIONS").IsEmpty()) {
// OUTPUT_DIMENSIONS input stream is specified, but value is missing.
return ::mediapipe::OkStatus();
}
const auto& dimensions =
cc->Inputs().Tag("OUTPUT_DIMENSIONS").Get<DimensionsPacketType>();
dst_width_ = dimensions[0];
dst_height_ = dimensions[1];
}
return helper_.RunInGlContext([this, cc]() -> ::mediapipe::Status {
const auto& input = TagOrIndex(cc->Inputs(), "VIDEO", 0).Get<GpuBuffer>();
QuadRenderer* renderer = nullptr;
@ -199,7 +215,7 @@ REGISTER_CALCULATOR(GlScalerCalculator);
src1 = helper_.CreateSourceTexture(input, 0);
src2 = helper_.CreateSourceTexture(input, 1);
} else // NOLINT(readability/braces)
#endif // __APPLE__
#endif // __APPLE__
{
src1 = helper_.CreateSourceTexture(input);
#ifdef __ANDROID__
@ -211,7 +227,7 @@ REGISTER_CALCULATOR(GlScalerCalculator);
}
renderer = ext_rgb_renderer_.get();
} else // NOLINT(readability/braces)
#endif // __ANDROID__
#endif // __ANDROID__
{
if (!rgb_renderer_) {
rgb_renderer_ = absl::make_unique<QuadRenderer>();

View File

@ -140,6 +140,9 @@ node {
num_landmarks: 21
input_image_width: 256
input_image_height: 256
# The additional scaling factor is used to account for the Z coordinate
# distribution in the training data.
normalize_z: 0.4
}
}
}

View File

@ -144,6 +144,9 @@ node {
num_landmarks: 21
input_image_width: 256
input_image_height: 256
# The additional scaling factor is used to account for the Z coordinate
# distribution in the training data.
normalize_z: 0.4
}
}
}

View File

@ -25,6 +25,7 @@ android_library(
),
visibility = ["//visibility:public"],
deps = [
"//mediapipe/framework:calculator_java_proto_lite",
"//mediapipe/java/com/google/mediapipe/framework:android_framework",
"//mediapipe/java/com/google/mediapipe/glutil",
"//third_party:androidx_appcompat",

View File

@ -14,17 +14,21 @@
package com.google.mediapipe.components;
import static java.lang.Math.max;
import android.graphics.SurfaceTexture;
import android.opengl.GLES11Ext;
import android.opengl.GLES20;
import android.util.Log;
import com.google.mediapipe.framework.AppTextureFrame;
import com.google.mediapipe.framework.GlSyncToken;
import com.google.mediapipe.glutil.ExternalTextureRenderer;
import com.google.mediapipe.glutil.GlThread;
import com.google.mediapipe.glutil.ShaderUtil;
import java.util.ArrayDeque;
import java.util.ArrayList;
import java.util.Collections;
import java.util.List;
import java.util.Queue;
import javax.microedition.khronos.egl.EGLContext;
/**
@ -204,8 +208,11 @@ public class ExternalTextureConverter implements TextureFrameProducer {
private static final long NANOS_PER_MICRO = 1000; // Nanoseconds in one microsecond.
private volatile SurfaceTexture surfaceTexture = null;
private final List<TextureFrameConsumer> consumers;
private List<AppTextureFrame> outputFrames = null;
private int outputFrameIndex = -1;
private final Queue<PoolTextureFrame> framesAvailable = new ArrayDeque<>();
private int framesInUse = 0;
private final int framesToKeep;
private ExternalTextureRenderer renderer = null;
private long nextFrameTimestampOffset = 0;
private long timestampOffsetNanos = 0;
@ -215,10 +222,27 @@ public class ExternalTextureConverter implements TextureFrameProducer {
protected int destinationWidth = 0;
protected int destinationHeight = 0;
private class PoolTextureFrame extends AppTextureFrame {
public PoolTextureFrame(int textureName, int width, int height) {
super(textureName, width, height);
}
@Override
public void release(GlSyncToken syncToken) {
super.release(syncToken);
poolFrameReleased(this);
}
@Override
public void release() {
super.release();
poolFrameReleased(this);
}
}
public RenderThread(EGLContext parentContext, int numBuffers) {
super(parentContext);
outputFrames = new ArrayList<>();
outputFrames.addAll(Collections.nCopies(numBuffers, null));
framesToKeep = numBuffers;
renderer = new ExternalTextureRenderer();
consumers = new ArrayList<>();
}
@ -283,8 +307,8 @@ public class ExternalTextureConverter implements TextureFrameProducer {
@Override
public void releaseGl() {
setSurfaceTexture(null, 0, 0);
for (int i = 0; i < outputFrames.size(); ++i) {
teardownDestination(i);
while (!framesAvailable.isEmpty()) {
teardownFrame(framesAvailable.remove());
}
renderer.release();
super.releaseGl(); // This releases the EGL context, so must do it after any GL calls.
@ -337,16 +361,11 @@ public class ExternalTextureConverter implements TextureFrameProducer {
}
}
private void teardownDestination(int index) {
if (outputFrames.get(index) != null) {
waitUntilReleased(outputFrames.get(index));
GLES20.glDeleteTextures(1, new int[] {outputFrames.get(index).getTextureName()}, 0);
outputFrames.set(index, null);
}
private static void teardownFrame(AppTextureFrame frame) {
GLES20.glDeleteTextures(1, new int[] {frame.getTextureName()}, 0);
}
private void setupDestination(int index) {
teardownDestination(index);
private PoolTextureFrame createFrame() {
int destinationTextureId = ShaderUtil.createRgbaTexture(destinationWidth, destinationHeight);
Log.d(
TAG,
@ -354,11 +373,9 @@ public class ExternalTextureConverter implements TextureFrameProducer {
"Created output texture: %d width: %d height: %d",
destinationTextureId, destinationWidth, destinationHeight));
bindFramebuffer(destinationTextureId, destinationWidth, destinationHeight);
outputFrames.set(
index, new AppTextureFrame(destinationTextureId, destinationWidth, destinationHeight));
return new PoolTextureFrame(destinationTextureId, destinationWidth, destinationHeight);
}
/**
* Gets next available frame or creates new one if next frame is not initialized
* or cannot be used with current surface texture.
@ -371,20 +388,38 @@ public class ExternalTextureConverter implements TextureFrameProducer {
* NOTE: must be invoked on GL thread
*/
private AppTextureFrame nextOutputFrame() {
outputFrameIndex = (outputFrameIndex + 1) % outputFrames.size();
AppTextureFrame outputFrame = outputFrames.get(outputFrameIndex);
// Check if the size has changed.
if (outputFrame == null
|| outputFrame.getWidth() != destinationWidth
|| outputFrame.getHeight() != destinationHeight) {
// setupDestination will wait for the frame to be released before reallocating it.
setupDestination(outputFrameIndex);
outputFrame = outputFrames.get(outputFrameIndex);
PoolTextureFrame outputFrame;
synchronized (this) {
outputFrame = framesAvailable.poll();
framesInUse++;
}
if (outputFrame == null) {
outputFrame = createFrame();
} else if (outputFrame.getWidth() != destinationWidth
|| outputFrame.getHeight() != destinationHeight) {
// Create anew if size has changed.
// TODO: waiting for the consumer sync here may not be necessary.
waitUntilReleased(outputFrame);
teardownFrame(outputFrame);
outputFrame = createFrame();
} else {
// Note: waitUntilReleased does two things: waits for the frame to be released by the CPU,
// and syncs with the GPU sync token provided by the consumer. The first part is redundant
// here (and completes immediately), but the second part is still needed.
waitUntilReleased(outputFrame);
}
waitUntilReleased(outputFrame);
return outputFrame;
}
protected synchronized void poolFrameReleased(PoolTextureFrame frame) {
framesAvailable.offer(frame);
framesInUse--;
int keep = max(framesToKeep - framesInUse, 0);
while (framesAvailable.size() > keep) {
teardownFrame(framesAvailable.remove());
}
}
/**
* Updates output frame with current pixels of surface texture and corresponding timestamp.
*
@ -417,16 +452,22 @@ public class ExternalTextureConverter implements TextureFrameProducer {
Log.v(
TAG,
String.format(
"Waiting for tex: %d width: %d height: %d",
frame.getTextureName(), frame.getWidth(), frame.getHeight()));
"Waiting for tex: %d width: %d height: %d timestamp: %d",
frame.getTextureName(),
frame.getWidth(),
frame.getHeight(),
frame.getTimestamp()));
}
frame.waitUntilReleased();
if (Log.isLoggable(TAG, Log.VERBOSE)) {
Log.v(
TAG,
String.format(
"Finished waiting for tex: %d width: %d height: %d",
frame.getTextureName(), frame.getWidth(), frame.getHeight()));
"Finished waiting for tex: %d width: %d height: %d timestamp: %d",
frame.getTextureName(),
frame.getWidth(),
frame.getHeight(),
frame.getTimestamp()));
}
} catch (InterruptedException ie) {
// Someone interrupted our thread. This is not supposed to happen: we own

View File

@ -20,6 +20,7 @@ import android.media.AudioFormat;
import android.os.Handler;
import android.util.Log;
import com.google.common.base.Preconditions;
import com.google.mediapipe.proto.CalculatorProto.CalculatorGraphConfig;
import com.google.mediapipe.framework.AndroidAssetUtil;
import com.google.mediapipe.framework.AndroidPacketCreator;
import com.google.mediapipe.framework.Graph;
@ -32,10 +33,12 @@ import com.google.mediapipe.framework.SurfaceOutput;
import com.google.mediapipe.framework.TextureFrame;
import java.io.File;
import java.nio.ByteBuffer;
import java.util.ArrayDeque;
import java.util.ArrayList;
import java.util.Arrays;
import java.util.List;
import java.util.Map;
import java.util.Queue;
import java.util.concurrent.atomic.AtomicBoolean;
import javax.annotation.Nullable;
@ -106,6 +109,15 @@ public class FrameProcessor implements TextureFrameProcessor, AudioDataProcessor
initializeGraphAndPacketCreator(context, graphName);
}
/**
* Constructor.
*
* @param graphConfig the proto object representation of the graph.
*/
public FrameProcessor(CalculatorGraphConfig graphConfig) {
initializeGraphAndPacketCreator(graphConfig);
}
/**
* Initializes a graph for processing data in real time.
*
@ -123,6 +135,17 @@ public class FrameProcessor implements TextureFrameProcessor, AudioDataProcessor
packetCreator = new AndroidPacketCreator(mediapipeGraph);
}
/**
* Initializes a graph for processing data in real time.
*
* @param graphConfig the proto object representation of the graph.
*/
private void initializeGraphAndPacketCreator(CalculatorGraphConfig graphConfig) {
mediapipeGraph = new Graph();
mediapipeGraph.loadBinaryGraph(graphConfig);
packetCreator = new AndroidPacketCreator(mediapipeGraph);
}
/** Callback for errors occurring during processing in the graph. */
public interface ErrorListener {
void onError(RuntimeException error);
@ -186,6 +209,8 @@ public class FrameProcessor implements TextureFrameProcessor, AudioDataProcessor
currentConsumers = videoConsumers;
}
for (TextureFrameConsumer consumer : currentConsumers) {
// Note: each consumer will release its TextureFrame, so each gets a separate object
// (though they all reference the same data).
TextureFrame frame = PacketGetter.getTextureFrame(packet);
if (Log.isLoggable(TAG, Log.VERBOSE)) {
Log.v(
@ -373,9 +398,10 @@ public class FrameProcessor implements TextureFrameProcessor, AudioDataProcessor
/**
* Returns true if the MediaPipe graph can accept one more input frame.
*
* @throws MediaPipeException for any error status.
*/
private boolean maybeAcceptNewFrame() {
private boolean maybeAcceptNewFrame(long timestamp) {
if (!started.getAndSet(true)) {
startGraph();
}
@ -395,7 +421,7 @@ public class FrameProcessor implements TextureFrameProcessor, AudioDataProcessor
frame.getTextureName(), frame.getWidth(), frame.getHeight()));
}
if (!maybeAcceptNewFrame()) {
if (!maybeAcceptNewFrame(frame.getTimestamp())) {
return;
}
@ -451,7 +477,7 @@ public class FrameProcessor implements TextureFrameProcessor, AudioDataProcessor
public void onNewFrame(final Bitmap bitmap, long timestamp) {
Packet packet = null;
try {
if (!maybeAcceptNewFrame()) {
if (!maybeAcceptNewFrame(timestamp)) {
return;
}

View File

@ -17,8 +17,8 @@ package com.google.mediapipe.components;
import android.Manifest;
import android.app.Activity;
import android.content.pm.PackageManager;
import androidx.core.app.ActivityCompat;
import android.util.Log;
import androidx.core.app.ActivityCompat;
import androidx.core.content.ContextCompat;
/** Manages camera permission request and handling. */

View File

@ -18,6 +18,10 @@ import com.google.mediapipe.framework.TextureFrame;
/** Lightweight abstraction for an object that can receive video frames. */
public interface TextureFrameConsumer {
/** Called when a new {@link TextureFrame} is available. */
/**
* Called when a new {@link TextureFrame} is available.
*
* Important: implementations of this method should call frame.release().
**/
public abstract void onNewFrame(TextureFrame frame);
}

View File

@ -272,6 +272,10 @@ public final class PacketGetter {
* <p>Note: in order for the application to be able to use the texture, its GL context must be
* linked with MediaPipe's. This is ensured by calling {@link Graph#createGlRunner(String,long)}
* with the native handle to the application's GL context as the second argument.
*
* <p>The returned GraphTextureFrame must be released by the caller. If this method is called
* multiple times, each returned GraphTextureFrame is an independent reference to the underlying
* texture data, and must be released individually.
*/
public static GraphTextureFrame getTextureFrame(final Packet packet) {
return new GraphTextureFrame(

View File

@ -1,7 +0,0 @@
tricorder: {
options: {
builder: {
config: "android_arm"
}
}
}

View File

@ -33,22 +33,22 @@ struct GpuSharedData;
/// Provides the delegate with a new video frame.
@optional
- (void)mediapipeGraph:(MPPGraph*)graph
- (void)mediapipeGraph:(MPPGraph *)graph
didOutputPixelBuffer:(CVPixelBufferRef)pixelBuffer
fromStream:(const std::string&)streamName;
fromStream:(const std::string &)streamName;
/// Provides the delegate with a new video frame and time stamp.
@optional
- (void)mediapipeGraph:(MPPGraph*)graph
- (void)mediapipeGraph:(MPPGraph *)graph
didOutputPixelBuffer:(CVPixelBufferRef)pixelBuffer
fromStream:(const std::string&)streamName
timestamp:(const mediapipe::Timestamp&)timestamp;
fromStream:(const std::string &)streamName
timestamp:(const mediapipe::Timestamp &)timestamp;
/// Provides the delegate with a raw packet.
@optional
- (void)mediapipeGraph:(MPPGraph*)graph
didOutputPacket:(const mediapipe::Packet&)packet
fromStream:(const std::string&)streamName;
- (void)mediapipeGraph:(MPPGraph *)graph
didOutputPacket:(const mediapipe::Packet &)packet
fromStream:(const std::string &)streamName;
@end
@ -100,34 +100,34 @@ typedef NS_ENUM(int, MPPPacketType) {
/// Copies the config and initializes the graph.
/// @param config The configuration describing the graph.
- (instancetype)initWithGraphConfig:(const mediapipe::CalculatorGraphConfig&)config
- (instancetype)initWithGraphConfig:(const mediapipe::CalculatorGraphConfig &)config
NS_DESIGNATED_INITIALIZER;
- (mediapipe::ProfilingContext*)getProfiler;
- (mediapipe::ProfilingContext *)getProfiler;
/// Sets a stream header. If the header was already set, it is overwritten.
/// @param packet The header.
/// @param streamName The name of the stream.
- (void)setHeaderPacket:(const mediapipe::Packet&)packet forStream:(const std::string&)streamName;
- (void)setHeaderPacket:(const mediapipe::Packet &)packet forStream:(const std::string &)streamName;
/// Sets a side packet. If it was already set, it is overwritten.
/// Must be called before the graph is started.
/// @param packet The packet to be associated with the input side packet.
/// @param name The name of the input side packet.
- (void)setSidePacket:(const mediapipe::Packet&)packet named:(const std::string&)name;
- (void)setSidePacket:(const mediapipe::Packet &)packet named:(const std::string &)name;
/// Sets a service packet. If it was already set, it is overwritten.
/// Must be called before the graph is started.
/// @param packet The packet to be associated with the service.
/// @param service.
- (void)setServicePacket:(mediapipe::Packet&)packet
forService:(const mediapipe::GraphServiceBase&)service;
- (void)setServicePacket:(mediapipe::Packet &)packet
forService:(const mediapipe::GraphServiceBase &)service;
/// Adds input side packets from a map. Any inputs that were already set are
/// left unchanged.
/// Must be called before the graph is started.
/// @param extraInputSidePackets The input side packets to be added.
- (void)addSidePackets:(const std::map<std::string, mediapipe::Packet>&)extraSidePackets;
- (void)addSidePackets:(const std::map<std::string, mediapipe::Packet> &)extraSidePackets;
// TODO: rename to addDelegateOutputStream:packetType:
/// Add an output stream in the graph from which the delegate wants to receive
@ -135,30 +135,30 @@ typedef NS_ENUM(int, MPPPacketType) {
/// @param outputStreamName The name of the output stream from which
/// the delegate will receive frames.
/// @param packetType The type of packet provided by the output streams.
- (void)addFrameOutputStream:(const std::string&)outputStreamName
- (void)addFrameOutputStream:(const std::string &)outputStreamName
outputPacketType:(MPPPacketType)packetType;
/// Starts running the graph.
/// @return YES if successful.
- (BOOL)startWithError:(NSError**)error;
- (BOOL)startWithError:(NSError **)error;
/// Sends a generic packet into a graph input stream.
/// The graph must have been started before calling this.
/// Returns YES if the packet was successfully sent.
- (BOOL)sendPacket:(const mediapipe::Packet&)packet
intoStream:(const std::string&)streamName
error:(NSError**)error;
- (BOOL)sendPacket:(const mediapipe::Packet &)packet
intoStream:(const std::string &)streamName
error:(NSError **)error;
- (BOOL)movePacket:(mediapipe::Packet&&)packet
intoStream:(const std::string&)streamName
error:(NSError**)error;
- (BOOL)movePacket:(mediapipe::Packet &&)packet
intoStream:(const std::string &)streamName
error:(NSError **)error;
/// Sets the maximum queue size for a stream. Experimental feature, currently
/// only supported for graph input streams. Should be called before starting the
/// graph.
- (BOOL)setMaxQueueSize:(int)maxQueueSize
forStream:(const std::string&)streamName
error:(NSError**)error;
forStream:(const std::string &)streamName
error:(NSError **)error;
/// Creates a MediaPipe packet wrapping the given pixelBuffer;
- (mediapipe::Packet)packetWithPixelBuffer:(CVPixelBufferRef)pixelBuffer
@ -170,9 +170,9 @@ typedef NS_ENUM(int, MPPPacketType) {
/// allows MediaPipe to overwrite the packet contents on successful sending for
/// possibly increased efficiency. Returns YES if the packet was successfully sent.
- (BOOL)sendPixelBuffer:(CVPixelBufferRef)imageBuffer
intoStream:(const std::string&)inputName
intoStream:(const std::string &)inputName
packetType:(MPPPacketType)packetType
timestamp:(const mediapipe::Timestamp&)timestamp
timestamp:(const mediapipe::Timestamp &)timestamp
allowOverwrite:(BOOL)allowOverwrite;
/// Sends a pixel buffer into a graph input stream, using the specified packet
@ -180,9 +180,23 @@ typedef NS_ENUM(int, MPPPacketType) {
/// returns NO if maxFramesInFlight is exceeded. Returns YES if the packet was
/// successfully sent.
- (BOOL)sendPixelBuffer:(CVPixelBufferRef)pixelBuffer
intoStream:(const std::string&)inputName
intoStream:(const std::string &)inputName
packetType:(MPPPacketType)packetType
timestamp:(const mediapipe::Timestamp&)timestamp;
timestamp:(const mediapipe::Timestamp &)timestamp;
/// Sends a pixel buffer into a graph input stream, using the specified packet
/// type. The graph must have been started before calling this. Drops frames and
/// returns NO if maxFramesInFlight is exceeded. If allowOverwrite is set to YES,
/// allows MediaPipe to overwrite the packet contents on successful sending for
/// possibly increased efficiency. Returns YES if the packet was successfully sent.
/// Sets error to a non-nil value if an error occurs in the graph when sending the
/// packet.
- (BOOL)sendPixelBuffer:(CVPixelBufferRef)imageBuffer
intoStream:(const std::string &)inputName
packetType:(MPPPacketType)packetType
timestamp:(const mediapipe::Timestamp &)timestamp
allowOverwrite:(BOOL)allowOverwrite
error:(NSError **)error;
/// Sends a pixel buffer into a graph input stream, using the specified packet
/// type. The graph must have been started before calling this. The timestamp is
@ -190,32 +204,32 @@ typedef NS_ENUM(int, MPPPacketType) {
/// frames and returns NO if maxFramesInFlight is exceeded. Returns YES if the
/// packet was successfully sent.
- (BOOL)sendPixelBuffer:(CVPixelBufferRef)pixelBuffer
intoStream:(const std::string&)inputName
intoStream:(const std::string &)inputName
packetType:(MPPPacketType)packetType;
/// Cancels a graph run. You must still call waitUntilDoneWithError: after this.
- (void)cancel;
/// Check if the graph contains this input stream
- (BOOL)hasInputStream:(const std::string&)inputName;
- (BOOL)hasInputStream:(const std::string &)inputName;
/// Closes an input stream.
/// You must close all graph input streams before stopping the graph.
/// @return YES if successful.
- (BOOL)closeInputStream:(const std::string&)inputName error:(NSError**)error;
- (BOOL)closeInputStream:(const std::string &)inputName error:(NSError **)error;
/// Closes all graph input streams.
/// @return YES if successful.
- (BOOL)closeAllInputStreamsWithError:(NSError**)error;
- (BOOL)closeAllInputStreamsWithError:(NSError **)error;
/// Stops running the graph.
/// Call this before releasing this object. All input streams must have been
/// closed. This call does not time out, so you should not call it from the main
/// thread.
/// @return YES if successful.
- (BOOL)waitUntilDoneWithError:(NSError**)error;
- (BOOL)waitUntilDoneWithError:(NSError **)error;
/// Waits for the graph to become idle.
- (BOOL)waitUntilIdleWithError:(NSError**)error;
- (BOOL)waitUntilIdleWithError:(NSError **)error;
@end

View File

@ -327,22 +327,35 @@ void CallFrameDelegate(void* wrapperVoid, const std::string& streamName,
packetType:(MPPPacketType)packetType
timestamp:(const mediapipe::Timestamp&)timestamp
allowOverwrite:(BOOL)allowOverwrite {
NSError* error;
bool success = [self sendPixelBuffer:imageBuffer
intoStream:inputName
packetType:packetType
timestamp:timestamp
allowOverwrite:allowOverwrite
error:&error];
if (error) {
_GTMDevLog(@"failed to send packet: %@", error);
}
return success;
}
- (BOOL)sendPixelBuffer:(CVPixelBufferRef)imageBuffer
intoStream:(const std::string&)inputName
packetType:(MPPPacketType)packetType
timestamp:(const mediapipe::Timestamp&)timestamp
allowOverwrite:(BOOL)allowOverwrite
error:(NSError**)error {
if (_maxFramesInFlight && _framesInFlight >= _maxFramesInFlight) return NO;
mediapipe::Packet packet = [self packetWithPixelBuffer:imageBuffer packetType:packetType];
NSError* error;
BOOL success;
if (allowOverwrite) {
packet = std::move(packet).At(timestamp);
success = [self movePacket:std::move(packet)
intoStream:inputName
error:&error];
success = [self movePacket:std::move(packet) intoStream:inputName error:error];
} else {
success = [self sendPacket:packet.At(timestamp)
intoStream:inputName
error:&error];
success = [self sendPacket:packet.At(timestamp) intoStream:inputName error:error];
}
if (success) _framesInFlight++;
else _GTMDevLog(@"failed to send packet: %@", error);
return success;
}

View File

@ -423,6 +423,10 @@ tasks and tracking (or class) fields for tracking information.
|`region/point/y`|feature list float list|`add_bbox_point_y` / `AddBBoxPointY`|A list of normalized y values for points in a frame.|
|`region/point/\*`| *special* |`add_bbox_point` / `AddBBoxPoint`|Operates on point/x,point/y with a single call.|
|`region/point/radius`|feature list float list|`add_bbox_point_radius` / `AddBBoxPointRadius`|A list of radii for points in a frame.|
|`region/3d_point/x`|feature list float list|`add_bbox_3d_point_x` / `AddBBox3dPointX`|A list of normalized x values for points in a frame.|
|`region/3d_point/y`|feature list float list|`add_bbox_3d_point_y` / `AddBBox3dPointY`|A list of normalized y values for points in a frame.|
|`region/3d_point/z`|feature list float list|`add_bbox_3d_point_z` / `AddBBox3dPointZ`|A list of normalized z values for points in a frame.|
|`region/3d_point/\*`| *special* |`add_bbox_3d_point` / `AddBBox3dPoint`|Operates on 3d_point/{x,y,z} with a single call.|
|`region/timestamp`|feature list int|`add_bbox_timestamp` / `AddBBoxTimestamp`|The timestamp in microseconds for the region annotations.|
|`region/num_regions`|feature list int|`add_bbox_num_regions` / `AddBBoxNumRegions`|The number of boxes or other regions in a frame. Should be 0 for unannotated frames.|
|`region/is_annotated`|feature list int|`add_bbox_is_annotated` / `AddBBoxIsAnnotated`|1 if this timestep is annotated. 0 otherwise. Distinguishes empty from unannotated frames.|

View File

@ -229,6 +229,18 @@ float TimestampsToRate(int64 first_timestamp, int64 second_timestamp) {
sequence);
}
}
if (Get3dPointSize(prefix, *sequence) > 0) {
std::string x_key = merge_prefix(prefix, kRegion3dPointXKey);
auto* region_feature_list = MutableFeatureList(x_key, sequence);
RET_CHECK_EQ(num_bboxes, region_feature_list->feature_size())
<< "Expected number of BBox timestamps and boxes to match.";
ClearBBoxNumRegions(prefix, sequence);
for (int i = 0; i < num_bboxes; ++i) {
AddBBoxNumRegions(
prefix, region_feature_list->feature(i).float_list().value_size(),
sequence);
}
}
// Collect which timestamps currently match to which indices in timestamps.
// skip empty timestamps.
// Requires sorted indices.
@ -453,6 +465,47 @@ void ClearPoint(const std::string& prefix,
ClearBBoxPointX(prefix, sequence);
}
int Get3dPointSize(const std::string& prefix,
const tensorflow::SequenceExample& sequence) {
return GetBBox3dPointXSize(prefix, sequence);
}
std::vector<::std::tuple<float, float, float>> Get3dPointAt(
const std::string& prefix, const tensorflow::SequenceExample& sequence,
int index) {
const auto& xs = GetBBox3dPointXAt(prefix, sequence, index);
const auto& ys = GetBBox3dPointYAt(prefix, sequence, index);
const auto& zs = GetBBox3dPointZAt(prefix, sequence, index);
std::vector<::std::tuple<float, float, float>> points(ys.size());
for (int i = 0; i < xs.size(); ++i) {
points[i] = std::make_tuple(xs[i], ys[i], zs[i]);
}
return points;
}
void Add3dPoint(const std::string& prefix,
const std::vector<::std::tuple<float, float, float>>& points,
tensorflow::SequenceExample* sequence) {
::std::vector<float> xs;
::std::vector<float> ys;
::std::vector<float> zs;
for (auto& point : points) {
xs.push_back(std::get<0>(point));
ys.push_back(std::get<1>(point));
zs.push_back(std::get<2>(point));
}
AddBBox3dPointX(prefix, xs, sequence);
AddBBox3dPointY(prefix, ys, sequence);
AddBBox3dPointZ(prefix, zs, sequence);
}
void Clear3dPoint(const std::string& prefix,
tensorflow::SequenceExample* sequence) {
ClearBBox3dPointX(prefix, sequence);
ClearBBox3dPointY(prefix, sequence);
ClearBBox3dPointZ(prefix, sequence);
}
std::unique_ptr<mediapipe::Matrix> GetAudioFromFeatureAt(
const std::string& prefix, const tensorflow::SequenceExample& sequence,
int index) {

View File

@ -268,6 +268,10 @@ const char kRegionBBoxXMaxKey[] = "region/bbox/xmax";
const char kRegionPointXKey[] = "region/point/x";
const char kRegionPointYKey[] = "region/point/y";
const char kRegionRadiusKey[] = "region/radius";
// The 3d point can denote keypoints.
const char kRegion3dPointXKey[] = "region/3d_point/x";
const char kRegion3dPointYKey[] = "region/3d_point/y";
const char kRegion3dPointZKey[] = "region/3d_point/z";
// The number of regions at that timestep.
const char kRegionNumRegionsKey[] = "region/num_regions";
// Whether that timestep is annotated for bounding regions.
@ -333,61 +337,111 @@ void AddPoint(const std::string& prefix,
void ClearPoint(const std::string& prefix,
tensorflow::SequenceExample* sequence);
#define FIXED_PREFIX_BBOX_ACCESSORS(identifier, prefix) \
inline int CONCAT_STR3(Get, identifier, \
Size)(const tensorflow::SequenceExample& sequence) { \
return GetBBoxSize(prefix, sequence); \
} \
inline std::vector<::mediapipe::Location> CONCAT_STR3(Get, identifier, At)( \
const tensorflow::SequenceExample& sequence, int index) { \
return GetBBoxAt(prefix, sequence, index); \
} \
inline void CONCAT_STR2(Add, identifier)( \
const std::vector<::mediapipe::Location>& bboxes, \
tensorflow::SequenceExample* sequence) { \
return AddBBox(prefix, bboxes, sequence); \
} \
inline void CONCAT_STR2( \
Clear, identifier)(tensorflow::SequenceExample * sequence) { \
return ClearBBox(prefix, sequence); \
} \
inline int CONCAT_STR3(Get, identifier, PointSize)( \
const tensorflow::SequenceExample& sequence) { \
return GetPointSize(prefix, sequence); \
} \
inline int CONCAT_STR3(Get, identifier, PointSize)( \
const std::string& name, const tensorflow::SequenceExample& sequence) { \
return GetPointSize(name, sequence); \
} \
inline std::vector<std::pair<float, float>> CONCAT_STR3( \
Get, identifier, PointAt)(const tensorflow::SequenceExample& sequence, \
int index) { \
return GetPointAt(prefix, sequence, index); \
} \
inline std::vector<std::pair<float, float>> CONCAT_STR3( \
Get, identifier, PointAt)(const std::string& name, \
const tensorflow::SequenceExample& sequence, \
int index) { \
return GetPointAt(name, sequence, index); \
} \
inline void CONCAT_STR3(Add, identifier, Point)( \
const std::vector<std::pair<float, float>>& points, \
tensorflow::SequenceExample* sequence) { \
return AddPoint(prefix, points, sequence); \
} \
inline void CONCAT_STR3(Add, identifier, Point)( \
const std::string& name, \
const std::vector<std::pair<float, float>>& points, \
tensorflow::SequenceExample* sequence) { \
return AddPoint(name, points, sequence); \
} \
inline void CONCAT_STR3(Clear, identifier, \
Point)(tensorflow::SequenceExample * sequence) { \
return ClearPoint(prefix, sequence); \
} \
inline void CONCAT_STR3(Clear, identifier, Point)( \
std::string name, tensorflow::SequenceExample * sequence) { \
return ClearPoint(name, sequence); \
// The input and output format is a pair of <y, x> coordinates to match the
// order of bounding box coordinates.
int Get3dPointSize(const std::string& prefix,
const tensorflow::SequenceExample& sequence);
std::vector<std::tuple<float, float, float>> Get3dPointAt(
const std::string& prefix, const tensorflow::SequenceExample& sequence,
int index);
void Add3dPoint(const std::string& prefix,
const std::vector<std::tuple<float, float, float>>& points,
tensorflow::SequenceExample* sequence);
void Clear3dPoint(const std::string& prefix,
tensorflow::SequenceExample* sequence);
#define FIXED_PREFIX_BBOX_ACCESSORS(identifier, prefix) \
inline int CONCAT_STR3(Get, identifier, \
Size)(const tensorflow::SequenceExample& sequence) { \
return GetBBoxSize(prefix, sequence); \
} \
inline std::vector<::mediapipe::Location> CONCAT_STR3(Get, identifier, At)( \
const tensorflow::SequenceExample& sequence, int index) { \
return GetBBoxAt(prefix, sequence, index); \
} \
inline void CONCAT_STR2(Add, identifier)( \
const std::vector<::mediapipe::Location>& bboxes, \
tensorflow::SequenceExample* sequence) { \
return AddBBox(prefix, bboxes, sequence); \
} \
inline void CONCAT_STR2( \
Clear, identifier)(tensorflow::SequenceExample * sequence) { \
return ClearBBox(prefix, sequence); \
} \
inline int CONCAT_STR3(Get, identifier, PointSize)( \
const tensorflow::SequenceExample& sequence) { \
return GetPointSize(prefix, sequence); \
} \
inline int CONCAT_STR3(Get, identifier, PointSize)( \
const std::string& name, const tensorflow::SequenceExample& sequence) { \
return GetPointSize(name, sequence); \
} \
inline std::vector<std::pair<float, float>> CONCAT_STR3( \
Get, identifier, PointAt)(const tensorflow::SequenceExample& sequence, \
int index) { \
return GetPointAt(prefix, sequence, index); \
} \
inline std::vector<std::pair<float, float>> CONCAT_STR3( \
Get, identifier, PointAt)(const std::string& name, \
const tensorflow::SequenceExample& sequence, \
int index) { \
return GetPointAt(name, sequence, index); \
} \
inline void CONCAT_STR3(Add, identifier, Point)( \
const std::vector<std::pair<float, float>>& points, \
tensorflow::SequenceExample* sequence) { \
return AddPoint(prefix, points, sequence); \
} \
inline void CONCAT_STR3(Add, identifier, Point)( \
const std::string& name, \
const std::vector<std::pair<float, float>>& points, \
tensorflow::SequenceExample* sequence) { \
return AddPoint(name, points, sequence); \
} \
inline void CONCAT_STR3(Clear, identifier, \
Point)(tensorflow::SequenceExample * sequence) { \
return ClearPoint(prefix, sequence); \
} \
inline void CONCAT_STR3(Clear, identifier, Point)( \
std::string name, tensorflow::SequenceExample * sequence) { \
return ClearPoint(name, sequence); \
} \
inline int CONCAT_STR3(Get, identifier, 3dPointSize)( \
const tensorflow::SequenceExample& sequence) { \
return Get3dPointSize(prefix, sequence); \
} \
inline int CONCAT_STR3(Get, identifier, 3dPointSize)( \
const std::string& name, const tensorflow::SequenceExample& sequence) { \
return Get3dPointSize(name, sequence); \
} \
inline std::vector<std::tuple<float, float, float>> CONCAT_STR3( \
Get, identifier, 3dPointAt)(const tensorflow::SequenceExample& sequence, \
int index) { \
return Get3dPointAt(prefix, sequence, index); \
} \
inline std::vector<std::tuple<float, float, float>> CONCAT_STR3( \
Get, identifier, 3dPointAt)(const std::string& name, \
const tensorflow::SequenceExample& sequence, \
int index) { \
return Get3dPointAt(name, sequence, index); \
} \
inline void CONCAT_STR3(Add, identifier, 3dPoint)( \
const std::vector<std::tuple<float, float, float>>& points, \
tensorflow::SequenceExample* sequence) { \
return Add3dPoint(prefix, points, sequence); \
} \
inline void CONCAT_STR3(Add, identifier, 3dPoint)( \
const std::string& name, \
const std::vector<std::tuple<float, float, float>>& points, \
tensorflow::SequenceExample* sequence) { \
return Add3dPoint(name, points, sequence); \
} \
inline void CONCAT_STR3(Clear, identifier, \
3dPoint)(tensorflow::SequenceExample * sequence) { \
return Clear3dPoint(prefix, sequence); \
} \
inline void CONCAT_STR3(Clear, identifier, 3dPoint)( \
std::string name, tensorflow::SequenceExample * sequence) { \
return Clear3dPoint(name, sequence); \
}
#define PREFIXED_BBOX(identifier, prefix) \
@ -435,6 +489,12 @@ void ClearPoint(const std::string& prefix,
kRegionPointYKey, prefix) \
FIXED_PREFIX_VECTOR_FLOAT_FEATURE_LIST(CONCAT_STR2(identifier, Radius), \
kRegionRadiusKey, prefix) \
FIXED_PREFIX_VECTOR_FLOAT_FEATURE_LIST(CONCAT_STR2(identifier, 3dPointX), \
kRegion3dPointXKey, prefix) \
FIXED_PREFIX_VECTOR_FLOAT_FEATURE_LIST(CONCAT_STR2(identifier, 3dPointY), \
kRegion3dPointYKey, prefix) \
FIXED_PREFIX_VECTOR_FLOAT_FEATURE_LIST(CONCAT_STR2(identifier, 3dPointZ), \
kRegion3dPointZKey, prefix) \
FIXED_PREFIX_VECTOR_FLOAT_FEATURE_LIST( \
CONCAT_STR2(identifier, EmbeddingFloats), kRegionEmbeddingFloatKey, \
prefix) \

View File

@ -262,6 +262,10 @@ REGION_BBOX_XMAX_KEY = "region/bbox/xmax"
REGION_POINT_X_KEY = "region/point/x"
REGION_POINT_Y_KEY = "region/point/y"
REGION_RADIUS_KEY = "region/radius"
# The 3D point can denote keypoints.
REGION_3D_POINT_X_KEY = "region/3d_point/x"
REGION_3D_POINT_Y_KEY = "region/3d_point/y"
REGION_3D_POINT_Z_KEY = "region/3d_point/z"
# The number of regions at that timestep.
REGION_NUM_REGIONS_KEY = "region/num_regions"
# Whether that timestep is annotated for regions.
@ -365,6 +369,15 @@ def _create_region_with_prefix(name, prefix):
prefix=prefix, module_dict=globals())
msu.create_float_list_feature_list(name + "_point_y", REGION_POINT_Y_KEY,
prefix=prefix, module_dict=globals())
msu.create_float_list_feature_list(
name + "_3d_point_x", REGION_3D_POINT_X_KEY,
prefix=prefix, module_dict=globals())
msu.create_float_list_feature_list(
name + "_3d_point_y", REGION_3D_POINT_Y_KEY,
prefix=prefix, module_dict=globals())
msu.create_float_list_feature_list(
name + "_3d_point_z", REGION_3D_POINT_Z_KEY,
prefix=prefix, module_dict=globals())
msu.create_bytes_list_context_feature(name + "_parts",
REGION_PARTS_KEY,
prefix=prefix, module_dict=globals())
@ -406,6 +419,39 @@ def _create_region_with_prefix(name, prefix):
clear_bbox_xmin(sequence_example, prefix=prefix)
clear_bbox_ymax(sequence_example, prefix=prefix)
clear_bbox_xmax(sequence_example, prefix=prefix)
def get_prefixed_point_at(index, sequence_example, prefix):
return np.stack((
get_bbox_point_y_at(index, sequence_example, prefix=prefix),
get_bbox_point_x_at(index, sequence_example, prefix=prefix)),
1)
def add_prefixed_point(values, sequence_example, prefix):
add_bbox_point_y(values[:, 0], sequence_example, prefix=prefix)
add_bbox_point_x(values[:, 1], sequence_example, prefix=prefix)
def get_prefixed_point_size(sequence_example, prefix):
return get_bbox_point_y_size(sequence_example, prefix=prefix)
def has_prefixed_point(sequence_example, prefix):
return has_bbox_point_y(sequence_example, prefix=prefix)
def clear_prefixed_point(sequence_example, prefix):
clear_bbox_point_y(sequence_example, prefix=prefix)
clear_bbox_point_x(sequence_example, prefix=prefix)
def get_prefixed_3d_point_at(index, sequence_example, prefix):
return np.stack((
get_bbox_3d_point_x_at(index, sequence_example, prefix=prefix),
get_bbox_3d_point_y_at(index, sequence_example, prefix=prefix),
get_bbox_3d_point_z_at(index, sequence_example, prefix=prefix)),
1)
def add_prefixed_3d_point(values, sequence_example, prefix):
add_bbox_3d_point_x(values[:, 0], sequence_example, prefix=prefix)
add_bbox_3d_point_y(values[:, 1], sequence_example, prefix=prefix)
add_bbox_3d_point_z(values[:, 2], sequence_example, prefix=prefix)
def get_prefixed_3d_point_size(sequence_example, prefix):
return get_bbox_3d_point_x_size(sequence_example, prefix=prefix)
def has_prefixed_3d_point(sequence_example, prefix):
return has_bbox_3d_point_x(sequence_example, prefix=prefix)
def clear_prefixed_3d_point(sequence_example, prefix):
clear_bbox_3d_point_x(sequence_example, prefix=prefix)
clear_bbox_3d_point_y(sequence_example, prefix=prefix)
clear_bbox_3d_point_z(sequence_example, prefix=prefix)
# pylint: enable=undefined-variable
msu.add_functions_to_module({
"get_" + name + "_at":
@ -419,6 +465,30 @@ def _create_region_with_prefix(name, prefix):
"clear_" + name:
functools.partial(clear_prefixed_bbox, prefix=prefix),
}, module_dict=globals())
msu.add_functions_to_module({
"get_" + name + "_point_at":
functools.partial(get_prefixed_point_at, prefix=prefix),
"add_" + name + "_point":
functools.partial(add_prefixed_point, prefix=prefix),
"get_" + name + "_point_size":
functools.partial(get_prefixed_point_size, prefix=prefix),
"has_" + name + "_point":
functools.partial(has_prefixed_point, prefix=prefix),
"clear_" + name + "_point":
functools.partial(clear_prefixed_point, prefix=prefix),
}, module_dict=globals())
msu.add_functions_to_module({
"get_" + name + "_3d_point_at":
functools.partial(get_prefixed_3d_point_at, prefix=prefix),
"add_" + name + "_3d_point":
functools.partial(add_prefixed_3d_point, prefix=prefix),
"get_" + name + "_3d_point_size":
functools.partial(get_prefixed_3d_point_size, prefix=prefix),
"has_" + name + "_3d_point":
functools.partial(has_prefixed_3d_point, prefix=prefix),
"clear_" + name + "_3d_point":
functools.partial(clear_prefixed_3d_point, prefix=prefix),
}, module_dict=globals())
PREDICTED_PREFIX = "PREDICTED"

View File

@ -436,6 +436,21 @@ TEST(MediaSequenceTest, RoundTripBBoxPointPrefixed) {
}
}
TEST(MediaSequenceTest, RoundTripBBox3dPoint) {
tensorflow::SequenceExample sequence;
std::vector<std::vector<std::tuple<float, float, float>>> points = {
{std::make_tuple(0.3, 0.5, 0.1), std::make_tuple(0.4, 0.7, 0.2)},
{std::make_tuple(0.7, 0.5, 0.3), std::make_tuple(0.3, 0.4, 0.4)}};
for (int i = 0; i < points.size(); ++i) {
AddBBox3dPoint(points[i], &sequence);
ASSERT_EQ(GetBBox3dPointSize(sequence), i + 1);
const auto& sequence_points = GetBBox3dPointAt(sequence, i);
for (int j = 0; j < sequence_points.size(); ++j) {
EXPECT_EQ(sequence_points[j], points[i][j]);
}
}
}
TEST(MediaSequenceTest, RoundTripRegionParts) {
tensorflow::SequenceExample sequence;
std::vector<std::string> parts = {"HEAD", "FEET"};

View File

@ -89,6 +89,9 @@ class MediaSequenceTest(tf.test.TestCase):
ms.add_bbox_xmax((0.47, 0.49), example)
ms.add_bbox_point_x((0.47, 0.49), example)
ms.add_bbox_point_y((0.47, 0.49), example)
ms.add_bbox_3d_point_x((0.47, 0.49), example)
ms.add_bbox_3d_point_y((0.47, 0.49), example)
ms.add_bbox_3d_point_z((0.47, 0.49), example)
ms.add_predicted_bbox_ymin((0.47, 0.49), example)
ms.add_predicted_bbox_xmin((0.47, 0.49), example)
ms.add_predicted_bbox_ymax((0.47, 0.49), example)
@ -133,6 +136,30 @@ class MediaSequenceTest(tf.test.TestCase):
ms.clear_bbox(example)
self.assertEqual(0, ms.get_bbox_size(example))
def test_point_round_trip(self):
example = tf.train.SequenceExample()
points = np.array([[0.1, 0.2],
[0.5, 0.6]])
ms.add_bbox_point(points, example)
ms.add_bbox_point(points, example)
self.assertEqual(2, ms.get_bbox_point_size(example))
self.assertAllClose(points, ms.get_bbox_point_at(0, example))
self.assertTrue(ms.has_bbox_point(example))
ms.clear_bbox_point(example)
self.assertEqual(0, ms.get_bbox_point_size(example))
def test_3d_point_round_trip(self):
example = tf.train.SequenceExample()
points = np.array([[0.1, 0.2, 0.3],
[0.5, 0.6, 0.7]])
ms.add_bbox_3d_point(points, example)
ms.add_bbox_3d_point(points, example)
self.assertEqual(2, ms.get_bbox_3d_point_size(example))
self.assertAllClose(points, ms.get_bbox_3d_point_at(0, example))
self.assertTrue(ms.has_bbox_3d_point(example))
ms.clear_bbox_3d_point(example)
self.assertEqual(0, ms.get_bbox_3d_point_size(example))
def test_predicted_bbox_round_trip(self):
example = tf.train.SequenceExample()
boxes = np.array([[0.1, 0.2, 0.3, 0.4],

View File

@ -19,6 +19,14 @@ package(default_visibility = [
"//mediapipe:__subpackages__",
])
cc_library(
name = "config",
hdrs = ["config.h"],
deps = [
"//mediapipe/framework:calculator_framework",
],
)
cc_library(
name = "cpu_op_resolver",
srcs = ["cpu_op_resolver.cc"],
@ -69,6 +77,7 @@ cc_test(
srcs = ["tensor_buffer_test.cc"],
deps = [
":tensor_buffer",
":config",
"//mediapipe/framework/port:gtest_main",
] + select({
"//mediapipe/gpu:disable_gpu": [],
@ -99,6 +108,7 @@ cc_library(
"@org_tensorflow//tensorflow/lite:framework",
"@org_tensorflow//tensorflow/lite/delegates/gpu:api",
"@org_tensorflow//tensorflow/lite/delegates/gpu/common:model",
"@org_tensorflow//tensorflow/lite/delegates/gpu/common/testing:tflite_model_reader",
"@org_tensorflow//tensorflow/lite/delegates/gpu/gl:api2",
],
"//mediapipe:android": [
@ -108,7 +118,9 @@ cc_library(
"//mediapipe/framework/port:statusor",
"@org_tensorflow//tensorflow/lite:framework",
"@org_tensorflow//tensorflow/lite/delegates/gpu:api",
"@org_tensorflow//tensorflow/lite/delegates/gpu/cl:api",
"@org_tensorflow//tensorflow/lite/delegates/gpu/common:model",
"@org_tensorflow//tensorflow/lite/delegates/gpu/common/testing:tflite_model_reader",
"@org_tensorflow//tensorflow/lite/delegates/gpu/gl:api2",
],
}) + ["@org_tensorflow//tensorflow/lite/core/api"],

View File

@ -0,0 +1,59 @@
// Copyright 2020 The MediaPipe Authors.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#ifndef MEDIAPIPE_UTIL_TFLITE_CONFIG_H_
#define MEDIAPIPE_UTIL_TFLITE_CONFIG_H_
#include "mediapipe/framework/calculator_framework.h"
// MediaPipe code should use the following defines to determine whether TFLite
// GPU support is available, and whether GL or Metal inference is available.
#ifdef MEDIAPIPE_DISABLE_GL_COMPUTE
#define MEDIAPIPE_TFLITE_GL_INFERENCE 0
#else
#define MEDIAPIPE_TFLITE_GL_INFERENCE 1
#endif // MEDIAPIPE_DISABLE_GL_COMPUTE
#ifdef MEDIAPIPE_IOS
#define MEDIAPIPE_TFLITE_METAL_INFERENCE 1
#else
#define MEDIAPIPE_TFLITE_METAL_INFERENCE 0
#endif // MEDIAPIPE_IOS
#define MEDIAPIPE_TFLITE_GPU_SUPPORTED \
((MEDIAPIPE_TFLITE_GL_INFERENCE) || (MEDIAPIPE_TFLITE_METAL_INFERENCE))
#if MEDIAPIPE_TFLITE_GL_INFERENCE
#include "tensorflow/lite/delegates/gpu/gl/gl_buffer.h"
#endif // MEDIAPIPE_TFLITE_GL_INFERENCE
#if MEDIAPIPE_TFLITE_METAL_INFERENCE
#import <Metal/Metal.h>
#endif // MEDIAPIPE_TFLITE_METAL_INFERENCE
namespace mediapipe {
#if MEDIAPIPE_TFLITE_GL_INFERENCE
typedef ::tflite::gpu::gl::GlBuffer GpuTensor;
#elif MEDIAPIPE_TFLITE_METAL_INFERENCE
typedef id<MTLBuffer> GpuTensor;
#else
struct DummyGpuTensor {};
typedef DummyGpuTensor GpuTensor; // Dummy define for less #ifdefs
#endif // MEDIAPIPE_TFLITE_GL_INFERENCE
} // namespace mediapipe
#endif // MEDIAPIPE_UTIL_TFLITE_CONFIG_H_

View File

@ -130,11 +130,10 @@ TfLiteStatus Prepare(TfLiteContext* context, TfLiteNode* node) {
auto padding = params->padding;
auto compute_out_size = [padding](int image_size, int filter_size,
int stride) -> int {
return padding == kTfLitePaddingSame
? (image_size + stride - 1) / stride
: padding == kTfLitePaddingValid
? (image_size - filter_size + stride) / stride
: 0;
return padding == kTfLitePaddingSame ? (image_size + stride - 1) / stride
: padding == kTfLitePaddingValid
? (image_size - filter_size + stride) / stride
: 0;
};
int out_width =

View File

@ -2,6 +2,7 @@
#include "mediapipe/framework/port/gmock.h"
#include "mediapipe/framework/port/gtest.h"
#include "mediapipe/util/tflite/config.h"
namespace mediapipe {
@ -12,7 +13,7 @@ TEST(Cpu, BasicTest) {
EXPECT_FALSE(tb.UsesGpu());
}
#if !defined(MEDIAPIPE_DISABLE_GPU)
#if MEDIAPIPE_TFLITE_GL_INFERENCE
TEST(Gpu, BasicTest) {
TensorBuffer tb;
std::shared_ptr<tflite::gpu::gl::GlBuffer> tfg_tb =
@ -20,7 +21,7 @@ TEST(Gpu, BasicTest) {
tb = TensorBuffer(tfg_tb);
EXPECT_TRUE(tb.UsesGpu());
}
#endif // !MEDIAPIPE_DISABLE_GPU
#endif // !MEDIAPIPE_TFLITE_GL_INFERENCE
} // namespace mediapipe

View File

@ -30,6 +30,13 @@
#include "tensorflow/lite/delegates/gpu/gl/api2.h"
#include "tensorflow/lite/model.h"
// This code should be enabled as soon as TensorFlow version, which mediapipe
// uses, will include this module.
#ifdef __ANDROID__
#include "tensorflow/lite/delegates/gpu/cl/api.h"
#endif
#include "tensorflow/lite/delegates/gpu/common/testing/tflite_model_reader.h"
namespace tflite {
namespace gpu {
namespace {
@ -51,6 +58,19 @@ ObjectDef GetSSBOObjectDef(int channels) {
mediapipe::Status TFLiteGPURunner::InitializeWithModel(
const tflite::FlatBufferModel& flatbuffer,
const tflite::OpResolver& op_resolver) {
// GraphFloat32 is created twice because, when OpenCL and OpenGL backends are
// initialized, different backend-specific graph transformations happen
// in-place. As GraphFloat32 is not copyable by design, we keep two copies of
// the graph until inference is built. This decision doesn't affect the amount
// of run time memory used, because both graph_gl_ and graph_cl_ are deleted
// in the end of the initialization stage.
graph_gl_ = std::make_unique<GraphFloat32>();
graph_cl_ = std::make_unique<GraphFloat32>();
MP_RETURN_IF_ERROR(
BuildFromFlatBuffer(flatbuffer, op_resolver, graph_gl_.get()));
MP_RETURN_IF_ERROR(
BuildFromFlatBuffer(flatbuffer, op_resolver, graph_cl_.get()));
for (const auto& input : graph_gl_->inputs()) {
input_shapes_.push_back(input->tensor.shape);
}
@ -140,6 +160,19 @@ mediapipe::Status TFLiteGPURunner::InitializeOpenGL(
absl::Status TFLiteGPURunner::InitializeOpenCL(
std::unique_ptr<InferenceBuilder>* builder) {
#ifdef __ANDROID__
cl::InferenceEnvironmentOptions env_options;
cl::InferenceEnvironmentProperties properties;
cl::InferenceOptions cl_options;
cl_options.priority1 = options_.priority1;
cl_options.priority2 = options_.priority2;
cl_options.priority3 = options_.priority3;
cl_options.usage = options_.usage;
MP_RETURN_IF_ERROR(
cl::NewInferenceEnvironment(env_options, &cl_environment_, &properties));
MP_RETURN_IF_ERROR(cl_environment_->NewInferenceBuilder(
cl_options, std::move(*graph_cl_), builder));
#endif
return absl::OkStatus();
}

View File

@ -27,6 +27,10 @@
#include "tensorflow/lite/delegates/gpu/gl/api2.h"
#include "tensorflow/lite/model.h"
#ifdef __ANDROID__
#include "tensorflow/lite/delegates/gpu/cl/api.h"
#endif
namespace tflite {
namespace gpu {
@ -64,6 +68,9 @@ class TFLiteGPURunner {
mediapipe::Status Build();
mediapipe::Status Invoke();
std::vector<BHWC> GetInputShapes() { return input_shapes_; }
std::vector<BHWC> GetOutputShapes() { return output_shapes_; }
private:
mediapipe::Status InitializeOpenGL(
std::unique_ptr<InferenceBuilder>* builder);
@ -73,6 +80,10 @@ class TFLiteGPURunner {
InferenceOptions options_;
std::unique_ptr<gl::InferenceEnvironment> gl_environment_;
#ifdef __ANDROID__
std::unique_ptr<cl::InferenceEnvironment> cl_environment_;
#endif
// graph_ is maintained temporarily and becomes invalid after runner_ is ready
std::unique_ptr<GraphFloat32> graph_gl_;
std::unique_ptr<GraphFloat32> graph_cl_;

View File

@ -236,7 +236,7 @@ inline void CheckAndSetInvokerOptions() {
LOG(WARNING) << "Unsupported invoker mode selected on Android. "
<< "OpenMP linkage detected, so falling back to OpenMP";
flags_parallel_invoker_mode = PARALLEL_INVOKER_OPENMP;
#else // _OPENMP
#else // _OPENMP
// Fallback mode for active parallel invoker without OpenMP is ThreadPool.
LOG(WARNING) << "Unsupported invoker mode selected on Android. "
<< "Falling back to ThreadPool";
@ -273,7 +273,7 @@ inline void CheckAndSetInvokerOptions() {
#endif // _OPENMP
}
#else // PARALLEL_INVOKER_ACTIVE
#else // PARALLEL_INVOKER_ACTIVE
if (flags_parallel_invoker_mode != PARALLEL_INVOKER_NONE) {
LOG(ERROR) << "Parallel execution requested but PARALLEL_INVOKER_ACTIVE "
<< "compile flag is not set. Falling back to single threaded "

View File

@ -2082,8 +2082,8 @@ void RegionFlowComputation::WideBaselineMatchFeatures(
!defined(CV_WRAPPER_3X)
LOG(FATAL) << "Supported on only with OpenCV 3.0. "
<< "Use bazel build flag : --define CV_WRAPPER=3X";
#else // (defined(__ANDROID__) || defined(__APPLE__) ||
// defined(__EMSCRIPTEN__)) && !defined(CV_WRAPPER_3X)
#else // (defined(__ANDROID__) || defined(__APPLE__) ||
// defined(__EMSCRIPTEN__)) && !defined(CV_WRAPPER_3X)
results->clear();
const auto& frame1 = from_data_ptr->frame;

View File

@ -0,0 +1,50 @@
diff --git a/googletest/include/gtest/internal/gtest-internal.h b/googletest/include/gtest/internal/gtest-internal.h
index 7f1a5b00e..c36029ee1 100644
--- a/googletest/include/gtest/internal/gtest-internal.h
+++ b/googletest/include/gtest/internal/gtest-internal.h
@@ -94,6 +94,12 @@ namespace proto2 {
class MessageLite;
}
+namespace google {
+namespace protobuf {
+class MessageLite;
+}
+}
+
namespace testing {
// Forward declarations.
@@ -881,10 +887,15 @@ class GTEST_API_ Random {
typename std::remove_const<typename std::remove_reference<T>::type>::type
// IsAProtocolMessage<T>::value is a compile-time bool constant that's
-// true if and only if T is type proto2::MessageLite or a subclass of it.
+// true if and only if T is type proto2::MessageLite or
+// google::protobuf::MessageLite or a subclass of one of them.
template <typename T>
struct IsAProtocolMessage
- : public std::is_convertible<const T*, const ::proto2::MessageLite*> {};
+ : public std::integral_constant<
+ bool,
+ std::is_convertible<const T*, const ::proto2::MessageLite*>::value ||
+ std::is_convertible<
+ const T*, const ::google::protobuf::MessageLite*>::value> {};
// When the compiler sees expression IsContainerTest<C>(0), if C is an
// STL-style container class, the first overload of IsContainerTest
diff --git a/googletest/test/gtest_unittest.cc b/googletest/test/gtest_unittest.cc
index 005a2d40d..631180e3d 100644
--- a/googletest/test/gtest_unittest.cc
+++ b/googletest/test/gtest_unittest.cc
@@ -7115,6 +7115,10 @@ TEST(IsAProtocolMessageTest, ValueIsTrueWhenTypeIsAProtocolMessage) {
EXPECT_TRUE(IsAProtocolMessage<::proto2::MessageLite>::value);
}
+TEST(IsAProtocolMessageTest, ValueIsTrueWhenTypeIsAnOpenSourceProtocolMessage) {
+ EXPECT_TRUE(IsAProtocolMessage<::google::protobuf::MessageLite>::value);
+}
+
// Tests that IsAProtocolMessage<T>::value is false when T is neither
// ::proto2::Message nor a sub-class of it.
TEST(IsAProtocolMessageTest, ValueIsFalseWhenTypeIsNotAProtocolMessage) {