Project import generated by Copybara.

GitOrigin-RevId: ff83882955f1a1e2a043ff4e71278be9d7217bbe
This commit is contained in:
MediaPipe Team 2021-05-04 18:30:15 -07:00 committed by chuoling
parent ecb5b5f44a
commit a9b643e0f5
210 changed files with 5312 additions and 3838 deletions

View File

@ -23,6 +23,7 @@ ENV DEBIAN_FRONTEND=noninteractive
RUN apt-get update && apt-get install -y --no-install-recommends \ RUN apt-get update && apt-get install -y --no-install-recommends \
build-essential \ build-essential \
gcc-8 g++-8 \
ca-certificates \ ca-certificates \
curl \ curl \
ffmpeg \ ffmpeg \
@ -44,6 +45,7 @@ RUN apt-get update && apt-get install -y --no-install-recommends \
apt-get clean && \ apt-get clean && \
rm -rf /var/lib/apt/lists/* rm -rf /var/lib/apt/lists/*
RUN update-alternatives --install /usr/bin/gcc gcc /usr/bin/gcc-8 100 --slave /usr/bin/g++ g++ /usr/bin/g++-8
RUN pip3 install --upgrade setuptools RUN pip3 install --upgrade setuptools
RUN pip3 install wheel RUN pip3 install wheel
RUN pip3 install future RUN pip3 install future

View File

@ -337,6 +337,8 @@ maven_install(
"androidx.test.espresso:espresso-core:3.1.1", "androidx.test.espresso:espresso-core:3.1.1",
"com.github.bumptech.glide:glide:4.11.0", "com.github.bumptech.glide:glide:4.11.0",
"com.google.android.material:material:aar:1.0.0-rc01", "com.google.android.material:material:aar:1.0.0-rc01",
"com.google.auto.value:auto-value:1.6.4",
"com.google.auto.value:auto-value-annotations:1.6.4",
"com.google.code.findbugs:jsr305:3.0.2", "com.google.code.findbugs:jsr305:3.0.2",
"com.google.flogger:flogger-system-backend:0.3.1", "com.google.flogger:flogger-system-backend:0.3.1",
"com.google.flogger:flogger:0.3.1", "com.google.flogger:flogger:0.3.1",
@ -367,9 +369,9 @@ http_archive(
) )
# Tensorflow repo should always go after the other external dependencies. # Tensorflow repo should always go after the other external dependencies.
# 2021-03-25 # 2021-04-30
_TENSORFLOW_GIT_COMMIT = "c67f68021824410ebe9f18513b8856ac1c6d4887" _TENSORFLOW_GIT_COMMIT = "5bd3c57ef184543d22e34e36cff9d9bea608e06d"
_TENSORFLOW_SHA256= "fd07d0b39422dc435e268c5e53b2646a8b4b1e3151b87837b43f86068faae87f" _TENSORFLOW_SHA256= "9a45862834221aafacf6fb275f92b3876bc89443cbecc51be93f13839a6609f0"
http_archive( http_archive(
name = "org_tensorflow", name = "org_tensorflow",
urls = [ urls = [

View File

@ -17,15 +17,15 @@
# Script to build/run all MediaPipe desktop example apps (with webcam input). # Script to build/run all MediaPipe desktop example apps (with webcam input).
# #
# To build and run all apps and store them in out_dir: # To build and run all apps and store them in out_dir:
# $ ./build_ios_examples.sh -d out_dir # $ ./build_desktop_examples.sh -d out_dir
# Omitting -d and the associated directory saves all generated apps in the # Omitting -d and the associated directory saves all generated apps in the
# current directory. # current directory.
# To build all apps and store them in out_dir: # To build all apps and store them in out_dir:
# $ ./build_ios_examples.sh -d out_dir -b # $ ./build_desktop_examples.sh -d out_dir -b
# Omitting -d and the associated directory saves all generated apps in the # Omitting -d and the associated directory saves all generated apps in the
# current directory. # current directory.
# To run all apps already stored in out_dir: # To run all apps already stored in out_dir:
# $ ./build_ios_examples.sh -d out_dir -r # $ ./build_desktop_examples.sh -d out_dir -r
# Omitting -d and the associated directory assumes all apps are in the current # Omitting -d and the associated directory assumes all apps are in the current
# directory. # directory.

View File

@ -187,7 +187,7 @@ node {
``` ```
In the calculator implementation, inputs and outputs are also identified by tag In the calculator implementation, inputs and outputs are also identified by tag
name and index number. In the function below input are output are identified: name and index number. In the function below input and output are identified:
* By index number: The combined input stream is identified simply by index * By index number: The combined input stream is identified simply by index
`0`. `0`.
@ -355,7 +355,6 @@ class PacketClonerCalculator : public CalculatorBase {
current_[i].At(cc->InputTimestamp())); current_[i].At(cc->InputTimestamp()));
// Add a packet to output stream of index i a packet from inputstream i // Add a packet to output stream of index i a packet from inputstream i
// with timestamp common to all present inputs // with timestamp common to all present inputs
//
} else { } else {
cc->Outputs().Index(i).SetNextTimestampBound( cc->Outputs().Index(i).SetNextTimestampBound(
cc->InputTimestamp().NextAllowedInStream()); cc->InputTimestamp().NextAllowedInStream());
@ -382,7 +381,7 @@ defined your calculator class, register it with a macro invocation
REGISTER_CALCULATOR(calculator_class_name). REGISTER_CALCULATOR(calculator_class_name).
Below is a trivial MediaPipe graph that has 3 input streams, 1 node Below is a trivial MediaPipe graph that has 3 input streams, 1 node
(PacketClonerCalculator) and 3 output streams. (PacketClonerCalculator) and 2 output streams.
```proto ```proto
input_stream: "room_mic_signal" input_stream: "room_mic_signal"

View File

@ -83,12 +83,12 @@ Below is an example of how to create a subgraph named `TwoPassThroughSubgraph`.
output_stream: "out3" output_stream: "out3"
node { node {
calculator: "PassThroughculator" calculator: "PassThroughCalculator"
input_stream: "out1" input_stream: "out1"
output_stream: "out2" output_stream: "out2"
} }
node { node {
calculator: "PassThroughculator" calculator: "PassThroughCalculator"
input_stream: "out2" input_stream: "out2"
output_stream: "out3" output_stream: "out3"
} }

View File

@ -57,7 +57,7 @@ Please verify all the necessary packages are installed.
* Android SDK Build-Tools 28 or 29 * Android SDK Build-Tools 28 or 29
* Android SDK Platform-Tools 28 or 29 * Android SDK Platform-Tools 28 or 29
* Android SDK Tools 26.1.1 * Android SDK Tools 26.1.1
* Android NDK 17c or above * Android NDK 19c or above
### Option 1: Build with Bazel in Command Line ### Option 1: Build with Bazel in Command Line
@ -111,7 +111,7 @@ app:
* Verify that Android SDK Build-Tools 28 or 29 is installed. * Verify that Android SDK Build-Tools 28 or 29 is installed.
* Verify that Android SDK Platform-Tools 28 or 29 is installed. * Verify that Android SDK Platform-Tools 28 or 29 is installed.
* Verify that Android SDK Tools 26.1.1 is installed. * Verify that Android SDK Tools 26.1.1 is installed.
* Verify that Android NDK 17c or above is installed. * Verify that Android NDK 19c or above is installed.
* Take note of the Android NDK Location, e.g., * Take note of the Android NDK Location, e.g.,
`/usr/local/home/Android/Sdk/ndk-bundle` or `/usr/local/home/Android/Sdk/ndk-bundle` or
`/usr/local/home/Android/Sdk/ndk/20.0.5594570`. `/usr/local/home/Android/Sdk/ndk/20.0.5594570`.

View File

@ -37,7 +37,7 @@ each project.
load("//mediapipe/java/com/google/mediapipe:mediapipe_aar.bzl", "mediapipe_aar") load("//mediapipe/java/com/google/mediapipe:mediapipe_aar.bzl", "mediapipe_aar")
mediapipe_aar( mediapipe_aar(
name = "mp_face_detection_aar", name = "mediapipe_face_detection",
calculators = ["//mediapipe/graphs/face_detection:mobile_calculators"], calculators = ["//mediapipe/graphs/face_detection:mobile_calculators"],
) )
``` ```
@ -45,26 +45,29 @@ each project.
2. Run the Bazel build command to generate the AAR. 2. Run the Bazel build command to generate the AAR.
```bash ```bash
bazel build -c opt --host_crosstool_top=@bazel_tools//tools/cpp:toolchain \ bazel build -c opt --strip=ALWAYS \
--fat_apk_cpu=arm64-v8a,armeabi-v7a --strip=ALWAYS \ --host_crosstool_top=@bazel_tools//tools/cpp:toolchain \
//path/to/the/aar/build/file:aar_name --fat_apk_cpu=arm64-v8a,armeabi-v7a \
//path/to/the/aar/build/file:aar_name.aar
``` ```
For the face detection AAR target we made in the step 1, run: For the face detection AAR target we made in step 1, run:
```bash ```bash
bazel build -c opt --host_crosstool_top=@bazel_tools//tools/cpp:toolchain --fat_apk_cpu=arm64-v8a,armeabi-v7a \ bazel build -c opt --strip=ALWAYS \
//mediapipe/examples/android/src/java/com/google/mediapipe/apps/aar_example:mp_face_detection_aar --host_crosstool_top=@bazel_tools//tools/cpp:toolchain \
--fat_apk_cpu=arm64-v8a,armeabi-v7a \
//mediapipe/examples/android/src/java/com/google/mediapipe/apps/aar_example:mediapipe_face_detection.aar
# It should print: # It should print:
# Target //mediapipe/examples/android/src/java/com/google/mediapipe/apps/aar_example:mp_face_detection_aar up-to-date: # Target //mediapipe/examples/android/src/java/com/google/mediapipe/apps/aar_example:mediapipe_face_detection.aar up-to-date:
# bazel-bin/mediapipe/examples/android/src/java/com/google/mediapipe/apps/aar_example/mp_face_detection_aar.aar # bazel-bin/mediapipe/examples/android/src/java/com/google/mediapipe/apps/aar_example/mediapipe_face_detection.aar
``` ```
3. (Optional) Save the AAR to your preferred location. 3. (Optional) Save the AAR to your preferred location.
```bash ```bash
cp bazel-bin/mediapipe/examples/android/src/java/com/google/mediapipe/apps/aar_example/mp_face_detection_aar.aar cp bazel-bin/mediapipe/examples/android/src/java/com/google/mediapipe/apps/aar_example/mediapipe_face_detection.aar
/absolute/path/to/your/preferred/location /absolute/path/to/your/preferred/location
``` ```
@ -75,7 +78,7 @@ each project.
2. Copy the AAR into app/libs. 2. Copy the AAR into app/libs.
```bash ```bash
cp bazel-bin/mediapipe/examples/android/src/java/com/google/mediapipe/apps/aar_example/mp_face_detection_aar.aar cp bazel-bin/mediapipe/examples/android/src/java/com/google/mediapipe/apps/aar_example/mediapipe_face_detection.aar
/path/to/your/app/libs/ /path/to/your/app/libs/
``` ```
@ -92,29 +95,14 @@ each project.
[the face detection tflite model](https://github.com/google/mediapipe/tree/master/mediapipe/modules/face_detection/face_detection_front.tflite). [the face detection tflite model](https://github.com/google/mediapipe/tree/master/mediapipe/modules/face_detection/face_detection_front.tflite).
```bash ```bash
bazel build -c opt mediapipe/mediapipe/graphs/face_detection:mobile_gpu_binary_graph bazel build -c opt mediapipe/graphs/face_detection:face_detection_mobile_gpu_binary_graph
cp bazel-bin/mediapipe/graphs/face_detection/mobile_gpu.binarypb /path/to/your/app/src/main/assets/ cp bazel-bin/mediapipe/graphs/face_detection/face_detection_mobile_gpu.binarypb /path/to/your/app/src/main/assets/
cp mediapipe/modules/face_detection/face_detection_front.tflite /path/to/your/app/src/main/assets/ cp mediapipe/modules/face_detection/face_detection_front.tflite /path/to/your/app/src/main/assets/
``` ```
![Screenshot](../images/mobile/assets_location.png) ![Screenshot](../images/mobile/assets_location.png)
4. Make app/src/main/jniLibs and copy OpenCV JNI libraries into 4. Modify app/build.gradle to add MediaPipe dependencies and MediaPipe AAR.
app/src/main/jniLibs.
MediaPipe depends on OpenCV, you will need to copy the precompiled OpenCV so
files into app/src/main/jniLibs. You can download the official OpenCV
Android SDK from
[here](https://github.com/opencv/opencv/releases/download/3.4.3/opencv-3.4.3-android-sdk.zip)
and run:
```bash
cp -R ~/Downloads/OpenCV-android-sdk/sdk/native/libs/arm* /path/to/your/app/src/main/jniLibs/
```
![Screenshot](../images/mobile/android_studio_opencv_location.png)
5. Modify app/build.gradle to add MediaPipe dependencies and MediaPipe AAR.
``` ```
dependencies { dependencies {
@ -136,10 +124,14 @@ each project.
implementation "androidx.camera:camera-core:$camerax_version" implementation "androidx.camera:camera-core:$camerax_version"
implementation "androidx.camera:camera-camera2:$camerax_version" implementation "androidx.camera:camera-camera2:$camerax_version"
implementation "androidx.camera:camera-lifecycle:$camerax_version" implementation "androidx.camera:camera-lifecycle:$camerax_version"
// AutoValue
def auto_value_version = "1.6.4"
implementation "com.google.auto.value:auto-value-annotations:$auto_value_version"
annotationProcessor "com.google.auto.value:auto-value:$auto_value_version"
} }
``` ```
6. Follow our Android app examples to use MediaPipe in Android Studio for your 5. Follow our Android app examples to use MediaPipe in Android Studio for your
use case. If you are looking for an example, a face detection example can be use case. If you are looking for an example, a face detection example can be
found found
[here](https://github.com/jiuqiant/mediapipe_face_detection_aar_example) and [here](https://github.com/jiuqiant/mediapipe_face_detection_aar_example) and

View File

@ -471,7 +471,7 @@ next section.
4. Install Visual C++ Build Tools 2019 and WinSDK 4. Install Visual C++ Build Tools 2019 and WinSDK
Go to Go to
[the VisualStudio website](ttps://visualstudio.microsoft.com/visual-cpp-build-tools), [the VisualStudio website](https://visualstudio.microsoft.com/visual-cpp-build-tools),
download build tools, and install Microsoft Visual C++ 2019 Redistributable download build tools, and install Microsoft Visual C++ 2019 Redistributable
and Microsoft Build Tools 2019. and Microsoft Build Tools 2019.
@ -738,7 +738,7 @@ common build issues.
root@bca08b91ff63:/mediapipe# bash ./setup_android_sdk_and_ndk.sh root@bca08b91ff63:/mediapipe# bash ./setup_android_sdk_and_ndk.sh
# Should print: # Should print:
# Android NDK is now installed. Consider setting $ANDROID_NDK_HOME environment variable to be /root/Android/Sdk/ndk-bundle/android-ndk-r18b # Android NDK is now installed. Consider setting $ANDROID_NDK_HOME environment variable to be /root/Android/Sdk/ndk-bundle/android-ndk-r19c
# Set android_ndk_repository and android_sdk_repository in WORKSPACE # Set android_ndk_repository and android_sdk_repository in WORKSPACE
# Done # Done

View File

@ -26,7 +26,7 @@ You can, for instance, activate a Python virtual environment:
$ python3 -m venv mp_env && source mp_env/bin/activate $ python3 -m venv mp_env && source mp_env/bin/activate
``` ```
Install MediaPipe Python package and start Python intepreter: Install MediaPipe Python package and start Python interpreter:
```bash ```bash
(mp_env)$ pip install mediapipe (mp_env)$ pip install mediapipe

View File

@ -97,6 +97,49 @@ linux_opencv/macos_opencv/windows_opencv.BUILD files for your local opencv
libraries. [This GitHub issue](https://github.com/google/mediapipe/issues/666) libraries. [This GitHub issue](https://github.com/google/mediapipe/issues/666)
may also help. may also help.
## Python pip install failure
The error message:
```
ERROR: Could not find a version that satisfies the requirement mediapipe
ERROR: No matching distribution found for mediapipe
```
after running `pip install mediapipe` usually indicates that there is no qualified MediaPipe Python for your system.
Please note that MediaPipe Python PyPI officially supports the **64-bit**
version of Python 3.7 and above on the following OS:
- x86_64 Linux
- x86_64 macOS 10.15+
- amd64 Windows
If the OS is currently supported and you still see this error, please make sure
that both the Python and pip binary are for Python 3.7 and above. Otherwise,
please consider building the MediaPipe Python package locally by following the
instructions [here](python.md#building-mediapipe-python-package).
## Python DLL load failure on Windows
The error message:
```
ImportError: DLL load failed: The specified module could not be found
```
usually indicates that the local Windows system is missing Visual C++
redistributable packages and/or Visual C++ runtime DLLs. This can be solved by
either installing the official
[vc_redist.x64.exe](https://support.microsoft.com/en-us/topic/the-latest-supported-visual-c-downloads-2647da03-1eea-4433-9aff-95f26a218cc0)
or installing the "msvc-runtime" Python package by running
```bash
$ python -m pip install msvc-runtime
```
Please note that the "msvc-runtime" Python package is not released or maintained
by Microsoft.
## Native method not found ## Native method not found
The error message: The error message:

Binary file not shown.

Before

Width:  |  Height:  |  Size: 35 KiB

After

Width:  |  Height:  |  Size: 34 KiB

Binary file not shown.

Before

Width:  |  Height:  |  Size: 75 KiB

Binary file not shown.

Before

Width:  |  Height:  |  Size: 29 KiB

After

Width:  |  Height:  |  Size: 42 KiB

Binary file not shown.

After

Width:  |  Height:  |  Size: 2.3 MiB

Binary file not shown.

Before

Width:  |  Height:  |  Size: 6.9 MiB

View File

@ -77,7 +77,7 @@ Supported configuration options:
```python ```python
import cv2 import cv2
import mediapipe as mp import mediapipe as mp
mp_face_detction = mp.solutions.face_detection mp_face_detection = mp.solutions.face_detection
mp_drawing = mp.solutions.drawing_utils mp_drawing = mp.solutions.drawing_utils
# For static images: # For static images:

View File

@ -135,12 +135,11 @@ another detection until it loses track, on reducing computation and latency. If
set to `true`, person detection runs every input image, ideal for processing a set to `true`, person detection runs every input image, ideal for processing a
batch of static, possibly unrelated, images. Default to `false`. batch of static, possibly unrelated, images. Default to `false`.
#### upper_body_only #### model_complexity
If set to `true`, the solution outputs only the 25 upper-body pose landmarks Complexity of the pose landmark model: `0`, `1` or `2`. Landmark accuracy as
(535 in total) instead of the full set of 33 pose landmarks (543 in total). Note well as inference latency generally go up with the model complexity. Default to
that upper-body-only prediction may be more accurate for use cases where the `1`.
lower-body parts are mostly out of view. Default to `false`.
#### smooth_landmarks #### smooth_landmarks
@ -207,7 +206,7 @@ install MediaPipe Python package, then learn more in the companion
Supported configuration options: Supported configuration options:
* [static_image_mode](#static_image_mode) * [static_image_mode](#static_image_mode)
* [upper_body_only](#upper_body_only) * [model_complexity](#model_complexity)
* [smooth_landmarks](#smooth_landmarks) * [smooth_landmarks](#smooth_landmarks)
* [min_detection_confidence](#min_detection_confidence) * [min_detection_confidence](#min_detection_confidence)
* [min_tracking_confidence](#min_tracking_confidence) * [min_tracking_confidence](#min_tracking_confidence)
@ -219,7 +218,9 @@ mp_drawing = mp.solutions.drawing_utils
mp_holistic = mp.solutions.holistic mp_holistic = mp.solutions.holistic
# For static images: # For static images:
with mp_holistic.Holistic(static_image_mode=True) as holistic: with mp_holistic.Holistic(
static_image_mode=True,
model_complexity=2) as holistic:
for idx, file in enumerate(file_list): for idx, file in enumerate(file_list):
image = cv2.imread(file) image = cv2.imread(file)
image_height, image_width, _ = image.shape image_height, image_width, _ = image.shape
@ -240,8 +241,6 @@ with mp_holistic.Holistic(static_image_mode=True) as holistic:
annotated_image, results.left_hand_landmarks, mp_holistic.HAND_CONNECTIONS) annotated_image, results.left_hand_landmarks, mp_holistic.HAND_CONNECTIONS)
mp_drawing.draw_landmarks( mp_drawing.draw_landmarks(
annotated_image, results.right_hand_landmarks, mp_holistic.HAND_CONNECTIONS) annotated_image, results.right_hand_landmarks, mp_holistic.HAND_CONNECTIONS)
# Use mp_holistic.UPPER_BODY_POSE_CONNECTIONS for drawing below when
# upper_body_only is set to True.
mp_drawing.draw_landmarks( mp_drawing.draw_landmarks(
annotated_image, results.pose_landmarks, mp_holistic.POSE_CONNECTIONS) annotated_image, results.pose_landmarks, mp_holistic.POSE_CONNECTIONS)
cv2.imwrite('/tmp/annotated_image' + str(idx) + '.png', annotated_image) cv2.imwrite('/tmp/annotated_image' + str(idx) + '.png', annotated_image)
@ -291,7 +290,7 @@ and the following usage example.
Supported configuration options: Supported configuration options:
* [upperBodyOnly](#upper_body_only) * [modelComplexity](#model_complexity)
* [smoothLandmarks](#smooth_landmarks) * [smoothLandmarks](#smooth_landmarks)
* [minDetectionConfidence](#min_detection_confidence) * [minDetectionConfidence](#min_detection_confidence)
* [minTrackingConfidence](#min_tracking_confidence) * [minTrackingConfidence](#min_tracking_confidence)
@ -348,7 +347,7 @@ const holistic = new Holistic({locateFile: (file) => {
return `https://cdn.jsdelivr.net/npm/@mediapipe/holistic/${file}`; return `https://cdn.jsdelivr.net/npm/@mediapipe/holistic/${file}`;
}}); }});
holistic.setOptions({ holistic.setOptions({
upperBodyOnly: false, modelComplexity: 1,
smoothLandmarks: true, smoothLandmarks: true,
minDetectionConfidence: 0.5, minDetectionConfidence: 0.5,
minTrackingConfidence: 0.5 minTrackingConfidence: 0.5

View File

@ -15,10 +15,10 @@ nav_order: 30
### [Face Detection](https://google.github.io/mediapipe/solutions/face_detection) ### [Face Detection](https://google.github.io/mediapipe/solutions/face_detection)
* Face detection model for front-facing/selfie camera: * Face detection model for front-facing/selfie camera:
[TFLite model](https://github.com/google/mediapipe/tree/master/mediapipe/models/face_detection_front.tflite), [TFLite model](https://github.com/google/mediapipe/tree/master/mediapipe/modules/face_detection/face_detection_front.tflite),
[TFLite model quantized for EdgeTPU/Coral](https://github.com/google/mediapipe/tree/master/mediapipe/examples/coral/models/face-detector-quantized_edgetpu.tflite) [TFLite model quantized for EdgeTPU/Coral](https://github.com/google/mediapipe/tree/master/mediapipe/examples/coral/models/face-detector-quantized_edgetpu.tflite)
* Face detection model for back-facing camera: * Face detection model for back-facing camera:
[TFLite model ](https://github.com/google/mediapipe/tree/master/mediapipe/models/face_detection_back.tflite) [TFLite model ](https://github.com/google/mediapipe/tree/master/mediapipe/modules/face_detection/face_detection_back.tflite)
* [Model card](https://mediapipe.page.link/blazeface-mc) * [Model card](https://mediapipe.page.link/blazeface-mc)
### [Face Mesh](https://google.github.io/mediapipe/solutions/face_mesh) ### [Face Mesh](https://google.github.io/mediapipe/solutions/face_mesh)
@ -49,10 +49,10 @@ nav_order: 30
* Pose detection model: * Pose detection model:
[TFLite model](https://github.com/google/mediapipe/tree/master/mediapipe/modules/pose_detection/pose_detection.tflite) [TFLite model](https://github.com/google/mediapipe/tree/master/mediapipe/modules/pose_detection/pose_detection.tflite)
* Full-body pose landmark model: * Pose landmark model:
[TFLite model](https://github.com/google/mediapipe/tree/master/mediapipe/modules/pose_landmark/pose_landmark_full_body.tflite) [TFLite model (lite)](https://github.com/google/mediapipe/tree/master/mediapipe/modules/pose_landmark/pose_landmark_lite.tflite),
* Upper-body pose landmark model: [TFLite model (full)](https://github.com/google/mediapipe/tree/master/mediapipe/modules/pose_landmark/pose_landmark_full.tflite),
[TFLite model](https://github.com/google/mediapipe/tree/master/mediapipe/modules/pose_landmark/pose_landmark_upper_body.tflite) [TFLite model (heavy)](https://github.com/google/mediapipe/tree/master/mediapipe/modules/pose_landmark/pose_landmark_heavy.tflite)
* [Model card](https://mediapipe.page.link/blazepose-mc) * [Model card](https://mediapipe.page.link/blazepose-mc)
### [Holistic](https://google.github.io/mediapipe/solutions/holistic) ### [Holistic](https://google.github.io/mediapipe/solutions/holistic)

View File

@ -30,8 +30,7 @@ overlay of digital content and information on top of the physical world in
augmented reality. augmented reality.
MediaPipe Pose is a ML solution for high-fidelity body pose tracking, inferring MediaPipe Pose is a ML solution for high-fidelity body pose tracking, inferring
33 3D landmarks on the whole body (or 25 upper-body landmarks) from RGB video 33 3D landmarks on the whole body from RGB video frames utilizing our
frames utilizing our
[BlazePose](https://ai.googleblog.com/2020/08/on-device-real-time-body-pose-tracking.html) [BlazePose](https://ai.googleblog.com/2020/08/on-device-real-time-body-pose-tracking.html)
research that also powers the research that also powers the
[ML Kit Pose Detection API](https://developers.google.com/ml-kit/vision/pose-detection). [ML Kit Pose Detection API](https://developers.google.com/ml-kit/vision/pose-detection).
@ -40,9 +39,9 @@ environments for inference, whereas our method achieves real-time performance on
most modern [mobile phones](#mobile), [desktops/laptops](#desktop), in most modern [mobile phones](#mobile), [desktops/laptops](#desktop), in
[python](#python-solution-api) and even on the [web](#javascript-solution-api). [python](#python-solution-api) and even on the [web](#javascript-solution-api).
![pose_tracking_upper_body_example.gif](../images/mobile/pose_tracking_upper_body_example.gif) | ![pose_tracking_example.gif](../images/mobile/pose_tracking_example.gif) |
:--------------------------------------------------------------------------------------------: | :----------------------------------------------------------------------: |
*Fig 1. Example of MediaPipe Pose for upper-body pose tracking.* | *Fig 1. Example of MediaPipe Pose for pose tracking.* |
## ML Pipeline ## ML Pipeline
@ -77,6 +76,23 @@ Note: To visualize a graph, copy the graph and paste it into
to visualize its associated subgraphs, please see to visualize its associated subgraphs, please see
[visualizer documentation](../tools/visualizer.md). [visualizer documentation](../tools/visualizer.md).
## Pose Estimation Quality
To evaluate the quality of our [models](./models.md#pose) against other
well-performing publicly available solutions, we use a validation dataset,
consisting of 1k images with diverse Yoga, HIIT, and Dance postures. Each image
contains only a single person located 2-4 meters from the camera. To be
consistent with other solutions, we perform evaluation only for 17 keypoints
from [COCO topology](https://cocodataset.org/#keypoints-2020).
Method | [mAP](https://cocodataset.org/#keypoints-eval) | [PCK@0.2](https://github.com/cbsudux/Human-Pose-Estimation-101) | [FPS](https://en.wikipedia.org/wiki/Frame_rate), Pixel 3 [TFLite GPU](https://www.tensorflow.org/lite/performance/gpu_advanced) | [FPS](https://en.wikipedia.org/wiki/Frame_rate), MacBook Pro (15-inch, 2017)
----------------------------------------------------------------------------------------------------- | ---------------------------------------------: | --------------------------------------------------------------: | ------------------------------------------------------------------------------------------------------------------------------: | ---------------------------------------------------------------------------:
BlazePose.Lite | 49.1 | 91.7 | 49 | 40
BlazePose.Full | 64.5 | 95.8 | 40 | 37
BlazePose.Heavy | 70.9 | 97.0 | 19 | 26
[AlphaPose.ResNet50](https://github.com/MVIG-SJTU/AlphaPose) | 57.6 | 93.1 | N/A | N/A
[Apple Vision](https://developer.apple.com/documentation/vision/detecting_human_body_poses_in_images) | 37.0 | 85.3 | N/A | N/A
## Models ## Models
### Person/pose Detection Model (BlazePose Detector) ### Person/pose Detection Model (BlazePose Detector)
@ -97,11 +113,8 @@ hip midpoints.
### Pose Landmark Model (BlazePose GHUM 3D) ### Pose Landmark Model (BlazePose GHUM 3D)
The landmark model in MediaPipe Pose comes in two versions: a full-body model The landmark model in MediaPipe Pose predicts the location of 33 pose landmarks
that predicts the location of 33 pose landmarks (see figure below), and an (see figure below).
upper-body version that only predicts the first 25. The latter may be more
accurate than the former in scenarios where the lower-body parts are mostly out
of view.
Please find more detail in the Please find more detail in the
[BlazePose Google AI Blog](https://ai.googleblog.com/2020/08/on-device-real-time-body-pose-tracking.html), [BlazePose Google AI Blog](https://ai.googleblog.com/2020/08/on-device-real-time-body-pose-tracking.html),
@ -129,12 +142,11 @@ until it loses track, on reducing computation and latency. If set to `true`,
person detection runs every input image, ideal for processing a batch of static, person detection runs every input image, ideal for processing a batch of static,
possibly unrelated, images. Default to `false`. possibly unrelated, images. Default to `false`.
#### upper_body_only #### model_complexity
If set to `true`, the solution outputs only the 25 upper-body pose landmarks. Complexity of the pose landmark model: `0`, `1` or `2`. Landmark accuracy as
Otherwise, it outputs the full set of 33 pose landmarks. Note that well as inference latency generally go up with the model complexity. Default to
upper-body-only prediction may be more accurate for use cases where the `1`.
lower-body parts are mostly out of view. Default to `false`.
#### smooth_landmarks #### smooth_landmarks
@ -170,9 +182,6 @@ A list of pose landmarks. Each lanmark consists of the following:
being the origin, and the smaller the value the closer the landmark is to being the origin, and the smaller the value the closer the landmark is to
the camera. The magnitude of `z` uses roughly the same scale as `x`. the camera. The magnitude of `z` uses roughly the same scale as `x`.
Note: `z` is predicted only in full-body mode, and should be discarded when
[upper_body_only](#upper_body_only) is `true`.
* `visibility`: A value in `[0.0, 1.0]` indicating the likelihood of the * `visibility`: A value in `[0.0, 1.0]` indicating the likelihood of the
landmark being visible (present and not occluded) in the image. landmark being visible (present and not occluded) in the image.
@ -185,7 +194,7 @@ install MediaPipe Python package, then learn more in the companion
Supported configuration options: Supported configuration options:
* [static_image_mode](#static_image_mode) * [static_image_mode](#static_image_mode)
* [upper_body_only](#upper_body_only) * [model_complexity](#model_complexity)
* [smooth_landmarks](#smooth_landmarks) * [smooth_landmarks](#smooth_landmarks)
* [min_detection_confidence](#min_detection_confidence) * [min_detection_confidence](#min_detection_confidence)
* [min_tracking_confidence](#min_tracking_confidence) * [min_tracking_confidence](#min_tracking_confidence)
@ -198,7 +207,9 @@ mp_pose = mp.solutions.pose
# For static images: # For static images:
with mp_pose.Pose( with mp_pose.Pose(
static_image_mode=True, min_detection_confidence=0.5) as pose: static_image_mode=True,
model_complexity=2,
min_detection_confidence=0.5) as pose:
for idx, file in enumerate(file_list): for idx, file in enumerate(file_list):
image = cv2.imread(file) image = cv2.imread(file)
image_height, image_width, _ = image.shape image_height, image_width, _ = image.shape
@ -214,8 +225,6 @@ with mp_pose.Pose(
) )
# Draw pose landmarks on the image. # Draw pose landmarks on the image.
annotated_image = image.copy() annotated_image = image.copy()
# Use mp_pose.UPPER_BODY_POSE_CONNECTIONS for drawing below when
# upper_body_only is set to True.
mp_drawing.draw_landmarks( mp_drawing.draw_landmarks(
annotated_image, results.pose_landmarks, mp_pose.POSE_CONNECTIONS) annotated_image, results.pose_landmarks, mp_pose.POSE_CONNECTIONS)
cv2.imwrite('/tmp/annotated_image' + str(idx) + '.png', annotated_image) cv2.imwrite('/tmp/annotated_image' + str(idx) + '.png', annotated_image)
@ -259,7 +268,7 @@ and the following usage example.
Supported configuration options: Supported configuration options:
* [upperBodyOnly](#upper_body_only) * [modelComplexity](#model_complexity)
* [smoothLandmarks](#smooth_landmarks) * [smoothLandmarks](#smooth_landmarks)
* [minDetectionConfidence](#min_detection_confidence) * [minDetectionConfidence](#min_detection_confidence)
* [minTrackingConfidence](#min_tracking_confidence) * [minTrackingConfidence](#min_tracking_confidence)
@ -306,7 +315,7 @@ const pose = new Pose({locateFile: (file) => {
return `https://cdn.jsdelivr.net/npm/@mediapipe/pose/${file}`; return `https://cdn.jsdelivr.net/npm/@mediapipe/pose/${file}`;
}}); }});
pose.setOptions({ pose.setOptions({
upperBodyOnly: false, modelComplexity: 1,
smoothLandmarks: true, smoothLandmarks: true,
minDetectionConfidence: 0.5, minDetectionConfidence: 0.5,
minTrackingConfidence: 0.5 minTrackingConfidence: 0.5
@ -347,16 +356,6 @@ to visualize its associated subgraphs, please see
* iOS target: * iOS target:
[`mediapipe/examples/ios/posetrackinggpu:PoseTrackingGpuApp`](http:/mediapipe/examples/ios/posetrackinggpu/BUILD) [`mediapipe/examples/ios/posetrackinggpu:PoseTrackingGpuApp`](http:/mediapipe/examples/ios/posetrackinggpu/BUILD)
#### Upper-body Only
* Graph:
[`mediapipe/graphs/pose_tracking/upper_body_pose_tracking_gpu.pbtxt`](https://github.com/google/mediapipe/tree/master/mediapipe/graphs/pose_tracking/upper_body_pose_tracking_gpu.pbtxt)
* Android target:
[(or download prebuilt ARM64 APK)](https://drive.google.com/file/d/1uKc6T7KSuA0Mlq2URi5YookHu0U3yoh_/view?usp=sharing)
[`mediapipe/examples/android/src/java/com/google/mediapipe/apps/upperbodyposetrackinggpu:upperbodyposetrackinggpu`](https://github.com/google/mediapipe/tree/master/mediapipe/examples/android/src/java/com/google/mediapipe/apps/upperbodyposetrackinggpu/BUILD)
* iOS target:
[`mediapipe/examples/ios/upperbodyposetrackinggpu:UpperBodyPoseTrackingGpuApp`](http:/mediapipe/examples/ios/upperbodyposetrackinggpu/BUILD)
### Desktop ### Desktop
Please first see general instructions for [desktop](../getting_started/cpp.md) Please first see general instructions for [desktop](../getting_started/cpp.md)
@ -375,19 +374,6 @@ on how to build MediaPipe examples.
* Target: * Target:
[`mediapipe/examples/desktop/pose_tracking:pose_tracking_gpu`](https://github.com/google/mediapipe/tree/master/mediapipe/examples/desktop/pose_tracking/BUILD) [`mediapipe/examples/desktop/pose_tracking:pose_tracking_gpu`](https://github.com/google/mediapipe/tree/master/mediapipe/examples/desktop/pose_tracking/BUILD)
#### Upper-body Only
* Running on CPU
* Graph:
[`mediapipe/graphs/pose_tracking/upper_body_pose_tracking_cpu.pbtxt`](https://github.com/google/mediapipe/tree/master/mediapipe/graphs/pose_tracking/upper_body_pose_tracking_cpu.pbtxt)
* Target:
[`mediapipe/examples/desktop/upper_body_pose_tracking:upper_body_pose_tracking_cpu`](https://github.com/google/mediapipe/tree/master/mediapipe/examples/desktop/upper_body_pose_tracking/BUILD)
* Running on GPU
* Graph:
[`mediapipe/graphs/pose_tracking/upper_body_pose_tracking_gpu.pbtxt`](https://github.com/google/mediapipe/tree/master/mediapipe/graphs/pose_tracking/upper_body_pose_tracking_gpu.pbtxt)
* Target:
[`mediapipe/examples/desktop/upper_body_pose_tracking:upper_body_pose_tracking_gpu`](https://github.com/google/mediapipe/tree/master/mediapipe/examples/desktop/upper_body_pose_tracking/BUILD)
## Resources ## Resources
* Google AI Blog: * Google AI Blog:

View File

@ -16,7 +16,6 @@
"mediapipe/examples/ios/objectdetectiongpu/BUILD", "mediapipe/examples/ios/objectdetectiongpu/BUILD",
"mediapipe/examples/ios/objectdetectiontrackinggpu/BUILD", "mediapipe/examples/ios/objectdetectiontrackinggpu/BUILD",
"mediapipe/examples/ios/posetrackinggpu/BUILD", "mediapipe/examples/ios/posetrackinggpu/BUILD",
"mediapipe/examples/ios/upperbodyposetrackinggpu/BUILD",
"mediapipe/framework/BUILD", "mediapipe/framework/BUILD",
"mediapipe/gpu/BUILD", "mediapipe/gpu/BUILD",
"mediapipe/objc/BUILD", "mediapipe/objc/BUILD",
@ -36,7 +35,6 @@
"//mediapipe/examples/ios/objectdetectiongpu:ObjectDetectionGpuApp", "//mediapipe/examples/ios/objectdetectiongpu:ObjectDetectionGpuApp",
"//mediapipe/examples/ios/objectdetectiontrackinggpu:ObjectDetectionTrackingGpuApp", "//mediapipe/examples/ios/objectdetectiontrackinggpu:ObjectDetectionTrackingGpuApp",
"//mediapipe/examples/ios/posetrackinggpu:PoseTrackingGpuApp", "//mediapipe/examples/ios/posetrackinggpu:PoseTrackingGpuApp",
"//mediapipe/examples/ios/upperbodyposetrackinggpu:UpperBodyPoseTrackingGpuApp",
"//mediapipe/objc:mediapipe_framework_ios" "//mediapipe/objc:mediapipe_framework_ios"
], ],
"optionSet" : { "optionSet" : {
@ -105,7 +103,6 @@
"mediapipe/examples/ios/objectdetectioncpu", "mediapipe/examples/ios/objectdetectioncpu",
"mediapipe/examples/ios/objectdetectiongpu", "mediapipe/examples/ios/objectdetectiongpu",
"mediapipe/examples/ios/posetrackinggpu", "mediapipe/examples/ios/posetrackinggpu",
"mediapipe/examples/ios/upperbodyposetrackinggpu",
"mediapipe/framework", "mediapipe/framework",
"mediapipe/framework/deps", "mediapipe/framework/deps",
"mediapipe/framework/formats", "mediapipe/framework/formats",

View File

@ -22,7 +22,6 @@
"mediapipe/examples/ios/objectdetectiongpu", "mediapipe/examples/ios/objectdetectiongpu",
"mediapipe/examples/ios/objectdetectiontrackinggpu", "mediapipe/examples/ios/objectdetectiontrackinggpu",
"mediapipe/examples/ios/posetrackinggpu", "mediapipe/examples/ios/posetrackinggpu",
"mediapipe/examples/ios/upperbodyposetrackinggpu",
"mediapipe/objc" "mediapipe/objc"
], ],
"projectName" : "Mediapipe", "projectName" : "Mediapipe",

View File

@ -451,8 +451,8 @@ cc_library(
) )
cc_library( cc_library(
name = "nonzero_calculator", name = "non_zero_calculator",
srcs = ["nonzero_calculator.cc"], srcs = ["non_zero_calculator.cc"],
visibility = [ visibility = [
"//visibility:public", "//visibility:public",
], ],
@ -464,6 +464,21 @@ cc_library(
alwayslink = 1, alwayslink = 1,
) )
cc_test(
name = "non_zero_calculator_test",
size = "small",
srcs = ["non_zero_calculator_test.cc"],
deps = [
":non_zero_calculator",
"//mediapipe/framework:calculator_framework",
"//mediapipe/framework:calculator_runner",
"//mediapipe/framework:timestamp",
"//mediapipe/framework/port:gtest_main",
"//mediapipe/framework/port:status",
"//mediapipe/framework/tool:validate_type",
],
)
cc_test( cc_test(
name = "mux_calculator_test", name = "mux_calculator_test",
srcs = ["mux_calculator_test.cc"], srcs = ["mux_calculator_test.cc"],
@ -665,6 +680,18 @@ cc_library(
alwayslink = 1, alwayslink = 1,
) )
cc_library(
name = "default_side_packet_calculator",
srcs = ["default_side_packet_calculator.cc"],
visibility = ["//visibility:public"],
deps = [
"//mediapipe/framework:calculator_framework",
"//mediapipe/framework/port:ret_check",
"//mediapipe/framework/port:status",
],
alwayslink = 1,
)
cc_library( cc_library(
name = "side_packet_to_stream_calculator", name = "side_packet_to_stream_calculator",
srcs = ["side_packet_to_stream_calculator.cc"], srcs = ["side_packet_to_stream_calculator.cc"],

View File

@ -0,0 +1,103 @@
// Copyright 2019 The MediaPipe Authors.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "mediapipe/framework/calculator_framework.h"
#include "mediapipe/framework/port/ret_check.h"
#include "mediapipe/framework/port/status.h"
namespace mediapipe {
namespace {
constexpr char kOptionalValueTag[] = "OPTIONAL_VALUE";
constexpr char kDefaultValueTag[] = "DEFAULT_VALUE";
constexpr char kValueTag[] = "VALUE";
} // namespace
// Outputs side packet default value if optional value is not provided.
//
// This calculator utilizes the fact that MediaPipe automatically removes
// optional side packets of the calculator configuration (i.e. OPTIONAL_VALUE).
// And if it happens - returns default value, otherwise - returns optional
// value.
//
// Input:
// OPTIONAL_VALUE (optional) - AnyType (but same type as DEFAULT_VALUE)
// Optional side packet value that is outputted by the calculator as is if
// provided.
//
// DEFAULT_VALUE - AnyType
// Default side pack value that is outputted by the calculator if
// OPTIONAL_VALUE is not provided.
//
// Output:
// VALUE - AnyType (but same type as DEFAULT_VALUE)
// Either OPTIONAL_VALUE (if provided) or DEFAULT_VALUE (otherwise).
//
// Usage example:
// node {
// calculator: "DefaultSidePacketCalculator"
// input_side_packet: "OPTIONAL_VALUE:segmentation_mask_enabled_optional"
// input_side_packet: "DEFAULT_VALUE:segmentation_mask_enabled_default"
// output_side_packet: "VALUE:segmentation_mask_enabled"
// }
class DefaultSidePacketCalculator : public CalculatorBase {
public:
static absl::Status GetContract(CalculatorContract* cc);
absl::Status Open(CalculatorContext* cc) override;
absl::Status Process(CalculatorContext* cc) override;
};
REGISTER_CALCULATOR(DefaultSidePacketCalculator);
absl::Status DefaultSidePacketCalculator::GetContract(CalculatorContract* cc) {
RET_CHECK(cc->InputSidePackets().HasTag(kDefaultValueTag))
<< "Default value must be provided";
cc->InputSidePackets().Tag(kDefaultValueTag).SetAny();
// Optional input side packet can be unspecified. In this case MediaPipe will
// remove it from the calculator config.
if (cc->InputSidePackets().HasTag(kOptionalValueTag)) {
cc->InputSidePackets()
.Tag(kOptionalValueTag)
.SetSameAs(&cc->InputSidePackets().Tag(kDefaultValueTag));
}
RET_CHECK(cc->OutputSidePackets().HasTag(kValueTag));
cc->OutputSidePackets().Tag(kValueTag).SetSameAs(
&cc->InputSidePackets().Tag(kDefaultValueTag));
return absl::OkStatus();
}
absl::Status DefaultSidePacketCalculator::Open(CalculatorContext* cc) {
// If optional value is provided it is returned as the calculator output.
if (cc->InputSidePackets().HasTag(kOptionalValueTag)) {
auto& packet = cc->InputSidePackets().Tag(kOptionalValueTag);
cc->OutputSidePackets().Tag(kValueTag).Set(packet);
return absl::OkStatus();
}
// If no optional value
auto& packet = cc->InputSidePackets().Tag(kDefaultValueTag);
cc->OutputSidePackets().Tag(kValueTag).Set(packet);
return absl::OkStatus();
}
absl::Status DefaultSidePacketCalculator::Process(CalculatorContext* cc) {
return absl::OkStatus();
}
} // namespace mediapipe

View File

@ -23,14 +23,26 @@ namespace api2 {
class NonZeroCalculator : public Node { class NonZeroCalculator : public Node {
public: public:
static constexpr Input<int>::SideFallback kIn{"INPUT"}; static constexpr Input<int>::SideFallback kIn{"INPUT"};
static constexpr Output<int> kOut{"OUTPUT"}; static constexpr Output<int>::Optional kOut{"OUTPUT"};
static constexpr Output<bool>::Optional kBooleanOut{"OUTPUT_BOOL"};
MEDIAPIPE_NODE_CONTRACT(kIn, kOut); MEDIAPIPE_NODE_CONTRACT(kIn, kOut, kBooleanOut);
absl::Status UpdateContract(CalculatorContract* cc) {
RET_CHECK(kOut(cc).IsConnected() || kBooleanOut(cc).IsConnected())
<< "At least one output stream is expected.";
return absl::OkStatus();
}
absl::Status Process(CalculatorContext* cc) final { absl::Status Process(CalculatorContext* cc) final {
if (!kIn(cc).IsEmpty()) { if (!kIn(cc).IsEmpty()) {
auto output = std::make_unique<int>((*kIn(cc) != 0) ? 1 : 0); bool isNonZero = *kIn(cc) != 0;
kOut(cc).Send(std::move(output)); if (kOut(cc).IsConnected()) {
kOut(cc).Send(std::make_unique<int>(isNonZero ? 1 : 0));
}
if (kBooleanOut(cc).IsConnected()) {
kBooleanOut(cc).Send(std::make_unique<bool>(isNonZero));
}
} }
return absl::OkStatus(); return absl::OkStatus();
} }

View File

@ -0,0 +1,93 @@
// Copyright 2021 The MediaPipe Authors.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "mediapipe/framework/calculator_framework.h"
#include "mediapipe/framework/calculator_runner.h"
#include "mediapipe/framework/port/canonical_errors.h"
#include "mediapipe/framework/port/gmock.h"
#include "mediapipe/framework/port/gtest.h"
#include "mediapipe/framework/port/status.h"
#include "mediapipe/framework/port/status_matchers.h"
#include "mediapipe/framework/timestamp.h"
#include "mediapipe/framework/tool/validate_type.h"
namespace mediapipe {
class NonZeroCalculatorTest : public ::testing::Test {
protected:
NonZeroCalculatorTest()
: runner_(
R"pb(
calculator: "NonZeroCalculator"
input_stream: "INPUT:input"
output_stream: "OUTPUT:output"
output_stream: "OUTPUT_BOOL:output_bool"
)pb") {}
void SetInput(const std::vector<int>& inputs) {
int timestamp = 0;
for (const auto input : inputs) {
runner_.MutableInputs()
->Get("INPUT", 0)
.packets.push_back(MakePacket<int>(input).At(Timestamp(timestamp++)));
}
}
std::vector<int> GetOutput() {
std::vector<int> result;
for (const auto output : runner_.Outputs().Get("OUTPUT", 0).packets) {
result.push_back(output.Get<int>());
}
return result;
}
std::vector<bool> GetOutputBool() {
std::vector<bool> result;
for (const auto output : runner_.Outputs().Get("OUTPUT_BOOL", 0).packets) {
result.push_back(output.Get<bool>());
}
return result;
}
CalculatorRunner runner_;
};
TEST_F(NonZeroCalculatorTest, ProducesZeroOutputForZeroInput) {
SetInput({0});
MP_ASSERT_OK(runner_.Run());
EXPECT_THAT(GetOutput(), ::testing::ElementsAre(0));
EXPECT_THAT(GetOutputBool(), ::testing::ElementsAre(false));
}
TEST_F(NonZeroCalculatorTest, ProducesNonZeroOutputForNonZeroInput) {
SetInput({1, 2, 3, -4, 5});
MP_ASSERT_OK(runner_.Run());
EXPECT_THAT(GetOutput(), ::testing::ElementsAre(1, 1, 1, 1, 1));
EXPECT_THAT(GetOutputBool(),
::testing::ElementsAre(true, true, true, true, true));
}
TEST_F(NonZeroCalculatorTest, SwitchesBetweenNonZeroAndZeroOutput) {
SetInput({1, 0, 3, 0, 5});
MP_ASSERT_OK(runner_.Run());
EXPECT_THAT(GetOutput(), ::testing::ElementsAre(1, 0, 1, 0, 1));
EXPECT_THAT(GetOutputBool(),
::testing::ElementsAre(true, false, true, false, true));
}
} // namespace mediapipe

View File

@ -285,7 +285,7 @@ absl::Status ImageCroppingCalculator::RenderGpu(CalculatorContext* cc) {
// Run cropping shader on GPU. // Run cropping shader on GPU.
{ {
gpu_helper_.BindFramebuffer(dst_tex); // GL_TEXTURE0 gpu_helper_.BindFramebuffer(dst_tex);
glActiveTexture(GL_TEXTURE1); glActiveTexture(GL_TEXTURE1);
glBindTexture(src_tex.target(), src_tex.name()); glBindTexture(src_tex.target(), src_tex.name());

View File

@ -546,7 +546,7 @@ absl::Status ImageTransformationCalculator::RenderGpu(CalculatorContext* cc) {
auto dst = gpu_helper_.CreateDestinationTexture(output_width, output_height, auto dst = gpu_helper_.CreateDestinationTexture(output_width, output_height,
input.format()); input.format());
gpu_helper_.BindFramebuffer(dst); // GL_TEXTURE0 gpu_helper_.BindFramebuffer(dst);
glActiveTexture(GL_TEXTURE1); glActiveTexture(GL_TEXTURE1);
glBindTexture(src1.target(), src1.name()); glBindTexture(src1.target(), src1.name());

View File

@ -209,6 +209,9 @@ absl::Status RecolorCalculator::Close(CalculatorContext* cc) {
absl::Status RecolorCalculator::RenderCpu(CalculatorContext* cc) { absl::Status RecolorCalculator::RenderCpu(CalculatorContext* cc) {
if (cc->Inputs().Tag(kMaskCpuTag).IsEmpty()) { if (cc->Inputs().Tag(kMaskCpuTag).IsEmpty()) {
cc->Outputs()
.Tag(kImageFrameTag)
.AddPacket(cc->Inputs().Tag(kImageFrameTag).Value());
return absl::OkStatus(); return absl::OkStatus();
} }
// Get inputs and setup output. // Get inputs and setup output.
@ -270,6 +273,9 @@ absl::Status RecolorCalculator::RenderCpu(CalculatorContext* cc) {
absl::Status RecolorCalculator::RenderGpu(CalculatorContext* cc) { absl::Status RecolorCalculator::RenderGpu(CalculatorContext* cc) {
if (cc->Inputs().Tag(kMaskGpuTag).IsEmpty()) { if (cc->Inputs().Tag(kMaskGpuTag).IsEmpty()) {
cc->Outputs()
.Tag(kGpuBufferTag)
.AddPacket(cc->Inputs().Tag(kGpuBufferTag).Value());
return absl::OkStatus(); return absl::OkStatus();
} }
#if !MEDIAPIPE_DISABLE_GPU #if !MEDIAPIPE_DISABLE_GPU
@ -287,7 +293,7 @@ absl::Status RecolorCalculator::RenderGpu(CalculatorContext* cc) {
// Run recolor shader on GPU. // Run recolor shader on GPU.
{ {
gpu_helper_.BindFramebuffer(dst_tex); // GL_TEXTURE0 gpu_helper_.BindFramebuffer(dst_tex);
glActiveTexture(GL_TEXTURE1); glActiveTexture(GL_TEXTURE1);
glBindTexture(img_tex.target(), img_tex.name()); glBindTexture(img_tex.target(), img_tex.name());

View File

@ -323,7 +323,7 @@ absl::Status SetAlphaCalculator::RenderGpu(CalculatorContext* cc) {
const auto& alpha_mask = const auto& alpha_mask =
cc->Inputs().Tag(kInputAlphaTagGpu).Get<mediapipe::GpuBuffer>(); cc->Inputs().Tag(kInputAlphaTagGpu).Get<mediapipe::GpuBuffer>();
auto alpha_texture = gpu_helper_.CreateSourceTexture(alpha_mask); auto alpha_texture = gpu_helper_.CreateSourceTexture(alpha_mask);
gpu_helper_.BindFramebuffer(output_texture); // GL_TEXTURE0 gpu_helper_.BindFramebuffer(output_texture);
glActiveTexture(GL_TEXTURE1); glActiveTexture(GL_TEXTURE1);
glBindTexture(GL_TEXTURE_2D, input_texture.name()); glBindTexture(GL_TEXTURE_2D, input_texture.name());
glActiveTexture(GL_TEXTURE2); glActiveTexture(GL_TEXTURE2);
@ -335,7 +335,7 @@ absl::Status SetAlphaCalculator::RenderGpu(CalculatorContext* cc) {
glBindTexture(GL_TEXTURE_2D, 0); glBindTexture(GL_TEXTURE_2D, 0);
alpha_texture.Release(); alpha_texture.Release();
} else { } else {
gpu_helper_.BindFramebuffer(output_texture); // GL_TEXTURE0 gpu_helper_.BindFramebuffer(output_texture);
glActiveTexture(GL_TEXTURE1); glActiveTexture(GL_TEXTURE1);
glBindTexture(GL_TEXTURE_2D, input_texture.name()); glBindTexture(GL_TEXTURE_2D, input_texture.name());
GlRender(cc); // use value from options GlRender(cc); // use value from options

View File

@ -490,6 +490,7 @@ cc_library(
"//mediapipe/framework/port:statusor", "//mediapipe/framework/port:statusor",
"//mediapipe/framework:calculator_framework", "//mediapipe/framework:calculator_framework",
"//mediapipe/framework:port", "//mediapipe/framework:port",
"//mediapipe/gpu:gpu_origin_cc_proto",
] + select({ ] + select({
"//mediapipe/gpu:disable_gpu": [], "//mediapipe/gpu:disable_gpu": [],
"//conditions:default": [":image_to_tensor_calculator_gpu_deps"], "//conditions:default": [":image_to_tensor_calculator_gpu_deps"],
@ -526,6 +527,7 @@ mediapipe_proto_library(
deps = [ deps = [
"//mediapipe/framework:calculator_options_proto", "//mediapipe/framework:calculator_options_proto",
"//mediapipe/framework:calculator_proto", "//mediapipe/framework:calculator_proto",
"//mediapipe/gpu:gpu_origin_proto",
], ],
) )

View File

@ -31,6 +31,7 @@
#include "mediapipe/framework/port/ret_check.h" #include "mediapipe/framework/port/ret_check.h"
#include "mediapipe/framework/port/status.h" #include "mediapipe/framework/port/status.h"
#include "mediapipe/framework/port/statusor.h" #include "mediapipe/framework/port/statusor.h"
#include "mediapipe/gpu/gpu_origin.pb.h"
#if !MEDIAPIPE_DISABLE_GPU #if !MEDIAPIPE_DISABLE_GPU
#include "mediapipe/gpu/gpu_buffer.h" #include "mediapipe/gpu/gpu_buffer.h"
@ -236,7 +237,7 @@ class ImageToTensorCalculator : public Node {
} }
private: private:
bool DoesInputStartAtBottom() { bool DoesGpuInputStartAtBottom() {
return options_.gpu_origin() != mediapipe::GpuOrigin_Mode_TOP_LEFT; return options_.gpu_origin() != mediapipe::GpuOrigin_Mode_TOP_LEFT;
} }
@ -290,11 +291,11 @@ class ImageToTensorCalculator : public Node {
#elif MEDIAPIPE_OPENGL_ES_VERSION >= MEDIAPIPE_OPENGL_ES_31 #elif MEDIAPIPE_OPENGL_ES_VERSION >= MEDIAPIPE_OPENGL_ES_31
ASSIGN_OR_RETURN(gpu_converter_, ASSIGN_OR_RETURN(gpu_converter_,
CreateImageToGlBufferTensorConverter( CreateImageToGlBufferTensorConverter(
cc, DoesInputStartAtBottom(), GetBorderMode())); cc, DoesGpuInputStartAtBottom(), GetBorderMode()));
#else #else
ASSIGN_OR_RETURN(gpu_converter_, ASSIGN_OR_RETURN(gpu_converter_,
CreateImageToGlTextureTensorConverter( CreateImageToGlTextureTensorConverter(
cc, DoesInputStartAtBottom(), GetBorderMode())); cc, DoesGpuInputStartAtBottom(), GetBorderMode()));
#endif // MEDIAPIPE_METAL_ENABLED #endif // MEDIAPIPE_METAL_ENABLED
#endif // !MEDIAPIPE_DISABLE_GPU #endif // !MEDIAPIPE_DISABLE_GPU
} }

View File

@ -17,20 +17,7 @@ syntax = "proto2";
package mediapipe; package mediapipe;
import "mediapipe/framework/calculator.proto"; import "mediapipe/framework/calculator.proto";
import "mediapipe/gpu/gpu_origin.proto";
message GpuOrigin {
enum Mode {
DEFAULT = 0;
// OpenGL: bottom-left origin
// Metal : top-left origin
CONVENTIONAL = 1;
// OpenGL: top-left origin
// Metal : top-left origin
TOP_LEFT = 2;
}
}
message ImageToTensorCalculatorOptions { message ImageToTensorCalculatorOptions {
extend mediapipe.CalculatorOptions { extend mediapipe.CalculatorOptions {

View File

@ -317,7 +317,8 @@ absl::Status InferenceCalculatorGlImpl::LoadModel(CalculatorContext* cc) {
absl::Status InferenceCalculatorGlImpl::LoadDelegate(CalculatorContext* cc) { absl::Status InferenceCalculatorGlImpl::LoadDelegate(CalculatorContext* cc) {
// Configure and create the delegate. // Configure and create the delegate.
TfLiteGpuDelegateOptions options = TfLiteGpuDelegateOptionsDefault(); TfLiteGpuDelegateOptions options = TfLiteGpuDelegateOptionsDefault();
options.compile_options.precision_loss_allowed = 1; options.compile_options.precision_loss_allowed =
allow_precision_loss_ ? 1 : 0;
options.compile_options.preferred_gl_object_type = options.compile_options.preferred_gl_object_type =
TFLITE_GL_OBJECT_TYPE_FASTEST; TFLITE_GL_OBJECT_TYPE_FASTEST;
options.compile_options.dynamic_batch_enabled = 0; options.compile_options.dynamic_batch_enabled = 0;

View File

@ -97,6 +97,7 @@ class InferenceCalculatorMetalImpl
Packet<TfLiteModelPtr> model_packet_; Packet<TfLiteModelPtr> model_packet_;
std::unique_ptr<tflite::Interpreter> interpreter_; std::unique_ptr<tflite::Interpreter> interpreter_;
TfLiteDelegatePtr delegate_; TfLiteDelegatePtr delegate_;
bool allow_precision_loss_ = false;
#if MEDIAPIPE_TFLITE_METAL_INFERENCE #if MEDIAPIPE_TFLITE_METAL_INFERENCE
MPPMetalHelper* gpu_helper_ = nullptr; MPPMetalHelper* gpu_helper_ = nullptr;
@ -122,6 +123,9 @@ absl::Status InferenceCalculatorMetalImpl::UpdateContract(
} }
absl::Status InferenceCalculatorMetalImpl::Open(CalculatorContext* cc) { absl::Status InferenceCalculatorMetalImpl::Open(CalculatorContext* cc) {
const auto& options = cc->Options<::mediapipe::InferenceCalculatorOptions>();
allow_precision_loss_ = options.delegate().gpu().allow_precision_loss();
MP_RETURN_IF_ERROR(LoadModel(cc)); MP_RETURN_IF_ERROR(LoadModel(cc));
gpu_helper_ = [[MPPMetalHelper alloc] initWithCalculatorContext:cc]; gpu_helper_ = [[MPPMetalHelper alloc] initWithCalculatorContext:cc];
@ -222,7 +226,7 @@ absl::Status InferenceCalculatorMetalImpl::LoadDelegate(CalculatorContext* cc) {
// Configure and create the delegate. // Configure and create the delegate.
TFLGpuDelegateOptions options; TFLGpuDelegateOptions options;
options.allow_precision_loss = true; options.allow_precision_loss = allow_precision_loss_;
options.wait_type = TFLGpuDelegateWaitType::TFLGpuDelegateWaitTypeDoNotWait; options.wait_type = TFLGpuDelegateWaitType::TFLGpuDelegateWaitTypeDoNotWait;
delegate_ = delegate_ =
TfLiteDelegatePtr(TFLGpuDelegateCreate(&options), &TFLGpuDelegateDelete); TfLiteDelegatePtr(TFLGpuDelegateCreate(&options), &TFLGpuDelegateDelete);
@ -239,7 +243,9 @@ absl::Status InferenceCalculatorMetalImpl::LoadDelegate(CalculatorContext* cc) {
tensor->dims->data + tensor->dims->size}; tensor->dims->data + tensor->dims->size};
dims.back() = RoundUp(dims.back(), 4); dims.back() = RoundUp(dims.back(), 4);
gpu_buffers_in_.emplace_back(absl::make_unique<Tensor>( gpu_buffers_in_.emplace_back(absl::make_unique<Tensor>(
Tensor::ElementType::kFloat16, Tensor::Shape{dims})); allow_precision_loss_ ? Tensor::ElementType::kFloat16
: Tensor::ElementType::kFloat32,
Tensor::Shape{dims}));
auto buffer_view = auto buffer_view =
gpu_buffers_in_[i]->GetMtlBufferWriteView(gpu_helper_.mtlDevice); gpu_buffers_in_[i]->GetMtlBufferWriteView(gpu_helper_.mtlDevice);
RET_CHECK_EQ(TFLGpuDelegateBindMetalBufferToTensor( RET_CHECK_EQ(TFLGpuDelegateBindMetalBufferToTensor(
@ -261,7 +267,9 @@ absl::Status InferenceCalculatorMetalImpl::LoadDelegate(CalculatorContext* cc) {
output_shapes_[i] = {dims}; output_shapes_[i] = {dims};
dims.back() = RoundUp(dims.back(), 4); dims.back() = RoundUp(dims.back(), 4);
gpu_buffers_out_.emplace_back(absl::make_unique<Tensor>( gpu_buffers_out_.emplace_back(absl::make_unique<Tensor>(
Tensor::ElementType::kFloat16, Tensor::Shape{dims})); allow_precision_loss_ ? Tensor::ElementType::kFloat16
: Tensor::ElementType::kFloat32,
Tensor::Shape{dims}));
RET_CHECK_EQ(TFLGpuDelegateBindMetalBufferToTensor( RET_CHECK_EQ(TFLGpuDelegateBindMetalBufferToTensor(
delegate_.get(), output_indices[i], delegate_.get(), output_indices[i],
gpu_buffers_out_[i] gpu_buffers_out_[i]
@ -271,17 +279,19 @@ absl::Status InferenceCalculatorMetalImpl::LoadDelegate(CalculatorContext* cc) {
} }
// Create converter for GPU input. // Create converter for GPU input.
converter_to_BPHWC4_ = [[TFLBufferConvert alloc] initWithDevice:device converter_to_BPHWC4_ =
isFloat16:true [[TFLBufferConvert alloc] initWithDevice:device
convertToPBHWC4:true]; isFloat16:allow_precision_loss_
convertToPBHWC4:true];
if (converter_to_BPHWC4_ == nil) { if (converter_to_BPHWC4_ == nil) {
return mediapipe::InternalError( return mediapipe::InternalError(
"Error initializating input buffer converter"); "Error initializating input buffer converter");
} }
// Create converter for GPU output. // Create converter for GPU output.
converter_from_BPHWC4_ = [[TFLBufferConvert alloc] initWithDevice:device converter_from_BPHWC4_ =
isFloat16:true [[TFLBufferConvert alloc] initWithDevice:device
convertToPBHWC4:false]; isFloat16:allow_precision_loss_
convertToPBHWC4:false];
if (converter_from_BPHWC4_ == nil) { if (converter_from_BPHWC4_ == nil) {
return absl::InternalError("Error initializating output buffer converter"); return absl::InternalError("Error initializating output buffer converter");
} }

View File

@ -89,7 +89,8 @@ absl::Status TensorsToClassificationCalculator::Open(CalculatorContext* cc) {
ASSIGN_OR_RETURN(string_path, ASSIGN_OR_RETURN(string_path,
PathToResourceAsFile(options_.label_map_path())); PathToResourceAsFile(options_.label_map_path()));
std::string label_map_string; std::string label_map_string;
MP_RETURN_IF_ERROR(file::GetContents(string_path, &label_map_string)); MP_RETURN_IF_ERROR(
mediapipe::GetResourceContents(string_path, &label_map_string));
std::istringstream stream(label_map_string); std::istringstream stream(label_map_string);
std::string line; std::string line;
@ -98,6 +99,14 @@ absl::Status TensorsToClassificationCalculator::Open(CalculatorContext* cc) {
label_map_[i++] = line; label_map_[i++] = line;
} }
label_map_loaded_ = true; label_map_loaded_ = true;
} else if (options_.has_label_map()) {
for (int i = 0; i < options_.label_map().entries_size(); ++i) {
const auto& entry = options_.label_map().entries(i);
RET_CHECK(!label_map_.contains(entry.id()))
<< "Duplicate id found: " << entry.id();
label_map_[entry.id()] = entry.label();
}
label_map_loaded_ = true;
} }
return absl::OkStatus(); return absl::OkStatus();

View File

@ -25,6 +25,14 @@ message TensorsToClassificationCalculatorOptions {
optional TensorsToClassificationCalculatorOptions ext = 335742638; optional TensorsToClassificationCalculatorOptions ext = 335742638;
} }
message LabelMap {
message Entry {
optional int32 id = 1;
optional string label = 2;
}
repeated Entry entries = 1;
}
// Score threshold for perserving the class. // Score threshold for perserving the class.
optional float min_score_threshold = 1; optional float min_score_threshold = 1;
// Number of highest scoring labels to output. If top_k is not positive then // Number of highest scoring labels to output. If top_k is not positive then
@ -32,6 +40,10 @@ message TensorsToClassificationCalculatorOptions {
optional int32 top_k = 2; optional int32 top_k = 2;
// Path to a label map file for getting the actual name of class ids. // Path to a label map file for getting the actual name of class ids.
optional string label_map_path = 3; optional string label_map_path = 3;
// Label map. (Can be used instead of label_map_path.)
// NOTE: "label_map_path", if specified, takes precedence over "label_map".
optional LabelMap label_map = 5;
// Whether the input is a single float for binary classification. // Whether the input is a single float for binary classification.
// When true, only a single float is expected in the input tensor and the // When true, only a single float is expected in the input tensor and the
// label map, if provided, is expected to have exactly two labels. // label map, if provided, is expected to have exactly two labels.

View File

@ -115,6 +115,41 @@ TEST_F(TensorsToClassificationCalculatorTest, CorrectOutputWithLabelMapPath) {
} }
} }
TEST_F(TensorsToClassificationCalculatorTest, CorrectOutputWithLabelMap) {
mediapipe::CalculatorRunner runner(ParseTextProtoOrDie<Node>(R"pb(
calculator: "TensorsToClassificationCalculator"
input_stream: "TENSORS:tensors"
output_stream: "CLASSIFICATIONS:classifications"
options {
[mediapipe.TensorsToClassificationCalculatorOptions.ext] {
label_map {
entries { id: 0, label: "ClassA" }
entries { id: 1, label: "ClassB" }
entries { id: 2, label: "ClassC" }
}
}
}
)pb"));
BuildGraph(&runner, {0, 0.5, 1});
MP_ASSERT_OK(runner.Run());
const auto& output_packets_ = runner.Outputs().Tag("CLASSIFICATIONS").packets;
EXPECT_EQ(1, output_packets_.size());
const auto& classification_list =
output_packets_[0].Get<ClassificationList>();
EXPECT_EQ(3, classification_list.classification_size());
// Verify that the label field is set.
for (int i = 0; i < classification_list.classification_size(); ++i) {
EXPECT_EQ(i, classification_list.classification(i).index());
EXPECT_EQ(i * 0.5, classification_list.classification(i).score());
ASSERT_TRUE(classification_list.classification(i).has_label());
}
}
TEST_F(TensorsToClassificationCalculatorTest, TEST_F(TensorsToClassificationCalculatorTest,
CorrectOutputWithLabelMinScoreThreshold) { CorrectOutputWithLabelMinScoreThreshold) {
mediapipe::CalculatorRunner runner(ParseTextProtoOrDie<Node>(R"pb( mediapipe::CalculatorRunner runner(ParseTextProtoOrDie<Node>(R"pb(

View File

@ -34,15 +34,28 @@ constexpr char kTensor[] = "TENSOR";
} // namespace } // namespace
// Input: // Input:
// Tensor of type DT_FLOAT, with values between 0-255 (SRGB or GRAY8). The // Tensor of type DT_FLOAT or DT_UINT8, with values between 0-255
// shape can be HxWx{3,1} or simply HxW. // (SRGB or GRAY8). The shape can be HxWx{3,1} or simply HxW.
// //
// Optionally supports a scale factor that can scale 0-1 value ranges to 0-255. // For DT_FLOAT tensors, optionally supports a scale factor that can scale 0-1
// value ranges to 0-255.
// //
// Output: // Output:
// ImageFrame containing the values of the tensor cast as uint8 (SRGB or GRAY8) // ImageFrame containing the values of the tensor cast as uint8 (SRGB or GRAY8)
// //
// Possible extensions: support other input ranges, maybe 4D tensors. // Possible extensions: support other input ranges, maybe 4D tensors.
//
// Example:
// node {
// calculator: "TensorToImageFrameCalculator"
// input_stream: "TENSOR:3d_float_tensor"
// output_stream: "IMAGE:image_frame"
// options {
// [mediapipe.TensorToImageFrameCalculatorOptions.ext] {
// scale_factor: 1.0 # set to 255.0 for [0,1] -> [0,255] scaling
// }
// }
// }
class TensorToImageFrameCalculator : public CalculatorBase { class TensorToImageFrameCalculator : public CalculatorBase {
public: public:
static absl::Status GetContract(CalculatorContract* cc); static absl::Status GetContract(CalculatorContract* cc);
@ -57,8 +70,8 @@ class TensorToImageFrameCalculator : public CalculatorBase {
REGISTER_CALCULATOR(TensorToImageFrameCalculator); REGISTER_CALCULATOR(TensorToImageFrameCalculator);
absl::Status TensorToImageFrameCalculator::GetContract(CalculatorContract* cc) { absl::Status TensorToImageFrameCalculator::GetContract(CalculatorContract* cc) {
RET_CHECK_EQ(cc->Inputs().NumEntries(), 1) RET_CHECK_EQ(cc->Outputs().NumEntries(), 1)
<< "Only one input stream is supported."; << "Only one output stream is supported.";
RET_CHECK_EQ(cc->Inputs().NumEntries(), 1) RET_CHECK_EQ(cc->Inputs().NumEntries(), 1)
<< "One input stream must be provided."; << "One input stream must be provided.";
RET_CHECK(cc->Inputs().HasTag(kTensor)) RET_CHECK(cc->Inputs().HasTag(kTensor))
@ -91,29 +104,44 @@ absl::Status TensorToImageFrameCalculator::Process(CalculatorContext* cc) {
RET_CHECK_EQ(depth, 3) << "Output tensor depth must be 3 or 1."; RET_CHECK_EQ(depth, 3) << "Output tensor depth must be 3 or 1.";
} }
} }
const int32 total_size = int32 height = input_tensor.dim_size(0);
input_tensor.dim_size(0) * input_tensor.dim_size(1) * depth; int32 width = input_tensor.dim_size(1);
std::unique_ptr<uint8[]> buffer(new uint8[total_size]); auto format = (depth == 3 ? ImageFormat::SRGB : ImageFormat::GRAY8);
auto data = input_tensor.flat<float>().data(); const int32 total_size = height * width * depth;
for (int i = 0; i < total_size; ++i) {
float d = scale_factor_ * data[i]; ::std::unique_ptr<const ImageFrame> output;
if (d < 0) d = 0; if (input_tensor.dtype() == tensorflow::DT_FLOAT) {
if (d > 255) d = 255; // Allocate buffer with alignments.
buffer[i] = d; std::unique_ptr<uint8_t[]> buffer(
new (std::align_val_t(EIGEN_MAX_ALIGN_BYTES)) uint8_t[total_size]);
auto data = input_tensor.flat<float>().data();
for (int i = 0; i < total_size; ++i) {
float d = scale_factor_ * data[i];
if (d < 0) d = 0;
if (d > 255) d = 255;
buffer[i] = d;
}
output = ::absl::make_unique<ImageFrame>(format, width, height,
width * depth, buffer.release());
} else if (input_tensor.dtype() == tensorflow::DT_UINT8) {
if (scale_factor_ != 1.0) {
return absl::InvalidArgumentError("scale_factor_ given for uint8 tensor");
}
// tf::Tensor has internally ref-counted buffer. The following code make the
// ImageFrame own the copied Tensor through the deleter, which increases
// the refcount of the buffer and allow us to use the shared buffer as the
// image. This allows us to create an ImageFrame object without copying
// buffer. const ImageFrame prevents the buffer from being modified later.
auto copy = new tf::Tensor(input_tensor);
output = ::absl::make_unique<const ImageFrame>(
format, width, height, width * depth, copy->flat<uint8_t>().data(),
[copy](uint8*) { delete copy; });
} else {
return absl::InvalidArgumentError(
absl::StrCat("Expected float or uint8 tensor, received ",
DataTypeString(input_tensor.dtype())));
} }
::std::unique_ptr<ImageFrame> output;
if (depth == 3) {
output = ::absl::make_unique<ImageFrame>(
ImageFormat::SRGB, input_tensor.dim_size(1), input_tensor.dim_size(0),
input_tensor.dim_size(1) * 3, buffer.release());
} else if (depth == 1) {
output = ::absl::make_unique<ImageFrame>(
ImageFormat::GRAY8, input_tensor.dim_size(1), input_tensor.dim_size(0),
input_tensor.dim_size(1), buffer.release());
} else {
return absl::InvalidArgumentError("Unrecognized image depth.");
}
cc->Outputs().Tag(kImage).Add(output.release(), cc->InputTimestamp()); cc->Outputs().Tag(kImage).Add(output.release(), cc->InputTimestamp());
return absl::OkStatus(); return absl::OkStatus();

View File

@ -29,6 +29,7 @@ constexpr char kImage[] = "IMAGE";
} // namespace } // namespace
template <class TypeParam>
class TensorToImageFrameCalculatorTest : public ::testing::Test { class TensorToImageFrameCalculatorTest : public ::testing::Test {
protected: protected:
void SetUpRunner() { void SetUpRunner() {
@ -42,14 +43,20 @@ class TensorToImageFrameCalculatorTest : public ::testing::Test {
std::unique_ptr<CalculatorRunner> runner_; std::unique_ptr<CalculatorRunner> runner_;
}; };
TEST_F(TensorToImageFrameCalculatorTest, Converts3DTensorToImageFrame) { using TensorToImageFrameCalculatorTestTypes = ::testing::Types<float, uint8_t>;
SetUpRunner(); TYPED_TEST_CASE(TensorToImageFrameCalculatorTest,
TensorToImageFrameCalculatorTestTypes);
TYPED_TEST(TensorToImageFrameCalculatorTest, Converts3DTensorToImageFrame) {
// TYPED_TEST requires explicit "this->"
this->SetUpRunner();
auto& runner = this->runner_;
constexpr int kWidth = 16; constexpr int kWidth = 16;
constexpr int kHeight = 8; constexpr int kHeight = 8;
const tf::TensorShape tensor_shape( const tf::TensorShape tensor_shape{kHeight, kWidth, 3};
std::vector<tf::int64>{kHeight, kWidth, 3}); auto tensor = absl::make_unique<tf::Tensor>(
auto tensor = absl::make_unique<tf::Tensor>(tf::DT_FLOAT, tensor_shape); tf::DataTypeToEnum<TypeParam>::v(), tensor_shape);
auto tensor_vec = tensor->flat<float>().data(); auto tensor_vec = tensor->template flat<TypeParam>().data();
// Writing sequence of integers as floats which we want back (as they were // Writing sequence of integers as floats which we want back (as they were
// written). // written).
@ -58,15 +65,16 @@ TEST_F(TensorToImageFrameCalculatorTest, Converts3DTensorToImageFrame) {
} }
const int64 time = 1234; const int64 time = 1234;
runner_->MutableInputs()->Tag(kTensor).packets.push_back( runner->MutableInputs()->Tag(kTensor).packets.push_back(
Adopt(tensor.release()).At(Timestamp(time))); Adopt(tensor.release()).At(Timestamp(time)));
EXPECT_TRUE(runner_->Run().ok()); EXPECT_TRUE(runner->Run().ok());
const std::vector<Packet>& output_packets = const std::vector<Packet>& output_packets =
runner_->Outputs().Tag(kImage).packets; runner->Outputs().Tag(kImage).packets;
EXPECT_EQ(1, output_packets.size()); EXPECT_EQ(1, output_packets.size());
EXPECT_EQ(time, output_packets[0].Timestamp().Value()); EXPECT_EQ(time, output_packets[0].Timestamp().Value());
const ImageFrame& output_image = output_packets[0].Get<ImageFrame>(); const ImageFrame& output_image = output_packets[0].Get<ImageFrame>();
EXPECT_EQ(ImageFormat::SRGB, output_image.Format());
EXPECT_EQ(kWidth, output_image.Width()); EXPECT_EQ(kWidth, output_image.Width());
EXPECT_EQ(kHeight, output_image.Height()); EXPECT_EQ(kHeight, output_image.Height());
@ -76,14 +84,15 @@ TEST_F(TensorToImageFrameCalculatorTest, Converts3DTensorToImageFrame) {
} }
} }
TEST_F(TensorToImageFrameCalculatorTest, Converts3DTensorToImageFrameGray) { TYPED_TEST(TensorToImageFrameCalculatorTest, Converts3DTensorToImageFrameGray) {
SetUpRunner(); this->SetUpRunner();
auto& runner = this->runner_;
constexpr int kWidth = 16; constexpr int kWidth = 16;
constexpr int kHeight = 8; constexpr int kHeight = 8;
const tf::TensorShape tensor_shape( const tf::TensorShape tensor_shape{kHeight, kWidth, 1};
std::vector<tf::int64>{kHeight, kWidth, 1}); auto tensor = absl::make_unique<tf::Tensor>(
auto tensor = absl::make_unique<tf::Tensor>(tf::DT_FLOAT, tensor_shape); tf::DataTypeToEnum<TypeParam>::v(), tensor_shape);
auto tensor_vec = tensor->flat<float>().data(); auto tensor_vec = tensor->template flat<TypeParam>().data();
// Writing sequence of integers as floats which we want back (as they were // Writing sequence of integers as floats which we want back (as they were
// written). // written).
@ -92,15 +101,16 @@ TEST_F(TensorToImageFrameCalculatorTest, Converts3DTensorToImageFrameGray) {
} }
const int64 time = 1234; const int64 time = 1234;
runner_->MutableInputs()->Tag(kTensor).packets.push_back( runner->MutableInputs()->Tag(kTensor).packets.push_back(
Adopt(tensor.release()).At(Timestamp(time))); Adopt(tensor.release()).At(Timestamp(time)));
EXPECT_TRUE(runner_->Run().ok()); EXPECT_TRUE(runner->Run().ok());
const std::vector<Packet>& output_packets = const std::vector<Packet>& output_packets =
runner_->Outputs().Tag(kImage).packets; runner->Outputs().Tag(kImage).packets;
EXPECT_EQ(1, output_packets.size()); EXPECT_EQ(1, output_packets.size());
EXPECT_EQ(time, output_packets[0].Timestamp().Value()); EXPECT_EQ(time, output_packets[0].Timestamp().Value());
const ImageFrame& output_image = output_packets[0].Get<ImageFrame>(); const ImageFrame& output_image = output_packets[0].Get<ImageFrame>();
EXPECT_EQ(ImageFormat::GRAY8, output_image.Format());
EXPECT_EQ(kWidth, output_image.Width()); EXPECT_EQ(kWidth, output_image.Width());
EXPECT_EQ(kHeight, output_image.Height()); EXPECT_EQ(kHeight, output_image.Height());
@ -110,13 +120,16 @@ TEST_F(TensorToImageFrameCalculatorTest, Converts3DTensorToImageFrameGray) {
} }
} }
TEST_F(TensorToImageFrameCalculatorTest, Converts3DTensorToImageFrame2DGray) { TYPED_TEST(TensorToImageFrameCalculatorTest,
SetUpRunner(); Converts3DTensorToImageFrame2DGray) {
this->SetUpRunner();
auto& runner = this->runner_;
constexpr int kWidth = 16; constexpr int kWidth = 16;
constexpr int kHeight = 8; constexpr int kHeight = 8;
const tf::TensorShape tensor_shape(std::vector<tf::int64>{kHeight, kWidth}); const tf::TensorShape tensor_shape{kHeight, kWidth};
auto tensor = absl::make_unique<tf::Tensor>(tf::DT_FLOAT, tensor_shape); auto tensor = absl::make_unique<tf::Tensor>(
auto tensor_vec = tensor->flat<float>().data(); tf::DataTypeToEnum<TypeParam>::v(), tensor_shape);
auto tensor_vec = tensor->template flat<TypeParam>().data();
// Writing sequence of integers as floats which we want back (as they were // Writing sequence of integers as floats which we want back (as they were
// written). // written).
@ -125,15 +138,16 @@ TEST_F(TensorToImageFrameCalculatorTest, Converts3DTensorToImageFrame2DGray) {
} }
const int64 time = 1234; const int64 time = 1234;
runner_->MutableInputs()->Tag(kTensor).packets.push_back( runner->MutableInputs()->Tag(kTensor).packets.push_back(
Adopt(tensor.release()).At(Timestamp(time))); Adopt(tensor.release()).At(Timestamp(time)));
EXPECT_TRUE(runner_->Run().ok()); EXPECT_TRUE(runner->Run().ok());
const std::vector<Packet>& output_packets = const std::vector<Packet>& output_packets =
runner_->Outputs().Tag(kImage).packets; runner->Outputs().Tag(kImage).packets;
EXPECT_EQ(1, output_packets.size()); EXPECT_EQ(1, output_packets.size());
EXPECT_EQ(time, output_packets[0].Timestamp().Value()); EXPECT_EQ(time, output_packets[0].Timestamp().Value());
const ImageFrame& output_image = output_packets[0].Get<ImageFrame>(); const ImageFrame& output_image = output_packets[0].Get<ImageFrame>();
EXPECT_EQ(ImageFormat::GRAY8, output_image.Format());
EXPECT_EQ(kWidth, output_image.Width()); EXPECT_EQ(kWidth, output_image.Width());
EXPECT_EQ(kHeight, output_image.Height()); EXPECT_EQ(kHeight, output_image.Height());

View File

@ -91,8 +91,6 @@ absl::Status FillTimeSeriesHeaderIfValid(const Packet& header_packet,
// the input data when it arrives in Process(). In particular, if the header // the input data when it arrives in Process(). In particular, if the header
// states that we produce a 1xD column vector, the input tensor must also be 1xD // states that we produce a 1xD column vector, the input tensor must also be 1xD
// //
// This designed was discussed in http://g/speakeranalysis/4uyx7cNRwJY and
// http://g/daredevil-project/VB26tcseUy8.
// Example Config // Example Config
// node: { // node: {
// calculator: "TensorToMatrixCalculator" // calculator: "TensorToMatrixCalculator"
@ -158,22 +156,17 @@ absl::Status TensorToMatrixCalculator::Open(CalculatorContext* cc) {
if (header_status.ok()) { if (header_status.ok()) {
if (cc->Options<TensorToMatrixCalculatorOptions>() if (cc->Options<TensorToMatrixCalculatorOptions>()
.has_time_series_header_overrides()) { .has_time_series_header_overrides()) {
// From design discussions with Daredevil, we only want to support single // This only supports a single sample per packet for now, so we hardcode
// sample per packet for now, so we hardcode the sample_rate based on the // the sample_rate based on the packet_rate of the REFERENCE and fail
// packet_rate of the REFERENCE and fail noisily if we cannot. An // if we cannot.
// alternative would be to calculate the sample_rate from the reference
// sample_rate and the change in num_samples between the reference and
// override headers:
// sample_rate_output = sample_rate_reference /
// (num_samples_override / num_samples_reference)
const TimeSeriesHeader& override_header = const TimeSeriesHeader& override_header =
cc->Options<TensorToMatrixCalculatorOptions>() cc->Options<TensorToMatrixCalculatorOptions>()
.time_series_header_overrides(); .time_series_header_overrides();
input_header->MergeFrom(override_header); input_header->MergeFrom(override_header);
CHECK(input_header->has_packet_rate()) RET_CHECK(input_header->has_packet_rate())
<< "The TimeSeriesHeader.packet_rate must be set."; << "The TimeSeriesHeader.packet_rate must be set.";
if (!override_header.has_sample_rate()) { if (!override_header.has_sample_rate()) {
CHECK_EQ(input_header->num_samples(), 1) RET_CHECK_EQ(input_header->num_samples(), 1)
<< "Currently the time series can only output single samples."; << "Currently the time series can only output single samples.";
input_header->set_sample_rate(input_header->packet_rate()); input_header->set_sample_rate(input_header->packet_rate());
} }
@ -186,20 +179,16 @@ absl::Status TensorToMatrixCalculator::Open(CalculatorContext* cc) {
} }
absl::Status TensorToMatrixCalculator::Process(CalculatorContext* cc) { absl::Status TensorToMatrixCalculator::Process(CalculatorContext* cc) {
// Daredevil requested CHECK for noisy failures rather than quieter RET_CHECK
// failures. These are absolute conditions of the graph for the graph to be
// valid, and if it is violated by any input anywhere, the graph will be
// invalid for all inputs. A hard CHECK will enable faster debugging by
// immediately exiting and more prominently displaying error messages.
// Do not replace with RET_CHECKs.
// Verify that each reference stream packet corresponds to a tensor packet // Verify that each reference stream packet corresponds to a tensor packet
// otherwise the header information is invalid. If we don't have a reference // otherwise the header information is invalid. If we don't have a reference
// stream, Process() is only called when we have an input tensor and this is // stream, Process() is only called when we have an input tensor and this is
// always True. // always True.
CHECK(cc->Inputs().HasTag(kTensor)) RET_CHECK(cc->Inputs().HasTag(kTensor))
<< "Tensor stream not available at same timestamp as the reference " << "Tensor stream not available at same timestamp as the reference "
"stream."; "stream.";
RET_CHECK(!cc->Inputs().Tag(kTensor).IsEmpty()) << "Tensor stream is empty.";
RET_CHECK_OK(cc->Inputs().Tag(kTensor).Value().ValidateAsType<tf::Tensor>())
<< "Tensor stream packet does not contain a Tensor.";
const tf::Tensor& input_tensor = cc->Inputs().Tag(kTensor).Get<tf::Tensor>(); const tf::Tensor& input_tensor = cc->Inputs().Tag(kTensor).Get<tf::Tensor>();
CHECK(1 == input_tensor.dims() || 2 == input_tensor.dims()) CHECK(1 == input_tensor.dims() || 2 == input_tensor.dims())
@ -207,13 +196,12 @@ absl::Status TensorToMatrixCalculator::Process(CalculatorContext* cc) {
const int32 length = input_tensor.dim_size(input_tensor.dims() - 1); const int32 length = input_tensor.dim_size(input_tensor.dims() - 1);
const int32 width = (1 == input_tensor.dims()) ? 1 : input_tensor.dim_size(0); const int32 width = (1 == input_tensor.dims()) ? 1 : input_tensor.dim_size(0);
if (header_.has_num_channels()) { if (header_.has_num_channels()) {
CHECK_EQ(length, header_.num_channels()) RET_CHECK_EQ(length, header_.num_channels())
<< "The number of channels at runtime does not match the header."; << "The number of channels at runtime does not match the header.";
} }
if (header_.has_num_samples()) { if (header_.has_num_samples()) {
CHECK_EQ(width, header_.num_samples()) RET_CHECK_EQ(width, header_.num_samples())
<< "The number of samples at runtime does not match the header."; << "The number of samples at runtime does not match the header.";
;
} }
auto output = absl::make_unique<Matrix>(width, length); auto output = absl::make_unique<Matrix>(width, length);
*output = *output =

View File

@ -98,388 +98,543 @@ class InferenceState {
// This calculator performs inference on a trained TensorFlow model. // This calculator performs inference on a trained TensorFlow model.
// //
// A mediapipe::TensorFlowSession with a model loaded and ready for use. // TensorFlow Sessions can be created from checkpoint paths, frozen models, or
// For this calculator it must include a tag_to_tensor_map. // the SavedModel system. See the TensorFlowSessionFrom* packet generators for
cc->InputSidePackets().Tag("SESSION").Set<TensorFlowSession>(); // details. Each of these methods defines a mapping between MediaPipe streams
if (cc->InputSidePackets().HasTag("RECURRENT_INIT_TENSORS")) { // and TensorFlow tensors. All of this information is passed in as an
cc->InputSidePackets() // input_side_packet.
.Tag("RECURRENT_INIT_TENSORS") //
.Set<std::unique_ptr<std::map<std::string, tf::Tensor>>>(); // The input and output streams are TensorFlow tensors labeled by tags. The tags
} // for the streams are matched to feeds and fetchs in a TensorFlow session using
return absl::OkStatus(); // a named_signature.generic_signature in the ModelManifest. The
} // generic_signature is used as key-value pairs between the MediaPipe tag and
// the TensorFlow tensor. The signature_name in the options proto determines
// which named_signature is used. The keys in the generic_signature must be
// valid MediaPipe tags ([A-Z0-9_]*, no lowercase or special characters). All of
// the tensors corresponding to tags in the signature for input_streams are fed
// to the model and for output_streams the tensors are fetched from the model.
//
// Other calculators are used to convert data to and from tensors, this op only
// handles the TensorFlow session and batching. Batching occurs by concatenating
// input tensors along the 0th dimension across timestamps. If the 0th dimension
// is not a batch dimension, this calculator will add a 0th dimension by
// default. Setting add_batch_dim_to_tensors to false disables the dimension
// addition. Once batch_size inputs have been provided, the batch will be run
// and the output tensors sent out on the output streams with timestamps
// corresponding to the input stream packets. Setting the batch_size to 1
// completely disables batching, but is indepdent of add_batch_dim_to_tensors.
//
// The TensorFlowInferenceCalculator also support feeding states recurrently for
// RNNs and LSTMs. Simply set the recurrent_tag_pair options to define the
// recurrent tensors. Initializing the recurrent state can be handled by the
// GraphTensorsPacketGenerator.
//
// The calculator updates two Counters to report timing information:
// --<name>-TotalTimeUsecs = Total time spent running inference (in usecs),
// --<name>-TotalProcessedTimestamps = # of instances processed
// (approximately batches processed * batch_size),
// where <name> is replaced with CalculatorGraphConfig::Node::name() if it
// exists, or with TensorFlowInferenceCalculator if the name is not set. The
// name must be set for timing information to be instance-specific in graphs
// with multiple TensorFlowInferenceCalculators.
//
// Example config:
// packet_generator {
// packet_generator: "TensorFlowSessionFromSavedModelGenerator"
// output_side_packet: "tensorflow_session"
// options {
// [mediapipe.TensorFlowSessionFromSavedModelGeneratorOptions.ext]: {
// saved_model_path: "/path/to/saved/model"
// signature_name: "mediapipe"
// }
// }
// }
// node {
// calculator: "TensorFlowInferenceCalculator"
// input_stream: "IMAGES:image_tensors_keyed_in_signature_by_tag"
// input_stream: "AUDIO:audio_tensors_keyed_in_signature_by_tag"
// output_stream: "LABELS:softmax_tensor_keyed_in_signature_by_tag"
// input_side_packet: "SESSION:tensorflow_session"
// }
//
// Where the input and output streams are treated as Packet<tf::Tensor> and
// the mediapipe_signature has tensor bindings between "IMAGES", "AUDIO", and
// "LABELS" and their respective tensors exported to /path/to/bundle. For an
// example of how this model was exported, see
// tensorflow_inference_test_graph_generator.py
//
// It is possible to use a GraphDef proto that was not exported by exporter (i.e
// without MetaGraph with bindings). Such GraphDef could contain all of its
// parameters in-lined (for example, it can be the output of freeze_graph.py).
// To instantiate a TensorFlow model from a GraphDef file, replace the
// packet_factory above with TensorFlowSessionFromFrozenGraphGenerator:
//
// packet_generator {
// packet_generator: "TensorFlowSessionFromFrozenGraphGenerator"
// output_side_packet: "SESSION:tensorflow_session"
// options {
// [mediapipe.TensorFlowSessionFromFrozenGraphGeneratorOptions.ext]: {
// graph_proto_path: "[PATH]"
// tag_to_tensor_names {
// key: "JPG_STRING"
// value: "input:0"
// }
// tag_to_tensor_names {
// key: "SOFTMAX"
// value: "softmax:0"
// }
// }
// }
// }
//
// It is also possible to use a GraphDef proto and checkpoint file that have not
// been frozen. This can be used to load graphs directly as they have been
// written from training. However, it is more brittle and you are encouraged to
// use a one of the more perminent formats described above. To instantiate a
// TensorFlow model from a GraphDef file and checkpoint, replace the
// packet_factory above with TensorFlowSessionFromModelCheckpointGenerator:
//
// packet_generator {
// packet_generator: "TensorFlowSessionFromModelCheckpointGenerator"
// output_side_packet: "SESSION:tensorflow_session"
// options {
// [mediapipe.TensorFlowSessionFromModelCheckpointGeneratorOptions.ext]: {
// graph_proto_path: "[PATH]"
// model_options {
// checkpoint_path: "[PATH2]"
// }
// tag_to_tensor_names {
// key: "JPG_STRING"
// value: "input:0"
// }
// tag_to_tensor_names {
// key: "SOFTMAX"
// value: "softmax:0"
// }
// }
// }
// }
class TensorFlowInferenceCalculator : public CalculatorBase {
public:
// Counters for recording timing information. The actual names have the value
// of CalculatorGraphConfig::Node::name() prepended.
static constexpr char kTotalUsecsCounterSuffix[] = "TotalTimeUsecs";
static constexpr char kTotalProcessedTimestampsCounterSuffix[] =
"TotalProcessedTimestamps";
static constexpr char kTotalSessionRunsTimeUsecsCounterSuffix[] =
"TotalSessionRunsTimeUsecs";
static constexpr char kTotalNumSessionRunsCounterSuffix[] =
"TotalNumSessionRuns";
std::unique_ptr<InferenceState> CreateInferenceState(CalculatorContext* cc) TensorFlowInferenceCalculator() : session_(nullptr) {
ABSL_EXCLUSIVE_LOCKS_REQUIRED(mutex_) { clock_ = std::unique_ptr<mediapipe::Clock>(
std::unique_ptr<InferenceState> inference_state = mediapipe::MonotonicClock::CreateSynchronizedMonotonicClock());
absl::make_unique<InferenceState>(); }
if (cc->InputSidePackets().HasTag("RECURRENT_INIT_TENSORS") &&
!cc->InputSidePackets().Tag("RECURRENT_INIT_TENSORS").IsEmpty()) { static absl::Status GetContract(CalculatorContract* cc) {
std::map<std::string, tf::Tensor>* init_tensor_map; const auto& options = cc->Options<TensorFlowInferenceCalculatorOptions>();
init_tensor_map = GetFromUniquePtr<std::map<std::string, tf::Tensor>>( RET_CHECK(!cc->Inputs().GetTags().empty());
cc->InputSidePackets().Tag("RECURRENT_INIT_TENSORS")); for (const std::string& tag : cc->Inputs().GetTags()) {
for (const auto& p : *init_tensor_map) { // The tensorflow::Tensor with the tag equal to the graph node. May
inference_state->input_tensor_batches_[p.first].emplace_back(p.second); // have a TimeSeriesHeader if all present TimeSeriesHeaders match.
if (!options.batched_input()) {
cc->Inputs().Tag(tag).Set<tf::Tensor>();
} else {
cc->Inputs().Tag(tag).Set<std::vector<mediapipe::Packet>>();
}
} }
} RET_CHECK(!cc->Outputs().GetTags().empty());
return inference_state; for (const std::string& tag : cc->Outputs().GetTags()) {
} // The tensorflow::Tensor with tag equal to the graph node to
// output. Any TimeSeriesHeader from the inputs will be forwarded
absl::Status Open(CalculatorContext* cc) override { // with channels set to 0.
options_ = cc->Options<TensorFlowInferenceCalculatorOptions>(); cc->Outputs().Tag(tag).Set<tf::Tensor>();
RET_CHECK(cc->InputSidePackets().HasTag("SESSION"));
session_ = cc->InputSidePackets()
.Tag("SESSION")
.Get<TensorFlowSession>()
.session.get();
tag_to_tensor_map_ = cc->InputSidePackets()
.Tag("SESSION")
.Get<TensorFlowSession>()
.tag_to_tensor_map;
// Validate and store the recurrent tags
RET_CHECK(options_.has_batch_size());
RET_CHECK(options_.batch_size() == 1 || options_.recurrent_tag_pair().empty())
<< "To use recurrent_tag_pairs, batch_size must be 1.";
for (const auto& tag_pair : options_.recurrent_tag_pair()) {
const std::vector<std::string> tags = absl::StrSplit(tag_pair, ':');
RET_CHECK_EQ(tags.size(), 2)
<< "recurrent_tag_pair must be a colon "
"separated std::string with two components: "
<< tag_pair;
RET_CHECK(mediapipe::ContainsKey(tag_to_tensor_map_, tags[0]))
<< "Can't find tag '" << tags[0] << "' in signature "
<< options_.signature_name();
RET_CHECK(mediapipe::ContainsKey(tag_to_tensor_map_, tags[1]))
<< "Can't find tag '" << tags[1] << "' in signature "
<< options_.signature_name();
recurrent_feed_tags_.insert(tags[0]);
recurrent_fetch_tags_to_feed_tags_[tags[1]] = tags[0];
}
// Check that all tags are present in this signature bound to tensors.
for (const std::string& tag : cc->Inputs().GetTags()) {
RET_CHECK(mediapipe::ContainsKey(tag_to_tensor_map_, tag))
<< "Can't find tag '" << tag << "' in signature "
<< options_.signature_name();
}
for (const std::string& tag : cc->Outputs().GetTags()) {
RET_CHECK(mediapipe::ContainsKey(tag_to_tensor_map_, tag))
<< "Can't find tag '" << tag << "' in signature "
<< options_.signature_name();
}
{
absl::WriterMutexLock l(&mutex_);
inference_state_ = std::unique_ptr<InferenceState>();
}
if (options_.batch_size() == 1 || options_.batched_input()) {
cc->SetOffset(0);
}
return absl::OkStatus();
}
// Adds a batch dimension to the input tensor if specified in the calculator
// options.
absl::Status AddBatchDimension(tf::Tensor* input_tensor) {
if (options_.add_batch_dim_to_tensors()) {
tf::TensorShape new_shape(input_tensor->shape());
new_shape.InsertDim(0, 1);
RET_CHECK(input_tensor->CopyFrom(*input_tensor, new_shape))
<< "Could not add 0th dimension to tensor without changing its shape."
<< " Current shape: " << input_tensor->shape().DebugString();
}
return absl::OkStatus();
}
absl::Status AggregateTensorPacket(
const std::string& tag_name, const Packet& packet,
std::map<Timestamp, std::map<std::string, tf::Tensor>>*
input_tensors_by_tag_by_timestamp,
InferenceState* inference_state) ABSL_EXCLUSIVE_LOCKS_REQUIRED(mutex_) {
tf::Tensor input_tensor(packet.Get<tf::Tensor>());
RET_CHECK_OK(AddBatchDimension(&input_tensor));
if (mediapipe::ContainsKey(recurrent_feed_tags_, tag_name)) {
// If we receive an input on a recurrent tag, override the state.
// It's OK to override the global state because there is just one
// input stream allowed for recurrent tensors.
inference_state_->input_tensor_batches_[tag_name].clear();
}
(*input_tensors_by_tag_by_timestamp)[packet.Timestamp()].insert(
std::make_pair(tag_name, input_tensor));
return absl::OkStatus();
}
// Removes the batch dimension of the output tensor if specified in the
// calculator options.
absl::Status RemoveBatchDimension(tf::Tensor* output_tensor) {
if (options_.add_batch_dim_to_tensors()) {
tf::TensorShape new_shape(output_tensor->shape());
new_shape.RemoveDim(0);
RET_CHECK(output_tensor->CopyFrom(*output_tensor, new_shape))
<< "Could not remove 0th dimension from tensor without changing its "
<< "shape. Current shape: " << output_tensor->shape().DebugString()
<< " (The expected first dimension is 1 for a batch element.)";
}
return absl::OkStatus();
}
absl::Status Process(CalculatorContext* cc) override {
std::unique_ptr<InferenceState> inference_state_to_process;
{
absl::WriterMutexLock l(&mutex_);
if (inference_state_ == nullptr) {
inference_state_ = CreateInferenceState(cc);
} }
std::map<Timestamp, std::map<std::string, tf::Tensor>> // A mediapipe::TensorFlowSession with a model loaded and ready for use.
input_tensors_by_tag_by_timestamp; // For this calculator it must include a tag_to_tensor_map.
for (const std::string& tag_as_node_name : cc->Inputs().GetTags()) { cc->InputSidePackets().Tag("SESSION").Set<TensorFlowSession>();
if (cc->Inputs().Tag(tag_as_node_name).IsEmpty()) { if (cc->InputSidePackets().HasTag("RECURRENT_INIT_TENSORS")) {
// Recurrent tensors can be empty. cc->InputSidePackets()
if (!mediapipe::ContainsKey(recurrent_feed_tags_, tag_as_node_name)) { .Tag("RECURRENT_INIT_TENSORS")
if (options_.skip_on_missing_features()) { .Set<std::unique_ptr<std::map<std::string, tf::Tensor>>>();
return absl::OkStatus(); }
} else { return absl::OkStatus();
return absl::InvalidArgumentError(absl::StrCat( }
"Tag ", tag_as_node_name,
" not present at timestamp: ", cc->InputTimestamp().Value())); std::unique_ptr<InferenceState> CreateInferenceState(CalculatorContext* cc)
ABSL_EXCLUSIVE_LOCKS_REQUIRED(mutex_) {
std::unique_ptr<InferenceState> inference_state =
absl::make_unique<InferenceState>();
if (cc->InputSidePackets().HasTag("RECURRENT_INIT_TENSORS") &&
!cc->InputSidePackets().Tag("RECURRENT_INIT_TENSORS").IsEmpty()) {
std::map<std::string, tf::Tensor>* init_tensor_map;
init_tensor_map = GetFromUniquePtr<std::map<std::string, tf::Tensor>>(
cc->InputSidePackets().Tag("RECURRENT_INIT_TENSORS"));
for (const auto& p : *init_tensor_map) {
inference_state->input_tensor_batches_[p.first].emplace_back(p.second);
}
}
return inference_state;
}
absl::Status Open(CalculatorContext* cc) override {
options_ = cc->Options<TensorFlowInferenceCalculatorOptions>();
RET_CHECK(cc->InputSidePackets().HasTag("SESSION"));
session_ = cc->InputSidePackets()
.Tag("SESSION")
.Get<TensorFlowSession>()
.session.get();
tag_to_tensor_map_ = cc->InputSidePackets()
.Tag("SESSION")
.Get<TensorFlowSession>()
.tag_to_tensor_map;
// Validate and store the recurrent tags
RET_CHECK(options_.has_batch_size());
RET_CHECK(options_.batch_size() == 1 ||
options_.recurrent_tag_pair().empty())
<< "To use recurrent_tag_pairs, batch_size must be 1.";
for (const auto& tag_pair : options_.recurrent_tag_pair()) {
const std::vector<std::string> tags = absl::StrSplit(tag_pair, ':');
RET_CHECK_EQ(tags.size(), 2)
<< "recurrent_tag_pair must be a colon "
"separated std::string with two components: "
<< tag_pair;
RET_CHECK(mediapipe::ContainsKey(tag_to_tensor_map_, tags[0]))
<< "Can't find tag '" << tags[0] << "' in signature "
<< options_.signature_name();
RET_CHECK(mediapipe::ContainsKey(tag_to_tensor_map_, tags[1]))
<< "Can't find tag '" << tags[1] << "' in signature "
<< options_.signature_name();
recurrent_feed_tags_.insert(tags[0]);
recurrent_fetch_tags_to_feed_tags_[tags[1]] = tags[0];
}
// Check that all tags are present in this signature bound to tensors.
for (const std::string& tag : cc->Inputs().GetTags()) {
RET_CHECK(mediapipe::ContainsKey(tag_to_tensor_map_, tag))
<< "Can't find tag '" << tag << "' in signature "
<< options_.signature_name();
}
for (const std::string& tag : cc->Outputs().GetTags()) {
RET_CHECK(mediapipe::ContainsKey(tag_to_tensor_map_, tag))
<< "Can't find tag '" << tag << "' in signature "
<< options_.signature_name();
}
{
absl::WriterMutexLock l(&mutex_);
inference_state_ = std::unique_ptr<InferenceState>();
}
if (options_.batch_size() == 1 || options_.batched_input()) {
cc->SetOffset(0);
}
return absl::OkStatus();
}
// Adds a batch dimension to the input tensor if specified in the calculator
// options.
absl::Status AddBatchDimension(tf::Tensor* input_tensor) {
if (options_.add_batch_dim_to_tensors()) {
tf::TensorShape new_shape(input_tensor->shape());
new_shape.InsertDim(0, 1);
RET_CHECK(input_tensor->CopyFrom(*input_tensor, new_shape))
<< "Could not add 0th dimension to tensor without changing its shape."
<< " Current shape: " << input_tensor->shape().DebugString();
}
return absl::OkStatus();
}
absl::Status AggregateTensorPacket(
const std::string& tag_name, const Packet& packet,
std::map<Timestamp, std::map<std::string, tf::Tensor>>*
input_tensors_by_tag_by_timestamp,
InferenceState* inference_state) ABSL_EXCLUSIVE_LOCKS_REQUIRED(mutex_) {
tf::Tensor input_tensor(packet.Get<tf::Tensor>());
RET_CHECK_OK(AddBatchDimension(&input_tensor));
if (mediapipe::ContainsKey(recurrent_feed_tags_, tag_name)) {
// If we receive an input on a recurrent tag, override the state.
// It's OK to override the global state because there is just one
// input stream allowed for recurrent tensors.
inference_state_->input_tensor_batches_[tag_name].clear();
}
(*input_tensors_by_tag_by_timestamp)[packet.Timestamp()].insert(
std::make_pair(tag_name, input_tensor));
return absl::OkStatus();
}
// Removes the batch dimension of the output tensor if specified in the
// calculator options.
absl::Status RemoveBatchDimension(tf::Tensor* output_tensor) {
if (options_.add_batch_dim_to_tensors()) {
tf::TensorShape new_shape(output_tensor->shape());
new_shape.RemoveDim(0);
RET_CHECK(output_tensor->CopyFrom(*output_tensor, new_shape))
<< "Could not remove 0th dimension from tensor without changing its "
<< "shape. Current shape: " << output_tensor->shape().DebugString()
<< " (The expected first dimension is 1 for a batch element.)";
}
return absl::OkStatus();
}
absl::Status Process(CalculatorContext* cc) override {
std::unique_ptr<InferenceState> inference_state_to_process;
{
absl::WriterMutexLock l(&mutex_);
if (inference_state_ == nullptr) {
inference_state_ = CreateInferenceState(cc);
}
std::map<Timestamp, std::map<std::string, tf::Tensor>>
input_tensors_by_tag_by_timestamp;
for (const std::string& tag_as_node_name : cc->Inputs().GetTags()) {
if (cc->Inputs().Tag(tag_as_node_name).IsEmpty()) {
// Recurrent tensors can be empty.
if (!mediapipe::ContainsKey(recurrent_feed_tags_, tag_as_node_name)) {
if (options_.skip_on_missing_features()) {
return absl::OkStatus();
} else {
return absl::InvalidArgumentError(absl::StrCat(
"Tag ", tag_as_node_name,
" not present at timestamp: ", cc->InputTimestamp().Value()));
}
} }
} else if (options_.batched_input()) {
const auto& tensor_packets =
cc->Inputs().Tag(tag_as_node_name).Get<std::vector<Packet>>();
if (tensor_packets.size() > options_.batch_size()) {
return absl::InvalidArgumentError(absl::StrCat(
"Batch for tag ", tag_as_node_name,
" has more packets than batch capacity. batch_size: ",
options_.batch_size(), " packets: ", tensor_packets.size()));
}
for (const auto& packet : tensor_packets) {
RET_CHECK_OK(AggregateTensorPacket(
tag_as_node_name, packet, &input_tensors_by_tag_by_timestamp,
inference_state_.get()));
}
} else {
RET_CHECK_OK(AggregateTensorPacket(
tag_as_node_name, cc->Inputs().Tag(tag_as_node_name).Value(),
&input_tensors_by_tag_by_timestamp, inference_state_.get()));
} }
} else if (options_.batched_input()) { }
const auto& tensor_packets = for (const auto& timestamp_and_input_tensors_by_tag :
cc->Inputs().Tag(tag_as_node_name).Get<std::vector<Packet>>(); input_tensors_by_tag_by_timestamp) {
if (tensor_packets.size() > options_.batch_size()) { inference_state_->batch_timestamps_.emplace_back(
return absl::InvalidArgumentError(absl::StrCat( timestamp_and_input_tensors_by_tag.first);
"Batch for tag ", tag_as_node_name, for (const auto& input_tensor_and_tag :
" has more packets than batch capacity. batch_size: ", timestamp_and_input_tensors_by_tag.second) {
options_.batch_size(), " packets: ", tensor_packets.size())); inference_state_->input_tensor_batches_[input_tensor_and_tag.first]
.emplace_back(input_tensor_and_tag.second);
} }
for (const auto& packet : tensor_packets) { }
RET_CHECK_OK(AggregateTensorPacket(tag_as_node_name, packet, if (inference_state_->batch_timestamps_.size() == options_.batch_size() ||
&input_tensors_by_tag_by_timestamp, options_.batched_input()) {
inference_state_.get())); inference_state_to_process = std::move(inference_state_);
inference_state_ = std::unique_ptr<InferenceState>();
}
}
if (inference_state_to_process) {
MP_RETURN_IF_ERROR(
OutputBatch(cc, std::move(inference_state_to_process)));
}
return absl::OkStatus();
}
absl::Status Close(CalculatorContext* cc) override {
std::unique_ptr<InferenceState> inference_state_to_process = nullptr;
{
absl::WriterMutexLock l(&mutex_);
if (cc->GraphStatus().ok() && inference_state_ != nullptr &&
!inference_state_->batch_timestamps_.empty()) {
inference_state_to_process = std::move(inference_state_);
inference_state_ = std::unique_ptr<InferenceState>();
}
}
if (inference_state_to_process) {
MP_RETURN_IF_ERROR(
OutputBatch(cc, std::move(inference_state_to_process)));
}
return absl::OkStatus();
}
// When a batch of input tensors is ready to be run, runs TensorFlow and
// outputs the output tensors. The output tensors have timestamps matching
// the input tensor that formed that batch element. Any requested
// batch_dimension is added and removed. This code takes advantage of the fact
// that copying a tensor shares the same reference-counted, heap allocated
// memory buffer. Therefore, copies are cheap and should not cause the memory
// buffer to fall out of scope. In contrast, concat is only used where
// necessary.
absl::Status OutputBatch(CalculatorContext* cc,
std::unique_ptr<InferenceState> inference_state) {
const int64 start_time = absl::ToUnixMicros(clock_->TimeNow());
std::vector<std::pair<mediapipe::ProtoString, tf::Tensor>> input_tensors;
for (auto& keyed_tensors : inference_state->input_tensor_batches_) {
if (options_.batch_size() == 1) {
// Short circuit to avoid the cost of deep copying tensors in concat.
if (!keyed_tensors.second.empty()) {
input_tensors.emplace_back(tag_to_tensor_map_[keyed_tensors.first],
keyed_tensors.second[0]);
} else {
// The input buffer can be empty for recurrent tensors.
RET_CHECK(
mediapipe::ContainsKey(recurrent_feed_tags_, keyed_tensors.first))
<< "A non-recurrent tensor does not have an input: "
<< keyed_tensors.first;
} }
} else { } else {
RET_CHECK_OK(AggregateTensorPacket( // Pad by replicating the first tens or, then ignore the values.
tag_as_node_name, cc->Inputs().Tag(tag_as_node_name).Value(), keyed_tensors.second.resize(options_.batch_size());
&input_tensors_by_tag_by_timestamp, inference_state_.get())); std::fill(keyed_tensors.second.begin() +
} inference_state->batch_timestamps_.size(),
} keyed_tensors.second.end(), keyed_tensors.second[0]);
for (const auto& timestamp_and_input_tensors_by_tag : tf::Tensor concated;
input_tensors_by_tag_by_timestamp) { const tf::Status concat_status =
inference_state_->batch_timestamps_.emplace_back( tf::tensor::Concat(keyed_tensors.second, &concated);
timestamp_and_input_tensors_by_tag.first); CHECK(concat_status.ok()) << concat_status.ToString();
for (const auto& input_tensor_and_tag :
timestamp_and_input_tensors_by_tag.second) {
inference_state_->input_tensor_batches_[input_tensor_and_tag.first]
.emplace_back(input_tensor_and_tag.second);
}
}
if (inference_state_->batch_timestamps_.size() == options_.batch_size() ||
options_.batched_input()) {
inference_state_to_process = std::move(inference_state_);
inference_state_ = std::unique_ptr<InferenceState>();
}
}
if (inference_state_to_process) {
MP_RETURN_IF_ERROR(OutputBatch(cc, std::move(inference_state_to_process)));
}
return absl::OkStatus();
}
absl::Status Close(CalculatorContext* cc) override {
std::unique_ptr<InferenceState> inference_state_to_process = nullptr;
{
absl::WriterMutexLock l(&mutex_);
if (cc->GraphStatus().ok() && inference_state_ != nullptr &&
!inference_state_->batch_timestamps_.empty()) {
inference_state_to_process = std::move(inference_state_);
inference_state_ = std::unique_ptr<InferenceState>();
}
}
if (inference_state_to_process) {
MP_RETURN_IF_ERROR(OutputBatch(cc, std::move(inference_state_to_process)));
}
return absl::OkStatus();
}
// When a batch of input tensors is ready to be run, runs TensorFlow and
// outputs the output tensors. The output tensors have timestamps matching
// the input tensor that formed that batch element. Any requested
// batch_dimension is added and removed. This code takes advantage of the fact
// that copying a tensor shares the same reference-counted, heap allocated
// memory buffer. Therefore, copies are cheap and should not cause the memory
// buffer to fall out of scope. In contrast, concat is only used where
// necessary.
absl::Status OutputBatch(CalculatorContext* cc,
std::unique_ptr<InferenceState> inference_state) {
const int64 start_time = absl::ToUnixMicros(clock_->TimeNow());
std::vector<std::pair<mediapipe::ProtoString, tf::Tensor>> input_tensors;
for (auto& keyed_tensors : inference_state->input_tensor_batches_) {
if (options_.batch_size() == 1) {
// Short circuit to avoid the cost of deep copying tensors in concat.
if (!keyed_tensors.second.empty()) {
input_tensors.emplace_back(tag_to_tensor_map_[keyed_tensors.first], input_tensors.emplace_back(tag_to_tensor_map_[keyed_tensors.first],
keyed_tensors.second[0]); concated);
} else {
// The input buffer can be empty for recurrent tensors.
RET_CHECK(
mediapipe::ContainsKey(recurrent_feed_tags_, keyed_tensors.first))
<< "A non-recurrent tensor does not have an input: "
<< keyed_tensors.first;
} }
} else {
// Pad by replicating the first tens or, then ignore the values.
keyed_tensors.second.resize(options_.batch_size());
std::fill(keyed_tensors.second.begin() +
inference_state->batch_timestamps_.size(),
keyed_tensors.second.end(), keyed_tensors.second[0]);
tf::Tensor concated;
const tf::Status concat_status =
tf::tensor::Concat(keyed_tensors.second, &concated);
CHECK(concat_status.ok()) << concat_status.ToString();
input_tensors.emplace_back(tag_to_tensor_map_[keyed_tensors.first],
concated);
} }
} inference_state->input_tensor_batches_.clear();
inference_state->input_tensor_batches_.clear(); std::vector<mediapipe::ProtoString> output_tensor_names;
std::vector<mediapipe::ProtoString> output_tensor_names; std::vector<std::string> output_name_in_signature;
std::vector<std::string> output_name_in_signature; for (const std::string& tag : cc->Outputs().GetTags()) {
for (const std::string& tag : cc->Outputs().GetTags()) { output_tensor_names.emplace_back(tag_to_tensor_map_[tag]);
output_tensor_names.emplace_back(tag_to_tensor_map_[tag]); output_name_in_signature.emplace_back(tag);
output_name_in_signature.emplace_back(tag);
}
for (const auto& tag_pair : recurrent_fetch_tags_to_feed_tags_) {
// Ensure that we always fetch the recurrent state tensors.
if (std::find(output_name_in_signature.begin(),
output_name_in_signature.end(),
tag_pair.first) == output_name_in_signature.end()) {
output_tensor_names.emplace_back(tag_to_tensor_map_[tag_pair.first]);
output_name_in_signature.emplace_back(tag_pair.first);
} }
} for (const auto& tag_pair : recurrent_fetch_tags_to_feed_tags_) {
std::vector<tf::Tensor> outputs; // Ensure that we always fetch the recurrent state tensors.
if (std::find(output_name_in_signature.begin(),
output_name_in_signature.end(),
tag_pair.first) == output_name_in_signature.end()) {
output_tensor_names.emplace_back(tag_to_tensor_map_[tag_pair.first]);
output_name_in_signature.emplace_back(tag_pair.first);
}
}
std::vector<tf::Tensor> outputs;
SimpleSemaphore* session_run_throttle = nullptr; SimpleSemaphore* session_run_throttle = nullptr;
if (options_.max_concurrent_session_runs() > 0) { if (options_.max_concurrent_session_runs() > 0) {
session_run_throttle = session_run_throttle =
get_session_run_throttle(options_.max_concurrent_session_runs()); get_session_run_throttle(options_.max_concurrent_session_runs());
session_run_throttle->Acquire(1); session_run_throttle->Acquire(1);
} }
const int64 run_start_time = absl::ToUnixMicros(clock_->TimeNow()); const int64 run_start_time = absl::ToUnixMicros(clock_->TimeNow());
tf::Status tf_status; tf::Status tf_status;
{ {
#if !defined(MEDIAPIPE_MOBILE) && !defined(__APPLE__) #if !defined(MEDIAPIPE_MOBILE) && !defined(__APPLE__)
tensorflow::profiler::TraceMe trace(absl::string_view(cc->NodeName())); tensorflow::profiler::TraceMe trace(absl::string_view(cc->NodeName()));
#endif #endif
tf_status = session_->Run(input_tensors, output_tensor_names, tf_status = session_->Run(input_tensors, output_tensor_names,
{} /* target_node_names */, &outputs); {} /* target_node_names */, &outputs);
} }
if (session_run_throttle != nullptr) { if (session_run_throttle != nullptr) {
session_run_throttle->Release(1); session_run_throttle->Release(1);
} }
// RET_CHECK on the tf::Status object itself in order to print an // RET_CHECK on the tf::Status object itself in order to print an
// informative error message. // informative error message.
RET_CHECK(tf_status.ok()) << "Run failed: " << tf_status.ToString(); RET_CHECK(tf_status.ok()) << "Run failed: " << tf_status.ToString();
const int64 run_end_time = absl::ToUnixMicros(clock_->TimeNow()); const int64 run_end_time = absl::ToUnixMicros(clock_->TimeNow());
cc->GetCounter(kTotalSessionRunsTimeUsecsCounterSuffix) cc->GetCounter(kTotalSessionRunsTimeUsecsCounterSuffix)
->IncrementBy(run_end_time - run_start_time); ->IncrementBy(run_end_time - run_start_time);
cc->GetCounter(kTotalNumSessionRunsCounterSuffix)->Increment(); cc->GetCounter(kTotalNumSessionRunsCounterSuffix)->Increment();
// Feed back the recurrent state. // Feed back the recurrent state.
for (const auto& tag_pair : recurrent_fetch_tags_to_feed_tags_) { for (const auto& tag_pair : recurrent_fetch_tags_to_feed_tags_) {
int pos = std::find(output_name_in_signature.begin(), int pos = std::find(output_name_in_signature.begin(),
output_name_in_signature.end(), tag_pair.first) - output_name_in_signature.end(), tag_pair.first) -
output_name_in_signature.begin(); output_name_in_signature.begin();
inference_state->input_tensor_batches_[tag_pair.second].emplace_back( inference_state->input_tensor_batches_[tag_pair.second].emplace_back(
outputs[pos]); outputs[pos]);
} }
absl::WriterMutexLock l(&mutex_); absl::WriterMutexLock l(&mutex_);
// Set that we want to split on each index of the 0th dimension. // Set that we want to split on each index of the 0th dimension.
std::vector<tf::int64> split_vector(options_.batch_size(), 1); std::vector<tf::int64> split_vector(options_.batch_size(), 1);
for (int i = 0; i < output_tensor_names.size(); ++i) { for (int i = 0; i < output_tensor_names.size(); ++i) {
if (options_.batch_size() == 1) { if (options_.batch_size() == 1) {
if (cc->Outputs().HasTag(output_name_in_signature[i])) { if (cc->Outputs().HasTag(output_name_in_signature[i])) {
tf::Tensor output_tensor(outputs[i]); tf::Tensor output_tensor(outputs[i]);
RET_CHECK_OK(RemoveBatchDimension(&output_tensor)); RET_CHECK_OK(RemoveBatchDimension(&output_tensor));
cc->Outputs() cc->Outputs()
.Tag(output_name_in_signature[i]) .Tag(output_name_in_signature[i])
.Add(new tf::Tensor(output_tensor), .Add(new tf::Tensor(output_tensor),
inference_state->batch_timestamps_[0]); inference_state->batch_timestamps_[0]);
} }
} else { } else {
std::vector<tf::Tensor> split_tensors; std::vector<tf::Tensor> split_tensors;
const tf::Status split_status = const tf::Status split_status =
tf::tensor::Split(outputs[i], split_vector, &split_tensors); tf::tensor::Split(outputs[i], split_vector, &split_tensors);
CHECK(split_status.ok()) << split_status.ToString(); CHECK(split_status.ok()) << split_status.ToString();
// Loop over timestamps so that we don't copy the padding. // Loop over timestamps so that we don't copy the padding.
for (int j = 0; j < inference_state->batch_timestamps_.size(); ++j) { for (int j = 0; j < inference_state->batch_timestamps_.size(); ++j) {
tf::Tensor output_tensor(split_tensors[j]); tf::Tensor output_tensor(split_tensors[j]);
RET_CHECK_OK(RemoveBatchDimension(&output_tensor)); RET_CHECK_OK(RemoveBatchDimension(&output_tensor));
cc->Outputs() cc->Outputs()
.Tag(output_name_in_signature[i]) .Tag(output_name_in_signature[i])
.Add(new tf::Tensor(output_tensor), .Add(new tf::Tensor(output_tensor),
inference_state->batch_timestamps_[j]); inference_state->batch_timestamps_[j]);
}
} }
} }
// Get end time and report.
const int64 end_time = absl::ToUnixMicros(clock_->TimeNow());
cc->GetCounter(kTotalUsecsCounterSuffix)
->IncrementBy(end_time - start_time);
cc->GetCounter(kTotalProcessedTimestampsCounterSuffix)
->IncrementBy(inference_state->batch_timestamps_.size());
// Make sure we hold on to the recursive state.
if (!options_.recurrent_tag_pair().empty()) {
inference_state_ = std::move(inference_state);
inference_state_->batch_timestamps_.clear();
}
return absl::OkStatus();
} }
// Get end time and report. private:
const int64 end_time = absl::ToUnixMicros(clock_->TimeNow()); // The Session object is provided by a packet factory and is owned by the
cc->GetCounter(kTotalUsecsCounterSuffix)->IncrementBy(end_time - start_time); // MediaPipe framework. Individual calls are thread-safe, but session state
cc->GetCounter(kTotalProcessedTimestampsCounterSuffix) // may be shared across threads.
->IncrementBy(inference_state->batch_timestamps_.size()); tf::Session* session_;
// Make sure we hold on to the recursive state. // A mapping between stream tags and the tensor names they are bound to.
if (!options_.recurrent_tag_pair().empty()) { std::map<std::string, std::string> tag_to_tensor_map_;
inference_state_ = std::move(inference_state);
inference_state_->batch_timestamps_.clear(); absl::Mutex mutex_;
std::unique_ptr<InferenceState> inference_state_ ABSL_GUARDED_BY(mutex_);
// The options for the calculator.
TensorFlowInferenceCalculatorOptions options_;
// Store the feed and fetch tags for feed/fetch recurrent networks.
std::set<std::string> recurrent_feed_tags_;
std::map<std::string, std::string> recurrent_fetch_tags_to_feed_tags_;
// Clock used to measure the computation time in OutputBatch().
std::unique_ptr<mediapipe::Clock> clock_;
// The static singleton semaphore to throttle concurrent session runs.
static SimpleSemaphore* get_session_run_throttle(
int32 max_concurrent_session_runs) {
static SimpleSemaphore* session_run_throttle =
new SimpleSemaphore(max_concurrent_session_runs);
return session_run_throttle;
} }
};
return absl::OkStatus();
}
private:
// The Session object is provided by a packet factory and is owned by the
// MediaPipe framework. Individual calls are thread-safe, but session state may
// be shared across threads.
tf::Session* session_;
// A mapping between stream tags and the tensor names they are bound to.
std::map<std::string, std::string> tag_to_tensor_map_;
absl::Mutex mutex_;
std::unique_ptr<InferenceState> inference_state_ ABSL_GUARDED_BY(mutex_);
// The options for the calculator.
TensorFlowInferenceCalculatorOptions options_;
// Store the feed and fetch tags for feed/fetch recurrent networks.
std::set<std::string> recurrent_feed_tags_;
std::map<std::string, std::string> recurrent_fetch_tags_to_feed_tags_;
// Clock used to measure the computation time in OutputBatch().
std::unique_ptr<mediapipe::Clock> clock_;
// The static singleton semaphore to throttle concurrent session runs.
static SimpleSemaphore* get_session_run_throttle(
int32 max_concurrent_session_runs) {
static SimpleSemaphore* session_run_throttle =
new SimpleSemaphore(max_concurrent_session_runs);
return session_run_throttle;
}
}
;
REGISTER_CALCULATOR(TensorFlowInferenceCalculator); REGISTER_CALCULATOR(TensorFlowInferenceCalculator);
constexpr char TensorFlowInferenceCalculator::kTotalUsecsCounterSuffix[]; constexpr char TensorFlowInferenceCalculator::kTotalUsecsCounterSuffix[];

