Project import generated by Copybara.

GitOrigin-RevId: 6e4aff1cc351be3ae4537b677f36d139ee50ce09
This commit is contained in:
MediaPipe Team 2021-03-25 15:01:44 -07:00 committed by chuoling
parent a92cff7a60
commit 7c331ad58b
175 changed files with 4804 additions and 1325 deletions

View File

@ -54,7 +54,7 @@ RUN pip3 install tf_slim
RUN ln -s /usr/bin/python3 /usr/bin/python RUN ln -s /usr/bin/python3 /usr/bin/python
# Install bazel # Install bazel
ARG BAZEL_VERSION=3.4.1 ARG BAZEL_VERSION=3.7.2
RUN mkdir /bazel && \ RUN mkdir /bazel && \
wget --no-check-certificate -O /bazel/installer.sh "https://github.com/bazelbuild/bazel/releases/download/${BAZEL_VERSION}/b\ wget --no-check-certificate -O /bazel/installer.sh "https://github.com/bazelbuild/bazel/releases/download/${BAZEL_VERSION}/b\
azel-${BAZEL_VERSION}-installer-linux-x86_64.sh" && \ azel-${BAZEL_VERSION}-installer-linux-x86_64.sh" && \

View File

@ -10,3 +10,7 @@ include requirements.txt
recursive-include mediapipe/modules *.tflite *.txt *.binarypb recursive-include mediapipe/modules *.tflite *.txt *.binarypb
exclude mediapipe/modules/objectron/object_detection_3d_chair_1stage.tflite exclude mediapipe/modules/objectron/object_detection_3d_chair_1stage.tflite
exclude mediapipe/modules/objectron/object_detection_3d_sneakers_1stage.tflite exclude mediapipe/modules/objectron/object_detection_3d_sneakers_1stage.tflite
exclude mediapipe/modules/objectron/object_detection_3d_sneakers.tflite
exclude mediapipe/modules/objectron/object_detection_3d_chair.tflite
exclude mediapipe/modules/objectron/object_detection_3d_camera.tflite
exclude mediapipe/modules/objectron/object_detection_3d_cup.tflite

View File

@ -44,7 +44,7 @@ Hair Segmentation
[Object Detection](https://google.github.io/mediapipe/solutions/object_detection) | ✅ | ✅ | ✅ | | | ✅ [Object Detection](https://google.github.io/mediapipe/solutions/object_detection) | ✅ | ✅ | ✅ | | | ✅
[Box Tracking](https://google.github.io/mediapipe/solutions/box_tracking) | ✅ | ✅ | ✅ | | | [Box Tracking](https://google.github.io/mediapipe/solutions/box_tracking) | ✅ | ✅ | ✅ | | |
[Instant Motion Tracking](https://google.github.io/mediapipe/solutions/instant_motion_tracking) | ✅ | | | | | [Instant Motion Tracking](https://google.github.io/mediapipe/solutions/instant_motion_tracking) | ✅ | | | | |
[Objectron](https://google.github.io/mediapipe/solutions/objectron) | ✅ | | | ✅ | | [Objectron](https://google.github.io/mediapipe/solutions/objectron) | ✅ | | ✅ | ✅ | |
[KNIFT](https://google.github.io/mediapipe/solutions/knift) | ✅ | | | | | [KNIFT](https://google.github.io/mediapipe/solutions/knift) | ✅ | | | | |
[AutoFlip](https://google.github.io/mediapipe/solutions/autoflip) | | | ✅ | | | [AutoFlip](https://google.github.io/mediapipe/solutions/autoflip) | | | ✅ | | |
[MediaSequence](https://google.github.io/mediapipe/solutions/media_sequence) | | | ✅ | | | [MediaSequence](https://google.github.io/mediapipe/solutions/media_sequence) | | | ✅ | | |

View File

@ -2,16 +2,19 @@ workspace(name = "mediapipe")
load("@bazel_tools//tools/build_defs/repo:http.bzl", "http_archive") load("@bazel_tools//tools/build_defs/repo:http.bzl", "http_archive")
skylib_version = "0.9.0"
http_archive( http_archive(
name = "bazel_skylib", name = "bazel_skylib",
type = "tar.gz", type = "tar.gz",
url = "https://github.com/bazelbuild/bazel-skylib/releases/download/{}/bazel_skylib-{}.tar.gz".format (skylib_version, skylib_version), urls = [
sha256 = "1dde365491125a3db70731e25658dfdd3bc5dbdfd11b840b3e987ecf043c7ca0", "https://github.com/bazelbuild/bazel-skylib/releases/download/1.0.3/bazel-skylib-1.0.3.tar.gz",
"https://mirror.bazel.build/github.com/bazelbuild/bazel-skylib/releases/download/1.0.3/bazel-skylib-1.0.3.tar.gz",
],
sha256 = "1c531376ac7e5a180e0237938a2536de0c54d93f5c278634818e0efc952dd56c",
) )
load("@bazel_skylib//:workspace.bzl", "bazel_skylib_workspace")
bazel_skylib_workspace()
load("@bazel_skylib//lib:versions.bzl", "versions") load("@bazel_skylib//lib:versions.bzl", "versions")
versions.check(minimum_bazel_version = "3.4.0") versions.check(minimum_bazel_version = "3.7.2")
# ABSL cpp library lts_2020_09_23 # ABSL cpp library lts_2020_09_23
http_archive( http_archive(
@ -38,8 +41,8 @@ http_archive(
http_archive( http_archive(
name = "rules_foreign_cc", name = "rules_foreign_cc",
strip_prefix = "rules_foreign_cc-main", strip_prefix = "rules_foreign_cc-0.1.0",
url = "https://github.com/bazelbuild/rules_foreign_cc/archive/main.zip", url = "https://github.com/bazelbuild/rules_foreign_cc/archive/0.1.0.zip",
) )
load("@rules_foreign_cc//:workspace_definitions.bzl", "rules_foreign_cc_dependencies") load("@rules_foreign_cc//:workspace_definitions.bzl", "rules_foreign_cc_dependencies")
@ -304,8 +307,8 @@ http_archive(
# Maven dependencies. # Maven dependencies.
RULES_JVM_EXTERNAL_TAG = "3.2" RULES_JVM_EXTERNAL_TAG = "4.0"
RULES_JVM_EXTERNAL_SHA = "82262ff4223c5fda6fb7ff8bd63db8131b51b413d26eb49e3131037e79e324af" RULES_JVM_EXTERNAL_SHA = "31701ad93dbfe544d597dbe62c9a1fdd76d81d8a9150c2bf1ecf928ecdf97169"
http_archive( http_archive(
name = "rules_jvm_external", name = "rules_jvm_external",
@ -318,7 +321,6 @@ load("@rules_jvm_external//:defs.bzl", "maven_install")
# Important: there can only be one maven_install rule. Add new maven deps here. # Important: there can only be one maven_install rule. Add new maven deps here.
maven_install( maven_install(
name = "maven",
artifacts = [ artifacts = [
"androidx.concurrent:concurrent-futures:1.0.0-alpha03", "androidx.concurrent:concurrent-futures:1.0.0-alpha03",
"androidx.lifecycle:lifecycle-common:2.2.0", "androidx.lifecycle:lifecycle-common:2.2.0",
@ -343,10 +345,10 @@ maven_install(
"org.hamcrest:hamcrest-library:1.3", "org.hamcrest:hamcrest-library:1.3",
], ],
repositories = [ repositories = [
"https://jcenter.bintray.com",
"https://maven.google.com", "https://maven.google.com",
"https://dl.google.com/dl/android/maven2", "https://dl.google.com/dl/android/maven2",
"https://repo1.maven.org/maven2", "https://repo1.maven.org/maven2",
"https://jcenter.bintray.com",
], ],
fetch_sources = True, fetch_sources = True,
version_conflict_policy = "pinned", version_conflict_policy = "pinned",
@ -363,10 +365,10 @@ http_archive(
], ],
) )
#Tensorflow repo should always go after the other external dependencies. # Tensorflow repo should always go after the other external dependencies.
# 2020-12-09 # 2021-03-25
_TENSORFLOW_GIT_COMMIT = "0eadbb13cef1226b1bae17c941f7870734d97f8a" _TENSORFLOW_GIT_COMMIT = "c67f68021824410ebe9f18513b8856ac1c6d4887"
_TENSORFLOW_SHA256= "4ae06daa5b09c62f31b7bc1f781fd59053f286dd64355830d8c2ac601b795ef0" _TENSORFLOW_SHA256= "fd07d0b39422dc435e268c5e53b2646a8b4b1e3151b87837b43f86068faae87f"
http_archive( http_archive(
name = "org_tensorflow", name = "org_tensorflow",
urls = [ urls = [
@ -383,5 +385,7 @@ http_archive(
sha256 = _TENSORFLOW_SHA256, sha256 = _TENSORFLOW_SHA256,
) )
load("@org_tensorflow//tensorflow:workspace.bzl", "tf_workspace") load("@org_tensorflow//tensorflow:workspace3.bzl", "tf_workspace3")
tf_workspace(tf_repo_name = "org_tensorflow") tf_workspace3()
load("@org_tensorflow//tensorflow:workspace2.bzl", "tf_workspace2")
tf_workspace2()

View File

@ -12,19 +12,30 @@ nav_order: 3
{:toc} {:toc}
--- ---
Each calculator is a node of of a graph. We describe how to create a new calculator, how to initialize a calculator, how to perform its calculations, input and output streams, timestamps, and options Calculators communicate by sending and receiving packets. Typically a single
packet is sent along each input stream at each input timestamp. A packet can
contain any kind of data, such as a single frame of video or a single integer
detection count.
## Creating a packet ## Creating a packet
Packets are generally created with `MediaPipe::Adopt()` (from packet.h). Packets are generally created with `mediapipe::MakePacket<T>()` or
`mediapipe::Adopt()` (from packet.h).
```c++ ```c++
// Create some data. // Create a packet containing some new data.
auto data = absl::make_unique<MyDataClass>("constructor_argument"); Packet p = MakePacket<MyDataClass>("constructor_argument");
// Create a packet to own the data.
Packet p = Adopt(data.release());
// Make a new packet with the same data and a different timestamp. // Make a new packet with the same data and a different timestamp.
Packet p2 = p.At(Timestamp::PostStream()); Packet p2 = p.At(Timestamp::PostStream());
``` ```
or:
```c++
// Create some new data.
auto data = absl::make_unique<MyDataClass>("constructor_argument");
// Create a packet to own the data.
Packet p = Adopt(data.release()).At(Timestamp::PostStream());
```
Data within a packet is accessed with `Packet::Get<T>()` Data within a packet is accessed with `Packet::Get<T>()`

View File

@ -28,7 +28,7 @@ Gradle.
* Install MediaPipe following these [instructions](./install.md). * Install MediaPipe following these [instructions](./install.md).
* Setup Java Runtime. * Setup Java Runtime.
* Setup Android SDK release 28.0.3 and above. * Setup Android SDK release 28.0.3 and above.
* Setup Android NDK r18b and above. * Setup Android NDK version between 18 and 21.
MediaPipe recommends setting up Android SDK and NDK via Android Studio (and see MediaPipe recommends setting up Android SDK and NDK via Android Studio (and see
below for Android Studio setup). However, if you prefer using MediaPipe without below for Android Studio setup). However, if you prefer using MediaPipe without

View File

@ -25,25 +25,11 @@ install --user six`.
## Installing on Debian and Ubuntu ## Installing on Debian and Ubuntu
1. Install Bazel. 1. Install Bazelisk.
Follow the official Follow the official
[Bazel documentation](https://docs.bazel.build/versions/master/install-ubuntu.html) [Bazel documentation](https://docs.bazel.build/versions/master/install-bazelisk.html)
to install Bazel 3.4 or higher. to install Bazelisk.
For Nvidia Jetson and Raspberry Pi devices with aarch64 Linux, Bazel needs
to be built from source:
```bash
# For Bazel 3.4.1
mkdir $HOME/bazel-3.4.1
cd $HOME/bazel-3.4.1
wget https://github.com/bazelbuild/bazel/releases/download/3.4.1/bazel-3.4.1-dist.zip
sudo apt-get install build-essential openjdk-8-jdk python zip unzip
unzip bazel-3.4.1-dist.zip
env EXTRA_BAZEL_ARGS="--host_javabase=@local_jdk//:jdk" bash ./compile.sh
sudo cp output/bazel /usr/local/bin/
```
2. Checkout MediaPipe repository. 2. Checkout MediaPipe repository.
@ -207,11 +193,11 @@ build issues.
**Disclaimer**: Running MediaPipe on CentOS is experimental. **Disclaimer**: Running MediaPipe on CentOS is experimental.
1. Install Bazel. 1. Install Bazelisk.
Follow the official Follow the official
[Bazel documentation](https://docs.bazel.build/versions/master/install-redhat.html) [Bazel documentation](https://docs.bazel.build/versions/master/install-bazelisk.html)
to install Bazel 3.4 or higher. to install Bazelisk.
2. Checkout MediaPipe repository. 2. Checkout MediaPipe repository.
@ -336,11 +322,11 @@ build issues.
* Install [Xcode](https://developer.apple.com/xcode/) and its Command Line * Install [Xcode](https://developer.apple.com/xcode/) and its Command Line
Tools by `xcode-select --install`. Tools by `xcode-select --install`.
2. Install Bazel. 2. Install Bazelisk.
Follow the official Follow the official
[Bazel documentation](https://docs.bazel.build/versions/master/install-os-x.html#install-with-installer-mac-os-x) [Bazel documentation](https://docs.bazel.build/versions/master/install-bazelisk.html)
to install Bazel 3.4 or higher. to install Bazelisk.
3. Checkout MediaPipe repository. 3. Checkout MediaPipe repository.
@ -353,7 +339,7 @@ build issues.
4. Install OpenCV and FFmpeg. 4. Install OpenCV and FFmpeg.
Option 1. Use HomeBrew package manager tool to install the pre-compiled Option 1. Use HomeBrew package manager tool to install the pre-compiled
OpenCV 3.4.5 libraries. FFmpeg will be installed via OpenCV. OpenCV 3 libraries. FFmpeg will be installed via OpenCV.
```bash ```bash
$ brew install opencv@3 $ brew install opencv@3
@ -484,29 +470,36 @@ next section.
4. Install Visual C++ Build Tools 2019 and WinSDK 4. Install Visual C++ Build Tools 2019 and WinSDK
Go to https://visualstudio.microsoft.com/visual-cpp-build-tools, download Go to
build tools, and install Microsoft Visual C++ 2019 Redistributable and [the VisualStudio website](ttps://visualstudio.microsoft.com/visual-cpp-build-tools),
Microsoft Build Tools 2019. download build tools, and install Microsoft Visual C++ 2019 Redistributable
and Microsoft Build Tools 2019.
Download the WinSDK from Download the WinSDK from
https://developer.microsoft.com/en-us/windows/downloads/windows-10-sdk/ and [the official MicroSoft website](https://developer.microsoft.com/en-us/windows/downloads/windows-10-sdk/)
install. and install.
5. Install Bazel and add the location of the Bazel executable to the `%PATH%` 5. Install Bazel or Bazelisk and add the location of the Bazel executable to
environment variable. the `%PATH%` environment variable.
Follow the official Option 1. Follow
[Bazel documentation](https://docs.bazel.build/versions/master/install-windows.html) [the official Bazel documentation](https://docs.bazel.build/versions/master/install-windows.html)
to install Bazel 3.4 or higher. to install Bazel 3.7.2 or higher.
6. Set Bazel variables. Option 2. Follow the official
[Bazel documentation](https://docs.bazel.build/versions/master/install-bazelisk.html)
to install Bazelisk.
6. Set Bazel variables. Learn more details about
["Build on Windows"](https://docs.bazel.build/versions/master/windows.html#build-c-with-msvc)
in the Bazel official documentation.
``` ```
# Find the exact paths and version numbers from your local version. # Please find the exact paths and version numbers from your local version.
C:\> set BAZEL_VS=C:\Program Files (x86)\Microsoft Visual Studio\2019\BuildTools C:\> set BAZEL_VS=C:\Program Files (x86)\Microsoft Visual Studio\2019\BuildTools
C:\> set BAZEL_VC=C:\Program Files (x86)\Microsoft Visual Studio\2019\BuildTools\VC C:\> set BAZEL_VC=C:\Program Files (x86)\Microsoft Visual Studio\2019\BuildTools\VC
C:\> set BAZEL_VC_FULL_VERSION=14.25.28610 C:\> set BAZEL_VC_FULL_VERSION=<Your local VC version>
C:\> set BAZEL_WINSDK_FULL_VERSION=10.1.18362.1 C:\> set BAZEL_WINSDK_FULL_VERSION=<Your local WinSDK version>
``` ```
7. Checkout MediaPipe repository. 7. Checkout MediaPipe repository.
@ -593,19 +586,11 @@ cameras. Alternatively, you use a video file as input.
username@DESKTOP-TMVLBJ1:~$ sudo apt-get update && sudo apt-get install -y build-essential git python zip adb openjdk-8-jdk username@DESKTOP-TMVLBJ1:~$ sudo apt-get update && sudo apt-get install -y build-essential git python zip adb openjdk-8-jdk
``` ```
5. Install Bazel. 5. Install Bazelisk.
```bash Follow the official
username@DESKTOP-TMVLBJ1:~$ curl -sLO --retry 5 --retry-max-time 10 \ [Bazel documentation](https://docs.bazel.build/versions/master/install-bazelisk.html)
https://storage.googleapis.com/bazel/3.4.1/release/bazel-3.4.1-installer-linux-x86_64.sh && \ to install Bazelisk.
sudo mkdir -p /usr/local/bazel/3.4.1 && \
chmod 755 bazel-3.4.1-installer-linux-x86_64.sh && \
sudo ./bazel-3.4.1-installer-linux-x86_64.sh --prefix=/usr/local/bazel/3.4.1 && \
source /usr/local/bazel/3.4.1/lib/bazel/bin/bazel-complete.bash
username@DESKTOP-TMVLBJ1:~$ /usr/local/bazel/3.4.1/lib/bazel/bin/bazel version && \
alias bazel='/usr/local/bazel/3.4.1/lib/bazel/bin/bazel'
```
6. Checkout MediaPipe repository. 6. Checkout MediaPipe repository.

View File

@ -44,7 +44,7 @@ Hair Segmentation
[Object Detection](https://google.github.io/mediapipe/solutions/object_detection) | ✅ | ✅ | ✅ | | | ✅ [Object Detection](https://google.github.io/mediapipe/solutions/object_detection) | ✅ | ✅ | ✅ | | | ✅
[Box Tracking](https://google.github.io/mediapipe/solutions/box_tracking) | ✅ | ✅ | ✅ | | | [Box Tracking](https://google.github.io/mediapipe/solutions/box_tracking) | ✅ | ✅ | ✅ | | |
[Instant Motion Tracking](https://google.github.io/mediapipe/solutions/instant_motion_tracking) | ✅ | | | | | [Instant Motion Tracking](https://google.github.io/mediapipe/solutions/instant_motion_tracking) | ✅ | | | | |
[Objectron](https://google.github.io/mediapipe/solutions/objectron) | ✅ | | | ✅ | | [Objectron](https://google.github.io/mediapipe/solutions/objectron) | ✅ | | ✅ | ✅ | |
[KNIFT](https://google.github.io/mediapipe/solutions/knift) | ✅ | | | | | [KNIFT](https://google.github.io/mediapipe/solutions/knift) | ✅ | | | | |
[AutoFlip](https://google.github.io/mediapipe/solutions/autoflip) | | | ✅ | | | [AutoFlip](https://google.github.io/mediapipe/solutions/autoflip) | | | ✅ | | |
[MediaSequence](https://google.github.io/mediapipe/solutions/media_sequence) | | | ✅ | | | [MediaSequence](https://google.github.io/mediapipe/solutions/media_sequence) | | | ✅ | | |

View File

@ -183,8 +183,8 @@ function onResults(results) {
canvasCtx.restore(); canvasCtx.restore();
} }
const faceDetection = new Objectron({locateFile: (file) => { const faceDetection = new FaceDetection({locateFile: (file) => {
return `https://cdn.jsdelivr.net/npm/@mediapipe/objectron@0.0/${file}`; return `https://cdn.jsdelivr.net/npm/@mediapipe/face_detection@0.0/${file}`;
}}); }});
faceDetection.setOptions({ faceDetection.setOptions({
minDetectionConfidence: 0.5 minDetectionConfidence: 0.5

View File

@ -358,15 +358,17 @@ cap.release()
## Example Apps ## Example Apps
Please first see general instructions for Please first see general instructions for
[Android](../getting_started/android.md) and [iOS](../getting_started/ios.md) on [Android](../getting_started/android.md), [iOS](../getting_started/ios.md), and
how to build MediaPipe examples. [desktop](../getting_started/cpp.md) on how to build MediaPipe examples.
Note: To visualize a graph, copy the graph and paste it into Note: To visualize a graph, copy the graph and paste it into
[MediaPipe Visualizer](https://viz.mediapipe.dev/). For more information on how [MediaPipe Visualizer](https://viz.mediapipe.dev/). For more information on how
to visualize its associated subgraphs, please see to visualize its associated subgraphs, please see
[visualizer documentation](../tools/visualizer.md). [visualizer documentation](../tools/visualizer.md).
### Two-stage Objectron ### Mobile
#### Two-stage Objectron
* Graph: * Graph:
[`mediapipe/graphs/object_detection_3d/object_occlusion_tracking.pbtxt`](https://github.com/google/mediapipe/tree/master/mediapipe/graphs/object_detection_3d/object_occlusion_tracking.pbtxt) [`mediapipe/graphs/object_detection_3d/object_occlusion_tracking.pbtxt`](https://github.com/google/mediapipe/tree/master/mediapipe/graphs/object_detection_3d/object_occlusion_tracking.pbtxt)
@ -404,7 +406,7 @@ to visualize its associated subgraphs, please see
* iOS target: Not available * iOS target: Not available
### Single-stage Objectron #### Single-stage Objectron
* Graph: * Graph:
[`mediapipe/graphs/object_detection_3d/object_occlusion_tracking_1stage.pbtxt`](https://github.com/google/mediapipe/tree/master/mediapipe/graphs/object_detection_3d/object_occlusion_tracking.pbtxt) [`mediapipe/graphs/object_detection_3d/object_occlusion_tracking_1stage.pbtxt`](https://github.com/google/mediapipe/tree/master/mediapipe/graphs/object_detection_3d/object_occlusion_tracking.pbtxt)
@ -428,7 +430,7 @@ to visualize its associated subgraphs, please see
* iOS target: Not available * iOS target: Not available
### Assets #### Assets
Example app bounding boxes are rendered with [GlAnimationOverlayCalculator](https://github.com/google/mediapipe/tree/master/mediapipe/graphs/object_detection_3d/calculators/gl_animation_overlay_calculator.cc) using a parsing of the sequenced .obj file Example app bounding boxes are rendered with [GlAnimationOverlayCalculator](https://github.com/google/mediapipe/tree/master/mediapipe/graphs/object_detection_3d/calculators/gl_animation_overlay_calculator.cc) using a parsing of the sequenced .obj file
format into a custom .uuu format. This can be done for user assets as follows: format into a custom .uuu format. This can be done for user assets as follows:
@ -449,9 +451,35 @@ Example app bounding boxes are rendered with [GlAnimationOverlayCalculator](http
> single .uuu animation file, using the order given by sorting the filenames alphanumerically. Also the ObjParser directory inputs must be given as > single .uuu animation file, using the order given by sorting the filenames alphanumerically. Also the ObjParser directory inputs must be given as
> absolute paths, not relative paths. See parser utility library at [`mediapipe/graphs/object_detection_3d/obj_parser/`](https://github.com/google/mediapipe/tree/master/mediapipe/graphs/object_detection_3d/obj_parser/) for more details. > absolute paths, not relative paths. See parser utility library at [`mediapipe/graphs/object_detection_3d/obj_parser/`](https://github.com/google/mediapipe/tree/master/mediapipe/graphs/object_detection_3d/obj_parser/) for more details.
### Coordinate Systems
#### Object Coordinate ### Desktop
To build the application, run:
```bash
bazel build -c opt --define MEDIAPIPE_DISABLE_GPU=1 mediapipe/examples/desktop/object_detection_3d:objectron_cpu
```
To run the application, replace `<input video path>` and `<output video path>`
in the command below with your own paths, and `<landmark model path>` and
`<allowed labels>` with the following:
Category | `<landmark model path>` | `<allowed labels>`
:------- | :-------------------------------------------------------------------------- | :-----------------
Shoe | mediapipe/modules/objectron/object_detection_3d_sneakers.tflite | Footwear
Chair | mediapipe/modules/objectron/object_detection_3d_chair.tflite | Chair
Cup | mediapipe/modules/objectron/object_detection_3d_cup.tflite | Mug
Camera | mediapipe/modules/objectron/object_detection_3d_camera.tflite | Camera
```
GLOG_logtostderr=1 bazel-bin/mediapipe/examples/desktop/object_detection_3d/objectron_cpu \
--calculator_graph_config_file=mediapipe/graphs/object_detection_3d/objectron_desktop_cpu.pbtxt \
--input_side_packets=input_video_path=<input video path>,output_video_path=<output video path>,box_landmark_model_path=<landmark model path>,allowed_labels=<allowed labels>
```
## Coordinate Systems
### Object Coordinate
Each object has its object coordinate frame. We use the below object coordinate Each object has its object coordinate frame. We use the below object coordinate
definition, with `+x` pointing right, `+y` pointing up and `+z` pointing front, definition, with `+x` pointing right, `+y` pointing up and `+z` pointing front,
@ -459,7 +487,7 @@ origin is at the center of the 3D bounding box.
![box_coordinate.svg](../images/box_coordinate.svg) ![box_coordinate.svg](../images/box_coordinate.svg)
#### Camera Coordinate ### Camera Coordinate
A 3D object is parameterized by its `scale` and `rotation`, `translation` with A 3D object is parameterized by its `scale` and `rotation`, `translation` with
regard to the camera coordinate frame. In this API we use the below camera regard to the camera coordinate frame. In this API we use the below camera
@ -476,7 +504,7 @@ camera frame by applying `rotation` and `translation`:
landmarks_3d = rotation * scale * unit_box + translation landmarks_3d = rotation * scale * unit_box + translation
``` ```
#### NDC Space ### NDC Space
In this API we use In this API we use
[NDC(normalized device coordinates)](http://www.songho.ca/opengl/gl_projectionmatrix.html) [NDC(normalized device coordinates)](http://www.songho.ca/opengl/gl_projectionmatrix.html)
@ -495,7 +523,7 @@ y_ndc = -fy * Y / Z + py
z_ndc = 1 / Z z_ndc = 1 / Z
``` ```
#### Pixel Space ### Pixel Space
In this API we set upper-left coner of an image as the origin of pixel In this API we set upper-left coner of an image as the origin of pixel
coordinate. One can convert from NDC to pixel space as follows: coordinate. One can convert from NDC to pixel space as follows:
@ -532,10 +560,11 @@ py = -py_pixel * 2.0 / image_height + 1.0
[Announcing the Objectron Dataset](https://ai.googleblog.com/2020/11/announcing-objectron-dataset.html) [Announcing the Objectron Dataset](https://ai.googleblog.com/2020/11/announcing-objectron-dataset.html)
* Google AI Blog: * Google AI Blog:
[Real-Time 3D Object Detection on Mobile Devices with MediaPipe](https://ai.googleblog.com/2020/03/real-time-3d-object-detection-on-mobile.html) [Real-Time 3D Object Detection on Mobile Devices with MediaPipe](https://ai.googleblog.com/2020/03/real-time-3d-object-detection-on-mobile.html)
* Paper: [Objectron: A Large Scale Dataset of Object-Centric Videos in the Wild with Pose Annotations](https://arxiv.org/abs/2012.09988), to appear in CVPR 2021
* Paper: [MobilePose: Real-Time Pose Estimation for Unseen Objects with Weak * Paper: [MobilePose: Real-Time Pose Estimation for Unseen Objects with Weak
Shape Supervision](https://arxiv.org/abs/2003.03522) Shape Supervision](https://arxiv.org/abs/2003.03522)
* Paper: * Paper:
[Instant 3D Object Tracking with Applications in Augmented Reality](https://drive.google.com/open?id=1O_zHmlgXIzAdKljp20U_JUkEHOGG52R8) [Instant 3D Object Tracking with Applications in Augmented Reality](https://drive.google.com/open?id=1O_zHmlgXIzAdKljp20U_JUkEHOGG52R8)
([presentation](https://www.youtube.com/watch?v=9ndF1AIo7h0)) ([presentation](https://www.youtube.com/watch?v=9ndF1AIo7h0)), Fourth Workshop on Computer Vision for AR/VR, CVPR 2020
* [Models and model cards](./models.md#objectron) * [Models and model cards](./models.md#objectron)
* [Python Colab](https://mediapipe.page.link/objectron_py_colab) * [Python Colab](https://mediapipe.page.link/objectron_py_colab)

View File

@ -25,10 +25,11 @@ One of the applications
[BlazePose](https://ai.googleblog.com/2020/08/on-device-real-time-body-pose-tracking.html) [BlazePose](https://ai.googleblog.com/2020/08/on-device-real-time-body-pose-tracking.html)
can enable is fitness. More specifically - pose classification and repetition can enable is fitness. More specifically - pose classification and repetition
counting. In this section we'll provide basic guidance on building a custom pose counting. In this section we'll provide basic guidance on building a custom pose
classifier with the help of [Colabs](#colabs) and wrap it in a simple classifier with the help of [Colabs](#colabs) and wrap it in a simple fitness
[fitness app](https://mediapipe.page.link/mlkit-pose-classification-demo-app) demo within
powered by [ML Kit](https://developers.google.com/ml-kit). Push-ups and squats [ML Kit quickstart app](https://developers.google.com/ml-kit/vision/pose-detection/classifying-poses#4_integrate_with_the_ml_kit_quickstart_app).
are used for demonstration purposes as the most common exercises. Push-ups and squats are used for demonstration purposes as the most common
exercises.
![pose_classification_pushups_and_squats.gif](../images/mobile/pose_classification_pushups_and_squats.gif) | ![pose_classification_pushups_and_squats.gif](../images/mobile/pose_classification_pushups_and_squats.gif) |
:--------------------------------------------------------------------------------------------------------: | :--------------------------------------------------------------------------------------------------------: |
@ -47,7 +48,7 @@ determines the object's class based on the closest samples in the training set.
classifier and form a training set using these [Colabs](#colabs), classifier and form a training set using these [Colabs](#colabs),
3. Perform the classification itself followed by repetition counting (e.g., in 3. Perform the classification itself followed by repetition counting (e.g., in
the the
[ML Kit demo app](https://mediapipe.page.link/mlkit-pose-classification-demo-app)). [ML Kit quickstart app](https://developers.google.com/ml-kit/vision/pose-detection/classifying-poses#4_integrate_with_the_ml_kit_quickstart_app)).
## Training Set ## Training Set
@ -76,7 +77,7 @@ video right in the Colab.
Code of the classifier is available both in the Code of the classifier is available both in the
[`Pose Classification Colab (Extended)`] and in the [`Pose Classification Colab (Extended)`] and in the
[ML Kit demo app](https://mediapipe.page.link/mlkit-pose-classification-demo-app). [ML Kit quickstart app](https://developers.google.com/ml-kit/vision/pose-detection/classifying-poses#4_integrate_with_the_ml_kit_quickstart_app).
Please refer to them for details of the approach described below. Please refer to them for details of the approach described below.
The k-NN algorithm used for pose classification requires a feature vector The k-NN algorithm used for pose classification requires a feature vector
@ -127,11 +128,13 @@ where the pose class and the counter can't be changed.
## Future Work ## Future Work
We are actively working on improving BlazePose GHUM 3D's Z prediction. It will We are actively working on improving
allow us to use joint angles in the feature vectors, which are more natural and [BlazePose GHUM 3D](./pose.md#pose-landmark-model-blazepose-ghum-3d)'s Z
easier to configure (although distances can still be useful to detect touches prediction. It will allow us to use joint angles in the feature vectors, which
between body parts) and to perform rotation normalization of poses and reduce are more natural and easier to configure (although distances can still be useful
the number of camera angles required for accurate k-NN classification. to detect touches between body parts) and to perform rotation normalization of
poses and reduce the number of camera angles required for accurate k-NN
classification.
## Colabs ## Colabs

View File

@ -28,7 +28,7 @@ has_toc: false
[Object Detection](https://google.github.io/mediapipe/solutions/object_detection) | ✅ | ✅ | ✅ | | | ✅ [Object Detection](https://google.github.io/mediapipe/solutions/object_detection) | ✅ | ✅ | ✅ | | | ✅
[Box Tracking](https://google.github.io/mediapipe/solutions/box_tracking) | ✅ | ✅ | ✅ | | | [Box Tracking](https://google.github.io/mediapipe/solutions/box_tracking) | ✅ | ✅ | ✅ | | |
[Instant Motion Tracking](https://google.github.io/mediapipe/solutions/instant_motion_tracking) | ✅ | | | | | [Instant Motion Tracking](https://google.github.io/mediapipe/solutions/instant_motion_tracking) | ✅ | | | | |
[Objectron](https://google.github.io/mediapipe/solutions/objectron) | ✅ | | | ✅ | | [Objectron](https://google.github.io/mediapipe/solutions/objectron) | ✅ | | | ✅ | |
[KNIFT](https://google.github.io/mediapipe/solutions/knift) | ✅ | | | | | [KNIFT](https://google.github.io/mediapipe/solutions/knift) | ✅ | | | | |
[AutoFlip](https://google.github.io/mediapipe/solutions/autoflip) | | | ✅ | | | [AutoFlip](https://google.github.io/mediapipe/solutions/autoflip) | | | ✅ | | |
[MediaSequence](https://google.github.io/mediapipe/solutions/media_sequence) | | | ✅ | | | [MediaSequence](https://google.github.io/mediapipe/solutions/media_sequence) | | | ✅ | | |

View File

@ -41,6 +41,7 @@ profiler_config {
trace_enabled: true trace_enabled: true
enable_profiler: true enable_profiler: true
trace_log_interval_count: 200 trace_log_interval_count: 200
trace_log_path: "/sdcard/Download/"
} }
``` ```
@ -64,7 +65,7 @@ MediaPipe will emit data into a pre-specified directory:
* On the desktop, this will be the `/tmp` directory. * On the desktop, this will be the `/tmp` directory.
* On Android, this will be the `/sdcard` directory. * On Android, this will be the external storage directory (e.g., `/storage/emulated/0/`).
* On iOS, this can be reached through XCode. Select "Window/Devices and * On iOS, this can be reached through XCode. Select "Window/Devices and
Simulators" and select the "Devices" tab. Simulators" and select the "Devices" tab.
@ -103,7 +104,7 @@ we record ten intervals of half a second each. This can be overridden by adding
* Include the line below in your `AndroidManifest.xml` file. * Include the line below in your `AndroidManifest.xml` file.
```xml ```xml
<uses-permission android:name="android.permission.WRITE_EXTERNAL_STORAGE" /> <uses-permission android:name="android.permission.MANAGE_EXTERNAL_STORAGE" />
``` ```
* Grant the permission either upon first app launch, or by going into * Grant the permission either upon first app launch, or by going into
@ -130,8 +131,8 @@ we record ten intervals of half a second each. This can be overridden by adding
events to a trace log files at: events to a trace log files at:
```bash ```bash
/sdcard/mediapipe_trace_0.binarypb /storage/emulated/0/Download/mediapipe_trace_0.binarypb
/sdcard/mediapipe_trace_1.binarypb /storage/emulated/0/Download/mediapipe_trace_1.binarypb
``` ```
After every 5 sec, writing shifts to a successive trace log file, such that After every 5 sec, writing shifts to a successive trace log file, such that
@ -139,10 +140,10 @@ we record ten intervals of half a second each. This can be overridden by adding
trace files have been written to the device using adb shell. trace files have been written to the device using adb shell.
```bash ```bash
adb shell "ls -la /sdcard/" adb shell "ls -la /storage/emulated/0/Download"
``` ```
On android, MediaPipe selects the external storage directory `/sdcard` for On android, MediaPipe selects the external storage (e.g., `/storage/emulated/0/`) for
trace logs. This directory can be overridden using the setting trace logs. This directory can be overridden using the setting
`trace_log_path`, like: `trace_log_path`, like:
@ -150,7 +151,7 @@ we record ten intervals of half a second each. This can be overridden by adding
profiler_config { profiler_config {
trace_enabled: true trace_enabled: true
enable_profiler: true enable_profiler: true
trace_log_path: "/sdcard/profiles/" trace_log_path: "/sdcard/Download/profiles/"
} }
``` ```
@ -161,7 +162,7 @@ we record ten intervals of half a second each. This can be overridden by adding
```bash ```bash
# from your terminal # from your terminal
adb pull /sdcard/mediapipe_trace_0.binarypb adb pull /storage/emulated/0/Download/mediapipe_trace_0.binarypb
# if successful you should see something like # if successful you should see something like
# /sdcard/mediapipe_trace_0.binarypb: 1 file pulled. 0.1 MB/s (6766 bytes in 0.045s) # /sdcard/mediapipe_trace_0.binarypb: 1 file pulled. 0.1 MB/s (6766 bytes in 0.045s)
``` ```

View File

@ -128,7 +128,7 @@ cc_library(
"//mediapipe/framework/port:ret_check", "//mediapipe/framework/port:ret_check",
"//mediapipe/util:time_series_util", "//mediapipe/util:time_series_util",
"@com_google_absl//absl/strings", "@com_google_absl//absl/strings",
"@eigen_archive//:eigen", "@eigen_archive//:eigen3",
], ],
alwayslink = 1, alwayslink = 1,
) )
@ -147,7 +147,7 @@ cc_library(
"//mediapipe/util:time_series_util", "//mediapipe/util:time_series_util",
"@com_google_absl//absl/strings", "@com_google_absl//absl/strings",
"@com_google_audio_tools//audio/dsp/mfcc", "@com_google_audio_tools//audio/dsp/mfcc",
"@eigen_archive//:eigen", "@eigen_archive//:eigen3",
], ],
alwayslink = 1, alwayslink = 1,
) )
@ -168,7 +168,7 @@ cc_library(
"@com_google_absl//absl/strings", "@com_google_absl//absl/strings",
"@com_google_audio_tools//audio/dsp:resampler", "@com_google_audio_tools//audio/dsp:resampler",
"@com_google_audio_tools//audio/dsp:resampler_q", "@com_google_audio_tools//audio/dsp:resampler_q",
"@eigen_archive//:eigen", "@eigen_archive//:eigen3",
], ],
alwayslink = 1, alwayslink = 1,
) )
@ -208,7 +208,7 @@ cc_library(
"@com_google_absl//absl/strings", "@com_google_absl//absl/strings",
"@com_google_audio_tools//audio/dsp:window_functions", "@com_google_audio_tools//audio/dsp:window_functions",
"@com_google_audio_tools//audio/dsp/spectrogram", "@com_google_audio_tools//audio/dsp/spectrogram",
"@eigen_archive//:eigen", "@eigen_archive//:eigen3",
], ],
alwayslink = 1, alwayslink = 1,
) )
@ -228,7 +228,7 @@ cc_library(
"//mediapipe/framework/port:status", "//mediapipe/framework/port:status",
"//mediapipe/util:time_series_util", "//mediapipe/util:time_series_util",
"@com_google_audio_tools//audio/dsp:window_functions", "@com_google_audio_tools//audio/dsp:window_functions",
"@eigen_archive//:eigen", "@eigen_archive//:eigen3",
], ],
alwayslink = 1, alwayslink = 1,
) )
@ -242,9 +242,9 @@ cc_test(
"//mediapipe/framework:calculator_runner", "//mediapipe/framework:calculator_runner",
"//mediapipe/framework/deps:file_path", "//mediapipe/framework/deps:file_path",
"//mediapipe/framework/formats:time_series_header_cc_proto", "//mediapipe/framework/formats:time_series_header_cc_proto",
"//mediapipe/framework/port:commandlineflags",
"//mediapipe/framework/port:gtest_main", "//mediapipe/framework/port:gtest_main",
"//mediapipe/framework/port:parse_text_proto", "//mediapipe/framework/port:parse_text_proto",
"@com_google_absl//absl/flags:flag",
], ],
) )
@ -261,7 +261,7 @@ cc_test(
"//mediapipe/framework/port:integral_types", "//mediapipe/framework/port:integral_types",
"//mediapipe/framework/port:parse_text_proto", "//mediapipe/framework/port:parse_text_proto",
"//mediapipe/util:time_series_test_util", "//mediapipe/util:time_series_test_util",
"@eigen_archive//:eigen", "@eigen_archive//:eigen3",
], ],
) )
@ -276,7 +276,7 @@ cc_test(
"//mediapipe/framework/port:gtest_main", "//mediapipe/framework/port:gtest_main",
"//mediapipe/framework/port:status", "//mediapipe/framework/port:status",
"//mediapipe/util:time_series_test_util", "//mediapipe/util:time_series_test_util",
"@eigen_archive//:eigen", "@eigen_archive//:eigen3",
], ],
) )
@ -296,7 +296,7 @@ cc_test(
"//mediapipe/framework/port:status", "//mediapipe/framework/port:status",
"//mediapipe/util:time_series_test_util", "//mediapipe/util:time_series_test_util",
"@com_google_audio_tools//audio/dsp:number_util", "@com_google_audio_tools//audio/dsp:number_util",
"@eigen_archive//:eigen", "@eigen_archive//:eigen3",
], ],
) )
@ -314,7 +314,7 @@ cc_test(
"//mediapipe/framework/port:integral_types", "//mediapipe/framework/port:integral_types",
"//mediapipe/framework/port:status", "//mediapipe/framework/port:status",
"//mediapipe/util:time_series_test_util", "//mediapipe/util:time_series_test_util",
"@eigen_archive//:eigen", "@eigen_archive//:eigen3",
], ],
) )
@ -333,7 +333,7 @@ cc_test(
"//mediapipe/framework/port:status", "//mediapipe/framework/port:status",
"//mediapipe/util:time_series_test_util", "//mediapipe/util:time_series_test_util",
"@com_google_audio_tools//audio/dsp:window_functions", "@com_google_audio_tools//audio/dsp:window_functions",
"@eigen_archive//:eigen", "@eigen_archive//:eigen3",
], ],
) )
@ -352,6 +352,6 @@ cc_test(
"//mediapipe/framework/tool:validate_type", "//mediapipe/framework/tool:validate_type",
"//mediapipe/util:time_series_test_util", "//mediapipe/util:time_series_test_util",
"@com_google_audio_tools//audio/dsp:signal_vector_util", "@com_google_audio_tools//audio/dsp:signal_vector_util",
"@eigen_archive//:eigen", "@eigen_archive//:eigen3",
], ],
) )

View File

@ -12,10 +12,10 @@
// See the License for the specific language governing permissions and // See the License for the specific language governing permissions and
// limitations under the License. // limitations under the License.
#include "absl/flags/flag.h"
#include "mediapipe/framework/calculator_runner.h" #include "mediapipe/framework/calculator_runner.h"
#include "mediapipe/framework/deps/file_path.h" #include "mediapipe/framework/deps/file_path.h"
#include "mediapipe/framework/formats/time_series_header.pb.h" #include "mediapipe/framework/formats/time_series_header.pb.h"
#include "mediapipe/framework/port/commandlineflags.h"
#include "mediapipe/framework/port/gmock.h" #include "mediapipe/framework/port/gmock.h"
#include "mediapipe/framework/port/gtest.h" #include "mediapipe/framework/port/gtest.h"
#include "mediapipe/framework/port/parse_text_proto.h" #include "mediapipe/framework/port/parse_text_proto.h"

View File

@ -414,7 +414,7 @@ cc_library(
"//mediapipe/framework/api2:node", "//mediapipe/framework/api2:node",
"//mediapipe/framework/formats:matrix", "//mediapipe/framework/formats:matrix",
"//mediapipe/framework/port:status", "//mediapipe/framework/port:status",
"@eigen_archive//:eigen", "@eigen_archive//:eigen3",
], ],
alwayslink = 1, alwayslink = 1,
) )
@ -430,7 +430,7 @@ cc_library(
"//mediapipe/framework/api2:node", "//mediapipe/framework/api2:node",
"//mediapipe/framework/formats:matrix", "//mediapipe/framework/formats:matrix",
"//mediapipe/framework/port:status", "//mediapipe/framework/port:status",
"@eigen_archive//:eigen", "@eigen_archive//:eigen3",
], ],
alwayslink = 1, alwayslink = 1,
) )
@ -450,6 +450,20 @@ cc_library(
alwayslink = 1, alwayslink = 1,
) )
cc_library(
name = "nonzero_calculator",
srcs = ["nonzero_calculator.cc"],
visibility = [
"//visibility:public",
],
deps = [
"//mediapipe/framework:calculator_framework",
"//mediapipe/framework/api2:node",
"//mediapipe/framework/port:ret_check",
],
alwayslink = 1,
)
cc_test( cc_test(
name = "mux_calculator_test", name = "mux_calculator_test",
srcs = ["mux_calculator_test.cc"], srcs = ["mux_calculator_test.cc"],
@ -776,7 +790,7 @@ cc_test(
"//mediapipe/framework/port:gtest_main", "//mediapipe/framework/port:gtest_main",
"//mediapipe/framework/port:logging", "//mediapipe/framework/port:logging",
"//mediapipe/framework/tool:validate_type", "//mediapipe/framework/tool:validate_type",
"@eigen_archive//:eigen", "@eigen_archive//:eigen3",
], ],
) )
@ -793,7 +807,7 @@ cc_test(
"//mediapipe/framework/port:logging", "//mediapipe/framework/port:logging",
"//mediapipe/framework/port:parse_text_proto", "//mediapipe/framework/port:parse_text_proto",
"//mediapipe/framework/tool:validate_type", "//mediapipe/framework/tool:validate_type",
"@eigen_archive//:eigen", "@eigen_archive//:eigen3",
], ],
) )
@ -1024,7 +1038,7 @@ cc_library(
"//mediapipe/framework/tool:status_util", "//mediapipe/framework/tool:status_util",
"//mediapipe/util:time_series_util", "//mediapipe/util:time_series_util",
"@com_google_absl//absl/memory", "@com_google_absl//absl/memory",
"@eigen_archive//:eigen", "@eigen_archive//:eigen3",
], ],
alwayslink = 1, alwayslink = 1,
) )

View File

@ -57,7 +57,7 @@ namespace mediapipe {
// //
// The "ALLOW" stream indicates the transition between accepting frames and // The "ALLOW" stream indicates the transition between accepting frames and
// dropping frames. "ALLOW = true" indicates the start of accepting frames // dropping frames. "ALLOW = true" indicates the start of accepting frames
// including the current timestamp, and "ALLOW = true" indicates the start of // including the current timestamp, and "ALLOW = false" indicates the start of
// dropping frames including the current timestamp. // dropping frames including the current timestamp.
// //
// FlowLimiterCalculator provides limited support for multiple input streams. // FlowLimiterCalculator provides limited support for multiple input streams.

View File

@ -0,0 +1,42 @@
// Copyright 2021 The MediaPipe Authors.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "mediapipe/framework/api2/node.h"
#include "mediapipe/framework/calculator_framework.h"
#include "mediapipe/framework/port/ret_check.h"
namespace mediapipe {
namespace api2 {
// A Calculator that returns 0 if INPUT is 0, and 1 otherwise.
class NonZeroCalculator : public Node {
public:
static constexpr Input<int>::SideFallback kIn{"INPUT"};
static constexpr Output<int> kOut{"OUTPUT"};
MEDIAPIPE_NODE_CONTRACT(kIn, kOut);
absl::Status Process(CalculatorContext* cc) final {
if (!kIn(cc).IsEmpty()) {
auto output = std::make_unique<int>((*kIn(cc) != 0) ? 1 : 0);
kOut(cc).Send(std::move(output));
}
return absl::OkStatus();
}
};
MEDIAPIPE_REGISTER_NODE(NonZeroCalculator);
} // namespace api2
} // namespace mediapipe

View File

@ -87,7 +87,6 @@ absl::Status PacketResamplerCalculator::Open(CalculatorContext* cc) {
flush_last_packet_ = resampler_options.flush_last_packet(); flush_last_packet_ = resampler_options.flush_last_packet();
jitter_ = resampler_options.jitter(); jitter_ = resampler_options.jitter();
jitter_with_reflection_ = resampler_options.jitter_with_reflection();
input_data_id_ = cc->Inputs().GetId("DATA", 0); input_data_id_ = cc->Inputs().GetId("DATA", 0);
if (!input_data_id_.IsValid()) { if (!input_data_id_.IsValid()) {
@ -98,11 +97,7 @@ absl::Status PacketResamplerCalculator::Open(CalculatorContext* cc) {
output_data_id_ = cc->Outputs().GetId("", 0); output_data_id_ = cc->Outputs().GetId("", 0);
} }
period_count_ = 0;
frame_rate_ = resampler_options.frame_rate(); frame_rate_ = resampler_options.frame_rate();
base_timestamp_ = resampler_options.has_base_timestamp()
? Timestamp(resampler_options.base_timestamp())
: Timestamp::Unset();
start_time_ = resampler_options.has_start_time() start_time_ = resampler_options.has_start_time()
? Timestamp(resampler_options.start_time()) ? Timestamp(resampler_options.start_time())
: Timestamp::Min(); : Timestamp::Min();
@ -141,30 +136,9 @@ absl::Status PacketResamplerCalculator::Open(CalculatorContext* cc) {
} }
} }
if (jitter_ != 0.0) { strategy_ = GetSamplingStrategy(resampler_options);
if (resampler_options.output_header() !=
PacketResamplerCalculatorOptions::NONE) { return strategy_->Open(cc);
LOG(WARNING) << "VideoHeader::frame_rate holds the target value and not "
"the actual value.";
}
if (flush_last_packet_) {
flush_last_packet_ = false;
LOG(WARNING) << "PacketResamplerCalculatorOptions.flush_last_packet is "
"ignored, because we are adding jitter.";
}
const auto& seed = cc->InputSidePackets().Tag("SEED").Get<std::string>();
random_ = CreateSecureRandom(seed);
if (random_ == nullptr) {
return absl::Status(
absl::StatusCode::kInvalidArgument,
"SecureRandom is not available. With \"jitter\" specified, "
"PacketResamplerCalculator processing cannot proceed.");
}
packet_reservoir_random_ = CreateSecureRandom(seed);
}
packet_reservoir_ =
std::make_unique<PacketReservoir>(packet_reservoir_random_.get());
return absl::OkStatus();
} }
absl::Status PacketResamplerCalculator::Process(CalculatorContext* cc) { absl::Status PacketResamplerCalculator::Process(CalculatorContext* cc) {
@ -177,171 +151,13 @@ absl::Status PacketResamplerCalculator::Process(CalculatorContext* cc) {
return absl::OkStatus(); return absl::OkStatus();
} }
} }
if (jitter_ != 0.0 && random_ != nullptr) {
// Packet reservior is used to make sure there's an output for every period, if (absl::Status status = strategy_->Process(cc); !status.ok()) {
// e.g. partial period at the end of the stream. return status; // Avoid MP_RETURN_IF_ERROR macro for external release.
if (packet_reservoir_->IsEnabled() &&
(first_timestamp_ == Timestamp::Unset() ||
(cc->InputTimestamp() - next_output_timestamp_min_).Value() >= 0)) {
auto curr_packet = cc->Inputs().Get(input_data_id_).Value();
packet_reservoir_->AddSample(curr_packet);
}
MP_RETURN_IF_ERROR(ProcessWithJitter(cc));
} else {
MP_RETURN_IF_ERROR(ProcessWithoutJitter(cc));
} }
last_packet_ = cc->Inputs().Get(input_data_id_).Value(); last_packet_ = cc->Inputs().Get(input_data_id_).Value();
return absl::OkStatus();
}
void PacketResamplerCalculator::InitializeNextOutputTimestampWithJitter() {
next_output_timestamp_min_ = first_timestamp_;
if (jitter_with_reflection_) {
next_output_timestamp_ =
first_timestamp_ + random_->UnbiasedUniform64(frame_time_usec_);
return;
}
next_output_timestamp_ =
first_timestamp_ + frame_time_usec_ * random_->RandFloat();
}
void PacketResamplerCalculator::UpdateNextOutputTimestampWithJitter() {
packet_reservoir_->Clear();
if (jitter_with_reflection_) {
next_output_timestamp_min_ += frame_time_usec_;
Timestamp next_output_timestamp_max_ =
next_output_timestamp_min_ + frame_time_usec_;
next_output_timestamp_ += frame_time_usec_ +
random_->UnbiasedUniform64(2 * jitter_usec_ + 1) -
jitter_usec_;
next_output_timestamp_ = Timestamp(ReflectBetween(
next_output_timestamp_.Value(), next_output_timestamp_min_.Value(),
next_output_timestamp_max_.Value()));
CHECK_GE(next_output_timestamp_, next_output_timestamp_min_);
CHECK_LT(next_output_timestamp_, next_output_timestamp_max_);
return;
}
packet_reservoir_->Disable();
next_output_timestamp_ +=
frame_time_usec_ *
((1.0 - jitter_) + 2.0 * jitter_ * random_->RandFloat());
}
absl::Status PacketResamplerCalculator::ProcessWithJitter(
CalculatorContext* cc) {
RET_CHECK_GT(cc->InputTimestamp(), Timestamp::PreStream());
RET_CHECK_NE(jitter_, 0.0);
if (first_timestamp_ == Timestamp::Unset()) {
first_timestamp_ = cc->InputTimestamp();
InitializeNextOutputTimestampWithJitter();
if (first_timestamp_ == next_output_timestamp_) {
OutputWithinLimits(
cc,
cc->Inputs().Get(input_data_id_).Value().At(next_output_timestamp_));
UpdateNextOutputTimestampWithJitter();
}
return absl::OkStatus();
}
if (frame_time_usec_ <
(cc->InputTimestamp() - last_packet_.Timestamp()).Value()) {
LOG_FIRST_N(WARNING, 2)
<< "Adding jitter is not very useful when upsampling.";
}
while (true) {
const int64 last_diff =
(next_output_timestamp_ - last_packet_.Timestamp()).Value();
RET_CHECK_GT(last_diff, 0);
const int64 curr_diff =
(next_output_timestamp_ - cc->InputTimestamp()).Value();
if (curr_diff > 0) {
break;
}
OutputWithinLimits(cc, (std::abs(curr_diff) > last_diff
? last_packet_
: cc->Inputs().Get(input_data_id_).Value())
.At(next_output_timestamp_));
UpdateNextOutputTimestampWithJitter();
// From now on every time a packet is emitted the timestamp of the next
// packet becomes known; that timestamp is stored in next_output_timestamp_.
// The only exception to this rule is the packet emitted from Close() which
// can only happen when jitter_with_reflection is enabled but in this case
// next_output_timestamp_min_ is a non-decreasing lower bound of any
// subsequent packet.
const Timestamp timestamp_bound = jitter_with_reflection_
? next_output_timestamp_min_
: next_output_timestamp_;
cc->Outputs().Get(output_data_id_).SetNextTimestampBound(timestamp_bound);
}
return absl::OkStatus();
}
absl::Status PacketResamplerCalculator::ProcessWithoutJitter(
CalculatorContext* cc) {
RET_CHECK_GT(cc->InputTimestamp(), Timestamp::PreStream());
RET_CHECK_EQ(jitter_, 0.0);
if (first_timestamp_ == Timestamp::Unset()) {
// This is the first packet, initialize the first_timestamp_.
if (base_timestamp_ == Timestamp::Unset()) {
// Initialize first_timestamp_ with exactly the first packet timestamp.
first_timestamp_ = cc->InputTimestamp();
} else {
// Initialize first_timestamp_ with the first packet timestamp
// aligned to the base_timestamp_.
int64 first_index = MathUtil::SafeRound<int64, double>(
(cc->InputTimestamp() - base_timestamp_).Seconds() * frame_rate_);
first_timestamp_ =
base_timestamp_ + TimestampDiffFromSeconds(first_index / frame_rate_);
}
if (cc->Outputs().UsesTags() && cc->Outputs().HasTag("VIDEO_HEADER")) {
cc->Outputs()
.Tag("VIDEO_HEADER")
.Add(new VideoHeader(video_header_), Timestamp::PreStream());
}
}
const Timestamp received_timestamp = cc->InputTimestamp();
const int64 received_timestamp_idx =
TimestampToPeriodIndex(received_timestamp);
// Only consider the received packet if it belongs to the current period
// (== period_count_) or to a newer one (> period_count_).
if (received_timestamp_idx >= period_count_) {
// Fill the empty periods until we are in the same index as the received
// packet.
while (received_timestamp_idx > period_count_) {
OutputWithinLimits(
cc, last_packet_.At(PeriodIndexToTimestamp(period_count_)));
++period_count_;
}
// Now, if the received packet has a timestamp larger than the middle of
// the current period, we can send a packet without waiting. We send the
// one closer to the middle.
Timestamp target_timestamp = PeriodIndexToTimestamp(period_count_);
if (received_timestamp >= target_timestamp) {
bool have_last_packet = (last_packet_.Timestamp() != Timestamp::Unset());
bool send_current =
!have_last_packet || (received_timestamp - target_timestamp <=
target_timestamp - last_packet_.Timestamp());
if (send_current) {
OutputWithinLimits(
cc, cc->Inputs().Get(input_data_id_).Value().At(target_timestamp));
} else {
OutputWithinLimits(cc, last_packet_.At(target_timestamp));
}
++period_count_;
}
// TODO: Add a mechanism to the framework to allow these packets
// to be output earlier (without waiting for a much later packet to
// arrive)
// Update the bound for the next packet.
cc->Outputs()
.Get(output_data_id_)
.SetNextTimestampBound(PeriodIndexToTimestamp(period_count_));
}
return absl::OkStatus(); return absl::OkStatus();
} }
@ -349,17 +165,34 @@ absl::Status PacketResamplerCalculator::Close(CalculatorContext* cc) {
if (!cc->GraphStatus().ok()) { if (!cc->GraphStatus().ok()) {
return absl::OkStatus(); return absl::OkStatus();
} }
// Emit the last packet received if we have at least one packet, but
// haven't sent anything for its period. return strategy_->Close(cc);
if (first_timestamp_ != Timestamp::Unset() && flush_last_packet_ && }
TimestampToPeriodIndex(last_packet_.Timestamp()) == period_count_) {
OutputWithinLimits(cc, std::unique_ptr<PacketResamplerStrategy>
last_packet_.At(PeriodIndexToTimestamp(period_count_))); PacketResamplerCalculator::GetSamplingStrategy(
const PacketResamplerCalculatorOptions& options) {
if (options.reproducible_sampling()) {
if (!options.jitter_with_reflection()) {
LOG(WARNING)
<< "reproducible_sampling enabled w/ jitter_with_reflection "
"disabled. "
<< "reproducible_sampling always uses jitter with reflection, "
<< "Ignoring jitter_with_reflection setting.";
} }
if (!packet_reservoir_->IsEmpty()) { return absl::make_unique<ReproducibleJitterWithReflectionStrategy>(this);
OutputWithinLimits(cc, packet_reservoir_->GetSample());
} }
return absl::OkStatus();
if (options.jitter() == 0) {
return absl::make_unique<NoJitterStrategy>(this);
}
if (options.jitter_with_reflection()) {
return absl::make_unique<LegacyJitterWithReflectionStrategy>(this);
}
// With jitter and no reflection.
return absl::make_unique<JitterWithoutReflectionStrategy>(this);
} }
Timestamp PacketResamplerCalculator::PeriodIndexToTimestamp(int64 index) const { Timestamp PacketResamplerCalculator::PeriodIndexToTimestamp(int64 index) const {
@ -385,4 +218,479 @@ void PacketResamplerCalculator::OutputWithinLimits(CalculatorContext* cc,
} }
} }
absl::Status LegacyJitterWithReflectionStrategy::Open(CalculatorContext* cc) {
const auto resampler_options =
tool::RetrieveOptions(cc->Options<PacketResamplerCalculatorOptions>(),
cc->InputSidePackets(), "OPTIONS");
if (resampler_options.output_header() !=
PacketResamplerCalculatorOptions::NONE) {
LOG(WARNING) << "VideoHeader::frame_rate holds the target value and not "
"the actual value.";
}
if (calculator_->flush_last_packet_) {
LOG(WARNING) << "PacketResamplerCalculatorOptions.flush_last_packet is "
"ignored, because we are adding jitter.";
}
const auto& seed = cc->InputSidePackets().Tag("SEED").Get<std::string>();
random_ = CreateSecureRandom(seed);
if (random_ == nullptr) {
return absl::InvalidArgumentError(
"SecureRandom is not available. With \"jitter\" specified, "
"PacketResamplerCalculator processing cannot proceed.");
}
packet_reservoir_random_ = CreateSecureRandom(seed);
packet_reservoir_ =
std::make_unique<PacketReservoir>(packet_reservoir_random_.get());
return absl::OkStatus();
}
absl::Status LegacyJitterWithReflectionStrategy::Close(CalculatorContext* cc) {
if (!packet_reservoir_->IsEmpty()) {
LOG(INFO) << "Emitting pack from reservoir.";
calculator_->OutputWithinLimits(cc, packet_reservoir_->GetSample());
}
return absl::OkStatus();
}
absl::Status LegacyJitterWithReflectionStrategy::Process(
CalculatorContext* cc) {
RET_CHECK_GT(cc->InputTimestamp(), Timestamp::PreStream());
if (packet_reservoir_->IsEnabled() &&
(first_timestamp_ == Timestamp::Unset() ||
(cc->InputTimestamp() - next_output_timestamp_min_).Value() >= 0)) {
auto curr_packet = cc->Inputs().Get(calculator_->input_data_id_).Value();
packet_reservoir_->AddSample(curr_packet);
}
if (first_timestamp_ == Timestamp::Unset()) {
first_timestamp_ = cc->InputTimestamp();
InitializeNextOutputTimestampWithJitter();
if (first_timestamp_ == next_output_timestamp_) {
calculator_->OutputWithinLimits(cc, cc->Inputs()
.Get(calculator_->input_data_id_)
.Value()
.At(next_output_timestamp_));
UpdateNextOutputTimestampWithJitter();
}
return absl::OkStatus();
}
if (calculator_->frame_time_usec_ <
(cc->InputTimestamp() - calculator_->last_packet_.Timestamp()).Value()) {
LOG_FIRST_N(WARNING, 2)
<< "Adding jitter is not very useful when upsampling.";
}
while (true) {
const int64 last_diff =
(next_output_timestamp_ - calculator_->last_packet_.Timestamp())
.Value();
RET_CHECK_GT(last_diff, 0);
const int64 curr_diff =
(next_output_timestamp_ - cc->InputTimestamp()).Value();
if (curr_diff > 0) {
break;
}
calculator_->OutputWithinLimits(
cc, (std::abs(curr_diff) > last_diff
? calculator_->last_packet_
: cc->Inputs().Get(calculator_->input_data_id_).Value())
.At(next_output_timestamp_));
UpdateNextOutputTimestampWithJitter();
// From now on every time a packet is emitted the timestamp of the next
// packet becomes known; that timestamp is stored in next_output_timestamp_.
// The only exception to this rule is the packet emitted from Close() which
// can only happen when jitter_with_reflection is enabled but in this case
// next_output_timestamp_min_ is a non-decreasing lower bound of any
// subsequent packet.
const Timestamp timestamp_bound = next_output_timestamp_min_;
cc->Outputs()
.Get(calculator_->output_data_id_)
.SetNextTimestampBound(timestamp_bound);
}
return absl::OkStatus();
}
void LegacyJitterWithReflectionStrategy::
InitializeNextOutputTimestampWithJitter() {
next_output_timestamp_min_ = first_timestamp_;
next_output_timestamp_ =
first_timestamp_ +
random_->UnbiasedUniform64(calculator_->frame_time_usec_);
}
void LegacyJitterWithReflectionStrategy::UpdateNextOutputTimestampWithJitter() {
packet_reservoir_->Clear();
next_output_timestamp_min_ += calculator_->frame_time_usec_;
Timestamp next_output_timestamp_max_ =
next_output_timestamp_min_ + calculator_->frame_time_usec_;
next_output_timestamp_ +=
calculator_->frame_time_usec_ +
random_->UnbiasedUniform64(2 * calculator_->jitter_usec_ + 1) -
calculator_->jitter_usec_;
next_output_timestamp_ = Timestamp(ReflectBetween(
next_output_timestamp_.Value(), next_output_timestamp_min_.Value(),
next_output_timestamp_max_.Value()));
CHECK_GE(next_output_timestamp_, next_output_timestamp_min_);
CHECK_LT(next_output_timestamp_, next_output_timestamp_max_);
}
absl::Status ReproducibleJitterWithReflectionStrategy::Open(
CalculatorContext* cc) {
const auto resampler_options =
tool::RetrieveOptions(cc->Options<PacketResamplerCalculatorOptions>(),
cc->InputSidePackets(), "OPTIONS");
if (resampler_options.output_header() !=
PacketResamplerCalculatorOptions::NONE) {
LOG(WARNING) << "VideoHeader::frame_rate holds the target value and not "
"the actual value.";
}
if (calculator_->flush_last_packet_) {
LOG(WARNING) << "PacketResamplerCalculatorOptions.flush_last_packet is "
"ignored, because we are adding jitter.";
}
const auto& seed = cc->InputSidePackets().Tag("SEED").Get<std::string>();
random_ = CreateSecureRandom(seed);
if (random_ == nullptr) {
return absl::InvalidArgumentError(
"SecureRandom is not available. With \"jitter\" specified, "
"PacketResamplerCalculator processing cannot proceed.");
}
return absl::OkStatus();
}
absl::Status ReproducibleJitterWithReflectionStrategy::Close(
CalculatorContext* cc) {
// If last packet is non-empty and a packet hasn't been emitted for this
// period, emit the last packet.
if (!calculator_->last_packet_.IsEmpty() && !packet_emitted_this_period_) {
calculator_->OutputWithinLimits(
cc, calculator_->last_packet_.At(next_output_timestamp_));
}
return absl::OkStatus();
}
absl::Status ReproducibleJitterWithReflectionStrategy::Process(
CalculatorContext* cc) {
RET_CHECK_GT(cc->InputTimestamp(), Timestamp::PreStream());
Packet current_packet = cc->Inputs().Get(calculator_->input_data_id_).Value();
if (calculator_->last_packet_.IsEmpty()) {
// last_packet is empty, this is the first packet of the stream.
InitializeNextOutputTimestamp(current_packet.Timestamp());
// If next_output_timestamp_ happens to fall before current_packet, emit
// current packet. Only a single packet can be emitted at the beginning
// of the stream.
if (next_output_timestamp_ < current_packet.Timestamp()) {
calculator_->OutputWithinLimits(
cc, current_packet.At(next_output_timestamp_));
packet_emitted_this_period_ = true;
}
return absl::OkStatus();
}
// Last packet is set, so we are mid-stream.
if (calculator_->frame_time_usec_ <
(current_packet.Timestamp() - calculator_->last_packet_.Timestamp())
.Value()) {
// Note, if the stream is upsampling, this could lead to the same packet
// being emitted twice. Upsampling and jitter doesn't make much sense
// but does technically work.
LOG_FIRST_N(WARNING, 2)
<< "Adding jitter is not very useful when upsampling.";
}
// Since we may be upsampling, we need to iteratively advance the
// next_output_timestamp_ one period at a time until it reaches the period
// current_packet is in. During this process, last_packet and/or
// current_packet may be repeatly emitted.
UpdateNextOutputTimestamp(current_packet.Timestamp());
while (!packet_emitted_this_period_ &&
next_output_timestamp_ <= current_packet.Timestamp()) {
// last_packet < next_output_timestamp_ <= current_packet,
// so emit the closest packet.
Packet packet_to_emit =
current_packet.Timestamp() - next_output_timestamp_ <
next_output_timestamp_ - calculator_->last_packet_.Timestamp()
? current_packet
: calculator_->last_packet_;
calculator_->OutputWithinLimits(cc,
packet_to_emit.At(next_output_timestamp_));
packet_emitted_this_period_ = true;
// If we are upsampling, packet_emitted_this_period_ can be reset by
// the following UpdateNext and the loop will iterate.
UpdateNextOutputTimestamp(current_packet.Timestamp());
}
// Set the bounds on the output stream. Note, if we emitted a packet
// above, it will already be set at next_output_timestamp_ + 1, in which
// case we have to skip setting it.
if (cc->Outputs().Get(calculator_->output_data_id_).NextTimestampBound() <
next_output_timestamp_) {
cc->Outputs()
.Get(calculator_->output_data_id_)
.SetNextTimestampBound(next_output_timestamp_);
}
return absl::OkStatus();
}
void ReproducibleJitterWithReflectionStrategy::InitializeNextOutputTimestamp(
Timestamp current_timestamp) {
if (next_output_timestamp_min_ != Timestamp::Unset()) {
return;
}
next_output_timestamp_min_ = Timestamp(0);
next_output_timestamp_ =
Timestamp(GetNextRandom(calculator_->frame_time_usec_));
// While the current timestamp is ahead of the max (i.e. min + frame_time),
// fast-forward.
while (current_timestamp >=
next_output_timestamp_min_ + calculator_->frame_time_usec_) {
packet_emitted_this_period_ = true; // Force update...
UpdateNextOutputTimestamp(current_timestamp);
}
}
void ReproducibleJitterWithReflectionStrategy::UpdateNextOutputTimestamp(
Timestamp current_timestamp) {
if (packet_emitted_this_period_ &&
current_timestamp >=
next_output_timestamp_min_ + calculator_->frame_time_usec_) {
next_output_timestamp_min_ += calculator_->frame_time_usec_;
Timestamp next_output_timestamp_max_ =
next_output_timestamp_min_ + calculator_->frame_time_usec_;
next_output_timestamp_ += calculator_->frame_time_usec_ +
GetNextRandom(2 * calculator_->jitter_usec_ + 1) -
calculator_->jitter_usec_;
next_output_timestamp_ = Timestamp(ReflectBetween(
next_output_timestamp_.Value(), next_output_timestamp_min_.Value(),
next_output_timestamp_max_.Value()));
packet_emitted_this_period_ = false;
}
}
absl::Status JitterWithoutReflectionStrategy::Open(CalculatorContext* cc) {
const auto resampler_options =
tool::RetrieveOptions(cc->Options<PacketResamplerCalculatorOptions>(),
cc->InputSidePackets(), "OPTIONS");
if (resampler_options.output_header() !=
PacketResamplerCalculatorOptions::NONE) {
LOG(WARNING) << "VideoHeader::frame_rate holds the target value and not "
"the actual value.";
}
if (calculator_->flush_last_packet_) {
LOG(WARNING) << "PacketResamplerCalculatorOptions.flush_last_packet is "
"ignored, because we are adding jitter.";
}
const auto& seed = cc->InputSidePackets().Tag("SEED").Get<std::string>();
random_ = CreateSecureRandom(seed);
if (random_ == nullptr) {
return absl::InvalidArgumentError(
"SecureRandom is not available. With \"jitter\" specified, "
"PacketResamplerCalculator processing cannot proceed.");
}
packet_reservoir_random_ = CreateSecureRandom(seed);
packet_reservoir_ =
absl::make_unique<PacketReservoir>(packet_reservoir_random_.get());
return absl::OkStatus();
}
absl::Status JitterWithoutReflectionStrategy::Close(CalculatorContext* cc) {
if (!packet_reservoir_->IsEmpty()) {
calculator_->OutputWithinLimits(cc, packet_reservoir_->GetSample());
}
return absl::OkStatus();
}
absl::Status JitterWithoutReflectionStrategy::Process(CalculatorContext* cc) {
RET_CHECK_GT(cc->InputTimestamp(), Timestamp::PreStream());
// Packet reservior is used to make sure there's an output for every period,
// e.g. partial period at the end of the stream.
if (packet_reservoir_->IsEnabled() &&
(calculator_->first_timestamp_ == Timestamp::Unset() ||
(cc->InputTimestamp() - next_output_timestamp_min_).Value() >= 0)) {
auto curr_packet = cc->Inputs().Get(calculator_->input_data_id_).Value();
packet_reservoir_->AddSample(curr_packet);
}
if (calculator_->first_timestamp_ == Timestamp::Unset()) {
calculator_->first_timestamp_ = cc->InputTimestamp();
InitializeNextOutputTimestamp();
if (calculator_->first_timestamp_ == next_output_timestamp_) {
calculator_->OutputWithinLimits(cc, cc->Inputs()
.Get(calculator_->input_data_id_)
.Value()
.At(next_output_timestamp_));
UpdateNextOutputTimestamp();
}
return absl::OkStatus();
}
if (calculator_->frame_time_usec_ <
(cc->InputTimestamp() - calculator_->last_packet_.Timestamp()).Value()) {
LOG_FIRST_N(WARNING, 2)
<< "Adding jitter is not very useful when upsampling.";
}
while (true) {
const int64 last_diff =
(next_output_timestamp_ - calculator_->last_packet_.Timestamp())
.Value();
RET_CHECK_GT(last_diff, 0);
const int64 curr_diff =
(next_output_timestamp_ - cc->InputTimestamp()).Value();
if (curr_diff > 0) {
break;
}
calculator_->OutputWithinLimits(
cc, (std::abs(curr_diff) > last_diff
? calculator_->last_packet_
: cc->Inputs().Get(calculator_->input_data_id_).Value())
.At(next_output_timestamp_));
UpdateNextOutputTimestamp();
cc->Outputs()
.Get(calculator_->output_data_id_)
.SetNextTimestampBound(next_output_timestamp_);
}
return absl::OkStatus();
}
void JitterWithoutReflectionStrategy::InitializeNextOutputTimestamp() {
next_output_timestamp_min_ = calculator_->first_timestamp_;
next_output_timestamp_ = calculator_->first_timestamp_ +
calculator_->frame_time_usec_ * random_->RandFloat();
}
void JitterWithoutReflectionStrategy::UpdateNextOutputTimestamp() {
packet_reservoir_->Clear();
packet_reservoir_->Disable();
next_output_timestamp_ += calculator_->frame_time_usec_ *
((1.0 - calculator_->jitter_) +
2.0 * calculator_->jitter_ * random_->RandFloat());
}
absl::Status NoJitterStrategy::Open(CalculatorContext* cc) {
const auto resampler_options =
tool::RetrieveOptions(cc->Options<PacketResamplerCalculatorOptions>(),
cc->InputSidePackets(), "OPTIONS");
base_timestamp_ = resampler_options.has_base_timestamp()
? Timestamp(resampler_options.base_timestamp())
: Timestamp::Unset();
period_count_ = 0;
return absl::OkStatus();
}
absl::Status NoJitterStrategy::Close(CalculatorContext* cc) {
// Emit the last packet received if we have at least one packet, but
// haven't sent anything for its period.
if (calculator_->first_timestamp_ != Timestamp::Unset() &&
calculator_->flush_last_packet_ &&
calculator_->TimestampToPeriodIndex(
calculator_->last_packet_.Timestamp()) == period_count_) {
calculator_->OutputWithinLimits(
cc, calculator_->last_packet_.At(
calculator_->PeriodIndexToTimestamp(period_count_)));
}
return absl::OkStatus();
}
absl::Status NoJitterStrategy::Process(CalculatorContext* cc) {
RET_CHECK_GT(cc->InputTimestamp(), Timestamp::PreStream());
if (calculator_->first_timestamp_ == Timestamp::Unset()) {
// This is the first packet, initialize the first_timestamp_.
if (base_timestamp_ == Timestamp::Unset()) {
// Initialize first_timestamp_ with exactly the first packet timestamp.
calculator_->first_timestamp_ = cc->InputTimestamp();
} else {
// Initialize first_timestamp_ with the first packet timestamp
// aligned to the base_timestamp_.
int64 first_index = MathUtil::SafeRound<int64, double>(
(cc->InputTimestamp() - base_timestamp_).Seconds() *
calculator_->frame_rate_);
calculator_->first_timestamp_ =
base_timestamp_ +
TimestampDiffFromSeconds(first_index / calculator_->frame_rate_);
}
if (cc->Outputs().UsesTags() && cc->Outputs().HasTag("VIDEO_HEADER")) {
cc->Outputs()
.Tag("VIDEO_HEADER")
.Add(new VideoHeader(calculator_->video_header_),
Timestamp::PreStream());
}
}
const Timestamp received_timestamp = cc->InputTimestamp();
const int64 received_timestamp_idx =
calculator_->TimestampToPeriodIndex(received_timestamp);
// Only consider the received packet if it belongs to the current period
// (== period_count_) or to a newer one (> period_count_).
if (received_timestamp_idx >= period_count_) {
// Fill the empty periods until we are in the same index as the received
// packet.
while (received_timestamp_idx > period_count_) {
calculator_->OutputWithinLimits(
cc, calculator_->last_packet_.At(
calculator_->PeriodIndexToTimestamp(period_count_)));
++period_count_;
}
// Now, if the received packet has a timestamp larger than the middle of
// the current period, we can send a packet without waiting. We send the
// one closer to the middle.
Timestamp target_timestamp =
calculator_->PeriodIndexToTimestamp(period_count_);
if (received_timestamp >= target_timestamp) {
bool have_last_packet =
(calculator_->last_packet_.Timestamp() != Timestamp::Unset());
bool send_current =
!have_last_packet ||
(received_timestamp - target_timestamp <=
target_timestamp - calculator_->last_packet_.Timestamp());
if (send_current) {
calculator_->OutputWithinLimits(cc,
cc->Inputs()
.Get(calculator_->input_data_id_)
.Value()
.At(target_timestamp));
} else {
calculator_->OutputWithinLimits(
cc, calculator_->last_packet_.At(target_timestamp));
}
++period_count_;
}
// TODO: Add a mechanism to the framework to allow these packets
// to be output earlier (without waiting for a much later packet to
// arrive)
// Update the bound for the next packet.
cc->Outputs()
.Get(calculator_->output_data_id_)
.SetNextTimestampBound(
calculator_->PeriodIndexToTimestamp(period_count_));
}
return absl::OkStatus();
}
} // namespace mediapipe } // namespace mediapipe

View File

@ -55,7 +55,7 @@ class PacketReservoir {
// correspond to timestamp t. // correspond to timestamp t.
// - The next packet is chosen randomly (uniform distribution) among frames // - The next packet is chosen randomly (uniform distribution) among frames
// that correspond to [t+(1-jitter)/frame_rate, t+(1+jitter)/frame_rate]. // that correspond to [t+(1-jitter)/frame_rate, t+(1+jitter)/frame_rate].
// - if jitter_with_reflection_ is true, the timestamp will be reflected // - if jitter_with_reflection is true, the timestamp will be reflected
// against the boundaries of [t_0 + (k-1)/frame_rate, t_0 + k/frame_rate) // against the boundaries of [t_0 + (k-1)/frame_rate, t_0 + k/frame_rate)
// so that its marginal distribution is uniform within this interval. // so that its marginal distribution is uniform within this interval.
// In the formula, t_0 is the timestamp of the first sampled // In the formula, t_0 is the timestamp of the first sampled
@ -66,6 +66,17 @@ class PacketReservoir {
// the resampling. For Cloud ML Video Intelligence API, the hash of the // the resampling. For Cloud ML Video Intelligence API, the hash of the
// input video should serve this purpose. For YouTube, either video ID or // input video should serve this purpose. For YouTube, either video ID or
// content hex ID of the input video should do. // content hex ID of the input video should do.
// - If reproducible_samping is true, care is taken to allow reproducible
// "mid-stream" sampling. The calculator can be executed on a stream that
// doesn't start at the first period. For instance, if the calculator
// is run on a 10 second stream it will produce the same set of samples
// as two runs of the calculator, the first with 3 seconds of input starting
// at time 0 and the second with 7 seconds of input starting at time +3s.
// - In order to guarantee the exact same samples, 1) the inputs must be
// aligned with the sampling period. For instance, if the sampling rate
// is 2 frames per second, streams should be aligned on 0.5 second
// boundaries, and 2) the stream must include at least one extra packet
// before and after the second aligned sampling period.
// //
// If jitter_ is not specified: // If jitter_ is not specified:
// - The first packet defines the first_timestamp of the output stream, // - The first packet defines the first_timestamp of the output stream,
@ -105,19 +116,6 @@ class PacketResamplerCalculator : public CalculatorBase {
absl::Status Close(CalculatorContext* cc) override; absl::Status Close(CalculatorContext* cc) override;
absl::Status Process(CalculatorContext* cc) override; absl::Status Process(CalculatorContext* cc) override;
private:
// Calculates the first sampled timestamp that incorporates a jittering
// offset.
void InitializeNextOutputTimestampWithJitter();
// Calculates the next sampled timestamp that incorporates a jittering offset.
void UpdateNextOutputTimestampWithJitter();
// Logic for Process() when jitter_ != 0.0.
absl::Status ProcessWithJitter(CalculatorContext* cc);
// Logic for Process() when jitter_ == 0.0.
absl::Status ProcessWithoutJitter(CalculatorContext* cc);
// Given the current count of periods that have passed, this returns // Given the current count of periods that have passed, this returns
// the next valid timestamp of the middle point of the next period: // the next valid timestamp of the middle point of the next period:
// if count is 0, it returns the first_timestamp_. // if count is 0, it returns the first_timestamp_.
@ -141,6 +139,16 @@ class PacketResamplerCalculator : public CalculatorBase {
// Outputs a packet if it is in range (start_time_, end_time_). // Outputs a packet if it is in range (start_time_, end_time_).
void OutputWithinLimits(CalculatorContext* cc, const Packet& packet) const; void OutputWithinLimits(CalculatorContext* cc, const Packet& packet) const;
protected:
// Returns Sampling Strategy to use.
//
// Virtual to allow injection of testing strategies.
virtual std::unique_ptr<class PacketResamplerStrategy> GetSamplingStrategy(
const mediapipe::PacketResamplerCalculatorOptions& options);
private:
std::unique_ptr<class PacketResamplerStrategy> strategy_;
// The timestamp of the first packet received. // The timestamp of the first packet received.
Timestamp first_timestamp_; Timestamp first_timestamp_;
@ -150,14 +158,6 @@ class PacketResamplerCalculator : public CalculatorBase {
// Inverse of frame_rate_. // Inverse of frame_rate_.
int64 frame_time_usec_; int64 frame_time_usec_;
// Number of periods that have passed (= #packets sent to the output).
//
// Can only be used if jitter_ equals zero.
int64 period_count_;
// The last packet that was received.
Packet last_packet_;
VideoHeader video_header_; VideoHeader video_header_;
// The "DATA" input stream. // The "DATA" input stream.
CollectionItemId input_data_id_; CollectionItemId input_data_id_;
@ -165,23 +165,15 @@ class PacketResamplerCalculator : public CalculatorBase {
CollectionItemId output_data_id_; CollectionItemId output_data_id_;
// Indicator whether to flush last packet even if its timestamp is greater // Indicator whether to flush last packet even if its timestamp is greater
// than the final stream timestamp. Set to false when jitter_ is non-zero. // than the final stream timestamp.
bool flush_last_packet_; bool flush_last_packet_;
// Jitter-related variables.
std::unique_ptr<RandomBase> random_;
double jitter_ = 0.0; double jitter_ = 0.0;
bool jitter_with_reflection_;
int64 jitter_usec_;
Timestamp next_output_timestamp_;
// If jittering_with_reflection_ is true, next_output_timestamp_ will be
// kept within the interval
// [next_output_timestamp_min_, next_output_timestamp_min_ + frame_time_usec_)
Timestamp next_output_timestamp_min_;
// If specified, output timestamps are aligned with base_timestamp. int64 jitter_usec_;
// Otherwise, they are aligned with the first input timestamp.
Timestamp base_timestamp_; // The last packet that was received.
Packet last_packet_;
// If specified, only outputs at/after start_time are included. // If specified, only outputs at/after start_time are included.
Timestamp start_time_; Timestamp start_time_;
@ -191,15 +183,210 @@ class PacketResamplerCalculator : public CalculatorBase {
// If set, the output timestamps nearest to start_time and end_time // If set, the output timestamps nearest to start_time and end_time
// are included in the output, even if the nearest timestamp is not // are included in the output, even if the nearest timestamp is not
// between start_time and end_time.W // between start_time and end_time.
bool round_limits_; bool round_limits_;
// Allow strategies access to all internal calculator state.
//
// The calculator and strategies are intimiately tied together so this should
// not break encapsulation.
friend class LegacyJitterWithReflectionStrategy;
friend class ReproducibleJitterWithReflectionStrategy;
friend class JitterWithoutReflectionStrategy;
friend class NoJitterStrategy;
};
// Abstract class encapsulating sampling stategy.
//
// These are used solely by PacketResamplerCalculator, but are exposed here
// to facilitate tests.
class PacketResamplerStrategy {
public:
PacketResamplerStrategy(PacketResamplerCalculator* calculator)
: calculator_(calculator) {}
virtual ~PacketResamplerStrategy() = default;
// Delegate for CalculatorBase::Open. See CalculatorBase for relevant
// implementation considerations.
virtual absl::Status Open(CalculatorContext* cc) = 0;
// Delegate for CalculatorBase::Close. See CalculatorBase for relevant
// implementation considerations.
virtual absl::Status Close(CalculatorContext* cc) = 0;
// Delegate for CalculatorBase::Process. See CalculatorBase for relevant
// implementation considerations.
virtual absl::Status Process(CalculatorContext* cc) = 0;
protected:
// Calculator running strategy.
PacketResamplerCalculator* calculator_;
};
// Strategy that applies Jitter with reflection based sampling.
//
// Used by PacketResamplerCalculator when both Jitter and reflection are
// enabled.
//
// This applies the legacy jitter with reflection which doesn't allow
// for reproducibility of sampling when starting mid-stream. This is maintained
// for backward compatibility.
class LegacyJitterWithReflectionStrategy : public PacketResamplerStrategy {
public:
LegacyJitterWithReflectionStrategy(PacketResamplerCalculator* calculator)
: PacketResamplerStrategy(calculator) {}
absl::Status Open(CalculatorContext* cc) override;
absl::Status Close(CalculatorContext* cc) override;
absl::Status Process(CalculatorContext* cc) override;
private:
void InitializeNextOutputTimestampWithJitter();
void UpdateNextOutputTimestampWithJitter();
// Jitter-related variables.
std::unique_ptr<RandomBase> random_;
// The timestamp of the first packet received.
Timestamp first_timestamp_;
// Next packet to be emitted. Since packets may not align perfectly with
// next_output_timestamp_, the closest packet will be emitted.
Timestamp next_output_timestamp_;
// Lower bound for next timestamp.
//
// next_output_timestamp_ will be kept within the interval
// [next_output_timestamp_min_, next_output_timestamp_min_ + frame_time_usec_)
Timestamp next_output_timestamp_min_ = Timestamp::Unset();
// packet reservior used for sampling random packet out of partial // packet reservior used for sampling random packet out of partial
// period when jitter is enabled // period when jitter is enabled
std::unique_ptr<PacketReservoir> packet_reservoir_; std::unique_ptr<PacketReservoir> packet_reservoir_;
// random number generator used in packet_reservior_. // random number generator used in packet_reservior_.
std::unique_ptr<RandomBase> packet_reservoir_random_; std::unique_ptr<RandomBase> packet_reservoir_random_;
}; };
// Strategy that applies reproducible jitter with reflection based sampling.
//
// Used by PacketResamplerCalculator when both Jitter and reflection are
// enabled.
class ReproducibleJitterWithReflectionStrategy
: public PacketResamplerStrategy {
public:
ReproducibleJitterWithReflectionStrategy(
PacketResamplerCalculator* calculator)
: PacketResamplerStrategy(calculator) {}
absl::Status Open(CalculatorContext* cc) override;
absl::Status Close(CalculatorContext* cc) override;
absl::Status Process(CalculatorContext* cc) override;
protected:
// Returns next random in range (0,n].
//
// Exposed as virtual function for testing Jitter with reflection.
// This is the only way random_ is accessed.
virtual uint64 GetNextRandom(uint64 n) {
return random_->UnbiasedUniform64(n);
}
private:
// Initializes Jitter with reflection.
//
// This will fast-forward to the period containing current_timestamp.
// next_output_timestamp_ is guarnateed to be current_timestamp's period
// and packet_emitted_this_period_ will be set to false.
void InitializeNextOutputTimestamp(Timestamp current_timestamp);
// Potentially advances next_output_timestamp_ a single period.
//
// next_output_timestamp_ will only be advanced if packet_emitted_this_period_
// is false. next_output_timestamp_ will never be advanced beyond
// current_timestamp's period.
//
// However, next_output_timestamp_ could fall before current_timestamp's
// period since only a single period can be advanced at a time.
void UpdateNextOutputTimestamp(Timestamp current_timestamp);
// Jitter-related variables.
std::unique_ptr<RandomBase> random_;
// Next packet to be emitted. Since packets may not align perfectly with
// next_output_timestamp_, the closest packet will be emitted.
Timestamp next_output_timestamp_;
// Lower bound for next timestamp.
//
// next_output_timestamp_ will be kept within the interval
// [next_output_timestamp_min_, next_output_timestamp_min_ + frame_time_usec_)
Timestamp next_output_timestamp_min_ = Timestamp::Unset();
// Indicates packet was emitted for current period (i.e. the period
// next_output_timestamp_ falls in.
bool packet_emitted_this_period_ = false;
};
// Strategy that applies Jitter without reflection based sampling.
//
// Used by PacketResamplerCalculator when Jitter is enabled and reflection is
// not enabled.
class JitterWithoutReflectionStrategy : public PacketResamplerStrategy {
public:
JitterWithoutReflectionStrategy(PacketResamplerCalculator* calculator)
: PacketResamplerStrategy(calculator) {}
absl::Status Open(CalculatorContext* cc) override;
absl::Status Close(CalculatorContext* cc) override;
absl::Status Process(CalculatorContext* cc) override;
private:
// Calculates the first sampled timestamp that incorporates a jittering
// offset.
void InitializeNextOutputTimestamp();
// Calculates the next sampled timestamp that incorporates a jittering offset.
void UpdateNextOutputTimestamp();
// Jitter-related variables.
std::unique_ptr<RandomBase> random_;
// Next packet to be emitted. Since packets may not align perfectly with
// next_output_timestamp_, the closest packet will be emitted.
Timestamp next_output_timestamp_;
// Lower bound for next timestamp.
//
// next_output_timestamp_ will be kept within the interval
// [next_output_timestamp_min_, next_output_timestamp_min_ + frame_time_usec_)
Timestamp next_output_timestamp_min_ = Timestamp::Unset();
// packet reservior used for sampling random packet out of partial period.
std::unique_ptr<PacketReservoir> packet_reservoir_;
// random number generator used in packet_reservior_.
std::unique_ptr<RandomBase> packet_reservoir_random_;
};
// Strategy that applies sampling without any jitter.
//
// Used by PacketResamplerCalculator when jitter is not enabled.
class NoJitterStrategy : public PacketResamplerStrategy {
public:
NoJitterStrategy(PacketResamplerCalculator* calculator)
: PacketResamplerStrategy(calculator) {}
absl::Status Open(CalculatorContext* cc) override;
absl::Status Close(CalculatorContext* cc) override;
absl::Status Process(CalculatorContext* cc) override;
private:
// Number of periods that have passed (= #packets sent to the output).
int64 period_count_;
// If specified, output timestamps are aligned with base_timestamp.
// Otherwise, they are aligned with the first input timestamp.
Timestamp base_timestamp_;
};
} // namespace mediapipe } // namespace mediapipe
#endif // MEDIAPIPE_CALCULATORS_CORE_PACKET_RESAMPLER_CALCULATOR_H_ #endif // MEDIAPIPE_CALCULATORS_CORE_PACKET_RESAMPLER_CALCULATOR_H_

View File

@ -68,8 +68,23 @@ message PacketResamplerCalculatorOptions {
// pseudo-random number generator does its job and the number of frames is // pseudo-random number generator does its job and the number of frames is
// sufficiently large, the average frame rate will be close to this value. // sufficiently large, the average frame rate will be close to this value.
optional double jitter = 4; optional double jitter = 4;
// Enables reflection when applying jitter.
//
// This option is ignored when reproducible_sampling is true, in which case
// reflection will be used.
//
// New use cases should use reproducible_sampling = true, as
// jitter_with_reflection is deprecated and will be removed at some point.
optional bool jitter_with_reflection = 9 [default = false]; optional bool jitter_with_reflection = 9 [default = false];
// If set, enabled reproducible sampling, allowing frames to be sampled
// without regards to where the stream starts. See
// packet_resampler_calculator.h for details.
//
// This enables reflection (ignoring jitter_with_reflection setting).
optional bool reproducible_sampling = 10 [default = false];
// If specified, output timestamps are aligned with base_timestamp. // If specified, output timestamps are aligned with base_timestamp.
// Otherwise, they are aligned with the first input timestamp. // Otherwise, they are aligned with the first input timestamp.
// //

View File

@ -30,6 +30,7 @@
namespace mediapipe { namespace mediapipe {
using ::testing::ElementsAre;
namespace { namespace {
// A simple version of CalculatorRunner with built-in convenience // A simple version of CalculatorRunner with built-in convenience
// methods for setting inputs from a vector and checking outputs // methods for setting inputs from a vector and checking outputs
@ -96,6 +97,77 @@ class SimpleRunner : public CalculatorRunner {
static int static_count_; static int static_count_;
}; };
// Matcher for Packets with uint64 payload, comparing arg packet's
// timestamp and uint64 payload.
MATCHER_P2(PacketAtTimestamp, payload, timestamp,
absl::StrCat(negation ? "isn't" : "is", " a packet with payload ",
payload, " @ time ", timestamp)) {
if (timestamp != arg.Timestamp().Value()) {
*result_listener << "at incorrect timestamp = " << arg.Timestamp().Value();
return false;
}
int64 actual_payload = arg.template Get<int64>();
if (actual_payload != payload) {
*result_listener << "with incorrect payload = " << actual_payload;
return false;
}
return true;
}
// JitterWithReflectionStrategy child class which injects a specified stream
// of "random" numbers.
//
// Calculators are created through factory methods, making testing and injection
// tricky. This class utilizes a static variable, random_sequence, to pass
// the desired random sequence into the calculator.
class ReproducibleJitterWithReflectionStrategyForTesting
: public ReproducibleJitterWithReflectionStrategy {
public:
ReproducibleJitterWithReflectionStrategyForTesting(
PacketResamplerCalculator* calculator)
: ReproducibleJitterWithReflectionStrategy(calculator) {}
// Statically accessed random sequence to use for jitter with reflection.
//
// An EXPECT will fail if sequence is less than the number requested during
// processing.
static std::vector<uint64> random_sequence;
protected:
virtual uint64 GetNextRandom(uint64 n) {
EXPECT_LT(sequence_index_, random_sequence.size());
return random_sequence[sequence_index_++] % n;
}
private:
int32 sequence_index_ = 0;
};
std::vector<uint64>
ReproducibleJitterWithReflectionStrategyForTesting::random_sequence;
// PacketResamplerCalculator child class which injects a specified stream
// of "random" numbers.
//
// Calculators are created through factory methods, making testing and injection
// tricky. This class utilizes a static variable, random_sequence, to pass
// the desired random sequence into the calculator.
class ReproducibleResamplerCalculatorForTesting
: public PacketResamplerCalculator {
public:
static absl::Status GetContract(CalculatorContract* cc) {
return PacketResamplerCalculator::GetContract(cc);
}
protected:
std::unique_ptr<class PacketResamplerStrategy> GetSamplingStrategy(
const mediapipe::PacketResamplerCalculatorOptions& Options) {
return absl::make_unique<
ReproducibleJitterWithReflectionStrategyForTesting>(this);
}
};
REGISTER_CALCULATOR(ReproducibleResamplerCalculatorForTesting);
int SimpleRunner::static_count_ = 0; int SimpleRunner::static_count_ = 0;
TEST(PacketResamplerCalculatorTest, NoPacketsInStream) { TEST(PacketResamplerCalculatorTest, NoPacketsInStream) {

View File

@ -561,13 +561,13 @@ cc_test(
"//mediapipe/framework/formats:image_frame_opencv", "//mediapipe/framework/formats:image_frame_opencv",
"//mediapipe/framework/formats:rect_cc_proto", "//mediapipe/framework/formats:rect_cc_proto",
"//mediapipe/framework/formats:tensor", "//mediapipe/framework/formats:tensor",
"//mediapipe/framework/port:commandlineflags",
"//mediapipe/framework/port:gtest_main", "//mediapipe/framework/port:gtest_main",
"//mediapipe/framework/port:integral_types", "//mediapipe/framework/port:integral_types",
"//mediapipe/framework/port:opencv_core", "//mediapipe/framework/port:opencv_core",
"//mediapipe/framework/port:opencv_imgcodecs", "//mediapipe/framework/port:opencv_imgcodecs",
"//mediapipe/framework/port:opencv_imgproc", "//mediapipe/framework/port:opencv_imgproc",
"//mediapipe/framework/port:parse_text_proto", "//mediapipe/framework/port:parse_text_proto",
"@com_google_absl//absl/flags:flag",
"@com_google_absl//absl/memory", "@com_google_absl//absl/memory",
"@com_google_absl//absl/strings", "@com_google_absl//absl/strings",
], ],

View File

@ -15,6 +15,7 @@
#include <cmath> #include <cmath>
#include <vector> #include <vector>
#include "absl/flags/flag.h"
#include "absl/memory/memory.h" #include "absl/memory/memory.h"
#include "absl/strings/substitute.h" #include "absl/strings/substitute.h"
#include "mediapipe/calculators/tensor/image_to_tensor_converter.h" #include "mediapipe/calculators/tensor/image_to_tensor_converter.h"
@ -28,7 +29,6 @@
#include "mediapipe/framework/formats/image_frame_opencv.h" #include "mediapipe/framework/formats/image_frame_opencv.h"
#include "mediapipe/framework/formats/rect.pb.h" #include "mediapipe/framework/formats/rect.pb.h"
#include "mediapipe/framework/formats/tensor.h" #include "mediapipe/framework/formats/tensor.h"
#include "mediapipe/framework/port/commandlineflags.h"
#include "mediapipe/framework/port/gtest.h" #include "mediapipe/framework/port/gtest.h"
#include "mediapipe/framework/port/integral_types.h" #include "mediapipe/framework/port/integral_types.h"
#include "mediapipe/framework/port/opencv_core_inc.h" #include "mediapipe/framework/port/opencv_core_inc.h"

View File

@ -41,7 +41,7 @@ class InferenceCalculatorSelectorImpl
(options.has_delegate() && options.delegate().has_gpu()); (options.has_delegate() && options.delegate().has_gpu());
if (should_use_gpu) { if (should_use_gpu) {
impls.emplace_back("Metal"); impls.emplace_back("Metal");
impls.emplace_back("MlDrift"); impls.emplace_back("MlDriftWebGl");
impls.emplace_back("Gl"); impls.emplace_back("Gl");
} }
impls.emplace_back("Cpu"); impls.emplace_back("Cpu");

View File

@ -118,8 +118,8 @@ struct InferenceCalculatorGl : public InferenceCalculator {
static constexpr char kCalculatorName[] = "InferenceCalculatorGl"; static constexpr char kCalculatorName[] = "InferenceCalculatorGl";
}; };
struct InferenceCalculatorMlDrift : public InferenceCalculator { struct InferenceCalculatorMlDriftWebGl : public InferenceCalculator {
static constexpr char kCalculatorName[] = "InferenceCalculatorMlDrift"; static constexpr char kCalculatorName[] = "InferenceCalculatorMlDriftWebGl";
}; };
struct InferenceCalculatorMetal : public InferenceCalculator { struct InferenceCalculatorMetal : public InferenceCalculator {

View File

@ -51,12 +51,12 @@ message InferenceCalculatorOptions {
// This option is valid for TFLite GPU delegate API2 only, // This option is valid for TFLite GPU delegate API2 only,
// Choose any of available APIs to force running inference using it. // Choose any of available APIs to force running inference using it.
enum API { enum Api {
ANY = 0; ANY = 0;
OPENGL = 1; OPENGL = 1;
OPENCL = 2; OPENCL = 2;
} }
optional API api = 4 [default = ANY]; optional Api api = 4 [default = ANY];
// This option is valid for TFLite GPU delegate API2 only, // This option is valid for TFLite GPU delegate API2 only,
// Set to true to use 16-bit float precision. If max precision is needed, // Set to true to use 16-bit float precision. If max precision is needed,

View File

@ -136,7 +136,7 @@ absl::Status InferenceCalculatorCpuImpl::LoadModel(CalculatorContext* cc) {
const auto& model = *model_packet_.Get(); const auto& model = *model_packet_.Get();
tflite::ops::builtin::BuiltinOpResolver op_resolver = tflite::ops::builtin::BuiltinOpResolver op_resolver =
kSideInCustomOpResolver(cc).GetOr( kSideInCustomOpResolver(cc).GetOr(
tflite::ops::builtin::BuiltinOpResolver()); tflite::ops::builtin::BuiltinOpResolverWithoutDefaultDelegates());
tflite::InterpreterBuilder(model, op_resolver)(&interpreter_); tflite::InterpreterBuilder(model, op_resolver)(&interpreter_);
RET_CHECK(interpreter_); RET_CHECK(interpreter_);

View File

@ -59,7 +59,7 @@ const std::vector<Param>& GetParams() {
p.back().delegate.mutable_gpu(); p.back().delegate.mutable_gpu();
#endif // TARGET_IPHONE_SIMULATOR #endif // TARGET_IPHONE_SIMULATOR
#if __EMSCRIPTEN__ #if __EMSCRIPTEN__
p.push_back({"MlDrift", "MlDrift"}); p.push_back({"MlDriftWebGl", "MlDriftWebGl"});
p.back().delegate.mutable_gpu(); p.back().delegate.mutable_gpu();
#endif // __EMSCRIPTEN__ #endif // __EMSCRIPTEN__
#if __ANDROID__ && 0 // Disabled for now since emulator can't go GLESv3 #if __ANDROID__ && 0 // Disabled for now since emulator can't go GLESv3

View File

@ -63,7 +63,7 @@ class InferenceCalculatorGlImpl
mediapipe::GlCalculatorHelper gpu_helper_; mediapipe::GlCalculatorHelper gpu_helper_;
std::unique_ptr<tflite::gpu::TFLiteGPURunner> tflite_gpu_runner_; std::unique_ptr<tflite::gpu::TFLiteGPURunner> tflite_gpu_runner_;
bool allow_precision_loss_ = false; bool allow_precision_loss_ = false;
mediapipe::InferenceCalculatorOptions::Delegate::Gpu::API mediapipe::InferenceCalculatorOptions::Delegate::Gpu::Api
tflite_gpu_runner_api_; tflite_gpu_runner_api_;
#endif // MEDIAPIPE_TFLITE_GL_INFERENCE #endif // MEDIAPIPE_TFLITE_GL_INFERENCE
@ -244,7 +244,7 @@ absl::Status InferenceCalculatorGlImpl::InitTFLiteGPURunner(
const auto& model = *model_packet_.Get(); const auto& model = *model_packet_.Get();
tflite::ops::builtin::BuiltinOpResolver op_resolver = tflite::ops::builtin::BuiltinOpResolver op_resolver =
kSideInCustomOpResolver(cc).GetOr( kSideInCustomOpResolver(cc).GetOr(
tflite::ops::builtin::BuiltinOpResolver()); tflite::ops::builtin::BuiltinOpResolverWithoutDefaultDelegates());
// Create runner // Create runner
tflite::gpu::InferenceOptions options; tflite::gpu::InferenceOptions options;
@ -294,7 +294,7 @@ absl::Status InferenceCalculatorGlImpl::LoadModel(CalculatorContext* cc) {
const auto& model = *model_packet_.Get(); const auto& model = *model_packet_.Get();
tflite::ops::builtin::BuiltinOpResolver op_resolver = tflite::ops::builtin::BuiltinOpResolver op_resolver =
kSideInCustomOpResolver(cc).GetOr( kSideInCustomOpResolver(cc).GetOr(
tflite::ops::builtin::BuiltinOpResolver()); tflite::ops::builtin::BuiltinOpResolverWithoutDefaultDelegates());
tflite::InterpreterBuilder(model, op_resolver)(&interpreter_); tflite::InterpreterBuilder(model, op_resolver)(&interpreter_);
RET_CHECK(interpreter_); RET_CHECK(interpreter_);

View File

@ -200,7 +200,7 @@ absl::Status InferenceCalculatorMetalImpl::LoadModel(CalculatorContext* cc) {
const auto& model = *model_packet_.Get(); const auto& model = *model_packet_.Get();
tflite::ops::builtin::BuiltinOpResolver op_resolver = tflite::ops::builtin::BuiltinOpResolver op_resolver =
kSideInCustomOpResolver(cc).GetOr( kSideInCustomOpResolver(cc).GetOr(
tflite::ops::builtin::BuiltinOpResolver()); tflite::ops::builtin::BuiltinOpResolverWithoutDefaultDelegates());
tflite::InterpreterBuilder(model, op_resolver)(&interpreter_); tflite::InterpreterBuilder(model, op_resolver)(&interpreter_);
RET_CHECK(interpreter_); RET_CHECK(interpreter_);

View File

@ -892,13 +892,13 @@ cc_test(
"//mediapipe/framework:calculator_runner", "//mediapipe/framework:calculator_runner",
"//mediapipe/framework:packet", "//mediapipe/framework:packet",
"//mediapipe/framework/deps:file_path", "//mediapipe/framework/deps:file_path",
"//mediapipe/framework/port:commandlineflags",
"//mediapipe/framework/port:file_helpers", "//mediapipe/framework/port:file_helpers",
"//mediapipe/framework/port:gtest_main", "//mediapipe/framework/port:gtest_main",
"//mediapipe/framework/port:parse_text_proto", "//mediapipe/framework/port:parse_text_proto",
"//mediapipe/framework/port:status", "//mediapipe/framework/port:status",
"//mediapipe/framework/tool:tag_map_helper", "//mediapipe/framework/tool:tag_map_helper",
"//mediapipe/framework/tool:validate_type", "//mediapipe/framework/tool:validate_type",
"@com_google_absl//absl/flags:flag",
"@com_google_absl//absl/strings", "@com_google_absl//absl/strings",
"@org_tensorflow//tensorflow/core:direct_session", "@org_tensorflow//tensorflow/core:direct_session",
"@org_tensorflow//tensorflow/core:framework", "@org_tensorflow//tensorflow/core:framework",
@ -923,13 +923,13 @@ cc_test(
"//mediapipe/framework:packet", "//mediapipe/framework:packet",
"//mediapipe/framework:packet_generator_cc_proto", "//mediapipe/framework:packet_generator_cc_proto",
"//mediapipe/framework/deps:file_path", "//mediapipe/framework/deps:file_path",
"//mediapipe/framework/port:commandlineflags",
"//mediapipe/framework/port:file_helpers", "//mediapipe/framework/port:file_helpers",
"//mediapipe/framework/port:gtest_main", "//mediapipe/framework/port:gtest_main",
"//mediapipe/framework/port:parse_text_proto", "//mediapipe/framework/port:parse_text_proto",
"//mediapipe/framework/port:status", "//mediapipe/framework/port:status",
"//mediapipe/framework/tool:tag_map_helper", "//mediapipe/framework/tool:tag_map_helper",
"//mediapipe/framework/tool:validate_type", "//mediapipe/framework/tool:validate_type",
"@com_google_absl//absl/flags:flag",
"@com_google_absl//absl/strings", "@com_google_absl//absl/strings",
"@org_tensorflow//tensorflow/core:direct_session", "@org_tensorflow//tensorflow/core:direct_session",
"@org_tensorflow//tensorflow/core:framework", "@org_tensorflow//tensorflow/core:framework",
@ -954,11 +954,11 @@ cc_test(
"//mediapipe/framework:packet", "//mediapipe/framework:packet",
"//mediapipe/framework:packet_generator_cc_proto", "//mediapipe/framework:packet_generator_cc_proto",
"//mediapipe/framework/deps:file_path", "//mediapipe/framework/deps:file_path",
"//mediapipe/framework/port:commandlineflags",
"//mediapipe/framework/port:gtest_main", "//mediapipe/framework/port:gtest_main",
"//mediapipe/framework/port:parse_text_proto", "//mediapipe/framework/port:parse_text_proto",
"//mediapipe/framework/tool:tag_map_helper", "//mediapipe/framework/tool:tag_map_helper",
"//mediapipe/framework/tool:validate_type", "//mediapipe/framework/tool:validate_type",
"@com_google_absl//absl/flags:flag",
"@com_google_absl//absl/strings", "@com_google_absl//absl/strings",
"@org_tensorflow//tensorflow/core:all_kernels", "@org_tensorflow//tensorflow/core:all_kernels",
"@org_tensorflow//tensorflow/core:direct_session", "@org_tensorflow//tensorflow/core:direct_session",
@ -981,11 +981,11 @@ cc_test(
"//mediapipe/framework:calculator_runner", "//mediapipe/framework:calculator_runner",
"//mediapipe/framework:packet", "//mediapipe/framework:packet",
"//mediapipe/framework/deps:file_path", "//mediapipe/framework/deps:file_path",
"//mediapipe/framework/port:commandlineflags",
"//mediapipe/framework/port:gtest_main", "//mediapipe/framework/port:gtest_main",
"//mediapipe/framework/port:parse_text_proto", "//mediapipe/framework/port:parse_text_proto",
"//mediapipe/framework/tool:tag_map_helper", "//mediapipe/framework/tool:tag_map_helper",
"//mediapipe/framework/tool:validate_type", "//mediapipe/framework/tool:validate_type",
"@com_google_absl//absl/flags:flag",
"@com_google_absl//absl/strings", "@com_google_absl//absl/strings",
"@org_tensorflow//tensorflow/core:all_kernels", "@org_tensorflow//tensorflow/core:all_kernels",
"@org_tensorflow//tensorflow/core:direct_session", "@org_tensorflow//tensorflow/core:direct_session",
@ -1144,8 +1144,8 @@ cc_test(
":tensorflow_inference_calculator", ":tensorflow_inference_calculator",
":tensorflow_session_from_frozen_graph_generator", ":tensorflow_session_from_frozen_graph_generator",
":tensorflow_session_from_frozen_graph_generator_cc_proto", ":tensorflow_session_from_frozen_graph_generator_cc_proto",
"@com_google_absl//absl/flags:flag",
"//mediapipe/framework/deps:file_path", "//mediapipe/framework/deps:file_path",
"//mediapipe/framework/port:commandlineflags",
"//mediapipe/framework/port:integral_types", "//mediapipe/framework/port:integral_types",
"//mediapipe/framework:calculator_framework", "//mediapipe/framework:calculator_framework",
"//mediapipe/framework:calculator_runner", "//mediapipe/framework:calculator_runner",

View File

@ -16,12 +16,12 @@
#include <string> #include <string>
#include <vector> #include <vector>
#include "absl/flags/flag.h"
#include "mediapipe/calculators/tensorflow/tensorflow_inference_calculator.pb.h" #include "mediapipe/calculators/tensorflow/tensorflow_inference_calculator.pb.h"
#include "mediapipe/calculators/tensorflow/tensorflow_session_from_frozen_graph_generator.pb.h" #include "mediapipe/calculators/tensorflow/tensorflow_session_from_frozen_graph_generator.pb.h"
#include "mediapipe/framework/calculator_framework.h" #include "mediapipe/framework/calculator_framework.h"
#include "mediapipe/framework/calculator_runner.h" #include "mediapipe/framework/calculator_runner.h"
#include "mediapipe/framework/deps/file_path.h" #include "mediapipe/framework/deps/file_path.h"
#include "mediapipe/framework/port/commandlineflags.h"
#include "mediapipe/framework/port/gmock.h" #include "mediapipe/framework/port/gmock.h"
#include "mediapipe/framework/port/gtest.h" #include "mediapipe/framework/port/gtest.h"
#include "mediapipe/framework/port/integral_types.h" #include "mediapipe/framework/port/integral_types.h"

View File

@ -12,6 +12,7 @@
// See the License for the specific language governing permissions and // See the License for the specific language governing permissions and
// limitations under the License. // limitations under the License.
#include "absl/flags/flag.h"
#include "absl/strings/substitute.h" #include "absl/strings/substitute.h"
#include "mediapipe/calculators/tensorflow/tensorflow_session.h" #include "mediapipe/calculators/tensorflow/tensorflow_session.h"
#include "mediapipe/calculators/tensorflow/tensorflow_session_from_frozen_graph_calculator.pb.h" #include "mediapipe/calculators/tensorflow/tensorflow_session_from_frozen_graph_calculator.pb.h"
@ -19,7 +20,6 @@
#include "mediapipe/framework/calculator_runner.h" #include "mediapipe/framework/calculator_runner.h"
#include "mediapipe/framework/deps/file_path.h" #include "mediapipe/framework/deps/file_path.h"
#include "mediapipe/framework/packet.h" #include "mediapipe/framework/packet.h"
#include "mediapipe/framework/port/commandlineflags.h"
#include "mediapipe/framework/port/file_helpers.h" #include "mediapipe/framework/port/file_helpers.h"
#include "mediapipe/framework/port/gmock.h" #include "mediapipe/framework/port/gmock.h"
#include "mediapipe/framework/port/gtest.h" #include "mediapipe/framework/port/gtest.h"

View File

@ -12,6 +12,7 @@
// See the License for the specific language governing permissions and // See the License for the specific language governing permissions and
// limitations under the License. // limitations under the License.
#include "absl/flags/flag.h"
#include "absl/strings/substitute.h" #include "absl/strings/substitute.h"
#include "mediapipe/calculators/tensorflow/tensorflow_session.h" #include "mediapipe/calculators/tensorflow/tensorflow_session.h"
#include "mediapipe/calculators/tensorflow/tensorflow_session_from_frozen_graph_generator.pb.h" #include "mediapipe/calculators/tensorflow/tensorflow_session_from_frozen_graph_generator.pb.h"
@ -19,7 +20,6 @@
#include "mediapipe/framework/deps/file_path.h" #include "mediapipe/framework/deps/file_path.h"
#include "mediapipe/framework/packet.h" #include "mediapipe/framework/packet.h"
#include "mediapipe/framework/packet_generator.pb.h" #include "mediapipe/framework/packet_generator.pb.h"
#include "mediapipe/framework/port/commandlineflags.h"
#include "mediapipe/framework/port/file_helpers.h" #include "mediapipe/framework/port/file_helpers.h"
#include "mediapipe/framework/port/gmock.h" #include "mediapipe/framework/port/gmock.h"
#include "mediapipe/framework/port/gtest.h" #include "mediapipe/framework/port/gtest.h"

View File

@ -12,6 +12,7 @@
// See the License for the specific language governing permissions and // See the License for the specific language governing permissions and
// limitations under the License. // limitations under the License.
#include "absl/flags/flag.h"
#include "absl/strings/str_replace.h" #include "absl/strings/str_replace.h"
#include "mediapipe/calculators/tensorflow/tensorflow_session.h" #include "mediapipe/calculators/tensorflow/tensorflow_session.h"
#include "mediapipe/calculators/tensorflow/tensorflow_session_from_saved_model_calculator.pb.h" #include "mediapipe/calculators/tensorflow/tensorflow_session_from_saved_model_calculator.pb.h"
@ -20,7 +21,6 @@
#include "mediapipe/framework/calculator_runner.h" #include "mediapipe/framework/calculator_runner.h"
#include "mediapipe/framework/deps/file_path.h" #include "mediapipe/framework/deps/file_path.h"
#include "mediapipe/framework/packet.h" #include "mediapipe/framework/packet.h"
#include "mediapipe/framework/port/commandlineflags.h"
#include "mediapipe/framework/port/gmock.h" #include "mediapipe/framework/port/gmock.h"
#include "mediapipe/framework/port/gtest.h" #include "mediapipe/framework/port/gtest.h"
#include "mediapipe/framework/port/parse_text_proto.h" #include "mediapipe/framework/port/parse_text_proto.h"

View File

@ -12,6 +12,7 @@
// See the License for the specific language governing permissions and // See the License for the specific language governing permissions and
// limitations under the License. // limitations under the License.
#include "absl/flags/flag.h"
#include "absl/strings/str_replace.h" #include "absl/strings/str_replace.h"
#include "mediapipe/calculators/tensorflow/tensorflow_session.h" #include "mediapipe/calculators/tensorflow/tensorflow_session.h"
#include "mediapipe/calculators/tensorflow/tensorflow_session_from_saved_model_generator.pb.h" #include "mediapipe/calculators/tensorflow/tensorflow_session_from_saved_model_generator.pb.h"
@ -19,7 +20,6 @@
#include "mediapipe/framework/deps/file_path.h" #include "mediapipe/framework/deps/file_path.h"
#include "mediapipe/framework/packet.h" #include "mediapipe/framework/packet.h"
#include "mediapipe/framework/packet_generator.pb.h" #include "mediapipe/framework/packet_generator.pb.h"
#include "mediapipe/framework/port/commandlineflags.h"
#include "mediapipe/framework/port/gmock.h" #include "mediapipe/framework/port/gmock.h"
#include "mediapipe/framework/port/gtest.h" #include "mediapipe/framework/port/gtest.h"
#include "mediapipe/framework/port/parse_text_proto.h" #include "mediapipe/framework/port/parse_text_proto.h"

View File

@ -12,6 +12,7 @@
// See the License for the specific language governing permissions and // See the License for the specific language governing permissions and
// limitations under the License. // limitations under the License.
#include "absl/container/flat_hash_map.h"
#include "absl/strings/match.h" #include "absl/strings/match.h"
#include "mediapipe/calculators/core/packet_resampler_calculator.pb.h" #include "mediapipe/calculators/core/packet_resampler_calculator.pb.h"
#include "mediapipe/calculators/tensorflow/unpack_media_sequence_calculator.pb.h" #include "mediapipe/calculators/tensorflow/unpack_media_sequence_calculator.pb.h"

View File

@ -56,11 +56,4 @@ message UnpackMediaSequenceCalculatorOptions {
// the clip start and end times and outputs these for the // the clip start and end times and outputs these for the
// AudioDecoderCalculator to consume. // AudioDecoderCalculator to consume.
optional AudioDecoderOptions base_audio_decoder_options = 9; optional AudioDecoderOptions base_audio_decoder_options = 9;
optional string keypoint_names = 10 [
default =
"NOSE,LEFT_EAR,RIGHT_EAR,LEFT_SHOULDER,RIGHT_SHOULDER,LEFT_FORE_PAW,RIGHT_FORE_PAW,LEFT_HIP,RIGHT_HIP,LEFT_HIND_PAW,RIGHT_HIND_PAW,ROOT_TAIL"
];
// When the keypoint doesn't exists, output this default value.
optional float default_keypoint_location = 11 [default = -1.0];
} }

View File

@ -147,11 +147,11 @@ cc_test(
"//mediapipe/framework:calculator_runner", "//mediapipe/framework:calculator_runner",
"//mediapipe/framework/deps:file_path", "//mediapipe/framework/deps:file_path",
"//mediapipe/framework/formats/object_detection:anchor_cc_proto", "//mediapipe/framework/formats/object_detection:anchor_cc_proto",
"//mediapipe/framework/port:commandlineflags",
"//mediapipe/framework/port:file_helpers", "//mediapipe/framework/port:file_helpers",
"//mediapipe/framework/port:gtest_main", "//mediapipe/framework/port:gtest_main",
"//mediapipe/framework/port:integral_types", "//mediapipe/framework/port:integral_types",
"//mediapipe/framework/port:parse_text_proto", "//mediapipe/framework/port:parse_text_proto",
"@com_google_absl//absl/flags:flag",
], ],
) )

View File

@ -12,11 +12,11 @@
// See the License for the specific language governing permissions and // See the License for the specific language governing permissions and
// limitations under the License. // limitations under the License.
#include "absl/flags/flag.h"
#include "mediapipe/framework/calculator_framework.h" #include "mediapipe/framework/calculator_framework.h"
#include "mediapipe/framework/calculator_runner.h" #include "mediapipe/framework/calculator_runner.h"
#include "mediapipe/framework/deps/file_path.h" #include "mediapipe/framework/deps/file_path.h"
#include "mediapipe/framework/formats/object_detection/anchor.pb.h" #include "mediapipe/framework/formats/object_detection/anchor.pb.h"
#include "mediapipe/framework/port/commandlineflags.h"
#include "mediapipe/framework/port/file_helpers.h" #include "mediapipe/framework/port/file_helpers.h"
#include "mediapipe/framework/port/gmock.h" #include "mediapipe/framework/port/gmock.h"
#include "mediapipe/framework/port/gtest.h" #include "mediapipe/framework/port/gtest.h"

View File

@ -278,7 +278,7 @@ class TfLiteInferenceCalculator : public CalculatorBase {
bool use_advanced_gpu_api_ = false; bool use_advanced_gpu_api_ = false;
bool allow_precision_loss_ = false; bool allow_precision_loss_ = false;
mediapipe::TfLiteInferenceCalculatorOptions::Delegate::Gpu::API mediapipe::TfLiteInferenceCalculatorOptions::Delegate::Gpu::Api
tflite_gpu_runner_api_; tflite_gpu_runner_api_;
bool use_kernel_caching_ = false; bool use_kernel_caching_ = false;
@ -702,11 +702,16 @@ absl::Status TfLiteInferenceCalculator::InitTFLiteGPURunner(
#if MEDIAPIPE_TFLITE_GL_INFERENCE #if MEDIAPIPE_TFLITE_GL_INFERENCE
ASSIGN_OR_RETURN(model_packet_, GetModelAsPacket(*cc)); ASSIGN_OR_RETURN(model_packet_, GetModelAsPacket(*cc));
const auto& model = *model_packet_.Get<TfLiteModelPtr>(); const auto& model = *model_packet_.Get<TfLiteModelPtr>();
tflite::ops::builtin::BuiltinOpResolver op_resolver;
tflite::ops::builtin::BuiltinOpResolverWithoutDefaultDelegates
default_op_resolver;
auto op_resolver_ptr =
static_cast<const tflite::ops::builtin::BuiltinOpResolver*>(
&default_op_resolver);
if (cc->InputSidePackets().HasTag("CUSTOM_OP_RESOLVER")) { if (cc->InputSidePackets().HasTag("CUSTOM_OP_RESOLVER")) {
op_resolver = cc->InputSidePackets() op_resolver_ptr = &(cc->InputSidePackets()
.Tag("CUSTOM_OP_RESOLVER") .Tag("CUSTOM_OP_RESOLVER")
.Get<tflite::ops::builtin::BuiltinOpResolver>(); .Get<tflite::ops::builtin::BuiltinOpResolver>());
} }
// Create runner // Create runner
@ -733,7 +738,7 @@ absl::Status TfLiteInferenceCalculator::InitTFLiteGPURunner(
} }
} }
MP_RETURN_IF_ERROR( MP_RETURN_IF_ERROR(
tflite_gpu_runner_->InitializeWithModel(model, op_resolver)); tflite_gpu_runner_->InitializeWithModel(model, *op_resolver_ptr));
// Allocate interpreter memory for cpu output. // Allocate interpreter memory for cpu output.
if (!gpu_output_) { if (!gpu_output_) {
@ -786,18 +791,24 @@ absl::Status TfLiteInferenceCalculator::LoadModel(CalculatorContext* cc) {
ASSIGN_OR_RETURN(model_packet_, GetModelAsPacket(*cc)); ASSIGN_OR_RETURN(model_packet_, GetModelAsPacket(*cc));
const auto& model = *model_packet_.Get<TfLiteModelPtr>(); const auto& model = *model_packet_.Get<TfLiteModelPtr>();
tflite::ops::builtin::BuiltinOpResolver op_resolver;
tflite::ops::builtin::BuiltinOpResolverWithoutDefaultDelegates
default_op_resolver;
auto op_resolver_ptr =
static_cast<const tflite::ops::builtin::BuiltinOpResolver*>(
&default_op_resolver);
if (cc->InputSidePackets().HasTag("CUSTOM_OP_RESOLVER")) { if (cc->InputSidePackets().HasTag("CUSTOM_OP_RESOLVER")) {
op_resolver = cc->InputSidePackets() op_resolver_ptr = &(cc->InputSidePackets()
.Tag("CUSTOM_OP_RESOLVER") .Tag("CUSTOM_OP_RESOLVER")
.Get<tflite::ops::builtin::BuiltinOpResolver>(); .Get<tflite::ops::builtin::BuiltinOpResolver>());
} }
#if defined(MEDIAPIPE_EDGE_TPU) #if defined(MEDIAPIPE_EDGE_TPU)
interpreter_ = interpreter_ =
BuildEdgeTpuInterpreter(model, &op_resolver, edgetpu_context_.get()); BuildEdgeTpuInterpreter(model, op_resolver_ptr, edgetpu_context_.get());
#else #else
tflite::InterpreterBuilder(model, op_resolver)(&interpreter_); tflite::InterpreterBuilder(model, *op_resolver_ptr)(&interpreter_);
#endif // MEDIAPIPE_EDGE_TPU #endif // MEDIAPIPE_EDGE_TPU
RET_CHECK(interpreter_); RET_CHECK(interpreter_);

View File

@ -51,12 +51,12 @@ message TfLiteInferenceCalculatorOptions {
// This option is valid for TFLite GPU delegate API2 only, // This option is valid for TFLite GPU delegate API2 only,
// Choose any of available APIs to force running inference using it. // Choose any of available APIs to force running inference using it.
enum API { enum Api {
ANY = 0; ANY = 0;
OPENGL = 1; OPENGL = 1;
OPENCL = 2; OPENCL = 2;
} }
optional API api = 4 [default = ANY]; optional Api api = 4 [default = ANY];
// This option is valid for TFLite GPU delegate API2 only, // This option is valid for TFLite GPU delegate API2 only,
// Set to true to use 16-bit float precision. If max precision is needed, // Set to true to use 16-bit float precision. If max precision is needed,

View File

@ -841,12 +841,39 @@ cc_library(
"//mediapipe/framework:timestamp", "//mediapipe/framework:timestamp",
"//mediapipe/framework/formats:landmark_cc_proto", "//mediapipe/framework/formats:landmark_cc_proto",
"//mediapipe/framework/port:ret_check", "//mediapipe/framework/port:ret_check",
"//mediapipe/util/filtering:one_euro_filter",
"//mediapipe/util/filtering:relative_velocity_filter", "//mediapipe/util/filtering:relative_velocity_filter",
"@com_google_absl//absl/algorithm:container", "@com_google_absl//absl/algorithm:container",
], ],
alwayslink = 1, alwayslink = 1,
) )
mediapipe_proto_library(
name = "visibility_smoothing_calculator_proto",
srcs = ["visibility_smoothing_calculator.proto"],
visibility = ["//visibility:public"],
deps = [
"//mediapipe/framework:calculator_options_proto",
"//mediapipe/framework:calculator_proto",
],
)
cc_library(
name = "visibility_smoothing_calculator",
srcs = ["visibility_smoothing_calculator.cc"],
visibility = ["//visibility:public"],
deps = [
":visibility_smoothing_calculator_cc_proto",
"//mediapipe/framework:calculator_framework",
"//mediapipe/framework:timestamp",
"//mediapipe/framework/formats:landmark_cc_proto",
"//mediapipe/framework/port:ret_check",
"//mediapipe/util/filtering:low_pass_filter",
"@com_google_absl//absl/algorithm:container",
],
alwayslink = 1,
)
cc_library( cc_library(
name = "landmarks_to_floats_calculator", name = "landmarks_to_floats_calculator",
srcs = ["landmarks_to_floats_calculator.cc"], srcs = ["landmarks_to_floats_calculator.cc"],
@ -858,7 +885,7 @@ cc_library(
"//mediapipe/framework/formats:matrix", "//mediapipe/framework/formats:matrix",
"//mediapipe/framework/port:ret_check", "//mediapipe/framework/port:ret_check",
"//mediapipe/framework/port:status", "//mediapipe/framework/port:status",
"@eigen_archive//:eigen", "@eigen_archive//:eigen3",
], ],
alwayslink = 1, alwayslink = 1,
) )
@ -1194,3 +1221,34 @@ cc_library(
}), }),
alwayslink = 1, alwayslink = 1,
) )
cc_library(
name = "detection_classifications_merger_calculator",
srcs = ["detection_classifications_merger_calculator.cc"],
deps = [
"//mediapipe/framework:calculator_framework",
"//mediapipe/framework/api2:node",
"//mediapipe/framework/formats:classification_cc_proto",
"//mediapipe/framework/formats:detection_cc_proto",
"//mediapipe/framework/port:ret_check",
"//mediapipe/framework/port:status",
"//mediapipe/framework/port:statusor",
"@com_google_absl//absl/strings",
],
alwayslink = 1,
)
cc_test(
name = "detection_classifications_merger_calculator_test",
srcs = ["detection_classifications_merger_calculator_test.cc"],
deps = [
":detection_classifications_merger_calculator",
"//mediapipe/framework:calculator_framework",
"//mediapipe/framework:calculator_runner",
"//mediapipe/framework/deps:message_matchers",
"//mediapipe/framework/formats:classification_cc_proto",
"//mediapipe/framework/formats:detection_cc_proto",
"//mediapipe/framework/port:gtest_main",
"//mediapipe/framework/port:parse_text_proto",
],
)

View File

@ -0,0 +1,149 @@
// Copyright 2021 The MediaPipe Authors.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "absl/strings/substitute.h"
#include "mediapipe/framework/api2/node.h"
#include "mediapipe/framework/calculator_framework.h"
#include "mediapipe/framework/formats/classification.pb.h"
#include "mediapipe/framework/formats/detection.pb.h"
#include "mediapipe/framework/port/ret_check.h"
#include "mediapipe/framework/port/status.h"
#include "mediapipe/framework/port/statusor.h"
namespace mediapipe {
namespace api2 {
namespace {} // namespace
// Replaces the classification labels and scores from the input `Detection` with
// the ones provided into the input `ClassificationList`. Namely:
// * `label_id[i]` becomes `classification[i].index`
// * `score[i]` becomes `classification[i].score`
// * `label[i]` becomes `classification[i].label` (if present)
//
// In case the input `ClassificationList` contains no results (i.e.
// `classification` is empty, which may happen if the classifier uses a score
// threshold and no confident enough result were returned), the input
// `Detection` is returned unchanged.
//
// This is specifically designed for two-stage detection cascades where the
// detections returned by a standalone detector (typically a class-agnostic
// localizer) are fed e.g. into a `TfLiteTaskImageClassifierCalculator` through
// the optional "RECT" or "NORM_RECT" input, e.g:
//
// node {
// calculator: "DetectionsToRectsCalculator"
// # Output of an upstream object detector.
// input_stream: "DETECTION:detection"
// output_stream: "NORM_RECT:norm_rect"
// }
// node {
// calculator: "TfLiteTaskImageClassifierCalculator"
// input_stream: "IMAGE:image"
// input_stream: "NORM_RECT:norm_rect"
// output_stream: "CLASSIFICATION_RESULT:classification_result"
// }
// node {
// calculator: "TfLiteTaskClassificationResultToClassificationsCalculator"
// input_stream: "CLASSIFICATION_RESULT:classification_result"
// output_stream: "CLASSIFICATION_LIST:classification_list"
// }
// node {
// calculator: "DetectionClassificationsMergerCalculator"
// input_stream: "INPUT_DETECTION:detection"
// input_stream: "CLASSIFICATION_LIST:classification_list"
// # Final output.
// output_stream: "OUTPUT_DETECTION:classified_detection"
// }
//
// Inputs:
// INPUT_DETECTION: `Detection` proto.
// CLASSIFICATION_LIST: `ClassificationList` proto.
//
// Output:
// OUTPUT_DETECTION: modified `Detection` proto.
class DetectionClassificationsMergerCalculator : public Node {
public:
static constexpr Input<Detection> kInputDetection{"INPUT_DETECTION"};
static constexpr Input<ClassificationList> kClassificationList{
"CLASSIFICATION_LIST"};
static constexpr Output<Detection> kOutputDetection{"OUTPUT_DETECTION"};
MEDIAPIPE_NODE_CONTRACT(kInputDetection, kClassificationList,
kOutputDetection);
absl::Status Process(CalculatorContext* cc) override;
};
MEDIAPIPE_REGISTER_NODE(DetectionClassificationsMergerCalculator);
absl::Status DetectionClassificationsMergerCalculator::Process(
CalculatorContext* cc) {
if (kInputDetection(cc).IsEmpty() && kClassificationList(cc).IsEmpty()) {
return absl::OkStatus();
}
RET_CHECK(!kInputDetection(cc).IsEmpty());
RET_CHECK(!kClassificationList(cc).IsEmpty());
Detection detection = *kInputDetection(cc);
const ClassificationList& classification_list = *kClassificationList(cc);
// Update input detection only if classification did return results.
if (classification_list.classification_size() != 0) {
detection.clear_label_id();
detection.clear_score();
detection.clear_label();
detection.clear_display_name();
for (const auto& classification : classification_list.classification()) {
if (!classification.has_index()) {
return absl::InvalidArgumentError(
"Missing required 'index' field in Classification proto.");
}
detection.add_label_id(classification.index());
if (!classification.has_score()) {
return absl::InvalidArgumentError(
"Missing required 'score' field in Classification proto.");
}
detection.add_score(classification.score());
if (classification.has_label()) {
detection.add_label(classification.label());
}
if (classification.has_display_name()) {
detection.add_display_name(classification.display_name());
}
}
// Post-conversion sanity checks.
if (detection.label_size() != 0 &&
detection.label_size() != detection.label_id_size()) {
return absl::InvalidArgumentError(absl::Substitute(
"Each input Classification is expected to either always or never "
"provide a 'label' field. Found $0 'label' fields for $1 "
"'Classification' objects.",
/*$0=*/detection.label_size(), /*$1=*/detection.label_id_size()));
}
if (detection.display_name_size() != 0 &&
detection.display_name_size() != detection.label_id_size()) {
return absl::InvalidArgumentError(absl::Substitute(
"Each input Classification is expected to either always or never "
"provide a 'display_name' field. Found $0 'display_name' fields for "
"$1 'Classification' objects.",
/*$0=*/detection.display_name_size(),
/*$1=*/detection.label_id_size()));
}
}
kOutputDetection(cc).Send(detection);
return absl::OkStatus();
}
} // namespace api2
} // namespace mediapipe

View File

@ -0,0 +1,320 @@
// Copyright 2021 The MediaPipe Authors.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "mediapipe/framework/calculator_framework.h"
#include "mediapipe/framework/calculator_runner.h"
#include "mediapipe/framework/deps/message_matchers.h"
#include "mediapipe/framework/formats/classification.pb.h"
#include "mediapipe/framework/formats/detection.pb.h"
#include "mediapipe/framework/port/gtest.h"
#include "mediapipe/framework/port/parse_text_proto.h"
#include "mediapipe/framework/port/status_matchers.h"
namespace mediapipe {
namespace {
constexpr char kGraphConfig[] = R"(
input_stream: "input_detection"
input_stream: "classification_list"
output_stream: "output_detection"
node {
calculator: "DetectionClassificationsMergerCalculator"
input_stream: "INPUT_DETECTION:input_detection"
input_stream: "CLASSIFICATION_LIST:classification_list"
output_stream: "OUTPUT_DETECTION:output_detection"
}
)";
constexpr char kInputDetection[] = R"(
label: "entity"
label_id: 1
score: 0.9
location_data {
format: BOUNDING_BOX
bounding_box { xmin: 50 ymin: 60 width: 70 height: 80 }
}
display_name: "Entity"
)";
// Checks that the input Detection is returned unchanged if the input
// ClassificationList does not contain any result.
TEST(DetectionClassificationsMergerCalculator, SucceedsWithNoClassification) {
auto graph_config = ParseTextProtoOrDie<CalculatorGraphConfig>(kGraphConfig);
// Prepare input packets.
const Detection& input_detection =
ParseTextProtoOrDie<Detection>(kInputDetection);
Packet input_detection_packet =
MakePacket<Detection>(input_detection).At(Timestamp(0));
const ClassificationList& classification_list =
ParseTextProtoOrDie<ClassificationList>("");
Packet classification_list_packet =
MakePacket<ClassificationList>(classification_list).At(Timestamp(0));
// Catch output.
std::vector<Packet> output_packets;
tool::AddVectorSink("output_detection", &graph_config, &output_packets);
// Run the graph.
CalculatorGraph graph;
MP_ASSERT_OK(graph.Initialize(graph_config));
MP_ASSERT_OK(graph.StartRun({}));
MP_ASSERT_OK(
graph.AddPacketToInputStream("input_detection", input_detection_packet));
MP_ASSERT_OK(graph.AddPacketToInputStream("classification_list",
classification_list_packet));
MP_ASSERT_OK(graph.WaitUntilIdle());
// Get and validate output.
EXPECT_THAT(output_packets, testing::SizeIs(1));
const Detection& output_detection = output_packets[0].Get<Detection>();
EXPECT_THAT(output_detection, mediapipe::EqualsProto(input_detection));
}
// Checks that merging succeeds when the input ClassificationList includes
// labels and display names.
TEST(DetectionClassificationsMergerCalculator,
SucceedsWithLabelsAndDisplayNames) {
auto graph_config = ParseTextProtoOrDie<CalculatorGraphConfig>(kGraphConfig);
// Prepare input packets.
const Detection& input_detection =
ParseTextProtoOrDie<Detection>(kInputDetection);
Packet input_detection_packet =
MakePacket<Detection>(input_detection).At(Timestamp(0));
const ClassificationList& classification_list =
ParseTextProtoOrDie<ClassificationList>(R"(
classification { index: 11 score: 0.5 label: "dog" display_name: "Dog" }
classification { index: 12 score: 0.4 label: "fox" display_name: "Fox" }
)");
Packet classification_list_packet =
MakePacket<ClassificationList>(classification_list).At(Timestamp(0));
// Catch output.
std::vector<Packet> output_packets;
tool::AddVectorSink("output_detection", &graph_config, &output_packets);
// Run the graph.
CalculatorGraph graph;
MP_ASSERT_OK(graph.Initialize(graph_config));
MP_ASSERT_OK(graph.StartRun({}));
MP_ASSERT_OK(
graph.AddPacketToInputStream("input_detection", input_detection_packet));
MP_ASSERT_OK(graph.AddPacketToInputStream("classification_list",
classification_list_packet));
MP_ASSERT_OK(graph.WaitUntilIdle());
// Get and validate output.
EXPECT_THAT(output_packets, testing::SizeIs(1));
const Detection& output_detection = output_packets[0].Get<Detection>();
EXPECT_THAT(output_detection,
mediapipe::EqualsProto(ParseTextProtoOrDie<Detection>(R"(
label: "dog"
label: "fox"
label_id: 11
label_id: 12
score: 0.5
score: 0.4
location_data {
format: BOUNDING_BOX
bounding_box { xmin: 50 ymin: 60 width: 70 height: 80 }
}
display_name: "Dog"
display_name: "Fox"
)")));
}
// Checks that merging succeeds when the input ClassificationList doesn't
// include labels and display names.
TEST(DetectionClassificationsMergerCalculator,
SucceedsWithoutLabelsAndDisplayNames) {
auto graph_config = ParseTextProtoOrDie<CalculatorGraphConfig>(kGraphConfig);
// Prepare input packets.
const Detection& input_detection =
ParseTextProtoOrDie<Detection>(kInputDetection);
Packet input_detection_packet =
MakePacket<Detection>(input_detection).At(Timestamp(0));
const ClassificationList& classification_list =
ParseTextProtoOrDie<ClassificationList>(R"(
classification { index: 11 score: 0.5 }
classification { index: 12 score: 0.4 }
)");
Packet classification_list_packet =
MakePacket<ClassificationList>(classification_list).At(Timestamp(0));
// Catch output.
std::vector<Packet> output_packets;
tool::AddVectorSink("output_detection", &graph_config, &output_packets);
// Run the graph.
CalculatorGraph graph;
MP_ASSERT_OK(graph.Initialize(graph_config));
MP_ASSERT_OK(graph.StartRun({}));
MP_ASSERT_OK(
graph.AddPacketToInputStream("input_detection", input_detection_packet));
MP_ASSERT_OK(graph.AddPacketToInputStream("classification_list",
classification_list_packet));
MP_ASSERT_OK(graph.WaitUntilIdle());
// Get and validate output.
EXPECT_THAT(output_packets, testing::SizeIs(1));
const Detection& output_detection = output_packets[0].Get<Detection>();
EXPECT_THAT(output_detection,
mediapipe::EqualsProto(ParseTextProtoOrDie<Detection>(R"(
label_id: 11
label_id: 12
score: 0.5
score: 0.4
location_data {
format: BOUNDING_BOX
bounding_box { xmin: 50 ymin: 60 width: 70 height: 80 }
}
)")));
}
// Checks that merging fails if the input ClassificationList misses mandatory
// "index" field.
TEST(DetectionClassificationsMergerCalculator, FailsWithMissingIndex) {
auto graph_config = ParseTextProtoOrDie<CalculatorGraphConfig>(kGraphConfig);
// Prepare input packets.
const Detection& input_detection =
ParseTextProtoOrDie<Detection>(kInputDetection);
Packet input_detection_packet =
MakePacket<Detection>(input_detection).At(Timestamp(0));
const ClassificationList& classification_list =
ParseTextProtoOrDie<ClassificationList>(R"(
classification { score: 0.5 label: "dog" }
)");
Packet classification_list_packet =
MakePacket<ClassificationList>(classification_list).At(Timestamp(0));
// Catch output.
std::vector<Packet> output_packets;
tool::AddVectorSink("output_detection", &graph_config, &output_packets);
// Run the graph.
CalculatorGraph graph;
MP_ASSERT_OK(graph.Initialize(graph_config));
MP_ASSERT_OK(graph.StartRun({}));
MP_ASSERT_OK(
graph.AddPacketToInputStream("input_detection", input_detection_packet));
MP_ASSERT_OK(graph.AddPacketToInputStream("classification_list",
classification_list_packet));
ASSERT_EQ(graph.WaitUntilIdle().code(), absl::StatusCode::kInvalidArgument);
}
// Checks that merging fails if the input ClassificationList misses mandatory
// "score" field.
TEST(DetectionClassificationsMergerCalculator, FailsWithMissingScore) {
auto graph_config = ParseTextProtoOrDie<CalculatorGraphConfig>(kGraphConfig);
// Prepare input packets.
const Detection& input_detection =
ParseTextProtoOrDie<Detection>(kInputDetection);
Packet input_detection_packet =
MakePacket<Detection>(input_detection).At(Timestamp(0));
const ClassificationList& classification_list =
ParseTextProtoOrDie<ClassificationList>(R"(
classification { index: 11 label: "dog" }
)");
Packet classification_list_packet =
MakePacket<ClassificationList>(classification_list).At(Timestamp(0));
// Catch output.
std::vector<Packet> output_packets;
tool::AddVectorSink("output_detection", &graph_config, &output_packets);
// Run the graph.
CalculatorGraph graph;
MP_ASSERT_OK(graph.Initialize(graph_config));
MP_ASSERT_OK(graph.StartRun({}));
MP_ASSERT_OK(
graph.AddPacketToInputStream("input_detection", input_detection_packet));
MP_ASSERT_OK(graph.AddPacketToInputStream("classification_list",
classification_list_packet));
ASSERT_EQ(graph.WaitUntilIdle().code(), absl::StatusCode::kInvalidArgument);
}
// Checks that merging fails if the input ClassificationList has an
// inconsistent number of labels.
TEST(DetectionClassificationsMergerCalculator,
FailsWithInconsistentNumberOfLabels) {
auto graph_config = ParseTextProtoOrDie<CalculatorGraphConfig>(kGraphConfig);
// Prepare input packets.
const Detection& input_detection =
ParseTextProtoOrDie<Detection>(kInputDetection);
Packet input_detection_packet =
MakePacket<Detection>(input_detection).At(Timestamp(0));
const ClassificationList& classification_list =
ParseTextProtoOrDie<ClassificationList>(R"(
classification { index: 11 score: 0.5 label: "dog" display_name: "Dog" }
classification { index: 12 score: 0.4 display_name: "Fox" }
)");
Packet classification_list_packet =
MakePacket<ClassificationList>(classification_list).At(Timestamp(0));
// Catch output.
std::vector<Packet> output_packets;
tool::AddVectorSink("output_detection", &graph_config, &output_packets);
// Run the graph.
CalculatorGraph graph;
MP_ASSERT_OK(graph.Initialize(graph_config));
MP_ASSERT_OK(graph.StartRun({}));
MP_ASSERT_OK(
graph.AddPacketToInputStream("input_detection", input_detection_packet));
MP_ASSERT_OK(graph.AddPacketToInputStream("classification_list",
classification_list_packet));
ASSERT_EQ(graph.WaitUntilIdle().code(), absl::StatusCode::kInvalidArgument);
}
// Checks that merging fails if the input ClassificationList has an
// inconsistent number of display names.
TEST(DetectionClassificationsMergerCalculator,
FailsWithInconsistentNumberOfDisplayNames) {
auto graph_config = ParseTextProtoOrDie<CalculatorGraphConfig>(kGraphConfig);
// Prepare input packets.
const Detection& input_detection =
ParseTextProtoOrDie<Detection>(kInputDetection);
Packet input_detection_packet =
MakePacket<Detection>(input_detection).At(Timestamp(0));
const ClassificationList& classification_list =
ParseTextProtoOrDie<ClassificationList>(R"(
classification { index: 11 score: 0.5 label: "dog" }
classification { index: 12 score: 0.4 label: "fox" display_name: "Fox" }
)");
Packet classification_list_packet =
MakePacket<ClassificationList>(classification_list).At(Timestamp(0));
// Catch output.
std::vector<Packet> output_packets;
tool::AddVectorSink("output_detection", &graph_config, &output_packets);
// Run the graph.
CalculatorGraph graph;
MP_ASSERT_OK(graph.Initialize(graph_config));
MP_ASSERT_OK(graph.StartRun({}));
MP_ASSERT_OK(
graph.AddPacketToInputStream("input_detection", input_detection_packet));
MP_ASSERT_OK(graph.AddPacketToInputStream("classification_list",
classification_list_packet));
ASSERT_EQ(graph.WaitUntilIdle().code(), absl::StatusCode::kInvalidArgument);
}
} // namespace
} // namespace mediapipe

View File

@ -12,12 +12,15 @@
// See the License for the specific language governing permissions and // See the License for the specific language governing permissions and
// limitations under the License. // limitations under the License.
#include <memory>
#include "absl/algorithm/container.h" #include "absl/algorithm/container.h"
#include "mediapipe/calculators/util/landmarks_smoothing_calculator.pb.h" #include "mediapipe/calculators/util/landmarks_smoothing_calculator.pb.h"
#include "mediapipe/framework/calculator_framework.h" #include "mediapipe/framework/calculator_framework.h"
#include "mediapipe/framework/formats/landmark.pb.h" #include "mediapipe/framework/formats/landmark.pb.h"
#include "mediapipe/framework/port/ret_check.h" #include "mediapipe/framework/port/ret_check.h"
#include "mediapipe/framework/timestamp.h" #include "mediapipe/framework/timestamp.h"
#include "mediapipe/util/filtering/one_euro_filter.h"
#include "mediapipe/util/filtering/relative_velocity_filter.h" #include "mediapipe/util/filtering/relative_velocity_filter.h"
namespace mediapipe { namespace mediapipe {
@ -25,19 +28,54 @@ namespace mediapipe {
namespace { namespace {
constexpr char kNormalizedLandmarksTag[] = "NORM_LANDMARKS"; constexpr char kNormalizedLandmarksTag[] = "NORM_LANDMARKS";
constexpr char kLandmarksTag[] = "LANDMARKS";
constexpr char kImageSizeTag[] = "IMAGE_SIZE"; constexpr char kImageSizeTag[] = "IMAGE_SIZE";
constexpr char kNormalizedFilteredLandmarksTag[] = "NORM_FILTERED_LANDMARKS"; constexpr char kNormalizedFilteredLandmarksTag[] = "NORM_FILTERED_LANDMARKS";
constexpr char kFilteredLandmarksTag[] = "FILTERED_LANDMARKS";
using mediapipe::OneEuroFilter;
using mediapipe::RelativeVelocityFilter; using mediapipe::RelativeVelocityFilter;
void NormalizedLandmarksToLandmarks(
const NormalizedLandmarkList& norm_landmarks, const int image_width,
const int image_height, LandmarkList* landmarks) {
for (int i = 0; i < norm_landmarks.landmark_size(); ++i) {
const auto& norm_landmark = norm_landmarks.landmark(i);
auto* landmark = landmarks->add_landmark();
landmark->set_x(norm_landmark.x() * image_width);
landmark->set_y(norm_landmark.y() * image_height);
// Scale Z the same way as X (using image width).
landmark->set_z(norm_landmark.z() * image_width);
landmark->set_visibility(norm_landmark.visibility());
landmark->set_presence(norm_landmark.presence());
}
}
void LandmarksToNormalizedLandmarks(const LandmarkList& landmarks,
const int image_width,
const int image_height,
NormalizedLandmarkList* norm_landmarks) {
for (int i = 0; i < landmarks.landmark_size(); ++i) {
const auto& landmark = landmarks.landmark(i);
auto* norm_landmark = norm_landmarks->add_landmark();
norm_landmark->set_x(landmark.x() / image_width);
norm_landmark->set_y(landmark.y() / image_height);
// Scale Z the same way as X (using image width).
norm_landmark->set_z(landmark.z() / image_width);
norm_landmark->set_visibility(landmark.visibility());
norm_landmark->set_presence(landmark.presence());
}
}
// Estimate object scale to use its inverse value as velocity scale for // Estimate object scale to use its inverse value as velocity scale for
// RelativeVelocityFilter. If value will be too small (less than // RelativeVelocityFilter. If value will be too small (less than
// `options_.min_allowed_object_scale`) smoothing will be disabled and // `options_.min_allowed_object_scale`) smoothing will be disabled and
// landmarks will be returned as is. // landmarks will be returned as is.
// Object scale is calculated as average between bounding box width and height // Object scale is calculated as average between bounding box width and height
// with sides parallel to axis. // with sides parallel to axis.
float GetObjectScale(const NormalizedLandmarkList& landmarks, int image_width, float GetObjectScale(const LandmarkList& landmarks) {
int image_height) {
const auto& lm_minmax_x = absl::c_minmax_element( const auto& lm_minmax_x = absl::c_minmax_element(
landmarks.landmark(), landmarks.landmark(),
[](const auto& a, const auto& b) { return a.x() < b.x(); }); [](const auto& a, const auto& b) { return a.x() < b.x(); });
@ -50,8 +88,8 @@ float GetObjectScale(const NormalizedLandmarkList& landmarks, int image_width,
const float y_min = lm_minmax_y.first->y(); const float y_min = lm_minmax_y.first->y();
const float y_max = lm_minmax_y.second->y(); const float y_max = lm_minmax_y.second->y();
const float object_width = (x_max - x_min) * image_width; const float object_width = x_max - x_min;
const float object_height = (y_max - y_min) * image_height; const float object_height = y_max - y_min;
return (object_width + object_height) / 2.0f; return (object_width + object_height) / 2.0f;
} }
@ -63,19 +101,17 @@ class LandmarksFilter {
virtual absl::Status Reset() { return absl::OkStatus(); } virtual absl::Status Reset() { return absl::OkStatus(); }
virtual absl::Status Apply(const NormalizedLandmarkList& in_landmarks, virtual absl::Status Apply(const LandmarkList& in_landmarks,
const std::pair<int, int>& image_size,
const absl::Duration& timestamp, const absl::Duration& timestamp,
NormalizedLandmarkList* out_landmarks) = 0; LandmarkList* out_landmarks) = 0;
}; };
// Returns landmarks as is without smoothing. // Returns landmarks as is without smoothing.
class NoFilter : public LandmarksFilter { class NoFilter : public LandmarksFilter {
public: public:
absl::Status Apply(const NormalizedLandmarkList& in_landmarks, absl::Status Apply(const LandmarkList& in_landmarks,
const std::pair<int, int>& image_size,
const absl::Duration& timestamp, const absl::Duration& timestamp,
NormalizedLandmarkList* out_landmarks) override { LandmarkList* out_landmarks) override {
*out_landmarks = in_landmarks; *out_landmarks = in_landmarks;
return absl::OkStatus(); return absl::OkStatus();
} }
@ -85,10 +121,11 @@ class NoFilter : public LandmarksFilter {
class VelocityFilter : public LandmarksFilter { class VelocityFilter : public LandmarksFilter {
public: public:
VelocityFilter(int window_size, float velocity_scale, VelocityFilter(int window_size, float velocity_scale,
float min_allowed_object_scale) float min_allowed_object_scale, bool disable_value_scaling)
: window_size_(window_size), : window_size_(window_size),
velocity_scale_(velocity_scale), velocity_scale_(velocity_scale),
min_allowed_object_scale_(min_allowed_object_scale) {} min_allowed_object_scale_(min_allowed_object_scale),
disable_value_scaling_(disable_value_scaling) {}
absl::Status Reset() override { absl::Status Reset() override {
x_filters_.clear(); x_filters_.clear();
@ -97,45 +134,37 @@ class VelocityFilter : public LandmarksFilter {
return absl::OkStatus(); return absl::OkStatus();
} }
absl::Status Apply(const NormalizedLandmarkList& in_landmarks, absl::Status Apply(const LandmarkList& in_landmarks,
const std::pair<int, int>& image_size,
const absl::Duration& timestamp, const absl::Duration& timestamp,
NormalizedLandmarkList* out_landmarks) override { LandmarkList* out_landmarks) override {
// Get image size.
int image_width;
int image_height;
std::tie(image_width, image_height) = image_size;
// Get value scale as inverse value of the object scale. // Get value scale as inverse value of the object scale.
// If value is too small smoothing will be disabled and landmarks will be // If value is too small smoothing will be disabled and landmarks will be
// returned as is. // returned as is.
const float object_scale = float value_scale = 1.0f;
GetObjectScale(in_landmarks, image_width, image_height); if (!disable_value_scaling_) {
const float object_scale = GetObjectScale(in_landmarks);
if (object_scale < min_allowed_object_scale_) { if (object_scale < min_allowed_object_scale_) {
*out_landmarks = in_landmarks; *out_landmarks = in_landmarks;
return absl::OkStatus(); return absl::OkStatus();
} }
const float value_scale = 1.0f / object_scale; value_scale = 1.0f / object_scale;
}
// Initialize filters once. // Initialize filters once.
MP_RETURN_IF_ERROR(InitializeFiltersIfEmpty(in_landmarks.landmark_size())); MP_RETURN_IF_ERROR(InitializeFiltersIfEmpty(in_landmarks.landmark_size()));
// Filter landmarks. Every axis of every landmark is filtered separately. // Filter landmarks. Every axis of every landmark is filtered separately.
for (int i = 0; i < in_landmarks.landmark_size(); ++i) { for (int i = 0; i < in_landmarks.landmark_size(); ++i) {
const NormalizedLandmark& in_landmark = in_landmarks.landmark(i); const auto& in_landmark = in_landmarks.landmark(i);
NormalizedLandmark* out_landmark = out_landmarks->add_landmark(); auto* out_landmark = out_landmarks->add_landmark();
*out_landmark = in_landmark; *out_landmark = in_landmark;
out_landmark->set_x(x_filters_[i].Apply(timestamp, value_scale, out_landmark->set_x(
in_landmark.x() * image_width) / x_filters_[i].Apply(timestamp, value_scale, in_landmark.x()));
image_width); out_landmark->set_y(
out_landmark->set_y(y_filters_[i].Apply(timestamp, value_scale, y_filters_[i].Apply(timestamp, value_scale, in_landmark.y()));
in_landmark.y() * image_height) / out_landmark->set_z(
image_height); z_filters_[i].Apply(timestamp, value_scale, in_landmark.z()));
// Scale Z the save was as X (using image width).
out_landmark->set_z(z_filters_[i].Apply(timestamp, value_scale,
in_landmark.z() * image_width) /
image_width);
} }
return absl::OkStatus(); return absl::OkStatus();
@ -165,12 +194,83 @@ class VelocityFilter : public LandmarksFilter {
int window_size_; int window_size_;
float velocity_scale_; float velocity_scale_;
float min_allowed_object_scale_; float min_allowed_object_scale_;
bool disable_value_scaling_;
std::vector<RelativeVelocityFilter> x_filters_; std::vector<RelativeVelocityFilter> x_filters_;
std::vector<RelativeVelocityFilter> y_filters_; std::vector<RelativeVelocityFilter> y_filters_;
std::vector<RelativeVelocityFilter> z_filters_; std::vector<RelativeVelocityFilter> z_filters_;
}; };
// Please check OneEuroFilter documentation for details.
class OneEuroFilterImpl : public LandmarksFilter {
public:
OneEuroFilterImpl(double frequency, double min_cutoff, double beta,
double derivate_cutoff)
: frequency_(frequency),
min_cutoff_(min_cutoff),
beta_(beta),
derivate_cutoff_(derivate_cutoff) {}
absl::Status Reset() override {
x_filters_.clear();
y_filters_.clear();
z_filters_.clear();
return absl::OkStatus();
}
absl::Status Apply(const LandmarkList& in_landmarks,
const absl::Duration& timestamp,
LandmarkList* out_landmarks) override {
// Initialize filters once.
MP_RETURN_IF_ERROR(InitializeFiltersIfEmpty(in_landmarks.landmark_size()));
// Filter landmarks. Every axis of every landmark is filtered separately.
for (int i = 0; i < in_landmarks.landmark_size(); ++i) {
const auto& in_landmark = in_landmarks.landmark(i);
auto* out_landmark = out_landmarks->add_landmark();
*out_landmark = in_landmark;
out_landmark->set_x(x_filters_[i].Apply(timestamp, in_landmark.x()));
out_landmark->set_y(y_filters_[i].Apply(timestamp, in_landmark.y()));
out_landmark->set_z(z_filters_[i].Apply(timestamp, in_landmark.z()));
}
return absl::OkStatus();
}
private:
// Initializes filters for the first time or after Reset. If initialized then
// check the size.
absl::Status InitializeFiltersIfEmpty(const int n_landmarks) {
if (!x_filters_.empty()) {
RET_CHECK_EQ(x_filters_.size(), n_landmarks);
RET_CHECK_EQ(y_filters_.size(), n_landmarks);
RET_CHECK_EQ(z_filters_.size(), n_landmarks);
return absl::OkStatus();
}
for (int i = 0; i < n_landmarks; ++i) {
x_filters_.push_back(
OneEuroFilter(frequency_, min_cutoff_, beta_, derivate_cutoff_));
y_filters_.push_back(
OneEuroFilter(frequency_, min_cutoff_, beta_, derivate_cutoff_));
z_filters_.push_back(
OneEuroFilter(frequency_, min_cutoff_, beta_, derivate_cutoff_));
}
return absl::OkStatus();
}
double frequency_;
double min_cutoff_;
double beta_;
double derivate_cutoff_;
std::vector<OneEuroFilter> x_filters_;
std::vector<OneEuroFilter> y_filters_;
std::vector<OneEuroFilter> z_filters_;
};
} // namespace } // namespace
// A calculator to smooth landmarks over time. // A calculator to smooth landmarks over time.
@ -207,16 +307,21 @@ class LandmarksSmoothingCalculator : public CalculatorBase {
absl::Status Process(CalculatorContext* cc) override; absl::Status Process(CalculatorContext* cc) override;
private: private:
LandmarksFilter* landmarks_filter_; std::unique_ptr<LandmarksFilter> landmarks_filter_;
}; };
REGISTER_CALCULATOR(LandmarksSmoothingCalculator); REGISTER_CALCULATOR(LandmarksSmoothingCalculator);
absl::Status LandmarksSmoothingCalculator::GetContract(CalculatorContract* cc) { absl::Status LandmarksSmoothingCalculator::GetContract(CalculatorContract* cc) {
if (cc->Inputs().HasTag(kNormalizedLandmarksTag)) {
cc->Inputs().Tag(kNormalizedLandmarksTag).Set<NormalizedLandmarkList>(); cc->Inputs().Tag(kNormalizedLandmarksTag).Set<NormalizedLandmarkList>();
cc->Inputs().Tag(kImageSizeTag).Set<std::pair<int, int>>(); cc->Inputs().Tag(kImageSizeTag).Set<std::pair<int, int>>();
cc->Outputs() cc->Outputs()
.Tag(kNormalizedFilteredLandmarksTag) .Tag(kNormalizedFilteredLandmarksTag)
.Set<NormalizedLandmarkList>(); .Set<NormalizedLandmarkList>();
} else {
cc->Inputs().Tag(kLandmarksTag).Set<LandmarkList>();
cc->Outputs().Tag(kFilteredLandmarksTag).Set<LandmarkList>();
}
return absl::OkStatus(); return absl::OkStatus();
} }
@ -227,12 +332,19 @@ absl::Status LandmarksSmoothingCalculator::Open(CalculatorContext* cc) {
// Pick landmarks filter. // Pick landmarks filter.
const auto& options = cc->Options<LandmarksSmoothingCalculatorOptions>(); const auto& options = cc->Options<LandmarksSmoothingCalculatorOptions>();
if (options.has_no_filter()) { if (options.has_no_filter()) {
landmarks_filter_ = new NoFilter(); landmarks_filter_ = absl::make_unique<NoFilter>();
} else if (options.has_velocity_filter()) { } else if (options.has_velocity_filter()) {
landmarks_filter_ = new VelocityFilter( landmarks_filter_ = absl::make_unique<VelocityFilter>(
options.velocity_filter().window_size(), options.velocity_filter().window_size(),
options.velocity_filter().velocity_scale(), options.velocity_filter().velocity_scale(),
options.velocity_filter().min_allowed_object_scale()); options.velocity_filter().min_allowed_object_scale(),
options.velocity_filter().disable_value_scaling());
} else if (options.has_one_euro_filter()) {
landmarks_filter_ = absl::make_unique<OneEuroFilterImpl>(
options.one_euro_filter().frequency(),
options.one_euro_filter().min_cutoff(),
options.one_euro_filter().beta(),
options.one_euro_filter().derivate_cutoff());
} else { } else {
RET_CHECK_FAIL() RET_CHECK_FAIL()
<< "Landmarks filter is either not specified or not supported"; << "Landmarks filter is either not specified or not supported";
@ -244,25 +356,53 @@ absl::Status LandmarksSmoothingCalculator::Open(CalculatorContext* cc) {
absl::Status LandmarksSmoothingCalculator::Process(CalculatorContext* cc) { absl::Status LandmarksSmoothingCalculator::Process(CalculatorContext* cc) {
// Check that landmarks are not empty and reset the filter if so. // Check that landmarks are not empty and reset the filter if so.
// Don't emit an empty packet for this timestamp. // Don't emit an empty packet for this timestamp.
if (cc->Inputs().Tag(kNormalizedLandmarksTag).IsEmpty()) { if ((cc->Inputs().HasTag(kNormalizedLandmarksTag) &&
cc->Inputs().Tag(kNormalizedLandmarksTag).IsEmpty()) ||
(cc->Inputs().HasTag(kLandmarksTag) &&
cc->Inputs().Tag(kLandmarksTag).IsEmpty())) {
MP_RETURN_IF_ERROR(landmarks_filter_->Reset()); MP_RETURN_IF_ERROR(landmarks_filter_->Reset());
return absl::OkStatus(); return absl::OkStatus();
} }
const auto& in_landmarks =
cc->Inputs().Tag(kNormalizedLandmarksTag).Get<NormalizedLandmarkList>();
const auto& image_size =
cc->Inputs().Tag(kImageSizeTag).Get<std::pair<int, int>>();
const auto& timestamp = const auto& timestamp =
absl::Microseconds(cc->InputTimestamp().Microseconds()); absl::Microseconds(cc->InputTimestamp().Microseconds());
auto out_landmarks = absl::make_unique<NormalizedLandmarkList>(); if (cc->Inputs().HasTag(kNormalizedLandmarksTag)) {
MP_RETURN_IF_ERROR(landmarks_filter_->Apply(in_landmarks, image_size, const auto& in_norm_landmarks =
timestamp, out_landmarks.get())); cc->Inputs().Tag(kNormalizedLandmarksTag).Get<NormalizedLandmarkList>();
int image_width;
int image_height;
std::tie(image_width, image_height) =
cc->Inputs().Tag(kImageSizeTag).Get<std::pair<int, int>>();
auto in_landmarks = absl::make_unique<LandmarkList>();
NormalizedLandmarksToLandmarks(in_norm_landmarks, image_width, image_height,
in_landmarks.get());
auto out_landmarks = absl::make_unique<LandmarkList>();
MP_RETURN_IF_ERROR(landmarks_filter_->Apply(*in_landmarks, timestamp,
out_landmarks.get()));
auto out_norm_landmarks = absl::make_unique<NormalizedLandmarkList>();
LandmarksToNormalizedLandmarks(*out_landmarks, image_width, image_height,
out_norm_landmarks.get());
cc->Outputs() cc->Outputs()
.Tag(kNormalizedFilteredLandmarksTag) .Tag(kNormalizedFilteredLandmarksTag)
.Add(out_norm_landmarks.release(), cc->InputTimestamp());
} else {
const auto& in_landmarks =
cc->Inputs().Tag(kLandmarksTag).Get<LandmarkList>();
auto out_landmarks = absl::make_unique<LandmarkList>();
MP_RETURN_IF_ERROR(
landmarks_filter_->Apply(in_landmarks, timestamp, out_landmarks.get()));
cc->Outputs()
.Tag(kFilteredLandmarksTag)
.Add(out_landmarks.release(), cc->InputTimestamp()); .Add(out_landmarks.release(), cc->InputTimestamp());
}
return absl::OkStatus(); return absl::OkStatus();
} }

View File

@ -39,10 +39,40 @@ message LandmarksSmoothingCalculatorOptions {
// If calculated object scale is less than given value smoothing will be // If calculated object scale is less than given value smoothing will be
// disabled and landmarks will be returned as is. // disabled and landmarks will be returned as is.
optional float min_allowed_object_scale = 3 [default = 1e-6]; optional float min_allowed_object_scale = 3 [default = 1e-6];
// Disable value scaling based on object size and use `1.0` instead.
// Value scale is calculated as inverse value of object size. Object size is
// calculated as maximum side of rectangular bounding box of the object in
// XY plane.
optional bool disable_value_scaling = 4 [default = false];
}
// For the details of the filter implementation and the procedure of its
// configuration please check http://cristal.univ-lille.fr/~casiez/1euro/
message OneEuroFilter {
// Frequency of incomming frames defined in seconds. Used only if can't be
// calculated from provided events (e.g. on the very first frame).
optional float frequency = 1 [default = 0.033];
// Minimum cutoff frequency. Start by tuning this parameter while keeping
// `beta = 0` to reduce jittering to the desired level. 1Hz (the default
// value) is a good starting point.
optional float min_cutoff = 2 [default = 1.0];
// Cutoff slope. After `min_cutoff` is configured, start increasing `beta`
// value to reduce the lag introduced by the `min_cutoff`. Find the desired
// balance between jittering and lag.
optional float beta = 3 [default = 0.0];
// Cutoff frequency for derivate. It is set to 1Hz in the original
// algorithm, but can be tuned to further smooth the speed (i.e. derivate)
// on the object.
optional float derivate_cutoff = 4 [default = 1.0];
} }
oneof filter_options { oneof filter_options {
NoFilter no_filter = 1; NoFilter no_filter = 1;
VelocityFilter velocity_filter = 2; VelocityFilter velocity_filter = 2;
OneEuroFilter one_euro_filter = 3;
} }
} }

View File

@ -34,6 +34,34 @@ constexpr char kRenderScaleTag[] = "RENDER_SCALE";
constexpr char kRenderDataTag[] = "RENDER_DATA"; constexpr char kRenderDataTag[] = "RENDER_DATA";
constexpr char kLandmarkLabel[] = "KEYPOINT"; constexpr char kLandmarkLabel[] = "KEYPOINT";
inline Color DefaultMinDepthLineColor() {
Color color;
color.set_r(0);
color.set_g(0);
color.set_b(0);
return color;
}
inline Color DefaultMaxDepthLineColor() {
Color color;
color.set_r(255);
color.set_g(255);
color.set_b(255);
return color;
}
inline Color MixColors(const Color& color1, const Color& color2,
float color1_weight) {
Color color;
color.set_r(static_cast<int>(color1.r() * color1_weight +
color2.r() * (1.f - color1_weight)));
color.set_g(static_cast<int>(color1.g() * color1_weight +
color2.g() * (1.f - color1_weight)));
color.set_b(static_cast<int>(color1.b() * color1_weight +
color2.b() * (1.f - color1_weight)));
return color;
}
inline void SetColor(RenderAnnotation* annotation, const Color& color) { inline void SetColor(RenderAnnotation* annotation, const Color& color) {
annotation->mutable_color()->set_r(color.r()); annotation->mutable_color()->set_r(color.r());
annotation->mutable_color()->set_g(color.g()); annotation->mutable_color()->set_g(color.g());
@ -57,6 +85,23 @@ inline void GetMinMaxZ(const LandmarkListType& landmarks, float* z_min,
} }
} }
template <class LandmarkType>
bool IsLandmarkVisibileAndPresent(const LandmarkType& landmark,
bool utilize_visibility,
float visibility_threshold,
bool utilize_presence,
float presence_threshold) {
if (utilize_visibility && landmark.has_visibility() &&
landmark.visibility() < visibility_threshold) {
return false;
}
if (utilize_presence && landmark.has_presence() &&
landmark.presence() < presence_threshold) {
return false;
}
return true;
}
void SetColorSizeValueFromZ(float z, float z_min, float z_max, void SetColorSizeValueFromZ(float z, float z_min, float z_max,
RenderAnnotation* render_annotation, RenderAnnotation* render_annotation,
float min_depth_circle_thickness, float min_depth_circle_thickness,
@ -75,8 +120,9 @@ void SetColorSizeValueFromZ(float z, float z_min, float z_max,
template <class LandmarkType> template <class LandmarkType>
void AddConnectionToRenderData(const LandmarkType& start, void AddConnectionToRenderData(const LandmarkType& start,
const LandmarkType& end, int gray_val1, const LandmarkType& end,
int gray_val2, float thickness, bool normalized, const Color& color_start, const Color& color_end,
float thickness, bool normalized,
RenderData* render_data) { RenderData* render_data) {
auto* connection_annotation = render_data->add_render_annotations(); auto* connection_annotation = render_data->add_render_annotations();
RenderAnnotation::GradientLine* line = RenderAnnotation::GradientLine* line =
@ -86,12 +132,13 @@ void AddConnectionToRenderData(const LandmarkType& start,
line->set_x_end(end.x()); line->set_x_end(end.x());
line->set_y_end(end.y()); line->set_y_end(end.y());
line->set_normalized(normalized); line->set_normalized(normalized);
line->mutable_color1()->set_r(gray_val1); line->mutable_color1()->set_r(color_start.r());
line->mutable_color1()->set_g(gray_val1); line->mutable_color1()->set_g(color_start.g());
line->mutable_color1()->set_b(gray_val1); line->mutable_color1()->set_b(color_start.b());
line->mutable_color2()->set_r(gray_val2); line->mutable_color2()->set_r(color_end.r());
line->mutable_color2()->set_g(gray_val2); line->mutable_color2()->set_g(color_end.g());
line->mutable_color2()->set_b(gray_val2); line->mutable_color2()->set_b(color_end.b());
connection_annotation->set_thickness(thickness); connection_annotation->set_thickness(thickness);
} }
@ -102,26 +149,26 @@ void AddConnectionsWithDepth(const LandmarkListType& landmarks,
float visibility_threshold, bool utilize_presence, float visibility_threshold, bool utilize_presence,
float presence_threshold, float thickness, float presence_threshold, float thickness,
bool normalized, float min_z, float max_z, bool normalized, float min_z, float max_z,
const Color& min_depth_line_color,
const Color& max_depth_line_color,
RenderData* render_data) { RenderData* render_data) {
for (int i = 0; i < landmark_connections.size(); i += 2) { for (int i = 0; i < landmark_connections.size(); i += 2) {
const auto& ld0 = landmarks.landmark(landmark_connections[i]); const auto& ld0 = landmarks.landmark(landmark_connections[i]);
const auto& ld1 = landmarks.landmark(landmark_connections[i + 1]); const auto& ld1 = landmarks.landmark(landmark_connections[i + 1]);
if (utilize_visibility && if (!IsLandmarkVisibileAndPresent<LandmarkType>(
((ld0.has_visibility() && ld0.visibility() < visibility_threshold) || ld0, utilize_visibility, visibility_threshold, utilize_presence,
(ld1.has_visibility() && ld1.visibility() < visibility_threshold))) { presence_threshold) ||
!IsLandmarkVisibileAndPresent<LandmarkType>(
ld1, utilize_visibility, visibility_threshold, utilize_presence,
presence_threshold)) {
continue; continue;
} }
if (utilize_presence && const Color color0 = MixColors(min_depth_line_color, max_depth_line_color,
((ld0.has_presence() && ld0.presence() < presence_threshold) || Remap(ld0.z(), min_z, max_z, 1.f));
(ld1.has_presence() && ld1.presence() < presence_threshold))) { const Color color1 = MixColors(min_depth_line_color, max_depth_line_color,
continue; Remap(ld1.z(), min_z, max_z, 1.f));
} AddConnectionToRenderData<LandmarkType>(ld0, ld1, color0, color1, thickness,
const int gray_val1 = normalized, render_data);
255 - static_cast<int>(Remap(ld0.z(), min_z, max_z, 255));
const int gray_val2 =
255 - static_cast<int>(Remap(ld1.z(), min_z, max_z, 255));
AddConnectionToRenderData<LandmarkType>(ld0, ld1, gray_val1, gray_val2,
thickness, normalized, render_data);
} }
} }
@ -151,14 +198,12 @@ void AddConnections(const LandmarkListType& landmarks,
for (int i = 0; i < landmark_connections.size(); i += 2) { for (int i = 0; i < landmark_connections.size(); i += 2) {
const auto& ld0 = landmarks.landmark(landmark_connections[i]); const auto& ld0 = landmarks.landmark(landmark_connections[i]);
const auto& ld1 = landmarks.landmark(landmark_connections[i + 1]); const auto& ld1 = landmarks.landmark(landmark_connections[i + 1]);
if (utilize_visibility && if (!IsLandmarkVisibileAndPresent<LandmarkType>(
((ld0.has_visibility() && ld0.visibility() < visibility_threshold) || ld0, utilize_visibility, visibility_threshold, utilize_presence,
(ld1.has_visibility() && ld1.visibility() < visibility_threshold))) { presence_threshold) ||
continue; !IsLandmarkVisibileAndPresent<LandmarkType>(
} ld1, utilize_visibility, visibility_threshold, utilize_presence,
if (utilize_presence && presence_threshold)) {
((ld0.has_presence() && ld0.presence() < presence_threshold) ||
(ld1.has_presence() && ld1.presence() < presence_threshold))) {
continue; continue;
} }
AddConnectionToRenderData<LandmarkType>(ld0, ld1, connection_color, AddConnectionToRenderData<LandmarkType>(ld0, ld1, connection_color,
@ -232,6 +277,13 @@ absl::Status LandmarksToRenderDataCalculator::Process(CalculatorContext* cc) {
float z_min = 0.f; float z_min = 0.f;
float z_max = 0.f; float z_max = 0.f;
const Color min_depth_line_color = options_.has_min_depth_line_color()
? options_.min_depth_line_color()
: DefaultMinDepthLineColor();
const Color max_depth_line_color = options_.has_max_depth_line_color()
? options_.max_depth_line_color()
: DefaultMaxDepthLineColor();
// Apply scale to `thickness` of rendered landmarks and connections to make // Apply scale to `thickness` of rendered landmarks and connections to make
// them bigger when object (e.g. pose, hand or face) is closer/bigger and // them bigger when object (e.g. pose, hand or face) is closer/bigger and
// snaller when object is further/smaller. // snaller when object is further/smaller.
@ -254,7 +306,7 @@ absl::Status LandmarksToRenderDataCalculator::Process(CalculatorContext* cc) {
landmarks, landmark_connections_, options_.utilize_visibility(), landmarks, landmark_connections_, options_.utilize_visibility(),
options_.visibility_threshold(), options_.utilize_presence(), options_.visibility_threshold(), options_.utilize_presence(),
options_.presence_threshold(), thickness, /*normalized=*/false, z_min, options_.presence_threshold(), thickness, /*normalized=*/false, z_min,
z_max, render_data.get()); z_max, min_depth_line_color, max_depth_line_color, render_data.get());
} else { } else {
AddConnections<LandmarkList, Landmark>( AddConnections<LandmarkList, Landmark>(
landmarks, landmark_connections_, options_.utilize_visibility(), landmarks, landmark_connections_, options_.utilize_visibility(),
@ -265,13 +317,10 @@ absl::Status LandmarksToRenderDataCalculator::Process(CalculatorContext* cc) {
for (int i = 0; i < landmarks.landmark_size(); ++i) { for (int i = 0; i < landmarks.landmark_size(); ++i) {
const Landmark& landmark = landmarks.landmark(i); const Landmark& landmark = landmarks.landmark(i);
if (options_.utilize_visibility() && landmark.has_visibility() && if (!IsLandmarkVisibileAndPresent<Landmark>(
landmark.visibility() < options_.visibility_threshold()) { landmark, options_.utilize_visibility(),
continue; options_.visibility_threshold(), options_.utilize_presence(),
} options_.presence_threshold())) {
if (options_.utilize_presence() && landmark.has_presence() &&
landmark.presence() < options_.presence_threshold()) {
continue; continue;
} }
@ -303,7 +352,7 @@ absl::Status LandmarksToRenderDataCalculator::Process(CalculatorContext* cc) {
landmarks, landmark_connections_, options_.utilize_visibility(), landmarks, landmark_connections_, options_.utilize_visibility(),
options_.visibility_threshold(), options_.utilize_presence(), options_.visibility_threshold(), options_.utilize_presence(),
options_.presence_threshold(), thickness, /*normalized=*/true, z_min, options_.presence_threshold(), thickness, /*normalized=*/true, z_min,
z_max, render_data.get()); z_max, min_depth_line_color, max_depth_line_color, render_data.get());
} else { } else {
AddConnections<NormalizedLandmarkList, NormalizedLandmark>( AddConnections<NormalizedLandmarkList, NormalizedLandmark>(
landmarks, landmark_connections_, options_.utilize_visibility(), landmarks, landmark_connections_, options_.utilize_visibility(),
@ -314,12 +363,10 @@ absl::Status LandmarksToRenderDataCalculator::Process(CalculatorContext* cc) {
for (int i = 0; i < landmarks.landmark_size(); ++i) { for (int i = 0; i < landmarks.landmark_size(); ++i) {
const NormalizedLandmark& landmark = landmarks.landmark(i); const NormalizedLandmark& landmark = landmarks.landmark(i);
if (options_.utilize_visibility() && landmark.has_visibility() && if (!IsLandmarkVisibileAndPresent<NormalizedLandmark>(
landmark.visibility() < options_.visibility_threshold()) { landmark, options_.utilize_visibility(),
continue; options_.visibility_threshold(), options_.utilize_presence(),
} options_.presence_threshold())) {
if (options_.utilize_presence() && landmark.has_presence() &&
landmark.presence() < options_.presence_threshold()) {
continue; continue;
} }

View File

@ -64,4 +64,10 @@ message LandmarksToRenderDataCalculatorOptions {
// Max thickness of the drawing for landmark circle. // Max thickness of the drawing for landmark circle.
optional double max_depth_circle_thickness = 11 [default = 18.0]; optional double max_depth_circle_thickness = 11 [default = 18.0];
// Gradient color for the lines connecting landmarks at the minimum depth.
optional Color min_depth_line_color = 12;
// Gradient color for the lines connecting landmarks at the maximum depth.
optional Color max_depth_line_color = 13;
} }

View File

@ -0,0 +1,194 @@
// Copyright 2021 The MediaPipe Authors.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include <memory>
#include "absl/algorithm/container.h"
#include "mediapipe/calculators/util/visibility_copy_calculator.pb.h"
#include "mediapipe/framework/calculator_framework.h"
#include "mediapipe/framework/formats/landmark.pb.h"
#include "mediapipe/framework/port/ret_check.h"
#include "mediapipe/framework/timestamp.h"
namespace mediapipe {
namespace {
constexpr char kLandmarksFromTag[] = "LANDMARKS_FROM";
constexpr char kNormalizedLandmarksFromTag[] = "NORM_LANDMARKS_FROM";
constexpr char kLandmarksToTag[] = "LANDMARKS_TO";
constexpr char kNormalizedLandmarksToTag[] = "NORM_LANDMARKS_TO";
} // namespace
// A calculator to copy visibility and presence between landmarks.
//
// Landmarks to copy from and to copy to can be of different type (normalized or
// non-normalized), but ladnmarks to copy to and output landmarks should be of
// the same type. Exactly one stream to copy landmarks from, to copy to and to
// output should be provided.
//
// Inputs:
// LANDMARKS_FROM (optional): A LandmarkList of landmarks to copy from.
// NORM_LANDMARKS_FROM (optional): A NormalizedLandmarkList of landmarks to
// copy from.
// LANDMARKS_TO (optional): A LandmarkList of landmarks to copy to.
// NORM_LANDMARKS_TO (optional): A NormalizedLandmarkList of landmarks to copy
// to.
//
// Outputs:
// LANDMARKS_TO (optional): A LandmarkList of landmarks from LANDMARKS_TO and
// visibility/presence from LANDMARKS_FROM or NORM_LANDMARKS_FROM.
// NORM_LANDMARKS_TO (optional): A NormalizedLandmarkList of landmarks to copy
// to.
//
// Example config:
// node {
// calculator: "VisibilityCopyCalculator"
// input_stream: "NORM_LANDMARKS_FROM:pose_landmarks"
// input_stream: "LANDMARKS_TO:pose_world_landmarks"
// output_stream: "LANDMARKS_TO:pose_world_landmarks_with_visibility"
// options: {
// [mediapipe.VisibilityCopyCalculatorOptions.ext] {
// copy_visibility: true
// copy_presence: true
// }
// }
// }
//
class VisibilityCopyCalculator : public CalculatorBase {
public:
static absl::Status GetContract(CalculatorContract* cc);
absl::Status Open(CalculatorContext* cc) override;
absl::Status Process(CalculatorContext* cc) override;
private:
template <class LandmarkFromType, class LandmarkToType>
absl::Status CopyVisibility(CalculatorContext* cc,
const std::string& landmarks_from_tag,
const std::string& landmarks_to_tag);
bool copy_visibility_;
bool copy_presence_;
};
REGISTER_CALCULATOR(VisibilityCopyCalculator);
absl::Status VisibilityCopyCalculator::GetContract(CalculatorContract* cc) {
// Landmarks to copy from.
RET_CHECK(cc->Inputs().HasTag(kLandmarksFromTag) ^
cc->Inputs().HasTag(kNormalizedLandmarksFromTag))
<< "Exatly one landmarks stream to copy from should be provided";
if (cc->Inputs().HasTag(kLandmarksFromTag)) {
cc->Inputs().Tag(kLandmarksFromTag).Set<LandmarkList>();
} else {
cc->Inputs().Tag(kNormalizedLandmarksFromTag).Set<NormalizedLandmarkList>();
}
// Landmarks to copy to and corresponding output landmarks.
RET_CHECK(cc->Inputs().HasTag(kLandmarksToTag) ^
cc->Inputs().HasTag(kNormalizedLandmarksToTag))
<< "Exatly one landmarks stream to copy to should be provided";
if (cc->Inputs().HasTag(kLandmarksToTag)) {
cc->Inputs().Tag(kLandmarksToTag).Set<LandmarkList>();
RET_CHECK(cc->Outputs().HasTag(kLandmarksToTag))
<< "Landmarks to copy to and output stream types should be the same";
cc->Outputs().Tag(kLandmarksToTag).Set<LandmarkList>();
} else {
cc->Inputs().Tag(kNormalizedLandmarksToTag).Set<NormalizedLandmarkList>();
RET_CHECK(cc->Outputs().HasTag(kNormalizedLandmarksToTag))
<< "Landmarks to copy to and output stream types should be the same";
cc->Outputs().Tag(kNormalizedLandmarksToTag).Set<NormalizedLandmarkList>();
}
return absl::OkStatus();
}
absl::Status VisibilityCopyCalculator::Open(CalculatorContext* cc) {
cc->SetOffset(TimestampDiff(0));
const auto& options = cc->Options<VisibilityCopyCalculatorOptions>();
copy_visibility_ = options.copy_visibility();
copy_presence_ = options.copy_presence();
return absl::OkStatus();
}
absl::Status VisibilityCopyCalculator::Process(CalculatorContext* cc) {
// Switch between all four possible combinations of landmarks from and
// landmarks to types (normalized and non-normalized).
auto status = absl::OkStatus();
if (cc->Inputs().HasTag(kLandmarksFromTag)) {
if (cc->Inputs().HasTag(kLandmarksToTag)) {
status = CopyVisibility<LandmarkList, LandmarkList>(cc, kLandmarksFromTag,
kLandmarksToTag);
} else {
status = CopyVisibility<LandmarkList, NormalizedLandmarkList>(
cc, kLandmarksFromTag, kNormalizedLandmarksToTag);
}
} else {
if (cc->Inputs().HasTag(kLandmarksToTag)) {
status = CopyVisibility<NormalizedLandmarkList, LandmarkList>(
cc, kNormalizedLandmarksFromTag, kLandmarksToTag);
} else {
status = CopyVisibility<NormalizedLandmarkList, NormalizedLandmarkList>(
cc, kNormalizedLandmarksFromTag, kNormalizedLandmarksToTag);
}
}
return status;
}
template <class LandmarkFromType, class LandmarkToType>
absl::Status VisibilityCopyCalculator::CopyVisibility(
CalculatorContext* cc, const std::string& landmarks_from_tag,
const std::string& landmarks_to_tag) {
// Check that both landmarks to copy from and to copy to are non empty.
if (cc->Inputs().Tag(landmarks_from_tag).IsEmpty() ||
cc->Inputs().Tag(landmarks_to_tag).IsEmpty()) {
return absl::OkStatus();
}
const auto landmarks_from =
cc->Inputs().Tag(landmarks_from_tag).Get<LandmarkFromType>();
const auto landmarks_to =
cc->Inputs().Tag(landmarks_to_tag).Get<LandmarkToType>();
auto landmarks_out = absl::make_unique<LandmarkToType>();
for (int i = 0; i < landmarks_from.landmark_size(); ++i) {
const auto& landmark_from = landmarks_from.landmark(i);
const auto& landmark_to = landmarks_to.landmark(i);
// Create output landmark and copy all fields from the `to` landmark.
const auto& landmark_out = landmarks_out->add_landmark();
*landmark_out = landmark_to;
// Copy visibility and presence from the `from` landmark.
if (copy_visibility_) {
landmark_out->set_visibility(landmark_from.visibility());
}
if (copy_presence_) {
landmark_out->set_presence(landmark_from.presence());
}
}
cc->Outputs()
.Tag(landmarks_to_tag)
.Add(landmarks_out.release(), cc->InputTimestamp());
return absl::OkStatus();
}
} // namespace mediapipe

View File

@ -0,0 +1,29 @@
// Copyright 2020 The MediaPipe Authors.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
syntax = "proto2";
package mediapipe;
import "mediapipe/framework/calculator_options.proto";
message VisibilityCopyCalculatorOptions {
extend CalculatorOptions {
optional VisibilityCopyCalculatorOptions ext = 363728421;
}
optional bool copy_visibility = 1 [default = true];
optional bool copy_presence = 2 [default = true];
}

View File

@ -0,0 +1,243 @@
// Copyright 2020 The MediaPipe Authors.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include <memory>
#include "absl/algorithm/container.h"
#include "mediapipe/calculators/util/visibility_smoothing_calculator.pb.h"
#include "mediapipe/framework/calculator_framework.h"
#include "mediapipe/framework/formats/landmark.pb.h"
#include "mediapipe/framework/port/ret_check.h"
#include "mediapipe/framework/timestamp.h"
#include "mediapipe/util/filtering/low_pass_filter.h"
namespace mediapipe {
namespace {
constexpr char kNormalizedLandmarksTag[] = "NORM_LANDMARKS";
constexpr char kLandmarksTag[] = "LANDMARKS";
constexpr char kNormalizedFilteredLandmarksTag[] = "NORM_FILTERED_LANDMARKS";
constexpr char kFilteredLandmarksTag[] = "FILTERED_LANDMARKS";
using mediapipe::LowPassFilter;
// Abstract class for various visibility filters.
class VisibilityFilter {
public:
virtual ~VisibilityFilter() = default;
virtual absl::Status Reset() { return absl::OkStatus(); }
virtual absl::Status Apply(const LandmarkList& in_landmarks,
const absl::Duration& timestamp,
LandmarkList* out_landmarks) = 0;
virtual absl::Status Apply(const NormalizedLandmarkList& in_landmarks,
const absl::Duration& timestamp,
NormalizedLandmarkList* out_landmarks) = 0;
};
// Returns visibility as is without smoothing.
class NoFilter : public VisibilityFilter {
public:
absl::Status Apply(const NormalizedLandmarkList& in_landmarks,
const absl::Duration& timestamp,
NormalizedLandmarkList* out_landmarks) override {
*out_landmarks = in_landmarks;
return absl::OkStatus();
}
absl::Status Apply(const LandmarkList& in_landmarks,
const absl::Duration& timestamp,
LandmarkList* out_landmarks) override {
*out_landmarks = in_landmarks;
return absl::OkStatus();
}
};
// Please check LowPassFilter documentation for details.
class LowPassVisibilityFilter : public VisibilityFilter {
public:
LowPassVisibilityFilter(float alpha) : alpha_(alpha) {}
absl::Status Reset() override {
visibility_filters_.clear();
return absl::OkStatus();
}
absl::Status Apply(const LandmarkList& in_landmarks,
const absl::Duration& timestamp,
LandmarkList* out_landmarks) override {
return ApplyImpl<LandmarkList>(in_landmarks, timestamp, out_landmarks);
}
absl::Status Apply(const NormalizedLandmarkList& in_landmarks,
const absl::Duration& timestamp,
NormalizedLandmarkList* out_landmarks) override {
return ApplyImpl<NormalizedLandmarkList>(in_landmarks, timestamp,
out_landmarks);
}
private:
template <class LandmarksType>
absl::Status ApplyImpl(const LandmarksType& in_landmarks,
const absl::Duration& timestamp,
LandmarksType* out_landmarks) {
// Initializes filters for the first time or after Reset. If initialized
// then check the size.
int n_landmarks = in_landmarks.landmark_size();
if (!visibility_filters_.empty()) {
RET_CHECK_EQ(visibility_filters_.size(), n_landmarks);
} else {
visibility_filters_.resize(n_landmarks, LowPassFilter(alpha_));
}
// Filter visibilities.
for (int i = 0; i < in_landmarks.landmark_size(); ++i) {
const auto& in_landmark = in_landmarks.landmark(i);
auto* out_landmark = out_landmarks->add_landmark();
*out_landmark = in_landmark;
out_landmark->set_visibility(
visibility_filters_[i].Apply(in_landmark.visibility()));
}
return absl::OkStatus();
}
float alpha_;
std::vector<LowPassFilter> visibility_filters_;
};
} // namespace
// A calculator to smooth landmark visibilities over time.
//
// Exactly one landmarks input stream is expected. Output stream type should be
// the same as the input one.
//
// Inputs:
// LANDMARKS (optional): A LandmarkList of landmarks you want to smooth.
// NORM_LANDMARKS (optional): A NormalizedLandmarkList of landmarks you want
// to smooth.
//
// Outputs:
// FILTERED_LANDMARKS (optional): A LandmarkList of smoothed landmarks.
// NORM_FILTERED_LANDMARKS (optional): A NormalizedLandmarkList of smoothed
// landmarks.
//
// Example config:
// node {
// calculator: "VisibilitySmoothingCalculator"
// input_stream: "NORM_LANDMARKS:pose_landmarks"
// output_stream: "NORM_FILTERED_LANDMARKS:pose_landmarks_filtered"
// options: {
// [mediapipe.VisibilitySmoothingCalculatorOptions.ext] {
// low_pass_filter: {
// alpha: 0.1
// }
// }
// }
// }
//
class VisibilitySmoothingCalculator : public CalculatorBase {
public:
static absl::Status GetContract(CalculatorContract* cc);
absl::Status Open(CalculatorContext* cc) override;
absl::Status Process(CalculatorContext* cc) override;
private:
std::unique_ptr<VisibilityFilter> visibility_filter_;
};
REGISTER_CALCULATOR(VisibilitySmoothingCalculator);
absl::Status VisibilitySmoothingCalculator::GetContract(
CalculatorContract* cc) {
RET_CHECK(cc->Inputs().HasTag(kNormalizedLandmarksTag) ^
cc->Inputs().HasTag(kLandmarksTag))
<< "Exactly one landmarks input stream is expected";
if (cc->Inputs().HasTag(kNormalizedLandmarksTag)) {
cc->Inputs().Tag(kNormalizedLandmarksTag).Set<NormalizedLandmarkList>();
RET_CHECK(cc->Outputs().HasTag(kNormalizedFilteredLandmarksTag))
<< "Landmarks output stream should of the same type as input one";
cc->Outputs()
.Tag(kNormalizedFilteredLandmarksTag)
.Set<NormalizedLandmarkList>();
} else {
cc->Inputs().Tag(kLandmarksTag).Set<LandmarkList>();
RET_CHECK(cc->Outputs().HasTag(kFilteredLandmarksTag))
<< "Landmarks output stream should of the same type as input one";
cc->Outputs().Tag(kFilteredLandmarksTag).Set<LandmarkList>();
}
return absl::OkStatus();
}
absl::Status VisibilitySmoothingCalculator::Open(CalculatorContext* cc) {
cc->SetOffset(TimestampDiff(0));
// Pick visibility filter.
const auto& options = cc->Options<VisibilitySmoothingCalculatorOptions>();
if (options.has_no_filter()) {
visibility_filter_ = absl::make_unique<NoFilter>();
} else if (options.has_low_pass_filter()) {
visibility_filter_ = absl::make_unique<LowPassVisibilityFilter>(
options.low_pass_filter().alpha());
} else {
RET_CHECK_FAIL()
<< "Visibility filter is either not specified or not supported";
}
return absl::OkStatus();
}
absl::Status VisibilitySmoothingCalculator::Process(CalculatorContext* cc) {
// Check that landmarks are not empty and reset the filter if so.
// Don't emit an empty packet for this timestamp.
if ((cc->Inputs().HasTag(kNormalizedLandmarksTag) &&
cc->Inputs().Tag(kNormalizedLandmarksTag).IsEmpty()) ||
(cc->Inputs().HasTag(kLandmarksTag) &&
cc->Inputs().Tag(kLandmarksTag).IsEmpty())) {
MP_RETURN_IF_ERROR(visibility_filter_->Reset());
return absl::OkStatus();
}
const auto& timestamp =
absl::Microseconds(cc->InputTimestamp().Microseconds());
if (cc->Inputs().HasTag(kNormalizedLandmarksTag)) {
const auto& in_landmarks =
cc->Inputs().Tag(kNormalizedLandmarksTag).Get<NormalizedLandmarkList>();
auto out_landmarks = absl::make_unique<NormalizedLandmarkList>();
MP_RETURN_IF_ERROR(visibility_filter_->Apply(in_landmarks, timestamp,
out_landmarks.get()));
cc->Outputs()
.Tag(kNormalizedFilteredLandmarksTag)
.Add(out_landmarks.release(), cc->InputTimestamp());
} else {
const auto& in_landmarks =
cc->Inputs().Tag(kLandmarksTag).Get<LandmarkList>();
auto out_landmarks = absl::make_unique<LandmarkList>();
MP_RETURN_IF_ERROR(visibility_filter_->Apply(in_landmarks, timestamp,
out_landmarks.get()));
cc->Outputs()
.Tag(kFilteredLandmarksTag)
.Add(out_landmarks.release(), cc->InputTimestamp());
}
return absl::OkStatus();
}
} // namespace mediapipe

View File

@ -0,0 +1,40 @@
// Copyright 2020 The MediaPipe Authors.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
syntax = "proto2";
package mediapipe;
import "mediapipe/framework/calculator_options.proto";
message VisibilitySmoothingCalculatorOptions {
extend CalculatorOptions {
optional VisibilitySmoothingCalculatorOptions ext = 360207350;
}
// Default behaviour and fast way to disable smoothing.
message NoFilter {}
message LowPassFilter {
// Coefficient applied to a new value, whilte `1 - alpha` is applied to a
// stored value. Should be in [0, 1] range. The smaller the value - the
// smoother result and the bigger lag.
optional float alpha = 1 [default = 0.1];
}
oneof filter_options {
NoFilter no_filter = 1;
LowPassFilter low_pass_filter = 2;
}
}

View File

@ -0,0 +1,108 @@
// Copyright 2019 The MediaPipe Authors.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include <cmath>
#include <vector>
#include "mediapipe/framework/calculator_framework.h"
#include "mediapipe/framework/formats/landmark.pb.h"
#include "mediapipe/framework/formats/rect.pb.h"
#include "mediapipe/framework/port/ret_check.h"
namespace mediapipe {
namespace {
constexpr char kLandmarksTag[] = "LANDMARKS";
constexpr char kRectTag[] = "NORM_RECT";
} // namespace
// Projects world landmarks from the rectangle to original coordinates.
//
// World landmarks are predicted in meters rather than in pixels of the image
// and have origin in the middle of the hips rather than in the corner of the
// pose image (cropped with given rectangle). Thus only rotation (but not scale
// and translation) is applied to the landmarks to transform them back to
// original coordinates.
//
// Input:
// LANDMARKS: A LandmarkList representing world landmarks in the rectangle.
// NORM_RECT: An NormalizedRect representing a normalized rectangle in image
// coordinates.
//
// Output:
// LANDMARKS: A LandmarkList representing world landmarks projected (rotated
// but not scaled or translated) from the rectangle to original
// coordinates.
//
// Usage example:
// node {
// calculator: "WorldLandmarkProjectionCalculator"
// input_stream: "LANDMARKS:landmarks"
// input_stream: "NORM_RECT:rect"
// output_stream: "LANDMARKS:projected_landmarks"
// }
//
class WorldLandmarkProjectionCalculator : public CalculatorBase {
public:
static absl::Status GetContract(CalculatorContract* cc) {
cc->Inputs().Tag(kLandmarksTag).Set<LandmarkList>();
cc->Inputs().Tag(kRectTag).Set<NormalizedRect>();
cc->Outputs().Tag(kLandmarksTag).Set<LandmarkList>();
return absl::OkStatus();
}
absl::Status Open(CalculatorContext* cc) override {
cc->SetOffset(TimestampDiff(0));
return absl::OkStatus();
}
absl::Status Process(CalculatorContext* cc) override {
// Check that landmarks and rect are not empty.
if (cc->Inputs().Tag(kLandmarksTag).IsEmpty() ||
cc->Inputs().Tag(kRectTag).IsEmpty()) {
return absl::OkStatus();
}
const auto& in_landmarks =
cc->Inputs().Tag(kLandmarksTag).Get<LandmarkList>();
const auto& in_rect = cc->Inputs().Tag(kRectTag).Get<NormalizedRect>();
auto out_landmarks = absl::make_unique<LandmarkList>();
for (int i = 0; i < in_landmarks.landmark_size(); ++i) {
const auto& in_landmark = in_landmarks.landmark(i);
Landmark* out_landmark = out_landmarks->add_landmark();
*out_landmark = in_landmark;
const float angle = in_rect.rotation();
out_landmark->set_x(std::cos(angle) * in_landmark.x() -
std::sin(angle) * in_landmark.y());
out_landmark->set_y(std::sin(angle) * in_landmark.x() +
std::cos(angle) * in_landmark.y());
}
cc->Outputs()
.Tag(kLandmarksTag)
.Add(out_landmarks.release(), cc->InputTimestamp());
return absl::OkStatus();
}
};
REGISTER_CALCULATOR(WorldLandmarkProjectionCalculator);
} // namespace mediapipe

View File

@ -426,6 +426,7 @@ cc_test(
"//mediapipe/framework/port:logging", "//mediapipe/framework/port:logging",
"//mediapipe/framework/port:opencv_core", "//mediapipe/framework/port:opencv_core",
"//mediapipe/framework/port:parse_text_proto", "//mediapipe/framework/port:parse_text_proto",
"@com_google_absl//absl/flags:flag",
], ],
) )
@ -450,6 +451,7 @@ cc_test(
"//mediapipe/framework/port:opencv_imgproc", "//mediapipe/framework/port:opencv_imgproc",
"//mediapipe/framework/port:opencv_video", "//mediapipe/framework/port:opencv_video",
"//mediapipe/framework/port:parse_text_proto", "//mediapipe/framework/port:parse_text_proto",
"@com_google_absl//absl/flags:flag",
], ],
) )
@ -534,6 +536,7 @@ cc_test(
"//mediapipe/framework/stream_handler:sync_set_input_stream_handler", "//mediapipe/framework/stream_handler:sync_set_input_stream_handler",
"//mediapipe/util/tracking:box_tracker_cc_proto", "//mediapipe/util/tracking:box_tracker_cc_proto",
"//mediapipe/util/tracking:tracking_cc_proto", "//mediapipe/util/tracking:tracking_cc_proto",
"@com_google_absl//absl/flags:flag",
], ],
) )

View File

@ -27,13 +27,14 @@ cc_library(
"//mediapipe/framework:calculator_framework", "//mediapipe/framework:calculator_framework",
"//mediapipe/framework/formats:image_frame", "//mediapipe/framework/formats:image_frame",
"//mediapipe/framework/formats:image_frame_opencv", "//mediapipe/framework/formats:image_frame_opencv",
"//mediapipe/framework/port:commandlineflags",
"//mediapipe/framework/port:file_helpers", "//mediapipe/framework/port:file_helpers",
"//mediapipe/framework/port:opencv_highgui", "//mediapipe/framework/port:opencv_highgui",
"//mediapipe/framework/port:opencv_imgproc", "//mediapipe/framework/port:opencv_imgproc",
"//mediapipe/framework/port:opencv_video", "//mediapipe/framework/port:opencv_video",
"//mediapipe/framework/port:parse_text_proto", "//mediapipe/framework/port:parse_text_proto",
"//mediapipe/framework/port:status", "//mediapipe/framework/port:status",
"@com_google_absl//absl/flags:flag",
"@com_google_absl//absl/flags:parse",
], ],
) )

View File

@ -62,7 +62,7 @@ COPY . /mediapipe/
# Install bazel # Install bazel
# Please match the current MediaPipe Bazel requirements according to docs. # Please match the current MediaPipe Bazel requirements according to docs.
ARG BAZEL_VERSION=3.4.1 ARG BAZEL_VERSION=3.7.2
RUN mkdir /bazel && \ RUN mkdir /bazel && \
wget --no-check-certificate -O /bazel/installer.sh "https://github.com/bazelbuild/bazel/releases/download/${BAZEL_VERSION}/bazel-${BAZEL_VERSION}-installer-linux-x86_64.sh" && \ wget --no-check-certificate -O /bazel/installer.sh "https://github.com/bazelbuild/bazel/releases/download/${BAZEL_VERSION}/bazel-${BAZEL_VERSION}-installer-linux-x86_64.sh" && \
wget --no-check-certificate -O /bazel/LICENSE.txt "https://raw.githubusercontent.com/bazelbuild/bazel/master/LICENSE" && \ wget --no-check-certificate -O /bazel/LICENSE.txt "https://raw.githubusercontent.com/bazelbuild/bazel/master/LICENSE" && \

View File

@ -15,10 +15,11 @@
// An example of sending OpenCV webcam frames into a MediaPipe graph. // An example of sending OpenCV webcam frames into a MediaPipe graph.
#include <cstdlib> #include <cstdlib>
#include "absl/flags/flag.h"
#include "absl/flags/parse.h"
#include "mediapipe/framework/calculator_framework.h" #include "mediapipe/framework/calculator_framework.h"
#include "mediapipe/framework/formats/image_frame.h" #include "mediapipe/framework/formats/image_frame.h"
#include "mediapipe/framework/formats/image_frame_opencv.h" #include "mediapipe/framework/formats/image_frame_opencv.h"
#include "mediapipe/framework/port/commandlineflags.h"
#include "mediapipe/framework/port/file_helpers.h" #include "mediapipe/framework/port/file_helpers.h"
#include "mediapipe/framework/port/opencv_highgui_inc.h" #include "mediapipe/framework/port/opencv_highgui_inc.h"
#include "mediapipe/framework/port/opencv_imgproc_inc.h" #include "mediapipe/framework/port/opencv_imgproc_inc.h"
@ -30,13 +31,12 @@ constexpr char kInputStream[] = "input_video";
constexpr char kOutputStream[] = "output_video"; constexpr char kOutputStream[] = "output_video";
constexpr char kWindowName[] = "MediaPipe"; constexpr char kWindowName[] = "MediaPipe";
DEFINE_string( ABSL_FLAG(std::string, calculator_graph_config_file, "",
calculator_graph_config_file, "",
"Name of file containing text format CalculatorGraphConfig proto."); "Name of file containing text format CalculatorGraphConfig proto.");
DEFINE_string(input_video_path, "", ABSL_FLAG(std::string, input_video_path, "",
"Full path of video to load. " "Full path of video to load. "
"If not provided, attempt to use a webcam."); "If not provided, attempt to use a webcam.");
DEFINE_string(output_video_path, "", ABSL_FLAG(std::string, output_video_path, "",
"Full path of where to save result (.mp4 only). " "Full path of where to save result (.mp4 only). "
"If not provided, show result in a window."); "If not provided, show result in a window.");
@ -143,7 +143,7 @@ absl::Status RunMPPGraph() {
int main(int argc, char** argv) { int main(int argc, char** argv) {
google::InitGoogleLogging(argv[0]); google::InitGoogleLogging(argv[0]);
gflags::ParseCommandLineFlags(&argc, &argv, true); absl::ParseCommandLine(argc, argv);
absl::Status run_status = RunMPPGraph(); absl::Status run_status = RunMPPGraph();
if (!run_status.ok()) { if (!run_status.ok()) {
LOG(ERROR) << "Failed to run the graph: " << run_status.message(); LOG(ERROR) << "Failed to run the graph: " << run_status.message();

View File

@ -23,13 +23,14 @@ cc_library(
srcs = ["simple_run_graph_main.cc"], srcs = ["simple_run_graph_main.cc"],
deps = [ deps = [
"//mediapipe/framework:calculator_framework", "//mediapipe/framework:calculator_framework",
"//mediapipe/framework/port:commandlineflags",
"//mediapipe/framework/port:file_helpers", "//mediapipe/framework/port:file_helpers",
"//mediapipe/framework/port:map_util", "//mediapipe/framework/port:map_util",
"//mediapipe/framework/port:parse_text_proto", "//mediapipe/framework/port:parse_text_proto",
"//mediapipe/framework/port:ret_check", "//mediapipe/framework/port:ret_check",
"//mediapipe/framework/port:status", "//mediapipe/framework/port:status",
"//mediapipe/framework/port:statusor", "//mediapipe/framework/port:statusor",
"@com_google_absl//absl/flags:flag",
"@com_google_absl//absl/flags:parse",
"@com_google_absl//absl/strings", "@com_google_absl//absl/strings",
], ],
) )
@ -41,13 +42,14 @@ cc_library(
"//mediapipe/framework:calculator_framework", "//mediapipe/framework:calculator_framework",
"//mediapipe/framework/formats:image_frame", "//mediapipe/framework/formats:image_frame",
"//mediapipe/framework/formats:image_frame_opencv", "//mediapipe/framework/formats:image_frame_opencv",
"//mediapipe/framework/port:commandlineflags",
"//mediapipe/framework/port:file_helpers", "//mediapipe/framework/port:file_helpers",
"//mediapipe/framework/port:opencv_highgui", "//mediapipe/framework/port:opencv_highgui",
"//mediapipe/framework/port:opencv_imgproc", "//mediapipe/framework/port:opencv_imgproc",
"//mediapipe/framework/port:opencv_video", "//mediapipe/framework/port:opencv_video",
"//mediapipe/framework/port:parse_text_proto", "//mediapipe/framework/port:parse_text_proto",
"//mediapipe/framework/port:status", "//mediapipe/framework/port:status",
"@com_google_absl//absl/flags:flag",
"@com_google_absl//absl/flags:parse",
], ],
) )
@ -62,7 +64,6 @@ cc_library(
"//mediapipe/framework:calculator_framework", "//mediapipe/framework:calculator_framework",
"//mediapipe/framework/formats:image_frame", "//mediapipe/framework/formats:image_frame",
"//mediapipe/framework/formats:image_frame_opencv", "//mediapipe/framework/formats:image_frame_opencv",
"//mediapipe/framework/port:commandlineflags",
"//mediapipe/framework/port:file_helpers", "//mediapipe/framework/port:file_helpers",
"//mediapipe/framework/port:opencv_highgui", "//mediapipe/framework/port:opencv_highgui",
"//mediapipe/framework/port:opencv_imgproc", "//mediapipe/framework/port:opencv_imgproc",
@ -72,5 +73,7 @@ cc_library(
"//mediapipe/gpu:gl_calculator_helper", "//mediapipe/gpu:gl_calculator_helper",
"//mediapipe/gpu:gpu_buffer", "//mediapipe/gpu:gpu_buffer",
"//mediapipe/gpu:gpu_shared_data_internal", "//mediapipe/gpu:gpu_shared_data_internal",
"@com_google_absl//absl/flags:flag",
"@com_google_absl//absl/flags:parse",
], ],
) )

View File

@ -54,18 +54,27 @@ mediapipe_cc_proto_library(
deps = [":border_detection_calculator_proto"], deps = [":border_detection_calculator_proto"],
) )
cc_library(
name = "content_zooming_calculator_state",
hdrs = ["content_zooming_calculator_state.h"],
deps = [
"//mediapipe/examples/desktop/autoflip/quality:kinematic_path_solver",
"//mediapipe/framework:timestamp",
"//mediapipe/framework/formats:rect_cc_proto",
],
)
cc_library( cc_library(
name = "content_zooming_calculator", name = "content_zooming_calculator",
srcs = ["content_zooming_calculator.cc"], srcs = ["content_zooming_calculator.cc"],
deps = [ deps = [
":content_zooming_calculator_cc_proto", ":content_zooming_calculator_cc_proto",
":content_zooming_calculator_state",
"//mediapipe/examples/desktop/autoflip:autoflip_messages_cc_proto", "//mediapipe/examples/desktop/autoflip:autoflip_messages_cc_proto",
"//mediapipe/examples/desktop/autoflip/quality:kinematic_path_solver",
"//mediapipe/framework:calculator_framework", "//mediapipe/framework:calculator_framework",
"//mediapipe/framework/formats:detection_cc_proto", "//mediapipe/framework/formats:detection_cc_proto",
"//mediapipe/framework/formats:image_frame", "//mediapipe/framework/formats:image_frame",
"//mediapipe/framework/formats:location_data_cc_proto", "//mediapipe/framework/formats:location_data_cc_proto",
"//mediapipe/framework/formats:rect_cc_proto",
"//mediapipe/framework/port:ret_check", "//mediapipe/framework/port:ret_check",
"//mediapipe/framework/port:status", "//mediapipe/framework/port:status",
], ],
@ -88,7 +97,9 @@ mediapipe_cc_proto_library(
"//mediapipe/examples/desktop/autoflip/quality:kinematic_path_solver_cc_proto", "//mediapipe/examples/desktop/autoflip/quality:kinematic_path_solver_cc_proto",
"//mediapipe/framework:calculator_cc_proto", "//mediapipe/framework:calculator_cc_proto",
], ],
visibility = ["//mediapipe/examples:__subpackages__"], visibility = [
"//mediapipe/examples:__subpackages__",
],
deps = [ deps = [
":content_zooming_calculator_proto", ":content_zooming_calculator_proto",
], ],
@ -127,6 +138,7 @@ cc_test(
deps = [ deps = [
":content_zooming_calculator", ":content_zooming_calculator",
":content_zooming_calculator_cc_proto", ":content_zooming_calculator_cc_proto",
":content_zooming_calculator_state",
"//mediapipe/examples/desktop/autoflip:autoflip_messages_cc_proto", "//mediapipe/examples/desktop/autoflip:autoflip_messages_cc_proto",
"//mediapipe/examples/desktop/autoflip/quality:kinematic_path_solver", "//mediapipe/examples/desktop/autoflip/quality:kinematic_path_solver",
"//mediapipe/framework:calculator_framework", "//mediapipe/framework:calculator_framework",
@ -368,7 +380,6 @@ cc_test(
"//mediapipe/framework/deps:file_path", "//mediapipe/framework/deps:file_path",
"//mediapipe/framework/formats:image_frame", "//mediapipe/framework/formats:image_frame",
"//mediapipe/framework/formats:image_frame_opencv", "//mediapipe/framework/formats:image_frame_opencv",
"//mediapipe/framework/port:commandlineflags",
"//mediapipe/framework/port:gtest_main", "//mediapipe/framework/port:gtest_main",
"//mediapipe/framework/port:opencv_core", "//mediapipe/framework/port:opencv_core",
"//mediapipe/framework/port:opencv_imgcodecs", "//mediapipe/framework/port:opencv_imgcodecs",
@ -376,6 +387,7 @@ cc_test(
"//mediapipe/framework/port:parse_text_proto", "//mediapipe/framework/port:parse_text_proto",
"//mediapipe/framework/port:ret_check", "//mediapipe/framework/port:ret_check",
"//mediapipe/framework/port:status", "//mediapipe/framework/port:status",
"@com_google_absl//absl/flags:flag",
"@com_google_absl//absl/strings", "@com_google_absl//absl/strings",
], ],
) )

View File

@ -17,12 +17,11 @@
#include "mediapipe/examples/desktop/autoflip/autoflip_messages.pb.h" #include "mediapipe/examples/desktop/autoflip/autoflip_messages.pb.h"
#include "mediapipe/examples/desktop/autoflip/calculators/content_zooming_calculator.pb.h" #include "mediapipe/examples/desktop/autoflip/calculators/content_zooming_calculator.pb.h"
#include "mediapipe/examples/desktop/autoflip/quality/kinematic_path_solver.h" #include "mediapipe/examples/desktop/autoflip/calculators/content_zooming_calculator_state.h"
#include "mediapipe/framework/calculator_framework.h" #include "mediapipe/framework/calculator_framework.h"
#include "mediapipe/framework/formats/detection.pb.h" #include "mediapipe/framework/formats/detection.pb.h"
#include "mediapipe/framework/formats/image_frame.h" #include "mediapipe/framework/formats/image_frame.h"
#include "mediapipe/framework/formats/location_data.pb.h" #include "mediapipe/framework/formats/location_data.pb.h"
#include "mediapipe/framework/formats/rect.pb.h"
#include "mediapipe/framework/port/ret_check.h" #include "mediapipe/framework/port/ret_check.h"
#include "mediapipe/framework/port/status.h" #include "mediapipe/framework/port/status.h"
#include "mediapipe/framework/port/status_builder.h" #include "mediapipe/framework/port/status_builder.h"
@ -33,12 +32,18 @@ constexpr char kSalientRegions[] = "SALIENT_REGIONS";
constexpr char kDetections[] = "DETECTIONS"; constexpr char kDetections[] = "DETECTIONS";
constexpr char kDetectedBorders[] = "BORDERS"; constexpr char kDetectedBorders[] = "BORDERS";
constexpr char kCropRect[] = "CROP_RECT"; constexpr char kCropRect[] = "CROP_RECT";
constexpr char kFirstCropRect[] = "FIRST_CROP_RECT";
// Field-of-view (degrees) of the camera's x-axis (width). // Field-of-view (degrees) of the camera's x-axis (width).
// TODO: Parameterize FOV based on camera specs. // TODO: Parameterize FOV based on camera specs.
constexpr float kFieldOfView = 60; constexpr float kFieldOfView = 60;
// A pointer to a ContentZoomingCalculatorStateCacheType in a side packet.
// Used to save state on Close and load state on Open in a new graph.
// Can be used to preserve state between graphs.
constexpr char kStateCache[] = "STATE_CACHE";
namespace mediapipe { namespace mediapipe {
namespace autoflip { namespace autoflip {
using StateCacheType = ContentZoomingCalculatorStateCacheType;
// Content zooming calculator zooms in on content when a detection has // Content zooming calculator zooms in on content when a detection has
// "only_required" set true or any raw detection input. It does this by // "only_required" set true or any raw detection input. It does this by
@ -49,8 +54,7 @@ namespace autoflip {
// include mobile makeover and autofliplive face reframing. // include mobile makeover and autofliplive face reframing.
class ContentZoomingCalculator : public CalculatorBase { class ContentZoomingCalculator : public CalculatorBase {
public: public:
ContentZoomingCalculator() ContentZoomingCalculator() : initialized_(false) {}
: initialized_(false), last_only_required_detection_(0) {}
~ContentZoomingCalculator() override {} ~ContentZoomingCalculator() override {}
ContentZoomingCalculator(const ContentZoomingCalculator&) = delete; ContentZoomingCalculator(const ContentZoomingCalculator&) = delete;
ContentZoomingCalculator& operator=(const ContentZoomingCalculator&) = delete; ContentZoomingCalculator& operator=(const ContentZoomingCalculator&) = delete;
@ -58,8 +62,25 @@ class ContentZoomingCalculator : public CalculatorBase {
static absl::Status GetContract(mediapipe::CalculatorContract* cc); static absl::Status GetContract(mediapipe::CalculatorContract* cc);
absl::Status Open(mediapipe::CalculatorContext* cc) override; absl::Status Open(mediapipe::CalculatorContext* cc) override;
absl::Status Process(mediapipe::CalculatorContext* cc) override; absl::Status Process(mediapipe::CalculatorContext* cc) override;
absl::Status Close(mediapipe::CalculatorContext* cc) override;
private: private:
// Tries to load state from a state-cache, if provided. Fallsback to
// initializing state if no cache or no value in the cache are available.
absl::Status MaybeLoadState(mediapipe::CalculatorContext* cc, int frame_width,
int frame_height);
// Saves state to a state-cache, if provided.
absl::Status SaveState(mediapipe::CalculatorContext* cc) const;
// Initializes the calculator for the given frame size, creating path solvers
// and resetting history like last measured values.
absl::Status InitializeState(int frame_width, int frame_height);
// Adjusts state to work with an updated frame size.
absl::Status UpdateForResolutionChange(int frame_width, int frame_height);
// Returns true if we are zooming to the initial rect.
bool IsZoomingToInitialRect(const Timestamp& timestamp) const;
// Builds the output rectangle when zooming to the initial rect.
absl::StatusOr<mediapipe::Rect> GetInitialZoomingRect(
int frame_width, int frame_height, const Timestamp& timestamp) const;
// Converts bounds to tilt offset, pan offset and height. // Converts bounds to tilt offset, pan offset and height.
absl::Status ConvertToPanTiltZoom(float xmin, float xmax, float ymin, absl::Status ConvertToPanTiltZoom(float xmin, float xmax, float ymin,
float ymax, int* tilt_offset, float ymax, int* tilt_offset,
@ -76,6 +97,10 @@ class ContentZoomingCalculator : public CalculatorBase {
std::unique_ptr<KinematicPathSolver> path_solver_tilt_; std::unique_ptr<KinematicPathSolver> path_solver_tilt_;
// Are parameters initialized. // Are parameters initialized.
bool initialized_; bool initialized_;
// Stores the time of the first crop rectangle.
Timestamp first_rect_timestamp_;
// Stores the first crop rectangle.
mediapipe::NormalizedRect first_rect_;
// Stores the time of the last "only_required" input. // Stores the time of the last "only_required" input.
int64 last_only_required_detection_; int64 last_only_required_detection_;
// Rect values of last message with detection(s). // Rect values of last message with detection(s).
@ -116,6 +141,12 @@ absl::Status ContentZoomingCalculator::GetContract(
if (cc->Outputs().HasTag(kCropRect)) { if (cc->Outputs().HasTag(kCropRect)) {
cc->Outputs().Tag(kCropRect).Set<mediapipe::Rect>(); cc->Outputs().Tag(kCropRect).Set<mediapipe::Rect>();
} }
if (cc->Outputs().HasTag(kFirstCropRect)) {
cc->Outputs().Tag(kFirstCropRect).Set<mediapipe::NormalizedRect>();
}
if (cc->InputSidePackets().HasTag(kStateCache)) {
cc->InputSidePackets().Tag(kStateCache).Set<StateCacheType*>();
}
return absl::OkStatus(); return absl::OkStatus();
} }
@ -135,6 +166,13 @@ absl::Status ContentZoomingCalculator::Open(mediapipe::CalculatorContext* cc) {
return absl::OkStatus(); return absl::OkStatus();
} }
absl::Status ContentZoomingCalculator::Close(mediapipe::CalculatorContext* cc) {
if (initialized_) {
MP_RETURN_IF_ERROR(SaveState(cc));
}
return absl::OkStatus();
}
absl::Status ContentZoomingCalculator::ConvertToPanTiltZoom( absl::Status ContentZoomingCalculator::ConvertToPanTiltZoom(
float xmin, float xmax, float ymin, float ymax, int* tilt_offset, float xmin, float xmax, float ymin, float ymax, int* tilt_offset,
int* pan_offset, int* height) { int* pan_offset, int* height) {
@ -275,18 +313,64 @@ absl::Status ContentZoomingCalculator::UpdateAspectAndMax() {
return absl::OkStatus(); return absl::OkStatus();
} }
absl::Status ContentZoomingCalculator::Process( absl::Status ContentZoomingCalculator::MaybeLoadState(
mediapipe::CalculatorContext* cc) { mediapipe::CalculatorContext* cc, int frame_width, int frame_height) {
// For async subgraph support, return on empty video size packets. const auto* state_cache =
if (cc->Inputs().HasTag(kVideoSize) && cc->InputSidePackets().HasTag(kStateCache)
cc->Inputs().Tag(kVideoSize).IsEmpty()) { ? cc->InputSidePackets().Tag(kStateCache).Get<StateCacheType*>()
: nullptr;
if (!state_cache || !state_cache->has_value()) {
return InitializeState(frame_width, frame_height);
}
const ContentZoomingCalculatorState& state = state_cache->value();
frame_width_ = state.frame_width;
frame_height_ = state.frame_height;
path_solver_pan_ =
std::make_unique<KinematicPathSolver>(state.path_solver_pan);
path_solver_tilt_ =
std::make_unique<KinematicPathSolver>(state.path_solver_tilt);
path_solver_zoom_ =
std::make_unique<KinematicPathSolver>(state.path_solver_zoom);
first_rect_timestamp_ = state.first_rect_timestamp;
first_rect_ = state.first_rect;
last_only_required_detection_ = state.last_only_required_detection;
last_measured_height_ = state.last_measured_height;
last_measured_x_offset_ = state.last_measured_x_offset;
last_measured_y_offset_ = state.last_measured_y_offset;
MP_RETURN_IF_ERROR(UpdateAspectAndMax());
return UpdateForResolutionChange(frame_width, frame_height);
}
absl::Status ContentZoomingCalculator::SaveState(
mediapipe::CalculatorContext* cc) const {
auto* state_cache =
cc->InputSidePackets().HasTag(kStateCache)
? cc->InputSidePackets().Tag(kStateCache).Get<StateCacheType*>()
: nullptr;
if (!state_cache) {
return absl::OkStatus(); return absl::OkStatus();
} }
int frame_width, frame_height;
MP_RETURN_IF_ERROR(GetVideoResolution(cc, &frame_width, &frame_height));
// Init on first call. *state_cache = ContentZoomingCalculatorState{
if (!initialized_) { .frame_height = frame_height_,
.frame_width = frame_width_,
.path_solver_zoom = *path_solver_zoom_,
.path_solver_pan = *path_solver_pan_,
.path_solver_tilt = *path_solver_tilt_,
.first_rect_timestamp = first_rect_timestamp_,
.first_rect = first_rect_,
.last_only_required_detection = last_only_required_detection_,
.last_measured_height = last_measured_height_,
.last_measured_x_offset = last_measured_x_offset_,
.last_measured_y_offset = last_measured_y_offset_,
};
return absl::OkStatus();
}
absl::Status ContentZoomingCalculator::InitializeState(int frame_width,
int frame_height) {
frame_width_ = frame_width; frame_width_ = frame_width;
frame_height_ = frame_height; frame_height_ = frame_height;
path_solver_pan_ = std::make_unique<KinematicPathSolver>( path_solver_pan_ = std::make_unique<KinematicPathSolver>(
@ -302,12 +386,16 @@ absl::Status ContentZoomingCalculator::Process(
options_.kinematic_options_zoom(), min_zoom_size, options_.kinematic_options_zoom(), min_zoom_size,
max_frame_value_ * frame_height_, max_frame_value_ * frame_height_,
static_cast<float>(frame_height_) / kFieldOfView); static_cast<float>(frame_height_) / kFieldOfView);
first_rect_timestamp_ = Timestamp::Unset();
last_only_required_detection_ = 0;
last_measured_height_ = max_frame_value_ * frame_height_; last_measured_height_ = max_frame_value_ * frame_height_;
last_measured_x_offset_ = target_aspect_ * frame_width_; last_measured_x_offset_ = target_aspect_ * frame_width_;
last_measured_y_offset_ = frame_width_ / 2; last_measured_y_offset_ = frame_width_ / 2;
initialized_ = true; return absl::OkStatus();
} }
absl::Status ContentZoomingCalculator::UpdateForResolutionChange(
int frame_width, int frame_height) {
// Update state for change in input resolution. // Update state for change in input resolution.
if (frame_width_ != frame_width || frame_height_ != frame_height) { if (frame_width_ != frame_width || frame_height_ != frame_height) {
double width_scale = frame_width / static_cast<double>(frame_width_); double width_scale = frame_width / static_cast<double>(frame_width_);
@ -328,6 +416,74 @@ absl::Status ContentZoomingCalculator::Process(
MP_RETURN_IF_ERROR(path_solver_zoom_->UpdatePixelsPerDegree( MP_RETURN_IF_ERROR(path_solver_zoom_->UpdatePixelsPerDegree(
static_cast<float>(frame_height_) / kFieldOfView)); static_cast<float>(frame_height_) / kFieldOfView));
} }
return absl::OkStatus();
}
bool ContentZoomingCalculator::IsZoomingToInitialRect(
const Timestamp& timestamp) const {
if (options_.us_to_first_rect() == 0 ||
first_rect_timestamp_ == Timestamp::Unset()) {
return false;
}
const int64 delta_us = (timestamp - first_rect_timestamp_).Value();
return (0 <= delta_us && delta_us <= options_.us_to_first_rect());
}
namespace {
double easeInQuad(double t) { return t * t; }
double easeOutQuad(double t) { return -1 * t * (t - 2); }
double easeInOutQuad(double t) {
if (t < 0.5) {
return easeInQuad(t * 2) * 0.5;
} else {
return easeOutQuad(t * 2 - 1) * 0.5 + 0.5;
}
}
double lerp(double a, double b, double i) { return a * (1 - i) + b * i; }
} // namespace
absl::StatusOr<mediapipe::Rect> ContentZoomingCalculator::GetInitialZoomingRect(
int frame_width, int frame_height, const Timestamp& timestamp) const {
RET_CHECK(IsZoomingToInitialRect(timestamp))
<< "Must only be called if zooming to initial rect.";
const int64 delta_us = (timestamp - first_rect_timestamp_).Value();
const int64 delay = options_.us_to_first_rect_delay();
const double interpolation = easeInOutQuad(std::max(
0.0, (delta_us - delay) /
static_cast<double>(options_.us_to_first_rect() - delay)));
const double x_center = lerp(0.5, first_rect_.x_center(), interpolation);
const double y_center = lerp(0.5, first_rect_.y_center(), interpolation);
const double width = lerp(1.0, first_rect_.width(), interpolation);
const double height = lerp(1.0, first_rect_.height(), interpolation);
mediapipe::Rect gpu_rect;
gpu_rect.set_x_center(x_center * frame_width);
gpu_rect.set_width(width * frame_width);
gpu_rect.set_y_center(y_center * frame_height);
gpu_rect.set_height(height * frame_height);
return gpu_rect;
}
absl::Status ContentZoomingCalculator::Process(
mediapipe::CalculatorContext* cc) {
// For async subgraph support, return on empty video size packets.
if (cc->Inputs().HasTag(kVideoSize) &&
cc->Inputs().Tag(kVideoSize).IsEmpty()) {
return absl::OkStatus();
}
int frame_width, frame_height;
MP_RETURN_IF_ERROR(GetVideoResolution(cc, &frame_width, &frame_height));
// Init on first call or re-init always if configured to be stateless.
if (!initialized_) {
MP_RETURN_IF_ERROR(MaybeLoadState(cc, frame_width, frame_height));
initialized_ = !options_.is_stateless();
} else {
MP_RETURN_IF_ERROR(UpdateForResolutionChange(frame_width, frame_height));
}
bool only_required_found = false; bool only_required_found = false;
@ -348,6 +504,10 @@ absl::Status ContentZoomingCalculator::Process(
if (cc->Inputs().HasTag(kDetections)) { if (cc->Inputs().HasTag(kDetections)) {
if (cc->Inputs().Tag(kDetections).IsEmpty()) { if (cc->Inputs().Tag(kDetections).IsEmpty()) {
if (last_only_required_detection_ == 0) {
// If no detections are available and we never had any,
// simply return the full-image rectangle as crop-rect.
if (cc->Outputs().HasTag(kCropRect)) {
auto default_rect = absl::make_unique<mediapipe::Rect>(); auto default_rect = absl::make_unique<mediapipe::Rect>();
default_rect->set_x_center(frame_width_ / 2); default_rect->set_x_center(frame_width_ / 2);
default_rect->set_y_center(frame_height_ / 2); default_rect->set_y_center(frame_height_ / 2);
@ -355,10 +515,20 @@ absl::Status ContentZoomingCalculator::Process(
default_rect->set_height(frame_height_); default_rect->set_height(frame_height_);
cc->Outputs().Tag(kCropRect).Add(default_rect.release(), cc->Outputs().Tag(kCropRect).Add(default_rect.release(),
Timestamp(cc->InputTimestamp())); Timestamp(cc->InputTimestamp()));
}
// Also provide a first crop rect: in this case a zero-sized one.
if (cc->Outputs().HasTag(kFirstCropRect)) {
cc->Outputs()
.Tag(kFirstCropRect)
.Add(new mediapipe::NormalizedRect(),
Timestamp(cc->InputTimestamp()));
}
return absl::OkStatus(); return absl::OkStatus();
} }
auto raw_detections = } else {
cc->Inputs().Tag(kDetections).Get<std::vector<mediapipe::Detection>>(); auto raw_detections = cc->Inputs()
.Tag(kDetections)
.Get<std::vector<mediapipe::Detection>>();
for (const auto& detection : raw_detections) { for (const auto& detection : raw_detections) {
only_required_found = true; only_required_found = true;
MP_RETURN_IF_ERROR(UpdateRanges( MP_RETURN_IF_ERROR(UpdateRanges(
@ -366,13 +536,20 @@ absl::Status ContentZoomingCalculator::Process(
options_.detection_shift_horizontal(), &xmin, &xmax, &ymin, &ymax)); options_.detection_shift_horizontal(), &xmin, &xmax, &ymin, &ymax));
} }
} }
}
bool zooming_to_initial_rect = IsZoomingToInitialRect(cc->InputTimestamp());
// Convert bounds to tilt/zoom and in pixel coordinates.
int offset_y, height, offset_x; int offset_y, height, offset_x;
if (zooming_to_initial_rect) {
// If we are zooming to the first rect, ignore any new incoming detections.
height = last_measured_height_;
offset_x = last_measured_x_offset_;
offset_y = last_measured_y_offset_;
} else if (only_required_found) {
// Convert bounds to tilt/zoom and in pixel coordinates.
MP_RETURN_IF_ERROR(ConvertToPanTiltZoom(xmin, xmax, ymin, ymax, &offset_y, MP_RETURN_IF_ERROR(ConvertToPanTiltZoom(xmin, xmax, ymin, ymax, &offset_y,
&offset_x, &height)); &offset_x, &height));
if (only_required_found) {
// A only required detection was found. // A only required detection was found.
last_only_required_detection_ = cc->InputTimestamp().Microseconds(); last_only_required_detection_ = cc->InputTimestamp().Microseconds();
last_measured_height_ = height; last_measured_height_ = height;
@ -383,7 +560,9 @@ absl::Status ContentZoomingCalculator::Process(
options_.us_before_zoomout()) { options_.us_before_zoomout()) {
// No only_require detections found within salient regions packets // No only_require detections found within salient regions packets
// arriving since us_before_zoomout duration. // arriving since us_before_zoomout duration.
height = max_frame_value_ * frame_height_; height = max_frame_value_ * frame_height_ +
(options_.kinematic_options_zoom().min_motion_to_reframe() *
(static_cast<float>(frame_height_) / kFieldOfView));
offset_x = (target_aspect_ * height) / 2; offset_x = (target_aspect_ * height) / 2;
offset_y = frame_height_ / 2; offset_y = frame_height_ / 2;
} else { } else {
@ -463,17 +642,44 @@ absl::Status ContentZoomingCalculator::Process(
.AddPacket(Adopt(features.release()).At(cc->InputTimestamp())); .AddPacket(Adopt(features.release()).At(cc->InputTimestamp()));
} }
if (first_rect_timestamp_ == Timestamp::Unset() &&
options_.us_to_first_rect() != 0) {
first_rect_timestamp_ = cc->InputTimestamp();
first_rect_.set_x_center(path_offset_x / static_cast<float>(frame_width_));
first_rect_.set_width(path_height * target_aspect_ /
static_cast<float>(frame_width_));
first_rect_.set_y_center(path_offset_y / static_cast<float>(frame_height_));
first_rect_.set_height(path_height / static_cast<float>(frame_height_));
// After setting the first rectangle, check whether we should zoom to it.
zooming_to_initial_rect = IsZoomingToInitialRect(cc->InputTimestamp());
}
// Transmit downstream to glcroppingcalculator. // Transmit downstream to glcroppingcalculator.
if (cc->Outputs().HasTag(kCropRect)) { if (cc->Outputs().HasTag(kCropRect)) {
auto gpu_rect = absl::make_unique<mediapipe::Rect>(); std::unique_ptr<mediapipe::Rect> gpu_rect;
if (zooming_to_initial_rect) {
auto rect = GetInitialZoomingRect(frame_width, frame_height,
cc->InputTimestamp());
MP_RETURN_IF_ERROR(rect.status());
gpu_rect = absl::make_unique<mediapipe::Rect>(*rect);
} else {
gpu_rect = absl::make_unique<mediapipe::Rect>();
gpu_rect->set_x_center(path_offset_x); gpu_rect->set_x_center(path_offset_x);
gpu_rect->set_width(path_height * target_aspect_); gpu_rect->set_width(path_height * target_aspect_);
gpu_rect->set_y_center(path_offset_y); gpu_rect->set_y_center(path_offset_y);
gpu_rect->set_height(path_height); gpu_rect->set_height(path_height);
}
cc->Outputs().Tag(kCropRect).Add(gpu_rect.release(), cc->Outputs().Tag(kCropRect).Add(gpu_rect.release(),
Timestamp(cc->InputTimestamp())); Timestamp(cc->InputTimestamp()));
} }
if (cc->Outputs().HasTag(kFirstCropRect)) {
cc->Outputs()
.Tag(kFirstCropRect)
.Add(new mediapipe::NormalizedRect(first_rect_),
Timestamp(cc->InputTimestamp()));
}
return absl::OkStatus(); return absl::OkStatus();
} }

View File

@ -19,7 +19,7 @@ package mediapipe.autoflip;
import "mediapipe/examples/desktop/autoflip/quality/kinematic_path_solver.proto"; import "mediapipe/examples/desktop/autoflip/quality/kinematic_path_solver.proto";
import "mediapipe/framework/calculator.proto"; import "mediapipe/framework/calculator.proto";
// NextTag: 14 // NextTag: 17
message ContentZoomingCalculatorOptions { message ContentZoomingCalculatorOptions {
extend mediapipe.CalculatorOptions { extend mediapipe.CalculatorOptions {
optional ContentZoomingCalculatorOptions ext = 313091992; optional ContentZoomingCalculatorOptions ext = 313091992;
@ -55,6 +55,16 @@ message ContentZoomingCalculatorOptions {
// Defines the smallest value in degrees the camera is permitted to zoom. // Defines the smallest value in degrees the camera is permitted to zoom.
optional float max_zoom_value_deg = 13 [default = 35]; optional float max_zoom_value_deg = 13 [default = 35];
// Whether to keep state between frames or to compute the final crop rect.
optional bool is_stateless = 14 [default = false];
// Duration (in MicroSeconds) for moving to the first crop rect.
optional int64 us_to_first_rect = 15 [default = 0];
// Duration (in MicroSeconds) to delay moving to the first crop rect.
// Used only if us_to_first_rect is set and is interpreted as part of the
// us_to_first_rect time budget.
optional int64 us_to_first_rect_delay = 16 [default = 0];
// Deprecated parameters // Deprecated parameters
optional KinematicOptions kinematic_options = 2 [deprecated = true]; optional KinematicOptions kinematic_options = 2 [deprecated = true];
optional int64 min_motion_to_reframe = 4 [deprecated = true]; optional int64 min_motion_to_reframe = 4 [deprecated = true];

View File

@ -0,0 +1,38 @@
#ifndef MEDIAPIPE_EXAMPLES_DESKTOP_AUTOFLIP_CALCULATORS_CONTENT_ZOOMING_CALCULATOR_STATE_H_
#define MEDIAPIPE_EXAMPLES_DESKTOP_AUTOFLIP_CALCULATORS_CONTENT_ZOOMING_CALCULATOR_STATE_H_
#include <optional>
#include "mediapipe/examples/desktop/autoflip/quality/kinematic_path_solver.h"
#include "mediapipe/framework/formats/rect.pb.h"
#include "mediapipe/framework/timestamp.h"
namespace mediapipe {
namespace autoflip {
struct ContentZoomingCalculatorState {
int frame_height = -1;
int frame_width = -1;
// Path solver used to smooth top/bottom border crop values.
KinematicPathSolver path_solver_zoom;
KinematicPathSolver path_solver_pan;
KinematicPathSolver path_solver_tilt;
// Stores the time of the first crop rectangle.
Timestamp first_rect_timestamp;
// Stores the first crop rectangle.
mediapipe::NormalizedRect first_rect;
// Stores the time of the last "only_required" input.
int64 last_only_required_detection = 0;
// Rect values of last message with detection(s).
int last_measured_height = 0;
int last_measured_x_offset = 0;
int last_measured_y_offset = 0;
};
using ContentZoomingCalculatorStateCacheType =
std::optional<ContentZoomingCalculatorState>;
} // namespace autoflip
} // namespace mediapipe
#endif // MEDIAPIPE_EXAMPLES_DESKTOP_AUTOFLIP_CALCULATORS_CONTENT_ZOOMING_CALCULATOR_STATE_H_

View File

@ -16,6 +16,7 @@
#include "mediapipe/examples/desktop/autoflip/autoflip_messages.pb.h" #include "mediapipe/examples/desktop/autoflip/autoflip_messages.pb.h"
#include "mediapipe/examples/desktop/autoflip/calculators/content_zooming_calculator.pb.h" #include "mediapipe/examples/desktop/autoflip/calculators/content_zooming_calculator.pb.h"
#include "mediapipe/examples/desktop/autoflip/calculators/content_zooming_calculator_state.h"
#include "mediapipe/examples/desktop/autoflip/quality/kinematic_path_solver.h" #include "mediapipe/examples/desktop/autoflip/quality/kinematic_path_solver.h"
#include "mediapipe/framework/calculator_framework.h" #include "mediapipe/framework/calculator_framework.h"
#include "mediapipe/framework/calculator_runner.h" #include "mediapipe/framework/calculator_runner.h"
@ -109,6 +110,7 @@ const char kConfigD[] = R"(
input_stream: "VIDEO_SIZE:size" input_stream: "VIDEO_SIZE:size"
input_stream: "DETECTIONS:detections" input_stream: "DETECTIONS:detections"
output_stream: "CROP_RECT:rect" output_stream: "CROP_RECT:rect"
output_stream: "FIRST_CROP_RECT:first_rect"
options: { options: {
[mediapipe.autoflip.ContentZoomingCalculatorOptions.ext]: { [mediapipe.autoflip.ContentZoomingCalculatorOptions.ext]: {
max_zoom_value_deg: 0 max_zoom_value_deg: 0
@ -147,19 +149,24 @@ void AddDetectionFrameSize(const cv::Rect_<float>& position, const int64 time,
const int width, const int height, const int width, const int height,
CalculatorRunner* runner) { CalculatorRunner* runner) {
auto detections = std::make_unique<std::vector<mediapipe::Detection>>(); auto detections = std::make_unique<std::vector<mediapipe::Detection>>();
if (position.width > 0 && position.height > 0) {
mediapipe::Detection detection; mediapipe::Detection detection;
detection.mutable_location_data()->set_format( detection.mutable_location_data()->set_format(
mediapipe::LocationData::RELATIVE_BOUNDING_BOX); mediapipe::LocationData::RELATIVE_BOUNDING_BOX);
detection.mutable_location_data() detection.mutable_location_data()
->mutable_relative_bounding_box() ->mutable_relative_bounding_box()
->set_height(position.height); ->set_height(position.height);
detection.mutable_location_data()->mutable_relative_bounding_box()->set_width( detection.mutable_location_data()
position.width); ->mutable_relative_bounding_box()
detection.mutable_location_data()->mutable_relative_bounding_box()->set_xmin( ->set_width(position.width);
position.x); detection.mutable_location_data()
detection.mutable_location_data()->mutable_relative_bounding_box()->set_ymin( ->mutable_relative_bounding_box()
position.y); ->set_xmin(position.x);
detection.mutable_location_data()
->mutable_relative_bounding_box()
->set_ymin(position.y);
detections->push_back(detection); detections->push_back(detection);
}
runner->MutableInputs() runner->MutableInputs()
->Tag("DETECTIONS") ->Tag("DETECTIONS")
.packets.push_back(Adopt(detections.release()).At(Timestamp(time))); .packets.push_back(Adopt(detections.release()).At(Timestamp(time)));
@ -185,7 +192,6 @@ void CheckCropRect(const int x_center, const int y_center, const int width,
EXPECT_EQ(rect.width(), width); EXPECT_EQ(rect.width(), width);
EXPECT_EQ(rect.height(), height); EXPECT_EQ(rect.height(), height);
} }
TEST(ContentZoomingCalculatorTest, ZoomTest) { TEST(ContentZoomingCalculatorTest, ZoomTest) {
auto runner = ::absl::make_unique<CalculatorRunner>( auto runner = ::absl::make_unique<CalculatorRunner>(
ParseTextProtoOrDie<CalculatorGraphConfig::Node>(kConfigA)); ParseTextProtoOrDie<CalculatorGraphConfig::Node>(kConfigA));
@ -244,6 +250,46 @@ TEST(ContentZoomingCalculatorTest, PanConfig) {
runner->Outputs().Tag("CROP_RECT").packets); runner->Outputs().Tag("CROP_RECT").packets);
} }
TEST(ContentZoomingCalculatorTest, PanConfigWithCache) {
mediapipe::autoflip::ContentZoomingCalculatorStateCacheType cache;
auto config = ParseTextProtoOrDie<CalculatorGraphConfig::Node>(kConfigD);
config.add_input_side_packet("STATE_CACHE:state_cache");
auto* options = config.mutable_options()->MutableExtension(
ContentZoomingCalculatorOptions::ext);
options->mutable_kinematic_options_pan()->set_min_motion_to_reframe(0.0);
options->mutable_kinematic_options_pan()->set_update_rate_seconds(2);
options->mutable_kinematic_options_tilt()->set_min_motion_to_reframe(50.0);
options->mutable_kinematic_options_zoom()->set_min_motion_to_reframe(50.0);
{
auto runner = ::absl::make_unique<CalculatorRunner>(config);
runner->MutableSidePackets()->Tag("STATE_CACHE") = MakePacket<
mediapipe::autoflip::ContentZoomingCalculatorStateCacheType*>(&cache);
AddDetection(cv::Rect_<float>(.4, .5, .1, .1), 0, runner.get());
MP_ASSERT_OK(runner->Run());
CheckCropRect(450, 550, 111, 111, 0,
runner->Outputs().Tag("CROP_RECT").packets);
}
{
auto runner = ::absl::make_unique<CalculatorRunner>(config);
runner->MutableSidePackets()->Tag("STATE_CACHE") = MakePacket<
mediapipe::autoflip::ContentZoomingCalculatorStateCacheType*>(&cache);
AddDetection(cv::Rect_<float>(.45, .55, .15, .15), 1000000, runner.get());
MP_ASSERT_OK(runner->Run());
CheckCropRect(483, 550, 111, 111, 0,
runner->Outputs().Tag("CROP_RECT").packets);
}
// Now repeat the last frame for a new runner without the cache to see a reset
{
auto runner = ::absl::make_unique<CalculatorRunner>(config);
runner->MutableSidePackets()->Tag("STATE_CACHE") = MakePacket<
mediapipe::autoflip::ContentZoomingCalculatorStateCacheType*>(nullptr);
AddDetection(cv::Rect_<float>(.45, .55, .15, .15), 2000000, runner.get());
MP_ASSERT_OK(runner->Run());
CheckCropRect(525, 625, 166, 166, 0, // Without a cache, state was lost.
runner->Outputs().Tag("CROP_RECT").packets);
}
}
TEST(ContentZoomingCalculatorTest, TiltConfig) { TEST(ContentZoomingCalculatorTest, TiltConfig) {
auto config = ParseTextProtoOrDie<CalculatorGraphConfig::Node>(kConfigD); auto config = ParseTextProtoOrDie<CalculatorGraphConfig::Node>(kConfigD);
auto* options = config.mutable_options()->MutableExtension( auto* options = config.mutable_options()->MutableExtension(
@ -280,6 +326,46 @@ TEST(ContentZoomingCalculatorTest, ZoomConfig) {
runner->Outputs().Tag("CROP_RECT").packets); runner->Outputs().Tag("CROP_RECT").packets);
} }
TEST(ContentZoomingCalculatorTest, ZoomConfigWithCache) {
mediapipe::autoflip::ContentZoomingCalculatorStateCacheType cache;
auto config = ParseTextProtoOrDie<CalculatorGraphConfig::Node>(kConfigD);
config.add_input_side_packet("STATE_CACHE:state_cache");
auto* options = config.mutable_options()->MutableExtension(
ContentZoomingCalculatorOptions::ext);
options->mutable_kinematic_options_pan()->set_min_motion_to_reframe(50.0);
options->mutable_kinematic_options_tilt()->set_min_motion_to_reframe(50.0);
options->mutable_kinematic_options_zoom()->set_min_motion_to_reframe(0.0);
options->mutable_kinematic_options_zoom()->set_update_rate_seconds(2);
{
auto runner = ::absl::make_unique<CalculatorRunner>(config);
runner->MutableSidePackets()->Tag("STATE_CACHE") = MakePacket<
mediapipe::autoflip::ContentZoomingCalculatorStateCacheType*>(&cache);
AddDetection(cv::Rect_<float>(.4, .5, .1, .1), 0, runner.get());
MP_ASSERT_OK(runner->Run());
CheckCropRect(450, 550, 111, 111, 0,
runner->Outputs().Tag("CROP_RECT").packets);
}
{
auto runner = ::absl::make_unique<CalculatorRunner>(config);
runner->MutableSidePackets()->Tag("STATE_CACHE") = MakePacket<
mediapipe::autoflip::ContentZoomingCalculatorStateCacheType*>(&cache);
AddDetection(cv::Rect_<float>(.45, .55, .15, .15), 1000000, runner.get());
MP_ASSERT_OK(runner->Run());
CheckCropRect(450, 550, 139, 139, 0,
runner->Outputs().Tag("CROP_RECT").packets);
}
// Now repeat the last frame for a new runner without the cache to see a reset
{
auto runner = ::absl::make_unique<CalculatorRunner>(config);
runner->MutableSidePackets()->Tag("STATE_CACHE") = MakePacket<
mediapipe::autoflip::ContentZoomingCalculatorStateCacheType*>(nullptr);
AddDetection(cv::Rect_<float>(.45, .55, .15, .15), 2000000, runner.get());
MP_ASSERT_OK(runner->Run());
CheckCropRect(525, 625, 166, 166, 0, // Without a cache, state was lost.
runner->Outputs().Tag("CROP_RECT").packets);
}
}
TEST(ContentZoomingCalculatorTest, MinAspectBorderValues) { TEST(ContentZoomingCalculatorTest, MinAspectBorderValues) {
auto runner = ::absl::make_unique<CalculatorRunner>( auto runner = ::absl::make_unique<CalculatorRunner>(
ParseTextProtoOrDie<CalculatorGraphConfig::Node>(kConfigB)); ParseTextProtoOrDie<CalculatorGraphConfig::Node>(kConfigB));
@ -509,6 +595,32 @@ TEST(ContentZoomingCalculatorTest, ResolutionChangeStationary) {
runner->Outputs().Tag("CROP_RECT").packets); runner->Outputs().Tag("CROP_RECT").packets);
} }
TEST(ContentZoomingCalculatorTest, ResolutionChangeStationaryWithCache) {
mediapipe::autoflip::ContentZoomingCalculatorStateCacheType cache;
auto config = ParseTextProtoOrDie<CalculatorGraphConfig::Node>(kConfigD);
config.add_input_side_packet("STATE_CACHE:state_cache");
{
auto runner = ::absl::make_unique<CalculatorRunner>(config);
runner->MutableSidePackets()->Tag("STATE_CACHE") = MakePacket<
mediapipe::autoflip::ContentZoomingCalculatorStateCacheType*>(&cache);
AddDetectionFrameSize(cv::Rect_<float>(.4, .4, .2, .2), 0, 1000, 1000,
runner.get());
MP_ASSERT_OK(runner->Run());
CheckCropRect(500, 500, 222, 222, 0,
runner->Outputs().Tag("CROP_RECT").packets);
}
{
auto runner = ::absl::make_unique<CalculatorRunner>(config);
runner->MutableSidePackets()->Tag("STATE_CACHE") = MakePacket<
mediapipe::autoflip::ContentZoomingCalculatorStateCacheType*>(&cache);
AddDetectionFrameSize(cv::Rect_<float>(.4, .4, .2, .2), 1, 500, 500,
runner.get());
MP_ASSERT_OK(runner->Run());
CheckCropRect(500 * 0.5, 500 * 0.5, 222 * 0.5, 222 * 0.5, 0,
runner->Outputs().Tag("CROP_RECT").packets);
}
}
TEST(ContentZoomingCalculatorTest, ResolutionChangeZooming) { TEST(ContentZoomingCalculatorTest, ResolutionChangeZooming) {
auto config = ParseTextProtoOrDie<CalculatorGraphConfig::Node>(kConfigD); auto config = ParseTextProtoOrDie<CalculatorGraphConfig::Node>(kConfigD);
auto runner = ::absl::make_unique<CalculatorRunner>(config); auto runner = ::absl::make_unique<CalculatorRunner>(config);
@ -527,6 +639,37 @@ TEST(ContentZoomingCalculatorTest, ResolutionChangeZooming) {
runner->Outputs().Tag("CROP_RECT").packets); runner->Outputs().Tag("CROP_RECT").packets);
} }
TEST(ContentZoomingCalculatorTest, ResolutionChangeZoomingWithCache) {
mediapipe::autoflip::ContentZoomingCalculatorStateCacheType cache;
auto config = ParseTextProtoOrDie<CalculatorGraphConfig::Node>(kConfigD);
config.add_input_side_packet("STATE_CACHE:state_cache");
{
auto runner = ::absl::make_unique<CalculatorRunner>(config);
runner->MutableSidePackets()->Tag("STATE_CACHE") = MakePacket<
mediapipe::autoflip::ContentZoomingCalculatorStateCacheType*>(&cache);
AddDetectionFrameSize(cv::Rect_<float>(.1, .1, .8, .8), 0, 1000, 1000,
runner.get());
MP_ASSERT_OK(runner->Run());
CheckCropRect(500, 500, 888, 888, 0,
runner->Outputs().Tag("CROP_RECT").packets);
}
// The second runner should just resume based on state from the first runner.
{
auto runner = ::absl::make_unique<CalculatorRunner>(config);
runner->MutableSidePackets()->Tag("STATE_CACHE") = MakePacket<
mediapipe::autoflip::ContentZoomingCalculatorStateCacheType*>(&cache);
AddDetectionFrameSize(cv::Rect_<float>(.4, .4, .2, .2), 1000000, 1000, 1000,
runner.get());
AddDetectionFrameSize(cv::Rect_<float>(.4, .4, .2, .2), 2000000, 500, 500,
runner.get());
MP_ASSERT_OK(runner->Run());
CheckCropRect(500, 500, 588, 588, 0,
runner->Outputs().Tag("CROP_RECT").packets);
CheckCropRect(500 * 0.5, 500 * 0.5, 288 * 0.5, 288 * 0.5, 1,
runner->Outputs().Tag("CROP_RECT").packets);
}
}
TEST(ContentZoomingCalculatorTest, MaxZoomValue) { TEST(ContentZoomingCalculatorTest, MaxZoomValue) {
auto config = ParseTextProtoOrDie<CalculatorGraphConfig::Node>(kConfigD); auto config = ParseTextProtoOrDie<CalculatorGraphConfig::Node>(kConfigD);
auto* options = config.mutable_options()->MutableExtension( auto* options = config.mutable_options()->MutableExtension(
@ -540,6 +683,108 @@ TEST(ContentZoomingCalculatorTest, MaxZoomValue) {
CheckCropRect(500, 500, 916, 916, 0, CheckCropRect(500, 500, 916, 916, 0,
runner->Outputs().Tag("CROP_RECT").packets); runner->Outputs().Tag("CROP_RECT").packets);
} }
TEST(ContentZoomingCalculatorTest, MaxZoomOutValue) {
auto config = ParseTextProtoOrDie<CalculatorGraphConfig::Node>(kConfigD);
auto* options = config.mutable_options()->MutableExtension(
ContentZoomingCalculatorOptions::ext);
options->set_scale_factor(1.0);
options->mutable_kinematic_options_zoom()->set_min_motion_to_reframe(5.0);
auto runner = ::absl::make_unique<CalculatorRunner>(config);
AddDetectionFrameSize(cv::Rect_<float>(.025, .025, .95, .95), 0, 1000, 1000,
runner.get());
AddDetectionFrameSize(cv::Rect_<float>(0, 0, -1, -1), 1000000, 1000, 1000,
runner.get());
AddDetectionFrameSize(cv::Rect_<float>(0, 0, -1, -1), 2000000, 1000, 1000,
runner.get());
MP_ASSERT_OK(runner->Run());
// 55/60 * 1000 = 916
CheckCropRect(500, 500, 950, 950, 0,
runner->Outputs().Tag("CROP_RECT").packets);
CheckCropRect(500, 500, 1000, 1000, 2,
runner->Outputs().Tag("CROP_RECT").packets);
}
TEST(ContentZoomingCalculatorTest, StartZoomedOut) {
auto config = ParseTextProtoOrDie<CalculatorGraphConfig::Node>(kConfigD);
auto* options = config.mutable_options()->MutableExtension(
ContentZoomingCalculatorOptions::ext);
options->set_us_to_first_rect(1000000);
options->set_us_to_first_rect_delay(500000);
auto runner = ::absl::make_unique<CalculatorRunner>(config);
AddDetectionFrameSize(cv::Rect_<float>(.4, .4, .2, .2), 0, 1000, 1000,
runner.get());
AddDetectionFrameSize(cv::Rect_<float>(.4, .4, .2, .2), 400000, 1000, 1000,
runner.get());
AddDetectionFrameSize(cv::Rect_<float>(.4, .4, .2, .2), 800000, 1000, 1000,
runner.get());
AddDetectionFrameSize(cv::Rect_<float>(.4, .4, .2, .2), 1000000, 1000, 1000,
runner.get());
AddDetectionFrameSize(cv::Rect_<float>(.4, .4, .2, .2), 1500000, 1000, 1000,
runner.get());
MP_ASSERT_OK(runner->Run());
CheckCropRect(500, 500, 1000, 1000, 0,
runner->Outputs().Tag("CROP_RECT").packets);
CheckCropRect(500, 500, 1000, 1000, 1,
runner->Outputs().Tag("CROP_RECT").packets);
CheckCropRect(500, 500, 470, 470, 2,
runner->Outputs().Tag("CROP_RECT").packets);
CheckCropRect(500, 500, 222, 222, 3,
runner->Outputs().Tag("CROP_RECT").packets);
CheckCropRect(500, 500, 222, 222, 4,
runner->Outputs().Tag("CROP_RECT").packets);
}
TEST(ContentZoomingCalculatorTest, ProvidesZeroSizeFirstRectWithoutDetections) {
auto config = ParseTextProtoOrDie<CalculatorGraphConfig::Node>(kConfigD);
auto runner = ::absl::make_unique<CalculatorRunner>(config);
auto input_size = ::absl::make_unique<std::pair<int, int>>(1000, 1000);
runner->MutableInputs()
->Tag("VIDEO_SIZE")
.packets.push_back(Adopt(input_size.release()).At(Timestamp(0)));
MP_ASSERT_OK(runner->Run());
const std::vector<Packet>& output_packets =
runner->Outputs().Tag("FIRST_CROP_RECT").packets;
ASSERT_EQ(output_packets.size(), 1);
const auto& rect = output_packets[0].Get<mediapipe::NormalizedRect>();
EXPECT_EQ(rect.x_center(), 0);
EXPECT_EQ(rect.y_center(), 0);
EXPECT_EQ(rect.width(), 0);
EXPECT_EQ(rect.height(), 0);
}
TEST(ContentZoomingCalculatorTest, ProvidesConstantFirstRect) {
auto config = ParseTextProtoOrDie<CalculatorGraphConfig::Node>(kConfigD);
auto* options = config.mutable_options()->MutableExtension(
ContentZoomingCalculatorOptions::ext);
options->set_us_to_first_rect(500000);
auto runner = ::absl::make_unique<CalculatorRunner>(config);
AddDetectionFrameSize(cv::Rect_<float>(.4, .4, .2, .2), 0, 1000, 1000,
runner.get());
AddDetectionFrameSize(cv::Rect_<float>(.4, .4, .2, .2), 500000, 1000, 1000,
runner.get());
AddDetectionFrameSize(cv::Rect_<float>(.4, .4, .2, .2), 1000000, 1000, 1000,
runner.get());
AddDetectionFrameSize(cv::Rect_<float>(.4, .4, .2, .2), 1500000, 1000, 1000,
runner.get());
MP_ASSERT_OK(runner->Run());
const std::vector<Packet>& output_packets =
runner->Outputs().Tag("FIRST_CROP_RECT").packets;
ASSERT_EQ(output_packets.size(), 4);
const auto& first_rect = output_packets[0].Get<mediapipe::NormalizedRect>();
EXPECT_NEAR(first_rect.x_center(), 0.5, 0.05);
EXPECT_NEAR(first_rect.y_center(), 0.5, 0.05);
EXPECT_NEAR(first_rect.width(), 0.222, 0.05);
EXPECT_NEAR(first_rect.height(), 0.222, 0.05);
for (int i = 1; i < 4; ++i) {
const auto& rect = output_packets[i].Get<mediapipe::NormalizedRect>();
EXPECT_EQ(first_rect.x_center(), rect.x_center());
EXPECT_EQ(first_rect.y_center(), rect.y_center());
EXPECT_EQ(first_rect.width(), rect.width());
EXPECT_EQ(first_rect.height(), rect.height());
}
}
} // namespace } // namespace
} // namespace autoflip } // namespace autoflip

View File

@ -12,6 +12,7 @@
// See the License for the specific language governing permissions and // See the License for the specific language governing permissions and
// limitations under the License. // limitations under the License.
#include "absl/flags/flag.h"
#include "absl/strings/string_view.h" #include "absl/strings/string_view.h"
#include "mediapipe/examples/desktop/autoflip/calculators/shot_boundary_calculator.pb.h" #include "mediapipe/examples/desktop/autoflip/calculators/shot_boundary_calculator.pb.h"
#include "mediapipe/framework/calculator_framework.h" #include "mediapipe/framework/calculator_framework.h"
@ -19,7 +20,6 @@
#include "mediapipe/framework/deps/file_path.h" #include "mediapipe/framework/deps/file_path.h"
#include "mediapipe/framework/formats/image_frame.h" #include "mediapipe/framework/formats/image_frame.h"
#include "mediapipe/framework/formats/image_frame_opencv.h" #include "mediapipe/framework/formats/image_frame_opencv.h"
#include "mediapipe/framework/port/commandlineflags.h"
#include "mediapipe/framework/port/gmock.h" #include "mediapipe/framework/port/gmock.h"
#include "mediapipe/framework/port/gtest.h" #include "mediapipe/framework/port/gtest.h"
#include "mediapipe/framework/port/opencv_core_inc.h" #include "mediapipe/framework/port/opencv_core_inc.h"

View File

@ -46,7 +46,9 @@ proto_library(
mediapipe_cc_proto_library( mediapipe_cc_proto_library(
name = "kinematic_path_solver_cc_proto", name = "kinematic_path_solver_cc_proto",
srcs = ["kinematic_path_solver.proto"], srcs = ["kinematic_path_solver.proto"],
visibility = ["//mediapipe/examples:__subpackages__"], visibility = [
"//mediapipe/examples:__subpackages__",
],
deps = [":kinematic_path_solver_proto"], deps = [":kinematic_path_solver_proto"],
) )
@ -96,11 +98,11 @@ cc_library(
deps = [ deps = [
"//mediapipe/framework/formats:image_frame", "//mediapipe/framework/formats:image_frame",
"//mediapipe/framework/formats:image_frame_opencv", "//mediapipe/framework/formats:image_frame_opencv",
"//mediapipe/framework/port:commandlineflags",
"//mediapipe/framework/port:opencv_core", "//mediapipe/framework/port:opencv_core",
"//mediapipe/framework/port:opencv_imgproc", "//mediapipe/framework/port:opencv_imgproc",
"//mediapipe/framework/port:ret_check", "//mediapipe/framework/port:ret_check",
"//mediapipe/framework/port:status", "//mediapipe/framework/port:status",
"@com_google_absl//absl/flags:flag",
], ],
) )
@ -249,10 +251,10 @@ cc_test(
":scene_camera_motion_analyzer", ":scene_camera_motion_analyzer",
"//mediapipe/examples/desktop/autoflip:autoflip_messages_cc_proto", "//mediapipe/examples/desktop/autoflip:autoflip_messages_cc_proto",
"//mediapipe/framework/deps:file_path", "//mediapipe/framework/deps:file_path",
"//mediapipe/framework/port:commandlineflags",
"//mediapipe/framework/port:file_helpers", "//mediapipe/framework/port:file_helpers",
"//mediapipe/framework/port:gtest_main", "//mediapipe/framework/port:gtest_main",
"//mediapipe/framework/port:status", "//mediapipe/framework/port:status",
"@com_google_absl//absl/flags:flag",
"@com_google_absl//absl/strings", "@com_google_absl//absl/strings",
], ],
) )
@ -280,13 +282,13 @@ cc_test(
"//mediapipe/framework/deps:file_path", "//mediapipe/framework/deps:file_path",
"//mediapipe/framework/formats:image_frame", "//mediapipe/framework/formats:image_frame",
"//mediapipe/framework/formats:image_frame_opencv", "//mediapipe/framework/formats:image_frame_opencv",
"//mediapipe/framework/port:commandlineflags",
"//mediapipe/framework/port:file_helpers", "//mediapipe/framework/port:file_helpers",
"//mediapipe/framework/port:gtest_main", "//mediapipe/framework/port:gtest_main",
"//mediapipe/framework/port:opencv_imgcodecs", "//mediapipe/framework/port:opencv_imgcodecs",
"//mediapipe/framework/port:opencv_imgproc", "//mediapipe/framework/port:opencv_imgproc",
"//mediapipe/framework/port:ret_check", "//mediapipe/framework/port:ret_check",
"//mediapipe/framework/port:status", "//mediapipe/framework/port:status",
"@com_google_absl//absl/flags:flag",
"@com_google_absl//absl/strings", "@com_google_absl//absl/strings",
], ],
) )

View File

@ -14,11 +14,11 @@
#include "mediapipe/examples/desktop/autoflip/quality/padding_effect_generator.h" #include "mediapipe/examples/desktop/autoflip/quality/padding_effect_generator.h"
#include "absl/flags/flag.h"
#include "absl/strings/str_cat.h" #include "absl/strings/str_cat.h"
#include "mediapipe/framework/deps/file_path.h" #include "mediapipe/framework/deps/file_path.h"
#include "mediapipe/framework/formats/image_frame.h" #include "mediapipe/framework/formats/image_frame.h"
#include "mediapipe/framework/formats/image_frame_opencv.h" #include "mediapipe/framework/formats/image_frame_opencv.h"
#include "mediapipe/framework/port/commandlineflags.h"
#include "mediapipe/framework/port/file_helpers.h" #include "mediapipe/framework/port/file_helpers.h"
#include "mediapipe/framework/port/gmock.h" #include "mediapipe/framework/port/gmock.h"
#include "mediapipe/framework/port/gtest.h" #include "mediapipe/framework/port/gtest.h"
@ -28,8 +28,9 @@
#include "mediapipe/framework/port/status_builder.h" #include "mediapipe/framework/port/status_builder.h"
#include "mediapipe/framework/port/status_matchers.h" #include "mediapipe/framework/port/status_matchers.h"
DEFINE_string(input_image, "", "The path to an input image."); ABSL_FLAG(std::string, input_image, "", "The path to an input image.");
DEFINE_string(output_folder, "", "The folder to output test result images."); ABSL_FLAG(std::string, output_folder, "",
"The folder to output test result images.");
namespace mediapipe { namespace mediapipe {
namespace autoflip { namespace autoflip {

View File

@ -19,12 +19,12 @@
#include <string> #include <string>
#include <vector> #include <vector>
#include "absl/flags/flag.h"
#include "absl/strings/str_split.h" #include "absl/strings/str_split.h"
#include "mediapipe/examples/desktop/autoflip/autoflip_messages.pb.h" #include "mediapipe/examples/desktop/autoflip/autoflip_messages.pb.h"
#include "mediapipe/examples/desktop/autoflip/quality/focus_point.pb.h" #include "mediapipe/examples/desktop/autoflip/quality/focus_point.pb.h"
#include "mediapipe/examples/desktop/autoflip/quality/piecewise_linear_function.h" #include "mediapipe/examples/desktop/autoflip/quality/piecewise_linear_function.h"
#include "mediapipe/framework/deps/file_path.h" #include "mediapipe/framework/deps/file_path.h"
#include "mediapipe/framework/port/commandlineflags.h"
#include "mediapipe/framework/port/file_helpers.h" #include "mediapipe/framework/port/file_helpers.h"
#include "mediapipe/framework/port/gmock.h" #include "mediapipe/framework/port/gmock.h"
#include "mediapipe/framework/port/gtest.h" #include "mediapipe/framework/port/gtest.h"

View File

@ -202,7 +202,7 @@ absl::Status DrawFocusPointAndCropWindow(
const auto& point = focus_point_frames[i].point(j); const auto& point = focus_point_frames[i].point(j);
const int x = point.norm_point_x() * scene_frame.cols; const int x = point.norm_point_x() * scene_frame.cols;
const int y = point.norm_point_y() * scene_frame.rows; const int y = point.norm_point_y() * scene_frame.rows;
cv::circle(viz_mat, cv::Point(x, y), 3, kRed, CV_FILLED); cv::circle(viz_mat, cv::Point(x, y), 3, kRed, cv::FILLED);
center_x += x; center_x += x;
center_y += y; center_y += y;
} }

View File

@ -15,10 +15,11 @@
// An example of sending OpenCV webcam frames into a MediaPipe graph. // An example of sending OpenCV webcam frames into a MediaPipe graph.
#include <cstdlib> #include <cstdlib>
#include "absl/flags/flag.h"
#include "absl/flags/parse.h"
#include "mediapipe/framework/calculator_framework.h" #include "mediapipe/framework/calculator_framework.h"
#include "mediapipe/framework/formats/image_frame.h" #include "mediapipe/framework/formats/image_frame.h"
#include "mediapipe/framework/formats/image_frame_opencv.h" #include "mediapipe/framework/formats/image_frame_opencv.h"
#include "mediapipe/framework/port/commandlineflags.h"
#include "mediapipe/framework/port/file_helpers.h" #include "mediapipe/framework/port/file_helpers.h"
#include "mediapipe/framework/port/opencv_highgui_inc.h" #include "mediapipe/framework/port/opencv_highgui_inc.h"
#include "mediapipe/framework/port/opencv_imgproc_inc.h" #include "mediapipe/framework/port/opencv_imgproc_inc.h"
@ -30,13 +31,12 @@ constexpr char kInputStream[] = "input_video";
constexpr char kOutputStream[] = "output_video"; constexpr char kOutputStream[] = "output_video";
constexpr char kWindowName[] = "MediaPipe"; constexpr char kWindowName[] = "MediaPipe";
DEFINE_string( ABSL_FLAG(std::string, calculator_graph_config_file, "",
calculator_graph_config_file, "",
"Name of file containing text format CalculatorGraphConfig proto."); "Name of file containing text format CalculatorGraphConfig proto.");
DEFINE_string(input_video_path, "", ABSL_FLAG(std::string, input_video_path, "",
"Full path of video to load. " "Full path of video to load. "
"If not provided, attempt to use a webcam."); "If not provided, attempt to use a webcam.");
DEFINE_string(output_video_path, "", ABSL_FLAG(std::string, output_video_path, "",
"Full path of where to save result (.mp4 only). " "Full path of where to save result (.mp4 only). "
"If not provided, show result in a window."); "If not provided, show result in a window.");
@ -148,7 +148,7 @@ absl::Status RunMPPGraph() {
int main(int argc, char** argv) { int main(int argc, char** argv) {
google::InitGoogleLogging(argv[0]); google::InitGoogleLogging(argv[0]);
gflags::ParseCommandLineFlags(&argc, &argv, true); absl::ParseCommandLine(argc, argv);
absl::Status run_status = RunMPPGraph(); absl::Status run_status = RunMPPGraph();
if (!run_status.ok()) { if (!run_status.ok()) {
LOG(ERROR) << "Failed to run the graph: " << run_status.message(); LOG(ERROR) << "Failed to run the graph: " << run_status.message();

View File

@ -16,10 +16,11 @@
// This example requires a linux computer and a GPU with EGL support drivers. // This example requires a linux computer and a GPU with EGL support drivers.
#include <cstdlib> #include <cstdlib>
#include "absl/flags/flag.h"
#include "absl/flags/parse.h"
#include "mediapipe/framework/calculator_framework.h" #include "mediapipe/framework/calculator_framework.h"
#include "mediapipe/framework/formats/image_frame.h" #include "mediapipe/framework/formats/image_frame.h"
#include "mediapipe/framework/formats/image_frame_opencv.h" #include "mediapipe/framework/formats/image_frame_opencv.h"
#include "mediapipe/framework/port/commandlineflags.h"
#include "mediapipe/framework/port/file_helpers.h" #include "mediapipe/framework/port/file_helpers.h"
#include "mediapipe/framework/port/opencv_highgui_inc.h" #include "mediapipe/framework/port/opencv_highgui_inc.h"
#include "mediapipe/framework/port/opencv_imgproc_inc.h" #include "mediapipe/framework/port/opencv_imgproc_inc.h"
@ -34,13 +35,12 @@ constexpr char kInputStream[] = "input_video";
constexpr char kOutputStream[] = "output_video"; constexpr char kOutputStream[] = "output_video";
constexpr char kWindowName[] = "MediaPipe"; constexpr char kWindowName[] = "MediaPipe";
DEFINE_string( ABSL_FLAG(std::string, calculator_graph_config_file, "",
calculator_graph_config_file, "",
"Name of file containing text format CalculatorGraphConfig proto."); "Name of file containing text format CalculatorGraphConfig proto.");
DEFINE_string(input_video_path, "", ABSL_FLAG(std::string, input_video_path, "",
"Full path of video to load. " "Full path of video to load. "
"If not provided, attempt to use a webcam."); "If not provided, attempt to use a webcam.");
DEFINE_string(output_video_path, "", ABSL_FLAG(std::string, output_video_path, "",
"Full path of where to save result (.mp4 only). " "Full path of where to save result (.mp4 only). "
"If not provided, show result in a window."); "If not provided, show result in a window.");
@ -191,7 +191,7 @@ absl::Status RunMPPGraph() {
int main(int argc, char** argv) { int main(int argc, char** argv) {
google::InitGoogleLogging(argv[0]); google::InitGoogleLogging(argv[0]);
gflags::ParseCommandLineFlags(&argc, &argv, true); absl::ParseCommandLine(argc, argv);
absl::Status run_status = RunMPPGraph(); absl::Status run_status = RunMPPGraph();
if (!run_status.ok()) { if (!run_status.ok()) {
LOG(ERROR) << "Failed to run the graph: " << run_status.message(); LOG(ERROR) << "Failed to run the graph: " << run_status.message();

View File

@ -23,7 +23,6 @@ cc_binary(
"//mediapipe/framework:calculator_framework", "//mediapipe/framework:calculator_framework",
"//mediapipe/framework/formats:image_frame", "//mediapipe/framework/formats:image_frame",
"//mediapipe/framework/formats:image_frame_opencv", "//mediapipe/framework/formats:image_frame_opencv",
"//mediapipe/framework/port:commandlineflags",
"//mediapipe/framework/port:file_helpers", "//mediapipe/framework/port:file_helpers",
"//mediapipe/framework/port:opencv_highgui", "//mediapipe/framework/port:opencv_highgui",
"//mediapipe/framework/port:opencv_imgproc", "//mediapipe/framework/port:opencv_imgproc",
@ -31,6 +30,8 @@ cc_binary(
"//mediapipe/framework/port:parse_text_proto", "//mediapipe/framework/port:parse_text_proto",
"//mediapipe/framework/port:status", "//mediapipe/framework/port:status",
"//mediapipe/graphs/iris_tracking:iris_depth_cpu_deps", "//mediapipe/graphs/iris_tracking:iris_depth_cpu_deps",
"@com_google_absl//absl/flags:flag",
"@com_google_absl//absl/flags:parse",
], ],
) )

View File

@ -17,11 +17,12 @@
#include <cstdlib> #include <cstdlib>
#include <memory> #include <memory>
#include "absl/flags/flag.h"
#include "absl/flags/parse.h"
#include "mediapipe/framework/calculator_framework.h" #include "mediapipe/framework/calculator_framework.h"
#include "mediapipe/framework/formats/image_frame.h" #include "mediapipe/framework/formats/image_frame.h"
#include "mediapipe/framework/formats/image_frame_opencv.h" #include "mediapipe/framework/formats/image_frame_opencv.h"
#include "mediapipe/framework/port/canonical_errors.h" #include "mediapipe/framework/port/canonical_errors.h"
#include "mediapipe/framework/port/commandlineflags.h"
#include "mediapipe/framework/port/file_helpers.h" #include "mediapipe/framework/port/file_helpers.h"
#include "mediapipe/framework/port/opencv_highgui_inc.h" #include "mediapipe/framework/port/opencv_highgui_inc.h"
#include "mediapipe/framework/port/opencv_imgproc_inc.h" #include "mediapipe/framework/port/opencv_imgproc_inc.h"
@ -38,10 +39,10 @@ constexpr char kCalculatorGraphConfigFile[] =
"mediapipe/graphs/iris_tracking/iris_depth_cpu.pbtxt"; "mediapipe/graphs/iris_tracking/iris_depth_cpu.pbtxt";
constexpr float kMicrosPerSecond = 1e6; constexpr float kMicrosPerSecond = 1e6;
DEFINE_string(input_image_path, "", ABSL_FLAG(std::string, input_image_path, "",
"Full path of image to load. " "Full path of image to load. "
"If not provided, nothing will run."); "If not provided, nothing will run.");
DEFINE_string(output_image_path, "", ABSL_FLAG(std::string, output_image_path, "",
"Full path of where to save image result (.jpg only). " "Full path of where to save image result (.jpg only). "
"If not provided, show result in a window."); "If not provided, show result in a window.");
@ -148,7 +149,7 @@ absl::Status RunMPPGraph() {
int main(int argc, char** argv) { int main(int argc, char** argv) {
google::InitGoogleLogging(argv[0]); google::InitGoogleLogging(argv[0]);
gflags::ParseCommandLineFlags(&argc, &argv, true); absl::ParseCommandLine(argc, argv);
absl::Status run_status = RunMPPGraph(); absl::Status run_status = RunMPPGraph();
if (!run_status.ok()) { if (!run_status.ok()) {
LOG(ERROR) << "Failed to run the graph: " << run_status.message(); LOG(ERROR) << "Failed to run the graph: " << run_status.message();

View File

@ -21,11 +21,12 @@ cc_library(
srcs = ["run_graph_file_io_main.cc"], srcs = ["run_graph_file_io_main.cc"],
deps = [ deps = [
"//mediapipe/framework:calculator_framework", "//mediapipe/framework:calculator_framework",
"//mediapipe/framework/port:commandlineflags",
"//mediapipe/framework/port:file_helpers", "//mediapipe/framework/port:file_helpers",
"//mediapipe/framework/port:map_util", "//mediapipe/framework/port:map_util",
"//mediapipe/framework/port:parse_text_proto", "//mediapipe/framework/port:parse_text_proto",
"//mediapipe/framework/port:status", "//mediapipe/framework/port:status",
"@com_google_absl//absl/flags:flag",
"@com_google_absl//absl/flags:parse",
"@com_google_absl//absl/strings", "@com_google_absl//absl/strings",
], ],
) )

View File

@ -17,23 +17,23 @@
// to disk. // to disk.
#include <cstdlib> #include <cstdlib>
#include "absl/flags/flag.h"
#include "absl/flags/parse.h"
#include "absl/strings/str_split.h" #include "absl/strings/str_split.h"
#include "mediapipe/framework/calculator_framework.h" #include "mediapipe/framework/calculator_framework.h"
#include "mediapipe/framework/port/commandlineflags.h"
#include "mediapipe/framework/port/file_helpers.h" #include "mediapipe/framework/port/file_helpers.h"
#include "mediapipe/framework/port/map_util.h" #include "mediapipe/framework/port/map_util.h"
#include "mediapipe/framework/port/parse_text_proto.h" #include "mediapipe/framework/port/parse_text_proto.h"
#include "mediapipe/framework/port/status.h" #include "mediapipe/framework/port/status.h"
DEFINE_string( ABSL_FLAG(std::string, calculator_graph_config_file, "",
calculator_graph_config_file, "",
"Name of file containing text format CalculatorGraphConfig proto."); "Name of file containing text format CalculatorGraphConfig proto.");
DEFINE_string(input_side_packets, "", ABSL_FLAG(std::string, input_side_packets, "",
"Comma-separated list of key=value pairs specifying side packets " "Comma-separated list of key=value pairs specifying side packets "
"and corresponding file paths for the CalculatorGraph. The side " "and corresponding file paths for the CalculatorGraph. The side "
"packets are read from the files and fed to the graph as strings " "packets are read from the files and fed to the graph as strings "
"even if they represent doubles, floats, etc."); "even if they represent doubles, floats, etc.");
DEFINE_string(output_side_packets, "", ABSL_FLAG(std::string, output_side_packets, "",
"Comma-separated list of key=value pairs specifying the output " "Comma-separated list of key=value pairs specifying the output "
"side packets and paths to write to disk for the " "side packets and paths to write to disk for the "
"CalculatorGraph."); "CalculatorGraph.");
@ -85,7 +85,7 @@ absl::Status RunMPPGraph() {
int main(int argc, char** argv) { int main(int argc, char** argv) {
google::InitGoogleLogging(argv[0]); google::InitGoogleLogging(argv[0]);
gflags::ParseCommandLineFlags(&argc, &argv, true); absl::ParseCommandLine(argc, argv);
absl::Status run_status = RunMPPGraph(); absl::Status run_status = RunMPPGraph();
if (!run_status.ok()) { if (!run_status.ok()) {
LOG(ERROR) << "Failed to run the graph: " << run_status.message(); LOG(ERROR) << "Failed to run the graph: " << run_status.message();

View File

@ -20,7 +20,7 @@ package(default_visibility = ["//mediapipe/examples:__subpackages__"])
# To run 3D object detection for shoes, # To run 3D object detection for shoes,
# bazel-bin/mediapipe/examples/desktop/object_detection_3d/objectron_cpu \ # bazel-bin/mediapipe/examples/desktop/object_detection_3d/objectron_cpu \
# --calculator_graph_config_file=mediapipe/graphs/object_detection_3d/objectron_desktop_cpu.pbtxt \ # --calculator_graph_config_file=mediapipe/graphs/object_detection_3d/objectron_desktop_cpu.pbtxt \
# --input_side_packets="input_video_path=<input_video_path>,box_landmark_model_path=mediapipe/models/object_detection_3d_sneakers.tflite,output_video_path=<output_video_path>,allowed_labels=Footwear" # --input_side_packets="input_video_path=<input_video_path>,box_landmark_model_path=mediapipe/modules/objectron/object_detection_3d_sneakers.tflite,output_video_path=<output_video_path>,allowed_labels=Footwear"
# To detect objects from other categories, change box_landmark_model_path and allowed_labels accordingly. # To detect objects from other categories, change box_landmark_model_path and allowed_labels accordingly.
# Chair: box_landmark_model_path=mediapipe/modules/objectron/object_detection_3d_chair.tflite,allowed_labels=Chair # Chair: box_landmark_model_path=mediapipe/modules/objectron/object_detection_3d_chair.tflite,allowed_labels=Chair
# Camera: box_landmark_model_path=mediapipe/modules/objectron/object_detection_3d_camera.tflite,allowed_labels=Camera # Camera: box_landmark_model_path=mediapipe/modules/objectron/object_detection_3d_camera.tflite,allowed_labels=Camera

View File

@ -20,11 +20,12 @@
#include <string> #include <string>
#include <vector> #include <vector>
#include "absl/flags/flag.h"
#include "absl/flags/parse.h"
#include "absl/strings/str_cat.h" #include "absl/strings/str_cat.h"
#include "absl/strings/str_split.h" #include "absl/strings/str_split.h"
#include "absl/strings/string_view.h" #include "absl/strings/string_view.h"
#include "mediapipe/framework/calculator_framework.h" #include "mediapipe/framework/calculator_framework.h"
#include "mediapipe/framework/port/commandlineflags.h"
#include "mediapipe/framework/port/file_helpers.h" #include "mediapipe/framework/port/file_helpers.h"
#include "mediapipe/framework/port/map_util.h" #include "mediapipe/framework/port/map_util.h"
#include "mediapipe/framework/port/parse_text_proto.h" #include "mediapipe/framework/port/parse_text_proto.h"
@ -32,29 +33,28 @@
#include "mediapipe/framework/port/status.h" #include "mediapipe/framework/port/status.h"
#include "mediapipe/framework/port/statusor.h" #include "mediapipe/framework/port/statusor.h"
DEFINE_string( ABSL_FLAG(std::string, calculator_graph_config_file, "",
calculator_graph_config_file, "",
"Name of file containing text format CalculatorGraphConfig proto."); "Name of file containing text format CalculatorGraphConfig proto.");
DEFINE_string(input_side_packets, "", ABSL_FLAG(std::string, input_side_packets, "",
"Comma-separated list of key=value pairs specifying side packets " "Comma-separated list of key=value pairs specifying side packets "
"for the CalculatorGraph. All values will be treated as the " "for the CalculatorGraph. All values will be treated as the "
"string type even if they represent doubles, floats, etc."); "string type even if they represent doubles, floats, etc.");
// Local file output flags. // Local file output flags.
// Output stream // Output stream
DEFINE_string(output_stream, "", ABSL_FLAG(std::string, output_stream, "",
"The output stream to output to the local file in csv format."); "The output stream to output to the local file in csv format.");
DEFINE_string(output_stream_file, "", ABSL_FLAG(std::string, output_stream_file, "",
"The name of the local file to output all packets sent to " "The name of the local file to output all packets sent to "
"the stream specified with --output_stream. "); "the stream specified with --output_stream. ");
DEFINE_bool(strip_timestamps, false, ABSL_FLAG(bool, strip_timestamps, false,
"If true, only the packet contents (without timestamps) will be " "If true, only the packet contents (without timestamps) will be "
"written into the local file."); "written into the local file.");
// Output side packets // Output side packets
DEFINE_string(output_side_packets, "", ABSL_FLAG(std::string, output_side_packets, "",
"A CSV of output side packets to output to local file."); "A CSV of output side packets to output to local file.");
DEFINE_string(output_side_packets_file, "", ABSL_FLAG(std::string, output_side_packets_file, "",
"The name of the local file to output all side packets specified " "The name of the local file to output all side packets specified "
"with --output_side_packets. "); "with --output_side_packets. ");
@ -143,7 +143,7 @@ absl::Status RunMPPGraph() {
int main(int argc, char** argv) { int main(int argc, char** argv) {
google::InitGoogleLogging(argv[0]); google::InitGoogleLogging(argv[0]);
gflags::ParseCommandLineFlags(&argc, &argv, true); absl::ParseCommandLine(argc, argv);
absl::Status run_status = RunMPPGraph(); absl::Status run_status = RunMPPGraph();
if (!run_status.ok()) { if (!run_status.ok()) {
LOG(ERROR) << "Failed to run the graph: " << run_status.message(); LOG(ERROR) << "Failed to run the graph: " << run_status.message();

View File

@ -18,10 +18,11 @@ cc_binary(
name = "extract_yt8m_features", name = "extract_yt8m_features",
srcs = ["extract_yt8m_features.cc"], srcs = ["extract_yt8m_features.cc"],
deps = [ deps = [
"@com_google_absl//absl/flags:flag",
"@com_google_absl//absl/flags:parse",
"@com_google_absl//absl/strings", "@com_google_absl//absl/strings",
"//mediapipe/framework:calculator_framework", "//mediapipe/framework:calculator_framework",
"//mediapipe/framework/formats:matrix", "//mediapipe/framework/formats:matrix",
"//mediapipe/framework/port:commandlineflags",
"//mediapipe/framework/port:file_helpers", "//mediapipe/framework/port:file_helpers",
"//mediapipe/framework/port:map_util", "//mediapipe/framework/port:map_util",
"//mediapipe/framework/port:parse_text_proto", "//mediapipe/framework/port:parse_text_proto",

View File

@ -17,24 +17,24 @@
// to disk. // to disk.
#include <cstdlib> #include <cstdlib>
#include "absl/flags/flag.h"
#include "absl/flags/parse.h"
#include "absl/strings/str_split.h" #include "absl/strings/str_split.h"
#include "mediapipe/framework/calculator_framework.h" #include "mediapipe/framework/calculator_framework.h"
#include "mediapipe/framework/formats/matrix.h" #include "mediapipe/framework/formats/matrix.h"
#include "mediapipe/framework/port/commandlineflags.h"
#include "mediapipe/framework/port/file_helpers.h" #include "mediapipe/framework/port/file_helpers.h"
#include "mediapipe/framework/port/map_util.h" #include "mediapipe/framework/port/map_util.h"
#include "mediapipe/framework/port/parse_text_proto.h" #include "mediapipe/framework/port/parse_text_proto.h"
#include "mediapipe/framework/port/status.h" #include "mediapipe/framework/port/status.h"
DEFINE_string( ABSL_FLAG(std::string, calculator_graph_config_file, "",
calculator_graph_config_file, "",
"Name of file containing text format CalculatorGraphConfig proto."); "Name of file containing text format CalculatorGraphConfig proto.");
DEFINE_string(input_side_packets, "", ABSL_FLAG(std::string, input_side_packets, "",
"Comma-separated list of key=value pairs specifying side packets " "Comma-separated list of key=value pairs specifying side packets "
"and corresponding file paths for the CalculatorGraph. The side " "and corresponding file paths for the CalculatorGraph. The side "
"packets are read from the files and fed to the graph as strings " "packets are read from the files and fed to the graph as strings "
"even if they represent doubles, floats, etc."); "even if they represent doubles, floats, etc.");
DEFINE_string(output_side_packets, "", ABSL_FLAG(std::string, output_side_packets, "",
"Comma-separated list of key=value pairs specifying the output " "Comma-separated list of key=value pairs specifying the output "
"side packets and paths to write to disk for the " "side packets and paths to write to disk for the "
"CalculatorGraph."); "CalculatorGraph.");
@ -126,7 +126,7 @@ absl::Status RunMPPGraph() {
int main(int argc, char** argv) { int main(int argc, char** argv) {
google::InitGoogleLogging(argv[0]); google::InitGoogleLogging(argv[0]);
gflags::ParseCommandLineFlags(&argc, &argv, true); absl::ParseCommandLine(argc, argv);
absl::Status run_status = RunMPPGraph(); absl::Status run_status = RunMPPGraph();
if (!run_status.ok()) { if (!run_status.ok()) {
LOG(ERROR) << "Failed to run the graph: " << run_status.message(); LOG(ERROR) << "Failed to run the graph: " << run_status.message();

View File

@ -23,7 +23,6 @@ package(default_visibility = ["//visibility:private"])
package_group( package_group(
name = "mediapipe_internal", name = "mediapipe_internal",
packages = [ packages = [
"//java/com/google/mediapipe/framework/...",
"//mediapipe/...", "//mediapipe/...",
], ],
) )
@ -78,21 +77,19 @@ mediapipe_proto_library(
mediapipe_proto_library( mediapipe_proto_library(
name = "mediapipe_options_proto", name = "mediapipe_options_proto",
srcs = ["mediapipe_options.proto"], srcs = ["mediapipe_options.proto"],
visibility = ["//mediapipe/framework:__subpackages__"], visibility = [":mediapipe_internal"],
) )
mediapipe_proto_library( mediapipe_proto_library(
name = "packet_factory_proto", name = "packet_factory_proto",
srcs = ["packet_factory.proto"], srcs = ["packet_factory.proto"],
visibility = ["//mediapipe/framework:__subpackages__"], visibility = [":mediapipe_internal"],
) )
mediapipe_proto_library( mediapipe_proto_library(
name = "packet_generator_proto", name = "packet_generator_proto",
srcs = ["packet_generator.proto"], srcs = ["packet_generator.proto"],
visibility = [ visibility = [":mediapipe_internal"],
"//mediapipe:__subpackages__",
],
) )
mediapipe_proto_library( mediapipe_proto_library(
@ -105,7 +102,7 @@ mediapipe_proto_library(
mediapipe_proto_library( mediapipe_proto_library(
name = "status_handler_proto", name = "status_handler_proto",
srcs = ["status_handler.proto"], srcs = ["status_handler.proto"],
visibility = ["//mediapipe/framework:__subpackages__"], visibility = [":mediapipe_internal"],
deps = ["//mediapipe/framework:mediapipe_options_proto"], deps = ["//mediapipe/framework:mediapipe_options_proto"],
) )
@ -274,14 +271,17 @@ cc_library(
], ],
deps = [ deps = [
":calculator_base", ":calculator_base",
":calculator_node",
":counter_factory", ":counter_factory",
":delegating_executor", ":delegating_executor",
":mediapipe_profiling", ":mediapipe_profiling",
":executor", ":executor",
":graph_output_stream", ":graph_output_stream",
":graph_service",
":graph_service_manager",
":input_stream_manager", ":input_stream_manager",
":input_stream_shard", ":input_stream_shard",
":graph_service", ":output_side_packet_impl",
":output_stream", ":output_stream",
":output_stream_manager", ":output_stream_manager",
":output_stream_poller", ":output_stream_poller",
@ -303,29 +303,27 @@ cc_library(
"//mediapipe/framework:packet_generator_cc_proto", "//mediapipe/framework:packet_generator_cc_proto",
"//mediapipe/framework:status_handler_cc_proto", "//mediapipe/framework:status_handler_cc_proto",
"//mediapipe/framework:thread_pool_executor_cc_proto", "//mediapipe/framework:thread_pool_executor_cc_proto",
"@com_google_absl//absl/container:flat_hash_map",
"@com_google_absl//absl/container:flat_hash_set",
"//mediapipe/gpu:graph_support",
"@com_google_absl//absl/base:core_headers", "@com_google_absl//absl/base:core_headers",
"@com_google_absl//absl/container:fixed_array", "@com_google_absl//absl/container:fixed_array",
"@com_google_absl//absl/container:flat_hash_map",
"@com_google_absl//absl/container:flat_hash_set",
"@com_google_absl//absl/memory", "@com_google_absl//absl/memory",
"@com_google_absl//absl/strings", "@com_google_absl//absl/strings",
"@com_google_absl//absl/strings:str_format", "@com_google_absl//absl/strings:str_format",
"@com_google_absl//absl/synchronization", "@com_google_absl//absl/synchronization",
":calculator_node",
":output_side_packet_impl",
"//mediapipe/framework/profiler:graph_profiler",
"//mediapipe/framework/tool:fill_packet_set",
"//mediapipe/framework/tool:status_util",
"//mediapipe/framework/tool:tag_map",
"//mediapipe/framework/tool:validate",
"//mediapipe/framework/tool:validate_name",
"//mediapipe/framework/port:core_proto", "//mediapipe/framework/port:core_proto",
"//mediapipe/framework/port:integral_types", "//mediapipe/framework/port:integral_types",
"//mediapipe/framework/port:logging", "//mediapipe/framework/port:logging",
"//mediapipe/framework/port:ret_check", "//mediapipe/framework/port:ret_check",
"//mediapipe/framework/port:source_location", "//mediapipe/framework/port:source_location",
"//mediapipe/framework/port:status", "//mediapipe/framework/port:status",
"//mediapipe/framework/profiler:graph_profiler",
"//mediapipe/framework/tool:fill_packet_set",
"//mediapipe/framework/tool:status_util",
"//mediapipe/framework/tool:tag_map",
"//mediapipe/framework/tool:validate",
"//mediapipe/framework/tool:validate_name",
"//mediapipe/gpu:graph_support",
"//mediapipe/util:cpu_util", "//mediapipe/util:cpu_util",
] + select({ ] + select({
"//conditions:default": [ "//conditions:default": [
@ -336,6 +334,28 @@ cc_library(
}), }),
) )
cc_library(
name = "graph_service_manager",
srcs = ["graph_service_manager.cc"],
hdrs = ["graph_service_manager.h"],
visibility = [":mediapipe_internal"],
deps = [
":graph_service",
"//mediapipe/framework:packet",
"@com_google_absl//absl/status",
],
)
cc_test(
name = "graph_service_manager_test",
srcs = ["graph_service_manager_test.cc"],
deps = [
":graph_service_manager",
"//mediapipe/framework:packet",
"//mediapipe/framework/port:gtest_main",
],
)
cc_library( cc_library(
name = "calculator_node", name = "calculator_node",
srcs = ["calculator_node.cc"], srcs = ["calculator_node.cc"],
@ -425,6 +445,7 @@ cc_library(
":counter", ":counter",
":counter_factory", ":counter_factory",
":graph_service", ":graph_service",
":graph_service_manager",
":input_stream", ":input_stream",
":output_stream", ":output_stream",
":packet", ":packet",
@ -977,6 +998,8 @@ cc_library(
hdrs = ["subgraph.h"], hdrs = ["subgraph.h"],
visibility = ["//visibility:public"], visibility = ["//visibility:public"],
deps = [ deps = [
":graph_service",
":graph_service_manager",
":port", ":port",
"//mediapipe/framework:calculator_cc_proto", "//mediapipe/framework:calculator_cc_proto",
"//mediapipe/framework:mediapipe_options_cc_proto", "//mediapipe/framework:mediapipe_options_cc_proto",
@ -989,6 +1012,8 @@ cc_library(
"//mediapipe/framework/tool:template_expander", "//mediapipe/framework/tool:template_expander",
"@com_google_absl//absl/base:core_headers", "@com_google_absl//absl/base:core_headers",
"@com_google_absl//absl/memory", "@com_google_absl//absl/memory",
"@com_google_absl//absl/status",
"@com_google_absl//absl/types:optional",
], ],
) )
@ -1008,7 +1033,7 @@ cc_library(
"//mediapipe/framework/port:status", "//mediapipe/framework/port:status",
"@com_google_absl//absl/base:core_headers", "@com_google_absl//absl/base:core_headers",
"@com_google_absl//absl/strings", "@com_google_absl//absl/strings",
"@eigen_archive//:eigen", "@eigen_archive//:eigen3",
], ],
alwayslink = 1, alwayslink = 1,
) )
@ -1102,6 +1127,7 @@ cc_library(
deps = [ deps = [
":calculator_base", ":calculator_base",
":calculator_contract", ":calculator_contract",
":graph_service_manager",
":legacy_calculator_support", ":legacy_calculator_support",
":packet", ":packet",
":packet_generator", ":packet_generator",
@ -1136,6 +1162,24 @@ cc_library(
], ],
) )
cc_test(
name = "validated_graph_config_test",
srcs = ["validated_graph_config_test.cc"],
deps = [
":calculator_framework",
":graph_service",
":graph_service_manager",
":validated_graph_config",
"//mediapipe/framework:calculator_cc_proto",
"//mediapipe/framework/api2:node",
"//mediapipe/framework/api2:port",
"//mediapipe/framework/deps:message_matchers",
"//mediapipe/framework/port:gtest_main",
"@com_google_absl//absl/status",
"@com_google_absl//absl/strings",
],
)
cc_library( cc_library(
name = "graph_validation", name = "graph_validation",
hdrs = ["graph_validation.h"], hdrs = ["graph_validation.h"],
@ -1591,13 +1635,16 @@ cc_test(
srcs = ["subgraph_test.cc"], srcs = ["subgraph_test.cc"],
deps = [ deps = [
":calculator_framework", ":calculator_framework",
":graph_service_manager",
":subgraph", ":subgraph",
":test_calculators", ":test_calculators",
"//mediapipe/calculators/core:constant_side_packet_calculator",
"//mediapipe/calculators/core:pass_through_calculator", "//mediapipe/calculators/core:pass_through_calculator",
"//mediapipe/framework/port:gtest_main", "//mediapipe/framework/port:gtest_main",
"//mediapipe/framework/port:status", "//mediapipe/framework/port:parse_text_proto",
"//mediapipe/framework/tool:sink", "//mediapipe/framework/tool:sink",
"//mediapipe/framework/tool/testdata:dub_quad_test_subgraph", "//mediapipe/framework/tool/testdata:dub_quad_test_subgraph",
"@com_google_absl//absl/strings:str_format",
], ],
) )

View File

@ -41,9 +41,9 @@ Counter* CalculatorContext::GetCounter(const std::string& name) {
return calculator_state_->GetCounter(name); return calculator_state_->GetCounter(name);
} }
CounterSet* CalculatorContext::GetCounterSet() { CounterFactory* CalculatorContext::GetCounterFactory() {
CHECK(calculator_state_); CHECK(calculator_state_);
return calculator_state_->GetCounterSet(); return calculator_state_->GetCounterFactory();
} }
const PacketSet& CalculatorContext::InputSidePackets() const { const PacketSet& CalculatorContext::InputSidePackets() const {

View File

@ -76,7 +76,7 @@ class CalculatorContext {
// Returns the counter set, which can be used to create new counters. // Returns the counter set, which can be used to create new counters.
// No prefix is added to counters created in this way. // No prefix is added to counters created in this way.
CounterSet* GetCounterSet(); CounterFactory* GetCounterFactory();
// Returns the current input timestamp, or Timestamp::Unset if there are // Returns the current input timestamp, or Timestamp::Unset if there are
// no input packets. // no input packets.
@ -113,26 +113,9 @@ class CalculatorContext {
return calculator_state_->GetSharedProfilingContext().get(); return calculator_state_->GetSharedProfilingContext().get();
} }
template <typename T>
class ServiceBinding {
public:
bool IsAvailable() {
return calculator_state_->IsServiceAvailable(service_);
}
T& GetObject() { return calculator_state_->GetServiceObject(service_); }
ServiceBinding(CalculatorState* calculator_state,
const GraphService<T>& service)
: calculator_state_(calculator_state), service_(service) {}
private:
CalculatorState* calculator_state_;
const GraphService<T>& service_;
};
template <typename T> template <typename T>
ServiceBinding<T> Service(const GraphService<T>& service) { ServiceBinding<T> Service(const GraphService<T>& service) {
return ServiceBinding<T>(calculator_state_, service); return ServiceBinding<T>(calculator_state_->GetServiceObject(service));
} }
private: private:

View File

@ -36,6 +36,7 @@
#include "mediapipe/framework/calculator_base.h" #include "mediapipe/framework/calculator_base.h"
#include "mediapipe/framework/counter_factory.h" #include "mediapipe/framework/counter_factory.h"
#include "mediapipe/framework/delegating_executor.h" #include "mediapipe/framework/delegating_executor.h"
#include "mediapipe/framework/graph_service_manager.h"
#include "mediapipe/framework/input_stream_manager.h" #include "mediapipe/framework/input_stream_manager.h"
#include "mediapipe/framework/mediapipe_profiling.h" #include "mediapipe/framework/mediapipe_profiling.h"
#include "mediapipe/framework/packet_generator.h" #include "mediapipe/framework/packet_generator.h"
@ -392,7 +393,8 @@ absl::Status CalculatorGraph::Initialize(
const CalculatorGraphConfig& input_config, const CalculatorGraphConfig& input_config,
const std::map<std::string, Packet>& side_packets) { const std::map<std::string, Packet>& side_packets) {
auto validated_graph = absl::make_unique<ValidatedGraphConfig>(); auto validated_graph = absl::make_unique<ValidatedGraphConfig>();
MP_RETURN_IF_ERROR(validated_graph->Initialize(input_config)); MP_RETURN_IF_ERROR(validated_graph->Initialize(
input_config, /*graph_registry=*/nullptr, &service_manager_));
return Initialize(std::move(validated_graph), side_packets); return Initialize(std::move(validated_graph), side_packets);
} }
@ -402,8 +404,8 @@ absl::Status CalculatorGraph::Initialize(
const std::map<std::string, Packet>& side_packets, const std::map<std::string, Packet>& side_packets,
const std::string& graph_type, const Subgraph::SubgraphOptions* options) { const std::string& graph_type, const Subgraph::SubgraphOptions* options) {
auto validated_graph = absl::make_unique<ValidatedGraphConfig>(); auto validated_graph = absl::make_unique<ValidatedGraphConfig>();
MP_RETURN_IF_ERROR(validated_graph->Initialize(input_configs, input_templates, MP_RETURN_IF_ERROR(validated_graph->Initialize(
graph_type, options)); input_configs, input_templates, graph_type, options, &service_manager_));
return Initialize(std::move(validated_graph), side_packets); return Initialize(std::move(validated_graph), side_packets);
} }
@ -509,19 +511,15 @@ absl::Status CalculatorGraph::StartRun(
#if !MEDIAPIPE_DISABLE_GPU #if !MEDIAPIPE_DISABLE_GPU
absl::Status CalculatorGraph::SetGpuResources( absl::Status CalculatorGraph::SetGpuResources(
std::shared_ptr<::mediapipe::GpuResources> resources) { std::shared_ptr<::mediapipe::GpuResources> resources) {
RET_CHECK(!ContainsKey(service_packets_, kGpuService.key)) auto gpu_service = service_manager_.GetServiceObject(kGpuService);
RET_CHECK_EQ(gpu_service, nullptr)
<< "The GPU resources have already been configured."; << "The GPU resources have already been configured.";
service_packets_[kGpuService.key] = return service_manager_.SetServiceObject(kGpuService, std::move(resources));
MakePacket<std::shared_ptr<::mediapipe::GpuResources>>(
std::move(resources));
return absl::OkStatus();
} }
std::shared_ptr<::mediapipe::GpuResources> CalculatorGraph::GetGpuResources() std::shared_ptr<::mediapipe::GpuResources> CalculatorGraph::GetGpuResources()
const { const {
auto service_iter = service_packets_.find(kGpuService.key); return service_manager_.GetServiceObject(kGpuService);
if (service_iter == service_packets_.end()) return nullptr;
return service_iter->second.Get<std::shared_ptr<::mediapipe::GpuResources>>();
} }
absl::StatusOr<std::map<std::string, Packet>> CalculatorGraph::PrepareGpu( absl::StatusOr<std::map<std::string, Packet>> CalculatorGraph::PrepareGpu(
@ -536,8 +534,7 @@ absl::StatusOr<std::map<std::string, Packet>> CalculatorGraph::PrepareGpu(
} }
} }
if (uses_gpu) { if (uses_gpu) {
auto service_iter = service_packets_.find(kGpuService.key); auto gpu_resources = service_manager_.GetServiceObject(kGpuService);
bool has_service = service_iter != service_packets_.end();
auto legacy_sp_iter = side_packets.find(kGpuSharedSidePacketName); auto legacy_sp_iter = side_packets.find(kGpuSharedSidePacketName);
// Workaround for b/116875321: CalculatorRunner provides an empty packet, // Workaround for b/116875321: CalculatorRunner provides an empty packet,
@ -545,15 +542,12 @@ absl::StatusOr<std::map<std::string, Packet>> CalculatorGraph::PrepareGpu(
bool has_legacy_sp = legacy_sp_iter != side_packets.end() && bool has_legacy_sp = legacy_sp_iter != side_packets.end() &&
!legacy_sp_iter->second.IsEmpty(); !legacy_sp_iter->second.IsEmpty();
std::shared_ptr<::mediapipe::GpuResources> gpu_resources; if (gpu_resources) {
if (has_service) {
if (has_legacy_sp) { if (has_legacy_sp) {
LOG(WARNING) LOG(WARNING)
<< "::mediapipe::GpuSharedData provided as a side packet while the " << "::mediapipe::GpuSharedData provided as a side packet while the "
<< "graph already had one; ignoring side packet"; << "graph already had one; ignoring side packet";
} }
gpu_resources = service_iter->second
.Get<std::shared_ptr<::mediapipe::GpuResources>>();
update_sp = true; update_sp = true;
} else { } else {
if (has_legacy_sp) { if (has_legacy_sp) {
@ -564,8 +558,8 @@ absl::StatusOr<std::map<std::string, Packet>> CalculatorGraph::PrepareGpu(
ASSIGN_OR_RETURN(gpu_resources, ::mediapipe::GpuResources::Create()); ASSIGN_OR_RETURN(gpu_resources, ::mediapipe::GpuResources::Create());
update_sp = true; update_sp = true;
} }
service_packets_[kGpuService.key] = MP_RETURN_IF_ERROR(
MakePacket<std::shared_ptr<::mediapipe::GpuResources>>(gpu_resources); service_manager_.SetServiceObject(kGpuService, gpu_resources));
} }
// Create or replace the legacy side packet if needed. // Create or replace the legacy side packet if needed.
@ -682,8 +676,10 @@ absl::Status CalculatorGraph::PrepareForRun(
std::placeholders::_1, std::placeholders::_2); std::placeholders::_1, std::placeholders::_2);
node.SetQueueSizeCallbacks(queue_size_callback, queue_size_callback); node.SetQueueSizeCallbacks(queue_size_callback, queue_size_callback);
scheduler_.AssignNodeToSchedulerQueue(&node); scheduler_.AssignNodeToSchedulerQueue(&node);
// TODO: update calculator node to use GraphServiceManager
// instead of service packets?
const absl::Status result = node.PrepareForRun( const absl::Status result = node.PrepareForRun(
current_run_side_packets_, service_packets_, current_run_side_packets_, service_manager_.ServicePackets(),
std::bind(&internal::Scheduler::ScheduleNodeForOpen, &scheduler_, std::bind(&internal::Scheduler::ScheduleNodeForOpen, &scheduler_,
&node), &node),
std::bind(&internal::Scheduler::AddNodeToSourcesQueue, &scheduler_, std::bind(&internal::Scheduler::AddNodeToSourcesQueue, &scheduler_,
@ -811,6 +807,11 @@ absl::Status CalculatorGraph::AddPacketToInputStreamInternal(
CHECK_GE(node_id, validated_graph_->CalculatorInfos().size()); CHECK_GE(node_id, validated_graph_->CalculatorInfos().size());
{ {
absl::MutexLock lock(&full_input_streams_mutex_); absl::MutexLock lock(&full_input_streams_mutex_);
if (full_input_streams_.empty()) {
return mediapipe::FailedPreconditionErrorBuilder(MEDIAPIPE_LOC)
<< "CalculatorGraph::AddPacketToInputStream() is called before "
"StartRun()";
}
if (graph_input_stream_add_mode_ == if (graph_input_stream_add_mode_ ==
GraphInputStreamAddMode::ADD_IF_NOT_FULL) { GraphInputStreamAddMode::ADD_IF_NOT_FULL) {
if (has_error_) { if (has_error_) {
@ -1170,21 +1171,6 @@ void CalculatorGraph::Pause() { scheduler_.Pause(); }
void CalculatorGraph::Resume() { scheduler_.Resume(); } void CalculatorGraph::Resume() { scheduler_.Resume(); }
absl::Status CalculatorGraph::SetServicePacket(const GraphServiceBase& service,
Packet p) {
// TODO: check that the graph has not been started!
service_packets_[service.key] = std::move(p);
return absl::OkStatus();
}
Packet CalculatorGraph::GetServicePacket(const GraphServiceBase& service) {
auto it = service_packets_.find(service.key);
if (it == service_packets_.end()) {
return {};
}
return it->second;
}
absl::Status CalculatorGraph::SetExecutorInternal( absl::Status CalculatorGraph::SetExecutorInternal(
const std::string& name, std::shared_ptr<Executor> executor) { const std::string& name, std::shared_ptr<Executor> executor) {
if (!executors_.emplace(name, executor).second) { if (!executors_.emplace(name, executor).second) {

View File

@ -38,6 +38,7 @@
#include "mediapipe/framework/executor.h" #include "mediapipe/framework/executor.h"
#include "mediapipe/framework/graph_output_stream.h" #include "mediapipe/framework/graph_output_stream.h"
#include "mediapipe/framework/graph_service.h" #include "mediapipe/framework/graph_service.h"
#include "mediapipe/framework/graph_service_manager.h"
#include "mediapipe/framework/mediapipe_profiling.h" #include "mediapipe/framework/mediapipe_profiling.h"
#include "mediapipe/framework/output_side_packet_impl.h" #include "mediapipe/framework/output_side_packet_impl.h"
#include "mediapipe/framework/output_stream.h" #include "mediapipe/framework/output_stream.h"
@ -377,19 +378,20 @@ class CalculatorGraph {
template <typename T> template <typename T>
absl::Status SetServiceObject(const GraphService<T>& service, absl::Status SetServiceObject(const GraphService<T>& service,
std::shared_ptr<T> object) { std::shared_ptr<T> object) {
return SetServicePacket(service, // TODO: check that the graph has not been started!
MakePacket<std::shared_ptr<T>>(std::move(object))); return service_manager_.SetServiceObject(service, object);
} }
template <typename T> template <typename T>
std::shared_ptr<T> GetServiceObject(const GraphService<T>& service) { std::shared_ptr<T> GetServiceObject(const GraphService<T>& service) {
Packet p = GetServicePacket(service); return service_manager_.GetServiceObject(service);
if (p.IsEmpty()) return nullptr;
return p.Get<std::shared_ptr<T>>();
} }
// Only the Java API should call this directly. // Only the Java API should call this directly.
absl::Status SetServicePacket(const GraphServiceBase& service, Packet p); absl::Status SetServicePacket(const GraphServiceBase& service, Packet p) {
// TODO: check that the graph has not been started!
return service_manager_.SetServicePacket(service, p);
}
private: private:
// GraphRunState is used as a parameter in the function CallStatusHandlers. // GraphRunState is used as a parameter in the function CallStatusHandlers.
@ -523,7 +525,6 @@ class CalculatorGraph {
// status before taking any action. // status before taking any action.
void UpdateThrottledNodes(InputStreamManager* stream, bool* stream_was_full); void UpdateThrottledNodes(InputStreamManager* stream, bool* stream_was_full);
Packet GetServicePacket(const GraphServiceBase& service);
#if !MEDIAPIPE_DISABLE_GPU #if !MEDIAPIPE_DISABLE_GPU
// Owns the legacy GpuSharedData if we need to create one for backwards // Owns the legacy GpuSharedData if we need to create one for backwards
// compatibility. // compatibility.
@ -598,7 +599,8 @@ class CalculatorGraph {
// The processed input side packet map for this run. // The processed input side packet map for this run.
std::map<std::string, Packet> current_run_side_packets_; std::map<std::string, Packet> current_run_side_packets_;
std::map<std::string, Packet> service_packets_; // Object to manage graph services.
GraphServiceManager service_manager_;
// Vector of errors encountered while running graph. Always use RecordError() // Vector of errors encountered while running graph. Always use RecordError()
// to add an error to this vector. // to add an error to this vector.

View File

@ -1361,6 +1361,38 @@ TEST(CalculatorGraphBoundsTest, ProcessTimestampBounds_Passthrough) {
MP_ASSERT_OK(graph.WaitUntilDone()); MP_ASSERT_OK(graph.WaitUntilDone());
} }
TEST(CalculatorGraphBoundsTest, PostStreamPacketToSetProcessTimestampBound) {
std::string config_str = R"(
input_stream: "input_0"
node {
calculator: "ProcessBoundToPacketCalculator"
input_stream: "input_0"
output_stream: "output_0"
}
)";
CalculatorGraphConfig config =
mediapipe::ParseTextProtoOrDie<CalculatorGraphConfig>(config_str);
CalculatorGraph graph;
std::vector<Packet> output_0_packets;
MP_ASSERT_OK(graph.Initialize(config));
MP_ASSERT_OK(graph.ObserveOutputStream("output_0", [&](const Packet& p) {
output_0_packets.push_back(p);
return absl::OkStatus();
}));
MP_ASSERT_OK(graph.StartRun({}));
MP_ASSERT_OK(graph.WaitUntilIdle());
MP_ASSERT_OK(graph.AddPacketToInputStream(
"input_0", MakePacket<int>(0).At(Timestamp::PostStream())));
MP_ASSERT_OK(graph.WaitUntilIdle());
EXPECT_EQ(output_0_packets.size(), 1);
EXPECT_EQ(output_0_packets[0].Timestamp(), Timestamp::PostStream());
// Shutdown the graph.
MP_ASSERT_OK(graph.CloseAllPacketSources());
MP_ASSERT_OK(graph.WaitUntilDone());
}
// A Calculator that sends a timestamp bound for every other input. // A Calculator that sends a timestamp bound for every other input.
class OccasionalBoundCalculator : public CalculatorBase { class OccasionalBoundCalculator : public CalculatorBase {
public: public:

View File

@ -4356,256 +4356,5 @@ TEST(CalculatorGraph, GraphInputStreamWithTag) {
ASSERT_EQ(5, packet_dump.size()); ASSERT_EQ(5, packet_dump.size());
} }
// Returns the first packet of the input stream.
class FirstPacketFilterCalculator : public CalculatorBase {
public:
FirstPacketFilterCalculator() {}
~FirstPacketFilterCalculator() override {}
static absl::Status GetContract(CalculatorContract* cc) {
cc->Inputs().Index(0).SetAny();
cc->Outputs().Index(0).SetSameAs(&cc->Inputs().Index(0));
return absl::OkStatus();
}
absl::Status Process(CalculatorContext* cc) override {
if (!seen_first_packet_) {
cc->Outputs().Index(0).AddPacket(cc->Inputs().Index(0).Value());
cc->Outputs().Index(0).Close();
seen_first_packet_ = true;
}
return absl::OkStatus();
}
private:
bool seen_first_packet_ = false;
};
REGISTER_CALCULATOR(FirstPacketFilterCalculator);
constexpr int kDefaultMaxCount = 1000;
TEST(CalculatorGraph, TestPollPacket) {
CalculatorGraphConfig config;
CalculatorGraphConfig::Node* node = config.add_node();
node->set_calculator("CountingSourceCalculator");
node->add_output_stream("output");
node->add_input_side_packet("MAX_COUNT:max_count");
CalculatorGraph graph;
MP_ASSERT_OK(graph.Initialize(config));
auto status_or_poller = graph.AddOutputStreamPoller("output");
ASSERT_TRUE(status_or_poller.ok());
OutputStreamPoller poller = std::move(status_or_poller.value());
MP_ASSERT_OK(
graph.StartRun({{"max_count", MakePacket<int>(kDefaultMaxCount)}}));
Packet packet;
int num_packets = 0;
while (poller.Next(&packet)) {
EXPECT_EQ(num_packets, packet.Get<int>());
++num_packets;
}
MP_ASSERT_OK(graph.CloseAllPacketSources());
MP_ASSERT_OK(graph.WaitUntilDone());
EXPECT_FALSE(poller.Next(&packet));
EXPECT_EQ(kDefaultMaxCount, num_packets);
}
TEST(CalculatorGraph, TestOutputStreamPollerDesiredQueueSize) {
CalculatorGraphConfig config;
CalculatorGraphConfig::Node* node = config.add_node();
node->set_calculator("CountingSourceCalculator");
node->add_output_stream("output");
node->add_input_side_packet("MAX_COUNT:max_count");
for (int queue_size = 1; queue_size < 10; ++queue_size) {
CalculatorGraph graph;
MP_ASSERT_OK(graph.Initialize(config));
auto status_or_poller = graph.AddOutputStreamPoller("output");
ASSERT_TRUE(status_or_poller.ok());
OutputStreamPoller poller = std::move(status_or_poller.value());
poller.SetMaxQueueSize(queue_size);
MP_ASSERT_OK(
graph.StartRun({{"max_count", MakePacket<int>(kDefaultMaxCount)}}));
Packet packet;
int num_packets = 0;
while (poller.Next(&packet)) {
EXPECT_EQ(num_packets, packet.Get<int>());
++num_packets;
}
MP_ASSERT_OK(graph.CloseAllPacketSources());
MP_ASSERT_OK(graph.WaitUntilDone());
EXPECT_FALSE(poller.Next(&packet));
EXPECT_EQ(kDefaultMaxCount, num_packets);
}
}
TEST(CalculatorGraph, TestPollPacketsFromMultipleStreams) {
CalculatorGraphConfig config;
CalculatorGraphConfig::Node* node1 = config.add_node();
node1->set_calculator("CountingSourceCalculator");
node1->add_output_stream("stream1");
node1->add_input_side_packet("MAX_COUNT:max_count");
CalculatorGraphConfig::Node* node2 = config.add_node();
node2->set_calculator("PassThroughCalculator");
node2->add_input_stream("stream1");
node2->add_output_stream("stream2");
CalculatorGraph graph;
MP_ASSERT_OK(graph.Initialize(config));
auto status_or_poller1 = graph.AddOutputStreamPoller("stream1");
ASSERT_TRUE(status_or_poller1.ok());
OutputStreamPoller poller1 = std::move(status_or_poller1.value());
auto status_or_poller2 = graph.AddOutputStreamPoller("stream2");
ASSERT_TRUE(status_or_poller2.ok());
OutputStreamPoller poller2 = std::move(status_or_poller2.value());
MP_ASSERT_OK(
graph.StartRun({{"max_count", MakePacket<int>(kDefaultMaxCount)}}));
Packet packet1;
Packet packet2;
int num_packets1 = 0;
int num_packets2 = 0;
int running_pollers = 2;
while (running_pollers > 0) {
if (poller1.Next(&packet1)) {
EXPECT_EQ(num_packets1++, packet1.Get<int>());
} else {
--running_pollers;
}
if (poller2.Next(&packet2)) {
EXPECT_EQ(num_packets2++, packet2.Get<int>());
} else {
--running_pollers;
}
}
MP_ASSERT_OK(graph.CloseAllPacketSources());
MP_ASSERT_OK(graph.WaitUntilDone());
EXPECT_FALSE(poller1.Next(&packet1));
EXPECT_FALSE(poller2.Next(&packet2));
EXPECT_EQ(kDefaultMaxCount, num_packets1);
EXPECT_EQ(kDefaultMaxCount, num_packets2);
}
// Ensure that when a custom input stream handler is used to handle packets from
// input streams, an error message is outputted with the appropriate link to
// resolve the issue when the calculator doesn't handle inputs in monotonically
// increasing order of timestamps.
TEST(CalculatorGraph, SimpleMuxCalculatorWithCustomInputStreamHandler) {
CalculatorGraph graph;
CalculatorGraphConfig config =
mediapipe::ParseTextProtoOrDie<CalculatorGraphConfig>(R"(
input_stream: 'input0'
input_stream: 'input1'
node {
calculator: 'SimpleMuxCalculator'
input_stream: 'input0'
input_stream: 'input1'
input_stream_handler {
input_stream_handler: "ImmediateInputStreamHandler"
}
output_stream: 'output'
}
)");
std::vector<Packet> packet_dump;
tool::AddVectorSink("output", &config, &packet_dump);
MP_ASSERT_OK(graph.Initialize(config));
MP_ASSERT_OK(graph.StartRun({}));
// Send packets to input stream "input0" at timestamps 0 and 1 consecutively.
Timestamp input0_timestamp = Timestamp(0);
MP_EXPECT_OK(graph.AddPacketToInputStream(
"input0", MakePacket<int>(1).At(input0_timestamp)));
MP_ASSERT_OK(graph.WaitUntilIdle());
ASSERT_EQ(1, packet_dump.size());
EXPECT_EQ(1, packet_dump[0].Get<int>());
++input0_timestamp;
MP_EXPECT_OK(graph.AddPacketToInputStream(
"input0", MakePacket<int>(3).At(input0_timestamp)));
MP_ASSERT_OK(graph.WaitUntilIdle());
ASSERT_EQ(2, packet_dump.size());
EXPECT_EQ(3, packet_dump[1].Get<int>());
// Send a packet to input stream "input1" at timestamp 0 after sending two
// packets at timestamps 0 and 1 to input stream "input0". This will result
// in a mismatch in timestamps as the SimpleMuxCalculator doesn't handle
// inputs from all streams in monotonically increasing order of timestamps.
Timestamp input1_timestamp = Timestamp(0);
MP_EXPECT_OK(graph.AddPacketToInputStream(
"input1", MakePacket<int>(2).At(input1_timestamp)));
absl::Status run_status = graph.WaitUntilIdle();
EXPECT_THAT(
run_status.ToString(),
testing::AllOf(
// The core problem.
testing::HasSubstr("timestamp mismatch on a calculator"),
testing::HasSubstr(
"timestamps that are not strictly monotonically increasing"),
// Link to the possible solution.
testing::HasSubstr("ImmediateInputStreamHandler class comment")));
}
void DoTestMultipleGraphRuns(absl::string_view input_stream_handler,
bool select_packet) {
std::string graph_proto = absl::StrFormat(R"(
input_stream: 'input'
input_stream: 'select'
node {
calculator: 'PassThroughCalculator'
input_stream: 'input'
input_stream: 'select'
input_stream_handler {
input_stream_handler: "%s"
}
output_stream: 'output'
output_stream: 'select_out'
}
)",
input_stream_handler.data());
CalculatorGraphConfig config =
mediapipe::ParseTextProtoOrDie<CalculatorGraphConfig>(graph_proto);
std::vector<Packet> packet_dump;
tool::AddVectorSink("output", &config, &packet_dump);
CalculatorGraph graph;
MP_ASSERT_OK(graph.Initialize(config));
struct Run {
Timestamp timestamp;
int value;
};
std::vector<Run> runs = {{.timestamp = Timestamp(2000), .value = 2},
{.timestamp = Timestamp(1000), .value = 1}};
for (const Run& run : runs) {
MP_ASSERT_OK(graph.StartRun({}));
if (select_packet) {
MP_EXPECT_OK(graph.AddPacketToInputStream(
"select", MakePacket<int>(0).At(run.timestamp)));
}
MP_EXPECT_OK(graph.AddPacketToInputStream(
"input", MakePacket<int>(run.value).At(run.timestamp)));
MP_ASSERT_OK(graph.WaitUntilIdle());
ASSERT_EQ(1, packet_dump.size());
EXPECT_EQ(run.value, packet_dump[0].Get<int>());
EXPECT_EQ(run.timestamp, packet_dump[0].Timestamp());
MP_ASSERT_OK(graph.CloseAllPacketSources());
MP_ASSERT_OK(graph.WaitUntilDone());
packet_dump.clear();
}
}
TEST(CalculatorGraph, MultipleRunsWithDifferentInputStreamHandlers) {
DoTestMultipleGraphRuns("BarrierInputStreamHandler", true);
DoTestMultipleGraphRuns("DefaultInputStreamHandler", true);
DoTestMultipleGraphRuns("EarlyCloseInputStreamHandler", true);
DoTestMultipleGraphRuns("FixedSizeInputStreamHandler", true);
DoTestMultipleGraphRuns("ImmediateInputStreamHandler", false);
DoTestMultipleGraphRuns("MuxInputStreamHandler", true);
DoTestMultipleGraphRuns("SyncSetInputStreamHandler", true);
DoTestMultipleGraphRuns("TimestampAlignInputStreamHandler", true);
}
} // namespace } // namespace
} // namespace mediapipe } // namespace mediapipe

View File

@ -408,13 +408,13 @@ absl::Status CalculatorNode::PrepareForRun(
validated_graph_->CalculatorInfos()[node_id_].Contract(); validated_graph_->CalculatorInfos()[node_id_].Contract();
for (const auto& svc_req : contract.ServiceRequests()) { for (const auto& svc_req : contract.ServiceRequests()) {
const auto& req = svc_req.second; const auto& req = svc_req.second;
std::string key{req.Service().key}; auto it = service_packets.find(req.Service().key);
auto it = service_packets.find(key);
if (it == service_packets.end()) { if (it == service_packets.end()) {
RET_CHECK(req.IsOptional()) RET_CHECK(req.IsOptional())
<< "required service '" << key << "' was not provided"; << "required service '" << req.Service().key << "' was not provided";
} else { } else {
calculator_state_->SetServicePacket(key, it->second); MP_RETURN_IF_ERROR(
calculator_state_->SetServicePacket(req.Service(), it->second));
} }
} }

View File

@ -61,13 +61,9 @@ Counter* CalculatorState::GetCounter(const std::string& name) {
return counter_factory_->GetCounter(absl::StrCat(NodeName(), "-", name)); return counter_factory_->GetCounter(absl::StrCat(NodeName(), "-", name));
} }
CounterSet* CalculatorState::GetCounterSet() { CounterFactory* CalculatorState::GetCounterFactory() {
CHECK(counter_factory_); CHECK(counter_factory_);
return counter_factory_->GetCounterSet(); return counter_factory_;
}
void CalculatorState::SetServicePacket(const std::string& key, Packet packet) {
service_packets_[key] = std::move(packet);
} }
} // namespace mediapipe } // namespace mediapipe

View File

@ -27,6 +27,7 @@
#include "mediapipe/framework/counter.h" #include "mediapipe/framework/counter.h"
#include "mediapipe/framework/counter_factory.h" #include "mediapipe/framework/counter_factory.h"
#include "mediapipe/framework/graph_service.h" #include "mediapipe/framework/graph_service.h"
#include "mediapipe/framework/graph_service_manager.h"
#include "mediapipe/framework/packet.h" #include "mediapipe/framework/packet.h"
#include "mediapipe/framework/packet_set.h" #include "mediapipe/framework/packet_set.h"
#include "mediapipe/framework/port.h" #include "mediapipe/framework/port.h"
@ -81,7 +82,7 @@ class CalculatorState {
// Returns a counter set, which can be passed to other classes, to generate // Returns a counter set, which can be passed to other classes, to generate
// counters. NOTE: This differs from GetCounter, in that the counters // counters. NOTE: This differs from GetCounter, in that the counters
// created by this counter set do not have the NodeName prefix. // created by this counter set do not have the NodeName prefix.
CounterSet* GetCounterSet(); CounterFactory* GetCounterFactory();
std::shared_ptr<ProfilingContext> GetSharedProfilingContext() const { std::shared_ptr<ProfilingContext> GetSharedProfilingContext() const {
return profiling_context_; return profiling_context_;
@ -99,17 +100,14 @@ class CalculatorState {
counter_factory_ = counter_factory; counter_factory_ = counter_factory;
} }
void SetServicePacket(const std::string& key, Packet packet); absl::Status SetServicePacket(const GraphServiceBase& service,
Packet packet) {
bool IsServiceAvailable(const GraphServiceBase& service) { return graph_service_manager_.SetServicePacket(service, packet);
return ContainsKey(service_packets_, service.key);
} }
template <typename T> template <typename T>
T& GetServiceObject(const GraphService<T>& service) { std::shared_ptr<T> GetServiceObject(const GraphService<T>& service) {
auto it = service_packets_.find(service.key); return graph_service_manager_.GetServiceObject(service);
CHECK(it != service_packets_.end());
return *it->second.template Get<std::shared_ptr<T>>();
} }
private: private:
@ -129,7 +127,7 @@ class CalculatorState {
// The graph tracing and profiling interface. // The graph tracing and profiling interface.
std::shared_ptr<ProfilingContext> profiling_context_; std::shared_ptr<ProfilingContext> profiling_context_;
std::map<std::string, Packet> service_packets_; GraphServiceManager graph_service_manager_;
//////////////////////////////////////// ////////////////////////////////////////
// Variables which ARE cleared by ResetBetweenRuns(). // Variables which ARE cleared by ResetBetweenRuns().

View File

@ -37,7 +37,7 @@ inline StatusBuilder RetCheckImpl(const absl::Status& status,
const char* condition, const char* condition,
mediapipe::source_location location) { mediapipe::source_location location) {
if (ABSL_PREDICT_TRUE(status.ok())) if (ABSL_PREDICT_TRUE(status.ok()))
return mediapipe::StatusBuilder(OkStatus(), location); return mediapipe::StatusBuilder(absl::OkStatus(), location);
return RetCheckFailSlowPath(location, condition, status); return RetCheckFailSlowPath(location, condition, status);
} }

View File

@ -18,7 +18,7 @@
namespace mediapipe { namespace mediapipe {
std::ostream& operator<<(std::ostream& os, const Status& x) { std::ostream& operator<<(std::ostream& os, const absl::Status& x) {
os << x.ToString(); os << x.ToString();
return os; return os;
} }

View File

@ -194,10 +194,10 @@ namespace status_macro_internal {
// that declares a variable. // that declares a variable.
class StatusAdaptorForMacros { class StatusAdaptorForMacros {
public: public:
StatusAdaptorForMacros(const Status& status, const char* file, int line) StatusAdaptorForMacros(const absl::Status& status, const char* file, int line)
: builder_(status, file, line) {} : builder_(status, file, line) {}
StatusAdaptorForMacros(Status&& status, const char* file, int line) StatusAdaptorForMacros(absl::Status&& status, const char* file, int line)
: builder_(std::move(status), file, line) {} : builder_(std::move(status), file, line) {}
StatusAdaptorForMacros(const StatusBuilder& builder, const char* /* file */, StatusAdaptorForMacros(const StatusBuilder& builder, const char* /* file */,

View File

@ -79,12 +79,9 @@ def _get_proto_provider(dep):
def _encode_binary_proto_impl(ctx): def _encode_binary_proto_impl(ctx):
"""Implementation of the encode_binary_proto rule.""" """Implementation of the encode_binary_proto rule."""
all_protos = depset()
for dep in ctx.attr.deps:
provider = _get_proto_provider(dep)
all_protos = depset( all_protos = depset(
direct = [], direct = [],
transitive = [all_protos, provider.transitive_sources], transitive = [_get_proto_provider(dep).transitive_sources for dep in ctx.attr.deps],
) )
textpb = ctx.file.input textpb = ctx.file.input
@ -120,7 +117,7 @@ def _encode_binary_proto_impl(ctx):
data_runfiles = ctx.runfiles(transitive_files = output_depset), data_runfiles = ctx.runfiles(transitive_files = output_depset),
)] )]
encode_binary_proto = rule( _encode_binary_proto = rule(
implementation = _encode_binary_proto_impl, implementation = _encode_binary_proto_impl,
attrs = { attrs = {
"_proto_compiler": attr.label( "_proto_compiler": attr.label(
@ -142,6 +139,15 @@ encode_binary_proto = rule(
}, },
) )
def encode_binary_proto(name, input, message_type, deps, **kwargs):
_encode_binary_proto(
name = name,
input = input,
message_type = message_type,
deps = deps,
**kwargs
)
def _generate_proto_descriptor_set_impl(ctx): def _generate_proto_descriptor_set_impl(ctx):
"""Implementation of the generate_proto_descriptor_set rule.""" """Implementation of the generate_proto_descriptor_set rule."""
all_protos = depset(transitive = [ all_protos = depset(transitive = [

View File

@ -114,7 +114,7 @@ cc_library(
"//mediapipe/framework/port:logging", "//mediapipe/framework/port:logging",
"//mediapipe/framework/port:ret_check", "//mediapipe/framework/port:ret_check",
"//mediapipe/framework/port:status", "//mediapipe/framework/port:status",
"@eigen_archive//:eigen", "@eigen_archive//:eigen3",
], ],
) )
@ -260,9 +260,11 @@ mediapipe_register_type(
include_headers = ["mediapipe/framework/formats/landmark.pb.h"], include_headers = ["mediapipe/framework/formats/landmark.pb.h"],
types = [ types = [
"::mediapipe::Landmark", "::mediapipe::Landmark",
"::mediapipe::LandmarkList",
"::mediapipe::NormalizedLandmark", "::mediapipe::NormalizedLandmark",
"::mediapipe::NormalizedLandmarkList", "::mediapipe::NormalizedLandmarkList",
"::std::vector<::mediapipe::Landmark>", "::std::vector<::mediapipe::Landmark>",
"::std::vector<::mediapipe::LandmarkList>",
"::std::vector<::mediapipe::NormalizedLandmark>", "::std::vector<::mediapipe::NormalizedLandmark>",
"::std::vector<::mediapipe::NormalizedLandmarkList>", "::std::vector<::mediapipe::NormalizedLandmarkList>",
], ],

View File

@ -31,6 +31,8 @@ message Classification {
optional float score = 2; optional float score = 2;
// Label or name of the class. // Label or name of the class.
optional string label = 3; optional string label = 3;
// Optional human-readable string for display purposes.
optional string display_name = 4;
} }
// Group of Classification protos. // Group of Classification protos.

View File

@ -78,6 +78,12 @@ class Image {
pixel_mutex_ = std::make_shared<absl::Mutex>(); pixel_mutex_ = std::make_shared<absl::Mutex>();
} }
// CPU getters.
const ImageFrameSharedPtr& GetImageFrameSharedPtr() const {
if (use_gpu_ == true) ConvertToCpu();
return image_frame_;
}
// Creates an Image representing the same image content as the input GPU // Creates an Image representing the same image content as the input GPU
// buffer in platform-specific representations. // buffer in platform-specific representations.
#if !MEDIAPIPE_DISABLE_GPU #if !MEDIAPIPE_DISABLE_GPU
@ -95,13 +101,8 @@ class Image {
gpu_buffer_ = gpu_buffer; gpu_buffer_ = gpu_buffer;
pixel_mutex_ = std::make_shared<absl::Mutex>(); pixel_mutex_ = std::make_shared<absl::Mutex>();
} }
#endif // !MEDIAPIPE_DISABLE_GPU
const ImageFrameSharedPtr& GetImageFrameSharedPtr() const { // GPU getters.
if (use_gpu_ == true) ConvertToCpu();
return image_frame_;
}
#if !MEDIAPIPE_DISABLE_GPU
#if MEDIAPIPE_GPU_BUFFER_USE_CV_PIXEL_BUFFER #if MEDIAPIPE_GPU_BUFFER_USE_CV_PIXEL_BUFFER
CVPixelBufferRef GetCVPixelBufferRef() const { CVPixelBufferRef GetCVPixelBufferRef() const {
if (use_gpu_ == false) ConvertToGpu(); if (use_gpu_ == false) ConvertToGpu();

View File

@ -47,8 +47,8 @@ message LandmarkList {
repeated Landmark landmark = 1; repeated Landmark landmark = 1;
} }
// A normalized version of above Landmark proto. All coordiates should be within // A normalized version of above Landmark proto. All coordinates should be
// [0, 1]. // within [0, 1].
message NormalizedLandmark { message NormalizedLandmark {
optional float x = 1; optional float x = 1;
optional float y = 2; optional float y = 2;

View File

@ -67,11 +67,11 @@ cc_test(
deps = [ deps = [
":optical_flow_field", ":optical_flow_field",
"//mediapipe/framework/deps:file_path", "//mediapipe/framework/deps:file_path",
"//mediapipe/framework/port:commandlineflags",
"//mediapipe/framework/port:file_helpers", "//mediapipe/framework/port:file_helpers",
"//mediapipe/framework/port:gtest_main", "//mediapipe/framework/port:gtest_main",
"//mediapipe/framework/port:integral_types", "//mediapipe/framework/port:integral_types",
"//mediapipe/framework/port:logging", "//mediapipe/framework/port:logging",
"@com_google_absl//absl/flags:flag",
"@org_tensorflow//tensorflow/core:framework", "@org_tensorflow//tensorflow/core:framework",
], ],
) )

Some files were not shown because too many files have changed in this diff Show More