Project import generated by Copybara.
GitOrigin-RevId: ff83882955f1a1e2a043ff4e71278be9d7217bbe
|
@ -23,6 +23,7 @@ ENV DEBIAN_FRONTEND=noninteractive
|
|||
|
||||
RUN apt-get update && apt-get install -y --no-install-recommends \
|
||||
build-essential \
|
||||
gcc-8 g++-8 \
|
||||
ca-certificates \
|
||||
curl \
|
||||
ffmpeg \
|
||||
|
@ -44,6 +45,7 @@ RUN apt-get update && apt-get install -y --no-install-recommends \
|
|||
apt-get clean && \
|
||||
rm -rf /var/lib/apt/lists/*
|
||||
|
||||
RUN update-alternatives --install /usr/bin/gcc gcc /usr/bin/gcc-8 100 --slave /usr/bin/g++ g++ /usr/bin/g++-8
|
||||
RUN pip3 install --upgrade setuptools
|
||||
RUN pip3 install wheel
|
||||
RUN pip3 install future
|
||||
|
|
|
@ -337,6 +337,8 @@ maven_install(
|
|||
"androidx.test.espresso:espresso-core:3.1.1",
|
||||
"com.github.bumptech.glide:glide:4.11.0",
|
||||
"com.google.android.material:material:aar:1.0.0-rc01",
|
||||
"com.google.auto.value:auto-value:1.6.4",
|
||||
"com.google.auto.value:auto-value-annotations:1.6.4",
|
||||
"com.google.code.findbugs:jsr305:3.0.2",
|
||||
"com.google.flogger:flogger-system-backend:0.3.1",
|
||||
"com.google.flogger:flogger:0.3.1",
|
||||
|
@ -367,9 +369,9 @@ http_archive(
|
|||
)
|
||||
|
||||
# Tensorflow repo should always go after the other external dependencies.
|
||||
# 2021-03-25
|
||||
_TENSORFLOW_GIT_COMMIT = "c67f68021824410ebe9f18513b8856ac1c6d4887"
|
||||
_TENSORFLOW_SHA256= "fd07d0b39422dc435e268c5e53b2646a8b4b1e3151b87837b43f86068faae87f"
|
||||
# 2021-04-30
|
||||
_TENSORFLOW_GIT_COMMIT = "5bd3c57ef184543d22e34e36cff9d9bea608e06d"
|
||||
_TENSORFLOW_SHA256= "9a45862834221aafacf6fb275f92b3876bc89443cbecc51be93f13839a6609f0"
|
||||
http_archive(
|
||||
name = "org_tensorflow",
|
||||
urls = [
|
||||
|
|
|
@ -17,15 +17,15 @@
|
|||
# Script to build/run all MediaPipe desktop example apps (with webcam input).
|
||||
#
|
||||
# To build and run all apps and store them in out_dir:
|
||||
# $ ./build_ios_examples.sh -d out_dir
|
||||
# $ ./build_desktop_examples.sh -d out_dir
|
||||
# Omitting -d and the associated directory saves all generated apps in the
|
||||
# current directory.
|
||||
# To build all apps and store them in out_dir:
|
||||
# $ ./build_ios_examples.sh -d out_dir -b
|
||||
# $ ./build_desktop_examples.sh -d out_dir -b
|
||||
# Omitting -d and the associated directory saves all generated apps in the
|
||||
# current directory.
|
||||
# To run all apps already stored in out_dir:
|
||||
# $ ./build_ios_examples.sh -d out_dir -r
|
||||
# $ ./build_desktop_examples.sh -d out_dir -r
|
||||
# Omitting -d and the associated directory assumes all apps are in the current
|
||||
# directory.
|
||||
|
||||
|
|
|
@ -187,7 +187,7 @@ node {
|
|||
```
|
||||
|
||||
In the calculator implementation, inputs and outputs are also identified by tag
|
||||
name and index number. In the function below input are output are identified:
|
||||
name and index number. In the function below input and output are identified:
|
||||
|
||||
* By index number: The combined input stream is identified simply by index
|
||||
`0`.
|
||||
|
@ -355,7 +355,6 @@ class PacketClonerCalculator : public CalculatorBase {
|
|||
current_[i].At(cc->InputTimestamp()));
|
||||
// Add a packet to output stream of index i a packet from inputstream i
|
||||
// with timestamp common to all present inputs
|
||||
//
|
||||
} else {
|
||||
cc->Outputs().Index(i).SetNextTimestampBound(
|
||||
cc->InputTimestamp().NextAllowedInStream());
|
||||
|
@ -382,7 +381,7 @@ defined your calculator class, register it with a macro invocation
|
|||
REGISTER_CALCULATOR(calculator_class_name).
|
||||
|
||||
Below is a trivial MediaPipe graph that has 3 input streams, 1 node
|
||||
(PacketClonerCalculator) and 3 output streams.
|
||||
(PacketClonerCalculator) and 2 output streams.
|
||||
|
||||
```proto
|
||||
input_stream: "room_mic_signal"
|
||||
|
|
|
@ -83,12 +83,12 @@ Below is an example of how to create a subgraph named `TwoPassThroughSubgraph`.
|
|||
output_stream: "out3"
|
||||
|
||||
node {
|
||||
calculator: "PassThroughculator"
|
||||
calculator: "PassThroughCalculator"
|
||||
input_stream: "out1"
|
||||
output_stream: "out2"
|
||||
}
|
||||
node {
|
||||
calculator: "PassThroughculator"
|
||||
calculator: "PassThroughCalculator"
|
||||
input_stream: "out2"
|
||||
output_stream: "out3"
|
||||
}
|
||||
|
|
|
@ -57,7 +57,7 @@ Please verify all the necessary packages are installed.
|
|||
* Android SDK Build-Tools 28 or 29
|
||||
* Android SDK Platform-Tools 28 or 29
|
||||
* Android SDK Tools 26.1.1
|
||||
* Android NDK 17c or above
|
||||
* Android NDK 19c or above
|
||||
|
||||
### Option 1: Build with Bazel in Command Line
|
||||
|
||||
|
@ -111,7 +111,7 @@ app:
|
|||
* Verify that Android SDK Build-Tools 28 or 29 is installed.
|
||||
* Verify that Android SDK Platform-Tools 28 or 29 is installed.
|
||||
* Verify that Android SDK Tools 26.1.1 is installed.
|
||||
* Verify that Android NDK 17c or above is installed.
|
||||
* Verify that Android NDK 19c or above is installed.
|
||||
* Take note of the Android NDK Location, e.g.,
|
||||
`/usr/local/home/Android/Sdk/ndk-bundle` or
|
||||
`/usr/local/home/Android/Sdk/ndk/20.0.5594570`.
|
||||
|
|
|
@ -37,7 +37,7 @@ each project.
|
|||
load("//mediapipe/java/com/google/mediapipe:mediapipe_aar.bzl", "mediapipe_aar")
|
||||
|
||||
mediapipe_aar(
|
||||
name = "mp_face_detection_aar",
|
||||
name = "mediapipe_face_detection",
|
||||
calculators = ["//mediapipe/graphs/face_detection:mobile_calculators"],
|
||||
)
|
||||
```
|
||||
|
@ -45,26 +45,29 @@ each project.
|
|||
2. Run the Bazel build command to generate the AAR.
|
||||
|
||||
```bash
|
||||
bazel build -c opt --host_crosstool_top=@bazel_tools//tools/cpp:toolchain \
|
||||
--fat_apk_cpu=arm64-v8a,armeabi-v7a --strip=ALWAYS \
|
||||
//path/to/the/aar/build/file:aar_name
|
||||
bazel build -c opt --strip=ALWAYS \
|
||||
--host_crosstool_top=@bazel_tools//tools/cpp:toolchain \
|
||||
--fat_apk_cpu=arm64-v8a,armeabi-v7a \
|
||||
//path/to/the/aar/build/file:aar_name.aar
|
||||
```
|
||||
|
||||
For the face detection AAR target we made in the step 1, run:
|
||||
For the face detection AAR target we made in step 1, run:
|
||||
|
||||
```bash
|
||||
bazel build -c opt --host_crosstool_top=@bazel_tools//tools/cpp:toolchain --fat_apk_cpu=arm64-v8a,armeabi-v7a \
|
||||
//mediapipe/examples/android/src/java/com/google/mediapipe/apps/aar_example:mp_face_detection_aar
|
||||
bazel build -c opt --strip=ALWAYS \
|
||||
--host_crosstool_top=@bazel_tools//tools/cpp:toolchain \
|
||||
--fat_apk_cpu=arm64-v8a,armeabi-v7a \
|
||||
//mediapipe/examples/android/src/java/com/google/mediapipe/apps/aar_example:mediapipe_face_detection.aar
|
||||
|
||||
# It should print:
|
||||
# Target //mediapipe/examples/android/src/java/com/google/mediapipe/apps/aar_example:mp_face_detection_aar up-to-date:
|
||||
# bazel-bin/mediapipe/examples/android/src/java/com/google/mediapipe/apps/aar_example/mp_face_detection_aar.aar
|
||||
# Target //mediapipe/examples/android/src/java/com/google/mediapipe/apps/aar_example:mediapipe_face_detection.aar up-to-date:
|
||||
# bazel-bin/mediapipe/examples/android/src/java/com/google/mediapipe/apps/aar_example/mediapipe_face_detection.aar
|
||||
```
|
||||
|
||||
3. (Optional) Save the AAR to your preferred location.
|
||||
|
||||
```bash
|
||||
cp bazel-bin/mediapipe/examples/android/src/java/com/google/mediapipe/apps/aar_example/mp_face_detection_aar.aar
|
||||
cp bazel-bin/mediapipe/examples/android/src/java/com/google/mediapipe/apps/aar_example/mediapipe_face_detection.aar
|
||||
/absolute/path/to/your/preferred/location
|
||||
```
|
||||
|
||||
|
@ -75,7 +78,7 @@ each project.
|
|||
2. Copy the AAR into app/libs.
|
||||
|
||||
```bash
|
||||
cp bazel-bin/mediapipe/examples/android/src/java/com/google/mediapipe/apps/aar_example/mp_face_detection_aar.aar
|
||||
cp bazel-bin/mediapipe/examples/android/src/java/com/google/mediapipe/apps/aar_example/mediapipe_face_detection.aar
|
||||
/path/to/your/app/libs/
|
||||
```
|
||||
|
||||
|
@ -92,29 +95,14 @@ each project.
|
|||
[the face detection tflite model](https://github.com/google/mediapipe/tree/master/mediapipe/modules/face_detection/face_detection_front.tflite).
|
||||
|
||||
```bash
|
||||
bazel build -c opt mediapipe/mediapipe/graphs/face_detection:mobile_gpu_binary_graph
|
||||
cp bazel-bin/mediapipe/graphs/face_detection/mobile_gpu.binarypb /path/to/your/app/src/main/assets/
|
||||
bazel build -c opt mediapipe/graphs/face_detection:face_detection_mobile_gpu_binary_graph
|
||||
cp bazel-bin/mediapipe/graphs/face_detection/face_detection_mobile_gpu.binarypb /path/to/your/app/src/main/assets/
|
||||
cp mediapipe/modules/face_detection/face_detection_front.tflite /path/to/your/app/src/main/assets/
|
||||
```
|
||||
|
||||
![Screenshot](../images/mobile/assets_location.png)
|
||||
|
||||
4. Make app/src/main/jniLibs and copy OpenCV JNI libraries into
|
||||
app/src/main/jniLibs.
|
||||
|
||||
MediaPipe depends on OpenCV, you will need to copy the precompiled OpenCV so
|
||||
files into app/src/main/jniLibs. You can download the official OpenCV
|
||||
Android SDK from
|
||||
[here](https://github.com/opencv/opencv/releases/download/3.4.3/opencv-3.4.3-android-sdk.zip)
|
||||
and run:
|
||||
|
||||
```bash
|
||||
cp -R ~/Downloads/OpenCV-android-sdk/sdk/native/libs/arm* /path/to/your/app/src/main/jniLibs/
|
||||
```
|
||||
|
||||
![Screenshot](../images/mobile/android_studio_opencv_location.png)
|
||||
|
||||
5. Modify app/build.gradle to add MediaPipe dependencies and MediaPipe AAR.
|
||||
4. Modify app/build.gradle to add MediaPipe dependencies and MediaPipe AAR.
|
||||
|
||||
```
|
||||
dependencies {
|
||||
|
@ -136,10 +124,14 @@ each project.
|
|||
implementation "androidx.camera:camera-core:$camerax_version"
|
||||
implementation "androidx.camera:camera-camera2:$camerax_version"
|
||||
implementation "androidx.camera:camera-lifecycle:$camerax_version"
|
||||
// AutoValue
|
||||
def auto_value_version = "1.6.4"
|
||||
implementation "com.google.auto.value:auto-value-annotations:$auto_value_version"
|
||||
annotationProcessor "com.google.auto.value:auto-value:$auto_value_version"
|
||||
}
|
||||
```
|
||||
|
||||
6. Follow our Android app examples to use MediaPipe in Android Studio for your
|
||||
5. Follow our Android app examples to use MediaPipe in Android Studio for your
|
||||
use case. If you are looking for an example, a face detection example can be
|
||||
found
|
||||
[here](https://github.com/jiuqiant/mediapipe_face_detection_aar_example) and
|
||||
|
|
|
@ -471,7 +471,7 @@ next section.
|
|||
4. Install Visual C++ Build Tools 2019 and WinSDK
|
||||
|
||||
Go to
|
||||
[the VisualStudio website](ttps://visualstudio.microsoft.com/visual-cpp-build-tools),
|
||||
[the VisualStudio website](https://visualstudio.microsoft.com/visual-cpp-build-tools),
|
||||
download build tools, and install Microsoft Visual C++ 2019 Redistributable
|
||||
and Microsoft Build Tools 2019.
|
||||
|
||||
|
@ -738,7 +738,7 @@ common build issues.
|
|||
root@bca08b91ff63:/mediapipe# bash ./setup_android_sdk_and_ndk.sh
|
||||
|
||||
# Should print:
|
||||
# Android NDK is now installed. Consider setting $ANDROID_NDK_HOME environment variable to be /root/Android/Sdk/ndk-bundle/android-ndk-r18b
|
||||
# Android NDK is now installed. Consider setting $ANDROID_NDK_HOME environment variable to be /root/Android/Sdk/ndk-bundle/android-ndk-r19c
|
||||
# Set android_ndk_repository and android_sdk_repository in WORKSPACE
|
||||
# Done
|
||||
|
||||
|
|
|
@ -26,7 +26,7 @@ You can, for instance, activate a Python virtual environment:
|
|||
$ python3 -m venv mp_env && source mp_env/bin/activate
|
||||
```
|
||||
|
||||
Install MediaPipe Python package and start Python intepreter:
|
||||
Install MediaPipe Python package and start Python interpreter:
|
||||
|
||||
```bash
|
||||
(mp_env)$ pip install mediapipe
|
||||
|
|
|
@ -97,6 +97,49 @@ linux_opencv/macos_opencv/windows_opencv.BUILD files for your local opencv
|
|||
libraries. [This GitHub issue](https://github.com/google/mediapipe/issues/666)
|
||||
may also help.
|
||||
|
||||
## Python pip install failure
|
||||
|
||||
The error message:
|
||||
|
||||
```
|
||||
ERROR: Could not find a version that satisfies the requirement mediapipe
|
||||
ERROR: No matching distribution found for mediapipe
|
||||
```
|
||||
|
||||
after running `pip install mediapipe` usually indicates that there is no qualified MediaPipe Python for your system.
|
||||
Please note that MediaPipe Python PyPI officially supports the **64-bit**
|
||||
version of Python 3.7 and above on the following OS:
|
||||
|
||||
- x86_64 Linux
|
||||
- x86_64 macOS 10.15+
|
||||
- amd64 Windows
|
||||
|
||||
If the OS is currently supported and you still see this error, please make sure
|
||||
that both the Python and pip binary are for Python 3.7 and above. Otherwise,
|
||||
please consider building the MediaPipe Python package locally by following the
|
||||
instructions [here](python.md#building-mediapipe-python-package).
|
||||
|
||||
## Python DLL load failure on Windows
|
||||
|
||||
The error message:
|
||||
|
||||
```
|
||||
ImportError: DLL load failed: The specified module could not be found
|
||||
```
|
||||
|
||||
usually indicates that the local Windows system is missing Visual C++
|
||||
redistributable packages and/or Visual C++ runtime DLLs. This can be solved by
|
||||
either installing the official
|
||||
[vc_redist.x64.exe](https://support.microsoft.com/en-us/topic/the-latest-supported-visual-c-downloads-2647da03-1eea-4433-9aff-95f26a218cc0)
|
||||
or installing the "msvc-runtime" Python package by running
|
||||
|
||||
```bash
|
||||
$ python -m pip install msvc-runtime
|
||||
```
|
||||
|
||||
Please note that the "msvc-runtime" Python package is not released or maintained
|
||||
by Microsoft.
|
||||
|
||||
## Native method not found
|
||||
|
||||
The error message:
|
||||
|
|
Before Width: | Height: | Size: 35 KiB After Width: | Height: | Size: 34 KiB |
Before Width: | Height: | Size: 75 KiB |
Before Width: | Height: | Size: 29 KiB After Width: | Height: | Size: 42 KiB |
BIN
docs/images/mobile/pose_tracking_example.gif
Normal file
After Width: | Height: | Size: 2.3 MiB |
Before Width: | Height: | Size: 6.9 MiB |
|
@ -77,7 +77,7 @@ Supported configuration options:
|
|||
```python
|
||||
import cv2
|
||||
import mediapipe as mp
|
||||
mp_face_detction = mp.solutions.face_detection
|
||||
mp_face_detection = mp.solutions.face_detection
|
||||
mp_drawing = mp.solutions.drawing_utils
|
||||
|
||||
# For static images:
|
||||
|
|
|
@ -135,12 +135,11 @@ another detection until it loses track, on reducing computation and latency. If
|
|||
set to `true`, person detection runs every input image, ideal for processing a
|
||||
batch of static, possibly unrelated, images. Default to `false`.
|
||||
|
||||
#### upper_body_only
|
||||
#### model_complexity
|
||||
|
||||
If set to `true`, the solution outputs only the 25 upper-body pose landmarks
|
||||
(535 in total) instead of the full set of 33 pose landmarks (543 in total). Note
|
||||
that upper-body-only prediction may be more accurate for use cases where the
|
||||
lower-body parts are mostly out of view. Default to `false`.
|
||||
Complexity of the pose landmark model: `0`, `1` or `2`. Landmark accuracy as
|
||||
well as inference latency generally go up with the model complexity. Default to
|
||||
`1`.
|
||||
|
||||
#### smooth_landmarks
|
||||
|
||||
|
@ -207,7 +206,7 @@ install MediaPipe Python package, then learn more in the companion
|
|||
Supported configuration options:
|
||||
|
||||
* [static_image_mode](#static_image_mode)
|
||||
* [upper_body_only](#upper_body_only)
|
||||
* [model_complexity](#model_complexity)
|
||||
* [smooth_landmarks](#smooth_landmarks)
|
||||
* [min_detection_confidence](#min_detection_confidence)
|
||||
* [min_tracking_confidence](#min_tracking_confidence)
|
||||
|
@ -219,7 +218,9 @@ mp_drawing = mp.solutions.drawing_utils
|
|||
mp_holistic = mp.solutions.holistic
|
||||
|
||||
# For static images:
|
||||
with mp_holistic.Holistic(static_image_mode=True) as holistic:
|
||||
with mp_holistic.Holistic(
|
||||
static_image_mode=True,
|
||||
model_complexity=2) as holistic:
|
||||
for idx, file in enumerate(file_list):
|
||||
image = cv2.imread(file)
|
||||
image_height, image_width, _ = image.shape
|
||||
|
@ -240,8 +241,6 @@ with mp_holistic.Holistic(static_image_mode=True) as holistic:
|
|||
annotated_image, results.left_hand_landmarks, mp_holistic.HAND_CONNECTIONS)
|
||||
mp_drawing.draw_landmarks(
|
||||
annotated_image, results.right_hand_landmarks, mp_holistic.HAND_CONNECTIONS)
|
||||
# Use mp_holistic.UPPER_BODY_POSE_CONNECTIONS for drawing below when
|
||||
# upper_body_only is set to True.
|
||||
mp_drawing.draw_landmarks(
|
||||
annotated_image, results.pose_landmarks, mp_holistic.POSE_CONNECTIONS)
|
||||
cv2.imwrite('/tmp/annotated_image' + str(idx) + '.png', annotated_image)
|
||||
|
@ -291,7 +290,7 @@ and the following usage example.
|
|||
|
||||
Supported configuration options:
|
||||
|
||||
* [upperBodyOnly](#upper_body_only)
|
||||
* [modelComplexity](#model_complexity)
|
||||
* [smoothLandmarks](#smooth_landmarks)
|
||||
* [minDetectionConfidence](#min_detection_confidence)
|
||||
* [minTrackingConfidence](#min_tracking_confidence)
|
||||
|
@ -348,7 +347,7 @@ const holistic = new Holistic({locateFile: (file) => {
|
|||
return `https://cdn.jsdelivr.net/npm/@mediapipe/holistic/${file}`;
|
||||
}});
|
||||
holistic.setOptions({
|
||||
upperBodyOnly: false,
|
||||
modelComplexity: 1,
|
||||
smoothLandmarks: true,
|
||||
minDetectionConfidence: 0.5,
|
||||
minTrackingConfidence: 0.5
|
||||
|
|
|
@ -15,10 +15,10 @@ nav_order: 30
|
|||
### [Face Detection](https://google.github.io/mediapipe/solutions/face_detection)
|
||||
|
||||
* Face detection model for front-facing/selfie camera:
|
||||
[TFLite model](https://github.com/google/mediapipe/tree/master/mediapipe/models/face_detection_front.tflite),
|
||||
[TFLite model](https://github.com/google/mediapipe/tree/master/mediapipe/modules/face_detection/face_detection_front.tflite),
|
||||
[TFLite model quantized for EdgeTPU/Coral](https://github.com/google/mediapipe/tree/master/mediapipe/examples/coral/models/face-detector-quantized_edgetpu.tflite)
|
||||
* Face detection model for back-facing camera:
|
||||
[TFLite model ](https://github.com/google/mediapipe/tree/master/mediapipe/models/face_detection_back.tflite)
|
||||
[TFLite model ](https://github.com/google/mediapipe/tree/master/mediapipe/modules/face_detection/face_detection_back.tflite)
|
||||
* [Model card](https://mediapipe.page.link/blazeface-mc)
|
||||
|
||||
### [Face Mesh](https://google.github.io/mediapipe/solutions/face_mesh)
|
||||
|
@ -49,10 +49,10 @@ nav_order: 30
|
|||
|
||||
* Pose detection model:
|
||||
[TFLite model](https://github.com/google/mediapipe/tree/master/mediapipe/modules/pose_detection/pose_detection.tflite)
|
||||
* Full-body pose landmark model:
|
||||
[TFLite model](https://github.com/google/mediapipe/tree/master/mediapipe/modules/pose_landmark/pose_landmark_full_body.tflite)
|
||||
* Upper-body pose landmark model:
|
||||
[TFLite model](https://github.com/google/mediapipe/tree/master/mediapipe/modules/pose_landmark/pose_landmark_upper_body.tflite)
|
||||
* Pose landmark model:
|
||||
[TFLite model (lite)](https://github.com/google/mediapipe/tree/master/mediapipe/modules/pose_landmark/pose_landmark_lite.tflite),
|
||||
[TFLite model (full)](https://github.com/google/mediapipe/tree/master/mediapipe/modules/pose_landmark/pose_landmark_full.tflite),
|
||||
[TFLite model (heavy)](https://github.com/google/mediapipe/tree/master/mediapipe/modules/pose_landmark/pose_landmark_heavy.tflite)
|
||||
* [Model card](https://mediapipe.page.link/blazepose-mc)
|
||||
|
||||
### [Holistic](https://google.github.io/mediapipe/solutions/holistic)
|
||||
|
|
|
@ -30,8 +30,7 @@ overlay of digital content and information on top of the physical world in
|
|||
augmented reality.
|
||||
|
||||
MediaPipe Pose is a ML solution for high-fidelity body pose tracking, inferring
|
||||
33 3D landmarks on the whole body (or 25 upper-body landmarks) from RGB video
|
||||
frames utilizing our
|
||||
33 3D landmarks on the whole body from RGB video frames utilizing our
|
||||
[BlazePose](https://ai.googleblog.com/2020/08/on-device-real-time-body-pose-tracking.html)
|
||||
research that also powers the
|
||||
[ML Kit Pose Detection API](https://developers.google.com/ml-kit/vision/pose-detection).
|
||||
|
@ -40,9 +39,9 @@ environments for inference, whereas our method achieves real-time performance on
|
|||
most modern [mobile phones](#mobile), [desktops/laptops](#desktop), in
|
||||
[python](#python-solution-api) and even on the [web](#javascript-solution-api).
|
||||
|
||||
![pose_tracking_upper_body_example.gif](../images/mobile/pose_tracking_upper_body_example.gif) |
|
||||
:--------------------------------------------------------------------------------------------: |
|
||||
*Fig 1. Example of MediaPipe Pose for upper-body pose tracking.* |
|
||||
![pose_tracking_example.gif](../images/mobile/pose_tracking_example.gif) |
|
||||
:----------------------------------------------------------------------: |
|
||||
*Fig 1. Example of MediaPipe Pose for pose tracking.* |
|
||||
|
||||
## ML Pipeline
|
||||
|
||||
|
@ -77,6 +76,23 @@ Note: To visualize a graph, copy the graph and paste it into
|
|||
to visualize its associated subgraphs, please see
|
||||
[visualizer documentation](../tools/visualizer.md).
|
||||
|
||||
## Pose Estimation Quality
|
||||
|
||||
To evaluate the quality of our [models](./models.md#pose) against other
|
||||
well-performing publicly available solutions, we use a validation dataset,
|
||||
consisting of 1k images with diverse Yoga, HIIT, and Dance postures. Each image
|
||||
contains only a single person located 2-4 meters from the camera. To be
|
||||
consistent with other solutions, we perform evaluation only for 17 keypoints
|
||||
from [COCO topology](https://cocodataset.org/#keypoints-2020).
|
||||
|
||||
Method | [mAP](https://cocodataset.org/#keypoints-eval) | [PCK@0.2](https://github.com/cbsudux/Human-Pose-Estimation-101) | [FPS](https://en.wikipedia.org/wiki/Frame_rate), Pixel 3 [TFLite GPU](https://www.tensorflow.org/lite/performance/gpu_advanced) | [FPS](https://en.wikipedia.org/wiki/Frame_rate), MacBook Pro (15-inch, 2017)
|
||||
----------------------------------------------------------------------------------------------------- | ---------------------------------------------: | --------------------------------------------------------------: | ------------------------------------------------------------------------------------------------------------------------------: | ---------------------------------------------------------------------------:
|
||||
BlazePose.Lite | 49.1 | 91.7 | 49 | 40
|
||||
BlazePose.Full | 64.5 | 95.8 | 40 | 37
|
||||
BlazePose.Heavy | 70.9 | 97.0 | 19 | 26
|
||||
[AlphaPose.ResNet50](https://github.com/MVIG-SJTU/AlphaPose) | 57.6 | 93.1 | N/A | N/A
|
||||
[Apple Vision](https://developer.apple.com/documentation/vision/detecting_human_body_poses_in_images) | 37.0 | 85.3 | N/A | N/A
|
||||
|
||||
## Models
|
||||
|
||||
### Person/pose Detection Model (BlazePose Detector)
|
||||
|
@ -97,11 +113,8 @@ hip midpoints.
|
|||
|
||||
### Pose Landmark Model (BlazePose GHUM 3D)
|
||||
|
||||
The landmark model in MediaPipe Pose comes in two versions: a full-body model
|
||||
that predicts the location of 33 pose landmarks (see figure below), and an
|
||||
upper-body version that only predicts the first 25. The latter may be more
|
||||
accurate than the former in scenarios where the lower-body parts are mostly out
|
||||
of view.
|
||||
The landmark model in MediaPipe Pose predicts the location of 33 pose landmarks
|
||||
(see figure below).
|
||||
|
||||
Please find more detail in the
|
||||
[BlazePose Google AI Blog](https://ai.googleblog.com/2020/08/on-device-real-time-body-pose-tracking.html),
|
||||
|
@ -129,12 +142,11 @@ until it loses track, on reducing computation and latency. If set to `true`,
|
|||
person detection runs every input image, ideal for processing a batch of static,
|
||||
possibly unrelated, images. Default to `false`.
|
||||
|
||||
#### upper_body_only
|
||||
#### model_complexity
|
||||
|
||||
If set to `true`, the solution outputs only the 25 upper-body pose landmarks.
|
||||
Otherwise, it outputs the full set of 33 pose landmarks. Note that
|
||||
upper-body-only prediction may be more accurate for use cases where the
|
||||
lower-body parts are mostly out of view. Default to `false`.
|
||||
Complexity of the pose landmark model: `0`, `1` or `2`. Landmark accuracy as
|
||||
well as inference latency generally go up with the model complexity. Default to
|
||||
`1`.
|
||||
|
||||
#### smooth_landmarks
|
||||
|
||||
|
@ -170,9 +182,6 @@ A list of pose landmarks. Each lanmark consists of the following:
|
|||
being the origin, and the smaller the value the closer the landmark is to
|
||||
the camera. The magnitude of `z` uses roughly the same scale as `x`.
|
||||
|
||||
Note: `z` is predicted only in full-body mode, and should be discarded when
|
||||
[upper_body_only](#upper_body_only) is `true`.
|
||||
|
||||
* `visibility`: A value in `[0.0, 1.0]` indicating the likelihood of the
|
||||
landmark being visible (present and not occluded) in the image.
|
||||
|
||||
|
@ -185,7 +194,7 @@ install MediaPipe Python package, then learn more in the companion
|
|||
Supported configuration options:
|
||||
|
||||
* [static_image_mode](#static_image_mode)
|
||||
* [upper_body_only](#upper_body_only)
|
||||
* [model_complexity](#model_complexity)
|
||||
* [smooth_landmarks](#smooth_landmarks)
|
||||
* [min_detection_confidence](#min_detection_confidence)
|
||||
* [min_tracking_confidence](#min_tracking_confidence)
|
||||
|
@ -198,7 +207,9 @@ mp_pose = mp.solutions.pose
|
|||
|
||||
# For static images:
|
||||
with mp_pose.Pose(
|
||||
static_image_mode=True, min_detection_confidence=0.5) as pose:
|
||||
static_image_mode=True,
|
||||
model_complexity=2,
|
||||
min_detection_confidence=0.5) as pose:
|
||||
for idx, file in enumerate(file_list):
|
||||
image = cv2.imread(file)
|
||||
image_height, image_width, _ = image.shape
|
||||
|
@ -214,8 +225,6 @@ with mp_pose.Pose(
|
|||
)
|
||||
# Draw pose landmarks on the image.
|
||||
annotated_image = image.copy()
|
||||
# Use mp_pose.UPPER_BODY_POSE_CONNECTIONS for drawing below when
|
||||
# upper_body_only is set to True.
|
||||
mp_drawing.draw_landmarks(
|
||||
annotated_image, results.pose_landmarks, mp_pose.POSE_CONNECTIONS)
|
||||
cv2.imwrite('/tmp/annotated_image' + str(idx) + '.png', annotated_image)
|
||||
|
@ -259,7 +268,7 @@ and the following usage example.
|
|||
|
||||
Supported configuration options:
|
||||
|
||||
* [upperBodyOnly](#upper_body_only)
|
||||
* [modelComplexity](#model_complexity)
|
||||
* [smoothLandmarks](#smooth_landmarks)
|
||||
* [minDetectionConfidence](#min_detection_confidence)
|
||||
* [minTrackingConfidence](#min_tracking_confidence)
|
||||
|
@ -306,7 +315,7 @@ const pose = new Pose({locateFile: (file) => {
|
|||
return `https://cdn.jsdelivr.net/npm/@mediapipe/pose/${file}`;
|
||||
}});
|
||||
pose.setOptions({
|
||||
upperBodyOnly: false,
|
||||
modelComplexity: 1,
|
||||
smoothLandmarks: true,
|
||||
minDetectionConfidence: 0.5,
|
||||
minTrackingConfidence: 0.5
|
||||
|
@ -347,16 +356,6 @@ to visualize its associated subgraphs, please see
|
|||
* iOS target:
|
||||
[`mediapipe/examples/ios/posetrackinggpu:PoseTrackingGpuApp`](http:/mediapipe/examples/ios/posetrackinggpu/BUILD)
|
||||
|
||||
#### Upper-body Only
|
||||
|
||||
* Graph:
|
||||
[`mediapipe/graphs/pose_tracking/upper_body_pose_tracking_gpu.pbtxt`](https://github.com/google/mediapipe/tree/master/mediapipe/graphs/pose_tracking/upper_body_pose_tracking_gpu.pbtxt)
|
||||
* Android target:
|
||||
[(or download prebuilt ARM64 APK)](https://drive.google.com/file/d/1uKc6T7KSuA0Mlq2URi5YookHu0U3yoh_/view?usp=sharing)
|
||||
[`mediapipe/examples/android/src/java/com/google/mediapipe/apps/upperbodyposetrackinggpu:upperbodyposetrackinggpu`](https://github.com/google/mediapipe/tree/master/mediapipe/examples/android/src/java/com/google/mediapipe/apps/upperbodyposetrackinggpu/BUILD)
|
||||
* iOS target:
|
||||
[`mediapipe/examples/ios/upperbodyposetrackinggpu:UpperBodyPoseTrackingGpuApp`](http:/mediapipe/examples/ios/upperbodyposetrackinggpu/BUILD)
|
||||
|
||||
### Desktop
|
||||
|
||||
Please first see general instructions for [desktop](../getting_started/cpp.md)
|
||||
|
@ -375,19 +374,6 @@ on how to build MediaPipe examples.
|
|||
* Target:
|
||||
[`mediapipe/examples/desktop/pose_tracking:pose_tracking_gpu`](https://github.com/google/mediapipe/tree/master/mediapipe/examples/desktop/pose_tracking/BUILD)
|
||||
|
||||
#### Upper-body Only
|
||||
|
||||
* Running on CPU
|
||||
* Graph:
|
||||
[`mediapipe/graphs/pose_tracking/upper_body_pose_tracking_cpu.pbtxt`](https://github.com/google/mediapipe/tree/master/mediapipe/graphs/pose_tracking/upper_body_pose_tracking_cpu.pbtxt)
|
||||
* Target:
|
||||
[`mediapipe/examples/desktop/upper_body_pose_tracking:upper_body_pose_tracking_cpu`](https://github.com/google/mediapipe/tree/master/mediapipe/examples/desktop/upper_body_pose_tracking/BUILD)
|
||||
* Running on GPU
|
||||
* Graph:
|
||||
[`mediapipe/graphs/pose_tracking/upper_body_pose_tracking_gpu.pbtxt`](https://github.com/google/mediapipe/tree/master/mediapipe/graphs/pose_tracking/upper_body_pose_tracking_gpu.pbtxt)
|
||||
* Target:
|
||||
[`mediapipe/examples/desktop/upper_body_pose_tracking:upper_body_pose_tracking_gpu`](https://github.com/google/mediapipe/tree/master/mediapipe/examples/desktop/upper_body_pose_tracking/BUILD)
|
||||
|
||||
## Resources
|
||||
|
||||
* Google AI Blog:
|
||||
|
|
|
@ -16,7 +16,6 @@
|
|||
"mediapipe/examples/ios/objectdetectiongpu/BUILD",
|
||||
"mediapipe/examples/ios/objectdetectiontrackinggpu/BUILD",
|
||||
"mediapipe/examples/ios/posetrackinggpu/BUILD",
|
||||
"mediapipe/examples/ios/upperbodyposetrackinggpu/BUILD",
|
||||
"mediapipe/framework/BUILD",
|
||||
"mediapipe/gpu/BUILD",
|
||||
"mediapipe/objc/BUILD",
|
||||
|
@ -36,7 +35,6 @@
|
|||
"//mediapipe/examples/ios/objectdetectiongpu:ObjectDetectionGpuApp",
|
||||
"//mediapipe/examples/ios/objectdetectiontrackinggpu:ObjectDetectionTrackingGpuApp",
|
||||
"//mediapipe/examples/ios/posetrackinggpu:PoseTrackingGpuApp",
|
||||
"//mediapipe/examples/ios/upperbodyposetrackinggpu:UpperBodyPoseTrackingGpuApp",
|
||||
"//mediapipe/objc:mediapipe_framework_ios"
|
||||
],
|
||||
"optionSet" : {
|
||||
|
@ -105,7 +103,6 @@
|
|||
"mediapipe/examples/ios/objectdetectioncpu",
|
||||
"mediapipe/examples/ios/objectdetectiongpu",
|
||||
"mediapipe/examples/ios/posetrackinggpu",
|
||||
"mediapipe/examples/ios/upperbodyposetrackinggpu",
|
||||
"mediapipe/framework",
|
||||
"mediapipe/framework/deps",
|
||||
"mediapipe/framework/formats",
|
||||
|
|
|
@ -22,7 +22,6 @@
|
|||
"mediapipe/examples/ios/objectdetectiongpu",
|
||||
"mediapipe/examples/ios/objectdetectiontrackinggpu",
|
||||
"mediapipe/examples/ios/posetrackinggpu",
|
||||
"mediapipe/examples/ios/upperbodyposetrackinggpu",
|
||||
"mediapipe/objc"
|
||||
],
|
||||
"projectName" : "Mediapipe",
|
||||
|
|
|
@ -451,8 +451,8 @@ cc_library(
|
|||
)
|
||||
|
||||
cc_library(
|
||||
name = "nonzero_calculator",
|
||||
srcs = ["nonzero_calculator.cc"],
|
||||
name = "non_zero_calculator",
|
||||
srcs = ["non_zero_calculator.cc"],
|
||||
visibility = [
|
||||
"//visibility:public",
|
||||
],
|
||||
|
@ -464,6 +464,21 @@ cc_library(
|
|||
alwayslink = 1,
|
||||
)
|
||||
|
||||
cc_test(
|
||||
name = "non_zero_calculator_test",
|
||||
size = "small",
|
||||
srcs = ["non_zero_calculator_test.cc"],
|
||||
deps = [
|
||||
":non_zero_calculator",
|
||||
"//mediapipe/framework:calculator_framework",
|
||||
"//mediapipe/framework:calculator_runner",
|
||||
"//mediapipe/framework:timestamp",
|
||||
"//mediapipe/framework/port:gtest_main",
|
||||
"//mediapipe/framework/port:status",
|
||||
"//mediapipe/framework/tool:validate_type",
|
||||
],
|
||||
)
|
||||
|
||||
cc_test(
|
||||
name = "mux_calculator_test",
|
||||
srcs = ["mux_calculator_test.cc"],
|
||||
|
@ -665,6 +680,18 @@ cc_library(
|
|||
alwayslink = 1,
|
||||
)
|
||||
|
||||
cc_library(
|
||||
name = "default_side_packet_calculator",
|
||||
srcs = ["default_side_packet_calculator.cc"],
|
||||
visibility = ["//visibility:public"],
|
||||
deps = [
|
||||
"//mediapipe/framework:calculator_framework",
|
||||
"//mediapipe/framework/port:ret_check",
|
||||
"//mediapipe/framework/port:status",
|
||||
],
|
||||
alwayslink = 1,
|
||||
)
|
||||
|
||||
cc_library(
|
||||
name = "side_packet_to_stream_calculator",
|
||||
srcs = ["side_packet_to_stream_calculator.cc"],
|
||||
|
|
103
mediapipe/calculators/core/default_side_packet_calculator.cc
Normal file
|
@ -0,0 +1,103 @@
|
|||
// Copyright 2019 The MediaPipe Authors.
|
||||
//
|
||||
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||
// you may not use this file except in compliance with the License.
|
||||
// You may obtain a copy of the License at
|
||||
//
|
||||
// http://www.apache.org/licenses/LICENSE-2.0
|
||||
//
|
||||
// Unless required by applicable law or agreed to in writing, software
|
||||
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
// See the License for the specific language governing permissions and
|
||||
// limitations under the License.
|
||||
|
||||
#include "mediapipe/framework/calculator_framework.h"
|
||||
#include "mediapipe/framework/port/ret_check.h"
|
||||
#include "mediapipe/framework/port/status.h"
|
||||
|
||||
namespace mediapipe {
|
||||
|
||||
namespace {
|
||||
|
||||
constexpr char kOptionalValueTag[] = "OPTIONAL_VALUE";
|
||||
constexpr char kDefaultValueTag[] = "DEFAULT_VALUE";
|
||||
constexpr char kValueTag[] = "VALUE";
|
||||
|
||||
} // namespace
|
||||
|
||||
// Outputs side packet default value if optional value is not provided.
|
||||
//
|
||||
// This calculator utilizes the fact that MediaPipe automatically removes
|
||||
// optional side packets of the calculator configuration (i.e. OPTIONAL_VALUE).
|
||||
// And if it happens - returns default value, otherwise - returns optional
|
||||
// value.
|
||||
//
|
||||
// Input:
|
||||
// OPTIONAL_VALUE (optional) - AnyType (but same type as DEFAULT_VALUE)
|
||||
// Optional side packet value that is outputted by the calculator as is if
|
||||
// provided.
|
||||
//
|
||||
// DEFAULT_VALUE - AnyType
|
||||
// Default side pack value that is outputted by the calculator if
|
||||
// OPTIONAL_VALUE is not provided.
|
||||
//
|
||||
// Output:
|
||||
// VALUE - AnyType (but same type as DEFAULT_VALUE)
|
||||
// Either OPTIONAL_VALUE (if provided) or DEFAULT_VALUE (otherwise).
|
||||
//
|
||||
// Usage example:
|
||||
// node {
|
||||
// calculator: "DefaultSidePacketCalculator"
|
||||
// input_side_packet: "OPTIONAL_VALUE:segmentation_mask_enabled_optional"
|
||||
// input_side_packet: "DEFAULT_VALUE:segmentation_mask_enabled_default"
|
||||
// output_side_packet: "VALUE:segmentation_mask_enabled"
|
||||
// }
|
||||
class DefaultSidePacketCalculator : public CalculatorBase {
|
||||
public:
|
||||
static absl::Status GetContract(CalculatorContract* cc);
|
||||
absl::Status Open(CalculatorContext* cc) override;
|
||||
absl::Status Process(CalculatorContext* cc) override;
|
||||
};
|
||||
REGISTER_CALCULATOR(DefaultSidePacketCalculator);
|
||||
|
||||
absl::Status DefaultSidePacketCalculator::GetContract(CalculatorContract* cc) {
|
||||
RET_CHECK(cc->InputSidePackets().HasTag(kDefaultValueTag))
|
||||
<< "Default value must be provided";
|
||||
cc->InputSidePackets().Tag(kDefaultValueTag).SetAny();
|
||||
|
||||
// Optional input side packet can be unspecified. In this case MediaPipe will
|
||||
// remove it from the calculator config.
|
||||
if (cc->InputSidePackets().HasTag(kOptionalValueTag)) {
|
||||
cc->InputSidePackets()
|
||||
.Tag(kOptionalValueTag)
|
||||
.SetSameAs(&cc->InputSidePackets().Tag(kDefaultValueTag));
|
||||
}
|
||||
|
||||
RET_CHECK(cc->OutputSidePackets().HasTag(kValueTag));
|
||||
cc->OutputSidePackets().Tag(kValueTag).SetSameAs(
|
||||
&cc->InputSidePackets().Tag(kDefaultValueTag));
|
||||
|
||||
return absl::OkStatus();
|
||||
}
|
||||
|
||||
absl::Status DefaultSidePacketCalculator::Open(CalculatorContext* cc) {
|
||||
// If optional value is provided it is returned as the calculator output.
|
||||
if (cc->InputSidePackets().HasTag(kOptionalValueTag)) {
|
||||
auto& packet = cc->InputSidePackets().Tag(kOptionalValueTag);
|
||||
cc->OutputSidePackets().Tag(kValueTag).Set(packet);
|
||||
return absl::OkStatus();
|
||||
}
|
||||
|
||||
// If no optional value
|
||||
auto& packet = cc->InputSidePackets().Tag(kDefaultValueTag);
|
||||
cc->OutputSidePackets().Tag(kValueTag).Set(packet);
|
||||
|
||||
return absl::OkStatus();
|
||||
}
|
||||
|
||||
absl::Status DefaultSidePacketCalculator::Process(CalculatorContext* cc) {
|
||||
return absl::OkStatus();
|
||||
}
|
||||
|
||||
} // namespace mediapipe
|
|
@ -23,14 +23,26 @@ namespace api2 {
|
|||
class NonZeroCalculator : public Node {
|
||||
public:
|
||||
static constexpr Input<int>::SideFallback kIn{"INPUT"};
|
||||
static constexpr Output<int> kOut{"OUTPUT"};
|
||||
static constexpr Output<int>::Optional kOut{"OUTPUT"};
|
||||
static constexpr Output<bool>::Optional kBooleanOut{"OUTPUT_BOOL"};
|
||||
|
||||
MEDIAPIPE_NODE_CONTRACT(kIn, kOut);
|
||||
MEDIAPIPE_NODE_CONTRACT(kIn, kOut, kBooleanOut);
|
||||
|
||||
absl::Status UpdateContract(CalculatorContract* cc) {
|
||||
RET_CHECK(kOut(cc).IsConnected() || kBooleanOut(cc).IsConnected())
|
||||
<< "At least one output stream is expected.";
|
||||
return absl::OkStatus();
|
||||
}
|
||||
|
||||
absl::Status Process(CalculatorContext* cc) final {
|
||||
if (!kIn(cc).IsEmpty()) {
|
||||
auto output = std::make_unique<int>((*kIn(cc) != 0) ? 1 : 0);
|
||||
kOut(cc).Send(std::move(output));
|
||||
bool isNonZero = *kIn(cc) != 0;
|
||||
if (kOut(cc).IsConnected()) {
|
||||
kOut(cc).Send(std::make_unique<int>(isNonZero ? 1 : 0));
|
||||
}
|
||||
if (kBooleanOut(cc).IsConnected()) {
|
||||
kBooleanOut(cc).Send(std::make_unique<bool>(isNonZero));
|
||||
}
|
||||
}
|
||||
return absl::OkStatus();
|
||||
}
|
93
mediapipe/calculators/core/non_zero_calculator_test.cc
Normal file
|
@ -0,0 +1,93 @@
|
|||
// Copyright 2021 The MediaPipe Authors.
|
||||
//
|
||||
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||
// you may not use this file except in compliance with the License.
|
||||
// You may obtain a copy of the License at
|
||||
//
|
||||
// http://www.apache.org/licenses/LICENSE-2.0
|
||||
//
|
||||
// Unless required by applicable law or agreed to in writing, software
|
||||
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
// See the License for the specific language governing permissions and
|
||||
// limitations under the License.
|
||||
|
||||
#include "mediapipe/framework/calculator_framework.h"
|
||||
#include "mediapipe/framework/calculator_runner.h"
|
||||
#include "mediapipe/framework/port/canonical_errors.h"
|
||||
#include "mediapipe/framework/port/gmock.h"
|
||||
#include "mediapipe/framework/port/gtest.h"
|
||||
#include "mediapipe/framework/port/status.h"
|
||||
#include "mediapipe/framework/port/status_matchers.h"
|
||||
#include "mediapipe/framework/timestamp.h"
|
||||
#include "mediapipe/framework/tool/validate_type.h"
|
||||
|
||||
namespace mediapipe {
|
||||
|
||||
class NonZeroCalculatorTest : public ::testing::Test {
|
||||
protected:
|
||||
NonZeroCalculatorTest()
|
||||
: runner_(
|
||||
R"pb(
|
||||
calculator: "NonZeroCalculator"
|
||||
input_stream: "INPUT:input"
|
||||
output_stream: "OUTPUT:output"
|
||||
output_stream: "OUTPUT_BOOL:output_bool"
|
||||
)pb") {}
|
||||
|
||||
void SetInput(const std::vector<int>& inputs) {
|
||||
int timestamp = 0;
|
||||
for (const auto input : inputs) {
|
||||
runner_.MutableInputs()
|
||||
->Get("INPUT", 0)
|
||||
.packets.push_back(MakePacket<int>(input).At(Timestamp(timestamp++)));
|
||||
}
|
||||
}
|
||||
|
||||
std::vector<int> GetOutput() {
|
||||
std::vector<int> result;
|
||||
for (const auto output : runner_.Outputs().Get("OUTPUT", 0).packets) {
|
||||
result.push_back(output.Get<int>());
|
||||
}
|
||||
return result;
|
||||
}
|
||||
|
||||
std::vector<bool> GetOutputBool() {
|
||||
std::vector<bool> result;
|
||||
for (const auto output : runner_.Outputs().Get("OUTPUT_BOOL", 0).packets) {
|
||||
result.push_back(output.Get<bool>());
|
||||
}
|
||||
return result;
|
||||
}
|
||||
|
||||
CalculatorRunner runner_;
|
||||
};
|
||||
|
||||
TEST_F(NonZeroCalculatorTest, ProducesZeroOutputForZeroInput) {
|
||||
SetInput({0});
|
||||
|
||||
MP_ASSERT_OK(runner_.Run());
|
||||
|
||||
EXPECT_THAT(GetOutput(), ::testing::ElementsAre(0));
|
||||
EXPECT_THAT(GetOutputBool(), ::testing::ElementsAre(false));
|
||||
}
|
||||
|
||||
TEST_F(NonZeroCalculatorTest, ProducesNonZeroOutputForNonZeroInput) {
|
||||
SetInput({1, 2, 3, -4, 5});
|
||||
|
||||
MP_ASSERT_OK(runner_.Run());
|
||||
|
||||
EXPECT_THAT(GetOutput(), ::testing::ElementsAre(1, 1, 1, 1, 1));
|
||||
EXPECT_THAT(GetOutputBool(),
|
||||
::testing::ElementsAre(true, true, true, true, true));
|
||||
}
|
||||
|
||||
TEST_F(NonZeroCalculatorTest, SwitchesBetweenNonZeroAndZeroOutput) {
|
||||
SetInput({1, 0, 3, 0, 5});
|
||||
MP_ASSERT_OK(runner_.Run());
|
||||
EXPECT_THAT(GetOutput(), ::testing::ElementsAre(1, 0, 1, 0, 1));
|
||||
EXPECT_THAT(GetOutputBool(),
|
||||
::testing::ElementsAre(true, false, true, false, true));
|
||||
}
|
||||
|
||||
} // namespace mediapipe
|
|
@ -285,7 +285,7 @@ absl::Status ImageCroppingCalculator::RenderGpu(CalculatorContext* cc) {
|
|||
|
||||
// Run cropping shader on GPU.
|
||||
{
|
||||
gpu_helper_.BindFramebuffer(dst_tex); // GL_TEXTURE0
|
||||
gpu_helper_.BindFramebuffer(dst_tex);
|
||||
|
||||
glActiveTexture(GL_TEXTURE1);
|
||||
glBindTexture(src_tex.target(), src_tex.name());
|
||||
|
|
|
@ -546,7 +546,7 @@ absl::Status ImageTransformationCalculator::RenderGpu(CalculatorContext* cc) {
|
|||
auto dst = gpu_helper_.CreateDestinationTexture(output_width, output_height,
|
||||
input.format());
|
||||
|
||||
gpu_helper_.BindFramebuffer(dst); // GL_TEXTURE0
|
||||
gpu_helper_.BindFramebuffer(dst);
|
||||
glActiveTexture(GL_TEXTURE1);
|
||||
glBindTexture(src1.target(), src1.name());
|
||||
|
||||
|
|
|
@ -209,6 +209,9 @@ absl::Status RecolorCalculator::Close(CalculatorContext* cc) {
|
|||
|
||||
absl::Status RecolorCalculator::RenderCpu(CalculatorContext* cc) {
|
||||
if (cc->Inputs().Tag(kMaskCpuTag).IsEmpty()) {
|
||||
cc->Outputs()
|
||||
.Tag(kImageFrameTag)
|
||||
.AddPacket(cc->Inputs().Tag(kImageFrameTag).Value());
|
||||
return absl::OkStatus();
|
||||
}
|
||||
// Get inputs and setup output.
|
||||
|
@ -270,6 +273,9 @@ absl::Status RecolorCalculator::RenderCpu(CalculatorContext* cc) {
|
|||
|
||||
absl::Status RecolorCalculator::RenderGpu(CalculatorContext* cc) {
|
||||
if (cc->Inputs().Tag(kMaskGpuTag).IsEmpty()) {
|
||||
cc->Outputs()
|
||||
.Tag(kGpuBufferTag)
|
||||
.AddPacket(cc->Inputs().Tag(kGpuBufferTag).Value());
|
||||
return absl::OkStatus();
|
||||
}
|
||||
#if !MEDIAPIPE_DISABLE_GPU
|
||||
|
@ -287,7 +293,7 @@ absl::Status RecolorCalculator::RenderGpu(CalculatorContext* cc) {
|
|||
|
||||
// Run recolor shader on GPU.
|
||||
{
|
||||
gpu_helper_.BindFramebuffer(dst_tex); // GL_TEXTURE0
|
||||
gpu_helper_.BindFramebuffer(dst_tex);
|
||||
|
||||
glActiveTexture(GL_TEXTURE1);
|
||||
glBindTexture(img_tex.target(), img_tex.name());
|
||||
|
|
|
@ -323,7 +323,7 @@ absl::Status SetAlphaCalculator::RenderGpu(CalculatorContext* cc) {
|
|||
const auto& alpha_mask =
|
||||
cc->Inputs().Tag(kInputAlphaTagGpu).Get<mediapipe::GpuBuffer>();
|
||||
auto alpha_texture = gpu_helper_.CreateSourceTexture(alpha_mask);
|
||||
gpu_helper_.BindFramebuffer(output_texture); // GL_TEXTURE0
|
||||
gpu_helper_.BindFramebuffer(output_texture);
|
||||
glActiveTexture(GL_TEXTURE1);
|
||||
glBindTexture(GL_TEXTURE_2D, input_texture.name());
|
||||
glActiveTexture(GL_TEXTURE2);
|
||||
|
@ -335,7 +335,7 @@ absl::Status SetAlphaCalculator::RenderGpu(CalculatorContext* cc) {
|
|||
glBindTexture(GL_TEXTURE_2D, 0);
|
||||
alpha_texture.Release();
|
||||
} else {
|
||||
gpu_helper_.BindFramebuffer(output_texture); // GL_TEXTURE0
|
||||
gpu_helper_.BindFramebuffer(output_texture);
|
||||
glActiveTexture(GL_TEXTURE1);
|
||||
glBindTexture(GL_TEXTURE_2D, input_texture.name());
|
||||
GlRender(cc); // use value from options
|
||||
|
|
|
@ -490,6 +490,7 @@ cc_library(
|
|||
"//mediapipe/framework/port:statusor",
|
||||
"//mediapipe/framework:calculator_framework",
|
||||
"//mediapipe/framework:port",
|
||||
"//mediapipe/gpu:gpu_origin_cc_proto",
|
||||
] + select({
|
||||
"//mediapipe/gpu:disable_gpu": [],
|
||||
"//conditions:default": [":image_to_tensor_calculator_gpu_deps"],
|
||||
|
@ -526,6 +527,7 @@ mediapipe_proto_library(
|
|||
deps = [
|
||||
"//mediapipe/framework:calculator_options_proto",
|
||||
"//mediapipe/framework:calculator_proto",
|
||||
"//mediapipe/gpu:gpu_origin_proto",
|
||||
],
|
||||
)
|
||||
|
||||
|
|
|
@ -31,6 +31,7 @@
|
|||
#include "mediapipe/framework/port/ret_check.h"
|
||||
#include "mediapipe/framework/port/status.h"
|
||||
#include "mediapipe/framework/port/statusor.h"
|
||||
#include "mediapipe/gpu/gpu_origin.pb.h"
|
||||
|
||||
#if !MEDIAPIPE_DISABLE_GPU
|
||||
#include "mediapipe/gpu/gpu_buffer.h"
|
||||
|
@ -236,7 +237,7 @@ class ImageToTensorCalculator : public Node {
|
|||
}
|
||||
|
||||
private:
|
||||
bool DoesInputStartAtBottom() {
|
||||
bool DoesGpuInputStartAtBottom() {
|
||||
return options_.gpu_origin() != mediapipe::GpuOrigin_Mode_TOP_LEFT;
|
||||
}
|
||||
|
||||
|
@ -290,11 +291,11 @@ class ImageToTensorCalculator : public Node {
|
|||
#elif MEDIAPIPE_OPENGL_ES_VERSION >= MEDIAPIPE_OPENGL_ES_31
|
||||
ASSIGN_OR_RETURN(gpu_converter_,
|
||||
CreateImageToGlBufferTensorConverter(
|
||||
cc, DoesInputStartAtBottom(), GetBorderMode()));
|
||||
cc, DoesGpuInputStartAtBottom(), GetBorderMode()));
|
||||
#else
|
||||
ASSIGN_OR_RETURN(gpu_converter_,
|
||||
CreateImageToGlTextureTensorConverter(
|
||||
cc, DoesInputStartAtBottom(), GetBorderMode()));
|
||||
cc, DoesGpuInputStartAtBottom(), GetBorderMode()));
|
||||
#endif // MEDIAPIPE_METAL_ENABLED
|
||||
#endif // !MEDIAPIPE_DISABLE_GPU
|
||||
}
|
||||
|
|
|
@ -17,20 +17,7 @@ syntax = "proto2";
|
|||
package mediapipe;
|
||||
|
||||
import "mediapipe/framework/calculator.proto";
|
||||
|
||||
message GpuOrigin {
|
||||
enum Mode {
|
||||
DEFAULT = 0;
|
||||
|
||||
// OpenGL: bottom-left origin
|
||||
// Metal : top-left origin
|
||||
CONVENTIONAL = 1;
|
||||
|
||||
// OpenGL: top-left origin
|
||||
// Metal : top-left origin
|
||||
TOP_LEFT = 2;
|
||||
}
|
||||
}
|
||||
import "mediapipe/gpu/gpu_origin.proto";
|
||||
|
||||
message ImageToTensorCalculatorOptions {
|
||||
extend mediapipe.CalculatorOptions {
|
||||
|
|
|
@ -317,7 +317,8 @@ absl::Status InferenceCalculatorGlImpl::LoadModel(CalculatorContext* cc) {
|
|||
absl::Status InferenceCalculatorGlImpl::LoadDelegate(CalculatorContext* cc) {
|
||||
// Configure and create the delegate.
|
||||
TfLiteGpuDelegateOptions options = TfLiteGpuDelegateOptionsDefault();
|
||||
options.compile_options.precision_loss_allowed = 1;
|
||||
options.compile_options.precision_loss_allowed =
|
||||
allow_precision_loss_ ? 1 : 0;
|
||||
options.compile_options.preferred_gl_object_type =
|
||||
TFLITE_GL_OBJECT_TYPE_FASTEST;
|
||||
options.compile_options.dynamic_batch_enabled = 0;
|
||||
|
|
|
@ -97,6 +97,7 @@ class InferenceCalculatorMetalImpl
|
|||
Packet<TfLiteModelPtr> model_packet_;
|
||||
std::unique_ptr<tflite::Interpreter> interpreter_;
|
||||
TfLiteDelegatePtr delegate_;
|
||||
bool allow_precision_loss_ = false;
|
||||
|
||||
#if MEDIAPIPE_TFLITE_METAL_INFERENCE
|
||||
MPPMetalHelper* gpu_helper_ = nullptr;
|
||||
|
@ -122,6 +123,9 @@ absl::Status InferenceCalculatorMetalImpl::UpdateContract(
|
|||
}
|
||||
|
||||
absl::Status InferenceCalculatorMetalImpl::Open(CalculatorContext* cc) {
|
||||
const auto& options = cc->Options<::mediapipe::InferenceCalculatorOptions>();
|
||||
allow_precision_loss_ = options.delegate().gpu().allow_precision_loss();
|
||||
|
||||
MP_RETURN_IF_ERROR(LoadModel(cc));
|
||||
|
||||
gpu_helper_ = [[MPPMetalHelper alloc] initWithCalculatorContext:cc];
|
||||
|
@ -222,7 +226,7 @@ absl::Status InferenceCalculatorMetalImpl::LoadDelegate(CalculatorContext* cc) {
|
|||
|
||||
// Configure and create the delegate.
|
||||
TFLGpuDelegateOptions options;
|
||||
options.allow_precision_loss = true;
|
||||
options.allow_precision_loss = allow_precision_loss_;
|
||||
options.wait_type = TFLGpuDelegateWaitType::TFLGpuDelegateWaitTypeDoNotWait;
|
||||
delegate_ =
|
||||
TfLiteDelegatePtr(TFLGpuDelegateCreate(&options), &TFLGpuDelegateDelete);
|
||||
|
@ -239,7 +243,9 @@ absl::Status InferenceCalculatorMetalImpl::LoadDelegate(CalculatorContext* cc) {
|
|||
tensor->dims->data + tensor->dims->size};
|
||||
dims.back() = RoundUp(dims.back(), 4);
|
||||
gpu_buffers_in_.emplace_back(absl::make_unique<Tensor>(
|
||||
Tensor::ElementType::kFloat16, Tensor::Shape{dims}));
|
||||
allow_precision_loss_ ? Tensor::ElementType::kFloat16
|
||||
: Tensor::ElementType::kFloat32,
|
||||
Tensor::Shape{dims}));
|
||||
auto buffer_view =
|
||||
gpu_buffers_in_[i]->GetMtlBufferWriteView(gpu_helper_.mtlDevice);
|
||||
RET_CHECK_EQ(TFLGpuDelegateBindMetalBufferToTensor(
|
||||
|
@ -261,7 +267,9 @@ absl::Status InferenceCalculatorMetalImpl::LoadDelegate(CalculatorContext* cc) {
|
|||
output_shapes_[i] = {dims};
|
||||
dims.back() = RoundUp(dims.back(), 4);
|
||||
gpu_buffers_out_.emplace_back(absl::make_unique<Tensor>(
|
||||
Tensor::ElementType::kFloat16, Tensor::Shape{dims}));
|
||||
allow_precision_loss_ ? Tensor::ElementType::kFloat16
|
||||
: Tensor::ElementType::kFloat32,
|
||||
Tensor::Shape{dims}));
|
||||
RET_CHECK_EQ(TFLGpuDelegateBindMetalBufferToTensor(
|
||||
delegate_.get(), output_indices[i],
|
||||
gpu_buffers_out_[i]
|
||||
|
@ -271,17 +279,19 @@ absl::Status InferenceCalculatorMetalImpl::LoadDelegate(CalculatorContext* cc) {
|
|||
}
|
||||
|
||||
// Create converter for GPU input.
|
||||
converter_to_BPHWC4_ = [[TFLBufferConvert alloc] initWithDevice:device
|
||||
isFloat16:true
|
||||
convertToPBHWC4:true];
|
||||
converter_to_BPHWC4_ =
|
||||
[[TFLBufferConvert alloc] initWithDevice:device
|
||||
isFloat16:allow_precision_loss_
|
||||
convertToPBHWC4:true];
|
||||
if (converter_to_BPHWC4_ == nil) {
|
||||
return mediapipe::InternalError(
|
||||
"Error initializating input buffer converter");
|
||||
}
|
||||
// Create converter for GPU output.
|
||||
converter_from_BPHWC4_ = [[TFLBufferConvert alloc] initWithDevice:device
|
||||
isFloat16:true
|
||||
convertToPBHWC4:false];
|
||||
converter_from_BPHWC4_ =
|
||||
[[TFLBufferConvert alloc] initWithDevice:device
|
||||
isFloat16:allow_precision_loss_
|
||||
convertToPBHWC4:false];
|
||||
if (converter_from_BPHWC4_ == nil) {
|
||||
return absl::InternalError("Error initializating output buffer converter");
|
||||
}
|
||||
|
|
|
@ -89,7 +89,8 @@ absl::Status TensorsToClassificationCalculator::Open(CalculatorContext* cc) {
|
|||
ASSIGN_OR_RETURN(string_path,
|
||||
PathToResourceAsFile(options_.label_map_path()));
|
||||
std::string label_map_string;
|
||||
MP_RETURN_IF_ERROR(file::GetContents(string_path, &label_map_string));
|
||||
MP_RETURN_IF_ERROR(
|
||||
mediapipe::GetResourceContents(string_path, &label_map_string));
|
||||
|
||||
std::istringstream stream(label_map_string);
|
||||
std::string line;
|
||||
|
@ -98,6 +99,14 @@ absl::Status TensorsToClassificationCalculator::Open(CalculatorContext* cc) {
|
|||
label_map_[i++] = line;
|
||||
}
|
||||
label_map_loaded_ = true;
|
||||
} else if (options_.has_label_map()) {
|
||||
for (int i = 0; i < options_.label_map().entries_size(); ++i) {
|
||||
const auto& entry = options_.label_map().entries(i);
|
||||
RET_CHECK(!label_map_.contains(entry.id()))
|
||||
<< "Duplicate id found: " << entry.id();
|
||||
label_map_[entry.id()] = entry.label();
|
||||
}
|
||||
label_map_loaded_ = true;
|
||||
}
|
||||
|
||||
return absl::OkStatus();
|
||||
|
|
|
@ -25,6 +25,14 @@ message TensorsToClassificationCalculatorOptions {
|
|||
optional TensorsToClassificationCalculatorOptions ext = 335742638;
|
||||
}
|
||||
|
||||
message LabelMap {
|
||||
message Entry {
|
||||
optional int32 id = 1;
|
||||
optional string label = 2;
|
||||
}
|
||||
repeated Entry entries = 1;
|
||||
}
|
||||
|
||||
// Score threshold for perserving the class.
|
||||
optional float min_score_threshold = 1;
|
||||
// Number of highest scoring labels to output. If top_k is not positive then
|
||||
|
@ -32,6 +40,10 @@ message TensorsToClassificationCalculatorOptions {
|
|||
optional int32 top_k = 2;
|
||||
// Path to a label map file for getting the actual name of class ids.
|
||||
optional string label_map_path = 3;
|
||||
// Label map. (Can be used instead of label_map_path.)
|
||||
// NOTE: "label_map_path", if specified, takes precedence over "label_map".
|
||||
optional LabelMap label_map = 5;
|
||||
|
||||
// Whether the input is a single float for binary classification.
|
||||
// When true, only a single float is expected in the input tensor and the
|
||||
// label map, if provided, is expected to have exactly two labels.
|
||||
|
|
|
@ -115,6 +115,41 @@ TEST_F(TensorsToClassificationCalculatorTest, CorrectOutputWithLabelMapPath) {
|
|||
}
|
||||
}
|
||||
|
||||
TEST_F(TensorsToClassificationCalculatorTest, CorrectOutputWithLabelMap) {
|
||||
mediapipe::CalculatorRunner runner(ParseTextProtoOrDie<Node>(R"pb(
|
||||
calculator: "TensorsToClassificationCalculator"
|
||||
input_stream: "TENSORS:tensors"
|
||||
output_stream: "CLASSIFICATIONS:classifications"
|
||||
options {
|
||||
[mediapipe.TensorsToClassificationCalculatorOptions.ext] {
|
||||
label_map {
|
||||
entries { id: 0, label: "ClassA" }
|
||||
entries { id: 1, label: "ClassB" }
|
||||
entries { id: 2, label: "ClassC" }
|
||||
}
|
||||
}
|
||||
}
|
||||
)pb"));
|
||||
|
||||
BuildGraph(&runner, {0, 0.5, 1});
|
||||
MP_ASSERT_OK(runner.Run());
|
||||
|
||||
const auto& output_packets_ = runner.Outputs().Tag("CLASSIFICATIONS").packets;
|
||||
|
||||
EXPECT_EQ(1, output_packets_.size());
|
||||
|
||||
const auto& classification_list =
|
||||
output_packets_[0].Get<ClassificationList>();
|
||||
EXPECT_EQ(3, classification_list.classification_size());
|
||||
|
||||
// Verify that the label field is set.
|
||||
for (int i = 0; i < classification_list.classification_size(); ++i) {
|
||||
EXPECT_EQ(i, classification_list.classification(i).index());
|
||||
EXPECT_EQ(i * 0.5, classification_list.classification(i).score());
|
||||
ASSERT_TRUE(classification_list.classification(i).has_label());
|
||||
}
|
||||
}
|
||||
|
||||
TEST_F(TensorsToClassificationCalculatorTest,
|
||||
CorrectOutputWithLabelMinScoreThreshold) {
|
||||
mediapipe::CalculatorRunner runner(ParseTextProtoOrDie<Node>(R"pb(
|
||||
|
|
|
@ -34,15 +34,28 @@ constexpr char kTensor[] = "TENSOR";
|
|||
} // namespace
|
||||
|
||||
// Input:
|
||||
// Tensor of type DT_FLOAT, with values between 0-255 (SRGB or GRAY8). The
|
||||
// shape can be HxWx{3,1} or simply HxW.
|
||||
// Tensor of type DT_FLOAT or DT_UINT8, with values between 0-255
|
||||
// (SRGB or GRAY8). The shape can be HxWx{3,1} or simply HxW.
|
||||
//
|
||||
// Optionally supports a scale factor that can scale 0-1 value ranges to 0-255.
|
||||
// For DT_FLOAT tensors, optionally supports a scale factor that can scale 0-1
|
||||
// value ranges to 0-255.
|
||||
//
|
||||
// Output:
|
||||
// ImageFrame containing the values of the tensor cast as uint8 (SRGB or GRAY8)
|
||||
//
|
||||
// Possible extensions: support other input ranges, maybe 4D tensors.
|
||||
//
|
||||
// Example:
|
||||
// node {
|
||||
// calculator: "TensorToImageFrameCalculator"
|
||||
// input_stream: "TENSOR:3d_float_tensor"
|
||||
// output_stream: "IMAGE:image_frame"
|
||||
// options {
|
||||
// [mediapipe.TensorToImageFrameCalculatorOptions.ext] {
|
||||
// scale_factor: 1.0 # set to 255.0 for [0,1] -> [0,255] scaling
|
||||
// }
|
||||
// }
|
||||
// }
|
||||
class TensorToImageFrameCalculator : public CalculatorBase {
|
||||
public:
|
||||
static absl::Status GetContract(CalculatorContract* cc);
|
||||
|
@ -57,8 +70,8 @@ class TensorToImageFrameCalculator : public CalculatorBase {
|
|||
REGISTER_CALCULATOR(TensorToImageFrameCalculator);
|
||||
|
||||
absl::Status TensorToImageFrameCalculator::GetContract(CalculatorContract* cc) {
|
||||
RET_CHECK_EQ(cc->Inputs().NumEntries(), 1)
|
||||
<< "Only one input stream is supported.";
|
||||
RET_CHECK_EQ(cc->Outputs().NumEntries(), 1)
|
||||
<< "Only one output stream is supported.";
|
||||
RET_CHECK_EQ(cc->Inputs().NumEntries(), 1)
|
||||
<< "One input stream must be provided.";
|
||||
RET_CHECK(cc->Inputs().HasTag(kTensor))
|
||||
|
@ -91,29 +104,44 @@ absl::Status TensorToImageFrameCalculator::Process(CalculatorContext* cc) {
|
|||
RET_CHECK_EQ(depth, 3) << "Output tensor depth must be 3 or 1.";
|
||||
}
|
||||
}
|
||||
const int32 total_size =
|
||||
input_tensor.dim_size(0) * input_tensor.dim_size(1) * depth;
|
||||
std::unique_ptr<uint8[]> buffer(new uint8[total_size]);
|
||||
auto data = input_tensor.flat<float>().data();
|
||||
for (int i = 0; i < total_size; ++i) {
|
||||
float d = scale_factor_ * data[i];
|
||||
if (d < 0) d = 0;
|
||||
if (d > 255) d = 255;
|
||||
buffer[i] = d;
|
||||
int32 height = input_tensor.dim_size(0);
|
||||
int32 width = input_tensor.dim_size(1);
|
||||
auto format = (depth == 3 ? ImageFormat::SRGB : ImageFormat::GRAY8);
|
||||
const int32 total_size = height * width * depth;
|
||||
|
||||
::std::unique_ptr<const ImageFrame> output;
|
||||
if (input_tensor.dtype() == tensorflow::DT_FLOAT) {
|
||||
// Allocate buffer with alignments.
|
||||
std::unique_ptr<uint8_t[]> buffer(
|
||||
new (std::align_val_t(EIGEN_MAX_ALIGN_BYTES)) uint8_t[total_size]);
|
||||
auto data = input_tensor.flat<float>().data();
|
||||
for (int i = 0; i < total_size; ++i) {
|
||||
float d = scale_factor_ * data[i];
|
||||
if (d < 0) d = 0;
|
||||
if (d > 255) d = 255;
|
||||
buffer[i] = d;
|
||||
}
|
||||
output = ::absl::make_unique<ImageFrame>(format, width, height,
|
||||
width * depth, buffer.release());
|
||||
} else if (input_tensor.dtype() == tensorflow::DT_UINT8) {
|
||||
if (scale_factor_ != 1.0) {
|
||||
return absl::InvalidArgumentError("scale_factor_ given for uint8 tensor");
|
||||
}
|
||||
// tf::Tensor has internally ref-counted buffer. The following code make the
|
||||
// ImageFrame own the copied Tensor through the deleter, which increases
|
||||
// the refcount of the buffer and allow us to use the shared buffer as the
|
||||
// image. This allows us to create an ImageFrame object without copying
|
||||
// buffer. const ImageFrame prevents the buffer from being modified later.
|
||||
auto copy = new tf::Tensor(input_tensor);
|
||||
output = ::absl::make_unique<const ImageFrame>(
|
||||
format, width, height, width * depth, copy->flat<uint8_t>().data(),
|
||||
[copy](uint8*) { delete copy; });
|
||||
} else {
|
||||
return absl::InvalidArgumentError(
|
||||
absl::StrCat("Expected float or uint8 tensor, received ",
|
||||
DataTypeString(input_tensor.dtype())));
|
||||
}
|
||||
|
||||
::std::unique_ptr<ImageFrame> output;
|
||||
if (depth == 3) {
|
||||
output = ::absl::make_unique<ImageFrame>(
|
||||
ImageFormat::SRGB, input_tensor.dim_size(1), input_tensor.dim_size(0),
|
||||
input_tensor.dim_size(1) * 3, buffer.release());
|
||||
} else if (depth == 1) {
|
||||
output = ::absl::make_unique<ImageFrame>(
|
||||
ImageFormat::GRAY8, input_tensor.dim_size(1), input_tensor.dim_size(0),
|
||||
input_tensor.dim_size(1), buffer.release());
|
||||
} else {
|
||||
return absl::InvalidArgumentError("Unrecognized image depth.");
|
||||
}
|
||||
cc->Outputs().Tag(kImage).Add(output.release(), cc->InputTimestamp());
|
||||
|
||||
return absl::OkStatus();
|
||||
|
|
|
@ -29,6 +29,7 @@ constexpr char kImage[] = "IMAGE";
|
|||
|
||||
} // namespace
|
||||
|
||||
template <class TypeParam>
|
||||
class TensorToImageFrameCalculatorTest : public ::testing::Test {
|
||||
protected:
|
||||
void SetUpRunner() {
|
||||
|
@ -42,14 +43,20 @@ class TensorToImageFrameCalculatorTest : public ::testing::Test {
|
|||
std::unique_ptr<CalculatorRunner> runner_;
|
||||
};
|
||||
|
||||
TEST_F(TensorToImageFrameCalculatorTest, Converts3DTensorToImageFrame) {
|
||||
SetUpRunner();
|
||||
using TensorToImageFrameCalculatorTestTypes = ::testing::Types<float, uint8_t>;
|
||||
TYPED_TEST_CASE(TensorToImageFrameCalculatorTest,
|
||||
TensorToImageFrameCalculatorTestTypes);
|
||||
|
||||
TYPED_TEST(TensorToImageFrameCalculatorTest, Converts3DTensorToImageFrame) {
|
||||
// TYPED_TEST requires explicit "this->"
|
||||
this->SetUpRunner();
|
||||
auto& runner = this->runner_;
|
||||
constexpr int kWidth = 16;
|
||||
constexpr int kHeight = 8;
|
||||
const tf::TensorShape tensor_shape(
|
||||
std::vector<tf::int64>{kHeight, kWidth, 3});
|
||||
auto tensor = absl::make_unique<tf::Tensor>(tf::DT_FLOAT, tensor_shape);
|
||||
auto tensor_vec = tensor->flat<float>().data();
|
||||
const tf::TensorShape tensor_shape{kHeight, kWidth, 3};
|
||||
auto tensor = absl::make_unique<tf::Tensor>(
|
||||
tf::DataTypeToEnum<TypeParam>::v(), tensor_shape);
|
||||
auto tensor_vec = tensor->template flat<TypeParam>().data();
|
||||
|
||||
// Writing sequence of integers as floats which we want back (as they were
|
||||
// written).
|
||||
|
@ -58,15 +65,16 @@ TEST_F(TensorToImageFrameCalculatorTest, Converts3DTensorToImageFrame) {
|
|||
}
|
||||
|
||||
const int64 time = 1234;
|
||||
runner_->MutableInputs()->Tag(kTensor).packets.push_back(
|
||||
runner->MutableInputs()->Tag(kTensor).packets.push_back(
|
||||
Adopt(tensor.release()).At(Timestamp(time)));
|
||||
|
||||
EXPECT_TRUE(runner_->Run().ok());
|
||||
EXPECT_TRUE(runner->Run().ok());
|
||||
const std::vector<Packet>& output_packets =
|
||||
runner_->Outputs().Tag(kImage).packets;
|
||||
runner->Outputs().Tag(kImage).packets;
|
||||
EXPECT_EQ(1, output_packets.size());
|
||||
EXPECT_EQ(time, output_packets[0].Timestamp().Value());
|
||||
const ImageFrame& output_image = output_packets[0].Get<ImageFrame>();
|
||||
EXPECT_EQ(ImageFormat::SRGB, output_image.Format());
|
||||
EXPECT_EQ(kWidth, output_image.Width());
|
||||
EXPECT_EQ(kHeight, output_image.Height());
|
||||
|
||||
|
@ -76,14 +84,15 @@ TEST_F(TensorToImageFrameCalculatorTest, Converts3DTensorToImageFrame) {
|
|||
}
|
||||
}
|
||||
|
||||
TEST_F(TensorToImageFrameCalculatorTest, Converts3DTensorToImageFrameGray) {
|
||||
SetUpRunner();
|
||||
TYPED_TEST(TensorToImageFrameCalculatorTest, Converts3DTensorToImageFrameGray) {
|
||||
this->SetUpRunner();
|
||||
auto& runner = this->runner_;
|
||||
constexpr int kWidth = 16;
|
||||
constexpr int kHeight = 8;
|
||||
const tf::TensorShape tensor_shape(
|
||||
std::vector<tf::int64>{kHeight, kWidth, 1});
|
||||
auto tensor = absl::make_unique<tf::Tensor>(tf::DT_FLOAT, tensor_shape);
|
||||
auto tensor_vec = tensor->flat<float>().data();
|
||||
const tf::TensorShape tensor_shape{kHeight, kWidth, 1};
|
||||
auto tensor = absl::make_unique<tf::Tensor>(
|
||||
tf::DataTypeToEnum<TypeParam>::v(), tensor_shape);
|
||||
auto tensor_vec = tensor->template flat<TypeParam>().data();
|
||||
|
||||
// Writing sequence of integers as floats which we want back (as they were
|
||||
// written).
|
||||
|
@ -92,15 +101,16 @@ TEST_F(TensorToImageFrameCalculatorTest, Converts3DTensorToImageFrameGray) {
|
|||
}
|
||||
|
||||
const int64 time = 1234;
|
||||
runner_->MutableInputs()->Tag(kTensor).packets.push_back(
|
||||
runner->MutableInputs()->Tag(kTensor).packets.push_back(
|
||||
Adopt(tensor.release()).At(Timestamp(time)));
|
||||
|
||||
EXPECT_TRUE(runner_->Run().ok());
|
||||
EXPECT_TRUE(runner->Run().ok());
|
||||
const std::vector<Packet>& output_packets =
|
||||
runner_->Outputs().Tag(kImage).packets;
|
||||
runner->Outputs().Tag(kImage).packets;
|
||||
EXPECT_EQ(1, output_packets.size());
|
||||
EXPECT_EQ(time, output_packets[0].Timestamp().Value());
|
||||
const ImageFrame& output_image = output_packets[0].Get<ImageFrame>();
|
||||
EXPECT_EQ(ImageFormat::GRAY8, output_image.Format());
|
||||
EXPECT_EQ(kWidth, output_image.Width());
|
||||
EXPECT_EQ(kHeight, output_image.Height());
|
||||
|
||||
|
@ -110,13 +120,16 @@ TEST_F(TensorToImageFrameCalculatorTest, Converts3DTensorToImageFrameGray) {
|
|||
}
|
||||
}
|
||||
|
||||
TEST_F(TensorToImageFrameCalculatorTest, Converts3DTensorToImageFrame2DGray) {
|
||||
SetUpRunner();
|
||||
TYPED_TEST(TensorToImageFrameCalculatorTest,
|
||||
Converts3DTensorToImageFrame2DGray) {
|
||||
this->SetUpRunner();
|
||||
auto& runner = this->runner_;
|
||||
constexpr int kWidth = 16;
|
||||
constexpr int kHeight = 8;
|
||||
const tf::TensorShape tensor_shape(std::vector<tf::int64>{kHeight, kWidth});
|
||||
auto tensor = absl::make_unique<tf::Tensor>(tf::DT_FLOAT, tensor_shape);
|
||||
auto tensor_vec = tensor->flat<float>().data();
|
||||
const tf::TensorShape tensor_shape{kHeight, kWidth};
|
||||
auto tensor = absl::make_unique<tf::Tensor>(
|
||||
tf::DataTypeToEnum<TypeParam>::v(), tensor_shape);
|
||||
auto tensor_vec = tensor->template flat<TypeParam>().data();
|
||||
|
||||
// Writing sequence of integers as floats which we want back (as they were
|
||||
// written).
|
||||
|
@ -125,15 +138,16 @@ TEST_F(TensorToImageFrameCalculatorTest, Converts3DTensorToImageFrame2DGray) {
|
|||
}
|
||||
|
||||
const int64 time = 1234;
|
||||
runner_->MutableInputs()->Tag(kTensor).packets.push_back(
|
||||
runner->MutableInputs()->Tag(kTensor).packets.push_back(
|
||||
Adopt(tensor.release()).At(Timestamp(time)));
|
||||
|
||||
EXPECT_TRUE(runner_->Run().ok());
|
||||
EXPECT_TRUE(runner->Run().ok());
|
||||
const std::vector<Packet>& output_packets =
|
||||
runner_->Outputs().Tag(kImage).packets;
|
||||
runner->Outputs().Tag(kImage).packets;
|
||||
EXPECT_EQ(1, output_packets.size());
|
||||
EXPECT_EQ(time, output_packets[0].Timestamp().Value());
|
||||
const ImageFrame& output_image = output_packets[0].Get<ImageFrame>();
|
||||
EXPECT_EQ(ImageFormat::GRAY8, output_image.Format());
|
||||
EXPECT_EQ(kWidth, output_image.Width());
|
||||
EXPECT_EQ(kHeight, output_image.Height());
|
||||
|
||||
|
|
|
@ -91,8 +91,6 @@ absl::Status FillTimeSeriesHeaderIfValid(const Packet& header_packet,
|
|||
// the input data when it arrives in Process(). In particular, if the header
|
||||
// states that we produce a 1xD column vector, the input tensor must also be 1xD
|
||||
//
|
||||
// This designed was discussed in http://g/speakeranalysis/4uyx7cNRwJY and
|
||||
// http://g/daredevil-project/VB26tcseUy8.
|
||||
// Example Config
|
||||
// node: {
|
||||
// calculator: "TensorToMatrixCalculator"
|
||||
|
@ -158,22 +156,17 @@ absl::Status TensorToMatrixCalculator::Open(CalculatorContext* cc) {
|
|||
if (header_status.ok()) {
|
||||
if (cc->Options<TensorToMatrixCalculatorOptions>()
|
||||
.has_time_series_header_overrides()) {
|
||||
// From design discussions with Daredevil, we only want to support single
|
||||
// sample per packet for now, so we hardcode the sample_rate based on the
|
||||
// packet_rate of the REFERENCE and fail noisily if we cannot. An
|
||||
// alternative would be to calculate the sample_rate from the reference
|
||||
// sample_rate and the change in num_samples between the reference and
|
||||
// override headers:
|
||||
// sample_rate_output = sample_rate_reference /
|
||||
// (num_samples_override / num_samples_reference)
|
||||
// This only supports a single sample per packet for now, so we hardcode
|
||||
// the sample_rate based on the packet_rate of the REFERENCE and fail
|
||||
// if we cannot.
|
||||
const TimeSeriesHeader& override_header =
|
||||
cc->Options<TensorToMatrixCalculatorOptions>()
|
||||
.time_series_header_overrides();
|
||||
input_header->MergeFrom(override_header);
|
||||
CHECK(input_header->has_packet_rate())
|
||||
RET_CHECK(input_header->has_packet_rate())
|
||||
<< "The TimeSeriesHeader.packet_rate must be set.";
|
||||
if (!override_header.has_sample_rate()) {
|
||||
CHECK_EQ(input_header->num_samples(), 1)
|
||||
RET_CHECK_EQ(input_header->num_samples(), 1)
|
||||
<< "Currently the time series can only output single samples.";
|
||||
input_header->set_sample_rate(input_header->packet_rate());
|
||||
}
|
||||
|
@ -186,20 +179,16 @@ absl::Status TensorToMatrixCalculator::Open(CalculatorContext* cc) {
|
|||
}
|
||||
|
||||
absl::Status TensorToMatrixCalculator::Process(CalculatorContext* cc) {
|
||||
// Daredevil requested CHECK for noisy failures rather than quieter RET_CHECK
|
||||
// failures. These are absolute conditions of the graph for the graph to be
|
||||
// valid, and if it is violated by any input anywhere, the graph will be
|
||||
// invalid for all inputs. A hard CHECK will enable faster debugging by
|
||||
// immediately exiting and more prominently displaying error messages.
|
||||
// Do not replace with RET_CHECKs.
|
||||
|
||||
// Verify that each reference stream packet corresponds to a tensor packet
|
||||
// otherwise the header information is invalid. If we don't have a reference
|
||||
// stream, Process() is only called when we have an input tensor and this is
|
||||
// always True.
|
||||
CHECK(cc->Inputs().HasTag(kTensor))
|
||||
RET_CHECK(cc->Inputs().HasTag(kTensor))
|
||||
<< "Tensor stream not available at same timestamp as the reference "
|
||||
"stream.";
|
||||
RET_CHECK(!cc->Inputs().Tag(kTensor).IsEmpty()) << "Tensor stream is empty.";
|
||||
RET_CHECK_OK(cc->Inputs().Tag(kTensor).Value().ValidateAsType<tf::Tensor>())
|
||||
<< "Tensor stream packet does not contain a Tensor.";
|
||||
|
||||
const tf::Tensor& input_tensor = cc->Inputs().Tag(kTensor).Get<tf::Tensor>();
|
||||
CHECK(1 == input_tensor.dims() || 2 == input_tensor.dims())
|
||||
|
@ -207,13 +196,12 @@ absl::Status TensorToMatrixCalculator::Process(CalculatorContext* cc) {
|
|||
const int32 length = input_tensor.dim_size(input_tensor.dims() - 1);
|
||||
const int32 width = (1 == input_tensor.dims()) ? 1 : input_tensor.dim_size(0);
|
||||
if (header_.has_num_channels()) {
|
||||
CHECK_EQ(length, header_.num_channels())
|
||||
RET_CHECK_EQ(length, header_.num_channels())
|
||||
<< "The number of channels at runtime does not match the header.";
|
||||
}
|
||||
if (header_.has_num_samples()) {
|
||||
CHECK_EQ(width, header_.num_samples())
|
||||
RET_CHECK_EQ(width, header_.num_samples())
|
||||
<< "The number of samples at runtime does not match the header.";
|
||||
;
|
||||
}
|
||||
auto output = absl::make_unique<Matrix>(width, length);
|
||||
*output =
|
||||
|
|
|
@ -98,388 +98,543 @@ class InferenceState {
|
|||
|
||||
// This calculator performs inference on a trained TensorFlow model.
|
||||
//
|
||||
// A mediapipe::TensorFlowSession with a model loaded and ready for use.
|
||||
// For this calculator it must include a tag_to_tensor_map.
|
||||
cc->InputSidePackets().Tag("SESSION").Set<TensorFlowSession>();
|
||||
if (cc->InputSidePackets().HasTag("RECURRENT_INIT_TENSORS")) {
|
||||
cc->InputSidePackets()
|
||||
.Tag("RECURRENT_INIT_TENSORS")
|
||||
.Set<std::unique_ptr<std::map<std::string, tf::Tensor>>>();
|
||||
}
|
||||
return absl::OkStatus();
|
||||
}
|
||||
// TensorFlow Sessions can be created from checkpoint paths, frozen models, or
|
||||
// the SavedModel system. See the TensorFlowSessionFrom* packet generators for
|
||||
// details. Each of these methods defines a mapping between MediaPipe streams
|
||||
// and TensorFlow tensors. All of this information is passed in as an
|
||||
// input_side_packet.
|
||||
//
|
||||
// The input and output streams are TensorFlow tensors labeled by tags. The tags
|
||||
// for the streams are matched to feeds and fetchs in a TensorFlow session using
|
||||
// a named_signature.generic_signature in the ModelManifest. The
|
||||
// generic_signature is used as key-value pairs between the MediaPipe tag and
|
||||
// the TensorFlow tensor. The signature_name in the options proto determines
|
||||
// which named_signature is used. The keys in the generic_signature must be
|
||||
// valid MediaPipe tags ([A-Z0-9_]*, no lowercase or special characters). All of
|
||||
// the tensors corresponding to tags in the signature for input_streams are fed
|
||||
// to the model and for output_streams the tensors are fetched from the model.
|
||||
//
|
||||
// Other calculators are used to convert data to and from tensors, this op only
|
||||
// handles the TensorFlow session and batching. Batching occurs by concatenating
|
||||
// input tensors along the 0th dimension across timestamps. If the 0th dimension
|
||||
// is not a batch dimension, this calculator will add a 0th dimension by
|
||||
// default. Setting add_batch_dim_to_tensors to false disables the dimension
|
||||
// addition. Once batch_size inputs have been provided, the batch will be run
|
||||
// and the output tensors sent out on the output streams with timestamps
|
||||
// corresponding to the input stream packets. Setting the batch_size to 1
|
||||
// completely disables batching, but is indepdent of add_batch_dim_to_tensors.
|
||||
//
|
||||
// The TensorFlowInferenceCalculator also support feeding states recurrently for
|
||||
// RNNs and LSTMs. Simply set the recurrent_tag_pair options to define the
|
||||
// recurrent tensors. Initializing the recurrent state can be handled by the
|
||||
// GraphTensorsPacketGenerator.
|
||||
//
|
||||
// The calculator updates two Counters to report timing information:
|
||||
// --<name>-TotalTimeUsecs = Total time spent running inference (in usecs),
|
||||
// --<name>-TotalProcessedTimestamps = # of instances processed
|
||||
// (approximately batches processed * batch_size),
|
||||
// where <name> is replaced with CalculatorGraphConfig::Node::name() if it
|
||||
// exists, or with TensorFlowInferenceCalculator if the name is not set. The
|
||||
// name must be set for timing information to be instance-specific in graphs
|
||||
// with multiple TensorFlowInferenceCalculators.
|
||||
//
|
||||
// Example config:
|
||||
// packet_generator {
|
||||
// packet_generator: "TensorFlowSessionFromSavedModelGenerator"
|
||||
// output_side_packet: "tensorflow_session"
|
||||
// options {
|
||||
// [mediapipe.TensorFlowSessionFromSavedModelGeneratorOptions.ext]: {
|
||||
// saved_model_path: "/path/to/saved/model"
|
||||
// signature_name: "mediapipe"
|
||||
// }
|
||||
// }
|
||||
// }
|
||||
// node {
|
||||
// calculator: "TensorFlowInferenceCalculator"
|
||||
// input_stream: "IMAGES:image_tensors_keyed_in_signature_by_tag"
|
||||
// input_stream: "AUDIO:audio_tensors_keyed_in_signature_by_tag"
|
||||
// output_stream: "LABELS:softmax_tensor_keyed_in_signature_by_tag"
|
||||
// input_side_packet: "SESSION:tensorflow_session"
|
||||
// }
|
||||
//
|
||||
// Where the input and output streams are treated as Packet<tf::Tensor> and
|
||||
// the mediapipe_signature has tensor bindings between "IMAGES", "AUDIO", and
|
||||
// "LABELS" and their respective tensors exported to /path/to/bundle. For an
|
||||
// example of how this model was exported, see
|
||||
// tensorflow_inference_test_graph_generator.py
|
||||
//
|
||||
// It is possible to use a GraphDef proto that was not exported by exporter (i.e
|
||||
// without MetaGraph with bindings). Such GraphDef could contain all of its
|
||||
// parameters in-lined (for example, it can be the output of freeze_graph.py).
|
||||
// To instantiate a TensorFlow model from a GraphDef file, replace the
|
||||
// packet_factory above with TensorFlowSessionFromFrozenGraphGenerator:
|
||||
//
|
||||
// packet_generator {
|
||||
// packet_generator: "TensorFlowSessionFromFrozenGraphGenerator"
|
||||
// output_side_packet: "SESSION:tensorflow_session"
|
||||
// options {
|
||||
// [mediapipe.TensorFlowSessionFromFrozenGraphGeneratorOptions.ext]: {
|
||||
// graph_proto_path: "[PATH]"
|
||||
// tag_to_tensor_names {
|
||||
// key: "JPG_STRING"
|
||||
// value: "input:0"
|
||||
// }
|
||||
// tag_to_tensor_names {
|
||||
// key: "SOFTMAX"
|
||||
// value: "softmax:0"
|
||||
// }
|
||||
// }
|
||||
// }
|
||||
// }
|
||||
//
|
||||
// It is also possible to use a GraphDef proto and checkpoint file that have not
|
||||
// been frozen. This can be used to load graphs directly as they have been
|
||||
// written from training. However, it is more brittle and you are encouraged to
|
||||
// use a one of the more perminent formats described above. To instantiate a
|
||||
// TensorFlow model from a GraphDef file and checkpoint, replace the
|
||||
// packet_factory above with TensorFlowSessionFromModelCheckpointGenerator:
|
||||
//
|
||||
// packet_generator {
|
||||
// packet_generator: "TensorFlowSessionFromModelCheckpointGenerator"
|
||||
// output_side_packet: "SESSION:tensorflow_session"
|
||||
// options {
|
||||
// [mediapipe.TensorFlowSessionFromModelCheckpointGeneratorOptions.ext]: {
|
||||
// graph_proto_path: "[PATH]"
|
||||
// model_options {
|
||||
// checkpoint_path: "[PATH2]"
|
||||
// }
|
||||
// tag_to_tensor_names {
|
||||
// key: "JPG_STRING"
|
||||
// value: "input:0"
|
||||
// }
|
||||
// tag_to_tensor_names {
|
||||
// key: "SOFTMAX"
|
||||
// value: "softmax:0"
|
||||
// }
|
||||
// }
|
||||
// }
|
||||
// }
|
||||
class TensorFlowInferenceCalculator : public CalculatorBase {
|
||||
public:
|
||||
// Counters for recording timing information. The actual names have the value
|
||||
// of CalculatorGraphConfig::Node::name() prepended.
|
||||
static constexpr char kTotalUsecsCounterSuffix[] = "TotalTimeUsecs";
|
||||
static constexpr char kTotalProcessedTimestampsCounterSuffix[] =
|
||||
"TotalProcessedTimestamps";
|
||||
static constexpr char kTotalSessionRunsTimeUsecsCounterSuffix[] =
|
||||
"TotalSessionRunsTimeUsecs";
|
||||
static constexpr char kTotalNumSessionRunsCounterSuffix[] =
|
||||
"TotalNumSessionRuns";
|
||||
|
||||
std::unique_ptr<InferenceState> CreateInferenceState(CalculatorContext* cc)
|
||||
ABSL_EXCLUSIVE_LOCKS_REQUIRED(mutex_) {
|
||||
std::unique_ptr<InferenceState> inference_state =
|
||||
absl::make_unique<InferenceState>();
|
||||
if (cc->InputSidePackets().HasTag("RECURRENT_INIT_TENSORS") &&
|
||||
!cc->InputSidePackets().Tag("RECURRENT_INIT_TENSORS").IsEmpty()) {
|
||||
std::map<std::string, tf::Tensor>* init_tensor_map;
|
||||
init_tensor_map = GetFromUniquePtr<std::map<std::string, tf::Tensor>>(
|
||||
cc->InputSidePackets().Tag("RECURRENT_INIT_TENSORS"));
|
||||
for (const auto& p : *init_tensor_map) {
|
||||
inference_state->input_tensor_batches_[p.first].emplace_back(p.second);
|
||||
TensorFlowInferenceCalculator() : session_(nullptr) {
|
||||
clock_ = std::unique_ptr<mediapipe::Clock>(
|
||||
mediapipe::MonotonicClock::CreateSynchronizedMonotonicClock());
|
||||
}
|
||||
|
||||
static absl::Status GetContract(CalculatorContract* cc) {
|
||||
const auto& options = cc->Options<TensorFlowInferenceCalculatorOptions>();
|
||||
RET_CHECK(!cc->Inputs().GetTags().empty());
|
||||
for (const std::string& tag : cc->Inputs().GetTags()) {
|
||||
// The tensorflow::Tensor with the tag equal to the graph node. May
|
||||
// have a TimeSeriesHeader if all present TimeSeriesHeaders match.
|
||||
if (!options.batched_input()) {
|
||||
cc->Inputs().Tag(tag).Set<tf::Tensor>();
|
||||
} else {
|
||||
cc->Inputs().Tag(tag).Set<std::vector<mediapipe::Packet>>();
|
||||
}
|
||||
}
|
||||
}
|
||||
return inference_state;
|
||||
}
|
||||
|
||||
absl::Status Open(CalculatorContext* cc) override {
|
||||
options_ = cc->Options<TensorFlowInferenceCalculatorOptions>();
|
||||
|
||||
RET_CHECK(cc->InputSidePackets().HasTag("SESSION"));
|
||||
session_ = cc->InputSidePackets()
|
||||
.Tag("SESSION")
|
||||
.Get<TensorFlowSession>()
|
||||
.session.get();
|
||||
tag_to_tensor_map_ = cc->InputSidePackets()
|
||||
.Tag("SESSION")
|
||||
.Get<TensorFlowSession>()
|
||||
.tag_to_tensor_map;
|
||||
|
||||
// Validate and store the recurrent tags
|
||||
RET_CHECK(options_.has_batch_size());
|
||||
RET_CHECK(options_.batch_size() == 1 || options_.recurrent_tag_pair().empty())
|
||||
<< "To use recurrent_tag_pairs, batch_size must be 1.";
|
||||
for (const auto& tag_pair : options_.recurrent_tag_pair()) {
|
||||
const std::vector<std::string> tags = absl::StrSplit(tag_pair, ':');
|
||||
RET_CHECK_EQ(tags.size(), 2)
|
||||
<< "recurrent_tag_pair must be a colon "
|
||||
"separated std::string with two components: "
|
||||
<< tag_pair;
|
||||
RET_CHECK(mediapipe::ContainsKey(tag_to_tensor_map_, tags[0]))
|
||||
<< "Can't find tag '" << tags[0] << "' in signature "
|
||||
<< options_.signature_name();
|
||||
RET_CHECK(mediapipe::ContainsKey(tag_to_tensor_map_, tags[1]))
|
||||
<< "Can't find tag '" << tags[1] << "' in signature "
|
||||
<< options_.signature_name();
|
||||
recurrent_feed_tags_.insert(tags[0]);
|
||||
recurrent_fetch_tags_to_feed_tags_[tags[1]] = tags[0];
|
||||
}
|
||||
|
||||
// Check that all tags are present in this signature bound to tensors.
|
||||
for (const std::string& tag : cc->Inputs().GetTags()) {
|
||||
RET_CHECK(mediapipe::ContainsKey(tag_to_tensor_map_, tag))
|
||||
<< "Can't find tag '" << tag << "' in signature "
|
||||
<< options_.signature_name();
|
||||
}
|
||||
for (const std::string& tag : cc->Outputs().GetTags()) {
|
||||
RET_CHECK(mediapipe::ContainsKey(tag_to_tensor_map_, tag))
|
||||
<< "Can't find tag '" << tag << "' in signature "
|
||||
<< options_.signature_name();
|
||||
}
|
||||
|
||||
{
|
||||
absl::WriterMutexLock l(&mutex_);
|
||||
inference_state_ = std::unique_ptr<InferenceState>();
|
||||
}
|
||||
|
||||
if (options_.batch_size() == 1 || options_.batched_input()) {
|
||||
cc->SetOffset(0);
|
||||
}
|
||||
|
||||
return absl::OkStatus();
|
||||
}
|
||||
|
||||
// Adds a batch dimension to the input tensor if specified in the calculator
|
||||
// options.
|
||||
absl::Status AddBatchDimension(tf::Tensor* input_tensor) {
|
||||
if (options_.add_batch_dim_to_tensors()) {
|
||||
tf::TensorShape new_shape(input_tensor->shape());
|
||||
new_shape.InsertDim(0, 1);
|
||||
RET_CHECK(input_tensor->CopyFrom(*input_tensor, new_shape))
|
||||
<< "Could not add 0th dimension to tensor without changing its shape."
|
||||
<< " Current shape: " << input_tensor->shape().DebugString();
|
||||
}
|
||||
return absl::OkStatus();
|
||||
}
|
||||
|
||||
absl::Status AggregateTensorPacket(
|
||||
const std::string& tag_name, const Packet& packet,
|
||||
std::map<Timestamp, std::map<std::string, tf::Tensor>>*
|
||||
input_tensors_by_tag_by_timestamp,
|
||||
InferenceState* inference_state) ABSL_EXCLUSIVE_LOCKS_REQUIRED(mutex_) {
|
||||
tf::Tensor input_tensor(packet.Get<tf::Tensor>());
|
||||
RET_CHECK_OK(AddBatchDimension(&input_tensor));
|
||||
if (mediapipe::ContainsKey(recurrent_feed_tags_, tag_name)) {
|
||||
// If we receive an input on a recurrent tag, override the state.
|
||||
// It's OK to override the global state because there is just one
|
||||
// input stream allowed for recurrent tensors.
|
||||
inference_state_->input_tensor_batches_[tag_name].clear();
|
||||
}
|
||||
(*input_tensors_by_tag_by_timestamp)[packet.Timestamp()].insert(
|
||||
std::make_pair(tag_name, input_tensor));
|
||||
return absl::OkStatus();
|
||||
}
|
||||
|
||||
// Removes the batch dimension of the output tensor if specified in the
|
||||
// calculator options.
|
||||
absl::Status RemoveBatchDimension(tf::Tensor* output_tensor) {
|
||||
if (options_.add_batch_dim_to_tensors()) {
|
||||
tf::TensorShape new_shape(output_tensor->shape());
|
||||
new_shape.RemoveDim(0);
|
||||
RET_CHECK(output_tensor->CopyFrom(*output_tensor, new_shape))
|
||||
<< "Could not remove 0th dimension from tensor without changing its "
|
||||
<< "shape. Current shape: " << output_tensor->shape().DebugString()
|
||||
<< " (The expected first dimension is 1 for a batch element.)";
|
||||
}
|
||||
return absl::OkStatus();
|
||||
}
|
||||
|
||||
absl::Status Process(CalculatorContext* cc) override {
|
||||
std::unique_ptr<InferenceState> inference_state_to_process;
|
||||
{
|
||||
absl::WriterMutexLock l(&mutex_);
|
||||
if (inference_state_ == nullptr) {
|
||||
inference_state_ = CreateInferenceState(cc);
|
||||
RET_CHECK(!cc->Outputs().GetTags().empty());
|
||||
for (const std::string& tag : cc->Outputs().GetTags()) {
|
||||
// The tensorflow::Tensor with tag equal to the graph node to
|
||||
// output. Any TimeSeriesHeader from the inputs will be forwarded
|
||||
// with channels set to 0.
|
||||
cc->Outputs().Tag(tag).Set<tf::Tensor>();
|
||||
}
|
||||
std::map<Timestamp, std::map<std::string, tf::Tensor>>
|
||||
input_tensors_by_tag_by_timestamp;
|
||||
for (const std::string& tag_as_node_name : cc->Inputs().GetTags()) {
|
||||
if (cc->Inputs().Tag(tag_as_node_name).IsEmpty()) {
|
||||
// Recurrent tensors can be empty.
|
||||
if (!mediapipe::ContainsKey(recurrent_feed_tags_, tag_as_node_name)) {
|
||||
if (options_.skip_on_missing_features()) {
|
||||
return absl::OkStatus();
|
||||
} else {
|
||||
return absl::InvalidArgumentError(absl::StrCat(
|
||||
"Tag ", tag_as_node_name,
|
||||
" not present at timestamp: ", cc->InputTimestamp().Value()));
|
||||
// A mediapipe::TensorFlowSession with a model loaded and ready for use.
|
||||
// For this calculator it must include a tag_to_tensor_map.
|
||||
cc->InputSidePackets().Tag("SESSION").Set<TensorFlowSession>();
|
||||
if (cc->InputSidePackets().HasTag("RECURRENT_INIT_TENSORS")) {
|
||||
cc->InputSidePackets()
|
||||
.Tag("RECURRENT_INIT_TENSORS")
|
||||
.Set<std::unique_ptr<std::map<std::string, tf::Tensor>>>();
|
||||
}
|
||||
return absl::OkStatus();
|
||||
}
|
||||
|
||||
std::unique_ptr<InferenceState> CreateInferenceState(CalculatorContext* cc)
|
||||
ABSL_EXCLUSIVE_LOCKS_REQUIRED(mutex_) {
|
||||
std::unique_ptr<InferenceState> inference_state =
|
||||
absl::make_unique<InferenceState>();
|
||||
if (cc->InputSidePackets().HasTag("RECURRENT_INIT_TENSORS") &&
|
||||
!cc->InputSidePackets().Tag("RECURRENT_INIT_TENSORS").IsEmpty()) {
|
||||
std::map<std::string, tf::Tensor>* init_tensor_map;
|
||||
init_tensor_map = GetFromUniquePtr<std::map<std::string, tf::Tensor>>(
|
||||
cc->InputSidePackets().Tag("RECURRENT_INIT_TENSORS"));
|
||||
for (const auto& p : *init_tensor_map) {
|
||||
inference_state->input_tensor_batches_[p.first].emplace_back(p.second);
|
||||
}
|
||||
}
|
||||
return inference_state;
|
||||
}
|
||||
|
||||
absl::Status Open(CalculatorContext* cc) override {
|
||||
options_ = cc->Options<TensorFlowInferenceCalculatorOptions>();
|
||||
|
||||
RET_CHECK(cc->InputSidePackets().HasTag("SESSION"));
|
||||
session_ = cc->InputSidePackets()
|
||||
.Tag("SESSION")
|
||||
.Get<TensorFlowSession>()
|
||||
.session.get();
|
||||
tag_to_tensor_map_ = cc->InputSidePackets()
|
||||
.Tag("SESSION")
|
||||
.Get<TensorFlowSession>()
|
||||
.tag_to_tensor_map;
|
||||
|
||||
// Validate and store the recurrent tags
|
||||
RET_CHECK(options_.has_batch_size());
|
||||
RET_CHECK(options_.batch_size() == 1 ||
|
||||
options_.recurrent_tag_pair().empty())
|
||||
<< "To use recurrent_tag_pairs, batch_size must be 1.";
|
||||
for (const auto& tag_pair : options_.recurrent_tag_pair()) {
|
||||
const std::vector<std::string> tags = absl::StrSplit(tag_pair, ':');
|
||||
RET_CHECK_EQ(tags.size(), 2)
|
||||
<< "recurrent_tag_pair must be a colon "
|
||||
"separated std::string with two components: "
|
||||
<< tag_pair;
|
||||
RET_CHECK(mediapipe::ContainsKey(tag_to_tensor_map_, tags[0]))
|
||||
<< "Can't find tag '" << tags[0] << "' in signature "
|
||||
<< options_.signature_name();
|
||||
RET_CHECK(mediapipe::ContainsKey(tag_to_tensor_map_, tags[1]))
|
||||
<< "Can't find tag '" << tags[1] << "' in signature "
|
||||
<< options_.signature_name();
|
||||
recurrent_feed_tags_.insert(tags[0]);
|
||||
recurrent_fetch_tags_to_feed_tags_[tags[1]] = tags[0];
|
||||
}
|
||||
|
||||
// Check that all tags are present in this signature bound to tensors.
|
||||
for (const std::string& tag : cc->Inputs().GetTags()) {
|
||||
RET_CHECK(mediapipe::ContainsKey(tag_to_tensor_map_, tag))
|
||||
<< "Can't find tag '" << tag << "' in signature "
|
||||
<< options_.signature_name();
|
||||
}
|
||||
for (const std::string& tag : cc->Outputs().GetTags()) {
|
||||
RET_CHECK(mediapipe::ContainsKey(tag_to_tensor_map_, tag))
|
||||
<< "Can't find tag '" << tag << "' in signature "
|
||||
<< options_.signature_name();
|
||||
}
|
||||
|
||||
{
|
||||
absl::WriterMutexLock l(&mutex_);
|
||||
inference_state_ = std::unique_ptr<InferenceState>();
|
||||
}
|
||||
|
||||
if (options_.batch_size() == 1 || options_.batched_input()) {
|
||||
cc->SetOffset(0);
|
||||
}
|
||||
|
||||
return absl::OkStatus();
|
||||
}
|
||||
|
||||
// Adds a batch dimension to the input tensor if specified in the calculator
|
||||
// options.
|
||||
absl::Status AddBatchDimension(tf::Tensor* input_tensor) {
|
||||
if (options_.add_batch_dim_to_tensors()) {
|
||||
tf::TensorShape new_shape(input_tensor->shape());
|
||||
new_shape.InsertDim(0, 1);
|
||||
RET_CHECK(input_tensor->CopyFrom(*input_tensor, new_shape))
|
||||
<< "Could not add 0th dimension to tensor without changing its shape."
|
||||
<< " Current shape: " << input_tensor->shape().DebugString();
|
||||
}
|
||||
return absl::OkStatus();
|
||||
}
|
||||
|
||||
absl::Status AggregateTensorPacket(
|
||||
const std::string& tag_name, const Packet& packet,
|
||||
std::map<Timestamp, std::map<std::string, tf::Tensor>>*
|
||||
input_tensors_by_tag_by_timestamp,
|
||||
InferenceState* inference_state) ABSL_EXCLUSIVE_LOCKS_REQUIRED(mutex_) {
|
||||
tf::Tensor input_tensor(packet.Get<tf::Tensor>());
|
||||
RET_CHECK_OK(AddBatchDimension(&input_tensor));
|
||||
if (mediapipe::ContainsKey(recurrent_feed_tags_, tag_name)) {
|
||||
// If we receive an input on a recurrent tag, override the state.
|
||||
// It's OK to override the global state because there is just one
|
||||
// input stream allowed for recurrent tensors.
|
||||
inference_state_->input_tensor_batches_[tag_name].clear();
|
||||
}
|
||||
(*input_tensors_by_tag_by_timestamp)[packet.Timestamp()].insert(
|
||||
std::make_pair(tag_name, input_tensor));
|
||||
return absl::OkStatus();
|
||||
}
|
||||
|
||||
// Removes the batch dimension of the output tensor if specified in the
|
||||
// calculator options.
|
||||
absl::Status RemoveBatchDimension(tf::Tensor* output_tensor) {
|
||||
if (options_.add_batch_dim_to_tensors()) {
|
||||
tf::TensorShape new_shape(output_tensor->shape());
|
||||
new_shape.RemoveDim(0);
|
||||
RET_CHECK(output_tensor->CopyFrom(*output_tensor, new_shape))
|
||||
<< "Could not remove 0th dimension from tensor without changing its "
|
||||
<< "shape. Current shape: " << output_tensor->shape().DebugString()
|
||||
<< " (The expected first dimension is 1 for a batch element.)";
|
||||
}
|
||||
return absl::OkStatus();
|
||||
}
|
||||
|
||||
absl::Status Process(CalculatorContext* cc) override {
|
||||
std::unique_ptr<InferenceState> inference_state_to_process;
|
||||
{
|
||||
absl::WriterMutexLock l(&mutex_);
|
||||
if (inference_state_ == nullptr) {
|
||||
inference_state_ = CreateInferenceState(cc);
|
||||
}
|
||||
std::map<Timestamp, std::map<std::string, tf::Tensor>>
|
||||
input_tensors_by_tag_by_timestamp;
|
||||
for (const std::string& tag_as_node_name : cc->Inputs().GetTags()) {
|
||||
if (cc->Inputs().Tag(tag_as_node_name).IsEmpty()) {
|
||||
// Recurrent tensors can be empty.
|
||||
if (!mediapipe::ContainsKey(recurrent_feed_tags_, tag_as_node_name)) {
|
||||
if (options_.skip_on_missing_features()) {
|
||||
return absl::OkStatus();
|
||||
} else {
|
||||
return absl::InvalidArgumentError(absl::StrCat(
|
||||
"Tag ", tag_as_node_name,
|
||||
" not present at timestamp: ", cc->InputTimestamp().Value()));
|
||||
}
|
||||
}
|
||||
} else if (options_.batched_input()) {
|
||||
const auto& tensor_packets =
|
||||
cc->Inputs().Tag(tag_as_node_name).Get<std::vector<Packet>>();
|
||||
if (tensor_packets.size() > options_.batch_size()) {
|
||||
return absl::InvalidArgumentError(absl::StrCat(
|
||||
"Batch for tag ", tag_as_node_name,
|
||||
" has more packets than batch capacity. batch_size: ",
|
||||
options_.batch_size(), " packets: ", tensor_packets.size()));
|
||||
}
|
||||
for (const auto& packet : tensor_packets) {
|
||||
RET_CHECK_OK(AggregateTensorPacket(
|
||||
tag_as_node_name, packet, &input_tensors_by_tag_by_timestamp,
|
||||
inference_state_.get()));
|
||||
}
|
||||
} else {
|
||||
RET_CHECK_OK(AggregateTensorPacket(
|
||||
tag_as_node_name, cc->Inputs().Tag(tag_as_node_name).Value(),
|
||||
&input_tensors_by_tag_by_timestamp, inference_state_.get()));
|
||||
}
|
||||
} else if (options_.batched_input()) {
|
||||
const auto& tensor_packets =
|
||||
cc->Inputs().Tag(tag_as_node_name).Get<std::vector<Packet>>();
|
||||
if (tensor_packets.size() > options_.batch_size()) {
|
||||
return absl::InvalidArgumentError(absl::StrCat(
|
||||
"Batch for tag ", tag_as_node_name,
|
||||
" has more packets than batch capacity. batch_size: ",
|
||||
options_.batch_size(), " packets: ", tensor_packets.size()));
|
||||
}
|
||||
for (const auto& timestamp_and_input_tensors_by_tag :
|
||||
input_tensors_by_tag_by_timestamp) {
|
||||
inference_state_->batch_timestamps_.emplace_back(
|
||||
timestamp_and_input_tensors_by_tag.first);
|
||||
for (const auto& input_tensor_and_tag :
|
||||
timestamp_and_input_tensors_by_tag.second) {
|
||||
inference_state_->input_tensor_batches_[input_tensor_and_tag.first]
|
||||
.emplace_back(input_tensor_and_tag.second);
|
||||
}
|
||||
for (const auto& packet : tensor_packets) {
|
||||
RET_CHECK_OK(AggregateTensorPacket(tag_as_node_name, packet,
|
||||
&input_tensors_by_tag_by_timestamp,
|
||||
inference_state_.get()));
|
||||
}
|
||||
if (inference_state_->batch_timestamps_.size() == options_.batch_size() ||
|
||||
options_.batched_input()) {
|
||||
inference_state_to_process = std::move(inference_state_);
|
||||
inference_state_ = std::unique_ptr<InferenceState>();
|
||||
}
|
||||
}
|
||||
|
||||
if (inference_state_to_process) {
|
||||
MP_RETURN_IF_ERROR(
|
||||
OutputBatch(cc, std::move(inference_state_to_process)));
|
||||
}
|
||||
|
||||
return absl::OkStatus();
|
||||
}
|
||||
|
||||
absl::Status Close(CalculatorContext* cc) override {
|
||||
std::unique_ptr<InferenceState> inference_state_to_process = nullptr;
|
||||
{
|
||||
absl::WriterMutexLock l(&mutex_);
|
||||
if (cc->GraphStatus().ok() && inference_state_ != nullptr &&
|
||||
!inference_state_->batch_timestamps_.empty()) {
|
||||
inference_state_to_process = std::move(inference_state_);
|
||||
inference_state_ = std::unique_ptr<InferenceState>();
|
||||
}
|
||||
}
|
||||
if (inference_state_to_process) {
|
||||
MP_RETURN_IF_ERROR(
|
||||
OutputBatch(cc, std::move(inference_state_to_process)));
|
||||
}
|
||||
return absl::OkStatus();
|
||||
}
|
||||
|
||||
// When a batch of input tensors is ready to be run, runs TensorFlow and
|
||||
// outputs the output tensors. The output tensors have timestamps matching
|
||||
// the input tensor that formed that batch element. Any requested
|
||||
// batch_dimension is added and removed. This code takes advantage of the fact
|
||||
// that copying a tensor shares the same reference-counted, heap allocated
|
||||
// memory buffer. Therefore, copies are cheap and should not cause the memory
|
||||
// buffer to fall out of scope. In contrast, concat is only used where
|
||||
// necessary.
|
||||
absl::Status OutputBatch(CalculatorContext* cc,
|
||||
std::unique_ptr<InferenceState> inference_state) {
|
||||
const int64 start_time = absl::ToUnixMicros(clock_->TimeNow());
|
||||
std::vector<std::pair<mediapipe::ProtoString, tf::Tensor>> input_tensors;
|
||||
|
||||
for (auto& keyed_tensors : inference_state->input_tensor_batches_) {
|
||||
if (options_.batch_size() == 1) {
|
||||
// Short circuit to avoid the cost of deep copying tensors in concat.
|
||||
if (!keyed_tensors.second.empty()) {
|
||||
input_tensors.emplace_back(tag_to_tensor_map_[keyed_tensors.first],
|
||||
keyed_tensors.second[0]);
|
||||
} else {
|
||||
// The input buffer can be empty for recurrent tensors.
|
||||
RET_CHECK(
|
||||
mediapipe::ContainsKey(recurrent_feed_tags_, keyed_tensors.first))
|
||||
<< "A non-recurrent tensor does not have an input: "
|
||||
<< keyed_tensors.first;
|
||||
}
|
||||
} else {
|
||||
RET_CHECK_OK(AggregateTensorPacket(
|
||||
tag_as_node_name, cc->Inputs().Tag(tag_as_node_name).Value(),
|
||||
&input_tensors_by_tag_by_timestamp, inference_state_.get()));
|
||||
}
|
||||
}
|
||||
for (const auto& timestamp_and_input_tensors_by_tag :
|
||||
input_tensors_by_tag_by_timestamp) {
|
||||
inference_state_->batch_timestamps_.emplace_back(
|
||||
timestamp_and_input_tensors_by_tag.first);
|
||||
for (const auto& input_tensor_and_tag :
|
||||
timestamp_and_input_tensors_by_tag.second) {
|
||||
inference_state_->input_tensor_batches_[input_tensor_and_tag.first]
|
||||
.emplace_back(input_tensor_and_tag.second);
|
||||
}
|
||||
}
|
||||
if (inference_state_->batch_timestamps_.size() == options_.batch_size() ||
|
||||
options_.batched_input()) {
|
||||
inference_state_to_process = std::move(inference_state_);
|
||||
inference_state_ = std::unique_ptr<InferenceState>();
|
||||
}
|
||||
}
|
||||
|
||||
if (inference_state_to_process) {
|
||||
MP_RETURN_IF_ERROR(OutputBatch(cc, std::move(inference_state_to_process)));
|
||||
}
|
||||
|
||||
return absl::OkStatus();
|
||||
}
|
||||
|
||||
absl::Status Close(CalculatorContext* cc) override {
|
||||
std::unique_ptr<InferenceState> inference_state_to_process = nullptr;
|
||||
{
|
||||
absl::WriterMutexLock l(&mutex_);
|
||||
if (cc->GraphStatus().ok() && inference_state_ != nullptr &&
|
||||
!inference_state_->batch_timestamps_.empty()) {
|
||||
inference_state_to_process = std::move(inference_state_);
|
||||
inference_state_ = std::unique_ptr<InferenceState>();
|
||||
}
|
||||
}
|
||||
if (inference_state_to_process) {
|
||||
MP_RETURN_IF_ERROR(OutputBatch(cc, std::move(inference_state_to_process)));
|
||||
}
|
||||
return absl::OkStatus();
|
||||
}
|
||||
|
||||
// When a batch of input tensors is ready to be run, runs TensorFlow and
|
||||
// outputs the output tensors. The output tensors have timestamps matching
|
||||
// the input tensor that formed that batch element. Any requested
|
||||
// batch_dimension is added and removed. This code takes advantage of the fact
|
||||
// that copying a tensor shares the same reference-counted, heap allocated
|
||||
// memory buffer. Therefore, copies are cheap and should not cause the memory
|
||||
// buffer to fall out of scope. In contrast, concat is only used where
|
||||
// necessary.
|
||||
absl::Status OutputBatch(CalculatorContext* cc,
|
||||
std::unique_ptr<InferenceState> inference_state) {
|
||||
const int64 start_time = absl::ToUnixMicros(clock_->TimeNow());
|
||||
std::vector<std::pair<mediapipe::ProtoString, tf::Tensor>> input_tensors;
|
||||
|
||||
for (auto& keyed_tensors : inference_state->input_tensor_batches_) {
|
||||
if (options_.batch_size() == 1) {
|
||||
// Short circuit to avoid the cost of deep copying tensors in concat.
|
||||
if (!keyed_tensors.second.empty()) {
|
||||
// Pad by replicating the first tens or, then ignore the values.
|
||||
keyed_tensors.second.resize(options_.batch_size());
|
||||
std::fill(keyed_tensors.second.begin() +
|
||||
inference_state->batch_timestamps_.size(),
|
||||
keyed_tensors.second.end(), keyed_tensors.second[0]);
|
||||
tf::Tensor concated;
|
||||
const tf::Status concat_status =
|
||||
tf::tensor::Concat(keyed_tensors.second, &concated);
|
||||
CHECK(concat_status.ok()) << concat_status.ToString();
|
||||
input_tensors.emplace_back(tag_to_tensor_map_[keyed_tensors.first],
|
||||
keyed_tensors.second[0]);
|
||||
} else {
|
||||
// The input buffer can be empty for recurrent tensors.
|
||||
RET_CHECK(
|
||||
mediapipe::ContainsKey(recurrent_feed_tags_, keyed_tensors.first))
|
||||
<< "A non-recurrent tensor does not have an input: "
|
||||
<< keyed_tensors.first;
|
||||
concated);
|
||||
}
|
||||
} else {
|
||||
// Pad by replicating the first tens or, then ignore the values.
|
||||
keyed_tensors.second.resize(options_.batch_size());
|
||||
std::fill(keyed_tensors.second.begin() +
|
||||
inference_state->batch_timestamps_.size(),
|
||||
keyed_tensors.second.end(), keyed_tensors.second[0]);
|
||||
tf::Tensor concated;
|
||||
const tf::Status concat_status =
|
||||
tf::tensor::Concat(keyed_tensors.second, &concated);
|
||||
CHECK(concat_status.ok()) << concat_status.ToString();
|
||||
input_tensors.emplace_back(tag_to_tensor_map_[keyed_tensors.first],
|
||||
concated);
|
||||
}
|
||||
}
|
||||
inference_state->input_tensor_batches_.clear();
|
||||
std::vector<mediapipe::ProtoString> output_tensor_names;
|
||||
std::vector<std::string> output_name_in_signature;
|
||||
for (const std::string& tag : cc->Outputs().GetTags()) {
|
||||
output_tensor_names.emplace_back(tag_to_tensor_map_[tag]);
|
||||
output_name_in_signature.emplace_back(tag);
|
||||
}
|
||||
for (const auto& tag_pair : recurrent_fetch_tags_to_feed_tags_) {
|
||||
// Ensure that we always fetch the recurrent state tensors.
|
||||
if (std::find(output_name_in_signature.begin(),
|
||||
output_name_in_signature.end(),
|
||||
tag_pair.first) == output_name_in_signature.end()) {
|
||||
output_tensor_names.emplace_back(tag_to_tensor_map_[tag_pair.first]);
|
||||
output_name_in_signature.emplace_back(tag_pair.first);
|
||||
inference_state->input_tensor_batches_.clear();
|
||||
std::vector<mediapipe::ProtoString> output_tensor_names;
|
||||
std::vector<std::string> output_name_in_signature;
|
||||
for (const std::string& tag : cc->Outputs().GetTags()) {
|
||||
output_tensor_names.emplace_back(tag_to_tensor_map_[tag]);
|
||||
output_name_in_signature.emplace_back(tag);
|
||||
}
|
||||
}
|
||||
std::vector<tf::Tensor> outputs;
|
||||
for (const auto& tag_pair : recurrent_fetch_tags_to_feed_tags_) {
|
||||
// Ensure that we always fetch the recurrent state tensors.
|
||||
if (std::find(output_name_in_signature.begin(),
|
||||
output_name_in_signature.end(),
|
||||
tag_pair.first) == output_name_in_signature.end()) {
|
||||
output_tensor_names.emplace_back(tag_to_tensor_map_[tag_pair.first]);
|
||||
output_name_in_signature.emplace_back(tag_pair.first);
|
||||
}
|
||||
}
|
||||
std::vector<tf::Tensor> outputs;
|
||||
|
||||
SimpleSemaphore* session_run_throttle = nullptr;
|
||||
if (options_.max_concurrent_session_runs() > 0) {
|
||||
session_run_throttle =
|
||||
get_session_run_throttle(options_.max_concurrent_session_runs());
|
||||
session_run_throttle->Acquire(1);
|
||||
}
|
||||
const int64 run_start_time = absl::ToUnixMicros(clock_->TimeNow());
|
||||
tf::Status tf_status;
|
||||
{
|
||||
SimpleSemaphore* session_run_throttle = nullptr;
|
||||
if (options_.max_concurrent_session_runs() > 0) {
|
||||
session_run_throttle =
|
||||
get_session_run_throttle(options_.max_concurrent_session_runs());
|
||||
session_run_throttle->Acquire(1);
|
||||
}
|
||||
const int64 run_start_time = absl::ToUnixMicros(clock_->TimeNow());
|
||||
tf::Status tf_status;
|
||||
{
|
||||
#if !defined(MEDIAPIPE_MOBILE) && !defined(__APPLE__)
|
||||
tensorflow::profiler::TraceMe trace(absl::string_view(cc->NodeName()));
|
||||
tensorflow::profiler::TraceMe trace(absl::string_view(cc->NodeName()));
|
||||
#endif
|
||||
tf_status = session_->Run(input_tensors, output_tensor_names,
|
||||
{} /* target_node_names */, &outputs);
|
||||
}
|
||||
tf_status = session_->Run(input_tensors, output_tensor_names,
|
||||
{} /* target_node_names */, &outputs);
|
||||
}
|
||||
|
||||
if (session_run_throttle != nullptr) {
|
||||
session_run_throttle->Release(1);
|
||||
}
|
||||
if (session_run_throttle != nullptr) {
|
||||
session_run_throttle->Release(1);
|
||||
}
|
||||
|
||||
// RET_CHECK on the tf::Status object itself in order to print an
|
||||
// informative error message.
|
||||
RET_CHECK(tf_status.ok()) << "Run failed: " << tf_status.ToString();
|
||||
// RET_CHECK on the tf::Status object itself in order to print an
|
||||
// informative error message.
|
||||
RET_CHECK(tf_status.ok()) << "Run failed: " << tf_status.ToString();
|
||||
|
||||
const int64 run_end_time = absl::ToUnixMicros(clock_->TimeNow());
|
||||
cc->GetCounter(kTotalSessionRunsTimeUsecsCounterSuffix)
|
||||
->IncrementBy(run_end_time - run_start_time);
|
||||
cc->GetCounter(kTotalNumSessionRunsCounterSuffix)->Increment();
|
||||
const int64 run_end_time = absl::ToUnixMicros(clock_->TimeNow());
|
||||
cc->GetCounter(kTotalSessionRunsTimeUsecsCounterSuffix)
|
||||
->IncrementBy(run_end_time - run_start_time);
|
||||
cc->GetCounter(kTotalNumSessionRunsCounterSuffix)->Increment();
|
||||
|
||||
// Feed back the recurrent state.
|
||||
for (const auto& tag_pair : recurrent_fetch_tags_to_feed_tags_) {
|
||||
int pos = std::find(output_name_in_signature.begin(),
|
||||
output_name_in_signature.end(), tag_pair.first) -
|
||||
output_name_in_signature.begin();
|
||||
inference_state->input_tensor_batches_[tag_pair.second].emplace_back(
|
||||
outputs[pos]);
|
||||
}
|
||||
// Feed back the recurrent state.
|
||||
for (const auto& tag_pair : recurrent_fetch_tags_to_feed_tags_) {
|
||||
int pos = std::find(output_name_in_signature.begin(),
|
||||
output_name_in_signature.end(), tag_pair.first) -
|
||||
output_name_in_signature.begin();
|
||||
inference_state->input_tensor_batches_[tag_pair.second].emplace_back(
|
||||
outputs[pos]);
|
||||
}
|
||||
|
||||
absl::WriterMutexLock l(&mutex_);
|
||||
// Set that we want to split on each index of the 0th dimension.
|
||||
std::vector<tf::int64> split_vector(options_.batch_size(), 1);
|
||||
for (int i = 0; i < output_tensor_names.size(); ++i) {
|
||||
if (options_.batch_size() == 1) {
|
||||
if (cc->Outputs().HasTag(output_name_in_signature[i])) {
|
||||
tf::Tensor output_tensor(outputs[i]);
|
||||
RET_CHECK_OK(RemoveBatchDimension(&output_tensor));
|
||||
cc->Outputs()
|
||||
.Tag(output_name_in_signature[i])
|
||||
.Add(new tf::Tensor(output_tensor),
|
||||
inference_state->batch_timestamps_[0]);
|
||||
}
|
||||
} else {
|
||||
std::vector<tf::Tensor> split_tensors;
|
||||
const tf::Status split_status =
|
||||
tf::tensor::Split(outputs[i], split_vector, &split_tensors);
|
||||
CHECK(split_status.ok()) << split_status.ToString();
|
||||
// Loop over timestamps so that we don't copy the padding.
|
||||
for (int j = 0; j < inference_state->batch_timestamps_.size(); ++j) {
|
||||
tf::Tensor output_tensor(split_tensors[j]);
|
||||
RET_CHECK_OK(RemoveBatchDimension(&output_tensor));
|
||||
cc->Outputs()
|
||||
.Tag(output_name_in_signature[i])
|
||||
.Add(new tf::Tensor(output_tensor),
|
||||
inference_state->batch_timestamps_[j]);
|
||||
absl::WriterMutexLock l(&mutex_);
|
||||
// Set that we want to split on each index of the 0th dimension.
|
||||
std::vector<tf::int64> split_vector(options_.batch_size(), 1);
|
||||
for (int i = 0; i < output_tensor_names.size(); ++i) {
|
||||
if (options_.batch_size() == 1) {
|
||||
if (cc->Outputs().HasTag(output_name_in_signature[i])) {
|
||||
tf::Tensor output_tensor(outputs[i]);
|
||||
RET_CHECK_OK(RemoveBatchDimension(&output_tensor));
|
||||
cc->Outputs()
|
||||
.Tag(output_name_in_signature[i])
|
||||
.Add(new tf::Tensor(output_tensor),
|
||||
inference_state->batch_timestamps_[0]);
|
||||
}
|
||||
} else {
|
||||
std::vector<tf::Tensor> split_tensors;
|
||||
const tf::Status split_status =
|
||||
tf::tensor::Split(outputs[i], split_vector, &split_tensors);
|
||||
CHECK(split_status.ok()) << split_status.ToString();
|
||||
// Loop over timestamps so that we don't copy the padding.
|
||||
for (int j = 0; j < inference_state->batch_timestamps_.size(); ++j) {
|
||||
tf::Tensor output_tensor(split_tensors[j]);
|
||||
RET_CHECK_OK(RemoveBatchDimension(&output_tensor));
|
||||
cc->Outputs()
|
||||
.Tag(output_name_in_signature[i])
|
||||
.Add(new tf::Tensor(output_tensor),
|
||||
inference_state->batch_timestamps_[j]);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// Get end time and report.
|
||||
const int64 end_time = absl::ToUnixMicros(clock_->TimeNow());
|
||||
cc->GetCounter(kTotalUsecsCounterSuffix)
|
||||
->IncrementBy(end_time - start_time);
|
||||
cc->GetCounter(kTotalProcessedTimestampsCounterSuffix)
|
||||
->IncrementBy(inference_state->batch_timestamps_.size());
|
||||
|
||||
// Make sure we hold on to the recursive state.
|
||||
if (!options_.recurrent_tag_pair().empty()) {
|
||||
inference_state_ = std::move(inference_state);
|
||||
inference_state_->batch_timestamps_.clear();
|
||||
}
|
||||
|
||||
return absl::OkStatus();
|
||||
}
|
||||
|
||||
// Get end time and report.
|
||||
const int64 end_time = absl::ToUnixMicros(clock_->TimeNow());
|
||||
cc->GetCounter(kTotalUsecsCounterSuffix)->IncrementBy(end_time - start_time);
|
||||
cc->GetCounter(kTotalProcessedTimestampsCounterSuffix)
|
||||
->IncrementBy(inference_state->batch_timestamps_.size());
|
||||
private:
|
||||
// The Session object is provided by a packet factory and is owned by the
|
||||
// MediaPipe framework. Individual calls are thread-safe, but session state
|
||||
// may be shared across threads.
|
||||
tf::Session* session_;
|
||||
|
||||
// Make sure we hold on to the recursive state.
|
||||
if (!options_.recurrent_tag_pair().empty()) {
|
||||
inference_state_ = std::move(inference_state);
|
||||
inference_state_->batch_timestamps_.clear();
|
||||
// A mapping between stream tags and the tensor names they are bound to.
|
||||
std::map<std::string, std::string> tag_to_tensor_map_;
|
||||
|
||||
absl::Mutex mutex_;
|
||||
std::unique_ptr<InferenceState> inference_state_ ABSL_GUARDED_BY(mutex_);
|
||||
|
||||
// The options for the calculator.
|
||||
TensorFlowInferenceCalculatorOptions options_;
|
||||
|
||||
// Store the feed and fetch tags for feed/fetch recurrent networks.
|
||||
std::set<std::string> recurrent_feed_tags_;
|
||||
std::map<std::string, std::string> recurrent_fetch_tags_to_feed_tags_;
|
||||
|
||||
// Clock used to measure the computation time in OutputBatch().
|
||||
std::unique_ptr<mediapipe::Clock> clock_;
|
||||
|
||||
// The static singleton semaphore to throttle concurrent session runs.
|
||||
static SimpleSemaphore* get_session_run_throttle(
|
||||
int32 max_concurrent_session_runs) {
|
||||
static SimpleSemaphore* session_run_throttle =
|
||||
new SimpleSemaphore(max_concurrent_session_runs);
|
||||
return session_run_throttle;
|
||||
}
|
||||
|
||||
return absl::OkStatus();
|
||||
}
|
||||
|
||||
private:
|
||||
// The Session object is provided by a packet factory and is owned by the
|
||||
// MediaPipe framework. Individual calls are thread-safe, but session state may
|
||||
// be shared across threads.
|
||||
tf::Session* session_;
|
||||
|
||||
// A mapping between stream tags and the tensor names they are bound to.
|
||||
std::map<std::string, std::string> tag_to_tensor_map_;
|
||||
|
||||
absl::Mutex mutex_;
|
||||
std::unique_ptr<InferenceState> inference_state_ ABSL_GUARDED_BY(mutex_);
|
||||
|
||||
// The options for the calculator.
|
||||
TensorFlowInferenceCalculatorOptions options_;
|
||||
|
||||
// Store the feed and fetch tags for feed/fetch recurrent networks.
|
||||
std::set<std::string> recurrent_feed_tags_;
|
||||
std::map<std::string, std::string> recurrent_fetch_tags_to_feed_tags_;
|
||||
|
||||
// Clock used to measure the computation time in OutputBatch().
|
||||
std::unique_ptr<mediapipe::Clock> clock_;
|
||||
|
||||
// The static singleton semaphore to throttle concurrent session runs.
|
||||
static SimpleSemaphore* get_session_run_throttle(
|
||||
int32 max_concurrent_session_runs) {
|
||||
static SimpleSemaphore* session_run_throttle =
|
||||
new SimpleSemaphore(max_concurrent_session_runs);
|
||||
return session_run_throttle;
|
||||
}
|
||||
}
|
||||
;
|
||||
};
|
||||
REGISTER_CALCULATOR(TensorFlowInferenceCalculator);
|
||||
|
||||
constexpr char TensorFlowInferenceCalculator::kTotalUsecsCounterSuffix[];
|
||||
|
|
|
@ -80,6 +80,7 @@ const std::string MaybeConvertSignatureToTag(
|
|||
// which in turn contains a TensorFlow Session ready for execution and a map
|
||||
// between tags and tensor names.
|
||||
//
|
||||
//
|
||||
// Example usage:
|
||||
// node {
|
||||
// calculator: "TensorFlowSessionFromSavedModelCalculator"
|
||||
|
|
|
@ -217,38 +217,41 @@ class UnpackMediaSequenceCalculator : public CalculatorBase {
|
|||
first_timestamp_seen_ = recent_timestamp;
|
||||
}
|
||||
}
|
||||
if (recent_timestamp > last_timestamp_seen) {
|
||||
if (recent_timestamp > last_timestamp_seen &&
|
||||
recent_timestamp < Timestamp::PostStream().Value()) {
|
||||
last_timestamp_key_ = map_kv.first;
|
||||
last_timestamp_seen = recent_timestamp;
|
||||
}
|
||||
}
|
||||
}
|
||||
if (!timestamps_.empty()) {
|
||||
RET_CHECK(!last_timestamp_key_.empty())
|
||||
<< "Something went wrong because the timestamp key is unset. "
|
||||
"Example: "
|
||||
<< sequence_->DebugString();
|
||||
RET_CHECK_GT(last_timestamp_seen, Timestamp::PreStream().Value())
|
||||
<< "Something went wrong because the last timestamp is unset. "
|
||||
"Example: "
|
||||
<< sequence_->DebugString();
|
||||
RET_CHECK_LT(first_timestamp_seen_,
|
||||
Timestamp::OneOverPostStream().Value())
|
||||
<< "Something went wrong because the first timestamp is unset. "
|
||||
"Example: "
|
||||
<< sequence_->DebugString();
|
||||
for (const auto& kv : timestamps_) {
|
||||
if (!kv.second.empty() &&
|
||||
kv.second[0] < Timestamp::PostStream().Value()) {
|
||||
// These checks only make sense if any values are not PostStream, but
|
||||
// only need to be made once.
|
||||
RET_CHECK(!last_timestamp_key_.empty())
|
||||
<< "Something went wrong because the timestamp key is unset. "
|
||||
<< "Example: " << sequence_->DebugString();
|
||||
RET_CHECK_GT(last_timestamp_seen, Timestamp::PreStream().Value())
|
||||
<< "Something went wrong because the last timestamp is unset. "
|
||||
<< "Example: " << sequence_->DebugString();
|
||||
RET_CHECK_LT(first_timestamp_seen_,
|
||||
Timestamp::OneOverPostStream().Value())
|
||||
<< "Something went wrong because the first timestamp is unset. "
|
||||
<< "Example: " << sequence_->DebugString();
|
||||
break;
|
||||
}
|
||||
}
|
||||
}
|
||||
current_timestamp_index_ = 0;
|
||||
process_poststream_ = false;
|
||||
|
||||
// Determine the data path and output it.
|
||||
const auto& options = cc->Options<UnpackMediaSequenceCalculatorOptions>();
|
||||
const auto& sequence = cc->InputSidePackets()
|
||||
.Tag(kSequenceExampleTag)
|
||||
.Get<tensorflow::SequenceExample>();
|
||||
if (cc->Outputs().HasTag(kKeypointsTag)) {
|
||||
keypoint_names_ = absl::StrSplit(options.keypoint_names(), ',');
|
||||
default_keypoint_location_ = options.default_keypoint_location();
|
||||
}
|
||||
if (cc->OutputSidePackets().HasTag(kDataPath)) {
|
||||
std::string root_directory = "";
|
||||
if (cc->InputSidePackets().HasTag(kDatasetRootDirTag)) {
|
||||
|
@ -349,19 +352,30 @@ class UnpackMediaSequenceCalculator : public CalculatorBase {
|
|||
// all packets on all streams that have a timestamp between the current
|
||||
// reference timestep and the previous reference timestep. This ensures that
|
||||
// we emit all timestamps in order, but also only emit a limited number in
|
||||
// any particular call to Process().
|
||||
int64 start_timestamp =
|
||||
timestamps_[last_timestamp_key_][current_timestamp_index_];
|
||||
if (current_timestamp_index_ == 0) {
|
||||
start_timestamp = first_timestamp_seen_;
|
||||
// any particular call to Process(). At the every end, we output the
|
||||
// poststream packets. If we only have poststream packets,
|
||||
// last_timestamp_key_ will be empty.
|
||||
int64 start_timestamp = 0;
|
||||
int64 end_timestamp = 0;
|
||||
if (last_timestamp_key_.empty() || process_poststream_) {
|
||||
process_poststream_ = true;
|
||||
start_timestamp = Timestamp::PostStream().Value();
|
||||
end_timestamp = Timestamp::OneOverPostStream().Value();
|
||||
} else {
|
||||
start_timestamp =
|
||||
timestamps_[last_timestamp_key_][current_timestamp_index_];
|
||||
if (current_timestamp_index_ == 0) {
|
||||
start_timestamp = first_timestamp_seen_;
|
||||
}
|
||||
|
||||
end_timestamp = start_timestamp + 1; // Base case at end of sequence.
|
||||
if (current_timestamp_index_ <
|
||||
timestamps_[last_timestamp_key_].size() - 1) {
|
||||
end_timestamp =
|
||||
timestamps_[last_timestamp_key_][current_timestamp_index_ + 1];
|
||||
}
|
||||
}
|
||||
|
||||
int64 end_timestamp = start_timestamp + 1; // Base case at end of sequence.
|
||||
if (current_timestamp_index_ <
|
||||
timestamps_[last_timestamp_key_].size() - 1) {
|
||||
end_timestamp =
|
||||
timestamps_[last_timestamp_key_][current_timestamp_index_ + 1];
|
||||
}
|
||||
for (const auto& map_kv : timestamps_) {
|
||||
for (int i = 0; i < map_kv.second.size(); ++i) {
|
||||
if (map_kv.second[i] >= start_timestamp &&
|
||||
|
@ -438,7 +452,14 @@ class UnpackMediaSequenceCalculator : public CalculatorBase {
|
|||
if (current_timestamp_index_ < timestamps_[last_timestamp_key_].size()) {
|
||||
return absl::OkStatus();
|
||||
} else {
|
||||
return tool::StatusStop();
|
||||
if (process_poststream_) {
|
||||
// Once we've processed the PostStream timestamp we can stop.
|
||||
return tool::StatusStop();
|
||||
} else {
|
||||
// Otherwise, we still need to do one more pass to process it.
|
||||
process_poststream_ = true;
|
||||
return absl::OkStatus();
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
|
@ -462,6 +483,7 @@ class UnpackMediaSequenceCalculator : public CalculatorBase {
|
|||
std::vector<std::string> keypoint_names_;
|
||||
// Default keypoint location when missing.
|
||||
float default_keypoint_location_;
|
||||
bool process_poststream_;
|
||||
};
|
||||
REGISTER_CALCULATOR(UnpackMediaSequenceCalculator);
|
||||
} // namespace mediapipe
|
||||
|
|
|
@ -412,6 +412,72 @@ TEST_F(UnpackMediaSequenceCalculatorTest, UnpacksTwoPostStreamFloatLists) {
|
|||
::testing::Eq(Timestamp::PostStream()));
|
||||
}
|
||||
|
||||
TEST_F(UnpackMediaSequenceCalculatorTest, UnpacksImageWithPostStreamFloatList) {
|
||||
SetUpCalculator({"IMAGE:images"}, {});
|
||||
auto input_sequence = absl::make_unique<tf::SequenceExample>();
|
||||
std::string test_video_id = "test_video_id";
|
||||
mpms::SetClipMediaId(test_video_id, input_sequence.get());
|
||||
|
||||
std::string test_image_string = "test_image_string";
|
||||
|
||||
int num_images = 1;
|
||||
for (int i = 0; i < num_images; ++i) {
|
||||
mpms::AddImageTimestamp(i, input_sequence.get());
|
||||
mpms::AddImageEncoded(test_image_string, input_sequence.get());
|
||||
}
|
||||
|
||||
mpms::AddFeatureFloats("FDENSE_MAX", {3.0f, 4.0f}, input_sequence.get());
|
||||
mpms::AddFeatureTimestamp("FDENSE_MAX", Timestamp::PostStream().Value(),
|
||||
input_sequence.get());
|
||||
|
||||
runner_->MutableSidePackets()->Tag("SEQUENCE_EXAMPLE") =
|
||||
Adopt(input_sequence.release());
|
||||
|
||||
MP_ASSERT_OK(runner_->Run());
|
||||
|
||||
const std::vector<Packet>& output_packets =
|
||||
runner_->Outputs().Tag("IMAGE").packets;
|
||||
ASSERT_EQ(num_images, output_packets.size());
|
||||
|
||||
for (int i = 0; i < num_images; ++i) {
|
||||
const std::string& output_image = output_packets[i].Get<std::string>();
|
||||
ASSERT_EQ(output_image, test_image_string);
|
||||
}
|
||||
}
|
||||
|
||||
TEST_F(UnpackMediaSequenceCalculatorTest, UnpacksPostStreamFloatListWithImage) {
|
||||
SetUpCalculator({"FLOAT_FEATURE_FDENSE_MAX:max"}, {});
|
||||
auto input_sequence = absl::make_unique<tf::SequenceExample>();
|
||||
std::string test_video_id = "test_video_id";
|
||||
mpms::SetClipMediaId(test_video_id, input_sequence.get());
|
||||
|
||||
std::string test_image_string = "test_image_string";
|
||||
|
||||
int num_images = 1;
|
||||
for (int i = 0; i < num_images; ++i) {
|
||||
mpms::AddImageTimestamp(i, input_sequence.get());
|
||||
mpms::AddImageEncoded(test_image_string, input_sequence.get());
|
||||
}
|
||||
|
||||
mpms::AddFeatureFloats("FDENSE_MAX", {3.0f, 4.0f}, input_sequence.get());
|
||||
mpms::AddFeatureTimestamp("FDENSE_MAX", Timestamp::PostStream().Value(),
|
||||
input_sequence.get());
|
||||
|
||||
runner_->MutableSidePackets()->Tag("SEQUENCE_EXAMPLE") =
|
||||
Adopt(input_sequence.release());
|
||||
|
||||
MP_ASSERT_OK(runner_->Run());
|
||||
|
||||
const std::vector<Packet>& fdense_max_packets =
|
||||
runner_->Outputs().Tag("FLOAT_FEATURE_FDENSE_MAX").packets;
|
||||
ASSERT_EQ(fdense_max_packets.size(), 1);
|
||||
const auto& fdense_max_vector =
|
||||
fdense_max_packets[0].Get<std::vector<float>>();
|
||||
ASSERT_THAT(fdense_max_vector, ::testing::ElementsAreArray({3.0f, 4.0f}));
|
||||
ASSERT_THAT(fdense_max_packets[0].Timestamp(),
|
||||
::testing::Eq(Timestamp::PostStream()));
|
||||
}
|
||||
|
||||
TEST_F(UnpackMediaSequenceCalculatorTest, GetDatasetFromPacket) {
|
||||
SetUpCalculator({}, {"DATA_PATH:data_path"}, {"DATASET_ROOT:root"});
|
||||
|
||||
|
|
|
@ -904,7 +904,8 @@ absl::Status TfLiteInferenceCalculator::LoadDelegate(CalculatorContext* cc) {
|
|||
#if MEDIAPIPE_TFLITE_GL_INFERENCE
|
||||
// Configure and create the delegate.
|
||||
TfLiteGpuDelegateOptions options = TfLiteGpuDelegateOptionsDefault();
|
||||
options.compile_options.precision_loss_allowed = 1;
|
||||
options.compile_options.precision_loss_allowed =
|
||||
allow_precision_loss_ ? 1 : 0;
|
||||
options.compile_options.preferred_gl_object_type =
|
||||
TFLITE_GL_OBJECT_TYPE_FASTEST;
|
||||
options.compile_options.dynamic_batch_enabled = 0;
|
||||
|
@ -968,7 +969,7 @@ absl::Status TfLiteInferenceCalculator::LoadDelegate(CalculatorContext* cc) {
|
|||
const int kHalfSize = 2; // sizeof(half)
|
||||
// Configure and create the delegate.
|
||||
TFLGpuDelegateOptions options;
|
||||
options.allow_precision_loss = true;
|
||||
options.allow_precision_loss = allow_precision_loss_;
|
||||
options.wait_type = TFLGpuDelegateWaitType::TFLGpuDelegateWaitTypeActive;
|
||||
if (!delegate_)
|
||||
delegate_ = TfLiteDelegatePtr(TFLGpuDelegateCreate(&options),
|
||||
|
@ -1080,9 +1081,10 @@ absl::Status TfLiteInferenceCalculator::LoadDelegate(CalculatorContext* cc) {
|
|||
}
|
||||
|
||||
// Create converter for GPU output.
|
||||
converter_from_BPHWC4_ = [[TFLBufferConvert alloc] initWithDevice:device
|
||||
isFloat16:true
|
||||
convertToPBHWC4:false];
|
||||
converter_from_BPHWC4_ =
|
||||
[[TFLBufferConvert alloc] initWithDevice:device
|
||||
isFloat16:allow_precision_loss_
|
||||
convertToPBHWC4:false];
|
||||
if (converter_from_BPHWC4_ == nil) {
|
||||
return absl::InternalError(
|
||||
"Error initializating output buffer converter");
|
||||
|
|
|
@ -439,7 +439,7 @@ absl::Status TfLiteTensorsToSegmentationCalculator::ProcessGpu(
|
|||
|
||||
// Run shader, upsample result.
|
||||
{
|
||||
gpu_helper_.BindFramebuffer(output_texture); // GL_TEXTURE0
|
||||
gpu_helper_.BindFramebuffer(output_texture);
|
||||
glActiveTexture(GL_TEXTURE1);
|
||||
glBindTexture(GL_TEXTURE_2D, small_mask_texture.id());
|
||||
GlRender();
|
||||
|
|
|
@ -821,6 +821,25 @@ cc_library(
|
|||
alwayslink = 1,
|
||||
)
|
||||
|
||||
cc_test(
|
||||
name = "landmark_projection_calculator_test",
|
||||
srcs = ["landmark_projection_calculator_test.cc"],
|
||||
deps = [
|
||||
":landmark_projection_calculator",
|
||||
"//mediapipe/calculators/tensor:image_to_tensor_utils",
|
||||
"//mediapipe/framework:calculator_cc_proto",
|
||||
"//mediapipe/framework:calculator_framework",
|
||||
"//mediapipe/framework:calculator_runner",
|
||||
"//mediapipe/framework/deps:message_matchers",
|
||||
"//mediapipe/framework/formats:landmark_cc_proto",
|
||||
"//mediapipe/framework/formats:rect_cc_proto",
|
||||
"//mediapipe/framework/port:gtest_main",
|
||||
"//mediapipe/framework/port:parse_text_proto",
|
||||
"@com_google_absl//absl/memory",
|
||||
"@com_google_googletest//:gtest_main",
|
||||
],
|
||||
)
|
||||
|
||||
mediapipe_proto_library(
|
||||
name = "landmarks_smoothing_calculator_proto",
|
||||
srcs = ["landmarks_smoothing_calculator.proto"],
|
||||
|
@ -1252,3 +1271,45 @@ cc_test(
|
|||
"//mediapipe/framework/port:parse_text_proto",
|
||||
],
|
||||
)
|
||||
|
||||
mediapipe_proto_library(
|
||||
name = "refine_landmarks_from_heatmap_calculator_proto",
|
||||
srcs = ["refine_landmarks_from_heatmap_calculator.proto"],
|
||||
visibility = ["//visibility:public"],
|
||||
deps = [
|
||||
"//mediapipe/framework:calculator_options_proto",
|
||||
"//mediapipe/framework:calculator_proto",
|
||||
],
|
||||
)
|
||||
|
||||
cc_library(
|
||||
name = "refine_landmarks_from_heatmap_calculator",
|
||||
srcs = ["refine_landmarks_from_heatmap_calculator.cc"],
|
||||
hdrs = ["refine_landmarks_from_heatmap_calculator.h"],
|
||||
copts = select({
|
||||
"//mediapipe:apple": [
|
||||
"-x objective-c++",
|
||||
"-fobjc-arc", # enable reference-counting
|
||||
],
|
||||
"//conditions:default": [],
|
||||
}),
|
||||
visibility = ["//visibility:public"],
|
||||
deps = [
|
||||
":refine_landmarks_from_heatmap_calculator_cc_proto",
|
||||
"//mediapipe/framework:calculator_framework",
|
||||
"//mediapipe/framework/api2:node",
|
||||
"//mediapipe/framework/formats:landmark_cc_proto",
|
||||
"//mediapipe/framework/formats:tensor",
|
||||
"//mediapipe/framework/port:statusor",
|
||||
],
|
||||
alwayslink = 1,
|
||||
)
|
||||
|
||||
cc_test(
|
||||
name = "refine_landmarks_from_heatmap_calculator_test",
|
||||
srcs = ["refine_landmarks_from_heatmap_calculator_test.cc"],
|
||||
deps = [
|
||||
":refine_landmarks_from_heatmap_calculator",
|
||||
"//mediapipe/framework/port:gtest_main",
|
||||
],
|
||||
)
|
||||
|
|
|
@ -402,7 +402,7 @@ absl::Status AnnotationOverlayCalculator::RenderToGpu(CalculatorContext* cc,
|
|||
|
||||
// Blend overlay image in GPU shader.
|
||||
{
|
||||
gpu_helper_.BindFramebuffer(output_texture); // GL_TEXTURE0
|
||||
gpu_helper_.BindFramebuffer(output_texture);
|
||||
|
||||
glActiveTexture(GL_TEXTURE1);
|
||||
glBindTexture(GL_TEXTURE_2D, input_texture.name());
|
||||
|
|
|
@ -54,6 +54,7 @@ class DetectionLabelIdToTextCalculator : public CalculatorBase {
|
|||
|
||||
private:
|
||||
absl::node_hash_map<int, std::string> label_map_;
|
||||
::mediapipe::DetectionLabelIdToTextCalculatorOptions options_;
|
||||
};
|
||||
REGISTER_CALCULATOR(DetectionLabelIdToTextCalculator);
|
||||
|
||||
|
@ -68,13 +69,13 @@ absl::Status DetectionLabelIdToTextCalculator::GetContract(
|
|||
absl::Status DetectionLabelIdToTextCalculator::Open(CalculatorContext* cc) {
|
||||
cc->SetOffset(TimestampDiff(0));
|
||||
|
||||
const auto& options =
|
||||
options_ =
|
||||
cc->Options<::mediapipe::DetectionLabelIdToTextCalculatorOptions>();
|
||||
|
||||
if (options.has_label_map_path()) {
|
||||
if (options_.has_label_map_path()) {
|
||||
std::string string_path;
|
||||
ASSIGN_OR_RETURN(string_path,
|
||||
PathToResourceAsFile(options.label_map_path()));
|
||||
PathToResourceAsFile(options_.label_map_path()));
|
||||
std::string label_map_string;
|
||||
MP_RETURN_IF_ERROR(file::GetContents(string_path, &label_map_string));
|
||||
|
||||
|
@ -85,8 +86,8 @@ absl::Status DetectionLabelIdToTextCalculator::Open(CalculatorContext* cc) {
|
|||
label_map_[i++] = line;
|
||||
}
|
||||
} else {
|
||||
for (int i = 0; i < options.label_size(); ++i) {
|
||||
label_map_[i] = options.label(i);
|
||||
for (int i = 0; i < options_.label_size(); ++i) {
|
||||
label_map_[i] = options_.label(i);
|
||||
}
|
||||
}
|
||||
return absl::OkStatus();
|
||||
|
@ -106,7 +107,7 @@ absl::Status DetectionLabelIdToTextCalculator::Process(CalculatorContext* cc) {
|
|||
}
|
||||
}
|
||||
// Remove label_id field if text labels exist.
|
||||
if (has_text_label) {
|
||||
if (has_text_label && !options_.keep_label_id()) {
|
||||
output_detection.clear_label_id();
|
||||
}
|
||||
}
|
||||
|
|
|
@ -31,4 +31,9 @@ message DetectionLabelIdToTextCalculatorOptions {
|
|||
// label: "label for id 1"
|
||||
// ...
|
||||
repeated string label = 2;
|
||||
|
||||
// By default, the `label_id` field from the input is stripped if a text label
|
||||
// could be found. By setting this field to true, it is always copied to the
|
||||
// output detections.
|
||||
optional bool keep_label_id = 3;
|
||||
}
|
||||
|
|
|
@ -120,7 +120,11 @@ absl::Status LabelsToRenderDataCalculator::Process(CalculatorContext* cc) {
|
|||
labels.resize(classifications.classification_size());
|
||||
scores.resize(classifications.classification_size());
|
||||
for (int i = 0; i < classifications.classification_size(); ++i) {
|
||||
labels[i] = classifications.classification(i).label();
|
||||
if (options_.use_display_name()) {
|
||||
labels[i] = classifications.classification(i).display_name();
|
||||
} else {
|
||||
labels[i] = classifications.classification(i).label();
|
||||
}
|
||||
scores[i] = classifications.classification(i).score();
|
||||
}
|
||||
} else {
|
||||
|
|
|
@ -59,4 +59,7 @@ message LabelsToRenderDataCalculatorOptions {
|
|||
BOTTOM_LEFT = 1;
|
||||
}
|
||||
optional Location location = 6 [default = TOP_LEFT];
|
||||
|
||||
// Uses Classification.display_name field instead of Classification.label.
|
||||
optional bool use_display_name = 9 [default = false];
|
||||
}
|
||||
|
|
|
@ -13,6 +13,7 @@
|
|||
// limitations under the License.
|
||||
|
||||
#include <cmath>
|
||||
#include <functional>
|
||||
#include <vector>
|
||||
|
||||
#include "mediapipe/calculators/util/landmark_projection_calculator.pb.h"
|
||||
|
@ -27,20 +28,32 @@ namespace {
|
|||
|
||||
constexpr char kLandmarksTag[] = "NORM_LANDMARKS";
|
||||
constexpr char kRectTag[] = "NORM_RECT";
|
||||
constexpr char kProjectionMatrix[] = "PROJECTION_MATRIX";
|
||||
|
||||
} // namespace
|
||||
|
||||
// Projects normalized landmarks in a rectangle to its original coordinates. The
|
||||
// rectangle must also be in normalized coordinates.
|
||||
// Projects normalized landmarks to its original coordinates.
|
||||
// Input:
|
||||
// NORM_LANDMARKS: A NormalizedLandmarkList representing landmarks
|
||||
// in a normalized rectangle.
|
||||
// NORM_RECT: An NormalizedRect representing a normalized rectangle in image
|
||||
// coordinates.
|
||||
// NORM_LANDMARKS - NormalizedLandmarkList
|
||||
// Represents landmarks in a normalized rectangle if NORM_RECT is specified
|
||||
// or landmarks that should be projected using PROJECTION_MATRIX if
|
||||
// specified. (Prefer using PROJECTION_MATRIX as it eliminates need of
|
||||
// letterbox removal step.)
|
||||
// NORM_RECT - NormalizedRect
|
||||
// Represents a normalized rectangle in image coordinates and results in
|
||||
// landmarks with their locations adjusted to the image.
|
||||
// PROJECTION_MATRIX - std::array<float, 16>
|
||||
// A 4x4 row-major-order matrix that maps landmarks' locations from one
|
||||
// coordinate system to another. In this case from the coordinate system of
|
||||
// the normalized region of interest to the coordinate system of the image.
|
||||
//
|
||||
// Note: either NORM_RECT or PROJECTION_MATRIX has to be specified.
|
||||
// Note: landmark's Z is projected in a custom way - it's scaled by width of
|
||||
// the normalized region of interest used during landmarks detection.
|
||||
//
|
||||
// Output:
|
||||
// NORM_LANDMARKS: A NormalizedLandmarkList representing landmarks
|
||||
// with their locations adjusted to the image.
|
||||
// NORM_LANDMARKS - NormalizedLandmarkList
|
||||
// Landmarks with their locations adjusted according to the inputs.
|
||||
//
|
||||
// Usage example:
|
||||
// node {
|
||||
|
@ -58,12 +71,27 @@ constexpr char kRectTag[] = "NORM_RECT";
|
|||
// output_stream: "NORM_LANDMARKS:0:projected_landmarks_0"
|
||||
// output_stream: "NORM_LANDMARKS:1:projected_landmarks_1"
|
||||
// }
|
||||
//
|
||||
// node {
|
||||
// calculator: "LandmarkProjectionCalculator"
|
||||
// input_stream: "NORM_LANDMARKS:landmarks"
|
||||
// input_stream: "PROECTION_MATRIX:matrix"
|
||||
// output_stream: "NORM_LANDMARKS:projected_landmarks"
|
||||
// }
|
||||
//
|
||||
// node {
|
||||
// calculator: "LandmarkProjectionCalculator"
|
||||
// input_stream: "NORM_LANDMARKS:0:landmarks_0"
|
||||
// input_stream: "NORM_LANDMARKS:1:landmarks_1"
|
||||
// input_stream: "PROECTION_MATRIX:matrix"
|
||||
// output_stream: "NORM_LANDMARKS:0:projected_landmarks_0"
|
||||
// output_stream: "NORM_LANDMARKS:1:projected_landmarks_1"
|
||||
// }
|
||||
class LandmarkProjectionCalculator : public CalculatorBase {
|
||||
public:
|
||||
static absl::Status GetContract(CalculatorContract* cc) {
|
||||
RET_CHECK(cc->Inputs().HasTag(kLandmarksTag) &&
|
||||
cc->Inputs().HasTag(kRectTag))
|
||||
<< "Missing one or more input streams.";
|
||||
RET_CHECK(cc->Inputs().HasTag(kLandmarksTag))
|
||||
<< "Missing NORM_LANDMARKS input.";
|
||||
|
||||
RET_CHECK_EQ(cc->Inputs().NumEntries(kLandmarksTag),
|
||||
cc->Outputs().NumEntries(kLandmarksTag))
|
||||
|
@ -73,7 +101,14 @@ class LandmarkProjectionCalculator : public CalculatorBase {
|
|||
id != cc->Inputs().EndId(kLandmarksTag); ++id) {
|
||||
cc->Inputs().Get(id).Set<NormalizedLandmarkList>();
|
||||
}
|
||||
cc->Inputs().Tag(kRectTag).Set<NormalizedRect>();
|
||||
RET_CHECK(cc->Inputs().HasTag(kRectTag) ^
|
||||
cc->Inputs().HasTag(kProjectionMatrix))
|
||||
<< "Either NORM_RECT or PROJECTION_MATRIX must be specified.";
|
||||
if (cc->Inputs().HasTag(kRectTag)) {
|
||||
cc->Inputs().Tag(kRectTag).Set<NormalizedRect>();
|
||||
} else {
|
||||
cc->Inputs().Tag(kProjectionMatrix).Set<std::array<float, 16>>();
|
||||
}
|
||||
|
||||
for (CollectionItemId id = cc->Outputs().BeginId(kLandmarksTag);
|
||||
id != cc->Outputs().EndId(kLandmarksTag); ++id) {
|
||||
|
@ -89,31 +124,50 @@ class LandmarkProjectionCalculator : public CalculatorBase {
|
|||
return absl::OkStatus();
|
||||
}
|
||||
|
||||
static void ProjectXY(const NormalizedLandmark& lm,
|
||||
const std::array<float, 16>& matrix,
|
||||
NormalizedLandmark* out) {
|
||||
out->set_x(lm.x() * matrix[0] + lm.y() * matrix[1] + lm.z() * matrix[2] +
|
||||
matrix[3]);
|
||||
out->set_y(lm.x() * matrix[4] + lm.y() * matrix[5] + lm.z() * matrix[6] +
|
||||
matrix[7]);
|
||||
}
|
||||
|
||||
/**
|
||||
* Landmark's Z scale is equal to a relative (to image) width of region of
|
||||
* interest used during detection. To calculate based on matrix:
|
||||
* 1. Project (0,0) --- (1,0) segment using matrix.
|
||||
* 2. Calculate length of the projected segment.
|
||||
*/
|
||||
static float CalculateZScale(const std::array<float, 16>& matrix) {
|
||||
NormalizedLandmark a;
|
||||
a.set_x(0.0f);
|
||||
a.set_y(0.0f);
|
||||
NormalizedLandmark b;
|
||||
b.set_x(1.0f);
|
||||
b.set_y(0.0f);
|
||||
NormalizedLandmark a_projected;
|
||||
ProjectXY(a, matrix, &a_projected);
|
||||
NormalizedLandmark b_projected;
|
||||
ProjectXY(b, matrix, &b_projected);
|
||||
return std::sqrt(std::pow(b_projected.x() - a_projected.x(), 2) +
|
||||
std::pow(b_projected.y() - a_projected.y(), 2));
|
||||
}
|
||||
|
||||
absl::Status Process(CalculatorContext* cc) override {
|
||||
if (cc->Inputs().Tag(kRectTag).IsEmpty()) {
|
||||
return absl::OkStatus();
|
||||
}
|
||||
const auto& input_rect = cc->Inputs().Tag(kRectTag).Get<NormalizedRect>();
|
||||
|
||||
const auto& options =
|
||||
cc->Options<::mediapipe::LandmarkProjectionCalculatorOptions>();
|
||||
|
||||
CollectionItemId input_id = cc->Inputs().BeginId(kLandmarksTag);
|
||||
CollectionItemId output_id = cc->Outputs().BeginId(kLandmarksTag);
|
||||
// Number of inputs and outpus is the same according to the contract.
|
||||
for (; input_id != cc->Inputs().EndId(kLandmarksTag);
|
||||
++input_id, ++output_id) {
|
||||
const auto& input_packet = cc->Inputs().Get(input_id);
|
||||
if (input_packet.IsEmpty()) {
|
||||
continue;
|
||||
std::function<void(const NormalizedLandmark&, NormalizedLandmark*)>
|
||||
project_fn;
|
||||
if (cc->Inputs().HasTag(kRectTag)) {
|
||||
if (cc->Inputs().Tag(kRectTag).IsEmpty()) {
|
||||
return absl::OkStatus();
|
||||
}
|
||||
|
||||
const auto& input_landmarks = input_packet.Get<NormalizedLandmarkList>();
|
||||
NormalizedLandmarkList output_landmarks;
|
||||
for (int i = 0; i < input_landmarks.landmark_size(); ++i) {
|
||||
const NormalizedLandmark& landmark = input_landmarks.landmark(i);
|
||||
NormalizedLandmark* new_landmark = output_landmarks.add_landmark();
|
||||
|
||||
const auto& input_rect = cc->Inputs().Tag(kRectTag).Get<NormalizedRect>();
|
||||
const auto& options =
|
||||
cc->Options<mediapipe::LandmarkProjectionCalculatorOptions>();
|
||||
project_fn = [&input_rect, &options](const NormalizedLandmark& landmark,
|
||||
NormalizedLandmark* new_landmark) {
|
||||
// TODO: fix projection or deprecate (current projection
|
||||
// calculations are incorrect for general case).
|
||||
const float x = landmark.x() - 0.5f;
|
||||
const float y = landmark.y() - 0.5f;
|
||||
const float angle =
|
||||
|
@ -130,10 +184,44 @@ class LandmarkProjectionCalculator : public CalculatorBase {
|
|||
new_landmark->set_x(new_x);
|
||||
new_landmark->set_y(new_y);
|
||||
new_landmark->set_z(new_z);
|
||||
};
|
||||
} else if (cc->Inputs().HasTag(kProjectionMatrix)) {
|
||||
if (cc->Inputs().Tag(kProjectionMatrix).IsEmpty()) {
|
||||
return absl::OkStatus();
|
||||
}
|
||||
const auto& project_mat =
|
||||
cc->Inputs().Tag(kProjectionMatrix).Get<std::array<float, 16>>();
|
||||
const float z_scale = CalculateZScale(project_mat);
|
||||
project_fn = [&project_mat, z_scale](const NormalizedLandmark& lm,
|
||||
NormalizedLandmark* new_landmark) {
|
||||
*new_landmark = lm;
|
||||
ProjectXY(lm, project_mat, new_landmark);
|
||||
new_landmark->set_z(z_scale * lm.z());
|
||||
};
|
||||
} else {
|
||||
return absl::InternalError("Either rect or matrix must be specified.");
|
||||
}
|
||||
|
||||
CollectionItemId input_id = cc->Inputs().BeginId(kLandmarksTag);
|
||||
CollectionItemId output_id = cc->Outputs().BeginId(kLandmarksTag);
|
||||
// Number of inputs and outpus is the same according to the contract.
|
||||
for (; input_id != cc->Inputs().EndId(kLandmarksTag);
|
||||
++input_id, ++output_id) {
|
||||
const auto& input_packet = cc->Inputs().Get(input_id);
|
||||
if (input_packet.IsEmpty()) {
|
||||
continue;
|
||||
}
|
||||
|
||||
const auto& input_landmarks = input_packet.Get<NormalizedLandmarkList>();
|
||||
NormalizedLandmarkList output_landmarks;
|
||||
for (int i = 0; i < input_landmarks.landmark_size(); ++i) {
|
||||
const NormalizedLandmark& landmark = input_landmarks.landmark(i);
|
||||
NormalizedLandmark* new_landmark = output_landmarks.add_landmark();
|
||||
project_fn(landmark, new_landmark);
|
||||
}
|
||||
|
||||
cc->Outputs().Get(output_id).AddPacket(
|
||||
MakePacket<NormalizedLandmarkList>(output_landmarks)
|
||||
MakePacket<NormalizedLandmarkList>(std::move(output_landmarks))
|
||||
.At(cc->InputTimestamp()));
|
||||
}
|
||||
return absl::OkStatus();
|
||||
|
|
|
@ -0,0 +1,240 @@
|
|||
#include <array>
|
||||
#include <vector>
|
||||
|
||||
#include "absl/memory/memory.h"
|
||||
#include "mediapipe/calculators/tensor/image_to_tensor_utils.h"
|
||||
#include "mediapipe/framework/calculator.pb.h"
|
||||
#include "mediapipe/framework/calculator_framework.h"
|
||||
#include "mediapipe/framework/calculator_runner.h"
|
||||
#include "mediapipe/framework/deps/message_matchers.h"
|
||||
#include "mediapipe/framework/formats/landmark.pb.h"
|
||||
#include "mediapipe/framework/formats/rect.pb.h"
|
||||
#include "mediapipe/framework/port/gtest.h"
|
||||
#include "mediapipe/framework/port/parse_text_proto.h"
|
||||
#include "mediapipe/framework/port/status_matchers.h"
|
||||
|
||||
namespace mediapipe {
|
||||
namespace {
|
||||
|
||||
absl::StatusOr<mediapipe::NormalizedLandmarkList> RunCalculator(
|
||||
mediapipe::NormalizedLandmarkList input, mediapipe::NormalizedRect rect) {
|
||||
mediapipe::CalculatorRunner runner(
|
||||
ParseTextProtoOrDie<mediapipe::CalculatorGraphConfig::Node>(R"pb(
|
||||
calculator: "LandmarkProjectionCalculator"
|
||||
input_stream: "NORM_LANDMARKS:landmarks"
|
||||
input_stream: "NORM_RECT:rect"
|
||||
output_stream: "NORM_LANDMARKS:projected_landmarks"
|
||||
)pb"));
|
||||
runner.MutableInputs()
|
||||
->Tag("NORM_LANDMARKS")
|
||||
.packets.push_back(
|
||||
MakePacket<mediapipe::NormalizedLandmarkList>(std::move(input))
|
||||
.At(Timestamp(1)));
|
||||
runner.MutableInputs()
|
||||
->Tag("NORM_RECT")
|
||||
.packets.push_back(MakePacket<mediapipe::NormalizedRect>(std::move(rect))
|
||||
.At(Timestamp(1)));
|
||||
|
||||
MP_RETURN_IF_ERROR(runner.Run());
|
||||
const auto& output_packets = runner.Outputs().Tag("NORM_LANDMARKS").packets;
|
||||
RET_CHECK_EQ(output_packets.size(), 1);
|
||||
return output_packets[0].Get<mediapipe::NormalizedLandmarkList>();
|
||||
}
|
||||
|
||||
TEST(LandmarkProjectionCalculatorTest, ProjectingWithDefaultRect) {
|
||||
mediapipe::NormalizedLandmarkList landmarks =
|
||||
ParseTextProtoOrDie<mediapipe::NormalizedLandmarkList>(R"pb(
|
||||
landmark { x: 10, y: 20, z: -0.5 }
|
||||
)pb");
|
||||
mediapipe::NormalizedRect rect =
|
||||
ParseTextProtoOrDie<mediapipe::NormalizedRect>(
|
||||
R"pb(
|
||||
x_center: 0.5,
|
||||
y_center: 0.5,
|
||||
width: 1.0,
|
||||
height: 1.0,
|
||||
rotation: 0.0
|
||||
)pb");
|
||||
|
||||
auto status_or_result = RunCalculator(std::move(landmarks), std::move(rect));
|
||||
MP_ASSERT_OK(status_or_result);
|
||||
|
||||
EXPECT_THAT(
|
||||
status_or_result.value(),
|
||||
EqualsProto(ParseTextProtoOrDie<mediapipe::NormalizedLandmarkList>(R"pb(
|
||||
landmark { x: 10, y: 20, z: -0.5 }
|
||||
)pb")));
|
||||
}
|
||||
|
||||
mediapipe::NormalizedRect GetCroppedRect() {
|
||||
return ParseTextProtoOrDie<mediapipe::NormalizedRect>(
|
||||
R"pb(
|
||||
x_center: 0.5, y_center: 0.5, width: 0.5, height: 2, rotation: 0.0
|
||||
)pb");
|
||||
}
|
||||
|
||||
mediapipe::NormalizedLandmarkList GetCroppedRectTestInput() {
|
||||
return ParseTextProtoOrDie<mediapipe::NormalizedLandmarkList>(R"pb(
|
||||
landmark { x: 1.0, y: 1.0, z: -0.5 }
|
||||
)pb");
|
||||
}
|
||||
|
||||
mediapipe::NormalizedLandmarkList GetCroppedRectTestExpectedResult() {
|
||||
return ParseTextProtoOrDie<mediapipe::NormalizedLandmarkList>(R"pb(
|
||||
landmark { x: 0.75, y: 1.5, z: -0.25 }
|
||||
)pb");
|
||||
}
|
||||
|
||||
TEST(LandmarkProjectionCalculatorTest, ProjectingWithCroppedRect) {
|
||||
auto status_or_result =
|
||||
RunCalculator(GetCroppedRectTestInput(), GetCroppedRect());
|
||||
MP_ASSERT_OK(status_or_result);
|
||||
|
||||
EXPECT_THAT(status_or_result.value(),
|
||||
EqualsProto(GetCroppedRectTestExpectedResult()));
|
||||
}
|
||||
|
||||
absl::StatusOr<mediapipe::NormalizedLandmarkList> RunCalculator(
|
||||
mediapipe::NormalizedLandmarkList input, std::array<float, 16> matrix) {
|
||||
mediapipe::CalculatorRunner runner(
|
||||
ParseTextProtoOrDie<mediapipe::CalculatorGraphConfig::Node>(R"pb(
|
||||
calculator: "LandmarkProjectionCalculator"
|
||||
input_stream: "NORM_LANDMARKS:landmarks"
|
||||
input_stream: "PROJECTION_MATRIX:matrix"
|
||||
output_stream: "NORM_LANDMARKS:projected_landmarks"
|
||||
)pb"));
|
||||
runner.MutableInputs()
|
||||
->Tag("NORM_LANDMARKS")
|
||||
.packets.push_back(
|
||||
MakePacket<mediapipe::NormalizedLandmarkList>(std::move(input))
|
||||
.At(Timestamp(1)));
|
||||
runner.MutableInputs()
|
||||
->Tag("PROJECTION_MATRIX")
|
||||
.packets.push_back(MakePacket<std::array<float, 16>>(std::move(matrix))
|
||||
.At(Timestamp(1)));
|
||||
|
||||
MP_RETURN_IF_ERROR(runner.Run());
|
||||
const auto& output_packets = runner.Outputs().Tag("NORM_LANDMARKS").packets;
|
||||
RET_CHECK_EQ(output_packets.size(), 1);
|
||||
return output_packets[0].Get<mediapipe::NormalizedLandmarkList>();
|
||||
}
|
||||
|
||||
TEST(LandmarkProjectionCalculatorTest, ProjectingWithIdentityMatrix) {
|
||||
mediapipe::NormalizedLandmarkList landmarks =
|
||||
ParseTextProtoOrDie<mediapipe::NormalizedLandmarkList>(R"pb(
|
||||
landmark { x: 10, y: 20, z: -0.5 }
|
||||
)pb");
|
||||
// clang-format off
|
||||
std::array<float, 16> matrix = {
|
||||
1.0f, 0.0f, 0.0f, 0.0f,
|
||||
0.0f, 1.0f, 0.0f, 0.0f,
|
||||
0.0f, 0.0f, 1.0f, 0.0f,
|
||||
0.0f, 0.0f, 0.0f, 1.0f,
|
||||
};
|
||||
// clang-format on
|
||||
|
||||
auto status_or_result =
|
||||
RunCalculator(std::move(landmarks), std::move(matrix));
|
||||
MP_ASSERT_OK(status_or_result);
|
||||
|
||||
EXPECT_THAT(
|
||||
status_or_result.value(),
|
||||
EqualsProto(ParseTextProtoOrDie<mediapipe::NormalizedLandmarkList>(R"pb(
|
||||
landmark { x: 10, y: 20, z: -0.5 }
|
||||
)pb")));
|
||||
}
|
||||
|
||||
TEST(LandmarkProjectionCalculatorTest, ProjectingWithCroppedRectMatrix) {
|
||||
constexpr int kRectWidth = 1280;
|
||||
constexpr int kRectHeight = 720;
|
||||
auto roi = GetRoi(kRectWidth, kRectHeight, GetCroppedRect());
|
||||
std::array<float, 16> matrix;
|
||||
GetRotatedSubRectToRectTransformMatrix(roi, kRectWidth, kRectHeight,
|
||||
/*flip_horizontaly=*/false, &matrix);
|
||||
auto status_or_result = RunCalculator(GetCroppedRectTestInput(), matrix);
|
||||
MP_ASSERT_OK(status_or_result);
|
||||
|
||||
EXPECT_THAT(status_or_result.value(),
|
||||
EqualsProto(GetCroppedRectTestExpectedResult()));
|
||||
}
|
||||
|
||||
TEST(LandmarkProjectionCalculatorTest, ProjectingWithScaleMatrix) {
|
||||
mediapipe::NormalizedLandmarkList landmarks =
|
||||
ParseTextProtoOrDie<mediapipe::NormalizedLandmarkList>(R"pb(
|
||||
landmark { x: 10, y: 20, z: -0.5 }
|
||||
landmark { x: 5, y: 6, z: 7 }
|
||||
)pb");
|
||||
// clang-format off
|
||||
std::array<float, 16> matrix = {
|
||||
10.0f, 0.0f, 0.0f, 0.0f,
|
||||
0.0f, 100.0f, 0.0f, 0.0f,
|
||||
0.0f, 0.0f, 1.0f, 0.0f,
|
||||
0.0f, 0.0f, 0.0f, 1.0f,
|
||||
};
|
||||
// clang-format on
|
||||
|
||||
auto status_or_result =
|
||||
RunCalculator(std::move(landmarks), std::move(matrix));
|
||||
MP_ASSERT_OK(status_or_result);
|
||||
|
||||
EXPECT_THAT(
|
||||
status_or_result.value(),
|
||||
EqualsProto(ParseTextProtoOrDie<mediapipe::NormalizedLandmarkList>(R"pb(
|
||||
landmark { x: 100, y: 2000, z: -5 }
|
||||
landmark { x: 50, y: 600, z: 70 }
|
||||
)pb")));
|
||||
}
|
||||
|
||||
TEST(LandmarkProjectionCalculatorTest, ProjectingWithTranslateMatrix) {
|
||||
mediapipe::NormalizedLandmarkList landmarks =
|
||||
ParseTextProtoOrDie<mediapipe::NormalizedLandmarkList>(R"pb(
|
||||
landmark { x: 10, y: 20, z: -0.5 }
|
||||
)pb");
|
||||
// clang-format off
|
||||
std::array<float, 16> matrix = {
|
||||
1.0f, 0.0f, 0.0f, 1.0f,
|
||||
0.0f, 1.0f, 0.0f, 2.0f,
|
||||
0.0f, 0.0f, 1.0f, 0.0f,
|
||||
0.0f, 0.0f, 0.0f, 1.0f,
|
||||
};
|
||||
// clang-format on
|
||||
|
||||
auto status_or_result =
|
||||
RunCalculator(std::move(landmarks), std::move(matrix));
|
||||
MP_ASSERT_OK(status_or_result);
|
||||
|
||||
EXPECT_THAT(
|
||||
status_or_result.value(),
|
||||
EqualsProto(ParseTextProtoOrDie<mediapipe::NormalizedLandmarkList>(R"pb(
|
||||
landmark { x: 11, y: 22, z: -0.5 }
|
||||
)pb")));
|
||||
}
|
||||
|
||||
TEST(LandmarkProjectionCalculatorTest, ProjectingWithRotationMatrix) {
|
||||
mediapipe::NormalizedLandmarkList landmarks =
|
||||
ParseTextProtoOrDie<mediapipe::NormalizedLandmarkList>(R"pb(
|
||||
landmark { x: 4, y: 0, z: -0.5 }
|
||||
)pb");
|
||||
// clang-format off
|
||||
// 90 degrees rotation matrix
|
||||
std::array<float, 16> matrix = {
|
||||
0.0f, -1.0f, 0.0f, 0.0f,
|
||||
1.0f, 0.0f, 0.0f, 0.0f,
|
||||
0.0f, 0.0f, 1.0f, 0.0f,
|
||||
0.0f, 0.0f, 0.0f, 1.0f,
|
||||
};
|
||||
// clang-format on
|
||||
|
||||
auto status_or_result =
|
||||
RunCalculator(std::move(landmarks), std::move(matrix));
|
||||
MP_ASSERT_OK(status_or_result);
|
||||
|
||||
EXPECT_THAT(
|
||||
status_or_result.value(),
|
||||
EqualsProto(ParseTextProtoOrDie<mediapipe::NormalizedLandmarkList>(R"pb(
|
||||
landmark { x: 0, y: 4, z: -0.5 }
|
||||
)pb")));
|
||||
}
|
||||
|
||||
} // namespace
|
||||
} // namespace mediapipe
|
|
@ -1,3 +1,17 @@
|
|||
// Copyright 2020 The MediaPipe Authors.
|
||||
//
|
||||
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||
// you may not use this file except in compliance with the License.
|
||||
// You may obtain a copy of the License at
|
||||
//
|
||||
// http://www.apache.org/licenses/LICENSE-2.0
|
||||
//
|
||||
// Unless required by applicable law or agreed to in writing, software
|
||||
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
// See the License for the specific language governing permissions and
|
||||
// limitations under the License.
|
||||
|
||||
#include "mediapipe/calculators/util/rect_to_render_scale_calculator.pb.h"
|
||||
#include "mediapipe/framework/calculator_framework.h"
|
||||
#include "mediapipe/framework/formats/rect.pb.h"
|
||||
|
|
|
@ -1,3 +1,17 @@
|
|||
// Copyright 2020 The MediaPipe Authors.
|
||||
//
|
||||
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||
// you may not use this file except in compliance with the License.
|
||||
// You may obtain a copy of the License at
|
||||
//
|
||||
// http://www.apache.org/licenses/LICENSE-2.0
|
||||
//
|
||||
// Unless required by applicable law or agreed to in writing, software
|
||||
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
// See the License for the specific language governing permissions and
|
||||
// limitations under the License.
|
||||
|
||||
syntax = "proto2";
|
||||
|
||||
package mediapipe;
|
||||
|
|
|
@ -0,0 +1,166 @@
|
|||
// Copyright 2021 The MediaPipe Authors.
|
||||
//
|
||||
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||
// you may not use this file except in compliance with the License.
|
||||
// You may obtain a copy of the License at
|
||||
//
|
||||
// http://www.apache.org/licenses/LICENSE-2.0
|
||||
//
|
||||
// Unless required by applicable law or agreed to in writing, software
|
||||
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
// See the License for the specific language governing permissions and
|
||||
// limitations under the License.
|
||||
|
||||
#include "mediapipe/calculators/util/refine_landmarks_from_heatmap_calculator.h"
|
||||
|
||||
#include "mediapipe/calculators/util/refine_landmarks_from_heatmap_calculator.pb.h"
|
||||
#include "mediapipe/framework/calculator_framework.h"
|
||||
|
||||
namespace mediapipe {
|
||||
|
||||
namespace {
|
||||
|
||||
inline float Sigmoid(float value) { return 1.0f / (1.0f + std::exp(-value)); }
|
||||
|
||||
absl::StatusOr<std::tuple<int, int, int>> GetHwcFromDims(
|
||||
const std::vector<int>& dims) {
|
||||
if (dims.size() == 3) {
|
||||
return std::make_tuple(dims[0], dims[1], dims[2]);
|
||||
} else if (dims.size() == 4) {
|
||||
// BHWC format check B == 1
|
||||
RET_CHECK_EQ(1, dims[0]) << "Expected batch to be 1 for BHWC heatmap";
|
||||
return std::make_tuple(dims[1], dims[2], dims[3]);
|
||||
} else {
|
||||
RET_CHECK(false) << "Invalid shape size for heatmap tensor" << dims.size();
|
||||
}
|
||||
}
|
||||
|
||||
} // namespace
|
||||
|
||||
namespace api2 {
|
||||
|
||||
// Refines landmarks using correspond heatmap area.
|
||||
//
|
||||
// Input:
|
||||
// NORM_LANDMARKS - Required. Input normalized landmarks to update.
|
||||
// TENSORS - Required. Vector of input tensors. 0th element should be heatmap.
|
||||
// The rest is unused.
|
||||
// Output:
|
||||
// NORM_LANDMARKS - Required. Updated normalized landmarks.
|
||||
class RefineLandmarksFromHeatmapCalculatorImpl
|
||||
: public NodeImpl<RefineLandmarksFromHeatmapCalculator,
|
||||
RefineLandmarksFromHeatmapCalculatorImpl> {
|
||||
public:
|
||||
absl::Status Open(CalculatorContext* cc) override { return absl::OkStatus(); }
|
||||
|
||||
absl::Status Process(CalculatorContext* cc) override {
|
||||
// Make sure we bypass landmarks if there is no detection.
|
||||
if (kInLandmarks(cc).IsEmpty()) {
|
||||
return absl::OkStatus();
|
||||
}
|
||||
// If for some reason heatmap is missing, just return original landmarks.
|
||||
if (kInTensors(cc).IsEmpty()) {
|
||||
kOutLandmarks(cc).Send(*kInLandmarks(cc));
|
||||
return absl::OkStatus();
|
||||
}
|
||||
|
||||
// Check basic prerequisites.
|
||||
const auto& input_tensors = *kInTensors(cc);
|
||||
RET_CHECK(!input_tensors.empty()) << "Empty input tensors list. First "
|
||||
"element is expeced to be a heatmap";
|
||||
|
||||
const auto& hm_tensor = input_tensors[0];
|
||||
const auto& in_lms = *kInLandmarks(cc);
|
||||
auto hm_view = hm_tensor.GetCpuReadView();
|
||||
auto hm_raw = hm_view.buffer<float>();
|
||||
const auto& options =
|
||||
cc->Options<mediapipe::RefineLandmarksFromHeatmapCalculatorOptions>();
|
||||
|
||||
ASSIGN_OR_RETURN(auto out_lms, RefineLandmarksFromHeatMap(
|
||||
in_lms, hm_raw, hm_tensor.shape().dims,
|
||||
options.kernel_size(),
|
||||
options.min_confidence_to_refine()));
|
||||
|
||||
kOutLandmarks(cc).Send(std::move(out_lms));
|
||||
return absl::OkStatus();
|
||||
}
|
||||
};
|
||||
|
||||
} // namespace api2
|
||||
|
||||
// Runs actual refinement
|
||||
// High level algorithm:
|
||||
//
|
||||
// Heatmap is accepted as tensor in HWC layout where i-th channel is a heatmap
|
||||
// for the i-th landmark.
|
||||
//
|
||||
// For each landmark we replace original value with a value calculated from the
|
||||
// area in heatmap close to original landmark position (in particular are
|
||||
// covered with kernel of size options.kernel_size). To calculate new coordinate
|
||||
// from heatmap we calculate an weighted average inside the kernel. We update
|
||||
// the landmark iff heatmap is confident in it's prediction i.e. max(heatmap) in
|
||||
// kernel is at least options.min_confidence_to_refine big.
|
||||
absl::StatusOr<mediapipe::NormalizedLandmarkList> RefineLandmarksFromHeatMap(
|
||||
const mediapipe::NormalizedLandmarkList& in_lms,
|
||||
const float* heatmap_raw_data, const std::vector<int>& heatmap_dims,
|
||||
int kernel_size, float min_confidence_to_refine) {
|
||||
ASSIGN_OR_RETURN(auto hm_dims, GetHwcFromDims(heatmap_dims));
|
||||
auto [hm_height, hm_width, hm_channels] = hm_dims;
|
||||
|
||||
RET_CHECK_EQ(in_lms.landmark_size(), hm_channels)
|
||||
<< "Expected heatmap to have number of layers == to number of "
|
||||
"landmarks";
|
||||
|
||||
int hm_row_size = hm_width * hm_channels;
|
||||
int hm_pixel_size = hm_channels;
|
||||
|
||||
mediapipe::NormalizedLandmarkList out_lms = in_lms;
|
||||
for (int lm_index = 0; lm_index < out_lms.landmark_size(); ++lm_index) {
|
||||
int center_col = out_lms.landmark(lm_index).x() * hm_width;
|
||||
int center_row = out_lms.landmark(lm_index).y() * hm_height;
|
||||
// Point is outside of the image let's keep it intact.
|
||||
if (center_col < 0 || center_col >= hm_width || center_row < 0 ||
|
||||
center_col >= hm_height) {
|
||||
continue;
|
||||
}
|
||||
|
||||
int offset = (kernel_size - 1) / 2;
|
||||
// Calculate area to iterate over. Note that we decrease the kernel on
|
||||
// the edges of the heatmap. Equivalent to zero border.
|
||||
int begin_col = std::max(0, center_col - offset);
|
||||
int end_col = std::min(hm_width, center_col + offset + 1);
|
||||
int begin_row = std::max(0, center_row - offset);
|
||||
int end_row = std::min(hm_height, center_row + offset + 1);
|
||||
|
||||
float sum = 0;
|
||||
float weighted_col = 0;
|
||||
float weighted_row = 0;
|
||||
float max_value = 0;
|
||||
|
||||
// Main loop. Go over kernel and calculate weighted sum of coordinates,
|
||||
// sum of weights and max weights.
|
||||
for (int row = begin_row; row < end_row; ++row) {
|
||||
for (int col = begin_col; col < end_col; ++col) {
|
||||
// We expect memory to be in HWC layout without padding.
|
||||
int idx = hm_row_size * row + hm_pixel_size * col + lm_index;
|
||||
// Right now we hardcode sigmoid activation as it will be wasteful to
|
||||
// calculate sigmoid for each value of heatmap in the model itself. If
|
||||
// we ever have other activations it should be trivial to expand via
|
||||
// options.
|
||||
float confidence = Sigmoid(heatmap_raw_data[idx]);
|
||||
sum += confidence;
|
||||
max_value = std::max(max_value, confidence);
|
||||
weighted_col += col * confidence;
|
||||
weighted_row += row * confidence;
|
||||
}
|
||||
}
|
||||
if (max_value >= min_confidence_to_refine && sum > 0) {
|
||||
out_lms.mutable_landmark(lm_index)->set_x(weighted_col / hm_width / sum);
|
||||
out_lms.mutable_landmark(lm_index)->set_y(weighted_row / hm_height / sum);
|
||||
}
|
||||
}
|
||||
return out_lms;
|
||||
}
|
||||
|
||||
} // namespace mediapipe
|
|
@ -0,0 +1,50 @@
|
|||
// Copyright 2021 The MediaPipe Authors.
|
||||
//
|
||||
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||
// you may not use this file except in compliance with the License.
|
||||
// You may obtain a copy of the License at
|
||||
//
|
||||
// http://www.apache.org/licenses/LICENSE-2.0
|
||||
//
|
||||
// Unless required by applicable law or agreed to in writing, software
|
||||
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
// See the License for the specific language governing permissions and
|
||||
// limitations under the License.
|
||||
|
||||
#ifndef MEDIAPIPE_CALCULATORS_UTIL_REFINE_LANDMARKS_FROM_HEATMAP_CALCULATOR_H_
|
||||
#define MEDIAPIPE_CALCULATORS_UTIL_REFINE_LANDMARKS_FROM_HEATMAP_CALCULATOR_H_
|
||||
|
||||
#include <vector>
|
||||
|
||||
#include "mediapipe/framework/api2/node.h"
|
||||
#include "mediapipe/framework/formats/landmark.pb.h"
|
||||
#include "mediapipe/framework/formats/tensor.h"
|
||||
#include "mediapipe/framework/port/statusor.h"
|
||||
|
||||
namespace mediapipe {
|
||||
namespace api2 {
|
||||
|
||||
class RefineLandmarksFromHeatmapCalculator : public NodeIntf {
|
||||
public:
|
||||
static constexpr Input<mediapipe::NormalizedLandmarkList> kInLandmarks{
|
||||
"NORM_LANDMARKS"};
|
||||
static constexpr Input<std::vector<Tensor>> kInTensors{"TENSORS"};
|
||||
static constexpr Output<mediapipe::NormalizedLandmarkList> kOutLandmarks{
|
||||
"NORM_LANDMARKS"};
|
||||
|
||||
MEDIAPIPE_NODE_INTERFACE(RefineLandmarksFromHeatmapCalculator, kInLandmarks,
|
||||
kInTensors, kOutLandmarks);
|
||||
};
|
||||
|
||||
} // namespace api2
|
||||
|
||||
// Exposed for testing.
|
||||
absl::StatusOr<mediapipe::NormalizedLandmarkList> RefineLandmarksFromHeatMap(
|
||||
const mediapipe::NormalizedLandmarkList& in_lms,
|
||||
const float* heatmap_raw_data, const std::vector<int>& heatmap_dims,
|
||||
int kernel_size, float min_confidence_to_refine);
|
||||
|
||||
} // namespace mediapipe
|
||||
|
||||
#endif // MEDIAPIPE_CALCULATORS_UTIL_REFINE_LANDMARKS_FROM_HEATMAP_CALCULATOR_H_
|
|
@ -0,0 +1,27 @@
|
|||
// Copyright 2021 The MediaPipe Authors.
|
||||
//
|
||||
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||
// you may not use this file except in compliance with the License.
|
||||
// You may obtain a copy of the License at
|
||||
//
|
||||
// http://www.apache.org/licenses/LICENSE-2.0
|
||||
//
|
||||
// Unless required by applicable law or agreed to in writing, software
|
||||
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
// See the License for the specific language governing permissions and
|
||||
// limitations under the License.
|
||||
|
||||
syntax = "proto2";
|
||||
|
||||
package mediapipe;
|
||||
|
||||
import "mediapipe/framework/calculator.proto";
|
||||
|
||||
message RefineLandmarksFromHeatmapCalculatorOptions {
|
||||
extend mediapipe.CalculatorOptions {
|
||||
optional RefineLandmarksFromHeatmapCalculatorOptions ext = 362281653;
|
||||
}
|
||||
optional int32 kernel_size = 1 [default = 9];
|
||||
optional float min_confidence_to_refine = 2 [default = 0.5];
|
||||
}
|
|
@ -0,0 +1,152 @@
|
|||
// Copyright 2021 The MediaPipe Authors.
|
||||
//
|
||||
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||
// you may not use this file except in compliance with the License.
|
||||
// You may obtain a copy of the License at
|
||||
//
|
||||
// http://www.apache.org/licenses/LICENSE-2.0
|
||||
//
|
||||
// Unless required by applicable law or agreed to in writing, software
|
||||
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
// See the License for the specific language governing permissions and
|
||||
// limitations under the License.
|
||||
|
||||
#include "mediapipe/calculators/util/refine_landmarks_from_heatmap_calculator.h"
|
||||
|
||||
#include "mediapipe/framework/port/gmock.h"
|
||||
#include "mediapipe/framework/port/gtest.h"
|
||||
#include "mediapipe/framework/port/status_matchers.h"
|
||||
|
||||
namespace mediapipe {
|
||||
namespace {
|
||||
|
||||
mediapipe::NormalizedLandmarkList vec_to_lms(
|
||||
const std::vector<std::pair<float, float>>& inp) {
|
||||
mediapipe::NormalizedLandmarkList ret;
|
||||
for (const auto& it : inp) {
|
||||
auto new_lm = ret.add_landmark();
|
||||
new_lm->set_x(it.first);
|
||||
new_lm->set_y(it.second);
|
||||
}
|
||||
return ret;
|
||||
}
|
||||
|
||||
std::vector<std::pair<float, float>> lms_to_vec(
|
||||
const mediapipe::NormalizedLandmarkList& lst) {
|
||||
std::vector<std::pair<float, float>> ret;
|
||||
for (const auto& lm : lst.landmark()) {
|
||||
ret.push_back({lm.x(), lm.y()});
|
||||
}
|
||||
return ret;
|
||||
}
|
||||
|
||||
std::vector<float> CHW_to_HWC(std::vector<float> inp, int height, int width,
|
||||
int depth) {
|
||||
std::vector<float> ret(inp.size());
|
||||
const float* inp_ptr = inp.data();
|
||||
for (int c = 0; c < depth; ++c) {
|
||||
for (int row = 0; row < height; ++row) {
|
||||
for (int col = 0; col < width; ++col) {
|
||||
int dest_idx = width * depth * row + depth * col + c;
|
||||
ret[dest_idx] = *inp_ptr;
|
||||
++inp_ptr;
|
||||
}
|
||||
}
|
||||
}
|
||||
return ret;
|
||||
}
|
||||
|
||||
using testing::ElementsAre;
|
||||
using testing::FloatEq;
|
||||
using testing::Pair;
|
||||
|
||||
TEST(RefineLandmarksFromHeatmapTest, Smoke) {
|
||||
float z = -10000000000000000;
|
||||
// clang-format off
|
||||
std::vector<float> hm = {
|
||||
z, z, z,
|
||||
1, z, z,
|
||||
z, z, z};
|
||||
// clang-format on
|
||||
|
||||
auto ret_or_error = RefineLandmarksFromHeatMap(vec_to_lms({{0.5, 0.5}}),
|
||||
hm.data(), {3, 3, 1}, 3, 0.1);
|
||||
MP_EXPECT_OK(ret_or_error);
|
||||
EXPECT_THAT(lms_to_vec(*ret_or_error),
|
||||
ElementsAre(Pair(FloatEq(0), FloatEq(1 / 3.))));
|
||||
}
|
||||
|
||||
TEST(RefineLandmarksFromHeatmapTest, MultiLayer) {
|
||||
float z = -10000000000000000;
|
||||
// clang-format off
|
||||
std::vector<float> hm = CHW_to_HWC({
|
||||
z, z, z,
|
||||
1, z, z,
|
||||
z, z, z,
|
||||
z, z, z,
|
||||
1, z, z,
|
||||
z, z, z,
|
||||
z, z, z,
|
||||
1, z, z,
|
||||
z, z, z}, 3, 3, 3);
|
||||
// clang-format on
|
||||
|
||||
auto ret_or_error = RefineLandmarksFromHeatMap(
|
||||
vec_to_lms({{0.5, 0.5}, {0.5, 0.5}, {0.5, 0.5}}), hm.data(), {3, 3, 3}, 3,
|
||||
0.1);
|
||||
MP_EXPECT_OK(ret_or_error);
|
||||
EXPECT_THAT(lms_to_vec(*ret_or_error),
|
||||
ElementsAre(Pair(FloatEq(0), FloatEq(1 / 3.)),
|
||||
Pair(FloatEq(0), FloatEq(1 / 3.)),
|
||||
Pair(FloatEq(0), FloatEq(1 / 3.))));
|
||||
}
|
||||
|
||||
TEST(RefineLandmarksFromHeatmapTest, KeepIfNotSure) {
|
||||
float z = -10000000000000000;
|
||||
// clang-format off
|
||||
std::vector<float> hm = CHW_to_HWC({
|
||||
z, z, z,
|
||||
0, z, z,
|
||||
z, z, z,
|
||||
z, z, z,
|
||||
0, z, z,
|
||||
z, z, z,
|
||||
z, z, z,
|
||||
0, z, z,
|
||||
z, z, z}, 3, 3, 3);
|
||||
// clang-format on
|
||||
|
||||
auto ret_or_error = RefineLandmarksFromHeatMap(
|
||||
vec_to_lms({{0.5, 0.5}, {0.5, 0.5}, {0.5, 0.5}}), hm.data(), {3, 3, 3}, 3,
|
||||
0.6);
|
||||
MP_EXPECT_OK(ret_or_error);
|
||||
EXPECT_THAT(lms_to_vec(*ret_or_error),
|
||||
ElementsAre(Pair(FloatEq(0.5), FloatEq(0.5)),
|
||||
Pair(FloatEq(0.5), FloatEq(0.5)),
|
||||
Pair(FloatEq(0.5), FloatEq(0.5))));
|
||||
}
|
||||
|
||||
TEST(RefineLandmarksFromHeatmapTest, Border) {
|
||||
float z = -10000000000000000;
|
||||
// clang-format off
|
||||
std::vector<float> hm = CHW_to_HWC({
|
||||
z, z, z,
|
||||
0, z, 0,
|
||||
z, z, z,
|
||||
|
||||
z, z, z,
|
||||
0, z, 0,
|
||||
z, z, 0}, 3, 3, 2);
|
||||
// clang-format on
|
||||
|
||||
auto ret_or_error = RefineLandmarksFromHeatMap(
|
||||
vec_to_lms({{0.0, 0.0}, {0.9, 0.9}}), hm.data(), {3, 3, 2}, 3, 0.1);
|
||||
MP_EXPECT_OK(ret_or_error);
|
||||
EXPECT_THAT(lms_to_vec(*ret_or_error),
|
||||
ElementsAre(Pair(FloatEq(0), FloatEq(1 / 3.)),
|
||||
Pair(FloatEq(2 / 3.), FloatEq(1 / 6. + 2 / 6.))));
|
||||
}
|
||||
|
||||
} // namespace
|
||||
} // namespace mediapipe
|
|
@ -101,7 +101,7 @@ absl::Status ThresholdingCalculator::Open(CalculatorContext* cc) {
|
|||
}
|
||||
|
||||
if (cc->InputSidePackets().HasTag("THRESHOLD")) {
|
||||
threshold_ = cc->InputSidePackets().Tag("THRESHOLD").Get<float>();
|
||||
threshold_ = cc->InputSidePackets().Tag("THRESHOLD").Get<double>();
|
||||
}
|
||||
return absl::OkStatus();
|
||||
}
|
||||
|
|
|
@ -43,8 +43,7 @@ android_binary(
|
|||
"//mediapipe/modules/hand_landmark:handedness.txt",
|
||||
"//mediapipe/modules/holistic_landmark:hand_recrop.tflite",
|
||||
"//mediapipe/modules/pose_detection:pose_detection.tflite",
|
||||
"//mediapipe/modules/pose_landmark:pose_landmark_upper_body.tflite",
|
||||
"//mediapipe/modules/pose_landmark:pose_landmark_full_body.tflite",
|
||||
"//mediapipe/modules/pose_landmark:pose_landmark_full.tflite",
|
||||
],
|
||||
assets_dir = "",
|
||||
manifest = "//mediapipe/examples/android/src/java/com/google/mediapipe/apps/basic:AndroidManifest.xml",
|
||||
|
|
|
@ -37,7 +37,7 @@ android_binary(
|
|||
srcs = glob(["*.java"]),
|
||||
assets = [
|
||||
"//mediapipe/graphs/pose_tracking:pose_tracking_gpu.binarypb",
|
||||
"//mediapipe/modules/pose_landmark:pose_landmark_full_body.tflite",
|
||||
"//mediapipe/modules/pose_landmark:pose_landmark_full.tflite",
|
||||
"//mediapipe/modules/pose_detection:pose_detection.tflite",
|
||||
],
|
||||
assets_dir = "",
|
||||
|
|
|
@ -1,64 +0,0 @@
|
|||
# Copyright 2019 The MediaPipe Authors.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
licenses(["notice"])
|
||||
|
||||
package(default_visibility = ["//visibility:private"])
|
||||
|
||||
cc_binary(
|
||||
name = "libmediapipe_jni.so",
|
||||
linkshared = 1,
|
||||
linkstatic = 1,
|
||||
deps = [
|
||||
"//mediapipe/graphs/pose_tracking:upper_body_pose_tracking_gpu_deps",
|
||||
"//mediapipe/java/com/google/mediapipe/framework/jni:mediapipe_framework_jni",
|
||||
],
|
||||
)
|
||||
|
||||
cc_library(
|
||||
name = "mediapipe_jni_lib",
|
||||
srcs = [":libmediapipe_jni.so"],
|
||||
alwayslink = 1,
|
||||
)
|
||||
|
||||
android_binary(
|
||||
name = "upperbodyposetrackinggpu",
|
||||
srcs = glob(["*.java"]),
|
||||
assets = [
|
||||
"//mediapipe/graphs/pose_tracking:upper_body_pose_tracking_gpu.binarypb",
|
||||
"//mediapipe/modules/pose_landmark:pose_landmark_upper_body.tflite",
|
||||
"//mediapipe/modules/pose_detection:pose_detection.tflite",
|
||||
],
|
||||
assets_dir = "",
|
||||
manifest = "//mediapipe/examples/android/src/java/com/google/mediapipe/apps/basic:AndroidManifest.xml",
|
||||
manifest_values = {
|
||||
"applicationId": "com.google.mediapipe.apps.upperbodyposetrackinggpu",
|
||||
"appName": "Upper Body Pose Tracking",
|
||||
"mainActivity": ".MainActivity",
|
||||
"cameraFacingFront": "False",
|
||||
"binaryGraphName": "upper_body_pose_tracking_gpu.binarypb",
|
||||
"inputVideoStreamName": "input_video",
|
||||
"outputVideoStreamName": "output_video",
|
||||
"flipFramesVertically": "True",
|
||||
"converterNumBuffers": "2",
|
||||
},
|
||||
multidex = "native",
|
||||
deps = [
|
||||
":mediapipe_jni_lib",
|
||||
"//mediapipe/examples/android/src/java/com/google/mediapipe/apps/basic:basic_lib",
|
||||
"//mediapipe/framework/formats:landmark_java_proto_lite",
|
||||
"//mediapipe/java/com/google/mediapipe/framework:android_framework",
|
||||
"@com_google_protobuf//:protobuf_javalite",
|
||||
],
|
||||
)
|
|
@ -1,75 +0,0 @@
|
|||
// Copyright 2019 The MediaPipe Authors.
|
||||
//
|
||||
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||
// you may not use this file except in compliance with the License.
|
||||
// You may obtain a copy of the License at
|
||||
//
|
||||
// http://www.apache.org/licenses/LICENSE-2.0
|
||||
//
|
||||
// Unless required by applicable law or agreed to in writing, software
|
||||
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
// See the License for the specific language governing permissions and
|
||||
// limitations under the License.
|
||||
|
||||
package com.google.mediapipe.apps.upperbodyposetrackinggpu;
|
||||
|
||||
import android.os.Bundle;
|
||||
import android.util.Log;
|
||||
import com.google.mediapipe.formats.proto.LandmarkProto.NormalizedLandmark;
|
||||
import com.google.mediapipe.formats.proto.LandmarkProto.NormalizedLandmarkList;
|
||||
import com.google.mediapipe.framework.PacketGetter;
|
||||
import com.google.protobuf.InvalidProtocolBufferException;
|
||||
|
||||
/** Main activity of MediaPipe upper-body pose tracking app. */
|
||||
public class MainActivity extends com.google.mediapipe.apps.basic.MainActivity {
|
||||
private static final String TAG = "MainActivity";
|
||||
|
||||
private static final String OUTPUT_LANDMARKS_STREAM_NAME = "pose_landmarks";
|
||||
|
||||
@Override
|
||||
protected void onCreate(Bundle savedInstanceState) {
|
||||
super.onCreate(savedInstanceState);
|
||||
|
||||
// To show verbose logging, run:
|
||||
// adb shell setprop log.tag.MainActivity VERBOSE
|
||||
if (Log.isLoggable(TAG, Log.VERBOSE)) {
|
||||
processor.addPacketCallback(
|
||||
OUTPUT_LANDMARKS_STREAM_NAME,
|
||||
(packet) -> {
|
||||
Log.v(TAG, "Received pose landmarks packet.");
|
||||
try {
|
||||
NormalizedLandmarkList poseLandmarks =
|
||||
PacketGetter.getProto(packet, NormalizedLandmarkList.class);
|
||||
Log.v(
|
||||
TAG,
|
||||
"[TS:"
|
||||
+ packet.getTimestamp()
|
||||
+ "] "
|
||||
+ getPoseLandmarksDebugString(poseLandmarks));
|
||||
} catch (InvalidProtocolBufferException exception) {
|
||||
Log.e(TAG, "Failed to get proto.", exception);
|
||||
}
|
||||
});
|
||||
}
|
||||
}
|
||||
|
||||
private static String getPoseLandmarksDebugString(NormalizedLandmarkList poseLandmarks) {
|
||||
String poseLandmarkStr = "Pose landmarks: " + poseLandmarks.getLandmarkCount() + "\n";
|
||||
int landmarkIndex = 0;
|
||||
for (NormalizedLandmark landmark : poseLandmarks.getLandmarkList()) {
|
||||
poseLandmarkStr +=
|
||||
"\tLandmark ["
|
||||
+ landmarkIndex
|
||||
+ "]: ("
|
||||
+ landmark.getX()
|
||||
+ ", "
|
||||
+ landmark.getY()
|
||||
+ ", "
|
||||
+ landmark.getZ()
|
||||
+ ")\n";
|
||||
++landmarkIndex;
|
||||
}
|
||||
return poseLandmarkStr;
|
||||
}
|
||||
}
|
|
@ -142,26 +142,13 @@ node {
|
|||
}
|
||||
}
|
||||
|
||||
# Maps detection label IDs to the corresponding label text ("Face"). The label
|
||||
# map is provided in the label_map_path option.
|
||||
node {
|
||||
calculator: "DetectionLabelIdToTextCalculator"
|
||||
input_stream: "filtered_detections"
|
||||
output_stream: "labeled_detections"
|
||||
options: {
|
||||
[mediapipe.DetectionLabelIdToTextCalculatorOptions.ext] {
|
||||
label_map_path: "mediapipe/models/face_detection_front_labelmap.txt"
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
# Adjusts detection locations (already normalized to [0.f, 1.f]) on the
|
||||
# letterboxed image (after image transformation with the FIT scale mode) to the
|
||||
# corresponding locations on the same image with the letterbox removed (the
|
||||
# input image to the graph before image transformation).
|
||||
node {
|
||||
calculator: "DetectionLetterboxRemovalCalculator"
|
||||
input_stream: "DETECTIONS:labeled_detections"
|
||||
input_stream: "DETECTIONS:filtered_detections"
|
||||
input_stream: "LETTERBOX_PADDING:letterbox_padding"
|
||||
output_stream: "DETECTIONS:output_detections"
|
||||
}
|
||||
|
|
|
@ -33,6 +33,10 @@ constexpr char kDetections[] = "DETECTIONS";
|
|||
constexpr char kDetectedBorders[] = "BORDERS";
|
||||
constexpr char kCropRect[] = "CROP_RECT";
|
||||
constexpr char kFirstCropRect[] = "FIRST_CROP_RECT";
|
||||
// Can be used to control whether an animated zoom should actually performed
|
||||
// (configured through option us_to_first_rect). If provided, a non-zero integer
|
||||
// will allow the animated zoom to be used when the first detections arrive.
|
||||
constexpr char kAnimateZoom[] = "ANIMATE_ZOOM";
|
||||
// Field-of-view (degrees) of the camera's x-axis (width).
|
||||
// TODO: Parameterize FOV based on camera specs.
|
||||
constexpr float kFieldOfView = 60;
|
||||
|
@ -76,10 +80,10 @@ class ContentZoomingCalculator : public CalculatorBase {
|
|||
absl::Status InitializeState(int frame_width, int frame_height);
|
||||
// Adjusts state to work with an updated frame size.
|
||||
absl::Status UpdateForResolutionChange(int frame_width, int frame_height);
|
||||
// Returns true if we are zooming to the initial rect.
|
||||
bool IsZoomingToInitialRect(const Timestamp& timestamp) const;
|
||||
// Builds the output rectangle when zooming to the initial rect.
|
||||
absl::StatusOr<mediapipe::Rect> GetInitialZoomingRect(
|
||||
// Returns true if we are animating to the first rect.
|
||||
bool IsAnimatingToFirstRect(const Timestamp& timestamp) const;
|
||||
// Builds the output rectangle when animating to the first rect.
|
||||
absl::StatusOr<mediapipe::Rect> GetAnimationRect(
|
||||
int frame_width, int frame_height, const Timestamp& timestamp) const;
|
||||
// Converts bounds to tilt offset, pan offset and height.
|
||||
absl::Status ConvertToPanTiltZoom(float xmin, float xmax, float ymin,
|
||||
|
@ -97,7 +101,10 @@ class ContentZoomingCalculator : public CalculatorBase {
|
|||
std::unique_ptr<KinematicPathSolver> path_solver_tilt_;
|
||||
// Are parameters initialized.
|
||||
bool initialized_;
|
||||
// Stores the time of the first crop rectangle.
|
||||
// Stores the time of the first crop rectangle. This is used to control
|
||||
// animating to it. Until a first crop rectangle was computed, it has
|
||||
// the value Timestamp::Unset(). If animating is not requested, it receives
|
||||
// the value Timestamp::Done() instead of the time.
|
||||
Timestamp first_rect_timestamp_;
|
||||
// Stores the first crop rectangle.
|
||||
mediapipe::NormalizedRect first_rect_;
|
||||
|
@ -135,6 +142,9 @@ absl::Status ContentZoomingCalculator::GetContract(
|
|||
if (cc->Inputs().HasTag(kDetections)) {
|
||||
cc->Inputs().Tag(kDetections).Set<std::vector<mediapipe::Detection>>();
|
||||
}
|
||||
if (cc->Inputs().HasTag(kAnimateZoom)) {
|
||||
cc->Inputs().Tag(kAnimateZoom).Set<bool>();
|
||||
}
|
||||
if (cc->Outputs().HasTag(kDetectedBorders)) {
|
||||
cc->Outputs().Tag(kDetectedBorders).Set<StaticFeatures>();
|
||||
}
|
||||
|
@ -419,10 +429,11 @@ absl::Status ContentZoomingCalculator::UpdateForResolutionChange(
|
|||
return absl::OkStatus();
|
||||
}
|
||||
|
||||
bool ContentZoomingCalculator::IsZoomingToInitialRect(
|
||||
bool ContentZoomingCalculator::IsAnimatingToFirstRect(
|
||||
const Timestamp& timestamp) const {
|
||||
if (options_.us_to_first_rect() == 0 ||
|
||||
first_rect_timestamp_ == Timestamp::Unset()) {
|
||||
first_rect_timestamp_ == Timestamp::Unset() ||
|
||||
first_rect_timestamp_ == Timestamp::Done()) {
|
||||
return false;
|
||||
}
|
||||
|
||||
|
@ -443,10 +454,10 @@ double easeInOutQuad(double t) {
|
|||
double lerp(double a, double b, double i) { return a * (1 - i) + b * i; }
|
||||
} // namespace
|
||||
|
||||
absl::StatusOr<mediapipe::Rect> ContentZoomingCalculator::GetInitialZoomingRect(
|
||||
absl::StatusOr<mediapipe::Rect> ContentZoomingCalculator::GetAnimationRect(
|
||||
int frame_width, int frame_height, const Timestamp& timestamp) const {
|
||||
RET_CHECK(IsZoomingToInitialRect(timestamp))
|
||||
<< "Must only be called if zooming to initial rect.";
|
||||
RET_CHECK(IsAnimatingToFirstRect(timestamp))
|
||||
<< "Must only be called if animating to first rect.";
|
||||
|
||||
const int64 delta_us = (timestamp - first_rect_timestamp_).Value();
|
||||
const int64 delay = options_.us_to_first_rect_delay();
|
||||
|
@ -538,15 +549,20 @@ absl::Status ContentZoomingCalculator::Process(
|
|||
}
|
||||
}
|
||||
|
||||
bool zooming_to_initial_rect = IsZoomingToInitialRect(cc->InputTimestamp());
|
||||
const bool may_start_animation = (options_.us_to_first_rect() != 0) &&
|
||||
(!cc->Inputs().HasTag(kAnimateZoom) ||
|
||||
cc->Inputs().Tag(kAnimateZoom).Get<bool>());
|
||||
bool is_animating = IsAnimatingToFirstRect(cc->InputTimestamp());
|
||||
|
||||
int offset_y, height, offset_x;
|
||||
if (zooming_to_initial_rect) {
|
||||
// If we are zooming to the first rect, ignore any new incoming detections.
|
||||
height = last_measured_height_;
|
||||
offset_x = last_measured_x_offset_;
|
||||
offset_y = last_measured_y_offset_;
|
||||
} else if (only_required_found) {
|
||||
if (!is_animating && options_.start_zoomed_out() && !may_start_animation &&
|
||||
first_rect_timestamp_ == Timestamp::Unset()) {
|
||||
// If we should start zoomed out and won't be doing an animation,
|
||||
// initialize the path solvers using the full frame, ignoring detections.
|
||||
height = max_frame_value_ * frame_height_;
|
||||
offset_x = (target_aspect_ * height) / 2;
|
||||
offset_y = frame_height_ / 2;
|
||||
} else if (!is_animating && only_required_found) {
|
||||
// Convert bounds to tilt/zoom and in pixel coordinates.
|
||||
MP_RETURN_IF_ERROR(ConvertToPanTiltZoom(xmin, xmax, ymin, ymax, &offset_y,
|
||||
&offset_x, &height));
|
||||
|
@ -555,9 +571,9 @@ absl::Status ContentZoomingCalculator::Process(
|
|||
last_measured_height_ = height;
|
||||
last_measured_x_offset_ = offset_x;
|
||||
last_measured_y_offset_ = offset_y;
|
||||
} else if (cc->InputTimestamp().Microseconds() -
|
||||
last_only_required_detection_ >=
|
||||
options_.us_before_zoomout()) {
|
||||
} else if (!is_animating && cc->InputTimestamp().Microseconds() -
|
||||
last_only_required_detection_ >=
|
||||
options_.us_before_zoomout()) {
|
||||
// No only_require detections found within salient regions packets
|
||||
// arriving since us_before_zoomout duration.
|
||||
height = max_frame_value_ * frame_height_ +
|
||||
|
@ -566,7 +582,8 @@ absl::Status ContentZoomingCalculator::Process(
|
|||
offset_x = (target_aspect_ * height) / 2;
|
||||
offset_y = frame_height_ / 2;
|
||||
} else {
|
||||
// No only detection found but using last detection due to
|
||||
// Either animating to the first rectangle, or
|
||||
// no only detection found but using last detection due to
|
||||
// duration_before_zoomout_us setting.
|
||||
height = last_measured_height_;
|
||||
offset_x = last_measured_x_offset_;
|
||||
|
@ -642,24 +659,28 @@ absl::Status ContentZoomingCalculator::Process(
|
|||
.AddPacket(Adopt(features.release()).At(cc->InputTimestamp()));
|
||||
}
|
||||
|
||||
if (first_rect_timestamp_ == Timestamp::Unset() &&
|
||||
options_.us_to_first_rect() != 0) {
|
||||
first_rect_timestamp_ = cc->InputTimestamp();
|
||||
// Record the first crop rectangle
|
||||
if (first_rect_timestamp_ == Timestamp::Unset()) {
|
||||
first_rect_.set_x_center(path_offset_x / static_cast<float>(frame_width_));
|
||||
first_rect_.set_width(path_height * target_aspect_ /
|
||||
static_cast<float>(frame_width_));
|
||||
first_rect_.set_y_center(path_offset_y / static_cast<float>(frame_height_));
|
||||
first_rect_.set_height(path_height / static_cast<float>(frame_height_));
|
||||
// After setting the first rectangle, check whether we should zoom to it.
|
||||
zooming_to_initial_rect = IsZoomingToInitialRect(cc->InputTimestamp());
|
||||
|
||||
// Record the time to serve as departure point for the animation.
|
||||
// If we are not allowed to start the animation, set Timestamp::Done.
|
||||
first_rect_timestamp_ =
|
||||
may_start_animation ? cc->InputTimestamp() : Timestamp::Done();
|
||||
// After setting the first rectangle, check whether we should animate to it.
|
||||
is_animating = IsAnimatingToFirstRect(cc->InputTimestamp());
|
||||
}
|
||||
|
||||
// Transmit downstream to glcroppingcalculator.
|
||||
if (cc->Outputs().HasTag(kCropRect)) {
|
||||
std::unique_ptr<mediapipe::Rect> gpu_rect;
|
||||
if (zooming_to_initial_rect) {
|
||||
auto rect = GetInitialZoomingRect(frame_width, frame_height,
|
||||
cc->InputTimestamp());
|
||||
if (is_animating) {
|
||||
auto rect =
|
||||
GetAnimationRect(frame_width, frame_height, cc->InputTimestamp());
|
||||
MP_RETURN_IF_ERROR(rect.status());
|
||||
gpu_rect = absl::make_unique<mediapipe::Rect>(*rect);
|
||||
} else {
|
||||
|
|
|
@ -19,7 +19,7 @@ package mediapipe.autoflip;
|
|||
import "mediapipe/examples/desktop/autoflip/quality/kinematic_path_solver.proto";
|
||||
import "mediapipe/framework/calculator.proto";
|
||||
|
||||
// NextTag: 17
|
||||
// NextTag: 18
|
||||
message ContentZoomingCalculatorOptions {
|
||||
extend mediapipe.CalculatorOptions {
|
||||
optional ContentZoomingCalculatorOptions ext = 313091992;
|
||||
|
@ -58,9 +58,15 @@ message ContentZoomingCalculatorOptions {
|
|||
// Whether to keep state between frames or to compute the final crop rect.
|
||||
optional bool is_stateless = 14 [default = false];
|
||||
|
||||
// Duration (in MicroSeconds) for moving to the first crop rect.
|
||||
// If true, on the first packet start with the camera zoomed out and then zoom
|
||||
// in on the subject. If false, the camera will start zoomed in on the
|
||||
// subject.
|
||||
optional bool start_zoomed_out = 17 [default = false];
|
||||
|
||||
// Duration (in MicroSeconds) for animating to the first crop rect.
|
||||
// Note that if set, takes precedence over start_zoomed_out.
|
||||
optional int64 us_to_first_rect = 15 [default = 0];
|
||||
// Duration (in MicroSeconds) to delay moving to the first crop rect.
|
||||
// Duration (in MicroSeconds) to delay animating to the first crop rect.
|
||||
// Used only if us_to_first_rect is set and is interpreted as part of the
|
||||
// us_to_first_rect time budget.
|
||||
optional int64 us_to_first_rect_delay = 16 [default = 0];
|
||||
|
|
|
@ -127,6 +127,29 @@ const char kConfigD[] = R"(
|
|||
}
|
||||
)";
|
||||
|
||||
const char kConfigE[] = R"(
|
||||
calculator: "ContentZoomingCalculator"
|
||||
input_stream: "VIDEO_SIZE:size"
|
||||
input_stream: "DETECTIONS:detections"
|
||||
input_stream: "ANIMATE_ZOOM:animate_zoom"
|
||||
output_stream: "CROP_RECT:rect"
|
||||
output_stream: "FIRST_CROP_RECT:first_rect"
|
||||
options: {
|
||||
[mediapipe.autoflip.ContentZoomingCalculatorOptions.ext]: {
|
||||
max_zoom_value_deg: 0
|
||||
kinematic_options_zoom {
|
||||
min_motion_to_reframe: 1.2
|
||||
}
|
||||
kinematic_options_tilt {
|
||||
min_motion_to_reframe: 1.2
|
||||
}
|
||||
kinematic_options_pan {
|
||||
min_motion_to_reframe: 1.2
|
||||
}
|
||||
}
|
||||
}
|
||||
)";
|
||||
|
||||
void CheckBorder(const StaticFeatures& static_features, int width, int height,
|
||||
int top_border, int bottom_border) {
|
||||
ASSERT_EQ(2, static_features.border().size());
|
||||
|
@ -145,9 +168,14 @@ void CheckBorder(const StaticFeatures& static_features, int width, int height,
|
|||
EXPECT_EQ(Border::BOTTOM, part.relative_position());
|
||||
}
|
||||
|
||||
struct AddDetectionFlags {
|
||||
std::optional<bool> animated_zoom;
|
||||
};
|
||||
|
||||
void AddDetectionFrameSize(const cv::Rect_<float>& position, const int64 time,
|
||||
const int width, const int height,
|
||||
CalculatorRunner* runner) {
|
||||
CalculatorRunner* runner,
|
||||
const AddDetectionFlags& flags = {}) {
|
||||
auto detections = std::make_unique<std::vector<mediapipe::Detection>>();
|
||||
if (position.width > 0 && position.height > 0) {
|
||||
mediapipe::Detection detection;
|
||||
|
@ -175,6 +203,14 @@ void AddDetectionFrameSize(const cv::Rect_<float>& position, const int64 time,
|
|||
runner->MutableInputs()
|
||||
->Tag("VIDEO_SIZE")
|
||||
.packets.push_back(Adopt(input_size.release()).At(Timestamp(time)));
|
||||
|
||||
if (flags.animated_zoom.has_value()) {
|
||||
runner->MutableInputs()
|
||||
->Tag("ANIMATE_ZOOM")
|
||||
.packets.push_back(
|
||||
mediapipe::MakePacket<bool>(flags.animated_zoom.value())
|
||||
.At(Timestamp(time)));
|
||||
}
|
||||
}
|
||||
|
||||
void AddDetection(const cv::Rect_<float>& position, const int64 time,
|
||||
|
@ -703,7 +739,33 @@ TEST(ContentZoomingCalculatorTest, MaxZoomOutValue) {
|
|||
CheckCropRect(500, 500, 1000, 1000, 2,
|
||||
runner->Outputs().Tag("CROP_RECT").packets);
|
||||
}
|
||||
|
||||
TEST(ContentZoomingCalculatorTest, StartZoomedOut) {
|
||||
auto config = ParseTextProtoOrDie<CalculatorGraphConfig::Node>(kConfigD);
|
||||
auto* options = config.mutable_options()->MutableExtension(
|
||||
ContentZoomingCalculatorOptions::ext);
|
||||
options->set_start_zoomed_out(true);
|
||||
auto runner = ::absl::make_unique<CalculatorRunner>(config);
|
||||
AddDetectionFrameSize(cv::Rect_<float>(.4, .4, .2, .2), 0, 1000, 1000,
|
||||
runner.get());
|
||||
AddDetectionFrameSize(cv::Rect_<float>(.4, .4, .2, .2), 400000, 1000, 1000,
|
||||
runner.get());
|
||||
AddDetectionFrameSize(cv::Rect_<float>(.4, .4, .2, .2), 800000, 1000, 1000,
|
||||
runner.get());
|
||||
AddDetectionFrameSize(cv::Rect_<float>(.4, .4, .2, .2), 1000000, 1000, 1000,
|
||||
runner.get());
|
||||
MP_ASSERT_OK(runner->Run());
|
||||
CheckCropRect(500, 500, 1000, 1000, 0,
|
||||
runner->Outputs().Tag("CROP_RECT").packets);
|
||||
CheckCropRect(500, 500, 880, 880, 1,
|
||||
runner->Outputs().Tag("CROP_RECT").packets);
|
||||
CheckCropRect(500, 500, 760, 760, 2,
|
||||
runner->Outputs().Tag("CROP_RECT").packets);
|
||||
CheckCropRect(500, 500, 655, 655, 3,
|
||||
runner->Outputs().Tag("CROP_RECT").packets);
|
||||
}
|
||||
|
||||
TEST(ContentZoomingCalculatorTest, AnimateToFirstRect) {
|
||||
auto config = ParseTextProtoOrDie<CalculatorGraphConfig::Node>(kConfigD);
|
||||
auto* options = config.mutable_options()->MutableExtension(
|
||||
ContentZoomingCalculatorOptions::ext);
|
||||
|
@ -733,6 +795,65 @@ TEST(ContentZoomingCalculatorTest, StartZoomedOut) {
|
|||
runner->Outputs().Tag("CROP_RECT").packets);
|
||||
}
|
||||
|
||||
TEST(ContentZoomingCalculatorTest, CanControlAnimation) {
|
||||
auto config = ParseTextProtoOrDie<CalculatorGraphConfig::Node>(kConfigE);
|
||||
auto* options = config.mutable_options()->MutableExtension(
|
||||
ContentZoomingCalculatorOptions::ext);
|
||||
options->set_start_zoomed_out(true);
|
||||
options->set_us_to_first_rect(1000000);
|
||||
options->set_us_to_first_rect_delay(500000);
|
||||
auto runner = ::absl::make_unique<CalculatorRunner>(config);
|
||||
// Request the animation for the first frame.
|
||||
AddDetectionFrameSize(cv::Rect_<float>(.4, .4, .2, .2), 0, 1000, 1000,
|
||||
runner.get(), {.animated_zoom = true});
|
||||
// We now stop requesting animated zoom and expect the already started
|
||||
// animation run to completion. This tests that the zoom in continues in the
|
||||
// call when it was started in the Meet greenroom.
|
||||
AddDetectionFrameSize(cv::Rect_<float>(.4, .4, .2, .2), 400000, 1000, 1000,
|
||||
runner.get(), {.animated_zoom = false});
|
||||
AddDetectionFrameSize(cv::Rect_<float>(.4, .4, .2, .2), 800000, 1000, 1000,
|
||||
runner.get(), {.animated_zoom = false});
|
||||
AddDetectionFrameSize(cv::Rect_<float>(.4, .4, .2, .2), 1000000, 1000, 1000,
|
||||
runner.get(), {.animated_zoom = false});
|
||||
AddDetectionFrameSize(cv::Rect_<float>(.4, .4, .2, .2), 1500000, 1000, 1000,
|
||||
runner.get(), {.animated_zoom = false});
|
||||
MP_ASSERT_OK(runner->Run());
|
||||
CheckCropRect(500, 500, 1000, 1000, 0,
|
||||
runner->Outputs().Tag("CROP_RECT").packets);
|
||||
CheckCropRect(500, 500, 1000, 1000, 1,
|
||||
runner->Outputs().Tag("CROP_RECT").packets);
|
||||
CheckCropRect(500, 500, 470, 470, 2,
|
||||
runner->Outputs().Tag("CROP_RECT").packets);
|
||||
CheckCropRect(500, 500, 222, 222, 3,
|
||||
runner->Outputs().Tag("CROP_RECT").packets);
|
||||
CheckCropRect(500, 500, 222, 222, 4,
|
||||
runner->Outputs().Tag("CROP_RECT").packets);
|
||||
}
|
||||
|
||||
TEST(ContentZoomingCalculatorTest, DoesNotAnimateIfDisabledViaInput) {
|
||||
auto config = ParseTextProtoOrDie<CalculatorGraphConfig::Node>(kConfigE);
|
||||
auto* options = config.mutable_options()->MutableExtension(
|
||||
ContentZoomingCalculatorOptions::ext);
|
||||
options->set_start_zoomed_out(true);
|
||||
options->set_us_to_first_rect(1000000);
|
||||
options->set_us_to_first_rect_delay(500000);
|
||||
auto runner = ::absl::make_unique<CalculatorRunner>(config);
|
||||
// Disable the animation already for the first frame.
|
||||
AddDetectionFrameSize(cv::Rect_<float>(.4, .4, .2, .2), 0, 1000, 1000,
|
||||
runner.get(), {.animated_zoom = false});
|
||||
AddDetectionFrameSize(cv::Rect_<float>(.4, .4, .2, .2), 400000, 1000, 1000,
|
||||
runner.get(), {.animated_zoom = false});
|
||||
AddDetectionFrameSize(cv::Rect_<float>(.4, .4, .2, .2), 800000, 1000, 1000,
|
||||
runner.get(), {.animated_zoom = false});
|
||||
MP_ASSERT_OK(runner->Run());
|
||||
CheckCropRect(500, 500, 1000, 1000, 0,
|
||||
runner->Outputs().Tag("CROP_RECT").packets);
|
||||
CheckCropRect(500, 500, 880, 880, 1,
|
||||
runner->Outputs().Tag("CROP_RECT").packets);
|
||||
CheckCropRect(500, 500, 760, 760, 2,
|
||||
runner->Outputs().Tag("CROP_RECT").packets);
|
||||
}
|
||||
|
||||
TEST(ContentZoomingCalculatorTest, ProvidesZeroSizeFirstRectWithoutDetections) {
|
||||
auto config = ParseTextProtoOrDie<CalculatorGraphConfig::Node>(kConfigD);
|
||||
auto runner = ::absl::make_unique<CalculatorRunner>(config);
|
||||
|
|
|
@ -47,4 +47,10 @@ message FaceBoxAdjusterCalculatorOptions {
|
|||
// and height respectively.
|
||||
optional float ipd_face_box_width_ratio = 6 [default = 0.5566];
|
||||
optional float ipd_face_box_height_ratio = 7 [default = 0.3131];
|
||||
|
||||
// The max look up angle before considering the eye distance unstable.
|
||||
optional float max_head_tilt_angle_deg = 8 [default = 12.0];
|
||||
// The max amount of time to use an old eye distance when the face look angle
|
||||
// is unstable.
|
||||
optional int32 max_facesize_history_us = 9 [default = 8000000];
|
||||
}
|
||||
|
|
|
@ -345,8 +345,7 @@ TEST(SceneCroppingCalculatorTest, ChecksPriorFrameBufferSize) {
|
|||
TEST(SceneCroppingCalculatorTest, ChecksDebugConfigWithoutCroppedFrame) {
|
||||
const CalculatorGraphConfig::Node config =
|
||||
ParseTextProtoOrDie<CalculatorGraphConfig::Node>(absl::Substitute(
|
||||
kDebugConfigNoCroppedFrame, kTargetWidth, kTargetHeight,
|
||||
kTargetSizeType, 0, kPriorFrameBufferSize));
|
||||
kDebugConfigNoCroppedFrame, kTargetWidth, kTargetHeight));
|
||||
auto runner = absl::make_unique<CalculatorRunner>(config);
|
||||
const auto status = runner->Run();
|
||||
EXPECT_FALSE(status.ok());
|
||||
|
|
|
@ -220,7 +220,7 @@ absl::Status KinematicPathSolver::GetTargetPosition(int* target_position) {
|
|||
|
||||
absl::Status KinematicPathSolver::UpdatePixelsPerDegree(
|
||||
const float pixels_per_degree) {
|
||||
RET_CHECK_GT(pixels_per_degree_, 0)
|
||||
RET_CHECK_GT(pixels_per_degree, 0)
|
||||
<< "pixels_per_degree must be larger than 0.";
|
||||
pixels_per_degree_ = pixels_per_degree;
|
||||
return absl::OkStatus();
|
||||
|
|
|
@ -38,7 +38,7 @@ node {
|
|||
output_stream: "TENSORS:detection_tensors"
|
||||
options: {
|
||||
[mediapipe.TfLiteInferenceCalculatorOptions.ext] {
|
||||
model_path: "mediapipe/models/face_detection_back.tflite"
|
||||
model_path: "mediapipe/modules/face_detection/face_detection_back.tflite"
|
||||
}
|
||||
}
|
||||
}
|
||||
|
@ -111,26 +111,13 @@ node {
|
|||
}
|
||||
}
|
||||
|
||||
# Maps detection label IDs to the corresponding label text ("Face"). The label
|
||||
# map is provided in the label_map_path option.
|
||||
node {
|
||||
calculator: "DetectionLabelIdToTextCalculator"
|
||||
input_stream: "filtered_detections"
|
||||
output_stream: "labeled_detections"
|
||||
options: {
|
||||
[mediapipe.DetectionLabelIdToTextCalculatorOptions.ext] {
|
||||
label_map_path: "mediapipe/models/face_detection_back_labelmap.txt"
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
# Adjusts detection locations (already normalized to [0.f, 1.f]) on the
|
||||
# letterboxed image (after image transformation with the FIT scale mode) to the
|
||||
# corresponding locations on the same image with the letterbox removed (the
|
||||
# input image to the graph before image transformation).
|
||||
node {
|
||||
calculator: "DetectionLetterboxRemovalCalculator"
|
||||
input_stream: "DETECTIONS:labeled_detections"
|
||||
input_stream: "DETECTIONS:filtered_detections"
|
||||
input_stream: "LETTERBOX_PADDING:letterbox_padding"
|
||||
output_stream: "DETECTIONS:output_detections"
|
||||
}
|
||||
|
|
|
@ -1,34 +0,0 @@
|
|||
# Copyright 2020 The MediaPipe Authors.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
licenses(["notice"])
|
||||
|
||||
package(default_visibility = ["//mediapipe/examples:__subpackages__"])
|
||||
|
||||
cc_binary(
|
||||
name = "upper_body_pose_tracking_cpu",
|
||||
deps = [
|
||||
"//mediapipe/examples/desktop:demo_run_graph_main",
|
||||
"//mediapipe/graphs/pose_tracking:upper_body_pose_tracking_cpu_deps",
|
||||
],
|
||||
)
|
||||
|
||||
# Linux only
|
||||
cc_binary(
|
||||
name = "upper_body_pose_tracking_gpu",
|
||||
deps = [
|
||||
"//mediapipe/examples/desktop:demo_run_graph_main_gpu",
|
||||
"//mediapipe/graphs/pose_tracking:upper_body_pose_tracking_gpu_deps",
|
||||
],
|
||||
)
|
|
@ -61,8 +61,7 @@ objc_library(
|
|||
"//mediapipe/modules/hand_landmark:handedness.txt",
|
||||
"//mediapipe/modules/holistic_landmark:hand_recrop.tflite",
|
||||
"//mediapipe/modules/pose_detection:pose_detection.tflite",
|
||||
"//mediapipe/modules/pose_landmark:pose_landmark_full_body.tflite",
|
||||
"//mediapipe/modules/pose_landmark:pose_landmark_upper_body.tflite",
|
||||
"//mediapipe/modules/pose_landmark:pose_landmark_full.tflite",
|
||||
],
|
||||
deps = [
|
||||
"//mediapipe/examples/ios/common:CommonMediaPipeAppLibrary",
|
||||
|
|
|
@ -63,7 +63,7 @@ objc_library(
|
|||
data = [
|
||||
"//mediapipe/graphs/pose_tracking:pose_tracking_gpu.binarypb",
|
||||
"//mediapipe/modules/pose_detection:pose_detection.tflite",
|
||||
"//mediapipe/modules/pose_landmark:pose_landmark_full_body.tflite",
|
||||
"//mediapipe/modules/pose_landmark:pose_landmark_full.tflite",
|
||||
],
|
||||
deps = [
|
||||
"//mediapipe/examples/ios/common:CommonMediaPipeAppLibrary",
|
||||
|
|
|
@ -1,78 +0,0 @@
|
|||
# Copyright 2020 The MediaPipe Authors.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
load(
|
||||
"@build_bazel_rules_apple//apple:ios.bzl",
|
||||
"ios_application",
|
||||
)
|
||||
load(
|
||||
"//mediapipe/examples/ios:bundle_id.bzl",
|
||||
"BUNDLE_ID_PREFIX",
|
||||
"example_provisioning",
|
||||
)
|
||||
|
||||
licenses(["notice"])
|
||||
|
||||
MIN_IOS_VERSION = "10.0"
|
||||
|
||||
alias(
|
||||
name = "upperbodyposetrackinggpu",
|
||||
actual = "UpperBodyPoseTrackingGpuApp",
|
||||
)
|
||||
|
||||
ios_application(
|
||||
name = "UpperBodyPoseTrackingGpuApp",
|
||||
app_icons = ["//mediapipe/examples/ios/common:AppIcon"],
|
||||
bundle_id = BUNDLE_ID_PREFIX + ".UpperBodyPoseTrackingGpu",
|
||||
families = [
|
||||
"iphone",
|
||||
"ipad",
|
||||
],
|
||||
infoplists = [
|
||||
"//mediapipe/examples/ios/common:Info.plist",
|
||||
"Info.plist",
|
||||
],
|
||||
minimum_os_version = MIN_IOS_VERSION,
|
||||
provisioning_profile = example_provisioning(),
|
||||
deps = [
|
||||
":UpperBodyPoseTrackingGpuAppLibrary",
|
||||
"@ios_opencv//:OpencvFramework",
|
||||
],
|
||||
)
|
||||
|
||||
objc_library(
|
||||
name = "UpperBodyPoseTrackingGpuAppLibrary",
|
||||
srcs = [
|
||||
"UpperBodyPoseTrackingViewController.mm",
|
||||
],
|
||||
hdrs = [
|
||||
"UpperBodyPoseTrackingViewController.h",
|
||||
],
|
||||
copts = ["-std=c++17"],
|
||||
data = [
|
||||
"//mediapipe/graphs/pose_tracking:upper_body_pose_tracking_gpu.binarypb",
|
||||
"//mediapipe/modules/pose_detection:pose_detection.tflite",
|
||||
"//mediapipe/modules/pose_landmark:pose_landmark_upper_body.tflite",
|
||||
],
|
||||
deps = [
|
||||
"//mediapipe/examples/ios/common:CommonMediaPipeAppLibrary",
|
||||
] + select({
|
||||
"//mediapipe:ios_i386": [],
|
||||
"//mediapipe:ios_x86_64": [],
|
||||
"//conditions:default": [
|
||||
"//mediapipe/graphs/pose_tracking:upper_body_pose_tracking_gpu_deps",
|
||||
"//mediapipe/framework/formats:landmark_cc_proto",
|
||||
],
|
||||
}),
|
||||
)
|
|
@ -1,16 +0,0 @@
|
|||
<?xml version="1.0" encoding="UTF-8"?>
|
||||
<!DOCTYPE plist PUBLIC "-//Apple//DTD PLIST 1.0//EN" "http://www.apple.com/DTDs/PropertyList-1.0.dtd">
|
||||
<plist version="1.0">
|
||||
<dict>
|
||||
<key>CameraPosition</key>
|
||||
<string>back</string>
|
||||
<key>MainViewController</key>
|
||||
<string>UpperBodyPoseTrackingViewController</string>
|
||||
<key>GraphOutputStream</key>
|
||||
<string>output_video</string>
|
||||
<key>GraphInputStream</key>
|
||||
<string>input_video</string>
|
||||
<key>GraphName</key>
|
||||
<string>upper_body_pose_tracking_gpu</string>
|
||||
</dict>
|
||||
</plist>
|
|
@ -1,53 +0,0 @@
|
|||
// Copyright 2019 The MediaPipe Authors.
|
||||
//
|
||||
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||
// you may not use this file except in compliance with the License.
|
||||
// You may obtain a copy of the License at
|
||||
//
|
||||
// http://www.apache.org/licenses/LICENSE-2.0
|
||||
//
|
||||
// Unless required by applicable law or agreed to in writing, software
|
||||
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
// See the License for the specific language governing permissions and
|
||||
// limitations under the License.
|
||||
|
||||
#import "UpperBodyPoseTrackingViewController.h"
|
||||
|
||||
#include "mediapipe/framework/formats/landmark.pb.h"
|
||||
|
||||
static const char* kLandmarksOutputStream = "pose_landmarks";
|
||||
|
||||
@implementation UpperBodyPoseTrackingViewController
|
||||
|
||||
#pragma mark - UIViewController methods
|
||||
|
||||
- (void)viewDidLoad {
|
||||
[super viewDidLoad];
|
||||
|
||||
[self.mediapipeGraph addFrameOutputStream:kLandmarksOutputStream
|
||||
outputPacketType:MPPPacketTypeRaw];
|
||||
}
|
||||
|
||||
#pragma mark - MPPGraphDelegate methods
|
||||
|
||||
// Receives a raw packet from the MediaPipe graph. Invoked on a MediaPipe worker thread.
|
||||
- (void)mediapipeGraph:(MPPGraph*)graph
|
||||
didOutputPacket:(const ::mediapipe::Packet&)packet
|
||||
fromStream:(const std::string&)streamName {
|
||||
if (streamName == kLandmarksOutputStream) {
|
||||
if (packet.IsEmpty()) {
|
||||
NSLog(@"[TS:%lld] No pose landmarks", packet.Timestamp().Value());
|
||||
return;
|
||||
}
|
||||
const auto& landmarks = packet.Get<::mediapipe::NormalizedLandmarkList>();
|
||||
NSLog(@"[TS:%lld] Number of pose landmarks: %d", packet.Timestamp().Value(),
|
||||
landmarks.landmark_size());
|
||||
for (int i = 0; i < landmarks.landmark_size(); ++i) {
|
||||
NSLog(@"\tLandmark[%d]: (%f, %f, %f)", i, landmarks.landmark(i).x(),
|
||||
landmarks.landmark(i).y(), landmarks.landmark(i).z());
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
@end
|
|
@ -15,6 +15,7 @@
|
|||
#
|
||||
|
||||
load("//mediapipe/framework/port:build_config.bzl", "mediapipe_proto_library")
|
||||
load("@bazel_skylib//:bzl_library.bzl", "bzl_library")
|
||||
|
||||
licenses(["notice"])
|
||||
|
||||
|
@ -27,10 +28,28 @@ package_group(
|
|||
],
|
||||
)
|
||||
|
||||
exports_files([
|
||||
"transitive_protos.bzl",
|
||||
"encode_binary_proto.bzl",
|
||||
])
|
||||
bzl_library(
|
||||
name = "transitive_protos_bzl",
|
||||
srcs = [
|
||||
"transitive_protos.bzl",
|
||||
],
|
||||
visibility = ["//mediapipe/framework:__subpackages__"],
|
||||
)
|
||||
|
||||
bzl_library(
|
||||
name = "encode_binary_proto_bzl",
|
||||
srcs = [
|
||||
"encode_binary_proto.bzl",
|
||||
],
|
||||
visibility = ["//visibility:public"],
|
||||
)
|
||||
|
||||
alias(
|
||||
name = "encode_binary_proto",
|
||||
actual = ":encode_binary_proto_bzl",
|
||||
deprecation = "Use encode_binary_proto_bzl",
|
||||
visibility = ["//visibility:public"],
|
||||
)
|
||||
|
||||
mediapipe_proto_library(
|
||||
name = "calculator_proto",
|
||||
|
|
|
@ -1,15 +1,8 @@
|
|||
package(
|
||||
default_visibility = [":preview_users"],
|
||||
default_visibility = ["//visibility:public"],
|
||||
features = ["-use_header_modules"],
|
||||
)
|
||||
|
||||
package_group(
|
||||
name = "preview_users",
|
||||
packages = [
|
||||
"//mediapipe/...",
|
||||
],
|
||||
)
|
||||
|
||||
licenses(["notice"])
|
||||
|
||||
cc_library(
|
||||
|
|
|
@ -422,6 +422,9 @@ message CalculatorGraphConfig {
|
|||
// the graph config.
|
||||
string type = 20;
|
||||
|
||||
// Can be used for annotating a graph.
|
||||
// The types and default values for graph options, in proto2 syntax.
|
||||
MediaPipeOptions options = 1001;
|
||||
|
||||
// The types and default values for graph options, in proto3 syntax.
|
||||
repeated google.protobuf.Any graph_options = 1002;
|
||||
}
|
||||
|
|
|
@ -411,7 +411,8 @@ absl::Status CalculatorGraph::Initialize(
|
|||
|
||||
absl::Status CalculatorGraph::ObserveOutputStream(
|
||||
const std::string& stream_name,
|
||||
std::function<absl::Status(const Packet&)> packet_callback) {
|
||||
std::function<absl::Status(const Packet&)> packet_callback,
|
||||
bool observe_timestamp_bounds) {
|
||||
RET_CHECK(initialized_).SetNoLogging()
|
||||
<< "CalculatorGraph is not initialized.";
|
||||
// TODO Allow output observers to be attached by graph level
|
||||
|
@ -425,7 +426,7 @@ absl::Status CalculatorGraph::ObserveOutputStream(
|
|||
auto observer = absl::make_unique<internal::OutputStreamObserver>();
|
||||
MP_RETURN_IF_ERROR(observer->Initialize(
|
||||
stream_name, &any_packet_type_, std::move(packet_callback),
|
||||
&output_stream_managers_[output_stream_index]));
|
||||
&output_stream_managers_[output_stream_index], observe_timestamp_bounds));
|
||||
graph_output_streams_.push_back(std::move(observer));
|
||||
return absl::OkStatus();
|
||||
}
|
||||
|
|
|
@ -157,7 +157,8 @@ class CalculatorGraph {
|
|||
// TODO: Rename to AddOutputStreamCallback.
|
||||
absl::Status ObserveOutputStream(
|
||||
const std::string& stream_name,
|
||||
std::function<absl::Status(const Packet&)> packet_callback);
|
||||
std::function<absl::Status(const Packet&)> packet_callback,
|
||||
bool observe_timestamp_bounds = false);
|
||||
|
||||
// Adds an OutputStreamPoller for a stream. This provides a synchronous,
|
||||
// polling API for accessing a stream's output. Should only be called before
|
||||
|
|
|
@ -1518,5 +1518,72 @@ TEST(CalculatorGraphBoundsTest, OffsetAndBound) {
|
|||
MP_ASSERT_OK(graph.WaitUntilDone());
|
||||
}
|
||||
|
||||
// A Calculator that sends empty output stream packets.
|
||||
class EmptyPacketCalculator : public CalculatorBase {
|
||||
public:
|
||||
static absl::Status GetContract(CalculatorContract* cc) {
|
||||
cc->Inputs().Index(0).Set<int>();
|
||||
cc->Outputs().Index(0).SetSameAs(&cc->Inputs().Index(0));
|
||||
return absl::OkStatus();
|
||||
}
|
||||
absl::Status Open(CalculatorContext* cc) final { return absl::OkStatus(); }
|
||||
absl::Status Process(CalculatorContext* cc) final {
|
||||
if (cc->InputTimestamp().Value() % 2 == 0) {
|
||||
cc->Outputs().Index(0).AddPacket(Packet().At(cc->InputTimestamp()));
|
||||
}
|
||||
return absl::OkStatus();
|
||||
}
|
||||
};
|
||||
REGISTER_CALCULATOR(EmptyPacketCalculator);
|
||||
|
||||
// This test shows that an output timestamp bound can be specified by outputing
|
||||
// an empty packet with a settled timestamp.
|
||||
TEST(CalculatorGraphBoundsTest, EmptyPacketOutput) {
|
||||
// OffsetAndBoundCalculator runs on parallel threads and sends ts
|
||||
// occasionally.
|
||||
std::string config_str = R"(
|
||||
input_stream: "input_0"
|
||||
node {
|
||||
calculator: "EmptyPacketCalculator"
|
||||
input_stream: "input_0"
|
||||
output_stream: "empty_0"
|
||||
}
|
||||
node {
|
||||
calculator: "ProcessBoundToPacketCalculator"
|
||||
input_stream: "empty_0"
|
||||
output_stream: "output_0"
|
||||
}
|
||||
)";
|
||||
CalculatorGraphConfig config =
|
||||
mediapipe::ParseTextProtoOrDie<CalculatorGraphConfig>(config_str);
|
||||
CalculatorGraph graph;
|
||||
std::vector<Packet> output_0_packets;
|
||||
MP_ASSERT_OK(graph.Initialize(config));
|
||||
MP_ASSERT_OK(graph.ObserveOutputStream("output_0", [&](const Packet& p) {
|
||||
output_0_packets.push_back(p);
|
||||
return absl::OkStatus();
|
||||
}));
|
||||
MP_ASSERT_OK(graph.StartRun({}));
|
||||
MP_ASSERT_OK(graph.WaitUntilIdle());
|
||||
|
||||
// Send in packets.
|
||||
for (int i = 0; i < 9; ++i) {
|
||||
const int ts = 10 + i * 10;
|
||||
Packet p = MakePacket<int>(i).At(Timestamp(ts));
|
||||
MP_ASSERT_OK(graph.AddPacketToInputStream("input_0", p));
|
||||
MP_ASSERT_OK(graph.WaitUntilIdle());
|
||||
}
|
||||
|
||||
// 9 empty packets are converted to bounds and then to packets.
|
||||
EXPECT_EQ(output_0_packets.size(), 9);
|
||||
for (int i = 0; i < 9; ++i) {
|
||||
EXPECT_EQ(output_0_packets[i].Timestamp(), Timestamp(10 + i * 10));
|
||||
}
|
||||
|
||||
// Shutdown the graph.
|
||||
MP_ASSERT_OK(graph.CloseAllPacketSources());
|
||||
MP_ASSERT_OK(graph.WaitUntilDone());
|
||||
}
|
||||
|
||||
} // namespace
|
||||
} // namespace mediapipe
|
||||
|
|
|
@ -16,11 +16,20 @@
|
|||
# The dependencies of mediapipe.
|
||||
|
||||
load("//mediapipe/framework/port:build_config.bzl", "mediapipe_cc_proto_library")
|
||||
load("@bazel_skylib//:bzl_library.bzl", "bzl_library")
|
||||
|
||||
licenses(["notice"])
|
||||
|
||||
package(default_visibility = ["//visibility:private"])
|
||||
|
||||
bzl_library(
|
||||
name = "expand_template_bzl",
|
||||
srcs = [
|
||||
"expand_template.bzl",
|
||||
],
|
||||
visibility = ["//mediapipe/framework:__subpackages__"],
|
||||
)
|
||||
|
||||
proto_library(
|
||||
name = "proto_descriptor_proto",
|
||||
srcs = ["proto_descriptor.proto"],
|
||||
|
|
|
@ -295,6 +295,7 @@ cc_library(
|
|||
"//mediapipe/framework/formats:image_format_cc_proto",
|
||||
"@com_google_absl//absl/synchronization",
|
||||
"//mediapipe/framework:port",
|
||||
"//mediapipe/framework:type_map",
|
||||
"//mediapipe/framework/port:logging",
|
||||
] + select({
|
||||
"//conditions:default": [
|
||||
|
|
|
@ -14,6 +14,8 @@
|
|||
|
||||
#include "mediapipe/framework/formats/image.h"
|
||||
|
||||
#include "mediapipe/framework/type_map.h"
|
||||
|
||||
namespace mediapipe {
|
||||
|
||||
// TODO Refactor common code from GpuBufferToImageFrameCalculator
|
||||
|
@ -67,8 +69,7 @@ bool Image::ConvertToGpu() const {
|
|||
#else
|
||||
if (use_gpu_) return true; // Already on GPU.
|
||||
#if MEDIAPIPE_GPU_BUFFER_USE_CV_PIXEL_BUFFER
|
||||
auto packet = MakePacket<ImageFrame>(std::move(*image_frame_));
|
||||
image_frame_ = nullptr;
|
||||
auto packet = PointToForeign<ImageFrame>(image_frame_.get());
|
||||
CFHolder<CVPixelBufferRef> buffer;
|
||||
auto status = CreateCVPixelBufferForImageFramePacket(packet, true, &buffer);
|
||||
CHECK_OK(status);
|
||||
|
@ -94,4 +95,7 @@ bool Image::ConvertToGpu() const {
|
|||
#endif // MEDIAPIPE_DISABLE_GPU
|
||||
}
|
||||
|
||||
MEDIAPIPE_REGISTER_TYPE(mediapipe::Image, "::mediapipe::Image", nullptr,
|
||||
nullptr);
|
||||
|
||||
} // namespace mediapipe
|
||||
|
|
|
@ -72,8 +72,8 @@ class Image {
|
|||
|
||||
// Creates an Image representing the same image content as the ImageFrame
|
||||
// the input shared pointer points to, and retaining shared ownership.
|
||||
explicit Image(ImageFrameSharedPtr frame_buffer)
|
||||
: image_frame_(std::move(frame_buffer)) {
|
||||
explicit Image(ImageFrameSharedPtr image_frame)
|
||||
: image_frame_(std::move(image_frame)) {
|
||||
use_gpu_ = false;
|
||||
pixel_mutex_ = std::make_shared<absl::Mutex>();
|
||||
}
|
||||
|
|
|
@ -30,6 +30,9 @@
|
|||
|
||||
namespace mediapipe {
|
||||
|
||||
// Zero and negative values are not checked here.
|
||||
bool IsPowerOfTwo(int v) { return (v & (v - 1)) == 0; }
|
||||
|
||||
int BhwcBatchFromShape(const Tensor::Shape& shape) {
|
||||
LOG_IF(FATAL, shape.dims.empty())
|
||||
<< "Tensor::Shape must be non-empty to retrieve a named dimension";
|
||||
|
@ -237,6 +240,12 @@ void Tensor::AllocateOpenGlTexture2d() const {
|
|||
glTexStorage2D(GL_TEXTURE_2D, 1, GL_RGBA32F, texture_width_,
|
||||
texture_height_);
|
||||
} else {
|
||||
// GLES2.0 supports only clamp addressing mode for NPOT textures.
|
||||
// If any of dimensions is NPOT then both addressing modes are clamp.
|
||||
if (!IsPowerOfTwo(texture_width_) || !IsPowerOfTwo(texture_height_)) {
|
||||
glTexParameteri(GL_TEXTURE_2D, GL_TEXTURE_WRAP_S, GL_CLAMP_TO_EDGE);
|
||||
glTexParameteri(GL_TEXTURE_2D, GL_TEXTURE_WRAP_T, GL_CLAMP_TO_EDGE);
|
||||
}
|
||||
// We assume all contexts will have the same extensions, so we only check
|
||||
// once for OES_texture_float extension, to save time.
|
||||
static bool has_oes_extension =
|
||||
|
|
|
@ -14,13 +14,16 @@
|
|||
|
||||
#include "mediapipe/framework/graph_output_stream.h"
|
||||
|
||||
#include "absl/synchronization/mutex.h"
|
||||
#include "mediapipe/framework/port/status.h"
|
||||
|
||||
namespace mediapipe {
|
||||
|
||||
namespace internal {
|
||||
|
||||
absl::Status GraphOutputStream::Initialize(
|
||||
const std::string& stream_name, const PacketType* packet_type,
|
||||
OutputStreamManager* output_stream_manager) {
|
||||
OutputStreamManager* output_stream_manager, bool observe_timestamp_bounds) {
|
||||
RET_CHECK(output_stream_manager);
|
||||
|
||||
// Initializes input_stream_handler_ with one input stream as the observer.
|
||||
|
@ -31,6 +34,7 @@ absl::Status GraphOutputStream::Initialize(
|
|||
input_stream_handler_ = absl::make_unique<GraphOutputStreamHandler>(
|
||||
tag_map, /*cc_manager=*/nullptr, MediaPipeOptions(),
|
||||
/*calculator_run_in_parallel=*/false);
|
||||
input_stream_handler_->SetProcessTimestampBounds(observe_timestamp_bounds);
|
||||
const CollectionItemId& id = tag_map->BeginId();
|
||||
input_stream_ = absl::make_unique<InputStreamManager>();
|
||||
MP_RETURN_IF_ERROR(
|
||||
|
@ -52,20 +56,58 @@ void GraphOutputStream::PrepareForRun(
|
|||
absl::Status OutputStreamObserver::Initialize(
|
||||
const std::string& stream_name, const PacketType* packet_type,
|
||||
std::function<absl::Status(const Packet&)> packet_callback,
|
||||
OutputStreamManager* output_stream_manager) {
|
||||
OutputStreamManager* output_stream_manager, bool observe_timestamp_bounds) {
|
||||
RET_CHECK(output_stream_manager);
|
||||
|
||||
packet_callback_ = std::move(packet_callback);
|
||||
observe_timestamp_bounds_ = observe_timestamp_bounds;
|
||||
return GraphOutputStream::Initialize(stream_name, packet_type,
|
||||
output_stream_manager);
|
||||
output_stream_manager,
|
||||
observe_timestamp_bounds);
|
||||
}
|
||||
|
||||
absl::Status OutputStreamObserver::Notify() {
|
||||
// Lets one thread perform packets notification as much as possible.
|
||||
// Other threads should quit if a thread is already performing notification.
|
||||
{
|
||||
absl::MutexLock l(&mutex_);
|
||||
|
||||
if (notifying_ == false) {
|
||||
notifying_ = true;
|
||||
} else {
|
||||
return absl::OkStatus();
|
||||
}
|
||||
}
|
||||
while (true) {
|
||||
bool empty;
|
||||
Timestamp min_timestamp = input_stream_->MinTimestampOrBound(&empty);
|
||||
if (empty) {
|
||||
break;
|
||||
// Emits an empty packet at timestamp_bound.PreviousAllowedInStream().
|
||||
if (observe_timestamp_bounds_ && min_timestamp < Timestamp::Done()) {
|
||||
Timestamp settled = (min_timestamp == Timestamp::PostStream()
|
||||
? Timestamp::PostStream()
|
||||
: min_timestamp.PreviousAllowedInStream());
|
||||
if (last_processed_ts_ < settled) {
|
||||
MP_RETURN_IF_ERROR(packet_callback_(Packet().At(settled)));
|
||||
last_processed_ts_ = settled;
|
||||
}
|
||||
}
|
||||
// Last check to make sure that the min timestamp or bound doesn't change.
|
||||
// If so, flips notifying_ to false to allow any other threads to perform
|
||||
// notification when new packets/timestamp bounds arrive. Otherwise, in
|
||||
// case of the min timestamp or bound getting updated, jumps to the
|
||||
// beginning of the notification loop for a new iteration.
|
||||
{
|
||||
absl::MutexLock l(&mutex_);
|
||||
Timestamp new_min_timestamp =
|
||||
input_stream_->MinTimestampOrBound(&empty);
|
||||
if (new_min_timestamp == min_timestamp) {
|
||||
notifying_ = false;
|
||||
break;
|
||||
} else {
|
||||
continue;
|
||||
}
|
||||
}
|
||||
}
|
||||
int num_packets_dropped = 0;
|
||||
bool stream_is_done = false;
|
||||
|
@ -75,6 +117,7 @@ absl::Status OutputStreamObserver::Notify() {
|
|||
<< absl::Substitute("Dropped $0 packet(s) on input stream \"$1\".",
|
||||
num_packets_dropped, input_stream_->Name());
|
||||
MP_RETURN_IF_ERROR(packet_callback_(packet));
|
||||
last_processed_ts_ = min_timestamp;
|
||||
}
|
||||
return absl::OkStatus();
|
||||
}
|
||||
|
|
|
@ -52,7 +52,8 @@ class GraphOutputStream {
|
|||
// is not transferred to the graph output stream object.
|
||||
absl::Status Initialize(const std::string& stream_name,
|
||||
const PacketType* packet_type,
|
||||
OutputStreamManager* output_stream_manager);
|
||||
OutputStreamManager* output_stream_manager,
|
||||
bool observe_timestamp_bounds = false);
|
||||
|
||||
// Installs callbacks into its GraphOutputStreamHandler.
|
||||
virtual void PrepareForRun(std::function<void()> notification_callback,
|
||||
|
@ -99,6 +100,10 @@ class GraphOutputStream {
|
|||
}
|
||||
};
|
||||
|
||||
bool observe_timestamp_bounds_;
|
||||
absl::Mutex mutex_;
|
||||
bool notifying_ ABSL_GUARDED_BY(mutex_) = false;
|
||||
Timestamp last_processed_ts_ = Timestamp::Unstarted();
|
||||
std::unique_ptr<InputStreamHandler> input_stream_handler_;
|
||||
std::unique_ptr<InputStreamManager> input_stream_;
|
||||
};
|
||||
|
@ -112,7 +117,8 @@ class OutputStreamObserver : public GraphOutputStream {
|
|||
absl::Status Initialize(
|
||||
const std::string& stream_name, const PacketType* packet_type,
|
||||
std::function<absl::Status(const Packet&)> packet_callback,
|
||||
OutputStreamManager* output_stream_manager);
|
||||
OutputStreamManager* output_stream_manager,
|
||||
bool observe_timestamp_bounds = false);
|
||||
|
||||
// Notifies the observer of new packets emitted by the observed
|
||||
// output stream.
|
||||
|
@ -128,6 +134,7 @@ class OutputStreamObserver : public GraphOutputStream {
|
|||
|
||||
// OutputStreamPollerImpl that returns packets to the caller via
|
||||
// Next()/NextBatch().
|
||||
// TODO: Support observe_timestamp_bounds.
|
||||
class OutputStreamPollerImpl : public GraphOutputStream {
|
||||
public:
|
||||
virtual ~OutputStreamPollerImpl() {}
|
||||
|
|
|
@ -20,6 +20,9 @@ syntax = "proto2";
|
|||
|
||||
package mediapipe;
|
||||
|
||||
option java_package = "com.google.mediapipe.proto";
|
||||
option java_outer_classname = "MediaPipeOptionsProto";
|
||||
|
||||
// Options used by a MediaPipe object.
|
||||
message MediaPipeOptions {
|
||||
extensions 20000 to max;
|
||||
|
|
|
@ -101,8 +101,8 @@ Status OutputStreamShard::AddPacketInternal(T&& packet) {
|
|||
}
|
||||
|
||||
if (packet.IsEmpty()) {
|
||||
return mediapipe::InvalidArgumentErrorBuilder(MEDIAPIPE_LOC)
|
||||
<< "Empty packet sent to stream \"" << Name() << "\".";
|
||||
SetNextTimestampBound(packet.Timestamp().NextAllowedInStream());
|
||||
return absl::OkStatus();
|
||||
}
|
||||
|
||||
const Timestamp timestamp = packet.Timestamp();
|
||||
|
|
|
@ -20,6 +20,7 @@ load(
|
|||
"mediapipe_binary_graph",
|
||||
)
|
||||
load("//mediapipe/framework:mediapipe_cc_test.bzl", "mediapipe_cc_test")
|
||||
load("@bazel_skylib//:bzl_library.bzl", "bzl_library")
|
||||
|
||||
licenses(["notice"])
|
||||
|
||||
|
@ -29,6 +30,30 @@ exports_files([
|
|||
"simple_subgraph_template.cc",
|
||||
])
|
||||
|
||||
bzl_library(
|
||||
name = "mediapipe_graph_bzl",
|
||||
srcs = [
|
||||
"mediapipe_graph.bzl",
|
||||
],
|
||||
visibility = ["//visibility:public"],
|
||||
deps = [
|
||||
":build_defs_bzl",
|
||||
"//mediapipe/framework:encode_binary_proto",
|
||||
"//mediapipe/framework:transitive_protos_bzl",
|
||||
"//mediapipe/framework/deps:expand_template_bzl",
|
||||
],
|
||||
)
|
||||
|
||||
bzl_library(
|
||||
name = "build_defs_bzl",
|
||||
srcs = [
|
||||
"build_defs.bzl",
|
||||
],
|
||||
visibility = [
|
||||
"//mediapipe/framework:__subpackages__",
|
||||
],
|
||||
)
|
||||
|
||||
cc_library(
|
||||
name = "text_to_binary_graph",
|
||||
srcs = ["text_to_binary_graph.cc"],
|
||||
|
@ -744,5 +769,7 @@ cc_test(
|
|||
|
||||
exports_files(
|
||||
["build_defs.bzl"],
|
||||
visibility = ["//mediapipe/framework:__subpackages__"],
|
||||
visibility = [
|
||||
"//mediapipe/framework:__subpackages__",
|
||||
],
|
||||
)
|
||||
|
|
|
@ -19,6 +19,7 @@
|
|||
#include "mediapipe/framework/tool/sink.h"
|
||||
|
||||
#include <memory>
|
||||
#include <utility>
|
||||
#include <vector>
|
||||
|
||||
#include "absl/strings/str_cat.h"
|
||||
|
@ -168,8 +169,19 @@ void AddMultiStreamCallback(
|
|||
std::function<void(const std::vector<Packet>&)> callback,
|
||||
CalculatorGraphConfig* config,
|
||||
std::pair<std::string, Packet>* side_packet) {
|
||||
std::map<std::string, Packet> side_packets;
|
||||
AddMultiStreamCallback(streams, callback, config, &side_packets,
|
||||
/*observe_timestamp_bounds=*/false);
|
||||
*side_packet = *side_packets.begin();
|
||||
}
|
||||
|
||||
void AddMultiStreamCallback(
|
||||
const std::vector<std::string>& streams,
|
||||
std::function<void(const std::vector<Packet>&)> callback,
|
||||
CalculatorGraphConfig* config, std::map<std::string, Packet>* side_packets,
|
||||
bool observe_timestamp_bounds) {
|
||||
CHECK(config);
|
||||
CHECK(side_packet);
|
||||
CHECK(side_packets);
|
||||
CalculatorGraphConfig::Node* sink_node = config->add_node();
|
||||
const std::string name = GetUnusedNodeName(
|
||||
*config, absl::StrCat("multi_callback_", absl::StrJoin(streams, "_")));
|
||||
|
@ -179,15 +191,23 @@ void AddMultiStreamCallback(
|
|||
sink_node->add_input_stream(stream_name);
|
||||
}
|
||||
|
||||
if (observe_timestamp_bounds) {
|
||||
const std::string observe_ts_bounds_packet_name = GetUnusedSidePacketName(
|
||||
*config, absl::StrCat(name, "_observe_ts_bounds"));
|
||||
sink_node->add_input_side_packet(absl::StrCat(
|
||||
"OBSERVE_TIMESTAMP_BOUNDS:", observe_ts_bounds_packet_name));
|
||||
InsertIfNotPresent(side_packets, observe_ts_bounds_packet_name,
|
||||
MakePacket<bool>(true));
|
||||
}
|
||||
const std::string input_side_packet_name =
|
||||
GetUnusedSidePacketName(*config, absl::StrCat(name, "_callback"));
|
||||
side_packet->first = input_side_packet_name;
|
||||
sink_node->add_input_side_packet(
|
||||
absl::StrCat("VECTOR_CALLBACK:", input_side_packet_name));
|
||||
|
||||
side_packet->second =
|
||||
InsertIfNotPresent(
|
||||
side_packets, input_side_packet_name,
|
||||
MakePacket<std::function<void(const std::vector<Packet>&)>>(
|
||||
std::move(callback));
|
||||
std::move(callback)));
|
||||
}
|
||||
|
||||
void AddCallbackWithHeaderCalculator(const std::string& stream_name,
|
||||
|
@ -240,6 +260,10 @@ absl::Status CallbackCalculator::GetContract(CalculatorContract* cc) {
|
|||
return mediapipe::InvalidArgumentErrorBuilder(MEDIAPIPE_LOC)
|
||||
<< "InputSidePackets must use tags.";
|
||||
}
|
||||
if (cc->InputSidePackets().HasTag("OBSERVE_TIMESTAMP_BOUNDS")) {
|
||||
cc->InputSidePackets().Tag("OBSERVE_TIMESTAMP_BOUNDS").Set<bool>();
|
||||
cc->SetProcessTimestampBounds(true);
|
||||
}
|
||||
|
||||
int count = allow_multiple_streams ? cc->Inputs().NumEntries("") : 1;
|
||||
for (int i = 0; i < count; ++i) {
|
||||
|
@ -266,6 +290,12 @@ absl::Status CallbackCalculator::Open(CalculatorContext* cc) {
|
|||
return mediapipe::InvalidArgumentErrorBuilder(MEDIAPIPE_LOC)
|
||||
<< "missing callback.";
|
||||
}
|
||||
if (cc->InputSidePackets().HasTag("OBSERVE_TIMESTAMP_BOUNDS") &&
|
||||
!cc->InputSidePackets().Tag("OBSERVE_TIMESTAMP_BOUNDS").Get<bool>()) {
|
||||
return mediapipe::InvalidArgumentErrorBuilder(MEDIAPIPE_LOC)
|
||||
<< "The value of the OBSERVE_TIMESTAMP_BOUNDS input side packet "
|
||||
"must be set to true";
|
||||
}
|
||||
return absl::OkStatus();
|
||||
}
|
||||
|
||||
|
|
|
@ -115,6 +115,12 @@ void AddMultiStreamCallback(
|
|||
std::function<void(const std::vector<Packet>&)> callback,
|
||||
CalculatorGraphConfig* config, std::pair<std::string, Packet>* side_packet);
|
||||
|
||||
void AddMultiStreamCallback(
|
||||
const std::vector<std::string>& streams,
|
||||
std::function<void(const std::vector<Packet>&)> callback,
|
||||
CalculatorGraphConfig* config, std::map<std::string, Packet>* side_packets,
|
||||
bool observe_timestamp_bounds = false);
|
||||
|
||||
// Add a CallbackWithHeaderCalculator to intercept packets sent on
|
||||
// stream stream_name, and the header packet on stream stream_header.
|
||||
// The input side packet with the produced name callback_side_packet_name
|
||||
|
|
|
@ -146,5 +146,63 @@ TEST(CallbackTest, TestAddMultiStreamCallback) {
|
|||
EXPECT_THAT(sums, testing::ElementsAre(15, 7, 9));
|
||||
}
|
||||
|
||||
class TimestampBoundTestCalculator : public CalculatorBase {
|
||||
public:
|
||||
static absl::Status GetContract(CalculatorContract* cc) {
|
||||
cc->Outputs().Index(0).Set<int>();
|
||||
cc->Outputs().Index(1).Set<int>();
|
||||
return absl::OkStatus();
|
||||
}
|
||||
absl::Status Open(CalculatorContext* cc) final { return absl::OkStatus(); }
|
||||
absl::Status Process(CalculatorContext* cc) final {
|
||||
if (count_ % 5 == 0) {
|
||||
cc->Outputs().Index(0).SetNextTimestampBound(Timestamp(count_ + 1));
|
||||
cc->Outputs().Index(1).SetNextTimestampBound(Timestamp(count_ + 1));
|
||||
}
|
||||
++count_;
|
||||
if (count_ == 13) {
|
||||
return tool::StatusStop();
|
||||
}
|
||||
return absl::OkStatus();
|
||||
}
|
||||
|
||||
private:
|
||||
int count_ = 1;
|
||||
};
|
||||
REGISTER_CALCULATOR(TimestampBoundTestCalculator);
|
||||
|
||||
TEST(CallbackTest, TestAddMultiStreamCallbackWithTimestampNotification) {
|
||||
std::string config_str = R"(
|
||||
node {
|
||||
calculator: "TimestampBoundTestCalculator"
|
||||
output_stream: "foo"
|
||||
output_stream: "bar"
|
||||
}
|
||||
)";
|
||||
CalculatorGraphConfig graph_config =
|
||||
mediapipe::ParseTextProtoOrDie<CalculatorGraphConfig>(config_str);
|
||||
|
||||
std::vector<int> sums;
|
||||
|
||||
std::map<std::string, Packet> side_packets;
|
||||
tool::AddMultiStreamCallback(
|
||||
{"foo", "bar"},
|
||||
[&sums](const std::vector<Packet>& packets) {
|
||||
Packet foo_p = packets[0];
|
||||
Packet bar_p = packets[1];
|
||||
ASSERT_TRUE(foo_p.IsEmpty() && bar_p.IsEmpty());
|
||||
int foo = foo_p.Timestamp().Value();
|
||||
int bar = bar_p.Timestamp().Value();
|
||||
sums.push_back(foo + bar);
|
||||
},
|
||||
&graph_config, &side_packets, true);
|
||||
|
||||
CalculatorGraph graph(graph_config);
|
||||
MP_ASSERT_OK(graph.StartRun(side_packets));
|
||||
MP_ASSERT_OK(graph.WaitUntilDone());
|
||||
|
||||
EXPECT_THAT(sums, testing::ElementsAre(10, 20));
|
||||
}
|
||||
|
||||
} // namespace
|
||||
} // namespace mediapipe
|
||||
|
|
|
@ -14,7 +14,7 @@
|
|||
|
||||
load("//mediapipe/gpu:metal.bzl", "metal_library")
|
||||
load("@build_bazel_rules_apple//apple:ios.bzl", "ios_unit_test")
|
||||
load("//mediapipe/framework/port:build_config.bzl", "mediapipe_cc_proto_library")
|
||||
load("//mediapipe/framework/port:build_config.bzl", "mediapipe_cc_proto_library", "mediapipe_proto_library")
|
||||
load("//mediapipe/framework:mediapipe_cc_test.bzl", "mediapipe_cc_test")
|
||||
|
||||
licenses(["notice"])
|
||||
|
@ -240,6 +240,12 @@ cc_library(
|
|||
],
|
||||
)
|
||||
|
||||
mediapipe_proto_library(
|
||||
name = "gpu_origin_proto",
|
||||
srcs = ["gpu_origin.proto"],
|
||||
visibility = ["//visibility:public"],
|
||||
)
|
||||
|
||||
objc_library(
|
||||
name = "pixel_buffer_pool_util",
|
||||
srcs = ["pixel_buffer_pool_util.mm"],
|
||||
|
@ -460,6 +466,8 @@ cc_library(
|
|||
"//mediapipe/framework:calculator_context",
|
||||
"//mediapipe/framework:calculator_node",
|
||||
"//mediapipe/framework/port:logging",
|
||||
"//mediapipe/util:resource_cache",
|
||||
"@com_google_absl//absl/hash",
|
||||
"@com_google_absl//absl/memory",
|
||||
"@com_google_absl//absl/synchronization",
|
||||
] + select({
|
||||
|
@ -760,8 +768,10 @@ cc_library(
|
|||
deps = [
|
||||
":gl_calculator_helper",
|
||||
":gl_quad_renderer",
|
||||
":gpu_buffer",
|
||||
":shader_util",
|
||||
"//mediapipe/framework:calculator_framework",
|
||||
"//mediapipe/framework/api2:node",
|
||||
"//mediapipe/framework/port:ret_check",
|
||||
"//mediapipe/framework/port:status",
|
||||
"//mediapipe/gpu:gl_surface_sink_calculator_cc_proto",
|
||||
|
|
|
@ -563,8 +563,14 @@ class GlFenceSyncPoint : public GlSyncPoint {
|
|||
|
||||
void WaitOnGpu() override {
|
||||
if (!sync_) return;
|
||||
// TODO: do not wait if we are already on the same context?
|
||||
// TODO: do not wait if we are already on the same context?
|
||||
// WebGL2 specifies a waitSync call, but since cross-context
|
||||
// synchronization is not supported, it's actually a no-op. Firefox prints
|
||||
// a warning when it's called, so let's just skip the call. See
|
||||
// b/184637485 for details.
|
||||
#ifndef __EMSCRIPTEN__
|
||||
glWaitSync(sync_, 0, GL_TIMEOUT_IGNORED);
|
||||
#endif
|
||||
}
|
||||
|
||||
bool IsReady() override {
|
||||
|
|