View File

@ -80,6 +80,7 @@ const std::string MaybeConvertSignatureToTag(
// which in turn contains a TensorFlow Session ready for execution and a map // which in turn contains a TensorFlow Session ready for execution and a map
// between tags and tensor names. // between tags and tensor names.
// //
//
// Example usage: // Example usage:
// node { // node {
// calculator: "TensorFlowSessionFromSavedModelCalculator" // calculator: "TensorFlowSessionFromSavedModelCalculator"

View File

@ -217,38 +217,41 @@ class UnpackMediaSequenceCalculator : public CalculatorBase {
first_timestamp_seen_ = recent_timestamp; first_timestamp_seen_ = recent_timestamp;
} }
} }
if (recent_timestamp > last_timestamp_seen) { if (recent_timestamp > last_timestamp_seen &&
recent_timestamp < Timestamp::PostStream().Value()) {
last_timestamp_key_ = map_kv.first; last_timestamp_key_ = map_kv.first;
last_timestamp_seen = recent_timestamp; last_timestamp_seen = recent_timestamp;
} }
} }
} }
if (!timestamps_.empty()) { if (!timestamps_.empty()) {
RET_CHECK(!last_timestamp_key_.empty()) for (const auto& kv : timestamps_) {
<< "Something went wrong because the timestamp key is unset. " if (!kv.second.empty() &&
"Example: " kv.second[0] < Timestamp::PostStream().Value()) {
<< sequence_->DebugString(); // These checks only make sense if any values are not PostStream, but
RET_CHECK_GT(last_timestamp_seen, Timestamp::PreStream().Value()) // only need to be made once.
<< "Something went wrong because the last timestamp is unset. " RET_CHECK(!last_timestamp_key_.empty())
"Example: " << "Something went wrong because the timestamp key is unset. "
<< sequence_->DebugString(); << "Example: " << sequence_->DebugString();
RET_CHECK_LT(first_timestamp_seen_, RET_CHECK_GT(last_timestamp_seen, Timestamp::PreStream().Value())
Timestamp::OneOverPostStream().Value()) << "Something went wrong because the last timestamp is unset. "
<< "Something went wrong because the first timestamp is unset. " << "Example: " << sequence_->DebugString();
"Example: " RET_CHECK_LT(first_timestamp_seen_,
<< sequence_->DebugString(); Timestamp::OneOverPostStream().Value())
<< "Something went wrong because the first timestamp is unset. "
<< "Example: " << sequence_->DebugString();
break;
}
}
} }
current_timestamp_index_ = 0; current_timestamp_index_ = 0;
process_poststream_ = false;
// Determine the data path and output it. // Determine the data path and output it.
const auto& options = cc->Options<UnpackMediaSequenceCalculatorOptions>(); const auto& options = cc->Options<UnpackMediaSequenceCalculatorOptions>();
const auto& sequence = cc->InputSidePackets() const auto& sequence = cc->InputSidePackets()
.Tag(kSequenceExampleTag) .Tag(kSequenceExampleTag)
.Get<tensorflow::SequenceExample>(); .Get<tensorflow::SequenceExample>();
if (cc->Outputs().HasTag(kKeypointsTag)) {
keypoint_names_ = absl::StrSplit(options.keypoint_names(), ',');
default_keypoint_location_ = options.default_keypoint_location();
}
if (cc->OutputSidePackets().HasTag(kDataPath)) { if (cc->OutputSidePackets().HasTag(kDataPath)) {
std::string root_directory = ""; std::string root_directory = "";
if (cc->InputSidePackets().HasTag(kDatasetRootDirTag)) { if (cc->InputSidePackets().HasTag(kDatasetRootDirTag)) {
@ -349,19 +352,30 @@ class UnpackMediaSequenceCalculator : public CalculatorBase {
// all packets on all streams that have a timestamp between the current // all packets on all streams that have a timestamp between the current
// reference timestep and the previous reference timestep. This ensures that // reference timestep and the previous reference timestep. This ensures that
// we emit all timestamps in order, but also only emit a limited number in // we emit all timestamps in order, but also only emit a limited number in
// any particular call to Process(). // any particular call to Process(). At the every end, we output the
int64 start_timestamp = // poststream packets. If we only have poststream packets,
timestamps_[last_timestamp_key_][current_timestamp_index_]; // last_timestamp_key_ will be empty.
if (current_timestamp_index_ == 0) { int64 start_timestamp = 0;
start_timestamp = first_timestamp_seen_; int64 end_timestamp = 0;
if (last_timestamp_key_.empty() || process_poststream_) {
process_poststream_ = true;
start_timestamp = Timestamp::PostStream().Value();
end_timestamp = Timestamp::OneOverPostStream().Value();
} else {
start_timestamp =
timestamps_[last_timestamp_key_][current_timestamp_index_];
if (current_timestamp_index_ == 0) {
start_timestamp = first_timestamp_seen_;
}
end_timestamp = start_timestamp + 1; // Base case at end of sequence.
if (current_timestamp_index_ <
timestamps_[last_timestamp_key_].size() - 1) {
end_timestamp =
timestamps_[last_timestamp_key_][current_timestamp_index_ + 1];
}
} }
int64 end_timestamp = start_timestamp + 1; // Base case at end of sequence.
if (current_timestamp_index_ <
timestamps_[last_timestamp_key_].size() - 1) {
end_timestamp =
timestamps_[last_timestamp_key_][current_timestamp_index_ + 1];
}
for (const auto& map_kv : timestamps_) { for (const auto& map_kv : timestamps_) {
for (int i = 0; i < map_kv.second.size(); ++i) { for (int i = 0; i < map_kv.second.size(); ++i) {
if (map_kv.second[i] >= start_timestamp && if (map_kv.second[i] >= start_timestamp &&
@ -438,7 +452,14 @@ class UnpackMediaSequenceCalculator : public CalculatorBase {
if (current_timestamp_index_ < timestamps_[last_timestamp_key_].size()) { if (current_timestamp_index_ < timestamps_[last_timestamp_key_].size()) {
return absl::OkStatus(); return absl::OkStatus();
} else { } else {
return tool::StatusStop(); if (process_poststream_) {
// Once we've processed the PostStream timestamp we can stop.
return tool::StatusStop();
} else {
// Otherwise, we still need to do one more pass to process it.
process_poststream_ = true;
return absl::OkStatus();
}
} }
} }
@ -462,6 +483,7 @@ class UnpackMediaSequenceCalculator : public CalculatorBase {
std::vector<std::string> keypoint_names_; std::vector<std::string> keypoint_names_;
// Default keypoint location when missing. // Default keypoint location when missing.
float default_keypoint_location_; float default_keypoint_location_;
bool process_poststream_;
}; };
REGISTER_CALCULATOR(UnpackMediaSequenceCalculator); REGISTER_CALCULATOR(UnpackMediaSequenceCalculator);
} // namespace mediapipe } // namespace mediapipe

View File

@ -412,6 +412,72 @@ TEST_F(UnpackMediaSequenceCalculatorTest, UnpacksTwoPostStreamFloatLists) {
::testing::Eq(Timestamp::PostStream())); ::testing::Eq(Timestamp::PostStream()));
} }
TEST_F(UnpackMediaSequenceCalculatorTest, UnpacksImageWithPostStreamFloatList) {
SetUpCalculator({"IMAGE:images"}, {});
auto input_sequence = absl::make_unique<tf::SequenceExample>();
std::string test_video_id = "test_video_id";
mpms::SetClipMediaId(test_video_id, input_sequence.get());
std::string test_image_string = "test_image_string";
int num_images = 1;
for (int i = 0; i < num_images; ++i) {
mpms::AddImageTimestamp(i, input_sequence.get());
mpms::AddImageEncoded(test_image_string, input_sequence.get());
}
mpms::AddFeatureFloats("FDENSE_MAX", {3.0f, 4.0f}, input_sequence.get());
mpms::AddFeatureTimestamp("FDENSE_MAX", Timestamp::PostStream().Value(),
input_sequence.get());
runner_->MutableSidePackets()->Tag("SEQUENCE_EXAMPLE") =
Adopt(input_sequence.release());
MP_ASSERT_OK(runner_->Run());
const std::vector<Packet>& output_packets =
runner_->Outputs().Tag("IMAGE").packets;
ASSERT_EQ(num_images, output_packets.size());
for (int i = 0; i < num_images; ++i) {
const std::string& output_image = output_packets[i].Get<std::string>();
ASSERT_EQ(output_image, test_image_string);
}
}
TEST_F(UnpackMediaSequenceCalculatorTest, UnpacksPostStreamFloatListWithImage) {
SetUpCalculator({"FLOAT_FEATURE_FDENSE_MAX:max"}, {});
auto input_sequence = absl::make_unique<tf::SequenceExample>();
std::string test_video_id = "test_video_id";
mpms::SetClipMediaId(test_video_id, input_sequence.get());
std::string test_image_string = "test_image_string";
int num_images = 1;
for (int i = 0; i < num_images; ++i) {
mpms::AddImageTimestamp(i, input_sequence.get());
mpms::AddImageEncoded(test_image_string, input_sequence.get());
}
mpms::AddFeatureFloats("FDENSE_MAX", {3.0f, 4.0f}, input_sequence.get());
mpms::AddFeatureTimestamp("FDENSE_MAX", Timestamp::PostStream().Value(),
input_sequence.get());
runner_->MutableSidePackets()->Tag("SEQUENCE_EXAMPLE") =
Adopt(input_sequence.release());
MP_ASSERT_OK(runner_->Run());
const std::vector<Packet>& fdense_max_packets =
runner_->Outputs().Tag("FLOAT_FEATURE_FDENSE_MAX").packets;
ASSERT_EQ(fdense_max_packets.size(), 1);
const auto& fdense_max_vector =
fdense_max_packets[0].Get<std::vector<float>>();
ASSERT_THAT(fdense_max_vector, ::testing::ElementsAreArray({3.0f, 4.0f}));
ASSERT_THAT(fdense_max_packets[0].Timestamp(),
::testing::Eq(Timestamp::PostStream()));
}
TEST_F(UnpackMediaSequenceCalculatorTest, GetDatasetFromPacket) { TEST_F(UnpackMediaSequenceCalculatorTest, GetDatasetFromPacket) {
SetUpCalculator({}, {"DATA_PATH:data_path"}, {"DATASET_ROOT:root"}); SetUpCalculator({}, {"DATA_PATH:data_path"}, {"DATASET_ROOT:root"});

View File

@ -904,7 +904,8 @@ absl::Status TfLiteInferenceCalculator::LoadDelegate(CalculatorContext* cc) {
#if MEDIAPIPE_TFLITE_GL_INFERENCE #if MEDIAPIPE_TFLITE_GL_INFERENCE
// Configure and create the delegate. // Configure and create the delegate.
TfLiteGpuDelegateOptions options = TfLiteGpuDelegateOptionsDefault(); TfLiteGpuDelegateOptions options = TfLiteGpuDelegateOptionsDefault();
options.compile_options.precision_loss_allowed = 1; options.compile_options.precision_loss_allowed =
allow_precision_loss_ ? 1 : 0;
options.compile_options.preferred_gl_object_type = options.compile_options.preferred_gl_object_type =
TFLITE_GL_OBJECT_TYPE_FASTEST; TFLITE_GL_OBJECT_TYPE_FASTEST;
options.compile_options.dynamic_batch_enabled = 0; options.compile_options.dynamic_batch_enabled = 0;
@ -968,7 +969,7 @@ absl::Status TfLiteInferenceCalculator::LoadDelegate(CalculatorContext* cc) {
const int kHalfSize = 2; // sizeof(half) const int kHalfSize = 2; // sizeof(half)
// Configure and create the delegate. // Configure and create the delegate.
TFLGpuDelegateOptions options; TFLGpuDelegateOptions options;
options.allow_precision_loss = true; options.allow_precision_loss = allow_precision_loss_;
options.wait_type = TFLGpuDelegateWaitType::TFLGpuDelegateWaitTypeActive; options.wait_type = TFLGpuDelegateWaitType::TFLGpuDelegateWaitTypeActive;
if (!delegate_) if (!delegate_)
delegate_ = TfLiteDelegatePtr(TFLGpuDelegateCreate(&options), delegate_ = TfLiteDelegatePtr(TFLGpuDelegateCreate(&options),
@ -1080,9 +1081,10 @@ absl::Status TfLiteInferenceCalculator::LoadDelegate(CalculatorContext* cc) {
} }
// Create converter for GPU output. // Create converter for GPU output.
converter_from_BPHWC4_ = [[TFLBufferConvert alloc] initWithDevice:device converter_from_BPHWC4_ =
isFloat16:true [[TFLBufferConvert alloc] initWithDevice:device
convertToPBHWC4:false]; isFloat16:allow_precision_loss_
convertToPBHWC4:false];
if (converter_from_BPHWC4_ == nil) { if (converter_from_BPHWC4_ == nil) {
return absl::InternalError( return absl::InternalError(
"Error initializating output buffer converter"); "Error initializating output buffer converter");

View File

@ -439,7 +439,7 @@ absl::Status TfLiteTensorsToSegmentationCalculator::ProcessGpu(
// Run shader, upsample result. // Run shader, upsample result.
{ {
gpu_helper_.BindFramebuffer(output_texture); // GL_TEXTURE0 gpu_helper_.BindFramebuffer(output_texture);
glActiveTexture(GL_TEXTURE1); glActiveTexture(GL_TEXTURE1);
glBindTexture(GL_TEXTURE_2D, small_mask_texture.id()); glBindTexture(GL_TEXTURE_2D, small_mask_texture.id());
GlRender(); GlRender();

View File

@ -821,6 +821,25 @@ cc_library(
alwayslink = 1, alwayslink = 1,
) )
cc_test(
name = "landmark_projection_calculator_test",
srcs = ["landmark_projection_calculator_test.cc"],
deps = [
":landmark_projection_calculator",
"//mediapipe/calculators/tensor:image_to_tensor_utils",
"//mediapipe/framework:calculator_cc_proto",
"//mediapipe/framework:calculator_framework",
"//mediapipe/framework:calculator_runner",
"//mediapipe/framework/deps:message_matchers",
"//mediapipe/framework/formats:landmark_cc_proto",
"//mediapipe/framework/formats:rect_cc_proto",
"//mediapipe/framework/port:gtest_main",
"//mediapipe/framework/port:parse_text_proto",
"@com_google_absl//absl/memory",
"@com_google_googletest//:gtest_main",
],
)
mediapipe_proto_library( mediapipe_proto_library(
name = "landmarks_smoothing_calculator_proto", name = "landmarks_smoothing_calculator_proto",
srcs = ["landmarks_smoothing_calculator.proto"], srcs = ["landmarks_smoothing_calculator.proto"],
@ -1252,3 +1271,45 @@ cc_test(
"//mediapipe/framework/port:parse_text_proto", "//mediapipe/framework/port:parse_text_proto",
], ],
) )
mediapipe_proto_library(
name = "refine_landmarks_from_heatmap_calculator_proto",
srcs = ["refine_landmarks_from_heatmap_calculator.proto"],
visibility = ["//visibility:public"],
deps = [
"//mediapipe/framework:calculator_options_proto",
"//mediapipe/framework:calculator_proto",
],
)
cc_library(
name = "refine_landmarks_from_heatmap_calculator",
srcs = ["refine_landmarks_from_heatmap_calculator.cc"],
hdrs = ["refine_landmarks_from_heatmap_calculator.h"],
copts = select({
"//mediapipe:apple": [
"-x objective-c++",
"-fobjc-arc", # enable reference-counting
],
"//conditions:default": [],
}),
visibility = ["//visibility:public"],
deps = [
":refine_landmarks_from_heatmap_calculator_cc_proto",
"//mediapipe/framework:calculator_framework",
"//mediapipe/framework/api2:node",
"//mediapipe/framework/formats:landmark_cc_proto",
"//mediapipe/framework/formats:tensor",
"//mediapipe/framework/port:statusor",
],
alwayslink = 1,
)
cc_test(
name = "refine_landmarks_from_heatmap_calculator_test",
srcs = ["refine_landmarks_from_heatmap_calculator_test.cc"],
deps = [
":refine_landmarks_from_heatmap_calculator",
"//mediapipe/framework/port:gtest_main",
],
)

View File

@ -402,7 +402,7 @@ absl::Status AnnotationOverlayCalculator::RenderToGpu(CalculatorContext* cc,
// Blend overlay image in GPU shader. // Blend overlay image in GPU shader.
{ {
gpu_helper_.BindFramebuffer(output_texture); // GL_TEXTURE0 gpu_helper_.BindFramebuffer(output_texture);
glActiveTexture(GL_TEXTURE1); glActiveTexture(GL_TEXTURE1);
glBindTexture(GL_TEXTURE_2D, input_texture.name()); glBindTexture(GL_TEXTURE_2D, input_texture.name());

View File

@ -54,6 +54,7 @@ class DetectionLabelIdToTextCalculator : public CalculatorBase {
private: private:
absl::node_hash_map<int, std::string> label_map_; absl::node_hash_map<int, std::string> label_map_;
::mediapipe::DetectionLabelIdToTextCalculatorOptions options_;
}; };
REGISTER_CALCULATOR(DetectionLabelIdToTextCalculator); REGISTER_CALCULATOR(DetectionLabelIdToTextCalculator);
@ -68,13 +69,13 @@ absl::Status DetectionLabelIdToTextCalculator::GetContract(
absl::Status DetectionLabelIdToTextCalculator::Open(CalculatorContext* cc) { absl::Status DetectionLabelIdToTextCalculator::Open(CalculatorContext* cc) {
cc->SetOffset(TimestampDiff(0)); cc->SetOffset(TimestampDiff(0));
const auto& options = options_ =
cc->Options<::mediapipe::DetectionLabelIdToTextCalculatorOptions>(); cc->Options<::mediapipe::DetectionLabelIdToTextCalculatorOptions>();
if (options.has_label_map_path()) { if (options_.has_label_map_path()) {
std::string string_path; std::string string_path;
ASSIGN_OR_RETURN(string_path, ASSIGN_OR_RETURN(string_path,
PathToResourceAsFile(options.label_map_path())); PathToResourceAsFile(options_.label_map_path()));
std::string label_map_string; std::string label_map_string;
MP_RETURN_IF_ERROR(file::GetContents(string_path, &label_map_string)); MP_RETURN_IF_ERROR(file::GetContents(string_path, &label_map_string));
@ -85,8 +86,8 @@ absl::Status DetectionLabelIdToTextCalculator::Open(CalculatorContext* cc) {
label_map_[i++] = line; label_map_[i++] = line;
} }
} else { } else {
for (int i = 0; i < options.label_size(); ++i) { for (int i = 0; i < options_.label_size(); ++i) {
label_map_[i] = options.label(i); label_map_[i] = options_.label(i);
} }
} }
return absl::OkStatus(); return absl::OkStatus();
@ -106,7 +107,7 @@ absl::Status DetectionLabelIdToTextCalculator::Process(CalculatorContext* cc) {
} }
} }
// Remove label_id field if text labels exist. // Remove label_id field if text labels exist.
if (has_text_label) { if (has_text_label && !options_.keep_label_id()) {
output_detection.clear_label_id(); output_detection.clear_label_id();
} }
} }

View File

@ -31,4 +31,9 @@ message DetectionLabelIdToTextCalculatorOptions {
// label: "label for id 1" // label: "label for id 1"
// ... // ...
repeated string label = 2; repeated string label = 2;
// By default, the `label_id` field from the input is stripped if a text label
// could be found. By setting this field to true, it is always copied to the
// output detections.
optional bool keep_label_id = 3;
} }

View File

@ -120,7 +120,11 @@ absl::Status LabelsToRenderDataCalculator::Process(CalculatorContext* cc) {
labels.resize(classifications.classification_size()); labels.resize(classifications.classification_size());
scores.resize(classifications.classification_size()); scores.resize(classifications.classification_size());
for (int i = 0; i < classifications.classification_size(); ++i) { for (int i = 0; i < classifications.classification_size(); ++i) {
labels[i] = classifications.classification(i).label(); if (options_.use_display_name()) {
labels[i] = classifications.classification(i).display_name();
} else {
labels[i] = classifications.classification(i).label();
}
scores[i] = classifications.classification(i).score(); scores[i] = classifications.classification(i).score();
} }
} else { } else {

View File

@ -59,4 +59,7 @@ message LabelsToRenderDataCalculatorOptions {
BOTTOM_LEFT = 1; BOTTOM_LEFT = 1;
} }
optional Location location = 6 [default = TOP_LEFT]; optional Location location = 6 [default = TOP_LEFT];
// Uses Classification.display_name field instead of Classification.label.
optional bool use_display_name = 9 [default = false];
} }

View File

@ -13,6 +13,7 @@
// limitations under the License. // limitations under the License.
#include <cmath> #include <cmath>
#include <functional>
#include <vector> #include <vector>
#include "mediapipe/calculators/util/landmark_projection_calculator.pb.h" #include "mediapipe/calculators/util/landmark_projection_calculator.pb.h"
@ -27,20 +28,32 @@ namespace {
constexpr char kLandmarksTag[] = "NORM_LANDMARKS"; constexpr char kLandmarksTag[] = "NORM_LANDMARKS";
constexpr char kRectTag[] = "NORM_RECT"; constexpr char kRectTag[] = "NORM_RECT";
constexpr char kProjectionMatrix[] = "PROJECTION_MATRIX";
} // namespace } // namespace
// Projects normalized landmarks in a rectangle to its original coordinates. The // Projects normalized landmarks to its original coordinates.
// rectangle must also be in normalized coordinates.
// Input: // Input:
// NORM_LANDMARKS: A NormalizedLandmarkList representing landmarks // NORM_LANDMARKS - NormalizedLandmarkList
// in a normalized rectangle. // Represents landmarks in a normalized rectangle if NORM_RECT is specified
// NORM_RECT: An NormalizedRect representing a normalized rectangle in image // or landmarks that should be projected using PROJECTION_MATRIX if
// coordinates. // specified. (Prefer using PROJECTION_MATRIX as it eliminates need of
// letterbox removal step.)
// NORM_RECT - NormalizedRect
// Represents a normalized rectangle in image coordinates and results in
// landmarks with their locations adjusted to the image.
// PROJECTION_MATRIX - std::array<float, 16>
// A 4x4 row-major-order matrix that maps landmarks' locations from one
// coordinate system to another. In this case from the coordinate system of
// the normalized region of interest to the coordinate system of the image.
//
// Note: either NORM_RECT or PROJECTION_MATRIX has to be specified.
// Note: landmark's Z is projected in a custom way - it's scaled by width of
// the normalized region of interest used during landmarks detection.
// //
// Output: // Output:
// NORM_LANDMARKS: A NormalizedLandmarkList representing landmarks // NORM_LANDMARKS - NormalizedLandmarkList
// with their locations adjusted to the image. // Landmarks with their locations adjusted according to the inputs.
// //
// Usage example: // Usage example:
// node { // node {
@ -58,12 +71,27 @@ constexpr char kRectTag[] = "NORM_RECT";
// output_stream: "NORM_LANDMARKS:0:projected_landmarks_0" // output_stream: "NORM_LANDMARKS:0:projected_landmarks_0"
// output_stream: "NORM_LANDMARKS:1:projected_landmarks_1" // output_stream: "NORM_LANDMARKS:1:projected_landmarks_1"
// } // }
//
// node {
// calculator: "LandmarkProjectionCalculator"
// input_stream: "NORM_LANDMARKS:landmarks"
// input_stream: "PROECTION_MATRIX:matrix"
// output_stream: "NORM_LANDMARKS:projected_landmarks"
// }
//
// node {
// calculator: "LandmarkProjectionCalculator"
// input_stream: "NORM_LANDMARKS:0:landmarks_0"
// input_stream: "NORM_LANDMARKS:1:landmarks_1"
// input_stream: "PROECTION_MATRIX:matrix"
// output_stream: "NORM_LANDMARKS:0:projected_landmarks_0"
// output_stream: "NORM_LANDMARKS:1:projected_landmarks_1"
// }
class LandmarkProjectionCalculator : public CalculatorBase { class LandmarkProjectionCalculator : public CalculatorBase {
public: public:
static absl::Status GetContract(CalculatorContract* cc) { static absl::Status GetContract(CalculatorContract* cc) {
RET_CHECK(cc->Inputs().HasTag(kLandmarksTag) && RET_CHECK(cc->Inputs().HasTag(kLandmarksTag))
cc->Inputs().HasTag(kRectTag)) << "Missing NORM_LANDMARKS input.";
<< "Missing one or more input streams.";
RET_CHECK_EQ(cc->Inputs().NumEntries(kLandmarksTag), RET_CHECK_EQ(cc->Inputs().NumEntries(kLandmarksTag),
cc->Outputs().NumEntries(kLandmarksTag)) cc->Outputs().NumEntries(kLandmarksTag))
@ -73,7 +101,14 @@ class LandmarkProjectionCalculator : public CalculatorBase {
id != cc->Inputs().EndId(kLandmarksTag); ++id) { id != cc->Inputs().EndId(kLandmarksTag); ++id) {
cc->Inputs().Get(id).Set<NormalizedLandmarkList>(); cc->Inputs().Get(id).Set<NormalizedLandmarkList>();
} }
cc->Inputs().Tag(kRectTag).Set<NormalizedRect>(); RET_CHECK(cc->Inputs().HasTag(kRectTag) ^
cc->Inputs().HasTag(kProjectionMatrix))
<< "Either NORM_RECT or PROJECTION_MATRIX must be specified.";
if (cc->Inputs().HasTag(kRectTag)) {
cc->Inputs().Tag(kRectTag).Set<NormalizedRect>();
} else {
cc->Inputs().Tag(kProjectionMatrix).Set<std::array<float, 16>>();
}
for (CollectionItemId id = cc->Outputs().BeginId(kLandmarksTag); for (CollectionItemId id = cc->Outputs().BeginId(kLandmarksTag);
id != cc->Outputs().EndId(kLandmarksTag); ++id) { id != cc->Outputs().EndId(kLandmarksTag); ++id) {
@ -89,31 +124,50 @@ class LandmarkProjectionCalculator : public CalculatorBase {
return absl::OkStatus(); return absl::OkStatus();
} }
static void ProjectXY(const NormalizedLandmark& lm,
const std::array<float, 16>& matrix,
NormalizedLandmark* out) {
out->set_x(lm.x() * matrix[0] + lm.y() * matrix[1] + lm.z() * matrix[2] +
matrix[3]);
out->set_y(lm.x() * matrix[4] + lm.y() * matrix[5] + lm.z() * matrix[6] +
matrix[7]);
}
/**
* Landmark's Z scale is equal to a relative (to image) width of region of
* interest used during detection. To calculate based on matrix:
* 1. Project (0,0) --- (1,0) segment using matrix.
* 2. Calculate length of the projected segment.
*/
static float CalculateZScale(const std::array<float, 16>& matrix) {
NormalizedLandmark a;
a.set_x(0.0f);
a.set_y(0.0f);
NormalizedLandmark b;
b.set_x(1.0f);
b.set_y(0.0f);
NormalizedLandmark a_projected;
ProjectXY(a, matrix, &a_projected);
NormalizedLandmark b_projected;
ProjectXY(b, matrix, &b_projected);
return std::sqrt(std::pow(b_projected.x() - a_projected.x(), 2) +
std::pow(b_projected.y() - a_projected.y(), 2));
}
absl::Status Process(CalculatorContext* cc) override { absl::Status Process(CalculatorContext* cc) override {
if (cc->Inputs().Tag(kRectTag).IsEmpty()) { std::function<void(const NormalizedLandmark&, NormalizedLandmark*)>
return absl::OkStatus(); project_fn;
} if (cc->Inputs().HasTag(kRectTag)) {
const auto& input_rect = cc->Inputs().Tag(kRectTag).Get<NormalizedRect>(); if (cc->Inputs().Tag(kRectTag).IsEmpty()) {
return absl::OkStatus();
const auto& options =
cc->Options<::mediapipe::LandmarkProjectionCalculatorOptions>();
CollectionItemId input_id = cc->Inputs().BeginId(kLandmarksTag);
CollectionItemId output_id = cc->Outputs().BeginId(kLandmarksTag);
// Number of inputs and outpus is the same according to the contract.
for (; input_id != cc->Inputs().EndId(kLandmarksTag);
++input_id, ++output_id) {
const auto& input_packet = cc->Inputs().Get(input_id);
if (input_packet.IsEmpty()) {
continue;
} }
const auto& input_rect = cc->Inputs().Tag(kRectTag).Get<NormalizedRect>();
const auto& input_landmarks = input_packet.Get<NormalizedLandmarkList>(); const auto& options =
NormalizedLandmarkList output_landmarks; cc->Options<mediapipe::LandmarkProjectionCalculatorOptions>();
for (int i = 0; i < input_landmarks.landmark_size(); ++i) { project_fn = [&input_rect, &options](const NormalizedLandmark& landmark,
const NormalizedLandmark& landmark = input_landmarks.landmark(i); NormalizedLandmark* new_landmark) {
NormalizedLandmark* new_landmark = output_landmarks.add_landmark(); // TODO: fix projection or deprecate (current projection
// calculations are incorrect for general case).
const float x = landmark.x() - 0.5f; const float x = landmark.x() - 0.5f;
const float y = landmark.y() - 0.5f; const float y = landmark.y() - 0.5f;
const float angle = const float angle =
@ -130,10 +184,44 @@ class LandmarkProjectionCalculator : public CalculatorBase {
new_landmark->set_x(new_x); new_landmark->set_x(new_x);
new_landmark->set_y(new_y); new_landmark->set_y(new_y);
new_landmark->set_z(new_z); new_landmark->set_z(new_z);
};
} else if (cc->Inputs().HasTag(kProjectionMatrix)) {
if (cc->Inputs().Tag(kProjectionMatrix).IsEmpty()) {
return absl::OkStatus();
}
const auto& project_mat =
cc->Inputs().Tag(kProjectionMatrix).Get<std::array<float, 16>>();
const float z_scale = CalculateZScale(project_mat);
project_fn = [&project_mat, z_scale](const NormalizedLandmark& lm,
NormalizedLandmark* new_landmark) {
*new_landmark = lm;
ProjectXY(lm, project_mat, new_landmark);
new_landmark->set_z(z_scale * lm.z());
};
} else {
return absl::InternalError("Either rect or matrix must be specified.");
}
CollectionItemId input_id = cc->Inputs().BeginId(kLandmarksTag);
CollectionItemId output_id = cc->Outputs().BeginId(kLandmarksTag);
// Number of inputs and outpus is the same according to the contract.
for (; input_id != cc->Inputs().EndId(kLandmarksTag);
++input_id, ++output_id) {
const auto& input_packet = cc->Inputs().Get(input_id);
if (input_packet.IsEmpty()) {
continue;
}
const auto& input_landmarks = input_packet.Get<NormalizedLandmarkList>();
NormalizedLandmarkList output_landmarks;
for (int i = 0; i < input_landmarks.landmark_size(); ++i) {
const NormalizedLandmark& landmark = input_landmarks.landmark(i);
NormalizedLandmark* new_landmark = output_landmarks.add_landmark();
project_fn(landmark, new_landmark);
} }
cc->Outputs().Get(output_id).AddPacket( cc->Outputs().Get(output_id).AddPacket(
MakePacket<NormalizedLandmarkList>(output_landmarks) MakePacket<NormalizedLandmarkList>(std::move(output_landmarks))
.At(cc->InputTimestamp())); .At(cc->InputTimestamp()));
} }
return absl::OkStatus(); return absl::OkStatus();

View File

@ -0,0 +1,240 @@
#include <array>
#include <vector>
#include "absl/memory/memory.h"
#include "mediapipe/calculators/tensor/image_to_tensor_utils.h"
#include "mediapipe/framework/calculator.pb.h"
#include "mediapipe/framework/calculator_framework.h"
#include "mediapipe/framework/calculator_runner.h"
#include "mediapipe/framework/deps/message_matchers.h"
#include "mediapipe/framework/formats/landmark.pb.h"
#include "mediapipe/framework/formats/rect.pb.h"
#include "mediapipe/framework/port/gtest.h"
#include "mediapipe/framework/port/parse_text_proto.h"
#include "mediapipe/framework/port/status_matchers.h"
namespace mediapipe {
namespace {
absl::StatusOr<mediapipe::NormalizedLandmarkList> RunCalculator(
mediapipe::NormalizedLandmarkList input, mediapipe::NormalizedRect rect) {
mediapipe::CalculatorRunner runner(
ParseTextProtoOrDie<mediapipe::CalculatorGraphConfig::Node>(R"pb(
calculator: "LandmarkProjectionCalculator"
input_stream: "NORM_LANDMARKS:landmarks"
input_stream: "NORM_RECT:rect"
output_stream: "NORM_LANDMARKS:projected_landmarks"
)pb"));
runner.MutableInputs()
->Tag("NORM_LANDMARKS")
.packets.push_back(
MakePacket<mediapipe::NormalizedLandmarkList>(std::move(input))
.At(Timestamp(1)));
runner.MutableInputs()
->Tag("NORM_RECT")
.packets.push_back(MakePacket<mediapipe::NormalizedRect>(std::move(rect))
.At(Timestamp(1)));
MP_RETURN_IF_ERROR(runner.Run());
const auto& output_packets = runner.Outputs().Tag("NORM_LANDMARKS").packets;
RET_CHECK_EQ(output_packets.size(), 1);
return output_packets[0].Get<mediapipe::NormalizedLandmarkList>();
}
TEST(LandmarkProjectionCalculatorTest, ProjectingWithDefaultRect) {
mediapipe::NormalizedLandmarkList landmarks =
ParseTextProtoOrDie<mediapipe::NormalizedLandmarkList>(R"pb(
landmark { x: 10, y: 20, z: -0.5 }
)pb");
mediapipe::NormalizedRect rect =
ParseTextProtoOrDie<mediapipe::NormalizedRect>(
R"pb(
x_center: 0.5,
y_center: 0.5,
width: 1.0,
height: 1.0,
rotation: 0.0
)pb");
auto status_or_result = RunCalculator(std::move(landmarks), std::move(rect));
MP_ASSERT_OK(status_or_result);
EXPECT_THAT(
status_or_result.value(),
EqualsProto(ParseTextProtoOrDie<mediapipe::NormalizedLandmarkList>(R"pb(
landmark { x: 10, y: 20, z: -0.5 }
)pb")));
}
mediapipe::NormalizedRect GetCroppedRect() {
return ParseTextProtoOrDie<mediapipe::NormalizedRect>(
R"pb(
x_center: 0.5, y_center: 0.5, width: 0.5, height: 2, rotation: 0.0
)pb");
}
mediapipe::NormalizedLandmarkList GetCroppedRectTestInput() {
return ParseTextProtoOrDie<mediapipe::NormalizedLandmarkList>(R"pb(
landmark { x: 1.0, y: 1.0, z: -0.5 }
)pb");
}
mediapipe::NormalizedLandmarkList GetCroppedRectTestExpectedResult() {
return ParseTextProtoOrDie<mediapipe::NormalizedLandmarkList>(R"pb(
landmark { x: 0.75, y: 1.5, z: -0.25 }
)pb");
}
TEST(LandmarkProjectionCalculatorTest, ProjectingWithCroppedRect) {
auto status_or_result =
RunCalculator(GetCroppedRectTestInput(), GetCroppedRect());
MP_ASSERT_OK(status_or_result);
EXPECT_THAT(status_or_result.value(),
EqualsProto(GetCroppedRectTestExpectedResult()));
}
absl::StatusOr<mediapipe::NormalizedLandmarkList> RunCalculator(
mediapipe::NormalizedLandmarkList input, std::array<float, 16> matrix) {
mediapipe::CalculatorRunner runner(
ParseTextProtoOrDie<mediapipe::CalculatorGraphConfig::Node>(R"pb(
calculator: "LandmarkProjectionCalculator"
input_stream: "NORM_LANDMARKS:landmarks"
input_stream: "PROJECTION_MATRIX:matrix"
output_stream: "NORM_LANDMARKS:projected_landmarks"
)pb"));
runner.MutableInputs()
->Tag("NORM_LANDMARKS")
.packets.push_back(
MakePacket<mediapipe::NormalizedLandmarkList>(std::move(input))
.At(Timestamp(1)));
runner.MutableInputs()
->Tag("PROJECTION_MATRIX")
.packets.push_back(MakePacket<std::array<float, 16>>(std::move(matrix))
.At(Timestamp(1)));
MP_RETURN_IF_ERROR(runner.Run());
const auto& output_packets = runner.Outputs().Tag("NORM_LANDMARKS").packets;
RET_CHECK_EQ(output_packets.size(), 1);
return output_packets[0].Get<mediapipe::NormalizedLandmarkList>();
}
TEST(LandmarkProjectionCalculatorTest, ProjectingWithIdentityMatrix) {
mediapipe::NormalizedLandmarkList landmarks =
ParseTextProtoOrDie<mediapipe::NormalizedLandmarkList>(R"pb(
landmark { x: 10, y: 20, z: -0.5 }
)pb");
// clang-format off
std::array<float, 16> matrix = {
1.0f, 0.0f, 0.0f, 0.0f,
0.0f, 1.0f, 0.0f, 0.0f,
0.0f, 0.0f, 1.0f, 0.0f,
0.0f, 0.0f, 0.0f, 1.0f,
};
// clang-format on
auto status_or_result =
RunCalculator(std::move(landmarks), std::move(matrix));
MP_ASSERT_OK(status_or_result);
EXPECT_THAT(
status_or_result.value(),
EqualsProto(ParseTextProtoOrDie<mediapipe::NormalizedLandmarkList>(R"pb(
landmark { x: 10, y: 20, z: -0.5 }
)pb")));
}
TEST(LandmarkProjectionCalculatorTest, ProjectingWithCroppedRectMatrix) {
constexpr int kRectWidth = 1280;
constexpr int kRectHeight = 720;
auto roi = GetRoi(kRectWidth, kRectHeight, GetCroppedRect());
std::array<float, 16> matrix;
GetRotatedSubRectToRectTransformMatrix(roi, kRectWidth, kRectHeight,
/*flip_horizontaly=*/false, &matrix);
auto status_or_result = RunCalculator(GetCroppedRectTestInput(), matrix);
MP_ASSERT_OK(status_or_result);
EXPECT_THAT(status_or_result.value(),
EqualsProto(GetCroppedRectTestExpectedResult()));
}
TEST(LandmarkProjectionCalculatorTest, ProjectingWithScaleMatrix) {
mediapipe::NormalizedLandmarkList landmarks =
ParseTextProtoOrDie<mediapipe::NormalizedLandmarkList>(R"pb(
landmark { x: 10, y: 20, z: -0.5 }
landmark { x: 5, y: 6, z: 7 }
)pb");
// clang-format off
std::array<float, 16> matrix = {
10.0f, 0.0f, 0.0f, 0.0f,
0.0f, 100.0f, 0.0f, 0.0f,
0.0f, 0.0f, 1.0f, 0.0f,
0.0f, 0.0f, 0.0f, 1.0f,
};
// clang-format on
auto status_or_result =
RunCalculator(std::move(landmarks), std::move(matrix));
MP_ASSERT_OK(status_or_result);
EXPECT_THAT(
status_or_result.value(),
EqualsProto(ParseTextProtoOrDie<mediapipe::NormalizedLandmarkList>(R"pb(
landmark { x: 100, y: 2000, z: -5 }
landmark { x: 50, y: 600, z: 70 }
)pb")));
}
TEST(LandmarkProjectionCalculatorTest, ProjectingWithTranslateMatrix) {
mediapipe::NormalizedLandmarkList landmarks =
ParseTextProtoOrDie<mediapipe::NormalizedLandmarkList>(R"pb(
landmark { x: 10, y: 20, z: -0.5 }
)pb");
// clang-format off
std::array<float, 16> matrix = {
1.0f, 0.0f, 0.0f, 1.0f,
0.0f, 1.0f, 0.0f, 2.0f,
0.0f, 0.0f, 1.0f, 0.0f,
0.0f, 0.0f, 0.0f, 1.0f,
};
// clang-format on
auto status_or_result =
RunCalculator(std::move(landmarks), std::move(matrix));
MP_ASSERT_OK(status_or_result);
EXPECT_THAT(
status_or_result.value(),
EqualsProto(ParseTextProtoOrDie<mediapipe::NormalizedLandmarkList>(R"pb(
landmark { x: 11, y: 22, z: -0.5 }
)pb")));
}
TEST(LandmarkProjectionCalculatorTest, ProjectingWithRotationMatrix) {
mediapipe::NormalizedLandmarkList landmarks =
ParseTextProtoOrDie<mediapipe::NormalizedLandmarkList>(R"pb(
landmark { x: 4, y: 0, z: -0.5 }
)pb");
// clang-format off
// 90 degrees rotation matrix
std::array<float, 16> matrix = {
0.0f, -1.0f, 0.0f, 0.0f,
1.0f, 0.0f, 0.0f, 0.0f,
0.0f, 0.0f, 1.0f, 0.0f,
0.0f, 0.0f, 0.0f, 1.0f,
};
// clang-format on
auto status_or_result =
RunCalculator(std::move(landmarks), std::move(matrix));
MP_ASSERT_OK(status_or_result);
EXPECT_THAT(
status_or_result.value(),
EqualsProto(ParseTextProtoOrDie<mediapipe::NormalizedLandmarkList>(R"pb(
landmark { x: 0, y: 4, z: -0.5 }
)pb")));
}
} // namespace
} // namespace mediapipe

View File

@ -1,3 +1,17 @@
// Copyright 2020 The MediaPipe Authors.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "mediapipe/calculators/util/rect_to_render_scale_calculator.pb.h" #include "mediapipe/calculators/util/rect_to_render_scale_calculator.pb.h"
#include "mediapipe/framework/calculator_framework.h" #include "mediapipe/framework/calculator_framework.h"
#include "mediapipe/framework/formats/rect.pb.h" #include "mediapipe/framework/formats/rect.pb.h"

View File

@ -1,3 +1,17 @@
// Copyright 2020 The MediaPipe Authors.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
syntax = "proto2"; syntax = "proto2";
package mediapipe; package mediapipe;

View File

@ -0,0 +1,166 @@
// Copyright 2021 The MediaPipe Authors.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "mediapipe/calculators/util/refine_landmarks_from_heatmap_calculator.h"
#include "mediapipe/calculators/util/refine_landmarks_from_heatmap_calculator.pb.h"
#include "mediapipe/framework/calculator_framework.h"
namespace mediapipe {
namespace {
inline float Sigmoid(float value) { return 1.0f / (1.0f + std::exp(-value)); }
absl::StatusOr<std::tuple<int, int, int>> GetHwcFromDims(
const std::vector<int>& dims) {
if (dims.size() == 3) {
return std::make_tuple(dims[0], dims[1], dims[2]);
} else if (dims.size() == 4) {
// BHWC format check B == 1
RET_CHECK_EQ(1, dims[0]) << "Expected batch to be 1 for BHWC heatmap";
return std::make_tuple(dims[1], dims[2], dims[3]);
} else {
RET_CHECK(false) << "Invalid shape size for heatmap tensor" << dims.size();
}
}
} // namespace
namespace api2 {
// Refines landmarks using correspond heatmap area.
//
// Input:
// NORM_LANDMARKS - Required. Input normalized landmarks to update.
// TENSORS - Required. Vector of input tensors. 0th element should be heatmap.
// The rest is unused.
// Output:
// NORM_LANDMARKS - Required. Updated normalized landmarks.
class RefineLandmarksFromHeatmapCalculatorImpl
: public NodeImpl<RefineLandmarksFromHeatmapCalculator,
RefineLandmarksFromHeatmapCalculatorImpl> {
public:
absl::Status Open(CalculatorContext* cc) override { return absl::OkStatus(); }
absl::Status Process(CalculatorContext* cc) override {
// Make sure we bypass landmarks if there is no detection.
if (kInLandmarks(cc).IsEmpty()) {
return absl::OkStatus();
}
// If for some reason heatmap is missing, just return original landmarks.
if (kInTensors(cc).IsEmpty()) {
kOutLandmarks(cc).Send(*kInLandmarks(cc));
return absl::OkStatus();
}
// Check basic prerequisites.
const auto& input_tensors = *kInTensors(cc);
RET_CHECK(!input_tensors.empty()) << "Empty input tensors list. First "
"element is expeced to be a heatmap";
const auto& hm_tensor = input_tensors[0];
const auto& in_lms = *kInLandmarks(cc);
auto hm_view = hm_tensor.GetCpuReadView();
auto hm_raw = hm_view.buffer<float>();
const auto& options =
cc->Options<mediapipe::RefineLandmarksFromHeatmapCalculatorOptions>();
ASSIGN_OR_RETURN(auto out_lms, RefineLandmarksFromHeatMap(
in_lms, hm_raw, hm_tensor.shape().dims,
options.kernel_size(),
options.min_confidence_to_refine()));
kOutLandmarks(cc).Send(std::move(out_lms));
return absl::OkStatus();
}
};
} // namespace api2
// Runs actual refinement
// High level algorithm:
//
// Heatmap is accepted as tensor in HWC layout where i-th channel is a heatmap
// for the i-th landmark.
//
// For each landmark we replace original value with a value calculated from the
// area in heatmap close to original landmark position (in particular are
// covered with kernel of size options.kernel_size). To calculate new coordinate
// from heatmap we calculate an weighted average inside the kernel. We update
// the landmark iff heatmap is confident in it's prediction i.e. max(heatmap) in
// kernel is at least options.min_confidence_to_refine big.
absl::StatusOr<mediapipe::NormalizedLandmarkList> RefineLandmarksFromHeatMap(
const mediapipe::NormalizedLandmarkList& in_lms,
const float* heatmap_raw_data, const std::vector<int>& heatmap_dims,
int kernel_size, float min_confidence_to_refine) {
ASSIGN_OR_RETURN(auto hm_dims, GetHwcFromDims(heatmap_dims));
auto [hm_height, hm_width, hm_channels] = hm_dims;
RET_CHECK_EQ(in_lms.landmark_size(), hm_channels)
<< "Expected heatmap to have number of layers == to number of "
"landmarks";
int hm_row_size = hm_width * hm_channels;
int hm_pixel_size = hm_channels;
mediapipe::NormalizedLandmarkList out_lms = in_lms;
for (int lm_index = 0; lm_index < out_lms.landmark_size(); ++lm_index) {
int center_col = out_lms.landmark(lm_index).x() * hm_width;
int center_row = out_lms.landmark(lm_index).y() * hm_height;
// Point is outside of the image let's keep it intact.
if (center_col < 0 || center_col >= hm_width || center_row < 0 ||
center_col >= hm_height) {
continue;
}
int offset = (kernel_size - 1) / 2;
// Calculate area to iterate over. Note that we decrease the kernel on
// the edges of the heatmap. Equivalent to zero border.
int begin_col = std::max(0, center_col - offset);
int end_col = std::min(hm_width, center_col + offset + 1);
int begin_row = std::max(0, center_row - offset);
int end_row = std::min(hm_height, center_row + offset + 1);
float sum = 0;
float weighted_col = 0;
float weighted_row = 0;
float max_value = 0;
// Main loop. Go over kernel and calculate weighted sum of coordinates,
// sum of weights and max weights.
for (int row = begin_row; row < end_row; ++row) {
for (int col = begin_col; col < end_col; ++col) {
// We expect memory to be in HWC layout without padding.
int idx = hm_row_size * row + hm_pixel_size * col + lm_index;
// Right now we hardcode sigmoid activation as it will be wasteful to
// calculate sigmoid for each value of heatmap in the model itself. If
// we ever have other activations it should be trivial to expand via
// options.
float confidence = Sigmoid(heatmap_raw_data[idx]);
sum += confidence;
max_value = std::max(max_value, confidence);
weighted_col += col * confidence;
weighted_row += row * confidence;
}
}
if (max_value >= min_confidence_to_refine && sum > 0) {
out_lms.mutable_landmark(lm_index)->set_x(weighted_col / hm_width / sum);
out_lms.mutable_landmark(lm_index)->set_y(weighted_row / hm_height / sum);
}
}
return out_lms;
}
} // namespace mediapipe

View File

@ -0,0 +1,50 @@
// Copyright 2021 The MediaPipe Authors.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#ifndef MEDIAPIPE_CALCULATORS_UTIL_REFINE_LANDMARKS_FROM_HEATMAP_CALCULATOR_H_
#define MEDIAPIPE_CALCULATORS_UTIL_REFINE_LANDMARKS_FROM_HEATMAP_CALCULATOR_H_
#include <vector>
#include "mediapipe/framework/api2/node.h"
#include "mediapipe/framework/formats/landmark.pb.h"
#include "mediapipe/framework/formats/tensor.h"
#include "mediapipe/framework/port/statusor.h"
namespace mediapipe {
namespace api2 {
class RefineLandmarksFromHeatmapCalculator : public NodeIntf {
public:
static constexpr Input<mediapipe::NormalizedLandmarkList> kInLandmarks{
"NORM_LANDMARKS"};
static constexpr Input<std::vector<Tensor>> kInTensors{"TENSORS"};
static constexpr Output<mediapipe::NormalizedLandmarkList> kOutLandmarks{
"NORM_LANDMARKS"};
MEDIAPIPE_NODE_INTERFACE(RefineLandmarksFromHeatmapCalculator, kInLandmarks,
kInTensors, kOutLandmarks);
};
} // namespace api2
// Exposed for testing.
absl::StatusOr<mediapipe::NormalizedLandmarkList> RefineLandmarksFromHeatMap(
const mediapipe::NormalizedLandmarkList& in_lms,
const float* heatmap_raw_data, const std::vector<int>& heatmap_dims,
int kernel_size, float min_confidence_to_refine);
} // namespace mediapipe
#endif // MEDIAPIPE_CALCULATORS_UTIL_REFINE_LANDMARKS_FROM_HEATMAP_CALCULATOR_H_

View File

@ -0,0 +1,27 @@
// Copyright 2021 The MediaPipe Authors.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
syntax = "proto2";
package mediapipe;
import "mediapipe/framework/calculator.proto";
message RefineLandmarksFromHeatmapCalculatorOptions {
extend mediapipe.CalculatorOptions {
optional RefineLandmarksFromHeatmapCalculatorOptions ext = 362281653;
}
optional int32 kernel_size = 1 [default = 9];
optional float min_confidence_to_refine = 2 [default = 0.5];
}

View File

@ -0,0 +1,152 @@
// Copyright 2021 The MediaPipe Authors.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "mediapipe/calculators/util/refine_landmarks_from_heatmap_calculator.h"
#include "mediapipe/framework/port/gmock.h"
#include "mediapipe/framework/port/gtest.h"
#include "mediapipe/framework/port/status_matchers.h"
namespace mediapipe {
namespace {
mediapipe::NormalizedLandmarkList vec_to_lms(
const std::vector<std::pair<float, float>>& inp) {
mediapipe::NormalizedLandmarkList ret;
for (const auto& it : inp) {
auto new_lm = ret.add_landmark();
new_lm->set_x(it.first);
new_lm->set_y(it.second);
}
return ret;
}
std::vector<std::pair<float, float>> lms_to_vec(
const mediapipe::NormalizedLandmarkList& lst) {
std::vector<std::pair<float, float>> ret;
for (const auto& lm : lst.landmark()) {
ret.push_back({lm.x(), lm.y()});
}
return ret;
}
std::vector<float> CHW_to_HWC(std::vector<float> inp, int height, int width,
int depth) {
std::vector<float> ret(inp.size());
const float* inp_ptr = inp.data();
for (int c = 0; c < depth; ++c) {
for (int row = 0; row < height; ++row) {
for (int col = 0; col < width; ++col) {
int dest_idx = width * depth * row + depth * col + c;
ret[dest_idx] = *inp_ptr;
++inp_ptr;
}
}
}
return ret;
}
using testing::ElementsAre;
using testing::FloatEq;
using testing::Pair;
TEST(RefineLandmarksFromHeatmapTest, Smoke) {
float z = -10000000000000000;
// clang-format off
std::vector<float> hm = {
z, z, z,
1, z, z,
z, z, z};
// clang-format on
auto ret_or_error = RefineLandmarksFromHeatMap(vec_to_lms({{0.5, 0.5}}),
hm.data(), {3, 3, 1}, 3, 0.1);
MP_EXPECT_OK(ret_or_error);
EXPECT_THAT(lms_to_vec(*ret_or_error),
ElementsAre(Pair(FloatEq(0), FloatEq(1 / 3.))));
}
TEST(RefineLandmarksFromHeatmapTest, MultiLayer) {
float z = -10000000000000000;
// clang-format off
std::vector<float> hm = CHW_to_HWC({
z, z, z,
1, z, z,
z, z, z,
z, z, z,
1, z, z,
z, z, z,
z, z, z,
1, z, z,
z, z, z}, 3, 3, 3);
// clang-format on
auto ret_or_error = RefineLandmarksFromHeatMap(
vec_to_lms({{0.5, 0.5}, {0.5, 0.5}, {0.5, 0.5}}), hm.data(), {3, 3, 3}, 3,
0.1);
MP_EXPECT_OK(ret_or_error);
EXPECT_THAT(lms_to_vec(*ret_or_error),
ElementsAre(Pair(FloatEq(0), FloatEq(1 / 3.)),
Pair(FloatEq(0), FloatEq(1 / 3.)),
Pair(FloatEq(0), FloatEq(1 / 3.))));
}
TEST(RefineLandmarksFromHeatmapTest, KeepIfNotSure) {
float z = -10000000000000000;
// clang-format off
std::vector<float> hm = CHW_to_HWC({
z, z, z,
0, z, z,
z, z, z,
z, z, z,
0, z, z,
z, z, z,
z, z, z,
0, z, z,
z, z, z}, 3, 3, 3);
// clang-format on
auto ret_or_error = RefineLandmarksFromHeatMap(
vec_to_lms({{0.5, 0.5}, {0.5, 0.5}, {0.5, 0.5}}), hm.data(), {3, 3, 3}, 3,
0.6);
MP_EXPECT_OK(ret_or_error);
EXPECT_THAT(lms_to_vec(*ret_or_error),
ElementsAre(Pair(FloatEq(0.5), FloatEq(0.5)),
Pair(FloatEq(0.5), FloatEq(0.5)),
Pair(FloatEq(0.5), FloatEq(0.5))));
}
TEST(RefineLandmarksFromHeatmapTest, Border) {
float z = -10000000000000000;
// clang-format off
std::vector<float> hm = CHW_to_HWC({
z, z, z,
0, z, 0,
z, z, z,
z, z, z,
0, z, 0,
z, z, 0}, 3, 3, 2);
// clang-format on
auto ret_or_error = RefineLandmarksFromHeatMap(
vec_to_lms({{0.0, 0.0}, {0.9, 0.9}}), hm.data(), {3, 3, 2}, 3, 0.1);
MP_EXPECT_OK(ret_or_error);
EXPECT_THAT(lms_to_vec(*ret_or_error),
ElementsAre(Pair(FloatEq(0), FloatEq(1 / 3.)),
Pair(FloatEq(2 / 3.), FloatEq(1 / 6. + 2 / 6.))));
}
} // namespace
} // namespace mediapipe

View File

@ -101,7 +101,7 @@ absl::Status ThresholdingCalculator::Open(CalculatorContext* cc) {
} }
if (cc->InputSidePackets().HasTag("THRESHOLD")) { if (cc->InputSidePackets().HasTag("THRESHOLD")) {
threshold_ = cc->InputSidePackets().Tag("THRESHOLD").Get<float>(); threshold_ = cc->InputSidePackets().Tag("THRESHOLD").Get<double>();
} }
return absl::OkStatus(); return absl::OkStatus();
} }

View File

@ -43,8 +43,7 @@ android_binary(
"//mediapipe/modules/hand_landmark:handedness.txt", "//mediapipe/modules/hand_landmark:handedness.txt",
"//mediapipe/modules/holistic_landmark:hand_recrop.tflite", "//mediapipe/modules/holistic_landmark:hand_recrop.tflite",
"//mediapipe/modules/pose_detection:pose_detection.tflite", "//mediapipe/modules/pose_detection:pose_detection.tflite",
"//mediapipe/modules/pose_landmark:pose_landmark_upper_body.tflite", "//mediapipe/modules/pose_landmark:pose_landmark_full.tflite",
"//mediapipe/modules/pose_landmark:pose_landmark_full_body.tflite",
], ],
assets_dir = "", assets_dir = "",
manifest = "//mediapipe/examples/android/src/java/com/google/mediapipe/apps/basic:AndroidManifest.xml", manifest = "//mediapipe/examples/android/src/java/com/google/mediapipe/apps/basic:AndroidManifest.xml",

View File

@ -37,7 +37,7 @@ android_binary(
srcs = glob(["*.java"]), srcs = glob(["*.java"]),
assets = [ assets = [
"//mediapipe/graphs/pose_tracking:pose_tracking_gpu.binarypb", "//mediapipe/graphs/pose_tracking:pose_tracking_gpu.binarypb",
"//mediapipe/modules/pose_landmark:pose_landmark_full_body.tflite", "//mediapipe/modules/pose_landmark:pose_landmark_full.tflite",
"//mediapipe/modules/pose_detection:pose_detection.tflite", "//mediapipe/modules/pose_detection:pose_detection.tflite",
], ],
assets_dir = "", assets_dir = "",

View File

@ -1,64 +0,0 @@
# Copyright 2019 The MediaPipe Authors.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
licenses(["notice"])
package(default_visibility = ["//visibility:private"])
cc_binary(
name = "libmediapipe_jni.so",
linkshared = 1,
linkstatic = 1,
deps = [
"//mediapipe/graphs/pose_tracking:upper_body_pose_tracking_gpu_deps",
"//mediapipe/java/com/google/mediapipe/framework/jni:mediapipe_framework_jni",
],
)
cc_library(
name = "mediapipe_jni_lib",
srcs = [":libmediapipe_jni.so"],
alwayslink = 1,
)
android_binary(
name = "upperbodyposetrackinggpu",
srcs = glob(["*.java"]),
assets = [
"//mediapipe/graphs/pose_tracking:upper_body_pose_tracking_gpu.binarypb",
"//mediapipe/modules/pose_landmark:pose_landmark_upper_body.tflite",
"//mediapipe/modules/pose_detection:pose_detection.tflite",
],
assets_dir = "",
manifest = "//mediapipe/examples/android/src/java/com/google/mediapipe/apps/basic:AndroidManifest.xml",
manifest_values = {
"applicationId": "com.google.mediapipe.apps.upperbodyposetrackinggpu",
"appName": "Upper Body Pose Tracking",
"mainActivity": ".MainActivity",
"cameraFacingFront": "False",
"binaryGraphName": "upper_body_pose_tracking_gpu.binarypb",
"inputVideoStreamName": "input_video",
"outputVideoStreamName": "output_video",
"flipFramesVertically": "True",
"converterNumBuffers": "2",
},
multidex = "native",
deps = [
":mediapipe_jni_lib",
"//mediapipe/examples/android/src/java/com/google/mediapipe/apps/basic:basic_lib",
"//mediapipe/framework/formats:landmark_java_proto_lite",
"//mediapipe/java/com/google/mediapipe/framework:android_framework",
"@com_google_protobuf//:protobuf_javalite",
],
)

View File

@ -1,75 +0,0 @@
// Copyright 2019 The MediaPipe Authors.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
package com.google.mediapipe.apps.upperbodyposetrackinggpu;
import android.os.Bundle;
import android.util.Log;
import com.google.mediapipe.formats.proto.LandmarkProto.NormalizedLandmark;
import com.google.mediapipe.formats.proto.LandmarkProto.NormalizedLandmarkList;
import com.google.mediapipe.framework.PacketGetter;
import com.google.protobuf.InvalidProtocolBufferException;
/** Main activity of MediaPipe upper-body pose tracking app. */
public class MainActivity extends com.google.mediapipe.apps.basic.MainActivity {
private static final String TAG = "MainActivity";
private static final String OUTPUT_LANDMARKS_STREAM_NAME = "pose_landmarks";
@Override
protected void onCreate(Bundle savedInstanceState) {
super.onCreate(savedInstanceState);
// To show verbose logging, run:
// adb shell setprop log.tag.MainActivity VERBOSE
if (Log.isLoggable(TAG, Log.VERBOSE)) {
processor.addPacketCallback(
OUTPUT_LANDMARKS_STREAM_NAME,
(packet) -> {
Log.v(TAG, "Received pose landmarks packet.");
try {
NormalizedLandmarkList poseLandmarks =
PacketGetter.getProto(packet, NormalizedLandmarkList.class);
Log.v(
TAG,
"[TS:"
+ packet.getTimestamp()
+ "] "
+ getPoseLandmarksDebugString(poseLandmarks));
} catch (InvalidProtocolBufferException exception) {
Log.e(TAG, "Failed to get proto.", exception);
}
});
}
}
private static String getPoseLandmarksDebugString(NormalizedLandmarkList poseLandmarks) {
String poseLandmarkStr = "Pose landmarks: " + poseLandmarks.getLandmarkCount() + "\n";
int landmarkIndex = 0;
for (NormalizedLandmark landmark : poseLandmarks.getLandmarkList()) {
poseLandmarkStr +=
"\tLandmark ["
+ landmarkIndex
+ "]: ("
+ landmark.getX()
+ ", "
+ landmark.getY()
+ ", "
+ landmark.getZ()
+ ")\n";
++landmarkIndex;
}
return poseLandmarkStr;
}
}

View File

@ -142,26 +142,13 @@ node {
} }
} }
# Maps detection label IDs to the corresponding label text ("Face"). The label
# map is provided in the label_map_path option.
node {
calculator: "DetectionLabelIdToTextCalculator"
input_stream: "filtered_detections"
output_stream: "labeled_detections"
options: {
[mediapipe.DetectionLabelIdToTextCalculatorOptions.ext] {
label_map_path: "mediapipe/models/face_detection_front_labelmap.txt"
}
}
}
# Adjusts detection locations (already normalized to [0.f, 1.f]) on the # Adjusts detection locations (already normalized to [0.f, 1.f]) on the
# letterboxed image (after image transformation with the FIT scale mode) to the # letterboxed image (after image transformation with the FIT scale mode) to the
# corresponding locations on the same image with the letterbox removed (the # corresponding locations on the same image with the letterbox removed (the
# input image to the graph before image transformation). # input image to the graph before image transformation).
node { node {
calculator: "DetectionLetterboxRemovalCalculator" calculator: "DetectionLetterboxRemovalCalculator"
input_stream: "DETECTIONS:labeled_detections" input_stream: "DETECTIONS:filtered_detections"
input_stream: "LETTERBOX_PADDING:letterbox_padding" input_stream: "LETTERBOX_PADDING:letterbox_padding"
output_stream: "DETECTIONS:output_detections" output_stream: "DETECTIONS:output_detections"
} }

View File

@ -33,6 +33,10 @@ constexpr char kDetections[] = "DETECTIONS";
constexpr char kDetectedBorders[] = "BORDERS"; constexpr char kDetectedBorders[] = "BORDERS";
constexpr char kCropRect[] = "CROP_RECT"; constexpr char kCropRect[] = "CROP_RECT";
constexpr char kFirstCropRect[] = "FIRST_CROP_RECT"; constexpr char kFirstCropRect[] = "FIRST_CROP_RECT";
// Can be used to control whether an animated zoom should actually performed
// (configured through option us_to_first_rect). If provided, a non-zero integer
// will allow the animated zoom to be used when the first detections arrive.
constexpr char kAnimateZoom[] = "ANIMATE_ZOOM";
// Field-of-view (degrees) of the camera's x-axis (width). // Field-of-view (degrees) of the camera's x-axis (width).
// TODO: Parameterize FOV based on camera specs. // TODO: Parameterize FOV based on camera specs.
constexpr float kFieldOfView = 60; constexpr float kFieldOfView = 60;
@ -76,10 +80,10 @@ class ContentZoomingCalculator : public CalculatorBase {
absl::Status InitializeState(int frame_width, int frame_height); absl::Status InitializeState(int frame_width, int frame_height);
// Adjusts state to work with an updated frame size. // Adjusts state to work with an updated frame size.
absl::Status UpdateForResolutionChange(int frame_width, int frame_height); absl::Status UpdateForResolutionChange(int frame_width, int frame_height);
// Returns true if we are zooming to the initial rect. // Returns true if we are animating to the first rect.
bool IsZoomingToInitialRect(const Timestamp& timestamp) const; bool IsAnimatingToFirstRect(const Timestamp& timestamp) const;
// Builds the output rectangle when zooming to the initial rect. // Builds the output rectangle when animating to the first rect.
absl::StatusOr<mediapipe::Rect> GetInitialZoomingRect( absl::StatusOr<mediapipe::Rect> GetAnimationRect(
int frame_width, int frame_height, const Timestamp& timestamp) const; int frame_width, int frame_height, const Timestamp& timestamp) const;
// Converts bounds to tilt offset, pan offset and height. // Converts bounds to tilt offset, pan offset and height.
absl::Status ConvertToPanTiltZoom(float xmin, float xmax, float ymin, absl::Status ConvertToPanTiltZoom(float xmin, float xmax, float ymin,
@ -97,7 +101,10 @@ class ContentZoomingCalculator : public CalculatorBase {
std::unique_ptr<KinematicPathSolver> path_solver_tilt_; std::unique_ptr<KinematicPathSolver> path_solver_tilt_;
// Are parameters initialized. // Are parameters initialized.
bool initialized_; bool initialized_;
// Stores the time of the first crop rectangle. // Stores the time of the first crop rectangle. This is used to control
// animating to it. Until a first crop rectangle was computed, it has
// the value Timestamp::Unset(). If animating is not requested, it receives
// the value Timestamp::Done() instead of the time.
Timestamp first_rect_timestamp_; Timestamp first_rect_timestamp_;
// Stores the first crop rectangle. // Stores the first crop rectangle.
mediapipe::NormalizedRect first_rect_; mediapipe::NormalizedRect first_rect_;
@ -135,6 +142,9 @@ absl::Status ContentZoomingCalculator::GetContract(
if (cc->Inputs().HasTag(kDetections)) { if (cc->Inputs().HasTag(kDetections)) {
cc->Inputs().Tag(kDetections).Set<std::vector<mediapipe::Detection>>(); cc->Inputs().Tag(kDetections).Set<std::vector<mediapipe::Detection>>();
} }
if (cc->Inputs().HasTag(kAnimateZoom)) {
cc->Inputs().Tag(kAnimateZoom).Set<bool>();
}
if (cc->Outputs().HasTag(kDetectedBorders)) { if (cc->Outputs().HasTag(kDetectedBorders)) {
cc->Outputs().Tag(kDetectedBorders).Set<StaticFeatures>(); cc->Outputs().Tag(kDetectedBorders).Set<StaticFeatures>();
} }
@ -419,10 +429,11 @@ absl::Status ContentZoomingCalculator::UpdateForResolutionChange(
return absl::OkStatus(); return absl::OkStatus();
} }
bool ContentZoomingCalculator::IsZoomingToInitialRect( bool ContentZoomingCalculator::IsAnimatingToFirstRect(
const Timestamp& timestamp) const { const Timestamp& timestamp) const {
if (options_.us_to_first_rect() == 0 || if (options_.us_to_first_rect() == 0 ||
first_rect_timestamp_ == Timestamp::Unset()) { first_rect_timestamp_ == Timestamp::Unset() ||
first_rect_timestamp_ == Timestamp::Done()) {
return false; return false;
} }
@ -443,10 +454,10 @@ double easeInOutQuad(double t) {
double lerp(double a, double b, double i) { return a * (1 - i) + b * i; } double lerp(double a, double b, double i) { return a * (1 - i) + b * i; }
} // namespace } // namespace
absl::StatusOr<mediapipe::Rect> ContentZoomingCalculator::GetInitialZoomingRect( absl::StatusOr<mediapipe::Rect> ContentZoomingCalculator::GetAnimationRect(
int frame_width, int frame_height, const Timestamp& timestamp) const { int frame_width, int frame_height, const Timestamp& timestamp) const {
RET_CHECK(IsZoomingToInitialRect(timestamp)) RET_CHECK(IsAnimatingToFirstRect(timestamp))
<< "Must only be called if zooming to initial rect."; << "Must only be called if animating to first rect.";
const int64 delta_us = (timestamp - first_rect_timestamp_).Value(); const int64 delta_us = (timestamp - first_rect_timestamp_).Value();
const int64 delay = options_.us_to_first_rect_delay(); const int64 delay = options_.us_to_first_rect_delay();
@ -538,15 +549,20 @@ absl::Status ContentZoomingCalculator::Process(
} }
} }
bool zooming_to_initial_rect = IsZoomingToInitialRect(cc->InputTimestamp()); const bool may_start_animation = (options_.us_to_first_rect() != 0) &&
(!cc->Inputs().HasTag(kAnimateZoom) ||
cc->Inputs().Tag(kAnimateZoom).Get<bool>());
bool is_animating = IsAnimatingToFirstRect(cc->InputTimestamp());
int offset_y, height, offset_x; int offset_y, height, offset_x;
if (zooming_to_initial_rect) { if (!is_animating && options_.start_zoomed_out() && !may_start_animation &&
// If we are zooming to the first rect, ignore any new incoming detections. first_rect_timestamp_ == Timestamp::Unset()) {
height = last_measured_height_; // If we should start zoomed out and won't be doing an animation,
offset_x = last_measured_x_offset_; // initialize the path solvers using the full frame, ignoring detections.
offset_y = last_measured_y_offset_; height = max_frame_value_ * frame_height_;
} else if (only_required_found) { offset_x = (target_aspect_ * height) / 2;
offset_y = frame_height_ / 2;
} else if (!is_animating && only_required_found) {
// Convert bounds to tilt/zoom and in pixel coordinates. // Convert bounds to tilt/zoom and in pixel coordinates.
MP_RETURN_IF_ERROR(ConvertToPanTiltZoom(xmin, xmax, ymin, ymax, &offset_y, MP_RETURN_IF_ERROR(ConvertToPanTiltZoom(xmin, xmax, ymin, ymax, &offset_y,
&offset_x, &height)); &offset_x, &height));
@ -555,9 +571,9 @@ absl::Status ContentZoomingCalculator::Process(
last_measured_height_ = height; last_measured_height_ = height;
last_measured_x_offset_ = offset_x; last_measured_x_offset_ = offset_x;
last_measured_y_offset_ = offset_y; last_measured_y_offset_ = offset_y;
} else if (cc->InputTimestamp().Microseconds() - } else if (!is_animating && cc->InputTimestamp().Microseconds() -
last_only_required_detection_ >= last_only_required_detection_ >=
options_.us_before_zoomout()) { options_.us_before_zoomout()) {
// No only_require detections found within salient regions packets // No only_require detections found within salient regions packets
// arriving since us_before_zoomout duration. // arriving since us_before_zoomout duration.
height = max_frame_value_ * frame_height_ + height = max_frame_value_ * frame_height_ +
@ -566,7 +582,8 @@ absl::Status ContentZoomingCalculator::Process(
offset_x = (target_aspect_ * height) / 2; offset_x = (target_aspect_ * height) / 2;
offset_y = frame_height_ / 2; offset_y = frame_height_ / 2;
} else { } else {
// No only detection found but using last detection due to // Either animating to the first rectangle, or
// no only detection found but using last detection due to
// duration_before_zoomout_us setting. // duration_before_zoomout_us setting.
height = last_measured_height_; height = last_measured_height_;
offset_x = last_measured_x_offset_; offset_x = last_measured_x_offset_;
@ -642,24 +659,28 @@ absl::Status ContentZoomingCalculator::Process(
.AddPacket(Adopt(features.release()).At(cc->InputTimestamp())); .AddPacket(Adopt(features.release()).At(cc->InputTimestamp()));
} }
if (first_rect_timestamp_ == Timestamp::Unset() && // Record the first crop rectangle
options_.us_to_first_rect() != 0) { if (first_rect_timestamp_ == Timestamp::Unset()) {
first_rect_timestamp_ = cc->InputTimestamp();
first_rect_.set_x_center(path_offset_x / static_cast<float>(frame_width_)); first_rect_.set_x_center(path_offset_x / static_cast<float>(frame_width_));
first_rect_.set_width(path_height * target_aspect_ / first_rect_.set_width(path_height * target_aspect_ /
static_cast<float>(frame_width_)); static_cast<float>(frame_width_));
first_rect_.set_y_center(path_offset_y / static_cast<float>(frame_height_)); first_rect_.set_y_center(path_offset_y / static_cast<float>(frame_height_));
first_rect_.set_height(path_height / static_cast<float>(frame_height_)); first_rect_.set_height(path_height / static_cast<float>(frame_height_));
// After setting the first rectangle, check whether we should zoom to it.
zooming_to_initial_rect = IsZoomingToInitialRect(cc->InputTimestamp()); // Record the time to serve as departure point for the animation.
// If we are not allowed to start the animation, set Timestamp::Done.
first_rect_timestamp_ =
may_start_animation ? cc->InputTimestamp() : Timestamp::Done();
// After setting the first rectangle, check whether we should animate to it.
is_animating = IsAnimatingToFirstRect(cc->InputTimestamp());
} }
// Transmit downstream to glcroppingcalculator. // Transmit downstream to glcroppingcalculator.
if (cc->Outputs().HasTag(kCropRect)) { if (cc->Outputs().HasTag(kCropRect)) {
std::unique_ptr<mediapipe::Rect> gpu_rect; std::unique_ptr<mediapipe::Rect> gpu_rect;
if (zooming_to_initial_rect) { if (is_animating) {
auto rect = GetInitialZoomingRect(frame_width, frame_height, auto rect =
cc->InputTimestamp()); GetAnimationRect(frame_width, frame_height, cc->InputTimestamp());
MP_RETURN_IF_ERROR(rect.status()); MP_RETURN_IF_ERROR(rect.status());
gpu_rect = absl::make_unique<mediapipe::Rect>(*rect); gpu_rect = absl::make_unique<mediapipe::Rect>(*rect);
} else { } else {

View File

@ -19,7 +19,7 @@ package mediapipe.autoflip;
import "mediapipe/examples/desktop/autoflip/quality/kinematic_path_solver.proto"; import "mediapipe/examples/desktop/autoflip/quality/kinematic_path_solver.proto";
import "mediapipe/framework/calculator.proto"; import "mediapipe/framework/calculator.proto";
// NextTag: 17 // NextTag: 18
message ContentZoomingCalculatorOptions { message ContentZoomingCalculatorOptions {
extend mediapipe.CalculatorOptions { extend mediapipe.CalculatorOptions {
optional ContentZoomingCalculatorOptions ext = 313091992; optional ContentZoomingCalculatorOptions ext = 313091992;
@ -58,9 +58,15 @@ message ContentZoomingCalculatorOptions {
// Whether to keep state between frames or to compute the final crop rect. // Whether to keep state between frames or to compute the final crop rect.
optional bool is_stateless = 14 [default = false]; optional bool is_stateless = 14 [default = false];
// Duration (in MicroSeconds) for moving to the first crop rect. // If true, on the first packet start with the camera zoomed out and then zoom
// in on the subject. If false, the camera will start zoomed in on the
// subject.
optional bool start_zoomed_out = 17 [default = false];
// Duration (in MicroSeconds) for animating to the first crop rect.
// Note that if set, takes precedence over start_zoomed_out.
optional int64 us_to_first_rect = 15 [default = 0]; optional int64 us_to_first_rect = 15 [default = 0];
// Duration (in MicroSeconds) to delay moving to the first crop rect. // Duration (in MicroSeconds) to delay animating to the first crop rect.
// Used only if us_to_first_rect is set and is interpreted as part of the // Used only if us_to_first_rect is set and is interpreted as part of the
// us_to_first_rect time budget. // us_to_first_rect time budget.
optional int64 us_to_first_rect_delay = 16 [default = 0]; optional int64 us_to_first_rect_delay = 16 [default = 0];

View File

@ -127,6 +127,29 @@ const char kConfigD[] = R"(
} }
)"; )";
const char kConfigE[] = R"(
calculator: "ContentZoomingCalculator"
input_stream: "VIDEO_SIZE:size"
input_stream: "DETECTIONS:detections"
input_stream: "ANIMATE_ZOOM:animate_zoom"
output_stream: "CROP_RECT:rect"
output_stream: "FIRST_CROP_RECT:first_rect"
options: {
[mediapipe.autoflip.ContentZoomingCalculatorOptions.ext]: {
max_zoom_value_deg: 0
kinematic_options_zoom {
min_motion_to_reframe: 1.2
}
kinematic_options_tilt {
min_motion_to_reframe: 1.2
}
kinematic_options_pan {
min_motion_to_reframe: 1.2
}
}
}
)";
void CheckBorder(const StaticFeatures& static_features, int width, int height, void CheckBorder(const StaticFeatures& static_features, int width, int height,
int top_border, int bottom_border) { int top_border, int bottom_border) {
ASSERT_EQ(2, static_features.border().size()); ASSERT_EQ(2, static_features.border().size());
@ -145,9 +168,14 @@ void CheckBorder(const StaticFeatures& static_features, int width, int height,
EXPECT_EQ(Border::BOTTOM, part.relative_position()); EXPECT_EQ(Border::BOTTOM, part.relative_position());
} }
struct AddDetectionFlags {
std::optional<bool> animated_zoom;
};
void AddDetectionFrameSize(const cv::Rect_<float>& position, const int64 time, void AddDetectionFrameSize(const cv::Rect_<float>& position, const int64 time,
const int width, const int height, const int width, const int height,
CalculatorRunner* runner) { CalculatorRunner* runner,
const AddDetectionFlags& flags = {}) {
auto detections = std::make_unique<std::vector<mediapipe::Detection>>(); auto detections = std::make_unique<std::vector<mediapipe::Detection>>();
if (position.width > 0 && position.height > 0) { if (position.width > 0 && position.height > 0) {
mediapipe::Detection detection; mediapipe::Detection detection;
@ -175,6 +203,14 @@ void AddDetectionFrameSize(const cv::Rect_<float>& position, const int64 time,
runner->MutableInputs() runner->MutableInputs()
->Tag("VIDEO_SIZE") ->Tag("VIDEO_SIZE")
.packets.push_back(Adopt(input_size.release()).At(Timestamp(time))); .packets.push_back(Adopt(input_size.release()).At(Timestamp(time)));
if (flags.animated_zoom.has_value()) {
runner->MutableInputs()
->Tag("ANIMATE_ZOOM")
.packets.push_back(
mediapipe::MakePacket<bool>(flags.animated_zoom.value())
.At(Timestamp(time)));
}
} }
void AddDetection(const cv::Rect_<float>& position, const int64 time, void AddDetection(const cv::Rect_<float>& position, const int64 time,
@ -703,7 +739,33 @@ TEST(ContentZoomingCalculatorTest, MaxZoomOutValue) {
CheckCropRect(500, 500, 1000, 1000, 2, CheckCropRect(500, 500, 1000, 1000, 2,
runner->Outputs().Tag("CROP_RECT").packets); runner->Outputs().Tag("CROP_RECT").packets);
} }
TEST(ContentZoomingCalculatorTest, StartZoomedOut) { TEST(ContentZoomingCalculatorTest, StartZoomedOut) {
auto config = ParseTextProtoOrDie<CalculatorGraphConfig::Node>(kConfigD);
auto* options = config.mutable_options()->MutableExtension(
ContentZoomingCalculatorOptions::ext);
options->set_start_zoomed_out(true);
auto runner = ::absl::make_unique<CalculatorRunner>(config);
AddDetectionFrameSize(cv::Rect_<float>(.4, .4, .2, .2), 0, 1000, 1000,
runner.get());
AddDetectionFrameSize(cv::Rect_<float>(.4, .4, .2, .2), 400000, 1000, 1000,
runner.get());
AddDetectionFrameSize(cv::Rect_<float>(.4, .4, .2, .2), 800000, 1000, 1000,
runner.get());
AddDetectionFrameSize(cv::Rect_<float>(.4, .4, .2, .2), 1000000, 1000, 1000,
runner.get());
MP_ASSERT_OK(runner->Run());
CheckCropRect(500, 500, 1000, 1000, 0,
runner->Outputs().Tag("CROP_RECT").packets);
CheckCropRect(500, 500, 880, 880, 1,
runner->Outputs().Tag("CROP_RECT").packets);
CheckCropRect(500, 500, 760, 760, 2,
runner->Outputs().Tag("CROP_RECT").packets);
CheckCropRect(500, 500, 655, 655, 3,
runner->Outputs().Tag("CROP_RECT").packets);
}
TEST(ContentZoomingCalculatorTest, AnimateToFirstRect) {
auto config = ParseTextProtoOrDie<CalculatorGraphConfig::Node>(kConfigD); auto config = ParseTextProtoOrDie<CalculatorGraphConfig::Node>(kConfigD);
auto* options = config.mutable_options()->MutableExtension( auto* options = config.mutable_options()->MutableExtension(
ContentZoomingCalculatorOptions::ext); ContentZoomingCalculatorOptions::ext);
@ -733,6 +795,65 @@ TEST(ContentZoomingCalculatorTest, StartZoomedOut) {
runner->Outputs().Tag("CROP_RECT").packets); runner->Outputs().Tag("CROP_RECT").packets);
} }
TEST(ContentZoomingCalculatorTest, CanControlAnimation) {
auto config = ParseTextProtoOrDie<CalculatorGraphConfig::Node>(kConfigE);
auto* options = config.mutable_options()->MutableExtension(
ContentZoomingCalculatorOptions::ext);
options->set_start_zoomed_out(true);
options->set_us_to_first_rect(1000000);
options->set_us_to_first_rect_delay(500000);
auto runner = ::absl::make_unique<CalculatorRunner>(config);
// Request the animation for the first frame.
AddDetectionFrameSize(cv::Rect_<float>(.4, .4, .2, .2), 0, 1000, 1000,
runner.get(), {.animated_zoom = true});
// We now stop requesting animated zoom and expect the already started
// animation run to completion. This tests that the zoom in continues in the
// call when it was started in the Meet greenroom.
AddDetectionFrameSize(cv::Rect_<float>(.4, .4, .2, .2), 400000, 1000, 1000,
runner.get(), {.animated_zoom = false});
AddDetectionFrameSize(cv::Rect_<float>(.4, .4, .2, .2), 800000, 1000, 1000,
runner.get(), {.animated_zoom = false});
AddDetectionFrameSize(cv::Rect_<float>(.4, .4, .2, .2), 1000000, 1000, 1000,
runner.get(), {.animated_zoom = false});
AddDetectionFrameSize(cv::Rect_<float>(.4, .4, .2, .2), 1500000, 1000, 1000,
runner.get(), {.animated_zoom = false});
MP_ASSERT_OK(runner->Run());
CheckCropRect(500, 500, 1000, 1000, 0,
runner->Outputs().Tag("CROP_RECT").packets);
CheckCropRect(500, 500, 1000, 1000, 1,
runner->Outputs().Tag("CROP_RECT").packets);
CheckCropRect(500, 500, 470, 470, 2,
runner->Outputs().Tag("CROP_RECT").packets);
CheckCropRect(500, 500, 222, 222, 3,
runner->Outputs().Tag("CROP_RECT").packets);
CheckCropRect(500, 500, 222, 222, 4,
runner->Outputs().Tag("CROP_RECT").packets);
}
TEST(ContentZoomingCalculatorTest, DoesNotAnimateIfDisabledViaInput) {
auto config = ParseTextProtoOrDie<CalculatorGraphConfig::Node>(kConfigE);
auto* options = config.mutable_options()->MutableExtension(
ContentZoomingCalculatorOptions::ext);
options->set_start_zoomed_out(true);
options->set_us_to_first_rect(1000000);
options->set_us_to_first_rect_delay(500000);
auto runner = ::absl::make_unique<CalculatorRunner>(config);
// Disable the animation already for the first frame.
AddDetectionFrameSize(cv::Rect_<float>(.4, .4, .2, .2), 0, 1000, 1000,
runner.get(), {.animated_zoom = false});
AddDetectionFrameSize(cv::Rect_<float>(.4, .4, .2, .2), 400000, 1000, 1000,
runner.get(), {.animated_zoom = false});
AddDetectionFrameSize(cv::Rect_<float>(.4, .4, .2, .2), 800000, 1000, 1000,
runner.get(), {.animated_zoom = false});
MP_ASSERT_OK(runner->Run());
CheckCropRect(500, 500, 1000, 1000, 0,
runner->Outputs().Tag("CROP_RECT").packets);
CheckCropRect(500, 500, 880, 880, 1,
runner->Outputs().Tag("CROP_RECT").packets);
CheckCropRect(500, 500, 760, 760, 2,
runner->Outputs().Tag("CROP_RECT").packets);
}
TEST(ContentZoomingCalculatorTest, ProvidesZeroSizeFirstRectWithoutDetections) { TEST(ContentZoomingCalculatorTest, ProvidesZeroSizeFirstRectWithoutDetections) {
auto config = ParseTextProtoOrDie<CalculatorGraphConfig::Node>(kConfigD); auto config = ParseTextProtoOrDie<CalculatorGraphConfig::Node>(kConfigD);
auto runner = ::absl::make_unique<CalculatorRunner>(config); auto runner = ::absl::make_unique<CalculatorRunner>(config);

View File

@ -47,4 +47,10 @@ message FaceBoxAdjusterCalculatorOptions {
// and height respectively. // and height respectively.
optional float ipd_face_box_width_ratio = 6 [default = 0.5566]; optional float ipd_face_box_width_ratio = 6 [default = 0.5566];
optional float ipd_face_box_height_ratio = 7 [default = 0.3131]; optional float ipd_face_box_height_ratio = 7 [default = 0.3131];
// The max look up angle before considering the eye distance unstable.
optional float max_head_tilt_angle_deg = 8 [default = 12.0];
// The max amount of time to use an old eye distance when the face look angle
// is unstable.
optional int32 max_facesize_history_us = 9 [default = 8000000];
} }

View File

@ -345,8 +345,7 @@ TEST(SceneCroppingCalculatorTest, ChecksPriorFrameBufferSize) {
TEST(SceneCroppingCalculatorTest, ChecksDebugConfigWithoutCroppedFrame) { TEST(SceneCroppingCalculatorTest, ChecksDebugConfigWithoutCroppedFrame) {
const CalculatorGraphConfig::Node config = const CalculatorGraphConfig::Node config =
ParseTextProtoOrDie<CalculatorGraphConfig::Node>(absl::Substitute( ParseTextProtoOrDie<CalculatorGraphConfig::Node>(absl::Substitute(
kDebugConfigNoCroppedFrame, kTargetWidth, kTargetHeight, kDebugConfigNoCroppedFrame, kTargetWidth, kTargetHeight));
kTargetSizeType, 0, kPriorFrameBufferSize));
auto runner = absl::make_unique<CalculatorRunner>(config); auto runner = absl::make_unique<CalculatorRunner>(config);
const auto status = runner->Run(); const auto status = runner->Run();
EXPECT_FALSE(status.ok()); EXPECT_FALSE(status.ok());

View File

@ -220,7 +220,7 @@ absl::Status KinematicPathSolver::GetTargetPosition(int* target_position) {
absl::Status KinematicPathSolver::UpdatePixelsPerDegree( absl::Status KinematicPathSolver::UpdatePixelsPerDegree(
const float pixels_per_degree) { const float pixels_per_degree) {
RET_CHECK_GT(pixels_per_degree_, 0) RET_CHECK_GT(pixels_per_degree, 0)
<< "pixels_per_degree must be larger than 0."; << "pixels_per_degree must be larger than 0.";
pixels_per_degree_ = pixels_per_degree; pixels_per_degree_ = pixels_per_degree;
return absl::OkStatus(); return absl::OkStatus();

View File

@ -38,7 +38,7 @@ node {
output_stream: "TENSORS:detection_tensors" output_stream: "TENSORS:detection_tensors"
options: { options: {
[mediapipe.TfLiteInferenceCalculatorOptions.ext] { [mediapipe.TfLiteInferenceCalculatorOptions.ext] {
model_path: "mediapipe/models/face_detection_back.tflite" model_path: "mediapipe/modules/face_detection/face_detection_back.tflite"
} }
} }
} }
@ -111,26 +111,13 @@ node {
} }
} }
# Maps detection label IDs to the corresponding label text ("Face"). The label
# map is provided in the label_map_path option.
node {
calculator: "DetectionLabelIdToTextCalculator"
input_stream: "filtered_detections"
output_stream: "labeled_detections"
options: {
[mediapipe.DetectionLabelIdToTextCalculatorOptions.ext] {
label_map_path: "mediapipe/models/face_detection_back_labelmap.txt"
}
}
}
# Adjusts detection locations (already normalized to [0.f, 1.f]) on the # Adjusts detection locations (already normalized to [0.f, 1.f]) on the
# letterboxed image (after image transformation with the FIT scale mode) to the # letterboxed image (after image transformation with the FIT scale mode) to the
# corresponding locations on the same image with the letterbox removed (the # corresponding locations on the same image with the letterbox removed (the
# input image to the graph before image transformation). # input image to the graph before image transformation).
node { node {
calculator: "DetectionLetterboxRemovalCalculator" calculator: "DetectionLetterboxRemovalCalculator"
input_stream: "DETECTIONS:labeled_detections" input_stream: "DETECTIONS:filtered_detections"
input_stream: "LETTERBOX_PADDING:letterbox_padding" input_stream: "LETTERBOX_PADDING:letterbox_padding"
output_stream: "DETECTIONS:output_detections" output_stream: "DETECTIONS:output_detections"
} }

View File

@ -1,34 +0,0 @@
# Copyright 2020 The MediaPipe Authors.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
licenses(["notice"])
package(default_visibility = ["//mediapipe/examples:__subpackages__"])
cc_binary(
name = "upper_body_pose_tracking_cpu",
deps = [
"//mediapipe/examples/desktop:demo_run_graph_main",
"//mediapipe/graphs/pose_tracking:upper_body_pose_tracking_cpu_deps",
],
)
# Linux only
cc_binary(
name = "upper_body_pose_tracking_gpu",
deps = [
"//mediapipe/examples/desktop:demo_run_graph_main_gpu",
"//mediapipe/graphs/pose_tracking:upper_body_pose_tracking_gpu_deps",
],
)

View File

@ -61,8 +61,7 @@ objc_library(
"//mediapipe/modules/hand_landmark:handedness.txt", "//mediapipe/modules/hand_landmark:handedness.txt",
"//mediapipe/modules/holistic_landmark:hand_recrop.tflite", "//mediapipe/modules/holistic_landmark:hand_recrop.tflite",
"//mediapipe/modules/pose_detection:pose_detection.tflite", "//mediapipe/modules/pose_detection:pose_detection.tflite",
"//mediapipe/modules/pose_landmark:pose_landmark_full_body.tflite", "//mediapipe/modules/pose_landmark:pose_landmark_full.tflite",
"//mediapipe/modules/pose_landmark:pose_landmark_upper_body.tflite",
], ],
deps = [ deps = [
"//mediapipe/examples/ios/common:CommonMediaPipeAppLibrary", "//mediapipe/examples/ios/common:CommonMediaPipeAppLibrary",

View File

@ -63,7 +63,7 @@ objc_library(
data = [ data = [
"//mediapipe/graphs/pose_tracking:pose_tracking_gpu.binarypb", "//mediapipe/graphs/pose_tracking:pose_tracking_gpu.binarypb",
"//mediapipe/modules/pose_detection:pose_detection.tflite", "//mediapipe/modules/pose_detection:pose_detection.tflite",
"//mediapipe/modules/pose_landmark:pose_landmark_full_body.tflite", "//mediapipe/modules/pose_landmark:pose_landmark_full.tflite",
], ],
deps = [ deps = [
"//mediapipe/examples/ios/common:CommonMediaPipeAppLibrary", "//mediapipe/examples/ios/common:CommonMediaPipeAppLibrary",

View File

@ -1,78 +0,0 @@
# Copyright 2020 The MediaPipe Authors.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
load(
"@build_bazel_rules_apple//apple:ios.bzl",
"ios_application",
)
load(
"//mediapipe/examples/ios:bundle_id.bzl",
"BUNDLE_ID_PREFIX",
"example_provisioning",
)
licenses(["notice"])
MIN_IOS_VERSION = "10.0"
alias(
name = "upperbodyposetrackinggpu",
actual = "UpperBodyPoseTrackingGpuApp",
)
ios_application(
name = "UpperBodyPoseTrackingGpuApp",
app_icons = ["//mediapipe/examples/ios/common:AppIcon"],
bundle_id = BUNDLE_ID_PREFIX + ".UpperBodyPoseTrackingGpu",
families = [
"iphone",
"ipad",
],
infoplists = [
"//mediapipe/examples/ios/common:Info.plist",
"Info.plist",
],
minimum_os_version = MIN_IOS_VERSION,
provisioning_profile = example_provisioning(),
deps = [
":UpperBodyPoseTrackingGpuAppLibrary",
"@ios_opencv//:OpencvFramework",
],
)
objc_library(
name = "UpperBodyPoseTrackingGpuAppLibrary",
srcs = [
"UpperBodyPoseTrackingViewController.mm",
],
hdrs = [
"UpperBodyPoseTrackingViewController.h",
],
copts = ["-std=c++17"],
data = [
"//mediapipe/graphs/pose_tracking:upper_body_pose_tracking_gpu.binarypb",
"//mediapipe/modules/pose_detection:pose_detection.tflite",
"//mediapipe/modules/pose_landmark:pose_landmark_upper_body.tflite",
],
deps = [
"//mediapipe/examples/ios/common:CommonMediaPipeAppLibrary",
] + select({
"//mediapipe:ios_i386": [],
"//mediapipe:ios_x86_64": [],
"//conditions:default": [
"//mediapipe/graphs/pose_tracking:upper_body_pose_tracking_gpu_deps",
"//mediapipe/framework/formats:landmark_cc_proto",
],
}),
)

View File

@ -1,16 +0,0 @@
<?xml version="1.0" encoding="UTF-8"?>
<!DOCTYPE plist PUBLIC "-//Apple//DTD PLIST 1.0//EN" "http://www.apple.com/DTDs/PropertyList-1.0.dtd">
<plist version="1.0">
<dict>
<key>CameraPosition</key>
<string>back</string>
<key>MainViewController</key>
<string>UpperBodyPoseTrackingViewController</string>
<key>GraphOutputStream</key>
<string>output_video</string>
<key>GraphInputStream</key>
<string>input_video</string>
<key>GraphName</key>
<string>upper_body_pose_tracking_gpu</string>
</dict>
</plist>

View File

@ -1,53 +0,0 @@
// Copyright 2019 The MediaPipe Authors.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#import "UpperBodyPoseTrackingViewController.h"
#include "mediapipe/framework/formats/landmark.pb.h"
static const char* kLandmarksOutputStream = "pose_landmarks";
@implementation UpperBodyPoseTrackingViewController
#pragma mark - UIViewController methods
- (void)viewDidLoad {
[super viewDidLoad];
[self.mediapipeGraph addFrameOutputStream:kLandmarksOutputStream
outputPacketType:MPPPacketTypeRaw];
}
#pragma mark - MPPGraphDelegate methods
// Receives a raw packet from the MediaPipe graph. Invoked on a MediaPipe worker thread.
- (void)mediapipeGraph:(MPPGraph*)graph
didOutputPacket:(const ::mediapipe::Packet&)packet
fromStream:(const std::string&)streamName {
if (streamName == kLandmarksOutputStream) {
if (packet.IsEmpty()) {
NSLog(@"[TS:%lld] No pose landmarks", packet.Timestamp().Value());
return;
}
const auto& landmarks = packet.Get<::mediapipe::NormalizedLandmarkList>();
NSLog(@"[TS:%lld] Number of pose landmarks: %d", packet.Timestamp().Value(),
landmarks.landmark_size());
for (int i = 0; i < landmarks.landmark_size(); ++i) {
NSLog(@"\tLandmark[%d]: (%f, %f, %f)", i, landmarks.landmark(i).x(),
landmarks.landmark(i).y(), landmarks.landmark(i).z());
}
}
}
@end

View File

@ -15,6 +15,7 @@
# #
load("//mediapipe/framework/port:build_config.bzl", "mediapipe_proto_library") load("//mediapipe/framework/port:build_config.bzl", "mediapipe_proto_library")
load("@bazel_skylib//:bzl_library.bzl", "bzl_library")
licenses(["notice"]) licenses(["notice"])
@ -27,10 +28,28 @@ package_group(
], ],
) )
exports_files([ bzl_library(
"transitive_protos.bzl", name = "transitive_protos_bzl",
"encode_binary_proto.bzl", srcs = [
]) "transitive_protos.bzl",
],
visibility = ["//mediapipe/framework:__subpackages__"],
)
bzl_library(
name = "encode_binary_proto_bzl",
srcs = [
"encode_binary_proto.bzl",
],
visibility = ["//visibility:public"],
)
alias(
name = "encode_binary_proto",
actual = ":encode_binary_proto_bzl",
deprecation = "Use encode_binary_proto_bzl",
visibility = ["//visibility:public"],
)
mediapipe_proto_library( mediapipe_proto_library(
name = "calculator_proto", name = "calculator_proto",

View File

@ -1,15 +1,8 @@
package( package(
default_visibility = [":preview_users"], default_visibility = ["//visibility:public"],
features = ["-use_header_modules"], features = ["-use_header_modules"],
) )
package_group(
name = "preview_users",
packages = [
"//mediapipe/...",
],
)
licenses(["notice"]) licenses(["notice"])
cc_library( cc_library(

View File

@ -422,6 +422,9 @@ message CalculatorGraphConfig {
// the graph config. // the graph config.
string type = 20; string type = 20;
// Can be used for annotating a graph. // The types and default values for graph options, in proto2 syntax.
MediaPipeOptions options = 1001; MediaPipeOptions options = 1001;
// The types and default values for graph options, in proto3 syntax.
repeated google.protobuf.Any graph_options = 1002;
} }

View File

@ -411,7 +411,8 @@ absl::Status CalculatorGraph::Initialize(
absl::Status CalculatorGraph::ObserveOutputStream( absl::Status CalculatorGraph::ObserveOutputStream(
const std::string& stream_name, const std::string& stream_name,
std::function<absl::Status(const Packet&)> packet_callback) { std::function<absl::Status(const Packet&)> packet_callback,
bool observe_timestamp_bounds) {
RET_CHECK(initialized_).SetNoLogging() RET_CHECK(initialized_).SetNoLogging()
<< "CalculatorGraph is not initialized."; << "CalculatorGraph is not initialized.";
// TODO Allow output observers to be attached by graph level // TODO Allow output observers to be attached by graph level
@ -425,7 +426,7 @@ absl::Status CalculatorGraph::ObserveOutputStream(
auto observer = absl::make_unique<internal::OutputStreamObserver>(); auto observer = absl::make_unique<internal::OutputStreamObserver>();
MP_RETURN_IF_ERROR(observer->Initialize( MP_RETURN_IF_ERROR(observer->Initialize(
stream_name, &any_packet_type_, std::move(packet_callback), stream_name, &any_packet_type_, std::move(packet_callback),
&output_stream_managers_[output_stream_index])); &output_stream_managers_[output_stream_index], observe_timestamp_bounds));
graph_output_streams_.push_back(std::move(observer)); graph_output_streams_.push_back(std::move(observer));
return absl::OkStatus(); return absl::OkStatus();
} }

View File

@ -157,7 +157,8 @@ class CalculatorGraph {
// TODO: Rename to AddOutputStreamCallback. // TODO: Rename to AddOutputStreamCallback.
absl::Status ObserveOutputStream( absl::Status ObserveOutputStream(
const std::string& stream_name, const std::string& stream_name,
std::function<absl::Status(const Packet&)> packet_callback); std::function<absl::Status(const Packet&)> packet_callback,
bool observe_timestamp_bounds = false);
// Adds an OutputStreamPoller for a stream. This provides a synchronous, // Adds an OutputStreamPoller for a stream. This provides a synchronous,
// polling API for accessing a stream's output. Should only be called before // polling API for accessing a stream's output. Should only be called before

View File

@ -1518,5 +1518,72 @@ TEST(CalculatorGraphBoundsTest, OffsetAndBound) {
MP_ASSERT_OK(graph.WaitUntilDone()); MP_ASSERT_OK(graph.WaitUntilDone());
} }
// A Calculator that sends empty output stream packets.
class EmptyPacketCalculator : public CalculatorBase {
public:
static absl::Status GetContract(CalculatorContract* cc) {
cc->Inputs().Index(0).Set<int>();
cc->Outputs().Index(0).SetSameAs(&cc->Inputs().Index(0));
return absl::OkStatus();
}
absl::Status Open(CalculatorContext* cc) final { return absl::OkStatus(); }
absl::Status Process(CalculatorContext* cc) final {
if (cc->InputTimestamp().Value() % 2 == 0) {
cc->Outputs().Index(0).AddPacket(Packet().At(cc->InputTimestamp()));
}
return absl::OkStatus();
}
};
REGISTER_CALCULATOR(EmptyPacketCalculator);
// This test shows that an output timestamp bound can be specified by outputing
// an empty packet with a settled timestamp.
TEST(CalculatorGraphBoundsTest, EmptyPacketOutput) {
// OffsetAndBoundCalculator runs on parallel threads and sends ts
// occasionally.
std::string config_str = R"(
input_stream: "input_0"
node {
calculator: "EmptyPacketCalculator"
input_stream: "input_0"
output_stream: "empty_0"
}
node {
calculator: "ProcessBoundToPacketCalculator"
input_stream: "empty_0"
output_stream: "output_0"
}
)";
CalculatorGraphConfig config =
mediapipe::ParseTextProtoOrDie<CalculatorGraphConfig>(config_str);
CalculatorGraph graph;
std::vector<Packet> output_0_packets;
MP_ASSERT_OK(graph.Initialize(config));
MP_ASSERT_OK(graph.ObserveOutputStream("output_0", [&](const Packet& p) {
output_0_packets.push_back(p);
return absl::OkStatus();
}));
MP_ASSERT_OK(graph.StartRun({}));
MP_ASSERT_OK(graph.WaitUntilIdle());
// Send in packets.
for (int i = 0; i < 9; ++i) {
const int ts = 10 + i * 10;
Packet p = MakePacket<int>(i).At(Timestamp(ts));
MP_ASSERT_OK(graph.AddPacketToInputStream("input_0", p));
MP_ASSERT_OK(graph.WaitUntilIdle());
}
// 9 empty packets are converted to bounds and then to packets.
EXPECT_EQ(output_0_packets.size(), 9);
for (int i = 0; i < 9; ++i) {
EXPECT_EQ(output_0_packets[i].Timestamp(), Timestamp(10 + i * 10));
}
// Shutdown the graph.
MP_ASSERT_OK(graph.CloseAllPacketSources());
MP_ASSERT_OK(graph.WaitUntilDone());
}
} // namespace } // namespace
} // namespace mediapipe } // namespace mediapipe

View File

@ -16,11 +16,20 @@
# The dependencies of mediapipe. # The dependencies of mediapipe.
load("//mediapipe/framework/port:build_config.bzl", "mediapipe_cc_proto_library") load("//mediapipe/framework/port:build_config.bzl", "mediapipe_cc_proto_library")
load("@bazel_skylib//:bzl_library.bzl", "bzl_library")
licenses(["notice"]) licenses(["notice"])
package(default_visibility = ["//visibility:private"]) package(default_visibility = ["//visibility:private"])
bzl_library(
name = "expand_template_bzl",
srcs = [
"expand_template.bzl",
],
visibility = ["//mediapipe/framework:__subpackages__"],
)
proto_library( proto_library(
name = "proto_descriptor_proto", name = "proto_descriptor_proto",
srcs = ["proto_descriptor.proto"], srcs = ["proto_descriptor.proto"],

View File

@ -295,6 +295,7 @@ cc_library(
"//mediapipe/framework/formats:image_format_cc_proto", "//mediapipe/framework/formats:image_format_cc_proto",
"@com_google_absl//absl/synchronization", "@com_google_absl//absl/synchronization",
"//mediapipe/framework:port", "//mediapipe/framework:port",
"//mediapipe/framework:type_map",
"//mediapipe/framework/port:logging", "//mediapipe/framework/port:logging",
] + select({ ] + select({
"//conditions:default": [ "//conditions:default": [

View File

@ -14,6 +14,8 @@
#include "mediapipe/framework/formats/image.h" #include "mediapipe/framework/formats/image.h"
#include "mediapipe/framework/type_map.h"
namespace mediapipe { namespace mediapipe {
// TODO Refactor common code from GpuBufferToImageFrameCalculator // TODO Refactor common code from GpuBufferToImageFrameCalculator
@ -67,8 +69,7 @@ bool Image::ConvertToGpu() const {
#else #else
if (use_gpu_) return true; // Already on GPU. if (use_gpu_) return true; // Already on GPU.
#if MEDIAPIPE_GPU_BUFFER_USE_CV_PIXEL_BUFFER #if MEDIAPIPE_GPU_BUFFER_USE_CV_PIXEL_BUFFER
auto packet = MakePacket<ImageFrame>(std::move(*image_frame_)); auto packet = PointToForeign<ImageFrame>(image_frame_.get());
image_frame_ = nullptr;
CFHolder<CVPixelBufferRef> buffer; CFHolder<CVPixelBufferRef> buffer;
auto status = CreateCVPixelBufferForImageFramePacket(packet, true, &buffer); auto status = CreateCVPixelBufferForImageFramePacket(packet, true, &buffer);
CHECK_OK(status); CHECK_OK(status);
@ -94,4 +95,7 @@ bool Image::ConvertToGpu() const {
#endif // MEDIAPIPE_DISABLE_GPU #endif // MEDIAPIPE_DISABLE_GPU
} }
MEDIAPIPE_REGISTER_TYPE(mediapipe::Image, "::mediapipe::Image", nullptr,
nullptr);
} // namespace mediapipe } // namespace mediapipe

View File

@ -72,8 +72,8 @@ class Image {
// Creates an Image representing the same image content as the ImageFrame // Creates an Image representing the same image content as the ImageFrame
// the input shared pointer points to, and retaining shared ownership. // the input shared pointer points to, and retaining shared ownership.
explicit Image(ImageFrameSharedPtr frame_buffer) explicit Image(ImageFrameSharedPtr image_frame)
: image_frame_(std::move(frame_buffer)) { : image_frame_(std::move(image_frame)) {
use_gpu_ = false; use_gpu_ = false;
pixel_mutex_ = std::make_shared<absl::Mutex>(); pixel_mutex_ = std::make_shared<absl::Mutex>();
} }

View File

@ -30,6 +30,9 @@
namespace mediapipe { namespace mediapipe {
// Zero and negative values are not checked here.
bool IsPowerOfTwo(int v) { return (v & (v - 1)) == 0; }
int BhwcBatchFromShape(const Tensor::Shape& shape) { int BhwcBatchFromShape(const Tensor::Shape& shape) {
LOG_IF(FATAL, shape.dims.empty()) LOG_IF(FATAL, shape.dims.empty())
<< "Tensor::Shape must be non-empty to retrieve a named dimension"; << "Tensor::Shape must be non-empty to retrieve a named dimension";
@ -237,6 +240,12 @@ void Tensor::AllocateOpenGlTexture2d() const {
glTexStorage2D(GL_TEXTURE_2D, 1, GL_RGBA32F, texture_width_, glTexStorage2D(GL_TEXTURE_2D, 1, GL_RGBA32F, texture_width_,
texture_height_); texture_height_);
} else { } else {
// GLES2.0 supports only clamp addressing mode for NPOT textures.
// If any of dimensions is NPOT then both addressing modes are clamp.
if (!IsPowerOfTwo(texture_width_) || !IsPowerOfTwo(texture_height_)) {
glTexParameteri(GL_TEXTURE_2D, GL_TEXTURE_WRAP_S, GL_CLAMP_TO_EDGE);
glTexParameteri(GL_TEXTURE_2D, GL_TEXTURE_WRAP_T, GL_CLAMP_TO_EDGE);
}
// We assume all contexts will have the same extensions, so we only check // We assume all contexts will have the same extensions, so we only check
// once for OES_texture_float extension, to save time. // once for OES_texture_float extension, to save time.
static bool has_oes_extension = static bool has_oes_extension =

View File

@ -14,13 +14,16 @@
#include "mediapipe/framework/graph_output_stream.h" #include "mediapipe/framework/graph_output_stream.h"
#include "absl/synchronization/mutex.h"
#include "mediapipe/framework/port/status.h"
namespace mediapipe { namespace mediapipe {
namespace internal { namespace internal {
absl::Status GraphOutputStream::Initialize( absl::Status GraphOutputStream::Initialize(
const std::string& stream_name, const PacketType* packet_type, const std::string& stream_name, const PacketType* packet_type,
OutputStreamManager* output_stream_manager) { OutputStreamManager* output_stream_manager, bool observe_timestamp_bounds) {
RET_CHECK(output_stream_manager); RET_CHECK(output_stream_manager);
// Initializes input_stream_handler_ with one input stream as the observer. // Initializes input_stream_handler_ with one input stream as the observer.
@ -31,6 +34,7 @@ absl::Status GraphOutputStream::Initialize(
input_stream_handler_ = absl::make_unique<GraphOutputStreamHandler>( input_stream_handler_ = absl::make_unique<GraphOutputStreamHandler>(
tag_map, /*cc_manager=*/nullptr, MediaPipeOptions(), tag_map, /*cc_manager=*/nullptr, MediaPipeOptions(),
/*calculator_run_in_parallel=*/false); /*calculator_run_in_parallel=*/false);
input_stream_handler_->SetProcessTimestampBounds(observe_timestamp_bounds);
const CollectionItemId& id = tag_map->BeginId(); const CollectionItemId& id = tag_map->BeginId();
input_stream_ = absl::make_unique<InputStreamManager>(); input_stream_ = absl::make_unique<InputStreamManager>();
MP_RETURN_IF_ERROR( MP_RETURN_IF_ERROR(
@ -52,20 +56,58 @@ void GraphOutputStream::PrepareForRun(
absl::Status OutputStreamObserver::Initialize( absl::Status OutputStreamObserver::Initialize(
const std::string& stream_name, const PacketType* packet_type, const std::string& stream_name, const PacketType* packet_type,
std::function<absl::Status(const Packet&)> packet_callback, std::function<absl::Status(const Packet&)> packet_callback,
OutputStreamManager* output_stream_manager) { OutputStreamManager* output_stream_manager, bool observe_timestamp_bounds) {
RET_CHECK(output_stream_manager); RET_CHECK(output_stream_manager);
packet_callback_ = std::move(packet_callback); packet_callback_ = std::move(packet_callback);
observe_timestamp_bounds_ = observe_timestamp_bounds;
return GraphOutputStream::Initialize(stream_name, packet_type, return GraphOutputStream::Initialize(stream_name, packet_type,
output_stream_manager); output_stream_manager,
observe_timestamp_bounds);
} }
absl::Status OutputStreamObserver::Notify() { absl::Status OutputStreamObserver::Notify() {
// Lets one thread perform packets notification as much as possible.
// Other threads should quit if a thread is already performing notification.
{
absl::MutexLock l(&mutex_);
if (notifying_ == false) {
notifying_ = true;
} else {
return absl::OkStatus();
}
}
while (true) { while (true) {
bool empty; bool empty;
Timestamp min_timestamp = input_stream_->MinTimestampOrBound(&empty); Timestamp min_timestamp = input_stream_->MinTimestampOrBound(&empty);
if (empty) { if (empty) {
break; // Emits an empty packet at timestamp_bound.PreviousAllowedInStream().
if (observe_timestamp_bounds_ && min_timestamp < Timestamp::Done()) {
Timestamp settled = (min_timestamp == Timestamp::PostStream()
? Timestamp::PostStream()
: min_timestamp.PreviousAllowedInStream());
if (last_processed_ts_ < settled) {
MP_RETURN_IF_ERROR(packet_callback_(Packet().At(settled)));
last_processed_ts_ = settled;
}
}
// Last check to make sure that the min timestamp or bound doesn't change.
// If so, flips notifying_ to false to allow any other threads to perform
// notification when new packets/timestamp bounds arrive. Otherwise, in
// case of the min timestamp or bound getting updated, jumps to the
// beginning of the notification loop for a new iteration.
{
absl::MutexLock l(&mutex_);
Timestamp new_min_timestamp =
input_stream_->MinTimestampOrBound(&empty);
if (new_min_timestamp == min_timestamp) {
notifying_ = false;
break;
} else {
continue;
}
}
} }
int num_packets_dropped = 0; int num_packets_dropped = 0;
bool stream_is_done = false; bool stream_is_done = false;
@ -75,6 +117,7 @@ absl::Status OutputStreamObserver::Notify() {
<< absl::Substitute("Dropped $0 packet(s) on input stream \"$1\".", << absl::Substitute("Dropped $0 packet(s) on input stream \"$1\".",
num_packets_dropped, input_stream_->Name()); num_packets_dropped, input_stream_->Name());
MP_RETURN_IF_ERROR(packet_callback_(packet)); MP_RETURN_IF_ERROR(packet_callback_(packet));
last_processed_ts_ = min_timestamp;
} }
return absl::OkStatus(); return absl::OkStatus();
} }

View File

@ -52,7 +52,8 @@ class GraphOutputStream {
// is not transferred to the graph output stream object. // is not transferred to the graph output stream object.
absl::Status Initialize(const std::string& stream_name, absl::Status Initialize(const std::string& stream_name,
const PacketType* packet_type, const PacketType* packet_type,
OutputStreamManager* output_stream_manager); OutputStreamManager* output_stream_manager,
bool observe_timestamp_bounds = false);
// Installs callbacks into its GraphOutputStreamHandler. // Installs callbacks into its GraphOutputStreamHandler.
virtual void PrepareForRun(std::function<void()> notification_callback, virtual void PrepareForRun(std::function<void()> notification_callback,
@ -99,6 +100,10 @@ class GraphOutputStream {
} }
}; };
bool observe_timestamp_bounds_;
absl::Mutex mutex_;
bool notifying_ ABSL_GUARDED_BY(mutex_) = false;
Timestamp last_processed_ts_ = Timestamp::Unstarted();
std::unique_ptr<InputStreamHandler> input_stream_handler_; std::unique_ptr<InputStreamHandler> input_stream_handler_;
std::unique_ptr<InputStreamManager> input_stream_; std::unique_ptr<InputStreamManager> input_stream_;
}; };
@ -112,7 +117,8 @@ class OutputStreamObserver : public GraphOutputStream {
absl::Status Initialize( absl::Status Initialize(
const std::string& stream_name, const PacketType* packet_type, const std::string& stream_name, const PacketType* packet_type,
std::function<absl::Status(const Packet&)> packet_callback, std::function<absl::Status(const Packet&)> packet_callback,
OutputStreamManager* output_stream_manager); OutputStreamManager* output_stream_manager,
bool observe_timestamp_bounds = false);
// Notifies the observer of new packets emitted by the observed // Notifies the observer of new packets emitted by the observed
// output stream. // output stream.
@ -128,6 +134,7 @@ class OutputStreamObserver : public GraphOutputStream {
// OutputStreamPollerImpl that returns packets to the caller via // OutputStreamPollerImpl that returns packets to the caller via
// Next()/NextBatch(). // Next()/NextBatch().
// TODO: Support observe_timestamp_bounds.
class OutputStreamPollerImpl : public GraphOutputStream { class OutputStreamPollerImpl : public GraphOutputStream {
public: public:
virtual ~OutputStreamPollerImpl() {} virtual ~OutputStreamPollerImpl() {}

View File

@ -20,6 +20,9 @@ syntax = "proto2";
package mediapipe; package mediapipe;
option java_package = "com.google.mediapipe.proto";
option java_outer_classname = "MediaPipeOptionsProto";
// Options used by a MediaPipe object. // Options used by a MediaPipe object.
message MediaPipeOptions { message MediaPipeOptions {
extensions 20000 to max; extensions 20000 to max;

View File

@ -101,8 +101,8 @@ Status OutputStreamShard::AddPacketInternal(T&& packet) {
} }
if (packet.IsEmpty()) { if (packet.IsEmpty()) {
return mediapipe::InvalidArgumentErrorBuilder(MEDIAPIPE_LOC) SetNextTimestampBound(packet.Timestamp().NextAllowedInStream());
<< "Empty packet sent to stream \"" << Name() << "\"."; return absl::OkStatus();
} }
const Timestamp timestamp = packet.Timestamp(); const Timestamp timestamp = packet.Timestamp();

View File

@ -20,6 +20,7 @@ load(
"mediapipe_binary_graph", "mediapipe_binary_graph",
) )
load("//mediapipe/framework:mediapipe_cc_test.bzl", "mediapipe_cc_test") load("//mediapipe/framework:mediapipe_cc_test.bzl", "mediapipe_cc_test")
load("@bazel_skylib//:bzl_library.bzl", "bzl_library")
licenses(["notice"]) licenses(["notice"])
@ -29,6 +30,30 @@ exports_files([
"simple_subgraph_template.cc", "simple_subgraph_template.cc",
]) ])
bzl_library(
name = "mediapipe_graph_bzl",
srcs = [
"mediapipe_graph.bzl",
],
visibility = ["//visibility:public"],
deps = [
":build_defs_bzl",
"//mediapipe/framework:encode_binary_proto",
"//mediapipe/framework:transitive_protos_bzl",
"//mediapipe/framework/deps:expand_template_bzl",
],
)
bzl_library(
name = "build_defs_bzl",
srcs = [
"build_defs.bzl",
],
visibility = [
"//mediapipe/framework:__subpackages__",
],
)
cc_library( cc_library(
name = "text_to_binary_graph", name = "text_to_binary_graph",
srcs = ["text_to_binary_graph.cc"], srcs = ["text_to_binary_graph.cc"],
@ -744,5 +769,7 @@ cc_test(
exports_files( exports_files(
["build_defs.bzl"], ["build_defs.bzl"],
visibility = ["//mediapipe/framework:__subpackages__"], visibility = [
"//mediapipe/framework:__subpackages__",
],
) )

View File

@ -19,6 +19,7 @@
#include "mediapipe/framework/tool/sink.h" #include "mediapipe/framework/tool/sink.h"
#include <memory> #include <memory>
#include <utility>
#include <vector> #include <vector>
#include "absl/strings/str_cat.h" #include "absl/strings/str_cat.h"
@ -168,8 +169,19 @@ void AddMultiStreamCallback(
std::function<void(const std::vector<Packet>&)> callback, std::function<void(const std::vector<Packet>&)> callback,
CalculatorGraphConfig* config, CalculatorGraphConfig* config,
std::pair<std::string, Packet>* side_packet) { std::pair<std::string, Packet>* side_packet) {
std::map<std::string, Packet> side_packets;
AddMultiStreamCallback(streams, callback, config, &side_packets,
/*observe_timestamp_bounds=*/false);
*side_packet = *side_packets.begin();
}
void AddMultiStreamCallback(
const std::vector<std::string>& streams,
std::function<void(const std::vector<Packet>&)> callback,
CalculatorGraphConfig* config, std::map<std::string, Packet>* side_packets,
bool observe_timestamp_bounds) {
CHECK(config); CHECK(config);
CHECK(side_packet); CHECK(side_packets);
CalculatorGraphConfig::Node* sink_node = config->add_node(); CalculatorGraphConfig::Node* sink_node = config->add_node();
const std::string name = GetUnusedNodeName( const std::string name = GetUnusedNodeName(
*config, absl::StrCat("multi_callback_", absl::StrJoin(streams, "_"))); *config, absl::StrCat("multi_callback_", absl::StrJoin(streams, "_")));
@ -179,15 +191,23 @@ void AddMultiStreamCallback(
sink_node->add_input_stream(stream_name); sink_node->add_input_stream(stream_name);
} }
if (observe_timestamp_bounds) {
const std::string observe_ts_bounds_packet_name = GetUnusedSidePacketName(
*config, absl::StrCat(name, "_observe_ts_bounds"));
sink_node->add_input_side_packet(absl::StrCat(
"OBSERVE_TIMESTAMP_BOUNDS:", observe_ts_bounds_packet_name));
InsertIfNotPresent(side_packets, observe_ts_bounds_packet_name,
MakePacket<bool>(true));
}
const std::string input_side_packet_name = const std::string input_side_packet_name =
GetUnusedSidePacketName(*config, absl::StrCat(name, "_callback")); GetUnusedSidePacketName(*config, absl::StrCat(name, "_callback"));
side_packet->first = input_side_packet_name;
sink_node->add_input_side_packet( sink_node->add_input_side_packet(
absl::StrCat("VECTOR_CALLBACK:", input_side_packet_name)); absl::StrCat("VECTOR_CALLBACK:", input_side_packet_name));
side_packet->second = InsertIfNotPresent(
side_packets, input_side_packet_name,
MakePacket<std::function<void(const std::vector<Packet>&)>>( MakePacket<std::function<void(const std::vector<Packet>&)>>(
std::move(callback)); std::move(callback)));
} }
void AddCallbackWithHeaderCalculator(const std::string& stream_name, void AddCallbackWithHeaderCalculator(const std::string& stream_name,
@ -240,6 +260,10 @@ absl::Status CallbackCalculator::GetContract(CalculatorContract* cc) {
return mediapipe::InvalidArgumentErrorBuilder(MEDIAPIPE_LOC) return mediapipe::InvalidArgumentErrorBuilder(MEDIAPIPE_LOC)
<< "InputSidePackets must use tags."; << "InputSidePackets must use tags.";
} }
if (cc->InputSidePackets().HasTag("OBSERVE_TIMESTAMP_BOUNDS")) {
cc->InputSidePackets().Tag("OBSERVE_TIMESTAMP_BOUNDS").Set<bool>();
cc->SetProcessTimestampBounds(true);
}
int count = allow_multiple_streams ? cc->Inputs().NumEntries("") : 1; int count = allow_multiple_streams ? cc->Inputs().NumEntries("") : 1;
for (int i = 0; i < count; ++i) { for (int i = 0; i < count; ++i) {
@ -266,6 +290,12 @@ absl::Status CallbackCalculator::Open(CalculatorContext* cc) {
return mediapipe::InvalidArgumentErrorBuilder(MEDIAPIPE_LOC) return mediapipe::InvalidArgumentErrorBuilder(MEDIAPIPE_LOC)
<< "missing callback."; << "missing callback.";
} }
if (cc->InputSidePackets().HasTag("OBSERVE_TIMESTAMP_BOUNDS") &&
!cc->InputSidePackets().Tag("OBSERVE_TIMESTAMP_BOUNDS").Get<bool>()) {
return mediapipe::InvalidArgumentErrorBuilder(MEDIAPIPE_LOC)
<< "The value of the OBSERVE_TIMESTAMP_BOUNDS input side packet "
"must be set to true";
}
return absl::OkStatus(); return absl::OkStatus();
} }

View File

@ -115,6 +115,12 @@ void AddMultiStreamCallback(
std::function<void(const std::vector<Packet>&)> callback, std::function<void(const std::vector<Packet>&)> callback,
CalculatorGraphConfig* config, std::pair<std::string, Packet>* side_packet); CalculatorGraphConfig* config, std::pair<std::string, Packet>* side_packet);
void AddMultiStreamCallback(
const std::vector<std::string>& streams,
std::function<void(const std::vector<Packet>&)> callback,
CalculatorGraphConfig* config, std::map<std::string, Packet>* side_packets,
bool observe_timestamp_bounds = false);
// Add a CallbackWithHeaderCalculator to intercept packets sent on // Add a CallbackWithHeaderCalculator to intercept packets sent on
// stream stream_name, and the header packet on stream stream_header. // stream stream_name, and the header packet on stream stream_header.
// The input side packet with the produced name callback_side_packet_name // The input side packet with the produced name callback_side_packet_name

View File

@ -146,5 +146,63 @@ TEST(CallbackTest, TestAddMultiStreamCallback) {
EXPECT_THAT(sums, testing::ElementsAre(15, 7, 9)); EXPECT_THAT(sums, testing::ElementsAre(15, 7, 9));
} }
class TimestampBoundTestCalculator : public CalculatorBase {
public:
static absl::Status GetContract(CalculatorContract* cc) {
cc->Outputs().Index(0).Set<int>();
cc->Outputs().Index(1).Set<int>();
return absl::OkStatus();
}
absl::Status Open(CalculatorContext* cc) final { return absl::OkStatus(); }
absl::Status Process(CalculatorContext* cc) final {
if (count_ % 5 == 0) {
cc->Outputs().Index(0).SetNextTimestampBound(Timestamp(count_ + 1));
cc->Outputs().Index(1).SetNextTimestampBound(Timestamp(count_ + 1));
}
++count_;
if (count_ == 13) {
return tool::StatusStop();
}
return absl::OkStatus();
}
private:
int count_ = 1;
};
REGISTER_CALCULATOR(TimestampBoundTestCalculator);
TEST(CallbackTest, TestAddMultiStreamCallbackWithTimestampNotification) {
std::string config_str = R"(
node {
calculator: "TimestampBoundTestCalculator"
output_stream: "foo"
output_stream: "bar"
}
)";
CalculatorGraphConfig graph_config =
mediapipe::ParseTextProtoOrDie<CalculatorGraphConfig>(config_str);
std::vector<int> sums;
std::map<std::string, Packet> side_packets;
tool::AddMultiStreamCallback(
{"foo", "bar"},
[&sums](const std::vector<Packet>& packets) {
Packet foo_p = packets[0];
Packet bar_p = packets[1];
ASSERT_TRUE(foo_p.IsEmpty() && bar_p.IsEmpty());
int foo = foo_p.Timestamp().Value();
int bar = bar_p.Timestamp().Value();
sums.push_back(foo + bar);
},
&graph_config, &side_packets, true);
CalculatorGraph graph(graph_config);
MP_ASSERT_OK(graph.StartRun(side_packets));
MP_ASSERT_OK(graph.WaitUntilDone());
EXPECT_THAT(sums, testing::ElementsAre(10, 20));
}
} // namespace } // namespace
} // namespace mediapipe } // namespace mediapipe

View File

@ -14,7 +14,7 @@
load("//mediapipe/gpu:metal.bzl", "metal_library") load("//mediapipe/gpu:metal.bzl", "metal_library")
load("@build_bazel_rules_apple//apple:ios.bzl", "ios_unit_test") load("@build_bazel_rules_apple//apple:ios.bzl", "ios_unit_test")
load("//mediapipe/framework/port:build_config.bzl", "mediapipe_cc_proto_library") load("//mediapipe/framework/port:build_config.bzl", "mediapipe_cc_proto_library", "mediapipe_proto_library")
load("//mediapipe/framework:mediapipe_cc_test.bzl", "mediapipe_cc_test") load("//mediapipe/framework:mediapipe_cc_test.bzl", "mediapipe_cc_test")
licenses(["notice"]) licenses(["notice"])
@ -240,6 +240,12 @@ cc_library(
], ],
) )
mediapipe_proto_library(
name = "gpu_origin_proto",
srcs = ["gpu_origin.proto"],
visibility = ["//visibility:public"],
)
objc_library( objc_library(
name = "pixel_buffer_pool_util", name = "pixel_buffer_pool_util",
srcs = ["pixel_buffer_pool_util.mm"], srcs = ["pixel_buffer_pool_util.mm"],
@ -460,6 +466,8 @@ cc_library(
"//mediapipe/framework:calculator_context", "//mediapipe/framework:calculator_context",
"//mediapipe/framework:calculator_node", "//mediapipe/framework:calculator_node",
"//mediapipe/framework/port:logging", "//mediapipe/framework/port:logging",
"//mediapipe/util:resource_cache",
"@com_google_absl//absl/hash",
"@com_google_absl//absl/memory", "@com_google_absl//absl/memory",
"@com_google_absl//absl/synchronization", "@com_google_absl//absl/synchronization",
] + select({ ] + select({
@ -760,8 +768,10 @@ cc_library(
deps = [ deps = [
":gl_calculator_helper", ":gl_calculator_helper",
":gl_quad_renderer", ":gl_quad_renderer",
":gpu_buffer",
":shader_util", ":shader_util",
"//mediapipe/framework:calculator_framework", "//mediapipe/framework:calculator_framework",
"//mediapipe/framework/api2:node",
"//mediapipe/framework/port:ret_check", "//mediapipe/framework/port:ret_check",
"//mediapipe/framework/port:status", "//mediapipe/framework/port:status",
"//mediapipe/gpu:gl_surface_sink_calculator_cc_proto", "//mediapipe/gpu:gl_surface_sink_calculator_cc_proto",

View File

@ -563,8 +563,14 @@ class GlFenceSyncPoint : public GlSyncPoint {
void WaitOnGpu() override { void WaitOnGpu() override {
if (!sync_) return; if (!sync_) return;
// TODO: do not wait if we are already on the same context? // TODO: do not wait if we are already on the same context?
// WebGL2 specifies a waitSync call, but since cross-context
// synchronization is not supported, it's actually a no-op. Firefox prints
// a warning when it's called, so let's just skip the call. See
// b/184637485 for details.
#ifndef __EMSCRIPTEN__
glWaitSync(sync_, 0, GL_TIMEOUT_IGNORED); glWaitSync(sync_, 0, GL_TIMEOUT_IGNORED);
#endif
} }
bool IsReady() override { bool IsReady() override {

Some files were not shown because too many files have changed in this diff Show More