Project import generated by Copybara.
GitOrigin-RevId: ff83882955f1a1e2a043ff4e71278be9d7217bbe
|
@ -23,6 +23,7 @@ ENV DEBIAN_FRONTEND=noninteractive
|
||||||
|
|
||||||
RUN apt-get update && apt-get install -y --no-install-recommends \
|
RUN apt-get update && apt-get install -y --no-install-recommends \
|
||||||
build-essential \
|
build-essential \
|
||||||
|
gcc-8 g++-8 \
|
||||||
ca-certificates \
|
ca-certificates \
|
||||||
curl \
|
curl \
|
||||||
ffmpeg \
|
ffmpeg \
|
||||||
|
@ -44,6 +45,7 @@ RUN apt-get update && apt-get install -y --no-install-recommends \
|
||||||
apt-get clean && \
|
apt-get clean && \
|
||||||
rm -rf /var/lib/apt/lists/*
|
rm -rf /var/lib/apt/lists/*
|
||||||
|
|
||||||
|
RUN update-alternatives --install /usr/bin/gcc gcc /usr/bin/gcc-8 100 --slave /usr/bin/g++ g++ /usr/bin/g++-8
|
||||||
RUN pip3 install --upgrade setuptools
|
RUN pip3 install --upgrade setuptools
|
||||||
RUN pip3 install wheel
|
RUN pip3 install wheel
|
||||||
RUN pip3 install future
|
RUN pip3 install future
|
||||||
|
|
|
@ -337,6 +337,8 @@ maven_install(
|
||||||
"androidx.test.espresso:espresso-core:3.1.1",
|
"androidx.test.espresso:espresso-core:3.1.1",
|
||||||
"com.github.bumptech.glide:glide:4.11.0",
|
"com.github.bumptech.glide:glide:4.11.0",
|
||||||
"com.google.android.material:material:aar:1.0.0-rc01",
|
"com.google.android.material:material:aar:1.0.0-rc01",
|
||||||
|
"com.google.auto.value:auto-value:1.6.4",
|
||||||
|
"com.google.auto.value:auto-value-annotations:1.6.4",
|
||||||
"com.google.code.findbugs:jsr305:3.0.2",
|
"com.google.code.findbugs:jsr305:3.0.2",
|
||||||
"com.google.flogger:flogger-system-backend:0.3.1",
|
"com.google.flogger:flogger-system-backend:0.3.1",
|
||||||
"com.google.flogger:flogger:0.3.1",
|
"com.google.flogger:flogger:0.3.1",
|
||||||
|
@ -367,9 +369,9 @@ http_archive(
|
||||||
)
|
)
|
||||||
|
|
||||||
# Tensorflow repo should always go after the other external dependencies.
|
# Tensorflow repo should always go after the other external dependencies.
|
||||||
# 2021-03-25
|
# 2021-04-30
|
||||||
_TENSORFLOW_GIT_COMMIT = "c67f68021824410ebe9f18513b8856ac1c6d4887"
|
_TENSORFLOW_GIT_COMMIT = "5bd3c57ef184543d22e34e36cff9d9bea608e06d"
|
||||||
_TENSORFLOW_SHA256= "fd07d0b39422dc435e268c5e53b2646a8b4b1e3151b87837b43f86068faae87f"
|
_TENSORFLOW_SHA256= "9a45862834221aafacf6fb275f92b3876bc89443cbecc51be93f13839a6609f0"
|
||||||
http_archive(
|
http_archive(
|
||||||
name = "org_tensorflow",
|
name = "org_tensorflow",
|
||||||
urls = [
|
urls = [
|
||||||
|
|
|
@ -17,15 +17,15 @@
|
||||||
# Script to build/run all MediaPipe desktop example apps (with webcam input).
|
# Script to build/run all MediaPipe desktop example apps (with webcam input).
|
||||||
#
|
#
|
||||||
# To build and run all apps and store them in out_dir:
|
# To build and run all apps and store them in out_dir:
|
||||||
# $ ./build_ios_examples.sh -d out_dir
|
# $ ./build_desktop_examples.sh -d out_dir
|
||||||
# Omitting -d and the associated directory saves all generated apps in the
|
# Omitting -d and the associated directory saves all generated apps in the
|
||||||
# current directory.
|
# current directory.
|
||||||
# To build all apps and store them in out_dir:
|
# To build all apps and store them in out_dir:
|
||||||
# $ ./build_ios_examples.sh -d out_dir -b
|
# $ ./build_desktop_examples.sh -d out_dir -b
|
||||||
# Omitting -d and the associated directory saves all generated apps in the
|
# Omitting -d and the associated directory saves all generated apps in the
|
||||||
# current directory.
|
# current directory.
|
||||||
# To run all apps already stored in out_dir:
|
# To run all apps already stored in out_dir:
|
||||||
# $ ./build_ios_examples.sh -d out_dir -r
|
# $ ./build_desktop_examples.sh -d out_dir -r
|
||||||
# Omitting -d and the associated directory assumes all apps are in the current
|
# Omitting -d and the associated directory assumes all apps are in the current
|
||||||
# directory.
|
# directory.
|
||||||
|
|
||||||
|
|
|
@ -187,7 +187,7 @@ node {
|
||||||
```
|
```
|
||||||
|
|
||||||
In the calculator implementation, inputs and outputs are also identified by tag
|
In the calculator implementation, inputs and outputs are also identified by tag
|
||||||
name and index number. In the function below input are output are identified:
|
name and index number. In the function below input and output are identified:
|
||||||
|
|
||||||
* By index number: The combined input stream is identified simply by index
|
* By index number: The combined input stream is identified simply by index
|
||||||
`0`.
|
`0`.
|
||||||
|
@ -355,7 +355,6 @@ class PacketClonerCalculator : public CalculatorBase {
|
||||||
current_[i].At(cc->InputTimestamp()));
|
current_[i].At(cc->InputTimestamp()));
|
||||||
// Add a packet to output stream of index i a packet from inputstream i
|
// Add a packet to output stream of index i a packet from inputstream i
|
||||||
// with timestamp common to all present inputs
|
// with timestamp common to all present inputs
|
||||||
//
|
|
||||||
} else {
|
} else {
|
||||||
cc->Outputs().Index(i).SetNextTimestampBound(
|
cc->Outputs().Index(i).SetNextTimestampBound(
|
||||||
cc->InputTimestamp().NextAllowedInStream());
|
cc->InputTimestamp().NextAllowedInStream());
|
||||||
|
@ -382,7 +381,7 @@ defined your calculator class, register it with a macro invocation
|
||||||
REGISTER_CALCULATOR(calculator_class_name).
|
REGISTER_CALCULATOR(calculator_class_name).
|
||||||
|
|
||||||
Below is a trivial MediaPipe graph that has 3 input streams, 1 node
|
Below is a trivial MediaPipe graph that has 3 input streams, 1 node
|
||||||
(PacketClonerCalculator) and 3 output streams.
|
(PacketClonerCalculator) and 2 output streams.
|
||||||
|
|
||||||
```proto
|
```proto
|
||||||
input_stream: "room_mic_signal"
|
input_stream: "room_mic_signal"
|
||||||
|
|
|
@ -83,12 +83,12 @@ Below is an example of how to create a subgraph named `TwoPassThroughSubgraph`.
|
||||||
output_stream: "out3"
|
output_stream: "out3"
|
||||||
|
|
||||||
node {
|
node {
|
||||||
calculator: "PassThroughculator"
|
calculator: "PassThroughCalculator"
|
||||||
input_stream: "out1"
|
input_stream: "out1"
|
||||||
output_stream: "out2"
|
output_stream: "out2"
|
||||||
}
|
}
|
||||||
node {
|
node {
|
||||||
calculator: "PassThroughculator"
|
calculator: "PassThroughCalculator"
|
||||||
input_stream: "out2"
|
input_stream: "out2"
|
||||||
output_stream: "out3"
|
output_stream: "out3"
|
||||||
}
|
}
|
||||||
|
|
|
@ -57,7 +57,7 @@ Please verify all the necessary packages are installed.
|
||||||
* Android SDK Build-Tools 28 or 29
|
* Android SDK Build-Tools 28 or 29
|
||||||
* Android SDK Platform-Tools 28 or 29
|
* Android SDK Platform-Tools 28 or 29
|
||||||
* Android SDK Tools 26.1.1
|
* Android SDK Tools 26.1.1
|
||||||
* Android NDK 17c or above
|
* Android NDK 19c or above
|
||||||
|
|
||||||
### Option 1: Build with Bazel in Command Line
|
### Option 1: Build with Bazel in Command Line
|
||||||
|
|
||||||
|
@ -111,7 +111,7 @@ app:
|
||||||
* Verify that Android SDK Build-Tools 28 or 29 is installed.
|
* Verify that Android SDK Build-Tools 28 or 29 is installed.
|
||||||
* Verify that Android SDK Platform-Tools 28 or 29 is installed.
|
* Verify that Android SDK Platform-Tools 28 or 29 is installed.
|
||||||
* Verify that Android SDK Tools 26.1.1 is installed.
|
* Verify that Android SDK Tools 26.1.1 is installed.
|
||||||
* Verify that Android NDK 17c or above is installed.
|
* Verify that Android NDK 19c or above is installed.
|
||||||
* Take note of the Android NDK Location, e.g.,
|
* Take note of the Android NDK Location, e.g.,
|
||||||
`/usr/local/home/Android/Sdk/ndk-bundle` or
|
`/usr/local/home/Android/Sdk/ndk-bundle` or
|
||||||
`/usr/local/home/Android/Sdk/ndk/20.0.5594570`.
|
`/usr/local/home/Android/Sdk/ndk/20.0.5594570`.
|
||||||
|
|
|
@ -37,7 +37,7 @@ each project.
|
||||||
load("//mediapipe/java/com/google/mediapipe:mediapipe_aar.bzl", "mediapipe_aar")
|
load("//mediapipe/java/com/google/mediapipe:mediapipe_aar.bzl", "mediapipe_aar")
|
||||||
|
|
||||||
mediapipe_aar(
|
mediapipe_aar(
|
||||||
name = "mp_face_detection_aar",
|
name = "mediapipe_face_detection",
|
||||||
calculators = ["//mediapipe/graphs/face_detection:mobile_calculators"],
|
calculators = ["//mediapipe/graphs/face_detection:mobile_calculators"],
|
||||||
)
|
)
|
||||||
```
|
```
|
||||||
|
@ -45,26 +45,29 @@ each project.
|
||||||
2. Run the Bazel build command to generate the AAR.
|
2. Run the Bazel build command to generate the AAR.
|
||||||
|
|
||||||
```bash
|
```bash
|
||||||
bazel build -c opt --host_crosstool_top=@bazel_tools//tools/cpp:toolchain \
|
bazel build -c opt --strip=ALWAYS \
|
||||||
--fat_apk_cpu=arm64-v8a,armeabi-v7a --strip=ALWAYS \
|
--host_crosstool_top=@bazel_tools//tools/cpp:toolchain \
|
||||||
//path/to/the/aar/build/file:aar_name
|
--fat_apk_cpu=arm64-v8a,armeabi-v7a \
|
||||||
|
//path/to/the/aar/build/file:aar_name.aar
|
||||||
```
|
```
|
||||||
|
|
||||||
For the face detection AAR target we made in the step 1, run:
|
For the face detection AAR target we made in step 1, run:
|
||||||
|
|
||||||
```bash
|
```bash
|
||||||
bazel build -c opt --host_crosstool_top=@bazel_tools//tools/cpp:toolchain --fat_apk_cpu=arm64-v8a,armeabi-v7a \
|
bazel build -c opt --strip=ALWAYS \
|
||||||
//mediapipe/examples/android/src/java/com/google/mediapipe/apps/aar_example:mp_face_detection_aar
|
--host_crosstool_top=@bazel_tools//tools/cpp:toolchain \
|
||||||
|
--fat_apk_cpu=arm64-v8a,armeabi-v7a \
|
||||||
|
//mediapipe/examples/android/src/java/com/google/mediapipe/apps/aar_example:mediapipe_face_detection.aar
|
||||||
|
|
||||||
# It should print:
|
# It should print:
|
||||||
# Target //mediapipe/examples/android/src/java/com/google/mediapipe/apps/aar_example:mp_face_detection_aar up-to-date:
|
# Target //mediapipe/examples/android/src/java/com/google/mediapipe/apps/aar_example:mediapipe_face_detection.aar up-to-date:
|
||||||
# bazel-bin/mediapipe/examples/android/src/java/com/google/mediapipe/apps/aar_example/mp_face_detection_aar.aar
|
# bazel-bin/mediapipe/examples/android/src/java/com/google/mediapipe/apps/aar_example/mediapipe_face_detection.aar
|
||||||
```
|
```
|
||||||
|
|
||||||
3. (Optional) Save the AAR to your preferred location.
|
3. (Optional) Save the AAR to your preferred location.
|
||||||
|
|
||||||
```bash
|
```bash
|
||||||
cp bazel-bin/mediapipe/examples/android/src/java/com/google/mediapipe/apps/aar_example/mp_face_detection_aar.aar
|
cp bazel-bin/mediapipe/examples/android/src/java/com/google/mediapipe/apps/aar_example/mediapipe_face_detection.aar
|
||||||
/absolute/path/to/your/preferred/location
|
/absolute/path/to/your/preferred/location
|
||||||
```
|
```
|
||||||
|
|
||||||
|
@ -75,7 +78,7 @@ each project.
|
||||||
2. Copy the AAR into app/libs.
|
2. Copy the AAR into app/libs.
|
||||||
|
|
||||||
```bash
|
```bash
|
||||||
cp bazel-bin/mediapipe/examples/android/src/java/com/google/mediapipe/apps/aar_example/mp_face_detection_aar.aar
|
cp bazel-bin/mediapipe/examples/android/src/java/com/google/mediapipe/apps/aar_example/mediapipe_face_detection.aar
|
||||||
/path/to/your/app/libs/
|
/path/to/your/app/libs/
|
||||||
```
|
```
|
||||||
|
|
||||||
|
@ -92,29 +95,14 @@ each project.
|
||||||
[the face detection tflite model](https://github.com/google/mediapipe/tree/master/mediapipe/modules/face_detection/face_detection_front.tflite).
|
[the face detection tflite model](https://github.com/google/mediapipe/tree/master/mediapipe/modules/face_detection/face_detection_front.tflite).
|
||||||
|
|
||||||
```bash
|
```bash
|
||||||
bazel build -c opt mediapipe/mediapipe/graphs/face_detection:mobile_gpu_binary_graph
|
bazel build -c opt mediapipe/graphs/face_detection:face_detection_mobile_gpu_binary_graph
|
||||||
cp bazel-bin/mediapipe/graphs/face_detection/mobile_gpu.binarypb /path/to/your/app/src/main/assets/
|
cp bazel-bin/mediapipe/graphs/face_detection/face_detection_mobile_gpu.binarypb /path/to/your/app/src/main/assets/
|
||||||
cp mediapipe/modules/face_detection/face_detection_front.tflite /path/to/your/app/src/main/assets/
|
cp mediapipe/modules/face_detection/face_detection_front.tflite /path/to/your/app/src/main/assets/
|
||||||
```
|
```
|
||||||
|
|
||||||
![Screenshot](../images/mobile/assets_location.png)
|
![Screenshot](../images/mobile/assets_location.png)
|
||||||
|
|
||||||
4. Make app/src/main/jniLibs and copy OpenCV JNI libraries into
|
4. Modify app/build.gradle to add MediaPipe dependencies and MediaPipe AAR.
|
||||||
app/src/main/jniLibs.
|
|
||||||
|
|
||||||
MediaPipe depends on OpenCV, you will need to copy the precompiled OpenCV so
|
|
||||||
files into app/src/main/jniLibs. You can download the official OpenCV
|
|
||||||
Android SDK from
|
|
||||||
[here](https://github.com/opencv/opencv/releases/download/3.4.3/opencv-3.4.3-android-sdk.zip)
|
|
||||||
and run:
|
|
||||||
|
|
||||||
```bash
|
|
||||||
cp -R ~/Downloads/OpenCV-android-sdk/sdk/native/libs/arm* /path/to/your/app/src/main/jniLibs/
|
|
||||||
```
|
|
||||||
|
|
||||||
![Screenshot](../images/mobile/android_studio_opencv_location.png)
|
|
||||||
|
|
||||||
5. Modify app/build.gradle to add MediaPipe dependencies and MediaPipe AAR.
|
|
||||||
|
|
||||||
```
|
```
|
||||||
dependencies {
|
dependencies {
|
||||||
|
@ -136,10 +124,14 @@ each project.
|
||||||
implementation "androidx.camera:camera-core:$camerax_version"
|
implementation "androidx.camera:camera-core:$camerax_version"
|
||||||
implementation "androidx.camera:camera-camera2:$camerax_version"
|
implementation "androidx.camera:camera-camera2:$camerax_version"
|
||||||
implementation "androidx.camera:camera-lifecycle:$camerax_version"
|
implementation "androidx.camera:camera-lifecycle:$camerax_version"
|
||||||
|
// AutoValue
|
||||||
|
def auto_value_version = "1.6.4"
|
||||||
|
implementation "com.google.auto.value:auto-value-annotations:$auto_value_version"
|
||||||
|
annotationProcessor "com.google.auto.value:auto-value:$auto_value_version"
|
||||||
}
|
}
|
||||||
```
|
```
|
||||||
|
|
||||||
6. Follow our Android app examples to use MediaPipe in Android Studio for your
|
5. Follow our Android app examples to use MediaPipe in Android Studio for your
|
||||||
use case. If you are looking for an example, a face detection example can be
|
use case. If you are looking for an example, a face detection example can be
|
||||||
found
|
found
|
||||||
[here](https://github.com/jiuqiant/mediapipe_face_detection_aar_example) and
|
[here](https://github.com/jiuqiant/mediapipe_face_detection_aar_example) and
|
||||||
|
|
|
@ -471,7 +471,7 @@ next section.
|
||||||
4. Install Visual C++ Build Tools 2019 and WinSDK
|
4. Install Visual C++ Build Tools 2019 and WinSDK
|
||||||
|
|
||||||
Go to
|
Go to
|
||||||
[the VisualStudio website](ttps://visualstudio.microsoft.com/visual-cpp-build-tools),
|
[the VisualStudio website](https://visualstudio.microsoft.com/visual-cpp-build-tools),
|
||||||
download build tools, and install Microsoft Visual C++ 2019 Redistributable
|
download build tools, and install Microsoft Visual C++ 2019 Redistributable
|
||||||
and Microsoft Build Tools 2019.
|
and Microsoft Build Tools 2019.
|
||||||
|
|
||||||
|
@ -738,7 +738,7 @@ common build issues.
|
||||||
root@bca08b91ff63:/mediapipe# bash ./setup_android_sdk_and_ndk.sh
|
root@bca08b91ff63:/mediapipe# bash ./setup_android_sdk_and_ndk.sh
|
||||||
|
|
||||||
# Should print:
|
# Should print:
|
||||||
# Android NDK is now installed. Consider setting $ANDROID_NDK_HOME environment variable to be /root/Android/Sdk/ndk-bundle/android-ndk-r18b
|
# Android NDK is now installed. Consider setting $ANDROID_NDK_HOME environment variable to be /root/Android/Sdk/ndk-bundle/android-ndk-r19c
|
||||||
# Set android_ndk_repository and android_sdk_repository in WORKSPACE
|
# Set android_ndk_repository and android_sdk_repository in WORKSPACE
|
||||||
# Done
|
# Done
|
||||||
|
|
||||||
|
|
|
@ -26,7 +26,7 @@ You can, for instance, activate a Python virtual environment:
|
||||||
$ python3 -m venv mp_env && source mp_env/bin/activate
|
$ python3 -m venv mp_env && source mp_env/bin/activate
|
||||||
```
|
```
|
||||||
|
|
||||||
Install MediaPipe Python package and start Python intepreter:
|
Install MediaPipe Python package and start Python interpreter:
|
||||||
|
|
||||||
```bash
|
```bash
|
||||||
(mp_env)$ pip install mediapipe
|
(mp_env)$ pip install mediapipe
|
||||||
|
|
|
@ -97,6 +97,49 @@ linux_opencv/macos_opencv/windows_opencv.BUILD files for your local opencv
|
||||||
libraries. [This GitHub issue](https://github.com/google/mediapipe/issues/666)
|
libraries. [This GitHub issue](https://github.com/google/mediapipe/issues/666)
|
||||||
may also help.
|
may also help.
|
||||||
|
|
||||||
|
## Python pip install failure
|
||||||
|
|
||||||
|
The error message:
|
||||||
|
|
||||||
|
```
|
||||||
|
ERROR: Could not find a version that satisfies the requirement mediapipe
|
||||||
|
ERROR: No matching distribution found for mediapipe
|
||||||
|
```
|
||||||
|
|
||||||
|
after running `pip install mediapipe` usually indicates that there is no qualified MediaPipe Python for your system.
|
||||||
|
Please note that MediaPipe Python PyPI officially supports the **64-bit**
|
||||||
|
version of Python 3.7 and above on the following OS:
|
||||||
|
|
||||||
|
- x86_64 Linux
|
||||||
|
- x86_64 macOS 10.15+
|
||||||
|
- amd64 Windows
|
||||||
|
|
||||||
|
If the OS is currently supported and you still see this error, please make sure
|
||||||
|
that both the Python and pip binary are for Python 3.7 and above. Otherwise,
|
||||||
|
please consider building the MediaPipe Python package locally by following the
|
||||||
|
instructions [here](python.md#building-mediapipe-python-package).
|
||||||
|
|
||||||
|
## Python DLL load failure on Windows
|
||||||
|
|
||||||
|
The error message:
|
||||||
|
|
||||||
|
```
|
||||||
|
ImportError: DLL load failed: The specified module could not be found
|
||||||
|
```
|
||||||
|
|
||||||
|
usually indicates that the local Windows system is missing Visual C++
|
||||||
|
redistributable packages and/or Visual C++ runtime DLLs. This can be solved by
|
||||||
|
either installing the official
|
||||||
|
[vc_redist.x64.exe](https://support.microsoft.com/en-us/topic/the-latest-supported-visual-c-downloads-2647da03-1eea-4433-9aff-95f26a218cc0)
|
||||||
|
or installing the "msvc-runtime" Python package by running
|
||||||
|
|
||||||
|
```bash
|
||||||
|
$ python -m pip install msvc-runtime
|
||||||
|
```
|
||||||
|
|
||||||
|
Please note that the "msvc-runtime" Python package is not released or maintained
|
||||||
|
by Microsoft.
|
||||||
|
|
||||||
## Native method not found
|
## Native method not found
|
||||||
|
|
||||||
The error message:
|
The error message:
|
||||||
|
|
Before Width: | Height: | Size: 35 KiB After Width: | Height: | Size: 34 KiB |
Before Width: | Height: | Size: 75 KiB |
Before Width: | Height: | Size: 29 KiB After Width: | Height: | Size: 42 KiB |
BIN
docs/images/mobile/pose_tracking_example.gif
Normal file
After Width: | Height: | Size: 2.3 MiB |
Before Width: | Height: | Size: 6.9 MiB |
|
@ -77,7 +77,7 @@ Supported configuration options:
|
||||||
```python
|
```python
|
||||||
import cv2
|
import cv2
|
||||||
import mediapipe as mp
|
import mediapipe as mp
|
||||||
mp_face_detction = mp.solutions.face_detection
|
mp_face_detection = mp.solutions.face_detection
|
||||||
mp_drawing = mp.solutions.drawing_utils
|
mp_drawing = mp.solutions.drawing_utils
|
||||||
|
|
||||||
# For static images:
|
# For static images:
|
||||||
|
|
|
@ -135,12 +135,11 @@ another detection until it loses track, on reducing computation and latency. If
|
||||||
set to `true`, person detection runs every input image, ideal for processing a
|
set to `true`, person detection runs every input image, ideal for processing a
|
||||||
batch of static, possibly unrelated, images. Default to `false`.
|
batch of static, possibly unrelated, images. Default to `false`.
|
||||||
|
|
||||||
#### upper_body_only
|
#### model_complexity
|
||||||
|
|
||||||
If set to `true`, the solution outputs only the 25 upper-body pose landmarks
|
Complexity of the pose landmark model: `0`, `1` or `2`. Landmark accuracy as
|
||||||
(535 in total) instead of the full set of 33 pose landmarks (543 in total). Note
|
well as inference latency generally go up with the model complexity. Default to
|
||||||
that upper-body-only prediction may be more accurate for use cases where the
|
`1`.
|
||||||
lower-body parts are mostly out of view. Default to `false`.
|
|
||||||
|
|
||||||
#### smooth_landmarks
|
#### smooth_landmarks
|
||||||
|
|
||||||
|
@ -207,7 +206,7 @@ install MediaPipe Python package, then learn more in the companion
|
||||||
Supported configuration options:
|
Supported configuration options:
|
||||||
|
|
||||||
* [static_image_mode](#static_image_mode)
|
* [static_image_mode](#static_image_mode)
|
||||||
* [upper_body_only](#upper_body_only)
|
* [model_complexity](#model_complexity)
|
||||||
* [smooth_landmarks](#smooth_landmarks)
|
* [smooth_landmarks](#smooth_landmarks)
|
||||||
* [min_detection_confidence](#min_detection_confidence)
|
* [min_detection_confidence](#min_detection_confidence)
|
||||||
* [min_tracking_confidence](#min_tracking_confidence)
|
* [min_tracking_confidence](#min_tracking_confidence)
|
||||||
|
@ -219,7 +218,9 @@ mp_drawing = mp.solutions.drawing_utils
|
||||||
mp_holistic = mp.solutions.holistic
|
mp_holistic = mp.solutions.holistic
|
||||||
|
|
||||||
# For static images:
|
# For static images:
|
||||||
with mp_holistic.Holistic(static_image_mode=True) as holistic:
|
with mp_holistic.Holistic(
|
||||||
|
static_image_mode=True,
|
||||||
|
model_complexity=2) as holistic:
|
||||||
for idx, file in enumerate(file_list):
|
for idx, file in enumerate(file_list):
|
||||||
image = cv2.imread(file)
|
image = cv2.imread(file)
|
||||||
image_height, image_width, _ = image.shape
|
image_height, image_width, _ = image.shape
|
||||||
|
@ -240,8 +241,6 @@ with mp_holistic.Holistic(static_image_mode=True) as holistic:
|
||||||
annotated_image, results.left_hand_landmarks, mp_holistic.HAND_CONNECTIONS)
|
annotated_image, results.left_hand_landmarks, mp_holistic.HAND_CONNECTIONS)
|
||||||
mp_drawing.draw_landmarks(
|
mp_drawing.draw_landmarks(
|
||||||
annotated_image, results.right_hand_landmarks, mp_holistic.HAND_CONNECTIONS)
|
annotated_image, results.right_hand_landmarks, mp_holistic.HAND_CONNECTIONS)
|
||||||
# Use mp_holistic.UPPER_BODY_POSE_CONNECTIONS for drawing below when
|
|
||||||
# upper_body_only is set to True.
|
|
||||||
mp_drawing.draw_landmarks(
|
mp_drawing.draw_landmarks(
|
||||||
annotated_image, results.pose_landmarks, mp_holistic.POSE_CONNECTIONS)
|
annotated_image, results.pose_landmarks, mp_holistic.POSE_CONNECTIONS)
|
||||||
cv2.imwrite('/tmp/annotated_image' + str(idx) + '.png', annotated_image)
|
cv2.imwrite('/tmp/annotated_image' + str(idx) + '.png', annotated_image)
|
||||||
|
@ -291,7 +290,7 @@ and the following usage example.
|
||||||
|
|
||||||
Supported configuration options:
|
Supported configuration options:
|
||||||
|
|
||||||
* [upperBodyOnly](#upper_body_only)
|
* [modelComplexity](#model_complexity)
|
||||||
* [smoothLandmarks](#smooth_landmarks)
|
* [smoothLandmarks](#smooth_landmarks)
|
||||||
* [minDetectionConfidence](#min_detection_confidence)
|
* [minDetectionConfidence](#min_detection_confidence)
|
||||||
* [minTrackingConfidence](#min_tracking_confidence)
|
* [minTrackingConfidence](#min_tracking_confidence)
|
||||||
|
@ -348,7 +347,7 @@ const holistic = new Holistic({locateFile: (file) => {
|
||||||
return `https://cdn.jsdelivr.net/npm/@mediapipe/holistic/${file}`;
|
return `https://cdn.jsdelivr.net/npm/@mediapipe/holistic/${file}`;
|
||||||
}});
|
}});
|
||||||
holistic.setOptions({
|
holistic.setOptions({
|
||||||
upperBodyOnly: false,
|
modelComplexity: 1,
|
||||||
smoothLandmarks: true,
|
smoothLandmarks: true,
|
||||||
minDetectionConfidence: 0.5,
|
minDetectionConfidence: 0.5,
|
||||||
minTrackingConfidence: 0.5
|
minTrackingConfidence: 0.5
|
||||||
|
|
|
@ -15,10 +15,10 @@ nav_order: 30
|
||||||
### [Face Detection](https://google.github.io/mediapipe/solutions/face_detection)
|
### [Face Detection](https://google.github.io/mediapipe/solutions/face_detection)
|
||||||
|
|
||||||
* Face detection model for front-facing/selfie camera:
|
* Face detection model for front-facing/selfie camera:
|
||||||
[TFLite model](https://github.com/google/mediapipe/tree/master/mediapipe/models/face_detection_front.tflite),
|
[TFLite model](https://github.com/google/mediapipe/tree/master/mediapipe/modules/face_detection/face_detection_front.tflite),
|
||||||
[TFLite model quantized for EdgeTPU/Coral](https://github.com/google/mediapipe/tree/master/mediapipe/examples/coral/models/face-detector-quantized_edgetpu.tflite)
|
[TFLite model quantized for EdgeTPU/Coral](https://github.com/google/mediapipe/tree/master/mediapipe/examples/coral/models/face-detector-quantized_edgetpu.tflite)
|
||||||
* Face detection model for back-facing camera:
|
* Face detection model for back-facing camera:
|
||||||
[TFLite model ](https://github.com/google/mediapipe/tree/master/mediapipe/models/face_detection_back.tflite)
|
[TFLite model ](https://github.com/google/mediapipe/tree/master/mediapipe/modules/face_detection/face_detection_back.tflite)
|
||||||
* [Model card](https://mediapipe.page.link/blazeface-mc)
|
* [Model card](https://mediapipe.page.link/blazeface-mc)
|
||||||
|
|
||||||
### [Face Mesh](https://google.github.io/mediapipe/solutions/face_mesh)
|
### [Face Mesh](https://google.github.io/mediapipe/solutions/face_mesh)
|
||||||
|
@ -49,10 +49,10 @@ nav_order: 30
|
||||||
|
|
||||||
* Pose detection model:
|
* Pose detection model:
|
||||||
[TFLite model](https://github.com/google/mediapipe/tree/master/mediapipe/modules/pose_detection/pose_detection.tflite)
|
[TFLite model](https://github.com/google/mediapipe/tree/master/mediapipe/modules/pose_detection/pose_detection.tflite)
|
||||||
* Full-body pose landmark model:
|
* Pose landmark model:
|
||||||
[TFLite model](https://github.com/google/mediapipe/tree/master/mediapipe/modules/pose_landmark/pose_landmark_full_body.tflite)
|
[TFLite model (lite)](https://github.com/google/mediapipe/tree/master/mediapipe/modules/pose_landmark/pose_landmark_lite.tflite),
|
||||||
* Upper-body pose landmark model:
|
[TFLite model (full)](https://github.com/google/mediapipe/tree/master/mediapipe/modules/pose_landmark/pose_landmark_full.tflite),
|
||||||
[TFLite model](https://github.com/google/mediapipe/tree/master/mediapipe/modules/pose_landmark/pose_landmark_upper_body.tflite)
|
[TFLite model (heavy)](https://github.com/google/mediapipe/tree/master/mediapipe/modules/pose_landmark/pose_landmark_heavy.tflite)
|
||||||
* [Model card](https://mediapipe.page.link/blazepose-mc)
|
* [Model card](https://mediapipe.page.link/blazepose-mc)
|
||||||
|
|
||||||
### [Holistic](https://google.github.io/mediapipe/solutions/holistic)
|
### [Holistic](https://google.github.io/mediapipe/solutions/holistic)
|
||||||
|
|
|
@ -30,8 +30,7 @@ overlay of digital content and information on top of the physical world in
|
||||||
augmented reality.
|
augmented reality.
|
||||||
|
|
||||||
MediaPipe Pose is a ML solution for high-fidelity body pose tracking, inferring
|
MediaPipe Pose is a ML solution for high-fidelity body pose tracking, inferring
|
||||||
33 3D landmarks on the whole body (or 25 upper-body landmarks) from RGB video
|
33 3D landmarks on the whole body from RGB video frames utilizing our
|
||||||
frames utilizing our
|
|
||||||
[BlazePose](https://ai.googleblog.com/2020/08/on-device-real-time-body-pose-tracking.html)
|
[BlazePose](https://ai.googleblog.com/2020/08/on-device-real-time-body-pose-tracking.html)
|
||||||
research that also powers the
|
research that also powers the
|
||||||
[ML Kit Pose Detection API](https://developers.google.com/ml-kit/vision/pose-detection).
|
[ML Kit Pose Detection API](https://developers.google.com/ml-kit/vision/pose-detection).
|
||||||
|
@ -40,9 +39,9 @@ environments for inference, whereas our method achieves real-time performance on
|
||||||
most modern [mobile phones](#mobile), [desktops/laptops](#desktop), in
|
most modern [mobile phones](#mobile), [desktops/laptops](#desktop), in
|
||||||
[python](#python-solution-api) and even on the [web](#javascript-solution-api).
|
[python](#python-solution-api) and even on the [web](#javascript-solution-api).
|
||||||
|
|
||||||
![pose_tracking_upper_body_example.gif](../images/mobile/pose_tracking_upper_body_example.gif) |
|
![pose_tracking_example.gif](../images/mobile/pose_tracking_example.gif) |
|
||||||
:--------------------------------------------------------------------------------------------: |
|
:----------------------------------------------------------------------: |
|
||||||
*Fig 1. Example of MediaPipe Pose for upper-body pose tracking.* |
|
*Fig 1. Example of MediaPipe Pose for pose tracking.* |
|
||||||
|
|
||||||
## ML Pipeline
|
## ML Pipeline
|
||||||
|
|
||||||
|
@ -77,6 +76,23 @@ Note: To visualize a graph, copy the graph and paste it into
|
||||||
to visualize its associated subgraphs, please see
|
to visualize its associated subgraphs, please see
|
||||||
[visualizer documentation](../tools/visualizer.md).
|
[visualizer documentation](../tools/visualizer.md).
|
||||||
|
|
||||||
|
## Pose Estimation Quality
|
||||||
|
|
||||||
|
To evaluate the quality of our [models](./models.md#pose) against other
|
||||||
|
well-performing publicly available solutions, we use a validation dataset,
|
||||||
|
consisting of 1k images with diverse Yoga, HIIT, and Dance postures. Each image
|
||||||
|
contains only a single person located 2-4 meters from the camera. To be
|
||||||
|
consistent with other solutions, we perform evaluation only for 17 keypoints
|
||||||
|
from [COCO topology](https://cocodataset.org/#keypoints-2020).
|
||||||
|
|
||||||
|
Method | [mAP](https://cocodataset.org/#keypoints-eval) | [PCK@0.2](https://github.com/cbsudux/Human-Pose-Estimation-101) | [FPS](https://en.wikipedia.org/wiki/Frame_rate), Pixel 3 [TFLite GPU](https://www.tensorflow.org/lite/performance/gpu_advanced) | [FPS](https://en.wikipedia.org/wiki/Frame_rate), MacBook Pro (15-inch, 2017)
|
||||||
|
----------------------------------------------------------------------------------------------------- | ---------------------------------------------: | --------------------------------------------------------------: | ------------------------------------------------------------------------------------------------------------------------------: | ---------------------------------------------------------------------------:
|
||||||
|
BlazePose.Lite | 49.1 | 91.7 | 49 | 40
|
||||||
|
BlazePose.Full | 64.5 | 95.8 | 40 | 37
|
||||||
|
BlazePose.Heavy | 70.9 | 97.0 | 19 | 26
|
||||||
|
[AlphaPose.ResNet50](https://github.com/MVIG-SJTU/AlphaPose) | 57.6 | 93.1 | N/A | N/A
|
||||||
|
[Apple Vision](https://developer.apple.com/documentation/vision/detecting_human_body_poses_in_images) | 37.0 | 85.3 | N/A | N/A
|
||||||
|
|
||||||
## Models
|
## Models
|
||||||
|
|
||||||
### Person/pose Detection Model (BlazePose Detector)
|
### Person/pose Detection Model (BlazePose Detector)
|
||||||
|
@ -97,11 +113,8 @@ hip midpoints.
|
||||||
|
|
||||||
### Pose Landmark Model (BlazePose GHUM 3D)
|
### Pose Landmark Model (BlazePose GHUM 3D)
|
||||||
|
|
||||||
The landmark model in MediaPipe Pose comes in two versions: a full-body model
|
The landmark model in MediaPipe Pose predicts the location of 33 pose landmarks
|
||||||
that predicts the location of 33 pose landmarks (see figure below), and an
|
(see figure below).
|
||||||
upper-body version that only predicts the first 25. The latter may be more
|
|
||||||
accurate than the former in scenarios where the lower-body parts are mostly out
|
|
||||||
of view.
|
|
||||||
|
|
||||||
Please find more detail in the
|
Please find more detail in the
|
||||||
[BlazePose Google AI Blog](https://ai.googleblog.com/2020/08/on-device-real-time-body-pose-tracking.html),
|
[BlazePose Google AI Blog](https://ai.googleblog.com/2020/08/on-device-real-time-body-pose-tracking.html),
|
||||||
|
@ -129,12 +142,11 @@ until it loses track, on reducing computation and latency. If set to `true`,
|
||||||
person detection runs every input image, ideal for processing a batch of static,
|
person detection runs every input image, ideal for processing a batch of static,
|
||||||
possibly unrelated, images. Default to `false`.
|
possibly unrelated, images. Default to `false`.
|
||||||
|
|
||||||
#### upper_body_only
|
#### model_complexity
|
||||||
|
|
||||||
If set to `true`, the solution outputs only the 25 upper-body pose landmarks.
|
Complexity of the pose landmark model: `0`, `1` or `2`. Landmark accuracy as
|
||||||
Otherwise, it outputs the full set of 33 pose landmarks. Note that
|
well as inference latency generally go up with the model complexity. Default to
|
||||||
upper-body-only prediction may be more accurate for use cases where the
|
`1`.
|
||||||
lower-body parts are mostly out of view. Default to `false`.
|
|
||||||
|
|
||||||
#### smooth_landmarks
|
#### smooth_landmarks
|
||||||
|
|
||||||
|
@ -170,9 +182,6 @@ A list of pose landmarks. Each lanmark consists of the following:
|
||||||
being the origin, and the smaller the value the closer the landmark is to
|
being the origin, and the smaller the value the closer the landmark is to
|
||||||
the camera. The magnitude of `z` uses roughly the same scale as `x`.
|
the camera. The magnitude of `z` uses roughly the same scale as `x`.
|
||||||
|
|
||||||
Note: `z` is predicted only in full-body mode, and should be discarded when
|
|
||||||
[upper_body_only](#upper_body_only) is `true`.
|
|
||||||
|
|
||||||
* `visibility`: A value in `[0.0, 1.0]` indicating the likelihood of the
|
* `visibility`: A value in `[0.0, 1.0]` indicating the likelihood of the
|
||||||
landmark being visible (present and not occluded) in the image.
|
landmark being visible (present and not occluded) in the image.
|
||||||
|
|
||||||
|
@ -185,7 +194,7 @@ install MediaPipe Python package, then learn more in the companion
|
||||||
Supported configuration options:
|
Supported configuration options:
|
||||||
|
|
||||||
* [static_image_mode](#static_image_mode)
|
* [static_image_mode](#static_image_mode)
|
||||||
* [upper_body_only](#upper_body_only)
|
* [model_complexity](#model_complexity)
|
||||||
* [smooth_landmarks](#smooth_landmarks)
|
* [smooth_landmarks](#smooth_landmarks)
|
||||||
* [min_detection_confidence](#min_detection_confidence)
|
* [min_detection_confidence](#min_detection_confidence)
|
||||||
* [min_tracking_confidence](#min_tracking_confidence)
|
* [min_tracking_confidence](#min_tracking_confidence)
|
||||||
|
@ -198,7 +207,9 @@ mp_pose = mp.solutions.pose
|
||||||
|
|
||||||
# For static images:
|
# For static images:
|
||||||
with mp_pose.Pose(
|
with mp_pose.Pose(
|
||||||
static_image_mode=True, min_detection_confidence=0.5) as pose:
|
static_image_mode=True,
|
||||||
|
model_complexity=2,
|
||||||
|
min_detection_confidence=0.5) as pose:
|
||||||
for idx, file in enumerate(file_list):
|
for idx, file in enumerate(file_list):
|
||||||
image = cv2.imread(file)
|
image = cv2.imread(file)
|
||||||
image_height, image_width, _ = image.shape
|
image_height, image_width, _ = image.shape
|
||||||
|
@ -214,8 +225,6 @@ with mp_pose.Pose(
|
||||||
)
|
)
|
||||||
# Draw pose landmarks on the image.
|
# Draw pose landmarks on the image.
|
||||||
annotated_image = image.copy()
|
annotated_image = image.copy()
|
||||||
# Use mp_pose.UPPER_BODY_POSE_CONNECTIONS for drawing below when
|
|
||||||
# upper_body_only is set to True.
|
|
||||||
mp_drawing.draw_landmarks(
|
mp_drawing.draw_landmarks(
|
||||||
annotated_image, results.pose_landmarks, mp_pose.POSE_CONNECTIONS)
|
annotated_image, results.pose_landmarks, mp_pose.POSE_CONNECTIONS)
|
||||||
cv2.imwrite('/tmp/annotated_image' + str(idx) + '.png', annotated_image)
|
cv2.imwrite('/tmp/annotated_image' + str(idx) + '.png', annotated_image)
|
||||||
|
@ -259,7 +268,7 @@ and the following usage example.
|
||||||
|
|
||||||
Supported configuration options:
|
Supported configuration options:
|
||||||
|
|
||||||
* [upperBodyOnly](#upper_body_only)
|
* [modelComplexity](#model_complexity)
|
||||||
* [smoothLandmarks](#smooth_landmarks)
|
* [smoothLandmarks](#smooth_landmarks)
|
||||||
* [minDetectionConfidence](#min_detection_confidence)
|
* [minDetectionConfidence](#min_detection_confidence)
|
||||||
* [minTrackingConfidence](#min_tracking_confidence)
|
* [minTrackingConfidence](#min_tracking_confidence)
|
||||||
|
@ -306,7 +315,7 @@ const pose = new Pose({locateFile: (file) => {
|
||||||
return `https://cdn.jsdelivr.net/npm/@mediapipe/pose/${file}`;
|
return `https://cdn.jsdelivr.net/npm/@mediapipe/pose/${file}`;
|
||||||
}});
|
}});
|
||||||
pose.setOptions({
|
pose.setOptions({
|
||||||
upperBodyOnly: false,
|
modelComplexity: 1,
|
||||||
smoothLandmarks: true,
|
smoothLandmarks: true,
|
||||||
minDetectionConfidence: 0.5,
|
minDetectionConfidence: 0.5,
|
||||||
minTrackingConfidence: 0.5
|
minTrackingConfidence: 0.5
|
||||||
|
@ -347,16 +356,6 @@ to visualize its associated subgraphs, please see
|
||||||
* iOS target:
|
* iOS target:
|
||||||
[`mediapipe/examples/ios/posetrackinggpu:PoseTrackingGpuApp`](http:/mediapipe/examples/ios/posetrackinggpu/BUILD)
|
[`mediapipe/examples/ios/posetrackinggpu:PoseTrackingGpuApp`](http:/mediapipe/examples/ios/posetrackinggpu/BUILD)
|
||||||
|
|
||||||
#### Upper-body Only
|
|
||||||
|
|
||||||
* Graph:
|
|
||||||
[`mediapipe/graphs/pose_tracking/upper_body_pose_tracking_gpu.pbtxt`](https://github.com/google/mediapipe/tree/master/mediapipe/graphs/pose_tracking/upper_body_pose_tracking_gpu.pbtxt)
|
|
||||||
* Android target:
|
|
||||||
[(or download prebuilt ARM64 APK)](https://drive.google.com/file/d/1uKc6T7KSuA0Mlq2URi5YookHu0U3yoh_/view?usp=sharing)
|
|
||||||
[`mediapipe/examples/android/src/java/com/google/mediapipe/apps/upperbodyposetrackinggpu:upperbodyposetrackinggpu`](https://github.com/google/mediapipe/tree/master/mediapipe/examples/android/src/java/com/google/mediapipe/apps/upperbodyposetrackinggpu/BUILD)
|
|
||||||
* iOS target:
|
|
||||||
[`mediapipe/examples/ios/upperbodyposetrackinggpu:UpperBodyPoseTrackingGpuApp`](http:/mediapipe/examples/ios/upperbodyposetrackinggpu/BUILD)
|
|
||||||
|
|
||||||
### Desktop
|
### Desktop
|
||||||
|
|
||||||
Please first see general instructions for [desktop](../getting_started/cpp.md)
|
Please first see general instructions for [desktop](../getting_started/cpp.md)
|
||||||
|
@ -375,19 +374,6 @@ on how to build MediaPipe examples.
|
||||||
* Target:
|
* Target:
|
||||||
[`mediapipe/examples/desktop/pose_tracking:pose_tracking_gpu`](https://github.com/google/mediapipe/tree/master/mediapipe/examples/desktop/pose_tracking/BUILD)
|
[`mediapipe/examples/desktop/pose_tracking:pose_tracking_gpu`](https://github.com/google/mediapipe/tree/master/mediapipe/examples/desktop/pose_tracking/BUILD)
|
||||||
|
|
||||||
#### Upper-body Only
|
|
||||||
|
|
||||||
* Running on CPU
|
|
||||||
* Graph:
|
|
||||||
[`mediapipe/graphs/pose_tracking/upper_body_pose_tracking_cpu.pbtxt`](https://github.com/google/mediapipe/tree/master/mediapipe/graphs/pose_tracking/upper_body_pose_tracking_cpu.pbtxt)
|
|
||||||
* Target:
|
|
||||||
[`mediapipe/examples/desktop/upper_body_pose_tracking:upper_body_pose_tracking_cpu`](https://github.com/google/mediapipe/tree/master/mediapipe/examples/desktop/upper_body_pose_tracking/BUILD)
|
|
||||||
* Running on GPU
|
|
||||||
* Graph:
|
|
||||||
[`mediapipe/graphs/pose_tracking/upper_body_pose_tracking_gpu.pbtxt`](https://github.com/google/mediapipe/tree/master/mediapipe/graphs/pose_tracking/upper_body_pose_tracking_gpu.pbtxt)
|
|
||||||
* Target:
|
|
||||||
[`mediapipe/examples/desktop/upper_body_pose_tracking:upper_body_pose_tracking_gpu`](https://github.com/google/mediapipe/tree/master/mediapipe/examples/desktop/upper_body_pose_tracking/BUILD)
|
|
||||||
|
|
||||||
## Resources
|
## Resources
|
||||||
|
|
||||||
* Google AI Blog:
|
* Google AI Blog:
|
||||||
|
|
|
@ -16,7 +16,6 @@
|
||||||
"mediapipe/examples/ios/objectdetectiongpu/BUILD",
|
"mediapipe/examples/ios/objectdetectiongpu/BUILD",
|
||||||
"mediapipe/examples/ios/objectdetectiontrackinggpu/BUILD",
|
"mediapipe/examples/ios/objectdetectiontrackinggpu/BUILD",
|
||||||
"mediapipe/examples/ios/posetrackinggpu/BUILD",
|
"mediapipe/examples/ios/posetrackinggpu/BUILD",
|
||||||
"mediapipe/examples/ios/upperbodyposetrackinggpu/BUILD",
|
|
||||||
"mediapipe/framework/BUILD",
|
"mediapipe/framework/BUILD",
|
||||||
"mediapipe/gpu/BUILD",
|
"mediapipe/gpu/BUILD",
|
||||||
"mediapipe/objc/BUILD",
|
"mediapipe/objc/BUILD",
|
||||||
|
@ -36,7 +35,6 @@
|
||||||
"//mediapipe/examples/ios/objectdetectiongpu:ObjectDetectionGpuApp",
|
"//mediapipe/examples/ios/objectdetectiongpu:ObjectDetectionGpuApp",
|
||||||
"//mediapipe/examples/ios/objectdetectiontrackinggpu:ObjectDetectionTrackingGpuApp",
|
"//mediapipe/examples/ios/objectdetectiontrackinggpu:ObjectDetectionTrackingGpuApp",
|
||||||
"//mediapipe/examples/ios/posetrackinggpu:PoseTrackingGpuApp",
|
"//mediapipe/examples/ios/posetrackinggpu:PoseTrackingGpuApp",
|
||||||
"//mediapipe/examples/ios/upperbodyposetrackinggpu:UpperBodyPoseTrackingGpuApp",
|
|
||||||
"//mediapipe/objc:mediapipe_framework_ios"
|
"//mediapipe/objc:mediapipe_framework_ios"
|
||||||
],
|
],
|
||||||
"optionSet" : {
|
"optionSet" : {
|
||||||
|
@ -105,7 +103,6 @@
|
||||||
"mediapipe/examples/ios/objectdetectioncpu",
|
"mediapipe/examples/ios/objectdetectioncpu",
|
||||||
"mediapipe/examples/ios/objectdetectiongpu",
|
"mediapipe/examples/ios/objectdetectiongpu",
|
||||||
"mediapipe/examples/ios/posetrackinggpu",
|
"mediapipe/examples/ios/posetrackinggpu",
|
||||||
"mediapipe/examples/ios/upperbodyposetrackinggpu",
|
|
||||||
"mediapipe/framework",
|
"mediapipe/framework",
|
||||||
"mediapipe/framework/deps",
|
"mediapipe/framework/deps",
|
||||||
"mediapipe/framework/formats",
|
"mediapipe/framework/formats",
|
||||||
|
|
|
@ -22,7 +22,6 @@
|
||||||
"mediapipe/examples/ios/objectdetectiongpu",
|
"mediapipe/examples/ios/objectdetectiongpu",
|
||||||
"mediapipe/examples/ios/objectdetectiontrackinggpu",
|
"mediapipe/examples/ios/objectdetectiontrackinggpu",
|
||||||
"mediapipe/examples/ios/posetrackinggpu",
|
"mediapipe/examples/ios/posetrackinggpu",
|
||||||
"mediapipe/examples/ios/upperbodyposetrackinggpu",
|
|
||||||
"mediapipe/objc"
|
"mediapipe/objc"
|
||||||
],
|
],
|
||||||
"projectName" : "Mediapipe",
|
"projectName" : "Mediapipe",
|
||||||
|
|
|
@ -451,8 +451,8 @@ cc_library(
|
||||||
)
|
)
|
||||||
|
|
||||||
cc_library(
|
cc_library(
|
||||||
name = "nonzero_calculator",
|
name = "non_zero_calculator",
|
||||||
srcs = ["nonzero_calculator.cc"],
|
srcs = ["non_zero_calculator.cc"],
|
||||||
visibility = [
|
visibility = [
|
||||||
"//visibility:public",
|
"//visibility:public",
|
||||||
],
|
],
|
||||||
|
@ -464,6 +464,21 @@ cc_library(
|
||||||
alwayslink = 1,
|
alwayslink = 1,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
cc_test(
|
||||||
|
name = "non_zero_calculator_test",
|
||||||
|
size = "small",
|
||||||
|
srcs = ["non_zero_calculator_test.cc"],
|
||||||
|
deps = [
|
||||||
|
":non_zero_calculator",
|
||||||
|
"//mediapipe/framework:calculator_framework",
|
||||||
|
"//mediapipe/framework:calculator_runner",
|
||||||
|
"//mediapipe/framework:timestamp",
|
||||||
|
"//mediapipe/framework/port:gtest_main",
|
||||||
|
"//mediapipe/framework/port:status",
|
||||||
|
"//mediapipe/framework/tool:validate_type",
|
||||||
|
],
|
||||||
|
)
|
||||||
|
|
||||||
cc_test(
|
cc_test(
|
||||||
name = "mux_calculator_test",
|
name = "mux_calculator_test",
|
||||||
srcs = ["mux_calculator_test.cc"],
|
srcs = ["mux_calculator_test.cc"],
|
||||||
|
@ -665,6 +680,18 @@ cc_library(
|
||||||
alwayslink = 1,
|
alwayslink = 1,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
cc_library(
|
||||||
|
name = "default_side_packet_calculator",
|
||||||
|
srcs = ["default_side_packet_calculator.cc"],
|
||||||
|
visibility = ["//visibility:public"],
|
||||||
|
deps = [
|
||||||
|
"//mediapipe/framework:calculator_framework",
|
||||||
|
"//mediapipe/framework/port:ret_check",
|
||||||
|
"//mediapipe/framework/port:status",
|
||||||
|
],
|
||||||
|
alwayslink = 1,
|
||||||
|
)
|
||||||
|
|
||||||
cc_library(
|
cc_library(
|
||||||
name = "side_packet_to_stream_calculator",
|
name = "side_packet_to_stream_calculator",
|
||||||
srcs = ["side_packet_to_stream_calculator.cc"],
|
srcs = ["side_packet_to_stream_calculator.cc"],
|
||||||
|
|
103
mediapipe/calculators/core/default_side_packet_calculator.cc
Normal file
|
@ -0,0 +1,103 @@
|
||||||
|
// Copyright 2019 The MediaPipe Authors.
|
||||||
|
//
|
||||||
|
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
// you may not use this file except in compliance with the License.
|
||||||
|
// You may obtain a copy of the License at
|
||||||
|
//
|
||||||
|
// http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
//
|
||||||
|
// Unless required by applicable law or agreed to in writing, software
|
||||||
|
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
// See the License for the specific language governing permissions and
|
||||||
|
// limitations under the License.
|
||||||
|
|
||||||
|
#include "mediapipe/framework/calculator_framework.h"
|
||||||
|
#include "mediapipe/framework/port/ret_check.h"
|
||||||
|
#include "mediapipe/framework/port/status.h"
|
||||||
|
|
||||||
|
namespace mediapipe {
|
||||||
|
|
||||||
|
namespace {
|
||||||
|
|
||||||
|
constexpr char kOptionalValueTag[] = "OPTIONAL_VALUE";
|
||||||
|
constexpr char kDefaultValueTag[] = "DEFAULT_VALUE";
|
||||||
|
constexpr char kValueTag[] = "VALUE";
|
||||||
|
|
||||||
|
} // namespace
|
||||||
|
|
||||||
|
// Outputs side packet default value if optional value is not provided.
|
||||||
|
//
|
||||||
|
// This calculator utilizes the fact that MediaPipe automatically removes
|
||||||
|
// optional side packets of the calculator configuration (i.e. OPTIONAL_VALUE).
|
||||||
|
// And if it happens - returns default value, otherwise - returns optional
|
||||||
|
// value.
|
||||||
|
//
|
||||||
|
// Input:
|
||||||
|
// OPTIONAL_VALUE (optional) - AnyType (but same type as DEFAULT_VALUE)
|
||||||
|
// Optional side packet value that is outputted by the calculator as is if
|
||||||
|
// provided.
|
||||||
|
//
|
||||||
|
// DEFAULT_VALUE - AnyType
|
||||||
|
// Default side pack value that is outputted by the calculator if
|
||||||
|
// OPTIONAL_VALUE is not provided.
|
||||||
|
//
|
||||||
|
// Output:
|
||||||
|
// VALUE - AnyType (but same type as DEFAULT_VALUE)
|
||||||
|
// Either OPTIONAL_VALUE (if provided) or DEFAULT_VALUE (otherwise).
|
||||||
|
//
|
||||||
|
// Usage example:
|
||||||
|
// node {
|
||||||
|
// calculator: "DefaultSidePacketCalculator"
|
||||||
|
// input_side_packet: "OPTIONAL_VALUE:segmentation_mask_enabled_optional"
|
||||||
|
// input_side_packet: "DEFAULT_VALUE:segmentation_mask_enabled_default"
|
||||||
|
// output_side_packet: "VALUE:segmentation_mask_enabled"
|
||||||
|
// }
|
||||||
|
class DefaultSidePacketCalculator : public CalculatorBase {
|
||||||
|
public:
|
||||||
|
static absl::Status GetContract(CalculatorContract* cc);
|
||||||
|
absl::Status Open(CalculatorContext* cc) override;
|
||||||
|
absl::Status Process(CalculatorContext* cc) override;
|
||||||
|
};
|
||||||
|
REGISTER_CALCULATOR(DefaultSidePacketCalculator);
|
||||||
|
|
||||||
|
absl::Status DefaultSidePacketCalculator::GetContract(CalculatorContract* cc) {
|
||||||
|
RET_CHECK(cc->InputSidePackets().HasTag(kDefaultValueTag))
|
||||||
|
<< "Default value must be provided";
|
||||||
|
cc->InputSidePackets().Tag(kDefaultValueTag).SetAny();
|
||||||
|
|
||||||
|
// Optional input side packet can be unspecified. In this case MediaPipe will
|
||||||
|
// remove it from the calculator config.
|
||||||
|
if (cc->InputSidePackets().HasTag(kOptionalValueTag)) {
|
||||||
|
cc->InputSidePackets()
|
||||||
|
.Tag(kOptionalValueTag)
|
||||||
|
.SetSameAs(&cc->InputSidePackets().Tag(kDefaultValueTag));
|
||||||
|
}
|
||||||
|
|
||||||
|
RET_CHECK(cc->OutputSidePackets().HasTag(kValueTag));
|
||||||
|
cc->OutputSidePackets().Tag(kValueTag).SetSameAs(
|
||||||
|
&cc->InputSidePackets().Tag(kDefaultValueTag));
|
||||||
|
|
||||||
|
return absl::OkStatus();
|
||||||
|
}
|
||||||
|
|
||||||
|
absl::Status DefaultSidePacketCalculator::Open(CalculatorContext* cc) {
|
||||||
|
// If optional value is provided it is returned as the calculator output.
|
||||||
|
if (cc->InputSidePackets().HasTag(kOptionalValueTag)) {
|
||||||
|
auto& packet = cc->InputSidePackets().Tag(kOptionalValueTag);
|
||||||
|
cc->OutputSidePackets().Tag(kValueTag).Set(packet);
|
||||||
|
return absl::OkStatus();
|
||||||
|
}
|
||||||
|
|
||||||
|
// If no optional value
|
||||||
|
auto& packet = cc->InputSidePackets().Tag(kDefaultValueTag);
|
||||||
|
cc->OutputSidePackets().Tag(kValueTag).Set(packet);
|
||||||
|
|
||||||
|
return absl::OkStatus();
|
||||||
|
}
|
||||||
|
|
||||||
|
absl::Status DefaultSidePacketCalculator::Process(CalculatorContext* cc) {
|
||||||
|
return absl::OkStatus();
|
||||||
|
}
|
||||||
|
|
||||||
|
} // namespace mediapipe
|
|
@ -23,14 +23,26 @@ namespace api2 {
|
||||||
class NonZeroCalculator : public Node {
|
class NonZeroCalculator : public Node {
|
||||||
public:
|
public:
|
||||||
static constexpr Input<int>::SideFallback kIn{"INPUT"};
|
static constexpr Input<int>::SideFallback kIn{"INPUT"};
|
||||||
static constexpr Output<int> kOut{"OUTPUT"};
|
static constexpr Output<int>::Optional kOut{"OUTPUT"};
|
||||||
|
static constexpr Output<bool>::Optional kBooleanOut{"OUTPUT_BOOL"};
|
||||||
|
|
||||||
MEDIAPIPE_NODE_CONTRACT(kIn, kOut);
|
MEDIAPIPE_NODE_CONTRACT(kIn, kOut, kBooleanOut);
|
||||||
|
|
||||||
|
absl::Status UpdateContract(CalculatorContract* cc) {
|
||||||
|
RET_CHECK(kOut(cc).IsConnected() || kBooleanOut(cc).IsConnected())
|
||||||
|
<< "At least one output stream is expected.";
|
||||||
|
return absl::OkStatus();
|
||||||
|
}
|
||||||
|
|
||||||
absl::Status Process(CalculatorContext* cc) final {
|
absl::Status Process(CalculatorContext* cc) final {
|
||||||
if (!kIn(cc).IsEmpty()) {
|
if (!kIn(cc).IsEmpty()) {
|
||||||
auto output = std::make_unique<int>((*kIn(cc) != 0) ? 1 : 0);
|
bool isNonZero = *kIn(cc) != 0;
|
||||||
kOut(cc).Send(std::move(output));
|
if (kOut(cc).IsConnected()) {
|
||||||
|
kOut(cc).Send(std::make_unique<int>(isNonZero ? 1 : 0));
|
||||||
|
}
|
||||||
|
if (kBooleanOut(cc).IsConnected()) {
|
||||||
|
kBooleanOut(cc).Send(std::make_unique<bool>(isNonZero));
|
||||||
|
}
|
||||||
}
|
}
|
||||||
return absl::OkStatus();
|
return absl::OkStatus();
|
||||||
}
|
}
|
93
mediapipe/calculators/core/non_zero_calculator_test.cc
Normal file
|
@ -0,0 +1,93 @@
|
||||||
|
// Copyright 2021 The MediaPipe Authors.
|
||||||
|
//
|
||||||
|
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
// you may not use this file except in compliance with the License.
|
||||||
|
// You may obtain a copy of the License at
|
||||||
|
//
|
||||||
|
// http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
//
|
||||||
|
// Unless required by applicable law or agreed to in writing, software
|
||||||
|
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
// See the License for the specific language governing permissions and
|
||||||
|
// limitations under the License.
|
||||||
|
|
||||||
|
#include "mediapipe/framework/calculator_framework.h"
|
||||||
|
#include "mediapipe/framework/calculator_runner.h"
|
||||||
|
#include "mediapipe/framework/port/canonical_errors.h"
|
||||||
|
#include "mediapipe/framework/port/gmock.h"
|
||||||
|
#include "mediapipe/framework/port/gtest.h"
|
||||||
|
#include "mediapipe/framework/port/status.h"
|
||||||
|
#include "mediapipe/framework/port/status_matchers.h"
|
||||||
|
#include "mediapipe/framework/timestamp.h"
|
||||||
|
#include "mediapipe/framework/tool/validate_type.h"
|
||||||
|
|
||||||
|
namespace mediapipe {
|
||||||
|
|
||||||
|
class NonZeroCalculatorTest : public ::testing::Test {
|
||||||
|
protected:
|
||||||
|
NonZeroCalculatorTest()
|
||||||
|
: runner_(
|
||||||
|
R"pb(
|
||||||
|
calculator: "NonZeroCalculator"
|
||||||
|
input_stream: "INPUT:input"
|
||||||
|
output_stream: "OUTPUT:output"
|
||||||
|
output_stream: "OUTPUT_BOOL:output_bool"
|
||||||
|
)pb") {}
|
||||||
|
|
||||||
|
void SetInput(const std::vector<int>& inputs) {
|
||||||
|
int timestamp = 0;
|
||||||
|
for (const auto input : inputs) {
|
||||||
|
runner_.MutableInputs()
|
||||||
|
->Get("INPUT", 0)
|
||||||
|
.packets.push_back(MakePacket<int>(input).At(Timestamp(timestamp++)));
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
std::vector<int> GetOutput() {
|
||||||
|
std::vector<int> result;
|
||||||
|
for (const auto output : runner_.Outputs().Get("OUTPUT", 0).packets) {
|
||||||
|
result.push_back(output.Get<int>());
|
||||||
|
}
|
||||||
|
return result;
|
||||||
|
}
|
||||||
|
|
||||||
|
std::vector<bool> GetOutputBool() {
|
||||||
|
std::vector<bool> result;
|
||||||
|
for (const auto output : runner_.Outputs().Get("OUTPUT_BOOL", 0).packets) {
|
||||||
|
result.push_back(output.Get<bool>());
|
||||||
|
}
|
||||||
|
return result;
|
||||||
|
}
|
||||||
|
|
||||||
|
CalculatorRunner runner_;
|
||||||
|
};
|
||||||
|
|
||||||
|
TEST_F(NonZeroCalculatorTest, ProducesZeroOutputForZeroInput) {
|
||||||
|
SetInput({0});
|
||||||
|
|
||||||
|
MP_ASSERT_OK(runner_.Run());
|
||||||
|
|
||||||
|
EXPECT_THAT(GetOutput(), ::testing::ElementsAre(0));
|
||||||
|
EXPECT_THAT(GetOutputBool(), ::testing::ElementsAre(false));
|
||||||
|
}
|
||||||
|
|
||||||
|
TEST_F(NonZeroCalculatorTest, ProducesNonZeroOutputForNonZeroInput) {
|
||||||
|
SetInput({1, 2, 3, -4, 5});
|
||||||
|
|
||||||
|
MP_ASSERT_OK(runner_.Run());
|
||||||
|
|
||||||
|
EXPECT_THAT(GetOutput(), ::testing::ElementsAre(1, 1, 1, 1, 1));
|
||||||
|
EXPECT_THAT(GetOutputBool(),
|
||||||
|
::testing::ElementsAre(true, true, true, true, true));
|
||||||
|
}
|
||||||
|
|
||||||
|
TEST_F(NonZeroCalculatorTest, SwitchesBetweenNonZeroAndZeroOutput) {
|
||||||
|
SetInput({1, 0, 3, 0, 5});
|
||||||
|
MP_ASSERT_OK(runner_.Run());
|
||||||
|
EXPECT_THAT(GetOutput(), ::testing::ElementsAre(1, 0, 1, 0, 1));
|
||||||
|
EXPECT_THAT(GetOutputBool(),
|
||||||
|
::testing::ElementsAre(true, false, true, false, true));
|
||||||
|
}
|
||||||
|
|
||||||
|
} // namespace mediapipe
|
|
@ -285,7 +285,7 @@ absl::Status ImageCroppingCalculator::RenderGpu(CalculatorContext* cc) {
|
||||||
|
|
||||||
// Run cropping shader on GPU.
|
// Run cropping shader on GPU.
|
||||||
{
|
{
|
||||||
gpu_helper_.BindFramebuffer(dst_tex); // GL_TEXTURE0
|
gpu_helper_.BindFramebuffer(dst_tex);
|
||||||
|
|
||||||
glActiveTexture(GL_TEXTURE1);
|
glActiveTexture(GL_TEXTURE1);
|
||||||
glBindTexture(src_tex.target(), src_tex.name());
|
glBindTexture(src_tex.target(), src_tex.name());
|
||||||
|
|
|
@ -546,7 +546,7 @@ absl::Status ImageTransformationCalculator::RenderGpu(CalculatorContext* cc) {
|
||||||
auto dst = gpu_helper_.CreateDestinationTexture(output_width, output_height,
|
auto dst = gpu_helper_.CreateDestinationTexture(output_width, output_height,
|
||||||
input.format());
|
input.format());
|
||||||
|
|
||||||
gpu_helper_.BindFramebuffer(dst); // GL_TEXTURE0
|
gpu_helper_.BindFramebuffer(dst);
|
||||||
glActiveTexture(GL_TEXTURE1);
|
glActiveTexture(GL_TEXTURE1);
|
||||||
glBindTexture(src1.target(), src1.name());
|
glBindTexture(src1.target(), src1.name());
|
||||||
|
|
||||||
|
|
|
@ -209,6 +209,9 @@ absl::Status RecolorCalculator::Close(CalculatorContext* cc) {
|
||||||
|
|
||||||
absl::Status RecolorCalculator::RenderCpu(CalculatorContext* cc) {
|
absl::Status RecolorCalculator::RenderCpu(CalculatorContext* cc) {
|
||||||
if (cc->Inputs().Tag(kMaskCpuTag).IsEmpty()) {
|
if (cc->Inputs().Tag(kMaskCpuTag).IsEmpty()) {
|
||||||
|
cc->Outputs()
|
||||||
|
.Tag(kImageFrameTag)
|
||||||
|
.AddPacket(cc->Inputs().Tag(kImageFrameTag).Value());
|
||||||
return absl::OkStatus();
|
return absl::OkStatus();
|
||||||
}
|
}
|
||||||
// Get inputs and setup output.
|
// Get inputs and setup output.
|
||||||
|
@ -270,6 +273,9 @@ absl::Status RecolorCalculator::RenderCpu(CalculatorContext* cc) {
|
||||||
|
|
||||||
absl::Status RecolorCalculator::RenderGpu(CalculatorContext* cc) {
|
absl::Status RecolorCalculator::RenderGpu(CalculatorContext* cc) {
|
||||||
if (cc->Inputs().Tag(kMaskGpuTag).IsEmpty()) {
|
if (cc->Inputs().Tag(kMaskGpuTag).IsEmpty()) {
|
||||||
|
cc->Outputs()
|
||||||
|
.Tag(kGpuBufferTag)
|
||||||
|
.AddPacket(cc->Inputs().Tag(kGpuBufferTag).Value());
|
||||||
return absl::OkStatus();
|
return absl::OkStatus();
|
||||||
}
|
}
|
||||||
#if !MEDIAPIPE_DISABLE_GPU
|
#if !MEDIAPIPE_DISABLE_GPU
|
||||||
|
@ -287,7 +293,7 @@ absl::Status RecolorCalculator::RenderGpu(CalculatorContext* cc) {
|
||||||
|
|
||||||
// Run recolor shader on GPU.
|
// Run recolor shader on GPU.
|
||||||
{
|
{
|
||||||
gpu_helper_.BindFramebuffer(dst_tex); // GL_TEXTURE0
|
gpu_helper_.BindFramebuffer(dst_tex);
|
||||||
|
|
||||||
glActiveTexture(GL_TEXTURE1);
|
glActiveTexture(GL_TEXTURE1);
|
||||||
glBindTexture(img_tex.target(), img_tex.name());
|
glBindTexture(img_tex.target(), img_tex.name());
|
||||||
|
|
|
@ -323,7 +323,7 @@ absl::Status SetAlphaCalculator::RenderGpu(CalculatorContext* cc) {
|
||||||
const auto& alpha_mask =
|
const auto& alpha_mask =
|
||||||
cc->Inputs().Tag(kInputAlphaTagGpu).Get<mediapipe::GpuBuffer>();
|
cc->Inputs().Tag(kInputAlphaTagGpu).Get<mediapipe::GpuBuffer>();
|
||||||
auto alpha_texture = gpu_helper_.CreateSourceTexture(alpha_mask);
|
auto alpha_texture = gpu_helper_.CreateSourceTexture(alpha_mask);
|
||||||
gpu_helper_.BindFramebuffer(output_texture); // GL_TEXTURE0
|
gpu_helper_.BindFramebuffer(output_texture);
|
||||||
glActiveTexture(GL_TEXTURE1);
|
glActiveTexture(GL_TEXTURE1);
|
||||||
glBindTexture(GL_TEXTURE_2D, input_texture.name());
|
glBindTexture(GL_TEXTURE_2D, input_texture.name());
|
||||||
glActiveTexture(GL_TEXTURE2);
|
glActiveTexture(GL_TEXTURE2);
|
||||||
|
@ -335,7 +335,7 @@ absl::Status SetAlphaCalculator::RenderGpu(CalculatorContext* cc) {
|
||||||
glBindTexture(GL_TEXTURE_2D, 0);
|
glBindTexture(GL_TEXTURE_2D, 0);
|
||||||
alpha_texture.Release();
|
alpha_texture.Release();
|
||||||
} else {
|
} else {
|
||||||
gpu_helper_.BindFramebuffer(output_texture); // GL_TEXTURE0
|
gpu_helper_.BindFramebuffer(output_texture);
|
||||||
glActiveTexture(GL_TEXTURE1);
|
glActiveTexture(GL_TEXTURE1);
|
||||||
glBindTexture(GL_TEXTURE_2D, input_texture.name());
|
glBindTexture(GL_TEXTURE_2D, input_texture.name());
|
||||||
GlRender(cc); // use value from options
|
GlRender(cc); // use value from options
|
||||||
|
|
|
@ -490,6 +490,7 @@ cc_library(
|
||||||
"//mediapipe/framework/port:statusor",
|
"//mediapipe/framework/port:statusor",
|
||||||
"//mediapipe/framework:calculator_framework",
|
"//mediapipe/framework:calculator_framework",
|
||||||
"//mediapipe/framework:port",
|
"//mediapipe/framework:port",
|
||||||
|
"//mediapipe/gpu:gpu_origin_cc_proto",
|
||||||
] + select({
|
] + select({
|
||||||
"//mediapipe/gpu:disable_gpu": [],
|
"//mediapipe/gpu:disable_gpu": [],
|
||||||
"//conditions:default": [":image_to_tensor_calculator_gpu_deps"],
|
"//conditions:default": [":image_to_tensor_calculator_gpu_deps"],
|
||||||
|
@ -526,6 +527,7 @@ mediapipe_proto_library(
|
||||||
deps = [
|
deps = [
|
||||||
"//mediapipe/framework:calculator_options_proto",
|
"//mediapipe/framework:calculator_options_proto",
|
||||||
"//mediapipe/framework:calculator_proto",
|
"//mediapipe/framework:calculator_proto",
|
||||||
|
"//mediapipe/gpu:gpu_origin_proto",
|
||||||
],
|
],
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
|
@ -31,6 +31,7 @@
|
||||||
#include "mediapipe/framework/port/ret_check.h"
|
#include "mediapipe/framework/port/ret_check.h"
|
||||||
#include "mediapipe/framework/port/status.h"
|
#include "mediapipe/framework/port/status.h"
|
||||||
#include "mediapipe/framework/port/statusor.h"
|
#include "mediapipe/framework/port/statusor.h"
|
||||||
|
#include "mediapipe/gpu/gpu_origin.pb.h"
|
||||||
|
|
||||||
#if !MEDIAPIPE_DISABLE_GPU
|
#if !MEDIAPIPE_DISABLE_GPU
|
||||||
#include "mediapipe/gpu/gpu_buffer.h"
|
#include "mediapipe/gpu/gpu_buffer.h"
|
||||||
|
@ -236,7 +237,7 @@ class ImageToTensorCalculator : public Node {
|
||||||
}
|
}
|
||||||
|
|
||||||
private:
|
private:
|
||||||
bool DoesInputStartAtBottom() {
|
bool DoesGpuInputStartAtBottom() {
|
||||||
return options_.gpu_origin() != mediapipe::GpuOrigin_Mode_TOP_LEFT;
|
return options_.gpu_origin() != mediapipe::GpuOrigin_Mode_TOP_LEFT;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -290,11 +291,11 @@ class ImageToTensorCalculator : public Node {
|
||||||
#elif MEDIAPIPE_OPENGL_ES_VERSION >= MEDIAPIPE_OPENGL_ES_31
|
#elif MEDIAPIPE_OPENGL_ES_VERSION >= MEDIAPIPE_OPENGL_ES_31
|
||||||
ASSIGN_OR_RETURN(gpu_converter_,
|
ASSIGN_OR_RETURN(gpu_converter_,
|
||||||
CreateImageToGlBufferTensorConverter(
|
CreateImageToGlBufferTensorConverter(
|
||||||
cc, DoesInputStartAtBottom(), GetBorderMode()));
|
cc, DoesGpuInputStartAtBottom(), GetBorderMode()));
|
||||||
#else
|
#else
|
||||||
ASSIGN_OR_RETURN(gpu_converter_,
|
ASSIGN_OR_RETURN(gpu_converter_,
|
||||||
CreateImageToGlTextureTensorConverter(
|
CreateImageToGlTextureTensorConverter(
|
||||||
cc, DoesInputStartAtBottom(), GetBorderMode()));
|
cc, DoesGpuInputStartAtBottom(), GetBorderMode()));
|
||||||
#endif // MEDIAPIPE_METAL_ENABLED
|
#endif // MEDIAPIPE_METAL_ENABLED
|
||||||
#endif // !MEDIAPIPE_DISABLE_GPU
|
#endif // !MEDIAPIPE_DISABLE_GPU
|
||||||
}
|
}
|
||||||
|
|
|
@ -17,20 +17,7 @@ syntax = "proto2";
|
||||||
package mediapipe;
|
package mediapipe;
|
||||||
|
|
||||||
import "mediapipe/framework/calculator.proto";
|
import "mediapipe/framework/calculator.proto";
|
||||||
|
import "mediapipe/gpu/gpu_origin.proto";
|
||||||
message GpuOrigin {
|
|
||||||
enum Mode {
|
|
||||||
DEFAULT = 0;
|
|
||||||
|
|
||||||
// OpenGL: bottom-left origin
|
|
||||||
// Metal : top-left origin
|
|
||||||
CONVENTIONAL = 1;
|
|
||||||
|
|
||||||
// OpenGL: top-left origin
|
|
||||||
// Metal : top-left origin
|
|
||||||
TOP_LEFT = 2;
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
message ImageToTensorCalculatorOptions {
|
message ImageToTensorCalculatorOptions {
|
||||||
extend mediapipe.CalculatorOptions {
|
extend mediapipe.CalculatorOptions {
|
||||||
|
|
|
@ -317,7 +317,8 @@ absl::Status InferenceCalculatorGlImpl::LoadModel(CalculatorContext* cc) {
|
||||||
absl::Status InferenceCalculatorGlImpl::LoadDelegate(CalculatorContext* cc) {
|
absl::Status InferenceCalculatorGlImpl::LoadDelegate(CalculatorContext* cc) {
|
||||||
// Configure and create the delegate.
|
// Configure and create the delegate.
|
||||||
TfLiteGpuDelegateOptions options = TfLiteGpuDelegateOptionsDefault();
|
TfLiteGpuDelegateOptions options = TfLiteGpuDelegateOptionsDefault();
|
||||||
options.compile_options.precision_loss_allowed = 1;
|
options.compile_options.precision_loss_allowed =
|
||||||
|
allow_precision_loss_ ? 1 : 0;
|
||||||
options.compile_options.preferred_gl_object_type =
|
options.compile_options.preferred_gl_object_type =
|
||||||
TFLITE_GL_OBJECT_TYPE_FASTEST;
|
TFLITE_GL_OBJECT_TYPE_FASTEST;
|
||||||
options.compile_options.dynamic_batch_enabled = 0;
|
options.compile_options.dynamic_batch_enabled = 0;
|
||||||
|
|
|
@ -97,6 +97,7 @@ class InferenceCalculatorMetalImpl
|
||||||
Packet<TfLiteModelPtr> model_packet_;
|
Packet<TfLiteModelPtr> model_packet_;
|
||||||
std::unique_ptr<tflite::Interpreter> interpreter_;
|
std::unique_ptr<tflite::Interpreter> interpreter_;
|
||||||
TfLiteDelegatePtr delegate_;
|
TfLiteDelegatePtr delegate_;
|
||||||
|
bool allow_precision_loss_ = false;
|
||||||
|
|
||||||
#if MEDIAPIPE_TFLITE_METAL_INFERENCE
|
#if MEDIAPIPE_TFLITE_METAL_INFERENCE
|
||||||
MPPMetalHelper* gpu_helper_ = nullptr;
|
MPPMetalHelper* gpu_helper_ = nullptr;
|
||||||
|
@ -122,6 +123,9 @@ absl::Status InferenceCalculatorMetalImpl::UpdateContract(
|
||||||
}
|
}
|
||||||
|
|
||||||
absl::Status InferenceCalculatorMetalImpl::Open(CalculatorContext* cc) {
|
absl::Status InferenceCalculatorMetalImpl::Open(CalculatorContext* cc) {
|
||||||
|
const auto& options = cc->Options<::mediapipe::InferenceCalculatorOptions>();
|
||||||
|
allow_precision_loss_ = options.delegate().gpu().allow_precision_loss();
|
||||||
|
|
||||||
MP_RETURN_IF_ERROR(LoadModel(cc));
|
MP_RETURN_IF_ERROR(LoadModel(cc));
|
||||||
|
|
||||||
gpu_helper_ = [[MPPMetalHelper alloc] initWithCalculatorContext:cc];
|
gpu_helper_ = [[MPPMetalHelper alloc] initWithCalculatorContext:cc];
|
||||||
|
@ -222,7 +226,7 @@ absl::Status InferenceCalculatorMetalImpl::LoadDelegate(CalculatorContext* cc) {
|
||||||
|
|
||||||
// Configure and create the delegate.
|
// Configure and create the delegate.
|
||||||
TFLGpuDelegateOptions options;
|
TFLGpuDelegateOptions options;
|
||||||
options.allow_precision_loss = true;
|
options.allow_precision_loss = allow_precision_loss_;
|
||||||
options.wait_type = TFLGpuDelegateWaitType::TFLGpuDelegateWaitTypeDoNotWait;
|
options.wait_type = TFLGpuDelegateWaitType::TFLGpuDelegateWaitTypeDoNotWait;
|
||||||
delegate_ =
|
delegate_ =
|
||||||
TfLiteDelegatePtr(TFLGpuDelegateCreate(&options), &TFLGpuDelegateDelete);
|
TfLiteDelegatePtr(TFLGpuDelegateCreate(&options), &TFLGpuDelegateDelete);
|
||||||
|
@ -239,7 +243,9 @@ absl::Status InferenceCalculatorMetalImpl::LoadDelegate(CalculatorContext* cc) {
|
||||||
tensor->dims->data + tensor->dims->size};
|
tensor->dims->data + tensor->dims->size};
|
||||||
dims.back() = RoundUp(dims.back(), 4);
|
dims.back() = RoundUp(dims.back(), 4);
|
||||||
gpu_buffers_in_.emplace_back(absl::make_unique<Tensor>(
|
gpu_buffers_in_.emplace_back(absl::make_unique<Tensor>(
|
||||||
Tensor::ElementType::kFloat16, Tensor::Shape{dims}));
|
allow_precision_loss_ ? Tensor::ElementType::kFloat16
|
||||||
|
: Tensor::ElementType::kFloat32,
|
||||||
|
Tensor::Shape{dims}));
|
||||||
auto buffer_view =
|
auto buffer_view =
|
||||||
gpu_buffers_in_[i]->GetMtlBufferWriteView(gpu_helper_.mtlDevice);
|
gpu_buffers_in_[i]->GetMtlBufferWriteView(gpu_helper_.mtlDevice);
|
||||||
RET_CHECK_EQ(TFLGpuDelegateBindMetalBufferToTensor(
|
RET_CHECK_EQ(TFLGpuDelegateBindMetalBufferToTensor(
|
||||||
|
@ -261,7 +267,9 @@ absl::Status InferenceCalculatorMetalImpl::LoadDelegate(CalculatorContext* cc) {
|
||||||
output_shapes_[i] = {dims};
|
output_shapes_[i] = {dims};
|
||||||
dims.back() = RoundUp(dims.back(), 4);
|
dims.back() = RoundUp(dims.back(), 4);
|
||||||
gpu_buffers_out_.emplace_back(absl::make_unique<Tensor>(
|
gpu_buffers_out_.emplace_back(absl::make_unique<Tensor>(
|
||||||
Tensor::ElementType::kFloat16, Tensor::Shape{dims}));
|
allow_precision_loss_ ? Tensor::ElementType::kFloat16
|
||||||
|
: Tensor::ElementType::kFloat32,
|
||||||
|
Tensor::Shape{dims}));
|
||||||
RET_CHECK_EQ(TFLGpuDelegateBindMetalBufferToTensor(
|
RET_CHECK_EQ(TFLGpuDelegateBindMetalBufferToTensor(
|
||||||
delegate_.get(), output_indices[i],
|
delegate_.get(), output_indices[i],
|
||||||
gpu_buffers_out_[i]
|
gpu_buffers_out_[i]
|
||||||
|
@ -271,17 +279,19 @@ absl::Status InferenceCalculatorMetalImpl::LoadDelegate(CalculatorContext* cc) {
|
||||||
}
|
}
|
||||||
|
|
||||||
// Create converter for GPU input.
|
// Create converter for GPU input.
|
||||||
converter_to_BPHWC4_ = [[TFLBufferConvert alloc] initWithDevice:device
|
converter_to_BPHWC4_ =
|
||||||
isFloat16:true
|
[[TFLBufferConvert alloc] initWithDevice:device
|
||||||
convertToPBHWC4:true];
|
isFloat16:allow_precision_loss_
|
||||||
|
convertToPBHWC4:true];
|
||||||
if (converter_to_BPHWC4_ == nil) {
|
if (converter_to_BPHWC4_ == nil) {
|
||||||
return mediapipe::InternalError(
|
return mediapipe::InternalError(
|
||||||
"Error initializating input buffer converter");
|
"Error initializating input buffer converter");
|
||||||
}
|
}
|
||||||
// Create converter for GPU output.
|
// Create converter for GPU output.
|
||||||
converter_from_BPHWC4_ = [[TFLBufferConvert alloc] initWithDevice:device
|
converter_from_BPHWC4_ =
|
||||||
isFloat16:true
|
[[TFLBufferConvert alloc] initWithDevice:device
|
||||||
convertToPBHWC4:false];
|
isFloat16:allow_precision_loss_
|
||||||
|
convertToPBHWC4:false];
|
||||||
if (converter_from_BPHWC4_ == nil) {
|
if (converter_from_BPHWC4_ == nil) {
|
||||||
return absl::InternalError("Error initializating output buffer converter");
|
return absl::InternalError("Error initializating output buffer converter");
|
||||||
}
|
}
|
||||||
|
|
|
@ -89,7 +89,8 @@ absl::Status TensorsToClassificationCalculator::Open(CalculatorContext* cc) {
|
||||||
ASSIGN_OR_RETURN(string_path,
|
ASSIGN_OR_RETURN(string_path,
|
||||||
PathToResourceAsFile(options_.label_map_path()));
|
PathToResourceAsFile(options_.label_map_path()));
|
||||||
std::string label_map_string;
|
std::string label_map_string;
|
||||||
MP_RETURN_IF_ERROR(file::GetContents(string_path, &label_map_string));
|
MP_RETURN_IF_ERROR(
|
||||||
|
mediapipe::GetResourceContents(string_path, &label_map_string));
|
||||||
|
|
||||||
std::istringstream stream(label_map_string);
|
std::istringstream stream(label_map_string);
|
||||||
std::string line;
|
std::string line;
|
||||||
|
@ -98,6 +99,14 @@ absl::Status TensorsToClassificationCalculator::Open(CalculatorContext* cc) {
|
||||||
label_map_[i++] = line;
|
label_map_[i++] = line;
|
||||||
}
|
}
|
||||||
label_map_loaded_ = true;
|
label_map_loaded_ = true;
|
||||||
|
} else if (options_.has_label_map()) {
|
||||||
|
for (int i = 0; i < options_.label_map().entries_size(); ++i) {
|
||||||
|
const auto& entry = options_.label_map().entries(i);
|
||||||
|
RET_CHECK(!label_map_.contains(entry.id()))
|
||||||
|
<< "Duplicate id found: " << entry.id();
|
||||||
|
label_map_[entry.id()] = entry.label();
|
||||||
|
}
|
||||||
|
label_map_loaded_ = true;
|
||||||
}
|
}
|
||||||
|
|
||||||
return absl::OkStatus();
|
return absl::OkStatus();
|
||||||
|
|
|
@ -25,6 +25,14 @@ message TensorsToClassificationCalculatorOptions {
|
||||||
optional TensorsToClassificationCalculatorOptions ext = 335742638;
|
optional TensorsToClassificationCalculatorOptions ext = 335742638;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
message LabelMap {
|
||||||
|
message Entry {
|
||||||
|
optional int32 id = 1;
|
||||||
|
optional string label = 2;
|
||||||
|
}
|
||||||
|
repeated Entry entries = 1;
|
||||||
|
}
|
||||||
|
|
||||||
// Score threshold for perserving the class.
|
// Score threshold for perserving the class.
|
||||||
optional float min_score_threshold = 1;
|
optional float min_score_threshold = 1;
|
||||||
// Number of highest scoring labels to output. If top_k is not positive then
|
// Number of highest scoring labels to output. If top_k is not positive then
|
||||||
|
@ -32,6 +40,10 @@ message TensorsToClassificationCalculatorOptions {
|
||||||
optional int32 top_k = 2;
|
optional int32 top_k = 2;
|
||||||
// Path to a label map file for getting the actual name of class ids.
|
// Path to a label map file for getting the actual name of class ids.
|
||||||
optional string label_map_path = 3;
|
optional string label_map_path = 3;
|
||||||
|
// Label map. (Can be used instead of label_map_path.)
|
||||||
|
// NOTE: "label_map_path", if specified, takes precedence over "label_map".
|
||||||
|
optional LabelMap label_map = 5;
|
||||||
|
|
||||||
// Whether the input is a single float for binary classification.
|
// Whether the input is a single float for binary classification.
|
||||||
// When true, only a single float is expected in the input tensor and the
|
// When true, only a single float is expected in the input tensor and the
|
||||||
// label map, if provided, is expected to have exactly two labels.
|
// label map, if provided, is expected to have exactly two labels.
|
||||||
|
|
|
@ -115,6 +115,41 @@ TEST_F(TensorsToClassificationCalculatorTest, CorrectOutputWithLabelMapPath) {
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
TEST_F(TensorsToClassificationCalculatorTest, CorrectOutputWithLabelMap) {
|
||||||
|
mediapipe::CalculatorRunner runner(ParseTextProtoOrDie<Node>(R"pb(
|
||||||
|
calculator: "TensorsToClassificationCalculator"
|
||||||
|
input_stream: "TENSORS:tensors"
|
||||||
|
output_stream: "CLASSIFICATIONS:classifications"
|
||||||
|
options {
|
||||||
|
[mediapipe.TensorsToClassificationCalculatorOptions.ext] {
|
||||||
|
label_map {
|
||||||
|
entries { id: 0, label: "ClassA" }
|
||||||
|
entries { id: 1, label: "ClassB" }
|
||||||
|
entries { id: 2, label: "ClassC" }
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
)pb"));
|
||||||
|
|
||||||
|
BuildGraph(&runner, {0, 0.5, 1});
|
||||||
|
MP_ASSERT_OK(runner.Run());
|
||||||
|
|
||||||
|
const auto& output_packets_ = runner.Outputs().Tag("CLASSIFICATIONS").packets;
|
||||||
|
|
||||||
|
EXPECT_EQ(1, output_packets_.size());
|
||||||
|
|
||||||
|
const auto& classification_list =
|
||||||
|
output_packets_[0].Get<ClassificationList>();
|
||||||
|
EXPECT_EQ(3, classification_list.classification_size());
|
||||||
|
|
||||||
|
// Verify that the label field is set.
|
||||||
|
for (int i = 0; i < classification_list.classification_size(); ++i) {
|
||||||
|
EXPECT_EQ(i, classification_list.classification(i).index());
|
||||||
|
EXPECT_EQ(i * 0.5, classification_list.classification(i).score());
|
||||||
|
ASSERT_TRUE(classification_list.classification(i).has_label());
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
TEST_F(TensorsToClassificationCalculatorTest,
|
TEST_F(TensorsToClassificationCalculatorTest,
|
||||||
CorrectOutputWithLabelMinScoreThreshold) {
|
CorrectOutputWithLabelMinScoreThreshold) {
|
||||||
mediapipe::CalculatorRunner runner(ParseTextProtoOrDie<Node>(R"pb(
|
mediapipe::CalculatorRunner runner(ParseTextProtoOrDie<Node>(R"pb(
|
||||||
|
|
|
@ -34,15 +34,28 @@ constexpr char kTensor[] = "TENSOR";
|
||||||
} // namespace
|
} // namespace
|
||||||
|
|
||||||
// Input:
|
// Input:
|
||||||
// Tensor of type DT_FLOAT, with values between 0-255 (SRGB or GRAY8). The
|
// Tensor of type DT_FLOAT or DT_UINT8, with values between 0-255
|
||||||
// shape can be HxWx{3,1} or simply HxW.
|
// (SRGB or GRAY8). The shape can be HxWx{3,1} or simply HxW.
|
||||||
//
|
//
|
||||||
// Optionally supports a scale factor that can scale 0-1 value ranges to 0-255.
|
// For DT_FLOAT tensors, optionally supports a scale factor that can scale 0-1
|
||||||
|
// value ranges to 0-255.
|
||||||
//
|
//
|
||||||
// Output:
|
// Output:
|
||||||
// ImageFrame containing the values of the tensor cast as uint8 (SRGB or GRAY8)
|
// ImageFrame containing the values of the tensor cast as uint8 (SRGB or GRAY8)
|
||||||
//
|
//
|
||||||
// Possible extensions: support other input ranges, maybe 4D tensors.
|
// Possible extensions: support other input ranges, maybe 4D tensors.
|
||||||
|
//
|
||||||
|
// Example:
|
||||||
|
// node {
|
||||||
|
// calculator: "TensorToImageFrameCalculator"
|
||||||
|
// input_stream: "TENSOR:3d_float_tensor"
|
||||||
|
// output_stream: "IMAGE:image_frame"
|
||||||
|
// options {
|
||||||
|
// [mediapipe.TensorToImageFrameCalculatorOptions.ext] {
|
||||||
|
// scale_factor: 1.0 # set to 255.0 for [0,1] -> [0,255] scaling
|
||||||
|
// }
|
||||||
|
// }
|
||||||
|
// }
|
||||||
class TensorToImageFrameCalculator : public CalculatorBase {
|
class TensorToImageFrameCalculator : public CalculatorBase {
|
||||||
public:
|
public:
|
||||||
static absl::Status GetContract(CalculatorContract* cc);
|
static absl::Status GetContract(CalculatorContract* cc);
|
||||||
|
@ -57,8 +70,8 @@ class TensorToImageFrameCalculator : public CalculatorBase {
|
||||||
REGISTER_CALCULATOR(TensorToImageFrameCalculator);
|
REGISTER_CALCULATOR(TensorToImageFrameCalculator);
|
||||||
|
|
||||||
absl::Status TensorToImageFrameCalculator::GetContract(CalculatorContract* cc) {
|
absl::Status TensorToImageFrameCalculator::GetContract(CalculatorContract* cc) {
|
||||||
RET_CHECK_EQ(cc->Inputs().NumEntries(), 1)
|
RET_CHECK_EQ(cc->Outputs().NumEntries(), 1)
|
||||||
<< "Only one input stream is supported.";
|
<< "Only one output stream is supported.";
|
||||||
RET_CHECK_EQ(cc->Inputs().NumEntries(), 1)
|
RET_CHECK_EQ(cc->Inputs().NumEntries(), 1)
|
||||||
<< "One input stream must be provided.";
|
<< "One input stream must be provided.";
|
||||||
RET_CHECK(cc->Inputs().HasTag(kTensor))
|
RET_CHECK(cc->Inputs().HasTag(kTensor))
|
||||||
|
@ -91,29 +104,44 @@ absl::Status TensorToImageFrameCalculator::Process(CalculatorContext* cc) {
|
||||||
RET_CHECK_EQ(depth, 3) << "Output tensor depth must be 3 or 1.";
|
RET_CHECK_EQ(depth, 3) << "Output tensor depth must be 3 or 1.";
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
const int32 total_size =
|
int32 height = input_tensor.dim_size(0);
|
||||||
input_tensor.dim_size(0) * input_tensor.dim_size(1) * depth;
|
int32 width = input_tensor.dim_size(1);
|
||||||
std::unique_ptr<uint8[]> buffer(new uint8[total_size]);
|
auto format = (depth == 3 ? ImageFormat::SRGB : ImageFormat::GRAY8);
|
||||||
auto data = input_tensor.flat<float>().data();
|
const int32 total_size = height * width * depth;
|
||||||
for (int i = 0; i < total_size; ++i) {
|
|
||||||
float d = scale_factor_ * data[i];
|
::std::unique_ptr<const ImageFrame> output;
|
||||||
if (d < 0) d = 0;
|
if (input_tensor.dtype() == tensorflow::DT_FLOAT) {
|
||||||
if (d > 255) d = 255;
|
// Allocate buffer with alignments.
|
||||||
buffer[i] = d;
|
std::unique_ptr<uint8_t[]> buffer(
|
||||||
|
new (std::align_val_t(EIGEN_MAX_ALIGN_BYTES)) uint8_t[total_size]);
|
||||||
|
auto data = input_tensor.flat<float>().data();
|
||||||
|
for (int i = 0; i < total_size; ++i) {
|
||||||
|
float d = scale_factor_ * data[i];
|
||||||
|
if (d < 0) d = 0;
|
||||||
|
if (d > 255) d = 255;
|
||||||
|
buffer[i] = d;
|
||||||
|
}
|
||||||
|
output = ::absl::make_unique<ImageFrame>(format, width, height,
|
||||||
|
width * depth, buffer.release());
|
||||||
|
} else if (input_tensor.dtype() == tensorflow::DT_UINT8) {
|
||||||
|
if (scale_factor_ != 1.0) {
|
||||||
|
return absl::InvalidArgumentError("scale_factor_ given for uint8 tensor");
|
||||||
|
}
|
||||||
|
// tf::Tensor has internally ref-counted buffer. The following code make the
|
||||||
|
// ImageFrame own the copied Tensor through the deleter, which increases
|
||||||
|
// the refcount of the buffer and allow us to use the shared buffer as the
|
||||||
|
// image. This allows us to create an ImageFrame object without copying
|
||||||
|
// buffer. const ImageFrame prevents the buffer from being modified later.
|
||||||
|
auto copy = new tf::Tensor(input_tensor);
|
||||||
|
output = ::absl::make_unique<const ImageFrame>(
|
||||||
|
format, width, height, width * depth, copy->flat<uint8_t>().data(),
|
||||||
|
[copy](uint8*) { delete copy; });
|
||||||
|
} else {
|
||||||
|
return absl::InvalidArgumentError(
|
||||||
|
absl::StrCat("Expected float or uint8 tensor, received ",
|
||||||
|
DataTypeString(input_tensor.dtype())));
|
||||||
}
|
}
|
||||||
|
|
||||||
::std::unique_ptr<ImageFrame> output;
|
|
||||||
if (depth == 3) {
|
|
||||||
output = ::absl::make_unique<ImageFrame>(
|
|
||||||
ImageFormat::SRGB, input_tensor.dim_size(1), input_tensor.dim_size(0),
|
|
||||||
input_tensor.dim_size(1) * 3, buffer.release());
|
|
||||||
} else if (depth == 1) {
|
|
||||||
output = ::absl::make_unique<ImageFrame>(
|
|
||||||
ImageFormat::GRAY8, input_tensor.dim_size(1), input_tensor.dim_size(0),
|
|
||||||
input_tensor.dim_size(1), buffer.release());
|
|
||||||
} else {
|
|
||||||
return absl::InvalidArgumentError("Unrecognized image depth.");
|
|
||||||
}
|
|
||||||
cc->Outputs().Tag(kImage).Add(output.release(), cc->InputTimestamp());
|
cc->Outputs().Tag(kImage).Add(output.release(), cc->InputTimestamp());
|
||||||
|
|
||||||
return absl::OkStatus();
|
return absl::OkStatus();
|
||||||
|
|
|
@ -29,6 +29,7 @@ constexpr char kImage[] = "IMAGE";
|
||||||
|
|
||||||
} // namespace
|
} // namespace
|
||||||
|
|
||||||
|
template <class TypeParam>
|
||||||
class TensorToImageFrameCalculatorTest : public ::testing::Test {
|
class TensorToImageFrameCalculatorTest : public ::testing::Test {
|
||||||
protected:
|
protected:
|
||||||
void SetUpRunner() {
|
void SetUpRunner() {
|
||||||
|
@ -42,14 +43,20 @@ class TensorToImageFrameCalculatorTest : public ::testing::Test {
|
||||||
std::unique_ptr<CalculatorRunner> runner_;
|
std::unique_ptr<CalculatorRunner> runner_;
|
||||||
};
|
};
|
||||||
|
|
||||||
TEST_F(TensorToImageFrameCalculatorTest, Converts3DTensorToImageFrame) {
|
using TensorToImageFrameCalculatorTestTypes = ::testing::Types<float, uint8_t>;
|
||||||
SetUpRunner();
|
TYPED_TEST_CASE(TensorToImageFrameCalculatorTest,
|
||||||
|
TensorToImageFrameCalculatorTestTypes);
|
||||||
|
|
||||||
|
TYPED_TEST(TensorToImageFrameCalculatorTest, Converts3DTensorToImageFrame) {
|
||||||
|
// TYPED_TEST requires explicit "this->"
|
||||||
|
this->SetUpRunner();
|
||||||
|
auto& runner = this->runner_;
|
||||||
constexpr int kWidth = 16;
|
constexpr int kWidth = 16;
|
||||||
constexpr int kHeight = 8;
|
constexpr int kHeight = 8;
|
||||||
const tf::TensorShape tensor_shape(
|
const tf::TensorShape tensor_shape{kHeight, kWidth, 3};
|
||||||
std::vector<tf::int64>{kHeight, kWidth, 3});
|
auto tensor = absl::make_unique<tf::Tensor>(
|
||||||
auto tensor = absl::make_unique<tf::Tensor>(tf::DT_FLOAT, tensor_shape);
|
tf::DataTypeToEnum<TypeParam>::v(), tensor_shape);
|
||||||
auto tensor_vec = tensor->flat<float>().data();
|
auto tensor_vec = tensor->template flat<TypeParam>().data();
|
||||||
|
|
||||||
// Writing sequence of integers as floats which we want back (as they were
|
// Writing sequence of integers as floats which we want back (as they were
|
||||||
// written).
|
// written).
|
||||||
|
@ -58,15 +65,16 @@ TEST_F(TensorToImageFrameCalculatorTest, Converts3DTensorToImageFrame) {
|
||||||
}
|
}
|
||||||
|
|
||||||
const int64 time = 1234;
|
const int64 time = 1234;
|
||||||
runner_->MutableInputs()->Tag(kTensor).packets.push_back(
|
runner->MutableInputs()->Tag(kTensor).packets.push_back(
|
||||||
Adopt(tensor.release()).At(Timestamp(time)));
|
Adopt(tensor.release()).At(Timestamp(time)));
|
||||||
|
|
||||||
EXPECT_TRUE(runner_->Run().ok());
|
EXPECT_TRUE(runner->Run().ok());
|
||||||
const std::vector<Packet>& output_packets =
|
const std::vector<Packet>& output_packets =
|
||||||
runner_->Outputs().Tag(kImage).packets;
|
runner->Outputs().Tag(kImage).packets;
|
||||||
EXPECT_EQ(1, output_packets.size());
|
EXPECT_EQ(1, output_packets.size());
|
||||||
EXPECT_EQ(time, output_packets[0].Timestamp().Value());
|
EXPECT_EQ(time, output_packets[0].Timestamp().Value());
|
||||||
const ImageFrame& output_image = output_packets[0].Get<ImageFrame>();
|
const ImageFrame& output_image = output_packets[0].Get<ImageFrame>();
|
||||||
|
EXPECT_EQ(ImageFormat::SRGB, output_image.Format());
|
||||||
EXPECT_EQ(kWidth, output_image.Width());
|
EXPECT_EQ(kWidth, output_image.Width());
|
||||||
EXPECT_EQ(kHeight, output_image.Height());
|
EXPECT_EQ(kHeight, output_image.Height());
|
||||||
|
|
||||||
|
@ -76,14 +84,15 @@ TEST_F(TensorToImageFrameCalculatorTest, Converts3DTensorToImageFrame) {
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
TEST_F(TensorToImageFrameCalculatorTest, Converts3DTensorToImageFrameGray) {
|
TYPED_TEST(TensorToImageFrameCalculatorTest, Converts3DTensorToImageFrameGray) {
|
||||||
SetUpRunner();
|
this->SetUpRunner();
|
||||||
|
auto& runner = this->runner_;
|
||||||
constexpr int kWidth = 16;
|
constexpr int kWidth = 16;
|
||||||
constexpr int kHeight = 8;
|
constexpr int kHeight = 8;
|
||||||
const tf::TensorShape tensor_shape(
|
const tf::TensorShape tensor_shape{kHeight, kWidth, 1};
|
||||||
std::vector<tf::int64>{kHeight, kWidth, 1});
|
auto tensor = absl::make_unique<tf::Tensor>(
|
||||||
auto tensor = absl::make_unique<tf::Tensor>(tf::DT_FLOAT, tensor_shape);
|
tf::DataTypeToEnum<TypeParam>::v(), tensor_shape);
|
||||||
auto tensor_vec = tensor->flat<float>().data();
|
auto tensor_vec = tensor->template flat<TypeParam>().data();
|
||||||
|
|
||||||
// Writing sequence of integers as floats which we want back (as they were
|
// Writing sequence of integers as floats which we want back (as they were
|
||||||
// written).
|
// written).
|
||||||
|
@ -92,15 +101,16 @@ TEST_F(TensorToImageFrameCalculatorTest, Converts3DTensorToImageFrameGray) {
|
||||||
}
|
}
|
||||||
|
|
||||||
const int64 time = 1234;
|
const int64 time = 1234;
|
||||||
runner_->MutableInputs()->Tag(kTensor).packets.push_back(
|
runner->MutableInputs()->Tag(kTensor).packets.push_back(
|
||||||
Adopt(tensor.release()).At(Timestamp(time)));
|
Adopt(tensor.release()).At(Timestamp(time)));
|
||||||
|
|
||||||
EXPECT_TRUE(runner_->Run().ok());
|
EXPECT_TRUE(runner->Run().ok());
|
||||||
const std::vector<Packet>& output_packets =
|
const std::vector<Packet>& output_packets =
|
||||||
runner_->Outputs().Tag(kImage).packets;
|
runner->Outputs().Tag(kImage).packets;
|
||||||
EXPECT_EQ(1, output_packets.size());
|
EXPECT_EQ(1, output_packets.size());
|
||||||
EXPECT_EQ(time, output_packets[0].Timestamp().Value());
|
EXPECT_EQ(time, output_packets[0].Timestamp().Value());
|
||||||
const ImageFrame& output_image = output_packets[0].Get<ImageFrame>();
|
const ImageFrame& output_image = output_packets[0].Get<ImageFrame>();
|
||||||
|
EXPECT_EQ(ImageFormat::GRAY8, output_image.Format());
|
||||||
EXPECT_EQ(kWidth, output_image.Width());
|
EXPECT_EQ(kWidth, output_image.Width());
|
||||||
EXPECT_EQ(kHeight, output_image.Height());
|
EXPECT_EQ(kHeight, output_image.Height());
|
||||||
|
|
||||||
|
@ -110,13 +120,16 @@ TEST_F(TensorToImageFrameCalculatorTest, Converts3DTensorToImageFrameGray) {
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
TEST_F(TensorToImageFrameCalculatorTest, Converts3DTensorToImageFrame2DGray) {
|
TYPED_TEST(TensorToImageFrameCalculatorTest,
|
||||||
SetUpRunner();
|
Converts3DTensorToImageFrame2DGray) {
|
||||||
|
this->SetUpRunner();
|
||||||
|
auto& runner = this->runner_;
|
||||||
constexpr int kWidth = 16;
|
constexpr int kWidth = 16;
|
||||||
constexpr int kHeight = 8;
|
constexpr int kHeight = 8;
|
||||||
const tf::TensorShape tensor_shape(std::vector<tf::int64>{kHeight, kWidth});
|
const tf::TensorShape tensor_shape{kHeight, kWidth};
|
||||||
auto tensor = absl::make_unique<tf::Tensor>(tf::DT_FLOAT, tensor_shape);
|
auto tensor = absl::make_unique<tf::Tensor>(
|
||||||
auto tensor_vec = tensor->flat<float>().data();
|
tf::DataTypeToEnum<TypeParam>::v(), tensor_shape);
|
||||||
|
auto tensor_vec = tensor->template flat<TypeParam>().data();
|
||||||
|
|
||||||
// Writing sequence of integers as floats which we want back (as they were
|
// Writing sequence of integers as floats which we want back (as they were
|
||||||
// written).
|
// written).
|
||||||
|
@ -125,15 +138,16 @@ TEST_F(TensorToImageFrameCalculatorTest, Converts3DTensorToImageFrame2DGray) {
|
||||||
}
|
}
|
||||||
|
|
||||||
const int64 time = 1234;
|
const int64 time = 1234;
|
||||||
runner_->MutableInputs()->Tag(kTensor).packets.push_back(
|
runner->MutableInputs()->Tag(kTensor).packets.push_back(
|
||||||
Adopt(tensor.release()).At(Timestamp(time)));
|
Adopt(tensor.release()).At(Timestamp(time)));
|
||||||
|
|
||||||
EXPECT_TRUE(runner_->Run().ok());
|
EXPECT_TRUE(runner->Run().ok());
|
||||||
const std::vector<Packet>& output_packets =
|
const std::vector<Packet>& output_packets =
|
||||||
runner_->Outputs().Tag(kImage).packets;
|
runner->Outputs().Tag(kImage).packets;
|
||||||
EXPECT_EQ(1, output_packets.size());
|
EXPECT_EQ(1, output_packets.size());
|
||||||
EXPECT_EQ(time, output_packets[0].Timestamp().Value());
|
EXPECT_EQ(time, output_packets[0].Timestamp().Value());
|
||||||
const ImageFrame& output_image = output_packets[0].Get<ImageFrame>();
|
const ImageFrame& output_image = output_packets[0].Get<ImageFrame>();
|
||||||
|
EXPECT_EQ(ImageFormat::GRAY8, output_image.Format());
|
||||||
EXPECT_EQ(kWidth, output_image.Width());
|
EXPECT_EQ(kWidth, output_image.Width());
|
||||||
EXPECT_EQ(kHeight, output_image.Height());
|
EXPECT_EQ(kHeight, output_image.Height());
|
||||||
|
|
||||||
|
|
|
@ -91,8 +91,6 @@ absl::Status FillTimeSeriesHeaderIfValid(const Packet& header_packet,
|
||||||
// the input data when it arrives in Process(). In particular, if the header
|
// the input data when it arrives in Process(). In particular, if the header
|
||||||
// states that we produce a 1xD column vector, the input tensor must also be 1xD
|
// states that we produce a 1xD column vector, the input tensor must also be 1xD
|
||||||
//
|
//
|
||||||
// This designed was discussed in http://g/speakeranalysis/4uyx7cNRwJY and
|
|
||||||
// http://g/daredevil-project/VB26tcseUy8.
|
|
||||||
// Example Config
|
// Example Config
|
||||||
// node: {
|
// node: {
|
||||||
// calculator: "TensorToMatrixCalculator"
|
// calculator: "TensorToMatrixCalculator"
|
||||||
|
@ -158,22 +156,17 @@ absl::Status TensorToMatrixCalculator::Open(CalculatorContext* cc) {
|
||||||
if (header_status.ok()) {
|
if (header_status.ok()) {
|
||||||
if (cc->Options<TensorToMatrixCalculatorOptions>()
|
if (cc->Options<TensorToMatrixCalculatorOptions>()
|
||||||
.has_time_series_header_overrides()) {
|
.has_time_series_header_overrides()) {
|
||||||
// From design discussions with Daredevil, we only want to support single
|
// This only supports a single sample per packet for now, so we hardcode
|
||||||
// sample per packet for now, so we hardcode the sample_rate based on the
|
// the sample_rate based on the packet_rate of the REFERENCE and fail
|
||||||
// packet_rate of the REFERENCE and fail noisily if we cannot. An
|
// if we cannot.
|
||||||
// alternative would be to calculate the sample_rate from the reference
|
|
||||||
// sample_rate and the change in num_samples between the reference and
|
|
||||||
// override headers:
|
|
||||||
// sample_rate_output = sample_rate_reference /
|
|
||||||
// (num_samples_override / num_samples_reference)
|
|
||||||
const TimeSeriesHeader& override_header =
|
const TimeSeriesHeader& override_header =
|
||||||
cc->Options<TensorToMatrixCalculatorOptions>()
|
cc->Options<TensorToMatrixCalculatorOptions>()
|
||||||
.time_series_header_overrides();
|
.time_series_header_overrides();
|
||||||
input_header->MergeFrom(override_header);
|
input_header->MergeFrom(override_header);
|
||||||
CHECK(input_header->has_packet_rate())
|
RET_CHECK(input_header->has_packet_rate())
|
||||||
<< "The TimeSeriesHeader.packet_rate must be set.";
|
<< "The TimeSeriesHeader.packet_rate must be set.";
|
||||||
if (!override_header.has_sample_rate()) {
|
if (!override_header.has_sample_rate()) {
|
||||||
CHECK_EQ(input_header->num_samples(), 1)
|
RET_CHECK_EQ(input_header->num_samples(), 1)
|
||||||
<< "Currently the time series can only output single samples.";
|
<< "Currently the time series can only output single samples.";
|
||||||
input_header->set_sample_rate(input_header->packet_rate());
|
input_header->set_sample_rate(input_header->packet_rate());
|
||||||
}
|
}
|
||||||
|
@ -186,20 +179,16 @@ absl::Status TensorToMatrixCalculator::Open(CalculatorContext* cc) {
|
||||||
}
|
}
|
||||||
|
|
||||||
absl::Status TensorToMatrixCalculator::Process(CalculatorContext* cc) {
|
absl::Status TensorToMatrixCalculator::Process(CalculatorContext* cc) {
|
||||||
// Daredevil requested CHECK for noisy failures rather than quieter RET_CHECK
|
|
||||||
// failures. These are absolute conditions of the graph for the graph to be
|
|
||||||
// valid, and if it is violated by any input anywhere, the graph will be
|
|
||||||
// invalid for all inputs. A hard CHECK will enable faster debugging by
|
|
||||||
// immediately exiting and more prominently displaying error messages.
|
|
||||||
// Do not replace with RET_CHECKs.
|
|
||||||
|
|
||||||
// Verify that each reference stream packet corresponds to a tensor packet
|
// Verify that each reference stream packet corresponds to a tensor packet
|
||||||
// otherwise the header information is invalid. If we don't have a reference
|
// otherwise the header information is invalid. If we don't have a reference
|
||||||
// stream, Process() is only called when we have an input tensor and this is
|
// stream, Process() is only called when we have an input tensor and this is
|
||||||
// always True.
|
// always True.
|
||||||
CHECK(cc->Inputs().HasTag(kTensor))
|
RET_CHECK(cc->Inputs().HasTag(kTensor))
|
||||||
<< "Tensor stream not available at same timestamp as the reference "
|
<< "Tensor stream not available at same timestamp as the reference "
|
||||||
"stream.";
|
"stream.";
|
||||||
|
RET_CHECK(!cc->Inputs().Tag(kTensor).IsEmpty()) << "Tensor stream is empty.";
|
||||||
|
RET_CHECK_OK(cc->Inputs().Tag(kTensor).Value().ValidateAsType<tf::Tensor>())
|
||||||
|
<< "Tensor stream packet does not contain a Tensor.";
|
||||||
|
|
||||||
const tf::Tensor& input_tensor = cc->Inputs().Tag(kTensor).Get<tf::Tensor>();
|
const tf::Tensor& input_tensor = cc->Inputs().Tag(kTensor).Get<tf::Tensor>();
|
||||||
CHECK(1 == input_tensor.dims() || 2 == input_tensor.dims())
|
CHECK(1 == input_tensor.dims() || 2 == input_tensor.dims())
|
||||||
|
@ -207,13 +196,12 @@ absl::Status TensorToMatrixCalculator::Process(CalculatorContext* cc) {
|
||||||
const int32 length = input_tensor.dim_size(input_tensor.dims() - 1);
|
const int32 length = input_tensor.dim_size(input_tensor.dims() - 1);
|
||||||
const int32 width = (1 == input_tensor.dims()) ? 1 : input_tensor.dim_size(0);
|
const int32 width = (1 == input_tensor.dims()) ? 1 : input_tensor.dim_size(0);
|
||||||
if (header_.has_num_channels()) {
|
if (header_.has_num_channels()) {
|
||||||
CHECK_EQ(length, header_.num_channels())
|
RET_CHECK_EQ(length, header_.num_channels())
|
||||||
<< "The number of channels at runtime does not match the header.";
|
<< "The number of channels at runtime does not match the header.";
|
||||||
}
|
}
|
||||||
if (header_.has_num_samples()) {
|
if (header_.has_num_samples()) {
|
||||||
CHECK_EQ(width, header_.num_samples())
|
RET_CHECK_EQ(width, header_.num_samples())
|
||||||
<< "The number of samples at runtime does not match the header.";
|
<< "The number of samples at runtime does not match the header.";
|
||||||
;
|
|
||||||
}
|
}
|
||||||
auto output = absl::make_unique<Matrix>(width, length);
|
auto output = absl::make_unique<Matrix>(width, length);
|
||||||
*output =
|
*output =
|
||||||
|
|
|
@ -98,388 +98,543 @@ class InferenceState {
|
||||||
|
|
||||||
// This calculator performs inference on a trained TensorFlow model.
|
// This calculator performs inference on a trained TensorFlow model.
|
||||||
//
|
//
|
||||||
// A mediapipe::TensorFlowSession with a model loaded and ready for use.
|
// TensorFlow Sessions can be created from checkpoint paths, frozen models, or
|
||||||
// For this calculator it must include a tag_to_tensor_map.
|
// the SavedModel system. See the TensorFlowSessionFrom* packet generators for
|
||||||
cc->InputSidePackets().Tag("SESSION").Set<TensorFlowSession>();
|
// details. Each of these methods defines a mapping between MediaPipe streams
|
||||||
if (cc->InputSidePackets().HasTag("RECURRENT_INIT_TENSORS")) {
|
// and TensorFlow tensors. All of this information is passed in as an
|
||||||
cc->InputSidePackets()
|
// input_side_packet.
|
||||||
.Tag("RECURRENT_INIT_TENSORS")
|
//
|
||||||
.Set<std::unique_ptr<std::map<std::string, tf::Tensor>>>();
|
// The input and output streams are TensorFlow tensors labeled by tags. The tags
|
||||||
}
|
// for the streams are matched to feeds and fetchs in a TensorFlow session using
|
||||||
return absl::OkStatus();
|
// a named_signature.generic_signature in the ModelManifest. The
|
||||||
}
|
// generic_signature is used as key-value pairs between the MediaPipe tag and
|
||||||
|
// the TensorFlow tensor. The signature_name in the options proto determines
|
||||||
|
// which named_signature is used. The keys in the generic_signature must be
|
||||||
|
// valid MediaPipe tags ([A-Z0-9_]*, no lowercase or special characters). All of
|
||||||
|
// the tensors corresponding to tags in the signature for input_streams are fed
|
||||||
|
// to the model and for output_streams the tensors are fetched from the model.
|
||||||
|
//
|
||||||
|
// Other calculators are used to convert data to and from tensors, this op only
|
||||||
|
// handles the TensorFlow session and batching. Batching occurs by concatenating
|
||||||
|
// input tensors along the 0th dimension across timestamps. If the 0th dimension
|
||||||
|
// is not a batch dimension, this calculator will add a 0th dimension by
|
||||||
|
// default. Setting add_batch_dim_to_tensors to false disables the dimension
|
||||||
|
// addition. Once batch_size inputs have been provided, the batch will be run
|
||||||
|
// and the output tensors sent out on the output streams with timestamps
|
||||||
|
// corresponding to the input stream packets. Setting the batch_size to 1
|
||||||
|
// completely disables batching, but is indepdent of add_batch_dim_to_tensors.
|
||||||
|
//
|
||||||
|
// The TensorFlowInferenceCalculator also support feeding states recurrently for
|
||||||
|
// RNNs and LSTMs. Simply set the recurrent_tag_pair options to define the
|
||||||
|
// recurrent tensors. Initializing the recurrent state can be handled by the
|
||||||
|
// GraphTensorsPacketGenerator.
|
||||||
|
//
|
||||||
|
// The calculator updates two Counters to report timing information:
|
||||||
|
// --<name>-TotalTimeUsecs = Total time spent running inference (in usecs),
|
||||||
|
// --<name>-TotalProcessedTimestamps = # of instances processed
|
||||||
|
// (approximately batches processed * batch_size),
|
||||||
|
// where <name> is replaced with CalculatorGraphConfig::Node::name() if it
|
||||||
|
// exists, or with TensorFlowInferenceCalculator if the name is not set. The
|
||||||
|
// name must be set for timing information to be instance-specific in graphs
|
||||||
|
// with multiple TensorFlowInferenceCalculators.
|
||||||
|
//
|
||||||
|
// Example config:
|
||||||
|
// packet_generator {
|
||||||
|
// packet_generator: "TensorFlowSessionFromSavedModelGenerator"
|
||||||
|
// output_side_packet: "tensorflow_session"
|
||||||
|
// options {
|
||||||
|
// [mediapipe.TensorFlowSessionFromSavedModelGeneratorOptions.ext]: {
|
||||||
|
// saved_model_path: "/path/to/saved/model"
|
||||||
|
// signature_name: "mediapipe"
|
||||||
|
// }
|
||||||
|
// }
|
||||||
|
// }
|
||||||
|
// node {
|
||||||
|
// calculator: "TensorFlowInferenceCalculator"
|
||||||
|
// input_stream: "IMAGES:image_tensors_keyed_in_signature_by_tag"
|
||||||
|
// input_stream: "AUDIO:audio_tensors_keyed_in_signature_by_tag"
|
||||||
|
// output_stream: "LABELS:softmax_tensor_keyed_in_signature_by_tag"
|
||||||
|
// input_side_packet: "SESSION:tensorflow_session"
|
||||||
|
// }
|
||||||
|
//
|
||||||
|
// Where the input and output streams are treated as Packet<tf::Tensor> and
|
||||||
|
// the mediapipe_signature has tensor bindings between "IMAGES", "AUDIO", and
|
||||||
|
// "LABELS" and their respective tensors exported to /path/to/bundle. For an
|
||||||
|
// example of how this model was exported, see
|
||||||
|
// tensorflow_inference_test_graph_generator.py
|
||||||
|
//
|
||||||
|
// It is possible to use a GraphDef proto that was not exported by exporter (i.e
|
||||||
|
// without MetaGraph with bindings). Such GraphDef could contain all of its
|
||||||
|
// parameters in-lined (for example, it can be the output of freeze_graph.py).
|
||||||
|
// To instantiate a TensorFlow model from a GraphDef file, replace the
|
||||||
|
// packet_factory above with TensorFlowSessionFromFrozenGraphGenerator:
|
||||||
|
//
|
||||||
|
// packet_generator {
|
||||||
|
// packet_generator: "TensorFlowSessionFromFrozenGraphGenerator"
|
||||||
|
// output_side_packet: "SESSION:tensorflow_session"
|
||||||
|
// options {
|
||||||
|
// [mediapipe.TensorFlowSessionFromFrozenGraphGeneratorOptions.ext]: {
|
||||||
|
// graph_proto_path: "[PATH]"
|
||||||
|
// tag_to_tensor_names {
|
||||||
|
// key: "JPG_STRING"
|
||||||
|
// value: "input:0"
|
||||||
|
// }
|
||||||
|
// tag_to_tensor_names {
|
||||||
|
// key: "SOFTMAX"
|
||||||
|
// value: "softmax:0"
|
||||||
|
// }
|
||||||
|
// }
|
||||||
|
// }
|
||||||
|
// }
|
||||||
|
//
|
||||||
|
// It is also possible to use a GraphDef proto and checkpoint file that have not
|
||||||
|
// been frozen. This can be used to load graphs directly as they have been
|
||||||
|
// written from training. However, it is more brittle and you are encouraged to
|
||||||
|
// use a one of the more perminent formats described above. To instantiate a
|
||||||
|
// TensorFlow model from a GraphDef file and checkpoint, replace the
|
||||||
|
// packet_factory above with TensorFlowSessionFromModelCheckpointGenerator:
|
||||||
|
//
|
||||||
|
// packet_generator {
|
||||||
|
// packet_generator: "TensorFlowSessionFromModelCheckpointGenerator"
|
||||||
|
// output_side_packet: "SESSION:tensorflow_session"
|
||||||
|
// options {
|
||||||
|
// [mediapipe.TensorFlowSessionFromModelCheckpointGeneratorOptions.ext]: {
|
||||||
|
// graph_proto_path: "[PATH]"
|
||||||
|
// model_options {
|
||||||
|
// checkpoint_path: "[PATH2]"
|
||||||
|
// }
|
||||||
|
// tag_to_tensor_names {
|
||||||
|
// key: "JPG_STRING"
|
||||||
|
// value: "input:0"
|
||||||
|
// }
|
||||||
|
// tag_to_tensor_names {
|
||||||
|
// key: "SOFTMAX"
|
||||||
|
// value: "softmax:0"
|
||||||
|
// }
|
||||||
|
// }
|
||||||
|
// }
|
||||||
|
// }
|
||||||
|
class TensorFlowInferenceCalculator : public CalculatorBase {
|
||||||
|
public:
|
||||||
|
// Counters for recording timing information. The actual names have the value
|
||||||
|
// of CalculatorGraphConfig::Node::name() prepended.
|
||||||
|
static constexpr char kTotalUsecsCounterSuffix[] = "TotalTimeUsecs";
|
||||||
|
static constexpr char kTotalProcessedTimestampsCounterSuffix[] =
|
||||||
|
"TotalProcessedTimestamps";
|
||||||
|
static constexpr char kTotalSessionRunsTimeUsecsCounterSuffix[] =
|
||||||
|
"TotalSessionRunsTimeUsecs";
|
||||||
|
static constexpr char kTotalNumSessionRunsCounterSuffix[] =
|
||||||
|
"TotalNumSessionRuns";
|
||||||
|
|
||||||
std::unique_ptr<InferenceState> CreateInferenceState(CalculatorContext* cc)
|
TensorFlowInferenceCalculator() : session_(nullptr) {
|
||||||
ABSL_EXCLUSIVE_LOCKS_REQUIRED(mutex_) {
|
clock_ = std::unique_ptr<mediapipe::Clock>(
|
||||||
std::unique_ptr<InferenceState> inference_state =
|
mediapipe::MonotonicClock::CreateSynchronizedMonotonicClock());
|
||||||
absl::make_unique<InferenceState>();
|
}
|
||||||
if (cc->InputSidePackets().HasTag("RECURRENT_INIT_TENSORS") &&
|
|
||||||
!cc->InputSidePackets().Tag("RECURRENT_INIT_TENSORS").IsEmpty()) {
|
static absl::Status GetContract(CalculatorContract* cc) {
|
||||||
std::map<std::string, tf::Tensor>* init_tensor_map;
|
const auto& options = cc->Options<TensorFlowInferenceCalculatorOptions>();
|
||||||
init_tensor_map = GetFromUniquePtr<std::map<std::string, tf::Tensor>>(
|
RET_CHECK(!cc->Inputs().GetTags().empty());
|
||||||
cc->InputSidePackets().Tag("RECURRENT_INIT_TENSORS"));
|
for (const std::string& tag : cc->Inputs().GetTags()) {
|
||||||
for (const auto& p : *init_tensor_map) {
|
// The tensorflow::Tensor with the tag equal to the graph node. May
|
||||||
inference_state->input_tensor_batches_[p.first].emplace_back(p.second);
|
// have a TimeSeriesHeader if all present TimeSeriesHeaders match.
|
||||||
|
if (!options.batched_input()) {
|
||||||
|
cc->Inputs().Tag(tag).Set<tf::Tensor>();
|
||||||
|
} else {
|
||||||
|
cc->Inputs().Tag(tag).Set<std::vector<mediapipe::Packet>>();
|
||||||
|
}
|
||||||
}
|
}
|
||||||
}
|
RET_CHECK(!cc->Outputs().GetTags().empty());
|
||||||
return inference_state;
|
for (const std::string& tag : cc->Outputs().GetTags()) {
|
||||||
}
|
// The tensorflow::Tensor with tag equal to the graph node to
|
||||||
|
// output. Any TimeSeriesHeader from the inputs will be forwarded
|
||||||
absl::Status Open(CalculatorContext* cc) override {
|
// with channels set to 0.
|
||||||
options_ = cc->Options<TensorFlowInferenceCalculatorOptions>();
|
cc->Outputs().Tag(tag).Set<tf::Tensor>();
|
||||||
|
|
||||||
RET_CHECK(cc->InputSidePackets().HasTag("SESSION"));
|
|
||||||
session_ = cc->InputSidePackets()
|
|
||||||
.Tag("SESSION")
|
|
||||||
.Get<TensorFlowSession>()
|
|
||||||
.session.get();
|
|
||||||
tag_to_tensor_map_ = cc->InputSidePackets()
|
|
||||||
.Tag("SESSION")
|
|
||||||
.Get<TensorFlowSession>()
|
|
||||||
.tag_to_tensor_map;
|
|
||||||
|
|
||||||
// Validate and store the recurrent tags
|
|
||||||
RET_CHECK(options_.has_batch_size());
|
|
||||||
RET_CHECK(options_.batch_size() == 1 || options_.recurrent_tag_pair().empty())
|
|
||||||
<< "To use recurrent_tag_pairs, batch_size must be 1.";
|
|
||||||
for (const auto& tag_pair : options_.recurrent_tag_pair()) {
|
|
||||||
const std::vector<std::string> tags = absl::StrSplit(tag_pair, ':');
|
|
||||||
RET_CHECK_EQ(tags.size(), 2)
|
|
||||||
<< "recurrent_tag_pair must be a colon "
|
|
||||||
"separated std::string with two components: "
|
|
||||||
<< tag_pair;
|
|
||||||
RET_CHECK(mediapipe::ContainsKey(tag_to_tensor_map_, tags[0]))
|
|
||||||
<< "Can't find tag '" << tags[0] << "' in signature "
|
|
||||||
<< options_.signature_name();
|
|
||||||
RET_CHECK(mediapipe::ContainsKey(tag_to_tensor_map_, tags[1]))
|
|
||||||
<< "Can't find tag '" << tags[1] << "' in signature "
|
|
||||||
<< options_.signature_name();
|
|
||||||
recurrent_feed_tags_.insert(tags[0]);
|
|
||||||
recurrent_fetch_tags_to_feed_tags_[tags[1]] = tags[0];
|
|
||||||
}
|
|
||||||
|
|
||||||
// Check that all tags are present in this signature bound to tensors.
|
|
||||||
for (const std::string& tag : cc->Inputs().GetTags()) {
|
|
||||||
RET_CHECK(mediapipe::ContainsKey(tag_to_tensor_map_, tag))
|
|
||||||
<< "Can't find tag '" << tag << "' in signature "
|
|
||||||
<< options_.signature_name();
|
|
||||||
}
|
|
||||||
for (const std::string& tag : cc->Outputs().GetTags()) {
|
|
||||||
RET_CHECK(mediapipe::ContainsKey(tag_to_tensor_map_, tag))
|
|
||||||
<< "Can't find tag '" << tag << "' in signature "
|
|
||||||
<< options_.signature_name();
|
|
||||||
}
|
|
||||||
|
|
||||||
{
|
|
||||||
absl::WriterMutexLock l(&mutex_);
|
|
||||||
inference_state_ = std::unique_ptr<InferenceState>();
|
|
||||||
}
|
|
||||||
|
|
||||||
if (options_.batch_size() == 1 || options_.batched_input()) {
|
|
||||||
cc->SetOffset(0);
|
|
||||||
}
|
|
||||||
|
|
||||||
return absl::OkStatus();
|
|
||||||
}
|
|
||||||
|
|
||||||
// Adds a batch dimension to the input tensor if specified in the calculator
|
|
||||||
// options.
|
|
||||||
absl::Status AddBatchDimension(tf::Tensor* input_tensor) {
|
|
||||||
if (options_.add_batch_dim_to_tensors()) {
|
|
||||||
tf::TensorShape new_shape(input_tensor->shape());
|
|
||||||
new_shape.InsertDim(0, 1);
|
|
||||||
RET_CHECK(input_tensor->CopyFrom(*input_tensor, new_shape))
|
|
||||||
<< "Could not add 0th dimension to tensor without changing its shape."
|
|
||||||
<< " Current shape: " << input_tensor->shape().DebugString();
|
|
||||||
}
|
|
||||||
return absl::OkStatus();
|
|
||||||
}
|
|
||||||
|
|
||||||
absl::Status AggregateTensorPacket(
|
|
||||||
const std::string& tag_name, const Packet& packet,
|
|
||||||
std::map<Timestamp, std::map<std::string, tf::Tensor>>*
|
|
||||||
input_tensors_by_tag_by_timestamp,
|
|
||||||
InferenceState* inference_state) ABSL_EXCLUSIVE_LOCKS_REQUIRED(mutex_) {
|
|
||||||
tf::Tensor input_tensor(packet.Get<tf::Tensor>());
|
|
||||||
RET_CHECK_OK(AddBatchDimension(&input_tensor));
|
|
||||||
if (mediapipe::ContainsKey(recurrent_feed_tags_, tag_name)) {
|
|
||||||
// If we receive an input on a recurrent tag, override the state.
|
|
||||||
// It's OK to override the global state because there is just one
|
|
||||||
// input stream allowed for recurrent tensors.
|
|
||||||
inference_state_->input_tensor_batches_[tag_name].clear();
|
|
||||||
}
|
|
||||||
(*input_tensors_by_tag_by_timestamp)[packet.Timestamp()].insert(
|
|
||||||
std::make_pair(tag_name, input_tensor));
|
|
||||||
return absl::OkStatus();
|
|
||||||
}
|
|
||||||
|
|
||||||
// Removes the batch dimension of the output tensor if specified in the
|
|
||||||
// calculator options.
|
|
||||||
absl::Status RemoveBatchDimension(tf::Tensor* output_tensor) {
|
|
||||||
if (options_.add_batch_dim_to_tensors()) {
|
|
||||||
tf::TensorShape new_shape(output_tensor->shape());
|
|
||||||
new_shape.RemoveDim(0);
|
|
||||||
RET_CHECK(output_tensor->CopyFrom(*output_tensor, new_shape))
|
|
||||||
<< "Could not remove 0th dimension from tensor without changing its "
|
|
||||||
<< "shape. Current shape: " << output_tensor->shape().DebugString()
|
|
||||||
<< " (The expected first dimension is 1 for a batch element.)";
|
|
||||||
}
|
|
||||||
return absl::OkStatus();
|
|
||||||
}
|
|
||||||
|
|
||||||
absl::Status Process(CalculatorContext* cc) override {
|
|
||||||
std::unique_ptr<InferenceState> inference_state_to_process;
|
|
||||||
{
|
|
||||||
absl::WriterMutexLock l(&mutex_);
|
|
||||||
if (inference_state_ == nullptr) {
|
|
||||||
inference_state_ = CreateInferenceState(cc);
|
|
||||||
}
|
}
|
||||||
std::map<Timestamp, std::map<std::string, tf::Tensor>>
|
// A mediapipe::TensorFlowSession with a model loaded and ready for use.
|
||||||
input_tensors_by_tag_by_timestamp;
|
// For this calculator it must include a tag_to_tensor_map.
|
||||||
for (const std::string& tag_as_node_name : cc->Inputs().GetTags()) {
|
cc->InputSidePackets().Tag("SESSION").Set<TensorFlowSession>();
|
||||||
if (cc->Inputs().Tag(tag_as_node_name).IsEmpty()) {
|
if (cc->InputSidePackets().HasTag("RECURRENT_INIT_TENSORS")) {
|
||||||
// Recurrent tensors can be empty.
|
cc->InputSidePackets()
|
||||||
if (!mediapipe::ContainsKey(recurrent_feed_tags_, tag_as_node_name)) {
|
.Tag("RECURRENT_INIT_TENSORS")
|
||||||
if (options_.skip_on_missing_features()) {
|
.Set<std::unique_ptr<std::map<std::string, tf::Tensor>>>();
|
||||||
return absl::OkStatus();
|
}
|
||||||
} else {
|
return absl::OkStatus();
|
||||||
return absl::InvalidArgumentError(absl::StrCat(
|
}
|
||||||
"Tag ", tag_as_node_name,
|
|
||||||
" not present at timestamp: ", cc->InputTimestamp().Value()));
|
std::unique_ptr<InferenceState> CreateInferenceState(CalculatorContext* cc)
|
||||||
|
ABSL_EXCLUSIVE_LOCKS_REQUIRED(mutex_) {
|
||||||
|
std::unique_ptr<InferenceState> inference_state =
|
||||||
|
absl::make_unique<InferenceState>();
|
||||||
|
if (cc->InputSidePackets().HasTag("RECURRENT_INIT_TENSORS") &&
|
||||||
|
!cc->InputSidePackets().Tag("RECURRENT_INIT_TENSORS").IsEmpty()) {
|
||||||
|
std::map<std::string, tf::Tensor>* init_tensor_map;
|
||||||
|
init_tensor_map = GetFromUniquePtr<std::map<std::string, tf::Tensor>>(
|
||||||
|
cc->InputSidePackets().Tag("RECURRENT_INIT_TENSORS"));
|
||||||
|
for (const auto& p : *init_tensor_map) {
|
||||||
|
inference_state->input_tensor_batches_[p.first].emplace_back(p.second);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return inference_state;
|
||||||
|
}
|
||||||
|
|
||||||
|
absl::Status Open(CalculatorContext* cc) override {
|
||||||
|
options_ = cc->Options<TensorFlowInferenceCalculatorOptions>();
|
||||||
|
|
||||||
|
RET_CHECK(cc->InputSidePackets().HasTag("SESSION"));
|
||||||
|
session_ = cc->InputSidePackets()
|
||||||
|
.Tag("SESSION")
|
||||||
|
.Get<TensorFlowSession>()
|
||||||
|
.session.get();
|
||||||
|
tag_to_tensor_map_ = cc->InputSidePackets()
|
||||||
|
.Tag("SESSION")
|
||||||
|
.Get<TensorFlowSession>()
|
||||||
|
.tag_to_tensor_map;
|
||||||
|
|
||||||
|
// Validate and store the recurrent tags
|
||||||
|
RET_CHECK(options_.has_batch_size());
|
||||||
|
RET_CHECK(options_.batch_size() == 1 ||
|
||||||
|
options_.recurrent_tag_pair().empty())
|
||||||
|
<< "To use recurrent_tag_pairs, batch_size must be 1.";
|
||||||
|
for (const auto& tag_pair : options_.recurrent_tag_pair()) {
|
||||||
|
const std::vector<std::string> tags = absl::StrSplit(tag_pair, ':');
|
||||||
|
RET_CHECK_EQ(tags.size(), 2)
|
||||||
|
<< "recurrent_tag_pair must be a colon "
|
||||||
|
"separated std::string with two components: "
|
||||||
|
<< tag_pair;
|
||||||
|
RET_CHECK(mediapipe::ContainsKey(tag_to_tensor_map_, tags[0]))
|
||||||
|
<< "Can't find tag '" << tags[0] << "' in signature "
|
||||||
|
<< options_.signature_name();
|
||||||
|
RET_CHECK(mediapipe::ContainsKey(tag_to_tensor_map_, tags[1]))
|
||||||
|
<< "Can't find tag '" << tags[1] << "' in signature "
|
||||||
|
<< options_.signature_name();
|
||||||
|
recurrent_feed_tags_.insert(tags[0]);
|
||||||
|
recurrent_fetch_tags_to_feed_tags_[tags[1]] = tags[0];
|
||||||
|
}
|
||||||
|
|
||||||
|
// Check that all tags are present in this signature bound to tensors.
|
||||||
|
for (const std::string& tag : cc->Inputs().GetTags()) {
|
||||||
|
RET_CHECK(mediapipe::ContainsKey(tag_to_tensor_map_, tag))
|
||||||
|
<< "Can't find tag '" << tag << "' in signature "
|
||||||
|
<< options_.signature_name();
|
||||||
|
}
|
||||||
|
for (const std::string& tag : cc->Outputs().GetTags()) {
|
||||||
|
RET_CHECK(mediapipe::ContainsKey(tag_to_tensor_map_, tag))
|
||||||
|
<< "Can't find tag '" << tag << "' in signature "
|
||||||
|
<< options_.signature_name();
|
||||||
|
}
|
||||||
|
|
||||||
|
{
|
||||||
|
absl::WriterMutexLock l(&mutex_);
|
||||||
|
inference_state_ = std::unique_ptr<InferenceState>();
|
||||||
|
}
|
||||||
|
|
||||||
|
if (options_.batch_size() == 1 || options_.batched_input()) {
|
||||||
|
cc->SetOffset(0);
|
||||||
|
}
|
||||||
|
|
||||||
|
return absl::OkStatus();
|
||||||
|
}
|
||||||
|
|
||||||
|
// Adds a batch dimension to the input tensor if specified in the calculator
|
||||||
|
// options.
|
||||||
|
absl::Status AddBatchDimension(tf::Tensor* input_tensor) {
|
||||||
|
if (options_.add_batch_dim_to_tensors()) {
|
||||||
|
tf::TensorShape new_shape(input_tensor->shape());
|
||||||
|
new_shape.InsertDim(0, 1);
|
||||||
|
RET_CHECK(input_tensor->CopyFrom(*input_tensor, new_shape))
|
||||||
|
<< "Could not add 0th dimension to tensor without changing its shape."
|
||||||
|
<< " Current shape: " << input_tensor->shape().DebugString();
|
||||||
|
}
|
||||||
|
return absl::OkStatus();
|
||||||
|
}
|
||||||
|
|
||||||
|
absl::Status AggregateTensorPacket(
|
||||||
|
const std::string& tag_name, const Packet& packet,
|
||||||
|
std::map<Timestamp, std::map<std::string, tf::Tensor>>*
|
||||||
|
input_tensors_by_tag_by_timestamp,
|
||||||
|
InferenceState* inference_state) ABSL_EXCLUSIVE_LOCKS_REQUIRED(mutex_) {
|
||||||
|
tf::Tensor input_tensor(packet.Get<tf::Tensor>());
|
||||||
|
RET_CHECK_OK(AddBatchDimension(&input_tensor));
|
||||||
|
if (mediapipe::ContainsKey(recurrent_feed_tags_, tag_name)) {
|
||||||
|
// If we receive an input on a recurrent tag, override the state.
|
||||||
|
// It's OK to override the global state because there is just one
|
||||||
|
// input stream allowed for recurrent tensors.
|
||||||
|
inference_state_->input_tensor_batches_[tag_name].clear();
|
||||||
|
}
|
||||||
|
(*input_tensors_by_tag_by_timestamp)[packet.Timestamp()].insert(
|
||||||
|
std::make_pair(tag_name, input_tensor));
|
||||||
|
return absl::OkStatus();
|
||||||
|
}
|
||||||
|
|
||||||
|
// Removes the batch dimension of the output tensor if specified in the
|
||||||
|
// calculator options.
|
||||||
|
absl::Status RemoveBatchDimension(tf::Tensor* output_tensor) {
|
||||||
|
if (options_.add_batch_dim_to_tensors()) {
|
||||||
|
tf::TensorShape new_shape(output_tensor->shape());
|
||||||
|
new_shape.RemoveDim(0);
|
||||||
|
RET_CHECK(output_tensor->CopyFrom(*output_tensor, new_shape))
|
||||||
|
<< "Could not remove 0th dimension from tensor without changing its "
|
||||||
|
<< "shape. Current shape: " << output_tensor->shape().DebugString()
|
||||||
|
<< " (The expected first dimension is 1 for a batch element.)";
|
||||||
|
}
|
||||||
|
return absl::OkStatus();
|
||||||
|
}
|
||||||
|
|
||||||
|
absl::Status Process(CalculatorContext* cc) override {
|
||||||
|
std::unique_ptr<InferenceState> inference_state_to_process;
|
||||||
|
{
|
||||||
|
absl::WriterMutexLock l(&mutex_);
|
||||||
|
if (inference_state_ == nullptr) {
|
||||||
|
inference_state_ = CreateInferenceState(cc);
|
||||||
|
}
|
||||||
|
std::map<Timestamp, std::map<std::string, tf::Tensor>>
|
||||||
|
input_tensors_by_tag_by_timestamp;
|
||||||
|
for (const std::string& tag_as_node_name : cc->Inputs().GetTags()) {
|
||||||
|
if (cc->Inputs().Tag(tag_as_node_name).IsEmpty()) {
|
||||||
|
// Recurrent tensors can be empty.
|
||||||
|
if (!mediapipe::ContainsKey(recurrent_feed_tags_, tag_as_node_name)) {
|
||||||
|
if (options_.skip_on_missing_features()) {
|
||||||
|
return absl::OkStatus();
|
||||||
|
} else {
|
||||||
|
return absl::InvalidArgumentError(absl::StrCat(
|
||||||
|
"Tag ", tag_as_node_name,
|
||||||
|
" not present at timestamp: ", cc->InputTimestamp().Value()));
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
} else if (options_.batched_input()) {
|
||||||
|
const auto& tensor_packets =
|
||||||
|
cc->Inputs().Tag(tag_as_node_name).Get<std::vector<Packet>>();
|
||||||
|
if (tensor_packets.size() > options_.batch_size()) {
|
||||||
|
return absl::InvalidArgumentError(absl::StrCat(
|
||||||
|
"Batch for tag ", tag_as_node_name,
|
||||||
|
" has more packets than batch capacity. batch_size: ",
|
||||||
|
options_.batch_size(), " packets: ", tensor_packets.size()));
|
||||||
|
}
|
||||||
|
for (const auto& packet : tensor_packets) {
|
||||||
|
RET_CHECK_OK(AggregateTensorPacket(
|
||||||
|
tag_as_node_name, packet, &input_tensors_by_tag_by_timestamp,
|
||||||
|
inference_state_.get()));
|
||||||
|
}
|
||||||
|
} else {
|
||||||
|
RET_CHECK_OK(AggregateTensorPacket(
|
||||||
|
tag_as_node_name, cc->Inputs().Tag(tag_as_node_name).Value(),
|
||||||
|
&input_tensors_by_tag_by_timestamp, inference_state_.get()));
|
||||||
}
|
}
|
||||||
} else if (options_.batched_input()) {
|
}
|
||||||
const auto& tensor_packets =
|
for (const auto& timestamp_and_input_tensors_by_tag :
|
||||||
cc->Inputs().Tag(tag_as_node_name).Get<std::vector<Packet>>();
|
input_tensors_by_tag_by_timestamp) {
|
||||||
if (tensor_packets.size() > options_.batch_size()) {
|
inference_state_->batch_timestamps_.emplace_back(
|
||||||
return absl::InvalidArgumentError(absl::StrCat(
|
timestamp_and_input_tensors_by_tag.first);
|
||||||
"Batch for tag ", tag_as_node_name,
|
for (const auto& input_tensor_and_tag :
|
||||||
" has more packets than batch capacity. batch_size: ",
|
timestamp_and_input_tensors_by_tag.second) {
|
||||||
options_.batch_size(), " packets: ", tensor_packets.size()));
|
inference_state_->input_tensor_batches_[input_tensor_and_tag.first]
|
||||||
|
.emplace_back(input_tensor_and_tag.second);
|
||||||
}
|
}
|
||||||
for (const auto& packet : tensor_packets) {
|
}
|
||||||
RET_CHECK_OK(AggregateTensorPacket(tag_as_node_name, packet,
|
if (inference_state_->batch_timestamps_.size() == options_.batch_size() ||
|
||||||
&input_tensors_by_tag_by_timestamp,
|
options_.batched_input()) {
|
||||||
inference_state_.get()));
|
inference_state_to_process = std::move(inference_state_);
|
||||||
|
inference_state_ = std::unique_ptr<InferenceState>();
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
if (inference_state_to_process) {
|
||||||
|
MP_RETURN_IF_ERROR(
|
||||||
|
OutputBatch(cc, std::move(inference_state_to_process)));
|
||||||
|
}
|
||||||
|
|
||||||
|
return absl::OkStatus();
|
||||||
|
}
|
||||||
|
|
||||||
|
absl::Status Close(CalculatorContext* cc) override {
|
||||||
|
std::unique_ptr<InferenceState> inference_state_to_process = nullptr;
|
||||||
|
{
|
||||||
|
absl::WriterMutexLock l(&mutex_);
|
||||||
|
if (cc->GraphStatus().ok() && inference_state_ != nullptr &&
|
||||||
|
!inference_state_->batch_timestamps_.empty()) {
|
||||||
|
inference_state_to_process = std::move(inference_state_);
|
||||||
|
inference_state_ = std::unique_ptr<InferenceState>();
|
||||||
|
}
|
||||||
|
}
|
||||||
|
if (inference_state_to_process) {
|
||||||
|
MP_RETURN_IF_ERROR(
|
||||||
|
OutputBatch(cc, std::move(inference_state_to_process)));
|
||||||
|
}
|
||||||
|
return absl::OkStatus();
|
||||||
|
}
|
||||||
|
|
||||||
|
// When a batch of input tensors is ready to be run, runs TensorFlow and
|
||||||
|
// outputs the output tensors. The output tensors have timestamps matching
|
||||||
|
// the input tensor that formed that batch element. Any requested
|
||||||
|
// batch_dimension is added and removed. This code takes advantage of the fact
|
||||||
|
// that copying a tensor shares the same reference-counted, heap allocated
|
||||||
|
// memory buffer. Therefore, copies are cheap and should not cause the memory
|
||||||
|
// buffer to fall out of scope. In contrast, concat is only used where
|
||||||
|
// necessary.
|
||||||
|
absl::Status OutputBatch(CalculatorContext* cc,
|
||||||
|
std::unique_ptr<InferenceState> inference_state) {
|
||||||
|
const int64 start_time = absl::ToUnixMicros(clock_->TimeNow());
|
||||||
|
std::vector<std::pair<mediapipe::ProtoString, tf::Tensor>> input_tensors;
|
||||||
|
|
||||||
|
for (auto& keyed_tensors : inference_state->input_tensor_batches_) {
|
||||||
|
if (options_.batch_size() == 1) {
|
||||||
|
// Short circuit to avoid the cost of deep copying tensors in concat.
|
||||||
|
if (!keyed_tensors.second.empty()) {
|
||||||
|
input_tensors.emplace_back(tag_to_tensor_map_[keyed_tensors.first],
|
||||||
|
keyed_tensors.second[0]);
|
||||||
|
} else {
|
||||||
|
// The input buffer can be empty for recurrent tensors.
|
||||||
|
RET_CHECK(
|
||||||
|
mediapipe::ContainsKey(recurrent_feed_tags_, keyed_tensors.first))
|
||||||
|
<< "A non-recurrent tensor does not have an input: "
|
||||||
|
<< keyed_tensors.first;
|
||||||
}
|
}
|
||||||
} else {
|
} else {
|
||||||
RET_CHECK_OK(AggregateTensorPacket(
|
// Pad by replicating the first tens or, then ignore the values.
|
||||||
tag_as_node_name, cc->Inputs().Tag(tag_as_node_name).Value(),
|
keyed_tensors.second.resize(options_.batch_size());
|
||||||
&input_tensors_by_tag_by_timestamp, inference_state_.get()));
|
std::fill(keyed_tensors.second.begin() +
|
||||||
}
|
inference_state->batch_timestamps_.size(),
|
||||||
}
|
keyed_tensors.second.end(), keyed_tensors.second[0]);
|
||||||
for (const auto& timestamp_and_input_tensors_by_tag :
|
tf::Tensor concated;
|
||||||
input_tensors_by_tag_by_timestamp) {
|
const tf::Status concat_status =
|
||||||
inference_state_->batch_timestamps_.emplace_back(
|
tf::tensor::Concat(keyed_tensors.second, &concated);
|
||||||
timestamp_and_input_tensors_by_tag.first);
|
CHECK(concat_status.ok()) << concat_status.ToString();
|
||||||
for (const auto& input_tensor_and_tag :
|
|
||||||
timestamp_and_input_tensors_by_tag.second) {
|
|
||||||
inference_state_->input_tensor_batches_[input_tensor_and_tag.first]
|
|
||||||
.emplace_back(input_tensor_and_tag.second);
|
|
||||||
}
|
|
||||||
}
|
|
||||||
if (inference_state_->batch_timestamps_.size() == options_.batch_size() ||
|
|
||||||
options_.batched_input()) {
|
|
||||||
inference_state_to_process = std::move(inference_state_);
|
|
||||||
inference_state_ = std::unique_ptr<InferenceState>();
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
if (inference_state_to_process) {
|
|
||||||
MP_RETURN_IF_ERROR(OutputBatch(cc, std::move(inference_state_to_process)));
|
|
||||||
}
|
|
||||||
|
|
||||||
return absl::OkStatus();
|
|
||||||
}
|
|
||||||
|
|
||||||
absl::Status Close(CalculatorContext* cc) override {
|
|
||||||
std::unique_ptr<InferenceState> inference_state_to_process = nullptr;
|
|
||||||
{
|
|
||||||
absl::WriterMutexLock l(&mutex_);
|
|
||||||
if (cc->GraphStatus().ok() && inference_state_ != nullptr &&
|
|
||||||
!inference_state_->batch_timestamps_.empty()) {
|
|
||||||
inference_state_to_process = std::move(inference_state_);
|
|
||||||
inference_state_ = std::unique_ptr<InferenceState>();
|
|
||||||
}
|
|
||||||
}
|
|
||||||
if (inference_state_to_process) {
|
|
||||||
MP_RETURN_IF_ERROR(OutputBatch(cc, std::move(inference_state_to_process)));
|
|
||||||
}
|
|
||||||
return absl::OkStatus();
|
|
||||||
}
|
|
||||||
|
|
||||||
// When a batch of input tensors is ready to be run, runs TensorFlow and
|
|
||||||
// outputs the output tensors. The output tensors have timestamps matching
|
|
||||||
// the input tensor that formed that batch element. Any requested
|
|
||||||
// batch_dimension is added and removed. This code takes advantage of the fact
|
|
||||||
// that copying a tensor shares the same reference-counted, heap allocated
|
|
||||||
// memory buffer. Therefore, copies are cheap and should not cause the memory
|
|
||||||
// buffer to fall out of scope. In contrast, concat is only used where
|
|
||||||
// necessary.
|
|
||||||
absl::Status OutputBatch(CalculatorContext* cc,
|
|
||||||
std::unique_ptr<InferenceState> inference_state) {
|
|
||||||
const int64 start_time = absl::ToUnixMicros(clock_->TimeNow());
|
|
||||||
std::vector<std::pair<mediapipe::ProtoString, tf::Tensor>> input_tensors;
|
|
||||||
|
|
||||||
for (auto& keyed_tensors : inference_state->input_tensor_batches_) {
|
|
||||||
if (options_.batch_size() == 1) {
|
|
||||||
// Short circuit to avoid the cost of deep copying tensors in concat.
|
|
||||||
if (!keyed_tensors.second.empty()) {
|
|
||||||
input_tensors.emplace_back(tag_to_tensor_map_[keyed_tensors.first],
|
input_tensors.emplace_back(tag_to_tensor_map_[keyed_tensors.first],
|
||||||
keyed_tensors.second[0]);
|
concated);
|
||||||
} else {
|
|
||||||
// The input buffer can be empty for recurrent tensors.
|
|
||||||
RET_CHECK(
|
|
||||||
mediapipe::ContainsKey(recurrent_feed_tags_, keyed_tensors.first))
|
|
||||||
<< "A non-recurrent tensor does not have an input: "
|
|
||||||
<< keyed_tensors.first;
|
|
||||||
}
|
}
|
||||||
} else {
|
|
||||||
// Pad by replicating the first tens or, then ignore the values.
|
|
||||||
keyed_tensors.second.resize(options_.batch_size());
|
|
||||||
std::fill(keyed_tensors.second.begin() +
|
|
||||||
inference_state->batch_timestamps_.size(),
|
|
||||||
keyed_tensors.second.end(), keyed_tensors.second[0]);
|
|
||||||
tf::Tensor concated;
|
|
||||||
const tf::Status concat_status =
|
|
||||||
tf::tensor::Concat(keyed_tensors.second, &concated);
|
|
||||||
CHECK(concat_status.ok()) << concat_status.ToString();
|
|
||||||
input_tensors.emplace_back(tag_to_tensor_map_[keyed_tensors.first],
|
|
||||||
concated);
|
|
||||||
}
|
}
|
||||||
}
|
inference_state->input_tensor_batches_.clear();
|
||||||
inference_state->input_tensor_batches_.clear();
|
std::vector<mediapipe::ProtoString> output_tensor_names;
|
||||||
std::vector<mediapipe::ProtoString> output_tensor_names;
|
std::vector<std::string> output_name_in_signature;
|
||||||
std::vector<std::string> output_name_in_signature;
|
for (const std::string& tag : cc->Outputs().GetTags()) {
|
||||||
for (const std::string& tag : cc->Outputs().GetTags()) {
|
output_tensor_names.emplace_back(tag_to_tensor_map_[tag]);
|
||||||
output_tensor_names.emplace_back(tag_to_tensor_map_[tag]);
|
output_name_in_signature.emplace_back(tag);
|
||||||
output_name_in_signature.emplace_back(tag);
|
|
||||||
}
|
|
||||||
for (const auto& tag_pair : recurrent_fetch_tags_to_feed_tags_) {
|
|
||||||
// Ensure that we always fetch the recurrent state tensors.
|
|
||||||
if (std::find(output_name_in_signature.begin(),
|
|
||||||
output_name_in_signature.end(),
|
|
||||||
tag_pair.first) == output_name_in_signature.end()) {
|
|
||||||
output_tensor_names.emplace_back(tag_to_tensor_map_[tag_pair.first]);
|
|
||||||
output_name_in_signature.emplace_back(tag_pair.first);
|
|
||||||
}
|
}
|
||||||
}
|
for (const auto& tag_pair : recurrent_fetch_tags_to_feed_tags_) {
|
||||||
std::vector<tf::Tensor> outputs;
|
// Ensure that we always fetch the recurrent state tensors.
|
||||||
|
if (std::find(output_name_in_signature.begin(),
|
||||||
|
output_name_in_signature.end(),
|
||||||
|
tag_pair.first) == output_name_in_signature.end()) {
|
||||||
|
output_tensor_names.emplace_back(tag_to_tensor_map_[tag_pair.first]);
|
||||||
|
output_name_in_signature.emplace_back(tag_pair.first);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
std::vector<tf::Tensor> outputs;
|
||||||
|
|
||||||
SimpleSemaphore* session_run_throttle = nullptr;
|
SimpleSemaphore* session_run_throttle = nullptr;
|
||||||
if (options_.max_concurrent_session_runs() > 0) {
|
if (options_.max_concurrent_session_runs() > 0) {
|
||||||
session_run_throttle =
|
session_run_throttle =
|
||||||
get_session_run_throttle(options_.max_concurrent_session_runs());
|
get_session_run_throttle(options_.max_concurrent_session_runs());
|
||||||
session_run_throttle->Acquire(1);
|
session_run_throttle->Acquire(1);
|
||||||
}
|
}
|
||||||
const int64 run_start_time = absl::ToUnixMicros(clock_->TimeNow());
|
const int64 run_start_time = absl::ToUnixMicros(clock_->TimeNow());
|
||||||
tf::Status tf_status;
|
tf::Status tf_status;
|
||||||
{
|
{
|
||||||
#if !defined(MEDIAPIPE_MOBILE) && !defined(__APPLE__)
|
#if !defined(MEDIAPIPE_MOBILE) && !defined(__APPLE__)
|
||||||
tensorflow::profiler::TraceMe trace(absl::string_view(cc->NodeName()));
|
tensorflow::profiler::TraceMe trace(absl::string_view(cc->NodeName()));
|
||||||
#endif
|
#endif
|
||||||
tf_status = session_->Run(input_tensors, output_tensor_names,
|
tf_status = session_->Run(input_tensors, output_tensor_names,
|
||||||
{} /* target_node_names */, &outputs);
|
{} /* target_node_names */, &outputs);
|
||||||
}
|
}
|
||||||
|
|
||||||
if (session_run_throttle != nullptr) {
|
if (session_run_throttle != nullptr) {
|
||||||
session_run_throttle->Release(1);
|
session_run_throttle->Release(1);
|
||||||
}
|
}
|
||||||
|
|
||||||
// RET_CHECK on the tf::Status object itself in order to print an
|
// RET_CHECK on the tf::Status object itself in order to print an
|
||||||
// informative error message.
|
// informative error message.
|
||||||
RET_CHECK(tf_status.ok()) << "Run failed: " << tf_status.ToString();
|
RET_CHECK(tf_status.ok()) << "Run failed: " << tf_status.ToString();
|
||||||
|
|
||||||
const int64 run_end_time = absl::ToUnixMicros(clock_->TimeNow());
|
const int64 run_end_time = absl::ToUnixMicros(clock_->TimeNow());
|
||||||
cc->GetCounter(kTotalSessionRunsTimeUsecsCounterSuffix)
|
cc->GetCounter(kTotalSessionRunsTimeUsecsCounterSuffix)
|
||||||
->IncrementBy(run_end_time - run_start_time);
|
->IncrementBy(run_end_time - run_start_time);
|
||||||
cc->GetCounter(kTotalNumSessionRunsCounterSuffix)->Increment();
|
cc->GetCounter(kTotalNumSessionRunsCounterSuffix)->Increment();
|
||||||
|
|
||||||
// Feed back the recurrent state.
|
// Feed back the recurrent state.
|
||||||
for (const auto& tag_pair : recurrent_fetch_tags_to_feed_tags_) {
|
for (const auto& tag_pair : recurrent_fetch_tags_to_feed_tags_) {
|
||||||
int pos = std::find(output_name_in_signature.begin(),
|
int pos = std::find(output_name_in_signature.begin(),
|
||||||
output_name_in_signature.end(), tag_pair.first) -
|
output_name_in_signature.end(), tag_pair.first) -
|
||||||
output_name_in_signature.begin();
|
output_name_in_signature.begin();
|
||||||
inference_state->input_tensor_batches_[tag_pair.second].emplace_back(
|
inference_state->input_tensor_batches_[tag_pair.second].emplace_back(
|
||||||
outputs[pos]);
|
outputs[pos]);
|
||||||
}
|
}
|
||||||
|
|
||||||
absl::WriterMutexLock l(&mutex_);
|
absl::WriterMutexLock l(&mutex_);
|
||||||
// Set that we want to split on each index of the 0th dimension.
|
// Set that we want to split on each index of the 0th dimension.
|
||||||
std::vector<tf::int64> split_vector(options_.batch_size(), 1);
|
std::vector<tf::int64> split_vector(options_.batch_size(), 1);
|
||||||
for (int i = 0; i < output_tensor_names.size(); ++i) {
|
for (int i = 0; i < output_tensor_names.size(); ++i) {
|
||||||
if (options_.batch_size() == 1) {
|
if (options_.batch_size() == 1) {
|
||||||
if (cc->Outputs().HasTag(output_name_in_signature[i])) {
|
if (cc->Outputs().HasTag(output_name_in_signature[i])) {
|
||||||
tf::Tensor output_tensor(outputs[i]);
|
tf::Tensor output_tensor(outputs[i]);
|
||||||
RET_CHECK_OK(RemoveBatchDimension(&output_tensor));
|
RET_CHECK_OK(RemoveBatchDimension(&output_tensor));
|
||||||
cc->Outputs()
|
cc->Outputs()
|
||||||
.Tag(output_name_in_signature[i])
|
.Tag(output_name_in_signature[i])
|
||||||
.Add(new tf::Tensor(output_tensor),
|
.Add(new tf::Tensor(output_tensor),
|
||||||
inference_state->batch_timestamps_[0]);
|
inference_state->batch_timestamps_[0]);
|
||||||
}
|
}
|
||||||
} else {
|
} else {
|
||||||
std::vector<tf::Tensor> split_tensors;
|
std::vector<tf::Tensor> split_tensors;
|
||||||
const tf::Status split_status =
|
const tf::Status split_status =
|
||||||
tf::tensor::Split(outputs[i], split_vector, &split_tensors);
|
tf::tensor::Split(outputs[i], split_vector, &split_tensors);
|
||||||
CHECK(split_status.ok()) << split_status.ToString();
|
CHECK(split_status.ok()) << split_status.ToString();
|
||||||
// Loop over timestamps so that we don't copy the padding.
|
// Loop over timestamps so that we don't copy the padding.
|
||||||
for (int j = 0; j < inference_state->batch_timestamps_.size(); ++j) {
|
for (int j = 0; j < inference_state->batch_timestamps_.size(); ++j) {
|
||||||
tf::Tensor output_tensor(split_tensors[j]);
|
tf::Tensor output_tensor(split_tensors[j]);
|
||||||
RET_CHECK_OK(RemoveBatchDimension(&output_tensor));
|
RET_CHECK_OK(RemoveBatchDimension(&output_tensor));
|
||||||
cc->Outputs()
|
cc->Outputs()
|
||||||
.Tag(output_name_in_signature[i])
|
.Tag(output_name_in_signature[i])
|
||||||
.Add(new tf::Tensor(output_tensor),
|
.Add(new tf::Tensor(output_tensor),
|
||||||
inference_state->batch_timestamps_[j]);
|
inference_state->batch_timestamps_[j]);
|
||||||
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// Get end time and report.
|
||||||
|
const int64 end_time = absl::ToUnixMicros(clock_->TimeNow());
|
||||||
|
cc->GetCounter(kTotalUsecsCounterSuffix)
|
||||||
|
->IncrementBy(end_time - start_time);
|
||||||
|
cc->GetCounter(kTotalProcessedTimestampsCounterSuffix)
|
||||||
|
->IncrementBy(inference_state->batch_timestamps_.size());
|
||||||
|
|
||||||
|
// Make sure we hold on to the recursive state.
|
||||||
|
if (!options_.recurrent_tag_pair().empty()) {
|
||||||
|
inference_state_ = std::move(inference_state);
|
||||||
|
inference_state_->batch_timestamps_.clear();
|
||||||
|
}
|
||||||
|
|
||||||
|
return absl::OkStatus();
|
||||||
}
|
}
|
||||||
|
|
||||||
// Get end time and report.
|
private:
|
||||||
const int64 end_time = absl::ToUnixMicros(clock_->TimeNow());
|
// The Session object is provided by a packet factory and is owned by the
|
||||||
cc->GetCounter(kTotalUsecsCounterSuffix)->IncrementBy(end_time - start_time);
|
// MediaPipe framework. Individual calls are thread-safe, but session state
|
||||||
cc->GetCounter(kTotalProcessedTimestampsCounterSuffix)
|
// may be shared across threads.
|
||||||
->IncrementBy(inference_state->batch_timestamps_.size());
|
tf::Session* session_;
|
||||||
|
|
||||||
// Make sure we hold on to the recursive state.
|
// A mapping between stream tags and the tensor names they are bound to.
|
||||||
if (!options_.recurrent_tag_pair().empty()) {
|
std::map<std::string, std::string> tag_to_tensor_map_;
|
||||||
inference_state_ = std::move(inference_state);
|
|
||||||
inference_state_->batch_timestamps_.clear();
|
absl::Mutex mutex_;
|
||||||
|
std::unique_ptr<InferenceState> inference_state_ ABSL_GUARDED_BY(mutex_);
|
||||||
|
|
||||||
|
// The options for the calculator.
|
||||||
|
TensorFlowInferenceCalculatorOptions options_;
|
||||||
|
|
||||||
|
// Store the feed and fetch tags for feed/fetch recurrent networks.
|
||||||
|
std::set<std::string> recurrent_feed_tags_;
|
||||||
|
std::map<std::string, std::string> recurrent_fetch_tags_to_feed_tags_;
|
||||||
|
|
||||||
|
// Clock used to measure the computation time in OutputBatch().
|
||||||
|
std::unique_ptr<mediapipe::Clock> clock_;
|
||||||
|
|
||||||
|
// The static singleton semaphore to throttle concurrent session runs.
|
||||||
|
static SimpleSemaphore* get_session_run_throttle(
|
||||||
|
int32 max_concurrent_session_runs) {
|
||||||
|
static SimpleSemaphore* session_run_throttle =
|
||||||
|
new SimpleSemaphore(max_concurrent_session_runs);
|
||||||
|
return session_run_throttle;
|
||||||
}
|
}
|
||||||
|
};
|
||||||
return absl::OkStatus();
|
|
||||||
}
|
|
||||||
|
|
||||||
private:
|
|
||||||
// The Session object is provided by a packet factory and is owned by the
|
|
||||||
// MediaPipe framework. Individual calls are thread-safe, but session state may
|
|
||||||
// be shared across threads.
|
|
||||||
tf::Session* session_;
|
|
||||||
|
|
||||||
// A mapping between stream tags and the tensor names they are bound to.
|
|
||||||
std::map<std::string, std::string> tag_to_tensor_map_;
|
|
||||||
|
|
||||||
absl::Mutex mutex_;
|
|
||||||
std::unique_ptr<InferenceState> inference_state_ ABSL_GUARDED_BY(mutex_);
|
|
||||||
|
|
||||||
// The options for the calculator.
|
|
||||||
TensorFlowInferenceCalculatorOptions options_;
|
|
||||||
|
|
||||||
// Store the feed and fetch tags for feed/fetch recurrent networks.
|
|
||||||
std::set<std::string> recurrent_feed_tags_;
|
|
||||||
std::map<std::string, std::string> recurrent_fetch_tags_to_feed_tags_;
|
|
||||||
|
|
||||||
// Clock used to measure the computation time in OutputBatch().
|
|
||||||
std::unique_ptr<mediapipe::Clock> clock_;
|
|
||||||
|
|
||||||
// The static singleton semaphore to throttle concurrent session runs.
|
|
||||||
static SimpleSemaphore* get_session_run_throttle(
|
|
||||||
int32 max_concurrent_session_runs) {
|
|
||||||
static SimpleSemaphore* session_run_throttle =
|
|
||||||
new SimpleSemaphore(max_concurrent_session_runs);
|
|
||||||
return session_run_throttle;
|
|
||||||
}
|
|
||||||
}
|
|
||||||
;
|
|
||||||
REGISTER_CALCULATOR(TensorFlowInferenceCalculator);
|
REGISTER_CALCULATOR(TensorFlowInferenceCalculator);
|
||||||
|
|
||||||
constexpr char TensorFlowInferenceCalculator::kTotalUsecsCounterSuffix[];
|
constexpr char TensorFlowInferenceCalculator::kTotalUsecsCounterSuffix[];
|
||||||
|
|
|
@ -80,6 +80,7 @@ const std::string MaybeConvertSignatureToTag(
|
||||||
// which in turn contains a TensorFlow Session ready for execution and a map
|
// which in turn contains a TensorFlow Session ready for execution and a map
|
||||||
// between tags and tensor names.
|
// between tags and tensor names.
|
||||||
//
|
//
|
||||||
|
//
|
||||||
// Example usage:
|
// Example usage:
|
||||||
// node {
|
// node {
|
||||||
// calculator: "TensorFlowSessionFromSavedModelCalculator"
|
// calculator: "TensorFlowSessionFromSavedModelCalculator"
|
||||||
|
|
|
@ -217,38 +217,41 @@ class UnpackMediaSequenceCalculator : public CalculatorBase {
|
||||||
first_timestamp_seen_ = recent_timestamp;
|
first_timestamp_seen_ = recent_timestamp;
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
if (recent_timestamp > last_timestamp_seen) {
|
if (recent_timestamp > last_timestamp_seen &&
|
||||||
|
recent_timestamp < Timestamp::PostStream().Value()) {
|
||||||
last_timestamp_key_ = map_kv.first;
|
last_timestamp_key_ = map_kv.first;
|
||||||
last_timestamp_seen = recent_timestamp;
|
last_timestamp_seen = recent_timestamp;
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
if (!timestamps_.empty()) {
|
if (!timestamps_.empty()) {
|
||||||
RET_CHECK(!last_timestamp_key_.empty())
|
for (const auto& kv : timestamps_) {
|
||||||
<< "Something went wrong because the timestamp key is unset. "
|
if (!kv.second.empty() &&
|
||||||
"Example: "
|
kv.second[0] < Timestamp::PostStream().Value()) {
|
||||||
<< sequence_->DebugString();
|
// These checks only make sense if any values are not PostStream, but
|
||||||
RET_CHECK_GT(last_timestamp_seen, Timestamp::PreStream().Value())
|
// only need to be made once.
|
||||||
<< "Something went wrong because the last timestamp is unset. "
|
RET_CHECK(!last_timestamp_key_.empty())
|
||||||
"Example: "
|
<< "Something went wrong because the timestamp key is unset. "
|
||||||
<< sequence_->DebugString();
|
<< "Example: " << sequence_->DebugString();
|
||||||
RET_CHECK_LT(first_timestamp_seen_,
|
RET_CHECK_GT(last_timestamp_seen, Timestamp::PreStream().Value())
|
||||||
Timestamp::OneOverPostStream().Value())
|
<< "Something went wrong because the last timestamp is unset. "
|
||||||
<< "Something went wrong because the first timestamp is unset. "
|
<< "Example: " << sequence_->DebugString();
|
||||||
"Example: "
|
RET_CHECK_LT(first_timestamp_seen_,
|
||||||
<< sequence_->DebugString();
|
Timestamp::OneOverPostStream().Value())
|
||||||
|
<< "Something went wrong because the first timestamp is unset. "
|
||||||
|
<< "Example: " << sequence_->DebugString();
|
||||||
|
break;
|
||||||
|
}
|
||||||
|
}
|
||||||
}
|
}
|
||||||
current_timestamp_index_ = 0;
|
current_timestamp_index_ = 0;
|
||||||
|
process_poststream_ = false;
|
||||||
|
|
||||||
// Determine the data path and output it.
|
// Determine the data path and output it.
|
||||||
const auto& options = cc->Options<UnpackMediaSequenceCalculatorOptions>();
|
const auto& options = cc->Options<UnpackMediaSequenceCalculatorOptions>();
|
||||||
const auto& sequence = cc->InputSidePackets()
|
const auto& sequence = cc->InputSidePackets()
|
||||||
.Tag(kSequenceExampleTag)
|
.Tag(kSequenceExampleTag)
|
||||||
.Get<tensorflow::SequenceExample>();
|
.Get<tensorflow::SequenceExample>();
|
||||||
if (cc->Outputs().HasTag(kKeypointsTag)) {
|
|
||||||
keypoint_names_ = absl::StrSplit(options.keypoint_names(), ',');
|
|
||||||
default_keypoint_location_ = options.default_keypoint_location();
|
|
||||||
}
|
|
||||||
if (cc->OutputSidePackets().HasTag(kDataPath)) {
|
if (cc->OutputSidePackets().HasTag(kDataPath)) {
|
||||||
std::string root_directory = "";
|
std::string root_directory = "";
|
||||||
if (cc->InputSidePackets().HasTag(kDatasetRootDirTag)) {
|
if (cc->InputSidePackets().HasTag(kDatasetRootDirTag)) {
|
||||||
|
@ -349,19 +352,30 @@ class UnpackMediaSequenceCalculator : public CalculatorBase {
|
||||||
// all packets on all streams that have a timestamp between the current
|
// all packets on all streams that have a timestamp between the current
|
||||||
// reference timestep and the previous reference timestep. This ensures that
|
// reference timestep and the previous reference timestep. This ensures that
|
||||||
// we emit all timestamps in order, but also only emit a limited number in
|
// we emit all timestamps in order, but also only emit a limited number in
|
||||||
// any particular call to Process().
|
// any particular call to Process(). At the every end, we output the
|
||||||
int64 start_timestamp =
|
// poststream packets. If we only have poststream packets,
|
||||||
timestamps_[last_timestamp_key_][current_timestamp_index_];
|
// last_timestamp_key_ will be empty.
|
||||||
if (current_timestamp_index_ == 0) {
|
int64 start_timestamp = 0;
|
||||||
start_timestamp = first_timestamp_seen_;
|
int64 end_timestamp = 0;
|
||||||
|
if (last_timestamp_key_.empty() || process_poststream_) {
|
||||||
|
process_poststream_ = true;
|
||||||
|
start_timestamp = Timestamp::PostStream().Value();
|
||||||
|
end_timestamp = Timestamp::OneOverPostStream().Value();
|
||||||
|
} else {
|
||||||
|
start_timestamp =
|
||||||
|
timestamps_[last_timestamp_key_][current_timestamp_index_];
|
||||||
|
if (current_timestamp_index_ == 0) {
|
||||||
|
start_timestamp = first_timestamp_seen_;
|
||||||
|
}
|
||||||
|
|
||||||
|
end_timestamp = start_timestamp + 1; // Base case at end of sequence.
|
||||||
|
if (current_timestamp_index_ <
|
||||||
|
timestamps_[last_timestamp_key_].size() - 1) {
|
||||||
|
end_timestamp =
|
||||||
|
timestamps_[last_timestamp_key_][current_timestamp_index_ + 1];
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
int64 end_timestamp = start_timestamp + 1; // Base case at end of sequence.
|
|
||||||
if (current_timestamp_index_ <
|
|
||||||
timestamps_[last_timestamp_key_].size() - 1) {
|
|
||||||
end_timestamp =
|
|
||||||
timestamps_[last_timestamp_key_][current_timestamp_index_ + 1];
|
|
||||||
}
|
|
||||||
for (const auto& map_kv : timestamps_) {
|
for (const auto& map_kv : timestamps_) {
|
||||||
for (int i = 0; i < map_kv.second.size(); ++i) {
|
for (int i = 0; i < map_kv.second.size(); ++i) {
|
||||||
if (map_kv.second[i] >= start_timestamp &&
|
if (map_kv.second[i] >= start_timestamp &&
|
||||||
|
@ -438,7 +452,14 @@ class UnpackMediaSequenceCalculator : public CalculatorBase {
|
||||||
if (current_timestamp_index_ < timestamps_[last_timestamp_key_].size()) {
|
if (current_timestamp_index_ < timestamps_[last_timestamp_key_].size()) {
|
||||||
return absl::OkStatus();
|
return absl::OkStatus();
|
||||||
} else {
|
} else {
|
||||||
return tool::StatusStop();
|
if (process_poststream_) {
|
||||||
|
// Once we've processed the PostStream timestamp we can stop.
|
||||||
|
return tool::StatusStop();
|
||||||
|
} else {
|
||||||
|
// Otherwise, we still need to do one more pass to process it.
|
||||||
|
process_poststream_ = true;
|
||||||
|
return absl::OkStatus();
|
||||||
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -462,6 +483,7 @@ class UnpackMediaSequenceCalculator : public CalculatorBase {
|
||||||
std::vector<std::string> keypoint_names_;
|
std::vector<std::string> keypoint_names_;
|
||||||
// Default keypoint location when missing.
|
// Default keypoint location when missing.
|
||||||
float default_keypoint_location_;
|
float default_keypoint_location_;
|
||||||
|
bool process_poststream_;
|
||||||
};
|
};
|
||||||
REGISTER_CALCULATOR(UnpackMediaSequenceCalculator);
|
REGISTER_CALCULATOR(UnpackMediaSequenceCalculator);
|
||||||
} // namespace mediapipe
|
} // namespace mediapipe
|
||||||
|
|
|
@ -412,6 +412,72 @@ TEST_F(UnpackMediaSequenceCalculatorTest, UnpacksTwoPostStreamFloatLists) {
|
||||||
::testing::Eq(Timestamp::PostStream()));
|
::testing::Eq(Timestamp::PostStream()));
|
||||||
}
|
}
|
||||||
|
|
||||||
|
TEST_F(UnpackMediaSequenceCalculatorTest, UnpacksImageWithPostStreamFloatList) {
|
||||||
|
SetUpCalculator({"IMAGE:images"}, {});
|
||||||
|
auto input_sequence = absl::make_unique<tf::SequenceExample>();
|
||||||
|
std::string test_video_id = "test_video_id";
|
||||||
|
mpms::SetClipMediaId(test_video_id, input_sequence.get());
|
||||||
|
|
||||||
|
std::string test_image_string = "test_image_string";
|
||||||
|
|
||||||
|
int num_images = 1;
|
||||||
|
for (int i = 0; i < num_images; ++i) {
|
||||||
|
mpms::AddImageTimestamp(i, input_sequence.get());
|
||||||
|
mpms::AddImageEncoded(test_image_string, input_sequence.get());
|
||||||
|
}
|
||||||
|
|
||||||
|
mpms::AddFeatureFloats("FDENSE_MAX", {3.0f, 4.0f}, input_sequence.get());
|
||||||
|
mpms::AddFeatureTimestamp("FDENSE_MAX", Timestamp::PostStream().Value(),
|
||||||
|
input_sequence.get());
|
||||||
|
|
||||||
|
runner_->MutableSidePackets()->Tag("SEQUENCE_EXAMPLE") =
|
||||||
|
Adopt(input_sequence.release());
|
||||||
|
|
||||||
|
MP_ASSERT_OK(runner_->Run());
|
||||||
|
|
||||||
|
const std::vector<Packet>& output_packets =
|
||||||
|
runner_->Outputs().Tag("IMAGE").packets;
|
||||||
|
ASSERT_EQ(num_images, output_packets.size());
|
||||||
|
|
||||||
|
for (int i = 0; i < num_images; ++i) {
|
||||||
|
const std::string& output_image = output_packets[i].Get<std::string>();
|
||||||
|
ASSERT_EQ(output_image, test_image_string);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
TEST_F(UnpackMediaSequenceCalculatorTest, UnpacksPostStreamFloatListWithImage) {
|
||||||
|
SetUpCalculator({"FLOAT_FEATURE_FDENSE_MAX:max"}, {});
|
||||||
|
auto input_sequence = absl::make_unique<tf::SequenceExample>();
|
||||||
|
std::string test_video_id = "test_video_id";
|
||||||
|
mpms::SetClipMediaId(test_video_id, input_sequence.get());
|
||||||
|
|
||||||
|
std::string test_image_string = "test_image_string";
|
||||||
|
|
||||||
|
int num_images = 1;
|
||||||
|
for (int i = 0; i < num_images; ++i) {
|
||||||
|
mpms::AddImageTimestamp(i, input_sequence.get());
|
||||||
|
mpms::AddImageEncoded(test_image_string, input_sequence.get());
|
||||||
|
}
|
||||||
|
|
||||||
|
mpms::AddFeatureFloats("FDENSE_MAX", {3.0f, 4.0f}, input_sequence.get());
|
||||||
|
mpms::AddFeatureTimestamp("FDENSE_MAX", Timestamp::PostStream().Value(),
|
||||||
|
input_sequence.get());
|
||||||
|
|
||||||
|
runner_->MutableSidePackets()->Tag("SEQUENCE_EXAMPLE") =
|
||||||
|
Adopt(input_sequence.release());
|
||||||
|
|
||||||
|
MP_ASSERT_OK(runner_->Run());
|
||||||
|
|
||||||
|
const std::vector<Packet>& fdense_max_packets =
|
||||||
|
runner_->Outputs().Tag("FLOAT_FEATURE_FDENSE_MAX").packets;
|
||||||
|
ASSERT_EQ(fdense_max_packets.size(), 1);
|
||||||
|
const auto& fdense_max_vector =
|
||||||
|
fdense_max_packets[0].Get<std::vector<float>>();
|
||||||
|
ASSERT_THAT(fdense_max_vector, ::testing::ElementsAreArray({3.0f, 4.0f}));
|
||||||
|
ASSERT_THAT(fdense_max_packets[0].Timestamp(),
|
||||||
|
::testing::Eq(Timestamp::PostStream()));
|
||||||
|
}
|
||||||
|
|
||||||
TEST_F(UnpackMediaSequenceCalculatorTest, GetDatasetFromPacket) {
|
TEST_F(UnpackMediaSequenceCalculatorTest, GetDatasetFromPacket) {
|
||||||
SetUpCalculator({}, {"DATA_PATH:data_path"}, {"DATASET_ROOT:root"});
|
SetUpCalculator({}, {"DATA_PATH:data_path"}, {"DATASET_ROOT:root"});
|
||||||
|
|
||||||
|
|
|
@ -904,7 +904,8 @@ absl::Status TfLiteInferenceCalculator::LoadDelegate(CalculatorContext* cc) {
|
||||||
#if MEDIAPIPE_TFLITE_GL_INFERENCE
|
#if MEDIAPIPE_TFLITE_GL_INFERENCE
|
||||||
// Configure and create the delegate.
|
// Configure and create the delegate.
|
||||||
TfLiteGpuDelegateOptions options = TfLiteGpuDelegateOptionsDefault();
|
TfLiteGpuDelegateOptions options = TfLiteGpuDelegateOptionsDefault();
|
||||||
options.compile_options.precision_loss_allowed = 1;
|
options.compile_options.precision_loss_allowed =
|
||||||
|
allow_precision_loss_ ? 1 : 0;
|
||||||
options.compile_options.preferred_gl_object_type =
|
options.compile_options.preferred_gl_object_type =
|
||||||
TFLITE_GL_OBJECT_TYPE_FASTEST;
|
TFLITE_GL_OBJECT_TYPE_FASTEST;
|
||||||
options.compile_options.dynamic_batch_enabled = 0;
|
options.compile_options.dynamic_batch_enabled = 0;
|
||||||
|
@ -968,7 +969,7 @@ absl::Status TfLiteInferenceCalculator::LoadDelegate(CalculatorContext* cc) {
|
||||||
const int kHalfSize = 2; // sizeof(half)
|
const int kHalfSize = 2; // sizeof(half)
|
||||||
// Configure and create the delegate.
|
// Configure and create the delegate.
|
||||||
TFLGpuDelegateOptions options;
|
TFLGpuDelegateOptions options;
|
||||||
options.allow_precision_loss = true;
|
options.allow_precision_loss = allow_precision_loss_;
|
||||||
options.wait_type = TFLGpuDelegateWaitType::TFLGpuDelegateWaitTypeActive;
|
options.wait_type = TFLGpuDelegateWaitType::TFLGpuDelegateWaitTypeActive;
|
||||||
if (!delegate_)
|
if (!delegate_)
|
||||||
delegate_ = TfLiteDelegatePtr(TFLGpuDelegateCreate(&options),
|
delegate_ = TfLiteDelegatePtr(TFLGpuDelegateCreate(&options),
|
||||||
|
@ -1080,9 +1081,10 @@ absl::Status TfLiteInferenceCalculator::LoadDelegate(CalculatorContext* cc) {
|
||||||
}
|
}
|
||||||
|
|
||||||
// Create converter for GPU output.
|
// Create converter for GPU output.
|
||||||
converter_from_BPHWC4_ = [[TFLBufferConvert alloc] initWithDevice:device
|
converter_from_BPHWC4_ =
|
||||||
isFloat16:true
|
[[TFLBufferConvert alloc] initWithDevice:device
|
||||||
convertToPBHWC4:false];
|
isFloat16:allow_precision_loss_
|
||||||
|
convertToPBHWC4:false];
|
||||||
if (converter_from_BPHWC4_ == nil) {
|
if (converter_from_BPHWC4_ == nil) {
|
||||||
return absl::InternalError(
|
return absl::InternalError(
|
||||||
"Error initializating output buffer converter");
|
"Error initializating output buffer converter");
|
||||||
|
|
|
@ -439,7 +439,7 @@ absl::Status TfLiteTensorsToSegmentationCalculator::ProcessGpu(
|
||||||
|
|
||||||
// Run shader, upsample result.
|
// Run shader, upsample result.
|
||||||
{
|
{
|
||||||
gpu_helper_.BindFramebuffer(output_texture); // GL_TEXTURE0
|
gpu_helper_.BindFramebuffer(output_texture);
|
||||||
glActiveTexture(GL_TEXTURE1);
|
glActiveTexture(GL_TEXTURE1);
|
||||||
glBindTexture(GL_TEXTURE_2D, small_mask_texture.id());
|
glBindTexture(GL_TEXTURE_2D, small_mask_texture.id());
|
||||||
GlRender();
|
GlRender();
|
||||||
|
|
|
@ -821,6 +821,25 @@ cc_library(
|
||||||
alwayslink = 1,
|
alwayslink = 1,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
cc_test(
|
||||||
|
name = "landmark_projection_calculator_test",
|
||||||
|
srcs = ["landmark_projection_calculator_test.cc"],
|
||||||
|
deps = [
|
||||||
|
":landmark_projection_calculator",
|
||||||
|
"//mediapipe/calculators/tensor:image_to_tensor_utils",
|
||||||
|
"//mediapipe/framework:calculator_cc_proto",
|
||||||
|
"//mediapipe/framework:calculator_framework",
|
||||||
|
"//mediapipe/framework:calculator_runner",
|
||||||
|
"//mediapipe/framework/deps:message_matchers",
|
||||||
|
"//mediapipe/framework/formats:landmark_cc_proto",
|
||||||
|
"//mediapipe/framework/formats:rect_cc_proto",
|
||||||
|
"//mediapipe/framework/port:gtest_main",
|
||||||
|
"//mediapipe/framework/port:parse_text_proto",
|
||||||
|
"@com_google_absl//absl/memory",
|
||||||
|
"@com_google_googletest//:gtest_main",
|
||||||
|
],
|
||||||
|
)
|
||||||
|
|
||||||
mediapipe_proto_library(
|
mediapipe_proto_library(
|
||||||
name = "landmarks_smoothing_calculator_proto",
|
name = "landmarks_smoothing_calculator_proto",
|
||||||
srcs = ["landmarks_smoothing_calculator.proto"],
|
srcs = ["landmarks_smoothing_calculator.proto"],
|
||||||
|
@ -1252,3 +1271,45 @@ cc_test(
|
||||||
"//mediapipe/framework/port:parse_text_proto",
|
"//mediapipe/framework/port:parse_text_proto",
|
||||||
],
|
],
|
||||||
)
|
)
|
||||||
|
|
||||||
|
mediapipe_proto_library(
|
||||||
|
name = "refine_landmarks_from_heatmap_calculator_proto",
|
||||||
|
srcs = ["refine_landmarks_from_heatmap_calculator.proto"],
|
||||||
|
visibility = ["//visibility:public"],
|
||||||
|
deps = [
|
||||||
|
"//mediapipe/framework:calculator_options_proto",
|
||||||
|
"//mediapipe/framework:calculator_proto",
|
||||||
|
],
|
||||||
|
)
|
||||||
|
|
||||||
|
cc_library(
|
||||||
|
name = "refine_landmarks_from_heatmap_calculator",
|
||||||
|
srcs = ["refine_landmarks_from_heatmap_calculator.cc"],
|
||||||
|
hdrs = ["refine_landmarks_from_heatmap_calculator.h"],
|
||||||
|
copts = select({
|
||||||
|
"//mediapipe:apple": [
|
||||||
|
"-x objective-c++",
|
||||||
|
"-fobjc-arc", # enable reference-counting
|
||||||
|
],
|
||||||
|
"//conditions:default": [],
|
||||||
|
}),
|
||||||
|
visibility = ["//visibility:public"],
|
||||||
|
deps = [
|
||||||
|
":refine_landmarks_from_heatmap_calculator_cc_proto",
|
||||||
|
"//mediapipe/framework:calculator_framework",
|
||||||
|
"//mediapipe/framework/api2:node",
|
||||||
|
"//mediapipe/framework/formats:landmark_cc_proto",
|
||||||
|
"//mediapipe/framework/formats:tensor",
|
||||||
|
"//mediapipe/framework/port:statusor",
|
||||||
|
],
|
||||||
|
alwayslink = 1,
|
||||||
|
)
|
||||||
|
|
||||||
|
cc_test(
|
||||||
|
name = "refine_landmarks_from_heatmap_calculator_test",
|
||||||
|
srcs = ["refine_landmarks_from_heatmap_calculator_test.cc"],
|
||||||
|
deps = [
|
||||||
|
":refine_landmarks_from_heatmap_calculator",
|
||||||
|
"//mediapipe/framework/port:gtest_main",
|
||||||
|
],
|
||||||
|
)
|
||||||
|
|
|
@ -402,7 +402,7 @@ absl::Status AnnotationOverlayCalculator::RenderToGpu(CalculatorContext* cc,
|
||||||
|
|
||||||
// Blend overlay image in GPU shader.
|
// Blend overlay image in GPU shader.
|
||||||
{
|
{
|
||||||
gpu_helper_.BindFramebuffer(output_texture); // GL_TEXTURE0
|
gpu_helper_.BindFramebuffer(output_texture);
|
||||||
|
|
||||||
glActiveTexture(GL_TEXTURE1);
|
glActiveTexture(GL_TEXTURE1);
|
||||||
glBindTexture(GL_TEXTURE_2D, input_texture.name());
|
glBindTexture(GL_TEXTURE_2D, input_texture.name());
|
||||||
|
|
|
@ -54,6 +54,7 @@ class DetectionLabelIdToTextCalculator : public CalculatorBase {
|
||||||
|
|
||||||
private:
|
private:
|
||||||
absl::node_hash_map<int, std::string> label_map_;
|
absl::node_hash_map<int, std::string> label_map_;
|
||||||
|
::mediapipe::DetectionLabelIdToTextCalculatorOptions options_;
|
||||||
};
|
};
|
||||||
REGISTER_CALCULATOR(DetectionLabelIdToTextCalculator);
|
REGISTER_CALCULATOR(DetectionLabelIdToTextCalculator);
|
||||||
|
|
||||||
|
@ -68,13 +69,13 @@ absl::Status DetectionLabelIdToTextCalculator::GetContract(
|
||||||
absl::Status DetectionLabelIdToTextCalculator::Open(CalculatorContext* cc) {
|
absl::Status DetectionLabelIdToTextCalculator::Open(CalculatorContext* cc) {
|
||||||
cc->SetOffset(TimestampDiff(0));
|
cc->SetOffset(TimestampDiff(0));
|
||||||
|
|
||||||
const auto& options =
|
options_ =
|
||||||
cc->Options<::mediapipe::DetectionLabelIdToTextCalculatorOptions>();
|
cc->Options<::mediapipe::DetectionLabelIdToTextCalculatorOptions>();
|
||||||
|
|
||||||
if (options.has_label_map_path()) {
|
if (options_.has_label_map_path()) {
|
||||||
std::string string_path;
|
std::string string_path;
|
||||||
ASSIGN_OR_RETURN(string_path,
|
ASSIGN_OR_RETURN(string_path,
|
||||||
PathToResourceAsFile(options.label_map_path()));
|
PathToResourceAsFile(options_.label_map_path()));
|
||||||
std::string label_map_string;
|
std::string label_map_string;
|
||||||
MP_RETURN_IF_ERROR(file::GetContents(string_path, &label_map_string));
|
MP_RETURN_IF_ERROR(file::GetContents(string_path, &label_map_string));
|
||||||
|
|
||||||
|
@ -85,8 +86,8 @@ absl::Status DetectionLabelIdToTextCalculator::Open(CalculatorContext* cc) {
|
||||||
label_map_[i++] = line;
|
label_map_[i++] = line;
|
||||||
}
|
}
|
||||||
} else {
|
} else {
|
||||||
for (int i = 0; i < options.label_size(); ++i) {
|
for (int i = 0; i < options_.label_size(); ++i) {
|
||||||
label_map_[i] = options.label(i);
|
label_map_[i] = options_.label(i);
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
return absl::OkStatus();
|
return absl::OkStatus();
|
||||||
|
@ -106,7 +107,7 @@ absl::Status DetectionLabelIdToTextCalculator::Process(CalculatorContext* cc) {
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
// Remove label_id field if text labels exist.
|
// Remove label_id field if text labels exist.
|
||||||
if (has_text_label) {
|
if (has_text_label && !options_.keep_label_id()) {
|
||||||
output_detection.clear_label_id();
|
output_detection.clear_label_id();
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
|
@ -31,4 +31,9 @@ message DetectionLabelIdToTextCalculatorOptions {
|
||||||
// label: "label for id 1"
|
// label: "label for id 1"
|
||||||
// ...
|
// ...
|
||||||
repeated string label = 2;
|
repeated string label = 2;
|
||||||
|
|
||||||
|
// By default, the `label_id` field from the input is stripped if a text label
|
||||||
|
// could be found. By setting this field to true, it is always copied to the
|
||||||
|
// output detections.
|
||||||
|
optional bool keep_label_id = 3;
|
||||||
}
|
}
|
||||||
|
|
|
@ -120,7 +120,11 @@ absl::Status LabelsToRenderDataCalculator::Process(CalculatorContext* cc) {
|
||||||
labels.resize(classifications.classification_size());
|
labels.resize(classifications.classification_size());
|
||||||
scores.resize(classifications.classification_size());
|
scores.resize(classifications.classification_size());
|
||||||
for (int i = 0; i < classifications.classification_size(); ++i) {
|
for (int i = 0; i < classifications.classification_size(); ++i) {
|
||||||
labels[i] = classifications.classification(i).label();
|
if (options_.use_display_name()) {
|
||||||
|
labels[i] = classifications.classification(i).display_name();
|
||||||
|
} else {
|
||||||
|
labels[i] = classifications.classification(i).label();
|
||||||
|
}
|
||||||
scores[i] = classifications.classification(i).score();
|
scores[i] = classifications.classification(i).score();
|
||||||
}
|
}
|
||||||
} else {
|
} else {
|
||||||
|
|
|
@ -59,4 +59,7 @@ message LabelsToRenderDataCalculatorOptions {
|
||||||
BOTTOM_LEFT = 1;
|
BOTTOM_LEFT = 1;
|
||||||
}
|
}
|
||||||
optional Location location = 6 [default = TOP_LEFT];
|
optional Location location = 6 [default = TOP_LEFT];
|
||||||
|
|
||||||
|
// Uses Classification.display_name field instead of Classification.label.
|
||||||
|
optional bool use_display_name = 9 [default = false];
|
||||||
}
|
}
|
||||||
|
|
|
@ -13,6 +13,7 @@
|
||||||
// limitations under the License.
|
// limitations under the License.
|
||||||
|
|
||||||
#include <cmath>
|
#include <cmath>
|
||||||
|
#include <functional>
|
||||||
#include <vector>
|
#include <vector>
|
||||||
|
|
||||||
#include "mediapipe/calculators/util/landmark_projection_calculator.pb.h"
|
#include "mediapipe/calculators/util/landmark_projection_calculator.pb.h"
|
||||||
|
@ -27,20 +28,32 @@ namespace {
|
||||||
|
|
||||||
constexpr char kLandmarksTag[] = "NORM_LANDMARKS";
|
constexpr char kLandmarksTag[] = "NORM_LANDMARKS";
|
||||||
constexpr char kRectTag[] = "NORM_RECT";
|
constexpr char kRectTag[] = "NORM_RECT";
|
||||||
|
constexpr char kProjectionMatrix[] = "PROJECTION_MATRIX";
|
||||||
|
|
||||||
} // namespace
|
} // namespace
|
||||||
|
|
||||||
// Projects normalized landmarks in a rectangle to its original coordinates. The
|
// Projects normalized landmarks to its original coordinates.
|
||||||
// rectangle must also be in normalized coordinates.
|
|
||||||
// Input:
|
// Input:
|
||||||
// NORM_LANDMARKS: A NormalizedLandmarkList representing landmarks
|
// NORM_LANDMARKS - NormalizedLandmarkList
|
||||||
// in a normalized rectangle.
|
// Represents landmarks in a normalized rectangle if NORM_RECT is specified
|
||||||
// NORM_RECT: An NormalizedRect representing a normalized rectangle in image
|
// or landmarks that should be projected using PROJECTION_MATRIX if
|
||||||
// coordinates.
|
// specified. (Prefer using PROJECTION_MATRIX as it eliminates need of
|
||||||
|
// letterbox removal step.)
|
||||||
|
// NORM_RECT - NormalizedRect
|
||||||
|
// Represents a normalized rectangle in image coordinates and results in
|
||||||
|
// landmarks with their locations adjusted to the image.
|
||||||
|
// PROJECTION_MATRIX - std::array<float, 16>
|
||||||
|
// A 4x4 row-major-order matrix that maps landmarks' locations from one
|
||||||
|
// coordinate system to another. In this case from the coordinate system of
|
||||||
|
// the normalized region of interest to the coordinate system of the image.
|
||||||
|
//
|
||||||
|
// Note: either NORM_RECT or PROJECTION_MATRIX has to be specified.
|
||||||
|
// Note: landmark's Z is projected in a custom way - it's scaled by width of
|
||||||
|
// the normalized region of interest used during landmarks detection.
|
||||||
//
|
//
|
||||||
// Output:
|
// Output:
|
||||||
// NORM_LANDMARKS: A NormalizedLandmarkList representing landmarks
|
// NORM_LANDMARKS - NormalizedLandmarkList
|
||||||
// with their locations adjusted to the image.
|
// Landmarks with their locations adjusted according to the inputs.
|
||||||
//
|
//
|
||||||
// Usage example:
|
// Usage example:
|
||||||
// node {
|
// node {
|
||||||
|
@ -58,12 +71,27 @@ constexpr char kRectTag[] = "NORM_RECT";
|
||||||
// output_stream: "NORM_LANDMARKS:0:projected_landmarks_0"
|
// output_stream: "NORM_LANDMARKS:0:projected_landmarks_0"
|
||||||
// output_stream: "NORM_LANDMARKS:1:projected_landmarks_1"
|
// output_stream: "NORM_LANDMARKS:1:projected_landmarks_1"
|
||||||
// }
|
// }
|
||||||
|
//
|
||||||
|
// node {
|
||||||
|
// calculator: "LandmarkProjectionCalculator"
|
||||||
|
// input_stream: "NORM_LANDMARKS:landmarks"
|
||||||
|
// input_stream: "PROECTION_MATRIX:matrix"
|
||||||
|
// output_stream: "NORM_LANDMARKS:projected_landmarks"
|
||||||
|
// }
|
||||||
|
//
|
||||||
|
// node {
|
||||||
|
// calculator: "LandmarkProjectionCalculator"
|
||||||
|
// input_stream: "NORM_LANDMARKS:0:landmarks_0"
|
||||||
|
// input_stream: "NORM_LANDMARKS:1:landmarks_1"
|
||||||
|
// input_stream: "PROECTION_MATRIX:matrix"
|
||||||
|
// output_stream: "NORM_LANDMARKS:0:projected_landmarks_0"
|
||||||
|
// output_stream: "NORM_LANDMARKS:1:projected_landmarks_1"
|
||||||
|
// }
|
||||||
class LandmarkProjectionCalculator : public CalculatorBase {
|
class LandmarkProjectionCalculator : public CalculatorBase {
|
||||||
public:
|
public:
|
||||||
static absl::Status GetContract(CalculatorContract* cc) {
|
static absl::Status GetContract(CalculatorContract* cc) {
|
||||||
RET_CHECK(cc->Inputs().HasTag(kLandmarksTag) &&
|
RET_CHECK(cc->Inputs().HasTag(kLandmarksTag))
|
||||||
cc->Inputs().HasTag(kRectTag))
|
<< "Missing NORM_LANDMARKS input.";
|
||||||
<< "Missing one or more input streams.";
|
|
||||||
|
|
||||||
RET_CHECK_EQ(cc->Inputs().NumEntries(kLandmarksTag),
|
RET_CHECK_EQ(cc->Inputs().NumEntries(kLandmarksTag),
|
||||||
cc->Outputs().NumEntries(kLandmarksTag))
|
cc->Outputs().NumEntries(kLandmarksTag))
|
||||||
|
@ -73,7 +101,14 @@ class LandmarkProjectionCalculator : public CalculatorBase {
|
||||||
id != cc->Inputs().EndId(kLandmarksTag); ++id) {
|
id != cc->Inputs().EndId(kLandmarksTag); ++id) {
|
||||||
cc->Inputs().Get(id).Set<NormalizedLandmarkList>();
|
cc->Inputs().Get(id).Set<NormalizedLandmarkList>();
|
||||||
}
|
}
|
||||||
cc->Inputs().Tag(kRectTag).Set<NormalizedRect>();
|
RET_CHECK(cc->Inputs().HasTag(kRectTag) ^
|
||||||
|
cc->Inputs().HasTag(kProjectionMatrix))
|
||||||
|
<< "Either NORM_RECT or PROJECTION_MATRIX must be specified.";
|
||||||
|
if (cc->Inputs().HasTag(kRectTag)) {
|
||||||
|
cc->Inputs().Tag(kRectTag).Set<NormalizedRect>();
|
||||||
|
} else {
|
||||||
|
cc->Inputs().Tag(kProjectionMatrix).Set<std::array<float, 16>>();
|
||||||
|
}
|
||||||
|
|
||||||
for (CollectionItemId id = cc->Outputs().BeginId(kLandmarksTag);
|
for (CollectionItemId id = cc->Outputs().BeginId(kLandmarksTag);
|
||||||
id != cc->Outputs().EndId(kLandmarksTag); ++id) {
|
id != cc->Outputs().EndId(kLandmarksTag); ++id) {
|
||||||
|
@ -89,31 +124,50 @@ class LandmarkProjectionCalculator : public CalculatorBase {
|
||||||
return absl::OkStatus();
|
return absl::OkStatus();
|
||||||
}
|
}
|
||||||
|
|
||||||
|
static void ProjectXY(const NormalizedLandmark& lm,
|
||||||
|
const std::array<float, 16>& matrix,
|
||||||
|
NormalizedLandmark* out) {
|
||||||
|
out->set_x(lm.x() * matrix[0] + lm.y() * matrix[1] + lm.z() * matrix[2] +
|
||||||
|
matrix[3]);
|
||||||
|
out->set_y(lm.x() * matrix[4] + lm.y() * matrix[5] + lm.z() * matrix[6] +
|
||||||
|
matrix[7]);
|
||||||
|
}
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Landmark's Z scale is equal to a relative (to image) width of region of
|
||||||
|
* interest used during detection. To calculate based on matrix:
|
||||||
|
* 1. Project (0,0) --- (1,0) segment using matrix.
|
||||||
|
* 2. Calculate length of the projected segment.
|
||||||
|
*/
|
||||||
|
static float CalculateZScale(const std::array<float, 16>& matrix) {
|
||||||
|
NormalizedLandmark a;
|
||||||
|
a.set_x(0.0f);
|
||||||
|
a.set_y(0.0f);
|
||||||
|
NormalizedLandmark b;
|
||||||
|
b.set_x(1.0f);
|
||||||
|
b.set_y(0.0f);
|
||||||
|
NormalizedLandmark a_projected;
|
||||||
|
ProjectXY(a, matrix, &a_projected);
|
||||||
|
NormalizedLandmark b_projected;
|
||||||
|
ProjectXY(b, matrix, &b_projected);
|
||||||
|
return std::sqrt(std::pow(b_projected.x() - a_projected.x(), 2) +
|
||||||
|
std::pow(b_projected.y() - a_projected.y(), 2));
|
||||||
|
}
|
||||||
|
|
||||||
absl::Status Process(CalculatorContext* cc) override {
|
absl::Status Process(CalculatorContext* cc) override {
|
||||||
if (cc->Inputs().Tag(kRectTag).IsEmpty()) {
|
std::function<void(const NormalizedLandmark&, NormalizedLandmark*)>
|
||||||
return absl::OkStatus();
|
project_fn;
|
||||||
}
|
if (cc->Inputs().HasTag(kRectTag)) {
|
||||||
const auto& input_rect = cc->Inputs().Tag(kRectTag).Get<NormalizedRect>();
|
if (cc->Inputs().Tag(kRectTag).IsEmpty()) {
|
||||||
|
return absl::OkStatus();
|
||||||
const auto& options =
|
|
||||||
cc->Options<::mediapipe::LandmarkProjectionCalculatorOptions>();
|
|
||||||
|
|
||||||
CollectionItemId input_id = cc->Inputs().BeginId(kLandmarksTag);
|
|
||||||
CollectionItemId output_id = cc->Outputs().BeginId(kLandmarksTag);
|
|
||||||
// Number of inputs and outpus is the same according to the contract.
|
|
||||||
for (; input_id != cc->Inputs().EndId(kLandmarksTag);
|
|
||||||
++input_id, ++output_id) {
|
|
||||||
const auto& input_packet = cc->Inputs().Get(input_id);
|
|
||||||
if (input_packet.IsEmpty()) {
|
|
||||||
continue;
|
|
||||||
}
|
}
|
||||||
|
const auto& input_rect = cc->Inputs().Tag(kRectTag).Get<NormalizedRect>();
|
||||||
const auto& input_landmarks = input_packet.Get<NormalizedLandmarkList>();
|
const auto& options =
|
||||||
NormalizedLandmarkList output_landmarks;
|
cc->Options<mediapipe::LandmarkProjectionCalculatorOptions>();
|
||||||
for (int i = 0; i < input_landmarks.landmark_size(); ++i) {
|
project_fn = [&input_rect, &options](const NormalizedLandmark& landmark,
|
||||||
const NormalizedLandmark& landmark = input_landmarks.landmark(i);
|
NormalizedLandmark* new_landmark) {
|
||||||
NormalizedLandmark* new_landmark = output_landmarks.add_landmark();
|
// TODO: fix projection or deprecate (current projection
|
||||||
|
// calculations are incorrect for general case).
|
||||||
const float x = landmark.x() - 0.5f;
|
const float x = landmark.x() - 0.5f;
|
||||||
const float y = landmark.y() - 0.5f;
|
const float y = landmark.y() - 0.5f;
|
||||||
const float angle =
|
const float angle =
|
||||||
|
@ -130,10 +184,44 @@ class LandmarkProjectionCalculator : public CalculatorBase {
|
||||||
new_landmark->set_x(new_x);
|
new_landmark->set_x(new_x);
|
||||||
new_landmark->set_y(new_y);
|
new_landmark->set_y(new_y);
|
||||||
new_landmark->set_z(new_z);
|
new_landmark->set_z(new_z);
|
||||||
|
};
|
||||||
|
} else if (cc->Inputs().HasTag(kProjectionMatrix)) {
|
||||||
|
if (cc->Inputs().Tag(kProjectionMatrix).IsEmpty()) {
|
||||||
|
return absl::OkStatus();
|
||||||
|
}
|
||||||
|
const auto& project_mat =
|
||||||
|
cc->Inputs().Tag(kProjectionMatrix).Get<std::array<float, 16>>();
|
||||||
|
const float z_scale = CalculateZScale(project_mat);
|
||||||
|
project_fn = [&project_mat, z_scale](const NormalizedLandmark& lm,
|
||||||
|
NormalizedLandmark* new_landmark) {
|
||||||
|
*new_landmark = lm;
|
||||||
|
ProjectXY(lm, project_mat, new_landmark);
|
||||||
|
new_landmark->set_z(z_scale * lm.z());
|
||||||
|
};
|
||||||
|
} else {
|
||||||
|
return absl::InternalError("Either rect or matrix must be specified.");
|
||||||
|
}
|
||||||
|
|
||||||
|
CollectionItemId input_id = cc->Inputs().BeginId(kLandmarksTag);
|
||||||
|
CollectionItemId output_id = cc->Outputs().BeginId(kLandmarksTag);
|
||||||
|
// Number of inputs and outpus is the same according to the contract.
|
||||||
|
for (; input_id != cc->Inputs().EndId(kLandmarksTag);
|
||||||
|
++input_id, ++output_id) {
|
||||||
|
const auto& input_packet = cc->Inputs().Get(input_id);
|
||||||
|
if (input_packet.IsEmpty()) {
|
||||||
|
continue;
|
||||||
|
}
|
||||||
|
|
||||||
|
const auto& input_landmarks = input_packet.Get<NormalizedLandmarkList>();
|
||||||
|
NormalizedLandmarkList output_landmarks;
|
||||||
|
for (int i = 0; i < input_landmarks.landmark_size(); ++i) {
|
||||||
|
const NormalizedLandmark& landmark = input_landmarks.landmark(i);
|
||||||
|
NormalizedLandmark* new_landmark = output_landmarks.add_landmark();
|
||||||
|
project_fn(landmark, new_landmark);
|
||||||
}
|
}
|
||||||
|
|
||||||
cc->Outputs().Get(output_id).AddPacket(
|
cc->Outputs().Get(output_id).AddPacket(
|
||||||
MakePacket<NormalizedLandmarkList>(output_landmarks)
|
MakePacket<NormalizedLandmarkList>(std::move(output_landmarks))
|
||||||
.At(cc->InputTimestamp()));
|
.At(cc->InputTimestamp()));
|
||||||
}
|
}
|
||||||
return absl::OkStatus();
|
return absl::OkStatus();
|
||||||
|
|
|
@ -0,0 +1,240 @@
|
||||||
|
#include <array>
|
||||||
|
#include <vector>
|
||||||
|
|
||||||
|
#include "absl/memory/memory.h"
|
||||||
|
#include "mediapipe/calculators/tensor/image_to_tensor_utils.h"
|
||||||
|
#include "mediapipe/framework/calculator.pb.h"
|
||||||
|
#include "mediapipe/framework/calculator_framework.h"
|
||||||
|
#include "mediapipe/framework/calculator_runner.h"
|
||||||
|
#include "mediapipe/framework/deps/message_matchers.h"
|
||||||
|
#include "mediapipe/framework/formats/landmark.pb.h"
|
||||||
|
#include "mediapipe/framework/formats/rect.pb.h"
|
||||||
|
#include "mediapipe/framework/port/gtest.h"
|
||||||
|
#include "mediapipe/framework/port/parse_text_proto.h"
|
||||||
|
#include "mediapipe/framework/port/status_matchers.h"
|
||||||
|
|
||||||
|
namespace mediapipe {
|
||||||
|
namespace {
|
||||||
|
|
||||||
|
absl::StatusOr<mediapipe::NormalizedLandmarkList> RunCalculator(
|
||||||
|
mediapipe::NormalizedLandmarkList input, mediapipe::NormalizedRect rect) {
|
||||||
|
mediapipe::CalculatorRunner runner(
|
||||||
|
ParseTextProtoOrDie<mediapipe::CalculatorGraphConfig::Node>(R"pb(
|
||||||
|
calculator: "LandmarkProjectionCalculator"
|
||||||
|
input_stream: "NORM_LANDMARKS:landmarks"
|
||||||
|
input_stream: "NORM_RECT:rect"
|
||||||
|
output_stream: "NORM_LANDMARKS:projected_landmarks"
|
||||||
|
)pb"));
|
||||||
|
runner.MutableInputs()
|
||||||
|
->Tag("NORM_LANDMARKS")
|
||||||
|
.packets.push_back(
|
||||||
|
MakePacket<mediapipe::NormalizedLandmarkList>(std::move(input))
|
||||||
|
.At(Timestamp(1)));
|
||||||
|
runner.MutableInputs()
|
||||||
|
->Tag("NORM_RECT")
|
||||||
|
.packets.push_back(MakePacket<mediapipe::NormalizedRect>(std::move(rect))
|
||||||
|
.At(Timestamp(1)));
|
||||||
|
|
||||||
|
MP_RETURN_IF_ERROR(runner.Run());
|
||||||
|
const auto& output_packets = runner.Outputs().Tag("NORM_LANDMARKS").packets;
|
||||||
|
RET_CHECK_EQ(output_packets.size(), 1);
|
||||||
|
return output_packets[0].Get<mediapipe::NormalizedLandmarkList>();
|
||||||
|
}
|
||||||
|
|
||||||
|
TEST(LandmarkProjectionCalculatorTest, ProjectingWithDefaultRect) {
|
||||||
|
mediapipe::NormalizedLandmarkList landmarks =
|
||||||
|
ParseTextProtoOrDie<mediapipe::NormalizedLandmarkList>(R"pb(
|
||||||
|
landmark { x: 10, y: 20, z: -0.5 }
|
||||||
|
)pb");
|
||||||
|
mediapipe::NormalizedRect rect =
|
||||||
|
ParseTextProtoOrDie<mediapipe::NormalizedRect>(
|
||||||
|
R"pb(
|
||||||
|
x_center: 0.5,
|
||||||
|
y_center: 0.5,
|
||||||
|
width: 1.0,
|
||||||
|
height: 1.0,
|
||||||
|
rotation: 0.0
|
||||||
|
)pb");
|
||||||
|
|
||||||
|
auto status_or_result = RunCalculator(std::move(landmarks), std::move(rect));
|
||||||
|
MP_ASSERT_OK(status_or_result);
|
||||||
|
|
||||||
|
EXPECT_THAT(
|
||||||
|
status_or_result.value(),
|
||||||
|
EqualsProto(ParseTextProtoOrDie<mediapipe::NormalizedLandmarkList>(R"pb(
|
||||||
|
landmark { x: 10, y: 20, z: -0.5 }
|
||||||
|
)pb")));
|
||||||
|
}
|
||||||
|
|
||||||
|
mediapipe::NormalizedRect GetCroppedRect() {
|
||||||
|
return ParseTextProtoOrDie<mediapipe::NormalizedRect>(
|
||||||
|
R"pb(
|
||||||
|
x_center: 0.5, y_center: 0.5, width: 0.5, height: 2, rotation: 0.0
|
||||||
|
)pb");
|
||||||
|
}
|
||||||
|
|
||||||
|
mediapipe::NormalizedLandmarkList GetCroppedRectTestInput() {
|
||||||
|
return ParseTextProtoOrDie<mediapipe::NormalizedLandmarkList>(R"pb(
|
||||||
|
landmark { x: 1.0, y: 1.0, z: -0.5 }
|
||||||
|
)pb");
|
||||||
|
}
|
||||||
|
|
||||||
|
mediapipe::NormalizedLandmarkList GetCroppedRectTestExpectedResult() {
|
||||||
|
return ParseTextProtoOrDie<mediapipe::NormalizedLandmarkList>(R"pb(
|
||||||
|
landmark { x: 0.75, y: 1.5, z: -0.25 }
|
||||||
|
)pb");
|
||||||
|
}
|
||||||
|
|
||||||
|
TEST(LandmarkProjectionCalculatorTest, ProjectingWithCroppedRect) {
|
||||||
|
auto status_or_result =
|
||||||
|
RunCalculator(GetCroppedRectTestInput(), GetCroppedRect());
|
||||||
|
MP_ASSERT_OK(status_or_result);
|
||||||
|
|
||||||
|
EXPECT_THAT(status_or_result.value(),
|
||||||
|
EqualsProto(GetCroppedRectTestExpectedResult()));
|
||||||
|
}
|
||||||
|
|
||||||
|
absl::StatusOr<mediapipe::NormalizedLandmarkList> RunCalculator(
|
||||||
|
mediapipe::NormalizedLandmarkList input, std::array<float, 16> matrix) {
|
||||||
|
mediapipe::CalculatorRunner runner(
|
||||||
|
ParseTextProtoOrDie<mediapipe::CalculatorGraphConfig::Node>(R"pb(
|
||||||
|
calculator: "LandmarkProjectionCalculator"
|
||||||
|
input_stream: "NORM_LANDMARKS:landmarks"
|
||||||
|
input_stream: "PROJECTION_MATRIX:matrix"
|
||||||
|
output_stream: "NORM_LANDMARKS:projected_landmarks"
|
||||||
|
)pb"));
|
||||||
|
runner.MutableInputs()
|
||||||
|
->Tag("NORM_LANDMARKS")
|
||||||
|
.packets.push_back(
|
||||||
|
MakePacket<mediapipe::NormalizedLandmarkList>(std::move(input))
|
||||||
|
.At(Timestamp(1)));
|
||||||
|
runner.MutableInputs()
|
||||||
|
->Tag("PROJECTION_MATRIX")
|
||||||
|
.packets.push_back(MakePacket<std::array<float, 16>>(std::move(matrix))
|
||||||
|
.At(Timestamp(1)));
|
||||||
|
|
||||||
|
MP_RETURN_IF_ERROR(runner.Run());
|
||||||
|
const auto& output_packets = runner.Outputs().Tag("NORM_LANDMARKS").packets;
|
||||||
|
RET_CHECK_EQ(output_packets.size(), 1);
|
||||||
|
return output_packets[0].Get<mediapipe::NormalizedLandmarkList>();
|
||||||
|
}
|
||||||
|
|
||||||
|
TEST(LandmarkProjectionCalculatorTest, ProjectingWithIdentityMatrix) {
|
||||||
|
mediapipe::NormalizedLandmarkList landmarks =
|
||||||
|
ParseTextProtoOrDie<mediapipe::NormalizedLandmarkList>(R"pb(
|
||||||
|
landmark { x: 10, y: 20, z: -0.5 }
|
||||||
|
)pb");
|
||||||
|
// clang-format off
|
||||||
|
std::array<float, 16> matrix = {
|
||||||
|
1.0f, 0.0f, 0.0f, 0.0f,
|
||||||
|
0.0f, 1.0f, 0.0f, 0.0f,
|
||||||
|
0.0f, 0.0f, 1.0f, 0.0f,
|
||||||
|
0.0f, 0.0f, 0.0f, 1.0f,
|
||||||
|
};
|
||||||
|
// clang-format on
|
||||||
|
|
||||||
|
auto status_or_result =
|
||||||
|
RunCalculator(std::move(landmarks), std::move(matrix));
|
||||||
|
MP_ASSERT_OK(status_or_result);
|
||||||
|
|
||||||
|
EXPECT_THAT(
|
||||||
|
status_or_result.value(),
|
||||||
|
EqualsProto(ParseTextProtoOrDie<mediapipe::NormalizedLandmarkList>(R"pb(
|
||||||
|
landmark { x: 10, y: 20, z: -0.5 }
|
||||||
|
)pb")));
|
||||||
|
}
|
||||||
|
|
||||||
|
TEST(LandmarkProjectionCalculatorTest, ProjectingWithCroppedRectMatrix) {
|
||||||
|
constexpr int kRectWidth = 1280;
|
||||||
|
constexpr int kRectHeight = 720;
|
||||||
|
auto roi = GetRoi(kRectWidth, kRectHeight, GetCroppedRect());
|
||||||
|
std::array<float, 16> matrix;
|
||||||
|
GetRotatedSubRectToRectTransformMatrix(roi, kRectWidth, kRectHeight,
|
||||||
|
/*flip_horizontaly=*/false, &matrix);
|
||||||
|
auto status_or_result = RunCalculator(GetCroppedRectTestInput(), matrix);
|
||||||
|
MP_ASSERT_OK(status_or_result);
|
||||||
|
|
||||||
|
EXPECT_THAT(status_or_result.value(),
|
||||||
|
EqualsProto(GetCroppedRectTestExpectedResult()));
|
||||||
|
}
|
||||||
|
|
||||||
|
TEST(LandmarkProjectionCalculatorTest, ProjectingWithScaleMatrix) {
|
||||||
|
mediapipe::NormalizedLandmarkList landmarks =
|
||||||
|
ParseTextProtoOrDie<mediapipe::NormalizedLandmarkList>(R"pb(
|
||||||
|
landmark { x: 10, y: 20, z: -0.5 }
|
||||||
|
landmark { x: 5, y: 6, z: 7 }
|
||||||
|
)pb");
|
||||||
|
// clang-format off
|
||||||
|
std::array<float, 16> matrix = {
|
||||||
|
10.0f, 0.0f, 0.0f, 0.0f,
|
||||||
|
0.0f, 100.0f, 0.0f, 0.0f,
|
||||||
|
0.0f, 0.0f, 1.0f, 0.0f,
|
||||||
|
0.0f, 0.0f, 0.0f, 1.0f,
|
||||||
|
};
|
||||||
|
// clang-format on
|
||||||
|
|
||||||
|
auto status_or_result =
|
||||||
|
RunCalculator(std::move(landmarks), std::move(matrix));
|
||||||
|
MP_ASSERT_OK(status_or_result);
|
||||||
|
|
||||||
|
EXPECT_THAT(
|
||||||
|
status_or_result.value(),
|
||||||
|
EqualsProto(ParseTextProtoOrDie<mediapipe::NormalizedLandmarkList>(R"pb(
|
||||||
|
landmark { x: 100, y: 2000, z: -5 }
|
||||||
|
landmark { x: 50, y: 600, z: 70 }
|
||||||
|
)pb")));
|
||||||
|
}
|
||||||
|
|
||||||
|
TEST(LandmarkProjectionCalculatorTest, ProjectingWithTranslateMatrix) {
|
||||||
|
mediapipe::NormalizedLandmarkList landmarks =
|
||||||
|
ParseTextProtoOrDie<mediapipe::NormalizedLandmarkList>(R"pb(
|
||||||
|
landmark { x: 10, y: 20, z: -0.5 }
|
||||||
|
)pb");
|
||||||
|
// clang-format off
|
||||||
|
std::array<float, 16> matrix = {
|
||||||
|
1.0f, 0.0f, 0.0f, 1.0f,
|
||||||
|
0.0f, 1.0f, 0.0f, 2.0f,
|
||||||
|
0.0f, 0.0f, 1.0f, 0.0f,
|
||||||
|
0.0f, 0.0f, 0.0f, 1.0f,
|
||||||
|
};
|
||||||
|
// clang-format on
|
||||||
|
|
||||||
|
auto status_or_result =
|
||||||
|
RunCalculator(std::move(landmarks), std::move(matrix));
|
||||||
|
MP_ASSERT_OK(status_or_result);
|
||||||
|
|
||||||
|
EXPECT_THAT(
|
||||||
|
status_or_result.value(),
|
||||||
|
EqualsProto(ParseTextProtoOrDie<mediapipe::NormalizedLandmarkList>(R"pb(
|
||||||
|
landmark { x: 11, y: 22, z: -0.5 }
|
||||||
|
)pb")));
|
||||||
|
}
|
||||||
|
|
||||||
|
TEST(LandmarkProjectionCalculatorTest, ProjectingWithRotationMatrix) {
|
||||||
|
mediapipe::NormalizedLandmarkList landmarks =
|
||||||
|
ParseTextProtoOrDie<mediapipe::NormalizedLandmarkList>(R"pb(
|
||||||
|
landmark { x: 4, y: 0, z: -0.5 }
|
||||||
|
)pb");
|
||||||
|
// clang-format off
|
||||||
|
// 90 degrees rotation matrix
|
||||||
|
std::array<float, 16> matrix = {
|
||||||
|
0.0f, -1.0f, 0.0f, 0.0f,
|
||||||
|
1.0f, 0.0f, 0.0f, 0.0f,
|
||||||
|
0.0f, 0.0f, 1.0f, 0.0f,
|
||||||
|
0.0f, 0.0f, 0.0f, 1.0f,
|
||||||
|
};
|
||||||
|
// clang-format on
|
||||||
|
|
||||||
|
auto status_or_result =
|
||||||
|
RunCalculator(std::move(landmarks), std::move(matrix));
|
||||||
|
MP_ASSERT_OK(status_or_result);
|
||||||
|
|
||||||
|
EXPECT_THAT(
|
||||||
|
status_or_result.value(),
|
||||||
|
EqualsProto(ParseTextProtoOrDie<mediapipe::NormalizedLandmarkList>(R"pb(
|
||||||
|
landmark { x: 0, y: 4, z: -0.5 }
|
||||||
|
)pb")));
|
||||||
|
}
|
||||||
|
|
||||||
|
} // namespace
|
||||||
|
} // namespace mediapipe
|
|
@ -1,3 +1,17 @@
|
||||||
|
// Copyright 2020 The MediaPipe Authors.
|
||||||
|
//
|
||||||
|
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
// you may not use this file except in compliance with the License.
|
||||||
|
// You may obtain a copy of the License at
|
||||||
|
//
|
||||||
|
// http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
//
|
||||||
|
// Unless required by applicable law or agreed to in writing, software
|
||||||
|
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
// See the License for the specific language governing permissions and
|
||||||
|
// limitations under the License.
|
||||||
|
|
||||||
#include "mediapipe/calculators/util/rect_to_render_scale_calculator.pb.h"
|
#include "mediapipe/calculators/util/rect_to_render_scale_calculator.pb.h"
|
||||||
#include "mediapipe/framework/calculator_framework.h"
|
#include "mediapipe/framework/calculator_framework.h"
|
||||||
#include "mediapipe/framework/formats/rect.pb.h"
|
#include "mediapipe/framework/formats/rect.pb.h"
|
||||||
|
|
|
@ -1,3 +1,17 @@
|
||||||
|
// Copyright 2020 The MediaPipe Authors.
|
||||||
|
//
|
||||||
|
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
// you may not use this file except in compliance with the License.
|
||||||
|
// You may obtain a copy of the License at
|
||||||
|
//
|
||||||
|
// http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
//
|
||||||
|
// Unless required by applicable law or agreed to in writing, software
|
||||||
|
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
// See the License for the specific language governing permissions and
|
||||||
|
// limitations under the License.
|
||||||
|
|
||||||
syntax = "proto2";
|
syntax = "proto2";
|
||||||
|
|
||||||
package mediapipe;
|
package mediapipe;
|
||||||
|
|
|
@ -0,0 +1,166 @@
|
||||||
|
// Copyright 2021 The MediaPipe Authors.
|
||||||
|
//
|
||||||
|
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
// you may not use this file except in compliance with the License.
|
||||||
|
// You may obtain a copy of the License at
|
||||||
|
//
|
||||||
|
// http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
//
|
||||||
|
// Unless required by applicable law or agreed to in writing, software
|
||||||
|
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
// See the License for the specific language governing permissions and
|
||||||
|
// limitations under the License.
|
||||||
|
|
||||||
|
#include "mediapipe/calculators/util/refine_landmarks_from_heatmap_calculator.h"
|
||||||
|
|
||||||
|
#include "mediapipe/calculators/util/refine_landmarks_from_heatmap_calculator.pb.h"
|
||||||
|
#include "mediapipe/framework/calculator_framework.h"
|
||||||
|
|
||||||
|
namespace mediapipe {
|
||||||
|
|
||||||
|
namespace {
|
||||||
|
|
||||||
|
inline float Sigmoid(float value) { return 1.0f / (1.0f + std::exp(-value)); }
|
||||||
|
|
||||||
|
absl::StatusOr<std::tuple<int, int, int>> GetHwcFromDims(
|
||||||
|
const std::vector<int>& dims) {
|
||||||
|
if (dims.size() == 3) {
|
||||||
|
return std::make_tuple(dims[0], dims[1], dims[2]);
|
||||||
|
} else if (dims.size() == 4) {
|
||||||
|
// BHWC format check B == 1
|
||||||
|
RET_CHECK_EQ(1, dims[0]) << "Expected batch to be 1 for BHWC heatmap";
|
||||||
|
return std::make_tuple(dims[1], dims[2], dims[3]);
|
||||||
|
} else {
|
||||||
|
RET_CHECK(false) << "Invalid shape size for heatmap tensor" << dims.size();
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
} // namespace
|
||||||
|
|
||||||
|
namespace api2 {
|
||||||
|
|
||||||
|
// Refines landmarks using correspond heatmap area.
|
||||||
|
//
|
||||||
|
// Input:
|
||||||
|
// NORM_LANDMARKS - Required. Input normalized landmarks to update.
|
||||||
|
// TENSORS - Required. Vector of input tensors. 0th element should be heatmap.
|
||||||
|
// The rest is unused.
|
||||||
|
// Output:
|
||||||
|
// NORM_LANDMARKS - Required. Updated normalized landmarks.
|
||||||
|
class RefineLandmarksFromHeatmapCalculatorImpl
|
||||||
|
: public NodeImpl<RefineLandmarksFromHeatmapCalculator,
|
||||||
|
RefineLandmarksFromHeatmapCalculatorImpl> {
|
||||||
|
public:
|
||||||
|
absl::Status Open(CalculatorContext* cc) override { return absl::OkStatus(); }
|
||||||
|
|
||||||
|
absl::Status Process(CalculatorContext* cc) override {
|
||||||
|
// Make sure we bypass landmarks if there is no detection.
|
||||||
|
if (kInLandmarks(cc).IsEmpty()) {
|
||||||
|
return absl::OkStatus();
|
||||||
|
}
|
||||||
|
// If for some reason heatmap is missing, just return original landmarks.
|
||||||
|
if (kInTensors(cc).IsEmpty()) {
|
||||||
|
kOutLandmarks(cc).Send(*kInLandmarks(cc));
|
||||||
|
return absl::OkStatus();
|
||||||
|
}
|
||||||
|
|
||||||
|
// Check basic prerequisites.
|
||||||
|
const auto& input_tensors = *kInTensors(cc);
|
||||||
|
RET_CHECK(!input_tensors.empty()) << "Empty input tensors list. First "
|
||||||
|
"element is expeced to be a heatmap";
|
||||||
|
|
||||||
|
const auto& hm_tensor = input_tensors[0];
|
||||||
|
const auto& in_lms = *kInLandmarks(cc);
|
||||||
|
auto hm_view = hm_tensor.GetCpuReadView();
|
||||||
|
auto hm_raw = hm_view.buffer<float>();
|
||||||
|
const auto& options =
|
||||||
|
cc->Options<mediapipe::RefineLandmarksFromHeatmapCalculatorOptions>();
|
||||||
|
|
||||||
|
ASSIGN_OR_RETURN(auto out_lms, RefineLandmarksFromHeatMap(
|
||||||
|
in_lms, hm_raw, hm_tensor.shape().dims,
|
||||||
|
options.kernel_size(),
|
||||||
|
options.min_confidence_to_refine()));
|
||||||
|
|
||||||
|
kOutLandmarks(cc).Send(std::move(out_lms));
|
||||||
|
return absl::OkStatus();
|
||||||
|
}
|
||||||
|
};
|
||||||
|
|
||||||
|
} // namespace api2
|
||||||
|
|
||||||
|
// Runs actual refinement
|
||||||
|
// High level algorithm:
|
||||||
|
//
|
||||||
|
// Heatmap is accepted as tensor in HWC layout where i-th channel is a heatmap
|
||||||
|
// for the i-th landmark.
|
||||||
|
//
|
||||||
|
// For each landmark we replace original value with a value calculated from the
|
||||||
|
// area in heatmap close to original landmark position (in particular are
|
||||||
|
// covered with kernel of size options.kernel_size). To calculate new coordinate
|
||||||
|
// from heatmap we calculate an weighted average inside the kernel. We update
|
||||||
|
// the landmark iff heatmap is confident in it's prediction i.e. max(heatmap) in
|
||||||
|
// kernel is at least options.min_confidence_to_refine big.
|
||||||
|
absl::StatusOr<mediapipe::NormalizedLandmarkList> RefineLandmarksFromHeatMap(
|
||||||
|
const mediapipe::NormalizedLandmarkList& in_lms,
|
||||||
|
const float* heatmap_raw_data, const std::vector<int>& heatmap_dims,
|
||||||
|
int kernel_size, float min_confidence_to_refine) {
|
||||||
|
ASSIGN_OR_RETURN(auto hm_dims, GetHwcFromDims(heatmap_dims));
|
||||||
|
auto [hm_height, hm_width, hm_channels] = hm_dims;
|
||||||
|
|
||||||
|
RET_CHECK_EQ(in_lms.landmark_size(), hm_channels)
|
||||||
|
<< "Expected heatmap to have number of layers == to number of "
|
||||||
|
"landmarks";
|
||||||
|
|
||||||
|
int hm_row_size = hm_width * hm_channels;
|
||||||
|
int hm_pixel_size = hm_channels;
|
||||||
|
|
||||||
|
mediapipe::NormalizedLandmarkList out_lms = in_lms;
|
||||||
|
for (int lm_index = 0; lm_index < out_lms.landmark_size(); ++lm_index) {
|
||||||
|
int center_col = out_lms.landmark(lm_index).x() * hm_width;
|
||||||
|
int center_row = out_lms.landmark(lm_index).y() * hm_height;
|
||||||
|
// Point is outside of the image let's keep it intact.
|
||||||
|
if (center_col < 0 || center_col >= hm_width || center_row < 0 ||
|
||||||
|
center_col >= hm_height) {
|
||||||
|
continue;
|
||||||
|
}
|
||||||
|
|
||||||
|
int offset = (kernel_size - 1) / 2;
|
||||||
|
// Calculate area to iterate over. Note that we decrease the kernel on
|
||||||
|
// the edges of the heatmap. Equivalent to zero border.
|
||||||
|
int begin_col = std::max(0, center_col - offset);
|
||||||
|
int end_col = std::min(hm_width, center_col + offset + 1);
|
||||||
|
int begin_row = std::max(0, center_row - offset);
|
||||||
|
int end_row = std::min(hm_height, center_row + offset + 1);
|
||||||
|
|
||||||
|
float sum = 0;
|
||||||
|
float weighted_col = 0;
|
||||||
|
float weighted_row = 0;
|
||||||
|
float max_value = 0;
|
||||||
|
|
||||||
|
// Main loop. Go over kernel and calculate weighted sum of coordinates,
|
||||||
|
// sum of weights and max weights.
|
||||||
|
for (int row = begin_row; row < end_row; ++row) {
|
||||||
|
for (int col = begin_col; col < end_col; ++col) {
|
||||||
|
// We expect memory to be in HWC layout without padding.
|
||||||
|
int idx = hm_row_size * row + hm_pixel_size * col + lm_index;
|
||||||
|
// Right now we hardcode sigmoid activation as it will be wasteful to
|
||||||
|
// calculate sigmoid for each value of heatmap in the model itself. If
|
||||||
|
// we ever have other activations it should be trivial to expand via
|
||||||
|
// options.
|
||||||
|
float confidence = Sigmoid(heatmap_raw_data[idx]);
|
||||||
|
sum += confidence;
|
||||||
|
max_value = std::max(max_value, confidence);
|
||||||
|
weighted_col += col * confidence;
|
||||||
|
weighted_row += row * confidence;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
if (max_value >= min_confidence_to_refine && sum > 0) {
|
||||||
|
out_lms.mutable_landmark(lm_index)->set_x(weighted_col / hm_width / sum);
|
||||||
|
out_lms.mutable_landmark(lm_index)->set_y(weighted_row / hm_height / sum);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return out_lms;
|
||||||
|
}
|
||||||
|
|
||||||
|
} // namespace mediapipe
|
|
@ -0,0 +1,50 @@
|
||||||
|
// Copyright 2021 The MediaPipe Authors.
|
||||||
|
//
|
||||||
|
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
// you may not use this file except in compliance with the License.
|
||||||
|
// You may obtain a copy of the License at
|
||||||
|
//
|
||||||
|
// http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
//
|
||||||
|
// Unless required by applicable law or agreed to in writing, software
|
||||||
|
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
// See the License for the specific language governing permissions and
|
||||||
|
// limitations under the License.
|
||||||
|
|
||||||
|
#ifndef MEDIAPIPE_CALCULATORS_UTIL_REFINE_LANDMARKS_FROM_HEATMAP_CALCULATOR_H_
|
||||||
|
#define MEDIAPIPE_CALCULATORS_UTIL_REFINE_LANDMARKS_FROM_HEATMAP_CALCULATOR_H_
|
||||||
|
|
||||||
|
#include <vector>
|
||||||
|
|
||||||
|
#include "mediapipe/framework/api2/node.h"
|
||||||
|
#include "mediapipe/framework/formats/landmark.pb.h"
|
||||||
|
#include "mediapipe/framework/formats/tensor.h"
|
||||||
|
#include "mediapipe/framework/port/statusor.h"
|
||||||
|
|
||||||
|
namespace mediapipe {
|
||||||
|
namespace api2 {
|
||||||
|
|
||||||
|
class RefineLandmarksFromHeatmapCalculator : public NodeIntf {
|
||||||
|
public:
|
||||||
|
static constexpr Input<mediapipe::NormalizedLandmarkList> kInLandmarks{
|
||||||
|
"NORM_LANDMARKS"};
|
||||||
|
static constexpr Input<std::vector<Tensor>> kInTensors{"TENSORS"};
|
||||||
|
static constexpr Output<mediapipe::NormalizedLandmarkList> kOutLandmarks{
|
||||||
|
"NORM_LANDMARKS"};
|
||||||
|
|
||||||
|
MEDIAPIPE_NODE_INTERFACE(RefineLandmarksFromHeatmapCalculator, kInLandmarks,
|
||||||
|
kInTensors, kOutLandmarks);
|
||||||
|
};
|
||||||
|
|
||||||
|
} // namespace api2
|
||||||
|
|
||||||
|
// Exposed for testing.
|
||||||
|
absl::StatusOr<mediapipe::NormalizedLandmarkList> RefineLandmarksFromHeatMap(
|
||||||
|
const mediapipe::NormalizedLandmarkList& in_lms,
|
||||||
|
const float* heatmap_raw_data, const std::vector<int>& heatmap_dims,
|
||||||
|
int kernel_size, float min_confidence_to_refine);
|
||||||
|
|
||||||
|
} // namespace mediapipe
|
||||||
|
|
||||||
|
#endif // MEDIAPIPE_CALCULATORS_UTIL_REFINE_LANDMARKS_FROM_HEATMAP_CALCULATOR_H_
|
|
@ -0,0 +1,27 @@
|
||||||
|
// Copyright 2021 The MediaPipe Authors.
|
||||||
|
//
|
||||||
|
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
// you may not use this file except in compliance with the License.
|
||||||
|
// You may obtain a copy of the License at
|
||||||
|
//
|
||||||
|
// http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
//
|
||||||
|
// Unless required by applicable law or agreed to in writing, software
|
||||||
|
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
// See the License for the specific language governing permissions and
|
||||||
|
// limitations under the License.
|
||||||
|
|
||||||
|
syntax = "proto2";
|
||||||
|
|
||||||
|
package mediapipe;
|
||||||
|
|
||||||
|
import "mediapipe/framework/calculator.proto";
|
||||||
|
|
||||||
|
message RefineLandmarksFromHeatmapCalculatorOptions {
|
||||||
|
extend mediapipe.CalculatorOptions {
|
||||||
|
optional RefineLandmarksFromHeatmapCalculatorOptions ext = 362281653;
|
||||||
|
}
|
||||||
|
optional int32 kernel_size = 1 [default = 9];
|
||||||
|
optional float min_confidence_to_refine = 2 [default = 0.5];
|
||||||
|
}
|
|
@ -0,0 +1,152 @@
|
||||||
|
// Copyright 2021 The MediaPipe Authors.
|
||||||
|
//
|
||||||
|
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
// you may not use this file except in compliance with the License.
|
||||||
|
// You may obtain a copy of the License at
|
||||||
|
//
|
||||||
|
// http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
//
|
||||||
|
// Unless required by applicable law or agreed to in writing, software
|
||||||
|
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
// See the License for the specific language governing permissions and
|
||||||
|
// limitations under the License.
|
||||||
|
|
||||||
|
#include "mediapipe/calculators/util/refine_landmarks_from_heatmap_calculator.h"
|
||||||
|
|
||||||
|
#include "mediapipe/framework/port/gmock.h"
|
||||||
|
#include "mediapipe/framework/port/gtest.h"
|
||||||
|
#include "mediapipe/framework/port/status_matchers.h"
|
||||||
|
|
||||||
|
namespace mediapipe {
|
||||||
|
namespace {
|
||||||
|
|
||||||
|
mediapipe::NormalizedLandmarkList vec_to_lms(
|
||||||
|
const std::vector<std::pair<float, float>>& inp) {
|
||||||
|
mediapipe::NormalizedLandmarkList ret;
|
||||||
|
for (const auto& it : inp) {
|
||||||
|
auto new_lm = ret.add_landmark();
|
||||||
|
new_lm->set_x(it.first);
|
||||||
|
new_lm->set_y(it.second);
|
||||||
|
}
|
||||||
|
return ret;
|
||||||
|
}
|
||||||
|
|
||||||
|
std::vector<std::pair<float, float>> lms_to_vec(
|
||||||
|
const mediapipe::NormalizedLandmarkList& lst) {
|
||||||
|
std::vector<std::pair<float, float>> ret;
|
||||||
|
for (const auto& lm : lst.landmark()) {
|
||||||
|
ret.push_back({lm.x(), lm.y()});
|
||||||
|
}
|
||||||
|
return ret;
|
||||||
|
}
|
||||||
|
|
||||||
|
std::vector<float> CHW_to_HWC(std::vector<float> inp, int height, int width,
|
||||||
|
int depth) {
|
||||||
|
std::vector<float> ret(inp.size());
|
||||||
|
const float* inp_ptr = inp.data();
|
||||||
|
for (int c = 0; c < depth; ++c) {
|
||||||
|
for (int row = 0; row < height; ++row) {
|
||||||
|
for (int col = 0; col < width; ++col) {
|
||||||
|
int dest_idx = width * depth * row + depth * col + c;
|
||||||
|
ret[dest_idx] = *inp_ptr;
|
||||||
|
++inp_ptr;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return ret;
|
||||||
|
}
|
||||||
|
|
||||||
|
using testing::ElementsAre;
|
||||||
|
using testing::FloatEq;
|
||||||
|
using testing::Pair;
|
||||||
|
|
||||||
|
TEST(RefineLandmarksFromHeatmapTest, Smoke) {
|
||||||
|
float z = -10000000000000000;
|
||||||
|
// clang-format off
|
||||||
|
std::vector<float> hm = {
|
||||||
|
z, z, z,
|
||||||
|
1, z, z,
|
||||||
|
z, z, z};
|
||||||
|
// clang-format on
|
||||||
|
|
||||||
|
auto ret_or_error = RefineLandmarksFromHeatMap(vec_to_lms({{0.5, 0.5}}),
|
||||||
|
hm.data(), {3, 3, 1}, 3, 0.1);
|
||||||
|
MP_EXPECT_OK(ret_or_error);
|
||||||
|
EXPECT_THAT(lms_to_vec(*ret_or_error),
|
||||||
|
ElementsAre(Pair(FloatEq(0), FloatEq(1 / 3.))));
|
||||||
|
}
|
||||||
|
|
||||||
|
TEST(RefineLandmarksFromHeatmapTest, MultiLayer) {
|
||||||
|
float z = -10000000000000000;
|
||||||
|
// clang-format off
|
||||||
|
std::vector<float> hm = CHW_to_HWC({
|
||||||
|
z, z, z,
|
||||||
|
1, z, z,
|
||||||
|
z, z, z,
|
||||||
|
z, z, z,
|
||||||
|
1, z, z,
|
||||||
|
z, z, z,
|
||||||
|
z, z, z,
|
||||||
|
1, z, z,
|
||||||
|
z, z, z}, 3, 3, 3);
|
||||||
|
// clang-format on
|
||||||
|
|
||||||
|
auto ret_or_error = RefineLandmarksFromHeatMap(
|
||||||
|
vec_to_lms({{0.5, 0.5}, {0.5, 0.5}, {0.5, 0.5}}), hm.data(), {3, 3, 3}, 3,
|
||||||
|
0.1);
|
||||||
|
MP_EXPECT_OK(ret_or_error);
|
||||||
|
EXPECT_THAT(lms_to_vec(*ret_or_error),
|
||||||
|
ElementsAre(Pair(FloatEq(0), FloatEq(1 / 3.)),
|
||||||
|
Pair(FloatEq(0), FloatEq(1 / 3.)),
|
||||||
|
Pair(FloatEq(0), FloatEq(1 / 3.))));
|
||||||
|
}
|
||||||
|
|
||||||
|
TEST(RefineLandmarksFromHeatmapTest, KeepIfNotSure) {
|
||||||
|
float z = -10000000000000000;
|
||||||
|
// clang-format off
|
||||||
|
std::vector<float> hm = CHW_to_HWC({
|
||||||
|
z, z, z,
|
||||||
|
0, z, z,
|
||||||
|
z, z, z,
|
||||||
|
z, z, z,
|
||||||
|
0, z, z,
|
||||||
|
z, z, z,
|
||||||
|
z, z, z,
|
||||||
|
0, z, z,
|
||||||
|
z, z, z}, 3, 3, 3);
|
||||||
|
// clang-format on
|
||||||
|
|
||||||
|
auto ret_or_error = RefineLandmarksFromHeatMap(
|
||||||
|
vec_to_lms({{0.5, 0.5}, {0.5, 0.5}, {0.5, 0.5}}), hm.data(), {3, 3, 3}, 3,
|
||||||
|
0.6);
|
||||||
|
MP_EXPECT_OK(ret_or_error);
|
||||||
|
EXPECT_THAT(lms_to_vec(*ret_or_error),
|
||||||
|
ElementsAre(Pair(FloatEq(0.5), FloatEq(0.5)),
|
||||||
|
Pair(FloatEq(0.5), FloatEq(0.5)),
|
||||||
|
Pair(FloatEq(0.5), FloatEq(0.5))));
|
||||||
|
}
|
||||||
|
|
||||||
|
TEST(RefineLandmarksFromHeatmapTest, Border) {
|
||||||
|
float z = -10000000000000000;
|
||||||
|
// clang-format off
|
||||||
|
std::vector<float> hm = CHW_to_HWC({
|
||||||
|
z, z, z,
|
||||||
|
0, z, 0,
|
||||||
|
z, z, z,
|
||||||
|
|
||||||
|
z, z, z,
|
||||||
|
0, z, 0,
|
||||||
|
z, z, 0}, 3, 3, 2);
|
||||||
|
// clang-format on
|
||||||
|
|
||||||
|
auto ret_or_error = RefineLandmarksFromHeatMap(
|
||||||
|
vec_to_lms({{0.0, 0.0}, {0.9, 0.9}}), hm.data(), {3, 3, 2}, 3, 0.1);
|
||||||
|
MP_EXPECT_OK(ret_or_error);
|
||||||
|
EXPECT_THAT(lms_to_vec(*ret_or_error),
|
||||||
|
ElementsAre(Pair(FloatEq(0), FloatEq(1 / 3.)),
|
||||||
|
Pair(FloatEq(2 / 3.), FloatEq(1 / 6. + 2 / 6.))));
|
||||||
|
}
|
||||||
|
|
||||||
|
} // namespace
|
||||||
|
} // namespace mediapipe
|
|
@ -101,7 +101,7 @@ absl::Status ThresholdingCalculator::Open(CalculatorContext* cc) {
|
||||||
}
|
}
|
||||||
|
|
||||||
if (cc->InputSidePackets().HasTag("THRESHOLD")) {
|
if (cc->InputSidePackets().HasTag("THRESHOLD")) {
|
||||||
threshold_ = cc->InputSidePackets().Tag("THRESHOLD").Get<float>();
|
threshold_ = cc->InputSidePackets().Tag("THRESHOLD").Get<double>();
|
||||||
}
|
}
|
||||||
return absl::OkStatus();
|
return absl::OkStatus();
|
||||||
}
|
}
|
||||||
|
|
|
@ -43,8 +43,7 @@ android_binary(
|
||||||
"//mediapipe/modules/hand_landmark:handedness.txt",
|
"//mediapipe/modules/hand_landmark:handedness.txt",
|
||||||
"//mediapipe/modules/holistic_landmark:hand_recrop.tflite",
|
"//mediapipe/modules/holistic_landmark:hand_recrop.tflite",
|
||||||
"//mediapipe/modules/pose_detection:pose_detection.tflite",
|
"//mediapipe/modules/pose_detection:pose_detection.tflite",
|
||||||
"//mediapipe/modules/pose_landmark:pose_landmark_upper_body.tflite",
|
"//mediapipe/modules/pose_landmark:pose_landmark_full.tflite",
|
||||||
"//mediapipe/modules/pose_landmark:pose_landmark_full_body.tflite",
|
|
||||||
],
|
],
|
||||||
assets_dir = "",
|
assets_dir = "",
|
||||||
manifest = "//mediapipe/examples/android/src/java/com/google/mediapipe/apps/basic:AndroidManifest.xml",
|
manifest = "//mediapipe/examples/android/src/java/com/google/mediapipe/apps/basic:AndroidManifest.xml",
|
||||||
|
|
|
@ -37,7 +37,7 @@ android_binary(
|
||||||
srcs = glob(["*.java"]),
|
srcs = glob(["*.java"]),
|
||||||
assets = [
|
assets = [
|
||||||
"//mediapipe/graphs/pose_tracking:pose_tracking_gpu.binarypb",
|
"//mediapipe/graphs/pose_tracking:pose_tracking_gpu.binarypb",
|
||||||
"//mediapipe/modules/pose_landmark:pose_landmark_full_body.tflite",
|
"//mediapipe/modules/pose_landmark:pose_landmark_full.tflite",
|
||||||
"//mediapipe/modules/pose_detection:pose_detection.tflite",
|
"//mediapipe/modules/pose_detection:pose_detection.tflite",
|
||||||
],
|
],
|
||||||
assets_dir = "",
|
assets_dir = "",
|
||||||
|
|
|
@ -1,64 +0,0 @@
|
||||||
# Copyright 2019 The MediaPipe Authors.
|
|
||||||
#
|
|
||||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
||||||
# you may not use this file except in compliance with the License.
|
|
||||||
# You may obtain a copy of the License at
|
|
||||||
#
|
|
||||||
# http://www.apache.org/licenses/LICENSE-2.0
|
|
||||||
#
|
|
||||||
# Unless required by applicable law or agreed to in writing, software
|
|
||||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
||||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
||||||
# See the License for the specific language governing permissions and
|
|
||||||
# limitations under the License.
|
|
||||||
|
|
||||||
licenses(["notice"])
|
|
||||||
|
|
||||||
package(default_visibility = ["//visibility:private"])
|
|
||||||
|
|
||||||
cc_binary(
|
|
||||||
name = "libmediapipe_jni.so",
|
|
||||||
linkshared = 1,
|
|
||||||
linkstatic = 1,
|
|
||||||
deps = [
|
|
||||||
"//mediapipe/graphs/pose_tracking:upper_body_pose_tracking_gpu_deps",
|
|
||||||
"//mediapipe/java/com/google/mediapipe/framework/jni:mediapipe_framework_jni",
|
|
||||||
],
|
|
||||||
)
|
|
||||||
|
|
||||||
cc_library(
|
|
||||||
name = "mediapipe_jni_lib",
|
|
||||||
srcs = [":libmediapipe_jni.so"],
|
|
||||||
alwayslink = 1,
|
|
||||||
)
|
|
||||||
|
|
||||||
android_binary(
|
|
||||||
name = "upperbodyposetrackinggpu",
|
|
||||||
srcs = glob(["*.java"]),
|
|
||||||
assets = [
|
|
||||||
"//mediapipe/graphs/pose_tracking:upper_body_pose_tracking_gpu.binarypb",
|
|
||||||
"//mediapipe/modules/pose_landmark:pose_landmark_upper_body.tflite",
|
|
||||||
"//mediapipe/modules/pose_detection:pose_detection.tflite",
|
|
||||||
],
|
|
||||||
assets_dir = "",
|
|
||||||
manifest = "//mediapipe/examples/android/src/java/com/google/mediapipe/apps/basic:AndroidManifest.xml",
|
|
||||||
manifest_values = {
|
|
||||||
"applicationId": "com.google.mediapipe.apps.upperbodyposetrackinggpu",
|
|
||||||
"appName": "Upper Body Pose Tracking",
|
|
||||||
"mainActivity": ".MainActivity",
|
|
||||||
"cameraFacingFront": "False",
|
|
||||||
"binaryGraphName": "upper_body_pose_tracking_gpu.binarypb",
|
|
||||||
"inputVideoStreamName": "input_video",
|
|
||||||
"outputVideoStreamName": "output_video",
|
|
||||||
"flipFramesVertically": "True",
|
|
||||||
"converterNumBuffers": "2",
|
|
||||||
},
|
|
||||||
multidex = "native",
|
|
||||||
deps = [
|
|
||||||
":mediapipe_jni_lib",
|
|
||||||
"//mediapipe/examples/android/src/java/com/google/mediapipe/apps/basic:basic_lib",
|
|
||||||
"//mediapipe/framework/formats:landmark_java_proto_lite",
|
|
||||||
"//mediapipe/java/com/google/mediapipe/framework:android_framework",
|
|
||||||
"@com_google_protobuf//:protobuf_javalite",
|
|
||||||
],
|
|
||||||
)
|
|
|
@ -1,75 +0,0 @@
|
||||||
// Copyright 2019 The MediaPipe Authors.
|
|
||||||
//
|
|
||||||
// Licensed under the Apache License, Version 2.0 (the "License");
|
|
||||||
// you may not use this file except in compliance with the License.
|
|
||||||
// You may obtain a copy of the License at
|
|
||||||
//
|
|
||||||
// http://www.apache.org/licenses/LICENSE-2.0
|
|
||||||
//
|
|
||||||
// Unless required by applicable law or agreed to in writing, software
|
|
||||||
// distributed under the License is distributed on an "AS IS" BASIS,
|
|
||||||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
||||||
// See the License for the specific language governing permissions and
|
|
||||||
// limitations under the License.
|
|
||||||
|
|
||||||
package com.google.mediapipe.apps.upperbodyposetrackinggpu;
|
|
||||||
|
|
||||||
import android.os.Bundle;
|
|
||||||
import android.util.Log;
|
|
||||||
import com.google.mediapipe.formats.proto.LandmarkProto.NormalizedLandmark;
|
|
||||||
import com.google.mediapipe.formats.proto.LandmarkProto.NormalizedLandmarkList;
|
|
||||||
import com.google.mediapipe.framework.PacketGetter;
|
|
||||||
import com.google.protobuf.InvalidProtocolBufferException;
|
|
||||||
|
|
||||||
/** Main activity of MediaPipe upper-body pose tracking app. */
|
|
||||||
public class MainActivity extends com.google.mediapipe.apps.basic.MainActivity {
|
|
||||||
private static final String TAG = "MainActivity";
|
|
||||||
|
|
||||||
private static final String OUTPUT_LANDMARKS_STREAM_NAME = "pose_landmarks";
|
|
||||||
|
|
||||||
@Override
|
|
||||||
protected void onCreate(Bundle savedInstanceState) {
|
|
||||||
super.onCreate(savedInstanceState);
|
|
||||||
|
|
||||||
// To show verbose logging, run:
|
|
||||||
// adb shell setprop log.tag.MainActivity VERBOSE
|
|
||||||
if (Log.isLoggable(TAG, Log.VERBOSE)) {
|
|
||||||
processor.addPacketCallback(
|
|
||||||
OUTPUT_LANDMARKS_STREAM_NAME,
|
|
||||||
(packet) -> {
|
|
||||||
Log.v(TAG, "Received pose landmarks packet.");
|
|
||||||
try {
|
|
||||||
NormalizedLandmarkList poseLandmarks =
|
|
||||||
PacketGetter.getProto(packet, NormalizedLandmarkList.class);
|
|
||||||
Log.v(
|
|
||||||
TAG,
|
|
||||||
"[TS:"
|
|
||||||
+ packet.getTimestamp()
|
|
||||||
+ "] "
|
|
||||||
+ getPoseLandmarksDebugString(poseLandmarks));
|
|
||||||
} catch (InvalidProtocolBufferException exception) {
|
|
||||||
Log.e(TAG, "Failed to get proto.", exception);
|
|
||||||
}
|
|
||||||
});
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
private static String getPoseLandmarksDebugString(NormalizedLandmarkList poseLandmarks) {
|
|
||||||
String poseLandmarkStr = "Pose landmarks: " + poseLandmarks.getLandmarkCount() + "\n";
|
|
||||||
int landmarkIndex = 0;
|
|
||||||
for (NormalizedLandmark landmark : poseLandmarks.getLandmarkList()) {
|
|
||||||
poseLandmarkStr +=
|
|
||||||
"\tLandmark ["
|
|
||||||
+ landmarkIndex
|
|
||||||
+ "]: ("
|
|
||||||
+ landmark.getX()
|
|
||||||
+ ", "
|
|
||||||
+ landmark.getY()
|
|
||||||
+ ", "
|
|
||||||
+ landmark.getZ()
|
|
||||||
+ ")\n";
|
|
||||||
++landmarkIndex;
|
|
||||||
}
|
|
||||||
return poseLandmarkStr;
|
|
||||||
}
|
|
||||||
}
|
|
|
@ -142,26 +142,13 @@ node {
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
# Maps detection label IDs to the corresponding label text ("Face"). The label
|
|
||||||
# map is provided in the label_map_path option.
|
|
||||||
node {
|
|
||||||
calculator: "DetectionLabelIdToTextCalculator"
|
|
||||||
input_stream: "filtered_detections"
|
|
||||||
output_stream: "labeled_detections"
|
|
||||||
options: {
|
|
||||||
[mediapipe.DetectionLabelIdToTextCalculatorOptions.ext] {
|
|
||||||
label_map_path: "mediapipe/models/face_detection_front_labelmap.txt"
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
# Adjusts detection locations (already normalized to [0.f, 1.f]) on the
|
# Adjusts detection locations (already normalized to [0.f, 1.f]) on the
|
||||||
# letterboxed image (after image transformation with the FIT scale mode) to the
|
# letterboxed image (after image transformation with the FIT scale mode) to the
|
||||||
# corresponding locations on the same image with the letterbox removed (the
|
# corresponding locations on the same image with the letterbox removed (the
|
||||||
# input image to the graph before image transformation).
|
# input image to the graph before image transformation).
|
||||||
node {
|
node {
|
||||||
calculator: "DetectionLetterboxRemovalCalculator"
|
calculator: "DetectionLetterboxRemovalCalculator"
|
||||||
input_stream: "DETECTIONS:labeled_detections"
|
input_stream: "DETECTIONS:filtered_detections"
|
||||||
input_stream: "LETTERBOX_PADDING:letterbox_padding"
|
input_stream: "LETTERBOX_PADDING:letterbox_padding"
|
||||||
output_stream: "DETECTIONS:output_detections"
|
output_stream: "DETECTIONS:output_detections"
|
||||||
}
|
}
|
||||||
|
|
|
@ -33,6 +33,10 @@ constexpr char kDetections[] = "DETECTIONS";
|
||||||
constexpr char kDetectedBorders[] = "BORDERS";
|
constexpr char kDetectedBorders[] = "BORDERS";
|
||||||
constexpr char kCropRect[] = "CROP_RECT";
|
constexpr char kCropRect[] = "CROP_RECT";
|
||||||
constexpr char kFirstCropRect[] = "FIRST_CROP_RECT";
|
constexpr char kFirstCropRect[] = "FIRST_CROP_RECT";
|
||||||
|
// Can be used to control whether an animated zoom should actually performed
|
||||||
|
// (configured through option us_to_first_rect). If provided, a non-zero integer
|
||||||
|
// will allow the animated zoom to be used when the first detections arrive.
|
||||||
|
constexpr char kAnimateZoom[] = "ANIMATE_ZOOM";
|
||||||
// Field-of-view (degrees) of the camera's x-axis (width).
|
// Field-of-view (degrees) of the camera's x-axis (width).
|
||||||
// TODO: Parameterize FOV based on camera specs.
|
// TODO: Parameterize FOV based on camera specs.
|
||||||
constexpr float kFieldOfView = 60;
|
constexpr float kFieldOfView = 60;
|
||||||
|
@ -76,10 +80,10 @@ class ContentZoomingCalculator : public CalculatorBase {
|
||||||
absl::Status InitializeState(int frame_width, int frame_height);
|
absl::Status InitializeState(int frame_width, int frame_height);
|
||||||
// Adjusts state to work with an updated frame size.
|
// Adjusts state to work with an updated frame size.
|
||||||
absl::Status UpdateForResolutionChange(int frame_width, int frame_height);
|
absl::Status UpdateForResolutionChange(int frame_width, int frame_height);
|
||||||
// Returns true if we are zooming to the initial rect.
|
// Returns true if we are animating to the first rect.
|
||||||
bool IsZoomingToInitialRect(const Timestamp& timestamp) const;
|
bool IsAnimatingToFirstRect(const Timestamp& timestamp) const;
|
||||||
// Builds the output rectangle when zooming to the initial rect.
|
// Builds the output rectangle when animating to the first rect.
|
||||||
absl::StatusOr<mediapipe::Rect> GetInitialZoomingRect(
|
absl::StatusOr<mediapipe::Rect> GetAnimationRect(
|
||||||
int frame_width, int frame_height, const Timestamp& timestamp) const;
|
int frame_width, int frame_height, const Timestamp& timestamp) const;
|
||||||
// Converts bounds to tilt offset, pan offset and height.
|
// Converts bounds to tilt offset, pan offset and height.
|
||||||
absl::Status ConvertToPanTiltZoom(float xmin, float xmax, float ymin,
|
absl::Status ConvertToPanTiltZoom(float xmin, float xmax, float ymin,
|
||||||
|
@ -97,7 +101,10 @@ class ContentZoomingCalculator : public CalculatorBase {
|
||||||
std::unique_ptr<KinematicPathSolver> path_solver_tilt_;
|
std::unique_ptr<KinematicPathSolver> path_solver_tilt_;
|
||||||
// Are parameters initialized.
|
// Are parameters initialized.
|
||||||
bool initialized_;
|
bool initialized_;
|
||||||
// Stores the time of the first crop rectangle.
|
// Stores the time of the first crop rectangle. This is used to control
|
||||||
|
// animating to it. Until a first crop rectangle was computed, it has
|
||||||
|
// the value Timestamp::Unset(). If animating is not requested, it receives
|
||||||
|
// the value Timestamp::Done() instead of the time.
|
||||||
Timestamp first_rect_timestamp_;
|
Timestamp first_rect_timestamp_;
|
||||||
// Stores the first crop rectangle.
|
// Stores the first crop rectangle.
|
||||||
mediapipe::NormalizedRect first_rect_;
|
mediapipe::NormalizedRect first_rect_;
|
||||||
|
@ -135,6 +142,9 @@ absl::Status ContentZoomingCalculator::GetContract(
|
||||||
if (cc->Inputs().HasTag(kDetections)) {
|
if (cc->Inputs().HasTag(kDetections)) {
|
||||||
cc->Inputs().Tag(kDetections).Set<std::vector<mediapipe::Detection>>();
|
cc->Inputs().Tag(kDetections).Set<std::vector<mediapipe::Detection>>();
|
||||||
}
|
}
|
||||||
|
if (cc->Inputs().HasTag(kAnimateZoom)) {
|
||||||
|
cc->Inputs().Tag(kAnimateZoom).Set<bool>();
|
||||||
|
}
|
||||||
if (cc->Outputs().HasTag(kDetectedBorders)) {
|
if (cc->Outputs().HasTag(kDetectedBorders)) {
|
||||||
cc->Outputs().Tag(kDetectedBorders).Set<StaticFeatures>();
|
cc->Outputs().Tag(kDetectedBorders).Set<StaticFeatures>();
|
||||||
}
|
}
|
||||||
|
@ -419,10 +429,11 @@ absl::Status ContentZoomingCalculator::UpdateForResolutionChange(
|
||||||
return absl::OkStatus();
|
return absl::OkStatus();
|
||||||
}
|
}
|
||||||
|
|
||||||
bool ContentZoomingCalculator::IsZoomingToInitialRect(
|
bool ContentZoomingCalculator::IsAnimatingToFirstRect(
|
||||||
const Timestamp& timestamp) const {
|
const Timestamp& timestamp) const {
|
||||||
if (options_.us_to_first_rect() == 0 ||
|
if (options_.us_to_first_rect() == 0 ||
|
||||||
first_rect_timestamp_ == Timestamp::Unset()) {
|
first_rect_timestamp_ == Timestamp::Unset() ||
|
||||||
|
first_rect_timestamp_ == Timestamp::Done()) {
|
||||||
return false;
|
return false;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -443,10 +454,10 @@ double easeInOutQuad(double t) {
|
||||||
double lerp(double a, double b, double i) { return a * (1 - i) + b * i; }
|
double lerp(double a, double b, double i) { return a * (1 - i) + b * i; }
|
||||||
} // namespace
|
} // namespace
|
||||||
|
|
||||||
absl::StatusOr<mediapipe::Rect> ContentZoomingCalculator::GetInitialZoomingRect(
|
absl::StatusOr<mediapipe::Rect> ContentZoomingCalculator::GetAnimationRect(
|
||||||
int frame_width, int frame_height, const Timestamp& timestamp) const {
|
int frame_width, int frame_height, const Timestamp& timestamp) const {
|
||||||
RET_CHECK(IsZoomingToInitialRect(timestamp))
|
RET_CHECK(IsAnimatingToFirstRect(timestamp))
|
||||||
<< "Must only be called if zooming to initial rect.";
|
<< "Must only be called if animating to first rect.";
|
||||||
|
|
||||||
const int64 delta_us = (timestamp - first_rect_timestamp_).Value();
|
const int64 delta_us = (timestamp - first_rect_timestamp_).Value();
|
||||||
const int64 delay = options_.us_to_first_rect_delay();
|
const int64 delay = options_.us_to_first_rect_delay();
|
||||||
|
@ -538,15 +549,20 @@ absl::Status ContentZoomingCalculator::Process(
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
bool zooming_to_initial_rect = IsZoomingToInitialRect(cc->InputTimestamp());
|
const bool may_start_animation = (options_.us_to_first_rect() != 0) &&
|
||||||
|
(!cc->Inputs().HasTag(kAnimateZoom) ||
|
||||||
|
cc->Inputs().Tag(kAnimateZoom).Get<bool>());
|
||||||
|
bool is_animating = IsAnimatingToFirstRect(cc->InputTimestamp());
|
||||||
|
|
||||||
int offset_y, height, offset_x;
|
int offset_y, height, offset_x;
|
||||||
if (zooming_to_initial_rect) {
|
if (!is_animating && options_.start_zoomed_out() && !may_start_animation &&
|
||||||
// If we are zooming to the first rect, ignore any new incoming detections.
|
first_rect_timestamp_ == Timestamp::Unset()) {
|
||||||
height = last_measured_height_;
|
// If we should start zoomed out and won't be doing an animation,
|
||||||
offset_x = last_measured_x_offset_;
|
// initialize the path solvers using the full frame, ignoring detections.
|
||||||
offset_y = last_measured_y_offset_;
|
height = max_frame_value_ * frame_height_;
|
||||||
} else if (only_required_found) {
|
offset_x = (target_aspect_ * height) / 2;
|
||||||
|
offset_y = frame_height_ / 2;
|
||||||
|
} else if (!is_animating && only_required_found) {
|
||||||
// Convert bounds to tilt/zoom and in pixel coordinates.
|
// Convert bounds to tilt/zoom and in pixel coordinates.
|
||||||
MP_RETURN_IF_ERROR(ConvertToPanTiltZoom(xmin, xmax, ymin, ymax, &offset_y,
|
MP_RETURN_IF_ERROR(ConvertToPanTiltZoom(xmin, xmax, ymin, ymax, &offset_y,
|
||||||
&offset_x, &height));
|
&offset_x, &height));
|
||||||
|
@ -555,9 +571,9 @@ absl::Status ContentZoomingCalculator::Process(
|
||||||
last_measured_height_ = height;
|
last_measured_height_ = height;
|
||||||
last_measured_x_offset_ = offset_x;
|
last_measured_x_offset_ = offset_x;
|
||||||
last_measured_y_offset_ = offset_y;
|
last_measured_y_offset_ = offset_y;
|
||||||
} else if (cc->InputTimestamp().Microseconds() -
|
} else if (!is_animating && cc->InputTimestamp().Microseconds() -
|
||||||
last_only_required_detection_ >=
|
last_only_required_detection_ >=
|
||||||
options_.us_before_zoomout()) {
|
options_.us_before_zoomout()) {
|
||||||
// No only_require detections found within salient regions packets
|
// No only_require detections found within salient regions packets
|
||||||
// arriving since us_before_zoomout duration.
|
// arriving since us_before_zoomout duration.
|
||||||
height = max_frame_value_ * frame_height_ +
|
height = max_frame_value_ * frame_height_ +
|
||||||
|
@ -566,7 +582,8 @@ absl::Status ContentZoomingCalculator::Process(
|
||||||
offset_x = (target_aspect_ * height) / 2;
|
offset_x = (target_aspect_ * height) / 2;
|
||||||
offset_y = frame_height_ / 2;
|
offset_y = frame_height_ / 2;
|
||||||
} else {
|
} else {
|
||||||
// No only detection found but using last detection due to
|
// Either animating to the first rectangle, or
|
||||||
|
// no only detection found but using last detection due to
|
||||||
// duration_before_zoomout_us setting.
|
// duration_before_zoomout_us setting.
|
||||||
height = last_measured_height_;
|
height = last_measured_height_;
|
||||||
offset_x = last_measured_x_offset_;
|
offset_x = last_measured_x_offset_;
|
||||||
|
@ -642,24 +659,28 @@ absl::Status ContentZoomingCalculator::Process(
|
||||||
.AddPacket(Adopt(features.release()).At(cc->InputTimestamp()));
|
.AddPacket(Adopt(features.release()).At(cc->InputTimestamp()));
|
||||||
}
|
}
|
||||||
|
|
||||||
if (first_rect_timestamp_ == Timestamp::Unset() &&
|
// Record the first crop rectangle
|
||||||
options_.us_to_first_rect() != 0) {
|
if (first_rect_timestamp_ == Timestamp::Unset()) {
|
||||||
first_rect_timestamp_ = cc->InputTimestamp();
|
|
||||||
first_rect_.set_x_center(path_offset_x / static_cast<float>(frame_width_));
|
first_rect_.set_x_center(path_offset_x / static_cast<float>(frame_width_));
|
||||||
first_rect_.set_width(path_height * target_aspect_ /
|
first_rect_.set_width(path_height * target_aspect_ /
|
||||||
static_cast<float>(frame_width_));
|
static_cast<float>(frame_width_));
|
||||||
first_rect_.set_y_center(path_offset_y / static_cast<float>(frame_height_));
|
first_rect_.set_y_center(path_offset_y / static_cast<float>(frame_height_));
|
||||||
first_rect_.set_height(path_height / static_cast<float>(frame_height_));
|
first_rect_.set_height(path_height / static_cast<float>(frame_height_));
|
||||||
// After setting the first rectangle, check whether we should zoom to it.
|
|
||||||
zooming_to_initial_rect = IsZoomingToInitialRect(cc->InputTimestamp());
|
// Record the time to serve as departure point for the animation.
|
||||||
|
// If we are not allowed to start the animation, set Timestamp::Done.
|
||||||
|
first_rect_timestamp_ =
|
||||||
|
may_start_animation ? cc->InputTimestamp() : Timestamp::Done();
|
||||||
|
// After setting the first rectangle, check whether we should animate to it.
|
||||||
|
is_animating = IsAnimatingToFirstRect(cc->InputTimestamp());
|
||||||
}
|
}
|
||||||
|
|
||||||
// Transmit downstream to glcroppingcalculator.
|
// Transmit downstream to glcroppingcalculator.
|
||||||
if (cc->Outputs().HasTag(kCropRect)) {
|
if (cc->Outputs().HasTag(kCropRect)) {
|
||||||
std::unique_ptr<mediapipe::Rect> gpu_rect;
|
std::unique_ptr<mediapipe::Rect> gpu_rect;
|
||||||
if (zooming_to_initial_rect) {
|
if (is_animating) {
|
||||||
auto rect = GetInitialZoomingRect(frame_width, frame_height,
|
auto rect =
|
||||||
cc->InputTimestamp());
|
GetAnimationRect(frame_width, frame_height, cc->InputTimestamp());
|
||||||
MP_RETURN_IF_ERROR(rect.status());
|
MP_RETURN_IF_ERROR(rect.status());
|
||||||
gpu_rect = absl::make_unique<mediapipe::Rect>(*rect);
|
gpu_rect = absl::make_unique<mediapipe::Rect>(*rect);
|
||||||
} else {
|
} else {
|
||||||
|
|
|
@ -19,7 +19,7 @@ package mediapipe.autoflip;
|
||||||
import "mediapipe/examples/desktop/autoflip/quality/kinematic_path_solver.proto";
|
import "mediapipe/examples/desktop/autoflip/quality/kinematic_path_solver.proto";
|
||||||
import "mediapipe/framework/calculator.proto";
|
import "mediapipe/framework/calculator.proto";
|
||||||
|
|
||||||
// NextTag: 17
|
// NextTag: 18
|
||||||
message ContentZoomingCalculatorOptions {
|
message ContentZoomingCalculatorOptions {
|
||||||
extend mediapipe.CalculatorOptions {
|
extend mediapipe.CalculatorOptions {
|
||||||
optional ContentZoomingCalculatorOptions ext = 313091992;
|
optional ContentZoomingCalculatorOptions ext = 313091992;
|
||||||
|
@ -58,9 +58,15 @@ message ContentZoomingCalculatorOptions {
|
||||||
// Whether to keep state between frames or to compute the final crop rect.
|
// Whether to keep state between frames or to compute the final crop rect.
|
||||||
optional bool is_stateless = 14 [default = false];
|
optional bool is_stateless = 14 [default = false];
|
||||||
|
|
||||||
// Duration (in MicroSeconds) for moving to the first crop rect.
|
// If true, on the first packet start with the camera zoomed out and then zoom
|
||||||
|
// in on the subject. If false, the camera will start zoomed in on the
|
||||||
|
// subject.
|
||||||
|
optional bool start_zoomed_out = 17 [default = false];
|
||||||
|
|
||||||
|
// Duration (in MicroSeconds) for animating to the first crop rect.
|
||||||
|
// Note that if set, takes precedence over start_zoomed_out.
|
||||||
optional int64 us_to_first_rect = 15 [default = 0];
|
optional int64 us_to_first_rect = 15 [default = 0];
|
||||||
// Duration (in MicroSeconds) to delay moving to the first crop rect.
|
// Duration (in MicroSeconds) to delay animating to the first crop rect.
|
||||||
// Used only if us_to_first_rect is set and is interpreted as part of the
|
// Used only if us_to_first_rect is set and is interpreted as part of the
|
||||||
// us_to_first_rect time budget.
|
// us_to_first_rect time budget.
|
||||||
optional int64 us_to_first_rect_delay = 16 [default = 0];
|
optional int64 us_to_first_rect_delay = 16 [default = 0];
|
||||||
|
|
|
@ -127,6 +127,29 @@ const char kConfigD[] = R"(
|
||||||
}
|
}
|
||||||
)";
|
)";
|
||||||
|
|
||||||
|
const char kConfigE[] = R"(
|
||||||
|
calculator: "ContentZoomingCalculator"
|
||||||
|
input_stream: "VIDEO_SIZE:size"
|
||||||
|
input_stream: "DETECTIONS:detections"
|
||||||
|
input_stream: "ANIMATE_ZOOM:animate_zoom"
|
||||||
|
output_stream: "CROP_RECT:rect"
|
||||||
|
output_stream: "FIRST_CROP_RECT:first_rect"
|
||||||
|
options: {
|
||||||
|
[mediapipe.autoflip.ContentZoomingCalculatorOptions.ext]: {
|
||||||
|
max_zoom_value_deg: 0
|
||||||
|
kinematic_options_zoom {
|
||||||
|
min_motion_to_reframe: 1.2
|
||||||
|
}
|
||||||
|
kinematic_options_tilt {
|
||||||
|
min_motion_to_reframe: 1.2
|
||||||
|
}
|
||||||
|
kinematic_options_pan {
|
||||||
|
min_motion_to_reframe: 1.2
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
)";
|
||||||
|
|
||||||
void CheckBorder(const StaticFeatures& static_features, int width, int height,
|
void CheckBorder(const StaticFeatures& static_features, int width, int height,
|
||||||
int top_border, int bottom_border) {
|
int top_border, int bottom_border) {
|
||||||
ASSERT_EQ(2, static_features.border().size());
|
ASSERT_EQ(2, static_features.border().size());
|
||||||
|
@ -145,9 +168,14 @@ void CheckBorder(const StaticFeatures& static_features, int width, int height,
|
||||||
EXPECT_EQ(Border::BOTTOM, part.relative_position());
|
EXPECT_EQ(Border::BOTTOM, part.relative_position());
|
||||||
}
|
}
|
||||||
|
|
||||||
|
struct AddDetectionFlags {
|
||||||
|
std::optional<bool> animated_zoom;
|
||||||
|
};
|
||||||
|
|
||||||
void AddDetectionFrameSize(const cv::Rect_<float>& position, const int64 time,
|
void AddDetectionFrameSize(const cv::Rect_<float>& position, const int64 time,
|
||||||
const int width, const int height,
|
const int width, const int height,
|
||||||
CalculatorRunner* runner) {
|
CalculatorRunner* runner,
|
||||||
|
const AddDetectionFlags& flags = {}) {
|
||||||
auto detections = std::make_unique<std::vector<mediapipe::Detection>>();
|
auto detections = std::make_unique<std::vector<mediapipe::Detection>>();
|
||||||
if (position.width > 0 && position.height > 0) {
|
if (position.width > 0 && position.height > 0) {
|
||||||
mediapipe::Detection detection;
|
mediapipe::Detection detection;
|
||||||
|
@ -175,6 +203,14 @@ void AddDetectionFrameSize(const cv::Rect_<float>& position, const int64 time,
|
||||||
runner->MutableInputs()
|
runner->MutableInputs()
|
||||||
->Tag("VIDEO_SIZE")
|
->Tag("VIDEO_SIZE")
|
||||||
.packets.push_back(Adopt(input_size.release()).At(Timestamp(time)));
|
.packets.push_back(Adopt(input_size.release()).At(Timestamp(time)));
|
||||||
|
|
||||||
|
if (flags.animated_zoom.has_value()) {
|
||||||
|
runner->MutableInputs()
|
||||||
|
->Tag("ANIMATE_ZOOM")
|
||||||
|
.packets.push_back(
|
||||||
|
mediapipe::MakePacket<bool>(flags.animated_zoom.value())
|
||||||
|
.At(Timestamp(time)));
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
void AddDetection(const cv::Rect_<float>& position, const int64 time,
|
void AddDetection(const cv::Rect_<float>& position, const int64 time,
|
||||||
|
@ -703,7 +739,33 @@ TEST(ContentZoomingCalculatorTest, MaxZoomOutValue) {
|
||||||
CheckCropRect(500, 500, 1000, 1000, 2,
|
CheckCropRect(500, 500, 1000, 1000, 2,
|
||||||
runner->Outputs().Tag("CROP_RECT").packets);
|
runner->Outputs().Tag("CROP_RECT").packets);
|
||||||
}
|
}
|
||||||
|
|
||||||
TEST(ContentZoomingCalculatorTest, StartZoomedOut) {
|
TEST(ContentZoomingCalculatorTest, StartZoomedOut) {
|
||||||
|
auto config = ParseTextProtoOrDie<CalculatorGraphConfig::Node>(kConfigD);
|
||||||
|
auto* options = config.mutable_options()->MutableExtension(
|
||||||
|
ContentZoomingCalculatorOptions::ext);
|
||||||
|
options->set_start_zoomed_out(true);
|
||||||
|
auto runner = ::absl::make_unique<CalculatorRunner>(config);
|
||||||
|
AddDetectionFrameSize(cv::Rect_<float>(.4, .4, .2, .2), 0, 1000, 1000,
|
||||||
|
runner.get());
|
||||||
|
AddDetectionFrameSize(cv::Rect_<float>(.4, .4, .2, .2), 400000, 1000, 1000,
|
||||||
|
runner.get());
|
||||||
|
AddDetectionFrameSize(cv::Rect_<float>(.4, .4, .2, .2), 800000, 1000, 1000,
|
||||||
|
runner.get());
|
||||||
|
AddDetectionFrameSize(cv::Rect_<float>(.4, .4, .2, .2), 1000000, 1000, 1000,
|
||||||
|
runner.get());
|
||||||
|
MP_ASSERT_OK(runner->Run());
|
||||||
|
CheckCropRect(500, 500, 1000, 1000, 0,
|
||||||
|
runner->Outputs().Tag("CROP_RECT").packets);
|
||||||
|
CheckCropRect(500, 500, 880, 880, 1,
|
||||||
|
runner->Outputs().Tag("CROP_RECT").packets);
|
||||||
|
CheckCropRect(500, 500, 760, 760, 2,
|
||||||
|
runner->Outputs().Tag("CROP_RECT").packets);
|
||||||
|
CheckCropRect(500, 500, 655, 655, 3,
|
||||||
|
runner->Outputs().Tag("CROP_RECT").packets);
|
||||||
|
}
|
||||||
|
|
||||||
|
TEST(ContentZoomingCalculatorTest, AnimateToFirstRect) {
|
||||||
auto config = ParseTextProtoOrDie<CalculatorGraphConfig::Node>(kConfigD);
|
auto config = ParseTextProtoOrDie<CalculatorGraphConfig::Node>(kConfigD);
|
||||||
auto* options = config.mutable_options()->MutableExtension(
|
auto* options = config.mutable_options()->MutableExtension(
|
||||||
ContentZoomingCalculatorOptions::ext);
|
ContentZoomingCalculatorOptions::ext);
|
||||||
|
@ -733,6 +795,65 @@ TEST(ContentZoomingCalculatorTest, StartZoomedOut) {
|
||||||
runner->Outputs().Tag("CROP_RECT").packets);
|
runner->Outputs().Tag("CROP_RECT").packets);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
TEST(ContentZoomingCalculatorTest, CanControlAnimation) {
|
||||||
|
auto config = ParseTextProtoOrDie<CalculatorGraphConfig::Node>(kConfigE);
|
||||||
|
auto* options = config.mutable_options()->MutableExtension(
|
||||||
|
ContentZoomingCalculatorOptions::ext);
|
||||||
|
options->set_start_zoomed_out(true);
|
||||||
|
options->set_us_to_first_rect(1000000);
|
||||||
|
options->set_us_to_first_rect_delay(500000);
|
||||||
|
auto runner = ::absl::make_unique<CalculatorRunner>(config);
|
||||||
|
// Request the animation for the first frame.
|
||||||
|
AddDetectionFrameSize(cv::Rect_<float>(.4, .4, .2, .2), 0, 1000, 1000,
|
||||||
|
runner.get(), {.animated_zoom = true});
|
||||||
|
// We now stop requesting animated zoom and expect the already started
|
||||||
|
// animation run to completion. This tests that the zoom in continues in the
|
||||||
|
// call when it was started in the Meet greenroom.
|
||||||
|
AddDetectionFrameSize(cv::Rect_<float>(.4, .4, .2, .2), 400000, 1000, 1000,
|
||||||
|
runner.get(), {.animated_zoom = false});
|
||||||
|
AddDetectionFrameSize(cv::Rect_<float>(.4, .4, .2, .2), 800000, 1000, 1000,
|
||||||
|
runner.get(), {.animated_zoom = false});
|
||||||
|
AddDetectionFrameSize(cv::Rect_<float>(.4, .4, .2, .2), 1000000, 1000, 1000,
|
||||||
|
runner.get(), {.animated_zoom = false});
|
||||||
|
AddDetectionFrameSize(cv::Rect_<float>(.4, .4, .2, .2), 1500000, 1000, 1000,
|
||||||
|
runner.get(), {.animated_zoom = false});
|
||||||
|
MP_ASSERT_OK(runner->Run());
|
||||||
|
CheckCropRect(500, 500, 1000, 1000, 0,
|
||||||
|
runner->Outputs().Tag("CROP_RECT").packets);
|
||||||
|
CheckCropRect(500, 500, 1000, 1000, 1,
|
||||||
|
runner->Outputs().Tag("CROP_RECT").packets);
|
||||||
|
CheckCropRect(500, 500, 470, 470, 2,
|
||||||
|
runner->Outputs().Tag("CROP_RECT").packets);
|
||||||
|
CheckCropRect(500, 500, 222, 222, 3,
|
||||||
|
runner->Outputs().Tag("CROP_RECT").packets);
|
||||||
|
CheckCropRect(500, 500, 222, 222, 4,
|
||||||
|
runner->Outputs().Tag("CROP_RECT").packets);
|
||||||
|
}
|
||||||
|
|
||||||
|
TEST(ContentZoomingCalculatorTest, DoesNotAnimateIfDisabledViaInput) {
|
||||||
|
auto config = ParseTextProtoOrDie<CalculatorGraphConfig::Node>(kConfigE);
|
||||||
|
auto* options = config.mutable_options()->MutableExtension(
|
||||||
|
ContentZoomingCalculatorOptions::ext);
|
||||||
|
options->set_start_zoomed_out(true);
|
||||||
|
options->set_us_to_first_rect(1000000);
|
||||||
|
options->set_us_to_first_rect_delay(500000);
|
||||||
|
auto runner = ::absl::make_unique<CalculatorRunner>(config);
|
||||||
|
// Disable the animation already for the first frame.
|
||||||
|
AddDetectionFrameSize(cv::Rect_<float>(.4, .4, .2, .2), 0, 1000, 1000,
|
||||||
|
runner.get(), {.animated_zoom = false});
|
||||||
|
AddDetectionFrameSize(cv::Rect_<float>(.4, .4, .2, .2), 400000, 1000, 1000,
|
||||||
|
runner.get(), {.animated_zoom = false});
|
||||||
|
AddDetectionFrameSize(cv::Rect_<float>(.4, .4, .2, .2), 800000, 1000, 1000,
|
||||||
|
runner.get(), {.animated_zoom = false});
|
||||||
|
MP_ASSERT_OK(runner->Run());
|
||||||
|
CheckCropRect(500, 500, 1000, 1000, 0,
|
||||||
|
runner->Outputs().Tag("CROP_RECT").packets);
|
||||||
|
CheckCropRect(500, 500, 880, 880, 1,
|
||||||
|
runner->Outputs().Tag("CROP_RECT").packets);
|
||||||
|
CheckCropRect(500, 500, 760, 760, 2,
|
||||||
|
runner->Outputs().Tag("CROP_RECT").packets);
|
||||||
|
}
|
||||||
|
|
||||||
TEST(ContentZoomingCalculatorTest, ProvidesZeroSizeFirstRectWithoutDetections) {
|
TEST(ContentZoomingCalculatorTest, ProvidesZeroSizeFirstRectWithoutDetections) {
|
||||||
auto config = ParseTextProtoOrDie<CalculatorGraphConfig::Node>(kConfigD);
|
auto config = ParseTextProtoOrDie<CalculatorGraphConfig::Node>(kConfigD);
|
||||||
auto runner = ::absl::make_unique<CalculatorRunner>(config);
|
auto runner = ::absl::make_unique<CalculatorRunner>(config);
|
||||||
|
|
|
@ -47,4 +47,10 @@ message FaceBoxAdjusterCalculatorOptions {
|
||||||
// and height respectively.
|
// and height respectively.
|
||||||
optional float ipd_face_box_width_ratio = 6 [default = 0.5566];
|
optional float ipd_face_box_width_ratio = 6 [default = 0.5566];
|
||||||
optional float ipd_face_box_height_ratio = 7 [default = 0.3131];
|
optional float ipd_face_box_height_ratio = 7 [default = 0.3131];
|
||||||
|
|
||||||
|
// The max look up angle before considering the eye distance unstable.
|
||||||
|
optional float max_head_tilt_angle_deg = 8 [default = 12.0];
|
||||||
|
// The max amount of time to use an old eye distance when the face look angle
|
||||||
|
// is unstable.
|
||||||
|
optional int32 max_facesize_history_us = 9 [default = 8000000];
|
||||||
}
|
}
|
||||||
|
|
|
@ -345,8 +345,7 @@ TEST(SceneCroppingCalculatorTest, ChecksPriorFrameBufferSize) {
|
||||||
TEST(SceneCroppingCalculatorTest, ChecksDebugConfigWithoutCroppedFrame) {
|
TEST(SceneCroppingCalculatorTest, ChecksDebugConfigWithoutCroppedFrame) {
|
||||||
const CalculatorGraphConfig::Node config =
|
const CalculatorGraphConfig::Node config =
|
||||||
ParseTextProtoOrDie<CalculatorGraphConfig::Node>(absl::Substitute(
|
ParseTextProtoOrDie<CalculatorGraphConfig::Node>(absl::Substitute(
|
||||||
kDebugConfigNoCroppedFrame, kTargetWidth, kTargetHeight,
|
kDebugConfigNoCroppedFrame, kTargetWidth, kTargetHeight));
|
||||||
kTargetSizeType, 0, kPriorFrameBufferSize));
|
|
||||||
auto runner = absl::make_unique<CalculatorRunner>(config);
|
auto runner = absl::make_unique<CalculatorRunner>(config);
|
||||||
const auto status = runner->Run();
|
const auto status = runner->Run();
|
||||||
EXPECT_FALSE(status.ok());
|
EXPECT_FALSE(status.ok());
|
||||||
|
|
|
@ -220,7 +220,7 @@ absl::Status KinematicPathSolver::GetTargetPosition(int* target_position) {
|
||||||
|
|
||||||
absl::Status KinematicPathSolver::UpdatePixelsPerDegree(
|
absl::Status KinematicPathSolver::UpdatePixelsPerDegree(
|
||||||
const float pixels_per_degree) {
|
const float pixels_per_degree) {
|
||||||
RET_CHECK_GT(pixels_per_degree_, 0)
|
RET_CHECK_GT(pixels_per_degree, 0)
|
||||||
<< "pixels_per_degree must be larger than 0.";
|
<< "pixels_per_degree must be larger than 0.";
|
||||||
pixels_per_degree_ = pixels_per_degree;
|
pixels_per_degree_ = pixels_per_degree;
|
||||||
return absl::OkStatus();
|
return absl::OkStatus();
|
||||||
|
|
|
@ -38,7 +38,7 @@ node {
|
||||||
output_stream: "TENSORS:detection_tensors"
|
output_stream: "TENSORS:detection_tensors"
|
||||||
options: {
|
options: {
|
||||||
[mediapipe.TfLiteInferenceCalculatorOptions.ext] {
|
[mediapipe.TfLiteInferenceCalculatorOptions.ext] {
|
||||||
model_path: "mediapipe/models/face_detection_back.tflite"
|
model_path: "mediapipe/modules/face_detection/face_detection_back.tflite"
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
@ -111,26 +111,13 @@ node {
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
# Maps detection label IDs to the corresponding label text ("Face"). The label
|
|
||||||
# map is provided in the label_map_path option.
|
|
||||||
node {
|
|
||||||
calculator: "DetectionLabelIdToTextCalculator"
|
|
||||||
input_stream: "filtered_detections"
|
|
||||||
output_stream: "labeled_detections"
|
|
||||||
options: {
|
|
||||||
[mediapipe.DetectionLabelIdToTextCalculatorOptions.ext] {
|
|
||||||
label_map_path: "mediapipe/models/face_detection_back_labelmap.txt"
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
# Adjusts detection locations (already normalized to [0.f, 1.f]) on the
|
# Adjusts detection locations (already normalized to [0.f, 1.f]) on the
|
||||||
# letterboxed image (after image transformation with the FIT scale mode) to the
|
# letterboxed image (after image transformation with the FIT scale mode) to the
|
||||||
# corresponding locations on the same image with the letterbox removed (the
|
# corresponding locations on the same image with the letterbox removed (the
|
||||||
# input image to the graph before image transformation).
|
# input image to the graph before image transformation).
|
||||||
node {
|
node {
|
||||||
calculator: "DetectionLetterboxRemovalCalculator"
|
calculator: "DetectionLetterboxRemovalCalculator"
|
||||||
input_stream: "DETECTIONS:labeled_detections"
|
input_stream: "DETECTIONS:filtered_detections"
|
||||||
input_stream: "LETTERBOX_PADDING:letterbox_padding"
|
input_stream: "LETTERBOX_PADDING:letterbox_padding"
|
||||||
output_stream: "DETECTIONS:output_detections"
|
output_stream: "DETECTIONS:output_detections"
|
||||||
}
|
}
|
||||||
|
|
|
@ -1,34 +0,0 @@
|
||||||
# Copyright 2020 The MediaPipe Authors.
|
|
||||||
#
|
|
||||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
||||||
# you may not use this file except in compliance with the License.
|
|
||||||
# You may obtain a copy of the License at
|
|
||||||
#
|
|
||||||
# http://www.apache.org/licenses/LICENSE-2.0
|
|
||||||
#
|
|
||||||
# Unless required by applicable law or agreed to in writing, software
|
|
||||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
||||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
||||||
# See the License for the specific language governing permissions and
|
|
||||||
# limitations under the License.
|
|
||||||
|
|
||||||
licenses(["notice"])
|
|
||||||
|
|
||||||
package(default_visibility = ["//mediapipe/examples:__subpackages__"])
|
|
||||||
|
|
||||||
cc_binary(
|
|
||||||
name = "upper_body_pose_tracking_cpu",
|
|
||||||
deps = [
|
|
||||||
"//mediapipe/examples/desktop:demo_run_graph_main",
|
|
||||||
"//mediapipe/graphs/pose_tracking:upper_body_pose_tracking_cpu_deps",
|
|
||||||
],
|
|
||||||
)
|
|
||||||
|
|
||||||
# Linux only
|
|
||||||
cc_binary(
|
|
||||||
name = "upper_body_pose_tracking_gpu",
|
|
||||||
deps = [
|
|
||||||
"//mediapipe/examples/desktop:demo_run_graph_main_gpu",
|
|
||||||
"//mediapipe/graphs/pose_tracking:upper_body_pose_tracking_gpu_deps",
|
|
||||||
],
|
|
||||||
)
|
|
|
@ -61,8 +61,7 @@ objc_library(
|
||||||
"//mediapipe/modules/hand_landmark:handedness.txt",
|
"//mediapipe/modules/hand_landmark:handedness.txt",
|
||||||
"//mediapipe/modules/holistic_landmark:hand_recrop.tflite",
|
"//mediapipe/modules/holistic_landmark:hand_recrop.tflite",
|
||||||
"//mediapipe/modules/pose_detection:pose_detection.tflite",
|
"//mediapipe/modules/pose_detection:pose_detection.tflite",
|
||||||
"//mediapipe/modules/pose_landmark:pose_landmark_full_body.tflite",
|
"//mediapipe/modules/pose_landmark:pose_landmark_full.tflite",
|
||||||
"//mediapipe/modules/pose_landmark:pose_landmark_upper_body.tflite",
|
|
||||||
],
|
],
|
||||||
deps = [
|
deps = [
|
||||||
"//mediapipe/examples/ios/common:CommonMediaPipeAppLibrary",
|
"//mediapipe/examples/ios/common:CommonMediaPipeAppLibrary",
|
||||||
|
|
|
@ -63,7 +63,7 @@ objc_library(
|
||||||
data = [
|
data = [
|
||||||
"//mediapipe/graphs/pose_tracking:pose_tracking_gpu.binarypb",
|
"//mediapipe/graphs/pose_tracking:pose_tracking_gpu.binarypb",
|
||||||
"//mediapipe/modules/pose_detection:pose_detection.tflite",
|
"//mediapipe/modules/pose_detection:pose_detection.tflite",
|
||||||
"//mediapipe/modules/pose_landmark:pose_landmark_full_body.tflite",
|
"//mediapipe/modules/pose_landmark:pose_landmark_full.tflite",
|
||||||
],
|
],
|
||||||
deps = [
|
deps = [
|
||||||
"//mediapipe/examples/ios/common:CommonMediaPipeAppLibrary",
|
"//mediapipe/examples/ios/common:CommonMediaPipeAppLibrary",
|
||||||
|
|
|
@ -1,78 +0,0 @@
|
||||||
# Copyright 2020 The MediaPipe Authors.
|
|
||||||
#
|
|
||||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
||||||
# you may not use this file except in compliance with the License.
|
|
||||||
# You may obtain a copy of the License at
|
|
||||||
#
|
|
||||||
# http://www.apache.org/licenses/LICENSE-2.0
|
|
||||||
#
|
|
||||||
# Unless required by applicable law or agreed to in writing, software
|
|
||||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
||||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
||||||
# See the License for the specific language governing permissions and
|
|
||||||
# limitations under the License.
|
|
||||||
|
|
||||||
load(
|
|
||||||
"@build_bazel_rules_apple//apple:ios.bzl",
|
|
||||||
"ios_application",
|
|
||||||
)
|
|
||||||
load(
|
|
||||||
"//mediapipe/examples/ios:bundle_id.bzl",
|
|
||||||
"BUNDLE_ID_PREFIX",
|
|
||||||
"example_provisioning",
|
|
||||||
)
|
|
||||||
|
|
||||||
licenses(["notice"])
|
|
||||||
|
|
||||||
MIN_IOS_VERSION = "10.0"
|
|
||||||
|
|
||||||
alias(
|
|
||||||
name = "upperbodyposetrackinggpu",
|
|
||||||
actual = "UpperBodyPoseTrackingGpuApp",
|
|
||||||
)
|
|
||||||
|
|
||||||
ios_application(
|
|
||||||
name = "UpperBodyPoseTrackingGpuApp",
|
|
||||||
app_icons = ["//mediapipe/examples/ios/common:AppIcon"],
|
|
||||||
bundle_id = BUNDLE_ID_PREFIX + ".UpperBodyPoseTrackingGpu",
|
|
||||||
families = [
|
|
||||||
"iphone",
|
|
||||||
"ipad",
|
|
||||||
],
|
|
||||||
infoplists = [
|
|
||||||
"//mediapipe/examples/ios/common:Info.plist",
|
|
||||||
"Info.plist",
|
|
||||||
],
|
|
||||||
minimum_os_version = MIN_IOS_VERSION,
|
|
||||||
provisioning_profile = example_provisioning(),
|
|
||||||
deps = [
|
|
||||||
":UpperBodyPoseTrackingGpuAppLibrary",
|
|
||||||
"@ios_opencv//:OpencvFramework",
|
|
||||||
],
|
|
||||||
)
|
|
||||||
|
|
||||||
objc_library(
|
|
||||||
name = "UpperBodyPoseTrackingGpuAppLibrary",
|
|
||||||
srcs = [
|
|
||||||
"UpperBodyPoseTrackingViewController.mm",
|
|
||||||
],
|
|
||||||
hdrs = [
|
|
||||||
"UpperBodyPoseTrackingViewController.h",
|
|
||||||
],
|
|
||||||
copts = ["-std=c++17"],
|
|
||||||
data = [
|
|
||||||
"//mediapipe/graphs/pose_tracking:upper_body_pose_tracking_gpu.binarypb",
|
|
||||||
"//mediapipe/modules/pose_detection:pose_detection.tflite",
|
|
||||||
"//mediapipe/modules/pose_landmark:pose_landmark_upper_body.tflite",
|
|
||||||
],
|
|
||||||
deps = [
|
|
||||||
"//mediapipe/examples/ios/common:CommonMediaPipeAppLibrary",
|
|
||||||
] + select({
|
|
||||||
"//mediapipe:ios_i386": [],
|
|
||||||
"//mediapipe:ios_x86_64": [],
|
|
||||||
"//conditions:default": [
|
|
||||||
"//mediapipe/graphs/pose_tracking:upper_body_pose_tracking_gpu_deps",
|
|
||||||
"//mediapipe/framework/formats:landmark_cc_proto",
|
|
||||||
],
|
|
||||||
}),
|
|
||||||
)
|
|
|
@ -1,16 +0,0 @@
|
||||||
<?xml version="1.0" encoding="UTF-8"?>
|
|
||||||
<!DOCTYPE plist PUBLIC "-//Apple//DTD PLIST 1.0//EN" "http://www.apple.com/DTDs/PropertyList-1.0.dtd">
|
|
||||||
<plist version="1.0">
|
|
||||||
<dict>
|
|
||||||
<key>CameraPosition</key>
|
|
||||||
<string>back</string>
|
|
||||||
<key>MainViewController</key>
|
|
||||||
<string>UpperBodyPoseTrackingViewController</string>
|
|
||||||
<key>GraphOutputStream</key>
|
|
||||||
<string>output_video</string>
|
|
||||||
<key>GraphInputStream</key>
|
|
||||||
<string>input_video</string>
|
|
||||||
<key>GraphName</key>
|
|
||||||
<string>upper_body_pose_tracking_gpu</string>
|
|
||||||
</dict>
|
|
||||||
</plist>
|
|
|
@ -1,53 +0,0 @@
|
||||||
// Copyright 2019 The MediaPipe Authors.
|
|
||||||
//
|
|
||||||
// Licensed under the Apache License, Version 2.0 (the "License");
|
|
||||||
// you may not use this file except in compliance with the License.
|
|
||||||
// You may obtain a copy of the License at
|
|
||||||
//
|
|
||||||
// http://www.apache.org/licenses/LICENSE-2.0
|
|
||||||
//
|
|
||||||
// Unless required by applicable law or agreed to in writing, software
|
|
||||||
// distributed under the License is distributed on an "AS IS" BASIS,
|
|
||||||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
||||||
// See the License for the specific language governing permissions and
|
|
||||||
// limitations under the License.
|
|
||||||
|
|
||||||
#import "UpperBodyPoseTrackingViewController.h"
|
|
||||||
|
|
||||||
#include "mediapipe/framework/formats/landmark.pb.h"
|
|
||||||
|
|
||||||
static const char* kLandmarksOutputStream = "pose_landmarks";
|
|
||||||
|
|
||||||
@implementation UpperBodyPoseTrackingViewController
|
|
||||||
|
|
||||||
#pragma mark - UIViewController methods
|
|
||||||
|
|
||||||
- (void)viewDidLoad {
|
|
||||||
[super viewDidLoad];
|
|
||||||
|
|
||||||
[self.mediapipeGraph addFrameOutputStream:kLandmarksOutputStream
|
|
||||||
outputPacketType:MPPPacketTypeRaw];
|
|
||||||
}
|
|
||||||
|
|
||||||
#pragma mark - MPPGraphDelegate methods
|
|
||||||
|
|
||||||
// Receives a raw packet from the MediaPipe graph. Invoked on a MediaPipe worker thread.
|
|
||||||
- (void)mediapipeGraph:(MPPGraph*)graph
|
|
||||||
didOutputPacket:(const ::mediapipe::Packet&)packet
|
|
||||||
fromStream:(const std::string&)streamName {
|
|
||||||
if (streamName == kLandmarksOutputStream) {
|
|
||||||
if (packet.IsEmpty()) {
|
|
||||||
NSLog(@"[TS:%lld] No pose landmarks", packet.Timestamp().Value());
|
|
||||||
return;
|
|
||||||
}
|
|
||||||
const auto& landmarks = packet.Get<::mediapipe::NormalizedLandmarkList>();
|
|
||||||
NSLog(@"[TS:%lld] Number of pose landmarks: %d", packet.Timestamp().Value(),
|
|
||||||
landmarks.landmark_size());
|
|
||||||
for (int i = 0; i < landmarks.landmark_size(); ++i) {
|
|
||||||
NSLog(@"\tLandmark[%d]: (%f, %f, %f)", i, landmarks.landmark(i).x(),
|
|
||||||
landmarks.landmark(i).y(), landmarks.landmark(i).z());
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
@end
|
|
|
@ -15,6 +15,7 @@
|
||||||
#
|
#
|
||||||
|
|
||||||
load("//mediapipe/framework/port:build_config.bzl", "mediapipe_proto_library")
|
load("//mediapipe/framework/port:build_config.bzl", "mediapipe_proto_library")
|
||||||
|
load("@bazel_skylib//:bzl_library.bzl", "bzl_library")
|
||||||
|
|
||||||
licenses(["notice"])
|
licenses(["notice"])
|
||||||
|
|
||||||
|
@ -27,10 +28,28 @@ package_group(
|
||||||
],
|
],
|
||||||
)
|
)
|
||||||
|
|
||||||
exports_files([
|
bzl_library(
|
||||||
"transitive_protos.bzl",
|
name = "transitive_protos_bzl",
|
||||||
"encode_binary_proto.bzl",
|
srcs = [
|
||||||
])
|
"transitive_protos.bzl",
|
||||||
|
],
|
||||||
|
visibility = ["//mediapipe/framework:__subpackages__"],
|
||||||
|
)
|
||||||
|
|
||||||
|
bzl_library(
|
||||||
|
name = "encode_binary_proto_bzl",
|
||||||
|
srcs = [
|
||||||
|
"encode_binary_proto.bzl",
|
||||||
|
],
|
||||||
|
visibility = ["//visibility:public"],
|
||||||
|
)
|
||||||
|
|
||||||
|
alias(
|
||||||
|
name = "encode_binary_proto",
|
||||||
|
actual = ":encode_binary_proto_bzl",
|
||||||
|
deprecation = "Use encode_binary_proto_bzl",
|
||||||
|
visibility = ["//visibility:public"],
|
||||||
|
)
|
||||||
|
|
||||||
mediapipe_proto_library(
|
mediapipe_proto_library(
|
||||||
name = "calculator_proto",
|
name = "calculator_proto",
|
||||||
|
|
|
@ -1,15 +1,8 @@
|
||||||
package(
|
package(
|
||||||
default_visibility = [":preview_users"],
|
default_visibility = ["//visibility:public"],
|
||||||
features = ["-use_header_modules"],
|
features = ["-use_header_modules"],
|
||||||
)
|
)
|
||||||
|
|
||||||
package_group(
|
|
||||||
name = "preview_users",
|
|
||||||
packages = [
|
|
||||||
"//mediapipe/...",
|
|
||||||
],
|
|
||||||
)
|
|
||||||
|
|
||||||
licenses(["notice"])
|
licenses(["notice"])
|
||||||
|
|
||||||
cc_library(
|
cc_library(
|
||||||
|
|
|
@ -422,6 +422,9 @@ message CalculatorGraphConfig {
|
||||||
// the graph config.
|
// the graph config.
|
||||||
string type = 20;
|
string type = 20;
|
||||||
|
|
||||||
// Can be used for annotating a graph.
|
// The types and default values for graph options, in proto2 syntax.
|
||||||
MediaPipeOptions options = 1001;
|
MediaPipeOptions options = 1001;
|
||||||
|
|
||||||
|
// The types and default values for graph options, in proto3 syntax.
|
||||||
|
repeated google.protobuf.Any graph_options = 1002;
|
||||||
}
|
}
|
||||||
|
|
|
@ -411,7 +411,8 @@ absl::Status CalculatorGraph::Initialize(
|
||||||
|
|
||||||
absl::Status CalculatorGraph::ObserveOutputStream(
|
absl::Status CalculatorGraph::ObserveOutputStream(
|
||||||
const std::string& stream_name,
|
const std::string& stream_name,
|
||||||
std::function<absl::Status(const Packet&)> packet_callback) {
|
std::function<absl::Status(const Packet&)> packet_callback,
|
||||||
|
bool observe_timestamp_bounds) {
|
||||||
RET_CHECK(initialized_).SetNoLogging()
|
RET_CHECK(initialized_).SetNoLogging()
|
||||||
<< "CalculatorGraph is not initialized.";
|
<< "CalculatorGraph is not initialized.";
|
||||||
// TODO Allow output observers to be attached by graph level
|
// TODO Allow output observers to be attached by graph level
|
||||||
|
@ -425,7 +426,7 @@ absl::Status CalculatorGraph::ObserveOutputStream(
|
||||||
auto observer = absl::make_unique<internal::OutputStreamObserver>();
|
auto observer = absl::make_unique<internal::OutputStreamObserver>();
|
||||||
MP_RETURN_IF_ERROR(observer->Initialize(
|
MP_RETURN_IF_ERROR(observer->Initialize(
|
||||||
stream_name, &any_packet_type_, std::move(packet_callback),
|
stream_name, &any_packet_type_, std::move(packet_callback),
|
||||||
&output_stream_managers_[output_stream_index]));
|
&output_stream_managers_[output_stream_index], observe_timestamp_bounds));
|
||||||
graph_output_streams_.push_back(std::move(observer));
|
graph_output_streams_.push_back(std::move(observer));
|
||||||
return absl::OkStatus();
|
return absl::OkStatus();
|
||||||
}
|
}
|
||||||
|
|
|
@ -157,7 +157,8 @@ class CalculatorGraph {
|
||||||
// TODO: Rename to AddOutputStreamCallback.
|
// TODO: Rename to AddOutputStreamCallback.
|
||||||
absl::Status ObserveOutputStream(
|
absl::Status ObserveOutputStream(
|
||||||
const std::string& stream_name,
|
const std::string& stream_name,
|
||||||
std::function<absl::Status(const Packet&)> packet_callback);
|
std::function<absl::Status(const Packet&)> packet_callback,
|
||||||
|
bool observe_timestamp_bounds = false);
|
||||||
|
|
||||||
// Adds an OutputStreamPoller for a stream. This provides a synchronous,
|
// Adds an OutputStreamPoller for a stream. This provides a synchronous,
|
||||||
// polling API for accessing a stream's output. Should only be called before
|
// polling API for accessing a stream's output. Should only be called before
|
||||||
|
|
|
@ -1518,5 +1518,72 @@ TEST(CalculatorGraphBoundsTest, OffsetAndBound) {
|
||||||
MP_ASSERT_OK(graph.WaitUntilDone());
|
MP_ASSERT_OK(graph.WaitUntilDone());
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// A Calculator that sends empty output stream packets.
|
||||||
|
class EmptyPacketCalculator : public CalculatorBase {
|
||||||
|
public:
|
||||||
|
static absl::Status GetContract(CalculatorContract* cc) {
|
||||||
|
cc->Inputs().Index(0).Set<int>();
|
||||||
|
cc->Outputs().Index(0).SetSameAs(&cc->Inputs().Index(0));
|
||||||
|
return absl::OkStatus();
|
||||||
|
}
|
||||||
|
absl::Status Open(CalculatorContext* cc) final { return absl::OkStatus(); }
|
||||||
|
absl::Status Process(CalculatorContext* cc) final {
|
||||||
|
if (cc->InputTimestamp().Value() % 2 == 0) {
|
||||||
|
cc->Outputs().Index(0).AddPacket(Packet().At(cc->InputTimestamp()));
|
||||||
|
}
|
||||||
|
return absl::OkStatus();
|
||||||
|
}
|
||||||
|
};
|
||||||
|
REGISTER_CALCULATOR(EmptyPacketCalculator);
|
||||||
|
|
||||||
|
// This test shows that an output timestamp bound can be specified by outputing
|
||||||
|
// an empty packet with a settled timestamp.
|
||||||
|
TEST(CalculatorGraphBoundsTest, EmptyPacketOutput) {
|
||||||
|
// OffsetAndBoundCalculator runs on parallel threads and sends ts
|
||||||
|
// occasionally.
|
||||||
|
std::string config_str = R"(
|
||||||
|
input_stream: "input_0"
|
||||||
|
node {
|
||||||
|
calculator: "EmptyPacketCalculator"
|
||||||
|
input_stream: "input_0"
|
||||||
|
output_stream: "empty_0"
|
||||||
|
}
|
||||||
|
node {
|
||||||
|
calculator: "ProcessBoundToPacketCalculator"
|
||||||
|
input_stream: "empty_0"
|
||||||
|
output_stream: "output_0"
|
||||||
|
}
|
||||||
|
)";
|
||||||
|
CalculatorGraphConfig config =
|
||||||
|
mediapipe::ParseTextProtoOrDie<CalculatorGraphConfig>(config_str);
|
||||||
|
CalculatorGraph graph;
|
||||||
|
std::vector<Packet> output_0_packets;
|
||||||
|
MP_ASSERT_OK(graph.Initialize(config));
|
||||||
|
MP_ASSERT_OK(graph.ObserveOutputStream("output_0", [&](const Packet& p) {
|
||||||
|
output_0_packets.push_back(p);
|
||||||
|
return absl::OkStatus();
|
||||||
|
}));
|
||||||
|
MP_ASSERT_OK(graph.StartRun({}));
|
||||||
|
MP_ASSERT_OK(graph.WaitUntilIdle());
|
||||||
|
|
||||||
|
// Send in packets.
|
||||||
|
for (int i = 0; i < 9; ++i) {
|
||||||
|
const int ts = 10 + i * 10;
|
||||||
|
Packet p = MakePacket<int>(i).At(Timestamp(ts));
|
||||||
|
MP_ASSERT_OK(graph.AddPacketToInputStream("input_0", p));
|
||||||
|
MP_ASSERT_OK(graph.WaitUntilIdle());
|
||||||
|
}
|
||||||
|
|
||||||
|
// 9 empty packets are converted to bounds and then to packets.
|
||||||
|
EXPECT_EQ(output_0_packets.size(), 9);
|
||||||
|
for (int i = 0; i < 9; ++i) {
|
||||||
|
EXPECT_EQ(output_0_packets[i].Timestamp(), Timestamp(10 + i * 10));
|
||||||
|
}
|
||||||
|
|
||||||
|
// Shutdown the graph.
|
||||||
|
MP_ASSERT_OK(graph.CloseAllPacketSources());
|
||||||
|
MP_ASSERT_OK(graph.WaitUntilDone());
|
||||||
|
}
|
||||||
|
|
||||||
} // namespace
|
} // namespace
|
||||||
} // namespace mediapipe
|
} // namespace mediapipe
|
||||||
|
|
|
@ -16,11 +16,20 @@
|
||||||
# The dependencies of mediapipe.
|
# The dependencies of mediapipe.
|
||||||
|
|
||||||
load("//mediapipe/framework/port:build_config.bzl", "mediapipe_cc_proto_library")
|
load("//mediapipe/framework/port:build_config.bzl", "mediapipe_cc_proto_library")
|
||||||
|
load("@bazel_skylib//:bzl_library.bzl", "bzl_library")
|
||||||
|
|
||||||
licenses(["notice"])
|
licenses(["notice"])
|
||||||
|
|
||||||
package(default_visibility = ["//visibility:private"])
|
package(default_visibility = ["//visibility:private"])
|
||||||
|
|
||||||
|
bzl_library(
|
||||||
|
name = "expand_template_bzl",
|
||||||
|
srcs = [
|
||||||
|
"expand_template.bzl",
|
||||||
|
],
|
||||||
|
visibility = ["//mediapipe/framework:__subpackages__"],
|
||||||
|
)
|
||||||
|
|
||||||
proto_library(
|
proto_library(
|
||||||
name = "proto_descriptor_proto",
|
name = "proto_descriptor_proto",
|
||||||
srcs = ["proto_descriptor.proto"],
|
srcs = ["proto_descriptor.proto"],
|
||||||
|
|
|
@ -295,6 +295,7 @@ cc_library(
|
||||||
"//mediapipe/framework/formats:image_format_cc_proto",
|
"//mediapipe/framework/formats:image_format_cc_proto",
|
||||||
"@com_google_absl//absl/synchronization",
|
"@com_google_absl//absl/synchronization",
|
||||||
"//mediapipe/framework:port",
|
"//mediapipe/framework:port",
|
||||||
|
"//mediapipe/framework:type_map",
|
||||||
"//mediapipe/framework/port:logging",
|
"//mediapipe/framework/port:logging",
|
||||||
] + select({
|
] + select({
|
||||||
"//conditions:default": [
|
"//conditions:default": [
|
||||||
|
|
|
@ -14,6 +14,8 @@
|
||||||
|
|
||||||
#include "mediapipe/framework/formats/image.h"
|
#include "mediapipe/framework/formats/image.h"
|
||||||
|
|
||||||
|
#include "mediapipe/framework/type_map.h"
|
||||||
|
|
||||||
namespace mediapipe {
|
namespace mediapipe {
|
||||||
|
|
||||||
// TODO Refactor common code from GpuBufferToImageFrameCalculator
|
// TODO Refactor common code from GpuBufferToImageFrameCalculator
|
||||||
|
@ -67,8 +69,7 @@ bool Image::ConvertToGpu() const {
|
||||||
#else
|
#else
|
||||||
if (use_gpu_) return true; // Already on GPU.
|
if (use_gpu_) return true; // Already on GPU.
|
||||||
#if MEDIAPIPE_GPU_BUFFER_USE_CV_PIXEL_BUFFER
|
#if MEDIAPIPE_GPU_BUFFER_USE_CV_PIXEL_BUFFER
|
||||||
auto packet = MakePacket<ImageFrame>(std::move(*image_frame_));
|
auto packet = PointToForeign<ImageFrame>(image_frame_.get());
|
||||||
image_frame_ = nullptr;
|
|
||||||
CFHolder<CVPixelBufferRef> buffer;
|
CFHolder<CVPixelBufferRef> buffer;
|
||||||
auto status = CreateCVPixelBufferForImageFramePacket(packet, true, &buffer);
|
auto status = CreateCVPixelBufferForImageFramePacket(packet, true, &buffer);
|
||||||
CHECK_OK(status);
|
CHECK_OK(status);
|
||||||
|
@ -94,4 +95,7 @@ bool Image::ConvertToGpu() const {
|
||||||
#endif // MEDIAPIPE_DISABLE_GPU
|
#endif // MEDIAPIPE_DISABLE_GPU
|
||||||
}
|
}
|
||||||
|
|
||||||
|
MEDIAPIPE_REGISTER_TYPE(mediapipe::Image, "::mediapipe::Image", nullptr,
|
||||||
|
nullptr);
|
||||||
|
|
||||||
} // namespace mediapipe
|
} // namespace mediapipe
|
||||||
|
|
|
@ -72,8 +72,8 @@ class Image {
|
||||||
|
|
||||||
// Creates an Image representing the same image content as the ImageFrame
|
// Creates an Image representing the same image content as the ImageFrame
|
||||||
// the input shared pointer points to, and retaining shared ownership.
|
// the input shared pointer points to, and retaining shared ownership.
|
||||||
explicit Image(ImageFrameSharedPtr frame_buffer)
|
explicit Image(ImageFrameSharedPtr image_frame)
|
||||||
: image_frame_(std::move(frame_buffer)) {
|
: image_frame_(std::move(image_frame)) {
|
||||||
use_gpu_ = false;
|
use_gpu_ = false;
|
||||||
pixel_mutex_ = std::make_shared<absl::Mutex>();
|
pixel_mutex_ = std::make_shared<absl::Mutex>();
|
||||||
}
|
}
|
||||||
|
|
|
@ -30,6 +30,9 @@
|
||||||
|
|
||||||
namespace mediapipe {
|
namespace mediapipe {
|
||||||
|
|
||||||
|
// Zero and negative values are not checked here.
|
||||||
|
bool IsPowerOfTwo(int v) { return (v & (v - 1)) == 0; }
|
||||||
|
|
||||||
int BhwcBatchFromShape(const Tensor::Shape& shape) {
|
int BhwcBatchFromShape(const Tensor::Shape& shape) {
|
||||||
LOG_IF(FATAL, shape.dims.empty())
|
LOG_IF(FATAL, shape.dims.empty())
|
||||||
<< "Tensor::Shape must be non-empty to retrieve a named dimension";
|
<< "Tensor::Shape must be non-empty to retrieve a named dimension";
|
||||||
|
@ -237,6 +240,12 @@ void Tensor::AllocateOpenGlTexture2d() const {
|
||||||
glTexStorage2D(GL_TEXTURE_2D, 1, GL_RGBA32F, texture_width_,
|
glTexStorage2D(GL_TEXTURE_2D, 1, GL_RGBA32F, texture_width_,
|
||||||
texture_height_);
|
texture_height_);
|
||||||
} else {
|
} else {
|
||||||
|
// GLES2.0 supports only clamp addressing mode for NPOT textures.
|
||||||
|
// If any of dimensions is NPOT then both addressing modes are clamp.
|
||||||
|
if (!IsPowerOfTwo(texture_width_) || !IsPowerOfTwo(texture_height_)) {
|
||||||
|
glTexParameteri(GL_TEXTURE_2D, GL_TEXTURE_WRAP_S, GL_CLAMP_TO_EDGE);
|
||||||
|
glTexParameteri(GL_TEXTURE_2D, GL_TEXTURE_WRAP_T, GL_CLAMP_TO_EDGE);
|
||||||
|
}
|
||||||
// We assume all contexts will have the same extensions, so we only check
|
// We assume all contexts will have the same extensions, so we only check
|
||||||
// once for OES_texture_float extension, to save time.
|
// once for OES_texture_float extension, to save time.
|
||||||
static bool has_oes_extension =
|
static bool has_oes_extension =
|
||||||
|
|
|
@ -14,13 +14,16 @@
|
||||||
|
|
||||||
#include "mediapipe/framework/graph_output_stream.h"
|
#include "mediapipe/framework/graph_output_stream.h"
|
||||||
|
|
||||||
|
#include "absl/synchronization/mutex.h"
|
||||||
|
#include "mediapipe/framework/port/status.h"
|
||||||
|
|
||||||
namespace mediapipe {
|
namespace mediapipe {
|
||||||
|
|
||||||
namespace internal {
|
namespace internal {
|
||||||
|
|
||||||
absl::Status GraphOutputStream::Initialize(
|
absl::Status GraphOutputStream::Initialize(
|
||||||
const std::string& stream_name, const PacketType* packet_type,
|
const std::string& stream_name, const PacketType* packet_type,
|
||||||
OutputStreamManager* output_stream_manager) {
|
OutputStreamManager* output_stream_manager, bool observe_timestamp_bounds) {
|
||||||
RET_CHECK(output_stream_manager);
|
RET_CHECK(output_stream_manager);
|
||||||
|
|
||||||
// Initializes input_stream_handler_ with one input stream as the observer.
|
// Initializes input_stream_handler_ with one input stream as the observer.
|
||||||
|
@ -31,6 +34,7 @@ absl::Status GraphOutputStream::Initialize(
|
||||||
input_stream_handler_ = absl::make_unique<GraphOutputStreamHandler>(
|
input_stream_handler_ = absl::make_unique<GraphOutputStreamHandler>(
|
||||||
tag_map, /*cc_manager=*/nullptr, MediaPipeOptions(),
|
tag_map, /*cc_manager=*/nullptr, MediaPipeOptions(),
|
||||||
/*calculator_run_in_parallel=*/false);
|
/*calculator_run_in_parallel=*/false);
|
||||||
|
input_stream_handler_->SetProcessTimestampBounds(observe_timestamp_bounds);
|
||||||
const CollectionItemId& id = tag_map->BeginId();
|
const CollectionItemId& id = tag_map->BeginId();
|
||||||
input_stream_ = absl::make_unique<InputStreamManager>();
|
input_stream_ = absl::make_unique<InputStreamManager>();
|
||||||
MP_RETURN_IF_ERROR(
|
MP_RETURN_IF_ERROR(
|
||||||
|
@ -52,20 +56,58 @@ void GraphOutputStream::PrepareForRun(
|
||||||
absl::Status OutputStreamObserver::Initialize(
|
absl::Status OutputStreamObserver::Initialize(
|
||||||
const std::string& stream_name, const PacketType* packet_type,
|
const std::string& stream_name, const PacketType* packet_type,
|
||||||
std::function<absl::Status(const Packet&)> packet_callback,
|
std::function<absl::Status(const Packet&)> packet_callback,
|
||||||
OutputStreamManager* output_stream_manager) {
|
OutputStreamManager* output_stream_manager, bool observe_timestamp_bounds) {
|
||||||
RET_CHECK(output_stream_manager);
|
RET_CHECK(output_stream_manager);
|
||||||
|
|
||||||
packet_callback_ = std::move(packet_callback);
|
packet_callback_ = std::move(packet_callback);
|
||||||
|
observe_timestamp_bounds_ = observe_timestamp_bounds;
|
||||||
return GraphOutputStream::Initialize(stream_name, packet_type,
|
return GraphOutputStream::Initialize(stream_name, packet_type,
|
||||||
output_stream_manager);
|
output_stream_manager,
|
||||||
|
observe_timestamp_bounds);
|
||||||
}
|
}
|
||||||
|
|
||||||
absl::Status OutputStreamObserver::Notify() {
|
absl::Status OutputStreamObserver::Notify() {
|
||||||
|
// Lets one thread perform packets notification as much as possible.
|
||||||
|
// Other threads should quit if a thread is already performing notification.
|
||||||
|
{
|
||||||
|
absl::MutexLock l(&mutex_);
|
||||||
|
|
||||||
|
if (notifying_ == false) {
|
||||||
|
notifying_ = true;
|
||||||
|
} else {
|
||||||
|
return absl::OkStatus();
|
||||||
|
}
|
||||||
|
}
|
||||||
while (true) {
|
while (true) {
|
||||||
bool empty;
|
bool empty;
|
||||||
Timestamp min_timestamp = input_stream_->MinTimestampOrBound(&empty);
|
Timestamp min_timestamp = input_stream_->MinTimestampOrBound(&empty);
|
||||||
if (empty) {
|
if (empty) {
|
||||||
break;
|
// Emits an empty packet at timestamp_bound.PreviousAllowedInStream().
|
||||||
|
if (observe_timestamp_bounds_ && min_timestamp < Timestamp::Done()) {
|
||||||
|
Timestamp settled = (min_timestamp == Timestamp::PostStream()
|
||||||
|
? Timestamp::PostStream()
|
||||||
|
: min_timestamp.PreviousAllowedInStream());
|
||||||
|
if (last_processed_ts_ < settled) {
|
||||||
|
MP_RETURN_IF_ERROR(packet_callback_(Packet().At(settled)));
|
||||||
|
last_processed_ts_ = settled;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
// Last check to make sure that the min timestamp or bound doesn't change.
|
||||||
|
// If so, flips notifying_ to false to allow any other threads to perform
|
||||||
|
// notification when new packets/timestamp bounds arrive. Otherwise, in
|
||||||
|
// case of the min timestamp or bound getting updated, jumps to the
|
||||||
|
// beginning of the notification loop for a new iteration.
|
||||||
|
{
|
||||||
|
absl::MutexLock l(&mutex_);
|
||||||
|
Timestamp new_min_timestamp =
|
||||||
|
input_stream_->MinTimestampOrBound(&empty);
|
||||||
|
if (new_min_timestamp == min_timestamp) {
|
||||||
|
notifying_ = false;
|
||||||
|
break;
|
||||||
|
} else {
|
||||||
|
continue;
|
||||||
|
}
|
||||||
|
}
|
||||||
}
|
}
|
||||||
int num_packets_dropped = 0;
|
int num_packets_dropped = 0;
|
||||||
bool stream_is_done = false;
|
bool stream_is_done = false;
|
||||||
|
@ -75,6 +117,7 @@ absl::Status OutputStreamObserver::Notify() {
|
||||||
<< absl::Substitute("Dropped $0 packet(s) on input stream \"$1\".",
|
<< absl::Substitute("Dropped $0 packet(s) on input stream \"$1\".",
|
||||||
num_packets_dropped, input_stream_->Name());
|
num_packets_dropped, input_stream_->Name());
|
||||||
MP_RETURN_IF_ERROR(packet_callback_(packet));
|
MP_RETURN_IF_ERROR(packet_callback_(packet));
|
||||||
|
last_processed_ts_ = min_timestamp;
|
||||||
}
|
}
|
||||||
return absl::OkStatus();
|
return absl::OkStatus();
|
||||||
}
|
}
|
||||||
|
|
|
@ -52,7 +52,8 @@ class GraphOutputStream {
|
||||||
// is not transferred to the graph output stream object.
|
// is not transferred to the graph output stream object.
|
||||||
absl::Status Initialize(const std::string& stream_name,
|
absl::Status Initialize(const std::string& stream_name,
|
||||||
const PacketType* packet_type,
|
const PacketType* packet_type,
|
||||||
OutputStreamManager* output_stream_manager);
|
OutputStreamManager* output_stream_manager,
|
||||||
|
bool observe_timestamp_bounds = false);
|
||||||
|
|
||||||
// Installs callbacks into its GraphOutputStreamHandler.
|
// Installs callbacks into its GraphOutputStreamHandler.
|
||||||
virtual void PrepareForRun(std::function<void()> notification_callback,
|
virtual void PrepareForRun(std::function<void()> notification_callback,
|
||||||
|
@ -99,6 +100,10 @@ class GraphOutputStream {
|
||||||
}
|
}
|
||||||
};
|
};
|
||||||
|
|
||||||
|
bool observe_timestamp_bounds_;
|
||||||
|
absl::Mutex mutex_;
|
||||||
|
bool notifying_ ABSL_GUARDED_BY(mutex_) = false;
|
||||||
|
Timestamp last_processed_ts_ = Timestamp::Unstarted();
|
||||||
std::unique_ptr<InputStreamHandler> input_stream_handler_;
|
std::unique_ptr<InputStreamHandler> input_stream_handler_;
|
||||||
std::unique_ptr<InputStreamManager> input_stream_;
|
std::unique_ptr<InputStreamManager> input_stream_;
|
||||||
};
|
};
|
||||||
|
@ -112,7 +117,8 @@ class OutputStreamObserver : public GraphOutputStream {
|
||||||
absl::Status Initialize(
|
absl::Status Initialize(
|
||||||
const std::string& stream_name, const PacketType* packet_type,
|
const std::string& stream_name, const PacketType* packet_type,
|
||||||
std::function<absl::Status(const Packet&)> packet_callback,
|
std::function<absl::Status(const Packet&)> packet_callback,
|
||||||
OutputStreamManager* output_stream_manager);
|
OutputStreamManager* output_stream_manager,
|
||||||
|
bool observe_timestamp_bounds = false);
|
||||||
|
|
||||||
// Notifies the observer of new packets emitted by the observed
|
// Notifies the observer of new packets emitted by the observed
|
||||||
// output stream.
|
// output stream.
|
||||||
|
@ -128,6 +134,7 @@ class OutputStreamObserver : public GraphOutputStream {
|
||||||
|
|
||||||
// OutputStreamPollerImpl that returns packets to the caller via
|
// OutputStreamPollerImpl that returns packets to the caller via
|
||||||
// Next()/NextBatch().
|
// Next()/NextBatch().
|
||||||
|
// TODO: Support observe_timestamp_bounds.
|
||||||
class OutputStreamPollerImpl : public GraphOutputStream {
|
class OutputStreamPollerImpl : public GraphOutputStream {
|
||||||
public:
|
public:
|
||||||
virtual ~OutputStreamPollerImpl() {}
|
virtual ~OutputStreamPollerImpl() {}
|
||||||
|
|
|
@ -20,6 +20,9 @@ syntax = "proto2";
|
||||||
|
|
||||||
package mediapipe;
|
package mediapipe;
|
||||||
|
|
||||||
|
option java_package = "com.google.mediapipe.proto";
|
||||||
|
option java_outer_classname = "MediaPipeOptionsProto";
|
||||||
|
|
||||||
// Options used by a MediaPipe object.
|
// Options used by a MediaPipe object.
|
||||||
message MediaPipeOptions {
|
message MediaPipeOptions {
|
||||||
extensions 20000 to max;
|
extensions 20000 to max;
|
||||||
|
|
|
@ -101,8 +101,8 @@ Status OutputStreamShard::AddPacketInternal(T&& packet) {
|
||||||
}
|
}
|
||||||
|
|
||||||
if (packet.IsEmpty()) {
|
if (packet.IsEmpty()) {
|
||||||
return mediapipe::InvalidArgumentErrorBuilder(MEDIAPIPE_LOC)
|
SetNextTimestampBound(packet.Timestamp().NextAllowedInStream());
|
||||||
<< "Empty packet sent to stream \"" << Name() << "\".";
|
return absl::OkStatus();
|
||||||
}
|
}
|
||||||
|
|
||||||
const Timestamp timestamp = packet.Timestamp();
|
const Timestamp timestamp = packet.Timestamp();
|
||||||
|
|
|
@ -20,6 +20,7 @@ load(
|
||||||
"mediapipe_binary_graph",
|
"mediapipe_binary_graph",
|
||||||
)
|
)
|
||||||
load("//mediapipe/framework:mediapipe_cc_test.bzl", "mediapipe_cc_test")
|
load("//mediapipe/framework:mediapipe_cc_test.bzl", "mediapipe_cc_test")
|
||||||
|
load("@bazel_skylib//:bzl_library.bzl", "bzl_library")
|
||||||
|
|
||||||
licenses(["notice"])
|
licenses(["notice"])
|
||||||
|
|
||||||
|
@ -29,6 +30,30 @@ exports_files([
|
||||||
"simple_subgraph_template.cc",
|
"simple_subgraph_template.cc",
|
||||||
])
|
])
|
||||||
|
|
||||||
|
bzl_library(
|
||||||
|
name = "mediapipe_graph_bzl",
|
||||||
|
srcs = [
|
||||||
|
"mediapipe_graph.bzl",
|
||||||
|
],
|
||||||
|
visibility = ["//visibility:public"],
|
||||||
|
deps = [
|
||||||
|
":build_defs_bzl",
|
||||||
|
"//mediapipe/framework:encode_binary_proto",
|
||||||
|
"//mediapipe/framework:transitive_protos_bzl",
|
||||||
|
"//mediapipe/framework/deps:expand_template_bzl",
|
||||||
|
],
|
||||||
|
)
|
||||||
|
|
||||||
|
bzl_library(
|
||||||
|
name = "build_defs_bzl",
|
||||||
|
srcs = [
|
||||||
|
"build_defs.bzl",
|
||||||
|
],
|
||||||
|
visibility = [
|
||||||
|
"//mediapipe/framework:__subpackages__",
|
||||||
|
],
|
||||||
|
)
|
||||||
|
|
||||||
cc_library(
|
cc_library(
|
||||||
name = "text_to_binary_graph",
|
name = "text_to_binary_graph",
|
||||||
srcs = ["text_to_binary_graph.cc"],
|
srcs = ["text_to_binary_graph.cc"],
|
||||||
|
@ -744,5 +769,7 @@ cc_test(
|
||||||
|
|
||||||
exports_files(
|
exports_files(
|
||||||
["build_defs.bzl"],
|
["build_defs.bzl"],
|
||||||
visibility = ["//mediapipe/framework:__subpackages__"],
|
visibility = [
|
||||||
|
"//mediapipe/framework:__subpackages__",
|
||||||
|
],
|
||||||
)
|
)
|
||||||
|
|
|
@ -19,6 +19,7 @@
|
||||||
#include "mediapipe/framework/tool/sink.h"
|
#include "mediapipe/framework/tool/sink.h"
|
||||||
|
|
||||||
#include <memory>
|
#include <memory>
|
||||||
|
#include <utility>
|
||||||
#include <vector>
|
#include <vector>
|
||||||
|
|
||||||
#include "absl/strings/str_cat.h"
|
#include "absl/strings/str_cat.h"
|
||||||
|
@ -168,8 +169,19 @@ void AddMultiStreamCallback(
|
||||||
std::function<void(const std::vector<Packet>&)> callback,
|
std::function<void(const std::vector<Packet>&)> callback,
|
||||||
CalculatorGraphConfig* config,
|
CalculatorGraphConfig* config,
|
||||||
std::pair<std::string, Packet>* side_packet) {
|
std::pair<std::string, Packet>* side_packet) {
|
||||||
|
std::map<std::string, Packet> side_packets;
|
||||||
|
AddMultiStreamCallback(streams, callback, config, &side_packets,
|
||||||
|
/*observe_timestamp_bounds=*/false);
|
||||||
|
*side_packet = *side_packets.begin();
|
||||||
|
}
|
||||||
|
|
||||||
|
void AddMultiStreamCallback(
|
||||||
|
const std::vector<std::string>& streams,
|
||||||
|
std::function<void(const std::vector<Packet>&)> callback,
|
||||||
|
CalculatorGraphConfig* config, std::map<std::string, Packet>* side_packets,
|
||||||
|
bool observe_timestamp_bounds) {
|
||||||
CHECK(config);
|
CHECK(config);
|
||||||
CHECK(side_packet);
|
CHECK(side_packets);
|
||||||
CalculatorGraphConfig::Node* sink_node = config->add_node();
|
CalculatorGraphConfig::Node* sink_node = config->add_node();
|
||||||
const std::string name = GetUnusedNodeName(
|
const std::string name = GetUnusedNodeName(
|
||||||
*config, absl::StrCat("multi_callback_", absl::StrJoin(streams, "_")));
|
*config, absl::StrCat("multi_callback_", absl::StrJoin(streams, "_")));
|
||||||
|
@ -179,15 +191,23 @@ void AddMultiStreamCallback(
|
||||||
sink_node->add_input_stream(stream_name);
|
sink_node->add_input_stream(stream_name);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
if (observe_timestamp_bounds) {
|
||||||
|
const std::string observe_ts_bounds_packet_name = GetUnusedSidePacketName(
|
||||||
|
*config, absl::StrCat(name, "_observe_ts_bounds"));
|
||||||
|
sink_node->add_input_side_packet(absl::StrCat(
|
||||||
|
"OBSERVE_TIMESTAMP_BOUNDS:", observe_ts_bounds_packet_name));
|
||||||
|
InsertIfNotPresent(side_packets, observe_ts_bounds_packet_name,
|
||||||
|
MakePacket<bool>(true));
|
||||||
|
}
|
||||||
const std::string input_side_packet_name =
|
const std::string input_side_packet_name =
|
||||||
GetUnusedSidePacketName(*config, absl::StrCat(name, "_callback"));
|
GetUnusedSidePacketName(*config, absl::StrCat(name, "_callback"));
|
||||||
side_packet->first = input_side_packet_name;
|
|
||||||
sink_node->add_input_side_packet(
|
sink_node->add_input_side_packet(
|
||||||
absl::StrCat("VECTOR_CALLBACK:", input_side_packet_name));
|
absl::StrCat("VECTOR_CALLBACK:", input_side_packet_name));
|
||||||
|
|
||||||
side_packet->second =
|
InsertIfNotPresent(
|
||||||
|
side_packets, input_side_packet_name,
|
||||||
MakePacket<std::function<void(const std::vector<Packet>&)>>(
|
MakePacket<std::function<void(const std::vector<Packet>&)>>(
|
||||||
std::move(callback));
|
std::move(callback)));
|
||||||
}
|
}
|
||||||
|
|
||||||
void AddCallbackWithHeaderCalculator(const std::string& stream_name,
|
void AddCallbackWithHeaderCalculator(const std::string& stream_name,
|
||||||
|
@ -240,6 +260,10 @@ absl::Status CallbackCalculator::GetContract(CalculatorContract* cc) {
|
||||||
return mediapipe::InvalidArgumentErrorBuilder(MEDIAPIPE_LOC)
|
return mediapipe::InvalidArgumentErrorBuilder(MEDIAPIPE_LOC)
|
||||||
<< "InputSidePackets must use tags.";
|
<< "InputSidePackets must use tags.";
|
||||||
}
|
}
|
||||||
|
if (cc->InputSidePackets().HasTag("OBSERVE_TIMESTAMP_BOUNDS")) {
|
||||||
|
cc->InputSidePackets().Tag("OBSERVE_TIMESTAMP_BOUNDS").Set<bool>();
|
||||||
|
cc->SetProcessTimestampBounds(true);
|
||||||
|
}
|
||||||
|
|
||||||
int count = allow_multiple_streams ? cc->Inputs().NumEntries("") : 1;
|
int count = allow_multiple_streams ? cc->Inputs().NumEntries("") : 1;
|
||||||
for (int i = 0; i < count; ++i) {
|
for (int i = 0; i < count; ++i) {
|
||||||
|
@ -266,6 +290,12 @@ absl::Status CallbackCalculator::Open(CalculatorContext* cc) {
|
||||||
return mediapipe::InvalidArgumentErrorBuilder(MEDIAPIPE_LOC)
|
return mediapipe::InvalidArgumentErrorBuilder(MEDIAPIPE_LOC)
|
||||||
<< "missing callback.";
|
<< "missing callback.";
|
||||||
}
|
}
|
||||||
|
if (cc->InputSidePackets().HasTag("OBSERVE_TIMESTAMP_BOUNDS") &&
|
||||||
|
!cc->InputSidePackets().Tag("OBSERVE_TIMESTAMP_BOUNDS").Get<bool>()) {
|
||||||
|
return mediapipe::InvalidArgumentErrorBuilder(MEDIAPIPE_LOC)
|
||||||
|
<< "The value of the OBSERVE_TIMESTAMP_BOUNDS input side packet "
|
||||||
|
"must be set to true";
|
||||||
|
}
|
||||||
return absl::OkStatus();
|
return absl::OkStatus();
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
|
@ -115,6 +115,12 @@ void AddMultiStreamCallback(
|
||||||
std::function<void(const std::vector<Packet>&)> callback,
|
std::function<void(const std::vector<Packet>&)> callback,
|
||||||
CalculatorGraphConfig* config, std::pair<std::string, Packet>* side_packet);
|
CalculatorGraphConfig* config, std::pair<std::string, Packet>* side_packet);
|
||||||
|
|
||||||
|
void AddMultiStreamCallback(
|
||||||
|
const std::vector<std::string>& streams,
|
||||||
|
std::function<void(const std::vector<Packet>&)> callback,
|
||||||
|
CalculatorGraphConfig* config, std::map<std::string, Packet>* side_packets,
|
||||||
|
bool observe_timestamp_bounds = false);
|
||||||
|
|
||||||
// Add a CallbackWithHeaderCalculator to intercept packets sent on
|
// Add a CallbackWithHeaderCalculator to intercept packets sent on
|
||||||
// stream stream_name, and the header packet on stream stream_header.
|
// stream stream_name, and the header packet on stream stream_header.
|
||||||
// The input side packet with the produced name callback_side_packet_name
|
// The input side packet with the produced name callback_side_packet_name
|
||||||
|
|
|
@ -146,5 +146,63 @@ TEST(CallbackTest, TestAddMultiStreamCallback) {
|
||||||
EXPECT_THAT(sums, testing::ElementsAre(15, 7, 9));
|
EXPECT_THAT(sums, testing::ElementsAre(15, 7, 9));
|
||||||
}
|
}
|
||||||
|
|
||||||
|
class TimestampBoundTestCalculator : public CalculatorBase {
|
||||||
|
public:
|
||||||
|
static absl::Status GetContract(CalculatorContract* cc) {
|
||||||
|
cc->Outputs().Index(0).Set<int>();
|
||||||
|
cc->Outputs().Index(1).Set<int>();
|
||||||
|
return absl::OkStatus();
|
||||||
|
}
|
||||||
|
absl::Status Open(CalculatorContext* cc) final { return absl::OkStatus(); }
|
||||||
|
absl::Status Process(CalculatorContext* cc) final {
|
||||||
|
if (count_ % 5 == 0) {
|
||||||
|
cc->Outputs().Index(0).SetNextTimestampBound(Timestamp(count_ + 1));
|
||||||
|
cc->Outputs().Index(1).SetNextTimestampBound(Timestamp(count_ + 1));
|
||||||
|
}
|
||||||
|
++count_;
|
||||||
|
if (count_ == 13) {
|
||||||
|
return tool::StatusStop();
|
||||||
|
}
|
||||||
|
return absl::OkStatus();
|
||||||
|
}
|
||||||
|
|
||||||
|
private:
|
||||||
|
int count_ = 1;
|
||||||
|
};
|
||||||
|
REGISTER_CALCULATOR(TimestampBoundTestCalculator);
|
||||||
|
|
||||||
|
TEST(CallbackTest, TestAddMultiStreamCallbackWithTimestampNotification) {
|
||||||
|
std::string config_str = R"(
|
||||||
|
node {
|
||||||
|
calculator: "TimestampBoundTestCalculator"
|
||||||
|
output_stream: "foo"
|
||||||
|
output_stream: "bar"
|
||||||
|
}
|
||||||
|
)";
|
||||||
|
CalculatorGraphConfig graph_config =
|
||||||
|
mediapipe::ParseTextProtoOrDie<CalculatorGraphConfig>(config_str);
|
||||||
|
|
||||||
|
std::vector<int> sums;
|
||||||
|
|
||||||
|
std::map<std::string, Packet> side_packets;
|
||||||
|
tool::AddMultiStreamCallback(
|
||||||
|
{"foo", "bar"},
|
||||||
|
[&sums](const std::vector<Packet>& packets) {
|
||||||
|
Packet foo_p = packets[0];
|
||||||
|
Packet bar_p = packets[1];
|
||||||
|
ASSERT_TRUE(foo_p.IsEmpty() && bar_p.IsEmpty());
|
||||||
|
int foo = foo_p.Timestamp().Value();
|
||||||
|
int bar = bar_p.Timestamp().Value();
|
||||||
|
sums.push_back(foo + bar);
|
||||||
|
},
|
||||||
|
&graph_config, &side_packets, true);
|
||||||
|
|
||||||
|
CalculatorGraph graph(graph_config);
|
||||||
|
MP_ASSERT_OK(graph.StartRun(side_packets));
|
||||||
|
MP_ASSERT_OK(graph.WaitUntilDone());
|
||||||
|
|
||||||
|
EXPECT_THAT(sums, testing::ElementsAre(10, 20));
|
||||||
|
}
|
||||||
|
|
||||||
} // namespace
|
} // namespace
|
||||||
} // namespace mediapipe
|
} // namespace mediapipe
|
||||||
|
|
|
@ -14,7 +14,7 @@
|
||||||
|
|
||||||
load("//mediapipe/gpu:metal.bzl", "metal_library")
|
load("//mediapipe/gpu:metal.bzl", "metal_library")
|
||||||
load("@build_bazel_rules_apple//apple:ios.bzl", "ios_unit_test")
|
load("@build_bazel_rules_apple//apple:ios.bzl", "ios_unit_test")
|
||||||
load("//mediapipe/framework/port:build_config.bzl", "mediapipe_cc_proto_library")
|
load("//mediapipe/framework/port:build_config.bzl", "mediapipe_cc_proto_library", "mediapipe_proto_library")
|
||||||
load("//mediapipe/framework:mediapipe_cc_test.bzl", "mediapipe_cc_test")
|
load("//mediapipe/framework:mediapipe_cc_test.bzl", "mediapipe_cc_test")
|
||||||
|
|
||||||
licenses(["notice"])
|
licenses(["notice"])
|
||||||
|
@ -240,6 +240,12 @@ cc_library(
|
||||||
],
|
],
|
||||||
)
|
)
|
||||||
|
|
||||||
|
mediapipe_proto_library(
|
||||||
|
name = "gpu_origin_proto",
|
||||||
|
srcs = ["gpu_origin.proto"],
|
||||||
|
visibility = ["//visibility:public"],
|
||||||
|
)
|
||||||
|
|
||||||
objc_library(
|
objc_library(
|
||||||
name = "pixel_buffer_pool_util",
|
name = "pixel_buffer_pool_util",
|
||||||
srcs = ["pixel_buffer_pool_util.mm"],
|
srcs = ["pixel_buffer_pool_util.mm"],
|
||||||
|
@ -460,6 +466,8 @@ cc_library(
|
||||||
"//mediapipe/framework:calculator_context",
|
"//mediapipe/framework:calculator_context",
|
||||||
"//mediapipe/framework:calculator_node",
|
"//mediapipe/framework:calculator_node",
|
||||||
"//mediapipe/framework/port:logging",
|
"//mediapipe/framework/port:logging",
|
||||||
|
"//mediapipe/util:resource_cache",
|
||||||
|
"@com_google_absl//absl/hash",
|
||||||
"@com_google_absl//absl/memory",
|
"@com_google_absl//absl/memory",
|
||||||
"@com_google_absl//absl/synchronization",
|
"@com_google_absl//absl/synchronization",
|
||||||
] + select({
|
] + select({
|
||||||
|
@ -760,8 +768,10 @@ cc_library(
|
||||||
deps = [
|
deps = [
|
||||||
":gl_calculator_helper",
|
":gl_calculator_helper",
|
||||||
":gl_quad_renderer",
|
":gl_quad_renderer",
|
||||||
|
":gpu_buffer",
|
||||||
":shader_util",
|
":shader_util",
|
||||||
"//mediapipe/framework:calculator_framework",
|
"//mediapipe/framework:calculator_framework",
|
||||||
|
"//mediapipe/framework/api2:node",
|
||||||
"//mediapipe/framework/port:ret_check",
|
"//mediapipe/framework/port:ret_check",
|
||||||
"//mediapipe/framework/port:status",
|
"//mediapipe/framework/port:status",
|
||||||
"//mediapipe/gpu:gl_surface_sink_calculator_cc_proto",
|
"//mediapipe/gpu:gl_surface_sink_calculator_cc_proto",
|
||||||
|
|
|
@ -563,8 +563,14 @@ class GlFenceSyncPoint : public GlSyncPoint {
|
||||||
|
|
||||||
void WaitOnGpu() override {
|
void WaitOnGpu() override {
|
||||||
if (!sync_) return;
|
if (!sync_) return;
|
||||||
// TODO: do not wait if we are already on the same context?
|
// TODO: do not wait if we are already on the same context?
|
||||||
|
// WebGL2 specifies a waitSync call, but since cross-context
|
||||||
|
// synchronization is not supported, it's actually a no-op. Firefox prints
|
||||||
|
// a warning when it's called, so let's just skip the call. See
|
||||||
|
// b/184637485 for details.
|
||||||
|
#ifndef __EMSCRIPTEN__
|
||||||
glWaitSync(sync_, 0, GL_TIMEOUT_IGNORED);
|
glWaitSync(sync_, 0, GL_TIMEOUT_IGNORED);
|
||||||
|
#endif
|
||||||
}
|
}
|
||||||
|
|
||||||
bool IsReady() override {
|
bool IsReady() override {
|
||||||
|
|