Project import generated by Copybara.
GitOrigin-RevId: 6e4aff1cc351be3ae4537b677f36d139ee50ce09
This commit is contained in:
parent
a92cff7a60
commit
7c331ad58b
|
@ -54,7 +54,7 @@ RUN pip3 install tf_slim
|
|||
RUN ln -s /usr/bin/python3 /usr/bin/python
|
||||
|
||||
# Install bazel
|
||||
ARG BAZEL_VERSION=3.4.1
|
||||
ARG BAZEL_VERSION=3.7.2
|
||||
RUN mkdir /bazel && \
|
||||
wget --no-check-certificate -O /bazel/installer.sh "https://github.com/bazelbuild/bazel/releases/download/${BAZEL_VERSION}/b\
|
||||
azel-${BAZEL_VERSION}-installer-linux-x86_64.sh" && \
|
||||
|
|
|
@ -10,3 +10,7 @@ include requirements.txt
|
|||
recursive-include mediapipe/modules *.tflite *.txt *.binarypb
|
||||
exclude mediapipe/modules/objectron/object_detection_3d_chair_1stage.tflite
|
||||
exclude mediapipe/modules/objectron/object_detection_3d_sneakers_1stage.tflite
|
||||
exclude mediapipe/modules/objectron/object_detection_3d_sneakers.tflite
|
||||
exclude mediapipe/modules/objectron/object_detection_3d_chair.tflite
|
||||
exclude mediapipe/modules/objectron/object_detection_3d_camera.tflite
|
||||
exclude mediapipe/modules/objectron/object_detection_3d_cup.tflite
|
||||
|
|
|
@ -44,7 +44,7 @@ Hair Segmentation
|
|||
[Object Detection](https://google.github.io/mediapipe/solutions/object_detection) | ✅ | ✅ | ✅ | | | ✅
|
||||
[Box Tracking](https://google.github.io/mediapipe/solutions/box_tracking) | ✅ | ✅ | ✅ | | |
|
||||
[Instant Motion Tracking](https://google.github.io/mediapipe/solutions/instant_motion_tracking) | ✅ | | | | |
|
||||
[Objectron](https://google.github.io/mediapipe/solutions/objectron) | ✅ | | | ✅ | |
|
||||
[Objectron](https://google.github.io/mediapipe/solutions/objectron) | ✅ | | ✅ | ✅ | |
|
||||
[KNIFT](https://google.github.io/mediapipe/solutions/knift) | ✅ | | | | |
|
||||
[AutoFlip](https://google.github.io/mediapipe/solutions/autoflip) | | | ✅ | | |
|
||||
[MediaSequence](https://google.github.io/mediapipe/solutions/media_sequence) | | | ✅ | | |
|
||||
|
|
38
WORKSPACE
38
WORKSPACE
|
@ -2,16 +2,19 @@ workspace(name = "mediapipe")
|
|||
|
||||
load("@bazel_tools//tools/build_defs/repo:http.bzl", "http_archive")
|
||||
|
||||
skylib_version = "0.9.0"
|
||||
http_archive(
|
||||
name = "bazel_skylib",
|
||||
type = "tar.gz",
|
||||
url = "https://github.com/bazelbuild/bazel-skylib/releases/download/{}/bazel_skylib-{}.tar.gz".format (skylib_version, skylib_version),
|
||||
sha256 = "1dde365491125a3db70731e25658dfdd3bc5dbdfd11b840b3e987ecf043c7ca0",
|
||||
urls = [
|
||||
"https://github.com/bazelbuild/bazel-skylib/releases/download/1.0.3/bazel-skylib-1.0.3.tar.gz",
|
||||
"https://mirror.bazel.build/github.com/bazelbuild/bazel-skylib/releases/download/1.0.3/bazel-skylib-1.0.3.tar.gz",
|
||||
],
|
||||
sha256 = "1c531376ac7e5a180e0237938a2536de0c54d93f5c278634818e0efc952dd56c",
|
||||
)
|
||||
load("@bazel_skylib//:workspace.bzl", "bazel_skylib_workspace")
|
||||
bazel_skylib_workspace()
|
||||
load("@bazel_skylib//lib:versions.bzl", "versions")
|
||||
versions.check(minimum_bazel_version = "3.4.0")
|
||||
|
||||
versions.check(minimum_bazel_version = "3.7.2")
|
||||
|
||||
# ABSL cpp library lts_2020_09_23
|
||||
http_archive(
|
||||
|
@ -38,8 +41,8 @@ http_archive(
|
|||
|
||||
http_archive(
|
||||
name = "rules_foreign_cc",
|
||||
strip_prefix = "rules_foreign_cc-main",
|
||||
url = "https://github.com/bazelbuild/rules_foreign_cc/archive/main.zip",
|
||||
strip_prefix = "rules_foreign_cc-0.1.0",
|
||||
url = "https://github.com/bazelbuild/rules_foreign_cc/archive/0.1.0.zip",
|
||||
)
|
||||
|
||||
load("@rules_foreign_cc//:workspace_definitions.bzl", "rules_foreign_cc_dependencies")
|
||||
|
@ -304,8 +307,8 @@ http_archive(
|
|||
|
||||
# Maven dependencies.
|
||||
|
||||
RULES_JVM_EXTERNAL_TAG = "3.2"
|
||||
RULES_JVM_EXTERNAL_SHA = "82262ff4223c5fda6fb7ff8bd63db8131b51b413d26eb49e3131037e79e324af"
|
||||
RULES_JVM_EXTERNAL_TAG = "4.0"
|
||||
RULES_JVM_EXTERNAL_SHA = "31701ad93dbfe544d597dbe62c9a1fdd76d81d8a9150c2bf1ecf928ecdf97169"
|
||||
|
||||
http_archive(
|
||||
name = "rules_jvm_external",
|
||||
|
@ -318,7 +321,6 @@ load("@rules_jvm_external//:defs.bzl", "maven_install")
|
|||
|
||||
# Important: there can only be one maven_install rule. Add new maven deps here.
|
||||
maven_install(
|
||||
name = "maven",
|
||||
artifacts = [
|
||||
"androidx.concurrent:concurrent-futures:1.0.0-alpha03",
|
||||
"androidx.lifecycle:lifecycle-common:2.2.0",
|
||||
|
@ -343,10 +345,10 @@ maven_install(
|
|||
"org.hamcrest:hamcrest-library:1.3",
|
||||
],
|
||||
repositories = [
|
||||
"https://jcenter.bintray.com",
|
||||
"https://maven.google.com",
|
||||
"https://dl.google.com/dl/android/maven2",
|
||||
"https://repo1.maven.org/maven2",
|
||||
"https://jcenter.bintray.com",
|
||||
],
|
||||
fetch_sources = True,
|
||||
version_conflict_policy = "pinned",
|
||||
|
@ -363,10 +365,10 @@ http_archive(
|
|||
],
|
||||
)
|
||||
|
||||
#Tensorflow repo should always go after the other external dependencies.
|
||||
# 2020-12-09
|
||||
_TENSORFLOW_GIT_COMMIT = "0eadbb13cef1226b1bae17c941f7870734d97f8a"
|
||||
_TENSORFLOW_SHA256= "4ae06daa5b09c62f31b7bc1f781fd59053f286dd64355830d8c2ac601b795ef0"
|
||||
# Tensorflow repo should always go after the other external dependencies.
|
||||
# 2021-03-25
|
||||
_TENSORFLOW_GIT_COMMIT = "c67f68021824410ebe9f18513b8856ac1c6d4887"
|
||||
_TENSORFLOW_SHA256= "fd07d0b39422dc435e268c5e53b2646a8b4b1e3151b87837b43f86068faae87f"
|
||||
http_archive(
|
||||
name = "org_tensorflow",
|
||||
urls = [
|
||||
|
@ -383,5 +385,7 @@ http_archive(
|
|||
sha256 = _TENSORFLOW_SHA256,
|
||||
)
|
||||
|
||||
load("@org_tensorflow//tensorflow:workspace.bzl", "tf_workspace")
|
||||
tf_workspace(tf_repo_name = "org_tensorflow")
|
||||
load("@org_tensorflow//tensorflow:workspace3.bzl", "tf_workspace3")
|
||||
tf_workspace3()
|
||||
load("@org_tensorflow//tensorflow:workspace2.bzl", "tf_workspace2")
|
||||
tf_workspace2()
|
||||
|
|
|
@ -12,19 +12,30 @@ nav_order: 3
|
|||
{:toc}
|
||||
---
|
||||
|
||||
Each calculator is a node of of a graph. We describe how to create a new calculator, how to initialize a calculator, how to perform its calculations, input and output streams, timestamps, and options
|
||||
Calculators communicate by sending and receiving packets. Typically a single
|
||||
packet is sent along each input stream at each input timestamp. A packet can
|
||||
contain any kind of data, such as a single frame of video or a single integer
|
||||
detection count.
|
||||
|
||||
## Creating a packet
|
||||
|
||||
Packets are generally created with `MediaPipe::Adopt()` (from packet.h).
|
||||
Packets are generally created with `mediapipe::MakePacket<T>()` or
|
||||
`mediapipe::Adopt()` (from packet.h).
|
||||
|
||||
```c++
|
||||
// Create some data.
|
||||
auto data = absl::make_unique<MyDataClass>("constructor_argument");
|
||||
// Create a packet to own the data.
|
||||
Packet p = Adopt(data.release());
|
||||
// Create a packet containing some new data.
|
||||
Packet p = MakePacket<MyDataClass>("constructor_argument");
|
||||
// Make a new packet with the same data and a different timestamp.
|
||||
Packet p2 = p.At(Timestamp::PostStream());
|
||||
```
|
||||
|
||||
or:
|
||||
|
||||
```c++
|
||||
// Create some new data.
|
||||
auto data = absl::make_unique<MyDataClass>("constructor_argument");
|
||||
// Create a packet to own the data.
|
||||
Packet p = Adopt(data.release()).At(Timestamp::PostStream());
|
||||
```
|
||||
|
||||
Data within a packet is accessed with `Packet::Get<T>()`
|
||||
|
|
|
@ -28,7 +28,7 @@ Gradle.
|
|||
* Install MediaPipe following these [instructions](./install.md).
|
||||
* Setup Java Runtime.
|
||||
* Setup Android SDK release 28.0.3 and above.
|
||||
* Setup Android NDK r18b and above.
|
||||
* Setup Android NDK version between 18 and 21.
|
||||
|
||||
MediaPipe recommends setting up Android SDK and NDK via Android Studio (and see
|
||||
below for Android Studio setup). However, if you prefer using MediaPipe without
|
||||
|
|
|
@ -25,25 +25,11 @@ install --user six`.
|
|||
|
||||
## Installing on Debian and Ubuntu
|
||||
|
||||
1. Install Bazel.
|
||||
1. Install Bazelisk.
|
||||
|
||||
Follow the official
|
||||
[Bazel documentation](https://docs.bazel.build/versions/master/install-ubuntu.html)
|
||||
to install Bazel 3.4 or higher.
|
||||
|
||||
For Nvidia Jetson and Raspberry Pi devices with aarch64 Linux, Bazel needs
|
||||
to be built from source:
|
||||
|
||||
```bash
|
||||
# For Bazel 3.4.1
|
||||
mkdir $HOME/bazel-3.4.1
|
||||
cd $HOME/bazel-3.4.1
|
||||
wget https://github.com/bazelbuild/bazel/releases/download/3.4.1/bazel-3.4.1-dist.zip
|
||||
sudo apt-get install build-essential openjdk-8-jdk python zip unzip
|
||||
unzip bazel-3.4.1-dist.zip
|
||||
env EXTRA_BAZEL_ARGS="--host_javabase=@local_jdk//:jdk" bash ./compile.sh
|
||||
sudo cp output/bazel /usr/local/bin/
|
||||
```
|
||||
[Bazel documentation](https://docs.bazel.build/versions/master/install-bazelisk.html)
|
||||
to install Bazelisk.
|
||||
|
||||
2. Checkout MediaPipe repository.
|
||||
|
||||
|
@ -207,11 +193,11 @@ build issues.
|
|||
|
||||
**Disclaimer**: Running MediaPipe on CentOS is experimental.
|
||||
|
||||
1. Install Bazel.
|
||||
1. Install Bazelisk.
|
||||
|
||||
Follow the official
|
||||
[Bazel documentation](https://docs.bazel.build/versions/master/install-redhat.html)
|
||||
to install Bazel 3.4 or higher.
|
||||
[Bazel documentation](https://docs.bazel.build/versions/master/install-bazelisk.html)
|
||||
to install Bazelisk.
|
||||
|
||||
2. Checkout MediaPipe repository.
|
||||
|
||||
|
@ -336,11 +322,11 @@ build issues.
|
|||
* Install [Xcode](https://developer.apple.com/xcode/) and its Command Line
|
||||
Tools by `xcode-select --install`.
|
||||
|
||||
2. Install Bazel.
|
||||
2. Install Bazelisk.
|
||||
|
||||
Follow the official
|
||||
[Bazel documentation](https://docs.bazel.build/versions/master/install-os-x.html#install-with-installer-mac-os-x)
|
||||
to install Bazel 3.4 or higher.
|
||||
[Bazel documentation](https://docs.bazel.build/versions/master/install-bazelisk.html)
|
||||
to install Bazelisk.
|
||||
|
||||
3. Checkout MediaPipe repository.
|
||||
|
||||
|
@ -353,7 +339,7 @@ build issues.
|
|||
4. Install OpenCV and FFmpeg.
|
||||
|
||||
Option 1. Use HomeBrew package manager tool to install the pre-compiled
|
||||
OpenCV 3.4.5 libraries. FFmpeg will be installed via OpenCV.
|
||||
OpenCV 3 libraries. FFmpeg will be installed via OpenCV.
|
||||
|
||||
```bash
|
||||
$ brew install opencv@3
|
||||
|
@ -484,29 +470,36 @@ next section.
|
|||
|
||||
4. Install Visual C++ Build Tools 2019 and WinSDK
|
||||
|
||||
Go to https://visualstudio.microsoft.com/visual-cpp-build-tools, download
|
||||
build tools, and install Microsoft Visual C++ 2019 Redistributable and
|
||||
Microsoft Build Tools 2019.
|
||||
Go to
|
||||
[the VisualStudio website](ttps://visualstudio.microsoft.com/visual-cpp-build-tools),
|
||||
download build tools, and install Microsoft Visual C++ 2019 Redistributable
|
||||
and Microsoft Build Tools 2019.
|
||||
|
||||
Download the WinSDK from
|
||||
https://developer.microsoft.com/en-us/windows/downloads/windows-10-sdk/ and
|
||||
install.
|
||||
[the official MicroSoft website](https://developer.microsoft.com/en-us/windows/downloads/windows-10-sdk/)
|
||||
and install.
|
||||
|
||||
5. Install Bazel and add the location of the Bazel executable to the `%PATH%`
|
||||
environment variable.
|
||||
5. Install Bazel or Bazelisk and add the location of the Bazel executable to
|
||||
the `%PATH%` environment variable.
|
||||
|
||||
Follow the official
|
||||
[Bazel documentation](https://docs.bazel.build/versions/master/install-windows.html)
|
||||
to install Bazel 3.4 or higher.
|
||||
Option 1. Follow
|
||||
[the official Bazel documentation](https://docs.bazel.build/versions/master/install-windows.html)
|
||||
to install Bazel 3.7.2 or higher.
|
||||
|
||||
6. Set Bazel variables.
|
||||
Option 2. Follow the official
|
||||
[Bazel documentation](https://docs.bazel.build/versions/master/install-bazelisk.html)
|
||||
to install Bazelisk.
|
||||
|
||||
6. Set Bazel variables. Learn more details about
|
||||
["Build on Windows"](https://docs.bazel.build/versions/master/windows.html#build-c-with-msvc)
|
||||
in the Bazel official documentation.
|
||||
|
||||
```
|
||||
# Find the exact paths and version numbers from your local version.
|
||||
# Please find the exact paths and version numbers from your local version.
|
||||
C:\> set BAZEL_VS=C:\Program Files (x86)\Microsoft Visual Studio\2019\BuildTools
|
||||
C:\> set BAZEL_VC=C:\Program Files (x86)\Microsoft Visual Studio\2019\BuildTools\VC
|
||||
C:\> set BAZEL_VC_FULL_VERSION=14.25.28610
|
||||
C:\> set BAZEL_WINSDK_FULL_VERSION=10.1.18362.1
|
||||
C:\> set BAZEL_VC_FULL_VERSION=<Your local VC version>
|
||||
C:\> set BAZEL_WINSDK_FULL_VERSION=<Your local WinSDK version>
|
||||
```
|
||||
|
||||
7. Checkout MediaPipe repository.
|
||||
|
@ -593,19 +586,11 @@ cameras. Alternatively, you use a video file as input.
|
|||
username@DESKTOP-TMVLBJ1:~$ sudo apt-get update && sudo apt-get install -y build-essential git python zip adb openjdk-8-jdk
|
||||
```
|
||||
|
||||
5. Install Bazel.
|
||||
5. Install Bazelisk.
|
||||
|
||||
```bash
|
||||
username@DESKTOP-TMVLBJ1:~$ curl -sLO --retry 5 --retry-max-time 10 \
|
||||
https://storage.googleapis.com/bazel/3.4.1/release/bazel-3.4.1-installer-linux-x86_64.sh && \
|
||||
sudo mkdir -p /usr/local/bazel/3.4.1 && \
|
||||
chmod 755 bazel-3.4.1-installer-linux-x86_64.sh && \
|
||||
sudo ./bazel-3.4.1-installer-linux-x86_64.sh --prefix=/usr/local/bazel/3.4.1 && \
|
||||
source /usr/local/bazel/3.4.1/lib/bazel/bin/bazel-complete.bash
|
||||
|
||||
username@DESKTOP-TMVLBJ1:~$ /usr/local/bazel/3.4.1/lib/bazel/bin/bazel version && \
|
||||
alias bazel='/usr/local/bazel/3.4.1/lib/bazel/bin/bazel'
|
||||
```
|
||||
Follow the official
|
||||
[Bazel documentation](https://docs.bazel.build/versions/master/install-bazelisk.html)
|
||||
to install Bazelisk.
|
||||
|
||||
6. Checkout MediaPipe repository.
|
||||
|
||||
|
|
|
@ -44,7 +44,7 @@ Hair Segmentation
|
|||
[Object Detection](https://google.github.io/mediapipe/solutions/object_detection) | ✅ | ✅ | ✅ | | | ✅
|
||||
[Box Tracking](https://google.github.io/mediapipe/solutions/box_tracking) | ✅ | ✅ | ✅ | | |
|
||||
[Instant Motion Tracking](https://google.github.io/mediapipe/solutions/instant_motion_tracking) | ✅ | | | | |
|
||||
[Objectron](https://google.github.io/mediapipe/solutions/objectron) | ✅ | | | ✅ | |
|
||||
[Objectron](https://google.github.io/mediapipe/solutions/objectron) | ✅ | | ✅ | ✅ | |
|
||||
[KNIFT](https://google.github.io/mediapipe/solutions/knift) | ✅ | | | | |
|
||||
[AutoFlip](https://google.github.io/mediapipe/solutions/autoflip) | | | ✅ | | |
|
||||
[MediaSequence](https://google.github.io/mediapipe/solutions/media_sequence) | | | ✅ | | |
|
||||
|
|
|
@ -183,8 +183,8 @@ function onResults(results) {
|
|||
canvasCtx.restore();
|
||||
}
|
||||
|
||||
const faceDetection = new Objectron({locateFile: (file) => {
|
||||
return `https://cdn.jsdelivr.net/npm/@mediapipe/objectron@0.0/${file}`;
|
||||
const faceDetection = new FaceDetection({locateFile: (file) => {
|
||||
return `https://cdn.jsdelivr.net/npm/@mediapipe/face_detection@0.0/${file}`;
|
||||
}});
|
||||
faceDetection.setOptions({
|
||||
minDetectionConfidence: 0.5
|
||||
|
|
|
@ -358,15 +358,17 @@ cap.release()
|
|||
## Example Apps
|
||||
|
||||
Please first see general instructions for
|
||||
[Android](../getting_started/android.md) and [iOS](../getting_started/ios.md) on
|
||||
how to build MediaPipe examples.
|
||||
[Android](../getting_started/android.md), [iOS](../getting_started/ios.md), and
|
||||
[desktop](../getting_started/cpp.md) on how to build MediaPipe examples.
|
||||
|
||||
Note: To visualize a graph, copy the graph and paste it into
|
||||
[MediaPipe Visualizer](https://viz.mediapipe.dev/). For more information on how
|
||||
to visualize its associated subgraphs, please see
|
||||
[visualizer documentation](../tools/visualizer.md).
|
||||
|
||||
### Two-stage Objectron
|
||||
### Mobile
|
||||
|
||||
#### Two-stage Objectron
|
||||
|
||||
* Graph:
|
||||
[`mediapipe/graphs/object_detection_3d/object_occlusion_tracking.pbtxt`](https://github.com/google/mediapipe/tree/master/mediapipe/graphs/object_detection_3d/object_occlusion_tracking.pbtxt)
|
||||
|
@ -404,7 +406,7 @@ to visualize its associated subgraphs, please see
|
|||
|
||||
* iOS target: Not available
|
||||
|
||||
### Single-stage Objectron
|
||||
#### Single-stage Objectron
|
||||
|
||||
* Graph:
|
||||
[`mediapipe/graphs/object_detection_3d/object_occlusion_tracking_1stage.pbtxt`](https://github.com/google/mediapipe/tree/master/mediapipe/graphs/object_detection_3d/object_occlusion_tracking.pbtxt)
|
||||
|
@ -428,7 +430,7 @@ to visualize its associated subgraphs, please see
|
|||
|
||||
* iOS target: Not available
|
||||
|
||||
### Assets
|
||||
#### Assets
|
||||
|
||||
Example app bounding boxes are rendered with [GlAnimationOverlayCalculator](https://github.com/google/mediapipe/tree/master/mediapipe/graphs/object_detection_3d/calculators/gl_animation_overlay_calculator.cc) using a parsing of the sequenced .obj file
|
||||
format into a custom .uuu format. This can be done for user assets as follows:
|
||||
|
@ -449,9 +451,35 @@ Example app bounding boxes are rendered with [GlAnimationOverlayCalculator](http
|
|||
> single .uuu animation file, using the order given by sorting the filenames alphanumerically. Also the ObjParser directory inputs must be given as
|
||||
> absolute paths, not relative paths. See parser utility library at [`mediapipe/graphs/object_detection_3d/obj_parser/`](https://github.com/google/mediapipe/tree/master/mediapipe/graphs/object_detection_3d/obj_parser/) for more details.
|
||||
|
||||
### Coordinate Systems
|
||||
|
||||
#### Object Coordinate
|
||||
### Desktop
|
||||
|
||||
To build the application, run:
|
||||
|
||||
```bash
|
||||
bazel build -c opt --define MEDIAPIPE_DISABLE_GPU=1 mediapipe/examples/desktop/object_detection_3d:objectron_cpu
|
||||
```
|
||||
|
||||
To run the application, replace `<input video path>` and `<output video path>`
|
||||
in the command below with your own paths, and `<landmark model path>` and
|
||||
`<allowed labels>` with the following:
|
||||
|
||||
Category | `<landmark model path>` | `<allowed labels>`
|
||||
:------- | :-------------------------------------------------------------------------- | :-----------------
|
||||
Shoe | mediapipe/modules/objectron/object_detection_3d_sneakers.tflite | Footwear
|
||||
Chair | mediapipe/modules/objectron/object_detection_3d_chair.tflite | Chair
|
||||
Cup | mediapipe/modules/objectron/object_detection_3d_cup.tflite | Mug
|
||||
Camera | mediapipe/modules/objectron/object_detection_3d_camera.tflite | Camera
|
||||
|
||||
```
|
||||
GLOG_logtostderr=1 bazel-bin/mediapipe/examples/desktop/object_detection_3d/objectron_cpu \
|
||||
--calculator_graph_config_file=mediapipe/graphs/object_detection_3d/objectron_desktop_cpu.pbtxt \
|
||||
--input_side_packets=input_video_path=<input video path>,output_video_path=<output video path>,box_landmark_model_path=<landmark model path>,allowed_labels=<allowed labels>
|
||||
```
|
||||
|
||||
## Coordinate Systems
|
||||
|
||||
### Object Coordinate
|
||||
|
||||
Each object has its object coordinate frame. We use the below object coordinate
|
||||
definition, with `+x` pointing right, `+y` pointing up and `+z` pointing front,
|
||||
|
@ -459,7 +487,7 @@ origin is at the center of the 3D bounding box.
|
|||
|
||||
![box_coordinate.svg](../images/box_coordinate.svg)
|
||||
|
||||
#### Camera Coordinate
|
||||
### Camera Coordinate
|
||||
|
||||
A 3D object is parameterized by its `scale` and `rotation`, `translation` with
|
||||
regard to the camera coordinate frame. In this API we use the below camera
|
||||
|
@ -476,7 +504,7 @@ camera frame by applying `rotation` and `translation`:
|
|||
landmarks_3d = rotation * scale * unit_box + translation
|
||||
```
|
||||
|
||||
#### NDC Space
|
||||
### NDC Space
|
||||
|
||||
In this API we use
|
||||
[NDC(normalized device coordinates)](http://www.songho.ca/opengl/gl_projectionmatrix.html)
|
||||
|
@ -495,7 +523,7 @@ y_ndc = -fy * Y / Z + py
|
|||
z_ndc = 1 / Z
|
||||
```
|
||||
|
||||
#### Pixel Space
|
||||
### Pixel Space
|
||||
|
||||
In this API we set upper-left coner of an image as the origin of pixel
|
||||
coordinate. One can convert from NDC to pixel space as follows:
|
||||
|
@ -532,10 +560,11 @@ py = -py_pixel * 2.0 / image_height + 1.0
|
|||
[Announcing the Objectron Dataset](https://ai.googleblog.com/2020/11/announcing-objectron-dataset.html)
|
||||
* Google AI Blog:
|
||||
[Real-Time 3D Object Detection on Mobile Devices with MediaPipe](https://ai.googleblog.com/2020/03/real-time-3d-object-detection-on-mobile.html)
|
||||
* Paper: [Objectron: A Large Scale Dataset of Object-Centric Videos in the Wild with Pose Annotations](https://arxiv.org/abs/2012.09988), to appear in CVPR 2021
|
||||
* Paper: [MobilePose: Real-Time Pose Estimation for Unseen Objects with Weak
|
||||
Shape Supervision](https://arxiv.org/abs/2003.03522)
|
||||
* Paper:
|
||||
[Instant 3D Object Tracking with Applications in Augmented Reality](https://drive.google.com/open?id=1O_zHmlgXIzAdKljp20U_JUkEHOGG52R8)
|
||||
([presentation](https://www.youtube.com/watch?v=9ndF1AIo7h0))
|
||||
([presentation](https://www.youtube.com/watch?v=9ndF1AIo7h0)), Fourth Workshop on Computer Vision for AR/VR, CVPR 2020
|
||||
* [Models and model cards](./models.md#objectron)
|
||||
* [Python Colab](https://mediapipe.page.link/objectron_py_colab)
|
||||
|
|
|
@ -25,10 +25,11 @@ One of the applications
|
|||
[BlazePose](https://ai.googleblog.com/2020/08/on-device-real-time-body-pose-tracking.html)
|
||||
can enable is fitness. More specifically - pose classification and repetition
|
||||
counting. In this section we'll provide basic guidance on building a custom pose
|
||||
classifier with the help of [Colabs](#colabs) and wrap it in a simple
|
||||
[fitness app](https://mediapipe.page.link/mlkit-pose-classification-demo-app)
|
||||
powered by [ML Kit](https://developers.google.com/ml-kit). Push-ups and squats
|
||||
are used for demonstration purposes as the most common exercises.
|
||||
classifier with the help of [Colabs](#colabs) and wrap it in a simple fitness
|
||||
demo within
|
||||
[ML Kit quickstart app](https://developers.google.com/ml-kit/vision/pose-detection/classifying-poses#4_integrate_with_the_ml_kit_quickstart_app).
|
||||
Push-ups and squats are used for demonstration purposes as the most common
|
||||
exercises.
|
||||
|
||||
![pose_classification_pushups_and_squats.gif](../images/mobile/pose_classification_pushups_and_squats.gif) |
|
||||
:--------------------------------------------------------------------------------------------------------: |
|
||||
|
@ -47,7 +48,7 @@ determines the object's class based on the closest samples in the training set.
|
|||
classifier and form a training set using these [Colabs](#colabs),
|
||||
3. Perform the classification itself followed by repetition counting (e.g., in
|
||||
the
|
||||
[ML Kit demo app](https://mediapipe.page.link/mlkit-pose-classification-demo-app)).
|
||||
[ML Kit quickstart app](https://developers.google.com/ml-kit/vision/pose-detection/classifying-poses#4_integrate_with_the_ml_kit_quickstart_app)).
|
||||
|
||||
## Training Set
|
||||
|
||||
|
@ -76,7 +77,7 @@ video right in the Colab.
|
|||
|
||||
Code of the classifier is available both in the
|
||||
[`Pose Classification Colab (Extended)`] and in the
|
||||
[ML Kit demo app](https://mediapipe.page.link/mlkit-pose-classification-demo-app).
|
||||
[ML Kit quickstart app](https://developers.google.com/ml-kit/vision/pose-detection/classifying-poses#4_integrate_with_the_ml_kit_quickstart_app).
|
||||
Please refer to them for details of the approach described below.
|
||||
|
||||
The k-NN algorithm used for pose classification requires a feature vector
|
||||
|
@ -127,11 +128,13 @@ where the pose class and the counter can't be changed.
|
|||
|
||||
## Future Work
|
||||
|
||||
We are actively working on improving BlazePose GHUM 3D's Z prediction. It will
|
||||
allow us to use joint angles in the feature vectors, which are more natural and
|
||||
easier to configure (although distances can still be useful to detect touches
|
||||
between body parts) and to perform rotation normalization of poses and reduce
|
||||
the number of camera angles required for accurate k-NN classification.
|
||||
We are actively working on improving
|
||||
[BlazePose GHUM 3D](./pose.md#pose-landmark-model-blazepose-ghum-3d)'s Z
|
||||
prediction. It will allow us to use joint angles in the feature vectors, which
|
||||
are more natural and easier to configure (although distances can still be useful
|
||||
to detect touches between body parts) and to perform rotation normalization of
|
||||
poses and reduce the number of camera angles required for accurate k-NN
|
||||
classification.
|
||||
|
||||
## Colabs
|
||||
|
||||
|
|
|
@ -28,7 +28,7 @@ has_toc: false
|
|||
[Object Detection](https://google.github.io/mediapipe/solutions/object_detection) | ✅ | ✅ | ✅ | | | ✅
|
||||
[Box Tracking](https://google.github.io/mediapipe/solutions/box_tracking) | ✅ | ✅ | ✅ | | |
|
||||
[Instant Motion Tracking](https://google.github.io/mediapipe/solutions/instant_motion_tracking) | ✅ | | | | |
|
||||
[Objectron](https://google.github.io/mediapipe/solutions/objectron) | ✅ | | | ✅ | |
|
||||
[Objectron](https://google.github.io/mediapipe/solutions/objectron) | ✅ | | ✅ | ✅ | |
|
||||
[KNIFT](https://google.github.io/mediapipe/solutions/knift) | ✅ | | | | |
|
||||
[AutoFlip](https://google.github.io/mediapipe/solutions/autoflip) | | | ✅ | | |
|
||||
[MediaSequence](https://google.github.io/mediapipe/solutions/media_sequence) | | | ✅ | | |
|
||||
|
|
|
@ -41,6 +41,7 @@ profiler_config {
|
|||
trace_enabled: true
|
||||
enable_profiler: true
|
||||
trace_log_interval_count: 200
|
||||
trace_log_path: "/sdcard/Download/"
|
||||
}
|
||||
```
|
||||
|
||||
|
@ -64,7 +65,7 @@ MediaPipe will emit data into a pre-specified directory:
|
|||
|
||||
* On the desktop, this will be the `/tmp` directory.
|
||||
|
||||
* On Android, this will be the `/sdcard` directory.
|
||||
* On Android, this will be the external storage directory (e.g., `/storage/emulated/0/`).
|
||||
|
||||
* On iOS, this can be reached through XCode. Select "Window/Devices and
|
||||
Simulators" and select the "Devices" tab.
|
||||
|
@ -103,7 +104,7 @@ we record ten intervals of half a second each. This can be overridden by adding
|
|||
* Include the line below in your `AndroidManifest.xml` file.
|
||||
|
||||
```xml
|
||||
<uses-permission android:name="android.permission.WRITE_EXTERNAL_STORAGE" />
|
||||
<uses-permission android:name="android.permission.MANAGE_EXTERNAL_STORAGE" />
|
||||
```
|
||||
|
||||
* Grant the permission either upon first app launch, or by going into
|
||||
|
@ -130,8 +131,8 @@ we record ten intervals of half a second each. This can be overridden by adding
|
|||
events to a trace log files at:
|
||||
|
||||
```bash
|
||||
/sdcard/mediapipe_trace_0.binarypb
|
||||
/sdcard/mediapipe_trace_1.binarypb
|
||||
/storage/emulated/0/Download/mediapipe_trace_0.binarypb
|
||||
/storage/emulated/0/Download/mediapipe_trace_1.binarypb
|
||||
```
|
||||
|
||||
After every 5 sec, writing shifts to a successive trace log file, such that
|
||||
|
@ -139,10 +140,10 @@ we record ten intervals of half a second each. This can be overridden by adding
|
|||
trace files have been written to the device using adb shell.
|
||||
|
||||
```bash
|
||||
adb shell "ls -la /sdcard/"
|
||||
adb shell "ls -la /storage/emulated/0/Download"
|
||||
```
|
||||
|
||||
On android, MediaPipe selects the external storage directory `/sdcard` for
|
||||
On android, MediaPipe selects the external storage (e.g., `/storage/emulated/0/`) for
|
||||
trace logs. This directory can be overridden using the setting
|
||||
`trace_log_path`, like:
|
||||
|
||||
|
@ -150,7 +151,7 @@ we record ten intervals of half a second each. This can be overridden by adding
|
|||
profiler_config {
|
||||
trace_enabled: true
|
||||
enable_profiler: true
|
||||
trace_log_path: "/sdcard/profiles/"
|
||||
trace_log_path: "/sdcard/Download/profiles/"
|
||||
}
|
||||
```
|
||||
|
||||
|
@ -161,7 +162,7 @@ we record ten intervals of half a second each. This can be overridden by adding
|
|||
|
||||
```bash
|
||||
# from your terminal
|
||||
adb pull /sdcard/mediapipe_trace_0.binarypb
|
||||
adb pull /storage/emulated/0/Download/mediapipe_trace_0.binarypb
|
||||
# if successful you should see something like
|
||||
# /sdcard/mediapipe_trace_0.binarypb: 1 file pulled. 0.1 MB/s (6766 bytes in 0.045s)
|
||||
```
|
||||
|
|
|
@ -128,7 +128,7 @@ cc_library(
|
|||
"//mediapipe/framework/port:ret_check",
|
||||
"//mediapipe/util:time_series_util",
|
||||
"@com_google_absl//absl/strings",
|
||||
"@eigen_archive//:eigen",
|
||||
"@eigen_archive//:eigen3",
|
||||
],
|
||||
alwayslink = 1,
|
||||
)
|
||||
|
@ -147,7 +147,7 @@ cc_library(
|
|||
"//mediapipe/util:time_series_util",
|
||||
"@com_google_absl//absl/strings",
|
||||
"@com_google_audio_tools//audio/dsp/mfcc",
|
||||
"@eigen_archive//:eigen",
|
||||
"@eigen_archive//:eigen3",
|
||||
],
|
||||
alwayslink = 1,
|
||||
)
|
||||
|
@ -168,7 +168,7 @@ cc_library(
|
|||
"@com_google_absl//absl/strings",
|
||||
"@com_google_audio_tools//audio/dsp:resampler",
|
||||
"@com_google_audio_tools//audio/dsp:resampler_q",
|
||||
"@eigen_archive//:eigen",
|
||||
"@eigen_archive//:eigen3",
|
||||
],
|
||||
alwayslink = 1,
|
||||
)
|
||||
|
@ -208,7 +208,7 @@ cc_library(
|
|||
"@com_google_absl//absl/strings",
|
||||
"@com_google_audio_tools//audio/dsp:window_functions",
|
||||
"@com_google_audio_tools//audio/dsp/spectrogram",
|
||||
"@eigen_archive//:eigen",
|
||||
"@eigen_archive//:eigen3",
|
||||
],
|
||||
alwayslink = 1,
|
||||
)
|
||||
|
@ -228,7 +228,7 @@ cc_library(
|
|||
"//mediapipe/framework/port:status",
|
||||
"//mediapipe/util:time_series_util",
|
||||
"@com_google_audio_tools//audio/dsp:window_functions",
|
||||
"@eigen_archive//:eigen",
|
||||
"@eigen_archive//:eigen3",
|
||||
],
|
||||
alwayslink = 1,
|
||||
)
|
||||
|
@ -242,9 +242,9 @@ cc_test(
|
|||
"//mediapipe/framework:calculator_runner",
|
||||
"//mediapipe/framework/deps:file_path",
|
||||
"//mediapipe/framework/formats:time_series_header_cc_proto",
|
||||
"//mediapipe/framework/port:commandlineflags",
|
||||
"//mediapipe/framework/port:gtest_main",
|
||||
"//mediapipe/framework/port:parse_text_proto",
|
||||
"@com_google_absl//absl/flags:flag",
|
||||
],
|
||||
)
|
||||
|
||||
|
@ -261,7 +261,7 @@ cc_test(
|
|||
"//mediapipe/framework/port:integral_types",
|
||||
"//mediapipe/framework/port:parse_text_proto",
|
||||
"//mediapipe/util:time_series_test_util",
|
||||
"@eigen_archive//:eigen",
|
||||
"@eigen_archive//:eigen3",
|
||||
],
|
||||
)
|
||||
|
||||
|
@ -276,7 +276,7 @@ cc_test(
|
|||
"//mediapipe/framework/port:gtest_main",
|
||||
"//mediapipe/framework/port:status",
|
||||
"//mediapipe/util:time_series_test_util",
|
||||
"@eigen_archive//:eigen",
|
||||
"@eigen_archive//:eigen3",
|
||||
],
|
||||
)
|
||||
|
||||
|
@ -296,7 +296,7 @@ cc_test(
|
|||
"//mediapipe/framework/port:status",
|
||||
"//mediapipe/util:time_series_test_util",
|
||||
"@com_google_audio_tools//audio/dsp:number_util",
|
||||
"@eigen_archive//:eigen",
|
||||
"@eigen_archive//:eigen3",
|
||||
],
|
||||
)
|
||||
|
||||
|
@ -314,7 +314,7 @@ cc_test(
|
|||
"//mediapipe/framework/port:integral_types",
|
||||
"//mediapipe/framework/port:status",
|
||||
"//mediapipe/util:time_series_test_util",
|
||||
"@eigen_archive//:eigen",
|
||||
"@eigen_archive//:eigen3",
|
||||
],
|
||||
)
|
||||
|
||||
|
@ -333,7 +333,7 @@ cc_test(
|
|||
"//mediapipe/framework/port:status",
|
||||
"//mediapipe/util:time_series_test_util",
|
||||
"@com_google_audio_tools//audio/dsp:window_functions",
|
||||
"@eigen_archive//:eigen",
|
||||
"@eigen_archive//:eigen3",
|
||||
],
|
||||
)
|
||||
|
||||
|
@ -352,6 +352,6 @@ cc_test(
|
|||
"//mediapipe/framework/tool:validate_type",
|
||||
"//mediapipe/util:time_series_test_util",
|
||||
"@com_google_audio_tools//audio/dsp:signal_vector_util",
|
||||
"@eigen_archive//:eigen",
|
||||
"@eigen_archive//:eigen3",
|
||||
],
|
||||
)
|
||||
|
|
|
@ -12,10 +12,10 @@
|
|||
// See the License for the specific language governing permissions and
|
||||
// limitations under the License.
|
||||
|
||||
#include "absl/flags/flag.h"
|
||||
#include "mediapipe/framework/calculator_runner.h"
|
||||
#include "mediapipe/framework/deps/file_path.h"
|
||||
#include "mediapipe/framework/formats/time_series_header.pb.h"
|
||||
#include "mediapipe/framework/port/commandlineflags.h"
|
||||
#include "mediapipe/framework/port/gmock.h"
|
||||
#include "mediapipe/framework/port/gtest.h"
|
||||
#include "mediapipe/framework/port/parse_text_proto.h"
|
||||
|
|
|
@ -414,7 +414,7 @@ cc_library(
|
|||
"//mediapipe/framework/api2:node",
|
||||
"//mediapipe/framework/formats:matrix",
|
||||
"//mediapipe/framework/port:status",
|
||||
"@eigen_archive//:eigen",
|
||||
"@eigen_archive//:eigen3",
|
||||
],
|
||||
alwayslink = 1,
|
||||
)
|
||||
|
@ -430,7 +430,7 @@ cc_library(
|
|||
"//mediapipe/framework/api2:node",
|
||||
"//mediapipe/framework/formats:matrix",
|
||||
"//mediapipe/framework/port:status",
|
||||
"@eigen_archive//:eigen",
|
||||
"@eigen_archive//:eigen3",
|
||||
],
|
||||
alwayslink = 1,
|
||||
)
|
||||
|
@ -450,6 +450,20 @@ cc_library(
|
|||
alwayslink = 1,
|
||||
)
|
||||
|
||||
cc_library(
|
||||
name = "nonzero_calculator",
|
||||
srcs = ["nonzero_calculator.cc"],
|
||||
visibility = [
|
||||
"//visibility:public",
|
||||
],
|
||||
deps = [
|
||||
"//mediapipe/framework:calculator_framework",
|
||||
"//mediapipe/framework/api2:node",
|
||||
"//mediapipe/framework/port:ret_check",
|
||||
],
|
||||
alwayslink = 1,
|
||||
)
|
||||
|
||||
cc_test(
|
||||
name = "mux_calculator_test",
|
||||
srcs = ["mux_calculator_test.cc"],
|
||||
|
@ -776,7 +790,7 @@ cc_test(
|
|||
"//mediapipe/framework/port:gtest_main",
|
||||
"//mediapipe/framework/port:logging",
|
||||
"//mediapipe/framework/tool:validate_type",
|
||||
"@eigen_archive//:eigen",
|
||||
"@eigen_archive//:eigen3",
|
||||
],
|
||||
)
|
||||
|
||||
|
@ -793,7 +807,7 @@ cc_test(
|
|||
"//mediapipe/framework/port:logging",
|
||||
"//mediapipe/framework/port:parse_text_proto",
|
||||
"//mediapipe/framework/tool:validate_type",
|
||||
"@eigen_archive//:eigen",
|
||||
"@eigen_archive//:eigen3",
|
||||
],
|
||||
)
|
||||
|
||||
|
@ -1024,7 +1038,7 @@ cc_library(
|
|||
"//mediapipe/framework/tool:status_util",
|
||||
"//mediapipe/util:time_series_util",
|
||||
"@com_google_absl//absl/memory",
|
||||
"@eigen_archive//:eigen",
|
||||
"@eigen_archive//:eigen3",
|
||||
],
|
||||
alwayslink = 1,
|
||||
)
|
||||
|
|
|
@ -57,7 +57,7 @@ namespace mediapipe {
|
|||
//
|
||||
// The "ALLOW" stream indicates the transition between accepting frames and
|
||||
// dropping frames. "ALLOW = true" indicates the start of accepting frames
|
||||
// including the current timestamp, and "ALLOW = true" indicates the start of
|
||||
// including the current timestamp, and "ALLOW = false" indicates the start of
|
||||
// dropping frames including the current timestamp.
|
||||
//
|
||||
// FlowLimiterCalculator provides limited support for multiple input streams.
|
||||
|
|
42
mediapipe/calculators/core/nonzero_calculator.cc
Normal file
42
mediapipe/calculators/core/nonzero_calculator.cc
Normal file
|
@ -0,0 +1,42 @@
|
|||
// Copyright 2021 The MediaPipe Authors.
|
||||
//
|
||||
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||
// you may not use this file except in compliance with the License.
|
||||
// You may obtain a copy of the License at
|
||||
//
|
||||
// http://www.apache.org/licenses/LICENSE-2.0
|
||||
//
|
||||
// Unless required by applicable law or agreed to in writing, software
|
||||
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
// See the License for the specific language governing permissions and
|
||||
// limitations under the License.
|
||||
|
||||
#include "mediapipe/framework/api2/node.h"
|
||||
#include "mediapipe/framework/calculator_framework.h"
|
||||
#include "mediapipe/framework/port/ret_check.h"
|
||||
|
||||
namespace mediapipe {
|
||||
namespace api2 {
|
||||
|
||||
// A Calculator that returns 0 if INPUT is 0, and 1 otherwise.
|
||||
class NonZeroCalculator : public Node {
|
||||
public:
|
||||
static constexpr Input<int>::SideFallback kIn{"INPUT"};
|
||||
static constexpr Output<int> kOut{"OUTPUT"};
|
||||
|
||||
MEDIAPIPE_NODE_CONTRACT(kIn, kOut);
|
||||
|
||||
absl::Status Process(CalculatorContext* cc) final {
|
||||
if (!kIn(cc).IsEmpty()) {
|
||||
auto output = std::make_unique<int>((*kIn(cc) != 0) ? 1 : 0);
|
||||
kOut(cc).Send(std::move(output));
|
||||
}
|
||||
return absl::OkStatus();
|
||||
}
|
||||
};
|
||||
|
||||
MEDIAPIPE_REGISTER_NODE(NonZeroCalculator);
|
||||
|
||||
} // namespace api2
|
||||
} // namespace mediapipe
|
|
@ -87,7 +87,6 @@ absl::Status PacketResamplerCalculator::Open(CalculatorContext* cc) {
|
|||
|
||||
flush_last_packet_ = resampler_options.flush_last_packet();
|
||||
jitter_ = resampler_options.jitter();
|
||||
jitter_with_reflection_ = resampler_options.jitter_with_reflection();
|
||||
|
||||
input_data_id_ = cc->Inputs().GetId("DATA", 0);
|
||||
if (!input_data_id_.IsValid()) {
|
||||
|
@ -98,11 +97,7 @@ absl::Status PacketResamplerCalculator::Open(CalculatorContext* cc) {
|
|||
output_data_id_ = cc->Outputs().GetId("", 0);
|
||||
}
|
||||
|
||||
period_count_ = 0;
|
||||
frame_rate_ = resampler_options.frame_rate();
|
||||
base_timestamp_ = resampler_options.has_base_timestamp()
|
||||
? Timestamp(resampler_options.base_timestamp())
|
||||
: Timestamp::Unset();
|
||||
start_time_ = resampler_options.has_start_time()
|
||||
? Timestamp(resampler_options.start_time())
|
||||
: Timestamp::Min();
|
||||
|
@ -141,30 +136,9 @@ absl::Status PacketResamplerCalculator::Open(CalculatorContext* cc) {
|
|||
}
|
||||
}
|
||||
|
||||
if (jitter_ != 0.0) {
|
||||
if (resampler_options.output_header() !=
|
||||
PacketResamplerCalculatorOptions::NONE) {
|
||||
LOG(WARNING) << "VideoHeader::frame_rate holds the target value and not "
|
||||
"the actual value.";
|
||||
}
|
||||
if (flush_last_packet_) {
|
||||
flush_last_packet_ = false;
|
||||
LOG(WARNING) << "PacketResamplerCalculatorOptions.flush_last_packet is "
|
||||
"ignored, because we are adding jitter.";
|
||||
}
|
||||
const auto& seed = cc->InputSidePackets().Tag("SEED").Get<std::string>();
|
||||
random_ = CreateSecureRandom(seed);
|
||||
if (random_ == nullptr) {
|
||||
return absl::Status(
|
||||
absl::StatusCode::kInvalidArgument,
|
||||
"SecureRandom is not available. With \"jitter\" specified, "
|
||||
"PacketResamplerCalculator processing cannot proceed.");
|
||||
}
|
||||
packet_reservoir_random_ = CreateSecureRandom(seed);
|
||||
}
|
||||
packet_reservoir_ =
|
||||
std::make_unique<PacketReservoir>(packet_reservoir_random_.get());
|
||||
return absl::OkStatus();
|
||||
strategy_ = GetSamplingStrategy(resampler_options);
|
||||
|
||||
return strategy_->Open(cc);
|
||||
}
|
||||
|
||||
absl::Status PacketResamplerCalculator::Process(CalculatorContext* cc) {
|
||||
|
@ -177,171 +151,13 @@ absl::Status PacketResamplerCalculator::Process(CalculatorContext* cc) {
|
|||
return absl::OkStatus();
|
||||
}
|
||||
}
|
||||
if (jitter_ != 0.0 && random_ != nullptr) {
|
||||
// Packet reservior is used to make sure there's an output for every period,
|
||||
// e.g. partial period at the end of the stream.
|
||||
if (packet_reservoir_->IsEnabled() &&
|
||||
(first_timestamp_ == Timestamp::Unset() ||
|
||||
(cc->InputTimestamp() - next_output_timestamp_min_).Value() >= 0)) {
|
||||
auto curr_packet = cc->Inputs().Get(input_data_id_).Value();
|
||||
packet_reservoir_->AddSample(curr_packet);
|
||||
}
|
||||
MP_RETURN_IF_ERROR(ProcessWithJitter(cc));
|
||||
} else {
|
||||
MP_RETURN_IF_ERROR(ProcessWithoutJitter(cc));
|
||||
|
||||
if (absl::Status status = strategy_->Process(cc); !status.ok()) {
|
||||
return status; // Avoid MP_RETURN_IF_ERROR macro for external release.
|
||||
}
|
||||
|
||||
last_packet_ = cc->Inputs().Get(input_data_id_).Value();
|
||||
return absl::OkStatus();
|
||||
}
|
||||
|
||||
void PacketResamplerCalculator::InitializeNextOutputTimestampWithJitter() {
|
||||
next_output_timestamp_min_ = first_timestamp_;
|
||||
if (jitter_with_reflection_) {
|
||||
next_output_timestamp_ =
|
||||
first_timestamp_ + random_->UnbiasedUniform64(frame_time_usec_);
|
||||
return;
|
||||
}
|
||||
next_output_timestamp_ =
|
||||
first_timestamp_ + frame_time_usec_ * random_->RandFloat();
|
||||
}
|
||||
|
||||
void PacketResamplerCalculator::UpdateNextOutputTimestampWithJitter() {
|
||||
packet_reservoir_->Clear();
|
||||
if (jitter_with_reflection_) {
|
||||
next_output_timestamp_min_ += frame_time_usec_;
|
||||
Timestamp next_output_timestamp_max_ =
|
||||
next_output_timestamp_min_ + frame_time_usec_;
|
||||
|
||||
next_output_timestamp_ += frame_time_usec_ +
|
||||
random_->UnbiasedUniform64(2 * jitter_usec_ + 1) -
|
||||
jitter_usec_;
|
||||
next_output_timestamp_ = Timestamp(ReflectBetween(
|
||||
next_output_timestamp_.Value(), next_output_timestamp_min_.Value(),
|
||||
next_output_timestamp_max_.Value()));
|
||||
CHECK_GE(next_output_timestamp_, next_output_timestamp_min_);
|
||||
CHECK_LT(next_output_timestamp_, next_output_timestamp_max_);
|
||||
return;
|
||||
}
|
||||
packet_reservoir_->Disable();
|
||||
next_output_timestamp_ +=
|
||||
frame_time_usec_ *
|
||||
((1.0 - jitter_) + 2.0 * jitter_ * random_->RandFloat());
|
||||
}
|
||||
|
||||
absl::Status PacketResamplerCalculator::ProcessWithJitter(
|
||||
CalculatorContext* cc) {
|
||||
RET_CHECK_GT(cc->InputTimestamp(), Timestamp::PreStream());
|
||||
RET_CHECK_NE(jitter_, 0.0);
|
||||
|
||||
if (first_timestamp_ == Timestamp::Unset()) {
|
||||
first_timestamp_ = cc->InputTimestamp();
|
||||
InitializeNextOutputTimestampWithJitter();
|
||||
if (first_timestamp_ == next_output_timestamp_) {
|
||||
OutputWithinLimits(
|
||||
cc,
|
||||
cc->Inputs().Get(input_data_id_).Value().At(next_output_timestamp_));
|
||||
UpdateNextOutputTimestampWithJitter();
|
||||
}
|
||||
return absl::OkStatus();
|
||||
}
|
||||
|
||||
if (frame_time_usec_ <
|
||||
(cc->InputTimestamp() - last_packet_.Timestamp()).Value()) {
|
||||
LOG_FIRST_N(WARNING, 2)
|
||||
<< "Adding jitter is not very useful when upsampling.";
|
||||
}
|
||||
|
||||
while (true) {
|
||||
const int64 last_diff =
|
||||
(next_output_timestamp_ - last_packet_.Timestamp()).Value();
|
||||
RET_CHECK_GT(last_diff, 0);
|
||||
const int64 curr_diff =
|
||||
(next_output_timestamp_ - cc->InputTimestamp()).Value();
|
||||
if (curr_diff > 0) {
|
||||
break;
|
||||
}
|
||||
OutputWithinLimits(cc, (std::abs(curr_diff) > last_diff
|
||||
? last_packet_
|
||||
: cc->Inputs().Get(input_data_id_).Value())
|
||||
.At(next_output_timestamp_));
|
||||
UpdateNextOutputTimestampWithJitter();
|
||||
// From now on every time a packet is emitted the timestamp of the next
|
||||
// packet becomes known; that timestamp is stored in next_output_timestamp_.
|
||||
// The only exception to this rule is the packet emitted from Close() which
|
||||
// can only happen when jitter_with_reflection is enabled but in this case
|
||||
// next_output_timestamp_min_ is a non-decreasing lower bound of any
|
||||
// subsequent packet.
|
||||
const Timestamp timestamp_bound = jitter_with_reflection_
|
||||
? next_output_timestamp_min_
|
||||
: next_output_timestamp_;
|
||||
cc->Outputs().Get(output_data_id_).SetNextTimestampBound(timestamp_bound);
|
||||
}
|
||||
return absl::OkStatus();
|
||||
}
|
||||
|
||||
absl::Status PacketResamplerCalculator::ProcessWithoutJitter(
|
||||
CalculatorContext* cc) {
|
||||
RET_CHECK_GT(cc->InputTimestamp(), Timestamp::PreStream());
|
||||
RET_CHECK_EQ(jitter_, 0.0);
|
||||
|
||||
if (first_timestamp_ == Timestamp::Unset()) {
|
||||
// This is the first packet, initialize the first_timestamp_.
|
||||
if (base_timestamp_ == Timestamp::Unset()) {
|
||||
// Initialize first_timestamp_ with exactly the first packet timestamp.
|
||||
first_timestamp_ = cc->InputTimestamp();
|
||||
} else {
|
||||
// Initialize first_timestamp_ with the first packet timestamp
|
||||
// aligned to the base_timestamp_.
|
||||
int64 first_index = MathUtil::SafeRound<int64, double>(
|
||||
(cc->InputTimestamp() - base_timestamp_).Seconds() * frame_rate_);
|
||||
first_timestamp_ =
|
||||
base_timestamp_ + TimestampDiffFromSeconds(first_index / frame_rate_);
|
||||
}
|
||||
if (cc->Outputs().UsesTags() && cc->Outputs().HasTag("VIDEO_HEADER")) {
|
||||
cc->Outputs()
|
||||
.Tag("VIDEO_HEADER")
|
||||
.Add(new VideoHeader(video_header_), Timestamp::PreStream());
|
||||
}
|
||||
}
|
||||
const Timestamp received_timestamp = cc->InputTimestamp();
|
||||
const int64 received_timestamp_idx =
|
||||
TimestampToPeriodIndex(received_timestamp);
|
||||
// Only consider the received packet if it belongs to the current period
|
||||
// (== period_count_) or to a newer one (> period_count_).
|
||||
if (received_timestamp_idx >= period_count_) {
|
||||
// Fill the empty periods until we are in the same index as the received
|
||||
// packet.
|
||||
while (received_timestamp_idx > period_count_) {
|
||||
OutputWithinLimits(
|
||||
cc, last_packet_.At(PeriodIndexToTimestamp(period_count_)));
|
||||
++period_count_;
|
||||
}
|
||||
// Now, if the received packet has a timestamp larger than the middle of
|
||||
// the current period, we can send a packet without waiting. We send the
|
||||
// one closer to the middle.
|
||||
Timestamp target_timestamp = PeriodIndexToTimestamp(period_count_);
|
||||
if (received_timestamp >= target_timestamp) {
|
||||
bool have_last_packet = (last_packet_.Timestamp() != Timestamp::Unset());
|
||||
bool send_current =
|
||||
!have_last_packet || (received_timestamp - target_timestamp <=
|
||||
target_timestamp - last_packet_.Timestamp());
|
||||
if (send_current) {
|
||||
OutputWithinLimits(
|
||||
cc, cc->Inputs().Get(input_data_id_).Value().At(target_timestamp));
|
||||
} else {
|
||||
OutputWithinLimits(cc, last_packet_.At(target_timestamp));
|
||||
}
|
||||
++period_count_;
|
||||
}
|
||||
// TODO: Add a mechanism to the framework to allow these packets
|
||||
// to be output earlier (without waiting for a much later packet to
|
||||
// arrive)
|
||||
|
||||
// Update the bound for the next packet.
|
||||
cc->Outputs()
|
||||
.Get(output_data_id_)
|
||||
.SetNextTimestampBound(PeriodIndexToTimestamp(period_count_));
|
||||
}
|
||||
return absl::OkStatus();
|
||||
}
|
||||
|
||||
|
@ -349,17 +165,34 @@ absl::Status PacketResamplerCalculator::Close(CalculatorContext* cc) {
|
|||
if (!cc->GraphStatus().ok()) {
|
||||
return absl::OkStatus();
|
||||
}
|
||||
// Emit the last packet received if we have at least one packet, but
|
||||
// haven't sent anything for its period.
|
||||
if (first_timestamp_ != Timestamp::Unset() && flush_last_packet_ &&
|
||||
TimestampToPeriodIndex(last_packet_.Timestamp()) == period_count_) {
|
||||
OutputWithinLimits(cc,
|
||||
last_packet_.At(PeriodIndexToTimestamp(period_count_)));
|
||||
|
||||
return strategy_->Close(cc);
|
||||
}
|
||||
|
||||
std::unique_ptr<PacketResamplerStrategy>
|
||||
PacketResamplerCalculator::GetSamplingStrategy(
|
||||
const PacketResamplerCalculatorOptions& options) {
|
||||
if (options.reproducible_sampling()) {
|
||||
if (!options.jitter_with_reflection()) {
|
||||
LOG(WARNING)
|
||||
<< "reproducible_sampling enabled w/ jitter_with_reflection "
|
||||
"disabled. "
|
||||
<< "reproducible_sampling always uses jitter with reflection, "
|
||||
<< "Ignoring jitter_with_reflection setting.";
|
||||
}
|
||||
return absl::make_unique<ReproducibleJitterWithReflectionStrategy>(this);
|
||||
}
|
||||
if (!packet_reservoir_->IsEmpty()) {
|
||||
OutputWithinLimits(cc, packet_reservoir_->GetSample());
|
||||
|
||||
if (options.jitter() == 0) {
|
||||
return absl::make_unique<NoJitterStrategy>(this);
|
||||
}
|
||||
return absl::OkStatus();
|
||||
|
||||
if (options.jitter_with_reflection()) {
|
||||
return absl::make_unique<LegacyJitterWithReflectionStrategy>(this);
|
||||
}
|
||||
|
||||
// With jitter and no reflection.
|
||||
return absl::make_unique<JitterWithoutReflectionStrategy>(this);
|
||||
}
|
||||
|
||||
Timestamp PacketResamplerCalculator::PeriodIndexToTimestamp(int64 index) const {
|
||||
|
@ -385,4 +218,479 @@ void PacketResamplerCalculator::OutputWithinLimits(CalculatorContext* cc,
|
|||
}
|
||||
}
|
||||
|
||||
absl::Status LegacyJitterWithReflectionStrategy::Open(CalculatorContext* cc) {
|
||||
const auto resampler_options =
|
||||
tool::RetrieveOptions(cc->Options<PacketResamplerCalculatorOptions>(),
|
||||
cc->InputSidePackets(), "OPTIONS");
|
||||
|
||||
if (resampler_options.output_header() !=
|
||||
PacketResamplerCalculatorOptions::NONE) {
|
||||
LOG(WARNING) << "VideoHeader::frame_rate holds the target value and not "
|
||||
"the actual value.";
|
||||
}
|
||||
|
||||
if (calculator_->flush_last_packet_) {
|
||||
LOG(WARNING) << "PacketResamplerCalculatorOptions.flush_last_packet is "
|
||||
"ignored, because we are adding jitter.";
|
||||
}
|
||||
|
||||
const auto& seed = cc->InputSidePackets().Tag("SEED").Get<std::string>();
|
||||
random_ = CreateSecureRandom(seed);
|
||||
if (random_ == nullptr) {
|
||||
return absl::InvalidArgumentError(
|
||||
"SecureRandom is not available. With \"jitter\" specified, "
|
||||
"PacketResamplerCalculator processing cannot proceed.");
|
||||
}
|
||||
|
||||
packet_reservoir_random_ = CreateSecureRandom(seed);
|
||||
packet_reservoir_ =
|
||||
std::make_unique<PacketReservoir>(packet_reservoir_random_.get());
|
||||
|
||||
return absl::OkStatus();
|
||||
}
|
||||
absl::Status LegacyJitterWithReflectionStrategy::Close(CalculatorContext* cc) {
|
||||
if (!packet_reservoir_->IsEmpty()) {
|
||||
LOG(INFO) << "Emitting pack from reservoir.";
|
||||
calculator_->OutputWithinLimits(cc, packet_reservoir_->GetSample());
|
||||
}
|
||||
return absl::OkStatus();
|
||||
}
|
||||
absl::Status LegacyJitterWithReflectionStrategy::Process(
|
||||
CalculatorContext* cc) {
|
||||
RET_CHECK_GT(cc->InputTimestamp(), Timestamp::PreStream());
|
||||
|
||||
if (packet_reservoir_->IsEnabled() &&
|
||||
(first_timestamp_ == Timestamp::Unset() ||
|
||||
(cc->InputTimestamp() - next_output_timestamp_min_).Value() >= 0)) {
|
||||
auto curr_packet = cc->Inputs().Get(calculator_->input_data_id_).Value();
|
||||
packet_reservoir_->AddSample(curr_packet);
|
||||
}
|
||||
|
||||
if (first_timestamp_ == Timestamp::Unset()) {
|
||||
first_timestamp_ = cc->InputTimestamp();
|
||||
InitializeNextOutputTimestampWithJitter();
|
||||
if (first_timestamp_ == next_output_timestamp_) {
|
||||
calculator_->OutputWithinLimits(cc, cc->Inputs()
|
||||
.Get(calculator_->input_data_id_)
|
||||
.Value()
|
||||
.At(next_output_timestamp_));
|
||||
UpdateNextOutputTimestampWithJitter();
|
||||
}
|
||||
return absl::OkStatus();
|
||||
}
|
||||
|
||||
if (calculator_->frame_time_usec_ <
|
||||
(cc->InputTimestamp() - calculator_->last_packet_.Timestamp()).Value()) {
|
||||
LOG_FIRST_N(WARNING, 2)
|
||||
<< "Adding jitter is not very useful when upsampling.";
|
||||
}
|
||||
|
||||
while (true) {
|
||||
const int64 last_diff =
|
||||
(next_output_timestamp_ - calculator_->last_packet_.Timestamp())
|
||||
.Value();
|
||||
RET_CHECK_GT(last_diff, 0);
|
||||
const int64 curr_diff =
|
||||
(next_output_timestamp_ - cc->InputTimestamp()).Value();
|
||||
if (curr_diff > 0) {
|
||||
break;
|
||||
}
|
||||
calculator_->OutputWithinLimits(
|
||||
cc, (std::abs(curr_diff) > last_diff
|
||||
? calculator_->last_packet_
|
||||
: cc->Inputs().Get(calculator_->input_data_id_).Value())
|
||||
.At(next_output_timestamp_));
|
||||
UpdateNextOutputTimestampWithJitter();
|
||||
// From now on every time a packet is emitted the timestamp of the next
|
||||
// packet becomes known; that timestamp is stored in next_output_timestamp_.
|
||||
// The only exception to this rule is the packet emitted from Close() which
|
||||
// can only happen when jitter_with_reflection is enabled but in this case
|
||||
// next_output_timestamp_min_ is a non-decreasing lower bound of any
|
||||
// subsequent packet.
|
||||
const Timestamp timestamp_bound = next_output_timestamp_min_;
|
||||
cc->Outputs()
|
||||
.Get(calculator_->output_data_id_)
|
||||
.SetNextTimestampBound(timestamp_bound);
|
||||
}
|
||||
return absl::OkStatus();
|
||||
}
|
||||
|
||||
void LegacyJitterWithReflectionStrategy::
|
||||
InitializeNextOutputTimestampWithJitter() {
|
||||
next_output_timestamp_min_ = first_timestamp_;
|
||||
next_output_timestamp_ =
|
||||
first_timestamp_ +
|
||||
random_->UnbiasedUniform64(calculator_->frame_time_usec_);
|
||||
}
|
||||
|
||||
void LegacyJitterWithReflectionStrategy::UpdateNextOutputTimestampWithJitter() {
|
||||
packet_reservoir_->Clear();
|
||||
next_output_timestamp_min_ += calculator_->frame_time_usec_;
|
||||
Timestamp next_output_timestamp_max_ =
|
||||
next_output_timestamp_min_ + calculator_->frame_time_usec_;
|
||||
|
||||
next_output_timestamp_ +=
|
||||
calculator_->frame_time_usec_ +
|
||||
random_->UnbiasedUniform64(2 * calculator_->jitter_usec_ + 1) -
|
||||
calculator_->jitter_usec_;
|
||||
next_output_timestamp_ = Timestamp(ReflectBetween(
|
||||
next_output_timestamp_.Value(), next_output_timestamp_min_.Value(),
|
||||
next_output_timestamp_max_.Value()));
|
||||
CHECK_GE(next_output_timestamp_, next_output_timestamp_min_);
|
||||
CHECK_LT(next_output_timestamp_, next_output_timestamp_max_);
|
||||
}
|
||||
|
||||
absl::Status ReproducibleJitterWithReflectionStrategy::Open(
|
||||
CalculatorContext* cc) {
|
||||
const auto resampler_options =
|
||||
tool::RetrieveOptions(cc->Options<PacketResamplerCalculatorOptions>(),
|
||||
cc->InputSidePackets(), "OPTIONS");
|
||||
|
||||
if (resampler_options.output_header() !=
|
||||
PacketResamplerCalculatorOptions::NONE) {
|
||||
LOG(WARNING) << "VideoHeader::frame_rate holds the target value and not "
|
||||
"the actual value.";
|
||||
}
|
||||
|
||||
if (calculator_->flush_last_packet_) {
|
||||
LOG(WARNING) << "PacketResamplerCalculatorOptions.flush_last_packet is "
|
||||
"ignored, because we are adding jitter.";
|
||||
}
|
||||
|
||||
const auto& seed = cc->InputSidePackets().Tag("SEED").Get<std::string>();
|
||||
random_ = CreateSecureRandom(seed);
|
||||
if (random_ == nullptr) {
|
||||
return absl::InvalidArgumentError(
|
||||
"SecureRandom is not available. With \"jitter\" specified, "
|
||||
"PacketResamplerCalculator processing cannot proceed.");
|
||||
}
|
||||
|
||||
return absl::OkStatus();
|
||||
}
|
||||
absl::Status ReproducibleJitterWithReflectionStrategy::Close(
|
||||
CalculatorContext* cc) {
|
||||
// If last packet is non-empty and a packet hasn't been emitted for this
|
||||
// period, emit the last packet.
|
||||
if (!calculator_->last_packet_.IsEmpty() && !packet_emitted_this_period_) {
|
||||
calculator_->OutputWithinLimits(
|
||||
cc, calculator_->last_packet_.At(next_output_timestamp_));
|
||||
}
|
||||
return absl::OkStatus();
|
||||
}
|
||||
absl::Status ReproducibleJitterWithReflectionStrategy::Process(
|
||||
CalculatorContext* cc) {
|
||||
RET_CHECK_GT(cc->InputTimestamp(), Timestamp::PreStream());
|
||||
|
||||
Packet current_packet = cc->Inputs().Get(calculator_->input_data_id_).Value();
|
||||
|
||||
if (calculator_->last_packet_.IsEmpty()) {
|
||||
// last_packet is empty, this is the first packet of the stream.
|
||||
|
||||
InitializeNextOutputTimestamp(current_packet.Timestamp());
|
||||
|
||||
// If next_output_timestamp_ happens to fall before current_packet, emit
|
||||
// current packet. Only a single packet can be emitted at the beginning
|
||||
// of the stream.
|
||||
if (next_output_timestamp_ < current_packet.Timestamp()) {
|
||||
calculator_->OutputWithinLimits(
|
||||
cc, current_packet.At(next_output_timestamp_));
|
||||
packet_emitted_this_period_ = true;
|
||||
}
|
||||
|
||||
return absl::OkStatus();
|
||||
}
|
||||
|
||||
// Last packet is set, so we are mid-stream.
|
||||
if (calculator_->frame_time_usec_ <
|
||||
(current_packet.Timestamp() - calculator_->last_packet_.Timestamp())
|
||||
.Value()) {
|
||||
// Note, if the stream is upsampling, this could lead to the same packet
|
||||
// being emitted twice. Upsampling and jitter doesn't make much sense
|
||||
// but does technically work.
|
||||
LOG_FIRST_N(WARNING, 2)
|
||||
<< "Adding jitter is not very useful when upsampling.";
|
||||
}
|
||||
|
||||
// Since we may be upsampling, we need to iteratively advance the
|
||||
// next_output_timestamp_ one period at a time until it reaches the period
|
||||
// current_packet is in. During this process, last_packet and/or
|
||||
// current_packet may be repeatly emitted.
|
||||
|
||||
UpdateNextOutputTimestamp(current_packet.Timestamp());
|
||||
|
||||
while (!packet_emitted_this_period_ &&
|
||||
next_output_timestamp_ <= current_packet.Timestamp()) {
|
||||
// last_packet < next_output_timestamp_ <= current_packet,
|
||||
// so emit the closest packet.
|
||||
Packet packet_to_emit =
|
||||
current_packet.Timestamp() - next_output_timestamp_ <
|
||||
next_output_timestamp_ - calculator_->last_packet_.Timestamp()
|
||||
? current_packet
|
||||
: calculator_->last_packet_;
|
||||
calculator_->OutputWithinLimits(cc,
|
||||
packet_to_emit.At(next_output_timestamp_));
|
||||
|
||||
packet_emitted_this_period_ = true;
|
||||
|
||||
// If we are upsampling, packet_emitted_this_period_ can be reset by
|
||||
// the following UpdateNext and the loop will iterate.
|
||||
UpdateNextOutputTimestamp(current_packet.Timestamp());
|
||||
}
|
||||
|
||||
// Set the bounds on the output stream. Note, if we emitted a packet
|
||||
// above, it will already be set at next_output_timestamp_ + 1, in which
|
||||
// case we have to skip setting it.
|
||||
if (cc->Outputs().Get(calculator_->output_data_id_).NextTimestampBound() <
|
||||
next_output_timestamp_) {
|
||||
cc->Outputs()
|
||||
.Get(calculator_->output_data_id_)
|
||||
.SetNextTimestampBound(next_output_timestamp_);
|
||||
}
|
||||
return absl::OkStatus();
|
||||
}
|
||||
|
||||
void ReproducibleJitterWithReflectionStrategy::InitializeNextOutputTimestamp(
|
||||
Timestamp current_timestamp) {
|
||||
if (next_output_timestamp_min_ != Timestamp::Unset()) {
|
||||
return;
|
||||
}
|
||||
|
||||
next_output_timestamp_min_ = Timestamp(0);
|
||||
next_output_timestamp_ =
|
||||
Timestamp(GetNextRandom(calculator_->frame_time_usec_));
|
||||
|
||||
// While the current timestamp is ahead of the max (i.e. min + frame_time),
|
||||
// fast-forward.
|
||||
while (current_timestamp >=
|
||||
next_output_timestamp_min_ + calculator_->frame_time_usec_) {
|
||||
packet_emitted_this_period_ = true; // Force update...
|
||||
UpdateNextOutputTimestamp(current_timestamp);
|
||||
}
|
||||
}
|
||||
|
||||
void ReproducibleJitterWithReflectionStrategy::UpdateNextOutputTimestamp(
|
||||
Timestamp current_timestamp) {
|
||||
if (packet_emitted_this_period_ &&
|
||||
current_timestamp >=
|
||||
next_output_timestamp_min_ + calculator_->frame_time_usec_) {
|
||||
next_output_timestamp_min_ += calculator_->frame_time_usec_;
|
||||
Timestamp next_output_timestamp_max_ =
|
||||
next_output_timestamp_min_ + calculator_->frame_time_usec_;
|
||||
|
||||
next_output_timestamp_ += calculator_->frame_time_usec_ +
|
||||
GetNextRandom(2 * calculator_->jitter_usec_ + 1) -
|
||||
calculator_->jitter_usec_;
|
||||
next_output_timestamp_ = Timestamp(ReflectBetween(
|
||||
next_output_timestamp_.Value(), next_output_timestamp_min_.Value(),
|
||||
next_output_timestamp_max_.Value()));
|
||||
|
||||
packet_emitted_this_period_ = false;
|
||||
}
|
||||
}
|
||||
|
||||
absl::Status JitterWithoutReflectionStrategy::Open(CalculatorContext* cc) {
|
||||
const auto resampler_options =
|
||||
tool::RetrieveOptions(cc->Options<PacketResamplerCalculatorOptions>(),
|
||||
cc->InputSidePackets(), "OPTIONS");
|
||||
|
||||
if (resampler_options.output_header() !=
|
||||
PacketResamplerCalculatorOptions::NONE) {
|
||||
LOG(WARNING) << "VideoHeader::frame_rate holds the target value and not "
|
||||
"the actual value.";
|
||||
}
|
||||
|
||||
if (calculator_->flush_last_packet_) {
|
||||
LOG(WARNING) << "PacketResamplerCalculatorOptions.flush_last_packet is "
|
||||
"ignored, because we are adding jitter.";
|
||||
}
|
||||
|
||||
const auto& seed = cc->InputSidePackets().Tag("SEED").Get<std::string>();
|
||||
random_ = CreateSecureRandom(seed);
|
||||
if (random_ == nullptr) {
|
||||
return absl::InvalidArgumentError(
|
||||
"SecureRandom is not available. With \"jitter\" specified, "
|
||||
"PacketResamplerCalculator processing cannot proceed.");
|
||||
}
|
||||
|
||||
packet_reservoir_random_ = CreateSecureRandom(seed);
|
||||
packet_reservoir_ =
|
||||
absl::make_unique<PacketReservoir>(packet_reservoir_random_.get());
|
||||
|
||||
return absl::OkStatus();
|
||||
}
|
||||
absl::Status JitterWithoutReflectionStrategy::Close(CalculatorContext* cc) {
|
||||
if (!packet_reservoir_->IsEmpty()) {
|
||||
calculator_->OutputWithinLimits(cc, packet_reservoir_->GetSample());
|
||||
}
|
||||
return absl::OkStatus();
|
||||
}
|
||||
absl::Status JitterWithoutReflectionStrategy::Process(CalculatorContext* cc) {
|
||||
RET_CHECK_GT(cc->InputTimestamp(), Timestamp::PreStream());
|
||||
|
||||
// Packet reservior is used to make sure there's an output for every period,
|
||||
// e.g. partial period at the end of the stream.
|
||||
if (packet_reservoir_->IsEnabled() &&
|
||||
(calculator_->first_timestamp_ == Timestamp::Unset() ||
|
||||
(cc->InputTimestamp() - next_output_timestamp_min_).Value() >= 0)) {
|
||||
auto curr_packet = cc->Inputs().Get(calculator_->input_data_id_).Value();
|
||||
packet_reservoir_->AddSample(curr_packet);
|
||||
}
|
||||
|
||||
if (calculator_->first_timestamp_ == Timestamp::Unset()) {
|
||||
calculator_->first_timestamp_ = cc->InputTimestamp();
|
||||
InitializeNextOutputTimestamp();
|
||||
if (calculator_->first_timestamp_ == next_output_timestamp_) {
|
||||
calculator_->OutputWithinLimits(cc, cc->Inputs()
|
||||
.Get(calculator_->input_data_id_)
|
||||
.Value()
|
||||
.At(next_output_timestamp_));
|
||||
UpdateNextOutputTimestamp();
|
||||
}
|
||||
return absl::OkStatus();
|
||||
}
|
||||
|
||||
if (calculator_->frame_time_usec_ <
|
||||
(cc->InputTimestamp() - calculator_->last_packet_.Timestamp()).Value()) {
|
||||
LOG_FIRST_N(WARNING, 2)
|
||||
<< "Adding jitter is not very useful when upsampling.";
|
||||
}
|
||||
|
||||
while (true) {
|
||||
const int64 last_diff =
|
||||
(next_output_timestamp_ - calculator_->last_packet_.Timestamp())
|
||||
.Value();
|
||||
RET_CHECK_GT(last_diff, 0);
|
||||
const int64 curr_diff =
|
||||
(next_output_timestamp_ - cc->InputTimestamp()).Value();
|
||||
if (curr_diff > 0) {
|
||||
break;
|
||||
}
|
||||
calculator_->OutputWithinLimits(
|
||||
cc, (std::abs(curr_diff) > last_diff
|
||||
? calculator_->last_packet_
|
||||
: cc->Inputs().Get(calculator_->input_data_id_).Value())
|
||||
.At(next_output_timestamp_));
|
||||
UpdateNextOutputTimestamp();
|
||||
cc->Outputs()
|
||||
.Get(calculator_->output_data_id_)
|
||||
.SetNextTimestampBound(next_output_timestamp_);
|
||||
}
|
||||
return absl::OkStatus();
|
||||
}
|
||||
|
||||
void JitterWithoutReflectionStrategy::InitializeNextOutputTimestamp() {
|
||||
next_output_timestamp_min_ = calculator_->first_timestamp_;
|
||||
next_output_timestamp_ = calculator_->first_timestamp_ +
|
||||
calculator_->frame_time_usec_ * random_->RandFloat();
|
||||
}
|
||||
|
||||
void JitterWithoutReflectionStrategy::UpdateNextOutputTimestamp() {
|
||||
packet_reservoir_->Clear();
|
||||
packet_reservoir_->Disable();
|
||||
next_output_timestamp_ += calculator_->frame_time_usec_ *
|
||||
((1.0 - calculator_->jitter_) +
|
||||
2.0 * calculator_->jitter_ * random_->RandFloat());
|
||||
}
|
||||
|
||||
absl::Status NoJitterStrategy::Open(CalculatorContext* cc) {
|
||||
const auto resampler_options =
|
||||
tool::RetrieveOptions(cc->Options<PacketResamplerCalculatorOptions>(),
|
||||
cc->InputSidePackets(), "OPTIONS");
|
||||
base_timestamp_ = resampler_options.has_base_timestamp()
|
||||
? Timestamp(resampler_options.base_timestamp())
|
||||
: Timestamp::Unset();
|
||||
|
||||
period_count_ = 0;
|
||||
|
||||
return absl::OkStatus();
|
||||
}
|
||||
absl::Status NoJitterStrategy::Close(CalculatorContext* cc) {
|
||||
// Emit the last packet received if we have at least one packet, but
|
||||
// haven't sent anything for its period.
|
||||
if (calculator_->first_timestamp_ != Timestamp::Unset() &&
|
||||
calculator_->flush_last_packet_ &&
|
||||
calculator_->TimestampToPeriodIndex(
|
||||
calculator_->last_packet_.Timestamp()) == period_count_) {
|
||||
calculator_->OutputWithinLimits(
|
||||
cc, calculator_->last_packet_.At(
|
||||
calculator_->PeriodIndexToTimestamp(period_count_)));
|
||||
}
|
||||
return absl::OkStatus();
|
||||
}
|
||||
absl::Status NoJitterStrategy::Process(CalculatorContext* cc) {
|
||||
RET_CHECK_GT(cc->InputTimestamp(), Timestamp::PreStream());
|
||||
|
||||
if (calculator_->first_timestamp_ == Timestamp::Unset()) {
|
||||
// This is the first packet, initialize the first_timestamp_.
|
||||
if (base_timestamp_ == Timestamp::Unset()) {
|
||||
// Initialize first_timestamp_ with exactly the first packet timestamp.
|
||||
calculator_->first_timestamp_ = cc->InputTimestamp();
|
||||
} else {
|
||||
// Initialize first_timestamp_ with the first packet timestamp
|
||||
// aligned to the base_timestamp_.
|
||||
int64 first_index = MathUtil::SafeRound<int64, double>(
|
||||
(cc->InputTimestamp() - base_timestamp_).Seconds() *
|
||||
calculator_->frame_rate_);
|
||||
calculator_->first_timestamp_ =
|
||||
base_timestamp_ +
|
||||
TimestampDiffFromSeconds(first_index / calculator_->frame_rate_);
|
||||
}
|
||||
if (cc->Outputs().UsesTags() && cc->Outputs().HasTag("VIDEO_HEADER")) {
|
||||
cc->Outputs()
|
||||
.Tag("VIDEO_HEADER")
|
||||
.Add(new VideoHeader(calculator_->video_header_),
|
||||
Timestamp::PreStream());
|
||||
}
|
||||
}
|
||||
const Timestamp received_timestamp = cc->InputTimestamp();
|
||||
const int64 received_timestamp_idx =
|
||||
calculator_->TimestampToPeriodIndex(received_timestamp);
|
||||
// Only consider the received packet if it belongs to the current period
|
||||
// (== period_count_) or to a newer one (> period_count_).
|
||||
if (received_timestamp_idx >= period_count_) {
|
||||
// Fill the empty periods until we are in the same index as the received
|
||||
// packet.
|
||||
while (received_timestamp_idx > period_count_) {
|
||||
calculator_->OutputWithinLimits(
|
||||
cc, calculator_->last_packet_.At(
|
||||
calculator_->PeriodIndexToTimestamp(period_count_)));
|
||||
++period_count_;
|
||||
}
|
||||
// Now, if the received packet has a timestamp larger than the middle of
|
||||
// the current period, we can send a packet without waiting. We send the
|
||||
// one closer to the middle.
|
||||
Timestamp target_timestamp =
|
||||
calculator_->PeriodIndexToTimestamp(period_count_);
|
||||
if (received_timestamp >= target_timestamp) {
|
||||
bool have_last_packet =
|
||||
(calculator_->last_packet_.Timestamp() != Timestamp::Unset());
|
||||
bool send_current =
|
||||
!have_last_packet ||
|
||||
(received_timestamp - target_timestamp <=
|
||||
target_timestamp - calculator_->last_packet_.Timestamp());
|
||||
if (send_current) {
|
||||
calculator_->OutputWithinLimits(cc,
|
||||
cc->Inputs()
|
||||
.Get(calculator_->input_data_id_)
|
||||
.Value()
|
||||
.At(target_timestamp));
|
||||
} else {
|
||||
calculator_->OutputWithinLimits(
|
||||
cc, calculator_->last_packet_.At(target_timestamp));
|
||||
}
|
||||
++period_count_;
|
||||
}
|
||||
// TODO: Add a mechanism to the framework to allow these packets
|
||||
// to be output earlier (without waiting for a much later packet to
|
||||
// arrive)
|
||||
|
||||
// Update the bound for the next packet.
|
||||
cc->Outputs()
|
||||
.Get(calculator_->output_data_id_)
|
||||
.SetNextTimestampBound(
|
||||
calculator_->PeriodIndexToTimestamp(period_count_));
|
||||
}
|
||||
return absl::OkStatus();
|
||||
}
|
||||
|
||||
} // namespace mediapipe
|
||||
|
|
|
@ -55,7 +55,7 @@ class PacketReservoir {
|
|||
// correspond to timestamp t.
|
||||
// - The next packet is chosen randomly (uniform distribution) among frames
|
||||
// that correspond to [t+(1-jitter)/frame_rate, t+(1+jitter)/frame_rate].
|
||||
// - if jitter_with_reflection_ is true, the timestamp will be reflected
|
||||
// - if jitter_with_reflection is true, the timestamp will be reflected
|
||||
// against the boundaries of [t_0 + (k-1)/frame_rate, t_0 + k/frame_rate)
|
||||
// so that its marginal distribution is uniform within this interval.
|
||||
// In the formula, t_0 is the timestamp of the first sampled
|
||||
|
@ -66,6 +66,17 @@ class PacketReservoir {
|
|||
// the resampling. For Cloud ML Video Intelligence API, the hash of the
|
||||
// input video should serve this purpose. For YouTube, either video ID or
|
||||
// content hex ID of the input video should do.
|
||||
// - If reproducible_samping is true, care is taken to allow reproducible
|
||||
// "mid-stream" sampling. The calculator can be executed on a stream that
|
||||
// doesn't start at the first period. For instance, if the calculator
|
||||
// is run on a 10 second stream it will produce the same set of samples
|
||||
// as two runs of the calculator, the first with 3 seconds of input starting
|
||||
// at time 0 and the second with 7 seconds of input starting at time +3s.
|
||||
// - In order to guarantee the exact same samples, 1) the inputs must be
|
||||
// aligned with the sampling period. For instance, if the sampling rate
|
||||
// is 2 frames per second, streams should be aligned on 0.5 second
|
||||
// boundaries, and 2) the stream must include at least one extra packet
|
||||
// before and after the second aligned sampling period.
|
||||
//
|
||||
// If jitter_ is not specified:
|
||||
// - The first packet defines the first_timestamp of the output stream,
|
||||
|
@ -105,19 +116,6 @@ class PacketResamplerCalculator : public CalculatorBase {
|
|||
absl::Status Close(CalculatorContext* cc) override;
|
||||
absl::Status Process(CalculatorContext* cc) override;
|
||||
|
||||
private:
|
||||
// Calculates the first sampled timestamp that incorporates a jittering
|
||||
// offset.
|
||||
void InitializeNextOutputTimestampWithJitter();
|
||||
// Calculates the next sampled timestamp that incorporates a jittering offset.
|
||||
void UpdateNextOutputTimestampWithJitter();
|
||||
|
||||
// Logic for Process() when jitter_ != 0.0.
|
||||
absl::Status ProcessWithJitter(CalculatorContext* cc);
|
||||
|
||||
// Logic for Process() when jitter_ == 0.0.
|
||||
absl::Status ProcessWithoutJitter(CalculatorContext* cc);
|
||||
|
||||
// Given the current count of periods that have passed, this returns
|
||||
// the next valid timestamp of the middle point of the next period:
|
||||
// if count is 0, it returns the first_timestamp_.
|
||||
|
@ -141,6 +139,16 @@ class PacketResamplerCalculator : public CalculatorBase {
|
|||
// Outputs a packet if it is in range (start_time_, end_time_).
|
||||
void OutputWithinLimits(CalculatorContext* cc, const Packet& packet) const;
|
||||
|
||||
protected:
|
||||
// Returns Sampling Strategy to use.
|
||||
//
|
||||
// Virtual to allow injection of testing strategies.
|
||||
virtual std::unique_ptr<class PacketResamplerStrategy> GetSamplingStrategy(
|
||||
const mediapipe::PacketResamplerCalculatorOptions& options);
|
||||
|
||||
private:
|
||||
std::unique_ptr<class PacketResamplerStrategy> strategy_;
|
||||
|
||||
// The timestamp of the first packet received.
|
||||
Timestamp first_timestamp_;
|
||||
|
||||
|
@ -150,14 +158,6 @@ class PacketResamplerCalculator : public CalculatorBase {
|
|||
// Inverse of frame_rate_.
|
||||
int64 frame_time_usec_;
|
||||
|
||||
// Number of periods that have passed (= #packets sent to the output).
|
||||
//
|
||||
// Can only be used if jitter_ equals zero.
|
||||
int64 period_count_;
|
||||
|
||||
// The last packet that was received.
|
||||
Packet last_packet_;
|
||||
|
||||
VideoHeader video_header_;
|
||||
// The "DATA" input stream.
|
||||
CollectionItemId input_data_id_;
|
||||
|
@ -165,23 +165,15 @@ class PacketResamplerCalculator : public CalculatorBase {
|
|||
CollectionItemId output_data_id_;
|
||||
|
||||
// Indicator whether to flush last packet even if its timestamp is greater
|
||||
// than the final stream timestamp. Set to false when jitter_ is non-zero.
|
||||
// than the final stream timestamp.
|
||||
bool flush_last_packet_;
|
||||
|
||||
// Jitter-related variables.
|
||||
std::unique_ptr<RandomBase> random_;
|
||||
double jitter_ = 0.0;
|
||||
bool jitter_with_reflection_;
|
||||
int64 jitter_usec_;
|
||||
Timestamp next_output_timestamp_;
|
||||
// If jittering_with_reflection_ is true, next_output_timestamp_ will be
|
||||
// kept within the interval
|
||||
// [next_output_timestamp_min_, next_output_timestamp_min_ + frame_time_usec_)
|
||||
Timestamp next_output_timestamp_min_;
|
||||
|
||||
// If specified, output timestamps are aligned with base_timestamp.
|
||||
// Otherwise, they are aligned with the first input timestamp.
|
||||
Timestamp base_timestamp_;
|
||||
int64 jitter_usec_;
|
||||
|
||||
// The last packet that was received.
|
||||
Packet last_packet_;
|
||||
|
||||
// If specified, only outputs at/after start_time are included.
|
||||
Timestamp start_time_;
|
||||
|
@ -191,15 +183,210 @@ class PacketResamplerCalculator : public CalculatorBase {
|
|||
|
||||
// If set, the output timestamps nearest to start_time and end_time
|
||||
// are included in the output, even if the nearest timestamp is not
|
||||
// between start_time and end_time.W
|
||||
// between start_time and end_time.
|
||||
bool round_limits_;
|
||||
|
||||
// Allow strategies access to all internal calculator state.
|
||||
//
|
||||
// The calculator and strategies are intimiately tied together so this should
|
||||
// not break encapsulation.
|
||||
friend class LegacyJitterWithReflectionStrategy;
|
||||
friend class ReproducibleJitterWithReflectionStrategy;
|
||||
friend class JitterWithoutReflectionStrategy;
|
||||
friend class NoJitterStrategy;
|
||||
};
|
||||
|
||||
// Abstract class encapsulating sampling stategy.
|
||||
//
|
||||
// These are used solely by PacketResamplerCalculator, but are exposed here
|
||||
// to facilitate tests.
|
||||
class PacketResamplerStrategy {
|
||||
public:
|
||||
PacketResamplerStrategy(PacketResamplerCalculator* calculator)
|
||||
: calculator_(calculator) {}
|
||||
virtual ~PacketResamplerStrategy() = default;
|
||||
|
||||
// Delegate for CalculatorBase::Open. See CalculatorBase for relevant
|
||||
// implementation considerations.
|
||||
virtual absl::Status Open(CalculatorContext* cc) = 0;
|
||||
// Delegate for CalculatorBase::Close. See CalculatorBase for relevant
|
||||
// implementation considerations.
|
||||
virtual absl::Status Close(CalculatorContext* cc) = 0;
|
||||
// Delegate for CalculatorBase::Process. See CalculatorBase for relevant
|
||||
// implementation considerations.
|
||||
virtual absl::Status Process(CalculatorContext* cc) = 0;
|
||||
|
||||
protected:
|
||||
// Calculator running strategy.
|
||||
PacketResamplerCalculator* calculator_;
|
||||
};
|
||||
|
||||
// Strategy that applies Jitter with reflection based sampling.
|
||||
//
|
||||
// Used by PacketResamplerCalculator when both Jitter and reflection are
|
||||
// enabled.
|
||||
//
|
||||
// This applies the legacy jitter with reflection which doesn't allow
|
||||
// for reproducibility of sampling when starting mid-stream. This is maintained
|
||||
// for backward compatibility.
|
||||
class LegacyJitterWithReflectionStrategy : public PacketResamplerStrategy {
|
||||
public:
|
||||
LegacyJitterWithReflectionStrategy(PacketResamplerCalculator* calculator)
|
||||
: PacketResamplerStrategy(calculator) {}
|
||||
|
||||
absl::Status Open(CalculatorContext* cc) override;
|
||||
absl::Status Close(CalculatorContext* cc) override;
|
||||
absl::Status Process(CalculatorContext* cc) override;
|
||||
|
||||
private:
|
||||
void InitializeNextOutputTimestampWithJitter();
|
||||
void UpdateNextOutputTimestampWithJitter();
|
||||
|
||||
// Jitter-related variables.
|
||||
std::unique_ptr<RandomBase> random_;
|
||||
|
||||
// The timestamp of the first packet received.
|
||||
Timestamp first_timestamp_;
|
||||
|
||||
// Next packet to be emitted. Since packets may not align perfectly with
|
||||
// next_output_timestamp_, the closest packet will be emitted.
|
||||
Timestamp next_output_timestamp_;
|
||||
|
||||
// Lower bound for next timestamp.
|
||||
//
|
||||
// next_output_timestamp_ will be kept within the interval
|
||||
// [next_output_timestamp_min_, next_output_timestamp_min_ + frame_time_usec_)
|
||||
Timestamp next_output_timestamp_min_ = Timestamp::Unset();
|
||||
|
||||
// packet reservior used for sampling random packet out of partial
|
||||
// period when jitter is enabled
|
||||
std::unique_ptr<PacketReservoir> packet_reservoir_;
|
||||
|
||||
// random number generator used in packet_reservior_.
|
||||
std::unique_ptr<RandomBase> packet_reservoir_random_;
|
||||
};
|
||||
|
||||
// Strategy that applies reproducible jitter with reflection based sampling.
|
||||
//
|
||||
// Used by PacketResamplerCalculator when both Jitter and reflection are
|
||||
// enabled.
|
||||
class ReproducibleJitterWithReflectionStrategy
|
||||
: public PacketResamplerStrategy {
|
||||
public:
|
||||
ReproducibleJitterWithReflectionStrategy(
|
||||
PacketResamplerCalculator* calculator)
|
||||
: PacketResamplerStrategy(calculator) {}
|
||||
|
||||
absl::Status Open(CalculatorContext* cc) override;
|
||||
absl::Status Close(CalculatorContext* cc) override;
|
||||
absl::Status Process(CalculatorContext* cc) override;
|
||||
|
||||
protected:
|
||||
// Returns next random in range (0,n].
|
||||
//
|
||||
// Exposed as virtual function for testing Jitter with reflection.
|
||||
// This is the only way random_ is accessed.
|
||||
virtual uint64 GetNextRandom(uint64 n) {
|
||||
return random_->UnbiasedUniform64(n);
|
||||
}
|
||||
|
||||
private:
|
||||
// Initializes Jitter with reflection.
|
||||
//
|
||||
// This will fast-forward to the period containing current_timestamp.
|
||||
// next_output_timestamp_ is guarnateed to be current_timestamp's period
|
||||
// and packet_emitted_this_period_ will be set to false.
|
||||
void InitializeNextOutputTimestamp(Timestamp current_timestamp);
|
||||
|
||||
// Potentially advances next_output_timestamp_ a single period.
|
||||
//
|
||||
// next_output_timestamp_ will only be advanced if packet_emitted_this_period_
|
||||
// is false. next_output_timestamp_ will never be advanced beyond
|
||||
// current_timestamp's period.
|
||||
//
|
||||
// However, next_output_timestamp_ could fall before current_timestamp's
|
||||
// period since only a single period can be advanced at a time.
|
||||
void UpdateNextOutputTimestamp(Timestamp current_timestamp);
|
||||
|
||||
// Jitter-related variables.
|
||||
std::unique_ptr<RandomBase> random_;
|
||||
|
||||
// Next packet to be emitted. Since packets may not align perfectly with
|
||||
// next_output_timestamp_, the closest packet will be emitted.
|
||||
Timestamp next_output_timestamp_;
|
||||
|
||||
// Lower bound for next timestamp.
|
||||
//
|
||||
// next_output_timestamp_ will be kept within the interval
|
||||
// [next_output_timestamp_min_, next_output_timestamp_min_ + frame_time_usec_)
|
||||
Timestamp next_output_timestamp_min_ = Timestamp::Unset();
|
||||
|
||||
// Indicates packet was emitted for current period (i.e. the period
|
||||
// next_output_timestamp_ falls in.
|
||||
bool packet_emitted_this_period_ = false;
|
||||
};
|
||||
|
||||
// Strategy that applies Jitter without reflection based sampling.
|
||||
//
|
||||
// Used by PacketResamplerCalculator when Jitter is enabled and reflection is
|
||||
// not enabled.
|
||||
class JitterWithoutReflectionStrategy : public PacketResamplerStrategy {
|
||||
public:
|
||||
JitterWithoutReflectionStrategy(PacketResamplerCalculator* calculator)
|
||||
: PacketResamplerStrategy(calculator) {}
|
||||
|
||||
absl::Status Open(CalculatorContext* cc) override;
|
||||
absl::Status Close(CalculatorContext* cc) override;
|
||||
absl::Status Process(CalculatorContext* cc) override;
|
||||
|
||||
private:
|
||||
// Calculates the first sampled timestamp that incorporates a jittering
|
||||
// offset.
|
||||
void InitializeNextOutputTimestamp();
|
||||
|
||||
// Calculates the next sampled timestamp that incorporates a jittering offset.
|
||||
void UpdateNextOutputTimestamp();
|
||||
|
||||
// Jitter-related variables.
|
||||
std::unique_ptr<RandomBase> random_;
|
||||
|
||||
// Next packet to be emitted. Since packets may not align perfectly with
|
||||
// next_output_timestamp_, the closest packet will be emitted.
|
||||
Timestamp next_output_timestamp_;
|
||||
|
||||
// Lower bound for next timestamp.
|
||||
//
|
||||
// next_output_timestamp_ will be kept within the interval
|
||||
// [next_output_timestamp_min_, next_output_timestamp_min_ + frame_time_usec_)
|
||||
Timestamp next_output_timestamp_min_ = Timestamp::Unset();
|
||||
|
||||
// packet reservior used for sampling random packet out of partial period.
|
||||
std::unique_ptr<PacketReservoir> packet_reservoir_;
|
||||
|
||||
// random number generator used in packet_reservior_.
|
||||
std::unique_ptr<RandomBase> packet_reservoir_random_;
|
||||
};
|
||||
|
||||
// Strategy that applies sampling without any jitter.
|
||||
//
|
||||
// Used by PacketResamplerCalculator when jitter is not enabled.
|
||||
class NoJitterStrategy : public PacketResamplerStrategy {
|
||||
public:
|
||||
NoJitterStrategy(PacketResamplerCalculator* calculator)
|
||||
: PacketResamplerStrategy(calculator) {}
|
||||
|
||||
absl::Status Open(CalculatorContext* cc) override;
|
||||
absl::Status Close(CalculatorContext* cc) override;
|
||||
absl::Status Process(CalculatorContext* cc) override;
|
||||
|
||||
private:
|
||||
// Number of periods that have passed (= #packets sent to the output).
|
||||
int64 period_count_;
|
||||
|
||||
// If specified, output timestamps are aligned with base_timestamp.
|
||||
// Otherwise, they are aligned with the first input timestamp.
|
||||
Timestamp base_timestamp_;
|
||||
};
|
||||
|
||||
} // namespace mediapipe
|
||||
#endif // MEDIAPIPE_CALCULATORS_CORE_PACKET_RESAMPLER_CALCULATOR_H_
|
||||
|
|
|
@ -68,8 +68,23 @@ message PacketResamplerCalculatorOptions {
|
|||
// pseudo-random number generator does its job and the number of frames is
|
||||
// sufficiently large, the average frame rate will be close to this value.
|
||||
optional double jitter = 4;
|
||||
|
||||
// Enables reflection when applying jitter.
|
||||
//
|
||||
// This option is ignored when reproducible_sampling is true, in which case
|
||||
// reflection will be used.
|
||||
//
|
||||
// New use cases should use reproducible_sampling = true, as
|
||||
// jitter_with_reflection is deprecated and will be removed at some point.
|
||||
optional bool jitter_with_reflection = 9 [default = false];
|
||||
|
||||
// If set, enabled reproducible sampling, allowing frames to be sampled
|
||||
// without regards to where the stream starts. See
|
||||
// packet_resampler_calculator.h for details.
|
||||
//
|
||||
// This enables reflection (ignoring jitter_with_reflection setting).
|
||||
optional bool reproducible_sampling = 10 [default = false];
|
||||
|
||||
// If specified, output timestamps are aligned with base_timestamp.
|
||||
// Otherwise, they are aligned with the first input timestamp.
|
||||
//
|
||||
|
|
|
@ -30,6 +30,7 @@
|
|||
|
||||
namespace mediapipe {
|
||||
|
||||
using ::testing::ElementsAre;
|
||||
namespace {
|
||||
// A simple version of CalculatorRunner with built-in convenience
|
||||
// methods for setting inputs from a vector and checking outputs
|
||||
|
@ -96,6 +97,77 @@ class SimpleRunner : public CalculatorRunner {
|
|||
static int static_count_;
|
||||
};
|
||||
|
||||
// Matcher for Packets with uint64 payload, comparing arg packet's
|
||||
// timestamp and uint64 payload.
|
||||
MATCHER_P2(PacketAtTimestamp, payload, timestamp,
|
||||
absl::StrCat(negation ? "isn't" : "is", " a packet with payload ",
|
||||
payload, " @ time ", timestamp)) {
|
||||
if (timestamp != arg.Timestamp().Value()) {
|
||||
*result_listener << "at incorrect timestamp = " << arg.Timestamp().Value();
|
||||
return false;
|
||||
}
|
||||
int64 actual_payload = arg.template Get<int64>();
|
||||
if (actual_payload != payload) {
|
||||
*result_listener << "with incorrect payload = " << actual_payload;
|
||||
return false;
|
||||
}
|
||||
return true;
|
||||
}
|
||||
|
||||
// JitterWithReflectionStrategy child class which injects a specified stream
|
||||
// of "random" numbers.
|
||||
//
|
||||
// Calculators are created through factory methods, making testing and injection
|
||||
// tricky. This class utilizes a static variable, random_sequence, to pass
|
||||
// the desired random sequence into the calculator.
|
||||
class ReproducibleJitterWithReflectionStrategyForTesting
|
||||
: public ReproducibleJitterWithReflectionStrategy {
|
||||
public:
|
||||
ReproducibleJitterWithReflectionStrategyForTesting(
|
||||
PacketResamplerCalculator* calculator)
|
||||
: ReproducibleJitterWithReflectionStrategy(calculator) {}
|
||||
|
||||
// Statically accessed random sequence to use for jitter with reflection.
|
||||
//
|
||||
// An EXPECT will fail if sequence is less than the number requested during
|
||||
// processing.
|
||||
static std::vector<uint64> random_sequence;
|
||||
|
||||
protected:
|
||||
virtual uint64 GetNextRandom(uint64 n) {
|
||||
EXPECT_LT(sequence_index_, random_sequence.size());
|
||||
return random_sequence[sequence_index_++] % n;
|
||||
}
|
||||
|
||||
private:
|
||||
int32 sequence_index_ = 0;
|
||||
};
|
||||
std::vector<uint64>
|
||||
ReproducibleJitterWithReflectionStrategyForTesting::random_sequence;
|
||||
|
||||
// PacketResamplerCalculator child class which injects a specified stream
|
||||
// of "random" numbers.
|
||||
//
|
||||
// Calculators are created through factory methods, making testing and injection
|
||||
// tricky. This class utilizes a static variable, random_sequence, to pass
|
||||
// the desired random sequence into the calculator.
|
||||
class ReproducibleResamplerCalculatorForTesting
|
||||
: public PacketResamplerCalculator {
|
||||
public:
|
||||
static absl::Status GetContract(CalculatorContract* cc) {
|
||||
return PacketResamplerCalculator::GetContract(cc);
|
||||
}
|
||||
|
||||
protected:
|
||||
std::unique_ptr<class PacketResamplerStrategy> GetSamplingStrategy(
|
||||
const mediapipe::PacketResamplerCalculatorOptions& Options) {
|
||||
return absl::make_unique<
|
||||
ReproducibleJitterWithReflectionStrategyForTesting>(this);
|
||||
}
|
||||
};
|
||||
|
||||
REGISTER_CALCULATOR(ReproducibleResamplerCalculatorForTesting);
|
||||
|
||||
int SimpleRunner::static_count_ = 0;
|
||||
|
||||
TEST(PacketResamplerCalculatorTest, NoPacketsInStream) {
|
||||
|
|
|
@ -561,13 +561,13 @@ cc_test(
|
|||
"//mediapipe/framework/formats:image_frame_opencv",
|
||||
"//mediapipe/framework/formats:rect_cc_proto",
|
||||
"//mediapipe/framework/formats:tensor",
|
||||
"//mediapipe/framework/port:commandlineflags",
|
||||
"//mediapipe/framework/port:gtest_main",
|
||||
"//mediapipe/framework/port:integral_types",
|
||||
"//mediapipe/framework/port:opencv_core",
|
||||
"//mediapipe/framework/port:opencv_imgcodecs",
|
||||
"//mediapipe/framework/port:opencv_imgproc",
|
||||
"//mediapipe/framework/port:parse_text_proto",
|
||||
"@com_google_absl//absl/flags:flag",
|
||||
"@com_google_absl//absl/memory",
|
||||
"@com_google_absl//absl/strings",
|
||||
],
|
||||
|
|
|
@ -15,6 +15,7 @@
|
|||
#include <cmath>
|
||||
#include <vector>
|
||||
|
||||
#include "absl/flags/flag.h"
|
||||
#include "absl/memory/memory.h"
|
||||
#include "absl/strings/substitute.h"
|
||||
#include "mediapipe/calculators/tensor/image_to_tensor_converter.h"
|
||||
|
@ -28,7 +29,6 @@
|
|||
#include "mediapipe/framework/formats/image_frame_opencv.h"
|
||||
#include "mediapipe/framework/formats/rect.pb.h"
|
||||
#include "mediapipe/framework/formats/tensor.h"
|
||||
#include "mediapipe/framework/port/commandlineflags.h"
|
||||
#include "mediapipe/framework/port/gtest.h"
|
||||
#include "mediapipe/framework/port/integral_types.h"
|
||||
#include "mediapipe/framework/port/opencv_core_inc.h"
|
||||
|
|
|
@ -41,7 +41,7 @@ class InferenceCalculatorSelectorImpl
|
|||
(options.has_delegate() && options.delegate().has_gpu());
|
||||
if (should_use_gpu) {
|
||||
impls.emplace_back("Metal");
|
||||
impls.emplace_back("MlDrift");
|
||||
impls.emplace_back("MlDriftWebGl");
|
||||
impls.emplace_back("Gl");
|
||||
}
|
||||
impls.emplace_back("Cpu");
|
||||
|
|
|
@ -118,8 +118,8 @@ struct InferenceCalculatorGl : public InferenceCalculator {
|
|||
static constexpr char kCalculatorName[] = "InferenceCalculatorGl";
|
||||
};
|
||||
|
||||
struct InferenceCalculatorMlDrift : public InferenceCalculator {
|
||||
static constexpr char kCalculatorName[] = "InferenceCalculatorMlDrift";
|
||||
struct InferenceCalculatorMlDriftWebGl : public InferenceCalculator {
|
||||
static constexpr char kCalculatorName[] = "InferenceCalculatorMlDriftWebGl";
|
||||
};
|
||||
|
||||
struct InferenceCalculatorMetal : public InferenceCalculator {
|
||||
|
|
|
@ -51,12 +51,12 @@ message InferenceCalculatorOptions {
|
|||
|
||||
// This option is valid for TFLite GPU delegate API2 only,
|
||||
// Choose any of available APIs to force running inference using it.
|
||||
enum API {
|
||||
enum Api {
|
||||
ANY = 0;
|
||||
OPENGL = 1;
|
||||
OPENCL = 2;
|
||||
}
|
||||
optional API api = 4 [default = ANY];
|
||||
optional Api api = 4 [default = ANY];
|
||||
|
||||
// This option is valid for TFLite GPU delegate API2 only,
|
||||
// Set to true to use 16-bit float precision. If max precision is needed,
|
||||
|
|
|
@ -136,7 +136,7 @@ absl::Status InferenceCalculatorCpuImpl::LoadModel(CalculatorContext* cc) {
|
|||
const auto& model = *model_packet_.Get();
|
||||
tflite::ops::builtin::BuiltinOpResolver op_resolver =
|
||||
kSideInCustomOpResolver(cc).GetOr(
|
||||
tflite::ops::builtin::BuiltinOpResolver());
|
||||
tflite::ops::builtin::BuiltinOpResolverWithoutDefaultDelegates());
|
||||
|
||||
tflite::InterpreterBuilder(model, op_resolver)(&interpreter_);
|
||||
RET_CHECK(interpreter_);
|
||||
|
|
|
@ -59,7 +59,7 @@ const std::vector<Param>& GetParams() {
|
|||
p.back().delegate.mutable_gpu();
|
||||
#endif // TARGET_IPHONE_SIMULATOR
|
||||
#if __EMSCRIPTEN__
|
||||
p.push_back({"MlDrift", "MlDrift"});
|
||||
p.push_back({"MlDriftWebGl", "MlDriftWebGl"});
|
||||
p.back().delegate.mutable_gpu();
|
||||
#endif // __EMSCRIPTEN__
|
||||
#if __ANDROID__ && 0 // Disabled for now since emulator can't go GLESv3
|
||||
|
|
|
@ -63,7 +63,7 @@ class InferenceCalculatorGlImpl
|
|||
mediapipe::GlCalculatorHelper gpu_helper_;
|
||||
std::unique_ptr<tflite::gpu::TFLiteGPURunner> tflite_gpu_runner_;
|
||||
bool allow_precision_loss_ = false;
|
||||
mediapipe::InferenceCalculatorOptions::Delegate::Gpu::API
|
||||
mediapipe::InferenceCalculatorOptions::Delegate::Gpu::Api
|
||||
tflite_gpu_runner_api_;
|
||||
#endif // MEDIAPIPE_TFLITE_GL_INFERENCE
|
||||
|
||||
|
@ -244,7 +244,7 @@ absl::Status InferenceCalculatorGlImpl::InitTFLiteGPURunner(
|
|||
const auto& model = *model_packet_.Get();
|
||||
tflite::ops::builtin::BuiltinOpResolver op_resolver =
|
||||
kSideInCustomOpResolver(cc).GetOr(
|
||||
tflite::ops::builtin::BuiltinOpResolver());
|
||||
tflite::ops::builtin::BuiltinOpResolverWithoutDefaultDelegates());
|
||||
|
||||
// Create runner
|
||||
tflite::gpu::InferenceOptions options;
|
||||
|
@ -294,7 +294,7 @@ absl::Status InferenceCalculatorGlImpl::LoadModel(CalculatorContext* cc) {
|
|||
const auto& model = *model_packet_.Get();
|
||||
tflite::ops::builtin::BuiltinOpResolver op_resolver =
|
||||
kSideInCustomOpResolver(cc).GetOr(
|
||||
tflite::ops::builtin::BuiltinOpResolver());
|
||||
tflite::ops::builtin::BuiltinOpResolverWithoutDefaultDelegates());
|
||||
|
||||
tflite::InterpreterBuilder(model, op_resolver)(&interpreter_);
|
||||
RET_CHECK(interpreter_);
|
||||
|
|
|
@ -200,7 +200,7 @@ absl::Status InferenceCalculatorMetalImpl::LoadModel(CalculatorContext* cc) {
|
|||
const auto& model = *model_packet_.Get();
|
||||
tflite::ops::builtin::BuiltinOpResolver op_resolver =
|
||||
kSideInCustomOpResolver(cc).GetOr(
|
||||
tflite::ops::builtin::BuiltinOpResolver());
|
||||
tflite::ops::builtin::BuiltinOpResolverWithoutDefaultDelegates());
|
||||
|
||||
tflite::InterpreterBuilder(model, op_resolver)(&interpreter_);
|
||||
RET_CHECK(interpreter_);
|
||||
|
|
|
@ -892,13 +892,13 @@ cc_test(
|
|||
"//mediapipe/framework:calculator_runner",
|
||||
"//mediapipe/framework:packet",
|
||||
"//mediapipe/framework/deps:file_path",
|
||||
"//mediapipe/framework/port:commandlineflags",
|
||||
"//mediapipe/framework/port:file_helpers",
|
||||
"//mediapipe/framework/port:gtest_main",
|
||||
"//mediapipe/framework/port:parse_text_proto",
|
||||
"//mediapipe/framework/port:status",
|
||||
"//mediapipe/framework/tool:tag_map_helper",
|
||||
"//mediapipe/framework/tool:validate_type",
|
||||
"@com_google_absl//absl/flags:flag",
|
||||
"@com_google_absl//absl/strings",
|
||||
"@org_tensorflow//tensorflow/core:direct_session",
|
||||
"@org_tensorflow//tensorflow/core:framework",
|
||||
|
@ -923,13 +923,13 @@ cc_test(
|
|||
"//mediapipe/framework:packet",
|
||||
"//mediapipe/framework:packet_generator_cc_proto",
|
||||
"//mediapipe/framework/deps:file_path",
|
||||
"//mediapipe/framework/port:commandlineflags",
|
||||
"//mediapipe/framework/port:file_helpers",
|
||||
"//mediapipe/framework/port:gtest_main",
|
||||
"//mediapipe/framework/port:parse_text_proto",
|
||||
"//mediapipe/framework/port:status",
|
||||
"//mediapipe/framework/tool:tag_map_helper",
|
||||
"//mediapipe/framework/tool:validate_type",
|
||||
"@com_google_absl//absl/flags:flag",
|
||||
"@com_google_absl//absl/strings",
|
||||
"@org_tensorflow//tensorflow/core:direct_session",
|
||||
"@org_tensorflow//tensorflow/core:framework",
|
||||
|
@ -954,11 +954,11 @@ cc_test(
|
|||
"//mediapipe/framework:packet",
|
||||
"//mediapipe/framework:packet_generator_cc_proto",
|
||||
"//mediapipe/framework/deps:file_path",
|
||||
"//mediapipe/framework/port:commandlineflags",
|
||||
"//mediapipe/framework/port:gtest_main",
|
||||
"//mediapipe/framework/port:parse_text_proto",
|
||||
"//mediapipe/framework/tool:tag_map_helper",
|
||||
"//mediapipe/framework/tool:validate_type",
|
||||
"@com_google_absl//absl/flags:flag",
|
||||
"@com_google_absl//absl/strings",
|
||||
"@org_tensorflow//tensorflow/core:all_kernels",
|
||||
"@org_tensorflow//tensorflow/core:direct_session",
|
||||
|
@ -981,11 +981,11 @@ cc_test(
|
|||
"//mediapipe/framework:calculator_runner",
|
||||
"//mediapipe/framework:packet",
|
||||
"//mediapipe/framework/deps:file_path",
|
||||
"//mediapipe/framework/port:commandlineflags",
|
||||
"//mediapipe/framework/port:gtest_main",
|
||||
"//mediapipe/framework/port:parse_text_proto",
|
||||
"//mediapipe/framework/tool:tag_map_helper",
|
||||
"//mediapipe/framework/tool:validate_type",
|
||||
"@com_google_absl//absl/flags:flag",
|
||||
"@com_google_absl//absl/strings",
|
||||
"@org_tensorflow//tensorflow/core:all_kernels",
|
||||
"@org_tensorflow//tensorflow/core:direct_session",
|
||||
|
@ -1144,8 +1144,8 @@ cc_test(
|
|||
":tensorflow_inference_calculator",
|
||||
":tensorflow_session_from_frozen_graph_generator",
|
||||
":tensorflow_session_from_frozen_graph_generator_cc_proto",
|
||||
"@com_google_absl//absl/flags:flag",
|
||||
"//mediapipe/framework/deps:file_path",
|
||||
"//mediapipe/framework/port:commandlineflags",
|
||||
"//mediapipe/framework/port:integral_types",
|
||||
"//mediapipe/framework:calculator_framework",
|
||||
"//mediapipe/framework:calculator_runner",
|
||||
|
|
|
@ -16,12 +16,12 @@
|
|||
#include <string>
|
||||
#include <vector>
|
||||
|
||||
#include "absl/flags/flag.h"
|
||||
#include "mediapipe/calculators/tensorflow/tensorflow_inference_calculator.pb.h"
|
||||
#include "mediapipe/calculators/tensorflow/tensorflow_session_from_frozen_graph_generator.pb.h"
|
||||
#include "mediapipe/framework/calculator_framework.h"
|
||||
#include "mediapipe/framework/calculator_runner.h"
|
||||
#include "mediapipe/framework/deps/file_path.h"
|
||||
#include "mediapipe/framework/port/commandlineflags.h"
|
||||
#include "mediapipe/framework/port/gmock.h"
|
||||
#include "mediapipe/framework/port/gtest.h"
|
||||
#include "mediapipe/framework/port/integral_types.h"
|
||||
|
|
|
@ -12,6 +12,7 @@
|
|||
// See the License for the specific language governing permissions and
|
||||
// limitations under the License.
|
||||
|
||||
#include "absl/flags/flag.h"
|
||||
#include "absl/strings/substitute.h"
|
||||
#include "mediapipe/calculators/tensorflow/tensorflow_session.h"
|
||||
#include "mediapipe/calculators/tensorflow/tensorflow_session_from_frozen_graph_calculator.pb.h"
|
||||
|
@ -19,7 +20,6 @@
|
|||
#include "mediapipe/framework/calculator_runner.h"
|
||||
#include "mediapipe/framework/deps/file_path.h"
|
||||
#include "mediapipe/framework/packet.h"
|
||||
#include "mediapipe/framework/port/commandlineflags.h"
|
||||
#include "mediapipe/framework/port/file_helpers.h"
|
||||
#include "mediapipe/framework/port/gmock.h"
|
||||
#include "mediapipe/framework/port/gtest.h"
|
||||
|
|
|
@ -12,6 +12,7 @@
|
|||
// See the License for the specific language governing permissions and
|
||||
// limitations under the License.
|
||||
|
||||
#include "absl/flags/flag.h"
|
||||
#include "absl/strings/substitute.h"
|
||||
#include "mediapipe/calculators/tensorflow/tensorflow_session.h"
|
||||
#include "mediapipe/calculators/tensorflow/tensorflow_session_from_frozen_graph_generator.pb.h"
|
||||
|
@ -19,7 +20,6 @@
|
|||
#include "mediapipe/framework/deps/file_path.h"
|
||||
#include "mediapipe/framework/packet.h"
|
||||
#include "mediapipe/framework/packet_generator.pb.h"
|
||||
#include "mediapipe/framework/port/commandlineflags.h"
|
||||
#include "mediapipe/framework/port/file_helpers.h"
|
||||
#include "mediapipe/framework/port/gmock.h"
|
||||
#include "mediapipe/framework/port/gtest.h"
|
||||
|
|
|
@ -12,6 +12,7 @@
|
|||
// See the License for the specific language governing permissions and
|
||||
// limitations under the License.
|
||||
|
||||
#include "absl/flags/flag.h"
|
||||
#include "absl/strings/str_replace.h"
|
||||
#include "mediapipe/calculators/tensorflow/tensorflow_session.h"
|
||||
#include "mediapipe/calculators/tensorflow/tensorflow_session_from_saved_model_calculator.pb.h"
|
||||
|
@ -20,7 +21,6 @@
|
|||
#include "mediapipe/framework/calculator_runner.h"
|
||||
#include "mediapipe/framework/deps/file_path.h"
|
||||
#include "mediapipe/framework/packet.h"
|
||||
#include "mediapipe/framework/port/commandlineflags.h"
|
||||
#include "mediapipe/framework/port/gmock.h"
|
||||
#include "mediapipe/framework/port/gtest.h"
|
||||
#include "mediapipe/framework/port/parse_text_proto.h"
|
||||
|
|
|
@ -12,6 +12,7 @@
|
|||
// See the License for the specific language governing permissions and
|
||||
// limitations under the License.
|
||||
|
||||
#include "absl/flags/flag.h"
|
||||
#include "absl/strings/str_replace.h"
|
||||
#include "mediapipe/calculators/tensorflow/tensorflow_session.h"
|
||||
#include "mediapipe/calculators/tensorflow/tensorflow_session_from_saved_model_generator.pb.h"
|
||||
|
@ -19,7 +20,6 @@
|
|||
#include "mediapipe/framework/deps/file_path.h"
|
||||
#include "mediapipe/framework/packet.h"
|
||||
#include "mediapipe/framework/packet_generator.pb.h"
|
||||
#include "mediapipe/framework/port/commandlineflags.h"
|
||||
#include "mediapipe/framework/port/gmock.h"
|
||||
#include "mediapipe/framework/port/gtest.h"
|
||||
#include "mediapipe/framework/port/parse_text_proto.h"
|
||||
|
|
|
@ -12,6 +12,7 @@
|
|||
// See the License for the specific language governing permissions and
|
||||
// limitations under the License.
|
||||
|
||||
#include "absl/container/flat_hash_map.h"
|
||||
#include "absl/strings/match.h"
|
||||
#include "mediapipe/calculators/core/packet_resampler_calculator.pb.h"
|
||||
#include "mediapipe/calculators/tensorflow/unpack_media_sequence_calculator.pb.h"
|
||||
|
|
|
@ -56,11 +56,4 @@ message UnpackMediaSequenceCalculatorOptions {
|
|||
// the clip start and end times and outputs these for the
|
||||
// AudioDecoderCalculator to consume.
|
||||
optional AudioDecoderOptions base_audio_decoder_options = 9;
|
||||
|
||||
optional string keypoint_names = 10 [
|
||||
default =
|
||||
"NOSE,LEFT_EAR,RIGHT_EAR,LEFT_SHOULDER,RIGHT_SHOULDER,LEFT_FORE_PAW,RIGHT_FORE_PAW,LEFT_HIP,RIGHT_HIP,LEFT_HIND_PAW,RIGHT_HIND_PAW,ROOT_TAIL"
|
||||
];
|
||||
// When the keypoint doesn't exists, output this default value.
|
||||
optional float default_keypoint_location = 11 [default = -1.0];
|
||||
}
|
||||
|
|
|
@ -147,11 +147,11 @@ cc_test(
|
|||
"//mediapipe/framework:calculator_runner",
|
||||
"//mediapipe/framework/deps:file_path",
|
||||
"//mediapipe/framework/formats/object_detection:anchor_cc_proto",
|
||||
"//mediapipe/framework/port:commandlineflags",
|
||||
"//mediapipe/framework/port:file_helpers",
|
||||
"//mediapipe/framework/port:gtest_main",
|
||||
"//mediapipe/framework/port:integral_types",
|
||||
"//mediapipe/framework/port:parse_text_proto",
|
||||
"@com_google_absl//absl/flags:flag",
|
||||
],
|
||||
)
|
||||
|
||||
|
|
|
@ -12,11 +12,11 @@
|
|||
// See the License for the specific language governing permissions and
|
||||
// limitations under the License.
|
||||
|
||||
#include "absl/flags/flag.h"
|
||||
#include "mediapipe/framework/calculator_framework.h"
|
||||
#include "mediapipe/framework/calculator_runner.h"
|
||||
#include "mediapipe/framework/deps/file_path.h"
|
||||
#include "mediapipe/framework/formats/object_detection/anchor.pb.h"
|
||||
#include "mediapipe/framework/port/commandlineflags.h"
|
||||
#include "mediapipe/framework/port/file_helpers.h"
|
||||
#include "mediapipe/framework/port/gmock.h"
|
||||
#include "mediapipe/framework/port/gtest.h"
|
||||
|
|
|
@ -278,7 +278,7 @@ class TfLiteInferenceCalculator : public CalculatorBase {
|
|||
|
||||
bool use_advanced_gpu_api_ = false;
|
||||
bool allow_precision_loss_ = false;
|
||||
mediapipe::TfLiteInferenceCalculatorOptions::Delegate::Gpu::API
|
||||
mediapipe::TfLiteInferenceCalculatorOptions::Delegate::Gpu::Api
|
||||
tflite_gpu_runner_api_;
|
||||
|
||||
bool use_kernel_caching_ = false;
|
||||
|
@ -702,11 +702,16 @@ absl::Status TfLiteInferenceCalculator::InitTFLiteGPURunner(
|
|||
#if MEDIAPIPE_TFLITE_GL_INFERENCE
|
||||
ASSIGN_OR_RETURN(model_packet_, GetModelAsPacket(*cc));
|
||||
const auto& model = *model_packet_.Get<TfLiteModelPtr>();
|
||||
tflite::ops::builtin::BuiltinOpResolver op_resolver;
|
||||
|
||||
tflite::ops::builtin::BuiltinOpResolverWithoutDefaultDelegates
|
||||
default_op_resolver;
|
||||
auto op_resolver_ptr =
|
||||
static_cast<const tflite::ops::builtin::BuiltinOpResolver*>(
|
||||
&default_op_resolver);
|
||||
if (cc->InputSidePackets().HasTag("CUSTOM_OP_RESOLVER")) {
|
||||
op_resolver = cc->InputSidePackets()
|
||||
.Tag("CUSTOM_OP_RESOLVER")
|
||||
.Get<tflite::ops::builtin::BuiltinOpResolver>();
|
||||
op_resolver_ptr = &(cc->InputSidePackets()
|
||||
.Tag("CUSTOM_OP_RESOLVER")
|
||||
.Get<tflite::ops::builtin::BuiltinOpResolver>());
|
||||
}
|
||||
|
||||
// Create runner
|
||||
|
@ -733,7 +738,7 @@ absl::Status TfLiteInferenceCalculator::InitTFLiteGPURunner(
|
|||
}
|
||||
}
|
||||
MP_RETURN_IF_ERROR(
|
||||
tflite_gpu_runner_->InitializeWithModel(model, op_resolver));
|
||||
tflite_gpu_runner_->InitializeWithModel(model, *op_resolver_ptr));
|
||||
|
||||
// Allocate interpreter memory for cpu output.
|
||||
if (!gpu_output_) {
|
||||
|
@ -786,18 +791,24 @@ absl::Status TfLiteInferenceCalculator::LoadModel(CalculatorContext* cc) {
|
|||
|
||||
ASSIGN_OR_RETURN(model_packet_, GetModelAsPacket(*cc));
|
||||
const auto& model = *model_packet_.Get<TfLiteModelPtr>();
|
||||
tflite::ops::builtin::BuiltinOpResolver op_resolver;
|
||||
|
||||
tflite::ops::builtin::BuiltinOpResolverWithoutDefaultDelegates
|
||||
default_op_resolver;
|
||||
auto op_resolver_ptr =
|
||||
static_cast<const tflite::ops::builtin::BuiltinOpResolver*>(
|
||||
&default_op_resolver);
|
||||
|
||||
if (cc->InputSidePackets().HasTag("CUSTOM_OP_RESOLVER")) {
|
||||
op_resolver = cc->InputSidePackets()
|
||||
.Tag("CUSTOM_OP_RESOLVER")
|
||||
.Get<tflite::ops::builtin::BuiltinOpResolver>();
|
||||
op_resolver_ptr = &(cc->InputSidePackets()
|
||||
.Tag("CUSTOM_OP_RESOLVER")
|
||||
.Get<tflite::ops::builtin::BuiltinOpResolver>());
|
||||
}
|
||||
|
||||
#if defined(MEDIAPIPE_EDGE_TPU)
|
||||
interpreter_ =
|
||||
BuildEdgeTpuInterpreter(model, &op_resolver, edgetpu_context_.get());
|
||||
BuildEdgeTpuInterpreter(model, op_resolver_ptr, edgetpu_context_.get());
|
||||
#else
|
||||
tflite::InterpreterBuilder(model, op_resolver)(&interpreter_);
|
||||
tflite::InterpreterBuilder(model, *op_resolver_ptr)(&interpreter_);
|
||||
#endif // MEDIAPIPE_EDGE_TPU
|
||||
|
||||
RET_CHECK(interpreter_);
|
||||
|
|
|
@ -51,12 +51,12 @@ message TfLiteInferenceCalculatorOptions {
|
|||
|
||||
// This option is valid for TFLite GPU delegate API2 only,
|
||||
// Choose any of available APIs to force running inference using it.
|
||||
enum API {
|
||||
enum Api {
|
||||
ANY = 0;
|
||||
OPENGL = 1;
|
||||
OPENCL = 2;
|
||||
}
|
||||
optional API api = 4 [default = ANY];
|
||||
optional Api api = 4 [default = ANY];
|
||||
|
||||
// This option is valid for TFLite GPU delegate API2 only,
|
||||
// Set to true to use 16-bit float precision. If max precision is needed,
|
||||
|
|
|
@ -841,12 +841,39 @@ cc_library(
|
|||
"//mediapipe/framework:timestamp",
|
||||
"//mediapipe/framework/formats:landmark_cc_proto",
|
||||
"//mediapipe/framework/port:ret_check",
|
||||
"//mediapipe/util/filtering:one_euro_filter",
|
||||
"//mediapipe/util/filtering:relative_velocity_filter",
|
||||
"@com_google_absl//absl/algorithm:container",
|
||||
],
|
||||
alwayslink = 1,
|
||||
)
|
||||
|
||||
mediapipe_proto_library(
|
||||
name = "visibility_smoothing_calculator_proto",
|
||||
srcs = ["visibility_smoothing_calculator.proto"],
|
||||
visibility = ["//visibility:public"],
|
||||
deps = [
|
||||
"//mediapipe/framework:calculator_options_proto",
|
||||
"//mediapipe/framework:calculator_proto",
|
||||
],
|
||||
)
|
||||
|
||||
cc_library(
|
||||
name = "visibility_smoothing_calculator",
|
||||
srcs = ["visibility_smoothing_calculator.cc"],
|
||||
visibility = ["//visibility:public"],
|
||||
deps = [
|
||||
":visibility_smoothing_calculator_cc_proto",
|
||||
"//mediapipe/framework:calculator_framework",
|
||||
"//mediapipe/framework:timestamp",
|
||||
"//mediapipe/framework/formats:landmark_cc_proto",
|
||||
"//mediapipe/framework/port:ret_check",
|
||||
"//mediapipe/util/filtering:low_pass_filter",
|
||||
"@com_google_absl//absl/algorithm:container",
|
||||
],
|
||||
alwayslink = 1,
|
||||
)
|
||||
|
||||
cc_library(
|
||||
name = "landmarks_to_floats_calculator",
|
||||
srcs = ["landmarks_to_floats_calculator.cc"],
|
||||
|
@ -858,7 +885,7 @@ cc_library(
|
|||
"//mediapipe/framework/formats:matrix",
|
||||
"//mediapipe/framework/port:ret_check",
|
||||
"//mediapipe/framework/port:status",
|
||||
"@eigen_archive//:eigen",
|
||||
"@eigen_archive//:eigen3",
|
||||
],
|
||||
alwayslink = 1,
|
||||
)
|
||||
|
@ -1194,3 +1221,34 @@ cc_library(
|
|||
}),
|
||||
alwayslink = 1,
|
||||
)
|
||||
|
||||
cc_library(
|
||||
name = "detection_classifications_merger_calculator",
|
||||
srcs = ["detection_classifications_merger_calculator.cc"],
|
||||
deps = [
|
||||
"//mediapipe/framework:calculator_framework",
|
||||
"//mediapipe/framework/api2:node",
|
||||
"//mediapipe/framework/formats:classification_cc_proto",
|
||||
"//mediapipe/framework/formats:detection_cc_proto",
|
||||
"//mediapipe/framework/port:ret_check",
|
||||
"//mediapipe/framework/port:status",
|
||||
"//mediapipe/framework/port:statusor",
|
||||
"@com_google_absl//absl/strings",
|
||||
],
|
||||
alwayslink = 1,
|
||||
)
|
||||
|
||||
cc_test(
|
||||
name = "detection_classifications_merger_calculator_test",
|
||||
srcs = ["detection_classifications_merger_calculator_test.cc"],
|
||||
deps = [
|
||||
":detection_classifications_merger_calculator",
|
||||
"//mediapipe/framework:calculator_framework",
|
||||
"//mediapipe/framework:calculator_runner",
|
||||
"//mediapipe/framework/deps:message_matchers",
|
||||
"//mediapipe/framework/formats:classification_cc_proto",
|
||||
"//mediapipe/framework/formats:detection_cc_proto",
|
||||
"//mediapipe/framework/port:gtest_main",
|
||||
"//mediapipe/framework/port:parse_text_proto",
|
||||
],
|
||||
)
|
||||
|
|
|
@ -0,0 +1,149 @@
|
|||
// Copyright 2021 The MediaPipe Authors.
|
||||
//
|
||||
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||
// you may not use this file except in compliance with the License.
|
||||
// You may obtain a copy of the License at
|
||||
//
|
||||
// http://www.apache.org/licenses/LICENSE-2.0
|
||||
//
|
||||
// Unless required by applicable law or agreed to in writing, software
|
||||
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
// See the License for the specific language governing permissions and
|
||||
// limitations under the License.
|
||||
|
||||
#include "absl/strings/substitute.h"
|
||||
#include "mediapipe/framework/api2/node.h"
|
||||
#include "mediapipe/framework/calculator_framework.h"
|
||||
#include "mediapipe/framework/formats/classification.pb.h"
|
||||
#include "mediapipe/framework/formats/detection.pb.h"
|
||||
#include "mediapipe/framework/port/ret_check.h"
|
||||
#include "mediapipe/framework/port/status.h"
|
||||
#include "mediapipe/framework/port/statusor.h"
|
||||
|
||||
namespace mediapipe {
|
||||
namespace api2 {
|
||||
|
||||
namespace {} // namespace
|
||||
|
||||
// Replaces the classification labels and scores from the input `Detection` with
|
||||
// the ones provided into the input `ClassificationList`. Namely:
|
||||
// * `label_id[i]` becomes `classification[i].index`
|
||||
// * `score[i]` becomes `classification[i].score`
|
||||
// * `label[i]` becomes `classification[i].label` (if present)
|
||||
//
|
||||
// In case the input `ClassificationList` contains no results (i.e.
|
||||
// `classification` is empty, which may happen if the classifier uses a score
|
||||
// threshold and no confident enough result were returned), the input
|
||||
// `Detection` is returned unchanged.
|
||||
//
|
||||
// This is specifically designed for two-stage detection cascades where the
|
||||
// detections returned by a standalone detector (typically a class-agnostic
|
||||
// localizer) are fed e.g. into a `TfLiteTaskImageClassifierCalculator` through
|
||||
// the optional "RECT" or "NORM_RECT" input, e.g:
|
||||
//
|
||||
// node {
|
||||
// calculator: "DetectionsToRectsCalculator"
|
||||
// # Output of an upstream object detector.
|
||||
// input_stream: "DETECTION:detection"
|
||||
// output_stream: "NORM_RECT:norm_rect"
|
||||
// }
|
||||
// node {
|
||||
// calculator: "TfLiteTaskImageClassifierCalculator"
|
||||
// input_stream: "IMAGE:image"
|
||||
// input_stream: "NORM_RECT:norm_rect"
|
||||
// output_stream: "CLASSIFICATION_RESULT:classification_result"
|
||||
// }
|
||||
// node {
|
||||
// calculator: "TfLiteTaskClassificationResultToClassificationsCalculator"
|
||||
// input_stream: "CLASSIFICATION_RESULT:classification_result"
|
||||
// output_stream: "CLASSIFICATION_LIST:classification_list"
|
||||
// }
|
||||
// node {
|
||||
// calculator: "DetectionClassificationsMergerCalculator"
|
||||
// input_stream: "INPUT_DETECTION:detection"
|
||||
// input_stream: "CLASSIFICATION_LIST:classification_list"
|
||||
// # Final output.
|
||||
// output_stream: "OUTPUT_DETECTION:classified_detection"
|
||||
// }
|
||||
//
|
||||
// Inputs:
|
||||
// INPUT_DETECTION: `Detection` proto.
|
||||
// CLASSIFICATION_LIST: `ClassificationList` proto.
|
||||
//
|
||||
// Output:
|
||||
// OUTPUT_DETECTION: modified `Detection` proto.
|
||||
class DetectionClassificationsMergerCalculator : public Node {
|
||||
public:
|
||||
static constexpr Input<Detection> kInputDetection{"INPUT_DETECTION"};
|
||||
static constexpr Input<ClassificationList> kClassificationList{
|
||||
"CLASSIFICATION_LIST"};
|
||||
static constexpr Output<Detection> kOutputDetection{"OUTPUT_DETECTION"};
|
||||
|
||||
MEDIAPIPE_NODE_CONTRACT(kInputDetection, kClassificationList,
|
||||
kOutputDetection);
|
||||
|
||||
absl::Status Process(CalculatorContext* cc) override;
|
||||
};
|
||||
MEDIAPIPE_REGISTER_NODE(DetectionClassificationsMergerCalculator);
|
||||
|
||||
absl::Status DetectionClassificationsMergerCalculator::Process(
|
||||
CalculatorContext* cc) {
|
||||
if (kInputDetection(cc).IsEmpty() && kClassificationList(cc).IsEmpty()) {
|
||||
return absl::OkStatus();
|
||||
}
|
||||
RET_CHECK(!kInputDetection(cc).IsEmpty());
|
||||
RET_CHECK(!kClassificationList(cc).IsEmpty());
|
||||
|
||||
Detection detection = *kInputDetection(cc);
|
||||
const ClassificationList& classification_list = *kClassificationList(cc);
|
||||
|
||||
// Update input detection only if classification did return results.
|
||||
if (classification_list.classification_size() != 0) {
|
||||
detection.clear_label_id();
|
||||
detection.clear_score();
|
||||
detection.clear_label();
|
||||
detection.clear_display_name();
|
||||
for (const auto& classification : classification_list.classification()) {
|
||||
if (!classification.has_index()) {
|
||||
return absl::InvalidArgumentError(
|
||||
"Missing required 'index' field in Classification proto.");
|
||||
}
|
||||
detection.add_label_id(classification.index());
|
||||
if (!classification.has_score()) {
|
||||
return absl::InvalidArgumentError(
|
||||
"Missing required 'score' field in Classification proto.");
|
||||
}
|
||||
detection.add_score(classification.score());
|
||||
if (classification.has_label()) {
|
||||
detection.add_label(classification.label());
|
||||
}
|
||||
if (classification.has_display_name()) {
|
||||
detection.add_display_name(classification.display_name());
|
||||
}
|
||||
}
|
||||
// Post-conversion sanity checks.
|
||||
if (detection.label_size() != 0 &&
|
||||
detection.label_size() != detection.label_id_size()) {
|
||||
return absl::InvalidArgumentError(absl::Substitute(
|
||||
"Each input Classification is expected to either always or never "
|
||||
"provide a 'label' field. Found $0 'label' fields for $1 "
|
||||
"'Classification' objects.",
|
||||
/*$0=*/detection.label_size(), /*$1=*/detection.label_id_size()));
|
||||
}
|
||||
if (detection.display_name_size() != 0 &&
|
||||
detection.display_name_size() != detection.label_id_size()) {
|
||||
return absl::InvalidArgumentError(absl::Substitute(
|
||||
"Each input Classification is expected to either always or never "
|
||||
"provide a 'display_name' field. Found $0 'display_name' fields for "
|
||||
"$1 'Classification' objects.",
|
||||
/*$0=*/detection.display_name_size(),
|
||||
/*$1=*/detection.label_id_size()));
|
||||
}
|
||||
}
|
||||
kOutputDetection(cc).Send(detection);
|
||||
return absl::OkStatus();
|
||||
}
|
||||
|
||||
} // namespace api2
|
||||
} // namespace mediapipe
|
|
@ -0,0 +1,320 @@
|
|||
// Copyright 2021 The MediaPipe Authors.
|
||||
//
|
||||
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||
// you may not use this file except in compliance with the License.
|
||||
// You may obtain a copy of the License at
|
||||
//
|
||||
// http://www.apache.org/licenses/LICENSE-2.0
|
||||
//
|
||||
// Unless required by applicable law or agreed to in writing, software
|
||||
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
// See the License for the specific language governing permissions and
|
||||
// limitations under the License.
|
||||
|
||||
#include "mediapipe/framework/calculator_framework.h"
|
||||
#include "mediapipe/framework/calculator_runner.h"
|
||||
#include "mediapipe/framework/deps/message_matchers.h"
|
||||
#include "mediapipe/framework/formats/classification.pb.h"
|
||||
#include "mediapipe/framework/formats/detection.pb.h"
|
||||
#include "mediapipe/framework/port/gtest.h"
|
||||
#include "mediapipe/framework/port/parse_text_proto.h"
|
||||
#include "mediapipe/framework/port/status_matchers.h"
|
||||
|
||||
namespace mediapipe {
|
||||
namespace {
|
||||
|
||||
constexpr char kGraphConfig[] = R"(
|
||||
input_stream: "input_detection"
|
||||
input_stream: "classification_list"
|
||||
output_stream: "output_detection"
|
||||
node {
|
||||
calculator: "DetectionClassificationsMergerCalculator"
|
||||
input_stream: "INPUT_DETECTION:input_detection"
|
||||
input_stream: "CLASSIFICATION_LIST:classification_list"
|
||||
output_stream: "OUTPUT_DETECTION:output_detection"
|
||||
}
|
||||
)";
|
||||
|
||||
constexpr char kInputDetection[] = R"(
|
||||
label: "entity"
|
||||
label_id: 1
|
||||
score: 0.9
|
||||
location_data {
|
||||
format: BOUNDING_BOX
|
||||
bounding_box { xmin: 50 ymin: 60 width: 70 height: 80 }
|
||||
}
|
||||
display_name: "Entity"
|
||||
)";
|
||||
|
||||
// Checks that the input Detection is returned unchanged if the input
|
||||
// ClassificationList does not contain any result.
|
||||
TEST(DetectionClassificationsMergerCalculator, SucceedsWithNoClassification) {
|
||||
auto graph_config = ParseTextProtoOrDie<CalculatorGraphConfig>(kGraphConfig);
|
||||
|
||||
// Prepare input packets.
|
||||
const Detection& input_detection =
|
||||
ParseTextProtoOrDie<Detection>(kInputDetection);
|
||||
Packet input_detection_packet =
|
||||
MakePacket<Detection>(input_detection).At(Timestamp(0));
|
||||
const ClassificationList& classification_list =
|
||||
ParseTextProtoOrDie<ClassificationList>("");
|
||||
Packet classification_list_packet =
|
||||
MakePacket<ClassificationList>(classification_list).At(Timestamp(0));
|
||||
|
||||
// Catch output.
|
||||
std::vector<Packet> output_packets;
|
||||
tool::AddVectorSink("output_detection", &graph_config, &output_packets);
|
||||
|
||||
// Run the graph.
|
||||
CalculatorGraph graph;
|
||||
MP_ASSERT_OK(graph.Initialize(graph_config));
|
||||
MP_ASSERT_OK(graph.StartRun({}));
|
||||
MP_ASSERT_OK(
|
||||
graph.AddPacketToInputStream("input_detection", input_detection_packet));
|
||||
MP_ASSERT_OK(graph.AddPacketToInputStream("classification_list",
|
||||
classification_list_packet));
|
||||
MP_ASSERT_OK(graph.WaitUntilIdle());
|
||||
|
||||
// Get and validate output.
|
||||
EXPECT_THAT(output_packets, testing::SizeIs(1));
|
||||
const Detection& output_detection = output_packets[0].Get<Detection>();
|
||||
EXPECT_THAT(output_detection, mediapipe::EqualsProto(input_detection));
|
||||
}
|
||||
|
||||
// Checks that merging succeeds when the input ClassificationList includes
|
||||
// labels and display names.
|
||||
TEST(DetectionClassificationsMergerCalculator,
|
||||
SucceedsWithLabelsAndDisplayNames) {
|
||||
auto graph_config = ParseTextProtoOrDie<CalculatorGraphConfig>(kGraphConfig);
|
||||
|
||||
// Prepare input packets.
|
||||
const Detection& input_detection =
|
||||
ParseTextProtoOrDie<Detection>(kInputDetection);
|
||||
Packet input_detection_packet =
|
||||
MakePacket<Detection>(input_detection).At(Timestamp(0));
|
||||
const ClassificationList& classification_list =
|
||||
ParseTextProtoOrDie<ClassificationList>(R"(
|
||||
classification { index: 11 score: 0.5 label: "dog" display_name: "Dog" }
|
||||
classification { index: 12 score: 0.4 label: "fox" display_name: "Fox" }
|
||||
)");
|
||||
Packet classification_list_packet =
|
||||
MakePacket<ClassificationList>(classification_list).At(Timestamp(0));
|
||||
|
||||
// Catch output.
|
||||
std::vector<Packet> output_packets;
|
||||
tool::AddVectorSink("output_detection", &graph_config, &output_packets);
|
||||
|
||||
// Run the graph.
|
||||
CalculatorGraph graph;
|
||||
MP_ASSERT_OK(graph.Initialize(graph_config));
|
||||
MP_ASSERT_OK(graph.StartRun({}));
|
||||
MP_ASSERT_OK(
|
||||
graph.AddPacketToInputStream("input_detection", input_detection_packet));
|
||||
MP_ASSERT_OK(graph.AddPacketToInputStream("classification_list",
|
||||
classification_list_packet));
|
||||
MP_ASSERT_OK(graph.WaitUntilIdle());
|
||||
|
||||
// Get and validate output.
|
||||
EXPECT_THAT(output_packets, testing::SizeIs(1));
|
||||
const Detection& output_detection = output_packets[0].Get<Detection>();
|
||||
EXPECT_THAT(output_detection,
|
||||
mediapipe::EqualsProto(ParseTextProtoOrDie<Detection>(R"(
|
||||
label: "dog"
|
||||
label: "fox"
|
||||
label_id: 11
|
||||
label_id: 12
|
||||
score: 0.5
|
||||
score: 0.4
|
||||
location_data {
|
||||
format: BOUNDING_BOX
|
||||
bounding_box { xmin: 50 ymin: 60 width: 70 height: 80 }
|
||||
}
|
||||
display_name: "Dog"
|
||||
display_name: "Fox"
|
||||
)")));
|
||||
}
|
||||
|
||||
// Checks that merging succeeds when the input ClassificationList doesn't
|
||||
// include labels and display names.
|
||||
TEST(DetectionClassificationsMergerCalculator,
|
||||
SucceedsWithoutLabelsAndDisplayNames) {
|
||||
auto graph_config = ParseTextProtoOrDie<CalculatorGraphConfig>(kGraphConfig);
|
||||
|
||||
// Prepare input packets.
|
||||
const Detection& input_detection =
|
||||
ParseTextProtoOrDie<Detection>(kInputDetection);
|
||||
Packet input_detection_packet =
|
||||
MakePacket<Detection>(input_detection).At(Timestamp(0));
|
||||
const ClassificationList& classification_list =
|
||||
ParseTextProtoOrDie<ClassificationList>(R"(
|
||||
classification { index: 11 score: 0.5 }
|
||||
classification { index: 12 score: 0.4 }
|
||||
)");
|
||||
Packet classification_list_packet =
|
||||
MakePacket<ClassificationList>(classification_list).At(Timestamp(0));
|
||||
|
||||
// Catch output.
|
||||
std::vector<Packet> output_packets;
|
||||
tool::AddVectorSink("output_detection", &graph_config, &output_packets);
|
||||
|
||||
// Run the graph.
|
||||
CalculatorGraph graph;
|
||||
MP_ASSERT_OK(graph.Initialize(graph_config));
|
||||
MP_ASSERT_OK(graph.StartRun({}));
|
||||
MP_ASSERT_OK(
|
||||
graph.AddPacketToInputStream("input_detection", input_detection_packet));
|
||||
MP_ASSERT_OK(graph.AddPacketToInputStream("classification_list",
|
||||
classification_list_packet));
|
||||
MP_ASSERT_OK(graph.WaitUntilIdle());
|
||||
|
||||
// Get and validate output.
|
||||
EXPECT_THAT(output_packets, testing::SizeIs(1));
|
||||
const Detection& output_detection = output_packets[0].Get<Detection>();
|
||||
EXPECT_THAT(output_detection,
|
||||
mediapipe::EqualsProto(ParseTextProtoOrDie<Detection>(R"(
|
||||
label_id: 11
|
||||
label_id: 12
|
||||
score: 0.5
|
||||
score: 0.4
|
||||
location_data {
|
||||
format: BOUNDING_BOX
|
||||
bounding_box { xmin: 50 ymin: 60 width: 70 height: 80 }
|
||||
}
|
||||
)")));
|
||||
}
|
||||
|
||||
// Checks that merging fails if the input ClassificationList misses mandatory
|
||||
// "index" field.
|
||||
TEST(DetectionClassificationsMergerCalculator, FailsWithMissingIndex) {
|
||||
auto graph_config = ParseTextProtoOrDie<CalculatorGraphConfig>(kGraphConfig);
|
||||
|
||||
// Prepare input packets.
|
||||
const Detection& input_detection =
|
||||
ParseTextProtoOrDie<Detection>(kInputDetection);
|
||||
Packet input_detection_packet =
|
||||
MakePacket<Detection>(input_detection).At(Timestamp(0));
|
||||
const ClassificationList& classification_list =
|
||||
ParseTextProtoOrDie<ClassificationList>(R"(
|
||||
classification { score: 0.5 label: "dog" }
|
||||
)");
|
||||
Packet classification_list_packet =
|
||||
MakePacket<ClassificationList>(classification_list).At(Timestamp(0));
|
||||
|
||||
// Catch output.
|
||||
std::vector<Packet> output_packets;
|
||||
tool::AddVectorSink("output_detection", &graph_config, &output_packets);
|
||||
|
||||
// Run the graph.
|
||||
CalculatorGraph graph;
|
||||
MP_ASSERT_OK(graph.Initialize(graph_config));
|
||||
MP_ASSERT_OK(graph.StartRun({}));
|
||||
MP_ASSERT_OK(
|
||||
graph.AddPacketToInputStream("input_detection", input_detection_packet));
|
||||
MP_ASSERT_OK(graph.AddPacketToInputStream("classification_list",
|
||||
classification_list_packet));
|
||||
ASSERT_EQ(graph.WaitUntilIdle().code(), absl::StatusCode::kInvalidArgument);
|
||||
}
|
||||
|
||||
// Checks that merging fails if the input ClassificationList misses mandatory
|
||||
// "score" field.
|
||||
TEST(DetectionClassificationsMergerCalculator, FailsWithMissingScore) {
|
||||
auto graph_config = ParseTextProtoOrDie<CalculatorGraphConfig>(kGraphConfig);
|
||||
|
||||
// Prepare input packets.
|
||||
const Detection& input_detection =
|
||||
ParseTextProtoOrDie<Detection>(kInputDetection);
|
||||
Packet input_detection_packet =
|
||||
MakePacket<Detection>(input_detection).At(Timestamp(0));
|
||||
const ClassificationList& classification_list =
|
||||
ParseTextProtoOrDie<ClassificationList>(R"(
|
||||
classification { index: 11 label: "dog" }
|
||||
)");
|
||||
Packet classification_list_packet =
|
||||
MakePacket<ClassificationList>(classification_list).At(Timestamp(0));
|
||||
|
||||
// Catch output.
|
||||
std::vector<Packet> output_packets;
|
||||
tool::AddVectorSink("output_detection", &graph_config, &output_packets);
|
||||
|
||||
// Run the graph.
|
||||
CalculatorGraph graph;
|
||||
MP_ASSERT_OK(graph.Initialize(graph_config));
|
||||
MP_ASSERT_OK(graph.StartRun({}));
|
||||
MP_ASSERT_OK(
|
||||
graph.AddPacketToInputStream("input_detection", input_detection_packet));
|
||||
MP_ASSERT_OK(graph.AddPacketToInputStream("classification_list",
|
||||
classification_list_packet));
|
||||
ASSERT_EQ(graph.WaitUntilIdle().code(), absl::StatusCode::kInvalidArgument);
|
||||
}
|
||||
|
||||
// Checks that merging fails if the input ClassificationList has an
|
||||
// inconsistent number of labels.
|
||||
TEST(DetectionClassificationsMergerCalculator,
|
||||
FailsWithInconsistentNumberOfLabels) {
|
||||
auto graph_config = ParseTextProtoOrDie<CalculatorGraphConfig>(kGraphConfig);
|
||||
|
||||
// Prepare input packets.
|
||||
const Detection& input_detection =
|
||||
ParseTextProtoOrDie<Detection>(kInputDetection);
|
||||
Packet input_detection_packet =
|
||||
MakePacket<Detection>(input_detection).At(Timestamp(0));
|
||||
const ClassificationList& classification_list =
|
||||
ParseTextProtoOrDie<ClassificationList>(R"(
|
||||
classification { index: 11 score: 0.5 label: "dog" display_name: "Dog" }
|
||||
classification { index: 12 score: 0.4 display_name: "Fox" }
|
||||
)");
|
||||
Packet classification_list_packet =
|
||||
MakePacket<ClassificationList>(classification_list).At(Timestamp(0));
|
||||
|
||||
// Catch output.
|
||||
std::vector<Packet> output_packets;
|
||||
tool::AddVectorSink("output_detection", &graph_config, &output_packets);
|
||||
|
||||
// Run the graph.
|
||||
CalculatorGraph graph;
|
||||
MP_ASSERT_OK(graph.Initialize(graph_config));
|
||||
MP_ASSERT_OK(graph.StartRun({}));
|
||||
MP_ASSERT_OK(
|
||||
graph.AddPacketToInputStream("input_detection", input_detection_packet));
|
||||
MP_ASSERT_OK(graph.AddPacketToInputStream("classification_list",
|
||||
classification_list_packet));
|
||||
ASSERT_EQ(graph.WaitUntilIdle().code(), absl::StatusCode::kInvalidArgument);
|
||||
}
|
||||
|
||||
// Checks that merging fails if the input ClassificationList has an
|
||||
// inconsistent number of display names.
|
||||
TEST(DetectionClassificationsMergerCalculator,
|
||||
FailsWithInconsistentNumberOfDisplayNames) {
|
||||
auto graph_config = ParseTextProtoOrDie<CalculatorGraphConfig>(kGraphConfig);
|
||||
|
||||
// Prepare input packets.
|
||||
const Detection& input_detection =
|
||||
ParseTextProtoOrDie<Detection>(kInputDetection);
|
||||
Packet input_detection_packet =
|
||||
MakePacket<Detection>(input_detection).At(Timestamp(0));
|
||||
const ClassificationList& classification_list =
|
||||
ParseTextProtoOrDie<ClassificationList>(R"(
|
||||
classification { index: 11 score: 0.5 label: "dog" }
|
||||
classification { index: 12 score: 0.4 label: "fox" display_name: "Fox" }
|
||||
)");
|
||||
Packet classification_list_packet =
|
||||
MakePacket<ClassificationList>(classification_list).At(Timestamp(0));
|
||||
|
||||
// Catch output.
|
||||
std::vector<Packet> output_packets;
|
||||
tool::AddVectorSink("output_detection", &graph_config, &output_packets);
|
||||
|
||||
// Run the graph.
|
||||
CalculatorGraph graph;
|
||||
MP_ASSERT_OK(graph.Initialize(graph_config));
|
||||
MP_ASSERT_OK(graph.StartRun({}));
|
||||
MP_ASSERT_OK(
|
||||
graph.AddPacketToInputStream("input_detection", input_detection_packet));
|
||||
MP_ASSERT_OK(graph.AddPacketToInputStream("classification_list",
|
||||
classification_list_packet));
|
||||
ASSERT_EQ(graph.WaitUntilIdle().code(), absl::StatusCode::kInvalidArgument);
|
||||
}
|
||||
|
||||
} // namespace
|
||||
} // namespace mediapipe
|
|
@ -12,12 +12,15 @@
|
|||
// See the License for the specific language governing permissions and
|
||||
// limitations under the License.
|
||||
|
||||
#include <memory>
|
||||
|
||||
#include "absl/algorithm/container.h"
|
||||
#include "mediapipe/calculators/util/landmarks_smoothing_calculator.pb.h"
|
||||
#include "mediapipe/framework/calculator_framework.h"
|
||||
#include "mediapipe/framework/formats/landmark.pb.h"
|
||||
#include "mediapipe/framework/port/ret_check.h"
|
||||
#include "mediapipe/framework/timestamp.h"
|
||||
#include "mediapipe/util/filtering/one_euro_filter.h"
|
||||
#include "mediapipe/util/filtering/relative_velocity_filter.h"
|
||||
|
||||
namespace mediapipe {
|
||||
|
@ -25,19 +28,54 @@ namespace mediapipe {
|
|||
namespace {
|
||||
|
||||
constexpr char kNormalizedLandmarksTag[] = "NORM_LANDMARKS";
|
||||
constexpr char kLandmarksTag[] = "LANDMARKS";
|
||||
constexpr char kImageSizeTag[] = "IMAGE_SIZE";
|
||||
constexpr char kNormalizedFilteredLandmarksTag[] = "NORM_FILTERED_LANDMARKS";
|
||||
constexpr char kFilteredLandmarksTag[] = "FILTERED_LANDMARKS";
|
||||
|
||||
using mediapipe::OneEuroFilter;
|
||||
using mediapipe::RelativeVelocityFilter;
|
||||
|
||||
void NormalizedLandmarksToLandmarks(
|
||||
const NormalizedLandmarkList& norm_landmarks, const int image_width,
|
||||
const int image_height, LandmarkList* landmarks) {
|
||||
for (int i = 0; i < norm_landmarks.landmark_size(); ++i) {
|
||||
const auto& norm_landmark = norm_landmarks.landmark(i);
|
||||
|
||||
auto* landmark = landmarks->add_landmark();
|
||||
landmark->set_x(norm_landmark.x() * image_width);
|
||||
landmark->set_y(norm_landmark.y() * image_height);
|
||||
// Scale Z the same way as X (using image width).
|
||||
landmark->set_z(norm_landmark.z() * image_width);
|
||||
landmark->set_visibility(norm_landmark.visibility());
|
||||
landmark->set_presence(norm_landmark.presence());
|
||||
}
|
||||
}
|
||||
|
||||
void LandmarksToNormalizedLandmarks(const LandmarkList& landmarks,
|
||||
const int image_width,
|
||||
const int image_height,
|
||||
NormalizedLandmarkList* norm_landmarks) {
|
||||
for (int i = 0; i < landmarks.landmark_size(); ++i) {
|
||||
const auto& landmark = landmarks.landmark(i);
|
||||
|
||||
auto* norm_landmark = norm_landmarks->add_landmark();
|
||||
norm_landmark->set_x(landmark.x() / image_width);
|
||||
norm_landmark->set_y(landmark.y() / image_height);
|
||||
// Scale Z the same way as X (using image width).
|
||||
norm_landmark->set_z(landmark.z() / image_width);
|
||||
norm_landmark->set_visibility(landmark.visibility());
|
||||
norm_landmark->set_presence(landmark.presence());
|
||||
}
|
||||
}
|
||||
|
||||
// Estimate object scale to use its inverse value as velocity scale for
|
||||
// RelativeVelocityFilter. If value will be too small (less than
|
||||
// `options_.min_allowed_object_scale`) smoothing will be disabled and
|
||||
// landmarks will be returned as is.
|
||||
// Object scale is calculated as average between bounding box width and height
|
||||
// with sides parallel to axis.
|
||||
float GetObjectScale(const NormalizedLandmarkList& landmarks, int image_width,
|
||||
int image_height) {
|
||||
float GetObjectScale(const LandmarkList& landmarks) {
|
||||
const auto& lm_minmax_x = absl::c_minmax_element(
|
||||
landmarks.landmark(),
|
||||
[](const auto& a, const auto& b) { return a.x() < b.x(); });
|
||||
|
@ -50,8 +88,8 @@ float GetObjectScale(const NormalizedLandmarkList& landmarks, int image_width,
|
|||
const float y_min = lm_minmax_y.first->y();
|
||||
const float y_max = lm_minmax_y.second->y();
|
||||
|
||||
const float object_width = (x_max - x_min) * image_width;
|
||||
const float object_height = (y_max - y_min) * image_height;
|
||||
const float object_width = x_max - x_min;
|
||||
const float object_height = y_max - y_min;
|
||||
|
||||
return (object_width + object_height) / 2.0f;
|
||||
}
|
||||
|
@ -63,19 +101,17 @@ class LandmarksFilter {
|
|||
|
||||
virtual absl::Status Reset() { return absl::OkStatus(); }
|
||||
|
||||
virtual absl::Status Apply(const NormalizedLandmarkList& in_landmarks,
|
||||
const std::pair<int, int>& image_size,
|
||||
virtual absl::Status Apply(const LandmarkList& in_landmarks,
|
||||
const absl::Duration& timestamp,
|
||||
NormalizedLandmarkList* out_landmarks) = 0;
|
||||
LandmarkList* out_landmarks) = 0;
|
||||
};
|
||||
|
||||
// Returns landmarks as is without smoothing.
|
||||
class NoFilter : public LandmarksFilter {
|
||||
public:
|
||||
absl::Status Apply(const NormalizedLandmarkList& in_landmarks,
|
||||
const std::pair<int, int>& image_size,
|
||||
absl::Status Apply(const LandmarkList& in_landmarks,
|
||||
const absl::Duration& timestamp,
|
||||
NormalizedLandmarkList* out_landmarks) override {
|
||||
LandmarkList* out_landmarks) override {
|
||||
*out_landmarks = in_landmarks;
|
||||
return absl::OkStatus();
|
||||
}
|
||||
|
@ -85,10 +121,11 @@ class NoFilter : public LandmarksFilter {
|
|||
class VelocityFilter : public LandmarksFilter {
|
||||
public:
|
||||
VelocityFilter(int window_size, float velocity_scale,
|
||||
float min_allowed_object_scale)
|
||||
float min_allowed_object_scale, bool disable_value_scaling)
|
||||
: window_size_(window_size),
|
||||
velocity_scale_(velocity_scale),
|
||||
min_allowed_object_scale_(min_allowed_object_scale) {}
|
||||
min_allowed_object_scale_(min_allowed_object_scale),
|
||||
disable_value_scaling_(disable_value_scaling) {}
|
||||
|
||||
absl::Status Reset() override {
|
||||
x_filters_.clear();
|
||||
|
@ -97,45 +134,37 @@ class VelocityFilter : public LandmarksFilter {
|
|||
return absl::OkStatus();
|
||||
}
|
||||
|
||||
absl::Status Apply(const NormalizedLandmarkList& in_landmarks,
|
||||
const std::pair<int, int>& image_size,
|
||||
absl::Status Apply(const LandmarkList& in_landmarks,
|
||||
const absl::Duration& timestamp,
|
||||
NormalizedLandmarkList* out_landmarks) override {
|
||||
// Get image size.
|
||||
int image_width;
|
||||
int image_height;
|
||||
std::tie(image_width, image_height) = image_size;
|
||||
|
||||
LandmarkList* out_landmarks) override {
|
||||
// Get value scale as inverse value of the object scale.
|
||||
// If value is too small smoothing will be disabled and landmarks will be
|
||||
// returned as is.
|
||||
const float object_scale =
|
||||
GetObjectScale(in_landmarks, image_width, image_height);
|
||||
if (object_scale < min_allowed_object_scale_) {
|
||||
*out_landmarks = in_landmarks;
|
||||
return absl::OkStatus();
|
||||
float value_scale = 1.0f;
|
||||
if (!disable_value_scaling_) {
|
||||
const float object_scale = GetObjectScale(in_landmarks);
|
||||
if (object_scale < min_allowed_object_scale_) {
|
||||
*out_landmarks = in_landmarks;
|
||||
return absl::OkStatus();
|
||||
}
|
||||
value_scale = 1.0f / object_scale;
|
||||
}
|
||||
const float value_scale = 1.0f / object_scale;
|
||||
|
||||
// Initialize filters once.
|
||||
MP_RETURN_IF_ERROR(InitializeFiltersIfEmpty(in_landmarks.landmark_size()));
|
||||
|
||||
// Filter landmarks. Every axis of every landmark is filtered separately.
|
||||
for (int i = 0; i < in_landmarks.landmark_size(); ++i) {
|
||||
const NormalizedLandmark& in_landmark = in_landmarks.landmark(i);
|
||||
const auto& in_landmark = in_landmarks.landmark(i);
|
||||
|
||||
NormalizedLandmark* out_landmark = out_landmarks->add_landmark();
|
||||
auto* out_landmark = out_landmarks->add_landmark();
|
||||
*out_landmark = in_landmark;
|
||||
out_landmark->set_x(x_filters_[i].Apply(timestamp, value_scale,
|
||||
in_landmark.x() * image_width) /
|
||||
image_width);
|
||||
out_landmark->set_y(y_filters_[i].Apply(timestamp, value_scale,
|
||||
in_landmark.y() * image_height) /
|
||||
image_height);
|
||||
// Scale Z the save was as X (using image width).
|
||||
out_landmark->set_z(z_filters_[i].Apply(timestamp, value_scale,
|
||||
in_landmark.z() * image_width) /
|
||||
image_width);
|
||||
out_landmark->set_x(
|
||||
x_filters_[i].Apply(timestamp, value_scale, in_landmark.x()));
|
||||
out_landmark->set_y(
|
||||
y_filters_[i].Apply(timestamp, value_scale, in_landmark.y()));
|
||||
out_landmark->set_z(
|
||||
z_filters_[i].Apply(timestamp, value_scale, in_landmark.z()));
|
||||
}
|
||||
|
||||
return absl::OkStatus();
|
||||
|
@ -165,12 +194,83 @@ class VelocityFilter : public LandmarksFilter {
|
|||
int window_size_;
|
||||
float velocity_scale_;
|
||||
float min_allowed_object_scale_;
|
||||
bool disable_value_scaling_;
|
||||
|
||||
std::vector<RelativeVelocityFilter> x_filters_;
|
||||
std::vector<RelativeVelocityFilter> y_filters_;
|
||||
std::vector<RelativeVelocityFilter> z_filters_;
|
||||
};
|
||||
|
||||
// Please check OneEuroFilter documentation for details.
|
||||
class OneEuroFilterImpl : public LandmarksFilter {
|
||||
public:
|
||||
OneEuroFilterImpl(double frequency, double min_cutoff, double beta,
|
||||
double derivate_cutoff)
|
||||
: frequency_(frequency),
|
||||
min_cutoff_(min_cutoff),
|
||||
beta_(beta),
|
||||
derivate_cutoff_(derivate_cutoff) {}
|
||||
|
||||
absl::Status Reset() override {
|
||||
x_filters_.clear();
|
||||
y_filters_.clear();
|
||||
z_filters_.clear();
|
||||
return absl::OkStatus();
|
||||
}
|
||||
|
||||
absl::Status Apply(const LandmarkList& in_landmarks,
|
||||
const absl::Duration& timestamp,
|
||||
LandmarkList* out_landmarks) override {
|
||||
// Initialize filters once.
|
||||
MP_RETURN_IF_ERROR(InitializeFiltersIfEmpty(in_landmarks.landmark_size()));
|
||||
|
||||
// Filter landmarks. Every axis of every landmark is filtered separately.
|
||||
for (int i = 0; i < in_landmarks.landmark_size(); ++i) {
|
||||
const auto& in_landmark = in_landmarks.landmark(i);
|
||||
|
||||
auto* out_landmark = out_landmarks->add_landmark();
|
||||
*out_landmark = in_landmark;
|
||||
out_landmark->set_x(x_filters_[i].Apply(timestamp, in_landmark.x()));
|
||||
out_landmark->set_y(y_filters_[i].Apply(timestamp, in_landmark.y()));
|
||||
out_landmark->set_z(z_filters_[i].Apply(timestamp, in_landmark.z()));
|
||||
}
|
||||
|
||||
return absl::OkStatus();
|
||||
}
|
||||
|
||||
private:
|
||||
// Initializes filters for the first time or after Reset. If initialized then
|
||||
// check the size.
|
||||
absl::Status InitializeFiltersIfEmpty(const int n_landmarks) {
|
||||
if (!x_filters_.empty()) {
|
||||
RET_CHECK_EQ(x_filters_.size(), n_landmarks);
|
||||
RET_CHECK_EQ(y_filters_.size(), n_landmarks);
|
||||
RET_CHECK_EQ(z_filters_.size(), n_landmarks);
|
||||
return absl::OkStatus();
|
||||
}
|
||||
|
||||
for (int i = 0; i < n_landmarks; ++i) {
|
||||
x_filters_.push_back(
|
||||
OneEuroFilter(frequency_, min_cutoff_, beta_, derivate_cutoff_));
|
||||
y_filters_.push_back(
|
||||
OneEuroFilter(frequency_, min_cutoff_, beta_, derivate_cutoff_));
|
||||
z_filters_.push_back(
|
||||
OneEuroFilter(frequency_, min_cutoff_, beta_, derivate_cutoff_));
|
||||
}
|
||||
|
||||
return absl::OkStatus();
|
||||
}
|
||||
|
||||
double frequency_;
|
||||
double min_cutoff_;
|
||||
double beta_;
|
||||
double derivate_cutoff_;
|
||||
|
||||
std::vector<OneEuroFilter> x_filters_;
|
||||
std::vector<OneEuroFilter> y_filters_;
|
||||
std::vector<OneEuroFilter> z_filters_;
|
||||
};
|
||||
|
||||
} // namespace
|
||||
|
||||
// A calculator to smooth landmarks over time.
|
||||
|
@ -207,16 +307,21 @@ class LandmarksSmoothingCalculator : public CalculatorBase {
|
|||
absl::Status Process(CalculatorContext* cc) override;
|
||||
|
||||
private:
|
||||
LandmarksFilter* landmarks_filter_;
|
||||
std::unique_ptr<LandmarksFilter> landmarks_filter_;
|
||||
};
|
||||
REGISTER_CALCULATOR(LandmarksSmoothingCalculator);
|
||||
|
||||
absl::Status LandmarksSmoothingCalculator::GetContract(CalculatorContract* cc) {
|
||||
cc->Inputs().Tag(kNormalizedLandmarksTag).Set<NormalizedLandmarkList>();
|
||||
cc->Inputs().Tag(kImageSizeTag).Set<std::pair<int, int>>();
|
||||
cc->Outputs()
|
||||
.Tag(kNormalizedFilteredLandmarksTag)
|
||||
.Set<NormalizedLandmarkList>();
|
||||
if (cc->Inputs().HasTag(kNormalizedLandmarksTag)) {
|
||||
cc->Inputs().Tag(kNormalizedLandmarksTag).Set<NormalizedLandmarkList>();
|
||||
cc->Inputs().Tag(kImageSizeTag).Set<std::pair<int, int>>();
|
||||
cc->Outputs()
|
||||
.Tag(kNormalizedFilteredLandmarksTag)
|
||||
.Set<NormalizedLandmarkList>();
|
||||
} else {
|
||||
cc->Inputs().Tag(kLandmarksTag).Set<LandmarkList>();
|
||||
cc->Outputs().Tag(kFilteredLandmarksTag).Set<LandmarkList>();
|
||||
}
|
||||
|
||||
return absl::OkStatus();
|
||||
}
|
||||
|
@ -227,12 +332,19 @@ absl::Status LandmarksSmoothingCalculator::Open(CalculatorContext* cc) {
|
|||
// Pick landmarks filter.
|
||||
const auto& options = cc->Options<LandmarksSmoothingCalculatorOptions>();
|
||||
if (options.has_no_filter()) {
|
||||
landmarks_filter_ = new NoFilter();
|
||||
landmarks_filter_ = absl::make_unique<NoFilter>();
|
||||
} else if (options.has_velocity_filter()) {
|
||||
landmarks_filter_ = new VelocityFilter(
|
||||
landmarks_filter_ = absl::make_unique<VelocityFilter>(
|
||||
options.velocity_filter().window_size(),
|
||||
options.velocity_filter().velocity_scale(),
|
||||
options.velocity_filter().min_allowed_object_scale());
|
||||
options.velocity_filter().min_allowed_object_scale(),
|
||||
options.velocity_filter().disable_value_scaling());
|
||||
} else if (options.has_one_euro_filter()) {
|
||||
landmarks_filter_ = absl::make_unique<OneEuroFilterImpl>(
|
||||
options.one_euro_filter().frequency(),
|
||||
options.one_euro_filter().min_cutoff(),
|
||||
options.one_euro_filter().beta(),
|
||||
options.one_euro_filter().derivate_cutoff());
|
||||
} else {
|
||||
RET_CHECK_FAIL()
|
||||
<< "Landmarks filter is either not specified or not supported";
|
||||
|
@ -244,25 +356,53 @@ absl::Status LandmarksSmoothingCalculator::Open(CalculatorContext* cc) {
|
|||
absl::Status LandmarksSmoothingCalculator::Process(CalculatorContext* cc) {
|
||||
// Check that landmarks are not empty and reset the filter if so.
|
||||
// Don't emit an empty packet for this timestamp.
|
||||
if (cc->Inputs().Tag(kNormalizedLandmarksTag).IsEmpty()) {
|
||||
if ((cc->Inputs().HasTag(kNormalizedLandmarksTag) &&
|
||||
cc->Inputs().Tag(kNormalizedLandmarksTag).IsEmpty()) ||
|
||||
(cc->Inputs().HasTag(kLandmarksTag) &&
|
||||
cc->Inputs().Tag(kLandmarksTag).IsEmpty())) {
|
||||
MP_RETURN_IF_ERROR(landmarks_filter_->Reset());
|
||||
return absl::OkStatus();
|
||||
}
|
||||
|
||||
const auto& in_landmarks =
|
||||
cc->Inputs().Tag(kNormalizedLandmarksTag).Get<NormalizedLandmarkList>();
|
||||
const auto& image_size =
|
||||
cc->Inputs().Tag(kImageSizeTag).Get<std::pair<int, int>>();
|
||||
const auto& timestamp =
|
||||
absl::Microseconds(cc->InputTimestamp().Microseconds());
|
||||
|
||||
auto out_landmarks = absl::make_unique<NormalizedLandmarkList>();
|
||||
MP_RETURN_IF_ERROR(landmarks_filter_->Apply(in_landmarks, image_size,
|
||||
timestamp, out_landmarks.get()));
|
||||
if (cc->Inputs().HasTag(kNormalizedLandmarksTag)) {
|
||||
const auto& in_norm_landmarks =
|
||||
cc->Inputs().Tag(kNormalizedLandmarksTag).Get<NormalizedLandmarkList>();
|
||||
|
||||
cc->Outputs()
|
||||
.Tag(kNormalizedFilteredLandmarksTag)
|
||||
.Add(out_landmarks.release(), cc->InputTimestamp());
|
||||
int image_width;
|
||||
int image_height;
|
||||
std::tie(image_width, image_height) =
|
||||
cc->Inputs().Tag(kImageSizeTag).Get<std::pair<int, int>>();
|
||||
|
||||
auto in_landmarks = absl::make_unique<LandmarkList>();
|
||||
NormalizedLandmarksToLandmarks(in_norm_landmarks, image_width, image_height,
|
||||
in_landmarks.get());
|
||||
|
||||
auto out_landmarks = absl::make_unique<LandmarkList>();
|
||||
MP_RETURN_IF_ERROR(landmarks_filter_->Apply(*in_landmarks, timestamp,
|
||||
out_landmarks.get()));
|
||||
|
||||
auto out_norm_landmarks = absl::make_unique<NormalizedLandmarkList>();
|
||||
LandmarksToNormalizedLandmarks(*out_landmarks, image_width, image_height,
|
||||
out_norm_landmarks.get());
|
||||
|
||||
cc->Outputs()
|
||||
.Tag(kNormalizedFilteredLandmarksTag)
|
||||
.Add(out_norm_landmarks.release(), cc->InputTimestamp());
|
||||
} else {
|
||||
const auto& in_landmarks =
|
||||
cc->Inputs().Tag(kLandmarksTag).Get<LandmarkList>();
|
||||
|
||||
auto out_landmarks = absl::make_unique<LandmarkList>();
|
||||
MP_RETURN_IF_ERROR(
|
||||
landmarks_filter_->Apply(in_landmarks, timestamp, out_landmarks.get()));
|
||||
|
||||
cc->Outputs()
|
||||
.Tag(kFilteredLandmarksTag)
|
||||
.Add(out_landmarks.release(), cc->InputTimestamp());
|
||||
}
|
||||
|
||||
return absl::OkStatus();
|
||||
}
|
||||
|
|
|
@ -39,10 +39,40 @@ message LandmarksSmoothingCalculatorOptions {
|
|||
// If calculated object scale is less than given value smoothing will be
|
||||
// disabled and landmarks will be returned as is.
|
||||
optional float min_allowed_object_scale = 3 [default = 1e-6];
|
||||
|
||||
// Disable value scaling based on object size and use `1.0` instead.
|
||||
// Value scale is calculated as inverse value of object size. Object size is
|
||||
// calculated as maximum side of rectangular bounding box of the object in
|
||||
// XY plane.
|
||||
optional bool disable_value_scaling = 4 [default = false];
|
||||
}
|
||||
|
||||
// For the details of the filter implementation and the procedure of its
|
||||
// configuration please check http://cristal.univ-lille.fr/~casiez/1euro/
|
||||
message OneEuroFilter {
|
||||
// Frequency of incomming frames defined in seconds. Used only if can't be
|
||||
// calculated from provided events (e.g. on the very first frame).
|
||||
optional float frequency = 1 [default = 0.033];
|
||||
|
||||
// Minimum cutoff frequency. Start by tuning this parameter while keeping
|
||||
// `beta = 0` to reduce jittering to the desired level. 1Hz (the default
|
||||
// value) is a good starting point.
|
||||
optional float min_cutoff = 2 [default = 1.0];
|
||||
|
||||
// Cutoff slope. After `min_cutoff` is configured, start increasing `beta`
|
||||
// value to reduce the lag introduced by the `min_cutoff`. Find the desired
|
||||
// balance between jittering and lag.
|
||||
optional float beta = 3 [default = 0.0];
|
||||
|
||||
// Cutoff frequency for derivate. It is set to 1Hz in the original
|
||||
// algorithm, but can be tuned to further smooth the speed (i.e. derivate)
|
||||
// on the object.
|
||||
optional float derivate_cutoff = 4 [default = 1.0];
|
||||
}
|
||||
|
||||
oneof filter_options {
|
||||
NoFilter no_filter = 1;
|
||||
VelocityFilter velocity_filter = 2;
|
||||
OneEuroFilter one_euro_filter = 3;
|
||||
}
|
||||
}
|
||||
|
|
|
@ -34,6 +34,34 @@ constexpr char kRenderScaleTag[] = "RENDER_SCALE";
|
|||
constexpr char kRenderDataTag[] = "RENDER_DATA";
|
||||
constexpr char kLandmarkLabel[] = "KEYPOINT";
|
||||
|
||||
inline Color DefaultMinDepthLineColor() {
|
||||
Color color;
|
||||
color.set_r(0);
|
||||
color.set_g(0);
|
||||
color.set_b(0);
|
||||
return color;
|
||||
}
|
||||
|
||||
inline Color DefaultMaxDepthLineColor() {
|
||||
Color color;
|
||||
color.set_r(255);
|
||||
color.set_g(255);
|
||||
color.set_b(255);
|
||||
return color;
|
||||
}
|
||||
|
||||
inline Color MixColors(const Color& color1, const Color& color2,
|
||||
float color1_weight) {
|
||||
Color color;
|
||||
color.set_r(static_cast<int>(color1.r() * color1_weight +
|
||||
color2.r() * (1.f - color1_weight)));
|
||||
color.set_g(static_cast<int>(color1.g() * color1_weight +
|
||||
color2.g() * (1.f - color1_weight)));
|
||||
color.set_b(static_cast<int>(color1.b() * color1_weight +
|
||||
color2.b() * (1.f - color1_weight)));
|
||||
return color;
|
||||
}
|
||||
|
||||
inline void SetColor(RenderAnnotation* annotation, const Color& color) {
|
||||
annotation->mutable_color()->set_r(color.r());
|
||||
annotation->mutable_color()->set_g(color.g());
|
||||
|
@ -57,6 +85,23 @@ inline void GetMinMaxZ(const LandmarkListType& landmarks, float* z_min,
|
|||
}
|
||||
}
|
||||
|
||||
template <class LandmarkType>
|
||||
bool IsLandmarkVisibileAndPresent(const LandmarkType& landmark,
|
||||
bool utilize_visibility,
|
||||
float visibility_threshold,
|
||||
bool utilize_presence,
|
||||
float presence_threshold) {
|
||||
if (utilize_visibility && landmark.has_visibility() &&
|
||||
landmark.visibility() < visibility_threshold) {
|
||||
return false;
|
||||
}
|
||||
if (utilize_presence && landmark.has_presence() &&
|
||||
landmark.presence() < presence_threshold) {
|
||||
return false;
|
||||
}
|
||||
return true;
|
||||
}
|
||||
|
||||
void SetColorSizeValueFromZ(float z, float z_min, float z_max,
|
||||
RenderAnnotation* render_annotation,
|
||||
float min_depth_circle_thickness,
|
||||
|
@ -75,8 +120,9 @@ void SetColorSizeValueFromZ(float z, float z_min, float z_max,
|
|||
|
||||
template <class LandmarkType>
|
||||
void AddConnectionToRenderData(const LandmarkType& start,
|
||||
const LandmarkType& end, int gray_val1,
|
||||
int gray_val2, float thickness, bool normalized,
|
||||
const LandmarkType& end,
|
||||
const Color& color_start, const Color& color_end,
|
||||
float thickness, bool normalized,
|
||||
RenderData* render_data) {
|
||||
auto* connection_annotation = render_data->add_render_annotations();
|
||||
RenderAnnotation::GradientLine* line =
|
||||
|
@ -86,12 +132,13 @@ void AddConnectionToRenderData(const LandmarkType& start,
|
|||
line->set_x_end(end.x());
|
||||
line->set_y_end(end.y());
|
||||
line->set_normalized(normalized);
|
||||
line->mutable_color1()->set_r(gray_val1);
|
||||
line->mutable_color1()->set_g(gray_val1);
|
||||
line->mutable_color1()->set_b(gray_val1);
|
||||
line->mutable_color2()->set_r(gray_val2);
|
||||
line->mutable_color2()->set_g(gray_val2);
|
||||
line->mutable_color2()->set_b(gray_val2);
|
||||
line->mutable_color1()->set_r(color_start.r());
|
||||
line->mutable_color1()->set_g(color_start.g());
|
||||
line->mutable_color1()->set_b(color_start.b());
|
||||
line->mutable_color2()->set_r(color_end.r());
|
||||
line->mutable_color2()->set_g(color_end.g());
|
||||
line->mutable_color2()->set_b(color_end.b());
|
||||
|
||||
connection_annotation->set_thickness(thickness);
|
||||
}
|
||||
|
||||
|
@ -102,26 +149,26 @@ void AddConnectionsWithDepth(const LandmarkListType& landmarks,
|
|||
float visibility_threshold, bool utilize_presence,
|
||||
float presence_threshold, float thickness,
|
||||
bool normalized, float min_z, float max_z,
|
||||
const Color& min_depth_line_color,
|
||||
const Color& max_depth_line_color,
|
||||
RenderData* render_data) {
|
||||
for (int i = 0; i < landmark_connections.size(); i += 2) {
|
||||
const auto& ld0 = landmarks.landmark(landmark_connections[i]);
|
||||
const auto& ld1 = landmarks.landmark(landmark_connections[i + 1]);
|
||||
if (utilize_visibility &&
|
||||
((ld0.has_visibility() && ld0.visibility() < visibility_threshold) ||
|
||||
(ld1.has_visibility() && ld1.visibility() < visibility_threshold))) {
|
||||
if (!IsLandmarkVisibileAndPresent<LandmarkType>(
|
||||
ld0, utilize_visibility, visibility_threshold, utilize_presence,
|
||||
presence_threshold) ||
|
||||
!IsLandmarkVisibileAndPresent<LandmarkType>(
|
||||
ld1, utilize_visibility, visibility_threshold, utilize_presence,
|
||||
presence_threshold)) {
|
||||
continue;
|
||||
}
|
||||
if (utilize_presence &&
|
||||
((ld0.has_presence() && ld0.presence() < presence_threshold) ||
|
||||
(ld1.has_presence() && ld1.presence() < presence_threshold))) {
|
||||
continue;
|
||||
}
|
||||
const int gray_val1 =
|
||||
255 - static_cast<int>(Remap(ld0.z(), min_z, max_z, 255));
|
||||
const int gray_val2 =
|
||||
255 - static_cast<int>(Remap(ld1.z(), min_z, max_z, 255));
|
||||
AddConnectionToRenderData<LandmarkType>(ld0, ld1, gray_val1, gray_val2,
|
||||
thickness, normalized, render_data);
|
||||
const Color color0 = MixColors(min_depth_line_color, max_depth_line_color,
|
||||
Remap(ld0.z(), min_z, max_z, 1.f));
|
||||
const Color color1 = MixColors(min_depth_line_color, max_depth_line_color,
|
||||
Remap(ld1.z(), min_z, max_z, 1.f));
|
||||
AddConnectionToRenderData<LandmarkType>(ld0, ld1, color0, color1, thickness,
|
||||
normalized, render_data);
|
||||
}
|
||||
}
|
||||
|
||||
|
@ -151,14 +198,12 @@ void AddConnections(const LandmarkListType& landmarks,
|
|||
for (int i = 0; i < landmark_connections.size(); i += 2) {
|
||||
const auto& ld0 = landmarks.landmark(landmark_connections[i]);
|
||||
const auto& ld1 = landmarks.landmark(landmark_connections[i + 1]);
|
||||
if (utilize_visibility &&
|
||||
((ld0.has_visibility() && ld0.visibility() < visibility_threshold) ||
|
||||
(ld1.has_visibility() && ld1.visibility() < visibility_threshold))) {
|
||||
continue;
|
||||
}
|
||||
if (utilize_presence &&
|
||||
((ld0.has_presence() && ld0.presence() < presence_threshold) ||
|
||||
(ld1.has_presence() && ld1.presence() < presence_threshold))) {
|
||||
if (!IsLandmarkVisibileAndPresent<LandmarkType>(
|
||||
ld0, utilize_visibility, visibility_threshold, utilize_presence,
|
||||
presence_threshold) ||
|
||||
!IsLandmarkVisibileAndPresent<LandmarkType>(
|
||||
ld1, utilize_visibility, visibility_threshold, utilize_presence,
|
||||
presence_threshold)) {
|
||||
continue;
|
||||
}
|
||||
AddConnectionToRenderData<LandmarkType>(ld0, ld1, connection_color,
|
||||
|
@ -232,6 +277,13 @@ absl::Status LandmarksToRenderDataCalculator::Process(CalculatorContext* cc) {
|
|||
float z_min = 0.f;
|
||||
float z_max = 0.f;
|
||||
|
||||
const Color min_depth_line_color = options_.has_min_depth_line_color()
|
||||
? options_.min_depth_line_color()
|
||||
: DefaultMinDepthLineColor();
|
||||
const Color max_depth_line_color = options_.has_max_depth_line_color()
|
||||
? options_.max_depth_line_color()
|
||||
: DefaultMaxDepthLineColor();
|
||||
|
||||
// Apply scale to `thickness` of rendered landmarks and connections to make
|
||||
// them bigger when object (e.g. pose, hand or face) is closer/bigger and
|
||||
// snaller when object is further/smaller.
|
||||
|
@ -254,7 +306,7 @@ absl::Status LandmarksToRenderDataCalculator::Process(CalculatorContext* cc) {
|
|||
landmarks, landmark_connections_, options_.utilize_visibility(),
|
||||
options_.visibility_threshold(), options_.utilize_presence(),
|
||||
options_.presence_threshold(), thickness, /*normalized=*/false, z_min,
|
||||
z_max, render_data.get());
|
||||
z_max, min_depth_line_color, max_depth_line_color, render_data.get());
|
||||
} else {
|
||||
AddConnections<LandmarkList, Landmark>(
|
||||
landmarks, landmark_connections_, options_.utilize_visibility(),
|
||||
|
@ -265,13 +317,10 @@ absl::Status LandmarksToRenderDataCalculator::Process(CalculatorContext* cc) {
|
|||
for (int i = 0; i < landmarks.landmark_size(); ++i) {
|
||||
const Landmark& landmark = landmarks.landmark(i);
|
||||
|
||||
if (options_.utilize_visibility() && landmark.has_visibility() &&
|
||||
landmark.visibility() < options_.visibility_threshold()) {
|
||||
continue;
|
||||
}
|
||||
|
||||
if (options_.utilize_presence() && landmark.has_presence() &&
|
||||
landmark.presence() < options_.presence_threshold()) {
|
||||
if (!IsLandmarkVisibileAndPresent<Landmark>(
|
||||
landmark, options_.utilize_visibility(),
|
||||
options_.visibility_threshold(), options_.utilize_presence(),
|
||||
options_.presence_threshold())) {
|
||||
continue;
|
||||
}
|
||||
|
||||
|
@ -303,7 +352,7 @@ absl::Status LandmarksToRenderDataCalculator::Process(CalculatorContext* cc) {
|
|||
landmarks, landmark_connections_, options_.utilize_visibility(),
|
||||
options_.visibility_threshold(), options_.utilize_presence(),
|
||||
options_.presence_threshold(), thickness, /*normalized=*/true, z_min,
|
||||
z_max, render_data.get());
|
||||
z_max, min_depth_line_color, max_depth_line_color, render_data.get());
|
||||
} else {
|
||||
AddConnections<NormalizedLandmarkList, NormalizedLandmark>(
|
||||
landmarks, landmark_connections_, options_.utilize_visibility(),
|
||||
|
@ -314,12 +363,10 @@ absl::Status LandmarksToRenderDataCalculator::Process(CalculatorContext* cc) {
|
|||
for (int i = 0; i < landmarks.landmark_size(); ++i) {
|
||||
const NormalizedLandmark& landmark = landmarks.landmark(i);
|
||||
|
||||
if (options_.utilize_visibility() && landmark.has_visibility() &&
|
||||
landmark.visibility() < options_.visibility_threshold()) {
|
||||
continue;
|
||||
}
|
||||
if (options_.utilize_presence() && landmark.has_presence() &&
|
||||
landmark.presence() < options_.presence_threshold()) {
|
||||
if (!IsLandmarkVisibileAndPresent<NormalizedLandmark>(
|
||||
landmark, options_.utilize_visibility(),
|
||||
options_.visibility_threshold(), options_.utilize_presence(),
|
||||
options_.presence_threshold())) {
|
||||
continue;
|
||||
}
|
||||
|
||||
|
|
|
@ -64,4 +64,10 @@ message LandmarksToRenderDataCalculatorOptions {
|
|||
|
||||
// Max thickness of the drawing for landmark circle.
|
||||
optional double max_depth_circle_thickness = 11 [default = 18.0];
|
||||
|
||||
// Gradient color for the lines connecting landmarks at the minimum depth.
|
||||
optional Color min_depth_line_color = 12;
|
||||
|
||||
// Gradient color for the lines connecting landmarks at the maximum depth.
|
||||
optional Color max_depth_line_color = 13;
|
||||
}
|
||||
|
|
194
mediapipe/calculators/util/visibility_copy_calculator.cc
Normal file
194
mediapipe/calculators/util/visibility_copy_calculator.cc
Normal file
|
@ -0,0 +1,194 @@
|
|||
// Copyright 2021 The MediaPipe Authors.
|
||||
//
|
||||
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||
// you may not use this file except in compliance with the License.
|
||||
// You may obtain a copy of the License at
|
||||
//
|
||||
// http://www.apache.org/licenses/LICENSE-2.0
|
||||
//
|
||||
// Unless required by applicable law or agreed to in writing, software
|
||||
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
// See the License for the specific language governing permissions and
|
||||
// limitations under the License.
|
||||
|
||||
#include <memory>
|
||||
|
||||
#include "absl/algorithm/container.h"
|
||||
#include "mediapipe/calculators/util/visibility_copy_calculator.pb.h"
|
||||
#include "mediapipe/framework/calculator_framework.h"
|
||||
#include "mediapipe/framework/formats/landmark.pb.h"
|
||||
#include "mediapipe/framework/port/ret_check.h"
|
||||
#include "mediapipe/framework/timestamp.h"
|
||||
|
||||
namespace mediapipe {
|
||||
|
||||
namespace {
|
||||
|
||||
constexpr char kLandmarksFromTag[] = "LANDMARKS_FROM";
|
||||
constexpr char kNormalizedLandmarksFromTag[] = "NORM_LANDMARKS_FROM";
|
||||
constexpr char kLandmarksToTag[] = "LANDMARKS_TO";
|
||||
constexpr char kNormalizedLandmarksToTag[] = "NORM_LANDMARKS_TO";
|
||||
|
||||
} // namespace
|
||||
|
||||
// A calculator to copy visibility and presence between landmarks.
|
||||
//
|
||||
// Landmarks to copy from and to copy to can be of different type (normalized or
|
||||
// non-normalized), but ladnmarks to copy to and output landmarks should be of
|
||||
// the same type. Exactly one stream to copy landmarks from, to copy to and to
|
||||
// output should be provided.
|
||||
//
|
||||
// Inputs:
|
||||
// LANDMARKS_FROM (optional): A LandmarkList of landmarks to copy from.
|
||||
// NORM_LANDMARKS_FROM (optional): A NormalizedLandmarkList of landmarks to
|
||||
// copy from.
|
||||
// LANDMARKS_TO (optional): A LandmarkList of landmarks to copy to.
|
||||
// NORM_LANDMARKS_TO (optional): A NormalizedLandmarkList of landmarks to copy
|
||||
// to.
|
||||
//
|
||||
// Outputs:
|
||||
// LANDMARKS_TO (optional): A LandmarkList of landmarks from LANDMARKS_TO and
|
||||
// visibility/presence from LANDMARKS_FROM or NORM_LANDMARKS_FROM.
|
||||
// NORM_LANDMARKS_TO (optional): A NormalizedLandmarkList of landmarks to copy
|
||||
// to.
|
||||
//
|
||||
// Example config:
|
||||
// node {
|
||||
// calculator: "VisibilityCopyCalculator"
|
||||
// input_stream: "NORM_LANDMARKS_FROM:pose_landmarks"
|
||||
// input_stream: "LANDMARKS_TO:pose_world_landmarks"
|
||||
// output_stream: "LANDMARKS_TO:pose_world_landmarks_with_visibility"
|
||||
// options: {
|
||||
// [mediapipe.VisibilityCopyCalculatorOptions.ext] {
|
||||
// copy_visibility: true
|
||||
// copy_presence: true
|
||||
// }
|
||||
// }
|
||||
// }
|
||||
//
|
||||
class VisibilityCopyCalculator : public CalculatorBase {
|
||||
public:
|
||||
static absl::Status GetContract(CalculatorContract* cc);
|
||||
absl::Status Open(CalculatorContext* cc) override;
|
||||
absl::Status Process(CalculatorContext* cc) override;
|
||||
|
||||
private:
|
||||
template <class LandmarkFromType, class LandmarkToType>
|
||||
absl::Status CopyVisibility(CalculatorContext* cc,
|
||||
const std::string& landmarks_from_tag,
|
||||
const std::string& landmarks_to_tag);
|
||||
|
||||
bool copy_visibility_;
|
||||
bool copy_presence_;
|
||||
};
|
||||
REGISTER_CALCULATOR(VisibilityCopyCalculator);
|
||||
|
||||
absl::Status VisibilityCopyCalculator::GetContract(CalculatorContract* cc) {
|
||||
// Landmarks to copy from.
|
||||
RET_CHECK(cc->Inputs().HasTag(kLandmarksFromTag) ^
|
||||
cc->Inputs().HasTag(kNormalizedLandmarksFromTag))
|
||||
<< "Exatly one landmarks stream to copy from should be provided";
|
||||
if (cc->Inputs().HasTag(kLandmarksFromTag)) {
|
||||
cc->Inputs().Tag(kLandmarksFromTag).Set<LandmarkList>();
|
||||
} else {
|
||||
cc->Inputs().Tag(kNormalizedLandmarksFromTag).Set<NormalizedLandmarkList>();
|
||||
}
|
||||
|
||||
// Landmarks to copy to and corresponding output landmarks.
|
||||
RET_CHECK(cc->Inputs().HasTag(kLandmarksToTag) ^
|
||||
cc->Inputs().HasTag(kNormalizedLandmarksToTag))
|
||||
<< "Exatly one landmarks stream to copy to should be provided";
|
||||
if (cc->Inputs().HasTag(kLandmarksToTag)) {
|
||||
cc->Inputs().Tag(kLandmarksToTag).Set<LandmarkList>();
|
||||
|
||||
RET_CHECK(cc->Outputs().HasTag(kLandmarksToTag))
|
||||
<< "Landmarks to copy to and output stream types should be the same";
|
||||
cc->Outputs().Tag(kLandmarksToTag).Set<LandmarkList>();
|
||||
} else {
|
||||
cc->Inputs().Tag(kNormalizedLandmarksToTag).Set<NormalizedLandmarkList>();
|
||||
|
||||
RET_CHECK(cc->Outputs().HasTag(kNormalizedLandmarksToTag))
|
||||
<< "Landmarks to copy to and output stream types should be the same";
|
||||
cc->Outputs().Tag(kNormalizedLandmarksToTag).Set<NormalizedLandmarkList>();
|
||||
}
|
||||
|
||||
return absl::OkStatus();
|
||||
}
|
||||
|
||||
absl::Status VisibilityCopyCalculator::Open(CalculatorContext* cc) {
|
||||
cc->SetOffset(TimestampDiff(0));
|
||||
|
||||
const auto& options = cc->Options<VisibilityCopyCalculatorOptions>();
|
||||
copy_visibility_ = options.copy_visibility();
|
||||
copy_presence_ = options.copy_presence();
|
||||
|
||||
return absl::OkStatus();
|
||||
}
|
||||
|
||||
absl::Status VisibilityCopyCalculator::Process(CalculatorContext* cc) {
|
||||
// Switch between all four possible combinations of landmarks from and
|
||||
// landmarks to types (normalized and non-normalized).
|
||||
auto status = absl::OkStatus();
|
||||
if (cc->Inputs().HasTag(kLandmarksFromTag)) {
|
||||
if (cc->Inputs().HasTag(kLandmarksToTag)) {
|
||||
status = CopyVisibility<LandmarkList, LandmarkList>(cc, kLandmarksFromTag,
|
||||
kLandmarksToTag);
|
||||
} else {
|
||||
status = CopyVisibility<LandmarkList, NormalizedLandmarkList>(
|
||||
cc, kLandmarksFromTag, kNormalizedLandmarksToTag);
|
||||
}
|
||||
} else {
|
||||
if (cc->Inputs().HasTag(kLandmarksToTag)) {
|
||||
status = CopyVisibility<NormalizedLandmarkList, LandmarkList>(
|
||||
cc, kNormalizedLandmarksFromTag, kLandmarksToTag);
|
||||
} else {
|
||||
status = CopyVisibility<NormalizedLandmarkList, NormalizedLandmarkList>(
|
||||
cc, kNormalizedLandmarksFromTag, kNormalizedLandmarksToTag);
|
||||
}
|
||||
}
|
||||
|
||||
return status;
|
||||
}
|
||||
|
||||
template <class LandmarkFromType, class LandmarkToType>
|
||||
absl::Status VisibilityCopyCalculator::CopyVisibility(
|
||||
CalculatorContext* cc, const std::string& landmarks_from_tag,
|
||||
const std::string& landmarks_to_tag) {
|
||||
// Check that both landmarks to copy from and to copy to are non empty.
|
||||
if (cc->Inputs().Tag(landmarks_from_tag).IsEmpty() ||
|
||||
cc->Inputs().Tag(landmarks_to_tag).IsEmpty()) {
|
||||
return absl::OkStatus();
|
||||
}
|
||||
|
||||
const auto landmarks_from =
|
||||
cc->Inputs().Tag(landmarks_from_tag).Get<LandmarkFromType>();
|
||||
const auto landmarks_to =
|
||||
cc->Inputs().Tag(landmarks_to_tag).Get<LandmarkToType>();
|
||||
auto landmarks_out = absl::make_unique<LandmarkToType>();
|
||||
|
||||
for (int i = 0; i < landmarks_from.landmark_size(); ++i) {
|
||||
const auto& landmark_from = landmarks_from.landmark(i);
|
||||
const auto& landmark_to = landmarks_to.landmark(i);
|
||||
|
||||
// Create output landmark and copy all fields from the `to` landmark.
|
||||
const auto& landmark_out = landmarks_out->add_landmark();
|
||||
*landmark_out = landmark_to;
|
||||
|
||||
// Copy visibility and presence from the `from` landmark.
|
||||
if (copy_visibility_) {
|
||||
landmark_out->set_visibility(landmark_from.visibility());
|
||||
}
|
||||
if (copy_presence_) {
|
||||
landmark_out->set_presence(landmark_from.presence());
|
||||
}
|
||||
}
|
||||
|
||||
cc->Outputs()
|
||||
.Tag(landmarks_to_tag)
|
||||
.Add(landmarks_out.release(), cc->InputTimestamp());
|
||||
|
||||
return absl::OkStatus();
|
||||
}
|
||||
|
||||
} // namespace mediapipe
|
29
mediapipe/calculators/util/visibility_copy_calculator.proto
Normal file
29
mediapipe/calculators/util/visibility_copy_calculator.proto
Normal file
|
@ -0,0 +1,29 @@
|
|||
// Copyright 2020 The MediaPipe Authors.
|
||||
//
|
||||
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||
// you may not use this file except in compliance with the License.
|
||||
// You may obtain a copy of the License at
|
||||
//
|
||||
// http://www.apache.org/licenses/LICENSE-2.0
|
||||
//
|
||||
// Unless required by applicable law or agreed to in writing, software
|
||||
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
// See the License for the specific language governing permissions and
|
||||
// limitations under the License.
|
||||
|
||||
syntax = "proto2";
|
||||
|
||||
package mediapipe;
|
||||
|
||||
import "mediapipe/framework/calculator_options.proto";
|
||||
|
||||
message VisibilityCopyCalculatorOptions {
|
||||
extend CalculatorOptions {
|
||||
optional VisibilityCopyCalculatorOptions ext = 363728421;
|
||||
}
|
||||
|
||||
optional bool copy_visibility = 1 [default = true];
|
||||
|
||||
optional bool copy_presence = 2 [default = true];
|
||||
}
|
243
mediapipe/calculators/util/visibility_smoothing_calculator.cc
Normal file
243
mediapipe/calculators/util/visibility_smoothing_calculator.cc
Normal file
|
@ -0,0 +1,243 @@
|
|||
// Copyright 2020 The MediaPipe Authors.
|
||||
//
|
||||
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||
// you may not use this file except in compliance with the License.
|
||||
// You may obtain a copy of the License at
|
||||
//
|
||||
// http://www.apache.org/licenses/LICENSE-2.0
|
||||
//
|
||||
// Unless required by applicable law or agreed to in writing, software
|
||||
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
// See the License for the specific language governing permissions and
|
||||
// limitations under the License.
|
||||
|
||||
#include <memory>
|
||||
|
||||
#include "absl/algorithm/container.h"
|
||||
#include "mediapipe/calculators/util/visibility_smoothing_calculator.pb.h"
|
||||
#include "mediapipe/framework/calculator_framework.h"
|
||||
#include "mediapipe/framework/formats/landmark.pb.h"
|
||||
#include "mediapipe/framework/port/ret_check.h"
|
||||
#include "mediapipe/framework/timestamp.h"
|
||||
#include "mediapipe/util/filtering/low_pass_filter.h"
|
||||
|
||||
namespace mediapipe {
|
||||
|
||||
namespace {
|
||||
|
||||
constexpr char kNormalizedLandmarksTag[] = "NORM_LANDMARKS";
|
||||
constexpr char kLandmarksTag[] = "LANDMARKS";
|
||||
constexpr char kNormalizedFilteredLandmarksTag[] = "NORM_FILTERED_LANDMARKS";
|
||||
constexpr char kFilteredLandmarksTag[] = "FILTERED_LANDMARKS";
|
||||
|
||||
using mediapipe::LowPassFilter;
|
||||
|
||||
// Abstract class for various visibility filters.
|
||||
class VisibilityFilter {
|
||||
public:
|
||||
virtual ~VisibilityFilter() = default;
|
||||
|
||||
virtual absl::Status Reset() { return absl::OkStatus(); }
|
||||
|
||||
virtual absl::Status Apply(const LandmarkList& in_landmarks,
|
||||
const absl::Duration& timestamp,
|
||||
LandmarkList* out_landmarks) = 0;
|
||||
|
||||
virtual absl::Status Apply(const NormalizedLandmarkList& in_landmarks,
|
||||
const absl::Duration& timestamp,
|
||||
NormalizedLandmarkList* out_landmarks) = 0;
|
||||
};
|
||||
|
||||
// Returns visibility as is without smoothing.
|
||||
class NoFilter : public VisibilityFilter {
|
||||
public:
|
||||
absl::Status Apply(const NormalizedLandmarkList& in_landmarks,
|
||||
const absl::Duration& timestamp,
|
||||
NormalizedLandmarkList* out_landmarks) override {
|
||||
*out_landmarks = in_landmarks;
|
||||
return absl::OkStatus();
|
||||
}
|
||||
|
||||
absl::Status Apply(const LandmarkList& in_landmarks,
|
||||
const absl::Duration& timestamp,
|
||||
LandmarkList* out_landmarks) override {
|
||||
*out_landmarks = in_landmarks;
|
||||
return absl::OkStatus();
|
||||
}
|
||||
};
|
||||
|
||||
// Please check LowPassFilter documentation for details.
|
||||
class LowPassVisibilityFilter : public VisibilityFilter {
|
||||
public:
|
||||
LowPassVisibilityFilter(float alpha) : alpha_(alpha) {}
|
||||
|
||||
absl::Status Reset() override {
|
||||
visibility_filters_.clear();
|
||||
return absl::OkStatus();
|
||||
}
|
||||
|
||||
absl::Status Apply(const LandmarkList& in_landmarks,
|
||||
const absl::Duration& timestamp,
|
||||
LandmarkList* out_landmarks) override {
|
||||
return ApplyImpl<LandmarkList>(in_landmarks, timestamp, out_landmarks);
|
||||
}
|
||||
|
||||
absl::Status Apply(const NormalizedLandmarkList& in_landmarks,
|
||||
const absl::Duration& timestamp,
|
||||
NormalizedLandmarkList* out_landmarks) override {
|
||||
return ApplyImpl<NormalizedLandmarkList>(in_landmarks, timestamp,
|
||||
out_landmarks);
|
||||
}
|
||||
|
||||
private:
|
||||
template <class LandmarksType>
|
||||
absl::Status ApplyImpl(const LandmarksType& in_landmarks,
|
||||
const absl::Duration& timestamp,
|
||||
LandmarksType* out_landmarks) {
|
||||
// Initializes filters for the first time or after Reset. If initialized
|
||||
// then check the size.
|
||||
int n_landmarks = in_landmarks.landmark_size();
|
||||
if (!visibility_filters_.empty()) {
|
||||
RET_CHECK_EQ(visibility_filters_.size(), n_landmarks);
|
||||
} else {
|
||||
visibility_filters_.resize(n_landmarks, LowPassFilter(alpha_));
|
||||
}
|
||||
|
||||
// Filter visibilities.
|
||||
for (int i = 0; i < in_landmarks.landmark_size(); ++i) {
|
||||
const auto& in_landmark = in_landmarks.landmark(i);
|
||||
|
||||
auto* out_landmark = out_landmarks->add_landmark();
|
||||
*out_landmark = in_landmark;
|
||||
out_landmark->set_visibility(
|
||||
visibility_filters_[i].Apply(in_landmark.visibility()));
|
||||
}
|
||||
|
||||
return absl::OkStatus();
|
||||
}
|
||||
|
||||
float alpha_;
|
||||
std::vector<LowPassFilter> visibility_filters_;
|
||||
};
|
||||
|
||||
} // namespace
|
||||
|
||||
// A calculator to smooth landmark visibilities over time.
|
||||
//
|
||||
// Exactly one landmarks input stream is expected. Output stream type should be
|
||||
// the same as the input one.
|
||||
//
|
||||
// Inputs:
|
||||
// LANDMARKS (optional): A LandmarkList of landmarks you want to smooth.
|
||||
// NORM_LANDMARKS (optional): A NormalizedLandmarkList of landmarks you want
|
||||
// to smooth.
|
||||
//
|
||||
// Outputs:
|
||||
// FILTERED_LANDMARKS (optional): A LandmarkList of smoothed landmarks.
|
||||
// NORM_FILTERED_LANDMARKS (optional): A NormalizedLandmarkList of smoothed
|
||||
// landmarks.
|
||||
//
|
||||
// Example config:
|
||||
// node {
|
||||
// calculator: "VisibilitySmoothingCalculator"
|
||||
// input_stream: "NORM_LANDMARKS:pose_landmarks"
|
||||
// output_stream: "NORM_FILTERED_LANDMARKS:pose_landmarks_filtered"
|
||||
// options: {
|
||||
// [mediapipe.VisibilitySmoothingCalculatorOptions.ext] {
|
||||
// low_pass_filter: {
|
||||
// alpha: 0.1
|
||||
// }
|
||||
// }
|
||||
// }
|
||||
// }
|
||||
//
|
||||
class VisibilitySmoothingCalculator : public CalculatorBase {
|
||||
public:
|
||||
static absl::Status GetContract(CalculatorContract* cc);
|
||||
absl::Status Open(CalculatorContext* cc) override;
|
||||
absl::Status Process(CalculatorContext* cc) override;
|
||||
|
||||
private:
|
||||
std::unique_ptr<VisibilityFilter> visibility_filter_;
|
||||
};
|
||||
REGISTER_CALCULATOR(VisibilitySmoothingCalculator);
|
||||
|
||||
absl::Status VisibilitySmoothingCalculator::GetContract(
|
||||
CalculatorContract* cc) {
|
||||
RET_CHECK(cc->Inputs().HasTag(kNormalizedLandmarksTag) ^
|
||||
cc->Inputs().HasTag(kLandmarksTag))
|
||||
<< "Exactly one landmarks input stream is expected";
|
||||
if (cc->Inputs().HasTag(kNormalizedLandmarksTag)) {
|
||||
cc->Inputs().Tag(kNormalizedLandmarksTag).Set<NormalizedLandmarkList>();
|
||||
RET_CHECK(cc->Outputs().HasTag(kNormalizedFilteredLandmarksTag))
|
||||
<< "Landmarks output stream should of the same type as input one";
|
||||
cc->Outputs()
|
||||
.Tag(kNormalizedFilteredLandmarksTag)
|
||||
.Set<NormalizedLandmarkList>();
|
||||
} else {
|
||||
cc->Inputs().Tag(kLandmarksTag).Set<LandmarkList>();
|
||||
RET_CHECK(cc->Outputs().HasTag(kFilteredLandmarksTag))
|
||||
<< "Landmarks output stream should of the same type as input one";
|
||||
cc->Outputs().Tag(kFilteredLandmarksTag).Set<LandmarkList>();
|
||||
}
|
||||
|
||||
return absl::OkStatus();
|
||||
}
|
||||
|
||||
absl::Status VisibilitySmoothingCalculator::Open(CalculatorContext* cc) {
|
||||
cc->SetOffset(TimestampDiff(0));
|
||||
|
||||
// Pick visibility filter.
|
||||
const auto& options = cc->Options<VisibilitySmoothingCalculatorOptions>();
|
||||
if (options.has_no_filter()) {
|
||||
visibility_filter_ = absl::make_unique<NoFilter>();
|
||||
} else if (options.has_low_pass_filter()) {
|
||||
visibility_filter_ = absl::make_unique<LowPassVisibilityFilter>(
|
||||
options.low_pass_filter().alpha());
|
||||
} else {
|
||||
RET_CHECK_FAIL()
|
||||
<< "Visibility filter is either not specified or not supported";
|
||||
}
|
||||
|
||||
return absl::OkStatus();
|
||||
}
|
||||
|
||||
absl::Status VisibilitySmoothingCalculator::Process(CalculatorContext* cc) {
|
||||
// Check that landmarks are not empty and reset the filter if so.
|
||||
// Don't emit an empty packet for this timestamp.
|
||||
if ((cc->Inputs().HasTag(kNormalizedLandmarksTag) &&
|
||||
cc->Inputs().Tag(kNormalizedLandmarksTag).IsEmpty()) ||
|
||||
(cc->Inputs().HasTag(kLandmarksTag) &&
|
||||
cc->Inputs().Tag(kLandmarksTag).IsEmpty())) {
|
||||
MP_RETURN_IF_ERROR(visibility_filter_->Reset());
|
||||
return absl::OkStatus();
|
||||
}
|
||||
|
||||
const auto& timestamp =
|
||||
absl::Microseconds(cc->InputTimestamp().Microseconds());
|
||||
|
||||
if (cc->Inputs().HasTag(kNormalizedLandmarksTag)) {
|
||||
const auto& in_landmarks =
|
||||
cc->Inputs().Tag(kNormalizedLandmarksTag).Get<NormalizedLandmarkList>();
|
||||
auto out_landmarks = absl::make_unique<NormalizedLandmarkList>();
|
||||
MP_RETURN_IF_ERROR(visibility_filter_->Apply(in_landmarks, timestamp,
|
||||
out_landmarks.get()));
|
||||
cc->Outputs()
|
||||
.Tag(kNormalizedFilteredLandmarksTag)
|
||||
.Add(out_landmarks.release(), cc->InputTimestamp());
|
||||
} else {
|
||||
const auto& in_landmarks =
|
||||
cc->Inputs().Tag(kLandmarksTag).Get<LandmarkList>();
|
||||
auto out_landmarks = absl::make_unique<LandmarkList>();
|
||||
MP_RETURN_IF_ERROR(visibility_filter_->Apply(in_landmarks, timestamp,
|
||||
out_landmarks.get()));
|
||||
cc->Outputs()
|
||||
.Tag(kFilteredLandmarksTag)
|
||||
.Add(out_landmarks.release(), cc->InputTimestamp());
|
||||
}
|
||||
|
||||
return absl::OkStatus();
|
||||
}
|
||||
|
||||
} // namespace mediapipe
|
|
@ -0,0 +1,40 @@
|
|||
// Copyright 2020 The MediaPipe Authors.
|
||||
//
|
||||
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||
// you may not use this file except in compliance with the License.
|
||||
// You may obtain a copy of the License at
|
||||
//
|
||||
// http://www.apache.org/licenses/LICENSE-2.0
|
||||
//
|
||||
// Unless required by applicable law or agreed to in writing, software
|
||||
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
// See the License for the specific language governing permissions and
|
||||
// limitations under the License.
|
||||
|
||||
syntax = "proto2";
|
||||
|
||||
package mediapipe;
|
||||
|
||||
import "mediapipe/framework/calculator_options.proto";
|
||||
|
||||
message VisibilitySmoothingCalculatorOptions {
|
||||
extend CalculatorOptions {
|
||||
optional VisibilitySmoothingCalculatorOptions ext = 360207350;
|
||||
}
|
||||
|
||||
// Default behaviour and fast way to disable smoothing.
|
||||
message NoFilter {}
|
||||
|
||||
message LowPassFilter {
|
||||
// Coefficient applied to a new value, whilte `1 - alpha` is applied to a
|
||||
// stored value. Should be in [0, 1] range. The smaller the value - the
|
||||
// smoother result and the bigger lag.
|
||||
optional float alpha = 1 [default = 0.1];
|
||||
}
|
||||
|
||||
oneof filter_options {
|
||||
NoFilter no_filter = 1;
|
||||
LowPassFilter low_pass_filter = 2;
|
||||
}
|
||||
}
|
|
@ -0,0 +1,108 @@
|
|||
// Copyright 2019 The MediaPipe Authors.
|
||||
//
|
||||
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||
// you may not use this file except in compliance with the License.
|
||||
// You may obtain a copy of the License at
|
||||
//
|
||||
// http://www.apache.org/licenses/LICENSE-2.0
|
||||
//
|
||||
// Unless required by applicable law or agreed to in writing, software
|
||||
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
// See the License for the specific language governing permissions and
|
||||
// limitations under the License.
|
||||
|
||||
#include <cmath>
|
||||
#include <vector>
|
||||
|
||||
#include "mediapipe/framework/calculator_framework.h"
|
||||
#include "mediapipe/framework/formats/landmark.pb.h"
|
||||
#include "mediapipe/framework/formats/rect.pb.h"
|
||||
#include "mediapipe/framework/port/ret_check.h"
|
||||
|
||||
namespace mediapipe {
|
||||
|
||||
namespace {
|
||||
|
||||
constexpr char kLandmarksTag[] = "LANDMARKS";
|
||||
constexpr char kRectTag[] = "NORM_RECT";
|
||||
|
||||
} // namespace
|
||||
|
||||
// Projects world landmarks from the rectangle to original coordinates.
|
||||
//
|
||||
// World landmarks are predicted in meters rather than in pixels of the image
|
||||
// and have origin in the middle of the hips rather than in the corner of the
|
||||
// pose image (cropped with given rectangle). Thus only rotation (but not scale
|
||||
// and translation) is applied to the landmarks to transform them back to
|
||||
// original coordinates.
|
||||
//
|
||||
// Input:
|
||||
// LANDMARKS: A LandmarkList representing world landmarks in the rectangle.
|
||||
// NORM_RECT: An NormalizedRect representing a normalized rectangle in image
|
||||
// coordinates.
|
||||
//
|
||||
// Output:
|
||||
// LANDMARKS: A LandmarkList representing world landmarks projected (rotated
|
||||
// but not scaled or translated) from the rectangle to original
|
||||
// coordinates.
|
||||
//
|
||||
// Usage example:
|
||||
// node {
|
||||
// calculator: "WorldLandmarkProjectionCalculator"
|
||||
// input_stream: "LANDMARKS:landmarks"
|
||||
// input_stream: "NORM_RECT:rect"
|
||||
// output_stream: "LANDMARKS:projected_landmarks"
|
||||
// }
|
||||
//
|
||||
class WorldLandmarkProjectionCalculator : public CalculatorBase {
|
||||
public:
|
||||
static absl::Status GetContract(CalculatorContract* cc) {
|
||||
cc->Inputs().Tag(kLandmarksTag).Set<LandmarkList>();
|
||||
cc->Inputs().Tag(kRectTag).Set<NormalizedRect>();
|
||||
cc->Outputs().Tag(kLandmarksTag).Set<LandmarkList>();
|
||||
|
||||
return absl::OkStatus();
|
||||
}
|
||||
|
||||
absl::Status Open(CalculatorContext* cc) override {
|
||||
cc->SetOffset(TimestampDiff(0));
|
||||
|
||||
return absl::OkStatus();
|
||||
}
|
||||
|
||||
absl::Status Process(CalculatorContext* cc) override {
|
||||
// Check that landmarks and rect are not empty.
|
||||
if (cc->Inputs().Tag(kLandmarksTag).IsEmpty() ||
|
||||
cc->Inputs().Tag(kRectTag).IsEmpty()) {
|
||||
return absl::OkStatus();
|
||||
}
|
||||
|
||||
const auto& in_landmarks =
|
||||
cc->Inputs().Tag(kLandmarksTag).Get<LandmarkList>();
|
||||
const auto& in_rect = cc->Inputs().Tag(kRectTag).Get<NormalizedRect>();
|
||||
|
||||
auto out_landmarks = absl::make_unique<LandmarkList>();
|
||||
for (int i = 0; i < in_landmarks.landmark_size(); ++i) {
|
||||
const auto& in_landmark = in_landmarks.landmark(i);
|
||||
|
||||
Landmark* out_landmark = out_landmarks->add_landmark();
|
||||
*out_landmark = in_landmark;
|
||||
|
||||
const float angle = in_rect.rotation();
|
||||
out_landmark->set_x(std::cos(angle) * in_landmark.x() -
|
||||
std::sin(angle) * in_landmark.y());
|
||||
out_landmark->set_y(std::sin(angle) * in_landmark.x() +
|
||||
std::cos(angle) * in_landmark.y());
|
||||
}
|
||||
|
||||
cc->Outputs()
|
||||
.Tag(kLandmarksTag)
|
||||
.Add(out_landmarks.release(), cc->InputTimestamp());
|
||||
|
||||
return absl::OkStatus();
|
||||
}
|
||||
};
|
||||
REGISTER_CALCULATOR(WorldLandmarkProjectionCalculator);
|
||||
|
||||
} // namespace mediapipe
|
|
@ -426,6 +426,7 @@ cc_test(
|
|||
"//mediapipe/framework/port:logging",
|
||||
"//mediapipe/framework/port:opencv_core",
|
||||
"//mediapipe/framework/port:parse_text_proto",
|
||||
"@com_google_absl//absl/flags:flag",
|
||||
],
|
||||
)
|
||||
|
||||
|
@ -450,6 +451,7 @@ cc_test(
|
|||
"//mediapipe/framework/port:opencv_imgproc",
|
||||
"//mediapipe/framework/port:opencv_video",
|
||||
"//mediapipe/framework/port:parse_text_proto",
|
||||
"@com_google_absl//absl/flags:flag",
|
||||
],
|
||||
)
|
||||
|
||||
|
@ -534,6 +536,7 @@ cc_test(
|
|||
"//mediapipe/framework/stream_handler:sync_set_input_stream_handler",
|
||||
"//mediapipe/util/tracking:box_tracker_cc_proto",
|
||||
"//mediapipe/util/tracking:tracking_cc_proto",
|
||||
"@com_google_absl//absl/flags:flag",
|
||||
],
|
||||
)
|
||||
|
||||
|
|
|
@ -27,13 +27,14 @@ cc_library(
|
|||
"//mediapipe/framework:calculator_framework",
|
||||
"//mediapipe/framework/formats:image_frame",
|
||||
"//mediapipe/framework/formats:image_frame_opencv",
|
||||
"//mediapipe/framework/port:commandlineflags",
|
||||
"//mediapipe/framework/port:file_helpers",
|
||||
"//mediapipe/framework/port:opencv_highgui",
|
||||
"//mediapipe/framework/port:opencv_imgproc",
|
||||
"//mediapipe/framework/port:opencv_video",
|
||||
"//mediapipe/framework/port:parse_text_proto",
|
||||
"//mediapipe/framework/port:status",
|
||||
"@com_google_absl//absl/flags:flag",
|
||||
"@com_google_absl//absl/flags:parse",
|
||||
],
|
||||
)
|
||||
|
||||
|
|
|
@ -62,7 +62,7 @@ COPY . /mediapipe/
|
|||
|
||||
# Install bazel
|
||||
# Please match the current MediaPipe Bazel requirements according to docs.
|
||||
ARG BAZEL_VERSION=3.4.1
|
||||
ARG BAZEL_VERSION=3.7.2
|
||||
RUN mkdir /bazel && \
|
||||
wget --no-check-certificate -O /bazel/installer.sh "https://github.com/bazelbuild/bazel/releases/download/${BAZEL_VERSION}/bazel-${BAZEL_VERSION}-installer-linux-x86_64.sh" && \
|
||||
wget --no-check-certificate -O /bazel/LICENSE.txt "https://raw.githubusercontent.com/bazelbuild/bazel/master/LICENSE" && \
|
||||
|
|
|
@ -15,10 +15,11 @@
|
|||
// An example of sending OpenCV webcam frames into a MediaPipe graph.
|
||||
#include <cstdlib>
|
||||
|
||||
#include "absl/flags/flag.h"
|
||||
#include "absl/flags/parse.h"
|
||||
#include "mediapipe/framework/calculator_framework.h"
|
||||
#include "mediapipe/framework/formats/image_frame.h"
|
||||
#include "mediapipe/framework/formats/image_frame_opencv.h"
|
||||
#include "mediapipe/framework/port/commandlineflags.h"
|
||||
#include "mediapipe/framework/port/file_helpers.h"
|
||||
#include "mediapipe/framework/port/opencv_highgui_inc.h"
|
||||
#include "mediapipe/framework/port/opencv_imgproc_inc.h"
|
||||
|
@ -30,15 +31,14 @@ constexpr char kInputStream[] = "input_video";
|
|||
constexpr char kOutputStream[] = "output_video";
|
||||
constexpr char kWindowName[] = "MediaPipe";
|
||||
|
||||
DEFINE_string(
|
||||
calculator_graph_config_file, "",
|
||||
"Name of file containing text format CalculatorGraphConfig proto.");
|
||||
DEFINE_string(input_video_path, "",
|
||||
"Full path of video to load. "
|
||||
"If not provided, attempt to use a webcam.");
|
||||
DEFINE_string(output_video_path, "",
|
||||
"Full path of where to save result (.mp4 only). "
|
||||
"If not provided, show result in a window.");
|
||||
ABSL_FLAG(std::string, calculator_graph_config_file, "",
|
||||
"Name of file containing text format CalculatorGraphConfig proto.");
|
||||
ABSL_FLAG(std::string, input_video_path, "",
|
||||
"Full path of video to load. "
|
||||
"If not provided, attempt to use a webcam.");
|
||||
ABSL_FLAG(std::string, output_video_path, "",
|
||||
"Full path of where to save result (.mp4 only). "
|
||||
"If not provided, show result in a window.");
|
||||
|
||||
absl::Status RunMPPGraph() {
|
||||
std::string calculator_graph_config_contents;
|
||||
|
@ -143,7 +143,7 @@ absl::Status RunMPPGraph() {
|
|||
|
||||
int main(int argc, char** argv) {
|
||||
google::InitGoogleLogging(argv[0]);
|
||||
gflags::ParseCommandLineFlags(&argc, &argv, true);
|
||||
absl::ParseCommandLine(argc, argv);
|
||||
absl::Status run_status = RunMPPGraph();
|
||||
if (!run_status.ok()) {
|
||||
LOG(ERROR) << "Failed to run the graph: " << run_status.message();
|
||||
|
|
|
@ -23,13 +23,14 @@ cc_library(
|
|||
srcs = ["simple_run_graph_main.cc"],
|
||||
deps = [
|
||||
"//mediapipe/framework:calculator_framework",
|
||||
"//mediapipe/framework/port:commandlineflags",
|
||||
"//mediapipe/framework/port:file_helpers",
|
||||
"//mediapipe/framework/port:map_util",
|
||||
"//mediapipe/framework/port:parse_text_proto",
|
||||
"//mediapipe/framework/port:ret_check",
|
||||
"//mediapipe/framework/port:status",
|
||||
"//mediapipe/framework/port:statusor",
|
||||
"@com_google_absl//absl/flags:flag",
|
||||
"@com_google_absl//absl/flags:parse",
|
||||
"@com_google_absl//absl/strings",
|
||||
],
|
||||
)
|
||||
|
@ -41,13 +42,14 @@ cc_library(
|
|||
"//mediapipe/framework:calculator_framework",
|
||||
"//mediapipe/framework/formats:image_frame",
|
||||
"//mediapipe/framework/formats:image_frame_opencv",
|
||||
"//mediapipe/framework/port:commandlineflags",
|
||||
"//mediapipe/framework/port:file_helpers",
|
||||
"//mediapipe/framework/port:opencv_highgui",
|
||||
"//mediapipe/framework/port:opencv_imgproc",
|
||||
"//mediapipe/framework/port:opencv_video",
|
||||
"//mediapipe/framework/port:parse_text_proto",
|
||||
"//mediapipe/framework/port:status",
|
||||
"@com_google_absl//absl/flags:flag",
|
||||
"@com_google_absl//absl/flags:parse",
|
||||
],
|
||||
)
|
||||
|
||||
|
@ -62,7 +64,6 @@ cc_library(
|
|||
"//mediapipe/framework:calculator_framework",
|
||||
"//mediapipe/framework/formats:image_frame",
|
||||
"//mediapipe/framework/formats:image_frame_opencv",
|
||||
"//mediapipe/framework/port:commandlineflags",
|
||||
"//mediapipe/framework/port:file_helpers",
|
||||
"//mediapipe/framework/port:opencv_highgui",
|
||||
"//mediapipe/framework/port:opencv_imgproc",
|
||||
|
@ -72,5 +73,7 @@ cc_library(
|
|||
"//mediapipe/gpu:gl_calculator_helper",
|
||||
"//mediapipe/gpu:gpu_buffer",
|
||||
"//mediapipe/gpu:gpu_shared_data_internal",
|
||||
"@com_google_absl//absl/flags:flag",
|
||||
"@com_google_absl//absl/flags:parse",
|
||||
],
|
||||
)
|
||||
|
|
|
@ -54,18 +54,27 @@ mediapipe_cc_proto_library(
|
|||
deps = [":border_detection_calculator_proto"],
|
||||
)
|
||||
|
||||
cc_library(
|
||||
name = "content_zooming_calculator_state",
|
||||
hdrs = ["content_zooming_calculator_state.h"],
|
||||
deps = [
|
||||
"//mediapipe/examples/desktop/autoflip/quality:kinematic_path_solver",
|
||||
"//mediapipe/framework:timestamp",
|
||||
"//mediapipe/framework/formats:rect_cc_proto",
|
||||
],
|
||||
)
|
||||
|
||||
cc_library(
|
||||
name = "content_zooming_calculator",
|
||||
srcs = ["content_zooming_calculator.cc"],
|
||||
deps = [
|
||||
":content_zooming_calculator_cc_proto",
|
||||
":content_zooming_calculator_state",
|
||||
"//mediapipe/examples/desktop/autoflip:autoflip_messages_cc_proto",
|
||||
"//mediapipe/examples/desktop/autoflip/quality:kinematic_path_solver",
|
||||
"//mediapipe/framework:calculator_framework",
|
||||
"//mediapipe/framework/formats:detection_cc_proto",
|
||||
"//mediapipe/framework/formats:image_frame",
|
||||
"//mediapipe/framework/formats:location_data_cc_proto",
|
||||
"//mediapipe/framework/formats:rect_cc_proto",
|
||||
"//mediapipe/framework/port:ret_check",
|
||||
"//mediapipe/framework/port:status",
|
||||
],
|
||||
|
@ -88,7 +97,9 @@ mediapipe_cc_proto_library(
|
|||
"//mediapipe/examples/desktop/autoflip/quality:kinematic_path_solver_cc_proto",
|
||||
"//mediapipe/framework:calculator_cc_proto",
|
||||
],
|
||||
visibility = ["//mediapipe/examples:__subpackages__"],
|
||||
visibility = [
|
||||
"//mediapipe/examples:__subpackages__",
|
||||
],
|
||||
deps = [
|
||||
":content_zooming_calculator_proto",
|
||||
],
|
||||
|
@ -127,6 +138,7 @@ cc_test(
|
|||
deps = [
|
||||
":content_zooming_calculator",
|
||||
":content_zooming_calculator_cc_proto",
|
||||
":content_zooming_calculator_state",
|
||||
"//mediapipe/examples/desktop/autoflip:autoflip_messages_cc_proto",
|
||||
"//mediapipe/examples/desktop/autoflip/quality:kinematic_path_solver",
|
||||
"//mediapipe/framework:calculator_framework",
|
||||
|
@ -368,7 +380,6 @@ cc_test(
|
|||
"//mediapipe/framework/deps:file_path",
|
||||
"//mediapipe/framework/formats:image_frame",
|
||||
"//mediapipe/framework/formats:image_frame_opencv",
|
||||
"//mediapipe/framework/port:commandlineflags",
|
||||
"//mediapipe/framework/port:gtest_main",
|
||||
"//mediapipe/framework/port:opencv_core",
|
||||
"//mediapipe/framework/port:opencv_imgcodecs",
|
||||
|
@ -376,6 +387,7 @@ cc_test(
|
|||
"//mediapipe/framework/port:parse_text_proto",
|
||||
"//mediapipe/framework/port:ret_check",
|
||||
"//mediapipe/framework/port:status",
|
||||
"@com_google_absl//absl/flags:flag",
|
||||
"@com_google_absl//absl/strings",
|
||||
],
|
||||
)
|
||||
|
|
|
@ -17,12 +17,11 @@
|
|||
|
||||
#include "mediapipe/examples/desktop/autoflip/autoflip_messages.pb.h"
|
||||
#include "mediapipe/examples/desktop/autoflip/calculators/content_zooming_calculator.pb.h"
|
||||
#include "mediapipe/examples/desktop/autoflip/quality/kinematic_path_solver.h"
|
||||
#include "mediapipe/examples/desktop/autoflip/calculators/content_zooming_calculator_state.h"
|
||||
#include "mediapipe/framework/calculator_framework.h"
|
||||
#include "mediapipe/framework/formats/detection.pb.h"
|
||||
#include "mediapipe/framework/formats/image_frame.h"
|
||||
#include "mediapipe/framework/formats/location_data.pb.h"
|
||||
#include "mediapipe/framework/formats/rect.pb.h"
|
||||
#include "mediapipe/framework/port/ret_check.h"
|
||||
#include "mediapipe/framework/port/status.h"
|
||||
#include "mediapipe/framework/port/status_builder.h"
|
||||
|
@ -33,12 +32,18 @@ constexpr char kSalientRegions[] = "SALIENT_REGIONS";
|
|||
constexpr char kDetections[] = "DETECTIONS";
|
||||
constexpr char kDetectedBorders[] = "BORDERS";
|
||||
constexpr char kCropRect[] = "CROP_RECT";
|
||||
constexpr char kFirstCropRect[] = "FIRST_CROP_RECT";
|
||||
// Field-of-view (degrees) of the camera's x-axis (width).
|
||||
// TODO: Parameterize FOV based on camera specs.
|
||||
constexpr float kFieldOfView = 60;
|
||||
// A pointer to a ContentZoomingCalculatorStateCacheType in a side packet.
|
||||
// Used to save state on Close and load state on Open in a new graph.
|
||||
// Can be used to preserve state between graphs.
|
||||
constexpr char kStateCache[] = "STATE_CACHE";
|
||||
|
||||
namespace mediapipe {
|
||||
namespace autoflip {
|
||||
using StateCacheType = ContentZoomingCalculatorStateCacheType;
|
||||
|
||||
// Content zooming calculator zooms in on content when a detection has
|
||||
// "only_required" set true or any raw detection input. It does this by
|
||||
|
@ -49,8 +54,7 @@ namespace autoflip {
|
|||
// include mobile makeover and autofliplive face reframing.
|
||||
class ContentZoomingCalculator : public CalculatorBase {
|
||||
public:
|
||||
ContentZoomingCalculator()
|
||||
: initialized_(false), last_only_required_detection_(0) {}
|
||||
ContentZoomingCalculator() : initialized_(false) {}
|
||||
~ContentZoomingCalculator() override {}
|
||||
ContentZoomingCalculator(const ContentZoomingCalculator&) = delete;
|
||||
ContentZoomingCalculator& operator=(const ContentZoomingCalculator&) = delete;
|
||||
|
@ -58,8 +62,25 @@ class ContentZoomingCalculator : public CalculatorBase {
|
|||
static absl::Status GetContract(mediapipe::CalculatorContract* cc);
|
||||
absl::Status Open(mediapipe::CalculatorContext* cc) override;
|
||||
absl::Status Process(mediapipe::CalculatorContext* cc) override;
|
||||
absl::Status Close(mediapipe::CalculatorContext* cc) override;
|
||||
|
||||
private:
|
||||
// Tries to load state from a state-cache, if provided. Fallsback to
|
||||
// initializing state if no cache or no value in the cache are available.
|
||||
absl::Status MaybeLoadState(mediapipe::CalculatorContext* cc, int frame_width,
|
||||
int frame_height);
|
||||
// Saves state to a state-cache, if provided.
|
||||
absl::Status SaveState(mediapipe::CalculatorContext* cc) const;
|
||||
// Initializes the calculator for the given frame size, creating path solvers
|
||||
// and resetting history like last measured values.
|
||||
absl::Status InitializeState(int frame_width, int frame_height);
|
||||
// Adjusts state to work with an updated frame size.
|
||||
absl::Status UpdateForResolutionChange(int frame_width, int frame_height);
|
||||
// Returns true if we are zooming to the initial rect.
|
||||
bool IsZoomingToInitialRect(const Timestamp& timestamp) const;
|
||||
// Builds the output rectangle when zooming to the initial rect.
|
||||
absl::StatusOr<mediapipe::Rect> GetInitialZoomingRect(
|
||||
int frame_width, int frame_height, const Timestamp& timestamp) const;
|
||||
// Converts bounds to tilt offset, pan offset and height.
|
||||
absl::Status ConvertToPanTiltZoom(float xmin, float xmax, float ymin,
|
||||
float ymax, int* tilt_offset,
|
||||
|
@ -76,6 +97,10 @@ class ContentZoomingCalculator : public CalculatorBase {
|
|||
std::unique_ptr<KinematicPathSolver> path_solver_tilt_;
|
||||
// Are parameters initialized.
|
||||
bool initialized_;
|
||||
// Stores the time of the first crop rectangle.
|
||||
Timestamp first_rect_timestamp_;
|
||||
// Stores the first crop rectangle.
|
||||
mediapipe::NormalizedRect first_rect_;
|
||||
// Stores the time of the last "only_required" input.
|
||||
int64 last_only_required_detection_;
|
||||
// Rect values of last message with detection(s).
|
||||
|
@ -116,6 +141,12 @@ absl::Status ContentZoomingCalculator::GetContract(
|
|||
if (cc->Outputs().HasTag(kCropRect)) {
|
||||
cc->Outputs().Tag(kCropRect).Set<mediapipe::Rect>();
|
||||
}
|
||||
if (cc->Outputs().HasTag(kFirstCropRect)) {
|
||||
cc->Outputs().Tag(kFirstCropRect).Set<mediapipe::NormalizedRect>();
|
||||
}
|
||||
if (cc->InputSidePackets().HasTag(kStateCache)) {
|
||||
cc->InputSidePackets().Tag(kStateCache).Set<StateCacheType*>();
|
||||
}
|
||||
return absl::OkStatus();
|
||||
}
|
||||
|
||||
|
@ -135,6 +166,13 @@ absl::Status ContentZoomingCalculator::Open(mediapipe::CalculatorContext* cc) {
|
|||
return absl::OkStatus();
|
||||
}
|
||||
|
||||
absl::Status ContentZoomingCalculator::Close(mediapipe::CalculatorContext* cc) {
|
||||
if (initialized_) {
|
||||
MP_RETURN_IF_ERROR(SaveState(cc));
|
||||
}
|
||||
return absl::OkStatus();
|
||||
}
|
||||
|
||||
absl::Status ContentZoomingCalculator::ConvertToPanTiltZoom(
|
||||
float xmin, float xmax, float ymin, float ymax, int* tilt_offset,
|
||||
int* pan_offset, int* height) {
|
||||
|
@ -275,39 +313,89 @@ absl::Status ContentZoomingCalculator::UpdateAspectAndMax() {
|
|||
return absl::OkStatus();
|
||||
}
|
||||
|
||||
absl::Status ContentZoomingCalculator::Process(
|
||||
mediapipe::CalculatorContext* cc) {
|
||||
// For async subgraph support, return on empty video size packets.
|
||||
if (cc->Inputs().HasTag(kVideoSize) &&
|
||||
cc->Inputs().Tag(kVideoSize).IsEmpty()) {
|
||||
absl::Status ContentZoomingCalculator::MaybeLoadState(
|
||||
mediapipe::CalculatorContext* cc, int frame_width, int frame_height) {
|
||||
const auto* state_cache =
|
||||
cc->InputSidePackets().HasTag(kStateCache)
|
||||
? cc->InputSidePackets().Tag(kStateCache).Get<StateCacheType*>()
|
||||
: nullptr;
|
||||
if (!state_cache || !state_cache->has_value()) {
|
||||
return InitializeState(frame_width, frame_height);
|
||||
}
|
||||
|
||||
const ContentZoomingCalculatorState& state = state_cache->value();
|
||||
frame_width_ = state.frame_width;
|
||||
frame_height_ = state.frame_height;
|
||||
path_solver_pan_ =
|
||||
std::make_unique<KinematicPathSolver>(state.path_solver_pan);
|
||||
path_solver_tilt_ =
|
||||
std::make_unique<KinematicPathSolver>(state.path_solver_tilt);
|
||||
path_solver_zoom_ =
|
||||
std::make_unique<KinematicPathSolver>(state.path_solver_zoom);
|
||||
first_rect_timestamp_ = state.first_rect_timestamp;
|
||||
first_rect_ = state.first_rect;
|
||||
last_only_required_detection_ = state.last_only_required_detection;
|
||||
last_measured_height_ = state.last_measured_height;
|
||||
last_measured_x_offset_ = state.last_measured_x_offset;
|
||||
last_measured_y_offset_ = state.last_measured_y_offset;
|
||||
MP_RETURN_IF_ERROR(UpdateAspectAndMax());
|
||||
|
||||
return UpdateForResolutionChange(frame_width, frame_height);
|
||||
}
|
||||
|
||||
absl::Status ContentZoomingCalculator::SaveState(
|
||||
mediapipe::CalculatorContext* cc) const {
|
||||
auto* state_cache =
|
||||
cc->InputSidePackets().HasTag(kStateCache)
|
||||
? cc->InputSidePackets().Tag(kStateCache).Get<StateCacheType*>()
|
||||
: nullptr;
|
||||
if (!state_cache) {
|
||||
return absl::OkStatus();
|
||||
}
|
||||
int frame_width, frame_height;
|
||||
MP_RETURN_IF_ERROR(GetVideoResolution(cc, &frame_width, &frame_height));
|
||||
|
||||
// Init on first call.
|
||||
if (!initialized_) {
|
||||
frame_width_ = frame_width;
|
||||
frame_height_ = frame_height;
|
||||
path_solver_pan_ = std::make_unique<KinematicPathSolver>(
|
||||
options_.kinematic_options_pan(), 0, frame_width_,
|
||||
static_cast<float>(frame_width_) / kFieldOfView);
|
||||
path_solver_tilt_ = std::make_unique<KinematicPathSolver>(
|
||||
options_.kinematic_options_tilt(), 0, frame_height_,
|
||||
static_cast<float>(frame_height_) / kFieldOfView);
|
||||
MP_RETURN_IF_ERROR(UpdateAspectAndMax());
|
||||
int min_zoom_size = frame_height_ * (options_.max_zoom_value_deg() /
|
||||
static_cast<double>(kFieldOfView));
|
||||
path_solver_zoom_ = std::make_unique<KinematicPathSolver>(
|
||||
options_.kinematic_options_zoom(), min_zoom_size,
|
||||
max_frame_value_ * frame_height_,
|
||||
static_cast<float>(frame_height_) / kFieldOfView);
|
||||
last_measured_height_ = max_frame_value_ * frame_height_;
|
||||
last_measured_x_offset_ = target_aspect_ * frame_width_;
|
||||
last_measured_y_offset_ = frame_width_ / 2;
|
||||
initialized_ = true;
|
||||
}
|
||||
*state_cache = ContentZoomingCalculatorState{
|
||||
.frame_height = frame_height_,
|
||||
.frame_width = frame_width_,
|
||||
.path_solver_zoom = *path_solver_zoom_,
|
||||
.path_solver_pan = *path_solver_pan_,
|
||||
.path_solver_tilt = *path_solver_tilt_,
|
||||
.first_rect_timestamp = first_rect_timestamp_,
|
||||
.first_rect = first_rect_,
|
||||
.last_only_required_detection = last_only_required_detection_,
|
||||
.last_measured_height = last_measured_height_,
|
||||
.last_measured_x_offset = last_measured_x_offset_,
|
||||
.last_measured_y_offset = last_measured_y_offset_,
|
||||
};
|
||||
return absl::OkStatus();
|
||||
}
|
||||
|
||||
absl::Status ContentZoomingCalculator::InitializeState(int frame_width,
|
||||
int frame_height) {
|
||||
frame_width_ = frame_width;
|
||||
frame_height_ = frame_height;
|
||||
path_solver_pan_ = std::make_unique<KinematicPathSolver>(
|
||||
options_.kinematic_options_pan(), 0, frame_width_,
|
||||
static_cast<float>(frame_width_) / kFieldOfView);
|
||||
path_solver_tilt_ = std::make_unique<KinematicPathSolver>(
|
||||
options_.kinematic_options_tilt(), 0, frame_height_,
|
||||
static_cast<float>(frame_height_) / kFieldOfView);
|
||||
MP_RETURN_IF_ERROR(UpdateAspectAndMax());
|
||||
int min_zoom_size = frame_height_ * (options_.max_zoom_value_deg() /
|
||||
static_cast<double>(kFieldOfView));
|
||||
path_solver_zoom_ = std::make_unique<KinematicPathSolver>(
|
||||
options_.kinematic_options_zoom(), min_zoom_size,
|
||||
max_frame_value_ * frame_height_,
|
||||
static_cast<float>(frame_height_) / kFieldOfView);
|
||||
first_rect_timestamp_ = Timestamp::Unset();
|
||||
last_only_required_detection_ = 0;
|
||||
last_measured_height_ = max_frame_value_ * frame_height_;
|
||||
last_measured_x_offset_ = target_aspect_ * frame_width_;
|
||||
last_measured_y_offset_ = frame_width_ / 2;
|
||||
return absl::OkStatus();
|
||||
}
|
||||
|
||||
absl::Status ContentZoomingCalculator::UpdateForResolutionChange(
|
||||
int frame_width, int frame_height) {
|
||||
// Update state for change in input resolution.
|
||||
if (frame_width_ != frame_width || frame_height_ != frame_height) {
|
||||
double width_scale = frame_width / static_cast<double>(frame_width_);
|
||||
|
@ -328,6 +416,74 @@ absl::Status ContentZoomingCalculator::Process(
|
|||
MP_RETURN_IF_ERROR(path_solver_zoom_->UpdatePixelsPerDegree(
|
||||
static_cast<float>(frame_height_) / kFieldOfView));
|
||||
}
|
||||
return absl::OkStatus();
|
||||
}
|
||||
|
||||
bool ContentZoomingCalculator::IsZoomingToInitialRect(
|
||||
const Timestamp& timestamp) const {
|
||||
if (options_.us_to_first_rect() == 0 ||
|
||||
first_rect_timestamp_ == Timestamp::Unset()) {
|
||||
return false;
|
||||
}
|
||||
|
||||
const int64 delta_us = (timestamp - first_rect_timestamp_).Value();
|
||||
return (0 <= delta_us && delta_us <= options_.us_to_first_rect());
|
||||
}
|
||||
|
||||
namespace {
|
||||
double easeInQuad(double t) { return t * t; }
|
||||
double easeOutQuad(double t) { return -1 * t * (t - 2); }
|
||||
double easeInOutQuad(double t) {
|
||||
if (t < 0.5) {
|
||||
return easeInQuad(t * 2) * 0.5;
|
||||
} else {
|
||||
return easeOutQuad(t * 2 - 1) * 0.5 + 0.5;
|
||||
}
|
||||
}
|
||||
double lerp(double a, double b, double i) { return a * (1 - i) + b * i; }
|
||||
} // namespace
|
||||
|
||||
absl::StatusOr<mediapipe::Rect> ContentZoomingCalculator::GetInitialZoomingRect(
|
||||
int frame_width, int frame_height, const Timestamp& timestamp) const {
|
||||
RET_CHECK(IsZoomingToInitialRect(timestamp))
|
||||
<< "Must only be called if zooming to initial rect.";
|
||||
|
||||
const int64 delta_us = (timestamp - first_rect_timestamp_).Value();
|
||||
const int64 delay = options_.us_to_first_rect_delay();
|
||||
const double interpolation = easeInOutQuad(std::max(
|
||||
0.0, (delta_us - delay) /
|
||||
static_cast<double>(options_.us_to_first_rect() - delay)));
|
||||
|
||||
const double x_center = lerp(0.5, first_rect_.x_center(), interpolation);
|
||||
const double y_center = lerp(0.5, first_rect_.y_center(), interpolation);
|
||||
const double width = lerp(1.0, first_rect_.width(), interpolation);
|
||||
const double height = lerp(1.0, first_rect_.height(), interpolation);
|
||||
|
||||
mediapipe::Rect gpu_rect;
|
||||
gpu_rect.set_x_center(x_center * frame_width);
|
||||
gpu_rect.set_width(width * frame_width);
|
||||
gpu_rect.set_y_center(y_center * frame_height);
|
||||
gpu_rect.set_height(height * frame_height);
|
||||
return gpu_rect;
|
||||
}
|
||||
|
||||
absl::Status ContentZoomingCalculator::Process(
|
||||
mediapipe::CalculatorContext* cc) {
|
||||
// For async subgraph support, return on empty video size packets.
|
||||
if (cc->Inputs().HasTag(kVideoSize) &&
|
||||
cc->Inputs().Tag(kVideoSize).IsEmpty()) {
|
||||
return absl::OkStatus();
|
||||
}
|
||||
int frame_width, frame_height;
|
||||
MP_RETURN_IF_ERROR(GetVideoResolution(cc, &frame_width, &frame_height));
|
||||
|
||||
// Init on first call or re-init always if configured to be stateless.
|
||||
if (!initialized_) {
|
||||
MP_RETURN_IF_ERROR(MaybeLoadState(cc, frame_width, frame_height));
|
||||
initialized_ = !options_.is_stateless();
|
||||
} else {
|
||||
MP_RETURN_IF_ERROR(UpdateForResolutionChange(frame_width, frame_height));
|
||||
}
|
||||
|
||||
bool only_required_found = false;
|
||||
|
||||
|
@ -348,31 +504,52 @@ absl::Status ContentZoomingCalculator::Process(
|
|||
|
||||
if (cc->Inputs().HasTag(kDetections)) {
|
||||
if (cc->Inputs().Tag(kDetections).IsEmpty()) {
|
||||
auto default_rect = absl::make_unique<mediapipe::Rect>();
|
||||
default_rect->set_x_center(frame_width_ / 2);
|
||||
default_rect->set_y_center(frame_height_ / 2);
|
||||
default_rect->set_width(frame_width_);
|
||||
default_rect->set_height(frame_height_);
|
||||
cc->Outputs().Tag(kCropRect).Add(default_rect.release(),
|
||||
Timestamp(cc->InputTimestamp()));
|
||||
return absl::OkStatus();
|
||||
}
|
||||
auto raw_detections =
|
||||
cc->Inputs().Tag(kDetections).Get<std::vector<mediapipe::Detection>>();
|
||||
for (const auto& detection : raw_detections) {
|
||||
only_required_found = true;
|
||||
MP_RETURN_IF_ERROR(UpdateRanges(
|
||||
detection, options_.detection_shift_vertical(),
|
||||
options_.detection_shift_horizontal(), &xmin, &xmax, &ymin, &ymax));
|
||||
if (last_only_required_detection_ == 0) {
|
||||
// If no detections are available and we never had any,
|
||||
// simply return the full-image rectangle as crop-rect.
|
||||
if (cc->Outputs().HasTag(kCropRect)) {
|
||||
auto default_rect = absl::make_unique<mediapipe::Rect>();
|
||||
default_rect->set_x_center(frame_width_ / 2);
|
||||
default_rect->set_y_center(frame_height_ / 2);
|
||||
default_rect->set_width(frame_width_);
|
||||
default_rect->set_height(frame_height_);
|
||||
cc->Outputs().Tag(kCropRect).Add(default_rect.release(),
|
||||
Timestamp(cc->InputTimestamp()));
|
||||
}
|
||||
// Also provide a first crop rect: in this case a zero-sized one.
|
||||
if (cc->Outputs().HasTag(kFirstCropRect)) {
|
||||
cc->Outputs()
|
||||
.Tag(kFirstCropRect)
|
||||
.Add(new mediapipe::NormalizedRect(),
|
||||
Timestamp(cc->InputTimestamp()));
|
||||
}
|
||||
return absl::OkStatus();
|
||||
}
|
||||
} else {
|
||||
auto raw_detections = cc->Inputs()
|
||||
.Tag(kDetections)
|
||||
.Get<std::vector<mediapipe::Detection>>();
|
||||
for (const auto& detection : raw_detections) {
|
||||
only_required_found = true;
|
||||
MP_RETURN_IF_ERROR(UpdateRanges(
|
||||
detection, options_.detection_shift_vertical(),
|
||||
options_.detection_shift_horizontal(), &xmin, &xmax, &ymin, &ymax));
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// Convert bounds to tilt/zoom and in pixel coordinates.
|
||||
int offset_y, height, offset_x;
|
||||
MP_RETURN_IF_ERROR(ConvertToPanTiltZoom(xmin, xmax, ymin, ymax, &offset_y,
|
||||
&offset_x, &height));
|
||||
bool zooming_to_initial_rect = IsZoomingToInitialRect(cc->InputTimestamp());
|
||||
|
||||
if (only_required_found) {
|
||||
int offset_y, height, offset_x;
|
||||
if (zooming_to_initial_rect) {
|
||||
// If we are zooming to the first rect, ignore any new incoming detections.
|
||||
height = last_measured_height_;
|
||||
offset_x = last_measured_x_offset_;
|
||||
offset_y = last_measured_y_offset_;
|
||||
} else if (only_required_found) {
|
||||
// Convert bounds to tilt/zoom and in pixel coordinates.
|
||||
MP_RETURN_IF_ERROR(ConvertToPanTiltZoom(xmin, xmax, ymin, ymax, &offset_y,
|
||||
&offset_x, &height));
|
||||
// A only required detection was found.
|
||||
last_only_required_detection_ = cc->InputTimestamp().Microseconds();
|
||||
last_measured_height_ = height;
|
||||
|
@ -383,7 +560,9 @@ absl::Status ContentZoomingCalculator::Process(
|
|||
options_.us_before_zoomout()) {
|
||||
// No only_require detections found within salient regions packets
|
||||
// arriving since us_before_zoomout duration.
|
||||
height = max_frame_value_ * frame_height_;
|
||||
height = max_frame_value_ * frame_height_ +
|
||||
(options_.kinematic_options_zoom().min_motion_to_reframe() *
|
||||
(static_cast<float>(frame_height_) / kFieldOfView));
|
||||
offset_x = (target_aspect_ * height) / 2;
|
||||
offset_y = frame_height_ / 2;
|
||||
} else {
|
||||
|
@ -463,17 +642,44 @@ absl::Status ContentZoomingCalculator::Process(
|
|||
.AddPacket(Adopt(features.release()).At(cc->InputTimestamp()));
|
||||
}
|
||||
|
||||
if (first_rect_timestamp_ == Timestamp::Unset() &&
|
||||
options_.us_to_first_rect() != 0) {
|
||||
first_rect_timestamp_ = cc->InputTimestamp();
|
||||
first_rect_.set_x_center(path_offset_x / static_cast<float>(frame_width_));
|
||||
first_rect_.set_width(path_height * target_aspect_ /
|
||||
static_cast<float>(frame_width_));
|
||||
first_rect_.set_y_center(path_offset_y / static_cast<float>(frame_height_));
|
||||
first_rect_.set_height(path_height / static_cast<float>(frame_height_));
|
||||
// After setting the first rectangle, check whether we should zoom to it.
|
||||
zooming_to_initial_rect = IsZoomingToInitialRect(cc->InputTimestamp());
|
||||
}
|
||||
|
||||
// Transmit downstream to glcroppingcalculator.
|
||||
if (cc->Outputs().HasTag(kCropRect)) {
|
||||
auto gpu_rect = absl::make_unique<mediapipe::Rect>();
|
||||
gpu_rect->set_x_center(path_offset_x);
|
||||
gpu_rect->set_width(path_height * target_aspect_);
|
||||
gpu_rect->set_y_center(path_offset_y);
|
||||
gpu_rect->set_height(path_height);
|
||||
std::unique_ptr<mediapipe::Rect> gpu_rect;
|
||||
if (zooming_to_initial_rect) {
|
||||
auto rect = GetInitialZoomingRect(frame_width, frame_height,
|
||||
cc->InputTimestamp());
|
||||
MP_RETURN_IF_ERROR(rect.status());
|
||||
gpu_rect = absl::make_unique<mediapipe::Rect>(*rect);
|
||||
} else {
|
||||
gpu_rect = absl::make_unique<mediapipe::Rect>();
|
||||
gpu_rect->set_x_center(path_offset_x);
|
||||
gpu_rect->set_width(path_height * target_aspect_);
|
||||
gpu_rect->set_y_center(path_offset_y);
|
||||
gpu_rect->set_height(path_height);
|
||||
}
|
||||
cc->Outputs().Tag(kCropRect).Add(gpu_rect.release(),
|
||||
Timestamp(cc->InputTimestamp()));
|
||||
}
|
||||
|
||||
if (cc->Outputs().HasTag(kFirstCropRect)) {
|
||||
cc->Outputs()
|
||||
.Tag(kFirstCropRect)
|
||||
.Add(new mediapipe::NormalizedRect(first_rect_),
|
||||
Timestamp(cc->InputTimestamp()));
|
||||
}
|
||||
|
||||
return absl::OkStatus();
|
||||
}
|
||||
|
||||
|
|
|
@ -19,7 +19,7 @@ package mediapipe.autoflip;
|
|||
import "mediapipe/examples/desktop/autoflip/quality/kinematic_path_solver.proto";
|
||||
import "mediapipe/framework/calculator.proto";
|
||||
|
||||
// NextTag: 14
|
||||
// NextTag: 17
|
||||
message ContentZoomingCalculatorOptions {
|
||||
extend mediapipe.CalculatorOptions {
|
||||
optional ContentZoomingCalculatorOptions ext = 313091992;
|
||||
|
@ -55,6 +55,16 @@ message ContentZoomingCalculatorOptions {
|
|||
// Defines the smallest value in degrees the camera is permitted to zoom.
|
||||
optional float max_zoom_value_deg = 13 [default = 35];
|
||||
|
||||
// Whether to keep state between frames or to compute the final crop rect.
|
||||
optional bool is_stateless = 14 [default = false];
|
||||
|
||||
// Duration (in MicroSeconds) for moving to the first crop rect.
|
||||
optional int64 us_to_first_rect = 15 [default = 0];
|
||||
// Duration (in MicroSeconds) to delay moving to the first crop rect.
|
||||
// Used only if us_to_first_rect is set and is interpreted as part of the
|
||||
// us_to_first_rect time budget.
|
||||
optional int64 us_to_first_rect_delay = 16 [default = 0];
|
||||
|
||||
// Deprecated parameters
|
||||
optional KinematicOptions kinematic_options = 2 [deprecated = true];
|
||||
optional int64 min_motion_to_reframe = 4 [deprecated = true];
|
||||
|
|
|
@ -0,0 +1,38 @@
|
|||
#ifndef MEDIAPIPE_EXAMPLES_DESKTOP_AUTOFLIP_CALCULATORS_CONTENT_ZOOMING_CALCULATOR_STATE_H_
|
||||
#define MEDIAPIPE_EXAMPLES_DESKTOP_AUTOFLIP_CALCULATORS_CONTENT_ZOOMING_CALCULATOR_STATE_H_
|
||||
|
||||
#include <optional>
|
||||
|
||||
#include "mediapipe/examples/desktop/autoflip/quality/kinematic_path_solver.h"
|
||||
#include "mediapipe/framework/formats/rect.pb.h"
|
||||
#include "mediapipe/framework/timestamp.h"
|
||||
|
||||
namespace mediapipe {
|
||||
namespace autoflip {
|
||||
|
||||
struct ContentZoomingCalculatorState {
|
||||
int frame_height = -1;
|
||||
int frame_width = -1;
|
||||
// Path solver used to smooth top/bottom border crop values.
|
||||
KinematicPathSolver path_solver_zoom;
|
||||
KinematicPathSolver path_solver_pan;
|
||||
KinematicPathSolver path_solver_tilt;
|
||||
// Stores the time of the first crop rectangle.
|
||||
Timestamp first_rect_timestamp;
|
||||
// Stores the first crop rectangle.
|
||||
mediapipe::NormalizedRect first_rect;
|
||||
// Stores the time of the last "only_required" input.
|
||||
int64 last_only_required_detection = 0;
|
||||
// Rect values of last message with detection(s).
|
||||
int last_measured_height = 0;
|
||||
int last_measured_x_offset = 0;
|
||||
int last_measured_y_offset = 0;
|
||||
};
|
||||
|
||||
using ContentZoomingCalculatorStateCacheType =
|
||||
std::optional<ContentZoomingCalculatorState>;
|
||||
|
||||
} // namespace autoflip
|
||||
} // namespace mediapipe
|
||||
|
||||
#endif // MEDIAPIPE_EXAMPLES_DESKTOP_AUTOFLIP_CALCULATORS_CONTENT_ZOOMING_CALCULATOR_STATE_H_
|
|
@ -16,6 +16,7 @@
|
|||
|
||||
#include "mediapipe/examples/desktop/autoflip/autoflip_messages.pb.h"
|
||||
#include "mediapipe/examples/desktop/autoflip/calculators/content_zooming_calculator.pb.h"
|
||||
#include "mediapipe/examples/desktop/autoflip/calculators/content_zooming_calculator_state.h"
|
||||
#include "mediapipe/examples/desktop/autoflip/quality/kinematic_path_solver.h"
|
||||
#include "mediapipe/framework/calculator_framework.h"
|
||||
#include "mediapipe/framework/calculator_runner.h"
|
||||
|
@ -109,6 +110,7 @@ const char kConfigD[] = R"(
|
|||
input_stream: "VIDEO_SIZE:size"
|
||||
input_stream: "DETECTIONS:detections"
|
||||
output_stream: "CROP_RECT:rect"
|
||||
output_stream: "FIRST_CROP_RECT:first_rect"
|
||||
options: {
|
||||
[mediapipe.autoflip.ContentZoomingCalculatorOptions.ext]: {
|
||||
max_zoom_value_deg: 0
|
||||
|
@ -147,19 +149,24 @@ void AddDetectionFrameSize(const cv::Rect_<float>& position, const int64 time,
|
|||
const int width, const int height,
|
||||
CalculatorRunner* runner) {
|
||||
auto detections = std::make_unique<std::vector<mediapipe::Detection>>();
|
||||
mediapipe::Detection detection;
|
||||
detection.mutable_location_data()->set_format(
|
||||
mediapipe::LocationData::RELATIVE_BOUNDING_BOX);
|
||||
detection.mutable_location_data()
|
||||
->mutable_relative_bounding_box()
|
||||
->set_height(position.height);
|
||||
detection.mutable_location_data()->mutable_relative_bounding_box()->set_width(
|
||||
position.width);
|
||||
detection.mutable_location_data()->mutable_relative_bounding_box()->set_xmin(
|
||||
position.x);
|
||||
detection.mutable_location_data()->mutable_relative_bounding_box()->set_ymin(
|
||||
position.y);
|
||||
detections->push_back(detection);
|
||||
if (position.width > 0 && position.height > 0) {
|
||||
mediapipe::Detection detection;
|
||||
detection.mutable_location_data()->set_format(
|
||||
mediapipe::LocationData::RELATIVE_BOUNDING_BOX);
|
||||
detection.mutable_location_data()
|
||||
->mutable_relative_bounding_box()
|
||||
->set_height(position.height);
|
||||
detection.mutable_location_data()
|
||||
->mutable_relative_bounding_box()
|
||||
->set_width(position.width);
|
||||
detection.mutable_location_data()
|
||||
->mutable_relative_bounding_box()
|
||||
->set_xmin(position.x);
|
||||
detection.mutable_location_data()
|
||||
->mutable_relative_bounding_box()
|
||||
->set_ymin(position.y);
|
||||
detections->push_back(detection);
|
||||
}
|
||||
runner->MutableInputs()
|
||||
->Tag("DETECTIONS")
|
||||
.packets.push_back(Adopt(detections.release()).At(Timestamp(time)));
|
||||
|
@ -185,7 +192,6 @@ void CheckCropRect(const int x_center, const int y_center, const int width,
|
|||
EXPECT_EQ(rect.width(), width);
|
||||
EXPECT_EQ(rect.height(), height);
|
||||
}
|
||||
|
||||
TEST(ContentZoomingCalculatorTest, ZoomTest) {
|
||||
auto runner = ::absl::make_unique<CalculatorRunner>(
|
||||
ParseTextProtoOrDie<CalculatorGraphConfig::Node>(kConfigA));
|
||||
|
@ -244,6 +250,46 @@ TEST(ContentZoomingCalculatorTest, PanConfig) {
|
|||
runner->Outputs().Tag("CROP_RECT").packets);
|
||||
}
|
||||
|
||||
TEST(ContentZoomingCalculatorTest, PanConfigWithCache) {
|
||||
mediapipe::autoflip::ContentZoomingCalculatorStateCacheType cache;
|
||||
auto config = ParseTextProtoOrDie<CalculatorGraphConfig::Node>(kConfigD);
|
||||
config.add_input_side_packet("STATE_CACHE:state_cache");
|
||||
auto* options = config.mutable_options()->MutableExtension(
|
||||
ContentZoomingCalculatorOptions::ext);
|
||||
options->mutable_kinematic_options_pan()->set_min_motion_to_reframe(0.0);
|
||||
options->mutable_kinematic_options_pan()->set_update_rate_seconds(2);
|
||||
options->mutable_kinematic_options_tilt()->set_min_motion_to_reframe(50.0);
|
||||
options->mutable_kinematic_options_zoom()->set_min_motion_to_reframe(50.0);
|
||||
{
|
||||
auto runner = ::absl::make_unique<CalculatorRunner>(config);
|
||||
runner->MutableSidePackets()->Tag("STATE_CACHE") = MakePacket<
|
||||
mediapipe::autoflip::ContentZoomingCalculatorStateCacheType*>(&cache);
|
||||
AddDetection(cv::Rect_<float>(.4, .5, .1, .1), 0, runner.get());
|
||||
MP_ASSERT_OK(runner->Run());
|
||||
CheckCropRect(450, 550, 111, 111, 0,
|
||||
runner->Outputs().Tag("CROP_RECT").packets);
|
||||
}
|
||||
{
|
||||
auto runner = ::absl::make_unique<CalculatorRunner>(config);
|
||||
runner->MutableSidePackets()->Tag("STATE_CACHE") = MakePacket<
|
||||
mediapipe::autoflip::ContentZoomingCalculatorStateCacheType*>(&cache);
|
||||
AddDetection(cv::Rect_<float>(.45, .55, .15, .15), 1000000, runner.get());
|
||||
MP_ASSERT_OK(runner->Run());
|
||||
CheckCropRect(483, 550, 111, 111, 0,
|
||||
runner->Outputs().Tag("CROP_RECT").packets);
|
||||
}
|
||||
// Now repeat the last frame for a new runner without the cache to see a reset
|
||||
{
|
||||
auto runner = ::absl::make_unique<CalculatorRunner>(config);
|
||||
runner->MutableSidePackets()->Tag("STATE_CACHE") = MakePacket<
|
||||
mediapipe::autoflip::ContentZoomingCalculatorStateCacheType*>(nullptr);
|
||||
AddDetection(cv::Rect_<float>(.45, .55, .15, .15), 2000000, runner.get());
|
||||
MP_ASSERT_OK(runner->Run());
|
||||
CheckCropRect(525, 625, 166, 166, 0, // Without a cache, state was lost.
|
||||
runner->Outputs().Tag("CROP_RECT").packets);
|
||||
}
|
||||
}
|
||||
|
||||
TEST(ContentZoomingCalculatorTest, TiltConfig) {
|
||||
auto config = ParseTextProtoOrDie<CalculatorGraphConfig::Node>(kConfigD);
|
||||
auto* options = config.mutable_options()->MutableExtension(
|
||||
|
@ -280,6 +326,46 @@ TEST(ContentZoomingCalculatorTest, ZoomConfig) {
|
|||
runner->Outputs().Tag("CROP_RECT").packets);
|
||||
}
|
||||
|
||||
TEST(ContentZoomingCalculatorTest, ZoomConfigWithCache) {
|
||||
mediapipe::autoflip::ContentZoomingCalculatorStateCacheType cache;
|
||||
auto config = ParseTextProtoOrDie<CalculatorGraphConfig::Node>(kConfigD);
|
||||
config.add_input_side_packet("STATE_CACHE:state_cache");
|
||||
auto* options = config.mutable_options()->MutableExtension(
|
||||
ContentZoomingCalculatorOptions::ext);
|
||||
options->mutable_kinematic_options_pan()->set_min_motion_to_reframe(50.0);
|
||||
options->mutable_kinematic_options_tilt()->set_min_motion_to_reframe(50.0);
|
||||
options->mutable_kinematic_options_zoom()->set_min_motion_to_reframe(0.0);
|
||||
options->mutable_kinematic_options_zoom()->set_update_rate_seconds(2);
|
||||
{
|
||||
auto runner = ::absl::make_unique<CalculatorRunner>(config);
|
||||
runner->MutableSidePackets()->Tag("STATE_CACHE") = MakePacket<
|
||||
mediapipe::autoflip::ContentZoomingCalculatorStateCacheType*>(&cache);
|
||||
AddDetection(cv::Rect_<float>(.4, .5, .1, .1), 0, runner.get());
|
||||
MP_ASSERT_OK(runner->Run());
|
||||
CheckCropRect(450, 550, 111, 111, 0,
|
||||
runner->Outputs().Tag("CROP_RECT").packets);
|
||||
}
|
||||
{
|
||||
auto runner = ::absl::make_unique<CalculatorRunner>(config);
|
||||
runner->MutableSidePackets()->Tag("STATE_CACHE") = MakePacket<
|
||||
mediapipe::autoflip::ContentZoomingCalculatorStateCacheType*>(&cache);
|
||||
AddDetection(cv::Rect_<float>(.45, .55, .15, .15), 1000000, runner.get());
|
||||
MP_ASSERT_OK(runner->Run());
|
||||
CheckCropRect(450, 550, 139, 139, 0,
|
||||
runner->Outputs().Tag("CROP_RECT").packets);
|
||||
}
|
||||
// Now repeat the last frame for a new runner without the cache to see a reset
|
||||
{
|
||||
auto runner = ::absl::make_unique<CalculatorRunner>(config);
|
||||
runner->MutableSidePackets()->Tag("STATE_CACHE") = MakePacket<
|
||||
mediapipe::autoflip::ContentZoomingCalculatorStateCacheType*>(nullptr);
|
||||
AddDetection(cv::Rect_<float>(.45, .55, .15, .15), 2000000, runner.get());
|
||||
MP_ASSERT_OK(runner->Run());
|
||||
CheckCropRect(525, 625, 166, 166, 0, // Without a cache, state was lost.
|
||||
runner->Outputs().Tag("CROP_RECT").packets);
|
||||
}
|
||||
}
|
||||
|
||||
TEST(ContentZoomingCalculatorTest, MinAspectBorderValues) {
|
||||
auto runner = ::absl::make_unique<CalculatorRunner>(
|
||||
ParseTextProtoOrDie<CalculatorGraphConfig::Node>(kConfigB));
|
||||
|
@ -509,6 +595,32 @@ TEST(ContentZoomingCalculatorTest, ResolutionChangeStationary) {
|
|||
runner->Outputs().Tag("CROP_RECT").packets);
|
||||
}
|
||||
|
||||
TEST(ContentZoomingCalculatorTest, ResolutionChangeStationaryWithCache) {
|
||||
mediapipe::autoflip::ContentZoomingCalculatorStateCacheType cache;
|
||||
auto config = ParseTextProtoOrDie<CalculatorGraphConfig::Node>(kConfigD);
|
||||
config.add_input_side_packet("STATE_CACHE:state_cache");
|
||||
{
|
||||
auto runner = ::absl::make_unique<CalculatorRunner>(config);
|
||||
runner->MutableSidePackets()->Tag("STATE_CACHE") = MakePacket<
|
||||
mediapipe::autoflip::ContentZoomingCalculatorStateCacheType*>(&cache);
|
||||
AddDetectionFrameSize(cv::Rect_<float>(.4, .4, .2, .2), 0, 1000, 1000,
|
||||
runner.get());
|
||||
MP_ASSERT_OK(runner->Run());
|
||||
CheckCropRect(500, 500, 222, 222, 0,
|
||||
runner->Outputs().Tag("CROP_RECT").packets);
|
||||
}
|
||||
{
|
||||
auto runner = ::absl::make_unique<CalculatorRunner>(config);
|
||||
runner->MutableSidePackets()->Tag("STATE_CACHE") = MakePacket<
|
||||
mediapipe::autoflip::ContentZoomingCalculatorStateCacheType*>(&cache);
|
||||
AddDetectionFrameSize(cv::Rect_<float>(.4, .4, .2, .2), 1, 500, 500,
|
||||
runner.get());
|
||||
MP_ASSERT_OK(runner->Run());
|
||||
CheckCropRect(500 * 0.5, 500 * 0.5, 222 * 0.5, 222 * 0.5, 0,
|
||||
runner->Outputs().Tag("CROP_RECT").packets);
|
||||
}
|
||||
}
|
||||
|
||||
TEST(ContentZoomingCalculatorTest, ResolutionChangeZooming) {
|
||||
auto config = ParseTextProtoOrDie<CalculatorGraphConfig::Node>(kConfigD);
|
||||
auto runner = ::absl::make_unique<CalculatorRunner>(config);
|
||||
|
@ -527,6 +639,37 @@ TEST(ContentZoomingCalculatorTest, ResolutionChangeZooming) {
|
|||
runner->Outputs().Tag("CROP_RECT").packets);
|
||||
}
|
||||
|
||||
TEST(ContentZoomingCalculatorTest, ResolutionChangeZoomingWithCache) {
|
||||
mediapipe::autoflip::ContentZoomingCalculatorStateCacheType cache;
|
||||
auto config = ParseTextProtoOrDie<CalculatorGraphConfig::Node>(kConfigD);
|
||||
config.add_input_side_packet("STATE_CACHE:state_cache");
|
||||
{
|
||||
auto runner = ::absl::make_unique<CalculatorRunner>(config);
|
||||
runner->MutableSidePackets()->Tag("STATE_CACHE") = MakePacket<
|
||||
mediapipe::autoflip::ContentZoomingCalculatorStateCacheType*>(&cache);
|
||||
AddDetectionFrameSize(cv::Rect_<float>(.1, .1, .8, .8), 0, 1000, 1000,
|
||||
runner.get());
|
||||
MP_ASSERT_OK(runner->Run());
|
||||
CheckCropRect(500, 500, 888, 888, 0,
|
||||
runner->Outputs().Tag("CROP_RECT").packets);
|
||||
}
|
||||
// The second runner should just resume based on state from the first runner.
|
||||
{
|
||||
auto runner = ::absl::make_unique<CalculatorRunner>(config);
|
||||
runner->MutableSidePackets()->Tag("STATE_CACHE") = MakePacket<
|
||||
mediapipe::autoflip::ContentZoomingCalculatorStateCacheType*>(&cache);
|
||||
AddDetectionFrameSize(cv::Rect_<float>(.4, .4, .2, .2), 1000000, 1000, 1000,
|
||||
runner.get());
|
||||
AddDetectionFrameSize(cv::Rect_<float>(.4, .4, .2, .2), 2000000, 500, 500,
|
||||
runner.get());
|
||||
MP_ASSERT_OK(runner->Run());
|
||||
CheckCropRect(500, 500, 588, 588, 0,
|
||||
runner->Outputs().Tag("CROP_RECT").packets);
|
||||
CheckCropRect(500 * 0.5, 500 * 0.5, 288 * 0.5, 288 * 0.5, 1,
|
||||
runner->Outputs().Tag("CROP_RECT").packets);
|
||||
}
|
||||
}
|
||||
|
||||
TEST(ContentZoomingCalculatorTest, MaxZoomValue) {
|
||||
auto config = ParseTextProtoOrDie<CalculatorGraphConfig::Node>(kConfigD);
|
||||
auto* options = config.mutable_options()->MutableExtension(
|
||||
|
@ -540,6 +683,108 @@ TEST(ContentZoomingCalculatorTest, MaxZoomValue) {
|
|||
CheckCropRect(500, 500, 916, 916, 0,
|
||||
runner->Outputs().Tag("CROP_RECT").packets);
|
||||
}
|
||||
TEST(ContentZoomingCalculatorTest, MaxZoomOutValue) {
|
||||
auto config = ParseTextProtoOrDie<CalculatorGraphConfig::Node>(kConfigD);
|
||||
auto* options = config.mutable_options()->MutableExtension(
|
||||
ContentZoomingCalculatorOptions::ext);
|
||||
options->set_scale_factor(1.0);
|
||||
options->mutable_kinematic_options_zoom()->set_min_motion_to_reframe(5.0);
|
||||
auto runner = ::absl::make_unique<CalculatorRunner>(config);
|
||||
AddDetectionFrameSize(cv::Rect_<float>(.025, .025, .95, .95), 0, 1000, 1000,
|
||||
runner.get());
|
||||
AddDetectionFrameSize(cv::Rect_<float>(0, 0, -1, -1), 1000000, 1000, 1000,
|
||||
runner.get());
|
||||
AddDetectionFrameSize(cv::Rect_<float>(0, 0, -1, -1), 2000000, 1000, 1000,
|
||||
runner.get());
|
||||
MP_ASSERT_OK(runner->Run());
|
||||
// 55/60 * 1000 = 916
|
||||
CheckCropRect(500, 500, 950, 950, 0,
|
||||
runner->Outputs().Tag("CROP_RECT").packets);
|
||||
CheckCropRect(500, 500, 1000, 1000, 2,
|
||||
runner->Outputs().Tag("CROP_RECT").packets);
|
||||
}
|
||||
TEST(ContentZoomingCalculatorTest, StartZoomedOut) {
|
||||
auto config = ParseTextProtoOrDie<CalculatorGraphConfig::Node>(kConfigD);
|
||||
auto* options = config.mutable_options()->MutableExtension(
|
||||
ContentZoomingCalculatorOptions::ext);
|
||||
options->set_us_to_first_rect(1000000);
|
||||
options->set_us_to_first_rect_delay(500000);
|
||||
auto runner = ::absl::make_unique<CalculatorRunner>(config);
|
||||
AddDetectionFrameSize(cv::Rect_<float>(.4, .4, .2, .2), 0, 1000, 1000,
|
||||
runner.get());
|
||||
AddDetectionFrameSize(cv::Rect_<float>(.4, .4, .2, .2), 400000, 1000, 1000,
|
||||
runner.get());
|
||||
AddDetectionFrameSize(cv::Rect_<float>(.4, .4, .2, .2), 800000, 1000, 1000,
|
||||
runner.get());
|
||||
AddDetectionFrameSize(cv::Rect_<float>(.4, .4, .2, .2), 1000000, 1000, 1000,
|
||||
runner.get());
|
||||
AddDetectionFrameSize(cv::Rect_<float>(.4, .4, .2, .2), 1500000, 1000, 1000,
|
||||
runner.get());
|
||||
MP_ASSERT_OK(runner->Run());
|
||||
CheckCropRect(500, 500, 1000, 1000, 0,
|
||||
runner->Outputs().Tag("CROP_RECT").packets);
|
||||
CheckCropRect(500, 500, 1000, 1000, 1,
|
||||
runner->Outputs().Tag("CROP_RECT").packets);
|
||||
CheckCropRect(500, 500, 470, 470, 2,
|
||||
runner->Outputs().Tag("CROP_RECT").packets);
|
||||
CheckCropRect(500, 500, 222, 222, 3,
|
||||
runner->Outputs().Tag("CROP_RECT").packets);
|
||||
CheckCropRect(500, 500, 222, 222, 4,
|
||||
runner->Outputs().Tag("CROP_RECT").packets);
|
||||
}
|
||||
|
||||
TEST(ContentZoomingCalculatorTest, ProvidesZeroSizeFirstRectWithoutDetections) {
|
||||
auto config = ParseTextProtoOrDie<CalculatorGraphConfig::Node>(kConfigD);
|
||||
auto runner = ::absl::make_unique<CalculatorRunner>(config);
|
||||
|
||||
auto input_size = ::absl::make_unique<std::pair<int, int>>(1000, 1000);
|
||||
runner->MutableInputs()
|
||||
->Tag("VIDEO_SIZE")
|
||||
.packets.push_back(Adopt(input_size.release()).At(Timestamp(0)));
|
||||
|
||||
MP_ASSERT_OK(runner->Run());
|
||||
|
||||
const std::vector<Packet>& output_packets =
|
||||
runner->Outputs().Tag("FIRST_CROP_RECT").packets;
|
||||
ASSERT_EQ(output_packets.size(), 1);
|
||||
const auto& rect = output_packets[0].Get<mediapipe::NormalizedRect>();
|
||||
EXPECT_EQ(rect.x_center(), 0);
|
||||
EXPECT_EQ(rect.y_center(), 0);
|
||||
EXPECT_EQ(rect.width(), 0);
|
||||
EXPECT_EQ(rect.height(), 0);
|
||||
}
|
||||
|
||||
TEST(ContentZoomingCalculatorTest, ProvidesConstantFirstRect) {
|
||||
auto config = ParseTextProtoOrDie<CalculatorGraphConfig::Node>(kConfigD);
|
||||
auto* options = config.mutable_options()->MutableExtension(
|
||||
ContentZoomingCalculatorOptions::ext);
|
||||
options->set_us_to_first_rect(500000);
|
||||
auto runner = ::absl::make_unique<CalculatorRunner>(config);
|
||||
AddDetectionFrameSize(cv::Rect_<float>(.4, .4, .2, .2), 0, 1000, 1000,
|
||||
runner.get());
|
||||
AddDetectionFrameSize(cv::Rect_<float>(.4, .4, .2, .2), 500000, 1000, 1000,
|
||||
runner.get());
|
||||
AddDetectionFrameSize(cv::Rect_<float>(.4, .4, .2, .2), 1000000, 1000, 1000,
|
||||
runner.get());
|
||||
AddDetectionFrameSize(cv::Rect_<float>(.4, .4, .2, .2), 1500000, 1000, 1000,
|
||||
runner.get());
|
||||
MP_ASSERT_OK(runner->Run());
|
||||
const std::vector<Packet>& output_packets =
|
||||
runner->Outputs().Tag("FIRST_CROP_RECT").packets;
|
||||
ASSERT_EQ(output_packets.size(), 4);
|
||||
const auto& first_rect = output_packets[0].Get<mediapipe::NormalizedRect>();
|
||||
EXPECT_NEAR(first_rect.x_center(), 0.5, 0.05);
|
||||
EXPECT_NEAR(first_rect.y_center(), 0.5, 0.05);
|
||||
EXPECT_NEAR(first_rect.width(), 0.222, 0.05);
|
||||
EXPECT_NEAR(first_rect.height(), 0.222, 0.05);
|
||||
for (int i = 1; i < 4; ++i) {
|
||||
const auto& rect = output_packets[i].Get<mediapipe::NormalizedRect>();
|
||||
EXPECT_EQ(first_rect.x_center(), rect.x_center());
|
||||
EXPECT_EQ(first_rect.y_center(), rect.y_center());
|
||||
EXPECT_EQ(first_rect.width(), rect.width());
|
||||
EXPECT_EQ(first_rect.height(), rect.height());
|
||||
}
|
||||
}
|
||||
|
||||
} // namespace
|
||||
} // namespace autoflip
|
||||
|
|
|
@ -12,6 +12,7 @@
|
|||
// See the License for the specific language governing permissions and
|
||||
// limitations under the License.
|
||||
|
||||
#include "absl/flags/flag.h"
|
||||
#include "absl/strings/string_view.h"
|
||||
#include "mediapipe/examples/desktop/autoflip/calculators/shot_boundary_calculator.pb.h"
|
||||
#include "mediapipe/framework/calculator_framework.h"
|
||||
|
@ -19,7 +20,6 @@
|
|||
#include "mediapipe/framework/deps/file_path.h"
|
||||
#include "mediapipe/framework/formats/image_frame.h"
|
||||
#include "mediapipe/framework/formats/image_frame_opencv.h"
|
||||
#include "mediapipe/framework/port/commandlineflags.h"
|
||||
#include "mediapipe/framework/port/gmock.h"
|
||||
#include "mediapipe/framework/port/gtest.h"
|
||||
#include "mediapipe/framework/port/opencv_core_inc.h"
|
||||
|
|
|
@ -46,7 +46,9 @@ proto_library(
|
|||
mediapipe_cc_proto_library(
|
||||
name = "kinematic_path_solver_cc_proto",
|
||||
srcs = ["kinematic_path_solver.proto"],
|
||||
visibility = ["//mediapipe/examples:__subpackages__"],
|
||||
visibility = [
|
||||
"//mediapipe/examples:__subpackages__",
|
||||
],
|
||||
deps = [":kinematic_path_solver_proto"],
|
||||
)
|
||||
|
||||
|
@ -96,11 +98,11 @@ cc_library(
|
|||
deps = [
|
||||
"//mediapipe/framework/formats:image_frame",
|
||||
"//mediapipe/framework/formats:image_frame_opencv",
|
||||
"//mediapipe/framework/port:commandlineflags",
|
||||
"//mediapipe/framework/port:opencv_core",
|
||||
"//mediapipe/framework/port:opencv_imgproc",
|
||||
"//mediapipe/framework/port:ret_check",
|
||||
"//mediapipe/framework/port:status",
|
||||
"@com_google_absl//absl/flags:flag",
|
||||
],
|
||||
)
|
||||
|
||||
|
@ -249,10 +251,10 @@ cc_test(
|
|||
":scene_camera_motion_analyzer",
|
||||
"//mediapipe/examples/desktop/autoflip:autoflip_messages_cc_proto",
|
||||
"//mediapipe/framework/deps:file_path",
|
||||
"//mediapipe/framework/port:commandlineflags",
|
||||
"//mediapipe/framework/port:file_helpers",
|
||||
"//mediapipe/framework/port:gtest_main",
|
||||
"//mediapipe/framework/port:status",
|
||||
"@com_google_absl//absl/flags:flag",
|
||||
"@com_google_absl//absl/strings",
|
||||
],
|
||||
)
|
||||
|
@ -280,13 +282,13 @@ cc_test(
|
|||
"//mediapipe/framework/deps:file_path",
|
||||
"//mediapipe/framework/formats:image_frame",
|
||||
"//mediapipe/framework/formats:image_frame_opencv",
|
||||
"//mediapipe/framework/port:commandlineflags",
|
||||
"//mediapipe/framework/port:file_helpers",
|
||||
"//mediapipe/framework/port:gtest_main",
|
||||
"//mediapipe/framework/port:opencv_imgcodecs",
|
||||
"//mediapipe/framework/port:opencv_imgproc",
|
||||
"//mediapipe/framework/port:ret_check",
|
||||
"//mediapipe/framework/port:status",
|
||||
"@com_google_absl//absl/flags:flag",
|
||||
"@com_google_absl//absl/strings",
|
||||
],
|
||||
)
|
||||
|
|
|
@ -14,11 +14,11 @@
|
|||
|
||||
#include "mediapipe/examples/desktop/autoflip/quality/padding_effect_generator.h"
|
||||
|
||||
#include "absl/flags/flag.h"
|
||||
#include "absl/strings/str_cat.h"
|
||||
#include "mediapipe/framework/deps/file_path.h"
|
||||
#include "mediapipe/framework/formats/image_frame.h"
|
||||
#include "mediapipe/framework/formats/image_frame_opencv.h"
|
||||
#include "mediapipe/framework/port/commandlineflags.h"
|
||||
#include "mediapipe/framework/port/file_helpers.h"
|
||||
#include "mediapipe/framework/port/gmock.h"
|
||||
#include "mediapipe/framework/port/gtest.h"
|
||||
|
@ -28,8 +28,9 @@
|
|||
#include "mediapipe/framework/port/status_builder.h"
|
||||
#include "mediapipe/framework/port/status_matchers.h"
|
||||
|
||||
DEFINE_string(input_image, "", "The path to an input image.");
|
||||
DEFINE_string(output_folder, "", "The folder to output test result images.");
|
||||
ABSL_FLAG(std::string, input_image, "", "The path to an input image.");
|
||||
ABSL_FLAG(std::string, output_folder, "",
|
||||
"The folder to output test result images.");
|
||||
|
||||
namespace mediapipe {
|
||||
namespace autoflip {
|
||||
|
|
|
@ -19,12 +19,12 @@
|
|||
#include <string>
|
||||
#include <vector>
|
||||
|
||||
#include "absl/flags/flag.h"
|
||||
#include "absl/strings/str_split.h"
|
||||
#include "mediapipe/examples/desktop/autoflip/autoflip_messages.pb.h"
|
||||
#include "mediapipe/examples/desktop/autoflip/quality/focus_point.pb.h"
|
||||
#include "mediapipe/examples/desktop/autoflip/quality/piecewise_linear_function.h"
|
||||
#include "mediapipe/framework/deps/file_path.h"
|
||||
#include "mediapipe/framework/port/commandlineflags.h"
|
||||
#include "mediapipe/framework/port/file_helpers.h"
|
||||
#include "mediapipe/framework/port/gmock.h"
|
||||
#include "mediapipe/framework/port/gtest.h"
|
||||
|
|
|
@ -202,7 +202,7 @@ absl::Status DrawFocusPointAndCropWindow(
|
|||
const auto& point = focus_point_frames[i].point(j);
|
||||
const int x = point.norm_point_x() * scene_frame.cols;
|
||||
const int y = point.norm_point_y() * scene_frame.rows;
|
||||
cv::circle(viz_mat, cv::Point(x, y), 3, kRed, CV_FILLED);
|
||||
cv::circle(viz_mat, cv::Point(x, y), 3, kRed, cv::FILLED);
|
||||
center_x += x;
|
||||
center_y += y;
|
||||
}
|
||||
|
|
|
@ -15,10 +15,11 @@
|
|||
// An example of sending OpenCV webcam frames into a MediaPipe graph.
|
||||
#include <cstdlib>
|
||||
|
||||
#include "absl/flags/flag.h"
|
||||
#include "absl/flags/parse.h"
|
||||
#include "mediapipe/framework/calculator_framework.h"
|
||||
#include "mediapipe/framework/formats/image_frame.h"
|
||||
#include "mediapipe/framework/formats/image_frame_opencv.h"
|
||||
#include "mediapipe/framework/port/commandlineflags.h"
|
||||
#include "mediapipe/framework/port/file_helpers.h"
|
||||
#include "mediapipe/framework/port/opencv_highgui_inc.h"
|
||||
#include "mediapipe/framework/port/opencv_imgproc_inc.h"
|
||||
|
@ -30,15 +31,14 @@ constexpr char kInputStream[] = "input_video";
|
|||
constexpr char kOutputStream[] = "output_video";
|
||||
constexpr char kWindowName[] = "MediaPipe";
|
||||
|
||||
DEFINE_string(
|
||||
calculator_graph_config_file, "",
|
||||
"Name of file containing text format CalculatorGraphConfig proto.");
|
||||
DEFINE_string(input_video_path, "",
|
||||
"Full path of video to load. "
|
||||
"If not provided, attempt to use a webcam.");
|
||||
DEFINE_string(output_video_path, "",
|
||||
"Full path of where to save result (.mp4 only). "
|
||||
"If not provided, show result in a window.");
|
||||
ABSL_FLAG(std::string, calculator_graph_config_file, "",
|
||||
"Name of file containing text format CalculatorGraphConfig proto.");
|
||||
ABSL_FLAG(std::string, input_video_path, "",
|
||||
"Full path of video to load. "
|
||||
"If not provided, attempt to use a webcam.");
|
||||
ABSL_FLAG(std::string, output_video_path, "",
|
||||
"Full path of where to save result (.mp4 only). "
|
||||
"If not provided, show result in a window.");
|
||||
|
||||
absl::Status RunMPPGraph() {
|
||||
std::string calculator_graph_config_contents;
|
||||
|
@ -148,7 +148,7 @@ absl::Status RunMPPGraph() {
|
|||
|
||||
int main(int argc, char** argv) {
|
||||
google::InitGoogleLogging(argv[0]);
|
||||
gflags::ParseCommandLineFlags(&argc, &argv, true);
|
||||
absl::ParseCommandLine(argc, argv);
|
||||
absl::Status run_status = RunMPPGraph();
|
||||
if (!run_status.ok()) {
|
||||
LOG(ERROR) << "Failed to run the graph: " << run_status.message();
|
||||
|
|
|
@ -16,10 +16,11 @@
|
|||
// This example requires a linux computer and a GPU with EGL support drivers.
|
||||
#include <cstdlib>
|
||||
|
||||
#include "absl/flags/flag.h"
|
||||
#include "absl/flags/parse.h"
|
||||
#include "mediapipe/framework/calculator_framework.h"
|
||||
#include "mediapipe/framework/formats/image_frame.h"
|
||||
#include "mediapipe/framework/formats/image_frame_opencv.h"
|
||||
#include "mediapipe/framework/port/commandlineflags.h"
|
||||
#include "mediapipe/framework/port/file_helpers.h"
|
||||
#include "mediapipe/framework/port/opencv_highgui_inc.h"
|
||||
#include "mediapipe/framework/port/opencv_imgproc_inc.h"
|
||||
|
@ -34,15 +35,14 @@ constexpr char kInputStream[] = "input_video";
|
|||
constexpr char kOutputStream[] = "output_video";
|
||||
constexpr char kWindowName[] = "MediaPipe";
|
||||
|
||||
DEFINE_string(
|
||||
calculator_graph_config_file, "",
|
||||
"Name of file containing text format CalculatorGraphConfig proto.");
|
||||
DEFINE_string(input_video_path, "",
|
||||
"Full path of video to load. "
|
||||
"If not provided, attempt to use a webcam.");
|
||||
DEFINE_string(output_video_path, "",
|
||||
"Full path of where to save result (.mp4 only). "
|
||||
"If not provided, show result in a window.");
|
||||
ABSL_FLAG(std::string, calculator_graph_config_file, "",
|
||||
"Name of file containing text format CalculatorGraphConfig proto.");
|
||||
ABSL_FLAG(std::string, input_video_path, "",
|
||||
"Full path of video to load. "
|
||||
"If not provided, attempt to use a webcam.");
|
||||
ABSL_FLAG(std::string, output_video_path, "",
|
||||
"Full path of where to save result (.mp4 only). "
|
||||
"If not provided, show result in a window.");
|
||||
|
||||
absl::Status RunMPPGraph() {
|
||||
std::string calculator_graph_config_contents;
|
||||
|
@ -191,7 +191,7 @@ absl::Status RunMPPGraph() {
|
|||
|
||||
int main(int argc, char** argv) {
|
||||
google::InitGoogleLogging(argv[0]);
|
||||
gflags::ParseCommandLineFlags(&argc, &argv, true);
|
||||
absl::ParseCommandLine(argc, argv);
|
||||
absl::Status run_status = RunMPPGraph();
|
||||
if (!run_status.ok()) {
|
||||
LOG(ERROR) << "Failed to run the graph: " << run_status.message();
|
||||
|
|
|
@ -23,7 +23,6 @@ cc_binary(
|
|||
"//mediapipe/framework:calculator_framework",
|
||||
"//mediapipe/framework/formats:image_frame",
|
||||
"//mediapipe/framework/formats:image_frame_opencv",
|
||||
"//mediapipe/framework/port:commandlineflags",
|
||||
"//mediapipe/framework/port:file_helpers",
|
||||
"//mediapipe/framework/port:opencv_highgui",
|
||||
"//mediapipe/framework/port:opencv_imgproc",
|
||||
|
@ -31,6 +30,8 @@ cc_binary(
|
|||
"//mediapipe/framework/port:parse_text_proto",
|
||||
"//mediapipe/framework/port:status",
|
||||
"//mediapipe/graphs/iris_tracking:iris_depth_cpu_deps",
|
||||
"@com_google_absl//absl/flags:flag",
|
||||
"@com_google_absl//absl/flags:parse",
|
||||
],
|
||||
)
|
||||
|
||||
|
|
|
@ -17,11 +17,12 @@
|
|||
#include <cstdlib>
|
||||
#include <memory>
|
||||
|
||||
#include "absl/flags/flag.h"
|
||||
#include "absl/flags/parse.h"
|
||||
#include "mediapipe/framework/calculator_framework.h"
|
||||
#include "mediapipe/framework/formats/image_frame.h"
|
||||
#include "mediapipe/framework/formats/image_frame_opencv.h"
|
||||
#include "mediapipe/framework/port/canonical_errors.h"
|
||||
#include "mediapipe/framework/port/commandlineflags.h"
|
||||
#include "mediapipe/framework/port/file_helpers.h"
|
||||
#include "mediapipe/framework/port/opencv_highgui_inc.h"
|
||||
#include "mediapipe/framework/port/opencv_imgproc_inc.h"
|
||||
|
@ -38,12 +39,12 @@ constexpr char kCalculatorGraphConfigFile[] =
|
|||
"mediapipe/graphs/iris_tracking/iris_depth_cpu.pbtxt";
|
||||
constexpr float kMicrosPerSecond = 1e6;
|
||||
|
||||
DEFINE_string(input_image_path, "",
|
||||
"Full path of image to load. "
|
||||
"If not provided, nothing will run.");
|
||||
DEFINE_string(output_image_path, "",
|
||||
"Full path of where to save image result (.jpg only). "
|
||||
"If not provided, show result in a window.");
|
||||
ABSL_FLAG(std::string, input_image_path, "",
|
||||
"Full path of image to load. "
|
||||
"If not provided, nothing will run.");
|
||||
ABSL_FLAG(std::string, output_image_path, "",
|
||||
"Full path of where to save image result (.jpg only). "
|
||||
"If not provided, show result in a window.");
|
||||
|
||||
namespace {
|
||||
|
||||
|
@ -148,7 +149,7 @@ absl::Status RunMPPGraph() {
|
|||
|
||||
int main(int argc, char** argv) {
|
||||
google::InitGoogleLogging(argv[0]);
|
||||
gflags::ParseCommandLineFlags(&argc, &argv, true);
|
||||
absl::ParseCommandLine(argc, argv);
|
||||
absl::Status run_status = RunMPPGraph();
|
||||
if (!run_status.ok()) {
|
||||
LOG(ERROR) << "Failed to run the graph: " << run_status.message();
|
||||
|
|
|
@ -21,11 +21,12 @@ cc_library(
|
|||
srcs = ["run_graph_file_io_main.cc"],
|
||||
deps = [
|
||||
"//mediapipe/framework:calculator_framework",
|
||||
"//mediapipe/framework/port:commandlineflags",
|
||||
"//mediapipe/framework/port:file_helpers",
|
||||
"//mediapipe/framework/port:map_util",
|
||||
"//mediapipe/framework/port:parse_text_proto",
|
||||
"//mediapipe/framework/port:status",
|
||||
"@com_google_absl//absl/flags:flag",
|
||||
"@com_google_absl//absl/flags:parse",
|
||||
"@com_google_absl//absl/strings",
|
||||
],
|
||||
)
|
||||
|
|
|
@ -17,26 +17,26 @@
|
|||
// to disk.
|
||||
#include <cstdlib>
|
||||
|
||||
#include "absl/flags/flag.h"
|
||||
#include "absl/flags/parse.h"
|
||||
#include "absl/strings/str_split.h"
|
||||
#include "mediapipe/framework/calculator_framework.h"
|
||||
#include "mediapipe/framework/port/commandlineflags.h"
|
||||
#include "mediapipe/framework/port/file_helpers.h"
|
||||
#include "mediapipe/framework/port/map_util.h"
|
||||
#include "mediapipe/framework/port/parse_text_proto.h"
|
||||
#include "mediapipe/framework/port/status.h"
|
||||
|
||||
DEFINE_string(
|
||||
calculator_graph_config_file, "",
|
||||
"Name of file containing text format CalculatorGraphConfig proto.");
|
||||
DEFINE_string(input_side_packets, "",
|
||||
"Comma-separated list of key=value pairs specifying side packets "
|
||||
"and corresponding file paths for the CalculatorGraph. The side "
|
||||
"packets are read from the files and fed to the graph as strings "
|
||||
"even if they represent doubles, floats, etc.");
|
||||
DEFINE_string(output_side_packets, "",
|
||||
"Comma-separated list of key=value pairs specifying the output "
|
||||
"side packets and paths to write to disk for the "
|
||||
"CalculatorGraph.");
|
||||
ABSL_FLAG(std::string, calculator_graph_config_file, "",
|
||||
"Name of file containing text format CalculatorGraphConfig proto.");
|
||||
ABSL_FLAG(std::string, input_side_packets, "",
|
||||
"Comma-separated list of key=value pairs specifying side packets "
|
||||
"and corresponding file paths for the CalculatorGraph. The side "
|
||||
"packets are read from the files and fed to the graph as strings "
|
||||
"even if they represent doubles, floats, etc.");
|
||||
ABSL_FLAG(std::string, output_side_packets, "",
|
||||
"Comma-separated list of key=value pairs specifying the output "
|
||||
"side packets and paths to write to disk for the "
|
||||
"CalculatorGraph.");
|
||||
|
||||
absl::Status RunMPPGraph() {
|
||||
std::string calculator_graph_config_contents;
|
||||
|
@ -85,7 +85,7 @@ absl::Status RunMPPGraph() {
|
|||
|
||||
int main(int argc, char** argv) {
|
||||
google::InitGoogleLogging(argv[0]);
|
||||
gflags::ParseCommandLineFlags(&argc, &argv, true);
|
||||
absl::ParseCommandLine(argc, argv);
|
||||
absl::Status run_status = RunMPPGraph();
|
||||
if (!run_status.ok()) {
|
||||
LOG(ERROR) << "Failed to run the graph: " << run_status.message();
|
||||
|
|
|
@ -20,7 +20,7 @@ package(default_visibility = ["//mediapipe/examples:__subpackages__"])
|
|||
# To run 3D object detection for shoes,
|
||||
# bazel-bin/mediapipe/examples/desktop/object_detection_3d/objectron_cpu \
|
||||
# --calculator_graph_config_file=mediapipe/graphs/object_detection_3d/objectron_desktop_cpu.pbtxt \
|
||||
# --input_side_packets="input_video_path=<input_video_path>,box_landmark_model_path=mediapipe/models/object_detection_3d_sneakers.tflite,output_video_path=<output_video_path>,allowed_labels=Footwear"
|
||||
# --input_side_packets="input_video_path=<input_video_path>,box_landmark_model_path=mediapipe/modules/objectron/object_detection_3d_sneakers.tflite,output_video_path=<output_video_path>,allowed_labels=Footwear"
|
||||
# To detect objects from other categories, change box_landmark_model_path and allowed_labels accordingly.
|
||||
# Chair: box_landmark_model_path=mediapipe/modules/objectron/object_detection_3d_chair.tflite,allowed_labels=Chair
|
||||
# Camera: box_landmark_model_path=mediapipe/modules/objectron/object_detection_3d_camera.tflite,allowed_labels=Camera
|
||||
|
|
|
@ -20,11 +20,12 @@
|
|||
#include <string>
|
||||
#include <vector>
|
||||
|
||||
#include "absl/flags/flag.h"
|
||||
#include "absl/flags/parse.h"
|
||||
#include "absl/strings/str_cat.h"
|
||||
#include "absl/strings/str_split.h"
|
||||
#include "absl/strings/string_view.h"
|
||||
#include "mediapipe/framework/calculator_framework.h"
|
||||
#include "mediapipe/framework/port/commandlineflags.h"
|
||||
#include "mediapipe/framework/port/file_helpers.h"
|
||||
#include "mediapipe/framework/port/map_util.h"
|
||||
#include "mediapipe/framework/port/parse_text_proto.h"
|
||||
|
@ -32,31 +33,30 @@
|
|||
#include "mediapipe/framework/port/status.h"
|
||||
#include "mediapipe/framework/port/statusor.h"
|
||||
|
||||
DEFINE_string(
|
||||
calculator_graph_config_file, "",
|
||||
"Name of file containing text format CalculatorGraphConfig proto.");
|
||||
ABSL_FLAG(std::string, calculator_graph_config_file, "",
|
||||
"Name of file containing text format CalculatorGraphConfig proto.");
|
||||
|
||||
DEFINE_string(input_side_packets, "",
|
||||
"Comma-separated list of key=value pairs specifying side packets "
|
||||
"for the CalculatorGraph. All values will be treated as the "
|
||||
"string type even if they represent doubles, floats, etc.");
|
||||
ABSL_FLAG(std::string, input_side_packets, "",
|
||||
"Comma-separated list of key=value pairs specifying side packets "
|
||||
"for the CalculatorGraph. All values will be treated as the "
|
||||
"string type even if they represent doubles, floats, etc.");
|
||||
|
||||
// Local file output flags.
|
||||
// Output stream
|
||||
DEFINE_string(output_stream, "",
|
||||
"The output stream to output to the local file in csv format.");
|
||||
DEFINE_string(output_stream_file, "",
|
||||
"The name of the local file to output all packets sent to "
|
||||
"the stream specified with --output_stream. ");
|
||||
DEFINE_bool(strip_timestamps, false,
|
||||
"If true, only the packet contents (without timestamps) will be "
|
||||
"written into the local file.");
|
||||
ABSL_FLAG(std::string, output_stream, "",
|
||||
"The output stream to output to the local file in csv format.");
|
||||
ABSL_FLAG(std::string, output_stream_file, "",
|
||||
"The name of the local file to output all packets sent to "
|
||||
"the stream specified with --output_stream. ");
|
||||
ABSL_FLAG(bool, strip_timestamps, false,
|
||||
"If true, only the packet contents (without timestamps) will be "
|
||||
"written into the local file.");
|
||||
// Output side packets
|
||||
DEFINE_string(output_side_packets, "",
|
||||
"A CSV of output side packets to output to local file.");
|
||||
DEFINE_string(output_side_packets_file, "",
|
||||
"The name of the local file to output all side packets specified "
|
||||
"with --output_side_packets. ");
|
||||
ABSL_FLAG(std::string, output_side_packets, "",
|
||||
"A CSV of output side packets to output to local file.");
|
||||
ABSL_FLAG(std::string, output_side_packets_file, "",
|
||||
"The name of the local file to output all side packets specified "
|
||||
"with --output_side_packets. ");
|
||||
|
||||
absl::Status OutputStreamToLocalFile(mediapipe::OutputStreamPoller& poller) {
|
||||
std::ofstream file;
|
||||
|
@ -143,7 +143,7 @@ absl::Status RunMPPGraph() {
|
|||
|
||||
int main(int argc, char** argv) {
|
||||
google::InitGoogleLogging(argv[0]);
|
||||
gflags::ParseCommandLineFlags(&argc, &argv, true);
|
||||
absl::ParseCommandLine(argc, argv);
|
||||
absl::Status run_status = RunMPPGraph();
|
||||
if (!run_status.ok()) {
|
||||
LOG(ERROR) << "Failed to run the graph: " << run_status.message();
|
||||
|
|
|
@ -18,10 +18,11 @@ cc_binary(
|
|||
name = "extract_yt8m_features",
|
||||
srcs = ["extract_yt8m_features.cc"],
|
||||
deps = [
|
||||
"@com_google_absl//absl/flags:flag",
|
||||
"@com_google_absl//absl/flags:parse",
|
||||
"@com_google_absl//absl/strings",
|
||||
"//mediapipe/framework:calculator_framework",
|
||||
"//mediapipe/framework/formats:matrix",
|
||||
"//mediapipe/framework/port:commandlineflags",
|
||||
"//mediapipe/framework/port:file_helpers",
|
||||
"//mediapipe/framework/port:map_util",
|
||||
"//mediapipe/framework/port:parse_text_proto",
|
||||
|
|
|
@ -17,27 +17,27 @@
|
|||
// to disk.
|
||||
#include <cstdlib>
|
||||
|
||||
#include "absl/flags/flag.h"
|
||||
#include "absl/flags/parse.h"
|
||||
#include "absl/strings/str_split.h"
|
||||
#include "mediapipe/framework/calculator_framework.h"
|
||||
#include "mediapipe/framework/formats/matrix.h"
|
||||
#include "mediapipe/framework/port/commandlineflags.h"
|
||||
#include "mediapipe/framework/port/file_helpers.h"
|
||||
#include "mediapipe/framework/port/map_util.h"
|
||||
#include "mediapipe/framework/port/parse_text_proto.h"
|
||||
#include "mediapipe/framework/port/status.h"
|
||||
|
||||
DEFINE_string(
|
||||
calculator_graph_config_file, "",
|
||||
"Name of file containing text format CalculatorGraphConfig proto.");
|
||||
DEFINE_string(input_side_packets, "",
|
||||
"Comma-separated list of key=value pairs specifying side packets "
|
||||
"and corresponding file paths for the CalculatorGraph. The side "
|
||||
"packets are read from the files and fed to the graph as strings "
|
||||
"even if they represent doubles, floats, etc.");
|
||||
DEFINE_string(output_side_packets, "",
|
||||
"Comma-separated list of key=value pairs specifying the output "
|
||||
"side packets and paths to write to disk for the "
|
||||
"CalculatorGraph.");
|
||||
ABSL_FLAG(std::string, calculator_graph_config_file, "",
|
||||
"Name of file containing text format CalculatorGraphConfig proto.");
|
||||
ABSL_FLAG(std::string, input_side_packets, "",
|
||||
"Comma-separated list of key=value pairs specifying side packets "
|
||||
"and corresponding file paths for the CalculatorGraph. The side "
|
||||
"packets are read from the files and fed to the graph as strings "
|
||||
"even if they represent doubles, floats, etc.");
|
||||
ABSL_FLAG(std::string, output_side_packets, "",
|
||||
"Comma-separated list of key=value pairs specifying the output "
|
||||
"side packets and paths to write to disk for the "
|
||||
"CalculatorGraph.");
|
||||
|
||||
absl::Status RunMPPGraph() {
|
||||
std::string calculator_graph_config_contents;
|
||||
|
@ -126,7 +126,7 @@ absl::Status RunMPPGraph() {
|
|||
|
||||
int main(int argc, char** argv) {
|
||||
google::InitGoogleLogging(argv[0]);
|
||||
gflags::ParseCommandLineFlags(&argc, &argv, true);
|
||||
absl::ParseCommandLine(argc, argv);
|
||||
absl::Status run_status = RunMPPGraph();
|
||||
if (!run_status.ok()) {
|
||||
LOG(ERROR) << "Failed to run the graph: " << run_status.message();
|
||||
|
|
|
@ -23,7 +23,6 @@ package(default_visibility = ["//visibility:private"])
|
|||
package_group(
|
||||
name = "mediapipe_internal",
|
||||
packages = [
|
||||
"//java/com/google/mediapipe/framework/...",
|
||||
"//mediapipe/...",
|
||||
],
|
||||
)
|
||||
|
@ -78,21 +77,19 @@ mediapipe_proto_library(
|
|||
mediapipe_proto_library(
|
||||
name = "mediapipe_options_proto",
|
||||
srcs = ["mediapipe_options.proto"],
|
||||
visibility = ["//mediapipe/framework:__subpackages__"],
|
||||
visibility = [":mediapipe_internal"],
|
||||
)
|
||||
|
||||
mediapipe_proto_library(
|
||||
name = "packet_factory_proto",
|
||||
srcs = ["packet_factory.proto"],
|
||||
visibility = ["//mediapipe/framework:__subpackages__"],
|
||||
visibility = [":mediapipe_internal"],
|
||||
)
|
||||
|
||||
mediapipe_proto_library(
|
||||
name = "packet_generator_proto",
|
||||
srcs = ["packet_generator.proto"],
|
||||
visibility = [
|
||||
"//mediapipe:__subpackages__",
|
||||
],
|
||||
visibility = [":mediapipe_internal"],
|
||||
)
|
||||
|
||||
mediapipe_proto_library(
|
||||
|
@ -105,7 +102,7 @@ mediapipe_proto_library(
|
|||
mediapipe_proto_library(
|
||||
name = "status_handler_proto",
|
||||
srcs = ["status_handler.proto"],
|
||||
visibility = ["//mediapipe/framework:__subpackages__"],
|
||||
visibility = [":mediapipe_internal"],
|
||||
deps = ["//mediapipe/framework:mediapipe_options_proto"],
|
||||
)
|
||||
|
||||
|
@ -274,14 +271,17 @@ cc_library(
|
|||
],
|
||||
deps = [
|
||||
":calculator_base",
|
||||
":calculator_node",
|
||||
":counter_factory",
|
||||
":delegating_executor",
|
||||
":mediapipe_profiling",
|
||||
":executor",
|
||||
":graph_output_stream",
|
||||
":graph_service",
|
||||
":graph_service_manager",
|
||||
":input_stream_manager",
|
||||
":input_stream_shard",
|
||||
":graph_service",
|
||||
":output_side_packet_impl",
|
||||
":output_stream",
|
||||
":output_stream_manager",
|
||||
":output_stream_poller",
|
||||
|
@ -303,29 +303,27 @@ cc_library(
|
|||
"//mediapipe/framework:packet_generator_cc_proto",
|
||||
"//mediapipe/framework:status_handler_cc_proto",
|
||||
"//mediapipe/framework:thread_pool_executor_cc_proto",
|
||||
"@com_google_absl//absl/container:flat_hash_map",
|
||||
"@com_google_absl//absl/container:flat_hash_set",
|
||||
"//mediapipe/gpu:graph_support",
|
||||
"@com_google_absl//absl/base:core_headers",
|
||||
"@com_google_absl//absl/container:fixed_array",
|
||||
"@com_google_absl//absl/container:flat_hash_map",
|
||||
"@com_google_absl//absl/container:flat_hash_set",
|
||||
"@com_google_absl//absl/memory",
|
||||
"@com_google_absl//absl/strings",
|
||||
"@com_google_absl//absl/strings:str_format",
|
||||
"@com_google_absl//absl/synchronization",
|
||||
":calculator_node",
|
||||
":output_side_packet_impl",
|
||||
"//mediapipe/framework/profiler:graph_profiler",
|
||||
"//mediapipe/framework/tool:fill_packet_set",
|
||||
"//mediapipe/framework/tool:status_util",
|
||||
"//mediapipe/framework/tool:tag_map",
|
||||
"//mediapipe/framework/tool:validate",
|
||||
"//mediapipe/framework/tool:validate_name",
|
||||
"//mediapipe/framework/port:core_proto",
|
||||
"//mediapipe/framework/port:integral_types",
|
||||
"//mediapipe/framework/port:logging",
|
||||
"//mediapipe/framework/port:ret_check",
|
||||
"//mediapipe/framework/port:source_location",
|
||||
"//mediapipe/framework/port:status",
|
||||
"//mediapipe/framework/profiler:graph_profiler",
|
||||
"//mediapipe/framework/tool:fill_packet_set",
|
||||
"//mediapipe/framework/tool:status_util",
|
||||
"//mediapipe/framework/tool:tag_map",
|
||||
"//mediapipe/framework/tool:validate",
|
||||
"//mediapipe/framework/tool:validate_name",
|
||||
"//mediapipe/gpu:graph_support",
|
||||
"//mediapipe/util:cpu_util",
|
||||
] + select({
|
||||
"//conditions:default": [
|
||||
|
@ -336,6 +334,28 @@ cc_library(
|
|||
}),
|
||||
)
|
||||
|
||||
cc_library(
|
||||
name = "graph_service_manager",
|
||||
srcs = ["graph_service_manager.cc"],
|
||||
hdrs = ["graph_service_manager.h"],
|
||||
visibility = [":mediapipe_internal"],
|
||||
deps = [
|
||||
":graph_service",
|
||||
"//mediapipe/framework:packet",
|
||||
"@com_google_absl//absl/status",
|
||||
],
|
||||
)
|
||||
|
||||
cc_test(
|
||||
name = "graph_service_manager_test",
|
||||
srcs = ["graph_service_manager_test.cc"],
|
||||
deps = [
|
||||
":graph_service_manager",
|
||||
"//mediapipe/framework:packet",
|
||||
"//mediapipe/framework/port:gtest_main",
|
||||
],
|
||||
)
|
||||
|
||||
cc_library(
|
||||
name = "calculator_node",
|
||||
srcs = ["calculator_node.cc"],
|
||||
|
@ -425,6 +445,7 @@ cc_library(
|
|||
":counter",
|
||||
":counter_factory",
|
||||
":graph_service",
|
||||
":graph_service_manager",
|
||||
":input_stream",
|
||||
":output_stream",
|
||||
":packet",
|
||||
|
@ -977,6 +998,8 @@ cc_library(
|
|||
hdrs = ["subgraph.h"],
|
||||
visibility = ["//visibility:public"],
|
||||
deps = [
|
||||
":graph_service",
|
||||
":graph_service_manager",
|
||||
":port",
|
||||
"//mediapipe/framework:calculator_cc_proto",
|
||||
"//mediapipe/framework:mediapipe_options_cc_proto",
|
||||
|
@ -989,6 +1012,8 @@ cc_library(
|
|||
"//mediapipe/framework/tool:template_expander",
|
||||
"@com_google_absl//absl/base:core_headers",
|
||||
"@com_google_absl//absl/memory",
|
||||
"@com_google_absl//absl/status",
|
||||
"@com_google_absl//absl/types:optional",
|
||||
],
|
||||
)
|
||||
|
||||
|
@ -1008,7 +1033,7 @@ cc_library(
|
|||
"//mediapipe/framework/port:status",
|
||||
"@com_google_absl//absl/base:core_headers",
|
||||
"@com_google_absl//absl/strings",
|
||||
"@eigen_archive//:eigen",
|
||||
"@eigen_archive//:eigen3",
|
||||
],
|
||||
alwayslink = 1,
|
||||
)
|
||||
|
@ -1102,6 +1127,7 @@ cc_library(
|
|||
deps = [
|
||||
":calculator_base",
|
||||
":calculator_contract",
|
||||
":graph_service_manager",
|
||||
":legacy_calculator_support",
|
||||
":packet",
|
||||
":packet_generator",
|
||||
|
@ -1136,6 +1162,24 @@ cc_library(
|
|||
],
|
||||
)
|
||||
|
||||
cc_test(
|
||||
name = "validated_graph_config_test",
|
||||
srcs = ["validated_graph_config_test.cc"],
|
||||
deps = [
|
||||
":calculator_framework",
|
||||
":graph_service",
|
||||
":graph_service_manager",
|
||||
":validated_graph_config",
|
||||
"//mediapipe/framework:calculator_cc_proto",
|
||||
"//mediapipe/framework/api2:node",
|
||||
"//mediapipe/framework/api2:port",
|
||||
"//mediapipe/framework/deps:message_matchers",
|
||||
"//mediapipe/framework/port:gtest_main",
|
||||
"@com_google_absl//absl/status",
|
||||
"@com_google_absl//absl/strings",
|
||||
],
|
||||
)
|
||||
|
||||
cc_library(
|
||||
name = "graph_validation",
|
||||
hdrs = ["graph_validation.h"],
|
||||
|
@ -1591,13 +1635,16 @@ cc_test(
|
|||
srcs = ["subgraph_test.cc"],
|
||||
deps = [
|
||||
":calculator_framework",
|
||||
":graph_service_manager",
|
||||
":subgraph",
|
||||
":test_calculators",
|
||||
"//mediapipe/calculators/core:constant_side_packet_calculator",
|
||||
"//mediapipe/calculators/core:pass_through_calculator",
|
||||
"//mediapipe/framework/port:gtest_main",
|
||||
"//mediapipe/framework/port:status",
|
||||
"//mediapipe/framework/port:parse_text_proto",
|
||||
"//mediapipe/framework/tool:sink",
|
||||
"//mediapipe/framework/tool/testdata:dub_quad_test_subgraph",
|
||||
"@com_google_absl//absl/strings:str_format",
|
||||
],
|
||||
)
|
||||
|
||||
|
|
|
@ -41,9 +41,9 @@ Counter* CalculatorContext::GetCounter(const std::string& name) {
|
|||
return calculator_state_->GetCounter(name);
|
||||
}
|
||||
|
||||
CounterSet* CalculatorContext::GetCounterSet() {
|
||||
CounterFactory* CalculatorContext::GetCounterFactory() {
|
||||
CHECK(calculator_state_);
|
||||
return calculator_state_->GetCounterSet();
|
||||
return calculator_state_->GetCounterFactory();
|
||||
}
|
||||
|
||||
const PacketSet& CalculatorContext::InputSidePackets() const {
|
||||
|
|
|
@ -76,7 +76,7 @@ class CalculatorContext {
|
|||
|
||||
// Returns the counter set, which can be used to create new counters.
|
||||
// No prefix is added to counters created in this way.
|
||||
CounterSet* GetCounterSet();
|
||||
CounterFactory* GetCounterFactory();
|
||||
|
||||
// Returns the current input timestamp, or Timestamp::Unset if there are
|
||||
// no input packets.
|
||||
|
@ -113,26 +113,9 @@ class CalculatorContext {
|
|||
return calculator_state_->GetSharedProfilingContext().get();
|
||||
}
|
||||
|
||||
template <typename T>
|
||||
class ServiceBinding {
|
||||
public:
|
||||
bool IsAvailable() {
|
||||
return calculator_state_->IsServiceAvailable(service_);
|
||||
}
|
||||
T& GetObject() { return calculator_state_->GetServiceObject(service_); }
|
||||
|
||||
ServiceBinding(CalculatorState* calculator_state,
|
||||
const GraphService<T>& service)
|
||||
: calculator_state_(calculator_state), service_(service) {}
|
||||
|
||||
private:
|
||||
CalculatorState* calculator_state_;
|
||||
const GraphService<T>& service_;
|
||||
};
|
||||
|
||||
template <typename T>
|
||||
ServiceBinding<T> Service(const GraphService<T>& service) {
|
||||
return ServiceBinding<T>(calculator_state_, service);
|
||||
return ServiceBinding<T>(calculator_state_->GetServiceObject(service));
|
||||
}
|
||||
|
||||
private:
|
||||
|
|
|
@ -36,6 +36,7 @@
|
|||
#include "mediapipe/framework/calculator_base.h"
|
||||
#include "mediapipe/framework/counter_factory.h"
|
||||
#include "mediapipe/framework/delegating_executor.h"
|
||||
#include "mediapipe/framework/graph_service_manager.h"
|
||||
#include "mediapipe/framework/input_stream_manager.h"
|
||||
#include "mediapipe/framework/mediapipe_profiling.h"
|
||||
#include "mediapipe/framework/packet_generator.h"
|
||||
|
@ -392,7 +393,8 @@ absl::Status CalculatorGraph::Initialize(
|
|||
const CalculatorGraphConfig& input_config,
|
||||
const std::map<std::string, Packet>& side_packets) {
|
||||
auto validated_graph = absl::make_unique<ValidatedGraphConfig>();
|
||||
MP_RETURN_IF_ERROR(validated_graph->Initialize(input_config));
|
||||
MP_RETURN_IF_ERROR(validated_graph->Initialize(
|
||||
input_config, /*graph_registry=*/nullptr, &service_manager_));
|
||||
return Initialize(std::move(validated_graph), side_packets);
|
||||
}
|
||||
|
||||
|
@ -402,8 +404,8 @@ absl::Status CalculatorGraph::Initialize(
|
|||
const std::map<std::string, Packet>& side_packets,
|
||||
const std::string& graph_type, const Subgraph::SubgraphOptions* options) {
|
||||
auto validated_graph = absl::make_unique<ValidatedGraphConfig>();
|
||||
MP_RETURN_IF_ERROR(validated_graph->Initialize(input_configs, input_templates,
|
||||
graph_type, options));
|
||||
MP_RETURN_IF_ERROR(validated_graph->Initialize(
|
||||
input_configs, input_templates, graph_type, options, &service_manager_));
|
||||
return Initialize(std::move(validated_graph), side_packets);
|
||||
}
|
||||
|
||||
|
@ -509,19 +511,15 @@ absl::Status CalculatorGraph::StartRun(
|
|||
#if !MEDIAPIPE_DISABLE_GPU
|
||||
absl::Status CalculatorGraph::SetGpuResources(
|
||||
std::shared_ptr<::mediapipe::GpuResources> resources) {
|
||||
RET_CHECK(!ContainsKey(service_packets_, kGpuService.key))
|
||||
auto gpu_service = service_manager_.GetServiceObject(kGpuService);
|
||||
RET_CHECK_EQ(gpu_service, nullptr)
|
||||
<< "The GPU resources have already been configured.";
|
||||
service_packets_[kGpuService.key] =
|
||||
MakePacket<std::shared_ptr<::mediapipe::GpuResources>>(
|
||||
std::move(resources));
|
||||
return absl::OkStatus();
|
||||
return service_manager_.SetServiceObject(kGpuService, std::move(resources));
|
||||
}
|
||||
|
||||
std::shared_ptr<::mediapipe::GpuResources> CalculatorGraph::GetGpuResources()
|
||||
const {
|
||||
auto service_iter = service_packets_.find(kGpuService.key);
|
||||
if (service_iter == service_packets_.end()) return nullptr;
|
||||
return service_iter->second.Get<std::shared_ptr<::mediapipe::GpuResources>>();
|
||||
return service_manager_.GetServiceObject(kGpuService);
|
||||
}
|
||||
|
||||
absl::StatusOr<std::map<std::string, Packet>> CalculatorGraph::PrepareGpu(
|
||||
|
@ -536,8 +534,7 @@ absl::StatusOr<std::map<std::string, Packet>> CalculatorGraph::PrepareGpu(
|
|||
}
|
||||
}
|
||||
if (uses_gpu) {
|
||||
auto service_iter = service_packets_.find(kGpuService.key);
|
||||
bool has_service = service_iter != service_packets_.end();
|
||||
auto gpu_resources = service_manager_.GetServiceObject(kGpuService);
|
||||
|
||||
auto legacy_sp_iter = side_packets.find(kGpuSharedSidePacketName);
|
||||
// Workaround for b/116875321: CalculatorRunner provides an empty packet,
|
||||
|
@ -545,15 +542,12 @@ absl::StatusOr<std::map<std::string, Packet>> CalculatorGraph::PrepareGpu(
|
|||
bool has_legacy_sp = legacy_sp_iter != side_packets.end() &&
|
||||
!legacy_sp_iter->second.IsEmpty();
|
||||
|
||||
std::shared_ptr<::mediapipe::GpuResources> gpu_resources;
|
||||
if (has_service) {
|
||||
if (gpu_resources) {
|
||||
if (has_legacy_sp) {
|
||||
LOG(WARNING)
|
||||
<< "::mediapipe::GpuSharedData provided as a side packet while the "
|
||||
<< "graph already had one; ignoring side packet";
|
||||
}
|
||||
gpu_resources = service_iter->second
|
||||
.Get<std::shared_ptr<::mediapipe::GpuResources>>();
|
||||
update_sp = true;
|
||||
} else {
|
||||
if (has_legacy_sp) {
|
||||
|
@ -564,8 +558,8 @@ absl::StatusOr<std::map<std::string, Packet>> CalculatorGraph::PrepareGpu(
|
|||
ASSIGN_OR_RETURN(gpu_resources, ::mediapipe::GpuResources::Create());
|
||||
update_sp = true;
|
||||
}
|
||||
service_packets_[kGpuService.key] =
|
||||
MakePacket<std::shared_ptr<::mediapipe::GpuResources>>(gpu_resources);
|
||||
MP_RETURN_IF_ERROR(
|
||||
service_manager_.SetServiceObject(kGpuService, gpu_resources));
|
||||
}
|
||||
|
||||
// Create or replace the legacy side packet if needed.
|
||||
|
@ -682,8 +676,10 @@ absl::Status CalculatorGraph::PrepareForRun(
|
|||
std::placeholders::_1, std::placeholders::_2);
|
||||
node.SetQueueSizeCallbacks(queue_size_callback, queue_size_callback);
|
||||
scheduler_.AssignNodeToSchedulerQueue(&node);
|
||||
// TODO: update calculator node to use GraphServiceManager
|
||||
// instead of service packets?
|
||||
const absl::Status result = node.PrepareForRun(
|
||||
current_run_side_packets_, service_packets_,
|
||||
current_run_side_packets_, service_manager_.ServicePackets(),
|
||||
std::bind(&internal::Scheduler::ScheduleNodeForOpen, &scheduler_,
|
||||
&node),
|
||||
std::bind(&internal::Scheduler::AddNodeToSourcesQueue, &scheduler_,
|
||||
|
@ -811,6 +807,11 @@ absl::Status CalculatorGraph::AddPacketToInputStreamInternal(
|
|||
CHECK_GE(node_id, validated_graph_->CalculatorInfos().size());
|
||||
{
|
||||
absl::MutexLock lock(&full_input_streams_mutex_);
|
||||
if (full_input_streams_.empty()) {
|
||||
return mediapipe::FailedPreconditionErrorBuilder(MEDIAPIPE_LOC)
|
||||
<< "CalculatorGraph::AddPacketToInputStream() is called before "
|
||||
"StartRun()";
|
||||
}
|
||||
if (graph_input_stream_add_mode_ ==
|
||||
GraphInputStreamAddMode::ADD_IF_NOT_FULL) {
|
||||
if (has_error_) {
|
||||
|
@ -1170,21 +1171,6 @@ void CalculatorGraph::Pause() { scheduler_.Pause(); }
|
|||
|
||||
void CalculatorGraph::Resume() { scheduler_.Resume(); }
|
||||
|
||||
absl::Status CalculatorGraph::SetServicePacket(const GraphServiceBase& service,
|
||||
Packet p) {
|
||||
// TODO: check that the graph has not been started!
|
||||
service_packets_[service.key] = std::move(p);
|
||||
return absl::OkStatus();
|
||||
}
|
||||
|
||||
Packet CalculatorGraph::GetServicePacket(const GraphServiceBase& service) {
|
||||
auto it = service_packets_.find(service.key);
|
||||
if (it == service_packets_.end()) {
|
||||
return {};
|
||||
}
|
||||
return it->second;
|
||||
}
|
||||
|
||||
absl::Status CalculatorGraph::SetExecutorInternal(
|
||||
const std::string& name, std::shared_ptr<Executor> executor) {
|
||||
if (!executors_.emplace(name, executor).second) {
|
||||
|
|
|
@ -38,6 +38,7 @@
|
|||
#include "mediapipe/framework/executor.h"
|
||||
#include "mediapipe/framework/graph_output_stream.h"
|
||||
#include "mediapipe/framework/graph_service.h"
|
||||
#include "mediapipe/framework/graph_service_manager.h"
|
||||
#include "mediapipe/framework/mediapipe_profiling.h"
|
||||
#include "mediapipe/framework/output_side_packet_impl.h"
|
||||
#include "mediapipe/framework/output_stream.h"
|
||||
|
@ -377,19 +378,20 @@ class CalculatorGraph {
|
|||
template <typename T>
|
||||
absl::Status SetServiceObject(const GraphService<T>& service,
|
||||
std::shared_ptr<T> object) {
|
||||
return SetServicePacket(service,
|
||||
MakePacket<std::shared_ptr<T>>(std::move(object)));
|
||||
// TODO: check that the graph has not been started!
|
||||
return service_manager_.SetServiceObject(service, object);
|
||||
}
|
||||
|
||||
template <typename T>
|
||||
std::shared_ptr<T> GetServiceObject(const GraphService<T>& service) {
|
||||
Packet p = GetServicePacket(service);
|
||||
if (p.IsEmpty()) return nullptr;
|
||||
return p.Get<std::shared_ptr<T>>();
|
||||
return service_manager_.GetServiceObject(service);
|
||||
}
|
||||
|
||||
// Only the Java API should call this directly.
|
||||
absl::Status SetServicePacket(const GraphServiceBase& service, Packet p);
|
||||
absl::Status SetServicePacket(const GraphServiceBase& service, Packet p) {
|
||||
// TODO: check that the graph has not been started!
|
||||
return service_manager_.SetServicePacket(service, p);
|
||||
}
|
||||
|
||||
private:
|
||||
// GraphRunState is used as a parameter in the function CallStatusHandlers.
|
||||
|
@ -523,7 +525,6 @@ class CalculatorGraph {
|
|||
// status before taking any action.
|
||||
void UpdateThrottledNodes(InputStreamManager* stream, bool* stream_was_full);
|
||||
|
||||
Packet GetServicePacket(const GraphServiceBase& service);
|
||||
#if !MEDIAPIPE_DISABLE_GPU
|
||||
// Owns the legacy GpuSharedData if we need to create one for backwards
|
||||
// compatibility.
|
||||
|
@ -598,7 +599,8 @@ class CalculatorGraph {
|
|||
// The processed input side packet map for this run.
|
||||
std::map<std::string, Packet> current_run_side_packets_;
|
||||
|
||||
std::map<std::string, Packet> service_packets_;
|
||||
// Object to manage graph services.
|
||||
GraphServiceManager service_manager_;
|
||||
|
||||
// Vector of errors encountered while running graph. Always use RecordError()
|
||||
// to add an error to this vector.
|
||||
|
|
|
@ -1361,6 +1361,38 @@ TEST(CalculatorGraphBoundsTest, ProcessTimestampBounds_Passthrough) {
|
|||
MP_ASSERT_OK(graph.WaitUntilDone());
|
||||
}
|
||||
|
||||
TEST(CalculatorGraphBoundsTest, PostStreamPacketToSetProcessTimestampBound) {
|
||||
std::string config_str = R"(
|
||||
input_stream: "input_0"
|
||||
node {
|
||||
calculator: "ProcessBoundToPacketCalculator"
|
||||
input_stream: "input_0"
|
||||
output_stream: "output_0"
|
||||
}
|
||||
)";
|
||||
CalculatorGraphConfig config =
|
||||
mediapipe::ParseTextProtoOrDie<CalculatorGraphConfig>(config_str);
|
||||
CalculatorGraph graph;
|
||||
std::vector<Packet> output_0_packets;
|
||||
MP_ASSERT_OK(graph.Initialize(config));
|
||||
MP_ASSERT_OK(graph.ObserveOutputStream("output_0", [&](const Packet& p) {
|
||||
output_0_packets.push_back(p);
|
||||
return absl::OkStatus();
|
||||
}));
|
||||
MP_ASSERT_OK(graph.StartRun({}));
|
||||
MP_ASSERT_OK(graph.WaitUntilIdle());
|
||||
|
||||
MP_ASSERT_OK(graph.AddPacketToInputStream(
|
||||
"input_0", MakePacket<int>(0).At(Timestamp::PostStream())));
|
||||
MP_ASSERT_OK(graph.WaitUntilIdle());
|
||||
EXPECT_EQ(output_0_packets.size(), 1);
|
||||
EXPECT_EQ(output_0_packets[0].Timestamp(), Timestamp::PostStream());
|
||||
|
||||
// Shutdown the graph.
|
||||
MP_ASSERT_OK(graph.CloseAllPacketSources());
|
||||
MP_ASSERT_OK(graph.WaitUntilDone());
|
||||
}
|
||||
|
||||
// A Calculator that sends a timestamp bound for every other input.
|
||||
class OccasionalBoundCalculator : public CalculatorBase {
|
||||
public:
|
||||
|
|
|
@ -4356,256 +4356,5 @@ TEST(CalculatorGraph, GraphInputStreamWithTag) {
|
|||
ASSERT_EQ(5, packet_dump.size());
|
||||
}
|
||||
|
||||
// Returns the first packet of the input stream.
|
||||
class FirstPacketFilterCalculator : public CalculatorBase {
|
||||
public:
|
||||
FirstPacketFilterCalculator() {}
|
||||
~FirstPacketFilterCalculator() override {}
|
||||
|
||||
static absl::Status GetContract(CalculatorContract* cc) {
|
||||
cc->Inputs().Index(0).SetAny();
|
||||
cc->Outputs().Index(0).SetSameAs(&cc->Inputs().Index(0));
|
||||
return absl::OkStatus();
|
||||
}
|
||||
|
||||
absl::Status Process(CalculatorContext* cc) override {
|
||||
if (!seen_first_packet_) {
|
||||
cc->Outputs().Index(0).AddPacket(cc->Inputs().Index(0).Value());
|
||||
cc->Outputs().Index(0).Close();
|
||||
seen_first_packet_ = true;
|
||||
}
|
||||
return absl::OkStatus();
|
||||
}
|
||||
|
||||
private:
|
||||
bool seen_first_packet_ = false;
|
||||
};
|
||||
REGISTER_CALCULATOR(FirstPacketFilterCalculator);
|
||||
constexpr int kDefaultMaxCount = 1000;
|
||||
|
||||
TEST(CalculatorGraph, TestPollPacket) {
|
||||
CalculatorGraphConfig config;
|
||||
CalculatorGraphConfig::Node* node = config.add_node();
|
||||
node->set_calculator("CountingSourceCalculator");
|
||||
node->add_output_stream("output");
|
||||
node->add_input_side_packet("MAX_COUNT:max_count");
|
||||
|
||||
CalculatorGraph graph;
|
||||
MP_ASSERT_OK(graph.Initialize(config));
|
||||
auto status_or_poller = graph.AddOutputStreamPoller("output");
|
||||
ASSERT_TRUE(status_or_poller.ok());
|
||||
OutputStreamPoller poller = std::move(status_or_poller.value());
|
||||
MP_ASSERT_OK(
|
||||
graph.StartRun({{"max_count", MakePacket<int>(kDefaultMaxCount)}}));
|
||||
Packet packet;
|
||||
int num_packets = 0;
|
||||
while (poller.Next(&packet)) {
|
||||
EXPECT_EQ(num_packets, packet.Get<int>());
|
||||
++num_packets;
|
||||
}
|
||||
MP_ASSERT_OK(graph.CloseAllPacketSources());
|
||||
MP_ASSERT_OK(graph.WaitUntilDone());
|
||||
EXPECT_FALSE(poller.Next(&packet));
|
||||
EXPECT_EQ(kDefaultMaxCount, num_packets);
|
||||
}
|
||||
|
||||
TEST(CalculatorGraph, TestOutputStreamPollerDesiredQueueSize) {
|
||||
CalculatorGraphConfig config;
|
||||
CalculatorGraphConfig::Node* node = config.add_node();
|
||||
node->set_calculator("CountingSourceCalculator");
|
||||
node->add_output_stream("output");
|
||||
node->add_input_side_packet("MAX_COUNT:max_count");
|
||||
|
||||
for (int queue_size = 1; queue_size < 10; ++queue_size) {
|
||||
CalculatorGraph graph;
|
||||
MP_ASSERT_OK(graph.Initialize(config));
|
||||
auto status_or_poller = graph.AddOutputStreamPoller("output");
|
||||
ASSERT_TRUE(status_or_poller.ok());
|
||||
OutputStreamPoller poller = std::move(status_or_poller.value());
|
||||
poller.SetMaxQueueSize(queue_size);
|
||||
MP_ASSERT_OK(
|
||||
graph.StartRun({{"max_count", MakePacket<int>(kDefaultMaxCount)}}));
|
||||
Packet packet;
|
||||
int num_packets = 0;
|
||||
while (poller.Next(&packet)) {
|
||||
EXPECT_EQ(num_packets, packet.Get<int>());
|
||||
++num_packets;
|
||||
}
|
||||
MP_ASSERT_OK(graph.CloseAllPacketSources());
|
||||
MP_ASSERT_OK(graph.WaitUntilDone());
|
||||
EXPECT_FALSE(poller.Next(&packet));
|
||||
EXPECT_EQ(kDefaultMaxCount, num_packets);
|
||||
}
|
||||
}
|
||||
|
||||
TEST(CalculatorGraph, TestPollPacketsFromMultipleStreams) {
|
||||
CalculatorGraphConfig config;
|
||||
CalculatorGraphConfig::Node* node1 = config.add_node();
|
||||
node1->set_calculator("CountingSourceCalculator");
|
||||
node1->add_output_stream("stream1");
|
||||
node1->add_input_side_packet("MAX_COUNT:max_count");
|
||||
CalculatorGraphConfig::Node* node2 = config.add_node();
|
||||
node2->set_calculator("PassThroughCalculator");
|
||||
node2->add_input_stream("stream1");
|
||||
node2->add_output_stream("stream2");
|
||||
|
||||
CalculatorGraph graph;
|
||||
MP_ASSERT_OK(graph.Initialize(config));
|
||||
auto status_or_poller1 = graph.AddOutputStreamPoller("stream1");
|
||||
ASSERT_TRUE(status_or_poller1.ok());
|
||||
OutputStreamPoller poller1 = std::move(status_or_poller1.value());
|
||||
auto status_or_poller2 = graph.AddOutputStreamPoller("stream2");
|
||||
ASSERT_TRUE(status_or_poller2.ok());
|
||||
OutputStreamPoller poller2 = std::move(status_or_poller2.value());
|
||||
MP_ASSERT_OK(
|
||||
graph.StartRun({{"max_count", MakePacket<int>(kDefaultMaxCount)}}));
|
||||
Packet packet1;
|
||||
Packet packet2;
|
||||
int num_packets1 = 0;
|
||||
int num_packets2 = 0;
|
||||
int running_pollers = 2;
|
||||
while (running_pollers > 0) {
|
||||
if (poller1.Next(&packet1)) {
|
||||
EXPECT_EQ(num_packets1++, packet1.Get<int>());
|
||||
} else {
|
||||
--running_pollers;
|
||||
}
|
||||
if (poller2.Next(&packet2)) {
|
||||
EXPECT_EQ(num_packets2++, packet2.Get<int>());
|
||||
} else {
|
||||
--running_pollers;
|
||||
}
|
||||
}
|
||||
MP_ASSERT_OK(graph.CloseAllPacketSources());
|
||||
MP_ASSERT_OK(graph.WaitUntilDone());
|
||||
EXPECT_FALSE(poller1.Next(&packet1));
|
||||
EXPECT_FALSE(poller2.Next(&packet2));
|
||||
EXPECT_EQ(kDefaultMaxCount, num_packets1);
|
||||
EXPECT_EQ(kDefaultMaxCount, num_packets2);
|
||||
}
|
||||
|
||||
// Ensure that when a custom input stream handler is used to handle packets from
|
||||
// input streams, an error message is outputted with the appropriate link to
|
||||
// resolve the issue when the calculator doesn't handle inputs in monotonically
|
||||
// increasing order of timestamps.
|
||||
TEST(CalculatorGraph, SimpleMuxCalculatorWithCustomInputStreamHandler) {
|
||||
CalculatorGraph graph;
|
||||
CalculatorGraphConfig config =
|
||||
mediapipe::ParseTextProtoOrDie<CalculatorGraphConfig>(R"(
|
||||
input_stream: 'input0'
|
||||
input_stream: 'input1'
|
||||
node {
|
||||
calculator: 'SimpleMuxCalculator'
|
||||
input_stream: 'input0'
|
||||
input_stream: 'input1'
|
||||
input_stream_handler {
|
||||
input_stream_handler: "ImmediateInputStreamHandler"
|
||||
}
|
||||
output_stream: 'output'
|
||||
}
|
||||
)");
|
||||
std::vector<Packet> packet_dump;
|
||||
tool::AddVectorSink("output", &config, &packet_dump);
|
||||
|
||||
MP_ASSERT_OK(graph.Initialize(config));
|
||||
MP_ASSERT_OK(graph.StartRun({}));
|
||||
|
||||
// Send packets to input stream "input0" at timestamps 0 and 1 consecutively.
|
||||
Timestamp input0_timestamp = Timestamp(0);
|
||||
MP_EXPECT_OK(graph.AddPacketToInputStream(
|
||||
"input0", MakePacket<int>(1).At(input0_timestamp)));
|
||||
MP_ASSERT_OK(graph.WaitUntilIdle());
|
||||
ASSERT_EQ(1, packet_dump.size());
|
||||
EXPECT_EQ(1, packet_dump[0].Get<int>());
|
||||
|
||||
++input0_timestamp;
|
||||
MP_EXPECT_OK(graph.AddPacketToInputStream(
|
||||
"input0", MakePacket<int>(3).At(input0_timestamp)));
|
||||
MP_ASSERT_OK(graph.WaitUntilIdle());
|
||||
ASSERT_EQ(2, packet_dump.size());
|
||||
EXPECT_EQ(3, packet_dump[1].Get<int>());
|
||||
|
||||
// Send a packet to input stream "input1" at timestamp 0 after sending two
|
||||
// packets at timestamps 0 and 1 to input stream "input0". This will result
|
||||
// in a mismatch in timestamps as the SimpleMuxCalculator doesn't handle
|
||||
// inputs from all streams in monotonically increasing order of timestamps.
|
||||
Timestamp input1_timestamp = Timestamp(0);
|
||||
MP_EXPECT_OK(graph.AddPacketToInputStream(
|
||||
"input1", MakePacket<int>(2).At(input1_timestamp)));
|
||||
absl::Status run_status = graph.WaitUntilIdle();
|
||||
EXPECT_THAT(
|
||||
run_status.ToString(),
|
||||
testing::AllOf(
|
||||
// The core problem.
|
||||
testing::HasSubstr("timestamp mismatch on a calculator"),
|
||||
testing::HasSubstr(
|
||||
"timestamps that are not strictly monotonically increasing"),
|
||||
// Link to the possible solution.
|
||||
testing::HasSubstr("ImmediateInputStreamHandler class comment")));
|
||||
}
|
||||
|
||||
void DoTestMultipleGraphRuns(absl::string_view input_stream_handler,
|
||||
bool select_packet) {
|
||||
std::string graph_proto = absl::StrFormat(R"(
|
||||
input_stream: 'input'
|
||||
input_stream: 'select'
|
||||
node {
|
||||
calculator: 'PassThroughCalculator'
|
||||
input_stream: 'input'
|
||||
input_stream: 'select'
|
||||
input_stream_handler {
|
||||
input_stream_handler: "%s"
|
||||
}
|
||||
output_stream: 'output'
|
||||
output_stream: 'select_out'
|
||||
}
|
||||
)",
|
||||
input_stream_handler.data());
|
||||
CalculatorGraphConfig config =
|
||||
mediapipe::ParseTextProtoOrDie<CalculatorGraphConfig>(graph_proto);
|
||||
std::vector<Packet> packet_dump;
|
||||
tool::AddVectorSink("output", &config, &packet_dump);
|
||||
|
||||
CalculatorGraph graph;
|
||||
MP_ASSERT_OK(graph.Initialize(config));
|
||||
|
||||
struct Run {
|
||||
Timestamp timestamp;
|
||||
int value;
|
||||
};
|
||||
std::vector<Run> runs = {{.timestamp = Timestamp(2000), .value = 2},
|
||||
{.timestamp = Timestamp(1000), .value = 1}};
|
||||
for (const Run& run : runs) {
|
||||
MP_ASSERT_OK(graph.StartRun({}));
|
||||
|
||||
if (select_packet) {
|
||||
MP_EXPECT_OK(graph.AddPacketToInputStream(
|
||||
"select", MakePacket<int>(0).At(run.timestamp)));
|
||||
}
|
||||
MP_EXPECT_OK(graph.AddPacketToInputStream(
|
||||
"input", MakePacket<int>(run.value).At(run.timestamp)));
|
||||
MP_ASSERT_OK(graph.WaitUntilIdle());
|
||||
ASSERT_EQ(1, packet_dump.size());
|
||||
EXPECT_EQ(run.value, packet_dump[0].Get<int>());
|
||||
EXPECT_EQ(run.timestamp, packet_dump[0].Timestamp());
|
||||
|
||||
MP_ASSERT_OK(graph.CloseAllPacketSources());
|
||||
MP_ASSERT_OK(graph.WaitUntilDone());
|
||||
|
||||
packet_dump.clear();
|
||||
}
|
||||
}
|
||||
|
||||
TEST(CalculatorGraph, MultipleRunsWithDifferentInputStreamHandlers) {
|
||||
DoTestMultipleGraphRuns("BarrierInputStreamHandler", true);
|
||||
DoTestMultipleGraphRuns("DefaultInputStreamHandler", true);
|
||||
DoTestMultipleGraphRuns("EarlyCloseInputStreamHandler", true);
|
||||
DoTestMultipleGraphRuns("FixedSizeInputStreamHandler", true);
|
||||
DoTestMultipleGraphRuns("ImmediateInputStreamHandler", false);
|
||||
DoTestMultipleGraphRuns("MuxInputStreamHandler", true);
|
||||
DoTestMultipleGraphRuns("SyncSetInputStreamHandler", true);
|
||||
DoTestMultipleGraphRuns("TimestampAlignInputStreamHandler", true);
|
||||
}
|
||||
|
||||
} // namespace
|
||||
} // namespace mediapipe
|
||||
|
|
|
@ -408,13 +408,13 @@ absl::Status CalculatorNode::PrepareForRun(
|
|||
validated_graph_->CalculatorInfos()[node_id_].Contract();
|
||||
for (const auto& svc_req : contract.ServiceRequests()) {
|
||||
const auto& req = svc_req.second;
|
||||
std::string key{req.Service().key};
|
||||
auto it = service_packets.find(key);
|
||||
auto it = service_packets.find(req.Service().key);
|
||||
if (it == service_packets.end()) {
|
||||
RET_CHECK(req.IsOptional())
|
||||
<< "required service '" << key << "' was not provided";
|
||||
<< "required service '" << req.Service().key << "' was not provided";
|
||||
} else {
|
||||
calculator_state_->SetServicePacket(key, it->second);
|
||||
MP_RETURN_IF_ERROR(
|
||||
calculator_state_->SetServicePacket(req.Service(), it->second));
|
||||
}
|
||||
}
|
||||
|
||||
|
|
|
@ -61,13 +61,9 @@ Counter* CalculatorState::GetCounter(const std::string& name) {
|
|||
return counter_factory_->GetCounter(absl::StrCat(NodeName(), "-", name));
|
||||
}
|
||||
|
||||
CounterSet* CalculatorState::GetCounterSet() {
|
||||
CounterFactory* CalculatorState::GetCounterFactory() {
|
||||
CHECK(counter_factory_);
|
||||
return counter_factory_->GetCounterSet();
|
||||
}
|
||||
|
||||
void CalculatorState::SetServicePacket(const std::string& key, Packet packet) {
|
||||
service_packets_[key] = std::move(packet);
|
||||
return counter_factory_;
|
||||
}
|
||||
|
||||
} // namespace mediapipe
|
||||
|
|
|
@ -27,6 +27,7 @@
|
|||
#include "mediapipe/framework/counter.h"
|
||||
#include "mediapipe/framework/counter_factory.h"
|
||||
#include "mediapipe/framework/graph_service.h"
|
||||
#include "mediapipe/framework/graph_service_manager.h"
|
||||
#include "mediapipe/framework/packet.h"
|
||||
#include "mediapipe/framework/packet_set.h"
|
||||
#include "mediapipe/framework/port.h"
|
||||
|
@ -81,7 +82,7 @@ class CalculatorState {
|
|||
// Returns a counter set, which can be passed to other classes, to generate
|
||||
// counters. NOTE: This differs from GetCounter, in that the counters
|
||||
// created by this counter set do not have the NodeName prefix.
|
||||
CounterSet* GetCounterSet();
|
||||
CounterFactory* GetCounterFactory();
|
||||
|
||||
std::shared_ptr<ProfilingContext> GetSharedProfilingContext() const {
|
||||
return profiling_context_;
|
||||
|
@ -99,17 +100,14 @@ class CalculatorState {
|
|||
counter_factory_ = counter_factory;
|
||||
}
|
||||
|
||||
void SetServicePacket(const std::string& key, Packet packet);
|
||||
|
||||
bool IsServiceAvailable(const GraphServiceBase& service) {
|
||||
return ContainsKey(service_packets_, service.key);
|
||||
absl::Status SetServicePacket(const GraphServiceBase& service,
|
||||
Packet packet) {
|
||||
return graph_service_manager_.SetServicePacket(service, packet);
|
||||
}
|
||||
|
||||
template <typename T>
|
||||
T& GetServiceObject(const GraphService<T>& service) {
|
||||
auto it = service_packets_.find(service.key);
|
||||
CHECK(it != service_packets_.end());
|
||||
return *it->second.template Get<std::shared_ptr<T>>();
|
||||
std::shared_ptr<T> GetServiceObject(const GraphService<T>& service) {
|
||||
return graph_service_manager_.GetServiceObject(service);
|
||||
}
|
||||
|
||||
private:
|
||||
|
@ -129,7 +127,7 @@ class CalculatorState {
|
|||
// The graph tracing and profiling interface.
|
||||
std::shared_ptr<ProfilingContext> profiling_context_;
|
||||
|
||||
std::map<std::string, Packet> service_packets_;
|
||||
GraphServiceManager graph_service_manager_;
|
||||
|
||||
////////////////////////////////////////
|
||||
// Variables which ARE cleared by ResetBetweenRuns().
|
||||
|
|
|
@ -37,7 +37,7 @@ inline StatusBuilder RetCheckImpl(const absl::Status& status,
|
|||
const char* condition,
|
||||
mediapipe::source_location location) {
|
||||
if (ABSL_PREDICT_TRUE(status.ok()))
|
||||
return mediapipe::StatusBuilder(OkStatus(), location);
|
||||
return mediapipe::StatusBuilder(absl::OkStatus(), location);
|
||||
return RetCheckFailSlowPath(location, condition, status);
|
||||
}
|
||||
|
||||
|
|
|
@ -18,7 +18,7 @@
|
|||
|
||||
namespace mediapipe {
|
||||
|
||||
std::ostream& operator<<(std::ostream& os, const Status& x) {
|
||||
std::ostream& operator<<(std::ostream& os, const absl::Status& x) {
|
||||
os << x.ToString();
|
||||
return os;
|
||||
}
|
||||
|
|
|
@ -194,10 +194,10 @@ namespace status_macro_internal {
|
|||
// that declares a variable.
|
||||
class StatusAdaptorForMacros {
|
||||
public:
|
||||
StatusAdaptorForMacros(const Status& status, const char* file, int line)
|
||||
StatusAdaptorForMacros(const absl::Status& status, const char* file, int line)
|
||||
: builder_(status, file, line) {}
|
||||
|
||||
StatusAdaptorForMacros(Status&& status, const char* file, int line)
|
||||
StatusAdaptorForMacros(absl::Status&& status, const char* file, int line)
|
||||
: builder_(std::move(status), file, line) {}
|
||||
|
||||
StatusAdaptorForMacros(const StatusBuilder& builder, const char* /* file */,
|
||||
|
|
|
@ -79,13 +79,10 @@ def _get_proto_provider(dep):
|
|||
|
||||
def _encode_binary_proto_impl(ctx):
|
||||
"""Implementation of the encode_binary_proto rule."""
|
||||
all_protos = depset()
|
||||
for dep in ctx.attr.deps:
|
||||
provider = _get_proto_provider(dep)
|
||||
all_protos = depset(
|
||||
direct = [],
|
||||
transitive = [all_protos, provider.transitive_sources],
|
||||
)
|
||||
all_protos = depset(
|
||||
direct = [],
|
||||
transitive = [_get_proto_provider(dep).transitive_sources for dep in ctx.attr.deps],
|
||||
)
|
||||
|
||||
textpb = ctx.file.input
|
||||
binarypb = ctx.outputs.output or ctx.actions.declare_file(
|
||||
|
@ -120,7 +117,7 @@ def _encode_binary_proto_impl(ctx):
|
|||
data_runfiles = ctx.runfiles(transitive_files = output_depset),
|
||||
)]
|
||||
|
||||
encode_binary_proto = rule(
|
||||
_encode_binary_proto = rule(
|
||||
implementation = _encode_binary_proto_impl,
|
||||
attrs = {
|
||||
"_proto_compiler": attr.label(
|
||||
|
@ -142,6 +139,15 @@ encode_binary_proto = rule(
|
|||
},
|
||||
)
|
||||
|
||||
def encode_binary_proto(name, input, message_type, deps, **kwargs):
|
||||
_encode_binary_proto(
|
||||
name = name,
|
||||
input = input,
|
||||
message_type = message_type,
|
||||
deps = deps,
|
||||
**kwargs
|
||||
)
|
||||
|
||||
def _generate_proto_descriptor_set_impl(ctx):
|
||||
"""Implementation of the generate_proto_descriptor_set rule."""
|
||||
all_protos = depset(transitive = [
|
||||
|
|
|
@ -114,7 +114,7 @@ cc_library(
|
|||
"//mediapipe/framework/port:logging",
|
||||
"//mediapipe/framework/port:ret_check",
|
||||
"//mediapipe/framework/port:status",
|
||||
"@eigen_archive//:eigen",
|
||||
"@eigen_archive//:eigen3",
|
||||
],
|
||||
)
|
||||
|
||||
|
@ -260,9 +260,11 @@ mediapipe_register_type(
|
|||
include_headers = ["mediapipe/framework/formats/landmark.pb.h"],
|
||||
types = [
|
||||
"::mediapipe::Landmark",
|
||||
"::mediapipe::LandmarkList",
|
||||
"::mediapipe::NormalizedLandmark",
|
||||
"::mediapipe::NormalizedLandmarkList",
|
||||
"::std::vector<::mediapipe::Landmark>",
|
||||
"::std::vector<::mediapipe::LandmarkList>",
|
||||
"::std::vector<::mediapipe::NormalizedLandmark>",
|
||||
"::std::vector<::mediapipe::NormalizedLandmarkList>",
|
||||
],
|
||||
|
|
|
@ -31,6 +31,8 @@ message Classification {
|
|||
optional float score = 2;
|
||||
// Label or name of the class.
|
||||
optional string label = 3;
|
||||
// Optional human-readable string for display purposes.
|
||||
optional string display_name = 4;
|
||||
}
|
||||
|
||||
// Group of Classification protos.
|
||||
|
|
|
@ -78,6 +78,12 @@ class Image {
|
|||
pixel_mutex_ = std::make_shared<absl::Mutex>();
|
||||
}
|
||||
|
||||
// CPU getters.
|
||||
const ImageFrameSharedPtr& GetImageFrameSharedPtr() const {
|
||||
if (use_gpu_ == true) ConvertToCpu();
|
||||
return image_frame_;
|
||||
}
|
||||
|
||||
// Creates an Image representing the same image content as the input GPU
|
||||
// buffer in platform-specific representations.
|
||||
#if !MEDIAPIPE_DISABLE_GPU
|
||||
|
@ -95,13 +101,8 @@ class Image {
|
|||
gpu_buffer_ = gpu_buffer;
|
||||
pixel_mutex_ = std::make_shared<absl::Mutex>();
|
||||
}
|
||||
#endif // !MEDIAPIPE_DISABLE_GPU
|
||||
|
||||
const ImageFrameSharedPtr& GetImageFrameSharedPtr() const {
|
||||
if (use_gpu_ == true) ConvertToCpu();
|
||||
return image_frame_;
|
||||
}
|
||||
#if !MEDIAPIPE_DISABLE_GPU
|
||||
// GPU getters.
|
||||
#if MEDIAPIPE_GPU_BUFFER_USE_CV_PIXEL_BUFFER
|
||||
CVPixelBufferRef GetCVPixelBufferRef() const {
|
||||
if (use_gpu_ == false) ConvertToGpu();
|
||||
|
|
|
@ -47,8 +47,8 @@ message LandmarkList {
|
|||
repeated Landmark landmark = 1;
|
||||
}
|
||||
|
||||
// A normalized version of above Landmark proto. All coordiates should be within
|
||||
// [0, 1].
|
||||
// A normalized version of above Landmark proto. All coordinates should be
|
||||
// within [0, 1].
|
||||
message NormalizedLandmark {
|
||||
optional float x = 1;
|
||||
optional float y = 2;
|
||||
|
|
|
@ -67,11 +67,11 @@ cc_test(
|
|||
deps = [
|
||||
":optical_flow_field",
|
||||
"//mediapipe/framework/deps:file_path",
|
||||
"//mediapipe/framework/port:commandlineflags",
|
||||
"//mediapipe/framework/port:file_helpers",
|
||||
"//mediapipe/framework/port:gtest_main",
|
||||
"//mediapipe/framework/port:integral_types",
|
||||
"//mediapipe/framework/port:logging",
|
||||
"@com_google_absl//absl/flags:flag",
|
||||
"@org_tensorflow//tensorflow/core:framework",
|
||||
],
|
||||
)
|
||||
|
|
Some files were not shown because too many files have changed in this diff Show More
Loading…
Reference in New Issue
Block a user