Project import generated by Copybara.

GitOrigin-RevId: e3a43e4e5e519cd14df7095749059e2613bdcf76
This commit is contained in:
MediaPipe Team 2020-07-08 17:34:05 -07:00 committed by jqtang
parent 67bd8a2bf0
commit e9fbe868e5
96 changed files with 2547 additions and 1176 deletions

2
BUILD
View File

@ -1,4 +1,4 @@
# Copyright 2019-2020 The MediaPipe Authors. # Copyright 2019 The MediaPipe Authors.
# #
# Licensed under the Apache License, Version 2.0 (the "License"); # Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License. # you may not use this file except in compliance with the License.

View File

@ -101,7 +101,7 @@ run code search using
## Videos ## Videos
* [YouTube Channel](https://www.youtube.com/channel/UCObqmpuSMx-usADtL_qdMAw) * [YouTube Channel](https://www.youtube.com/c/MediaPipe)
## Events ## Events
@ -123,7 +123,7 @@ run code search using
* [Awesome MediaPipe](https://mediapipe.org) - A curated list of awesome * [Awesome MediaPipe](https://mediapipe.org) - A curated list of awesome
MediaPipe related frameworks, libraries and software MediaPipe related frameworks, libraries and software
* [Slack community](https://mediapipe.slack.com) for MediaPipe users * [Slack community](https://https://mediapipe.page.link/joinslack) for MediaPipe users
* [Discuss](https://groups.google.com/forum/#!forum/mediapipe) - General * [Discuss](https://groups.google.com/forum/#!forum/mediapipe) - General
community discussion around MediaPipe community discussion around MediaPipe

View File

@ -37,10 +37,19 @@ http_archive(
) )
# GoogleTest/GoogleMock framework. Used by most unit-tests. # GoogleTest/GoogleMock framework. Used by most unit-tests.
# Last updated 2020-06-30.
http_archive( http_archive(
name = "com_google_googletest", name = "com_google_googletest",
urls = ["https://github.com/google/googletest/archive/master.zip"], urls = ["https://github.com/google/googletest/archive/aee0f9d9b5b87796ee8a0ab26b7587ec30e8858e.zip"],
strip_prefix = "googletest-master", patches = [
# fix for https://github.com/google/googletest/issues/2817
"@//third_party:com_google_googletest_9d580ea80592189e6d44fa35bcf9cdea8bf620d6.diff"
],
patch_args = [
"-p1",
],
strip_prefix = "googletest-aee0f9d9b5b87796ee8a0ab26b7587ec30e8858e",
sha256 = "04a1751f94244307cebe695a69cc945f9387a80b0ef1af21394a490697c5c895",
) )
# Google Benchmark library. # Google Benchmark library.

74
build_ios_examples.sh Normal file
View File

@ -0,0 +1,74 @@
#!/bin/bash
# Copyright 2020 The MediaPipe Authors.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# =========================================================================
#
# Script to build all MediaPipe iOS example apps.
#
# To build all apps and store them in out_dir:
# $ ./build_ios_examples.sh -d out_dir
# Omitting -d and the associated directory saves all generated IPAs in the
# current directory.
# $ ./build_ios_examples.sh -d out_dir --nostrip
# Same as above except that the symnbols are not stripped.
set -e
out_dir="."
strip=true
app_dir="mediapipe/examples/ios"
bin_dir="bazel-bin"
declare -a default_bazel_flags=(build -c opt --config=ios_arm64)
while [[ -n $1 ]]; do
case $1 in
-d)
shift
out_dir=$1
;;
--nostrip)
strip=false
;;
*)
echo "Unsupported input argument $1."
exit 1
;;
esac
shift
done
echo "app_dir: $app_dir"
echo "out_dir: $out_dir"
echo "strip: $strip"
declare -a bazel_flags
apps="${app_dir}/*"
for app in ${apps}; do
if [[ -d "${app}" ]]; then
target_name=${app##*/}
target="${app}:${target_name}"
echo "=== Target: ${target}"
bazel_flags=("${default_bazel_flags[@]}")
bazel_flags+=(${target})
if [[ $strip == true ]]; then
bazel_flags+=(--linkopt=-s)
fi
bazel "${bazel_flags[@]}"
cp -f "${bin_dir}/${app}/"*".ipa" "${out_dir}"
fi
done

View File

@ -149,15 +149,15 @@ When possible, these calculators use platform-specific functionality to share da
The below diagram shows the data flow in a mobile application that captures video from the camera, runs it through a MediaPipe graph, and renders the output on the screen in real time. The dashed line indicates which parts are inside the MediaPipe graph proper. This application runs a Canny edge-detection filter on the CPU using OpenCV, and overlays it on top of the original video using the GPU. The below diagram shows the data flow in a mobile application that captures video from the camera, runs it through a MediaPipe graph, and renders the output on the screen in real time. The dashed line indicates which parts are inside the MediaPipe graph proper. This application runs a Canny edge-detection filter on the CPU using OpenCV, and overlays it on top of the original video using the GPU.
| ![How GPU calculators interact](../images/gpu_example_graph.png) | ![How GPU calculators interact](../images/gpu_example_graph.png)
| :--------------------------------------------------------------------------: |
| *Video frames from the camera are fed into the graph as `GpuBuffer` packets. | Video frames from the camera are fed into the graph as `GpuBuffer` packets. The
: The input stream is accessed by two calculators in parallel. : input stream is accessed by two calculators in parallel.
: `GpuBufferToImageFrameCalculator` converts the buffer into an `ImageFrame`, : `GpuBufferToImageFrameCalculator` converts the buffer into an `ImageFrame`,
: which is then sent through a grayscale converter and a canny filter (both : which is then sent through a grayscale converter and a canny filter (both based
: based on OpenCV and running on the CPU), whose output is then converted into : on OpenCV and running on the CPU), whose output is then converted into a
: a `GpuBuffer` again. A multi-input GPU calculator, GlOverlayCalculator, : `GpuBuffer` again. A multi-input GPU calculator, GlOverlayCalculator, takes as
: takes as input both the original `GpuBuffer` and the one coming out of the : input both the original `GpuBuffer` and the one coming out of the edge detector,
: edge detector, and overlays them using a shader. The output is then sent : and overlays them using a shader. The output is then sent back to the
: back to the application using a callback calculator, and the application : application using a callback calculator, and the application renders the image
: renders the image to the screen using OpenGL.* : to the screen using OpenGL.

View File

@ -184,12 +184,8 @@ app:
### Prerequisite ### Prerequisite
1. Install [Xcode](https://developer.apple.com/xcode/) and the Command Line 1. Install [Xcode](https://developer.apple.com/xcode/), and additionally
Tools. install the Command Line Tools by:
Follow Apple's instructions to obtain the required development certificates
and provisioning profiles for your iOS device. Install the Command Line
Tools by
```bash ```bash
xcode-select --install xcode-select --install
@ -209,26 +205,31 @@ app:
pip3 install --user six pip3 install --user six
``` ```
4. Clone the MediaPipe repository. 4. Follow
[Apple's instructions](https://developer.apple.com/support/certificates/) to
obtain the required development certificates and provisioning profiles for
your iOS device.
Tip: You can the following command to see the provisioning profiles you have
previously downloaded using Xcode: `open
~/Library/MobileDevice/"Provisioning Profiles"`. If there are none, generate
and download a profile on
[Apple's developer site](https://developer.apple.com/account/resources/).
5. Clone the MediaPipe repository.
```bash ```bash
git clone https://github.com/google/mediapipe.git git clone https://github.com/google/mediapipe.git
``` ```
5. Symlink or copy your provisioning profile to 6. In the cloned MediaPipe repository, symlink or copy your provisioning profile
`mediapipe/mediapipe/provisioning_profile.mobileprovision`. to `mediapipe/provisioning_profile.mobileprovision`, e.g.,
```bash ```bash
cd mediapipe cd mediapipe
ln -s ~/Downloads/MyProvisioningProfile.mobileprovision mediapipe/provisioning_profile.mobileprovision ln -s ~/Downloads/MyProvisioningProfile.mobileprovision mediapipe/provisioning_profile.mobileprovision
``` ```
Tip: You can use this command to see the provisioning profiles you have
previously downloaded using Xcode: `open
~/Library/MobileDevice/"Provisioning Profiles"`. If there are none, generate
and download a profile on
[Apple's developer site](https://developer.apple.com/account/resources/).
### Option 1: Build with Bazel in Command Line ### Option 1: Build with Bazel in Command Line
1. Modify the `bundle_id` field of the app's `ios_application` build target to 1. Modify the `bundle_id` field of the app's `ios_application` build target to
@ -246,6 +247,10 @@ app:
You may see a permission request from `codesign` in order to sign the app. You may see a permission request from `codesign` in order to sign the app.
Tip: You can run this
[script](https://github.com/google/mediapipe/blob/master/build_ios_examples.sh)
to build all MediaPipe iOS example apps.
3. In Xcode, open the `Devices and Simulators` window (command-shift-2). 3. In Xcode, open the `Devices and Simulators` window (command-shift-2).
4. Make sure your device is connected. You will see a list of installed apps. 4. Make sure your device is connected. You will see a list of installed apps.

View File

@ -44,6 +44,18 @@ apps, see these [instructions](./building_examples.md#ios).
[Bazel documentation](https://docs.bazel.build/versions/master/install-ubuntu.html) [Bazel documentation](https://docs.bazel.build/versions/master/install-ubuntu.html)
to install Bazel 2.0 or higher. to install Bazel 2.0 or higher.
For Nvidia Jetson and Raspberry Pi devices with ARM Ubuntu, Bazel needs to
be built from source.
```bash
# For Bazel 3.0.0
wget https://github.com/bazelbuild/bazel/releases/download/3.0.0/bazel-3.0.0-dist.zip
sudo apt-get install build-essential openjdk-8-jdk python zip unzip
unzip bazel-3.0.0-dist.zip
env EXTRA_BAZEL_ARGS="--host_javabase=@local_jdk//:jdk" bash ./compile.sh
sudo cp output/bazel /usr/local/bin/
```
3. Install OpenCV and FFmpeg. 3. Install OpenCV and FFmpeg.
Option 1. Use package manager tool to install the pre-compiled OpenCV Option 1. Use package manager tool to install the pre-compiled OpenCV
@ -58,6 +70,14 @@ apps, see these [instructions](./building_examples.md#ios).
libopencv-imgproc-dev libopencv-video-dev libopencv-imgproc-dev libopencv-video-dev
``` ```
[`opencv_linux.BUILD`] is configured for x86_64 by default. For Nvidia
Jetson and Raspberry Pi devices with ARM Ubuntu, the lib paths need to be
modified.
```bash
sed -i "s/x86_64-linux-gnu/aarch64-linux-gnu/g" third_party/opencv_linux.BUILD
```
Option 2. Run [`setup_opencv.sh`] to automatically build OpenCV from source Option 2. Run [`setup_opencv.sh`] to automatically build OpenCV from source
and modify MediaPipe's OpenCV config. and modify MediaPipe's OpenCV config.
@ -493,14 +513,14 @@ cameras. Alternatively, you use a video file as input.
```bash ```bash
username@DESKTOP-TMVLBJ1:~$ curl -sLO --retry 5 --retry-max-time 10 \ username@DESKTOP-TMVLBJ1:~$ curl -sLO --retry 5 --retry-max-time 10 \
https://storage.googleapis.com/bazel/2.0.0/release/bazel-2.0.0-installer-linux-x86_64.sh && \ https://storage.googleapis.com/bazel/3.0.0/release/bazel-3.0.0-installer-linux-x86_64.sh && \
sudo mkdir -p /usr/local/bazel/2.0.0 && \ sudo mkdir -p /usr/local/bazel/3.0.0 && \
chmod 755 bazel-2.0.0-installer-linux-x86_64.sh && \ chmod 755 bazel-3.0.0-installer-linux-x86_64.sh && \
sudo ./bazel-2.0.0-installer-linux-x86_64.sh --prefix=/usr/local/bazel/2.0.0 && \ sudo ./bazel-3.0.0-installer-linux-x86_64.sh --prefix=/usr/local/bazel/3.0.0 && \
source /usr/local/bazel/2.0.0/lib/bazel/bin/bazel-complete.bash source /usr/local/bazel/3.0.0/lib/bazel/bin/bazel-complete.bash
username@DESKTOP-TMVLBJ1:~$ /usr/local/bazel/2.0.0/lib/bazel/bin/bazel version && \ username@DESKTOP-TMVLBJ1:~$ /usr/local/bazel/3.0.0/lib/bazel/bin/bazel version && \
alias bazel='/usr/local/bazel/2.0.0/lib/bazel/bin/bazel' alias bazel='/usr/local/bazel/3.0.0/lib/bazel/bin/bazel'
``` ```
6. Checkout MediaPipe repository. 6. Checkout MediaPipe repository.

View File

@ -101,7 +101,7 @@ run code search using
## Videos ## Videos
* [YouTube Channel](https://www.youtube.com/channel/UCObqmpuSMx-usADtL_qdMAw) * [YouTube Channel](https://www.youtube.com/c/MediaPipe)
## Events ## Events
@ -123,7 +123,7 @@ run code search using
* [Awesome MediaPipe](https://mediapipe.org) - A curated list of awesome * [Awesome MediaPipe](https://mediapipe.org) - A curated list of awesome
MediaPipe related frameworks, libraries and software MediaPipe related frameworks, libraries and software
* [Slack community](https://mediapipe.slack.com) for MediaPipe users * [Slack community](https://https://mediapipe.page.link/joinslack) for MediaPipe users
* [Discuss](https://groups.google.com/forum/#!forum/mediapipe) - General * [Discuss](https://groups.google.com/forum/#!forum/mediapipe) - General
community discussion around MediaPipe community discussion around MediaPipe

View File

@ -1,6 +1,6 @@
--- ---
layout: default layout: default
title: Hand title: Hands
parent: Solutions parent: Solutions
nav_order: 3 nav_order: 3
--- ---
@ -219,9 +219,13 @@ Please refer to [these instructions](../index.md#mediapipe-on-the-web).
## Resources ## Resources
* Google AI Blog: [On-Device, Real-Time Hand Tracking with MediaPipe](https://ai.googleblog.com/2019/08/on-device-real-time-hand-tracking-with.html) * Google AI Blog:
* TensorFlow Blog: [Face and hand tracking in the browser with MediaPipe and [On-Device, Real-Time Hand Tracking with MediaPipe](https://ai.googleblog.com/2019/08/on-device-real-time-hand-tracking-with.html)
TensorFlow.js](https://blog.tensorflow.org/2020/03/face-and-hand-tracking-in-browser-with-mediapipe-and-tensorflowjs.html) * TensorFlow Blog:
[Face and hand tracking in the browser with MediaPipe and TensorFlow.js](https://blog.tensorflow.org/2020/03/face-and-hand-tracking-in-browser-with-mediapipe-and-tensorflowjs.html)
* Paper:
[MediaPipe Hands: On-device Real-time Hand Tracking](https://arxiv.org/abs/2006.10214)
([presentation](https://www.youtube.com/watch?v=I-UOrvxxXEk))
* Palm detection model: * Palm detection model:
[TFLite model](https://github.com/google/mediapipe/tree/master/mediapipe/models/palm_detection.tflite), [TFLite model](https://github.com/google/mediapipe/tree/master/mediapipe/models/palm_detection.tflite),
[TF.js model](https://tfhub.dev/mediapipe/handdetector/1) [TF.js model](https://tfhub.dev/mediapipe/handdetector/1)

View File

@ -188,5 +188,8 @@ to visualize its associated subgraphs, please see
[Real-Time 3D Object Detection on Mobile Devices with MediaPipe](https://ai.googleblog.com/2020/03/real-time-3d-object-detection-on-mobile.html) [Real-Time 3D Object Detection on Mobile Devices with MediaPipe](https://ai.googleblog.com/2020/03/real-time-3d-object-detection-on-mobile.html)
* Paper: [MobilePose: Real-Time Pose Estimation for Unseen Objects with Weak * Paper: [MobilePose: Real-Time Pose Estimation for Unseen Objects with Weak
Shape Supervision](https://arxiv.org/abs/2003.03522) Shape Supervision](https://arxiv.org/abs/2003.03522)
* Paper:
[Instant 3D Object Tracking with Applications in Augmented Reality](https://drive.google.com/open?id=1O_zHmlgXIzAdKljp20U_JUkEHOGG52R8)
([presentation](https://www.youtube.com/watch?v=9ndF1AIo7h0))
* [TFLite model for shoes](https://github.com/google/mediapipe/tree/master/mediapipe/models/object_detection_3d_sneakers.tflite) * [TFLite model for shoes](https://github.com/google/mediapipe/tree/master/mediapipe/models/object_detection_3d_sneakers.tflite)
* [TFLite model for chairs](https://github.com/google/mediapipe/tree/master/mediapipe/models/object_detection_3d_chair.tflite) * [TFLite model for chairs](https://github.com/google/mediapipe/tree/master/mediapipe/models/object_detection_3d_chair.tflite)

View File

@ -21,7 +21,16 @@ available on Linux, Android, or iOS.
## Enabling tracing and profiling ## Enabling tracing and profiling
To enable tracing/profiling of a mediapipe graph, the `CalculatorGraphConfig` (in To enable tracing and profiling of a mediapipe graph:
1. The profiling library must be linked to the framework.
2. Tracing and profiling must be enabled in the graph configuration.
The profiling library is linked to the framework by default. If needed,
the profiling library can be omitted from the framework using the bazel
command line option: `--define MEDIAPIPE_PROFILING=0`.
To enable tracing and profiling, the `CalculatorGraphConfig` (in
[calculator.proto](https://github.com/google/mediapipe/tree/master/mediapipe/framework/calculator.proto)) [calculator.proto](https://github.com/google/mediapipe/tree/master/mediapipe/framework/calculator.proto))
representing the graph must have a `profiler_config` message at its root. Here representing the graph must have a `profiler_config` message at its root. Here
is a simple setup that turns on a few extra options: is a simple setup that turns on a few extra options:

View File

@ -386,14 +386,13 @@ REGISTER_CALCULATOR(ImageTransformationCalculator);
const int input_width = input_mat.cols; const int input_width = input_mat.cols;
const int input_height = input_mat.rows; const int input_height = input_mat.rows;
if (!output_height_ || !output_width_) { int output_width;
output_height_ = input_height; int output_height;
output_width_ = input_width; ComputeOutputDimensions(input_width, input_height, &output_width,
} &output_height);
if (output_width_ > 0 && output_height_ > 0) {
cv::Mat scaled_mat; cv::Mat scaled_mat;
int output_width = output_width_;
int output_height = output_height_;
if (scale_mode_ == mediapipe::ScaleMode_Mode_STRETCH) { if (scale_mode_ == mediapipe::ScaleMode_Mode_STRETCH) {
int scale_flag = int scale_flag =
input_mat.cols > output_width_ && input_mat.rows > output_height_ input_mat.cols > output_width_ && input_mat.rows > output_height_
@ -416,7 +415,8 @@ REGISTER_CALCULATOR(ImageTransformationCalculator);
const int bottom = output_height_ - target_height - top; const int bottom = output_height_ - target_height - top;
const int left = (output_width_ - target_width) / 2; const int left = (output_width_ - target_width) / 2;
const int right = output_width_ - target_width - left; const int right = output_width_ - target_width - left;
cv::copyMakeBorder(intermediate_mat, scaled_mat, top, bottom, left, right, cv::copyMakeBorder(intermediate_mat, scaled_mat, top, bottom, left,
right,
options_.constant_padding() ? cv::BORDER_CONSTANT options_.constant_padding() ? cv::BORDER_CONSTANT
: cv::BORDER_REPLICATE); : cv::BORDER_REPLICATE);
} else { } else {
@ -426,6 +426,8 @@ REGISTER_CALCULATOR(ImageTransformationCalculator);
output_height = target_height; output_height = target_height;
} }
} }
input_mat = scaled_mat;
}
if (cc->Outputs().HasTag("LETTERBOX_PADDING")) { if (cc->Outputs().HasTag("LETTERBOX_PADDING")) {
auto padding = absl::make_unique<std::array<float, 4>>(); auto padding = absl::make_unique<std::array<float, 4>>();
@ -437,10 +439,33 @@ REGISTER_CALCULATOR(ImageTransformationCalculator);
} }
cv::Mat rotated_mat; cv::Mat rotated_mat;
cv::Size rotated_size(output_width, output_height);
if (input_mat.size() == rotated_size) {
const int angle = RotationModeToDegrees(rotation_); const int angle = RotationModeToDegrees(rotation_);
cv::Point2f src_center(scaled_mat.cols / 2.0, scaled_mat.rows / 2.0); cv::Point2f src_center(input_mat.cols / 2.0, input_mat.rows / 2.0);
cv::Mat rotation_mat = cv::getRotationMatrix2D(src_center, angle, 1.0); cv::Mat rotation_mat = cv::getRotationMatrix2D(src_center, angle, 1.0);
cv::warpAffine(scaled_mat, rotated_mat, rotation_mat, scaled_mat.size()); cv::warpAffine(input_mat, rotated_mat, rotation_mat, rotated_size);
} else {
switch (rotation_) {
case mediapipe::RotationMode_Mode_UNKNOWN:
case mediapipe::RotationMode_Mode_ROTATION_0:
LOG(ERROR) << "Not rotating image.";
rotated_mat = input_mat;
break;
case mediapipe::RotationMode_Mode_ROTATION_90:
LOG(ERROR) << "Rotating image by 90 degrees ccw.";
cv::rotate(input_mat, rotated_mat, cv::ROTATE_90_COUNTERCLOCKWISE);
break;
case mediapipe::RotationMode_Mode_ROTATION_180:
LOG(ERROR) << "Rotating image by 180 degrees.";
cv::rotate(input_mat, rotated_mat, cv::ROTATE_180);
break;
case mediapipe::RotationMode_Mode_ROTATION_270:
LOG(ERROR) << "Rotating image by 90 degrees cw.";
cv::rotate(input_mat, rotated_mat, cv::ROTATE_90_CLOCKWISE);
break;
}
}
cv::Mat flipped_mat; cv::Mat flipped_mat;
if (flip_horizontally_ || flip_vertically_) { if (flip_horizontally_ || flip_vertically_) {

View File

@ -139,7 +139,6 @@ class TensorFlowSessionFromSavedModelGenerator : public PacketGenerator {
static_cast<::mediapipe::StatusCode>(status.code()), static_cast<::mediapipe::StatusCode>(status.code()),
status.ToString()); status.ToString());
} }
auto session = absl::make_unique<TensorFlowSession>(); auto session = absl::make_unique<TensorFlowSession>();
session->session = std::move(saved_model->session); session->session = std::move(saved_model->session);

View File

@ -14,6 +14,7 @@
# #
load("//mediapipe/framework/port:build_config.bzl", "mediapipe_cc_proto_library") load("//mediapipe/framework/port:build_config.bzl", "mediapipe_cc_proto_library")
load("@bazel_skylib//lib:selects.bzl", "selects")
licenses(["notice"]) # Apache 2.0 licenses(["notice"]) # Apache 2.0
@ -202,6 +203,13 @@ cc_library(
alwayslink = 1, alwayslink = 1,
) )
selects.config_setting_group(
name = "gpu_inference_disabled",
match_any = [
"//mediapipe/gpu:disable_gpu",
],
)
cc_library( cc_library(
name = "tflite_inference_calculator", name = "tflite_inference_calculator",
srcs = ["tflite_inference_calculator.cc"], srcs = ["tflite_inference_calculator.cc"],
@ -226,13 +234,14 @@ cc_library(
"@com_google_absl//absl/memory", "@com_google_absl//absl/memory",
"//mediapipe/framework:calculator_framework", "//mediapipe/framework:calculator_framework",
"//mediapipe/util:resource_util", "//mediapipe/util:resource_util",
"//mediapipe/util/tflite:config",
"@org_tensorflow//tensorflow/lite:framework", "@org_tensorflow//tensorflow/lite:framework",
"@org_tensorflow//tensorflow/lite/delegates/xnnpack:xnnpack_delegate", "@org_tensorflow//tensorflow/lite/delegates/xnnpack:xnnpack_delegate",
"@org_tensorflow//tensorflow/lite/kernels:builtin_ops", "@org_tensorflow//tensorflow/lite/kernels:builtin_ops",
"//mediapipe/framework/stream_handler:fixed_size_input_stream_handler", "//mediapipe/framework/stream_handler:fixed_size_input_stream_handler",
"//mediapipe/framework/port:ret_check", "//mediapipe/framework/port:ret_check",
] + select({ ] + selects.with_or({
"//mediapipe/gpu:disable_gpu": [], ":gpu_inference_disabled": [],
"//mediapipe:ios": [ "//mediapipe:ios": [
"//mediapipe/gpu:MPPMetalHelper", "//mediapipe/gpu:MPPMetalHelper",
"//mediapipe/gpu:MPPMetalUtil", "//mediapipe/gpu:MPPMetalUtil",
@ -285,6 +294,7 @@ cc_library(
}), }),
visibility = ["//visibility:public"], visibility = ["//visibility:public"],
deps = [ deps = [
"//mediapipe/util/tflite:config",
":util", ":util",
":tflite_converter_calculator_cc_proto", ":tflite_converter_calculator_cc_proto",
"//mediapipe/util:resource_util", "//mediapipe/util:resource_util",
@ -295,23 +305,26 @@ cc_library(
"//mediapipe/framework/port:ret_check", "//mediapipe/framework/port:ret_check",
"@org_tensorflow//tensorflow/lite:framework", "@org_tensorflow//tensorflow/lite:framework",
"@org_tensorflow//tensorflow/lite/kernels:builtin_ops", "@org_tensorflow//tensorflow/lite/kernels:builtin_ops",
] + select({ ] + selects.with_or({
"//mediapipe/gpu:disable_gpu": [], ":gpu_inference_disabled": [],
"//mediapipe:ios": [ "//mediapipe:ios": [
"//mediapipe/gpu:MPPMetalUtil", "//mediapipe/gpu:MPPMetalUtil",
"//mediapipe/gpu:gpu_buffer",
"//mediapipe/gpu:MPPMetalHelper", "//mediapipe/gpu:MPPMetalHelper",
"//mediapipe/objc:mediapipe_framework_ios", "//mediapipe/objc:mediapipe_framework_ios",
"@org_tensorflow//tensorflow/lite/delegates/gpu:metal_delegate", "@org_tensorflow//tensorflow/lite/delegates/gpu:metal_delegate",
], ],
"//conditions:default": [ "//conditions:default": [
"//mediapipe/gpu:gpu_buffer",
"//mediapipe/gpu:gl_calculator_helper", "//mediapipe/gpu:gl_calculator_helper",
"@org_tensorflow//tensorflow/lite/delegates/gpu:gl_delegate", "@org_tensorflow//tensorflow/lite/delegates/gpu:gl_delegate",
"@org_tensorflow//tensorflow/lite/delegates/gpu/gl:gl_buffer", "@org_tensorflow//tensorflow/lite/delegates/gpu/gl:gl_buffer",
"@org_tensorflow//tensorflow/lite/delegates/gpu/gl:gl_program", "@org_tensorflow//tensorflow/lite/delegates/gpu/gl:gl_program",
"@org_tensorflow//tensorflow/lite/delegates/gpu/gl:gl_shader", "@org_tensorflow//tensorflow/lite/delegates/gpu/gl:gl_shader",
], ],
}) + select({
"//mediapipe/gpu:disable_gpu": [],
"//conditions:default": [
"//mediapipe/gpu:gpu_buffer",
],
}), }),
alwayslink = 1, alwayslink = 1,
) )
@ -348,8 +361,8 @@ cc_library(
"//mediapipe/framework:calculator_framework", "//mediapipe/framework:calculator_framework",
"//mediapipe/util:resource_util", "//mediapipe/util:resource_util",
"@org_tensorflow//tensorflow/lite:framework", "@org_tensorflow//tensorflow/lite:framework",
] + select({ ] + selects.with_or({
"//mediapipe/gpu:disable_gpu": [], ":gpu_inference_disabled": [],
"//mediapipe:ios": [], "//mediapipe:ios": [],
"//conditions:default": [ "//conditions:default": [
"//mediapipe/gpu:gl_calculator_helper", "//mediapipe/gpu:gl_calculator_helper",
@ -404,6 +417,7 @@ cc_library(
}), }),
visibility = ["//visibility:public"], visibility = ["//visibility:public"],
deps = [ deps = [
"//mediapipe/util/tflite:config",
":util", ":util",
":tflite_tensors_to_detections_calculator_cc_proto", ":tflite_tensors_to_detections_calculator_cc_proto",
"//mediapipe/framework/formats:detection_cc_proto", "//mediapipe/framework/formats:detection_cc_proto",
@ -415,8 +429,8 @@ cc_library(
"//mediapipe/framework/formats/object_detection:anchor_cc_proto", "//mediapipe/framework/formats/object_detection:anchor_cc_proto",
"//mediapipe/framework/port:ret_check", "//mediapipe/framework/port:ret_check",
"@org_tensorflow//tensorflow/lite:framework", "@org_tensorflow//tensorflow/lite:framework",
] + select({ ] + selects.with_or({
"//mediapipe/gpu:disable_gpu": [], ":gpu_inference_disabled": [],
"//mediapipe:ios": [ "//mediapipe:ios": [
"//mediapipe/gpu:MPPMetalUtil", "//mediapipe/gpu:MPPMetalUtil",
"//mediapipe/gpu:gpu_buffer", "//mediapipe/gpu:gpu_buffer",
@ -492,6 +506,8 @@ cc_library(
alwayslink = 1, alwayslink = 1,
) )
# To run this with native GPU on Linux, use:
# bazel test //mediapipe/calculators/tflite:tflite_inference_calculator_test --copt=-DTFLITE_GPU_EXTRA_GLES_DEPS --copt=-DMESA_EGL_NO_X11_HEADERS --copt=-DEGL_NO_X11 --config=grte_v5 --test_strategy=local
cc_test( cc_test(
name = "tflite_inference_calculator_test", name = "tflite_inference_calculator_test",
srcs = ["tflite_inference_calculator_test.cc"], srcs = ["tflite_inference_calculator_test.cc"],

View File

@ -22,19 +22,23 @@
#include "mediapipe/framework/formats/matrix.h" #include "mediapipe/framework/formats/matrix.h"
#include "mediapipe/framework/port/ret_check.h" #include "mediapipe/framework/port/ret_check.h"
#include "mediapipe/util/resource_util.h" #include "mediapipe/util/resource_util.h"
#include "mediapipe/util/tflite/config.h"
#include "tensorflow/lite/error_reporter.h" #include "tensorflow/lite/error_reporter.h"
#include "tensorflow/lite/interpreter.h" #include "tensorflow/lite/interpreter.h"
#if !defined(MEDIAPIPE_DISABLE_GL_COMPUTE) #ifndef MEDIAPIPE_DISABLE_GPU
#include "mediapipe/gpu/gl_calculator_helper.h"
#include "mediapipe/gpu/gpu_buffer.h" #include "mediapipe/gpu/gpu_buffer.h"
#endif // MEDIAPIPE_DISABLE_GPU
#if MEDIAPIPE_TFLITE_GL_INFERENCE
#include "mediapipe/gpu/gl_calculator_helper.h"
#include "tensorflow/lite/delegates/gpu/gl/gl_buffer.h" #include "tensorflow/lite/delegates/gpu/gl/gl_buffer.h"
#include "tensorflow/lite/delegates/gpu/gl/gl_program.h" #include "tensorflow/lite/delegates/gpu/gl/gl_program.h"
#include "tensorflow/lite/delegates/gpu/gl/gl_shader.h" #include "tensorflow/lite/delegates/gpu/gl/gl_shader.h"
#include "tensorflow/lite/delegates/gpu/gl_delegate.h" #include "tensorflow/lite/delegates/gpu/gl_delegate.h"
#endif // !MEDIAPIPE_DISABLE_GPU #endif // MEDIAPIPE_TFLITE_GL_INFERENCE
#if defined(MEDIAPIPE_IOS) #if MEDIAPIPE_TFLITE_METAL_INFERENCE
#import <CoreVideo/CoreVideo.h> #import <CoreVideo/CoreVideo.h>
#import <Metal/Metal.h> #import <Metal/Metal.h>
#import <MetalKit/MetalKit.h> #import <MetalKit/MetalKit.h>
@ -43,13 +47,7 @@
#include "mediapipe/gpu/MPPMetalUtil.h" #include "mediapipe/gpu/MPPMetalUtil.h"
#include "mediapipe/gpu/gpu_buffer.h" #include "mediapipe/gpu/gpu_buffer.h"
#include "tensorflow/lite/delegates/gpu/metal_delegate.h" #include "tensorflow/lite/delegates/gpu/metal_delegate.h"
#endif // iOS #endif // MEDIAPIPE_TFLITE_METAL_INFERENCE
#if !defined(MEDIAPIPE_DISABLE_GL_COMPUTE)
typedef ::tflite::gpu::gl::GlBuffer GpuTensor;
#elif defined(MEDIAPIPE_IOS)
typedef id<MTLBuffer> GpuTensor;
#endif
namespace { namespace {
constexpr int kWorkgroupSize = 8; // Block size for GPU shader. constexpr int kWorkgroupSize = 8; // Block size for GPU shader.
@ -73,7 +71,7 @@ constexpr char kMatrixTag[] = "MATRIX";
namespace mediapipe { namespace mediapipe {
namespace { namespace {
#if !defined(MEDIAPIPE_DISABLE_GL_COMPUTE) #if MEDIAPIPE_TFLITE_GL_INFERENCE
using ::tflite::gpu::gl::CreateReadWriteShaderStorageBuffer; using ::tflite::gpu::gl::CreateReadWriteShaderStorageBuffer;
using ::tflite::gpu::gl::GlProgram; using ::tflite::gpu::gl::GlProgram;
using ::tflite::gpu::gl::GlShader; using ::tflite::gpu::gl::GlShader;
@ -83,13 +81,13 @@ struct GPUData {
GlShader shader; GlShader shader;
GlProgram program; GlProgram program;
}; };
#elif defined(MEDIAPIPE_IOS) #elif MEDIAPIPE_TFLITE_METAL_INFERENCE
struct GPUData { struct GPUData {
int elements = 1; int elements = 1;
GpuTensor buffer; GpuTensor buffer;
id<MTLComputePipelineState> pipeline_state; id<MTLComputePipelineState> pipeline_state;
}; };
#endif #endif // MEDIAPIPE_TFLITE_GL_INFERENCE
} // namespace } // namespace
@ -157,13 +155,13 @@ class TfLiteConverterCalculator : public CalculatorBase {
std::unique_ptr<tflite::Interpreter> interpreter_ = nullptr; std::unique_ptr<tflite::Interpreter> interpreter_ = nullptr;
#if !defined(MEDIAPIPE_DISABLE_GL_COMPUTE) #if MEDIAPIPE_TFLITE_GL_INFERENCE
mediapipe::GlCalculatorHelper gpu_helper_; mediapipe::GlCalculatorHelper gpu_helper_;
std::unique_ptr<GPUData> gpu_data_out_; std::unique_ptr<GPUData> gpu_data_out_;
#elif defined(MEDIAPIPE_IOS) #elif MEDIAPIPE_TFLITE_METAL_INFERENCE
MPPMetalHelper* gpu_helper_ = nullptr; MPPMetalHelper* gpu_helper_ = nullptr;
std::unique_ptr<GPUData> gpu_data_out_; std::unique_ptr<GPUData> gpu_data_out_;
#endif #endif // MEDIAPIPE_TFLITE_GL_INFERENCE
bool initialized_ = false; bool initialized_ = false;
bool use_gpu_ = false; bool use_gpu_ = false;
@ -178,6 +176,18 @@ class TfLiteConverterCalculator : public CalculatorBase {
}; };
REGISTER_CALCULATOR(TfLiteConverterCalculator); REGISTER_CALCULATOR(TfLiteConverterCalculator);
namespace {
template <class CC>
bool ShouldUseGpu(CC* cc) {
#if MEDIAPIPE_TFLITE_GPU_SUPPORTED
return cc->Inputs().HasTag(kGpuBufferTag) ||
cc->Outputs().HasTag(kTensorsGpuTag);
#else
return false;
#endif // MEDIAPIPE_TFLITE_GPU_SUPPORTED
}
} // namespace
::mediapipe::Status TfLiteConverterCalculator::GetContract( ::mediapipe::Status TfLiteConverterCalculator::GetContract(
CalculatorContract* cc) { CalculatorContract* cc) {
// Confirm only one of the input streams is present. // Confirm only one of the input streams is present.
@ -189,37 +199,31 @@ REGISTER_CALCULATOR(TfLiteConverterCalculator);
RET_CHECK(cc->Outputs().HasTag(kTensorsTag) ^ RET_CHECK(cc->Outputs().HasTag(kTensorsTag) ^
cc->Outputs().HasTag(kTensorsGpuTag)); cc->Outputs().HasTag(kTensorsGpuTag));
bool use_gpu = false;
if (cc->Inputs().HasTag(kImageFrameTag)) { if (cc->Inputs().HasTag(kImageFrameTag)) {
cc->Inputs().Tag(kImageFrameTag).Set<ImageFrame>(); cc->Inputs().Tag(kImageFrameTag).Set<ImageFrame>();
} }
if (cc->Inputs().HasTag(kMatrixTag)) { if (cc->Inputs().HasTag(kMatrixTag)) {
cc->Inputs().Tag(kMatrixTag).Set<Matrix>(); cc->Inputs().Tag(kMatrixTag).Set<Matrix>();
} }
#if !defined(MEDIAPIPE_DISABLE_GPU) && !defined(__EMSCRIPTEN__) #ifndef MEDIAPIPE_DISABLE_GPU
if (cc->Inputs().HasTag(kGpuBufferTag)) { if (cc->Inputs().HasTag(kGpuBufferTag)) {
cc->Inputs().Tag(kGpuBufferTag).Set<mediapipe::GpuBuffer>(); cc->Inputs().Tag(kGpuBufferTag).Set<mediapipe::GpuBuffer>();
use_gpu |= true;
} }
#endif // !MEDIAPIPE_DISABLE_GPU #endif // MEDIAPIPE_DISABLE_GPU
if (cc->Outputs().HasTag(kTensorsTag)) { if (cc->Outputs().HasTag(kTensorsTag)) {
cc->Outputs().Tag(kTensorsTag).Set<std::vector<TfLiteTensor>>(); cc->Outputs().Tag(kTensorsTag).Set<std::vector<TfLiteTensor>>();
} }
#if !defined(MEDIAPIPE_DISABLE_GPU) && !defined(__EMSCRIPTEN__)
if (cc->Outputs().HasTag(kTensorsGpuTag)) { if (cc->Outputs().HasTag(kTensorsGpuTag)) {
cc->Outputs().Tag(kTensorsGpuTag).Set<std::vector<GpuTensor>>(); cc->Outputs().Tag(kTensorsGpuTag).Set<std::vector<GpuTensor>>();
use_gpu |= true;
} }
#endif // !MEDIAPIPE_DISABLE_GPU
if (use_gpu) { if (ShouldUseGpu(cc)) {
#if !defined(MEDIAPIPE_DISABLE_GL_COMPUTE) #if MEDIAPIPE_TFLITE_GL_INFERENCE
MP_RETURN_IF_ERROR(mediapipe::GlCalculatorHelper::UpdateContract(cc)); MP_RETURN_IF_ERROR(mediapipe::GlCalculatorHelper::UpdateContract(cc));
#elif defined(MEDIAPIPE_IOS) #elif MEDIAPIPE_TFLITE_METAL_INFERENCE
MP_RETURN_IF_ERROR([MPPMetalHelper updateContract:cc]); MP_RETURN_IF_ERROR([MPPMetalHelper updateContract:cc]);
#endif #endif // MEDIAPIPE_TFLITE_GL_INFERENCE
} }
// Assign this calculator's default InputStreamHandler. // Assign this calculator's default InputStreamHandler.
@ -233,14 +237,7 @@ REGISTER_CALCULATOR(TfLiteConverterCalculator);
MP_RETURN_IF_ERROR(LoadOptions(cc)); MP_RETURN_IF_ERROR(LoadOptions(cc));
if (cc->Inputs().HasTag(kGpuBufferTag) || use_gpu_ = ShouldUseGpu(cc);
cc->Outputs().HasTag(kGpuBufferTag)) {
#if !defined(MEDIAPIPE_DISABLE_GPU) && !defined(__EMSCRIPTEN__)
use_gpu_ = true;
#else
RET_CHECK_FAIL() << "GPU processing not enabled.";
#endif
}
if (use_gpu_) { if (use_gpu_) {
// Cannot mix CPU/GPU streams. // Cannot mix CPU/GPU streams.
@ -248,12 +245,12 @@ REGISTER_CALCULATOR(TfLiteConverterCalculator);
cc->Outputs().HasTag(kTensorsGpuTag)); cc->Outputs().HasTag(kTensorsGpuTag));
// Cannot use quantization. // Cannot use quantization.
use_quantized_tensors_ = false; use_quantized_tensors_ = false;
#if !defined(MEDIAPIPE_DISABLE_GL_COMPUTE) #if MEDIAPIPE_TFLITE_GL_INFERENCE
MP_RETURN_IF_ERROR(gpu_helper_.Open(cc)); MP_RETURN_IF_ERROR(gpu_helper_.Open(cc));
#elif defined(MEDIAPIPE_IOS) #elif MEDIAPIPE_TFLITE_METAL_INFERENCE
gpu_helper_ = [[MPPMetalHelper alloc] initWithCalculatorContext:cc]; gpu_helper_ = [[MPPMetalHelper alloc] initWithCalculatorContext:cc];
RET_CHECK(gpu_helper_); RET_CHECK(gpu_helper_);
#endif #endif // MEDIAPIPE_TFLITE_GL_INFERENCE
} else { } else {
interpreter_ = absl::make_unique<tflite::Interpreter>(); interpreter_ = absl::make_unique<tflite::Interpreter>();
interpreter_->AddTensors(1); interpreter_->AddTensors(1);
@ -282,12 +279,12 @@ REGISTER_CALCULATOR(TfLiteConverterCalculator);
} }
::mediapipe::Status TfLiteConverterCalculator::Close(CalculatorContext* cc) { ::mediapipe::Status TfLiteConverterCalculator::Close(CalculatorContext* cc) {
#if !defined(MEDIAPIPE_DISABLE_GL_COMPUTE) interpreter_.reset();
#if MEDIAPIPE_TFLITE_GL_INFERENCE
gpu_helper_.RunInGlContext([this] { gpu_data_out_.reset(); }); gpu_helper_.RunInGlContext([this] { gpu_data_out_.reset(); });
#endif #elif MEDIAPIPE_TFLITE_METAL_INFERENCE
#if defined(MEDIAPIPE_IOS)
gpu_data_out_.reset(); gpu_data_out_.reset();
#endif #endif // MEDIAPIPE_TFLITE_GL_INFERENCE
return ::mediapipe::OkStatus(); return ::mediapipe::OkStatus();
} }
@ -318,8 +315,14 @@ REGISTER_CALCULATOR(TfLiteConverterCalculator);
RET_CHECK(format != mediapipe::ImageFormat::VEC32F1) RET_CHECK(format != mediapipe::ImageFormat::VEC32F1)
<< "Only 8-bit input images are supported for quantization."; << "Only 8-bit input images are supported for quantization.";
quant.type = kTfLiteAffineQuantization; quant.type = kTfLiteAffineQuantization;
quant.params = nullptr; auto quant_params = static_cast<TfLiteAffineQuantization*>(
// Optional: Set 'quant' quantization params here if needed. malloc(sizeof(TfLiteAffineQuantization)));
quant_params->scale = TfLiteFloatArrayCreate(1);
quant_params->scale->data[0] = 1.0;
quant_params->zero_point = TfLiteIntArrayCreate(1);
quant_params->zero_point->data[0] = 0;
quant_params->quantized_dimension = 0;
quant.params = quant_params;
interpreter_->SetTensorParametersReadWrite(0, kTfLiteUInt8, "", interpreter_->SetTensorParametersReadWrite(0, kTfLiteUInt8, "",
{channels_preserved}, quant); {channels_preserved}, quant);
} else { } else {
@ -414,7 +417,7 @@ REGISTER_CALCULATOR(TfLiteConverterCalculator);
::mediapipe::Status TfLiteConverterCalculator::ProcessGPU( ::mediapipe::Status TfLiteConverterCalculator::ProcessGPU(
CalculatorContext* cc) { CalculatorContext* cc) {
#if !defined(MEDIAPIPE_DISABLE_GL_COMPUTE) #if MEDIAPIPE_TFLITE_GL_INFERENCE
// GpuBuffer to tflite::gpu::GlBuffer conversion. // GpuBuffer to tflite::gpu::GlBuffer conversion.
const auto& input = const auto& input =
cc->Inputs().Tag(kGpuBufferTag).Get<mediapipe::GpuBuffer>(); cc->Inputs().Tag(kGpuBufferTag).Get<mediapipe::GpuBuffer>();
@ -451,7 +454,7 @@ REGISTER_CALCULATOR(TfLiteConverterCalculator);
cc->Outputs() cc->Outputs()
.Tag(kTensorsGpuTag) .Tag(kTensorsGpuTag)
.Add(output_tensors.release(), cc->InputTimestamp()); .Add(output_tensors.release(), cc->InputTimestamp());
#elif defined(MEDIAPIPE_IOS) #elif MEDIAPIPE_TFLITE_METAL_INFERENCE
// GpuBuffer to id<MTLBuffer> conversion. // GpuBuffer to id<MTLBuffer> conversion.
const auto& input = const auto& input =
cc->Inputs().Tag(kGpuBufferTag).Get<mediapipe::GpuBuffer>(); cc->Inputs().Tag(kGpuBufferTag).Get<mediapipe::GpuBuffer>();
@ -490,13 +493,13 @@ REGISTER_CALCULATOR(TfLiteConverterCalculator);
.Add(output_tensors.release(), cc->InputTimestamp()); .Add(output_tensors.release(), cc->InputTimestamp());
#else #else
RET_CHECK_FAIL() << "GPU processing is not enabled."; RET_CHECK_FAIL() << "GPU processing is not enabled.";
#endif #endif // MEDIAPIPE_TFLITE_GL_INFERENCE
return ::mediapipe::OkStatus(); return ::mediapipe::OkStatus();
} }
::mediapipe::Status TfLiteConverterCalculator::InitGpu(CalculatorContext* cc) { ::mediapipe::Status TfLiteConverterCalculator::InitGpu(CalculatorContext* cc) {
#if !defined(MEDIAPIPE_DISABLE_GPU) && !defined(__EMSCRIPTEN__) #if MEDIAPIPE_TFLITE_GPU_SUPPORTED
// Get input image sizes. // Get input image sizes.
const auto& input = const auto& input =
cc->Inputs().Tag(kGpuBufferTag).Get<mediapipe::GpuBuffer>(); cc->Inputs().Tag(kGpuBufferTag).Get<mediapipe::GpuBuffer>();
@ -512,9 +515,9 @@ REGISTER_CALCULATOR(TfLiteConverterCalculator);
RET_CHECK_FAIL() << "Unsupported GPU input format."; RET_CHECK_FAIL() << "Unsupported GPU input format.";
if (include_alpha && (format != mediapipe::ImageFormat::SRGBA)) if (include_alpha && (format != mediapipe::ImageFormat::SRGBA))
RET_CHECK_FAIL() << "Num input channels is less than desired output."; RET_CHECK_FAIL() << "Num input channels is less than desired output.";
#endif // !MEDIAPIPE_DISABLE_GPU #endif // MEDIAPIPE_TFLITE_GPU_SUPPORTED
#if !defined(MEDIAPIPE_DISABLE_GL_COMPUTE) #if MEDIAPIPE_TFLITE_GL_INFERENCE
MP_RETURN_IF_ERROR(gpu_helper_.RunInGlContext( MP_RETURN_IF_ERROR(gpu_helper_.RunInGlContext(
[this, &include_alpha, &input, &single_channel]() -> ::mediapipe::Status { [this, &include_alpha, &input, &single_channel]() -> ::mediapipe::Status {
// Device memory. // Device memory.
@ -559,7 +562,7 @@ REGISTER_CALCULATOR(TfLiteConverterCalculator);
return ::mediapipe::OkStatus(); return ::mediapipe::OkStatus();
})); }));
#elif defined(MEDIAPIPE_IOS) #elif MEDIAPIPE_TFLITE_METAL_INFERENCE
RET_CHECK(include_alpha) RET_CHECK(include_alpha)
<< "iOS GPU inference currently accepts only RGBA input."; << "iOS GPU inference currently accepts only RGBA input.";
@ -616,7 +619,7 @@ REGISTER_CALCULATOR(TfLiteConverterCalculator);
RET_CHECK(gpu_data_out_->pipeline_state != nil) RET_CHECK(gpu_data_out_->pipeline_state != nil)
<< "Couldn't create pipeline state " << "Couldn't create pipeline state "
<< [[error localizedDescription] UTF8String]; << [[error localizedDescription] UTF8String];
#endif #endif // MEDIAPIPE_TFLITE_GL_INFERENCE
return ::mediapipe::OkStatus(); return ::mediapipe::OkStatus();
} }

View File

@ -22,6 +22,7 @@
#include "mediapipe/calculators/tflite/util.h" #include "mediapipe/calculators/tflite/util.h"
#include "mediapipe/framework/calculator_framework.h" #include "mediapipe/framework/calculator_framework.h"
#include "mediapipe/framework/port/ret_check.h" #include "mediapipe/framework/port/ret_check.h"
#include "mediapipe/util/tflite/config.h"
#if !defined(__EMSCRIPTEN__) || defined(__EMSCRIPTEN_PTHREADS__) #if !defined(__EMSCRIPTEN__) || defined(__EMSCRIPTEN_PTHREADS__)
#include "mediapipe/util/cpu_util.h" #include "mediapipe/util/cpu_util.h"
@ -33,7 +34,7 @@
#include "tensorflow/lite/kernels/register.h" #include "tensorflow/lite/kernels/register.h"
#include "tensorflow/lite/model.h" #include "tensorflow/lite/model.h"
#if !defined(MEDIAPIPE_DISABLE_GL_COMPUTE) #if MEDIAPIPE_TFLITE_GL_INFERENCE
#include "mediapipe/gpu/gl_calculator_helper.h" #include "mediapipe/gpu/gl_calculator_helper.h"
#include "mediapipe/gpu/gpu_buffer.h" #include "mediapipe/gpu/gpu_buffer.h"
#include "mediapipe/util/tflite/tflite_gpu_runner.h" #include "mediapipe/util/tflite/tflite_gpu_runner.h"
@ -42,9 +43,9 @@
#include "tensorflow/lite/delegates/gpu/gl/gl_program.h" #include "tensorflow/lite/delegates/gpu/gl/gl_program.h"
#include "tensorflow/lite/delegates/gpu/gl/gl_shader.h" #include "tensorflow/lite/delegates/gpu/gl/gl_shader.h"
#include "tensorflow/lite/delegates/gpu/gl_delegate.h" #include "tensorflow/lite/delegates/gpu/gl_delegate.h"
#endif // !MEDIAPIPE_DISABLE_GL_COMPUTE #endif // MEDIAPIPE_TFLITE_GL_INFERENCE
#if defined(MEDIAPIPE_IOS) #if MEDIAPIPE_TFLITE_METAL_INFERENCE
#import <CoreVideo/CoreVideo.h> #import <CoreVideo/CoreVideo.h>
#import <Metal/Metal.h> #import <Metal/Metal.h>
#import <MetalKit/MetalKit.h> #import <MetalKit/MetalKit.h>
@ -56,7 +57,7 @@
#include "tensorflow/lite/delegates/gpu/metal/buffer_convert.h" #include "tensorflow/lite/delegates/gpu/metal/buffer_convert.h"
#include "tensorflow/lite/delegates/gpu/metal_delegate.h" #include "tensorflow/lite/delegates/gpu/metal_delegate.h"
#include "tensorflow/lite/delegates/gpu/metal_delegate_internal.h" #include "tensorflow/lite/delegates/gpu/metal_delegate_internal.h"
#endif // iOS #endif // MEDIAPIPE_TFLITE_METAL_INFERENCE
#if !defined(MEDIAPIPE_EDGE_TPU) #if !defined(MEDIAPIPE_EDGE_TPU)
#include "tensorflow/lite/delegates/xnnpack/xnnpack_delegate.h" #include "tensorflow/lite/delegates/xnnpack/xnnpack_delegate.h"
@ -71,12 +72,6 @@ int NumGroups(const int size, const int group_size) { // NOLINT
return (size + group_size - 1) / group_size; return (size + group_size - 1) / group_size;
} }
#if !defined(MEDIAPIPE_DISABLE_GL_COMPUTE)
typedef ::tflite::gpu::gl::GlBuffer GpuTensor;
#elif defined(MEDIAPIPE_IOS)
typedef id<MTLBuffer> GpuTensor;
#endif
// Round up n to next multiple of m. // Round up n to next multiple of m.
size_t RoundUp(size_t n, size_t m) { return ((n + m - 1) / m) * m; } // NOLINT size_t RoundUp(size_t n, size_t m) { return ((n + m - 1) / m) * m; } // NOLINT
@ -112,13 +107,13 @@ std::unique_ptr<tflite::Interpreter> BuildEdgeTpuInterpreter(
// * Aux // * Aux
namespace mediapipe { namespace mediapipe {
#if !defined(MEDIAPIPE_DISABLE_GL_COMPUTE) #if MEDIAPIPE_TFLITE_GL_INFERENCE
using ::tflite::gpu::gl::CopyBuffer; using ::tflite::gpu::gl::CopyBuffer;
using ::tflite::gpu::gl::CreateReadWriteShaderStorageBuffer; using ::tflite::gpu::gl::CreateReadWriteShaderStorageBuffer;
using ::tflite::gpu::gl::GlBuffer; using ::tflite::gpu::gl::GlBuffer;
#endif #endif
#if !defined(MEDIAPIPE_DISABLE_GPU) && !defined(__EMSCRIPTEN__) #if MEDIAPIPE_TFLITE_GPU_SUPPORTED
namespace { namespace {
struct GPUData { struct GPUData {
int elements = 1; int elements = 1;
@ -126,7 +121,7 @@ struct GPUData {
::tflite::gpu::BHWC shape; ::tflite::gpu::BHWC shape;
}; };
} // namespace } // namespace
#endif #endif // MEDIAPIPE_TFLITE_GPU_SUPPORTED
// Returns number of threads to configure XNNPACK delegate with. // Returns number of threads to configure XNNPACK delegate with.
// (Equal to user provided value if specified. Otherwise, it returns number of // (Equal to user provided value if specified. Otherwise, it returns number of
@ -152,7 +147,7 @@ int GetXnnpackNumThreads(
// Creates an interpreter with given model and calls invoke(). // Creates an interpreter with given model and calls invoke().
// Optionally run inference on CPU/GPU. // Optionally run inference on CPU/GPU.
// //
// This calculator is designed to be used with the TfLiteConverterCalcualtor, // This calculator is designed to be used with the TfLiteConverterCalculator,
// to get the appropriate inputs. // to get the appropriate inputs.
// //
// When the input tensors are on CPU, gpu inference is optional and can be // When the input tensors are on CPU, gpu inference is optional and can be
@ -183,7 +178,6 @@ int GetXnnpackNumThreads(
// options: { // options: {
// [mediapipe.TfLiteInferenceCalculatorOptions.ext] { // [mediapipe.TfLiteInferenceCalculatorOptions.ext] {
// model_path: "modelname.tflite" // model_path: "modelname.tflite"
// delegate { gpu {} }
// } // }
// } // }
// } // }
@ -192,11 +186,12 @@ int GetXnnpackNumThreads(
// //
// node { // node {
// calculator: "TfLiteInferenceCalculator" // calculator: "TfLiteInferenceCalculator"
// input_stream: "TENSORS:tensor_image" // input_stream: "TENSORS_GPU:tensor_image"
// input_side_packet: "MODEL:model" // input_side_packet: "MODEL:model"
// output_stream: "TENSORS:tensors" // output_stream: "TENSORS_GPU:tensors"
// options: { // options: {
// [mediapipe.TfLiteInferenceCalculatorOptions.ext] { // [mediapipe.TfLiteInferenceCalculatorOptions.ext] {
// model_path: "modelname.tflite"
// delegate { gpu {} } // delegate { gpu {} }
// } // }
// } // }
@ -228,24 +223,45 @@ class TfLiteInferenceCalculator : public CalculatorBase {
::mediapipe::Status LoadModel(CalculatorContext* cc); ::mediapipe::Status LoadModel(CalculatorContext* cc);
::mediapipe::StatusOr<Packet> GetModelAsPacket(const CalculatorContext& cc); ::mediapipe::StatusOr<Packet> GetModelAsPacket(const CalculatorContext& cc);
::mediapipe::Status LoadDelegate(CalculatorContext* cc); ::mediapipe::Status LoadDelegate(CalculatorContext* cc);
::mediapipe::Status InitTFLiteGPURunner(); ::mediapipe::Status InitTFLiteGPURunner(CalculatorContext* cc);
::mediapipe::Status ProcessInputsCpu(
CalculatorContext* cc, std::vector<TfLiteTensor>* output_tensors_cpu);
::mediapipe::Status ProcessOutputsCpu(
CalculatorContext* cc,
std::unique_ptr<std::vector<TfLiteTensor>> output_tensors_cpu);
::mediapipe::Status ProcessInputsGpu(
CalculatorContext* cc, std::vector<GpuTensor>* output_tensors_gpu);
::mediapipe::Status ProcessOutputsGpu(
CalculatorContext* cc,
std::unique_ptr<std::vector<TfLiteTensor>> output_tensors_cpu,
std::unique_ptr<std::vector<GpuTensor>> output_tensors_gpu);
::mediapipe::Status RunInContextIfNeeded(
std::function<::mediapipe::Status(void)> f) {
if (gpu_inference_) {
#if MEDIAPIPE_TFLITE_GL_INFERENCE
return gpu_helper_.RunInGlContext(std::move(f));
#endif // MEDIAPIPE_TFLITE_GL_INFERENCE
}
return f();
}
Packet model_packet_; Packet model_packet_;
std::unique_ptr<tflite::Interpreter> interpreter_; std::unique_ptr<tflite::Interpreter> interpreter_;
TfLiteDelegatePtr delegate_; TfLiteDelegatePtr delegate_;
#if !defined(MEDIAPIPE_DISABLE_GL_COMPUTE) #if MEDIAPIPE_TFLITE_GL_INFERENCE
mediapipe::GlCalculatorHelper gpu_helper_; mediapipe::GlCalculatorHelper gpu_helper_;
std::vector<std::unique_ptr<GPUData>> gpu_data_in_; std::vector<std::unique_ptr<GPUData>> gpu_data_in_;
std::vector<std::unique_ptr<GPUData>> gpu_data_out_; std::vector<std::unique_ptr<GPUData>> gpu_data_out_;
std::unique_ptr<tflite::gpu::TFLiteGPURunner> tflite_gpu_runner_; std::unique_ptr<tflite::gpu::TFLiteGPURunner> tflite_gpu_runner_;
#elif defined(MEDIAPIPE_IOS) #elif MEDIAPIPE_TFLITE_METAL_INFERENCE
MPPMetalHelper* gpu_helper_ = nullptr; MPPMetalHelper* gpu_helper_ = nullptr;
std::vector<std::unique_ptr<GPUData>> gpu_data_in_; std::vector<std::unique_ptr<GPUData>> gpu_data_in_;
std::vector<std::unique_ptr<GPUData>> gpu_data_out_; std::vector<std::unique_ptr<GPUData>> gpu_data_out_;
id<MTLComputePipelineState> fp32_to_fp16_program_; id<MTLComputePipelineState> fp32_to_fp16_program_;
TFLBufferConvert* converter_from_BPHWC4_ = nil; TFLBufferConvert* converter_from_BPHWC4_ = nil;
#endif #endif // MEDIAPIPE_TFLITE_GL_INFERENCE
#if defined(MEDIAPIPE_EDGE_TPU) #if defined(MEDIAPIPE_EDGE_TPU)
std::shared_ptr<edgetpu::EdgeTpuContext> edgetpu_context_ = std::shared_ptr<edgetpu::EdgeTpuContext> edgetpu_context_ =
@ -263,6 +279,22 @@ REGISTER_CALCULATOR(TfLiteInferenceCalculator);
// Calculator Core Section // Calculator Core Section
namespace {
template <class CC>
bool ShouldUseGpu(CC* cc) {
#if MEDIAPIPE_TFLITE_GPU_SUPPORTED
const auto& options =
cc->template Options<::mediapipe::TfLiteInferenceCalculatorOptions>();
return options.use_gpu() ||
(options.has_delegate() && options.delegate().has_gpu()) ||
cc->Inputs().HasTag(kTensorsGpuTag) ||
cc->Outputs().HasTag(kTensorsGpuTag);
#else
return false;
#endif // MEDIAPIPE_TFLITE_GPU_SUPPORTED
}
} // namespace
::mediapipe::Status TfLiteInferenceCalculator::GetContract( ::mediapipe::Status TfLiteInferenceCalculator::GetContract(
CalculatorContract* cc) { CalculatorContract* cc) {
RET_CHECK(cc->Inputs().HasTag(kTensorsTag) ^ RET_CHECK(cc->Inputs().HasTag(kTensorsTag) ^
@ -276,32 +308,15 @@ REGISTER_CALCULATOR(TfLiteInferenceCalculator);
cc->InputSidePackets().HasTag("MODEL")) cc->InputSidePackets().HasTag("MODEL"))
<< "Either model as side packet or model path in options is required."; << "Either model as side packet or model path in options is required.";
bool use_gpu =
options.has_delegate() ? options.delegate().has_gpu() : options.use_gpu();
if (cc->Inputs().HasTag(kTensorsTag)) if (cc->Inputs().HasTag(kTensorsTag))
cc->Inputs().Tag(kTensorsTag).Set<std::vector<TfLiteTensor>>(); cc->Inputs().Tag(kTensorsTag).Set<std::vector<TfLiteTensor>>();
#if !defined(MEDIAPIPE_DISABLE_GPU) && !defined(__EMSCRIPTEN__)
if (cc->Inputs().HasTag(kTensorsGpuTag)) {
RET_CHECK(!options.has_delegate() || options.delegate().has_gpu())
<< "GPU input is compatible with GPU delegate only.";
cc->Inputs().Tag(kTensorsGpuTag).Set<std::vector<GpuTensor>>();
use_gpu |= true;
}
#endif // !MEDIAPIPE_DISABLE_GPU
if (cc->Outputs().HasTag(kTensorsTag)) if (cc->Outputs().HasTag(kTensorsTag))
cc->Outputs().Tag(kTensorsTag).Set<std::vector<TfLiteTensor>>(); cc->Outputs().Tag(kTensorsTag).Set<std::vector<TfLiteTensor>>();
#if !defined(MEDIAPIPE_DISABLE_GPU) && !defined(__EMSCRIPTEN__)
if (cc->Outputs().HasTag(kTensorsGpuTag)) {
RET_CHECK(!options.has_delegate() || options.delegate().has_gpu())
<< "GPU output is compatible with GPU delegate only.";
if (cc->Inputs().HasTag(kTensorsGpuTag))
cc->Inputs().Tag(kTensorsGpuTag).Set<std::vector<GpuTensor>>();
if (cc->Outputs().HasTag(kTensorsGpuTag))
cc->Outputs().Tag(kTensorsGpuTag).Set<std::vector<GpuTensor>>(); cc->Outputs().Tag(kTensorsGpuTag).Set<std::vector<GpuTensor>>();
use_gpu |= true;
}
#endif // !MEDIAPIPE_DISABLE_GPU
if (cc->InputSidePackets().HasTag("CUSTOM_OP_RESOLVER")) { if (cc->InputSidePackets().HasTag("CUSTOM_OP_RESOLVER")) {
cc->InputSidePackets() cc->InputSidePackets()
@ -312,10 +327,10 @@ REGISTER_CALCULATOR(TfLiteInferenceCalculator);
cc->InputSidePackets().Tag("MODEL").Set<TfLiteModelPtr>(); cc->InputSidePackets().Tag("MODEL").Set<TfLiteModelPtr>();
} }
if (use_gpu) { if (ShouldUseGpu(cc)) {
#if !defined(MEDIAPIPE_DISABLE_GL_COMPUTE) #if MEDIAPIPE_TFLITE_GL_INFERENCE
MP_RETURN_IF_ERROR(mediapipe::GlCalculatorHelper::UpdateContract(cc)); MP_RETURN_IF_ERROR(mediapipe::GlCalculatorHelper::UpdateContract(cc));
#elif defined(MEDIAPIPE_IOS) #elif MEDIAPIPE_TFLITE_METAL_INFERENCE
MP_RETURN_IF_ERROR([MPPMetalHelper updateContract:cc]); MP_RETURN_IF_ERROR([MPPMetalHelper updateContract:cc]);
#endif #endif
} }
@ -331,149 +346,111 @@ REGISTER_CALCULATOR(TfLiteInferenceCalculator);
const auto& options = const auto& options =
cc->Options<::mediapipe::TfLiteInferenceCalculatorOptions>(); cc->Options<::mediapipe::TfLiteInferenceCalculatorOptions>();
gpu_inference_ = options.use_gpu();
if (cc->Inputs().HasTag(kTensorsGpuTag)) { gpu_inference_ = ShouldUseGpu(cc);
#if !defined(MEDIAPIPE_DISABLE_GPU) && !defined(__EMSCRIPTEN__) gpu_input_ = cc->Inputs().HasTag(kTensorsGpuTag);
gpu_input_ = true; gpu_output_ = cc->Outputs().HasTag(kTensorsGpuTag);
gpu_inference_ = true; // Inference must be on GPU also.
#else
RET_CHECK(!cc->Inputs().HasTag(kTensorsGpuTag))
<< "GPU processing not enabled.";
#endif // !MEDIAPIPE_DISABLE_GPU
}
if (cc->Outputs().HasTag(kTensorsGpuTag)) { use_advanced_gpu_api_ = MEDIAPIPE_TFLITE_GL_INFERENCE &&
#if !defined(MEDIAPIPE_DISABLE_GPU) && !defined(__EMSCRIPTEN__) options.has_delegate() &&
gpu_output_ = true; options.delegate().has_gpu() &&
RET_CHECK(cc->Inputs().HasTag(kTensorsGpuTag)) options.delegate().gpu().use_advanced_gpu_api();
<< "GPU output must also have GPU Input."; if (use_advanced_gpu_api_ && !gpu_input_) {
#else LOG(WARNING) << "Cannot use advanced GPU APIs, input must be GPU buffers."
RET_CHECK(!cc->Inputs().HasTag(kTensorsGpuTag)) "Falling back to the default TFLite API.";
<< "GPU processing not enabled.";
#endif // !MEDIAPIPE_DISABLE_GPU
}
use_advanced_gpu_api_ = false;
if (use_advanced_gpu_api_ && !(gpu_input_ && gpu_output_)) {
LOG(WARNING)
<< "Cannot use advanced GPU APIs, both inputs and outputs must "
"be GPU buffers. Falling back to the default TFLite API.";
use_advanced_gpu_api_ = false; use_advanced_gpu_api_ = false;
} }
CHECK(!use_advanced_gpu_api_ || gpu_inference_);
MP_RETURN_IF_ERROR(LoadModel(cc)); MP_RETURN_IF_ERROR(LoadModel(cc));
if (gpu_inference_) { if (gpu_inference_) {
#if !defined(MEDIAPIPE_DISABLE_GL_COMPUTE) #if MEDIAPIPE_TFLITE_GL_INFERENCE
MP_RETURN_IF_ERROR(gpu_helper_.Open(cc)); MP_RETURN_IF_ERROR(gpu_helper_.Open(cc));
#elif defined(MEDIAPIPE_IOS)
gpu_helper_ = [[MPPMetalHelper alloc] initWithCalculatorContext:cc];
RET_CHECK(gpu_helper_);
#endif
#if !defined(MEDIAPIPE_DISABLE_GL_COMPUTE)
MP_RETURN_IF_ERROR( MP_RETURN_IF_ERROR(
gpu_helper_.RunInGlContext([this, &cc]() -> ::mediapipe::Status { gpu_helper_.RunInGlContext([this, &cc]() -> ::mediapipe::Status {
return use_advanced_gpu_api_ ? InitTFLiteGPURunner() return use_advanced_gpu_api_ ? InitTFLiteGPURunner(cc)
: LoadDelegate(cc); : LoadDelegate(cc);
})); }));
if (use_advanced_gpu_api_) return ::mediapipe::OkStatus(); #elif MEDIAPIPE_TFLITE_METAL_INFERENCE
#else gpu_helper_ = [[MPPMetalHelper alloc] initWithCalculatorContext:cc];
RET_CHECK(gpu_helper_);
MP_RETURN_IF_ERROR(LoadDelegate(cc)); MP_RETURN_IF_ERROR(LoadDelegate(cc));
#endif #endif
} else { } else {
#if defined(__EMSCRIPTEN__) || defined(MEDIAPIPE_ANDROID) // TODO: why only on these platforms?
// It seems that the XNNPACK delegate fails to load on Linux.
#if defined(__EMSCRIPTEN__) || defined(MEDIAPIPE_ANDROID) || \
defined(MEDIAPIPE_IOS)
MP_RETURN_IF_ERROR(LoadDelegate(cc)); MP_RETURN_IF_ERROR(LoadDelegate(cc));
#endif // __EMSCRIPTEN__ || ANDROID #endif // __EMSCRIPTEN__ || MEDIAPIPE_ANDROID || MEDIAPIPE_IOS
} }
return ::mediapipe::OkStatus(); return ::mediapipe::OkStatus();
} }
::mediapipe::Status TfLiteInferenceCalculator::Process(CalculatorContext* cc) { ::mediapipe::Status TfLiteInferenceCalculator::Process(CalculatorContext* cc) {
return RunInContextIfNeeded([this, cc]() -> ::mediapipe::Status {
// 0. Declare outputs // 0. Declare outputs
#if !defined(MEDIAPIPE_DISABLE_GL_COMPUTE) || defined(MEDIAPIPE_IOS)
auto output_tensors_gpu = absl::make_unique<std::vector<GpuTensor>>(); auto output_tensors_gpu = absl::make_unique<std::vector<GpuTensor>>();
#endif
auto output_tensors_cpu = absl::make_unique<std::vector<TfLiteTensor>>(); auto output_tensors_cpu = absl::make_unique<std::vector<TfLiteTensor>>();
// 1. Receive pre-processed tensor inputs. // 1. Receive pre-processed tensor inputs.
if (use_advanced_gpu_api_ && gpu_output_) { if (gpu_input_) {
#if !defined(MEDIAPIPE_DISABLE_GL_COMPUTE) MP_RETURN_IF_ERROR(ProcessInputsGpu(cc, output_tensors_gpu.get()));
if (cc->Inputs().Tag(kTensorsGpuTag).IsEmpty()) { } else {
return ::mediapipe::OkStatus(); MP_RETURN_IF_ERROR(ProcessInputsCpu(cc, output_tensors_cpu.get()));
} }
const auto& input_tensors =
cc->Inputs().Tag(kTensorsGpuTag).Get<std::vector<GpuTensor>>(); // 2. Run inference.
RET_CHECK(!input_tensors.empty()); #if MEDIAPIPE_TFLITE_GL_INFERENCE
MP_RETURN_IF_ERROR(gpu_helper_.RunInGlContext( if (gpu_inference_ && use_advanced_gpu_api_) {
[this, &input_tensors, &output_tensors_gpu]() -> ::mediapipe::Status { RET_CHECK(tflite_gpu_runner_->Invoke().ok());
for (int i = 0; i < input_tensors.size(); ++i) { } else {
MP_RETURN_IF_ERROR(tflite_gpu_runner_->BindSSBOToInputTensor( RET_CHECK_EQ(interpreter_->Invoke(), kTfLiteOk);
input_tensors[i].id(), i));
} }
// Allocate output tensor. #else
output_tensors_gpu->resize(gpu_data_out_.size()); RET_CHECK_EQ(interpreter_->Invoke(), kTfLiteOk);
for (int i = 0; i < gpu_data_out_.size(); ++i) { #endif // MEDIAPIPE_TFLITE_GL_INFERENCE
GpuTensor& tensor = output_tensors_gpu->at(i);
RET_CHECK_CALL(CreateReadWriteShaderStorageBuffer<float>( // 3. Output processed tensors.
gpu_data_out_[i]->elements, &tensor)); if (gpu_output_ || use_advanced_gpu_api_) {
MP_RETURN_IF_ERROR( MP_RETURN_IF_ERROR(ProcessOutputsGpu(cc, std::move(output_tensors_cpu),
tflite_gpu_runner_->BindSSBOToOutputTensor(tensor.id(), i)); std::move(output_tensors_gpu)));
} } else {
return ::mediapipe::OkStatus(); MP_RETURN_IF_ERROR(ProcessOutputsCpu(cc, std::move(output_tensors_cpu)));
}));
#endif
} else if (gpu_input_) {
// Read GPU input into SSBO.
#if !defined(MEDIAPIPE_DISABLE_GL_COMPUTE)
if (cc->Inputs().Tag(kTensorsGpuTag).IsEmpty()) {
return ::mediapipe::OkStatus();
}
const auto& input_tensors =
cc->Inputs().Tag(kTensorsGpuTag).Get<std::vector<GpuTensor>>();
RET_CHECK_GT(input_tensors.size(), 0);
MP_RETURN_IF_ERROR(gpu_helper_.RunInGlContext(
[this, &input_tensors]() -> ::mediapipe::Status {
// Explicit copy input.
gpu_data_in_.resize(input_tensors.size());
for (int i = 0; i < input_tensors.size(); ++i) {
RET_CHECK_CALL(
CopyBuffer(input_tensors[i], gpu_data_in_[i]->buffer));
} }
return ::mediapipe::OkStatus(); return ::mediapipe::OkStatus();
})); });
#elif defined(MEDIAPIPE_IOS)
if (cc->Inputs().Tag(kTensorsGpuTag).IsEmpty()) {
return ::mediapipe::OkStatus();
} }
const auto& input_tensors =
cc->Inputs().Tag(kTensorsGpuTag).Get<std::vector<GpuTensor>>(); ::mediapipe::Status TfLiteInferenceCalculator::Close(CalculatorContext* cc) {
RET_CHECK_GT(input_tensors.size(), 0); return RunInContextIfNeeded([this]() -> ::mediapipe::Status {
// Explicit copy input with conversion float 32 bits to 16 bits. if (delegate_) {
gpu_data_in_.resize(input_tensors.size()); interpreter_ = nullptr;
id<MTLCommandBuffer> command_buffer = [gpu_helper_ commandBuffer]; delegate_ = nullptr;
command_buffer.label = @"TfLiteInferenceCalculatorConvert"; #if MEDIAPIPE_TFLITE_GPU_SUPPORTED
id<MTLComputeCommandEncoder> compute_encoder = if (gpu_inference_) {
[command_buffer computeCommandEncoder]; for (int i = 0; i < gpu_data_in_.size(); ++i) {
[compute_encoder setComputePipelineState:fp32_to_fp16_program_]; gpu_data_in_[i].reset();
for (int i = 0; i < input_tensors.size(); ++i) {
[compute_encoder setBuffer:input_tensors[i] offset:0 atIndex:0];
[compute_encoder setBuffer:gpu_data_in_[i]->buffer offset:0 atIndex:1];
constexpr int kWorkgroupSize = 64; // Block size for GPU shader.
MTLSize threads_per_group = MTLSizeMake(kWorkgroupSize, 1, 1);
const int threadgroups =
NumGroups(gpu_data_in_[i]->elements, kWorkgroupSize);
[compute_encoder dispatchThreadgroups:MTLSizeMake(threadgroups, 1, 1)
threadsPerThreadgroup:threads_per_group];
} }
[compute_encoder endEncoding]; for (int i = 0; i < gpu_data_out_.size(); ++i) {
[command_buffer commit]; gpu_data_out_[i].reset();
#else }
RET_CHECK_FAIL() << "GPU processing not enabled."; }
#endif // MEDIAPIPE_TFLITE_GPU_SUPPORTED
}
#if defined(MEDIAPIPE_EDGE_TPU)
edgetpu_context_.reset();
#endif #endif
} else { return ::mediapipe::OkStatus();
});
}
// Calculator Auxiliary Section
::mediapipe::Status TfLiteInferenceCalculator::ProcessInputsCpu(
CalculatorContext* cc, std::vector<TfLiteTensor>* output_tensors_cpu) {
if (cc->Inputs().Tag(kTensorsTag).IsEmpty()) { if (cc->Inputs().Tag(kTensorsTag).IsEmpty()) {
return ::mediapipe::OkStatus(); return ::mediapipe::OkStatus();
} }
@ -496,39 +473,128 @@ REGISTER_CALCULATOR(TfLiteInferenceCalculator);
input_tensor->bytes); input_tensor->bytes);
} }
} }
}
// 2. Run inference.
if (gpu_inference_) {
#if !defined(MEDIAPIPE_DISABLE_GL_COMPUTE)
MP_RETURN_IF_ERROR(
gpu_helper_.RunInGlContext([this]() -> ::mediapipe::Status {
if (use_advanced_gpu_api_) {
RET_CHECK(tflite_gpu_runner_->Invoke().ok());
} else {
RET_CHECK_EQ(interpreter_->Invoke(), kTfLiteOk);
}
return ::mediapipe::OkStatus(); return ::mediapipe::OkStatus();
}));
#elif defined(MEDIAPIPE_IOS)
RET_CHECK_EQ(interpreter_->Invoke(), kTfLiteOk);
#endif
} else {
RET_CHECK_EQ(interpreter_->Invoke(), kTfLiteOk);
} }
// 3. Output processed tensors. ::mediapipe::Status TfLiteInferenceCalculator::ProcessInputsGpu(
CalculatorContext* cc, std::vector<GpuTensor>* output_tensors_gpu) {
if (cc->Inputs().Tag(kTensorsGpuTag).IsEmpty()) {
return ::mediapipe::OkStatus();
}
if (use_advanced_gpu_api_) { if (use_advanced_gpu_api_) {
#if !defined(MEDIAPIPE_DISABLE_GL_COMPUTE) #if MEDIAPIPE_TFLITE_GL_INFERENCE
const auto& input_tensors =
cc->Inputs().Tag(kTensorsGpuTag).Get<std::vector<GpuTensor>>();
RET_CHECK(!input_tensors.empty());
for (int i = 0; i < input_tensors.size(); ++i) {
MP_RETURN_IF_ERROR(
tflite_gpu_runner_->BindSSBOToInputTensor(input_tensors[i].id(), i));
}
if (gpu_output_) {
// Allocate new output tensor.
output_tensors_gpu->resize(gpu_data_out_.size());
for (int i = 0; i < gpu_data_out_.size(); ++i) {
GpuTensor& tensor = output_tensors_gpu->at(i);
RET_CHECK_CALL(CreateReadWriteShaderStorageBuffer<float>(
gpu_data_out_[i]->elements, &tensor));
MP_RETURN_IF_ERROR(
tflite_gpu_runner_->BindSSBOToOutputTensor(tensor.id(), i));
}
} else {
// Re-use internal output tensor.
for (int i = 0; i < gpu_data_out_.size(); ++i) {
MP_RETURN_IF_ERROR(tflite_gpu_runner_->BindSSBOToOutputTensor(
gpu_data_out_[i]->buffer.id(), i));
}
}
#endif // MEDIAPIPE_TFLITE_GL_INFERENCE
} else if (gpu_input_) {
// Read GPU input into SSBO.
#if MEDIAPIPE_TFLITE_GL_INFERENCE
const auto& input_tensors =
cc->Inputs().Tag(kTensorsGpuTag).Get<std::vector<GpuTensor>>();
RET_CHECK_GT(input_tensors.size(), 0);
// Explicit copy input.
gpu_data_in_.resize(input_tensors.size());
for (int i = 0; i < input_tensors.size(); ++i) {
RET_CHECK_CALL(CopyBuffer(input_tensors[i], gpu_data_in_[i]->buffer));
}
#elif MEDIAPIPE_TFLITE_METAL_INFERENCE
const auto& input_tensors =
cc->Inputs().Tag(kTensorsGpuTag).Get<std::vector<GpuTensor>>();
RET_CHECK_GT(input_tensors.size(), 0);
// Explicit copy input with conversion float 32 bits to 16 bits.
gpu_data_in_.resize(input_tensors.size());
id<MTLCommandBuffer> command_buffer = [gpu_helper_ commandBuffer];
command_buffer.label = @"TfLiteInferenceCalculatorConvert";
id<MTLComputeCommandEncoder> compute_encoder =
[command_buffer computeCommandEncoder];
[compute_encoder setComputePipelineState:fp32_to_fp16_program_];
for (int i = 0; i < input_tensors.size(); ++i) {
[compute_encoder setBuffer:input_tensors[i] offset:0 atIndex:0];
[compute_encoder setBuffer:gpu_data_in_[i]->buffer offset:0 atIndex:1];
constexpr int kWorkgroupSize = 64; // Block size for GPU shader.
MTLSize threads_per_group = MTLSizeMake(kWorkgroupSize, 1, 1);
const int threadgroups =
NumGroups(gpu_data_in_[i]->elements, kWorkgroupSize);
[compute_encoder dispatchThreadgroups:MTLSizeMake(threadgroups, 1, 1)
threadsPerThreadgroup:threads_per_group];
}
[compute_encoder endEncoding];
[command_buffer commit];
#endif // MEDIAPIPE_TFLITE_GL_INFERENCE
}
return ::mediapipe::OkStatus();
}
::mediapipe::Status TfLiteInferenceCalculator::ProcessOutputsCpu(
CalculatorContext* cc,
std::unique_ptr<std::vector<TfLiteTensor>> output_tensors_cpu) {
// Output result tensors (CPU).
const auto& tensor_indexes = interpreter_->outputs();
for (int i = 0; i < tensor_indexes.size(); ++i) {
TfLiteTensor* tensor = interpreter_->tensor(tensor_indexes[i]);
output_tensors_cpu->emplace_back(*tensor);
}
cc->Outputs()
.Tag(kTensorsTag)
.Add(output_tensors_cpu.release(), cc->InputTimestamp());
return ::mediapipe::OkStatus();
}
::mediapipe::Status TfLiteInferenceCalculator::ProcessOutputsGpu(
CalculatorContext* cc,
std::unique_ptr<std::vector<TfLiteTensor>> output_tensors_cpu,
std::unique_ptr<std::vector<GpuTensor>> output_tensors_gpu) {
if (use_advanced_gpu_api_) {
#if MEDIAPIPE_TFLITE_GL_INFERENCE
if (gpu_output_) {
// Send out pre-allocated tensors.
cc->Outputs() cc->Outputs()
.Tag(kTensorsGpuTag) .Tag(kTensorsGpuTag)
.Add(output_tensors_gpu.release(), cc->InputTimestamp()); .Add(output_tensors_gpu.release(), cc->InputTimestamp());
#endif } else {
// Download to CPU for output.
const auto& tensor_indexes = interpreter_->inputs();
for (int i = 0; i < tensor_indexes.size(); ++i) {
TfLiteTensor* tensor = interpreter_->tensor(tensor_indexes[i]);
std::vector<float> gpu_data(tensor->bytes / sizeof(float));
RET_CHECK_CALL(gpu_data_out_[i]->buffer.Read(
absl::MakeSpan(tensor->data.f, tensor->bytes)));
output_tensors_cpu->emplace_back(*tensor);
}
// Output result tensors (CPU).
cc->Outputs()
.Tag(kTensorsTag)
.Add(output_tensors_cpu.release(), cc->InputTimestamp());
}
#endif // MEDIAPIPE_TFLITE_GL_INFERENCE
} else if (gpu_output_) { } else if (gpu_output_) {
#if !defined(MEDIAPIPE_DISABLE_GL_COMPUTE) #if MEDIAPIPE_TFLITE_GL_INFERENCE
// Output result tensors (GPU). // Output result tensors (GPU).
MP_RETURN_IF_ERROR(gpu_helper_.RunInGlContext(
[this, &output_tensors_gpu]() -> ::mediapipe::Status {
output_tensors_gpu->resize(gpu_data_out_.size()); output_tensors_gpu->resize(gpu_data_out_.size());
for (int i = 0; i < gpu_data_out_.size(); ++i) { for (int i = 0; i < gpu_data_out_.size(); ++i) {
GpuTensor& tensor = output_tensors_gpu->at(i); GpuTensor& tensor = output_tensors_gpu->at(i);
@ -537,12 +603,10 @@ REGISTER_CALCULATOR(TfLiteInferenceCalculator);
gpu_data_out_[i]->elements, &tensor)); gpu_data_out_[i]->elements, &tensor));
RET_CHECK_CALL(CopyBuffer(gpu_data_out_[i]->buffer, tensor)); RET_CHECK_CALL(CopyBuffer(gpu_data_out_[i]->buffer, tensor));
} }
return ::mediapipe::OkStatus();
}));
cc->Outputs() cc->Outputs()
.Tag(kTensorsGpuTag) .Tag(kTensorsGpuTag)
.Add(output_tensors_gpu.release(), cc->InputTimestamp()); .Add(output_tensors_gpu.release(), cc->InputTimestamp());
#elif defined(MEDIAPIPE_IOS) #elif MEDIAPIPE_TFLITE_METAL_INFERENCE
// Output result tensors (GPU). // Output result tensors (GPU).
output_tensors_gpu->resize(gpu_data_out_.size()); output_tensors_gpu->resize(gpu_data_out_.size());
id<MTLDevice> device = gpu_helper_.mtlDevice; id<MTLDevice> device = gpu_helper_.mtlDevice;
@ -566,68 +630,58 @@ REGISTER_CALCULATOR(TfLiteInferenceCalculator);
cc->Outputs() cc->Outputs()
.Tag(kTensorsGpuTag) .Tag(kTensorsGpuTag)
.Add(output_tensors_gpu.release(), cc->InputTimestamp()); .Add(output_tensors_gpu.release(), cc->InputTimestamp());
#else #endif // MEDIAPIPE_TFLITE_GL_INFERENCE
RET_CHECK_FAIL() << "GPU processing not enabled.";
#endif // !MEDIAPIPE_DISABLE_GPU
} else {
// Output result tensors (CPU).
const auto& tensor_indexes = interpreter_->outputs();
for (int i = 0; i < tensor_indexes.size(); ++i) {
TfLiteTensor* tensor = interpreter_->tensor(tensor_indexes[i]);
output_tensors_cpu->emplace_back(*tensor);
}
cc->Outputs()
.Tag(kTensorsTag)
.Add(output_tensors_cpu.release(), cc->InputTimestamp());
} }
return ::mediapipe::OkStatus(); return ::mediapipe::OkStatus();
} }
::mediapipe::Status TfLiteInferenceCalculator::Close(CalculatorContext* cc) { ::mediapipe::Status TfLiteInferenceCalculator::InitTFLiteGPURunner(
if (delegate_) { CalculatorContext* cc) {
if (gpu_inference_) { #if MEDIAPIPE_TFLITE_GL_INFERENCE
#if !defined(MEDIAPIPE_DISABLE_GL_COMPUTE) ASSIGN_OR_RETURN(model_packet_, GetModelAsPacket(*cc));
MP_RETURN_IF_ERROR(gpu_helper_.RunInGlContext([this]() -> Status { const auto& model = *model_packet_.Get<TfLiteModelPtr>();
interpreter_ = nullptr; tflite::ops::builtin::BuiltinOpResolver op_resolver;
delegate_ = nullptr; if (cc->InputSidePackets().HasTag("CUSTOM_OP_RESOLVER")) {
for (int i = 0; i < gpu_data_in_.size(); ++i) { op_resolver = cc->InputSidePackets()
gpu_data_in_[i].reset(); .Tag("CUSTOM_OP_RESOLVER")
} .Get<tflite::ops::builtin::BuiltinOpResolver>();
for (int i = 0; i < gpu_data_out_.size(); ++i) {
gpu_data_out_[i].reset();
}
return ::mediapipe::OkStatus();
}));
#elif defined(MEDIAPIPE_IOS)
interpreter_ = nullptr;
delegate_ = nullptr;
for (int i = 0; i < gpu_data_in_.size(); ++i) {
gpu_data_in_[i].reset();
}
for (int i = 0; i < gpu_data_out_.size(); ++i) {
gpu_data_out_[i].reset();
}
#endif
} else {
interpreter_ = nullptr;
delegate_ = nullptr;
}
}
#if defined(MEDIAPIPE_EDGE_TPU)
edgetpu_context_.reset();
#endif
return ::mediapipe::OkStatus();
} }
// Calculator Auxiliary Section // Create runner
tflite::gpu::InferenceOptions options;
options.priority1 = tflite::gpu::InferencePriority::MIN_LATENCY;
options.priority2 = tflite::gpu::InferencePriority::AUTO;
options.priority3 = tflite::gpu::InferencePriority::AUTO;
options.usage = tflite::gpu::InferenceUsage::SUSTAINED_SPEED;
tflite_gpu_runner_ = std::make_unique<tflite::gpu::TFLiteGPURunner>(options);
RET_CHECK_CALL(tflite_gpu_runner_->InitializeWithModel(model, op_resolver));
// Allocate interpreter memory for cpu output.
if (!gpu_output_) {
interpreter_ = absl::make_unique<tflite::Interpreter>();
const int num_outputs = tflite_gpu_runner_->GetOutputShapes().size();
interpreter_->AddTensors(num_outputs);
std::vector<int> indices(num_outputs);
for (int i = 0; i < num_outputs; ++i) indices[i] = i;
// There is no ResizeOutputTensor(), so we use 'inputs' space instead.
interpreter_->SetInputs(indices);
TfLiteQuantization quant;
quant.type = kTfLiteNoQuantization;
quant.params = nullptr;
for (int i = 0; i < num_outputs; ++i) {
auto shape = tflite_gpu_runner_->GetOutputShapes()[i];
const int tensor_idx = interpreter_->inputs()[i];
interpreter_->SetTensorParametersReadWrite(tensor_idx, kTfLiteFloat32, "",
{shape.c}, quant);
CHECK(interpreter_->ResizeInputTensor(
tensor_idx, {shape.h, shape.w, shape.c}) == kTfLiteOk);
}
CHECK(interpreter_->AllocateTensors() == kTfLiteOk);
}
::mediapipe::Status TfLiteInferenceCalculator::InitTFLiteGPURunner() {
#if !defined(MEDIAPIPE_DISABLE_GL_COMPUTE)
// Create and bind OpenGL buffers for outputs. // Create and bind OpenGL buffers for outputs.
// These buffers are created onve and later their ids are jut passed to the // The buffers are created once and their ids are passed to calculator outputs
// calculator outputs.
gpu_data_out_.resize(tflite_gpu_runner_->outputs_size()); gpu_data_out_.resize(tflite_gpu_runner_->outputs_size());
for (int i = 0; i < tflite_gpu_runner_->outputs_size(); ++i) { for (int i = 0; i < tflite_gpu_runner_->outputs_size(); ++i) {
gpu_data_out_[i] = absl::make_unique<GPUData>(); gpu_data_out_[i] = absl::make_unique<GPUData>();
@ -638,15 +692,20 @@ REGISTER_CALCULATOR(TfLiteInferenceCalculator);
gpu_data_out_[i]->elements, &gpu_data_out_[i]->buffer)); gpu_data_out_[i]->elements, &gpu_data_out_[i]->buffer));
} }
RET_CHECK_CALL(tflite_gpu_runner_->Build()); RET_CHECK_CALL(tflite_gpu_runner_->Build());
#endif #endif // MEDIAPIPE_TFLITE_GL_INFERENCE
return ::mediapipe::OkStatus(); return ::mediapipe::OkStatus();
} }
::mediapipe::Status TfLiteInferenceCalculator::LoadModel( ::mediapipe::Status TfLiteInferenceCalculator::LoadModel(
CalculatorContext* cc) { CalculatorContext* cc) {
if (use_advanced_gpu_api_) {
// Use InitTFLiteGPURunner for everything.
return ::mediapipe::OkStatus();
}
ASSIGN_OR_RETURN(model_packet_, GetModelAsPacket(*cc)); ASSIGN_OR_RETURN(model_packet_, GetModelAsPacket(*cc));
const auto& model = *model_packet_.Get<TfLiteModelPtr>(); const auto& model = *model_packet_.Get<TfLiteModelPtr>();
tflite::ops::builtin::BuiltinOpResolver op_resolver; tflite::ops::builtin::BuiltinOpResolver op_resolver;
if (cc->InputSidePackets().HasTag("CUSTOM_OP_RESOLVER")) { if (cc->InputSidePackets().HasTag("CUSTOM_OP_RESOLVER")) {
op_resolver = cc->InputSidePackets() op_resolver = cc->InputSidePackets()
@ -654,19 +713,6 @@ REGISTER_CALCULATOR(TfLiteInferenceCalculator);
.Get<tflite::ops::builtin::BuiltinOpResolver>(); .Get<tflite::ops::builtin::BuiltinOpResolver>();
} }
#if !defined(MEDIAPIPE_DISABLE_GL_COMPUTE)
if (use_advanced_gpu_api_) {
tflite::gpu::InferenceOptions options;
options.priority1 = tflite::gpu::InferencePriority::MIN_LATENCY;
options.priority2 = tflite::gpu::InferencePriority::AUTO;
options.priority3 = tflite::gpu::InferencePriority::AUTO;
options.usage = tflite::gpu::InferenceUsage::SUSTAINED_SPEED;
tflite_gpu_runner_ =
std::make_unique<tflite::gpu::TFLiteGPURunner>(options);
return tflite_gpu_runner_->InitializeWithModel(model, op_resolver);
}
#endif
#if defined(MEDIAPIPE_EDGE_TPU) #if defined(MEDIAPIPE_EDGE_TPU)
interpreter_ = interpreter_ =
BuildEdgeTpuInterpreter(model, &op_resolver, edgetpu_context_.get()); BuildEdgeTpuInterpreter(model, &op_resolver, edgetpu_context_.get());
@ -771,7 +817,7 @@ REGISTER_CALCULATOR(TfLiteInferenceCalculator);
return ::mediapipe::OkStatus(); return ::mediapipe::OkStatus();
} }
#if !defined(MEDIAPIPE_DISABLE_GL_COMPUTE) #if MEDIAPIPE_TFLITE_GL_INFERENCE
// Configure and create the delegate. // Configure and create the delegate.
TfLiteGpuDelegateOptions options = TfLiteGpuDelegateOptionsDefault(); TfLiteGpuDelegateOptions options = TfLiteGpuDelegateOptionsDefault();
options.compile_options.precision_loss_allowed = 1; options.compile_options.precision_loss_allowed = 1;
@ -832,9 +878,9 @@ REGISTER_CALCULATOR(TfLiteInferenceCalculator);
// Must call this last. // Must call this last.
RET_CHECK_EQ(interpreter_->ModifyGraphWithDelegate(delegate_.get()), RET_CHECK_EQ(interpreter_->ModifyGraphWithDelegate(delegate_.get()),
kTfLiteOk); kTfLiteOk);
#endif // OpenGL #endif // MEDIAPIPE_TFLITE_GL_INFERENCE
#if defined(MEDIAPIPE_IOS) #if MEDIAPIPE_TFLITE_METAL_INFERENCE
const int kHalfSize = 2; // sizeof(half) const int kHalfSize = 2; // sizeof(half)
// Configure and create the delegate. // Configure and create the delegate.
TFLGpuDelegateOptions options; TFLGpuDelegateOptions options;
@ -958,7 +1004,7 @@ REGISTER_CALCULATOR(TfLiteInferenceCalculator);
"Error initializating output buffer converter"); "Error initializating output buffer converter");
} }
} }
#endif // iOS #endif // MEDIAPIPE_TFLITE_METAL_INFERENCE
return ::mediapipe::OkStatus(); return ::mediapipe::OkStatus();
} }

View File

@ -45,6 +45,8 @@ message TfLiteInferenceCalculatorOptions {
message Gpu { message Gpu {
// Experimental, Android/Linux only. Use TFLite GPU delegate API2 for // Experimental, Android/Linux only. Use TFLite GPU delegate API2 for
// the NN inference. // the NN inference.
// example:
// delegate: { gpu { use_advanced_gpu_api: true } }
optional bool use_advanced_gpu_api = 1 [default = false]; optional bool use_advanced_gpu_api = 1 [default = false];
} }
// Android only. // Android only.

View File

@ -25,17 +25,18 @@
#include "mediapipe/framework/formats/location.h" #include "mediapipe/framework/formats/location.h"
#include "mediapipe/framework/formats/object_detection/anchor.pb.h" #include "mediapipe/framework/formats/object_detection/anchor.pb.h"
#include "mediapipe/framework/port/ret_check.h" #include "mediapipe/framework/port/ret_check.h"
#include "mediapipe/util/tflite/config.h"
#include "tensorflow/lite/interpreter.h" #include "tensorflow/lite/interpreter.h"
#if !defined(MEDIAPIPE_DISABLE_GL_COMPUTE) #if MEDIAPIPE_TFLITE_GL_INFERENCE
#include "mediapipe/gpu/gl_calculator_helper.h" #include "mediapipe/gpu/gl_calculator_helper.h"
#include "tensorflow/lite/delegates/gpu/gl/gl_buffer.h" #include "tensorflow/lite/delegates/gpu/gl/gl_buffer.h"
#include "tensorflow/lite/delegates/gpu/gl/gl_program.h" #include "tensorflow/lite/delegates/gpu/gl/gl_program.h"
#include "tensorflow/lite/delegates/gpu/gl/gl_shader.h" #include "tensorflow/lite/delegates/gpu/gl/gl_shader.h"
#include "tensorflow/lite/delegates/gpu/gl_delegate.h" #include "tensorflow/lite/delegates/gpu/gl_delegate.h"
#endif // !MEDIAPIPE_DISABLE_GPU #endif // MEDIAPIPE_TFLITE_GL_INFERENCE
#if defined(MEDIAPIPE_IOS) #if MEDIAPIPE_TFLITE_METAL_INFERENCE
#import <CoreVideo/CoreVideo.h> #import <CoreVideo/CoreVideo.h>
#import <Metal/Metal.h> #import <Metal/Metal.h>
#import <MetalKit/MetalKit.h> #import <MetalKit/MetalKit.h>
@ -44,7 +45,7 @@
#include "mediapipe/gpu/MPPMetalUtil.h" #include "mediapipe/gpu/MPPMetalUtil.h"
#include "mediapipe/gpu/gpu_buffer.h" #include "mediapipe/gpu/gpu_buffer.h"
#include "tensorflow/lite/delegates/gpu/metal_delegate.h" #include "tensorflow/lite/delegates/gpu/metal_delegate.h"
#endif // iOS #endif // MEDIAPIPE_TFLITE_METAL_INFERENCE
namespace { namespace {
constexpr int kNumInputTensorsWithAnchors = 3; constexpr int kNumInputTensorsWithAnchors = 3;
@ -56,22 +57,17 @@ constexpr char kTensorsGpuTag[] = "TENSORS_GPU";
namespace mediapipe { namespace mediapipe {
#if !defined(MEDIAPIPE_DISABLE_GL_COMPUTE) #if MEDIAPIPE_TFLITE_GL_INFERENCE
using ::tflite::gpu::gl::CreateReadWriteShaderStorageBuffer; using ::tflite::gpu::gl::CreateReadWriteShaderStorageBuffer;
using ::tflite::gpu::gl::GlShader; using ::tflite::gpu::gl::GlShader;
#endif
#if !defined(MEDIAPIPE_DISABLE_GL_COMPUTE)
typedef ::tflite::gpu::gl::GlBuffer GpuTensor;
typedef ::tflite::gpu::gl::GlProgram GpuProgram; typedef ::tflite::gpu::gl::GlProgram GpuProgram;
#elif defined(MEDIAPIPE_IOS) #elif MEDIAPIPE_TFLITE_METAL_INFERENCE
typedef id<MTLBuffer> GpuTensor;
typedef id<MTLComputePipelineState> GpuProgram; typedef id<MTLComputePipelineState> GpuProgram;
#endif #endif // MEDIAPIPE_TFLITE_GL_INFERENCE
namespace { namespace {
#if !defined(MEDIAPIPE_DISABLE_GPU) && !defined(__EMSCRIPTEN__) #if MEDIAPIPE_TFLITE_GPU_SUPPORTED
struct GPUData { struct GPUData {
GpuProgram decode_program; GpuProgram decode_program;
GpuProgram score_program; GpuProgram score_program;
@ -81,7 +77,7 @@ struct GPUData {
GpuTensor scored_boxes_buffer; GpuTensor scored_boxes_buffer;
GpuTensor raw_scores_buffer; GpuTensor raw_scores_buffer;
}; };
#endif #endif // MEDIAPIPE_TFLITE_GPU_SUPPORTED
void ConvertRawValuesToAnchors(const float* raw_anchors, int num_boxes, void ConvertRawValuesToAnchors(const float* raw_anchors, int num_boxes,
std::vector<Anchor>* anchors) { std::vector<Anchor>* anchors) {
@ -181,13 +177,13 @@ class TfLiteTensorsToDetectionsCalculator : public CalculatorBase {
std::vector<Anchor> anchors_; std::vector<Anchor> anchors_;
bool side_packet_anchors_{}; bool side_packet_anchors_{};
#if !defined(MEDIAPIPE_DISABLE_GL_COMPUTE) #if MEDIAPIPE_TFLITE_GL_INFERENCE
mediapipe::GlCalculatorHelper gpu_helper_; mediapipe::GlCalculatorHelper gpu_helper_;
std::unique_ptr<GPUData> gpu_data_; std::unique_ptr<GPUData> gpu_data_;
#elif defined(MEDIAPIPE_IOS) #elif MEDIAPIPE_TFLITE_METAL_INFERENCE
MPPMetalHelper* gpu_helper_ = nullptr; MPPMetalHelper* gpu_helper_ = nullptr;
std::unique_ptr<GPUData> gpu_data_; std::unique_ptr<GPUData> gpu_data_;
#endif #endif // MEDIAPIPE_TFLITE_GL_INFERENCE
bool gpu_input_ = false; bool gpu_input_ = false;
bool anchors_init_ = false; bool anchors_init_ = false;
@ -205,12 +201,10 @@ REGISTER_CALCULATOR(TfLiteTensorsToDetectionsCalculator);
cc->Inputs().Tag(kTensorsTag).Set<std::vector<TfLiteTensor>>(); cc->Inputs().Tag(kTensorsTag).Set<std::vector<TfLiteTensor>>();
} }
#if !defined(MEDIAPIPE_DISABLE_GPU) && !defined(__EMSCRIPTEN__)
if (cc->Inputs().HasTag(kTensorsGpuTag)) { if (cc->Inputs().HasTag(kTensorsGpuTag)) {
cc->Inputs().Tag(kTensorsGpuTag).Set<std::vector<GpuTensor>>(); cc->Inputs().Tag(kTensorsGpuTag).Set<std::vector<GpuTensor>>();
use_gpu |= true; use_gpu |= true;
} }
#endif // !MEDIAPIPE_DISABLE_GPU
if (cc->Outputs().HasTag("DETECTIONS")) { if (cc->Outputs().HasTag("DETECTIONS")) {
cc->Outputs().Tag("DETECTIONS").Set<std::vector<Detection>>(); cc->Outputs().Tag("DETECTIONS").Set<std::vector<Detection>>();
@ -223,11 +217,11 @@ REGISTER_CALCULATOR(TfLiteTensorsToDetectionsCalculator);
} }
if (use_gpu) { if (use_gpu) {
#if !defined(MEDIAPIPE_DISABLE_GL_COMPUTE) #if MEDIAPIPE_TFLITE_GL_INFERENCE
MP_RETURN_IF_ERROR(mediapipe::GlCalculatorHelper::UpdateContract(cc)); MP_RETURN_IF_ERROR(mediapipe::GlCalculatorHelper::UpdateContract(cc));
#elif defined(MEDIAPIPE_IOS) #elif MEDIAPIPE_TFLITE_METAL_INFERENCE
MP_RETURN_IF_ERROR([MPPMetalHelper updateContract:cc]); MP_RETURN_IF_ERROR([MPPMetalHelper updateContract:cc]);
#endif #endif // MEDIAPIPE_TFLITE_GL_INFERENCE
} }
return ::mediapipe::OkStatus(); return ::mediapipe::OkStatus();
@ -239,12 +233,12 @@ REGISTER_CALCULATOR(TfLiteTensorsToDetectionsCalculator);
if (cc->Inputs().HasTag(kTensorsGpuTag)) { if (cc->Inputs().HasTag(kTensorsGpuTag)) {
gpu_input_ = true; gpu_input_ = true;
#if !defined(MEDIAPIPE_DISABLE_GL_COMPUTE) #if MEDIAPIPE_TFLITE_GL_INFERENCE
MP_RETURN_IF_ERROR(gpu_helper_.Open(cc)); MP_RETURN_IF_ERROR(gpu_helper_.Open(cc));
#elif defined(MEDIAPIPE_IOS) #elif MEDIAPIPE_TFLITE_METAL_INFERENCE
gpu_helper_ = [[MPPMetalHelper alloc] initWithCalculatorContext:cc]; gpu_helper_ = [[MPPMetalHelper alloc] initWithCalculatorContext:cc];
RET_CHECK(gpu_helper_); RET_CHECK(gpu_helper_);
#endif #endif // MEDIAPIPE_TFLITE_GL_INFERENCE
} }
MP_RETURN_IF_ERROR(LoadOptions(cc)); MP_RETURN_IF_ERROR(LoadOptions(cc));
@ -401,7 +395,7 @@ REGISTER_CALCULATOR(TfLiteTensorsToDetectionsCalculator);
} }
::mediapipe::Status TfLiteTensorsToDetectionsCalculator::ProcessGPU( ::mediapipe::Status TfLiteTensorsToDetectionsCalculator::ProcessGPU(
CalculatorContext* cc, std::vector<Detection>* output_detections) { CalculatorContext* cc, std::vector<Detection>* output_detections) {
#if !defined(MEDIAPIPE_DISABLE_GL_COMPUTE) #if MEDIAPIPE_TFLITE_GL_INFERENCE
const auto& input_tensors = const auto& input_tensors =
cc->Inputs().Tag(kTensorsGpuTag).Get<std::vector<GpuTensor>>(); cc->Inputs().Tag(kTensorsGpuTag).Get<std::vector<GpuTensor>>();
RET_CHECK_GE(input_tensors.size(), 2); RET_CHECK_GE(input_tensors.size(), 2);
@ -464,7 +458,7 @@ REGISTER_CALCULATOR(TfLiteTensorsToDetectionsCalculator);
return ::mediapipe::OkStatus(); return ::mediapipe::OkStatus();
})); }));
#elif defined(MEDIAPIPE_IOS) #elif MEDIAPIPE_TFLITE_METAL_INFERENCE
const auto& input_tensors = const auto& input_tensors =
cc->Inputs().Tag(kTensorsGpuTag).Get<std::vector<GpuTensor>>(); cc->Inputs().Tag(kTensorsGpuTag).Get<std::vector<GpuTensor>>();
@ -546,17 +540,17 @@ REGISTER_CALCULATOR(TfLiteTensorsToDetectionsCalculator);
#else #else
LOG(ERROR) << "GPU input on non-Android not supported yet."; LOG(ERROR) << "GPU input on non-Android not supported yet.";
#endif #endif // MEDIAPIPE_TFLITE_GL_INFERENCE
return ::mediapipe::OkStatus(); return ::mediapipe::OkStatus();
} }
::mediapipe::Status TfLiteTensorsToDetectionsCalculator::Close( ::mediapipe::Status TfLiteTensorsToDetectionsCalculator::Close(
CalculatorContext* cc) { CalculatorContext* cc) {
#if !defined(MEDIAPIPE_DISABLE_GL_COMPUTE) #if MEDIAPIPE_TFLITE_GL_INFERENCE
gpu_helper_.RunInGlContext([this] { gpu_data_.reset(); }); gpu_helper_.RunInGlContext([this] { gpu_data_.reset(); });
#elif defined(MEDIAPIPE_IOS) #elif MEDIAPIPE_TFLITE_METAL_INFERENCE
gpu_data_.reset(); gpu_data_.reset();
#endif #endif // MEDIAPIPE_TFLITE_GL_INFERENCE
return ::mediapipe::OkStatus(); return ::mediapipe::OkStatus();
} }
@ -705,7 +699,7 @@ Detection TfLiteTensorsToDetectionsCalculator::ConvertToDetection(
::mediapipe::Status TfLiteTensorsToDetectionsCalculator::GpuInit( ::mediapipe::Status TfLiteTensorsToDetectionsCalculator::GpuInit(
CalculatorContext* cc) { CalculatorContext* cc) {
#if !defined(MEDIAPIPE_DISABLE_GL_COMPUTE) #if MEDIAPIPE_TFLITE_GL_INFERENCE
MP_RETURN_IF_ERROR(gpu_helper_.RunInGlContext([this]() MP_RETURN_IF_ERROR(gpu_helper_.RunInGlContext([this]()
-> ::mediapipe::Status { -> ::mediapipe::Status {
gpu_data_ = absl::make_unique<GPUData>(); gpu_data_ = absl::make_unique<GPUData>();
@ -918,7 +912,7 @@ void main() {
return ::mediapipe::OkStatus(); return ::mediapipe::OkStatus();
})); }));
#elif defined(MEDIAPIPE_IOS) #elif MEDIAPIPE_TFLITE_METAL_INFERENCE
gpu_data_ = absl::make_unique<GPUData>(); gpu_data_ = absl::make_unique<GPUData>();
id<MTLDevice> device = gpu_helper_.mtlDevice; id<MTLDevice> device = gpu_helper_.mtlDevice;
@ -1148,7 +1142,7 @@ kernel void scoreKernel(
CHECK_LT(num_classes_, max_wg_size) << "# classes must be <" << max_wg_size; CHECK_LT(num_classes_, max_wg_size) << "# classes must be <" << max_wg_size;
} }
#endif // !defined(MEDIAPIPE_DISABLE_GL_COMPUTE) #endif // MEDIAPIPE_TFLITE_GL_INFERENCE
return ::mediapipe::OkStatus(); return ::mediapipe::OkStatus();
} }

View File

@ -217,11 +217,11 @@ REGISTER_CALCULATOR(TfLiteTensorsToLandmarksCalculator);
for (int i = 0; i < output_landmarks.landmark_size(); ++i) { for (int i = 0; i < output_landmarks.landmark_size(); ++i) {
const Landmark& landmark = output_landmarks.landmark(i); const Landmark& landmark = output_landmarks.landmark(i);
NormalizedLandmark* norm_landmark = output_norm_landmarks.add_landmark(); NormalizedLandmark* norm_landmark = output_norm_landmarks.add_landmark();
norm_landmark->set_x(static_cast<float>(landmark.x()) / norm_landmark->set_x(landmark.x() / options_.input_image_width());
options_.input_image_width()); norm_landmark->set_y(landmark.y() / options_.input_image_height());
norm_landmark->set_y(static_cast<float>(landmark.y()) / // Scale Z coordinate as X + allow additional uniform normalization.
options_.input_image_height()); norm_landmark->set_z(landmark.z() / options_.input_image_width() /
norm_landmark->set_z(landmark.z() / options_.normalize_z()); options_.normalize_z());
norm_landmark->set_visibility(landmark.visibility()); norm_landmark->set_visibility(landmark.visibility());
} }
cc->Outputs() cc->Outputs()

View File

@ -29,7 +29,8 @@ message TfLiteTensorsToLandmarksCalculatorOptions {
required int32 num_landmarks = 1; required int32 num_landmarks = 1;
// Size of the input image for the model. These options are used only when // Size of the input image for the model. These options are used only when
// normalized landmarks is needed. // normalized landmarks are needed. Z coordinate is scaled as X assuming
// a weak perspective projection camera model.
optional int32 input_image_width = 2; optional int32 input_image_width = 2;
optional int32 input_image_height = 3; optional int32 input_image_height = 3;
@ -46,6 +47,8 @@ message TfLiteTensorsToLandmarksCalculatorOptions {
// beforehand. // beforehand.
optional bool flip_horizontally = 6 [default = false]; optional bool flip_horizontally = 6 [default = false];
// A value that z values should be divided by. // A value that Z coordinates should be divided by. This option is used only
// when normalized landmarks are needed. It is applied in addition to Z
// coordinate being re-scaled as X.
optional float normalize_z = 5 [default = 1.0]; optional float normalize_z = 5 [default = 1.0];
} }

View File

@ -376,6 +376,7 @@ cc_library(
visibility = ["//visibility:public"], visibility = ["//visibility:public"],
deps = [ deps = [
":timed_box_list_id_to_label_calculator_cc_proto", ":timed_box_list_id_to_label_calculator_cc_proto",
"@com_google_absl//absl/container:node_hash_map",
"//mediapipe/framework/port:status", "//mediapipe/framework/port:status",
"//mediapipe/framework:calculator_framework", "//mediapipe/framework:calculator_framework",
"//mediapipe/framework:packet", "//mediapipe/framework:packet",

View File

@ -122,11 +122,13 @@ class LandmarkLetterboxRemovalCalculator : public CalculatorBase {
NormalizedLandmark* new_landmark = output_landmarks.add_landmark(); NormalizedLandmark* new_landmark = output_landmarks.add_landmark();
const float new_x = (landmark.x() - left) / (1.0f - left_and_right); const float new_x = (landmark.x() - left) / (1.0f - left_and_right);
const float new_y = (landmark.y() - top) / (1.0f - top_and_bottom); const float new_y = (landmark.y() - top) / (1.0f - top_and_bottom);
const float new_z =
landmark.z() / (1.0f - left_and_right); // Scale Z coordinate as X.
new_landmark->set_x(new_x); new_landmark->set_x(new_x);
new_landmark->set_y(new_y); new_landmark->set_y(new_y);
// Keep z-coord as is. // Keep z-coord as is.
new_landmark->set_z(landmark.z()); new_landmark->set_z(new_z);
// Keep visibility as is. // Keep visibility as is.
new_landmark->set_visibility(landmark.visibility()); new_landmark->set_visibility(landmark.visibility());
} }

View File

@ -123,11 +123,12 @@ class LandmarkProjectionCalculator : public CalculatorBase {
new_x = new_x * input_rect.width() + input_rect.x_center(); new_x = new_x * input_rect.width() + input_rect.x_center();
new_y = new_y * input_rect.height() + input_rect.y_center(); new_y = new_y * input_rect.height() + input_rect.y_center();
const float new_z =
landmark.z() * input_rect.width(); // Scale Z coordinate as X.
new_landmark->set_x(new_x); new_landmark->set_x(new_x);
new_landmark->set_y(new_y); new_landmark->set_y(new_y);
// Keep z-coord as is. new_landmark->set_z(new_z);
new_landmark->set_z(landmark.z());
// Keep visibility as is. // Keep visibility as is.
new_landmark->set_visibility(landmark.visibility()); new_landmark->set_visibility(landmark.visibility());
} }

View File

@ -12,6 +12,7 @@
// See the License for the specific language governing permissions and // See the License for the specific language governing permissions and
// limitations under the License. // limitations under the License.
#include "absl/container/node_hash_map.h"
#include "mediapipe/calculators/util/timed_box_list_id_to_label_calculator.pb.h" #include "mediapipe/calculators/util/timed_box_list_id_to_label_calculator.pb.h"
#include "mediapipe/framework/calculator_framework.h" #include "mediapipe/framework/calculator_framework.h"
#include "mediapipe/framework/packet.h" #include "mediapipe/framework/packet.h"
@ -53,7 +54,7 @@ class TimedBoxListIdToLabelCalculator : public CalculatorBase {
::mediapipe::Status Process(CalculatorContext* cc) override; ::mediapipe::Status Process(CalculatorContext* cc) override;
private: private:
std::unordered_map<int, std::string> label_map_; absl::node_hash_map<int, std::string> label_map_;
}; };
REGISTER_CALCULATOR(TimedBoxListIdToLabelCalculator); REGISTER_CALCULATOR(TimedBoxListIdToLabelCalculator);

View File

@ -1,4 +1 @@
MediaPipe Examples This directory contains MediaPipe example applications for Android. Please see [Solutions](https://solutions.mediapipe.dev)for details.
==================
This directory contains MediaPipe Android example applications. Please see [src/java/com/google/mediapipe/apps/README.md](src/java/com/google/mediapipe/apps/README.md) for details.

View File

@ -1,7 +0,0 @@
tricorder: {
options: {
builder: {
config: "android_arm64"
}
}
}

View File

@ -83,7 +83,7 @@ android_binary(
manifest = "//mediapipe/examples/android/src/java/com/google/mediapipe/apps/basic:AndroidManifest.xml", manifest = "//mediapipe/examples/android/src/java/com/google/mediapipe/apps/basic:AndroidManifest.xml",
manifest_values = { manifest_values = {
"applicationId": "com.google.mediapipe.apps.objectdetection3d", "applicationId": "com.google.mediapipe.apps.objectdetection3d",
"appName": "Object Detection 3D", "appName": "Objectron",
"mainActivity": ".MainActivity", "mainActivity": ".MainActivity",
"cameraFacingFront": "False", "cameraFacingFront": "False",
"binaryGraphName": "object_detection_3d.binarypb", "binaryGraphName": "object_detection_3d.binarypb",

View File

@ -1,113 +1 @@
**Hello World** This directory contains MediaPipe example applications for desktop. Please see [Solutions](https://solutions.mediapipe.dev)for details.
To build the "Hello World" example, use:
```
bazel build -c opt mediapipe/examples/desktop/hello_world:hello_world
```
and then run it using:
```
export GLOG_logtostderr=1
bazel-bin/mediapipe/examples/desktop/hello_world/hello_world
```
**TFlite Object Detection**
To build the object detection demo using a TFLite model on desktop, use:
```
bazel build -c opt mediapipe/examples/desktop/object_detection:object_detection_tflite --define MEDIAPIPE_DISABLE_GPU=1
```
and run it using:
```
export GLOG_logtostderr=1
bazel-bin/mediapipe/examples/desktop/object_detection/object_detection_tflite \
--calculator_graph_config_file=mediapipe/graphs/object_detection/object_detection_desktop_tflite_graph.pbtxt \
--input_side_packets=input_video_path=/path/to/input/file,output_video_path=/path/to/output/file
```
**TensorFlow Object Detection**
To build the object detection demo using a TensorFlow model on desktop, use:
```
export GLOG_logtostderr=1
bazel build -c opt mediapipe/examples/desktop/object_detection:object_detection_tensorflow \
--define MEDIAPIPE_DISABLE_GPU=1
```
and run it using:
```
export GLOG_logtostderr=1
bazel-bin/mediapipe/examples/desktop/object_detection/object_detection_tensorflow \
--calculator_graph_config_file=mediapipe/graphs/object_detection/object_detection_desktop_tensorflow_graph.pbtxt \
--input_side_packets=input_video_path=/path/to/input/file,output_video_path=/path/to/output/file
```
**TFlite Hand Detection**
To build the hand detection demo using a TFLite model on desktop, use:
```
bazel build -c opt mediapipe/examples/desktop/hand_tracking:hand_tracking_tflite --define MEDIAPIPE_DISABLE_GPU=1
```
and run it using:
```
export GLOG_logtostderr=1
bazel-bin/mediapipe/examples/desktop/hand_tracking/hand_tracking_tflite \
--calculator_graph_config_file=mediapipe/graphs/hand_tracking/hand_detection_desktop.pbtxt \
--input_side_packets=input_video_path=/path/to/input/file,output_video_path=/path/to/output/file
```
**TFlite Hand Tracking**
To build the hand tracking demo using a TFLite model on desktop, use:
```
bazel build -c opt mediapipe/examples/desktop/hand_tracking:hand_tracking_tflite --define MEDIAPIPE_DISABLE_GPU=1
```
and run it using:
```
export GLOG_logtostderr=1
bazel-bin/mediapipe/examples/desktop/hand_tracking/hand_tracking_tflite \
--calculator_graph_config_file=mediapipe/graphs/hand_tracking/hand_tracking_desktop.pbtxt \
--input_side_packets=input_video_path=/path/to/input/file,output_video_path=/path/to/output/file
```
**TFlite Multi-Hand Tracking**
To build the multi-hand tracking demo using a TFLite model on desktop, use:
```
bazel build -c opt mediapipe/examples/desktop/multi_hand_tracking:multi_hand_tracking_tflite --define MEDIAPIPE_DISABLE_GPU=1
```
and run it using:
```
export GLOG_logtostderr=1
bazel-bin/mediapipe/examples/desktop/multi_hand_tracking/multi_hand_tracking_tflite \
--calculator_graph_config_file=mediapipe/graphs/hand_tracking/multi_hand_tracking_desktop.pbtxt \
--input_side_packets=input_video_path=/path/to/input/file,output_video_path=/path/to/output/file
```
To change the number of hands to `x` in this application, change:
1. `min_size:x` in `CollectionHasMinSizeCalculatorOptions` in `mediapipe/graphs/hand_tracking/multi_hand_tracking_desktop.pbtxt`.
2. `max_vec_size:x` in `ClipVectorSizeCalculatorOptions` in `mediapipe/examples/dekstop/hand_tracking/subgraphs/multi_hand_detection_cpu.pbtxt`.

View File

@ -62,8 +62,10 @@ cc_library(
"//mediapipe/examples/desktop/autoflip:autoflip_messages_cc_proto", "//mediapipe/examples/desktop/autoflip:autoflip_messages_cc_proto",
"//mediapipe/examples/desktop/autoflip/quality:kinematic_path_solver", "//mediapipe/examples/desktop/autoflip/quality:kinematic_path_solver",
"//mediapipe/framework:calculator_framework", "//mediapipe/framework:calculator_framework",
"//mediapipe/framework/formats:detection_cc_proto",
"//mediapipe/framework/formats:image_frame", "//mediapipe/framework/formats:image_frame",
"//mediapipe/framework/formats:image_frame_opencv", "//mediapipe/framework/formats:location_data_cc_proto",
"//mediapipe/framework/formats:rect_cc_proto",
"//mediapipe/framework/port:ret_check", "//mediapipe/framework/port:ret_check",
"//mediapipe/framework/port:status", "//mediapipe/framework/port:status",
], ],
@ -126,17 +128,20 @@ cc_test(
":content_zooming_calculator", ":content_zooming_calculator",
":content_zooming_calculator_cc_proto", ":content_zooming_calculator_cc_proto",
"//mediapipe/examples/desktop/autoflip:autoflip_messages_cc_proto", "//mediapipe/examples/desktop/autoflip:autoflip_messages_cc_proto",
"//mediapipe/examples/desktop/autoflip/quality:kinematic_path_solver",
"//mediapipe/framework:calculator_framework", "//mediapipe/framework:calculator_framework",
"//mediapipe/framework:calculator_runner", "//mediapipe/framework:calculator_runner",
"//mediapipe/framework/formats:detection_cc_proto",
"//mediapipe/framework/formats:image_frame", "//mediapipe/framework/formats:image_frame",
"//mediapipe/framework/formats:image_frame_opencv", "//mediapipe/framework/formats:image_frame_opencv",
"//mediapipe/framework/formats:location_data_cc_proto",
"//mediapipe/framework/formats:rect_cc_proto",
"//mediapipe/framework/port:benchmark", "//mediapipe/framework/port:benchmark",
"//mediapipe/framework/port:gtest_main", "//mediapipe/framework/port:gtest_main",
"//mediapipe/framework/port:opencv_core", "//mediapipe/framework/port:opencv_core",
"//mediapipe/framework/port:parse_text_proto", "//mediapipe/framework/port:parse_text_proto",
"//mediapipe/framework/port:ret_check", "//mediapipe/framework/port:ret_check",
"//mediapipe/framework/port:status", "//mediapipe/framework/port:status",
"@com_google_absl//absl/strings",
], ],
) )

View File

@ -19,16 +19,20 @@
#include "mediapipe/examples/desktop/autoflip/calculators/content_zooming_calculator.pb.h" #include "mediapipe/examples/desktop/autoflip/calculators/content_zooming_calculator.pb.h"
#include "mediapipe/examples/desktop/autoflip/quality/kinematic_path_solver.h" #include "mediapipe/examples/desktop/autoflip/quality/kinematic_path_solver.h"
#include "mediapipe/framework/calculator_framework.h" #include "mediapipe/framework/calculator_framework.h"
#include "mediapipe/framework/formats/detection.pb.h"
#include "mediapipe/framework/formats/image_frame.h" #include "mediapipe/framework/formats/image_frame.h"
#include "mediapipe/framework/formats/image_frame_opencv.h" #include "mediapipe/framework/formats/location_data.pb.h"
#include "mediapipe/framework/formats/rect.pb.h"
#include "mediapipe/framework/port/ret_check.h" #include "mediapipe/framework/port/ret_check.h"
#include "mediapipe/framework/port/status.h" #include "mediapipe/framework/port/status.h"
#include "mediapipe/framework/port/status_builder.h" #include "mediapipe/framework/port/status_builder.h"
constexpr char kVideoFrame[] = "VIDEO"; constexpr char kVideoFrame[] = "VIDEO";
constexpr char kVideoSize[] = "VIDEO_SIZE"; constexpr char kVideoSize[] = "VIDEO_SIZE";
constexpr char kDetectionSet[] = "DETECTIONS"; constexpr char kSalientRegions[] = "SALIENT_REGIONS";
constexpr char kDetections[] = "DETECTIONS";
constexpr char kDetectedBorders[] = "BORDERS"; constexpr char kDetectedBorders[] = "BORDERS";
constexpr char kCropRect[] = "CROP_RECT";
// Field-of-view (degrees) of the camera's x-axis (width). // Field-of-view (degrees) of the camera's x-axis (width).
// TODO: Parameterize FOV based on camera specs. // TODO: Parameterize FOV based on camera specs.
constexpr float kWidthFieldOfView = 60; constexpr float kWidthFieldOfView = 60;
@ -37,12 +41,12 @@ namespace mediapipe {
namespace autoflip { namespace autoflip {
// Content zooming calculator zooms in on content when a detection has // Content zooming calculator zooms in on content when a detection has
// "only_required" set true. It does this by computing the value of top/bottom // "only_required" set true or any raw detection input. It does this by
// borders to remove from the output and sends these to the // computing the value of top/bottom borders to remove from the output and sends
// SceneCroppingCalculator. When more than one detections are received the zoom // these to the SceneCroppingCalculator using BORDERS output or a full rect crop
// box is calculated as the union of the detections. Typical applications // using CROP_RECT output. When more than one detections are received the
// include mobile makeover and autofliplive face reframing. Currently only // zoom box is calculated as the union of the detections. Typical applications
// supports y-dimension zooming. // include mobile makeover and autofliplive face reframing.
class ContentZoomingCalculator : public CalculatorBase { class ContentZoomingCalculator : public CalculatorBase {
public: public:
ContentZoomingCalculator() ContentZoomingCalculator()
@ -56,26 +60,32 @@ class ContentZoomingCalculator : public CalculatorBase {
::mediapipe::Status Process(mediapipe::CalculatorContext* cc) override; ::mediapipe::Status Process(mediapipe::CalculatorContext* cc) override;
private: private:
// Converts bounds to tilt offset and height. // Converts bounds to tilt offset, pan offset and height.
::mediapipe::Status ConvertToTiltZoom(float xmin, float xmax, float ymin, ::mediapipe::Status ConvertToPanTiltZoom(float xmin, float xmax, float ymin,
float ymax, int* tilt_offset, float ymax, int* tilt_offset,
int* height); int* pan_offset, int* height);
ContentZoomingCalculatorOptions options_; ContentZoomingCalculatorOptions options_;
// Detection frame width/height. // Detection frame width/height.
int frame_height_; int frame_height_;
int frame_width_; int frame_width_;
// Path solver used to smooth top/bottom border crop values. // Path solver used to smooth top/bottom border crop values.
std::unique_ptr<KinematicPathSolver> path_solver_height_; std::unique_ptr<KinematicPathSolver> path_solver_height_;
std::unique_ptr<KinematicPathSolver> path_solver_width_;
std::unique_ptr<KinematicPathSolver> path_solver_offset_; std::unique_ptr<KinematicPathSolver> path_solver_offset_;
// Are parameters initialized. // Are parameters initialized.
bool initialized_; bool initialized_;
// Stores the time of the last "only_required" input. // Stores the time of the last "only_required" input.
int64 last_only_required_detection_; int64 last_only_required_detection_;
// Border values of last message with detection. // Rect values of last message with detection(s).
int last_measured_height_; int last_measured_height_;
int last_measured_x_offset_;
int last_measured_y_offset_; int last_measured_y_offset_;
// Min border values. // Target aspect ratio.
float min_height_value_; float target_aspect_;
// Max size of bounding box. If input/output aspect ratios are the same,
// will be 1.0. Else, will be less than 1.0 to prevent exceeding the size of
// the image in either dimension.
float max_frame_value_;
}; };
REGISTER_CALCULATOR(ContentZoomingCalculator); REGISTER_CALCULATOR(ContentZoomingCalculator);
@ -92,8 +102,18 @@ REGISTER_CALCULATOR(ContentZoomingCalculator);
return ::mediapipe::UnknownErrorBuilder(MEDIAPIPE_LOC) return ::mediapipe::UnknownErrorBuilder(MEDIAPIPE_LOC)
<< "Input VIDEO or VIDEO_SIZE must be provided."; << "Input VIDEO or VIDEO_SIZE must be provided.";
} }
cc->Inputs().Tag(kDetectionSet).Set<DetectionSet>(); if (cc->Inputs().HasTag(kSalientRegions)) {
cc->Inputs().Tag(kSalientRegions).Set<DetectionSet>();
}
if (cc->Inputs().HasTag(kDetections)) {
cc->Inputs().Tag(kDetections).Set<std::vector<mediapipe::Detection>>();
}
if (cc->Outputs().HasTag(kDetectedBorders)) {
cc->Outputs().Tag(kDetectedBorders).Set<StaticFeatures>(); cc->Outputs().Tag(kDetectedBorders).Set<StaticFeatures>();
}
if (cc->Outputs().HasTag(kCropRect)) {
cc->Outputs().Tag(kCropRect).Set<mediapipe::Rect>();
}
return ::mediapipe::OkStatus(); return ::mediapipe::OkStatus();
} }
@ -108,29 +128,38 @@ REGISTER_CALCULATOR(ContentZoomingCalculator);
if (options_.has_min_motion_to_reframe()) { if (options_.has_min_motion_to_reframe()) {
return ::mediapipe::UnknownErrorBuilder(MEDIAPIPE_LOC) return ::mediapipe::UnknownErrorBuilder(MEDIAPIPE_LOC)
<< "Deprecated min_motion_to_reframe was set, please set " << "Deprecated min_motion_to_reframe was set, please set "
"in kinematic_options_zoom and kinematic_options_tilt directly."; "in kinematic_options_zoom and kinematic_options_tilt "
"directly.";
} }
return ::mediapipe::OkStatus(); return ::mediapipe::OkStatus();
} }
::mediapipe::Status ContentZoomingCalculator::ConvertToTiltZoom( ::mediapipe::Status ContentZoomingCalculator::ConvertToPanTiltZoom(
float xmin, float xmax, float ymin, float ymax, int* tilt_offset, float xmin, float xmax, float ymin, float ymax, int* tilt_offset,
int* height) { int* pan_offset, int* height) {
// Find center of the y-axis offset (for tilt control). // Find center of the y-axis offset (for tilt control).
float y_center = ymin + (ymax - ymin) / 2; float y_center = ymin + (ymax - ymin) / 2;
// Find center of the x-axis offset (for pan control).
float x_center = xmin + (xmax - xmin) / 2;
// Find size and apply scale factor to y-axis. // Find size and apply scale factor to y-axis.
float fit_size = fmax((ymax - ymin) / options_.scale_factor(), xmax - xmin); float fit_size = fmax((ymax - ymin) / options_.scale_factor(), xmax - xmin);
// Apply min zoom for cases where the target size is wider than input frame // Apply max frame for cases where the target size is different than input
// size. // frame size.
fit_size = fmin(min_height_value_, fit_size); fit_size = fmin(max_frame_value_, fit_size);
// Prevent box from extending beyond the image. // Prevent box from extending beyond the image.
if (y_center - fit_size / 2 < 0) { if (y_center - fit_size / 2 < 0) {
y_center = fit_size / 2; y_center = fit_size / 2;
} else if (y_center + fit_size / 2 > 1) { } else if (y_center + fit_size / 2 > 1) {
y_center = 1 - fit_size / 2; y_center = 1 - fit_size / 2;
} }
if (x_center - fit_size / 2 < 0) {
x_center = fit_size / 2;
} else if (x_center + fit_size / 2 > 1) {
x_center = 1 - fit_size / 2;
}
// Scale to pixel coordinates. // Scale to pixel coordinates.
*tilt_offset = frame_height_ * y_center; *tilt_offset = frame_height_ * y_center;
*pan_offset = frame_width_ * x_center;
*height = frame_height_ * fit_size; *height = frame_height_ * fit_size;
return ::mediapipe::OkStatus(); return ::mediapipe::OkStatus();
} }
@ -151,6 +180,20 @@ namespace {
return ::mediapipe::OkStatus(); return ::mediapipe::OkStatus();
} }
::mediapipe::Status UpdateRanges(const mediapipe::Detection& detection,
float* xmin, float* xmax, float* ymin,
float* ymax) {
RET_CHECK(detection.location_data().format() ==
mediapipe::LocationData::RELATIVE_BOUNDING_BOX)
<< "Face detection input is lacking required relative_bounding_box()";
const auto& location = detection.location_data().relative_bounding_box();
*xmin = fmin(*xmin, location.xmin());
*xmax = fmax(*xmax, location.xmin() + location.width());
*ymin = fmin(*ymin, location.ymin());
*ymax = fmax(*ymax, location.ymin() + location.height());
return ::mediapipe::OkStatus();
}
void MakeStaticFeatures(const int top_border, const int bottom_border, void MakeStaticFeatures(const int top_border, const int bottom_border,
const int frame_width, const int frame_height, const int frame_width, const int frame_height,
StaticFeatures* static_feature) { StaticFeatures* static_feature) {
@ -173,10 +216,8 @@ void MakeStaticFeatures(const int top_border, const int bottom_border,
::mediapipe::Status ContentZoomingCalculator::Process( ::mediapipe::Status ContentZoomingCalculator::Process(
mediapipe::CalculatorContext* cc) { mediapipe::CalculatorContext* cc) {
if (cc->Inputs().HasTag(kVideoFrame)) { if (cc->Inputs().HasTag(kVideoFrame)) {
cv::Mat frame = mediapipe::formats::MatView( frame_width_ = cc->Inputs().Tag(kVideoFrame).Get<ImageFrame>().Width();
&cc->Inputs().Tag(kVideoFrame).Get<ImageFrame>()); frame_height_ = cc->Inputs().Tag(kVideoFrame).Get<ImageFrame>().Height();
frame_width_ = frame.cols;
frame_height_ = frame.rows;
} else if (cc->Inputs().HasTag(kVideoSize)) { } else if (cc->Inputs().HasTag(kVideoSize)) {
frame_width_ = frame_width_ =
cc->Inputs().Tag(kVideoSize).Get<std::pair<int, int>>().first; cc->Inputs().Tag(kVideoSize).Get<std::pair<int, int>>().first;
@ -191,10 +232,14 @@ void MakeStaticFeatures(const int top_border, const int bottom_border,
path_solver_height_ = std::make_unique<KinematicPathSolver>( path_solver_height_ = std::make_unique<KinematicPathSolver>(
options_.kinematic_options_zoom(), 0, frame_height_, options_.kinematic_options_zoom(), 0, frame_height_,
static_cast<float>(frame_width_) / kWidthFieldOfView); static_cast<float>(frame_width_) / kWidthFieldOfView);
path_solver_width_ = std::make_unique<KinematicPathSolver>(
options_.kinematic_options_pan(), 0, frame_width_,
static_cast<float>(frame_width_) / kWidthFieldOfView);
path_solver_offset_ = std::make_unique<KinematicPathSolver>( path_solver_offset_ = std::make_unique<KinematicPathSolver>(
options_.kinematic_options_tilt(), 0, frame_height_, options_.kinematic_options_tilt(), 0, frame_height_,
static_cast<float>(frame_width_) / kWidthFieldOfView); static_cast<float>(frame_width_) / kWidthFieldOfView);
min_height_value_ = 1.0; max_frame_value_ = 1.0;
target_aspect_ = frame_width_ / static_cast<float>(frame_height_);
// If target size is set and wider than input aspect, make sure to always // If target size is set and wider than input aspect, make sure to always
// crop the min required amount. // crop the min required amount.
if (options_.has_target_size()) { if (options_.has_target_size()) {
@ -203,21 +248,23 @@ void MakeStaticFeatures(const int top_border, const int bottom_border,
RET_CHECK_GT(options_.target_size().height(), 0) RET_CHECK_GT(options_.target_size().height(), 0)
<< "Provided target height not valid."; << "Provided target height not valid.";
float input_aspect = frame_width_ / static_cast<float>(frame_height_); float input_aspect = frame_width_ / static_cast<float>(frame_height_);
float target_aspect = options_.target_size().width() / target_aspect_ = options_.target_size().width() /
static_cast<float>(options_.target_size().height()); static_cast<float>(options_.target_size().height());
min_height_value_ = max_frame_value_ = std::min(input_aspect / target_aspect_,
(input_aspect < target_aspect) ? input_aspect / target_aspect : 1.0; target_aspect_ / input_aspect);
} }
last_measured_height_ = min_height_value_ * frame_height_; last_measured_height_ = max_frame_value_ * frame_height_;
last_measured_x_offset_ = target_aspect_ * frame_width_;
last_measured_y_offset_ = frame_width_ / 2; last_measured_y_offset_ = frame_width_ / 2;
initialized_ = true; initialized_ = true;
} }
auto detection_set = cc->Inputs().Tag(kDetectionSet).Get<DetectionSet>();
bool only_required_found = false; bool only_required_found = false;
// Compute the box that contains all "is_required" detections. // Compute the box that contains all "is_required" detections.
float xmin = 1, ymin = 1, xmax = 0, ymax = 0; float xmin = 1, ymin = 1, xmax = 0, ymax = 0;
if (cc->Inputs().HasTag(kSalientRegions)) {
auto detection_set = cc->Inputs().Tag(kSalientRegions).Get<DetectionSet>();
for (const auto& region : detection_set.detections()) { for (const auto& region : detection_set.detections()) {
if (!region.only_required()) { if (!region.only_required()) {
continue; continue;
@ -225,46 +272,64 @@ void MakeStaticFeatures(const int top_border, const int bottom_border,
only_required_found = true; only_required_found = true;
MP_RETURN_IF_ERROR(UpdateRanges(region, &xmin, &xmax, &ymin, &ymax)); MP_RETURN_IF_ERROR(UpdateRanges(region, &xmin, &xmax, &ymin, &ymax));
} }
}
if (cc->Inputs().HasTag(kDetections)) {
auto raw_detections =
cc->Inputs().Tag(kDetections).Get<std::vector<mediapipe::Detection>>();
for (const auto& detection : raw_detections) {
only_required_found = true;
MP_RETURN_IF_ERROR(UpdateRanges(detection, &xmin, &xmax, &ymin, &ymax));
}
}
// Convert bounds to tilt/zoom and in pixel coordinates. // Convert bounds to tilt/zoom and in pixel coordinates.
int offset, height; int offset_y, height, offset_x;
MP_RETURN_IF_ERROR( MP_RETURN_IF_ERROR(ConvertToPanTiltZoom(xmin, xmax, ymin, ymax, &offset_y,
ConvertToTiltZoom(xmin, xmax, ymin, ymax, &offset, &height)); &offset_x, &height));
if (only_required_found) { if (only_required_found) {
// A only required detection was found. // A only required detection was found.
last_only_required_detection_ = cc->InputTimestamp().Microseconds(); last_only_required_detection_ = cc->InputTimestamp().Microseconds();
last_measured_height_ = height; last_measured_height_ = height;
last_measured_y_offset_ = offset; last_measured_x_offset_ = offset_x;
last_measured_y_offset_ = offset_y;
} else if (cc->InputTimestamp().Microseconds() - } else if (cc->InputTimestamp().Microseconds() -
last_only_required_detection_ >= last_only_required_detection_ >=
options_.us_before_zoomout()) { options_.us_before_zoomout()) {
// No only_require detections found within salient regions packets arriving // No only_require detections found within salient regions packets
// since us_before_zoomout duration. // arriving since us_before_zoomout duration.
height = min_height_value_ * frame_height_; height = max_frame_value_ * frame_height_;
offset = frame_height_ / 2; offset_x = (target_aspect_ * height) / 2;
offset_y = frame_height_ / 2;
} else { } else {
// No only detection found but using last detection due to // No only detection found but using last detection due to
// duration_before_zoomout_us setting. // duration_before_zoomout_us setting.
height = last_measured_height_; height = last_measured_height_;
offset = last_measured_y_offset_; offset_x = last_measured_x_offset_;
offset_y = last_measured_y_offset_;
} }
// Compute smoothed camera paths. // Compute smoothed camera paths.
MP_RETURN_IF_ERROR(path_solver_height_->AddObservation( MP_RETURN_IF_ERROR(path_solver_height_->AddObservation(
height, cc->InputTimestamp().Microseconds())); height, cc->InputTimestamp().Microseconds()));
MP_RETURN_IF_ERROR(path_solver_width_->AddObservation(
offset_x, cc->InputTimestamp().Microseconds()));
MP_RETURN_IF_ERROR(path_solver_offset_->AddObservation( MP_RETURN_IF_ERROR(path_solver_offset_->AddObservation(
offset, cc->InputTimestamp().Microseconds())); offset_y, cc->InputTimestamp().Microseconds()));
int path_size; int path_height;
MP_RETURN_IF_ERROR(path_solver_height_->GetState(&path_size)); MP_RETURN_IF_ERROR(path_solver_height_->GetState(&path_height));
int path_offset; int path_offset_x;
MP_RETURN_IF_ERROR(path_solver_offset_->GetState(&path_offset)); MP_RETURN_IF_ERROR(path_solver_width_->GetState(&path_offset_x));
int path_offset_y;
MP_RETURN_IF_ERROR(path_solver_offset_->GetState(&path_offset_y));
// Convert to top/bottom borders to remove. // Convert to top/bottom borders to remove.
int path_top = path_offset - path_size / 2; int path_top = path_offset_y - path_height / 2;
int path_bottom = frame_height_ - (path_offset + path_size / 2); int path_bottom = frame_height_ - (path_offset_y + path_height / 2);
// Transmit result downstream. // Transmit result downstream to scenecroppingcalculator.
if (cc->Outputs().HasTag(kDetectedBorders)) {
std::unique_ptr<StaticFeatures> features = std::unique_ptr<StaticFeatures> features =
absl::make_unique<StaticFeatures>(); absl::make_unique<StaticFeatures>();
MakeStaticFeatures(path_top, path_bottom, frame_width_, frame_height_, MakeStaticFeatures(path_top, path_bottom, frame_width_, frame_height_,
@ -272,6 +337,18 @@ void MakeStaticFeatures(const int top_border, const int bottom_border,
cc->Outputs() cc->Outputs()
.Tag(kDetectedBorders) .Tag(kDetectedBorders)
.AddPacket(Adopt(features.release()).At(cc->InputTimestamp())); .AddPacket(Adopt(features.release()).At(cc->InputTimestamp()));
}
// Transmit downstream to glcroppingcalculator.
if (cc->Outputs().HasTag(kCropRect)) {
auto gpu_rect = absl::make_unique<mediapipe::Rect>();
gpu_rect->set_x_center(path_offset_x);
gpu_rect->set_width(path_height * target_aspect_);
gpu_rect->set_y_center(path_offset_y);
gpu_rect->set_height(path_height);
cc->Outputs().Tag(kCropRect).Add(gpu_rect.release(),
Timestamp(cc->InputTimestamp()));
}
return ::mediapipe::OkStatus(); return ::mediapipe::OkStatus();
} }

View File

@ -32,6 +32,8 @@ message ContentZoomingCalculatorOptions {
optional KinematicOptions kinematic_options_zoom = 6; optional KinematicOptions kinematic_options_zoom = 6;
// Kinematic options for tilt (y-axis reframing.) // Kinematic options for tilt (y-axis reframing.)
optional KinematicOptions kinematic_options_tilt = 7; optional KinematicOptions kinematic_options_tilt = 7;
// Kinematic options for pan (x-axis reframing.)
optional KinematicOptions kinematic_options_pan = 10;
// Duration (in MicroSeconds) before returning to fully zoomed out position // Duration (in MicroSeconds) before returning to fully zoomed out position
// when no "only_required" frames are received. // when no "only_required" frames are received.
optional int64 us_before_zoomout = 9 [default = 1000000]; optional int64 us_before_zoomout = 9 [default = 1000000];

View File

@ -16,10 +16,14 @@
#include "mediapipe/examples/desktop/autoflip/autoflip_messages.pb.h" #include "mediapipe/examples/desktop/autoflip/autoflip_messages.pb.h"
#include "mediapipe/examples/desktop/autoflip/calculators/content_zooming_calculator.pb.h" #include "mediapipe/examples/desktop/autoflip/calculators/content_zooming_calculator.pb.h"
#include "mediapipe/examples/desktop/autoflip/quality/kinematic_path_solver.h"
#include "mediapipe/framework/calculator_framework.h" #include "mediapipe/framework/calculator_framework.h"
#include "mediapipe/framework/calculator_runner.h" #include "mediapipe/framework/calculator_runner.h"
#include "mediapipe/framework/formats/detection.pb.h"
#include "mediapipe/framework/formats/image_frame.h" #include "mediapipe/framework/formats/image_frame.h"
#include "mediapipe/framework/formats/image_frame_opencv.h" #include "mediapipe/framework/formats/image_frame_opencv.h"
#include "mediapipe/framework/formats/location_data.pb.h"
#include "mediapipe/framework/formats/rect.pb.h"
#include "mediapipe/framework/port/benchmark.h" #include "mediapipe/framework/port/benchmark.h"
#include "mediapipe/framework/port/gmock.h" #include "mediapipe/framework/port/gmock.h"
#include "mediapipe/framework/port/gtest.h" #include "mediapipe/framework/port/gtest.h"
@ -36,14 +40,14 @@ namespace {
const char kConfigA[] = R"( const char kConfigA[] = R"(
calculator: "ContentZoomingCalculator" calculator: "ContentZoomingCalculator"
input_stream: "VIDEO:camera_frames" input_stream: "VIDEO:camera_frames"
input_stream: "DETECTIONS:detection_set" input_stream: "SALIENT_REGIONS:detection_set"
output_stream: "BORDERS:borders" output_stream: "BORDERS:borders"
)"; )";
const char kConfigB[] = R"( const char kConfigB[] = R"(
calculator: "ContentZoomingCalculator" calculator: "ContentZoomingCalculator"
input_stream: "VIDEO:camera_frames" input_stream: "VIDEO:camera_frames"
input_stream: "DETECTIONS:detection_set" input_stream: "SALIENT_REGIONS:detection_set"
output_stream: "BORDERS:borders" output_stream: "BORDERS:borders"
options: { options: {
[mediapipe.autoflip.ContentZoomingCalculatorOptions.ext]: { [mediapipe.autoflip.ContentZoomingCalculatorOptions.ext]: {
@ -58,10 +62,17 @@ const char kConfigB[] = R"(
const char kConfigC[] = R"( const char kConfigC[] = R"(
calculator: "ContentZoomingCalculator" calculator: "ContentZoomingCalculator"
input_stream: "VIDEO_SIZE:size" input_stream: "VIDEO_SIZE:size"
input_stream: "DETECTIONS:detection_set" input_stream: "SALIENT_REGIONS:detection_set"
output_stream: "BORDERS:borders" output_stream: "BORDERS:borders"
)"; )";
const char kConfigD[] = R"(
calculator: "ContentZoomingCalculator"
input_stream: "VIDEO_SIZE:size"
input_stream: "DETECTIONS:detections"
output_stream: "CROP_RECT:rect"
)";
void CheckBorder(const StaticFeatures& static_features, int width, int height, void CheckBorder(const StaticFeatures& static_features, int width, int height,
int top_border, int bottom_border) { int top_border, int bottom_border) {
ASSERT_EQ(2, static_features.border().size()); ASSERT_EQ(2, static_features.border().size());
@ -80,6 +91,43 @@ void CheckBorder(const StaticFeatures& static_features, int width, int height,
EXPECT_EQ(Border::BOTTOM, part.relative_position()); EXPECT_EQ(Border::BOTTOM, part.relative_position());
} }
void AddDetection(const cv::Rect_<float>& position, const int64 time,
CalculatorRunner* runner) {
auto detections = std::make_unique<std::vector<mediapipe::Detection>>();
mediapipe::Detection detection;
detection.mutable_location_data()->set_format(
mediapipe::LocationData::RELATIVE_BOUNDING_BOX);
detection.mutable_location_data()
->mutable_relative_bounding_box()
->set_height(position.height);
detection.mutable_location_data()->mutable_relative_bounding_box()->set_width(
position.width);
detection.mutable_location_data()->mutable_relative_bounding_box()->set_xmin(
position.x);
detection.mutable_location_data()->mutable_relative_bounding_box()->set_ymin(
position.y);
detections->push_back(detection);
runner->MutableInputs()
->Tag("DETECTIONS")
.packets.push_back(Adopt(detections.release()).At(Timestamp(time)));
auto input_size = ::absl::make_unique<std::pair<int, int>>(1000, 1000);
runner->MutableInputs()
->Tag("VIDEO_SIZE")
.packets.push_back(Adopt(input_size.release()).At(Timestamp(time)));
}
void CheckCropRect(const int x_center, const int y_center, const int width,
const int height, const int frame_number,
const std::vector<Packet>& output_packets) {
ASSERT_GT(output_packets.size(), frame_number);
const auto& rect = output_packets[frame_number].Get<mediapipe::Rect>();
EXPECT_EQ(rect.x_center(), x_center);
EXPECT_EQ(rect.y_center(), y_center);
EXPECT_EQ(rect.width(), width);
EXPECT_EQ(rect.height(), height);
}
TEST(ContentZoomingCalculatorTest, ZoomTest) { TEST(ContentZoomingCalculatorTest, ZoomTest) {
auto runner = ::absl::make_unique<CalculatorRunner>( auto runner = ::absl::make_unique<CalculatorRunner>(
ParseTextProtoOrDie<CalculatorGraphConfig::Node>(kConfigA)); ParseTextProtoOrDie<CalculatorGraphConfig::Node>(kConfigA));
@ -98,7 +146,7 @@ TEST(ContentZoomingCalculatorTest, ZoomTest) {
Adopt(input_frame.release()).At(Timestamp(0))); Adopt(input_frame.release()).At(Timestamp(0)));
runner->MutableInputs() runner->MutableInputs()
->Tag("DETECTIONS") ->Tag("SALIENT_REGIONS")
.packets.push_back(Adopt(detection_set.release()).At(Timestamp(0))); .packets.push_back(Adopt(detection_set.release()).At(Timestamp(0)));
// Run the calculator. // Run the calculator.
@ -111,6 +159,66 @@ TEST(ContentZoomingCalculatorTest, ZoomTest) {
CheckBorder(static_features, 1000, 1000, 495, 395); CheckBorder(static_features, 1000, 1000, 495, 395);
} }
TEST(ContentZoomingCalculatorTest, ZoomTestFullPTZ) {
auto runner = ::absl::make_unique<CalculatorRunner>(
ParseTextProtoOrDie<CalculatorGraphConfig::Node>(kConfigD));
AddDetection(cv::Rect_<float>(.4, .5, .1, .1), 0, runner.get());
MP_ASSERT_OK(runner->Run());
CheckCropRect(450, 550, 111, 111, 0,
runner->Outputs().Tag("CROP_RECT").packets);
}
TEST(ContentZoomingCalculatorTest, PanConfig) {
auto config = ParseTextProtoOrDie<CalculatorGraphConfig::Node>(kConfigD);
auto* options = config.mutable_options()->MutableExtension(
ContentZoomingCalculatorOptions::ext);
options->mutable_kinematic_options_pan()->set_min_motion_to_reframe(0.0);
options->mutable_kinematic_options_tilt()->set_min_motion_to_reframe(5.0);
options->mutable_kinematic_options_zoom()->set_min_motion_to_reframe(5.0);
auto runner = ::absl::make_unique<CalculatorRunner>(config);
AddDetection(cv::Rect_<float>(.4, .5, .1, .1), 0, runner.get());
AddDetection(cv::Rect_<float>(.45, .55, .15, .15), 1000000, runner.get());
MP_ASSERT_OK(runner->Run());
CheckCropRect(450, 550, 111, 111, 0,
runner->Outputs().Tag("CROP_RECT").packets);
CheckCropRect(488, 550, 111, 111, 1,
runner->Outputs().Tag("CROP_RECT").packets);
}
TEST(ContentZoomingCalculatorTest, TiltConfig) {
auto config = ParseTextProtoOrDie<CalculatorGraphConfig::Node>(kConfigD);
auto* options = config.mutable_options()->MutableExtension(
ContentZoomingCalculatorOptions::ext);
options->mutable_kinematic_options_pan()->set_min_motion_to_reframe(5.0);
options->mutable_kinematic_options_tilt()->set_min_motion_to_reframe(0.0);
options->mutable_kinematic_options_zoom()->set_min_motion_to_reframe(5.0);
auto runner = ::absl::make_unique<CalculatorRunner>(config);
AddDetection(cv::Rect_<float>(.4, .5, .1, .1), 0, runner.get());
AddDetection(cv::Rect_<float>(.45, .55, .15, .15), 1000000, runner.get());
MP_ASSERT_OK(runner->Run());
CheckCropRect(450, 550, 111, 111, 0,
runner->Outputs().Tag("CROP_RECT").packets);
CheckCropRect(450, 588, 111, 111, 1,
runner->Outputs().Tag("CROP_RECT").packets);
}
TEST(ContentZoomingCalculatorTest, ZoomConfig) {
auto config = ParseTextProtoOrDie<CalculatorGraphConfig::Node>(kConfigD);
auto* options = config.mutable_options()->MutableExtension(
ContentZoomingCalculatorOptions::ext);
options->mutable_kinematic_options_pan()->set_min_motion_to_reframe(5.0);
options->mutable_kinematic_options_tilt()->set_min_motion_to_reframe(5.0);
options->mutable_kinematic_options_zoom()->set_min_motion_to_reframe(0.0);
auto runner = ::absl::make_unique<CalculatorRunner>(config);
AddDetection(cv::Rect_<float>(.4, .5, .1, .1), 0, runner.get());
AddDetection(cv::Rect_<float>(.45, .55, .15, .15), 1000000, runner.get());
MP_ASSERT_OK(runner->Run());
CheckCropRect(450, 550, 111, 111, 0,
runner->Outputs().Tag("CROP_RECT").packets);
CheckCropRect(450, 550, 139, 139, 1,
runner->Outputs().Tag("CROP_RECT").packets);
}
TEST(ContentZoomingCalculatorTest, MinAspectBorderValues) { TEST(ContentZoomingCalculatorTest, MinAspectBorderValues) {
auto runner = ::absl::make_unique<CalculatorRunner>( auto runner = ::absl::make_unique<CalculatorRunner>(
ParseTextProtoOrDie<CalculatorGraphConfig::Node>(kConfigB)); ParseTextProtoOrDie<CalculatorGraphConfig::Node>(kConfigB));
@ -129,7 +237,7 @@ TEST(ContentZoomingCalculatorTest, MinAspectBorderValues) {
Adopt(input_frame.release()).At(Timestamp(0))); Adopt(input_frame.release()).At(Timestamp(0)));
runner->MutableInputs() runner->MutableInputs()
->Tag("DETECTIONS") ->Tag("SALIENT_REGIONS")
.packets.push_back(Adopt(detection_set.release()).At(Timestamp(0))); .packets.push_back(Adopt(detection_set.release()).At(Timestamp(0)));
// Run the calculator. // Run the calculator.
@ -166,7 +274,7 @@ TEST(ContentZoomingCalculatorTest, TwoFacesWide) {
Adopt(input_frame.release()).At(Timestamp(0))); Adopt(input_frame.release()).At(Timestamp(0)));
runner->MutableInputs() runner->MutableInputs()
->Tag("DETECTIONS") ->Tag("SALIENT_REGIONS")
.packets.push_back(Adopt(detection_set.release()).At(Timestamp(0))); .packets.push_back(Adopt(detection_set.release()).At(Timestamp(0)));
// Run the calculator. // Run the calculator.
@ -191,7 +299,7 @@ TEST(ContentZoomingCalculatorTest, NoDetectionOnInit) {
Adopt(input_frame.release()).At(Timestamp(0))); Adopt(input_frame.release()).At(Timestamp(0)));
runner->MutableInputs() runner->MutableInputs()
->Tag("DETECTIONS") ->Tag("SALIENT_REGIONS")
.packets.push_back(Adopt(detection_set.release()).At(Timestamp(0))); .packets.push_back(Adopt(detection_set.release()).At(Timestamp(0)));
// Run the calculator. // Run the calculator.
@ -223,7 +331,7 @@ TEST(ContentZoomingCalculatorTest, ZoomTestPairSize) {
.packets.push_back(Adopt(input_size.release()).At(Timestamp(0))); .packets.push_back(Adopt(input_size.release()).At(Timestamp(0)));
runner->MutableInputs() runner->MutableInputs()
->Tag("DETECTIONS") ->Tag("SALIENT_REGIONS")
.packets.push_back(Adopt(detection_set.release()).At(Timestamp(0))); .packets.push_back(Adopt(detection_set.release()).At(Timestamp(0)));
// Run the calculator. // Run the calculator.

View File

@ -37,7 +37,7 @@ node {
output_stream: "TENSORS:detection_tensors" output_stream: "TENSORS:detection_tensors"
options: { options: {
[mediapipe.TfLiteInferenceCalculatorOptions.ext] { [mediapipe.TfLiteInferenceCalculatorOptions.ext] {
model_path: "face_detection_front.tflite" model_path: "mediapipe/models/face_detection_front.tflite"
} }
} }
} }
@ -118,7 +118,7 @@ node {
output_stream: "labeled_detections" output_stream: "labeled_detections"
options: { options: {
[mediapipe.DetectionLabelIdToTextCalculatorOptions.ext] { [mediapipe.DetectionLabelIdToTextCalculatorOptions.ext] {
label_map_path: "face_detection_front_labelmap.txt" label_map_path: "mediapipe/models/face_detection_front_labelmap.txt"
} }
} }
} }

View File

@ -1,18 +1 @@
This directory contains example MediaPipe applications on iOS. This directory contains MediaPipe example applications for iOS. Please see [Solutions](https://solutions.mediapipe.dev)for details.
| Use Case | Directory |
|---------------------------------------|:-----------------------------------:|
| Edge Detection on GPU | edgedetection |
| Face Detection on CPU | facedetectioncpu |
| Face Detection on GPU | facedetectiongpu |
| Object Detection on CPU | objectdetectioncpu |
| Object Detection on GPU | objectdetectiongpu |
| Hand Detection on GPU | handdetectiongpu |
| Hand Tracking on GPU | handtrackinggpu |
For instance, to build an example app for face detection on CPU, run:
```bash
bazel build -c opt --config=ios_arm64 --xcode_version=$XCODE_VERSION --cxxopt='-std=c++14' mediapipe/examples/ios/facedetectioncpu:FaceDetectionCpuApp
```
(Note: with your own $XCODE_VERSION)

View File

@ -21,6 +21,11 @@ load(
"ios_application", "ios_application",
) )
alias(
name = "edgedetectiongpu",
actual = "EdgeDetectionGpuApp",
)
ios_application( ios_application(
name = "EdgeDetectionGpuApp", name = "EdgeDetectionGpuApp",
bundle_id = "com.google.mediapipe.EdgeDetectionGpu", bundle_id = "com.google.mediapipe.EdgeDetectionGpu",

View File

@ -21,6 +21,11 @@ load(
"ios_application", "ios_application",
) )
alias(
name = "facedetectioncpu",
actual = "FaceDetectionCpuApp",
)
ios_application( ios_application(
name = "FaceDetectionCpuApp", name = "FaceDetectionCpuApp",
bundle_id = "com.google.mediapipe.FaceDetectionCpu", bundle_id = "com.google.mediapipe.FaceDetectionCpu",

View File

@ -21,6 +21,11 @@ load(
"ios_application", "ios_application",
) )
alias(
name = "facedetectiongpu",
actual = "FaceDetectionGpuApp",
)
ios_application( ios_application(
name = "FaceDetectionGpuApp", name = "FaceDetectionGpuApp",
bundle_id = "com.google.mediapipe.FaceDetectionGpu", bundle_id = "com.google.mediapipe.FaceDetectionGpu",

View File

@ -21,6 +21,11 @@ licenses(["notice"]) # Apache 2.0
MIN_IOS_VERSION = "10.0" MIN_IOS_VERSION = "10.0"
alias(
name = "facemeshgpu",
actual = "FaceMeshGpuApp",
)
ios_application( ios_application(
name = "FaceMeshGpuApp", name = "FaceMeshGpuApp",
bundle_id = "com.google.mediapipe.FaceMeshGpu", bundle_id = "com.google.mediapipe.FaceMeshGpu",

View File

@ -21,6 +21,11 @@ load(
"ios_application", "ios_application",
) )
alias(
name = "handdetectiongpu",
actual = "HandDetectionGpuApp",
)
ios_application( ios_application(
name = "HandDetectionGpuApp", name = "HandDetectionGpuApp",
bundle_id = "com.google.mediapipe.HandDetectionGpu", bundle_id = "com.google.mediapipe.HandDetectionGpu",

View File

@ -21,6 +21,11 @@ licenses(["notice"]) # Apache 2.0
MIN_IOS_VERSION = "10.0" MIN_IOS_VERSION = "10.0"
alias(
name = "handtrackinggpu",
actual = "HandTrackingGpuApp",
)
ios_application( ios_application(
name = "HandTrackingGpuApp", name = "HandTrackingGpuApp",
bundle_id = "com.google.mediapipe.HandTrackingGpu", bundle_id = "com.google.mediapipe.HandTrackingGpu",

View File

@ -21,6 +21,11 @@ licenses(["notice"]) # Apache 2.0
MIN_IOS_VERSION = "10.0" MIN_IOS_VERSION = "10.0"
alias(
name = "multihandtrackinggpu",
actual = "MultiHandTrackingGpuApp",
)
ios_application( ios_application(
name = "MultiHandTrackingGpuApp", name = "MultiHandTrackingGpuApp",
bundle_id = "com.google.mediapipe.MultiHandTrackingGpu", bundle_id = "com.google.mediapipe.MultiHandTrackingGpu",

View File

@ -21,6 +21,11 @@ load(
"ios_application", "ios_application",
) )
alias(
name = "objectdetectioncpu",
actual = "ObjectDetectionCpuApp",
)
ios_application( ios_application(
name = "ObjectDetectionCpuApp", name = "ObjectDetectionCpuApp",
bundle_id = "com.google.mediapipe.ObjectDetectionCpu", bundle_id = "com.google.mediapipe.ObjectDetectionCpu",

View File

@ -13,6 +13,7 @@
// limitations under the License. // limitations under the License.
#import "AppDelegate.h" #import "AppDelegate.h"
#import "ViewController.h"
@interface AppDelegate () @interface AppDelegate ()
@ -22,7 +23,14 @@
- (BOOL)application:(UIApplication *)application - (BOOL)application:(UIApplication *)application
didFinishLaunchingWithOptions:(NSDictionary *)launchOptions { didFinishLaunchingWithOptions:(NSDictionary *)launchOptions {
// Override point for customization after application launch. ViewController *viewController = (ViewController *)self.window.rootViewController;
NSURL *url = [launchOptions objectForKey:UIApplicationLaunchOptionsURLKey];
// Unattended testing on Firebase is enabled by custom URL schema.
if ([url.scheme isEqualToString:@"firebase-game-loop"]) {
[viewController setSourceMode:MediaPipeDemoSourceVideo];
} else {
[viewController setSourceMode:MediaPipeDemoSourceBackCamera];
}
return YES; return YES;
} }

View File

@ -21,6 +21,11 @@ load(
"ios_application", "ios_application",
) )
alias(
name = "objectdetectiongpu",
actual = "ObjectDetectionGpuApp",
)
ios_application( ios_application(
name = "ObjectDetectionGpuApp", name = "ObjectDetectionGpuApp",
bundle_id = "com.google.mediapipe.ObjectDetectionGpu", bundle_id = "com.google.mediapipe.ObjectDetectionGpu",

View File

@ -38,5 +38,18 @@
<array> <array>
<string>UIInterfaceOrientationPortrait</string> <string>UIInterfaceOrientationPortrait</string>
</array> </array>
<key>CFBundleURLTypes</key>
<array>
<dict>
<key>CFBundleURLName</key>
<string>com.google.firebase</string>
<key>CFBundleTypeRole</key>
<string>Editor</string>
<key>CFBundleURLSchemes</key>
<array>
<string>firebase-game-loop</string>
</array>
</dict>
</array>
</dict> </dict>
</plist> </plist>

View File

@ -14,6 +14,11 @@
#import <UIKit/UIKit.h> #import <UIKit/UIKit.h>
@interface ViewController : UIViewController typedef NS_ENUM(NSInteger, MediaPipeDemoSourceMode) {
MediaPipeDemoSourceBackCamera,
MediaPipeDemoSourceVideo
};
@interface ViewController : UIViewController
- (void)setSourceMode:(MediaPipeDemoSourceMode)mode;
@end @end

View File

@ -17,6 +17,7 @@
#import "mediapipe/objc/MPPGraph.h" #import "mediapipe/objc/MPPGraph.h"
#import "mediapipe/objc/MPPCameraInputSource.h" #import "mediapipe/objc/MPPCameraInputSource.h"
#import "mediapipe/objc/MPPLayerRenderer.h" #import "mediapipe/objc/MPPLayerRenderer.h"
#import "mediapipe/objc/MPPPlayerInputSource.h"
static NSString* const kGraphName = @"mobile_gpu"; static NSString* const kGraphName = @"mobile_gpu";
@ -35,6 +36,8 @@ static const char* kVideoQueueLabel = "com.google.mediapipe.example.videoQueue";
@implementation ViewController { @implementation ViewController {
/// Handles camera access via AVCaptureSession library. /// Handles camera access via AVCaptureSession library.
MPPCameraInputSource* _cameraSource; MPPCameraInputSource* _cameraSource;
MPPPlayerInputSource* _videoSource;
MediaPipeDemoSourceMode _sourceMode;
/// Inform the user when camera is unavailable. /// Inform the user when camera is unavailable.
IBOutlet UILabel* _noCameraLabel; IBOutlet UILabel* _noCameraLabel;
@ -47,6 +50,10 @@ static const char* kVideoQueueLabel = "com.google.mediapipe.example.videoQueue";
dispatch_queue_t _videoQueue; dispatch_queue_t _videoQueue;
} }
- (void)setSourceMode:(MediaPipeDemoSourceMode)mode {
_sourceMode = mode;
}
#pragma mark - Cleanup methods #pragma mark - Cleanup methods
- (void)dealloc { - (void)dealloc {
@ -97,13 +104,6 @@ static const char* kVideoQueueLabel = "com.google.mediapipe.example.videoQueue";
DISPATCH_QUEUE_SERIAL, QOS_CLASS_USER_INTERACTIVE, /*relative_priority=*/0); DISPATCH_QUEUE_SERIAL, QOS_CLASS_USER_INTERACTIVE, /*relative_priority=*/0);
_videoQueue = dispatch_queue_create(kVideoQueueLabel, qosAttribute); _videoQueue = dispatch_queue_create(kVideoQueueLabel, qosAttribute);
_cameraSource = [[MPPCameraInputSource alloc] init];
[_cameraSource setDelegate:self queue:_videoQueue];
_cameraSource.sessionPreset = AVCaptureSessionPresetHigh;
_cameraSource.cameraPosition = AVCaptureDevicePositionBack;
// The frame's native format is rotated with respect to the portrait orientation.
_cameraSource.orientation = AVCaptureVideoOrientationPortrait;
self.mediapipeGraph = [[self class] loadGraphFromResource:kGraphName]; self.mediapipeGraph = [[self class] loadGraphFromResource:kGraphName];
self.mediapipeGraph.delegate = self; self.mediapipeGraph.delegate = self;
// Set maxFramesInFlight to a small value to avoid memory contention for real-time processing. // Set maxFramesInFlight to a small value to avoid memory contention for real-time processing.
@ -119,27 +119,43 @@ static const char* kVideoQueueLabel = "com.google.mediapipe.example.videoQueue";
- (void)viewWillAppear:(BOOL)animated { - (void)viewWillAppear:(BOOL)animated {
[super viewWillAppear:animated]; [super viewWillAppear:animated];
[_cameraSource requestCameraAccessWithCompletionHandler:^void(BOOL granted) {
if (granted) {
[self startGraphAndCamera];
dispatch_async(dispatch_get_main_queue(), ^{
_noCameraLabel.hidden = YES;
});
}
}];
}
- (void)startGraphAndCamera {
// Start running self.mediapipeGraph. // Start running self.mediapipeGraph.
NSError* error; NSError* error;
if (![self.mediapipeGraph startWithError:&error]) { if (![self.mediapipeGraph startWithError:&error]) {
NSLog(@"Failed to start graph: %@", error); NSLog(@"Failed to start graph: %@", error);
} }
// Start fetching frames from the camera. switch (_sourceMode) {
case MediaPipeDemoSourceVideo: {
AVAsset* video =
[AVAsset assetWithURL:[[NSBundle mainBundle] URLForResource:@"object_detection"
withExtension:@"mov"]];
_videoSource = [[MPPPlayerInputSource alloc] initWithAVAsset:video];
[_videoSource setDelegate:self queue:_videoQueue];
dispatch_async(_videoQueue, ^{
[_videoSource start];
});
break;
}
case MediaPipeDemoSourceBackCamera:
_cameraSource = [[MPPCameraInputSource alloc] init];
[_cameraSource setDelegate:self queue:_videoQueue];
_cameraSource.sessionPreset = AVCaptureSessionPresetHigh;
_cameraSource.cameraPosition = AVCaptureDevicePositionBack;
// The frame's native format is rotated with respect to the portrait orientation.
_cameraSource.orientation = AVCaptureVideoOrientationPortrait;
[_cameraSource requestCameraAccessWithCompletionHandler:^void(BOOL granted) {
if (granted) {
dispatch_async(_videoQueue, ^{ dispatch_async(_videoQueue, ^{
[_cameraSource start]; [_cameraSource start];
}); });
dispatch_async(dispatch_get_main_queue(), ^{
_noCameraLabel.hidden = YES;
});
}
}];
break;
}
} }
#pragma mark - MPPGraphDelegate methods #pragma mark - MPPGraphDelegate methods
@ -164,7 +180,7 @@ static const char* kVideoQueueLabel = "com.google.mediapipe.example.videoQueue";
- (void)processVideoFrame:(CVPixelBufferRef)imageBuffer - (void)processVideoFrame:(CVPixelBufferRef)imageBuffer
timestamp:(CMTime)timestamp timestamp:(CMTime)timestamp
fromSource:(MPPInputSource*)source { fromSource:(MPPInputSource*)source {
if (source != _cameraSource) { if (source != _cameraSource && source != _videoSource) {
NSLog(@"Unknown source: %@", source); NSLog(@"Unknown source: %@", source);
return; return;
} }

View File

@ -36,7 +36,7 @@ exports_files([
mediapipe_proto_library( mediapipe_proto_library(
name = "calculator_proto", name = "calculator_proto",
srcs = ["calculator.proto"], srcs = ["calculator.proto"],
visibility = [":mediapipe_internal"], visibility = ["//visibility:public"],
deps = [ deps = [
"//mediapipe/framework:calculator_options_proto", "//mediapipe/framework:calculator_options_proto",
"//mediapipe/framework:mediapipe_options_proto", "//mediapipe/framework:mediapipe_options_proto",
@ -68,7 +68,7 @@ mediapipe_proto_library(
mediapipe_proto_library( mediapipe_proto_library(
name = "calculator_profile_proto", name = "calculator_profile_proto",
srcs = ["calculator_profile.proto"], srcs = ["calculator_profile.proto"],
visibility = [":mediapipe_internal"], visibility = ["//visibility:public"],
deps = [ deps = [
"//mediapipe/framework:calculator_options_proto", "//mediapipe/framework:calculator_options_proto",
"//mediapipe/framework:calculator_proto", "//mediapipe/framework:calculator_proto",
@ -830,6 +830,8 @@ cc_library(
":port", ":port",
":timestamp", ":timestamp",
":type_map", ":type_map",
"//mediapipe/framework/deps:no_destructor",
"//mediapipe/framework/deps:registration",
"//mediapipe/framework/port:core_proto", "//mediapipe/framework/port:core_proto",
"//mediapipe/framework/port:integral_types", "//mediapipe/framework/port:integral_types",
"//mediapipe/framework/port:logging", "//mediapipe/framework/port:logging",
@ -1524,6 +1526,21 @@ cc_test(
], ],
) )
cc_test(
name = "packet_registration_test",
size = "small",
srcs = ["packet_registration_test.cc"],
deps = [
":calculator_framework",
":packet",
":packet_test_cc_proto",
":type_map",
"//mediapipe/framework/port:core_proto",
"//mediapipe/framework/port:gtest_main",
"@com_google_absl//absl/strings",
],
)
cc_test( cc_test(
name = "packet_generator_test", name = "packet_generator_test",
size = "small", size = "small",

View File

@ -115,6 +115,9 @@ class CalculatorContract {
// When true, Process is called for every new timestamp bound, with or without // When true, Process is called for every new timestamp bound, with or without
// new packets. A call to Process with only an input timestamp bound is // new packets. A call to Process with only an input timestamp bound is
// normally used to compute a new output timestamp bound. // normally used to compute a new output timestamp bound.
// NOTE: Also, when true, Process is called when input streams become done,
// which means, Process needs to handle input streams in "done" state.
// (Usually, by closing calculators' outputs where and when appropriate.)
void SetProcessTimestampBounds(bool process_timestamps) { void SetProcessTimestampBounds(bool process_timestamps) {
process_timestamps_ = process_timestamps; process_timestamps_ = process_timestamps;
} }

View File

@ -91,6 +91,9 @@ typedef ::mediapipe::StatusOr<OutputStreamPoller> StatusOrPoller;
// {{"video_id", mediapipe::MakePacket<std::string>("Ex-uGhDzue4")}})); // {{"video_id", mediapipe::MakePacket<std::string>("Ex-uGhDzue4")}}));
// // See mediapipe/framework/graph_runner.h for an interface // // See mediapipe/framework/graph_runner.h for an interface
// // to insert and extract packets from a graph as it runs. // // to insert and extract packets from a graph as it runs.
// // Once it is done using the graph, close its streams and wait till done.
// MP_RETURN_IF_ERROR(graph->CloseAllInputStreams());
// MP_RETURN_IF_ERROR(graph->WaitUntilDone());
class CalculatorGraph { class CalculatorGraph {
public: public:
// Defines possible modes for adding a packet to a graph input stream. // Defines possible modes for adding a packet to a graph input stream.
@ -157,8 +160,9 @@ class CalculatorGraph {
std::function<::mediapipe::Status(const Packet&)> packet_callback); std::function<::mediapipe::Status(const Packet&)> packet_callback);
// Adds an OutputStreamPoller for a stream. This provides a synchronous, // Adds an OutputStreamPoller for a stream. This provides a synchronous,
// polling API for accessing a stream's output. For asynchronous output, use // polling API for accessing a stream's output. Should only be called before
// ObserveOutputStream. See also the helpers in tool/sink.h. // Run() or StartRun(). For asynchronous output, use ObserveOutputStream. See
// also the helpers in tool/sink.h.
StatusOrPoller AddOutputStreamPoller(const std::string& stream_name); StatusOrPoller AddOutputStreamPoller(const std::string& stream_name);
// Gets output side packet by name after the graph is done. However, base // Gets output side packet by name after the graph is done. However, base
@ -300,6 +304,13 @@ class CalculatorGraph {
void RecordError(const ::mediapipe::Status& error) void RecordError(const ::mediapipe::Status& error)
ABSL_LOCKS_EXCLUDED(error_mutex_); ABSL_LOCKS_EXCLUDED(error_mutex_);
// Combines errors into a status. Returns true if the vector of errors is
// non-empty.
bool GetCombinedErrors(const std::string& error_prefix,
::mediapipe::Status* error_status);
// Convenience overload which specifies a default error prefix.
bool GetCombinedErrors(::mediapipe::Status* error_status);
// Returns the maximum input stream queue size. // Returns the maximum input stream queue size.
int GetMaxInputStreamQueueSize(); int GetMaxInputStreamQueueSize();
@ -501,13 +512,6 @@ class CalculatorGraph {
void CleanupAfterRun(::mediapipe::Status* status) void CleanupAfterRun(::mediapipe::Status* status)
ABSL_LOCKS_EXCLUDED(error_mutex_); ABSL_LOCKS_EXCLUDED(error_mutex_);
// Combines errors into a status. Returns true if the vector of errors is
// non-empty.
bool GetCombinedErrors(const std::string& error_prefix,
::mediapipe::Status* error_status);
// Convenience overload which specifies a default error prefix.
bool GetCombinedErrors(::mediapipe::Status* error_status);
// Calls HandlePreRunStatus or HandleStatus on the StatusHandlers. Which one // Calls HandlePreRunStatus or HandleStatus on the StatusHandlers. Which one
// is called depends on the GraphRunState parameter (PRE_RUN or POST_RUN). // is called depends on the GraphRunState parameter (PRE_RUN or POST_RUN).
// current_run_side_packets_ must be set before this function is called. // current_run_side_packets_ must be set before this function is called.

View File

@ -459,7 +459,8 @@ class Vector3
int LargestAbsComponent() const { int LargestAbsComponent() const {
Vector3 temp = Abs(); Vector3 temp = Abs();
return temp[0] > temp[1] ? temp[0] > temp[2] ? 0 : 2 return temp[0] > temp[1] ? temp[0] > temp[2] ? 0 : 2
: temp[1] > temp[2] ? 1 : 2; : temp[1] > temp[2] ? 1
: 2;
} }
// return the index of the smallest, median ,largest component of the vector // return the index of the smallest, median ,largest component of the vector

View File

@ -155,7 +155,7 @@ class InputStreamHandler {
// max number of invocations that are allowed to be scheduled is reached. // max number of invocations that are allowed to be scheduled is reached.
// Returns true if at least one invocation has been scheduled. // Returns true if at least one invocation has been scheduled.
// The latest minimum timestamp bound of the input streams is returned in // The latest minimum timestamp bound of the input streams is returned in
// *input_bound iff the latest readiness of the node is kNotReady when the // *input_bound if the latest readiness of the node is kNotReady when the
// function returns. During batching, this value will be equal to the // function returns. During batching, this value will be equal to the
// timestamp of the first set of inputs in the batch. In other cases, // timestamp of the first set of inputs in the batch. In other cases,
// Timestamp::Unset() is returned. // Timestamp::Unset() is returned.

View File

@ -66,6 +66,20 @@ class LegacyCalculatorSupport {
}; };
}; };
// We only declare this variable for two specializations of the template because
// it is only meant to be used for these two types.
// Note that, since these variables are members of specific template
// _specializations_, they are not themselves templates, and therefore their
// definitions must be in the .cc file. However, a declaration still needs to be
// included in the header, or some compilers will assume they have no
// definition.
template <>
thread_local CalculatorContext*
LegacyCalculatorSupport::Scoped<CalculatorContext>::current_;
template <>
thread_local CalculatorContract*
LegacyCalculatorSupport::Scoped<CalculatorContract>::current_;
} // namespace mediapipe } // namespace mediapipe
#endif // MEDIAPIPE_FRAMEWORK_LEGACY_CALCULATOR_SUPPORT_H_ #endif // MEDIAPIPE_FRAMEWORK_LEGACY_CALCULATOR_SUPPORT_H_

View File

@ -51,6 +51,18 @@ const HolderBase* GetHolder(const Packet& packet) {
return packet.holder_.get(); return packet.holder_.get();
} }
::mediapipe::StatusOr<Packet> PacketFromDynamicProto(
const std::string& type_name, const std::string& serialized) {
ASSIGN_OR_RETURN(
auto message_holder,
packet_internal::MessageHolderRegistry::CreateByName(type_name));
auto* message =
const_cast<proto_ns::MessageLite*>(message_holder->GetProtoMessageLite());
RET_CHECK_NE(message, nullptr);
RET_CHECK(message->ParseFromString(serialized));
return packet_internal::Create(message_holder.release());
}
} // namespace packet_internal } // namespace packet_internal
Packet Packet::At(class Timestamp timestamp) const& { Packet Packet::At(class Timestamp timestamp) const& {

View File

@ -27,6 +27,8 @@
#include "absl/memory/memory.h" #include "absl/memory/memory.h"
#include "absl/strings/str_cat.h" #include "absl/strings/str_cat.h"
#include "absl/synchronization/mutex.h" #include "absl/synchronization/mutex.h"
#include "mediapipe/framework/deps/no_destructor.h"
#include "mediapipe/framework/deps/registration.h"
#include "mediapipe/framework/port.h" #include "mediapipe/framework/port.h"
#include "mediapipe/framework/port/canonical_errors.h" #include "mediapipe/framework/port/canonical_errors.h"
#include "mediapipe/framework/port/logging.h" #include "mediapipe/framework/port/logging.h"
@ -51,6 +53,8 @@ Packet Create(HolderBase* holder, Timestamp timestamp);
Packet Create(std::shared_ptr<HolderBase> holder, Timestamp timestamp); Packet Create(std::shared_ptr<HolderBase> holder, Timestamp timestamp);
const HolderBase* GetHolder(const Packet& packet); const HolderBase* GetHolder(const Packet& packet);
const std::shared_ptr<HolderBase>& GetHolderShared(const Packet& packet); const std::shared_ptr<HolderBase>& GetHolderShared(const Packet& packet);
::mediapipe::StatusOr<Packet> PacketFromDynamicProto(
const std::string& type_name, const std::string& serialized);
} // namespace packet_internal } // namespace packet_internal
// A generic container class which can hold data of any type. The type of // A generic container class which can hold data of any type. The type of
@ -355,112 +359,11 @@ class HolderBase {
// Downcasts this to Holder<T>. Returns nullptr if deserialization // Downcasts this to Holder<T>. Returns nullptr if deserialization
// failed or if the requested type is not what is stored. // failed or if the requested type is not what is stored.
template <typename T> template <typename T>
inline Holder<T>* As( Holder<T>* As();
typename std::enable_if<
(!std::is_base_of<proto_ns::MessageLite, T>::value &&
!std::is_base_of<proto_ns::Message, T>::value) ||
(std::is_same<proto_ns::MessageLite, T>::value ||
std::is_same<proto_ns::Message, T>::value)>::type* = 0) {
if (HolderIsOfType<Holder<T>>() || HolderIsOfType<ForeignHolder<T>>()) {
return static_cast<Holder<T>*>(this);
}
// Does not hold a T.
return nullptr;
}
// For proto Message/MessageLite subclasses.
// When holder data is a concrete proto, the method downcasts this to
// Holder<T> if the requested type is what is stored.
// When holder data is a generic proto Message/MessageLite and a concrete
// proto type T is requested, the method will downcast the HolderBase to
// Holder<T> if the proto data is an instance of T.
template <typename T>
inline Holder<T>* As(
typename std::enable_if<
(std::is_base_of<proto_ns::MessageLite, T>::value ||
std::is_base_of<proto_ns::Message, T>::value) &&
(!std::is_same<proto_ns::MessageLite, T>::value &&
!std::is_same<proto_ns::Message, T>::value)>::type* = 0) {
// Holder data is an instance of subclass type T.
if (HolderIsOfType<Holder<T>>() || HolderIsOfType<ForeignHolder<T>>()) {
return static_cast<Holder<T>*>(this);
}
// Holder data is a generic proto Message/MessageLite and a subclass type T
// is requested.
if (HolderIsOfType<Holder<proto_ns::Message>>() ||
HolderIsOfType<ForeignHolder<proto_ns::Message>>() ||
HolderIsOfType<Holder<proto_ns::MessageLite>>() ||
HolderIsOfType<ForeignHolder<proto_ns::MessageLite>>()) {
// TODO: Holder<proto_ns::Message/MessageLite> cannot be
// legally downcast to Holder<T>, even though that downcast works in
// practice. Need to propose a better way to do the downcast.
Holder<T>* holder = static_cast<Holder<T>*>(this);
T tmp;
VLOG(2) << "Holder proto data type: " << holder->data().GetTypeName()
<< " vs requested proto type: " << tmp.GetTypeName();
if (tmp.GetTypeName() == holder->data().GetTypeName()) {
return holder;
}
}
// Does not hold a T.
return nullptr;
}
// Same as non-const As() function. // Same as non-const As() function.
template <typename T> template <typename T>
inline const Holder<T>* As( const Holder<T>* As() const;
typename std::enable_if<
(!std::is_base_of<proto_ns::MessageLite, T>::value &&
!std::is_base_of<proto_ns::Message, T>::value) ||
(std::is_same<proto_ns::MessageLite, T>::value ||
std::is_same<proto_ns::Message, T>::value)>::type* = 0) const {
if (HolderIsOfType<Holder<T>>() || HolderIsOfType<ForeignHolder<T>>()) {
return static_cast<const Holder<T>*>(this);
}
// Does not hold a T.
return nullptr;
}
// For proto Message/MessageLite subclasses.
// When holder data is a concrete proto, the method downcasts this to
// Holder<T> if the requested type is what is stored.
// When holder data is a generic proto Message/MessageLite and a concrete
// proto type T is requested, the method will downcast the HolderBase to
// Holder<T> if the proto data is an instance of T.
template <typename T>
inline const Holder<T>* As(
typename std::enable_if<
(std::is_base_of<proto_ns::MessageLite, T>::value ||
std::is_base_of<proto_ns::Message, T>::value) &&
(!std::is_same<proto_ns::MessageLite, T>::value &&
!std::is_same<proto_ns::Message, T>::value)>::type* = 0) const {
if (HolderIsOfType<Holder<T>>() || HolderIsOfType<ForeignHolder<T>>()) {
return static_cast<const Holder<T>*>(this);
}
// Holder data is a generic proto Message/MessageLite and a subclass type T
// is requested.
if (HolderIsOfType<Holder<proto_ns::Message>>() ||
HolderIsOfType<ForeignHolder<proto_ns::Message>>() ||
HolderIsOfType<Holder<proto_ns::MessageLite>>() ||
HolderIsOfType<ForeignHolder<proto_ns::MessageLite>>()) {
// TODO: Holder<proto_ns::Message/MessageLite> cannot be
// legally downcast to Holder<T>, even though that downcast works in
// practice. Need to propose a better way to do the downcast.
Holder<T>* holder = static_cast<const Holder<T>*>(this);
T tmp;
VLOG(2) << "Holder proto data type: " << holder->data().GetTypeName()
<< " vs requested proto type: " << tmp.GetTypeName();
if (tmp.GetTypeName() == holder->data().GetTypeName()) {
return holder;
}
}
// Does not hold a T.
return nullptr;
}
// Returns the pointer to MessageLite type for the data in holder, if // Returns the pointer to MessageLite type for the data in holder, if
// underlying object is protocol buffer type, otherwise, nullptr is returned. // underlying object is protocol buffer type, otherwise, nullptr is returned.
@ -520,12 +423,68 @@ ConvertToVectorOfProtoMessageLitePtrs(const T* data,
return result; return result;
} }
// This registry is used to create Holders of the right concrete C++ type given
// a proto type std::string (which is used as the registration key).
class MessageHolderRegistry
: public GlobalFactoryRegistry<std::unique_ptr<HolderBase>> {};
template <typename T>
struct is_concrete_proto_t
: public std::integral_constant<
bool, std::is_base_of<proto_ns::MessageLite, T>{} &&
!std::is_same<proto_ns::MessageLite, T>{} &&
!std::is_same<proto_ns::Message, T>{}> {};
// Registers a message type. T must be a non-cv-qualified concrete proto type.
template <typename T>
struct MessageRegistrationImpl {
static NoDestructor<mediapipe::RegistrationToken> registration;
};
// Static members of template classes can be defined in the header.
template <typename T>
NoDestructor<mediapipe::RegistrationToken>
MessageRegistrationImpl<T>::registration(MessageHolderRegistry::Register(
T{}.GetTypeName(), [] { return absl::make_unique<Holder<T>>(new T); }));
// For non-Message payloads, this does nothing.
template <typename T, typename Enable = void>
struct HolderSupport {
static void EnsureStaticInit() {}
};
// This template ensures that, for each concrete MessageLite subclass that is
// stored in a Packet, we register a function that allows us to create a
// Holder with the correct payload type from the proto's type name.
template <typename T>
struct HolderSupport<T,
typename std::enable_if<is_concrete_proto_t<T>{}>::type> {
// We must use std::remove_cv to ensure we don't try to register Foo twice if
// there are Holder<Foo> and Holder<const Foo>. TODO: lift this
// up to Holder?
using R = MessageRegistrationImpl<typename std::remove_cv<T>::type>;
// For the registration static member to be instantiated, it needs to be
// referenced in a context that requires the definition to exist (see ISO/IEC
// C++ 2003 standard, 14.7.1). Calling this ensures that's the case.
// We need two different call-sites to cover proto types for which packets
// are only ever created (i.e. the protos are only produced by calculators)
// and proto types for which packets are only ever consumed (i.e. the protos
// are only consumed by calculators).
static void EnsureStaticInit() { CHECK(R::registration.get() != nullptr); }
};
template <typename T> template <typename T>
class Holder : public HolderBase { class Holder : public HolderBase {
public: public:
explicit Holder(const T* ptr) : ptr_(ptr) { SetHolderTypeId<Holder>(); } explicit Holder(const T* ptr) : ptr_(ptr) {
HolderSupport<T>::EnsureStaticInit();
SetHolderTypeId<Holder>();
}
~Holder() override { delete_helper(); } ~Holder() override { delete_helper(); }
const T& data() const { return *ptr_; } const T& data() const {
HolderSupport<T>::EnsureStaticInit();
return *ptr_;
}
size_t GetTypeId() const final { return tool::GetTypeHash<T>(); } size_t GetTypeId() const final { return tool::GetTypeHash<T>(); }
// Releases the underlying data pointer and transfers the ownership to a // Releases the underlying data pointer and transfers the ownership to a
// unique pointer. // unique pointer.
@ -622,6 +581,24 @@ class ForeignHolder : public Holder<T> {
} }
}; };
template <typename T>
Holder<T>* HolderBase::As() {
if (HolderIsOfType<Holder<T>>() || HolderIsOfType<ForeignHolder<T>>()) {
return static_cast<Holder<T>*>(this);
}
// Does not hold a T.
return nullptr;
}
template <typename T>
const Holder<T>* HolderBase::As() const {
if (HolderIsOfType<Holder<T>>() || HolderIsOfType<ForeignHolder<T>>()) {
return static_cast<const Holder<T>*>(this);
}
// Does not hold a T.
return nullptr;
}
} // namespace packet_internal } // namespace packet_internal
inline Packet::Packet(const Packet& packet) inline Packet::Packet(const Packet& packet)

View File

@ -0,0 +1,57 @@
// Copyright 2020 The MediaPipe Authors.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "absl/strings/str_cat.h"
#include "mediapipe/framework/calculator_framework.h"
#include "mediapipe/framework/packet.h"
#include "mediapipe/framework/packet_test.pb.h"
#include "mediapipe/framework/port/core_proto_inc.h"
#include "mediapipe/framework/port/gmock.h"
#include "mediapipe/framework/port/gtest.h"
#include "mediapipe/framework/port/status_matchers.h"
namespace mediapipe {
namespace {
namespace test_ns {
class TestSinkCalculator : public CalculatorBase {
public:
static ::mediapipe::Status GetContract(CalculatorContract* cc) {
cc->Inputs().Tag("IN").Set<mediapipe::InputOnlyProto>();
cc->Outputs().Tag("OUT").Set<int>();
return ::mediapipe::OkStatus();
}
::mediapipe::Status Process(CalculatorContext* cc) override {
int x = cc->Inputs().Tag("IN").Get<mediapipe::InputOnlyProto>().x();
cc->Outputs().Tag("OUT").AddPacket(
MakePacket<int>(x).At(cc->InputTimestamp()));
return ::mediapipe::OkStatus();
}
};
REGISTER_CALCULATOR(::mediapipe::test_ns::TestSinkCalculator);
} // namespace test_ns
TEST(PacketTest, InputTypeRegistration) {
using testing::Contains;
ASSERT_EQ(mediapipe::InputOnlyProto{}.GetTypeName(),
"mediapipe.InputOnlyProto");
EXPECT_THAT(packet_internal::MessageHolderRegistry::GetRegisteredNames(),
Contains("mediapipe.InputOnlyProto"));
}
} // namespace
} // namespace mediapipe

View File

@ -174,54 +174,13 @@ TEST(PacketTest, ReturnGenericProtobufMessage) {
.x(0)); .x(0));
} }
TEST(PacketTest, ReturnProtobufMessageSubType) {
std::unique_ptr<::mediapipe::PacketTestProto> proto_ptr(
new ::mediapipe::PacketTestProto);
proto_ptr->add_x(123);
Packet packet = Adopt(static_cast<proto_ns::Message*>(proto_ptr.release()));
EXPECT_EQ(123, packet.Get<::mediapipe::PacketTestProto>().x(0));
EXPECT_EQ(123, packet.Get<const ::mediapipe::PacketTestProto>().x(0));
}
TEST(PacketTest, TryWrongProtobufMessageSubType) { TEST(PacketTest, TryWrongProtobufMessageSubType) {
// Packet of PacketTestProto.
std::unique_ptr<::mediapipe::PacketTestProto> proto_ptr( std::unique_ptr<::mediapipe::PacketTestProto> proto_ptr(
new ::mediapipe::PacketTestProto); new ::mediapipe::PacketTestProto);
proto_ptr->add_x(123); proto_ptr->add_x(123);
Packet packet = Adopt(proto_ptr.release()); Packet packet = Adopt(proto_ptr.release());
EXPECT_FALSE(packet.ValidateAsType<::mediapipe::SimpleProto>().ok()); EXPECT_FALSE(packet.ValidateAsType<::mediapipe::SimpleProto>().ok());
EXPECT_TRUE(packet.ValidateAsType<::mediapipe::PacketTestProto>().ok()); EXPECT_TRUE(packet.ValidateAsType<::mediapipe::PacketTestProto>().ok());
// Packet of proto_ns::Message.
proto_ptr.reset(new ::mediapipe::PacketTestProto);
proto_ptr->add_x(456);
Packet packet2 = Adopt(static_cast<proto_ns::Message*>(proto_ptr.release()));
EXPECT_FALSE(packet2.ValidateAsType<::mediapipe::SimpleProto>().ok());
EXPECT_TRUE(packet2.ValidateAsType<::mediapipe::PacketTestProto>().ok());
EXPECT_EQ(123, packet.Get<::mediapipe::PacketTestProto>().x(0));
}
TEST(PacketTest, ReturnProtobufMessageLiteSubType) {
std::unique_ptr<::mediapipe::PacketTestProto> proto_ptr(
new ::mediapipe::PacketTestProto);
proto_ptr->add_x(123);
Packet packet =
Adopt(static_cast<proto_ns::MessageLite*>(proto_ptr.release()));
EXPECT_EQ(123, packet.Get<::mediapipe::PacketTestProto>().x(0));
EXPECT_EQ(123, packet.Get<const ::mediapipe::PacketTestProto>().x(0));
}
TEST(PacketTest, TryWrongProtobufMessageLiteSubType) {
// Packet of PacketTestProto.
std::unique_ptr<::mediapipe::PacketTestProto> proto_ptr(
new ::mediapipe::PacketTestProto);
// Packet of proto_ns::MessageLite.
proto_ptr->add_x(456);
Packet packet =
Adopt(static_cast<proto_ns::MessageLite*>(proto_ptr.release()));
EXPECT_FALSE(packet.ValidateAsType<::mediapipe::SimpleProto>().ok());
EXPECT_TRUE(packet.ValidateAsType<::mediapipe::PacketTestProto>().ok());
EXPECT_EQ(456, packet.Get<::mediapipe::PacketTestProto>().x(0));
} }
TEST(PacketTest, GetProtoBase) { TEST(PacketTest, GetProtoBase) {
@ -505,5 +464,26 @@ TEST(PacketTest, TestConsumeOrCopyBoundedArray) {
EXPECT_TRUE(packet2.IsEmpty()); EXPECT_TRUE(packet2.IsEmpty());
} }
TEST(PacketTest, MessageHolderRegistration) {
using testing::Contains;
Packet packet = MakePacket<mediapipe::SimpleProto>();
ASSERT_EQ(mediapipe::SimpleProto{}.GetTypeName(), "mediapipe.SimpleProto");
EXPECT_THAT(packet_internal::MessageHolderRegistry::GetRegisteredNames(),
Contains("mediapipe.SimpleProto"));
}
TEST(PacketTest, PacketFromSerializedProto) {
mediapipe::SimpleProto original;
original.add_value("foo");
std::string serialized = original.SerializeAsString();
StatusOr<Packet> maybe_packet = packet_internal::PacketFromDynamicProto(
"mediapipe.SimpleProto", serialized);
MP_ASSERT_OK(maybe_packet);
Packet packet = maybe_packet.ValueOrDie();
MP_EXPECT_OK(packet.ValidateAsType<::mediapipe::SimpleProto>());
EXPECT_FALSE(packet.ValidateAsType<::mediapipe::PacketTestProto>().ok());
}
} // namespace } // namespace
} // namespace mediapipe } // namespace mediapipe

View File

@ -39,3 +39,9 @@ message SerializationProxyProto {
repeated float float_value = 2; repeated float float_value = 2;
repeated string string_value = 3; repeated string string_value = 3;
} }
// This proto should be used only as an input to a calculator, to verify that
// that case is covered.
message InputOnlyProto {
optional int32 x = 1;
}

View File

@ -46,7 +46,7 @@
// but may or may not still be able to run other OpenGL code. // but may or may not still be able to run other OpenGL code.
#if !defined(MEDIAPIPE_DISABLE_GL_COMPUTE) && \ #if !defined(MEDIAPIPE_DISABLE_GL_COMPUTE) && \
(defined(__APPLE__) || defined(__EMSCRIPTEN__) || \ (defined(__APPLE__) || defined(__EMSCRIPTEN__) || \
defined(MEDIAPIPE_DISABLE_GPU)) defined(MEDIAPIPE_DISABLE_GPU) || MEDIAPIPE_USING_SWIFTSHADER)
#define MEDIAPIPE_DISABLE_GL_COMPUTE #define MEDIAPIPE_DISABLE_GL_COMPUTE
#endif #endif

View File

@ -143,8 +143,8 @@ TEST_F(GraphTracerTest, CalculatorTrace) {
{{MakePacket<std::string>("goodbye").At(start_timestamp_)}}); {{MakePacket<std::string>("goodbye").At(start_timestamp_)}});
// Validate the GraphTrace data. // Validate the GraphTrace data.
EXPECT_THAT(GetTrace(), EXPECT_THAT(
EqualsProto(::mediapipe::ParseTextProtoOrDie<GraphTrace>(R"( GetTrace(), EqualsProto(::mediapipe::ParseTextProtoOrDie<GraphTrace>(R"(
base_time: 1608911100000000 base_time: 1608911100000000
base_timestamp: 1608911100000000 base_timestamp: 1608911100000000
stream_name: "" stream_name: ""
@ -163,7 +163,7 @@ TEST_F(GraphTracerTest, CalculatorTrace) {
stream_id: 1 stream_id: 1
event_data: 1 event_data: 1
} }
output_trace { packet_timestamp: 0 stream_id: 2 } output_trace { packet_timestamp: 0 stream_id: 2 event_data: 2 }
} }
)"))); )")));
} }
@ -205,18 +205,27 @@ TEST_F(GraphTracerTest, GraphTrace) {
LogOutputPackets("PCalculator_3", GraphTrace::PROCESS, curr_time, LogOutputPackets("PCalculator_3", GraphTrace::PROCESS, curr_time,
{{MakePacket<std::string>("out").At(start_timestamp_)}}); {{MakePacket<std::string>("out").At(start_timestamp_)}});
curr_time += absl::Microseconds(2000); curr_time += absl::Microseconds(2000);
ClearCalculatorContext("PCalculator_3");
LogInputPackets("PCalculator_3", GraphTrace::PROCESS, curr_time, // Note: the packet data ID is based on the packet's payload address, which
// means the same ID can be reused if data is allocated in the same location
// as a previously expired packet (b/160212191). This means the generated
// trace can change depending on the allocator. To keep results stable, we
// must keep the packets used in this test alive until the end. Each
// TestContextBuilder happens to keep a reference to all packets for the last
// context, so for now we just create a separate TestContextBuilder instead of
// clearing it. TODO: revise this test.
SetUpCalculatorContext("PCalculator_3a", /*node_id=*/2, {"up_2"}, {"down_2"});
LogInputPackets("PCalculator_3a", GraphTrace::PROCESS, curr_time,
{MakePacket<std::string>("pup").At(start_timestamp_ + 5)}); {MakePacket<std::string>("pup").At(start_timestamp_ + 5)});
curr_time += absl::Microseconds(20000); curr_time += absl::Microseconds(20000);
LogOutputPackets( LogOutputPackets(
"PCalculator_3", GraphTrace::PROCESS, curr_time, "PCalculator_3a", GraphTrace::PROCESS, curr_time,
{{MakePacket<std::string>("pout").At(start_timestamp_ + 5)}}); {{MakePacket<std::string>("pout").At(start_timestamp_ + 5)}});
curr_time += absl::Microseconds(1000); curr_time += absl::Microseconds(1000);
// Validate the GraphTrace data. // Validate the GraphTrace data.
EXPECT_THAT(GetTrace(), EXPECT_THAT(
EqualsProto(::mediapipe::ParseTextProtoOrDie<GraphTrace>(R"( GetTrace(), EqualsProto(::mediapipe::ParseTextProtoOrDie<GraphTrace>(R"(
base_time: 1608911100000000 base_time: 1608911100000000
base_timestamp: 1608911100000000 base_timestamp: 1608911100000000
stream_name: "" stream_name: ""
@ -238,9 +247,9 @@ TEST_F(GraphTracerTest, GraphTrace) {
stream_id: 1 stream_id: 1
event_data: 1 event_data: 1
} }
output_trace { packet_timestamp: 0 stream_id: 2 } output_trace { packet_timestamp: 0 stream_id: 2 event_data: 2 }
output_trace { packet_timestamp: 0 stream_id: 3 } output_trace { packet_timestamp: 0 stream_id: 3 event_data: 3 }
output_trace { packet_timestamp: 5 stream_id: 3 } output_trace { packet_timestamp: 5 stream_id: 3 event_data: 4 }
} }
calculator_trace { calculator_trace {
node_id: 1 node_id: 1
@ -254,9 +263,9 @@ TEST_F(GraphTracerTest, GraphTrace) {
finish_time: 11000 finish_time: 11000
packet_timestamp: 0 packet_timestamp: 0
stream_id: 2 stream_id: 2
event_data: 2 event_data: 5
} }
output_trace { packet_timestamp: 0 stream_id: 4 } output_trace { packet_timestamp: 0 stream_id: 4 event_data: 6 }
} }
calculator_trace { calculator_trace {
node_id: 2 node_id: 2
@ -270,9 +279,9 @@ TEST_F(GraphTracerTest, GraphTrace) {
finish_time: 16000 finish_time: 16000
packet_timestamp: 0 packet_timestamp: 0
stream_id: 3 stream_id: 3
event_data: 3 event_data: 7
} }
output_trace { packet_timestamp: 0 stream_id: 5 } output_trace { packet_timestamp: 0 stream_id: 5 event_data: 8 }
} }
calculator_trace { calculator_trace {
node_id: 2 node_id: 2
@ -286,9 +295,9 @@ TEST_F(GraphTracerTest, GraphTrace) {
finish_time: 38000 finish_time: 38000
packet_timestamp: 5 packet_timestamp: 5
stream_id: 3 stream_id: 3
event_data: 4 event_data: 9
} }
output_trace { packet_timestamp: 5 stream_id: 5 } output_trace { packet_timestamp: 5 stream_id: 5 event_data: 10 }
} }
)"))); )")));
@ -1275,7 +1284,9 @@ TEST_F(GraphTracerE2ETest, GpuTaskTrace) {
GraphTrace trace_1; GraphTrace trace_1;
builder.CreateTrace(buffer, absl::InfinitePast(), absl::InfiniteFuture(), builder.CreateTrace(buffer, absl::InfinitePast(), absl::InfiniteFuture(),
&trace_1); &trace_1);
EXPECT_THAT(trace_1, EqualsProto(::mediapipe::ParseTextProtoOrDie<GraphTrace>( EXPECT_THAT(
trace_1,
EqualsProto(::mediapipe::ParseTextProtoOrDie<GraphTrace>(
R"( R"(
base_time: 1100 base_time: 1100
base_timestamp: 1000 base_timestamp: 1000
@ -1294,7 +1305,7 @@ TEST_F(GraphTracerE2ETest, GpuTaskTrace) {
stream_id: 1 stream_id: 1
event_data: 0 event_data: 0
} }
output_trace { packet_timestamp: 0 stream_id: 2 } output_trace { packet_timestamp: 0 stream_id: 2 event_data: 0 }
thread_id: 0 thread_id: 0
} }
calculator_trace { calculator_trace {

View File

@ -330,13 +330,12 @@ class TraceBuilder::Impl {
if (trace_event_registry_[event->event_type].is_stream_event()) { if (trace_event_registry_[event->event_type].is_stream_event()) {
auto stream_trace = event->is_finish ? result->add_output_trace() auto stream_trace = event->is_finish ? result->add_output_trace()
: result->add_input_trace(); : result->add_input_trace();
if (event->is_finish) {
// Log only the packet id for each output event.
stream_trace->set_stream_id(stream_id_map_[event->stream_id]);
stream_trace->set_packet_timestamp(LogTimestamp(event->packet_ts));
} else {
// Log the full stream trace for each input event.
BuildStreamTrace(*event, stream_trace); BuildStreamTrace(*event, stream_trace);
if (!event->is_finish) {
// Note: is_finish is true for output events, false for input events.
// For input events, we log some additional timing information. The
// finish_time is the start_time of this Process call, the start_time
// is the finish_time of the Process call that output the packet.
stream_trace->set_finish_time(LogTime(event->event_time)); stream_trace->set_finish_time(LogTime(event->event_time));
const TraceEvent* output_event = FindOutputEvent(*event); const TraceEvent* output_event = FindOutputEvent(*event);
if (output_event) { if (output_event) {

View File

@ -116,10 +116,19 @@ NodeReadiness ImmediateInputStreamHandler::GetNodeReadiness(
CHECK_EQ(stream_ts, Timestamp::Done()); CHECK_EQ(stream_ts, Timestamp::Done());
if (ProcessTimestampBounds()) { if (ProcessTimestampBounds()) {
// With kReadyForClose, the timestamp-bound Done is returned. // With kReadyForClose, the timestamp-bound Done is returned.
// This bound is processed using the preceding input-timestamp.
// TODO: Make all InputStreamHandlers process Done() like this. // TODO: Make all InputStreamHandlers process Done() like this.
ready_timestamps_[i] = stream_ts.PreviousAllowedInStream(); static const Timestamp kDonePrecedingTimestamp =
input_timestamp = std::min(input_timestamp, ready_timestamps_[i]); Timestamp::Done().PreviousAllowedInStream();
if (prev_ts < kDonePrecedingTimestamp) {
// When kReadyForClose is received for the first time for a sync set,
// it is processed using the timestamp preceding Done() to indicate
// input stream is done, but still needs to be processed.
min_bound = std::min(min_bound, kDonePrecedingTimestamp);
input_timestamp = std::min(input_timestamp, kDonePrecedingTimestamp);
ready_timestamps_[i] = kDonePrecedingTimestamp;
} else {
ready_timestamps_[i] = Timestamp::Done();
}
} else if (prev_ts < Timestamp::Done()) { } else if (prev_ts < Timestamp::Done()) {
stream_became_done = true; stream_became_done = true;
ready_timestamps_[i] = Timestamp::Done(); ready_timestamps_[i] = Timestamp::Done();

View File

@ -133,6 +133,11 @@ class ImmediateInputStreamHandlerTest : public ::testing::Test {
} }
} }
const InputStream& Input(const CollectionItemId& id) {
CHECK(cc_);
return cc_->Inputs().Get(id);
}
PacketType packet_type_; PacketType packet_type_;
std::function<void()> headers_ready_callback_; std::function<void()> headers_ready_callback_;
std::function<void()> notification_callback_; std::function<void()> notification_callback_;
@ -262,6 +267,344 @@ TEST_F(ImmediateInputStreamHandlerTest, ReadyForClose) {
EXPECT_TRUE(errors_.empty()); EXPECT_TRUE(errors_.empty());
} }
TEST_F(ImmediateInputStreamHandlerTest, ProcessTimestampBounds) {
input_stream_handler_->SetProcessTimestampBounds(true);
Timestamp min_stream_timestamp;
ASSERT_FALSE(input_stream_handler_->ScheduleInvocations(
/*max_allowance=*/1, &min_stream_timestamp));
EXPECT_EQ(min_stream_timestamp, Timestamp::PreStream());
const auto& input_a_id = name_to_id_["input_a"];
const auto& input_b_id = name_to_id_["input_b"];
const auto& input_c_id = name_to_id_["input_c"];
std::list<Packet> packets;
packets.push_back(Adopt(new std::string("packet 1")).At(Timestamp(1)));
input_stream_handler_->AddPackets(input_b_id, packets);
input_stream_handler_->SetNextTimestampBound(input_b_id, Timestamp::Done());
ASSERT_TRUE(input_stream_handler_->ScheduleInvocations(
/*max_allowance=*/1, &min_stream_timestamp));
EXPECT_EQ(min_stream_timestamp, Timestamp::Unset());
EXPECT_EQ(cc_->InputTimestamp(), Timestamp(1));
ExpectPackets(cc_->Inputs(), {{"input_b", "packet 1"}});
EXPECT_EQ(Input(input_b_id).Value().Timestamp(), Timestamp(1));
EXPECT_TRUE(Input(input_a_id).Value().IsEmpty());
EXPECT_EQ(Input(input_a_id).Value().Timestamp(), Timestamp::Unstarted());
EXPECT_TRUE(Input(input_c_id).Value().IsEmpty());
EXPECT_EQ(Input(input_c_id).Value().Timestamp(), Timestamp::Unstarted());
// FinalizeInputSet() is a no-op.
input_stream_handler_->FinalizeInputSet(cc_->InputTimestamp(),
&cc_->Inputs());
input_stream_handler_->ClearCurrentInputs(cc_);
ASSERT_TRUE(input_stream_handler_->ScheduleInvocations(
/*max_allowance=*/1, &min_stream_timestamp));
EXPECT_EQ(min_stream_timestamp, Timestamp::Unset());
EXPECT_EQ(cc_->InputTimestamp(), Timestamp::Max());
EXPECT_TRUE(Input(input_b_id).Value().IsEmpty());
EXPECT_EQ(Input(input_b_id).Value().Timestamp(), Timestamp::Max());
EXPECT_TRUE(Input(input_a_id).Value().IsEmpty());
EXPECT_EQ(Input(input_a_id).Value().Timestamp(), Timestamp::Unstarted());
EXPECT_TRUE(Input(input_c_id).Value().IsEmpty());
EXPECT_EQ(Input(input_c_id).Value().Timestamp(), Timestamp::Unstarted());
// FinalizeInputSet() is a no-op.
input_stream_handler_->FinalizeInputSet(cc_->InputTimestamp(),
&cc_->Inputs());
input_stream_handler_->ClearCurrentInputs(cc_);
EXPECT_TRUE(
input_stream_handler_->GetInputStreamManager(input_b_id)->IsEmpty());
input_stream_handler_->SetNextTimestampBound(input_a_id, Timestamp::Done());
input_stream_handler_->SetNextTimestampBound(input_c_id, Timestamp::Done());
ASSERT_TRUE(input_stream_handler_->ScheduleInvocations(
/*max_allowance=*/1, &min_stream_timestamp));
EXPECT_EQ(min_stream_timestamp, Timestamp::Unset());
EXPECT_EQ(cc_->InputTimestamp(), Timestamp::Max());
EXPECT_TRUE(Input(input_b_id).Value().IsEmpty());
EXPECT_EQ(Input(input_b_id).Value().Timestamp(), Timestamp::Max());
EXPECT_TRUE(Input(input_a_id).Value().IsEmpty());
EXPECT_EQ(Input(input_a_id).Value().Timestamp(), Timestamp::Max());
EXPECT_TRUE(Input(input_c_id).Value().IsEmpty());
EXPECT_EQ(Input(input_c_id).Value().Timestamp(), Timestamp::Max());
input_stream_handler_->FinalizeInputSet(cc_->InputTimestamp(),
&cc_->Inputs());
input_stream_handler_->ClearCurrentInputs(cc_);
EXPECT_TRUE(errors_.empty());
// Schedule invocation for Close.
ASSERT_TRUE(input_stream_handler_->ScheduleInvocations(
/*max_allowance=*/1, &min_stream_timestamp));
EXPECT_EQ(min_stream_timestamp, Timestamp::Unset());
EXPECT_EQ(cc_->InputTimestamp(), Timestamp::Done());
EXPECT_TRUE(Input(input_b_id).Value().IsEmpty());
EXPECT_EQ(Input(input_b_id).Value().Timestamp(), Timestamp::Unset());
EXPECT_TRUE(Input(input_a_id).Value().IsEmpty());
EXPECT_EQ(Input(input_a_id).Value().Timestamp(), Timestamp::Unset());
EXPECT_TRUE(Input(input_c_id).Value().IsEmpty());
EXPECT_EQ(Input(input_c_id).Value().Timestamp(), Timestamp::Unset());
input_stream_handler_->FinalizeInputSet(cc_->InputTimestamp(),
&cc_->Inputs());
input_stream_handler_->ClearCurrentInputs(cc_);
EXPECT_TRUE(errors_.empty());
}
TEST_F(ImmediateInputStreamHandlerTest,
ProcessTimestampBoundsNoOpScheduleInvocations) {
input_stream_handler_->SetProcessTimestampBounds(true);
const auto& input_a_id = name_to_id_["input_a"];
const auto& input_b_id = name_to_id_["input_b"];
const auto& input_c_id = name_to_id_["input_c"];
Timestamp min_stream_timestamp;
std::list<Packet> packets;
packets.push_back(Adopt(new std::string("packet 1")).At(Timestamp(1)));
input_stream_handler_->AddPackets(input_b_id, packets);
ASSERT_TRUE(input_stream_handler_->ScheduleInvocations(
/*max_allowance=*/1, &min_stream_timestamp));
EXPECT_EQ(min_stream_timestamp, Timestamp::Unset());
EXPECT_EQ(cc_->InputTimestamp(), Timestamp(1));
ExpectPackets(cc_->Inputs(), {{"input_b", "packet 1"}});
EXPECT_EQ(Input(input_b_id).Value().Timestamp(), Timestamp(1));
EXPECT_TRUE(Input(input_a_id).Value().IsEmpty());
EXPECT_EQ(Input(input_a_id).Value().Timestamp(), Timestamp::Unstarted());
EXPECT_TRUE(Input(input_c_id).Value().IsEmpty());
EXPECT_EQ(Input(input_c_id).Value().Timestamp(), Timestamp::Unstarted());
// FinalizeInputSet() is a no-op.
input_stream_handler_->FinalizeInputSet(cc_->InputTimestamp(),
&cc_->Inputs());
input_stream_handler_->ClearCurrentInputs(cc_);
input_stream_handler_->SetNextTimestampBound(input_a_id, Timestamp::Done());
input_stream_handler_->SetNextTimestampBound(input_c_id, Timestamp::Done());
ASSERT_TRUE(input_stream_handler_->ScheduleInvocations(
/*max_allowance=*/1, &min_stream_timestamp));
EXPECT_EQ(min_stream_timestamp, Timestamp::Unset());
EXPECT_EQ(cc_->InputTimestamp(), Timestamp::Max());
EXPECT_TRUE(Input(input_b_id).Value().IsEmpty());
EXPECT_EQ(Input(input_b_id).Value().Timestamp(), Timestamp(1));
EXPECT_TRUE(Input(input_a_id).Value().IsEmpty());
EXPECT_EQ(Input(input_a_id).Value().Timestamp(), Timestamp::Max());
EXPECT_TRUE(Input(input_c_id).Value().IsEmpty());
EXPECT_EQ(Input(input_c_id).Value().Timestamp(), Timestamp::Max());
input_stream_handler_->FinalizeInputSet(cc_->InputTimestamp(),
&cc_->Inputs());
input_stream_handler_->ClearCurrentInputs(cc_);
EXPECT_TRUE(errors_.empty());
// Try to schedule invocations several times again. Considering nothing
// changed since last invocation nothing should be scheduled.
for (int i = 0; i < 3; ++i) {
ASSERT_FALSE(input_stream_handler_->ScheduleInvocations(
/*max_allowance=*/1, &min_stream_timestamp));
EXPECT_EQ(min_stream_timestamp, Timestamp(2));
EXPECT_TRUE(Input(input_b_id).Value().IsEmpty());
EXPECT_EQ(Input(input_b_id).Value().Timestamp(), Timestamp::Unset());
EXPECT_TRUE(Input(input_a_id).Value().IsEmpty());
EXPECT_EQ(Input(input_a_id).Value().Timestamp(), Timestamp::Unset());
EXPECT_TRUE(Input(input_c_id).Value().IsEmpty());
EXPECT_EQ(Input(input_c_id).Value().Timestamp(), Timestamp::Unset());
}
input_stream_handler_->SetNextTimestampBound(input_b_id, Timestamp::Done());
ASSERT_TRUE(input_stream_handler_->ScheduleInvocations(
/*max_allowance=*/1, &min_stream_timestamp));
EXPECT_EQ(min_stream_timestamp, Timestamp::Unset());
EXPECT_EQ(cc_->InputTimestamp(), Timestamp::Max());
EXPECT_TRUE(Input(input_b_id).Value().IsEmpty());
EXPECT_EQ(Input(input_b_id).Value().Timestamp(), Timestamp::Max());
EXPECT_TRUE(Input(input_a_id).Value().IsEmpty());
EXPECT_EQ(Input(input_a_id).Value().Timestamp(), Timestamp::Max());
EXPECT_TRUE(Input(input_c_id).Value().IsEmpty());
EXPECT_EQ(Input(input_c_id).Value().Timestamp(), Timestamp::Max());
input_stream_handler_->FinalizeInputSet(cc_->InputTimestamp(),
&cc_->Inputs());
input_stream_handler_->ClearCurrentInputs(cc_);
EXPECT_TRUE(errors_.empty());
// Schedule invocation for Close.
ASSERT_TRUE(input_stream_handler_->ScheduleInvocations(
/*max_allowance=*/1, &min_stream_timestamp));
EXPECT_EQ(min_stream_timestamp, Timestamp::Unset());
EXPECT_EQ(cc_->InputTimestamp(), Timestamp::Done());
EXPECT_TRUE(Input(input_b_id).Value().IsEmpty());
EXPECT_EQ(Input(input_b_id).Value().Timestamp(), Timestamp::Unset());
EXPECT_TRUE(Input(input_a_id).Value().IsEmpty());
EXPECT_EQ(Input(input_a_id).Value().Timestamp(), Timestamp::Unset());
EXPECT_TRUE(Input(input_c_id).Value().IsEmpty());
EXPECT_EQ(Input(input_c_id).Value().Timestamp(), Timestamp::Unset());
input_stream_handler_->FinalizeInputSet(cc_->InputTimestamp(),
&cc_->Inputs());
input_stream_handler_->ClearCurrentInputs(cc_);
EXPECT_TRUE(errors_.empty());
// Try to schedule invocations several times again. Considering nothing
// changed since last invocation nothing should be scheduled.
for (int i = 0; i < 3; ++i) {
ASSERT_FALSE(input_stream_handler_->ScheduleInvocations(
/*max_allowance=*/1, &min_stream_timestamp));
EXPECT_EQ(min_stream_timestamp, Timestamp::Unset());
EXPECT_TRUE(Input(input_b_id).Value().IsEmpty());
EXPECT_EQ(Input(input_b_id).Value().Timestamp(), Timestamp::Unset());
EXPECT_TRUE(Input(input_a_id).Value().IsEmpty());
EXPECT_EQ(Input(input_a_id).Value().Timestamp(), Timestamp::Unset());
EXPECT_TRUE(Input(input_c_id).Value().IsEmpty());
EXPECT_EQ(Input(input_c_id).Value().Timestamp(), Timestamp::Unset());
}
}
// Due to some temporary changes in ImmediateInputStreamHandler some packets
// - were queued but never released
// - were released in incorrect order
// As other test cases were passing, this test case is designed to ensure that.
TEST_F(ImmediateInputStreamHandlerTest, VerifyPacketsReleaseOrder) {
input_stream_handler_->SetProcessTimestampBounds(true);
const auto& input_a_id = name_to_id_["input_a"];
const auto& input_b_id = name_to_id_["input_b"];
const auto& input_c_id = name_to_id_["input_c"];
Packet packet_a = Adopt(new std::string("packet a"));
Packet packet_b = Adopt(new std::string("packet b"));
Packet packet_c = Adopt(new std::string("packet c"));
input_stream_handler_->AddPackets(input_a_id, {packet_a.At(Timestamp(1))});
input_stream_handler_->AddPackets(input_b_id, {packet_b.At(Timestamp(2))});
input_stream_handler_->AddPackets(input_c_id, {packet_c.At(Timestamp(3))});
Timestamp min_stream_timestamp;
ASSERT_TRUE(input_stream_handler_->ScheduleInvocations(
/*max_allowance=*/1, &min_stream_timestamp));
EXPECT_EQ(min_stream_timestamp, Timestamp::Unset());
EXPECT_EQ(cc_->InputTimestamp(), Timestamp(1));
ASSERT_FALSE(Input(input_a_id).IsEmpty());
EXPECT_EQ(Input(input_a_id).Get<std::string>(), "packet a");
EXPECT_EQ(Input(input_a_id).Value().Timestamp(), Timestamp(1));
EXPECT_TRUE(Input(input_b_id).Value().IsEmpty());
EXPECT_EQ(Input(input_b_id).Value().Timestamp(), Timestamp(1));
EXPECT_TRUE(Input(input_c_id).Value().IsEmpty());
EXPECT_EQ(Input(input_c_id).Value().Timestamp(), Timestamp(2));
// FinalizeInputSet() is a no-op.
input_stream_handler_->FinalizeInputSet(cc_->InputTimestamp(),
&cc_->Inputs());
input_stream_handler_->ClearCurrentInputs(cc_);
input_stream_handler_->AddPackets(input_a_id, {packet_a.At(Timestamp(5))});
input_stream_handler_->AddPackets(input_b_id, {packet_b.At(Timestamp(5))});
input_stream_handler_->AddPackets(input_c_id, {packet_c.At(Timestamp(5))});
ASSERT_TRUE(input_stream_handler_->ScheduleInvocations(
/*max_allowance=*/1, &min_stream_timestamp));
EXPECT_EQ(min_stream_timestamp, Timestamp::Unset());
EXPECT_EQ(cc_->InputTimestamp(), Timestamp(2));
EXPECT_TRUE(Input(input_a_id).Value().IsEmpty());
EXPECT_EQ(Input(input_a_id).Value().Timestamp(), Timestamp(4));
ASSERT_FALSE(Input(input_b_id).IsEmpty());
EXPECT_EQ(Input(input_b_id).Get<std::string>(), "packet b");
EXPECT_EQ(Input(input_b_id).Value().Timestamp(), Timestamp(2));
EXPECT_TRUE(Input(input_c_id).Value().IsEmpty());
EXPECT_EQ(Input(input_c_id).Value().Timestamp(), Timestamp(2));
// FinalizeInputSet() is a no-op.
input_stream_handler_->FinalizeInputSet(cc_->InputTimestamp(),
&cc_->Inputs());
input_stream_handler_->ClearCurrentInputs(cc_);
ASSERT_TRUE(input_stream_handler_->ScheduleInvocations(
/*max_allowance=*/1, &min_stream_timestamp));
EXPECT_EQ(min_stream_timestamp, Timestamp::Unset());
EXPECT_EQ(cc_->InputTimestamp(), Timestamp(3));
EXPECT_TRUE(Input(input_a_id).Value().IsEmpty());
EXPECT_EQ(Input(input_a_id).Value().Timestamp(), Timestamp(4));
EXPECT_TRUE(Input(input_b_id).Value().IsEmpty());
EXPECT_EQ(Input(input_b_id).Value().Timestamp(), Timestamp(4));
ASSERT_FALSE(Input(input_c_id).Value().IsEmpty());
EXPECT_EQ(Input(input_c_id).Get<std::string>(), "packet c");
EXPECT_EQ(Input(input_c_id).Value().Timestamp(), Timestamp(3));
// FinalizeInputSet() is a no-op.
input_stream_handler_->FinalizeInputSet(cc_->InputTimestamp(),
&cc_->Inputs());
input_stream_handler_->ClearCurrentInputs(cc_);
ASSERT_TRUE(input_stream_handler_->ScheduleInvocations(
/*max_allowance=*/1, &min_stream_timestamp));
EXPECT_EQ(min_stream_timestamp, Timestamp::Unset());
EXPECT_EQ(cc_->InputTimestamp(), Timestamp(5));
ASSERT_FALSE(Input(input_a_id).Value().IsEmpty());
EXPECT_EQ(Input(input_a_id).Get<std::string>(), "packet a");
EXPECT_EQ(Input(input_a_id).Value().Timestamp(), Timestamp(5));
ASSERT_FALSE(Input(input_b_id).Value().IsEmpty());
EXPECT_EQ(Input(input_b_id).Get<std::string>(), "packet b");
EXPECT_EQ(Input(input_b_id).Value().Timestamp(), Timestamp(5));
ASSERT_FALSE(Input(input_c_id).Value().IsEmpty());
EXPECT_EQ(Input(input_c_id).Get<std::string>(), "packet c");
EXPECT_EQ(Input(input_c_id).Value().Timestamp(), Timestamp(5));
input_stream_handler_->FinalizeInputSet(cc_->InputTimestamp(),
&cc_->Inputs());
input_stream_handler_->ClearCurrentInputs(cc_);
input_stream_handler_->SetNextTimestampBound(input_a_id, Timestamp::Done());
input_stream_handler_->SetNextTimestampBound(input_b_id, Timestamp::Done());
input_stream_handler_->SetNextTimestampBound(input_c_id, Timestamp::Done());
ASSERT_TRUE(input_stream_handler_->ScheduleInvocations(
/*max_allowance=*/1, &min_stream_timestamp));
EXPECT_EQ(min_stream_timestamp, Timestamp::Unset());
EXPECT_EQ(cc_->InputTimestamp(), Timestamp::Max());
EXPECT_TRUE(Input(input_b_id).Value().IsEmpty());
EXPECT_EQ(Input(input_b_id).Value().Timestamp(), Timestamp::Max());
EXPECT_TRUE(Input(input_a_id).Value().IsEmpty());
EXPECT_EQ(Input(input_a_id).Value().Timestamp(), Timestamp::Max());
EXPECT_TRUE(Input(input_c_id).Value().IsEmpty());
EXPECT_EQ(Input(input_c_id).Value().Timestamp(), Timestamp::Max());
input_stream_handler_->FinalizeInputSet(cc_->InputTimestamp(),
&cc_->Inputs());
input_stream_handler_->ClearCurrentInputs(cc_);
// Schedule invocation for Close.
ASSERT_TRUE(input_stream_handler_->ScheduleInvocations(
/*max_allowance=*/1, &min_stream_timestamp));
EXPECT_EQ(min_stream_timestamp, Timestamp::Unset());
EXPECT_EQ(cc_->InputTimestamp(), Timestamp::Done());
EXPECT_TRUE(Input(input_b_id).Value().IsEmpty());
EXPECT_EQ(Input(input_b_id).Value().Timestamp(), Timestamp::Unset());
EXPECT_TRUE(Input(input_a_id).Value().IsEmpty());
EXPECT_EQ(Input(input_a_id).Value().Timestamp(), Timestamp::Unset());
EXPECT_TRUE(Input(input_c_id).Value().IsEmpty());
EXPECT_EQ(Input(input_c_id).Value().Timestamp(), Timestamp::Unset());
input_stream_handler_->FinalizeInputSet(cc_->InputTimestamp(),
&cc_->Inputs());
input_stream_handler_->ClearCurrentInputs(cc_);
ASSERT_FALSE(input_stream_handler_->ScheduleInvocations(
/*max_allowance=*/1, &min_stream_timestamp));
EXPECT_EQ(min_stream_timestamp, Timestamp::Unset());
EXPECT_TRUE(Input(input_b_id).Value().IsEmpty());
EXPECT_EQ(Input(input_b_id).Value().Timestamp(), Timestamp::Unset());
EXPECT_TRUE(Input(input_a_id).Value().IsEmpty());
EXPECT_EQ(Input(input_a_id).Value().Timestamp(), Timestamp::Unset());
EXPECT_TRUE(Input(input_c_id).Value().IsEmpty());
EXPECT_EQ(Input(input_c_id).Value().Timestamp(), Timestamp::Unset());
}
// This test simulates how CalculatorNode::ProcessNode() uses an input // This test simulates how CalculatorNode::ProcessNode() uses an input
// stream handler and the associated input streams. // stream handler and the associated input streams.
TEST_F(ImmediateInputStreamHandlerTest, SimulateProcessNode) { TEST_F(ImmediateInputStreamHandlerTest, SimulateProcessNode) {

View File

@ -641,4 +641,61 @@ class DummyTestCalculator : public CalculatorBase {
}; };
REGISTER_CALCULATOR(DummyTestCalculator); REGISTER_CALCULATOR(DummyTestCalculator);
// A Calculator that passes the input value to the output after sleeping for
// a set number of microseconds.
class PassThroughWithSleepCalculator : public CalculatorBase {
public:
static ::mediapipe::Status GetContract(CalculatorContract* cc) {
cc->Inputs().Index(0).Set<int>();
cc->Outputs().Index(0).SetSameAs(&cc->Inputs().Index(0));
cc->InputSidePackets().Tag("SLEEP_MICROS").Set<int>();
cc->InputSidePackets().Tag("CLOCK").Set<std::shared_ptr<Clock>>();
return ::mediapipe::OkStatus();
}
::mediapipe::Status Open(CalculatorContext* cc) final {
cc->SetOffset(TimestampDiff(0));
sleep_micros_ = cc->InputSidePackets().Tag("SLEEP_MICROS").Get<int>();
if (sleep_micros_ < 0) {
return ::mediapipe::InternalError("SLEEP_MICROS should be >= 0");
}
clock_ = cc->InputSidePackets().Tag("CLOCK").Get<std::shared_ptr<Clock>>();
return ::mediapipe::OkStatus();
}
::mediapipe::Status Process(CalculatorContext* cc) final {
clock_->Sleep(absl::Microseconds(sleep_micros_));
int value = cc->Inputs().Index(0).Value().Get<int>();
cc->Outputs().Index(0).Add(new int(value), cc->InputTimestamp());
return ::mediapipe::OkStatus();
}
private:
int sleep_micros_ = 0;
std::shared_ptr<Clock> clock_;
};
REGISTER_CALCULATOR(PassThroughWithSleepCalculator);
// A Calculator that multiples two input values.
class MultiplyIntCalculator : public CalculatorBase {
public:
static ::mediapipe::Status GetContract(CalculatorContract* cc) {
cc->Inputs().Index(0).Set<int>();
cc->Inputs().Index(1).SetSameAs(&cc->Inputs().Index(0));
// cc->Outputs().Index(0).SetSameAs(&cc->Inputs().Index(0));
RET_CHECK(cc->Outputs().HasTag("OUT"));
cc->Outputs().Tag("OUT").SetSameAs(&cc->Inputs().Index(0));
return ::mediapipe::OkStatus();
}
::mediapipe::Status Open(CalculatorContext* cc) final {
cc->SetOffset(TimestampDiff(0));
return ::mediapipe::OkStatus();
}
::mediapipe::Status Process(CalculatorContext* cc) final {
int x = cc->Inputs().Index(0).Value().Get<int>();
int y = cc->Inputs().Index(1).Value().Get<int>();
cc->Outputs().Tag("OUT").Add(new int(x * y), cc->InputTimestamp());
return ::mediapipe::OkStatus();
}
};
REGISTER_CALCULATOR(MultiplyIntCalculator);
} // namespace mediapipe } // namespace mediapipe

View File

@ -101,6 +101,13 @@ std::string ParseNameFromStream(const std::string& stream) {
return name; return name;
} }
std::pair<std::string, int> ParseTagIndex(const std::string& tag_index) {
std::string tag;
int index;
MEDIAPIPE_CHECK_OK(tool::ParseTagIndex(tag_index, &tag, &index));
return {tag, index};
}
std::pair<std::string, int> ParseTagIndexFromStream(const std::string& stream) { std::pair<std::string, int> ParseTagIndexFromStream(const std::string& stream) {
std::string tag, name; std::string tag, name;
int index; int index;

View File

@ -76,6 +76,9 @@ std::string CanonicalNodeName(const CalculatorGraphConfig& graph_config,
// Parses the name from a "tag:index:name". // Parses the name from a "tag:index:name".
std::string ParseNameFromStream(const std::string& stream); std::string ParseNameFromStream(const std::string& stream);
// Parses the TagIndex from a "tag:index".
std::pair<std::string, int> ParseTagIndex(const std::string& tag_index);
// Parses the TagIndex from a "tag:index:name". // Parses the TagIndex from a "tag:index:name".
std::pair<std::string, int> ParseTagIndexFromStream(const std::string& stream); std::pair<std::string, int> ParseTagIndexFromStream(const std::string& stream);

View File

@ -13,15 +13,15 @@
# limitations under the License. # limitations under the License.
# #
licenses(["notice"]) # Apache 2.0
package(default_visibility = ["//mediapipe:__subpackages__"])
load( load(
"//mediapipe/framework/tool:mediapipe_graph.bzl", "//mediapipe/framework/tool:mediapipe_graph.bzl",
"mediapipe_simple_subgraph", "mediapipe_simple_subgraph",
) )
licenses(["notice"]) # Apache 2.0
package(default_visibility = ["//mediapipe:__subpackages__"])
filegroup( filegroup(
name = "test_graph", name = "test_graph",
srcs = ["test.pbtxt"], srcs = ["test.pbtxt"],
@ -31,6 +31,8 @@ exports_files([
"test.pbtxt", "test.pbtxt",
"dub_quad_test_subgraph.pbtxt", "dub_quad_test_subgraph.pbtxt",
"nested_test_subgraph.pbtxt", "nested_test_subgraph.pbtxt",
"single_flow_container_test.pbtxt",
"dual_flow_container_test.pbtxt",
]) ])
mediapipe_simple_subgraph( mediapipe_simple_subgraph(

View File

@ -12,14 +12,14 @@
# See the License for the specific language governing permissions and # See the License for the specific language governing permissions and
# limitations under the License. # limitations under the License.
licenses(["notice"]) # Apache 2.0
package(default_visibility = ["//visibility:public"])
load("//mediapipe/gpu:metal.bzl", "metal_library") load("//mediapipe/gpu:metal.bzl", "metal_library")
load("@build_bazel_rules_apple//apple:ios.bzl", "ios_unit_test") load("@build_bazel_rules_apple//apple:ios.bzl", "ios_unit_test")
load("//mediapipe/framework/port:build_config.bzl", "mediapipe_cc_proto_library") load("//mediapipe/framework/port:build_config.bzl", "mediapipe_cc_proto_library")
licenses(["notice"]) # Apache 2.0
package(default_visibility = ["//visibility:public"])
# Disabling GPU support is sometimes useful on desktop Linux because SwiftShader can # Disabling GPU support is sometimes useful on desktop Linux because SwiftShader can
# interfere with desktop GL. b/73494271 # interfere with desktop GL. b/73494271
config_setting( config_setting(

View File

@ -39,6 +39,7 @@ namespace mediapipe {
// ROTATION: the counterclockwise rotation angle in degrees. This allows // ROTATION: the counterclockwise rotation angle in degrees. This allows
// user to specify different rotation angles for different frames. If this // user to specify different rotation angles for different frames. If this
// stream is provided, it will override the ROTATION input side packet. // stream is provided, it will override the ROTATION input side packet.
// OUTPUT_DIMENSIONS: the output width and height in pixels.
// Additional output streams: // Additional output streams:
// TOP_BOTTOM_PADDING: If use FIT scale mode, this stream outputs the padding // TOP_BOTTOM_PADDING: If use FIT scale mode, this stream outputs the padding
// size of the input image in normalized value [0, 1] for top and bottom // size of the input image in normalized value [0, 1] for top and bottom
@ -103,6 +104,9 @@ REGISTER_CALCULATOR(GlScalerCalculator);
if (cc->Inputs().HasTag("ROTATION")) { if (cc->Inputs().HasTag("ROTATION")) {
cc->Inputs().Tag("ROTATION").Set<int>(); cc->Inputs().Tag("ROTATION").Set<int>();
} }
if (cc->Inputs().HasTag("OUTPUT_DIMENSIONS")) {
cc->Inputs().Tag("OUTPUT_DIMENSIONS").Set<DimensionsPacketType>();
}
MP_RETURN_IF_ERROR(GlCalculatorHelper::UpdateContract(cc)); MP_RETURN_IF_ERROR(GlCalculatorHelper::UpdateContract(cc));
if (cc->InputSidePackets().HasTag("OPTIONS")) { if (cc->InputSidePackets().HasTag("OPTIONS")) {
@ -181,6 +185,18 @@ REGISTER_CALCULATOR(GlScalerCalculator);
} }
::mediapipe::Status GlScalerCalculator::Process(CalculatorContext* cc) { ::mediapipe::Status GlScalerCalculator::Process(CalculatorContext* cc) {
if (cc->Inputs().HasTag("OUTPUT_DIMENSIONS")) {
if (cc->Inputs().Tag("OUTPUT_DIMENSIONS").IsEmpty()) {
// OUTPUT_DIMENSIONS input stream is specified, but value is missing.
return ::mediapipe::OkStatus();
}
const auto& dimensions =
cc->Inputs().Tag("OUTPUT_DIMENSIONS").Get<DimensionsPacketType>();
dst_width_ = dimensions[0];
dst_height_ = dimensions[1];
}
return helper_.RunInGlContext([this, cc]() -> ::mediapipe::Status { return helper_.RunInGlContext([this, cc]() -> ::mediapipe::Status {
const auto& input = TagOrIndex(cc->Inputs(), "VIDEO", 0).Get<GpuBuffer>(); const auto& input = TagOrIndex(cc->Inputs(), "VIDEO", 0).Get<GpuBuffer>();
QuadRenderer* renderer = nullptr; QuadRenderer* renderer = nullptr;

View File

@ -140,6 +140,9 @@ node {
num_landmarks: 21 num_landmarks: 21
input_image_width: 256 input_image_width: 256
input_image_height: 256 input_image_height: 256
# The additional scaling factor is used to account for the Z coordinate
# distribution in the training data.
normalize_z: 0.4
} }
} }
} }

View File

@ -144,6 +144,9 @@ node {
num_landmarks: 21 num_landmarks: 21
input_image_width: 256 input_image_width: 256
input_image_height: 256 input_image_height: 256
# The additional scaling factor is used to account for the Z coordinate
# distribution in the training data.
normalize_z: 0.4
} }
} }
} }

View File

@ -25,6 +25,7 @@ android_library(
), ),
visibility = ["//visibility:public"], visibility = ["//visibility:public"],
deps = [ deps = [
"//mediapipe/framework:calculator_java_proto_lite",
"//mediapipe/java/com/google/mediapipe/framework:android_framework", "//mediapipe/java/com/google/mediapipe/framework:android_framework",
"//mediapipe/java/com/google/mediapipe/glutil", "//mediapipe/java/com/google/mediapipe/glutil",
"//third_party:androidx_appcompat", "//third_party:androidx_appcompat",

View File

@ -14,17 +14,21 @@
package com.google.mediapipe.components; package com.google.mediapipe.components;
import static java.lang.Math.max;
import android.graphics.SurfaceTexture; import android.graphics.SurfaceTexture;
import android.opengl.GLES11Ext; import android.opengl.GLES11Ext;
import android.opengl.GLES20; import android.opengl.GLES20;
import android.util.Log; import android.util.Log;
import com.google.mediapipe.framework.AppTextureFrame; import com.google.mediapipe.framework.AppTextureFrame;
import com.google.mediapipe.framework.GlSyncToken;
import com.google.mediapipe.glutil.ExternalTextureRenderer; import com.google.mediapipe.glutil.ExternalTextureRenderer;
import com.google.mediapipe.glutil.GlThread; import com.google.mediapipe.glutil.GlThread;
import com.google.mediapipe.glutil.ShaderUtil; import com.google.mediapipe.glutil.ShaderUtil;
import java.util.ArrayDeque;
import java.util.ArrayList; import java.util.ArrayList;
import java.util.Collections;
import java.util.List; import java.util.List;
import java.util.Queue;
import javax.microedition.khronos.egl.EGLContext; import javax.microedition.khronos.egl.EGLContext;
/** /**
@ -204,8 +208,11 @@ public class ExternalTextureConverter implements TextureFrameProducer {
private static final long NANOS_PER_MICRO = 1000; // Nanoseconds in one microsecond. private static final long NANOS_PER_MICRO = 1000; // Nanoseconds in one microsecond.
private volatile SurfaceTexture surfaceTexture = null; private volatile SurfaceTexture surfaceTexture = null;
private final List<TextureFrameConsumer> consumers; private final List<TextureFrameConsumer> consumers;
private List<AppTextureFrame> outputFrames = null;
private int outputFrameIndex = -1; private final Queue<PoolTextureFrame> framesAvailable = new ArrayDeque<>();
private int framesInUse = 0;
private final int framesToKeep;
private ExternalTextureRenderer renderer = null; private ExternalTextureRenderer renderer = null;
private long nextFrameTimestampOffset = 0; private long nextFrameTimestampOffset = 0;
private long timestampOffsetNanos = 0; private long timestampOffsetNanos = 0;
@ -215,10 +222,27 @@ public class ExternalTextureConverter implements TextureFrameProducer {
protected int destinationWidth = 0; protected int destinationWidth = 0;
protected int destinationHeight = 0; protected int destinationHeight = 0;
private class PoolTextureFrame extends AppTextureFrame {
public PoolTextureFrame(int textureName, int width, int height) {
super(textureName, width, height);
}
@Override
public void release(GlSyncToken syncToken) {
super.release(syncToken);
poolFrameReleased(this);
}
@Override
public void release() {
super.release();
poolFrameReleased(this);
}
}
public RenderThread(EGLContext parentContext, int numBuffers) { public RenderThread(EGLContext parentContext, int numBuffers) {
super(parentContext); super(parentContext);
outputFrames = new ArrayList<>(); framesToKeep = numBuffers;
outputFrames.addAll(Collections.nCopies(numBuffers, null));
renderer = new ExternalTextureRenderer(); renderer = new ExternalTextureRenderer();
consumers = new ArrayList<>(); consumers = new ArrayList<>();
} }
@ -283,8 +307,8 @@ public class ExternalTextureConverter implements TextureFrameProducer {
@Override @Override
public void releaseGl() { public void releaseGl() {
setSurfaceTexture(null, 0, 0); setSurfaceTexture(null, 0, 0);
for (int i = 0; i < outputFrames.size(); ++i) { while (!framesAvailable.isEmpty()) {
teardownDestination(i); teardownFrame(framesAvailable.remove());
} }
renderer.release(); renderer.release();
super.releaseGl(); // This releases the EGL context, so must do it after any GL calls. super.releaseGl(); // This releases the EGL context, so must do it after any GL calls.
@ -337,16 +361,11 @@ public class ExternalTextureConverter implements TextureFrameProducer {
} }
} }
private void teardownDestination(int index) { private static void teardownFrame(AppTextureFrame frame) {
if (outputFrames.get(index) != null) { GLES20.glDeleteTextures(1, new int[] {frame.getTextureName()}, 0);
waitUntilReleased(outputFrames.get(index));
GLES20.glDeleteTextures(1, new int[] {outputFrames.get(index).getTextureName()}, 0);
outputFrames.set(index, null);
}
} }
private void setupDestination(int index) { private PoolTextureFrame createFrame() {
teardownDestination(index);
int destinationTextureId = ShaderUtil.createRgbaTexture(destinationWidth, destinationHeight); int destinationTextureId = ShaderUtil.createRgbaTexture(destinationWidth, destinationHeight);
Log.d( Log.d(
TAG, TAG,
@ -354,11 +373,9 @@ public class ExternalTextureConverter implements TextureFrameProducer {
"Created output texture: %d width: %d height: %d", "Created output texture: %d width: %d height: %d",
destinationTextureId, destinationWidth, destinationHeight)); destinationTextureId, destinationWidth, destinationHeight));
bindFramebuffer(destinationTextureId, destinationWidth, destinationHeight); bindFramebuffer(destinationTextureId, destinationWidth, destinationHeight);
outputFrames.set( return new PoolTextureFrame(destinationTextureId, destinationWidth, destinationHeight);
index, new AppTextureFrame(destinationTextureId, destinationWidth, destinationHeight));
} }
/** /**
* Gets next available frame or creates new one if next frame is not initialized * Gets next available frame or creates new one if next frame is not initialized
* or cannot be used with current surface texture. * or cannot be used with current surface texture.
@ -371,20 +388,38 @@ public class ExternalTextureConverter implements TextureFrameProducer {
* NOTE: must be invoked on GL thread * NOTE: must be invoked on GL thread
*/ */
private AppTextureFrame nextOutputFrame() { private AppTextureFrame nextOutputFrame() {
outputFrameIndex = (outputFrameIndex + 1) % outputFrames.size(); PoolTextureFrame outputFrame;
AppTextureFrame outputFrame = outputFrames.get(outputFrameIndex); synchronized (this) {
// Check if the size has changed. outputFrame = framesAvailable.poll();
if (outputFrame == null framesInUse++;
|| outputFrame.getWidth() != destinationWidth
|| outputFrame.getHeight() != destinationHeight) {
// setupDestination will wait for the frame to be released before reallocating it.
setupDestination(outputFrameIndex);
outputFrame = outputFrames.get(outputFrameIndex);
} }
if (outputFrame == null) {
outputFrame = createFrame();
} else if (outputFrame.getWidth() != destinationWidth
|| outputFrame.getHeight() != destinationHeight) {
// Create anew if size has changed.
// TODO: waiting for the consumer sync here may not be necessary.
waitUntilReleased(outputFrame); waitUntilReleased(outputFrame);
teardownFrame(outputFrame);
outputFrame = createFrame();
} else {
// Note: waitUntilReleased does two things: waits for the frame to be released by the CPU,
// and syncs with the GPU sync token provided by the consumer. The first part is redundant
// here (and completes immediately), but the second part is still needed.
waitUntilReleased(outputFrame);
}
return outputFrame; return outputFrame;
} }
protected synchronized void poolFrameReleased(PoolTextureFrame frame) {
framesAvailable.offer(frame);
framesInUse--;
int keep = max(framesToKeep - framesInUse, 0);
while (framesAvailable.size() > keep) {
teardownFrame(framesAvailable.remove());
}
}
/** /**
* Updates output frame with current pixels of surface texture and corresponding timestamp. * Updates output frame with current pixels of surface texture and corresponding timestamp.
* *
@ -417,16 +452,22 @@ public class ExternalTextureConverter implements TextureFrameProducer {
Log.v( Log.v(
TAG, TAG,
String.format( String.format(
"Waiting for tex: %d width: %d height: %d", "Waiting for tex: %d width: %d height: %d timestamp: %d",
frame.getTextureName(), frame.getWidth(), frame.getHeight())); frame.getTextureName(),
frame.getWidth(),
frame.getHeight(),
frame.getTimestamp()));
} }
frame.waitUntilReleased(); frame.waitUntilReleased();
if (Log.isLoggable(TAG, Log.VERBOSE)) { if (Log.isLoggable(TAG, Log.VERBOSE)) {
Log.v( Log.v(
TAG, TAG,
String.format( String.format(
"Finished waiting for tex: %d width: %d height: %d", "Finished waiting for tex: %d width: %d height: %d timestamp: %d",
frame.getTextureName(), frame.getWidth(), frame.getHeight())); frame.getTextureName(),
frame.getWidth(),
frame.getHeight(),
frame.getTimestamp()));
} }
} catch (InterruptedException ie) { } catch (InterruptedException ie) {
// Someone interrupted our thread. This is not supposed to happen: we own // Someone interrupted our thread. This is not supposed to happen: we own

View File

@ -20,6 +20,7 @@ import android.media.AudioFormat;
import android.os.Handler; import android.os.Handler;
import android.util.Log; import android.util.Log;
import com.google.common.base.Preconditions; import com.google.common.base.Preconditions;
import com.google.mediapipe.proto.CalculatorProto.CalculatorGraphConfig;
import com.google.mediapipe.framework.AndroidAssetUtil; import com.google.mediapipe.framework.AndroidAssetUtil;
import com.google.mediapipe.framework.AndroidPacketCreator; import com.google.mediapipe.framework.AndroidPacketCreator;
import com.google.mediapipe.framework.Graph; import com.google.mediapipe.framework.Graph;
@ -32,10 +33,12 @@ import com.google.mediapipe.framework.SurfaceOutput;
import com.google.mediapipe.framework.TextureFrame; import com.google.mediapipe.framework.TextureFrame;
import java.io.File; import java.io.File;
import java.nio.ByteBuffer; import java.nio.ByteBuffer;
import java.util.ArrayDeque;
import java.util.ArrayList; import java.util.ArrayList;
import java.util.Arrays; import java.util.Arrays;
import java.util.List; import java.util.List;
import java.util.Map; import java.util.Map;
import java.util.Queue;
import java.util.concurrent.atomic.AtomicBoolean; import java.util.concurrent.atomic.AtomicBoolean;
import javax.annotation.Nullable; import javax.annotation.Nullable;
@ -106,6 +109,15 @@ public class FrameProcessor implements TextureFrameProcessor, AudioDataProcessor
initializeGraphAndPacketCreator(context, graphName); initializeGraphAndPacketCreator(context, graphName);
} }
/**
* Constructor.
*
* @param graphConfig the proto object representation of the graph.
*/
public FrameProcessor(CalculatorGraphConfig graphConfig) {
initializeGraphAndPacketCreator(graphConfig);
}
/** /**
* Initializes a graph for processing data in real time. * Initializes a graph for processing data in real time.
* *
@ -123,6 +135,17 @@ public class FrameProcessor implements TextureFrameProcessor, AudioDataProcessor
packetCreator = new AndroidPacketCreator(mediapipeGraph); packetCreator = new AndroidPacketCreator(mediapipeGraph);
} }
/**
* Initializes a graph for processing data in real time.
*
* @param graphConfig the proto object representation of the graph.
*/
private void initializeGraphAndPacketCreator(CalculatorGraphConfig graphConfig) {
mediapipeGraph = new Graph();
mediapipeGraph.loadBinaryGraph(graphConfig);
packetCreator = new AndroidPacketCreator(mediapipeGraph);
}
/** Callback for errors occurring during processing in the graph. */ /** Callback for errors occurring during processing in the graph. */
public interface ErrorListener { public interface ErrorListener {
void onError(RuntimeException error); void onError(RuntimeException error);
@ -186,6 +209,8 @@ public class FrameProcessor implements TextureFrameProcessor, AudioDataProcessor
currentConsumers = videoConsumers; currentConsumers = videoConsumers;
} }
for (TextureFrameConsumer consumer : currentConsumers) { for (TextureFrameConsumer consumer : currentConsumers) {
// Note: each consumer will release its TextureFrame, so each gets a separate object
// (though they all reference the same data).
TextureFrame frame = PacketGetter.getTextureFrame(packet); TextureFrame frame = PacketGetter.getTextureFrame(packet);
if (Log.isLoggable(TAG, Log.VERBOSE)) { if (Log.isLoggable(TAG, Log.VERBOSE)) {
Log.v( Log.v(
@ -373,9 +398,10 @@ public class FrameProcessor implements TextureFrameProcessor, AudioDataProcessor
/** /**
* Returns true if the MediaPipe graph can accept one more input frame. * Returns true if the MediaPipe graph can accept one more input frame.
*
* @throws MediaPipeException for any error status. * @throws MediaPipeException for any error status.
*/ */
private boolean maybeAcceptNewFrame() { private boolean maybeAcceptNewFrame(long timestamp) {
if (!started.getAndSet(true)) { if (!started.getAndSet(true)) {
startGraph(); startGraph();
} }
@ -395,7 +421,7 @@ public class FrameProcessor implements TextureFrameProcessor, AudioDataProcessor
frame.getTextureName(), frame.getWidth(), frame.getHeight())); frame.getTextureName(), frame.getWidth(), frame.getHeight()));
} }
if (!maybeAcceptNewFrame()) { if (!maybeAcceptNewFrame(frame.getTimestamp())) {
return; return;
} }
@ -451,7 +477,7 @@ public class FrameProcessor implements TextureFrameProcessor, AudioDataProcessor
public void onNewFrame(final Bitmap bitmap, long timestamp) { public void onNewFrame(final Bitmap bitmap, long timestamp) {
Packet packet = null; Packet packet = null;
try { try {
if (!maybeAcceptNewFrame()) { if (!maybeAcceptNewFrame(timestamp)) {
return; return;
} }

View File

@ -17,8 +17,8 @@ package com.google.mediapipe.components;
import android.Manifest; import android.Manifest;
import android.app.Activity; import android.app.Activity;
import android.content.pm.PackageManager; import android.content.pm.PackageManager;
import androidx.core.app.ActivityCompat;
import android.util.Log; import android.util.Log;
import androidx.core.app.ActivityCompat;
import androidx.core.content.ContextCompat; import androidx.core.content.ContextCompat;
/** Manages camera permission request and handling. */ /** Manages camera permission request and handling. */

View File

@ -18,6 +18,10 @@ import com.google.mediapipe.framework.TextureFrame;
/** Lightweight abstraction for an object that can receive video frames. */ /** Lightweight abstraction for an object that can receive video frames. */
public interface TextureFrameConsumer { public interface TextureFrameConsumer {
/** Called when a new {@link TextureFrame} is available. */ /**
* Called when a new {@link TextureFrame} is available.
*
* Important: implementations of this method should call frame.release().
**/
public abstract void onNewFrame(TextureFrame frame); public abstract void onNewFrame(TextureFrame frame);
} }

View File

@ -272,6 +272,10 @@ public final class PacketGetter {
* <p>Note: in order for the application to be able to use the texture, its GL context must be * <p>Note: in order for the application to be able to use the texture, its GL context must be
* linked with MediaPipe's. This is ensured by calling {@link Graph#createGlRunner(String,long)} * linked with MediaPipe's. This is ensured by calling {@link Graph#createGlRunner(String,long)}
* with the native handle to the application's GL context as the second argument. * with the native handle to the application's GL context as the second argument.
*
* <p>The returned GraphTextureFrame must be released by the caller. If this method is called
* multiple times, each returned GraphTextureFrame is an independent reference to the underlying
* texture data, and must be released individually.
*/ */
public static GraphTextureFrame getTextureFrame(final Packet packet) { public static GraphTextureFrame getTextureFrame(final Packet packet) {
return new GraphTextureFrame( return new GraphTextureFrame(

View File

@ -1,7 +0,0 @@
tricorder: {
options: {
builder: {
config: "android_arm"
}
}
}

View File

@ -184,6 +184,20 @@ typedef NS_ENUM(int, MPPPacketType) {
packetType:(MPPPacketType)packetType packetType:(MPPPacketType)packetType
timestamp:(const mediapipe::Timestamp &)timestamp; timestamp:(const mediapipe::Timestamp &)timestamp;
/// Sends a pixel buffer into a graph input stream, using the specified packet
/// type. The graph must have been started before calling this. Drops frames and
/// returns NO if maxFramesInFlight is exceeded. If allowOverwrite is set to YES,
/// allows MediaPipe to overwrite the packet contents on successful sending for
/// possibly increased efficiency. Returns YES if the packet was successfully sent.
/// Sets error to a non-nil value if an error occurs in the graph when sending the
/// packet.
- (BOOL)sendPixelBuffer:(CVPixelBufferRef)imageBuffer
intoStream:(const std::string &)inputName
packetType:(MPPPacketType)packetType
timestamp:(const mediapipe::Timestamp &)timestamp
allowOverwrite:(BOOL)allowOverwrite
error:(NSError **)error;
/// Sends a pixel buffer into a graph input stream, using the specified packet /// Sends a pixel buffer into a graph input stream, using the specified packet
/// type. The graph must have been started before calling this. The timestamp is /// type. The graph must have been started before calling this. The timestamp is
/// automatically incremented from the last timestamp used by this method. Drops /// automatically incremented from the last timestamp used by this method. Drops

View File

@ -327,22 +327,35 @@ void CallFrameDelegate(void* wrapperVoid, const std::string& streamName,
packetType:(MPPPacketType)packetType packetType:(MPPPacketType)packetType
timestamp:(const mediapipe::Timestamp&)timestamp timestamp:(const mediapipe::Timestamp&)timestamp
allowOverwrite:(BOOL)allowOverwrite { allowOverwrite:(BOOL)allowOverwrite {
NSError* error;
bool success = [self sendPixelBuffer:imageBuffer
intoStream:inputName
packetType:packetType
timestamp:timestamp
allowOverwrite:allowOverwrite
error:&error];
if (error) {
_GTMDevLog(@"failed to send packet: %@", error);
}
return success;
}
- (BOOL)sendPixelBuffer:(CVPixelBufferRef)imageBuffer
intoStream:(const std::string&)inputName
packetType:(MPPPacketType)packetType
timestamp:(const mediapipe::Timestamp&)timestamp
allowOverwrite:(BOOL)allowOverwrite
error:(NSError**)error {
if (_maxFramesInFlight && _framesInFlight >= _maxFramesInFlight) return NO; if (_maxFramesInFlight && _framesInFlight >= _maxFramesInFlight) return NO;
mediapipe::Packet packet = [self packetWithPixelBuffer:imageBuffer packetType:packetType]; mediapipe::Packet packet = [self packetWithPixelBuffer:imageBuffer packetType:packetType];
NSError* error;
BOOL success; BOOL success;
if (allowOverwrite) { if (allowOverwrite) {
packet = std::move(packet).At(timestamp); packet = std::move(packet).At(timestamp);
success = [self movePacket:std::move(packet) success = [self movePacket:std::move(packet) intoStream:inputName error:error];
intoStream:inputName
error:&error];
} else { } else {
success = [self sendPacket:packet.At(timestamp) success = [self sendPacket:packet.At(timestamp) intoStream:inputName error:error];
intoStream:inputName
error:&error];
} }
if (success) _framesInFlight++; if (success) _framesInFlight++;
else _GTMDevLog(@"failed to send packet: %@", error);
return success; return success;
} }

View File

@ -423,6 +423,10 @@ tasks and tracking (or class) fields for tracking information.
|`region/point/y`|feature list float list|`add_bbox_point_y` / `AddBBoxPointY`|A list of normalized y values for points in a frame.| |`region/point/y`|feature list float list|`add_bbox_point_y` / `AddBBoxPointY`|A list of normalized y values for points in a frame.|
|`region/point/\*`| *special* |`add_bbox_point` / `AddBBoxPoint`|Operates on point/x,point/y with a single call.| |`region/point/\*`| *special* |`add_bbox_point` / `AddBBoxPoint`|Operates on point/x,point/y with a single call.|
|`region/point/radius`|feature list float list|`add_bbox_point_radius` / `AddBBoxPointRadius`|A list of radii for points in a frame.| |`region/point/radius`|feature list float list|`add_bbox_point_radius` / `AddBBoxPointRadius`|A list of radii for points in a frame.|
|`region/3d_point/x`|feature list float list|`add_bbox_3d_point_x` / `AddBBox3dPointX`|A list of normalized x values for points in a frame.|
|`region/3d_point/y`|feature list float list|`add_bbox_3d_point_y` / `AddBBox3dPointY`|A list of normalized y values for points in a frame.|
|`region/3d_point/z`|feature list float list|`add_bbox_3d_point_z` / `AddBBox3dPointZ`|A list of normalized z values for points in a frame.|
|`region/3d_point/\*`| *special* |`add_bbox_3d_point` / `AddBBox3dPoint`|Operates on 3d_point/{x,y,z} with a single call.|
|`region/timestamp`|feature list int|`add_bbox_timestamp` / `AddBBoxTimestamp`|The timestamp in microseconds for the region annotations.| |`region/timestamp`|feature list int|`add_bbox_timestamp` / `AddBBoxTimestamp`|The timestamp in microseconds for the region annotations.|
|`region/num_regions`|feature list int|`add_bbox_num_regions` / `AddBBoxNumRegions`|The number of boxes or other regions in a frame. Should be 0 for unannotated frames.| |`region/num_regions`|feature list int|`add_bbox_num_regions` / `AddBBoxNumRegions`|The number of boxes or other regions in a frame. Should be 0 for unannotated frames.|
|`region/is_annotated`|feature list int|`add_bbox_is_annotated` / `AddBBoxIsAnnotated`|1 if this timestep is annotated. 0 otherwise. Distinguishes empty from unannotated frames.| |`region/is_annotated`|feature list int|`add_bbox_is_annotated` / `AddBBoxIsAnnotated`|1 if this timestep is annotated. 0 otherwise. Distinguishes empty from unannotated frames.|

View File

@ -229,6 +229,18 @@ float TimestampsToRate(int64 first_timestamp, int64 second_timestamp) {
sequence); sequence);
} }
} }
if (Get3dPointSize(prefix, *sequence) > 0) {
std::string x_key = merge_prefix(prefix, kRegion3dPointXKey);
auto* region_feature_list = MutableFeatureList(x_key, sequence);
RET_CHECK_EQ(num_bboxes, region_feature_list->feature_size())
<< "Expected number of BBox timestamps and boxes to match.";
ClearBBoxNumRegions(prefix, sequence);
for (int i = 0; i < num_bboxes; ++i) {
AddBBoxNumRegions(
prefix, region_feature_list->feature(i).float_list().value_size(),
sequence);
}
}
// Collect which timestamps currently match to which indices in timestamps. // Collect which timestamps currently match to which indices in timestamps.
// skip empty timestamps. // skip empty timestamps.
// Requires sorted indices. // Requires sorted indices.
@ -453,6 +465,47 @@ void ClearPoint(const std::string& prefix,
ClearBBoxPointX(prefix, sequence); ClearBBoxPointX(prefix, sequence);
} }
int Get3dPointSize(const std::string& prefix,
const tensorflow::SequenceExample& sequence) {
return GetBBox3dPointXSize(prefix, sequence);
}
std::vector<::std::tuple<float, float, float>> Get3dPointAt(
const std::string& prefix, const tensorflow::SequenceExample& sequence,
int index) {
const auto& xs = GetBBox3dPointXAt(prefix, sequence, index);
const auto& ys = GetBBox3dPointYAt(prefix, sequence, index);
const auto& zs = GetBBox3dPointZAt(prefix, sequence, index);
std::vector<::std::tuple<float, float, float>> points(ys.size());
for (int i = 0; i < xs.size(); ++i) {
points[i] = std::make_tuple(xs[i], ys[i], zs[i]);
}
return points;
}
void Add3dPoint(const std::string& prefix,
const std::vector<::std::tuple<float, float, float>>& points,
tensorflow::SequenceExample* sequence) {
::std::vector<float> xs;
::std::vector<float> ys;
::std::vector<float> zs;
for (auto& point : points) {
xs.push_back(std::get<0>(point));
ys.push_back(std::get<1>(point));
zs.push_back(std::get<2>(point));
}
AddBBox3dPointX(prefix, xs, sequence);
AddBBox3dPointY(prefix, ys, sequence);
AddBBox3dPointZ(prefix, zs, sequence);
}
void Clear3dPoint(const std::string& prefix,
tensorflow::SequenceExample* sequence) {
ClearBBox3dPointX(prefix, sequence);
ClearBBox3dPointY(prefix, sequence);
ClearBBox3dPointZ(prefix, sequence);
}
std::unique_ptr<mediapipe::Matrix> GetAudioFromFeatureAt( std::unique_ptr<mediapipe::Matrix> GetAudioFromFeatureAt(
const std::string& prefix, const tensorflow::SequenceExample& sequence, const std::string& prefix, const tensorflow::SequenceExample& sequence,
int index) { int index) {

View File

@ -268,6 +268,10 @@ const char kRegionBBoxXMaxKey[] = "region/bbox/xmax";
const char kRegionPointXKey[] = "region/point/x"; const char kRegionPointXKey[] = "region/point/x";
const char kRegionPointYKey[] = "region/point/y"; const char kRegionPointYKey[] = "region/point/y";
const char kRegionRadiusKey[] = "region/radius"; const char kRegionRadiusKey[] = "region/radius";
// The 3d point can denote keypoints.
const char kRegion3dPointXKey[] = "region/3d_point/x";
const char kRegion3dPointYKey[] = "region/3d_point/y";
const char kRegion3dPointZKey[] = "region/3d_point/z";
// The number of regions at that timestep. // The number of regions at that timestep.
const char kRegionNumRegionsKey[] = "region/num_regions"; const char kRegionNumRegionsKey[] = "region/num_regions";
// Whether that timestep is annotated for bounding regions. // Whether that timestep is annotated for bounding regions.
@ -333,6 +337,18 @@ void AddPoint(const std::string& prefix,
void ClearPoint(const std::string& prefix, void ClearPoint(const std::string& prefix,
tensorflow::SequenceExample* sequence); tensorflow::SequenceExample* sequence);
// The input and output format is a pair of <y, x> coordinates to match the
// order of bounding box coordinates.
int Get3dPointSize(const std::string& prefix,
const tensorflow::SequenceExample& sequence);
std::vector<std::tuple<float, float, float>> Get3dPointAt(
const std::string& prefix, const tensorflow::SequenceExample& sequence,
int index);
void Add3dPoint(const std::string& prefix,
const std::vector<std::tuple<float, float, float>>& points,
tensorflow::SequenceExample* sequence);
void Clear3dPoint(const std::string& prefix,
tensorflow::SequenceExample* sequence);
#define FIXED_PREFIX_BBOX_ACCESSORS(identifier, prefix) \ #define FIXED_PREFIX_BBOX_ACCESSORS(identifier, prefix) \
inline int CONCAT_STR3(Get, identifier, \ inline int CONCAT_STR3(Get, identifier, \
Size)(const tensorflow::SequenceExample& sequence) { \ Size)(const tensorflow::SequenceExample& sequence) { \
@ -388,6 +404,44 @@ void ClearPoint(const std::string& prefix,
inline void CONCAT_STR3(Clear, identifier, Point)( \ inline void CONCAT_STR3(Clear, identifier, Point)( \
std::string name, tensorflow::SequenceExample * sequence) { \ std::string name, tensorflow::SequenceExample * sequence) { \
return ClearPoint(name, sequence); \ return ClearPoint(name, sequence); \
} \
inline int CONCAT_STR3(Get, identifier, 3dPointSize)( \
const tensorflow::SequenceExample& sequence) { \
return Get3dPointSize(prefix, sequence); \
} \
inline int CONCAT_STR3(Get, identifier, 3dPointSize)( \
const std::string& name, const tensorflow::SequenceExample& sequence) { \
return Get3dPointSize(name, sequence); \
} \
inline std::vector<std::tuple<float, float, float>> CONCAT_STR3( \
Get, identifier, 3dPointAt)(const tensorflow::SequenceExample& sequence, \
int index) { \
return Get3dPointAt(prefix, sequence, index); \
} \
inline std::vector<std::tuple<float, float, float>> CONCAT_STR3( \
Get, identifier, 3dPointAt)(const std::string& name, \
const tensorflow::SequenceExample& sequence, \
int index) { \
return Get3dPointAt(name, sequence, index); \
} \
inline void CONCAT_STR3(Add, identifier, 3dPoint)( \
const std::vector<std::tuple<float, float, float>>& points, \
tensorflow::SequenceExample* sequence) { \
return Add3dPoint(prefix, points, sequence); \
} \
inline void CONCAT_STR3(Add, identifier, 3dPoint)( \
const std::string& name, \
const std::vector<std::tuple<float, float, float>>& points, \
tensorflow::SequenceExample* sequence) { \
return Add3dPoint(name, points, sequence); \
} \
inline void CONCAT_STR3(Clear, identifier, \
3dPoint)(tensorflow::SequenceExample * sequence) { \
return Clear3dPoint(prefix, sequence); \
} \
inline void CONCAT_STR3(Clear, identifier, 3dPoint)( \
std::string name, tensorflow::SequenceExample * sequence) { \
return Clear3dPoint(name, sequence); \
} }
#define PREFIXED_BBOX(identifier, prefix) \ #define PREFIXED_BBOX(identifier, prefix) \
@ -435,6 +489,12 @@ void ClearPoint(const std::string& prefix,
kRegionPointYKey, prefix) \ kRegionPointYKey, prefix) \
FIXED_PREFIX_VECTOR_FLOAT_FEATURE_LIST(CONCAT_STR2(identifier, Radius), \ FIXED_PREFIX_VECTOR_FLOAT_FEATURE_LIST(CONCAT_STR2(identifier, Radius), \
kRegionRadiusKey, prefix) \ kRegionRadiusKey, prefix) \
FIXED_PREFIX_VECTOR_FLOAT_FEATURE_LIST(CONCAT_STR2(identifier, 3dPointX), \
kRegion3dPointXKey, prefix) \
FIXED_PREFIX_VECTOR_FLOAT_FEATURE_LIST(CONCAT_STR2(identifier, 3dPointY), \
kRegion3dPointYKey, prefix) \
FIXED_PREFIX_VECTOR_FLOAT_FEATURE_LIST(CONCAT_STR2(identifier, 3dPointZ), \
kRegion3dPointZKey, prefix) \
FIXED_PREFIX_VECTOR_FLOAT_FEATURE_LIST( \ FIXED_PREFIX_VECTOR_FLOAT_FEATURE_LIST( \
CONCAT_STR2(identifier, EmbeddingFloats), kRegionEmbeddingFloatKey, \ CONCAT_STR2(identifier, EmbeddingFloats), kRegionEmbeddingFloatKey, \
prefix) \ prefix) \

View File

@ -262,6 +262,10 @@ REGION_BBOX_XMAX_KEY = "region/bbox/xmax"
REGION_POINT_X_KEY = "region/point/x" REGION_POINT_X_KEY = "region/point/x"
REGION_POINT_Y_KEY = "region/point/y" REGION_POINT_Y_KEY = "region/point/y"
REGION_RADIUS_KEY = "region/radius" REGION_RADIUS_KEY = "region/radius"
# The 3D point can denote keypoints.
REGION_3D_POINT_X_KEY = "region/3d_point/x"
REGION_3D_POINT_Y_KEY = "region/3d_point/y"
REGION_3D_POINT_Z_KEY = "region/3d_point/z"
# The number of regions at that timestep. # The number of regions at that timestep.
REGION_NUM_REGIONS_KEY = "region/num_regions" REGION_NUM_REGIONS_KEY = "region/num_regions"
# Whether that timestep is annotated for regions. # Whether that timestep is annotated for regions.
@ -365,6 +369,15 @@ def _create_region_with_prefix(name, prefix):
prefix=prefix, module_dict=globals()) prefix=prefix, module_dict=globals())
msu.create_float_list_feature_list(name + "_point_y", REGION_POINT_Y_KEY, msu.create_float_list_feature_list(name + "_point_y", REGION_POINT_Y_KEY,
prefix=prefix, module_dict=globals()) prefix=prefix, module_dict=globals())
msu.create_float_list_feature_list(
name + "_3d_point_x", REGION_3D_POINT_X_KEY,
prefix=prefix, module_dict=globals())
msu.create_float_list_feature_list(
name + "_3d_point_y", REGION_3D_POINT_Y_KEY,
prefix=prefix, module_dict=globals())
msu.create_float_list_feature_list(
name + "_3d_point_z", REGION_3D_POINT_Z_KEY,
prefix=prefix, module_dict=globals())
msu.create_bytes_list_context_feature(name + "_parts", msu.create_bytes_list_context_feature(name + "_parts",
REGION_PARTS_KEY, REGION_PARTS_KEY,
prefix=prefix, module_dict=globals()) prefix=prefix, module_dict=globals())
@ -406,6 +419,39 @@ def _create_region_with_prefix(name, prefix):
clear_bbox_xmin(sequence_example, prefix=prefix) clear_bbox_xmin(sequence_example, prefix=prefix)
clear_bbox_ymax(sequence_example, prefix=prefix) clear_bbox_ymax(sequence_example, prefix=prefix)
clear_bbox_xmax(sequence_example, prefix=prefix) clear_bbox_xmax(sequence_example, prefix=prefix)
def get_prefixed_point_at(index, sequence_example, prefix):
return np.stack((
get_bbox_point_y_at(index, sequence_example, prefix=prefix),
get_bbox_point_x_at(index, sequence_example, prefix=prefix)),
1)
def add_prefixed_point(values, sequence_example, prefix):
add_bbox_point_y(values[:, 0], sequence_example, prefix=prefix)
add_bbox_point_x(values[:, 1], sequence_example, prefix=prefix)
def get_prefixed_point_size(sequence_example, prefix):
return get_bbox_point_y_size(sequence_example, prefix=prefix)
def has_prefixed_point(sequence_example, prefix):
return has_bbox_point_y(sequence_example, prefix=prefix)
def clear_prefixed_point(sequence_example, prefix):
clear_bbox_point_y(sequence_example, prefix=prefix)
clear_bbox_point_x(sequence_example, prefix=prefix)
def get_prefixed_3d_point_at(index, sequence_example, prefix):
return np.stack((
get_bbox_3d_point_x_at(index, sequence_example, prefix=prefix),
get_bbox_3d_point_y_at(index, sequence_example, prefix=prefix),
get_bbox_3d_point_z_at(index, sequence_example, prefix=prefix)),
1)
def add_prefixed_3d_point(values, sequence_example, prefix):
add_bbox_3d_point_x(values[:, 0], sequence_example, prefix=prefix)
add_bbox_3d_point_y(values[:, 1], sequence_example, prefix=prefix)
add_bbox_3d_point_z(values[:, 2], sequence_example, prefix=prefix)
def get_prefixed_3d_point_size(sequence_example, prefix):
return get_bbox_3d_point_x_size(sequence_example, prefix=prefix)
def has_prefixed_3d_point(sequence_example, prefix):
return has_bbox_3d_point_x(sequence_example, prefix=prefix)
def clear_prefixed_3d_point(sequence_example, prefix):
clear_bbox_3d_point_x(sequence_example, prefix=prefix)
clear_bbox_3d_point_y(sequence_example, prefix=prefix)
clear_bbox_3d_point_z(sequence_example, prefix=prefix)
# pylint: enable=undefined-variable # pylint: enable=undefined-variable
msu.add_functions_to_module({ msu.add_functions_to_module({
"get_" + name + "_at": "get_" + name + "_at":
@ -419,6 +465,30 @@ def _create_region_with_prefix(name, prefix):
"clear_" + name: "clear_" + name:
functools.partial(clear_prefixed_bbox, prefix=prefix), functools.partial(clear_prefixed_bbox, prefix=prefix),
}, module_dict=globals()) }, module_dict=globals())
msu.add_functions_to_module({
"get_" + name + "_point_at":
functools.partial(get_prefixed_point_at, prefix=prefix),
"add_" + name + "_point":
functools.partial(add_prefixed_point, prefix=prefix),
"get_" + name + "_point_size":
functools.partial(get_prefixed_point_size, prefix=prefix),
"has_" + name + "_point":
functools.partial(has_prefixed_point, prefix=prefix),
"clear_" + name + "_point":
functools.partial(clear_prefixed_point, prefix=prefix),
}, module_dict=globals())
msu.add_functions_to_module({
"get_" + name + "_3d_point_at":
functools.partial(get_prefixed_3d_point_at, prefix=prefix),
"add_" + name + "_3d_point":
functools.partial(add_prefixed_3d_point, prefix=prefix),
"get_" + name + "_3d_point_size":
functools.partial(get_prefixed_3d_point_size, prefix=prefix),
"has_" + name + "_3d_point":
functools.partial(has_prefixed_3d_point, prefix=prefix),
"clear_" + name + "_3d_point":
functools.partial(clear_prefixed_3d_point, prefix=prefix),
}, module_dict=globals())
PREDICTED_PREFIX = "PREDICTED" PREDICTED_PREFIX = "PREDICTED"

View File

@ -436,6 +436,21 @@ TEST(MediaSequenceTest, RoundTripBBoxPointPrefixed) {
} }
} }
TEST(MediaSequenceTest, RoundTripBBox3dPoint) {
tensorflow::SequenceExample sequence;
std::vector<std::vector<std::tuple<float, float, float>>> points = {
{std::make_tuple(0.3, 0.5, 0.1), std::make_tuple(0.4, 0.7, 0.2)},
{std::make_tuple(0.7, 0.5, 0.3), std::make_tuple(0.3, 0.4, 0.4)}};
for (int i = 0; i < points.size(); ++i) {
AddBBox3dPoint(points[i], &sequence);
ASSERT_EQ(GetBBox3dPointSize(sequence), i + 1);
const auto& sequence_points = GetBBox3dPointAt(sequence, i);
for (int j = 0; j < sequence_points.size(); ++j) {
EXPECT_EQ(sequence_points[j], points[i][j]);
}
}
}
TEST(MediaSequenceTest, RoundTripRegionParts) { TEST(MediaSequenceTest, RoundTripRegionParts) {
tensorflow::SequenceExample sequence; tensorflow::SequenceExample sequence;
std::vector<std::string> parts = {"HEAD", "FEET"}; std::vector<std::string> parts = {"HEAD", "FEET"};

View File

@ -89,6 +89,9 @@ class MediaSequenceTest(tf.test.TestCase):
ms.add_bbox_xmax((0.47, 0.49), example) ms.add_bbox_xmax((0.47, 0.49), example)
ms.add_bbox_point_x((0.47, 0.49), example) ms.add_bbox_point_x((0.47, 0.49), example)
ms.add_bbox_point_y((0.47, 0.49), example) ms.add_bbox_point_y((0.47, 0.49), example)
ms.add_bbox_3d_point_x((0.47, 0.49), example)
ms.add_bbox_3d_point_y((0.47, 0.49), example)
ms.add_bbox_3d_point_z((0.47, 0.49), example)
ms.add_predicted_bbox_ymin((0.47, 0.49), example) ms.add_predicted_bbox_ymin((0.47, 0.49), example)
ms.add_predicted_bbox_xmin((0.47, 0.49), example) ms.add_predicted_bbox_xmin((0.47, 0.49), example)
ms.add_predicted_bbox_ymax((0.47, 0.49), example) ms.add_predicted_bbox_ymax((0.47, 0.49), example)
@ -133,6 +136,30 @@ class MediaSequenceTest(tf.test.TestCase):
ms.clear_bbox(example) ms.clear_bbox(example)
self.assertEqual(0, ms.get_bbox_size(example)) self.assertEqual(0, ms.get_bbox_size(example))
def test_point_round_trip(self):
example = tf.train.SequenceExample()
points = np.array([[0.1, 0.2],
[0.5, 0.6]])
ms.add_bbox_point(points, example)
ms.add_bbox_point(points, example)
self.assertEqual(2, ms.get_bbox_point_size(example))
self.assertAllClose(points, ms.get_bbox_point_at(0, example))
self.assertTrue(ms.has_bbox_point(example))
ms.clear_bbox_point(example)
self.assertEqual(0, ms.get_bbox_point_size(example))
def test_3d_point_round_trip(self):
example = tf.train.SequenceExample()
points = np.array([[0.1, 0.2, 0.3],
[0.5, 0.6, 0.7]])
ms.add_bbox_3d_point(points, example)
ms.add_bbox_3d_point(points, example)
self.assertEqual(2, ms.get_bbox_3d_point_size(example))
self.assertAllClose(points, ms.get_bbox_3d_point_at(0, example))
self.assertTrue(ms.has_bbox_3d_point(example))
ms.clear_bbox_3d_point(example)
self.assertEqual(0, ms.get_bbox_3d_point_size(example))
def test_predicted_bbox_round_trip(self): def test_predicted_bbox_round_trip(self):
example = tf.train.SequenceExample() example = tf.train.SequenceExample()
boxes = np.array([[0.1, 0.2, 0.3, 0.4], boxes = np.array([[0.1, 0.2, 0.3, 0.4],

View File

@ -19,6 +19,14 @@ package(default_visibility = [
"//mediapipe:__subpackages__", "//mediapipe:__subpackages__",
]) ])
cc_library(
name = "config",
hdrs = ["config.h"],
deps = [
"//mediapipe/framework:calculator_framework",
],
)
cc_library( cc_library(
name = "cpu_op_resolver", name = "cpu_op_resolver",
srcs = ["cpu_op_resolver.cc"], srcs = ["cpu_op_resolver.cc"],
@ -69,6 +77,7 @@ cc_test(
srcs = ["tensor_buffer_test.cc"], srcs = ["tensor_buffer_test.cc"],
deps = [ deps = [
":tensor_buffer", ":tensor_buffer",
":config",
"//mediapipe/framework/port:gtest_main", "//mediapipe/framework/port:gtest_main",
] + select({ ] + select({
"//mediapipe/gpu:disable_gpu": [], "//mediapipe/gpu:disable_gpu": [],
@ -99,6 +108,7 @@ cc_library(
"@org_tensorflow//tensorflow/lite:framework", "@org_tensorflow//tensorflow/lite:framework",
"@org_tensorflow//tensorflow/lite/delegates/gpu:api", "@org_tensorflow//tensorflow/lite/delegates/gpu:api",
"@org_tensorflow//tensorflow/lite/delegates/gpu/common:model", "@org_tensorflow//tensorflow/lite/delegates/gpu/common:model",
"@org_tensorflow//tensorflow/lite/delegates/gpu/common/testing:tflite_model_reader",
"@org_tensorflow//tensorflow/lite/delegates/gpu/gl:api2", "@org_tensorflow//tensorflow/lite/delegates/gpu/gl:api2",
], ],
"//mediapipe:android": [ "//mediapipe:android": [
@ -108,7 +118,9 @@ cc_library(
"//mediapipe/framework/port:statusor", "//mediapipe/framework/port:statusor",
"@org_tensorflow//tensorflow/lite:framework", "@org_tensorflow//tensorflow/lite:framework",
"@org_tensorflow//tensorflow/lite/delegates/gpu:api", "@org_tensorflow//tensorflow/lite/delegates/gpu:api",
"@org_tensorflow//tensorflow/lite/delegates/gpu/cl:api",
"@org_tensorflow//tensorflow/lite/delegates/gpu/common:model", "@org_tensorflow//tensorflow/lite/delegates/gpu/common:model",
"@org_tensorflow//tensorflow/lite/delegates/gpu/common/testing:tflite_model_reader",
"@org_tensorflow//tensorflow/lite/delegates/gpu/gl:api2", "@org_tensorflow//tensorflow/lite/delegates/gpu/gl:api2",
], ],
}) + ["@org_tensorflow//tensorflow/lite/core/api"], }) + ["@org_tensorflow//tensorflow/lite/core/api"],

View File

@ -0,0 +1,59 @@
// Copyright 2020 The MediaPipe Authors.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#ifndef MEDIAPIPE_UTIL_TFLITE_CONFIG_H_
#define MEDIAPIPE_UTIL_TFLITE_CONFIG_H_
#include "mediapipe/framework/calculator_framework.h"
// MediaPipe code should use the following defines to determine whether TFLite
// GPU support is available, and whether GL or Metal inference is available.
#ifdef MEDIAPIPE_DISABLE_GL_COMPUTE
#define MEDIAPIPE_TFLITE_GL_INFERENCE 0
#else
#define MEDIAPIPE_TFLITE_GL_INFERENCE 1
#endif // MEDIAPIPE_DISABLE_GL_COMPUTE
#ifdef MEDIAPIPE_IOS
#define MEDIAPIPE_TFLITE_METAL_INFERENCE 1
#else
#define MEDIAPIPE_TFLITE_METAL_INFERENCE 0
#endif // MEDIAPIPE_IOS
#define MEDIAPIPE_TFLITE_GPU_SUPPORTED \
((MEDIAPIPE_TFLITE_GL_INFERENCE) || (MEDIAPIPE_TFLITE_METAL_INFERENCE))
#if MEDIAPIPE_TFLITE_GL_INFERENCE
#include "tensorflow/lite/delegates/gpu/gl/gl_buffer.h"
#endif // MEDIAPIPE_TFLITE_GL_INFERENCE
#if MEDIAPIPE_TFLITE_METAL_INFERENCE
#import <Metal/Metal.h>
#endif // MEDIAPIPE_TFLITE_METAL_INFERENCE
namespace mediapipe {
#if MEDIAPIPE_TFLITE_GL_INFERENCE
typedef ::tflite::gpu::gl::GlBuffer GpuTensor;
#elif MEDIAPIPE_TFLITE_METAL_INFERENCE
typedef id<MTLBuffer> GpuTensor;
#else
struct DummyGpuTensor {};
typedef DummyGpuTensor GpuTensor; // Dummy define for less #ifdefs
#endif // MEDIAPIPE_TFLITE_GL_INFERENCE
} // namespace mediapipe
#endif // MEDIAPIPE_UTIL_TFLITE_CONFIG_H_

View File

@ -130,8 +130,7 @@ TfLiteStatus Prepare(TfLiteContext* context, TfLiteNode* node) {
auto padding = params->padding; auto padding = params->padding;
auto compute_out_size = [padding](int image_size, int filter_size, auto compute_out_size = [padding](int image_size, int filter_size,
int stride) -> int { int stride) -> int {
return padding == kTfLitePaddingSame return padding == kTfLitePaddingSame ? (image_size + stride - 1) / stride
? (image_size + stride - 1) / stride
: padding == kTfLitePaddingValid : padding == kTfLitePaddingValid
? (image_size - filter_size + stride) / stride ? (image_size - filter_size + stride) / stride
: 0; : 0;

View File

@ -2,6 +2,7 @@
#include "mediapipe/framework/port/gmock.h" #include "mediapipe/framework/port/gmock.h"
#include "mediapipe/framework/port/gtest.h" #include "mediapipe/framework/port/gtest.h"
#include "mediapipe/util/tflite/config.h"
namespace mediapipe { namespace mediapipe {
@ -12,7 +13,7 @@ TEST(Cpu, BasicTest) {
EXPECT_FALSE(tb.UsesGpu()); EXPECT_FALSE(tb.UsesGpu());
} }
#if !defined(MEDIAPIPE_DISABLE_GPU) #if MEDIAPIPE_TFLITE_GL_INFERENCE
TEST(Gpu, BasicTest) { TEST(Gpu, BasicTest) {
TensorBuffer tb; TensorBuffer tb;
std::shared_ptr<tflite::gpu::gl::GlBuffer> tfg_tb = std::shared_ptr<tflite::gpu::gl::GlBuffer> tfg_tb =
@ -20,7 +21,7 @@ TEST(Gpu, BasicTest) {
tb = TensorBuffer(tfg_tb); tb = TensorBuffer(tfg_tb);
EXPECT_TRUE(tb.UsesGpu()); EXPECT_TRUE(tb.UsesGpu());
} }
#endif // !MEDIAPIPE_DISABLE_GPU #endif // !MEDIAPIPE_TFLITE_GL_INFERENCE
} // namespace mediapipe } // namespace mediapipe

View File

@ -30,6 +30,13 @@
#include "tensorflow/lite/delegates/gpu/gl/api2.h" #include "tensorflow/lite/delegates/gpu/gl/api2.h"
#include "tensorflow/lite/model.h" #include "tensorflow/lite/model.h"
// This code should be enabled as soon as TensorFlow version, which mediapipe
// uses, will include this module.
#ifdef __ANDROID__
#include "tensorflow/lite/delegates/gpu/cl/api.h"
#endif
#include "tensorflow/lite/delegates/gpu/common/testing/tflite_model_reader.h"
namespace tflite { namespace tflite {
namespace gpu { namespace gpu {
namespace { namespace {
@ -51,6 +58,19 @@ ObjectDef GetSSBOObjectDef(int channels) {
mediapipe::Status TFLiteGPURunner::InitializeWithModel( mediapipe::Status TFLiteGPURunner::InitializeWithModel(
const tflite::FlatBufferModel& flatbuffer, const tflite::FlatBufferModel& flatbuffer,
const tflite::OpResolver& op_resolver) { const tflite::OpResolver& op_resolver) {
// GraphFloat32 is created twice because, when OpenCL and OpenGL backends are
// initialized, different backend-specific graph transformations happen
// in-place. As GraphFloat32 is not copyable by design, we keep two copies of
// the graph until inference is built. This decision doesn't affect the amount
// of run time memory used, because both graph_gl_ and graph_cl_ are deleted
// in the end of the initialization stage.
graph_gl_ = std::make_unique<GraphFloat32>();
graph_cl_ = std::make_unique<GraphFloat32>();
MP_RETURN_IF_ERROR(
BuildFromFlatBuffer(flatbuffer, op_resolver, graph_gl_.get()));
MP_RETURN_IF_ERROR(
BuildFromFlatBuffer(flatbuffer, op_resolver, graph_cl_.get()));
for (const auto& input : graph_gl_->inputs()) { for (const auto& input : graph_gl_->inputs()) {
input_shapes_.push_back(input->tensor.shape); input_shapes_.push_back(input->tensor.shape);
} }
@ -140,6 +160,19 @@ mediapipe::Status TFLiteGPURunner::InitializeOpenGL(
absl::Status TFLiteGPURunner::InitializeOpenCL( absl::Status TFLiteGPURunner::InitializeOpenCL(
std::unique_ptr<InferenceBuilder>* builder) { std::unique_ptr<InferenceBuilder>* builder) {
#ifdef __ANDROID__
cl::InferenceEnvironmentOptions env_options;
cl::InferenceEnvironmentProperties properties;
cl::InferenceOptions cl_options;
cl_options.priority1 = options_.priority1;
cl_options.priority2 = options_.priority2;
cl_options.priority3 = options_.priority3;
cl_options.usage = options_.usage;
MP_RETURN_IF_ERROR(
cl::NewInferenceEnvironment(env_options, &cl_environment_, &properties));
MP_RETURN_IF_ERROR(cl_environment_->NewInferenceBuilder(
cl_options, std::move(*graph_cl_), builder));
#endif
return absl::OkStatus(); return absl::OkStatus();
} }

View File

@ -27,6 +27,10 @@
#include "tensorflow/lite/delegates/gpu/gl/api2.h" #include "tensorflow/lite/delegates/gpu/gl/api2.h"
#include "tensorflow/lite/model.h" #include "tensorflow/lite/model.h"
#ifdef __ANDROID__
#include "tensorflow/lite/delegates/gpu/cl/api.h"
#endif
namespace tflite { namespace tflite {
namespace gpu { namespace gpu {
@ -64,6 +68,9 @@ class TFLiteGPURunner {
mediapipe::Status Build(); mediapipe::Status Build();
mediapipe::Status Invoke(); mediapipe::Status Invoke();
std::vector<BHWC> GetInputShapes() { return input_shapes_; }
std::vector<BHWC> GetOutputShapes() { return output_shapes_; }
private: private:
mediapipe::Status InitializeOpenGL( mediapipe::Status InitializeOpenGL(
std::unique_ptr<InferenceBuilder>* builder); std::unique_ptr<InferenceBuilder>* builder);
@ -73,6 +80,10 @@ class TFLiteGPURunner {
InferenceOptions options_; InferenceOptions options_;
std::unique_ptr<gl::InferenceEnvironment> gl_environment_; std::unique_ptr<gl::InferenceEnvironment> gl_environment_;
#ifdef __ANDROID__
std::unique_ptr<cl::InferenceEnvironment> cl_environment_;
#endif
// graph_ is maintained temporarily and becomes invalid after runner_ is ready // graph_ is maintained temporarily and becomes invalid after runner_ is ready
std::unique_ptr<GraphFloat32> graph_gl_; std::unique_ptr<GraphFloat32> graph_gl_;
std::unique_ptr<GraphFloat32> graph_cl_; std::unique_ptr<GraphFloat32> graph_cl_;

View File

@ -0,0 +1,50 @@
diff --git a/googletest/include/gtest/internal/gtest-internal.h b/googletest/include/gtest/internal/gtest-internal.h
index 7f1a5b00e..c36029ee1 100644
--- a/googletest/include/gtest/internal/gtest-internal.h
+++ b/googletest/include/gtest/internal/gtest-internal.h
@@ -94,6 +94,12 @@ namespace proto2 {
class MessageLite;
}
+namespace google {
+namespace protobuf {
+class MessageLite;
+}
+}
+
namespace testing {
// Forward declarations.
@@ -881,10 +887,15 @@ class GTEST_API_ Random {
typename std::remove_const<typename std::remove_reference<T>::type>::type
// IsAProtocolMessage<T>::value is a compile-time bool constant that's
-// true if and only if T is type proto2::MessageLite or a subclass of it.
+// true if and only if T is type proto2::MessageLite or
+// google::protobuf::MessageLite or a subclass of one of them.
template <typename T>
struct IsAProtocolMessage
- : public std::is_convertible<const T*, const ::proto2::MessageLite*> {};
+ : public std::integral_constant<
+ bool,
+ std::is_convertible<const T*, const ::proto2::MessageLite*>::value ||
+ std::is_convertible<
+ const T*, const ::google::protobuf::MessageLite*>::value> {};
// When the compiler sees expression IsContainerTest<C>(0), if C is an
// STL-style container class, the first overload of IsContainerTest
diff --git a/googletest/test/gtest_unittest.cc b/googletest/test/gtest_unittest.cc
index 005a2d40d..631180e3d 100644
--- a/googletest/test/gtest_unittest.cc
+++ b/googletest/test/gtest_unittest.cc
@@ -7115,6 +7115,10 @@ TEST(IsAProtocolMessageTest, ValueIsTrueWhenTypeIsAProtocolMessage) {
EXPECT_TRUE(IsAProtocolMessage<::proto2::MessageLite>::value);
}
+TEST(IsAProtocolMessageTest, ValueIsTrueWhenTypeIsAnOpenSourceProtocolMessage) {
+ EXPECT_TRUE(IsAProtocolMessage<::google::protobuf::MessageLite>::value);
+}
+
// Tests that IsAProtocolMessage<T>::value is false when T is neither
// ::proto2::Message nor a sub-class of it.
TEST(IsAProtocolMessageTest, ValueIsFalseWhenTypeIsNotAProtocolMessage) {